def decode_dataset(model, src, tgt, config): """Evaluate model.""" inputs = [] preds = [] auxs = [] ground_truths = [] for j in range(0, len(src['data']), config['data']['batch_size']): sys.stdout.write("\r%s/%s..." % (j, len(src['data']))) sys.stdout.flush() # get batch input_content, input_aux, output = data.minibatch( src, tgt, j, config['data']['batch_size'], config['data']['max_len'], config['model']['model_type'], is_test=True) input_lines_src, output_lines_src, srclens, srcmask, indices = input_content input_ids_aux, _, auxlens, auxmask, _ = input_aux input_lines_tgt, output_lines_tgt, _, _, _ = output # TODO -- beam search tgt_pred = decode_minibatch(config['data']['max_len'], tgt['tok2id']['<s>'], model, input_lines_src, srclens, srcmask, input_ids_aux, auxlens, auxmask) # convert seqs to tokens def ids_to_toks(tok_seqs, id2tok): out = [] # take off the gpu tok_seqs = tok_seqs.cpu().numpy() # convert to toks, cut off at </s>, delete any start tokens (preds were kickstarted w them) for line in tok_seqs: toks = [id2tok[x] for x in line] if '<s>' in toks: toks.remove('<s>') cut_idx = toks.index('</s>') if '</s>' in toks else len(toks) out.append(toks[:cut_idx]) # unsort out = data.unsort(out, indices) return out # convert inputs/preds/targets/aux to human-readable form inputs += ids_to_toks(output_lines_src, src['id2tok']) preds += ids_to_toks(tgt_pred, tgt['id2tok']) ground_truths += ids_to_toks(output_lines_tgt, tgt['id2tok']) if config['model']['model_type'] == 'delete': auxs += [[str(x)] for x in input_ids_aux.data.cpu().numpy() ] # because of list comp in inference_metrics() elif config['model']['model_type'] == 'delete_retrieve': auxs += ids_to_toks(input_ids_aux, tgt['id2tok']) elif config['model']['model_type'] == 'seq2seq': auxs += ['None' for _ in range(len(tgt_pred))] return inputs, preds, ground_truths, auxs
def evaluate_lpp(model, src, tgt, config): """ evaluate log perplexity WITHOUT decoding (i.e., with teacher forcing) """ weight_mask = torch.ones(len(tgt['tok2id'])) if CUDA: weight_mask = weight_mask.cuda() weight_mask[tgt['tok2id']['<pad>']] = 0 loss_criterion = nn.CrossEntropyLoss(weight=weight_mask) if CUDA: loss_criterion = loss_criterion.cuda() losses = [] for j in range(0, len(src['data']), config['data']['batch_size']): sys.stdout.write("\r%s/%s..." % (j, len(src['data']))) sys.stdout.flush() # get batch input_content, input_aux, output = data.minibatch( src, tgt, j, config['data']['batch_size'], config['data']['max_len'], config['model']['model_type'], is_test=True) input_lines_src, _, srclens, srcmask, _ = input_content input_ids_aux, _, auxlens, auxmask, _ = input_aux input_lines_tgt, output_lines_tgt, _, _, _ = output decoder_logit, decoder_probs = model(input_lines_src, input_lines_tgt, srcmask, srclens, input_ids_aux, auxlens, auxmask) loss = loss_criterion( decoder_logit.contiguous().view(-1, len(tgt['tok2id'])), output_lines_tgt.view(-1)) losses.append(loss.data[0]) return np.mean(losses)
os.system("rm %s" % ckpt_path) # replace with new checkpoint torch.save(model.state_dict(), working_dir + '/model.%s.ckpt' % epoch) best_metric = cur_metric best_epoch = epoch - 1 losses = [] for i in range(0, len(src['content']), batch_size): if args.overfit: i = 50 batch_idx = i / batch_size input_content, input_aux, output = data.minibatch( src, tgt, i, batch_size, max_length, config['model']['model_type']) input_lines_src, _, srclens, srcmask, _ = input_content input_ids_aux, _, auxlens, auxmask, _ = input_aux input_lines_tgt, output_lines_tgt, _, _, _ = output decoder_logit, decoder_probs = model( input_lines_src, input_lines_tgt, srcmask, srclens, input_ids_aux, auxlens, auxmask) optimizer.zero_grad() loss = loss_criterion( decoder_logit.contiguous().view(-1, tgt_vocab_size), output_lines_tgt.view(-1) )