Source code for EduNLP.Pretrain.elmo_vec

import torch
from torch import nn
import torch.nn.functional as F
import torch.utils.data as tud
import torch.optim as optim
import numpy as np
import json
import os
import time
from EduNLP.SIF import Symbol, FORMULA_SYMBOL, FIGURE_SYMBOL, QUES_MARK_SYMBOL, TAG_SYMBOL, SEP_SYMBOL
from EduNLP.Tokenizer import PureTextTokenizer
from EduNLP.ModelZoo.rnn import ElmoLM
from EduNLP.ModelZoo import set_device

UNK_SYMBOL = '[UNK]'
PAD_SYMBOL = '[PAD]'


[docs]class ElmoTokenizer(object): """ Examples -------- >>> t=ElmoTokenizer() >>> items = ["有公式$\\FormFigureID{wrong1?}$,如图$\\FigureID{088f15ea-xxx}$,\\ ... 若$x,y$满足约束条件公式$\\FormFigureBase64{wrong2?}$,$\\SIFSep$,则$z=x+7 y$的最大值为$\\SIFBlank$"] >>> t.tokenize(items[0]) ['公式', '如图', '[FIGURE]', 'x', ',', 'y', '约束条件', '公式', '[SEP]', 'z', '=', 'x', '+', '7', 'y', '最大值', '[MARK]'] >>> len(t) 18 """ def __init__(self, path: str = None): """ Parameters ---------- path: str, optional the path of saved ElmoTokenizer, e.g. "../elmo_pub_math/vocab.json" """ self.pure_tokenizer = PureTextTokenizer() self.t2id = {PAD_SYMBOL: 0, UNK_SYMBOL: 1, FORMULA_SYMBOL: 2, FIGURE_SYMBOL: 3, QUES_MARK_SYMBOL: 4, TAG_SYMBOL: 5, SEP_SYMBOL: 6} if path is None: pass else: self.load_vocab(path) def __call__(self, item: (str, list), freeze_vocab=False, pad_to_max_length=False, *args, **kwargs): tokens, lengths = self.tokenize(item=item, freeze_vocab=freeze_vocab, return_length=True) if isinstance(item, str): return self.to_index(item=tokens, pad_to_max_length=pad_to_max_length), lengths else: ret = [] for ts in tokens: ret.append(self.to_index(item=ts, pad_to_max_length=pad_to_max_length)) return ret, lengths def __len__(self): return len(self.t2id)
[docs] def tokenize(self, item: (str, list), freeze_vocab=False, return_length=False): items = [item] if isinstance(item, str) else item lengths = [] tokens = [] for i in self.pure_tokenizer(items): tokens.append(i) lengths.append(len(i)) if not freeze_vocab: for t in i: self.append(t) tokens = tokens[0] if isinstance(item, str) else tokens lengths = lengths[0] if isinstance(item, str) else lengths if return_length: return tokens, lengths else: return tokens
[docs] def to_index(self, item: list, max_length=128, pad_to_max_length=False): ret = [self.t2id[UNK_SYMBOL] if token not in self.t2id else self.t2id[token] for token in item] if pad_to_max_length: if len(ret) < max_length: ret = ret + (max_length - len(ret)) * [self.t2id[PAD_SYMBOL]] else: ret = ret[0:max_length - 1] return ret
[docs] def append(self, item): if item in self.t2id: pass else: self.t2id[item] = len(self.t2id)
[docs] def save_vocab(self, path): with open(path, 'w') as f: json.dump(self.t2id, f) return path
[docs] def load_vocab(self, path): with open(path, 'r') as f: self.t2id = json.load(f) return path
[docs]class ElmoDataset(tud.Dataset): def __init__(self, texts: list, tokenizer: ElmoTokenizer, max_length=128): """ Parameters ---------- texts: list tokenizer: ElmoTokenizer max_length: int, optional, default=128 """ super(ElmoDataset, self).__init__() self.tokenizer = tokenizer self.texts = [text if len(text) < max_length else text[0:max_length - 1] for text in texts] self.max_length = max_length def __len__(self): return len(self.texts) def __getitem__(self, index): text = self.texts[index] sample = { 'length': len(text), 'idx': self.tokenizer.to_index(text, pad_to_max_length=True, max_length=self.max_length) } return sample
[docs]def elmo_collate_fn(batch_data): pred_mask = [] idx_mask = [] max_len = max([data['length'] for data in batch_data]) for data in batch_data: pred_mask.append([True] * data['length'] + [False] * (max_len - data['length'])) for data in batch_data: idx_mask.append([True] * data['length'] + [False] * (len(data['idx']) - data['length'])) ret_batch = { 'pred_mask': torch.tensor(pred_mask), 'idx_mask': torch.tensor(idx_mask), 'length': torch.tensor([data['length'] for data in batch_data]), 'idx': torch.tensor([data['idx'] for data in batch_data]) } return ret_batch
[docs]def train_elmo(texts: list, output_dir: str, pretrained_dir: str = None, emb_dim=512, hid_dim=512, batch_size=2, epochs=3, lr: float = 5e-4, device=None): """ Parameters ---------- texts: list, required The training corpus of shape (text_num, token_num), a text must be tokenized into tokens output_dir: str, required The directory to save trained model files pretrained_dir: str, optional The pretrained model files' directory emb_dim: int, optional, default=512 The embedding dim hid_dim: int, optional, default=1024 The hidden dim batch_size: int, optional, default=2 The training batch size epochs: int, optional, default=3 The training epochs lr: float, optional, default=5e-4 The learning rate device: str, optional Default is 'cuda' if available, otherwise 'cpu' Returns ------- output_dir: str The directory that trained model files are saved """ tokenizer = ElmoTokenizer() if pretrained_dir: tokenizer.load_vocab(os.path.join(pretrained_dir, 'vocab.json')) else: for text in texts: for token in text: tokenizer.append(token) train_dataset = ElmoDataset(texts, tokenizer) if pretrained_dir: with open(os.path.join(pretrained_dir, 'config.json'), 'r') as f: config = json.load(f) model = ElmoLM(vocab_size=len(tokenizer), embedding_dim=config['emb_dim'], hidden_size=config['hid_dim']) model.load_state_dict(torch.load(os.path.join(pretrained_dir, 'weight.pt'))) else: model = ElmoLM(vocab_size=len(tokenizer), embedding_dim=emb_dim, hidden_size=hid_dim, batch_first=True) if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: device = torch.device(device) model.LM_layer.rnn.flatten_parameters() model.to(device) model.train() global_step = 0 adam = optim.Adam(model.parameters(), lr=lr) loss_func = nn.BCELoss() loss_func.to(device) dataloader = tud.DataLoader(train_dataset, collate_fn=elmo_collate_fn, batch_size=batch_size, shuffle=True) begin = time.time() for epoch in range(epochs): for step, sample in enumerate(dataloader): pred_mask = sample['pred_mask'].to(device) idx_mask = sample['idx_mask'].to(device) idx = sample['idx'].to(device) length = sample['length'].to(device) try: y = F.one_hot(idx, num_classes=len(tokenizer)).to(device) pred_forward, pred_backward, _, _ = model.forward(idx, length) pred_forward = pred_forward[pred_mask] pred_backward = pred_backward[torch.flip(pred_mask, [1])] y_rev = torch.flip(y, [1])[torch.flip(idx_mask, [1])] y = y[idx_mask] forward_loss = loss_func(pred_forward[:, :-1].double(), y[:, 1:].double()) backward_loss = loss_func(pred_backward[:, :-1].double(), y_rev[:, 1:].double()) forward_loss.backward(retain_graph=True) backward_loss.backward(retain_graph=True) adam.step() adam.zero_grad() global_step += 1 if global_step % 10 == 0: print("[Global step %d, epoch %d, batch %d] Loss: %.10f" % ( global_step, epoch, step, forward_loss + backward_loss)) except RuntimeError as e: print("RuntimeError:", e) print("[DEBUG]Sample idx:", idx) end = time.time() print("Train time: ", (end - begin)) model.cpu() config = { 'emb_dim': emb_dim, 'hid_dim': hid_dim, 'batch_first': True, 'vocab_size': len(tokenizer) } if not os.path.exists(output_dir): os.makedirs(output_dir) with open(os.path.join(output_dir, 'config.json'), 'w') as f: json.dump(config, f) torch.save(model.state_dict(), os.path.join(output_dir, 'weight.pt')) tokenizer.save_vocab(os.path.join(output_dir, 'vocab.json')) return output_dir