XLNet 的 fine-tune 实现

Posted by YiKe on August 28, 2020

本文主要是基于英文文本关系抽取比赛,讲解如何fine-tune Huggingface的预训练模型,同时可以看作是关系抽取的一个简单案例

数据预览

训练数据包含两列。第一列是文本,其中<e1></e1>包起来的是第一个实体,<e2></e2>包起来的是第二个实体。第二列是关系,除了”Other”这个关系以外,其它关系都有先后顺序,比方说第一个样本:”伤害是由邀请制度造成的”,其中”harm”是因果关系中的果,而”system”是因,所以最终label是”Cause-Effect(e2,e1)”

思路

总体来说都是用Huggingface的预训练模型做fine-tune,但是具体怎么做有很多种方法

第一种方法

由于label非常特殊,不光要预测实体之间的关系,还要预测其顺序,因此可以考虑预测三个值,具体如下图所示(以第一个句子为例)

[CLS]位置的输出会做一个10分类,而两个实体对应的输出分别做一个3分类。这个3分类是有讲究的,分类为0只有一种情况,就是当[CLS]的输出预测为”Other”的时候,因为”Other”是不需要考虑两个实体的顺序的;另外,假如”harm”预测为1,”system”预测为2,表示在当前这个关系中”harm”的顺序在”system”的前面,反之一样的道理。通过预测三个值,就可以唯一确定一个关系了

第二种方法

先看下图,具体来说就是将两个实体的输出拼接在一起,然后做一个18分类

第三种方法

这是最简单的一种方法,本文使用的也是这种方法。直接将[CLS]的输出进行一个18分类

其实还有其他更多方法,大家自己下去尝试即可

Data Preprocessing

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def process_data(filename):
    with open(filename) as f:
        rows = [row for row in csv.reader(f)]
        rows = np.array(rows[1:]) # all data, 2D
        label_list = [label for _, label in rows] # label list
        global classes_list
        classes_list = list(set(label_list)) # non-repeated label list
        num_classes = len(classes_list) # num of classes
        for i in range(len(label_list)):
            label_list[i] = classes_list.index(label_list[i]) # index of label

        name_list, sentence_list = [], []
        for sentence, _ in rows:
            begin = sentence.find('<e1>')
            end = sentence.find('</e1>')
            e1 = sentence[begin:end + 5]

            begin = sentence.find('<e2>')
            end = sentence.find('</e2>')
            e2 = sentence[begin:end + 5]
            
            name_list.append(e1 + " " + e2)
            sentence_list.append(sentence)
    print(num_classes)
    return name_list, sentence_list, label_list, classes_list, num_classes

name_list是个一维的list,里面存了每一行文本中两个实体的名称,两个实体之间用空格隔开。sentence_list是个一维的list,里面存了每一行文本。label_list是个一维的list,里面的值是int类型的,就是将原本str类型的label标签转为对应的index。classes_list就是去重后的label。num_classes就是len(classes_list)

这里我用的是第三种思路做的,因此后面是用不到name_list的,但我还是将其提取出来,方便后面读者调用

XLNetTokenizer

接下来要做的是将提取出来的sentence_list经过XLNetTokenizer,以每句话为单位,获取一句话中所有词的索引,attention mask等相关内容

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')

def convert(names, sentences, target): # name_list, sentence_list, label_list
    input_ids, token_type_ids, attention_mask = [], [], []
    for i in range(len(sentences)):
        encoded_dict = tokenizer.encode_plus(
            sentences[i],        # 输入文本
            add_special_tokens = True,      # 添加 '[CLS]' 和 '[SEP]'
            max_length = 96,             # 不够填充
            padding = 'max_length',
            truncation = True,             # 太长截断
            return_tensors = 'pt',         # 返回 pytorch tensors 格式的数据
        )
        input_ids.append(encoded_dict['input_ids'])
        token_type_ids.append(encoded_dict['token_type_ids'])
        attention_mask.append(encoded_dict['attention_mask'])

    input_ids = torch.cat(input_ids, dim=0)
    token_type_ids = torch.cat(token_type_ids, dim=0)
    attention_mask = torch.cat(attention_mask, dim=0)

    input_ids = torch.LongTensor(input_ids)
    token_type_ids = torch.LongTensor(token_type_ids)
    attention_mask = torch.LongTensor(attention_mask)
    target = torch.LongTensor(target)

    return input_ids, token_type_ids, attention_mask, target

Train Test Split

为了验证模型效果,所以我们将从训练数据中划分一部分作为验证集,只需调用sklearn中的train_test_split方法即可

1
2
3
4
5
6
7
8
9
train_inputs, val_inputs, train_labels, val_labels = train_test_split(input_ids, labels, random_state=1, test_size=0.1)
train_token, val_token, _, _ = train_test_split(token_type_ids, labels, random_state=1, test_size=0.1)
train_mask, val_mask, _, _ = train_test_split(attention_mask, labels, random_state=1, test_size=0.1)

train_data = Data.TensorDataset(train_inputs, train_token, train_mask, train_labels)
train_dataloader = Data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

validation_data = Data.TensorDataset(val_inputs, val_token, val_mask, val_labels)
validation_dataloader = Data.DataLoader(validation_data, batch_size=batch_size, shuffle=True)

这里需要注意的是,由于每个样本的ids,token,mask是一一对应的,所以我们必须设置三个train_test_split方法中的random_state值相等,否则他们拆分得到的数据就乱了

XLNetForSequenceClassification

由于这里我是用的是简单的句子分类思路,直接调用Huggingface中有现成的API即可(注意设定分类个数)。下面的代码参考自Huggingface Docs中的Training and fine-tuning

1
2
3
4
5
6
7
8
9
10
11
model = XLNetForSequenceClassification.from_pretrained('xlnet-base-cased', num_labels=num_classes).to(device)

param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'gamma', 'beta']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
    'weight_decay_rate': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
    'weight_decay_rate': 0.0}]

optimizer = AdamW(optimizer_grouped_parameters, lr=1e-5)

Train & Val

训练的时候,直接传给model指定的参数即可

1
2
3
4
5
6
7
8
9
10
11
12
for _ in range(2):
    for i, batch in enumerate(train_dataloader):
        batch = tuple(t.to(device) for t in batch)
        loss = model(batch[0], token_type_ids=batch[1], attention_mask=batch[2], labels=batch[3])[0]
        print(loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 10 == 0:
          eval(model, validation_dataloader)

这个model的返回值可以查看XLNetForSequenceClassificatiton的API文档。如果需要loss进行反向传播,取出第一个值即可;如果需要预测值,则取出第二个值;如果需要每层、每个词的隐藏状态,则取出第四个值

在训练的过程中,每经过10个epoch,就在验证集上测试一次。在validation以及test时我们是不会传给模型真实标签的,因此模型也不会返回loss,所以此时模型返回的第一个值不再是loss,而是logits

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten() # [3, 5, 8, 1, 2, ....]
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

def eval(model, validation_dataloader):
    model.eval()
    eval_loss, eval_accuracy, nb_eval_steps = 0, 0, 0
    for batch in validation_dataloader:
        batch = tuple(t.to(device) for t in batch)
        with torch.no_grad():
            logits = model(batch[0], token_type_ids=batch[1], attention_mask=batch[2])[0]
            logits = logits.detach().cpu().numpy()
            label_ids = batch[3].cpu().numpy()
            tmp_eval_accuracy = flat_accuracy(logits, label_ids)
            eval_accuracy += tmp_eval_accuracy
            nb_eval_steps += 1
    print("Validation Accuracy: {}".format(eval_accuracy / nb_eval_steps))
    global best_score
    if best_score < eval_accuracy / nb_eval_steps:
        best_score = eval_accuracy / nb_eval_steps
        save(model)

在验证集上计算准确率的同时,保存到目前为止在验证集上准确率最高的模型参数,后面真正做测试的时候就用这个参数。下面的代码涉及到保存模型的操作

1
2
3
4
5
6
7
8
output_dir = './models/'
output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
output_config_file = os.path.join(output_dir, CONFIG_NAME)

def save(model):
    # save
    torch.save(model.state_dict(), output_model_file)
    model.config.to_json_file(output_config_file)

如果要加载模型,只需要一行代码即可

1
2
# load model
model = XLNetForSequenceClassification.from_pretrained(output_dir).to(device)

Test

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def pred():
    # load model
    model = XLNetForSequenceClassification.from_pretrained(output_dir).to(device)

    sentence_list = []
    with open('test.csv') as f:
        rows = [row for row in csv.reader(f)]
        rows = np.array(rows[1:])
        sentence_list = [text for idx, text in rows]

    input_ids, token_type_ids, attention_mask, _ = convert(['test'], sentence_list, [1]) # whatever name_list and label_list
    dataset = Data.TensorDataset(input_ids, token_type_ids, attention_mask)
    loader = Data.DataLoader(dataset, 32, False)

    pred_label = []
    model.eval()
    for i, batch in enumerate(loader):
        batch = tuple(t.to(device) for t in batch)
        with torch.no_grad():
            logits = model(batch[0], token_type_ids=batch[1], attention_mask=batch[2])[0]
            logits = logits.detach().cpu().numpy()
            preds = np.argmax(logits, axis=1).flatten()
            pred_label.extend(preds)
    
    for i in range(len(pred_label)):
        pred_label[i] = classes_list[pred_label[i]]

    pd.DataFrame(data=pred_label, index=range(len(pred_label))).to_csv('pred.csv')

其实这里没有用到name_list和真实的标签,但是为了能复用前面的convert()函数,保证函数参数一致,所以我随便传了两个参数['test'][1]

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import os
import csv
import torch
import transformers
import numpy as np
import pandas as pd
from transformers import *
import torch.utils.data as Data
from sklearn.model_selection import train_test_split

transformers.logging.set_verbosity_error()
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

best_score = 0
batch_size = 32
classes_list = []

output_dir = './models/'
output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
output_config_file = os.path.join(output_dir, CONFIG_NAME)

def process_data(filename):
    with open(filename) as f:
        rows = [row for row in csv.reader(f)]
        rows = np.array(rows[1:]) # all data, 2D
        label_list = [label for _, label in rows] # label list
        global classes_list
        classes_list = list(set(label_list)) # non-repeated label list
        num_classes = len(classes_list) # num of classes
        for i in range(len(label_list)):
            label_list[i] = classes_list.index(label_list[i]) # index of label

        name_list, sentence_list = [], []
        for sentence, _ in rows:
            begin = sentence.find('<e1>')
            end = sentence.find('</e1>')
            e1 = sentence[begin:end + 5]

            begin = sentence.find('<e2>')
            end = sentence.find('</e2>')
            e2 = sentence[begin:end + 5]
            
            name_list.append(e1 + " " + e2)
            sentence_list.append(sentence)
    print(num_classes)
    return name_list, sentence_list, label_list, classes_list, num_classes

def convert(names, sentences, target): # name_list, sentence_list, label_list
    input_ids, token_type_ids, attention_mask = [], [], []
    for i in range(len(sentences)):
        encoded_dict = tokenizer.encode_plus(
            sentences[i],        # 输入文本
            add_special_tokens = True,      # 添加 '[CLS]' 和 '[SEP]'
            max_length = 96,           # 填充 & 截断长度
            pad_to_max_length = True,
            return_tensors = 'pt',         # 返回 pytorch tensors 格式的数据
        )
        input_ids.append(encoded_dict['input_ids'])
        token_type_ids.append(encoded_dict['token_type_ids'])
        attention_mask.append(encoded_dict['attention_mask'])

    input_ids = torch.cat(input_ids, dim=0)
    token_type_ids = torch.cat(token_type_ids, dim=0)
    attention_mask = torch.cat(attention_mask, dim=0)

    input_ids = torch.LongTensor(input_ids)
    token_type_ids = torch.LongTensor(token_type_ids)
    attention_mask = torch.LongTensor(attention_mask)
    target = torch.LongTensor(target)

    return input_ids, token_type_ids, attention_mask, target

def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten() # [3, 5, 8, 1, 2, ....]
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

def save(model):
    # save
    torch.save(model.state_dict(), output_model_file)
    model.config.to_json_file(output_config_file)


def eval(model, validation_dataloader):
    model.eval()
    eval_loss, eval_accuracy, nb_eval_steps = 0, 0, 0
    for batch in validation_dataloader:
        batch = tuple(t.to(device) for t in batch)
        with torch.no_grad():
            logits = model(batch[0], token_type_ids=batch[1], attention_mask=batch[2])[0]
            logits = logits.detach().cpu().numpy()
            label_ids = batch[3].cpu().numpy()
            tmp_eval_accuracy = flat_accuracy(logits, label_ids)
            eval_accuracy += tmp_eval_accuracy
            nb_eval_steps += 1
    print("Validation Accuracy: {}".format(eval_accuracy / nb_eval_steps))
    global best_score
    if best_score < eval_accuracy / nb_eval_steps:
        best_score = eval_accuracy / nb_eval_steps
        save(model)

def train_eval():
    name_list, sentence_list, label_list, _, num_classes = process_data('train.csv')
    input_ids, token_type_ids, attention_mask, labels = convert(name_list, sentence_list, label_list)

    train_inputs, val_inputs, train_labels, val_labels = train_test_split(input_ids, labels, random_state=1, test_size=0.1)
    train_token, val_token, _, _ = train_test_split(token_type_ids, labels, random_state=1, test_size=0.1)
    train_mask, val_mask, _, _ = train_test_split(attention_mask, labels, random_state=1, test_size=0.1)
    
    train_data = Data.TensorDataset(train_inputs, train_token, train_mask, train_labels)
    train_dataloader = Data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

    validation_data = Data.TensorDataset(val_inputs, val_token, val_mask, val_labels)
    validation_dataloader = Data.DataLoader(validation_data, batch_size=batch_size, shuffle=True)

    model = XLNetForSequenceClassification.from_pretrained('xlnet-base-cased', num_labels=num_classes).to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay_rate': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay_rate': 0.0}]

    optimizer = AdamW(optimizer_grouped_parameters, lr=1e-5)

    for _ in range(2):
        for i, batch in enumerate(train_dataloader):
            batch = tuple(t.to(device) for t in batch)
            loss = model(batch[0], token_type_ids=batch[1], attention_mask=batch[2], labels=batch[3])[0]
            print(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if i % 10 == 0:
              eval(model, validation_dataloader)

def pred():
    # load model
    model = XLNetForSequenceClassification.from_pretrained(output_dir).to(device)

    sentence_list = []
    with open('test.csv') as f:
        rows = [row for row in csv.reader(f)]
        rows = np.array(rows[1:])
        sentence_list = [text for idx, text in rows]

    input_ids, token_type_ids, attention_mask, _ = convert(['test'], sentence_list, [1]) # whatever name_list and label_list
    dataset = Data.TensorDataset(input_ids, token_type_ids, attention_mask)
    loader = Data.DataLoader(dataset, 32, False)

    pred_label = []
    model.eval()
    for i, batch in enumerate(loader):
        batch = tuple(t.to(device) for t in batch)
        with torch.no_grad():
            logits = model(batch[0], token_type_ids=batch[1], attention_mask=batch[2])[0]
            logits = logits.detach().cpu().numpy()
            preds = np.argmax(logits, axis=1).flatten()
            pred_label.extend(preds)
    
    for i in range(len(pred_label)):
        pred_label[i] = classes_list[pred_label[i]]

    pd.DataFrame(data=pred_label, index=range(len(pred_label))).to_csv('pred.csv')

if __name__ == '__main__':
    train_eval()
    pred()