def evaluate_lpp(model, src, tgt, config): """ evaluate log perplexity WITHOUT decoding (i.e., with teacher forcing) """ weight_mask = torch.ones(len(tgt['tok2id'])) if CUDA: weight_mask = weight_mask.cuda() weight_mask[tgt['tok2id']['<pad>']] = 0 loss_criterion = nn.CrossEntropyLoss(weight=weight_mask) if CUDA: loss_criterion = loss_criterion.cuda() losses = [] for j in range(0, len(src['data']), config['data']['batch_size']): # get batch input_content, input_aux, output = data.minibatch(src, tgt, j, config['data']['batch_size'], config['data']['max_len'], config['model']['model_type'], is_test=True) input_content_src, _, srclens, srcmask, _ = input_content input_ids_aux, _, auxlens, auxmask, _ = input_aux input_data_tgt, output_data_tgt, _, _, _ = output decoder_logit, decoder_probs = model(input_content_src, srcmask, srclens, input_ids_aux, auxmask, auxlens, input_data_tgt, mode='train') loss = loss_criterion(decoder_logit.contiguous().view(-1, len(tgt['tok2id'])), output_data_tgt.view(-1)) losses.append(loss.item()) return np.mean(losses)
def decode_dataset(model, src, tgt, config): """Evaluate model.""" inputs = [] preds = [] auxs = [] ground_truths = [] for j in range(0, len(src['data']), config['data']['batch_size']): sys.stdout.write("\r%s/%s..." % (j, len(src['data']))) sys.stdout.flush() # get batch input_content, input_aux, output = data.minibatch( src, tgt, j, config['data']['batch_size'], config['data']['max_len'], config['model']['model_type'], is_test=True) input_lines_src, output_lines_src, srclens, srcmask, indices = input_content input_ids_aux, _, auxlens, auxmask, _ = input_aux input_lines_tgt, output_lines_tgt, _, _, _ = output # TODO -- beam search tgt_pred = decode_minibatch(config['data']['max_len'], tgt['tok2id']['<s>'], model, input_lines_src, srclens, srcmask, input_ids_aux, auxlens, auxmask) # convert seqs to tokens def ids_to_toks(tok_seqs, id2tok): out = [] # take off the gpu tok_seqs = tok_seqs.cpu().numpy() # convert to toks, cut off at </s>, delete any start tokens (preds were kickstarted w them) for line in tok_seqs: toks = [id2tok[x] for x in line] if '<s>' in toks: toks.remove('<s>') cut_idx = toks.index('</s>') if '</s>' in toks else len(toks) out.append(toks[:cut_idx]) # unsort out = data.unsort(out, indices) return out # convert inputs/preds/targets/aux to human-readable form inputs += ids_to_toks(output_lines_src, src['id2tok']) preds += ids_to_toks(tgt_pred, tgt['id2tok']) ground_truths += ids_to_toks(output_lines_tgt, tgt['id2tok']) if config['model']['model_type'] == 'delete': auxs += [[str(x)] for x in input_ids_aux.data.cpu().numpy() ] # because of list comp in inference_metrics() elif config['model']['model_type'] == 'delete_retrieve': auxs += ids_to_toks(input_ids_aux, tgt['id2tok']) elif config['model']['model_type'] == 'seq2seq': auxs += ['None' for _ in range(len(tgt_pred))] return inputs, preds, ground_truths, auxs
def evaluate_lpp(model, src, tgt, config): """ evaluate log perplexity WITHOUT decoding (i.e., with teacher forcing) """ weight_mask = torch.ones(len(tgt['tok2id'])) if CUDA: weight_mask = weight_mask.cuda() weight_mask[tgt['tok2id']['<pad>']] = 0 loss_criterion = nn.CrossEntropyLoss(weight=weight_mask) if CUDA: loss_criterion = loss_criterion.cuda() losses = [] for j in range(0, len(src['data']), config['data']['batch_size']): sys.stdout.write("\r%s/%s..." % (j, len(src['data']))) sys.stdout.flush() # get batch input_content, input_aux, output, side_info, _ = data.minibatch( src, tgt, j, config['data']['batch_size'], config['data']['max_len'], config, is_test=True) input_lines_src, _, srclens, srcmask, _ = input_content input_ids_aux, _, auxlens, auxmask, _ = input_aux input_lines_tgt, output_lines_tgt, _, _, _ = output side_info, _, _, _, _ = side_info decoder_logit, decoder_probs, _, _ = model(input_lines_src, input_lines_tgt, srcmask, srclens, input_ids_aux, auxlens, auxmask, side_info) loss = loss_criterion( decoder_logit.contiguous().view(-1, len(tgt['tok2id'])), output_lines_tgt.view(-1)) losses.append(loss.data[0]) return np.mean(losses)
def sinusoid(): net_init, net_fn = mlp(n_output=1, n_hidden_layer=2, bias_coef=1.0, n_hidden_unit=40, activation='relu', norm='batch_norm') rng = random.PRNGKey(42) in_shape = (-1, 1) out_shape, net_params = net_init(rng, in_shape) def loss(params, batch): inputs, targets = batch predictions = net_fn(params, inputs) return np.mean((predictions - targets)**2) opt_init, opt_update, get_params = optimizers.momentum(step_size=1e-2, mass=0.9) opt_update = jit(opt_update) @jit def step(i, opt_state, batch): params = get_params(opt_state) g = grad(loss)(params, batch) return opt_update(i, g, opt_state) task = sinusoid_task(n_support=1000, n_query=100) opt_state = opt_init(net_params) for i, (x, y) in enumerate( minibatch(task['x_train'], task['y_train'], batch_size=256, train_epochs=1000)): opt_state = step(i, opt_state, batch=(x, y)) if i == 0 or (i + 1) % 100 == 0: print( f"train loss: {loss(get_params(opt_state), (task['x_train'], task['y_train']))}," f"\ttest loss: {loss(get_params(opt_state), (task['x_test'], task['y_test']))}" )
def evaluate_rouge(model, src, tgt, config): """ evaluate log perplexity WITH decoding args: src: src data object (i.e. data 0, learnt by the model) tgt: target data object (i.e. data 0, learnt by the model) """ weight_mask = torch.ones(len(tgt['tok2id'])) if CUDA: weight_mask = weight_mask.cuda() weight_mask[tgt['tok2id']['<pad>']] = 0 searcher = models.GreedySearchDecoder(model) rouge_list = [] decoded_results = [] for j in range(0, len(src['data'])): # batch_size = 1 input_content, input_aux, output = data.minibatch(src, src, j, 1, config['data']['max_len'], config['model']['model_type']) input_content_src, _, srclens, srcmask, _ = input_content input_ids_aux, _, auxlens, auxmask, _ = input_aux input_data_tgt, output_data_tgt, _, _, _ = output decoder_logit, decoded_data_tgt = searcher(input_content_src, srcmask, srclens, input_ids_aux, auxmask, auxlens, 20, tgt['tok2id']['<s>']) decoded_sent = id2word(decoded_data_tgt, tgt) gold_sent = id2word(output_data_tgt, tgt) rouge = rouge_2(gold_sent, decoded_sent) rouge_list.append(rouge) decoded_results.append(decoded_sent) #print('Source content sentence:'+gold_sent) #print('Decoded data sentence:'+decoded_sent) return np.mean(rouge_list), decoded_results
def decode_dataset(model, src, tgt, config, k=20): """Evaluate model.""" inputs = [] preds = [] top_k_preds = [] auxs = [] ground_truths = [] raw_srcs = [] for j in range(0, len(src['data']), config['data']['batch_size']): sys.stdout.write("\r%s/%s..." % (j, len(src['data']))) sys.stdout.flush() # get batch input_content, input_aux, output, side_info, raw_src = data.minibatch( src, tgt, j, config['data']['batch_size'], config['data']['max_len'], config, is_test=True) input_lines_src, output_lines_src, srclens, srcmask, indices = input_content input_ids_aux, _, auxlens, auxmask, _ = input_aux input_lines_tgt, output_lines_tgt, _, _, _ = output _, raw_src, _, _, _ = raw_src side_info, _, _, _, _ = side_info # TODO -- beam search tgt_pred_top_k = decode_minibatch(config['data']['max_len'], tgt['tok2id']['<s>'], model, input_lines_src, srclens, srcmask, input_ids_aux, auxlens, auxmask, side_info, k=k) # convert inputs/preds/targets/aux to human-readable form inputs += ids_to_toks(output_lines_src, src['id2tok'], indices) ground_truths += ids_to_toks(output_lines_tgt, tgt['id2tok'], indices) raw_srcs += ids_to_toks(raw_src, src['id2tok'], indices) # TODO -- refactor this stuff!! it's shitty # get the "offical" predictions from the model pred_toks, pred_lens = ids_to_toks(tgt_pred_top_k[:, :, 0], tgt['id2tok'], indices, save_cuts=True) preds += pred_toks # now get all the other top-k prediction levels top_k_pred = [pred_toks] for i in range(k - 1): top_k_pred.append( ids_to_toks(tgt_pred_top_k[:, :, i + 1], tgt['id2tok'], indices, cuts=pred_lens)) # top_k_pred is [k, batch, length] where length is ragged # but we want it in [batch, length, k]. Manual transpose b/c ragged :( batch_size = len(top_k_pred[0]) # could be variable at test time pred_lens = data.unsort(pred_lens, indices) top_k_pred_transposed = [[] for _ in range(batch_size)] for bi in range(batch_size): for ti in range(pred_lens[bi]): top_k_pred_transposed[bi] += [[ top_k_pred[ki][bi][ti] for ki in range(k) ]] top_k_preds += top_k_pred_transposed if config['model']['model_type'] == 'delete': auxs += [[str(x)] for x in input_ids_aux.data.cpu().numpy() ] # because of list comp in inference_metrics() elif config['model']['model_type'] == 'delete_retrieve': auxs += ids_to_toks(input_ids_aux, tgt['id2tok'], indices) elif config['model']['model_type'] == 'seq2seq': auxs += ['None' for _ in range(batch_size)] return inputs, preds, top_k_preds, ground_truths, auxs, raw_srcs
def train(config, working_dir): # load data src, tok_weights_dict = data.gen_train_data(src=config['data']['src'], tgt=config['data']['tgt'], config=config) src_dev, tgt_dev = data.gen_dev_data(src=config['data']['src_dev'], tgt=config['data']['tgt_dev'], tok_weights_dict=tok_weights_dict, config=config) logging.info('Reading data done!') # build model model = build_model(src, config) logging.info('MODEL HAS %s params' % model.count_params()) # get most recent checkpoint model, start_epoch = attempt_load_model(model=model, checkpoint_dir=working_dir) # initialize loss criterion weight_mask = torch.ones(len(src['tok2id'])) weight_mask[src['tok2id']['<pad>']] = 0 loss_criterion = nn.CrossEntropyLoss(weight=weight_mask) if CUDA: model = model.cuda() weight_mask = weight_mask.cuda() loss_criterion = loss_criterion.cuda() # initialize optimizer if config['training']['optimizer'] == 'adam': lr = config['training']['learning_rate'] optimizer = optim.Adam(model.parameters(), lr=lr) elif config['training']['optimizer'] == 'sgd': lr = config['training']['learning_rate'] optimizer = optim.SGD(model.parameters(), lr=lr) elif config['training']['optimizer']=='adadelta': lr = config['training']['learning_rate'] optimizer = optim.Adadelta(model.parameters(), lr=lr) else: raise NotImplementedError("Learning method not recommend for task") # start training start_since_last_report = time.time() losses_since_last_report = [] best_metric = 0.0 cur_metric = 0.0 # log perplexity or BLEU dev_loss = 0.0 dev_rouge = 0.0 num_batches = len(src['content']) // config['data']['batch_size'] for epoch in range(start_epoch, config['training']['epochs']): if cur_metric > best_metric: # rm old checkpoint for ckpt_path in glob.glob(working_dir + '/model.*'): os.system("rm %s" % ckpt_path) # replace with new checkpoint torch.save(model.state_dict(), working_dir + '/model.%s.ckpt' % epoch) best_metric = cur_metric for i in range(0, len(src['content']), config['data']['batch_size']): batch_idx = i // config['data']['batch_size'] # generate current training data batch input_content, input_aux, output = data.minibatch(src, src, i, config['data']['batch_size'], config['data']['max_len'], config['model']['model_type']) input_content_src, _, srclens, srcmask, _ = input_content input_ids_aux, _, auxlens, auxmask, _ = input_aux input_data_tgt, output_data_tgt, _, _, _ = output # train the model with current training data batch decoder_logit, decoder_probs = model(input_content_src, srcmask, srclens, input_ids_aux, auxmask, auxlens, input_data_tgt, mode='train') # setup the optimizer optimizer.zero_grad() loss = loss_criterion(decoder_logit.contiguous().view(-1, len(src['tok2id'])), output_data_tgt.view(-1)) losses_since_last_report.append(loss.item()) # perform backpropagation loss.backward() # clip gradients _ = nn.utils.clip_grad_norm_(model.parameters(), config['training']['max_norm']) # update model params optimizer.step() # print out the training information if batch_idx % config['training']['batches_per_report'] == 0: s = float(time.time() - start_since_last_report) wps = (config['data']['batch_size'] * config['training']['batches_per_report']) / s avg_loss = np.mean(losses_since_last_report) info = (epoch, batch_idx, num_batches, wps, avg_loss, dev_loss, dev_rouge) cur_metric = dev_rouge logging.info('EPOCH: %s ITER: %s/%s WPS: %.2f LOSS: %.4f DEV_LOSS: %.4f DEV_ROUGE: %.4f' % info) start_since_last_report = time.time() losses_since_last_report = [] # start evaluate the model on entire dev set logging.info('EPOCH %s COMPLETE. VALIDATING...' % epoch) model.eval() # compute validation loss logging.info('Computing dev_loss on validation data ...') dev_loss = evaluation.evaluate_lpp(model=model, src=tgt_dev, tgt=tgt_dev, config=config) dev_rouge, decoded_sents = evaluation.evaluate_rouge(model=model, src=src_dev, tgt=tgt_dev, config=config) logging.info('...done!') # switch back to train mode model.train()
# replace with new checkpoint torch.save(model.state_dict(), working_dir + '/model.%s.ckpt' % epoch) best_metric = cur_metric best_epoch = epoch - 1 losses = [] for i in range(0, len(src['data']), batch_size): if args.test: continue if args.overfit: i = 0 batch_idx = i / batch_size input_content, input_aux, output, side_info, _ = data.minibatch( src, tgt, i, batch_size, max_length, config) input_lines_src, _, srclens, srcmask, _ = input_content input_ids_aux, _, auxlens, auxmask, _ = input_aux input_lines_tgt, output_lines_tgt, _, _, _ = output side_info, _, _, _, _ = side_info decoder_logit, decoder_probs, side_logit, side_loss = model( input_lines_src, input_lines_tgt, srcmask, srclens, input_ids_aux, auxlens, auxmask, side_info) optimizer.zero_grad() generator_loss = loss_criterion( decoder_logit.contiguous().view(-1, tgt_vocab_size), output_lines_tgt.view(-1)
os.system("rm %s" % ckpt_path) # replace with new checkpoint torch.save(model.state_dict(), working_dir + '/model.%s.ckpt' % epoch) best_metric = cur_metric best_epoch = epoch - 1 losses = [] for i in range(0, len(src['content']), batch_size): if args.overfit: i = 50 batch_idx = i / batch_size input_content, input_aux, output = data.minibatch( src, tgt, i, batch_size, max_length, config['model']['model_type']) input_lines_src, _, srclens, srcmask, _ = input_content input_ids_aux, _, auxlens, auxmask, _ = input_aux input_lines_tgt, output_lines_tgt, _, _, _ = output decoder_logit, decoder_probs = model(input_lines_src, input_lines_tgt, srcmask, srclens, input_ids_aux, auxlens, auxmask) optimizer.zero_grad() loss = loss_criterion( decoder_logit.contiguous().view(-1, tgt_vocab_size), output_lines_tgt.view(-1)) losses.append(loss.data.item()) losses_since_last_report.append(loss.data.item())
def my_decode_dataset(model, src, tgt, config): searcher = models.GreedySearchDecoder(model) rouge_list = [] initial_inputs = [] preds = [] ground_truths = [] auxs = [] for j in range(0, len(src['data'])): if j%100 == 0: logging.info('Finished decoding data: %d/%d ...'% (j, len(src['data']))) # batch_size = 1 inputs, _, outputs = data.minibatch(src, tgt, j, 1, config['data']['max_len'], config['model']['model_type'], is_test=True) input_content_src, _, srclens, srcmask, _ = inputs _, output_data_tgt, tgtlens, tgtmask, _ = outputs tgt_dist_measurer = tgt['dist_measurer'] related_content_tgt = tgt_dist_measurer.most_similar(j, n=3) # list of n seq_str # related_content_tgt = source_content_str, target_content_str, target_att_str, idx, score # Put all the retrieved attributes together retrieved_attrs_set = set() for single_data_tgt in related_content_tgt: sp = single_data_tgt[2].split() for attr in sp: retrieved_attrs_set.add(attr) retrieved_attrs = ' '.join(retrieved_attrs_set) input_ids_aux, auxlens, auxmask = word2id(retrieved_attrs, None, tgt, config['data']['max_len']) n_decoded_sents = [] input_ids_aux = Variable(torch.LongTensor(input_ids_aux)) auxlens = Variable(torch.LongTensor(auxlens)) auxmask = Variable(torch.LongTensor(auxmask)) if CUDA: input_ids_aux = input_ids_aux.cuda() auxlens = auxlens.cuda() auxmask = auxmask.cuda() _, decoded_data_tgt = searcher(input_content_src, srcmask, srclens, input_ids_aux, auxmask, auxlens, 20, tgt['tok2id']['<s>']) decode_sent = id2word(decoded_data_tgt, tgt) n_decoded_sents.append(decode_sent) #print('Source content sentence:'+''.join(related_content_tgt[0][1])) #print('Decoded data sentence:'+n_decoded_sents[0]) input_sent = id2word(input_content_src, src) initial_inputs.append(input_sent.split()) pred_sent = n_decoded_sents[0] preds.append(pred_sent.split()) truth_sent = id2word(output_data_tgt, tgt) ground_truths.append(truth_sent.split()) aux_sent = id2word(input_ids_aux, src) auxs.append(aux_sent.split()) rouge_cur = rouge_2(truth_sent, pred_sent) rouge_list.append(rouge_cur) return searcher, rouge_list, initial_inputs, preds, ground_truths, auxs