随风而逝_似我飘零:
文章写得很好,代码有几处有几点优化建议,仅供参考[code=python]
class MyDataset(Dataset):
def __init__(self, texts, labels, max_length):
self.all_text = texts
self.all_label = labels
self.max_len = max_length
self.tokenizer = BertTokenizer.from_pretrained(parsers().bert_pred)
def __getitem__(self, index):
# 根据tokenizer生成文本中每个token的位置编码和掩码信息
result = self.tokenizer.encode_plus(text=self.all_text[index], max_length=self.max_len, padding='max_length',
truncation=True, return_tensors='pt')
# 标签信息
label = int(self.all_label[index])
# 将所有信息都转化为tensor处理
token_ids = result.input_ids
mask = result.attention_mask
label = torch.tensor(label)
return (token_ids, mask), label
def __len__(self):
# 得到文本的长度
return len(self.all_text)
[/code]
---
[code=python]
def train():
# 加载参数
args = parsers()
device = "cuda:0" if torch.cuda.is_a
[/code]