class DocReaderModel(object): def create_embed(self, vocab_size, embed_dim, padding_idx=0): return nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx) def create_word_embed(self, embedding=None, opt={}, prefix='wemb'): vocab_size = opt.get('vocab_size', 1) embed_dim = opt.get('{}_dim'.format(prefix), 300) self.embedding = self.create_embed(vocab_size, embed_dim) if embedding is not None: self.embedding.weight.data = embedding if opt['fix_embeddings'] or opt['tune_partial'] == 0: opt['fix_embeddings'] = True opt['tune_partial'] = 0 for p in self.embedding.parameters(): p.requires_grad = False else: assert opt['tune_partial'] < embedding.size(0) fixed_embedding = embedding[opt['tune_partial']:] self.register_buffer('fixed_embedding', fixed_embedding) self.fixed_embedding = fixed_embedding return embed_dim def __init__(self, opt, embedding=None, state_dict=None): self.opt = opt self.updates = state_dict[ 'updates'] if state_dict and 'updates' in state_dict else 0 self.eval_embed_transfer = True self.train_loss = AverageMeter() if state_dict and 'train_loss' in state_dict: self.train_loss.load_state_dict(state_dict['train_loss']) self.network = DNetwork(opt, embedding) self.forward_network = nn.DataParallel( self.network) if opt['multi_gpu'] else self.network self.state_dict = state_dict parameters = [p for p in self.network.parameters() if p.requires_grad] if opt['optimizer'] == 'sgd': self.optimizer = optim.SGD(parameters, opt['learning_rate'], momentum=opt['momentum'], weight_decay=opt['weight_decay']) elif opt['optimizer'] == 'adamax': self.optimizer = optim.Adamax(parameters, opt['learning_rate'], weight_decay=opt['weight_decay']) elif opt['optimizer'] == 'adam': self.optimizer = optim.Adam(parameters, opt['learning_rate'], weight_decay=opt['weight_decay']) elif opt['optimizer'] == 'adadelta': self.optimizer = optim.Adadelta(parameters, opt['learning_rate'], rho=0.95) else: raise RuntimeError('Unsupported optimizer: %s' % opt['optimizer']) if opt['fix_embeddings']: wvec_size = 0 else: wvec_size = (opt['vocab_size'] - opt['tune_partial']) * opt['embedding_dim'] if opt.get('have_lr_scheduler', False): if opt.get('scheduler_type', 'rop') == 'rop': self.scheduler = ReduceLROnPlateau(self.optimizer, mode='max', factor=opt['lr_gamma'], patience=3) elif opt.get('scheduler_type', 'rop') == 'exp': self.scheduler = ExponentioalLR(self.optimizer, gamma=opt.get('lr_gamma', 0.5)) else: milestones = [ int(step) for step in opt.get('multi_step_lr', '10,20,30').split(',') ] self.scheduler = MultiStepLR(self.optimizer, milestones=milestones, gamma=opt.get('lr_gamma')) else: self.scheduler = None self.total_param = sum([p.nelement() for p in parameters]) - wvec_size def update(self, batch, name_map, dataset_name): self.network.train() pred = self.forward_network(*batch[:name_map['input_len']], dataset_name=dataset_name) if dataset_name == 'wdw': if self.opt['cuda']: y = Variable(batch[name_map['truth']].cuda(async=True)) score = Variable(batch[name_map['score']].cuda(async=True)) else: y = Variable(batch[name_map['truth']]) score = Variable(batch[name_map['score']]) loss = F.nll_loss(pred, y, reduction='none') else: if self.opt['cuda']: y = Variable( batch[name_map['start']].cuda(async=True)), Variable( batch[name_map['end']].cuda(async=True)) score = Variable(batch[name_map['score']].cuda(async=True)) else: y = Variable(batch[name_map['start']]), Variable( batch[name_map['end']]) score = Variable(batch[name_map['score']]) start, end = pred loss = F.cross_entropy(start, y[0], reduce=False) + F.cross_entropy( end, y[1], reduce=False) if self.opt['uncertainty_loss']: loss = loss * torch.exp( -self.network.log_uncertainty[dataset_name] ) / 2 + self.network.log_uncertainty[dataset_name] / 2 loss = torch.mean(loss * score) if self.opt['elmo_l2'] > 0: loss += self.network.elmo_l2norm() * self.opt['elmo_l2'] self.train_loss.update(loss.item(), len(score)) self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.network.parameters(), self.opt['grad_clipping']) self.optimizer.step() if self.opt['ema']: self.ema.update() self.updates += 1 self.reset_embeddings() self.eval_embed_transfer = True self.para_swapped = False def eval(self): if self.opt['ema']: self.ema.swap_parameters() self.para_swapped = True def train(self): if self.para_swapped: self.ema.swap_parameters() self.para_swapped = False def predict(self, batch, name_map, top_k=1, dataset_name='squad'): self.network.eval() if self.eval_embed_transfer: self.update_eval_embed() self.eval_embed_transfer = False self.network.drop_emb = False pred = self.forward_network(*batch[:name_map['input_len']], dataset_name=dataset_name) if dataset_name == 'wdw': probs = pred.cpu() predictions = torch.max(probs, dim=1)[1].tolist() return (predictions, probs.tolist()) else: start, end = pred if name_map['valid_size'] != -1: valid_size = name_map['valid_size'] start = start[:valid_size, :] end = end[:valid_size, :] else: valid_size = len(batch[name_map['text']]) start = F.softmax(start, dim=1) end = F.softmax(end, dim=1) start = start.data.cpu() end = end.data.cpu() text = batch[name_map['text']] spans = batch[name_map['span']] predictions = [] best_scores = [] if 'marco' in dataset_name: max_len = self.opt['marco_max_len'] or start.size(1) else: max_len = self.opt['max_len'] or start.size(1) doc_len = start.size(1) pos_enc = self.position_encoding(doc_len, max_len) for i in range(start.size(0)): scores = torch.ger(start[i], end[i]) scores = scores * pos_enc scores.triu_() scores = scores.numpy() best_idx = np.argpartition(scores, -top_k, axis=None)[-top_k] best_score = np.partition(scores, -top_k, axis=None)[-top_k] s_idx, e_idx = np.unravel_index(best_idx, scores.shape) s_offset, e_offset = spans[i][s_idx][0], spans[i][e_idx][1] predictions.append(text[i][s_offset:e_offset]) best_scores.append(best_score) start_scores_list = start.tolist() end_scores_list = end.tolist() return (predictions, best_scores, start_scores_list, end_scores_list) def setup_eval_embed(self, eval_embed, padding_idx=0): self.network.lexicon_encoder.eval_embed = nn.Embedding( eval_embed.size(0), eval_embed.size(1), padding_idx=padding_idx) self.network.lexicon_encoder.eval_embed.weight.data = eval_embed for p in self.network.lexicon_encoder.eval_embed.parameters(): p.requires_grad = False self.eval_embed_transfer = True if self.opt['covec_on']: self.network.lexicon_encoder.ContextualEmbed.setup_eval_embed( eval_embed) def update_eval_embed(self): if self.opt['tune_partial'] > 0: offset = self.opt['tune_partial'] self.network.lexicon_encoder.eval_embed.weight.data[0:offset,:] \ = self.network.lexicon_encoder.embedding.weight.data[0:offset,:] def reset_embeddings(self): if self.opt['tune_partial'] > 0: offset = self.opt['tune_partial'] if offset < self.network.lexicon_encoder.embedding.weight.data.size( 0): self.network.lexicon_encoder.embedding.weight.data[offset:,:] \ = self.network.lexicon_encoder.fixed_embedding def save(self, filename, epoch, best_em_score, best_f1_score): # strip cove network_state = dict([(k, v) for k, v in self.network.state_dict().items() if k[0:4] != 'CoVe' and '_elmo_lstm' not in k]) if 'eval_embed.weight' in network_state: del network_state['eval_embed.weight'] if 'lexicon_encoder.fixed_embedding' in network_state: del network_state['lexicon_encoder.fixed_embedding'] params = { 'state_dict': { 'network': network_state, 'optimizer': self.optimizer.state_dict(), 'train_loss': self.train_loss.state_dict(), 'updates': self.updates, 'ema': self.ema.state_dict() }, 'config': self.opt, 'random_state': random.getstate(), 'torch_state': torch.random.get_rng_state(), 'torch_cuda_state': torch.cuda.get_rng_state(), 'epoch': epoch, 'best_em_score': best_em_score, 'best_f1_score': best_f1_score } if self.scheduler: params['scheduler_state'] = self.scheduler.state_dict() for try_id in range(10): try: torch.save(params, filename) break except Exception as e: print('save failed. error:', e) logger.info('model saved to {}'.format(filename)) def cuda(self): self.network.cuda() ema_state = None if self.state_dict: new_state = set(self.network.state_dict().keys()) for k in list(self.state_dict['network'].keys()): if k not in new_state: print('key dropped:', k) del self.state_dict['network'][k] for k, v in list(self.network.state_dict().items()): if k not in self.state_dict['network']: self.state_dict['network'][k] = v self.network.load_state_dict(self.state_dict['network']) if self.opt['tune_partial'] > 0: offset = self.opt['tune_partial'] self.network.lexicon_encoder.embedding.weight.data[0:offset,:] \ = self.state_dict['network']['lexicon_encoder.embedding.weight'][0:offset,:] if 'optimizer' in self.state_dict and not self.opt[ 'not_resume_optimizer']: self.optimizer.load_state_dict(self.state_dict['optimizer']) ema_state = self.state_dict['ema'] if self.opt['ema']: self.ema = EMA(self.opt['ema_gamma'], self.network, ema_state) def position_encoding(self, m, threshold=4): encoding = np.ones((m, m), dtype=np.float32) for i in range(m): for j in range(i, m): if j - i > threshold: encoding[i][j] = float(1.0 / math.log(j - i + 1)) return torch.from_numpy(encoding)
class DocReaderModel(object): def __init__(self, opt, embedding=None, state_dict=None): self.opt = opt self.updates = state_dict[ 'updates'] if state_dict and 'updates' in state_dict else 0 self.eval_embed_transfer = True self.train_loss = AverageMeter() self.network = DNetwork(opt, embedding) if state_dict: new_state = set(self.network.state_dict().keys()) for k in list(state_dict['network'].keys()): if k not in new_state: del state_dict['network'][k] for k, v in list(self.network.state_dict().items()): if k not in state_dict['network']: state_dict['network'][k] = v self.network.load_state_dict(state_dict['network']) parameters = [p for p in self.network.parameters() if p.requires_grad] if opt['optimizer'] == 'sgd': self.optimizer = optim.SGD(parameters, opt['learning_rate'], momentum=opt['momentum'], weight_decay=opt['weight_decay']) elif opt['optimizer'] == 'adamax': self.optimizer = optim.Adamax(parameters, opt['learning_rate'], weight_decay=opt['weight_decay']) elif opt['optimizer'] == 'adam': self.optimizer = optim.Adam(parameters, opt['learning_rate'], weight_decay=opt['weight_decay']) elif opt['optimizer'] == 'adadelta': self.optimizer = optim.Adadelta(parameters, opt['learning_rate'], rho=0.95) else: raise RuntimeError('Unsupported optimizer: %s' % opt['optimizer']) if state_dict and 'optimizer' in state_dict: self.optimizer.load_state_dict(state_dict['optimizer']) if opt['fix_embeddings']: wvec_size = 0 else: wvec_size = (opt['vocab_size'] - opt['tune_partial']) * opt['embedding_dim'] if opt.get('have_lr_scheduler', False): if opt.get('scheduler_type', 'rop') == 'rop': self.scheduler = ReduceLROnPlateau(self.optimizer, mode='max', factor=opt['lr_gamma'], patience=3) elif opt.get('scheduler_type', 'rop') == 'exp': self.scheduler = ExponentioalLR(self.optimizer, gamma=opt.get('lr_gamma', 0.5)) else: milestones = [ int(step) for step in opt.get('multi_step_lr', '10,20,30').split(',') ] self.scheduler = MultiStepLR(self.optimizer, milestones=milestones, gamma=opt.get('lr_gamma')) else: self.scheduler = None self.total_param = sum([p.nelement() for p in parameters]) - wvec_size def update(self, batch): self.network.train() if self.opt['cuda']: y = Variable(batch['start'].cuda(async=True)), Variable( batch['end'].cuda(async=True)) if self.opt.get('v2_on', False): label = Variable(batch['label'].cuda(async=True), requires_grad=False) else: y = Variable(batch['start']), Variable(batch['end']) if self.opt.get('v2_on', False): label = Variable(batch['label'], requires_grad=False) start, end, pred = self.network(batch) loss = F.cross_entropy(start, y[0]) + F.cross_entropy(end, y[1]) if self.opt.get('v2_on', False): loss = loss + F.binary_cross_entropy(pred, torch.unsqueeze( label, 1)) * self.opt.get('classifier_gamma', 1) self.train_loss.update(loss.item(), len(start)) self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.network.parameters(), self.opt['grad_clipping']) self.optimizer.step() self.updates += 1 self.reset_embeddings() self.eval_embed_transfer = True def predict(self, batch, top_k=1): self.network.eval() self.network.drop_emb = False # Transfer trained embedding to evaluation embedding if self.eval_embed_transfer: self.update_eval_embed() self.eval_embed_transfer = False start, end, lab = self.network(batch) start = F.softmax(start, 1) end = F.softmax(end, 1) start = start.data.cpu() end = end.data.cpu() if lab is not None: lab = lab.data.cpu() text = batch['text'] spans = batch['span'] predictions = [] best_scores = [] label_predictions = [] max_len = self.opt['max_len'] or start.size(1) doc_len = start.size(1) pos_enc = self.position_encoding(doc_len, max_len) for i in range(start.size(0)): scores = torch.ger(start[i], end[i]) scores = scores * pos_enc scores.triu_() scores = scores.numpy() best_idx = np.argpartition(scores, -top_k, axis=None)[-top_k] best_score = np.partition(scores, -top_k, axis=None)[-top_k] s_idx, e_idx = np.unravel_index(best_idx, scores.shape) if self.opt.get('v2_on', False): label_score = float(lab[i]) s_offset, e_offset = spans[i][s_idx][0], spans[i][e_idx][1] answer = text[i][s_offset:e_offset] if s_idx == len(spans[i]) - 1: answer = '' predictions.append(answer) best_scores.append(best_score) label_predictions.append(label_score) else: s_offset, e_offset = spans[i][s_idx][0], spans[i][e_idx][1] predictions.append(text[i][s_offset:e_offset]) best_scores.append(best_score) #if self.opt.get('v2_on', False): # return (predictions, best_scores, label_predictions) #return (predictions, best_scores) return (predictions, best_scores, label_predictions) def setup_eval_embed(self, eval_embed, padding_idx=0): self.network.lexicon_encoder.eval_embed = nn.Embedding( eval_embed.size(0), eval_embed.size(1), padding_idx=padding_idx) self.network.lexicon_encoder.eval_embed.weight.data = eval_embed for p in self.network.lexicon_encoder.eval_embed.parameters(): p.requires_grad = False self.eval_embed_transfer = True if self.opt['covec_on']: self.network.lexicon_encoder.ContextualEmbed.setup_eval_embed( eval_embed) def update_eval_embed(self): if self.opt['tune_partial'] > 0: offset = self.opt['tune_partial'] self.network.lexicon_encoder.eval_embed.weight.data[0:offset] \ = self.network.lexicon_encoder.embedding.weight.data[0:offset] def reset_embeddings(self): if self.opt['tune_partial'] > 0: offset = self.opt['tune_partial'] if offset < self.network.lexicon_encoder.embedding.weight.data.size( 0): self.network.lexicon_encoder.embedding.weight.data[offset:] \ = self.network.lexicon_encoder.fixed_embedding def save(self, filename, epoch): network_state = dict([(k, v) for k, v in self.network.state_dict().items() if k[0:4] != 'CoVe']) if 'eval_embed.weight' in network_state: del network_state['eval_embed.weight'] if 'fixed_embedding' in network_state: del network_state['fixed_embedding'] params = { 'state_dict': { 'network': network_state }, 'config': self.opt, } torch.save(params, filename) logger.info('model saved to {}'.format(filename)) def cuda(self): self.network.cuda() def position_encoding(self, m, threshold=5): encoding = np.ones((m, m), dtype=np.float32) for i in range(m): for j in range(i, m): if j - i > threshold: encoding[i][j] = float(1.0 / math.log(j - i + 1)) return torch.from_numpy(encoding)
class DocReaderModel(object): def __init__(self, opt, embedding=None, state_dict=None): self.opt = opt self.updates = state_dict[ 'updates'] if state_dict and 'updates' in state_dict else 0 self.eval_embed_transfer = True self.train_loss = AverageMeter() if self.opt['weight_type'] == 'bleu': print('Use BLEU for weighing') self.sentence_metric = eval_bleu.sentence_bleu elif self.opt['weight_type'] == 'nist': print('Use NIST for weighing') self.sentence_metric = eval_nist.sentence_nist else: raise ValueError('Unknown weight type {}'.format( self.opt['weight_type'])) if self.opt['model_type'] == 'san': encoder = DNetwork(opt, embedding) elif self.opt['model_type'] in {'seq2seq', 'memnet'}: encoder = DNetwork_Seq2seq(opt, embedding) else: raise ValueError('Unknown model type: {}'.format( self.opt['model_type'])) if self.opt['model_type'] in {'seq2seq', 'memnet'}: self.cove_embedder = ContextualEmbed(opt['covec_path'], opt['vocab_size'], embedding=embedding) else: self.cove_embedder = None decoder_hidden_size = opt['decoder_hidden_size'] enc_dec_bridge = nn.Linear(encoder.hidden_size, decoder_hidden_size) if opt['self_attention_on']: doc_mem_hidden_size = encoder.doc_mem_gen.output_size else: doc_mem_hidden_size = encoder.doc_understand.output_size decoder = SANDecoder(doc_mem_hidden_size, decoder_hidden_size, opt, prefix='decoder', dropout=encoder.dropout) ans_embedding = nn.Embedding(opt['vocab_size'], doc_mem_hidden_size, padding_idx=0) print('decoder hidden size: %d' % decoder_hidden_size) print('ans emb size: %d' % doc_mem_hidden_size) generator = nn.Sequential( nn.Linear(decoder_hidden_size, opt['vocab_size']), nn.LogSoftmax(dim=1)) loss_compute = nn.NLLLoss(ignore_index=0) self.network = myNetwork(encoder, decoder, ans_embedding, generator, loss_compute, enc_dec_bridge) if state_dict: print('loading checkpoint model...') new_state = set(self.network.state_dict().keys()) for k in list(state_dict['network'].keys()): if k not in new_state: del state_dict['network'][k] for k, v in list(self.network.state_dict().items()): if k not in state_dict['network']: state_dict['network'][k] = v self.network.load_state_dict(state_dict['network']) # Building optimizer. parameters = [p for p in self.network.parameters() if p.requires_grad] if opt['optimizer'] == 'sgd': self.optimizer = optim.SGD(parameters, opt['learning_rate'], momentum=opt['momentum'], weight_decay=opt['weight_decay']) elif opt['optimizer'] == 'adamax': self.optimizer = optim.Adamax(parameters, opt['learning_rate'], weight_decay=opt['weight_decay']) elif opt['optimizer'] == 'adam': self.optimizer = optim.Adam(parameters, opt['learning_rate'], weight_decay=opt['weight_decay']) elif opt['optimizer'] == 'adadelta': self.optimizer = optim.Adadelta(parameters, opt['learning_rate'], rho=0.95) else: raise RuntimeError('Unsupported optimizer: %s' % opt['optimizer']) if state_dict and 'optimizer' in state_dict: self.optimizer.load_state_dict(state_dict['optimizer']) if opt['fix_embeddings']: wvec_size = 0 else: wvec_size = 0 if opt.get('have_lr_scheduler', False): if opt.get('scheduler_type', 'rop') == 'rop': self.scheduler = ReduceLROnPlateau(self.optimizer, mode='min', factor=opt['lr_gamma'], patience=2, verbose=True) elif opt.get('scheduler_type', 'rop') == 'exp': self.scheduler = ExponentioalLR(self.optimizer, gamma=opt.get('lr_gamma', 0.5)) else: milestones = [ int(step) for step in opt.get('multi_step_lr', '10,20,30').split(',') ] self.scheduler = MultiStepLR(self.optimizer, milestones=milestones, gamma=opt.get('lr_gamma')) else: self.scheduler = None self.total_param = sum([p.nelement() for p in parameters]) - wvec_size ## RNN encoder + memory def encode_memnet(self, query, batch): if self.opt['cuda']: query = query.cuda() query_emb = self.network.encoder.embedding(query) encoder_hidden = self.network.encoder.initHidden(query.size(1)) if self.opt['cuda']: encoder_hidden = encoder_hidden.cuda() query_emb = query_emb.cuda() encoder_hidden = Variable(encoder_hidden) for word in torch.split(query_emb, 1): word = word.squeeze(0) encoder_hidden = self.network.encoder(word, encoder_hidden) mem_hidden = self.network.encoder.add_fact_memory( encoder_hidden, batch) mem_hidden += encoder_hidden return mem_hidden def patch(self, v): if self.opt['cuda']: v = Variable(v.cuda(async=True)) else: v = Variable(v) return v ## RNN encoder def encode(self, query, batch): if self.opt['cuda']: query = query.cuda() query_emb = self.network.encoder.embedding(query) query_cove_low, query_cove_high = self.cove_embedder( Variable(batch['query_tok']), Variable(batch['query_mask'])) if self.opt['cuda']: query_cove_low = query_cove_low.cuda() query_cove_high = query_cove_high.cuda() query_cove_low = query_cove_low.transpose(1, 0, 2) query_cove_high = query_cove_high.transpose(1, 0, 2) query_emb = torch.cat([query_emb, query_cove_low, query_cove_high], 2) encoder_hidden = self.network.encoder.initHidden(query.size(1)) if self.opt['cuda']: encoder_hidden = encoder_hidden.cuda() encoder_hidden = Variable(encoder_hidden) for word in torch.split(query_emb, 1): word = word.squeeze(0) encoder_hidden = self.network.encoder(word, encoder_hidden) return encoder_hidden def compute_w(self, fact, res, smooth=0, batch_size=32): def _strip_pad(lst): lst = [str(_) for _ in lst] lst = ' '.join(lst) lst = lst.strip(' 0') lst = lst.split() return lst w = [] for f, r in zip(fact, res): f = _strip_pad(f) r = _strip_pad(r) fact_bleu = self.sentence_metric([f], r, smooth=True) fact_bleu += smooth w.append(fact_bleu) w = np.array(w) w = w / sum(w) w = w * batch_size return w def update(self, batch, smooth=-1, rep_train=0.5): self.network.train() if rep_train > 0: rep_train = 1 - rep_train rep_len = int(len(batch['doc_tok']) * rep_train) answer_token = batch['answer_token'][:rep_len] doc_tok = batch['doc_tok'][rep_len:] ans_len = len(batch['answer_token'][1]) doc_tok = doc_tok[:, :ans_len] doc_ans = torch.cat((answer_token, doc_tok), 0) doc_ans = Variable(doc_ans.transpose(0, 1), requires_grad=False) else: doc_ans = Variable(batch['answer_token'].transpose(0, 1), requires_grad=False) if self.opt['cuda']: doc_ans = doc_ans.cuda() doc_ans_emb = self.network.ans_embedding(doc_ans) if self.opt['model_type'] == 'san': doc_mem, query_mem, doc_mask = self.network.encoder(batch) elif self.opt['model_type'] in {'seq2seq', 'memnet'}: query = Variable(batch['query_tok'].transpose(0, 1)) if self.opt['model_type'] == 'seq2seq': encoder_hidden = self.encode(query, batch) else: encoder_hidden = self.encode_memnet(query, batch) query_mem = encoder_hidden doc_mem, doc_mask = None, None else: raise ValueError('Unknown model type: {}'.format( self.opt['model_type'])) batch_size = query_mem.size(0) hidden = self.network.enc_dec_bridge(query_mem) hiddens = [] for word in torch.split(doc_ans_emb, 1)[:-1]: word = word.squeeze(0) hidden = self.network.decoder(word, hidden, doc_mem, doc_mask) hiddens.append(hidden) hiddens = torch.stack(hiddens) log_probs = self.network.generator(hiddens.view(-1, hiddens.size(2))) if smooth >= 0: weight = self.compute_w(batch['doc_tok'], batch['answer_token'][:, 1:-1], smooth, batch_size) weight = np.reshape(weight, [-1, 1, 1]) weight = torch.FloatTensor(weight).cuda() weight = Variable(weight, requires_grad=False) new_log_probs = log_probs.view(batch_size, -1, self.opt['vocab_size']) new_log_probs = weight * new_log_probs log_probs = new_log_probs.view(-1, self.opt['vocab_size']) target = doc_ans[1:].view(-1).data target = Variable(target, requires_grad=False) loss = self.network.loss_compute(log_probs, target) self.train_loss.update(loss.data.item(), doc_ans.size(1)) ## update loss self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.network.parameters(), self.opt['grad_clipping']) self.optimizer.step() self.updates += 1 self.eval_embed_transfer = True def predict(self, batch, top_k=2): max_len = self.opt['max_len'] BOS_token = STA_ID self.network.eval() self.network.drop_emb = False if self.opt['model_type'] == 'san': doc_mem, query_mem, doc_mask = self.network.encoder(batch) elif self.opt['model_type'] in {'seq2seq', 'memnet'}: query = Variable(batch['query_tok'].transpose(0, 1)) if self.opt['model_type'] == 'seq2seq': encoder_hidden = self.encode(query, batch) else: encoder_hidden = self.encode_memnet(query, batch) query_mem = encoder_hidden doc_mem, doc_mask = None, None else: raise ValueError('Unknown model type: {}'.format( self.opt['model_type'])) hidden = self.network.enc_dec_bridge(query_mem) batch_size = query_mem.size(0) next_token = Variable(torch.LongTensor([BOS_token] * batch_size), requires_grad=False).cuda() def _get_topk_tokens(log_prob, topk): """all except `log_prob` must be numpy """ log_prob_py = log_prob.data.cpu().numpy() topk_tokens = log_prob_py.argsort()[:, -topk:] return topk_tokens fact_py = batch['doc_tok'].numpy().tolist() def _delta_bleu(exist_subseq, fact, log_prob, topk): """all except `log_prob` must be numpy """ log_prob_py = log_prob.data.cpu().numpy() if exist_subseq is None: exist_bleu = np.zeros([batch_size]) else: exist_bleu = [ self.sentence_metric([r], f, smooth=True) for r, f in zip(exist_subseq, fact) ] delta_bleu = np.zeros([batch_size, self.opt['vocab_size']]) topk_tokens = log_prob_py.argsort()[:, -topk:] if self.opt['decoding_bleu_lambda'] > 0: for topk_i in range(topk): candidate_token = topk_tokens[:, topk_i] delta_bleu_i = _delta_bleu_core(candidate_token, exist_subseq, fact, exist_bleu) delta_bleu[range(batch_size), candidate_token] = delta_bleu_i if self.opt['decoding_bleu_normalize']: delta_bleu_sum = np.sum(delta_bleu, axis=1, keepdims=True) delta_bleu /= (delta_bleu_sum + 1e-7) return delta_bleu, topk_tokens def _delta_bleu_core(candidate_token, exist_subseq, fact, exist_bleu): """all inputs must be numpy or python, not pytorch """ candidate_token = np.reshape(candidate_token, [-1, 1]) if exist_subseq is None: new_subseq = candidate_token else: new_subseq = np.concatenate([exist_subseq, candidate_token], 1) new_bleu = [ self.sentence_metric([r], f, smooth=True) for r, f in zip(new_subseq, fact) ] return np.array(new_bleu) - np.array(exist_bleu) def _remove_tokens(tokens): rm_skips = torch.zeros([batch_size, self.opt['vocab_size']]) rm_skips[:, tokens] = -1000000 return Variable(rm_skips, requires_grad=False).cuda() preds = [] pred_topks = [] preds_np = None for step in range(max_len): word = self.network.ans_embedding(next_token) hidden = self.network.decoder(word, hidden, doc_mem, doc_mask) log_prob = self.network.generator(hidden) unk_id = self.opt['unk_id'] rm_UNK = torch.cat([ torch.zeros([batch_size, unk_id]), torch.ones([batch_size, 1]) * -1000000, torch.zeros([batch_size, self.opt['vocab_size'] - unk_id - 1]) ], dim=1).float() log_prob += Variable(rm_UNK, requires_grad=False).cuda() if self.opt['skip_tokens']: log_prob += _remove_tokens(self.opt['skip_tokens']) if self.opt['skip_tokens_first'] and step == 0: log_prob += _remove_tokens(self.opt['skip_tokens_first']) if self.opt['decoding'] == 'greedy': _, next_token = torch.max(log_prob, 1) elif self.opt['decoding'] == 'sample': t = self.opt['temperature'] next_token = torch.multinomial(torch.exp(log_prob / t), 1).squeeze(-1) elif self.opt['decoding'] == 'weight': delta_bleu, log_prob_topk_tokens = _delta_bleu( preds_np, fact_py, log_prob, self.opt['decoding_topk']) effective_log_prob_sum = Variable(torch.zeros([batch_size]), requires_grad=False).cuda() dumb_log_prob = np.ones(log_prob.size()) * -10000000 for topk_i in range(self.opt['decoding_topk']): log_prob_topk_i = log_prob[ torch.LongTensor(range(batch_size)).cuda(), torch.LongTensor(log_prob_topk_tokens[:, topk_i]).cuda()] dumb_log_prob[range( batch_size ), log_prob_topk_tokens[:, topk_i]] = log_prob_topk_i.data.cpu( ).numpy() effective_log_prob_sum += log_prob_topk_i dumb_log_prob = Variable(torch.FloatTensor(dumb_log_prob), requires_grad=False).cuda() delta_bleu_w = effective_log_prob_sum / self.opt[ 'decoding_topk'] delta_bleu_w = delta_bleu_w.view(-1, 1) bleu_reweight = delta_bleu_w * Variable( torch.FloatTensor(delta_bleu), requires_grad=False).cuda() w_log_prob = dumb_log_prob + self.opt[ 'decoding_bleu_lambda'] * bleu_reweight t = self.opt['temperature'] next_token = torch.multinomial(torch.exp(w_log_prob / t), 1).squeeze(-1) else: raise ValueError('Unknown decoding: %s' % self.opt['decoding']) preds.append(next_token.data.cpu().numpy()) next_token_np = next_token.data.cpu().numpy() next_token_np = np.reshape(next_token_np, [-1, 1]) if preds_np is None: preds_np = next_token_np else: preds_np = np.concatenate([preds_np, next_token_np], 1) _, topk_list = torch.topk(log_prob, top_k) pred_topks.append(topk_list.data.cpu().numpy()) prediction_topks = [[p[i] for p in pred_topks] for i in range(batch_size)] predictions = [[p[i] for p in preds] for i in range(batch_size)] return (predictions, prediction_topks) def eval_test_loss(self, batch): self.network.eval() self.network.drop_emb = False with torch.no_grad(): doc_ans = Variable(batch['answer_token'].transpose(0, 1)) if self.opt['cuda']: doc_ans = doc_ans.cuda() doc_ans_emb = self.network.ans_embedding(doc_ans) if self.opt['model_type'] == 'san': doc_mem, query_mem, doc_mask = self.network.encoder(batch) elif self.opt['model_type'] in {'seq2seq', 'memnet'}: query = Variable(batch['query_tok'].transpose(0, 1)) if self.opt['model_type'] == 'seq2seq': encoder_hidden = self.encode(query, batch) else: encoder_hidden = self.encode_memnet(query, batch) query_mem = encoder_hidden doc_mem, doc_mask = None, None else: raise ValueError('Unknown model type: {}'.format( self.opt['model_type'])) hidden = self.network.enc_dec_bridge(query_mem) hiddens = [] for word in torch.split(doc_ans_emb, 1)[:-1]: word = word.squeeze(0) hidden = self.network.decoder(word, hidden, doc_mem, doc_mask) hiddens.append(hidden) hiddens = torch.stack(hiddens) log_probs = self.network.generator(hiddens.view(-1, hiddens.size(2))) target = doc_ans[1:].contiguous().view(-1).data with torch.no_grad(): target = Variable(target) loss = self.network.loss_compute(log_probs, target) return loss def setup_eval_embed(self, eval_embed, padding_idx=0): self.network.encoder.lexicon_encoder.eval_embed = nn.Embedding( eval_embed.size(0), eval_embed.size(1), padding_idx=padding_idx) self.network.encoder.lexicon_encoder.eval_embed.weight.data = eval_embed for p in self.network.encoder.lexicon_encoder.eval_embed.parameters(): p.requires_grad = False self.eval_embed_transfer = True if self.opt['covec_on']: self.network.encoder.lexicon_encoder.ContextualEmbed.setup_eval_embed( eval_embed) def update_eval_embed(self): if self.opt['tune_partial'] > 0: offset = self.opt['tune_partial'] self.network.encoder.lexicon_encoder.eval_embed.weight.data[0:offset] \ = self.network.encoder.lexicon_encoder.embedding.weight.data[0:offset] def reset_embeddings(self): if self.opt['tune_partial'] > 0: offset = self.opt['tune_partial'] if offset < self.network.encoder.lexicon_encoder.embedding.weight.data.size( 0): self.network.encoder.lexicon_encoder.embedding.weight.data[offset:] \ = self.network.encoder.lexicon_encoder.fixed_embedding def save(self, filename, epoch): # strip cove network_state = dict([(k, v) for k, v in self.network.state_dict().items() if k[0:4] != 'CoVe']) if 'eval_embed.weight' in network_state: del network_state['eval_embed.weight'] if 'fixed_embedding' in network_state: del network_state['fixed_embedding'] params = { 'state_dict': { 'network': network_state }, 'config': self.opt, } torch.save(params, filename) logger.info('model saved to {}'.format(filename)) def cuda(self): self.network.cuda() def position_encoding(self, m, threshold=4): encoding = np.ones((m, m), dtype=np.float32) for i in range(m): for j in range(i, m): if j - i > threshold: encoding[i][j] = float(1.0 / math.log(j - i + 1)) return torch.from_numpy(encoding)