在自然语言处理(NLP)任务中,处理超长文本(通常指长度超过模型最大支持长度的文本)是一个常见的挑战。BERT等预训练模型通常具有固定的最大序列长度限制(例如,BERT-base的最大序列长度为512个标记)。当需要处理超过这个长度的文本时,需要采取特定的策略来确保模型能够有效地处理这些数据。
BertModel(
(embeddings): BertEmbeddings(
(word_embeddings): Embedding(21128, 768, padding_idx=0)
(position_embeddings): Embedding(512, 768)
(token_type_embeddings): Embedding(2, 768)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0-11): 12 x BertLayer(
(attention): BertAttention(
(self): BertSdpaSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
(intermediate_act_fn): GELUActivation()
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh()
)
)
BertConfig {
"architectures": [
"BertForMaskedLM"
],
"attention_probs_dropout_prob": 0.1,
"classifier_dropout": null,
"directionality": "bidi",
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"pooler_fc_size": 768,
"pooler_num_attention_heads": 12,
"pooler_num_fc_layers": 3,
"pooler_size_per_head": 128,
"pooler_type": "first_token_transform",
"position_embedding_type": "absolute",
"transformers_version": "4.47.1",
"type_vocab_size": 2,
"use_cache": true,
"vocab_size": 21128
}
以上是bert-base-chinese模型信息,BertEmbeddings中(position_embeddings): Embedding(512, 768)以及模型配置信息中"max_position_embeddings": 512表示模型最大序列长度为512个标记。
('早报讯 中国地球深部探测计划已全面展开,这是中国有史以来最大的地质勘探计划,也被认为是“中国挺进地心第一步”。“深部探测专项”负责人、中国地质科学院副院长董树文说,中国所面临的资源挑战,是启动深部探测工程的最大动因。打个万米孔需投数亿元据中央电视台《新闻联播》报道,在海拔4400米的喜马拉雅山罗布莎地区,国家深部探测专项——罗布莎科学钻探实验正在紧张进行。除罗布莎之外,山东莱阳、云南腾冲等地的6个钻探项目也在进行中,国家将从这7个钻探点中选择一处进行超越1万米的科学钻探。科学钻是人类获取地球内部信息最有效、最直观的方法。目前全球仅有苏联的“科拉超级钻”达到过1万米以下的深度。中国曾在2001年于江苏东海县启动了自己的“超级钻”,它在2005年达到了5100多米的深度。据《?望东方周刊》报道,这个深度钻井设备存放在四川宏华集团。它高45米,占地1万平方米,重量超过1000吨。在运输时,这套设备拆分后需要30到50辆大卡车。如果达到1万米的钻井深度要求,制造费用预计超过1亿元。而一个科学钻“打一个1万米深的孔就要数亿元投入”,董树文说。与此同时,深部探测计划的另一个实验项目——地震反射剖面探测也在西藏阿里进行。这种探测是用地下爆破的方法,通过追踪反射信号,探明数十公里地下的结构,用科学家的话说就是“给地球深层做一个CT”。据了解,科学钻探和深地震剖面探测只是中国地球深部探测计划的两个组成部分,这一计划集合了12位院士、200多名研究员以及上千名科研人员,共实施大地电磁探测、地壳全元素探测等九个实验项目,将在2012年底前完成。该计划预算达到30亿元,希望从深层次了解一系列中国人关心的重大问题:从油气蕴藏、矿产分布,一直到青藏高原的扩展还将在四川产生何种地质变动。最大动因:资源能源缺口中国所面临的资源挑战,是启动深部探测工程的最大动因。它也将成为未来实施地壳工程的主要推力。“资源能源缺口是立项的第一出发点。”“深部探测专项”负责人、中国地质科学院副院长董树文此前向《?望东方周刊》记者说。事实上,早在2004年9月中国就启动了全国重要矿产危机矿山找矿专项,在原有矿山300米至500米的开采深度上,向更深部进行勘探开采。到2008年经过评估,这一工程已为中国带来了价值超过1万亿元人民币的矿产资源。目前深部探测项目取得的最大成果之一,就是对长江中下游成矿带的深度阐述,首次实现了大型矿区三维透明化。在深部探测计划的基础上,中国科学家正在筹划详细揭示中国地壳结构的“地壳探测工程”,为保障资源供应、防灾减灾和发展地球科学提供全面的深部数据和信息。地下4000米变“透明”据悉,国土资源部系统的科研机构2002年即提出启动“地壳探测工程”,但由于预算高达30亿元,项目申报被长时间搁置。此后,资源缺口凸显、地质灾害频发与地球科学相对落后并存的局面,使开展深部探测工程的诉求日益强烈。2006年,国务院发布《关于加强地质工作的决定》,其中明确提出“实施地壳探测工程,提高地球认知、资源勘查和灾害预警水平”,堪称中国地壳探测工程的转折点,次年,中国地质科学院再度开始申报进行地壳探测工程,最终2008年5月发生在四川的汶川地震给专项带来了转机。2008年“深部技术探测与实验研究”项目启动当年,即得到了7000多万元启动资金。涉及地质学、地球物理学等多个基础学科的庞大工程就此启动,中国人也加入了向地心挺进的行列。2009年11月3日,国务院总理温家宝在首都科技界大会上谈到,在地球深部资源探测方面,中国已有固体矿产勘探开采的深度大都小于500米,而世界一些矿业大国已经达到2500米到4000米,南非计划开采的深度达到6000米。他当时还举例说,澳大利亚在本世纪初率先提出“玻璃地球”计划,也就是要使地下1000米变得“透明”;加拿大人近期提出的类似计划,要搞到3000米。据报道,欧美国家已纷纷开展了各自的深部探测计划,其中美国2003年启动的“地球透镜计划”将在15年内投入200亿美元;到2005年,俄罗斯在欧洲部分和盛产能源的西西伯利亚实施了十多处超级钻项目,目前该国已是全球最大的资源蕴藏地和出口国;而深部探测还使加拿大和澳大利亚在最近20年来始终保持着世界资源勘探大国的地位。2010年,全国人大副委员长、中国科学院原院长路甬祥曾强调:深部矿产资源勘探与开发是影响中国可持续发展能力的战略性科技问题。他说,应使中国主要区域地下4000米变得“透明”,以解决中国资源短缺的瓶颈。', 9)
('虎年旺财,每天22时的钟声响彻封神,散财童子在线准时为玩家发放“吉星高照礼包”,宝石原石、雕琢符、开封卷轴等众多宝贝免费白拿!每天一个小时“童子献宝”活动,玩家上线即可拥抱好礼的狂欢时刻!每天22时至23时,散财童子恭候玩家为全服在线勇士们随机发放物品。他在一个小时内,每十五分钟分四次放送礼包,囊括各种宝石原石、开封卷轴、雕琢符、修补石、女娲玉精致物品。好礼多多拿到手软,今天你是《封神ol》在线获奇宝的幸运儿么?每天欢天喜地准时上线来《封神ol》尽情免费拿大礼吧,封神放出的惊喜连绵不断,让你乐不停HIGH翻天!首款大型3G神话手机网游《封神online》:精致细腻的立体画面*真实唯美的百种场景;海量新奇的剧情任务*独创万种的极品换装;激发能量的百变宝石*神秘莫测的洞装开封;尊贵告诉的坐骑马车*玩法多样的职业种族;刺激诱惑的副本探险*淋漓尽致的跨服战场;统领一方的霸气诸侯*爽快过瘾的PK击杀;甜蜜特权的结婚系统*亲切畅谈的聊天交友;开发商:空中网旗下知名手机游戏开发商天津猛犸下载方式:手机登录fs.kong.net下载或发送手机短信"FS"到106633554455免费下载。客服电话:022-58113111客服信箱:gamekefu@kongzhong.com', 4)
使用bert-base-chinese对以上文本进行训练时
RuntimeError: The size of tensor a (1024) must match the size of tensor b (512) at non-singleton dimension 1
超长文本训练的挑战
序列长度限制:BERT等预训练模型通常具有固定的最大序列长度限制。例如,BERT-base的最大序列长度为512个标记。当文本长度超过这个限制时,直接输入模型会导致截断,从而丢失重要信息。
信息丢失:超长文本截断可能导致重要信息丢失,尤其是在文本的开头或结尾部分。这些部分可能包含关键的上下文信息,截断后会影响模型的理解能力。
性能下降:由于信息丢失和截断,模型在处理超长文本时的性能可能会显著下降。特别是在需要理解全局语义的任务中,如文档分类、摘要生成等,超长文本的处理尤为重要。
解决方案
文本分段
滑动窗口
滑动窗口是一种常见的方法,通过将超长文本分割成多个重叠的子序列进行处理。具体步骤如下:
分割文本:将超长文本分割成多个长度为max_length的子序列,子序列之间有重叠部分。
处理子序列:对每个子序列分别进行编码,提取特征。
合并特征:将多个子序列的特征进行合并,可以使用简单的平均、最大池化或更复杂的融合方法。
def sliding_window(text, tokenizer, max_length=512, overlap=128):
inputs = tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=max_length,
truncation=True,
stride=overlap,
return_tensors="pt",
return_overflowing_tokens=True
)
return inputs["input_ids"], inputs["attention_mask"], inputs["token_type_ids"]
# 示例使用
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
text = "超长文本内容..."
input_ids, attention_mask, token_type_ids = sliding_window(text, tokenizer)
文档切片
文档切片是另一种方法,通过将超长文本分割成多个独立的子文档进行处理。具体步骤如下:
分割文本:将超长文本分割成多个独立的子文档。
处理子文档:对每个子文档分别进行编码,提取特征。
合并特征:将多个子文档的特征进行合并,可以使用简单的平均、最大池化或更复杂的融合方法。
def document_chunking(text, tokenizer, max_length=512):
tokens = tokenizer.tokenize(text)
chunks = [tokens[i:i + max_length] for i in range(0, len(tokens), max_length)]
input_ids = []
attention_mask = []
token_type_ids = []
for chunk in chunks:
inputs = tokenizer.encode_plus(
chunk,
add_special_tokens=True,
max_length=max_length,
truncation=True,
return_tensors="pt"
)
input_ids.append(inputs["input_ids"])
attention_mask.append(inputs["attention_mask"])
token_type_ids.append(inputs["token_type_ids"])
return input_ids, attention_mask, token_type_ids
# 示例使用
input_ids, attention_mask, token_type_ids = document_chunking(text, tokenizer)
固定长度分段
将文本按照固定长度进行分段,并在每个段落之间添加特殊标记(如[SEP])以表示分段。
def segment_text(text, max_length=512, overlap=50):
segments = []
for i in range(0, len(text), max_length - overlap):
segment = text[i:i + max_length]
segments.append(segment)
return segments
# 示例
segments = segment_text("这是一个非常长的文本,需要进行分段以适应模型的最大序列长度。")
句子边界分段
将文本按照句子边界进行分段,确保每个段落包含完整的句子。
import re
def segment_text_by_sentences(text, max_length=512):
sentences = re.split(r'(?<=[。!?])', text)
segments = []
current_segment = ""
for sentence in sentences:
if len(current_segment) + len(sentence) <= max_length:
current_segment += sentence
else:
segments.append(current_segment)
current_segment = sentence
if current_segment:
segments.append(current_segment)
return segments
# 示例
segments = segment_text_by_sentences("这是一个非常长的文本。需要进行分段以适应模型的最大序列长度!")
使用支持超长序列的模型
一些预训练模型已经支持超长序列。例如,Longformer、Reformer和BigBird等模型能够处理超过512个标记的序列。
Longformer
Longformer通过稀疏注意力机制(Sparse Attention Mechanism)支持超长序列。具体使用方法如下:
from transformers import LongformerModel, LongformerTokenizer
tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")
model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
text = "超长文本内容..."
inputs = tokenizer(text, return_tensors="pt", truncation=False)
outputs = model(**inputs)
Reformer
Reformer通过局部敏感哈希(Locality-Sensitive Hashing)和块稀疏注意力(Chunked Sparse Attention)支持超长序列。具体使用方法如下:
from transformers import ReformerModel, ReformerTokenizer
tokenizer = ReformerTokenizer.from_pretrained("google/reformer-crime-and-punishment")
model = ReformerModel.from_pretrained("google/reformer-crime-and-punishment")
text = "超长文本内容..."
inputs = tokenizer(text, return_tensors="pt", truncation=False)
outputs = model(**inputs)
BigBird
BigBird通过块稀疏注意力(Block Sparse Attention)支持超长序列。具体使用方法如下:
from transformers import BigBirdModel, BigBirdTokenizer
tokenizer = BigBirdTokenizer.from_pretrained("google/bigbird-roberta-base")
model = BigBirdModel.from_pretrained("google/bigbird-roberta-base")
text = "超长文本内容..."
inputs = tokenizer(text, return_tensors="pt", truncation=False)
outputs = model(**inputs)
特征融合
在处理超长文本时,可以将多个子序列或子文档的特征进行融合,以获得更全面的文本表示。常见的融合方法包括:
平均池化
将多个子序列或子文档的特征进行平均池化,得到一个综合特征向量。
def average_pooling(features):
return torch.mean(features, dim=0)
# 示例使用
features = [model(input_ids[i], attention_mask[i], token_type_ids[i]).last_hidden_state[:, 0] for i in range(len(input_ids))]
average_feature = average_pooling(torch.stack(features))
最大池化
将多个子序列或子文档的特征进行最大池化,得到一个综合特征向量。
def max_pooling(features):
return torch.max(features, dim=0).values
# 示例使用
max_feature = max_pooling(torch.stack(features))
注意力融合
使用注意力机制对多个子序列或子文档的特征进行加权融合,得到一个综合特征向量。
import torch.nn.functional as F
class AttentionFusion(torch.nn.Module):
def __init__(self, input_dim):
super(AttentionFusion, self).__init__()
self.attention = torch.nn.Linear(input_dim, 1)
def forward(self, features):
attention_weights = F.softmax(self.attention(features), dim=0)
fused_feature = torch.sum(attention_weights * features, dim=0)
return fused_feature
# 示例使用
attention_fusion = AttentionFusion(768)
fused_feature = attention_fusion(torch.stack(features))
自定义模型
在某些情况下,可能需要自定义模型以更好地处理超长文本。可以通过以下方法进行自定义:
层次化模型
层次化模型通过多层处理逐步提取特征。例如,先对子序列进行编码,再对子序列的特征进行进一步编码。
class HierarchicalModel(torch.nn.Module):
def __init__(self, base_model, num_classes):
super(HierarchicalModel, self).__init__()
self.base_model = base_model
self.fc = torch.nn.Linear(base_model.config.hidden_size, num_classes)
def forward(self, input_ids, attention_mask, token_type_ids):
features = []
for i in range(len(input_ids)):
output = self.base_model(input_ids[i], attention_mask[i], token_type_ids[i])
features.append(output.last_hidden_state[:, 0])
features = torch.stack(features)
pooled_feature = torch.mean(features, dim=0)
logits = self.fc(pooled_feature)
return logits
# 示例使用
base_model = BertModel.from_pretrained("bert-base-chinese")
hierarchical_model = HierarchicalModel(base_model, num_classes=10)
logits = hierarchical_model(input_ids, attention_mask, token_type_ids)
自定义注意力机制
自定义注意力机制可以更好地捕捉超长文本中的全局信息。例如,使用全局注意力机制对整个文本进行编码。
class GlobalAttentionModel(torch.nn.Module):
def __init__(self, base_model, num_classes):
super(GlobalAttentionModel, self).__init__()
self.base_model = base_model
self.attention = torch.nn.Linear(base_model.config.hidden_size, 1)
self.fc = torch.nn.Linear(base_model.config.hidden_size, num_classes)
def forward(self, input_ids, attention_mask, token_type_ids):
features = []
for i in range(len(input_ids)):
output = self.base_model(input_ids[i], attention_mask[i], token_type_ids[i])
features.append(output.last_hidden_state)
features = torch.stack(features)
attention_weights = F.softmax(self.attention(features), dim=1)
pooled_feature = torch.sum(attention_weights * features, dim=1)
logits = self.fc(pooled_feature[:, 0])
return logits
# 示例使用
global_attention_model = GlobalAttentionModel(base_model, num_classes=10)
logits = global_attention_model(input_ids, attention_mask, token_type_ids)
修改模型模型最大序列长度
如果现有的模型无法满足需求,可以自定义模型以支持超长序列。例如,可以通过堆叠多个Transformer层来增加模型的上下文理解能力。
from transformers import BertModel, BertConfig
import torch
class CustomLongModel(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.bert = BertModel(config)
self.fc = torch.nn.Linear(768, 10)
def forward(self, input_ids, attention_mask, token_type_ids):
out = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
out = self.fc(out.last_hidden_state[:, 0])
return out
#加载预训练模型
# pretrained = BertModel.from_pretrained(r"E:\PycharmProjects\demo_7\model\bert-base-chinese\models--bert-base-chinese\snapshots\c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f").to(DEVICE)
# pretrained.embeddings.position_embeddings = torch.nn.Embedding(1024,768).to(DEVICE)
# 自定义配置
config = BertConfig.from_pretrained("bert-base-chinese")
config.max_position_embeddings = 1024
# 初始化模型
model = CustomLongModel(config).to(DEVICE)
print(model)
使用分段编码
将文本分段后,对每个段落进行编码,并将这些编码结果进行聚合。聚合方法可以是简单的平均池化或更复杂的注意力机制。
def segment_and_encode(text, tokenizer, model, max_length=512, overlap=50):
segments = segment_text(text, max_length, overlap)
segment_encodings = []
for segment in segments:
inputs = tokenizer(segment, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
segment_encodings.append(outputs.last_hidden_state[:, 0])
return torch.stack(segment_encodings)
# 示例
segment_encodings = segment_and_encode("这是一个非常长的文本,需要进行分段编码。", tokenizer, model)
实践案例
Longformer模型案例
from transformers import LongformerModel, LongformerTokenizer
# 加载预训练模型和分词器
tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")
model = LongformerModel.from_pretrained("allenai/longformer-base-4096").to(DEVICE)
# 示例文本
text = "这是一个非常长的文本,需要使用支持超长序列的模型来处理。" * 1000
# 编码文本
inputs = tokenizer(text, return_tensors="pt", truncation=False)
outputs = model(**inputs)
# 打印输出
print(outputs.last_hidden_state.shape)
通过使用Longformer模型,我们能够处理长度超过512个标记的文本,而无需进行截断或分段。
自定义模型来处理超长文本
使用自定义模型来处理超长文本,并对每个段落进行编码和聚合:
from transformers import BertModel, BertConfig
import torch
# 自定义配置
config = BertConfig.from_pretrained("bert-base-chinese")
config.max_position_embeddings = 1024
# 初始化模型
model = CustomLongModel(config).to(DEVICE)
# 示例文本
text = "这是一个非常长的文本,需要进行分段编码。" * 1000
# 分段编码
segment_encodings = segment_and_encode(text, tokenizer, model)
# 聚合编码结果
pooled_output = torch.mean(segment_encodings, dim=0)
# 打印聚合结果
print(pooled_output.shape)
修改模型编码长度
from torch.utils.data import Dataset
from datasets import load_dataset
class MyDataset(Dataset):
def __init__(self,split):
#从磁盘加载数据
self.dataset = load_dataset(path="csv",data_files=f"/Volumes/Date/huggingface/dataset/news/{split}.csv",split="train")
def __len__(self):
return len(self.dataset)
def __getitem__(self, item):
text = self.dataset[item]["text"]
label = self.dataset[item]["label"]
return text,label
from transformers import BertModel, BertConfig
import torch
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载预训练模型
#pretrained = BertModel.from_pretrained("/Volumes/Date/huggingface/model/bert-base-chinese/models--bert-base-chinese/snapshots/c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f").to(DEVICE)
# pretrained.embeddings.position_embeddings = torch.nn.Embedding(1024,768).to(DEVICE)
config = BertConfig.from_pretrained("/Volumes/Date/huggingface/model/bert-base-chinese/models--bert-base-chinese/snapshots/c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f")
config.max_position_embeddings = 1024
# print(config)
# 使用配置文件初始化模型
pretrained = BertModel(config).to(DEVICE)
#print(pretrained)
# 定义下游任务
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc = torch.nn.Linear(768, 10)
def forward(self, input_ids, attention_mask, token_type_ids):
# 冻结预训练模型权重
#with torch.no_grad():
# out = pretrained(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
out = pretrained(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
# 增量模型参与训练
out = self.fc(out.last_hidden_state[:, 0])
return out
#模型训练
import torch
from MyData import MyDataset
from torch.utils.data import DataLoader
from net import Model
from transformers import BertTokenizer,AdamW
#定义设备信息
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#定义训练的轮次
EPOCH= 30000
token = BertTokenizer.from_pretrained("/Volumes/Date/huggingface/model/bert-base-chinese/models--bert-base-chinese/snapshots/c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f")
def collate_fn(data):
sents = [i[0]for i in data]
label = [i[1] for i in data]
#编码
data = token.batch_encode_plus(
batch_text_or_text_pairs=sents,
truncation=True,
max_length=1024,
padding="max_length",
return_tensors="pt",
return_length=True
)
input_ids = data["input_ids"]
attention_mask = data["attention_mask"]
token_type_ids = data["token_type_ids"]
labels = torch.LongTensor(label)
return input_ids,attention_mask,token_type_ids,labels
#创建数据集
train_dataset = MyDataset("train")
train_loader = DataLoader(
dataset=train_dataset,
batch_size=2,
shuffle=True,
#舍弃最后一个批次的数据,防止形状出错
drop_last=True,
#对加载进来的数据进行编码
collate_fn=collate_fn
)
val_dataset = MyDataset("validation")
val_loader = DataLoader(
dataset=val_dataset,
batch_size=2,
shuffle=True,
#舍弃最后一个批次的数据,防止形状出错
drop_last=True,
#对加载进来的数据进行编码
collate_fn=collate_fn
)
if __name__ == '__main__':
#开始训练
print(DEVICE)
model = Model().to(DEVICE)
#定义优化器
optimizer = AdamW(model.parameters())
#定义损失函数
loss_func = torch.nn.CrossEntropyLoss()
#初始化最佳验证准确率
best_val_acc = 0.0
for epoch in range(EPOCH):
for i,(input_ids,attention_mask,token_type_ids,labels) in enumerate(train_loader):
#将数据存放到DEVICE上
input_ids, attention_mask, token_type_ids, labels = input_ids.to(DEVICE),attention_mask.to(DEVICE),token_type_ids.to(DEVICE),labels.to(DEVICE)
#前向计算(将数据输入模型,得到输出)
out = model(input_ids=input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
#根据输出,计算损失
loss = loss_func(out,labels)
#根据损失,优化参数
optimizer.zero_grad()
loss.backward()
optimizer.step()
#每隔5个批次输出训练信息
if i%5==0:
out = out.argmax(dim=1)
acc = (out==labels).sum().item()/len(labels)
print(f"epoch:{epoch},i:{i},loss:{loss.item()},acc:{acc}")
#验证模型(判断是否过拟合)
#设置为评估模式
model.eval()
#不需要模型参与训练
with torch.no_grad():
val_acc = 0.0
val_loss = 0.0
for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(val_loader):
# 将数据存放到DEVICE上
input_ids, attention_mask, token_type_ids, labels = input_ids.to(DEVICE), attention_mask.to(
DEVICE), token_type_ids.to(DEVICE), labels.to(DEVICE)
# 前向计算(将数据输入模型,得到输出)
out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
# 根据输出,计算损失
val_loss += loss_func(out, labels)
out = out.argmax(dim=1)
val_acc+=(out==labels).sum().item()
val_loss /= len(val_loader)
val_acc /= len(val_loader)
print(f"验证集:loss:{val_loss},acc:{val_acc}")
#根据验证准确率保存最优参数
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(),"params/best_bert.pth")
print(f"Epoch:{epoch}:保存最优参数:acc:{best_val_acc}")
#保存最后一轮参数
torch.save(model.state_dict(),f"params/last_bert.pth")
print(epoch,f"Epcot:{epoch}最后一轮参数保存成功!")