def build_batch(self): ''' build train dev test batches ''' padding_id = self.vocab.from_string(padding_key) char_padding_id = self.char_vocab.from_string(padding_key) rel_padding_id = None if self.rels_vocab is not None: rel_padding_id = self.rels_vocab.from_string(padding_key) self.train_build_batch = Build_Batch( features=self.train_features_list, batch_size=self.opts.train_batch_size, opts=self.opts, pad_idx=padding_id, char_padding_id=char_padding_id, rel_padding_id=rel_padding_id) self.dev_build_batch = Build_Batch(features=self.dev_features_list, batch_size=self.opts.dev_batch_size, opts=self.opts, pad_idx=padding_id, char_padding_id=char_padding_id, rel_padding_id=rel_padding_id) self.test_build_batch = Build_Batch( features=self.test_features_list, batch_size=self.opts.test_batch_size, opts=self.opts, pad_idx=padding_id, char_padding_id=char_padding_id, rel_padding_id=rel_padding_id) if self.opts.train_batch_type == 'normal': self.train_batch_features, self.train_data_batchs = self.train_build_batch.create_sorted_normal_batch( ) elif self.opts.train_batch_type == 'same': self.train_batch_features, self.train_data_batchs = self.train_build_batch.create_same_sents_length_one_batch( ) else: raise RuntimeError('not normal or same') if self.opts.dev_batch_type == 'normal': self.dev_batch_features, self.dev_data_batchs = self.dev_build_batch.create_sorted_normal_batch( ) elif self.opts.dev_batch_type == 'same': self.dev_batch_features, self.dev_data_batchs = self.dev_build_batch.create_same_sents_length_one_batch( ) else: raise RuntimeError('not normal or same') if self.opts.test_batch_type == 'normal': self.test_batch_features, self.test_data_batchs = self.test_build_batch.create_sorted_normal_batch( ) elif self.opts.test_batch_type == 'same': self.test_batch_features, self.test_data_batchs = self.test_build_batch.create_same_sents_length_one_batch( ) else: raise RuntimeError('not normal or same')
def decoder(self): ''' build train dev test batches ''' padding_id = self.vocab.from_string(padding_key) char_padding_id = self.char_vocab.from_string(padding_key) rel_padding_id = None if self.rel_vocab is not None: rel_padding_id = self.rel_vocab.from_string(padding_key) self.build_batch = Build_Batch(features=self.features_list, batch_size=self.batch_size, opts=self.opts, pad_idx=padding_id, char_padding_id=char_padding_id, rel_padding_id=rel_padding_id) self.batch_features, self.data_batchs = self.build_batch.create_sorted_normal_batch( ) # correct_num = 0 data_batchs = self.data_batchs wrongs = [] for batch in data_batchs: self.model.eval() if 'tree' in self.opts.model: sents = Variable(torch.LongTensor(batch[0]), requires_grad=False) label = Variable(torch.LongTensor(batch[1]), requires_grad=False) heads = batch[4] xlength = batch[6] tag_rels = Variable(torch.LongTensor(batch[7]), requires_grad=False) if self.opts.use_cuda: sents = sents.cuda() label = label.cuda() tag_rels = tag_rels.cuda() if self.opts.model in ['treelstm', 'bitreelstm']: pred = self.model(sents, heads, xlength) if self.opts.model in [ 'lstm_treelstm_rel', 'treelstm_rel', 'bitreelstm_rel' ]: pred = self.model(sents, tag_rels, heads, xlength) else: sents = Variable(torch.LongTensor(batch[0])) label = Variable(torch.LongTensor(batch[1])) char_data = [] if 'Char' in self.opts.model: for char_list in batch[2]: char_data.append(Variable(torch.LongTensor(char_list))) if self.opts.use_cuda: sents = sents.cuda() label = label.cuda() new_char_data = [] for data in char_data: new_char_data.append(data.cuda()) char_data = new_char_data if 'Char' in self.opts.model: pred = self.model(sents, char_data) else: pred = self.model(sents) # correct_num += (torch.max(pred, 1)[1].view(label.size()).data == label.data).sum() pred_index = torch.max(pred, 1)[1].view(label.size()).data.tolist() sents = batch[0] label = batch[1] for index, (t, p) in enumerate(zip(label, pred_index)): if t != p: wrong_sent, length = self.get_sent(sents[index]) right_label = self.get_label(t) wrong_label = self.get_label(p) wrongs.append( (wrong_sent, length, wrong_label, right_label)) self.write(wrongs)
class Decoder: def __init__(self, opts): self.opts = opts self.model = torch.load(self.opts.model_path) self.features_list, self.vocab, self.char_vocab, self.label_vocab, self.rel_vocab \ = None, None, None, None, None self.batch_size = self.opts.batch_size self.save_path = self.opts.save_path self.load_data(self.opts.dir, self.opts.type) self.decoder() def load_data(self, data_dir, type): self.features_list = torch.load(data_dir + '/' + type + '.sst') self.vocab = torch.load(data_dir + '/vocab.sst') self.char_vocab = torch.load(data_dir + '/char_vocab.sst') self.label_vocab = torch.load(data_dir + '/label_vocab.sst') self.rel_vocab = torch.load(data_dir + '/rel_vocab.sst') def decoder(self): ''' build train dev test batches ''' padding_id = self.vocab.from_string(padding_key) char_padding_id = self.char_vocab.from_string(padding_key) rel_padding_id = None if self.rel_vocab is not None: rel_padding_id = self.rel_vocab.from_string(padding_key) self.build_batch = Build_Batch(features=self.features_list, batch_size=self.batch_size, opts=self.opts, pad_idx=padding_id, char_padding_id=char_padding_id, rel_padding_id=rel_padding_id) self.batch_features, self.data_batchs = self.build_batch.create_sorted_normal_batch( ) # correct_num = 0 data_batchs = self.data_batchs wrongs = [] for batch in data_batchs: self.model.eval() if 'tree' in self.opts.model: sents = Variable(torch.LongTensor(batch[0]), requires_grad=False) label = Variable(torch.LongTensor(batch[1]), requires_grad=False) heads = batch[4] xlength = batch[6] tag_rels = Variable(torch.LongTensor(batch[7]), requires_grad=False) if self.opts.use_cuda: sents = sents.cuda() label = label.cuda() tag_rels = tag_rels.cuda() if self.opts.model in ['treelstm', 'bitreelstm']: pred = self.model(sents, heads, xlength) if self.opts.model in [ 'lstm_treelstm_rel', 'treelstm_rel', 'bitreelstm_rel' ]: pred = self.model(sents, tag_rels, heads, xlength) else: sents = Variable(torch.LongTensor(batch[0])) label = Variable(torch.LongTensor(batch[1])) char_data = [] if 'Char' in self.opts.model: for char_list in batch[2]: char_data.append(Variable(torch.LongTensor(char_list))) if self.opts.use_cuda: sents = sents.cuda() label = label.cuda() new_char_data = [] for data in char_data: new_char_data.append(data.cuda()) char_data = new_char_data if 'Char' in self.opts.model: pred = self.model(sents, char_data) else: pred = self.model(sents) # correct_num += (torch.max(pred, 1)[1].view(label.size()).data == label.data).sum() pred_index = torch.max(pred, 1)[1].view(label.size()).data.tolist() sents = batch[0] label = batch[1] for index, (t, p) in enumerate(zip(label, pred_index)): if t != p: wrong_sent, length = self.get_sent(sents[index]) right_label = self.get_label(t) wrong_label = self.get_label(p) wrongs.append( (wrong_sent, length, wrong_label, right_label)) self.write(wrongs) def write(self, wrongs): with open(self.save_path, 'w', encoding='utf8') as f: for wrong in wrongs: f.write('pred: ' + str(wrong[2]) + ' right: ' + str(wrong[3]) + ' length: ' + str(wrong[1]) + ' sent: ' + wrong[0] + '\n') def get_sent(self, idx): sent = [] length = 0 for id in idx: word = self.vocab.id2string[id] if word != padding_key: sent.append(word) length += 1 return ' '.join(sent), length def get_label(self, id): return self.label_vocab.id2string[id]
class Trainer: def __init__(self, train_dev_test, opts, vocab, label_vocab): self.train_features_list = train_dev_test[0] self.dev_features_list = train_dev_test[1] self.test_features_list = train_dev_test[2] self.opts = opts self.vocab = vocab[0] self.char_vocab = vocab[1] self.label_vocab = label_vocab self.epoch = opts.epoch self.shuffle = opts.shuffle self.model = None self.best_dev = 0 self.best_dev_test = 0 self.best_dev_epoch = 0 self.char = False self.lr = self.opts.lr self.build_batch() self.init_model() self.init_optim() self.print_log = Log(opts) #save model switch self.save_model_switch = False random.seed(opts.seed) torch.manual_seed(opts.seed) if self.opts.use_cuda: # torch.backends.cudnn.enabled = True torch.backends.cudnn.deterministic = True torch.cuda.set_device(self.opts.gpu_device) torch.cuda.manual_seed(self.opts.seed) log = 'use CUDA!' self.print_log.print_log(log) print(log) def build_batch(self): ''' build train dev test batches ''' padding_id = self.vocab.from_string(padding_key) char_padding_id = self.char_vocab.from_string(padding_key) self.train_build_batch = Build_Batch( features=self.train_features_list, batch_size=self.opts.train_batch_size, opts=self.opts, pad_idx=padding_id, char_padding_id=char_padding_id) self.dev_build_batch = Build_Batch(features=self.dev_features_list, batch_size=self.opts.dev_batch_size, opts=self.opts, pad_idx=padding_id, char_padding_id=char_padding_id) self.test_build_batch = Build_Batch( features=self.test_features_list, batch_size=self.opts.test_batch_size, opts=self.opts, pad_idx=padding_id, char_padding_id=char_padding_id) if self.opts.train_batch_type == 'normal': self.train_batch_features, self.train_data_batchs = self.train_build_batch.create_sorted_normal_batch( ) elif self.opts.train_batch_type == 'same': self.train_batch_features, self.train_data_batchs = self.train_build_batch.create_same_sents_length_one_batch( ) else: raise RuntimeError('not normal or same') if self.opts.dev_batch_type == 'normal': self.dev_batch_features, self.dev_data_batchs = self.dev_build_batch.create_sorted_normal_batch( ) elif self.opts.dev_batch_type == 'same': self.dev_batch_features, self.dev_data_batchs = self.dev_build_batch.create_same_sents_length_one_batch( ) else: raise RuntimeError('not normal or same') if self.opts.test_batch_type == 'normal': self.test_batch_features, self.test_data_batchs = self.test_build_batch.create_sorted_normal_batch( ) elif self.opts.test_batch_type == 'same': self.test_batch_features, self.test_data_batchs = self.test_build_batch.create_same_sents_length_one_batch( ) else: raise RuntimeError('not normal or same') def init_model(self): ''' pooling, rnn, lstm, bilstm, cnn, multi_cnn, gru :return: ''' if self.opts.model == 'pooling': self.model = Pooling(opts=self.opts, vocab=self.vocab, label_vocab=self.label_vocab) elif self.opts.model == 'cnn': self.model = CNN(opts=self.opts, vocab=self.vocab, label_vocab=self.label_vocab) elif self.opts.model == 'multi_channel_cnn': self.model = Multi_Channel_CNN(opts=self.opts, vocab=self.vocab, label_vocab=self.label_vocab) elif self.opts.model == 'multi_layer_cnn': self.model = Multi_Layer_CNN(opts=self.opts, vocab=self.vocab, label_vocab=self.label_vocab) elif self.opts.model == 'char_cnn': self.char = True self.model = Char_CNN(opts=self.opts, vocab=self.vocab, char_vocab=self.char_vocab, label_vocab=self.label_vocab) elif self.opts.model == 'lstm': self.model = LSTM(opts=self.opts, vocab=self.vocab, label_vocab=self.label_vocab) elif self.opts.model == 'gru': self.model = GRU(opts=self.opts, vocab=self.vocab, label_vocab=self.label_vocab) elif self.opts.model == 'lstm_cnn': self.model = LSTM_CNN(opts=self.opts, vocab=self.vocab, label_vocab=self.label_vocab) else: raise RuntimeError('please choose your model first!') if self.opts.use_cuda: self.model = self.model.cuda() def init_optim(self): 'sgd, adam' if self.opts.optim == 'sgd': self.optimizer = optim.SGD(self.model.parameters(), lr=self.opts.lr, weight_decay=self.opts.weight_decay, momentum=self.opts.momentum) elif self.opts.optim == 'adam': self.optimizer = optim.Adam(self.model.parameters(), lr=self.opts.lr, weight_decay=self.opts.weight_decay) def train(self): early_stop_count = 1 lr_decay_count = 1 for epoch in range(self.epoch): totle_loss = torch.Tensor([0]) correct_num = 0 step = 0 inst_num = 0 if self.shuffle: random.shuffle(self.train_data_batchs) log = 'data has shuffled!' print(log) self.print_log.print_log(log) for batch in self.train_data_batchs: self.model.train() self.optimizer.zero_grad() inst_num += len(batch[1]) sents = Variable(torch.LongTensor(batch[0])) label = Variable(torch.LongTensor(batch[1])) # print(data) char_data = [] if self.char: for char_list in batch[2]: char_data.append(Variable(torch.LongTensor(char_list))) # print(type(char_data[0]), char_data[0].size()) # print(char_data) if self.opts.use_cuda: sents = sents.cuda() label = label.cuda() new_char_data = [] for data in char_data: new_char_data.append(data.cuda()) char_data = new_char_data # print(type(char_data[0])) if self.char: pred = self.model(sents, char_data) else: pred = self.model(sents) loss = F.cross_entropy(pred, label) loss.backward() if self.opts.init_clip_max_norm is not None: utils.clip_grad_norm_( self.model.parameters(), max_norm=self.opts.init_clip_max_norm) self.optimizer.step() loss = loss.cpu() totle_loss += loss.data step += 1 correct_num += (torch.max(pred, 1)[1].view( label.size()).data == label.data).sum() if step % self.opts.print_every == 0: avg_loss = totle_loss / inst_num acc = float(correct_num) / inst_num * 100 time_dic = self.get_time() time_str = "[{}-{:0>2d}-{:0>2d} {:0>2d}:{:0>2d}:{:0>2d}]".format(time_dic['year'], time_dic['month'], time_dic['day'], \ time_dic['hour'], time_dic['min'], time_dic['sec']) log = time_str + " Epoch {} step {} lr={:.8f} acc: {:.2f}% loss: {:.6f}".format( epoch, step, self.lr, acc, avg_loss.numpy()[0]) self.print_log.print_log(log) print(log) totle_loss = torch.Tensor([0]) inst_num = 0 correct_num = 0 dev_score = self.accurcy(type='dev') test_score = self.accurcy(type='test') if dev_score > self.best_dev: early_stop_count = 0 lr_decay_count = 0 self.best_dev = dev_score self.best_dev_epoch = epoch self.best_dev_test = test_score log = "Update! best test acc: {:.2f}%".format( self.best_dev_test) print(log) self.save_model(epoch) else: early_stop_count += 1 lr_decay_count += 1 log = "not improved, best test acc: {:.2f}%, in epoch {}".format( self.best_dev_test, self.best_dev_epoch) print(log) self.print_log.print_log(log) if early_stop_count == self.opts.early_stop: log = "{} epoch have not improved, so early stop the train!".format( early_stop_count) self.print_log.print_log(log) print(log) return if lr_decay_count == self.opts.lr_decay_every: lr_decay_count = 0 self.adjust_learning_rate(self.optimizer, self.opts.lr_decay_rate) log = "{} epoch have not improved, so adjust lr to {}".format( early_stop_count, self.lr) self.print_log.print_log(log) print(log) def save_model(self, cur_epoch): if not os.path.isdir(self.opts.save_model_dir): os.mkdir(self.opts.save_model_dir) if self.opts.save_model_start_from <= cur_epoch: self.save_model_switch = True # if self.save_model_switch and (cur_epoch - self.opts.save_model_start_from) % self.opts.save_model_every == 0: if self.save_model_switch: time_dic = self.get_time() time_str = "{}-{:0>2d}-{:0>2d}-{:0>2d}-{:0>2d}-{:0>2d}-".format(time_dic['year'], time_dic['month'], time_dic['day'], \ time_dic['hour'], time_dic['min'], time_dic['sec']) fname = self.opts.save_model_dir + '/' + time_str + self.opts.model + '-model_epoch_' + str( cur_epoch) + '.pt' torch.save(self.model, fname) self.print_log.print_log('model saved succeed in ' + fname) print('model saved succeed in ' + fname) def accurcy(self, type): totle_loss = torch.Tensor([0]) correct_num = 0 inst_num = 0 data_batchs = None if type == 'dev': data_batchs = self.dev_data_batchs elif type == 'test': data_batchs = self.test_data_batchs else: raise RuntimeError('type wrong!') for batch in data_batchs: self.model.eval() inst_num += len(batch[1]) sents = Variable(torch.LongTensor(batch[0])) label = Variable(torch.LongTensor(batch[1])) char_data = [] if self.char: for char_list in batch[2]: char_data.append(Variable(torch.LongTensor(char_list))) if self.opts.use_cuda: sents = sents.cuda() label = label.cuda() if self.char: new_char_data = [] for data in char_data: new_char_data.append(data.cuda()) char_data = new_char_data if self.char: pred = self.model(sents, char_data) else: pred = self.model(sents) loss = F.cross_entropy(pred, label) loss = loss.cpu() totle_loss += loss.data correct_num += (torch.max(pred, 1)[1].view( label.size()).data == label.data).sum() avg_loss = totle_loss / inst_num acc = float(correct_num) / inst_num * 100 log = type + " acc: {:.2f}% loss: {:.6f}".format( acc, avg_loss.numpy()[0]) self.print_log.print_log(log) print(log) return acc def get_time(self): # tm_year=2018, tm_mon=10, tm_mday=28, tm_hour=10, tm_min=32, tm_sec=14, tm_wday=6, tm_yday=301, tm_isdst=0 cur_time = time.localtime(time.time()) dic = dict() dic['year'] = cur_time.tm_year dic['month'] = cur_time.tm_mon dic['day'] = cur_time.tm_mday dic['hour'] = cur_time.tm_hour dic['min'] = cur_time.tm_min dic['sec'] = cur_time.tm_sec return dic def adjust_learning_rate(self, optim, lr_decay_rate): for param_group in optim.param_groups: param_group['lr'] = param_group['lr'] * (1 - lr_decay_rate) self.lr = param_group['lr']