def valid_model(args, model, dev, dev_metrics=None, distillation=False, print_out=False, U=None, beam=1, alpha=0.6): print_seqs = ['[sources]', '[targets]', '[decoded]', '[fertili]', '[origind]'] src_outputs, trg_outputs, dec_outputs = [], [], [] outputs = {} model.eval() progressbar = tqdm(total=len([1 for _ in dev]), desc='start decoding for validation...') for j, dev_batch in enumerate(dev): inputs, input_masks, \ targets, target_masks, \ sources, source_masks, \ encoding, batch_size = model.quick_prepare(dev_batch, distillation, U=U) decoder_inputs, input_reorder, fertility_cost = inputs, None, None if type(model) is FastTransformer: decoder_inputs, input_reorder, decoder_masks, fertility_cost, pred_fertility = \ model.prepare_initial(encoding, sources, source_masks, input_masks, None, mode='argmax') else: decoder_masks = input_masks decoding, out, probs = model(encoding, source_masks, decoder_inputs, decoder_masks, decoding=True, return_probs=True, beam=beam, alpha=alpha) dev_outputs = [('src', sources), ('trg', targets), ('trg', decoding)] dev_outputs = [model.output_decoding(d) for d in dev_outputs] if (print_out and (j < 5)): for k, d in enumerate(dev_outputs): args.logger.info("{}: {}".format(print_seqs[k], d[0])) args.logger.info('------------------------------------------------------------------') src_outputs += dev_outputs[0] trg_outputs += dev_outputs[1] dec_outputs += dev_outputs[2] if dev_metrics is not None: values = [0, 0] dev_metrics.accumulate(batch_size, *values) info = 'Validation: decoding step={}'.format(j + 1) progressbar.update(1) progressbar.set_description(info) progressbar.close() corpus_bleu = computeBLEU(dec_outputs, trg_outputs, corpus=True, tokenizer=tokenizer) outputs['corpus_bleu'] = corpus_bleu outputs['dev_output'] = tuple(src_outputs, trg_outputs, dec_outputs) if dev_metrics is not None: args.logger.info(dev_metrics) args.logger.info("The dev-set corpus BLEU = {}".format(corpus_bleu)) return outputs
def validation_epoch_end(self, val_step_outputs): # global myGlobal # avg_val_loss = torch.tensor([x['loss'] for x in val_step_outputs]).mean() # avg_val_acc = torch.tensor([x["progress_bar"]["val_acc"] for x in val_step_outputs]).mean() # # pbar = {'avg_val_acc': avg_val_acc} print("Translation Sample =================") #"An old man trying to get up from a broken chair #A man wearing red shirt sitting under a tree for sentence in config.sentences: if config.USE_BPE == False: # if self.nepochs == config.MAX_EPOCHS: # myGlobal.change(True) # myGlobal = True translated_sentence = translate_sentence(self, sentence, self.german_vocab, self.english_vocab, self.deviceLegacy, max_length=50) # print("Output", translated_sentence) # print(sentence) # global myGlobal # myGlobal = False # exit() # if self.nepochs == config.MAX_EPOCHS: # myGlobal.change(False) # print("Input", sentence) # print("Output", translated_sentence) # exit() else: translated_sentence = translate_sentence_bpe( self, sentence, self.german_vocab, self.english_vocab, self.deviceLegacy, max_length=50) print("Output", translated_sentence) # if config.COMPUTE_BLEU == True and self.nepochs == config.MAX_EPOCHS: if config.COMPUTE_BLEU == True and self.nepochs > 0: bleu_score = computeBLEU(self.test_data, self, self.german_vocab, self.english_vocab, self.deviceLegacy) self.bleu_scores.append(bleu_score) print("BLEU score: ", bleu_score) if self.nepochs % 1 == 0: writeArrToCSV(self.bleu_scores) return
def valid_model(args, model, dev, dev_metrics=None, distillation=False, print_out=False, teacher_model=None): print_seqs = [ '[sources]', '[targets]', '[decoded]', '[fertili]', '[origind]' ] trg_outputs, dec_outputs = [], [] outputs = {} model.eval() if teacher_model is not None: teacher_model.eval() for j, dev_batch in enumerate(dev): inputs, input_masks, \ targets, target_masks, \ sources, source_masks, \ encoding, batch_size = model.quick_prepare(dev_batch, distillation) decoder_inputs, input_reorder, fertility_cost = inputs, None, None if type(model) is FastTransformer: decoder_inputs, input_reorder, decoder_masks, fertility_cost, pred_fertility = \ model.prepare_initial(encoding, sources, source_masks, input_masks, None, mode='argmax') else: decoder_masks = input_masks decoding, out, probs = model(encoding, source_masks, decoder_inputs, decoder_masks, decoding=True, return_probs=True) dev_outputs = [('src', sources), ('trg', targets), ('trg', decoding)] if type(model) is FastTransformer: dev_outputs += [('src', input_reorder)] dev_outputs = [model.output_decoding(d) for d in dev_outputs] gleu = computeGLEU(dev_outputs[2], dev_outputs[1], corpus=False, tokenizer=tokenizer) if print_out: for k, d in enumerate(dev_outputs): args.logger.info("{}: {}".format(print_seqs[k], d[0])) args.logger.info( '------------------------------------------------------------------' ) if teacher_model is not None: # teacher is Transformer, student is FastTransformer inputs_student, _, targets_student, _, _, _, encoding_teacher, _ \ = teacher_model.quick_prepare(dev_batch, False, decoding, decoding, input_masks, target_masks, source_masks) teacher_real_loss = teacher_model.cost( targets, target_masks, out=teacher_model(encoding_teacher, source_masks, inputs, input_masks)) teacher_fake_out = teacher_model(encoding_teacher, source_masks, inputs_student, input_masks) teacher_fake_loss = teacher_model.cost(targets_student, target_masks, out=teacher_fake_out) teacher_alter_loss = teacher_model.cost(targets, target_masks, out=teacher_fake_out) trg_outputs += dev_outputs[1] dec_outputs += dev_outputs[2] if dev_metrics is not None: values = [0, gleu] if teacher_model is not None: values += [ teacher_real_loss, teacher_fake_loss, teacher_real_loss - teacher_fake_loss, teacher_alter_loss, teacher_alter_loss - teacher_fake_loss ] if fertility_cost is not None: values += [fertility_cost] dev_metrics.accumulate(batch_size, *values) corpus_gleu = computeGLEU(dec_outputs, trg_outputs, corpus=True, tokenizer=tokenizer) corpus_bleu = computeBLEU(dec_outputs, trg_outputs, corpus=True, tokenizer=tokenizer) outputs['corpus_gleu'] = corpus_gleu outputs['corpus_bleu'] = corpus_bleu if dev_metrics is not None: args.logger.info(dev_metrics) args.logger.info("The dev-set corpus GLEU = {}".format(corpus_gleu)) args.logger.info("The dev-set corpus BLEU = {}".format(corpus_bleu)) return outputs
def valid_model(args, model, dev, dev_metrics=None, dev_metrics_trg=None, dev_metrics_average=None, print_out=False, teacher_model=None, trg_len_dic=None): print_seq = (['REF '] if args.dataset == "mscoco" else [ 'SRC ', 'REF ' ]) + ['HYP{}'.format(ii + 1) for ii in range(args.valid_repeat_dec)] trg_outputs = [] real_all_outputs = [[] for ii in range(args.valid_repeat_dec)] short_all_outputs = [[] for ii in range(args.valid_repeat_dec)] outputs_data = {} model.eval() for j, dev_batch in enumerate(dev): if args.dataset == "mscoco": # only use first caption for calculating log likelihood all_captions = dev_batch[1] dev_batch[1] = dev_batch[1][0] decoder_inputs, decoder_masks,\ targets, target_masks,\ _, source_masks,\ encoding, batch_size, rest = model.quick_prepare_mscoco(dev_batch, all_captions=all_captions, fast=(type(model) is FastTransformer), inputs_dec=args.inputs_dec, trg_len_option=args.trg_len_option, max_len=args.max_offset, trg_len_dic=trg_len_dic, bp=args.bp) else: decoder_inputs, decoder_masks,\ targets, target_masks,\ sources, source_masks,\ encoding, batch_size, rest = model.quick_prepare(dev_batch, fast=(type(model) is FastTransformer), trg_len_option=args.trg_len_option, trg_len_ratio=args.trg_len_ratio, trg_len_dic=trg_len_dic, bp=args.bp) losses, all_decodings = [], [] if type(model) is Transformer: decoding, out, probs = model(encoding, source_masks, decoder_inputs, decoder_masks, beam=1, decoding=True, return_probs=True) loss = model.cost(targets, target_masks, out=out) losses.append(loss) all_decodings.append(decoding) elif type(model) is FastTransformer: for iter_ in range(args.valid_repeat_dec): curr_iter = min(iter_, args.num_decs - 1) next_iter = min(curr_iter + 1, args.num_decs - 1) decoding, out, probs = model(encoding, source_masks, decoder_inputs, decoder_masks, decoding=True, return_probs=True, iter_=curr_iter) loss = model.cost(targets, target_masks, out=out, iter_=curr_iter) losses.append(loss) all_decodings.append(decoding) decoder_inputs = 0 if args.next_dec_input in ["both", "emb"]: _, argmax = torch.max(probs, dim=-1) emb = F.embedding( argmax, model.decoder[next_iter].out.weight * math.sqrt(args.d_model)) decoder_inputs += emb if args.next_dec_input in ["both", "out"]: decoder_inputs += out if args.dataset == "mscoco": # make sure that 5 captions per each example num_captions = len(all_captions[0]) for c in range(1, len(all_captions)): assert (num_captions == len(all_captions[c])) # untokenize reference captions for n_ref in range(len(all_captions)): n_caps = len(all_captions[0]) for c in range(n_caps): all_captions[n_ref][c] = all_captions[n_ref][c].replace( "@@ ", "") src_ref = [list(map(list, zip(*all_captions)))] else: src_ref = [ model.output_decoding(d) for d in [('src', sources), ('trg', targets)] ] real_outputs = [ model.output_decoding(d) for d in [('trg', xx) for xx in all_decodings] ] if print_out: if args.dataset != "mscoco": for k, d in enumerate(src_ref + real_outputs): args.logger.info("{} ({}): {}".format( print_seq[k], len(d[0].split(" ")), d[0])) else: for k in range(len(all_captions[0])): for c in range(len(all_captions)): args.logger.info("REF ({}): {}".format( len(all_captions[c][k].split(" ")), all_captions[c][k])) for c in range(len(real_outputs)): args.logger.info("HYP {} ({}): {}".format( c + 1, len(real_outputs[c][k].split(" ")), real_outputs[c][k])) args.logger.info( '------------------------------------------------------------------' ) trg_outputs += src_ref[-1] for ii, d_outputs in enumerate(real_outputs): real_all_outputs[ii] += d_outputs if dev_metrics is not None: dev_metrics.accumulate(batch_size, *losses) if dev_metrics_trg is not None: dev_metrics_trg.accumulate(batch_size, *[rest[0], rest[1], rest[2]]) if dev_metrics_average is not None: dev_metrics_average.accumulate(batch_size, *[rest[3], rest[4]]) if args.dataset != "mscoco": real_bleu = [ computeBLEU(ith_output, trg_outputs, corpus=True, tokenizer=tokenizer) for ith_output in real_all_outputs ] else: real_bleu = [ computeBLEUMSCOCO(ith_output, trg_outputs, corpus=True, tokenizer=tokenizer) for ith_output in real_all_outputs ] outputs_data['real'] = real_bleu if "predict" in args.trg_len_option: outputs_data['pred_target_len_loss'] = getattr(dev_metrics_trg, 'pred_target_len_loss') outputs_data['pred_target_len_correct'] = getattr( dev_metrics_trg, 'pred_target_len_correct') outputs_data['pred_target_len_approx'] = getattr( dev_metrics_trg, 'pred_target_len_approx') outputs_data['average_target_len_correct'] = getattr( dev_metrics_average, 'average_target_len_correct') outputs_data['average_target_len_approx'] = getattr( dev_metrics_average, 'average_target_len_approx') if dev_metrics is not None: args.logger.info(dev_metrics) if dev_metrics_trg is not None: args.logger.info(dev_metrics_trg) if dev_metrics_average is not None: args.logger.info(dev_metrics_average) for idx in range(args.valid_repeat_dec): print_str = "iter {} | {}".format( idx + 1, print_bleu(real_bleu[idx], verbose=False)) args.logger.info(print_str) return outputs_data
def decode_model(args, model, dev, evaluate=True, decoding_path=None, names=None, maxsteps=None): args.logger.info("decoding, f_size={}, beam_size={}, alpha={}".format( args.f_size, args.beam_size, args.alpha)) dev.train = False # make iterator volatile=True if maxsteps is None: progressbar = tqdm(total=sum([1 for _ in dev]), desc='start decoding') else: progressbar = tqdm(total=maxsteps, desc='start decoding') model.eval() if decoding_path is not None: handles = [ open(os.path.join(decoding_path, name), 'w') for name in names ] corpus_size = 0 src_outputs, trg_outputs, dec_outputs, timings = [], [], [], [] decoded_words, target_words, decoded_info = 0, 0, 0 attentions = None pad_id = model.decoder[0].field.vocab.stoi['<pad>'] eos_id = model.decoder[0].field.vocab.stoi['<eos>'] curr_time = 0 cum_bs = 0 for iters, dev_batch in enumerate(dev): if iters > maxsteps: args.logger.info('complete {} steps of decoding'.format(maxsteps)) break start_t = time.time() # encoding inputs, input_masks, \ targets, target_masks, \ sources, source_masks, \ encoding, batch_size = model.quick_prepare(dev_batch) cum_bs += batch_size # for now if type(model) is Transformer: all_decodings = [] decoder_inputs, decoder_masks = inputs, input_masks decoding = model(encoding, source_masks, decoder_inputs, decoder_masks, beam=args.beam_size, alpha=args.alpha, \ decoding=True, feedback=attentions) all_decodings.append(decoding) elif type(model) is FastTransformer: decoder_inputs, _, decoder_masks = \ model.prepare_initial(encoding, sources, source_masks, input_masks,\ N=args.f_size) batch_size, src_len, hsize = encoding[0].size() all_decodings = [] prev_dec_output = None iter_ = 0 while True: iter_num = min(iter_, args.num_shared_dec - 1) next_iter = min(iter_ + 1, args.num_shared_dec - 1) decoding, out, probs = model(encoding, source_masks, decoder_inputs, decoder_masks, decoding=True, return_probs=True, iter_=iter_num) all_decodings.append(decoding) thedecoder = model.decoder[iter_num] logits = thedecoder.out(out) _, argmax = torch.max(logits, dim=-1) decoder_inputs = F.embedding( argmax, model.decoder[next_iter].out.weight * math.sqrt(args.d_model)) if args.sum_out_and_emb: decoder_inputs += out iter_ += 1 if iter_ == args.valid_repeat_dec: break used_t = time.time() - start_t curr_time += used_t real_mask = 1 - ((decoding.data == eos_id) + (decoding.data == pad_id)).float() outputs = [ model.output_decoding(d) for d in [('src', sources), ('trg', targets), ('trg', decoding)] ] all_dec_outputs = [ model.output_decoding(d) for d in [('trg', all_decodings[ii]) for ii in range(len(all_decodings))] ] corpus_size += batch_size src_outputs += outputs[0] trg_outputs += outputs[1] dec_outputs += outputs[-1] """ for sent_i in range(len(outputs[0])): print ('SRC') print (outputs[0][sent_i]) print ('TRG') print (outputs[1][sent_i]) for ii in range(len(all_decodings)): print ('DEC iter {}'.format(ii)) print (all_dec_outputs[ii][sent_i]) print ('---------------------------') """ timings += [used_t] if decoding_path is not None: for s, t, d in zip(outputs[0], outputs[1], outputs[2]): s, t, d = s.replace('@@ ', ''), t.replace('@@ ', ''), d.replace('@@ ', '') print(s, file=handles[0], flush=True) print(t, file=handles[1], flush=True) print(d, file=handles[2], flush=True) print(curr_time / float(cum_bs) * 1000) #progressbar.update(1) #progressbar.set_description('finishing sentences={}/batches={}, speed={} sec/batch'.format(corpus_size, iters, curr_time / (1 + iters))) if evaluate: corpus_bleu = computeBLEU(dec_outputs, trg_outputs, corpus=True, tokenizer=tokenizer) #args.logger.info("The dev-set corpus BLEU = {}".format(corpus_bleu)) print("The dev-set corpus BLEU = {}".format(corpus_bleu))
def run_fast_transformer(decoder_inputs, decoder_masks,\ sources, source_masks,\ targets,\ encoding,\ model, args, use_argmax=True): trg_unidx = model.output_decoding(('trg', targets)) batch_size, src_len, hsize = encoding[0].size() all_decodings = [] all_probs = [] iter_ = 0 bleu_hist = [[] for xx in range(batch_size)] output_hist = [[] for xx in range(batch_size)] multiset_hist = [[] for xx in range(batch_size)] num_iters = [0 for xx in range(batch_size)] done_ = [False for xx in range(batch_size)] final_decoding = [None for xx in range(batch_size)] while True: curr_iter = min(iter_, args.num_decs - 1) next_iter = min(iter_ + 1, args.num_decs - 1) decoding, out, probs = model(encoding, source_masks, decoder_inputs, decoder_masks, decoding=True, return_probs=True, iter_=curr_iter) dec_output = decoding.data.cpu().numpy().tolist() """ if args.trg_len_option != "reference": decoder_masks = 0. * decoder_masks for bidx in range(batch_size): try: decoder_masks[bidx,:(dec_output[bidx].index(3))+1] = 1. except: decoder_masks[bidx,:] = 1. """ if args.adaptive_decoding == "oracle": out_unidx = model.output_decoding(('trg', decoding)) sentence_bleus = computeBLEU(out_unidx, trg_unidx, corpus=False, tokenizer=tokenizer) for bidx in range(batch_size): output_hist[bidx].append(dec_output[bidx]) bleu_hist[bidx].append(sentence_bleus[bidx]) converged = oracle_converged(bleu_hist, num_items=args.adaptive_window) for bidx in range(batch_size): if not done_[bidx] and converged[bidx] and num_iters[bidx] == 0: num_iters[bidx] = iter_ + 1 - (args.adaptive_window - 1) done_[bidx] = True final_decoding[bidx] = output_hist[bidx][-args. adaptive_window] elif args.adaptive_decoding == "equality": for bidx in range(batch_size): #if 3 in dec_output[bidx]: # dec_output[bidx] = dec_output[bidx][:dec_output[bidx].index(3)] output_hist[bidx].append(dec_output[bidx]) converged = equality_converged(output_hist, num_items=args.adaptive_window) for bidx in range(batch_size): if not done_[bidx] and converged[bidx] and num_iters[bidx] == 0: num_iters[bidx] = iter_ + 1 done_[bidx] = True final_decoding[bidx] = output_hist[bidx][-1] elif args.adaptive_decoding == "jaccard": for bidx in range(batch_size): #if 3 in dec_output[bidx]: # dec_output[bidx] = dec_output[bidx][:dec_output[bidx].index(3)] output_hist[bidx].append(dec_output[bidx]) multiset_hist[bidx].append(Multiset(dec_output[bidx])) converged = jaccard_converged(multiset_hist, num_items=args.adaptive_window) for bidx in range(batch_size): if not done_[bidx] and converged[bidx] and num_iters[bidx] == 0: num_iters[bidx] = iter_ + 1 done_[bidx] = True final_decoding[bidx] = output_hist[bidx][-1] all_decodings.append(decoding) all_probs.append(probs) decoder_inputs = 0 if args.next_dec_input in ["both", "emb"]: if use_argmax: _, argmax = torch.max(probs, dim=-1) else: probs_sz = probs.size() probs_ = Variable(probs.data, requires_grad=False) argmax = torch.multinomial( probs_.contiguous().view(-1, probs_sz[-1]), 1).view(*probs_sz[:-1]) emb = F.embedding( argmax, model.decoder[next_iter].out.weight * math.sqrt(args.d_model)) decoder_inputs += emb if args.next_dec_input in ["both", "out"]: decoder_inputs += out iter_ += 1 if iter_ == args.valid_repeat_dec or (False not in done_): break if args.adaptive_decoding != None: for bidx in range(batch_size): if num_iters[bidx] == 0: num_iters[bidx] = 20 if final_decoding[bidx] == None: if args.adaptive_decoding == "oracle": final_decoding[bidx] = output_hist[bidx][np.argmax( bleu_hist[bidx])] else: final_decoding[bidx] = output_hist[bidx][-1] decoding = Variable(torch.LongTensor(np.array(final_decoding))) if decoder_masks.is_cuda: decoding = decoding.cuda() return decoding, all_decodings, num_iters, all_probs
def decode_model(args, model, dev, evaluate=True, trg_len_dic=None, decoding_path=None, names=None, maxsteps=None): args.logger.info("decoding, f_size={}, beam_size={}, alpha={}".format( args.f_size, args.beam_size, args.alpha)) dev.train = False # make iterator volatile=True if not args.no_tqdm: progressbar = tqdm(total=200, desc='start decoding') model.eval() if not args.debug: decoding_path.mkdir(parents=True, exist_ok=True) handles = [(decoding_path / name).open('w') for name in names] corpus_size = 0 src_outputs, trg_outputs, dec_outputs, timings = [], [], [], [] all_decs = [[] for idx in range(args.valid_repeat_dec)] decoded_words, target_words, decoded_info = 0, 0, 0 attentions = None decoder = model.decoder[ 0] if args.model is FastTransformer else model.decoder pad_id = decoder.field.vocab.stoi['<pad>'] eos_id = decoder.field.vocab.stoi['<eos>'] curr_time = 0 cum_sentences = 0 cum_tokens = 0 cum_images = 0 # used for mscoco num_iters_total = [] for iters, dev_batch in enumerate(dev): start_t = time.time() if args.dataset != "mscoco": decoder_inputs, decoder_masks,\ targets, target_masks,\ sources, source_masks,\ encoding, batch_size, rest = model.quick_prepare(dev_batch, fast=(type(model) is FastTransformer), trg_len_option=args.trg_len_option, trg_len_ratio=args.trg_len_ratio, trg_len_dic=trg_len_dic, bp=args.bp) else: # only use first caption for calculating log likelihood all_captions = dev_batch[1] dev_batch[1] = dev_batch[1][0] decoder_inputs, decoder_masks,\ targets, target_masks,\ _, source_masks,\ encoding, batch_size, rest = model.quick_prepare_mscoco(dev_batch, all_captions=all_captions, fast=(type(model) is FastTransformer), inputs_dec=args.inputs_dec, trg_len_option=args.trg_len_option, max_len=args.max_len, trg_len_dic=trg_len_dic, bp=args.bp, gpu=args.gpu>-1) sources = None cum_sentences += batch_size batch_size, src_len, hsize = encoding[0].size() # for now if type(model) is Transformer: all_decodings = [] decoding = model(encoding, source_masks, decoder_inputs, decoder_masks, beam=args.beam_size, alpha=args.alpha, \ decoding=True, feedback=attentions) all_decodings.append(decoding) num_iters = [0] elif type(model) is FastTransformer: decoding, all_decodings, num_iters, argmax_all_probs = run_fast_transformer(decoder_inputs, decoder_masks, \ sources, source_masks, targets, encoding, model, args, use_argmax=True) num_iters_total.extend(num_iters) if not args.use_argmax: for _ in range(args.num_samples): _, _, _, sampled_all_probs = run_fast_transformer(decoder_inputs, decoder_masks, \ sources, source_masks, encoding, model, args, use_argmax=False) for iter_ in range(args.valid_repeat_dec): argmax_all_probs[iter_] = argmax_all_probs[ iter_] + sampled_all_probs[iter_] all_decodings = [] for iter_ in range(args.valid_repeat_dec): argmax_all_probs[ iter_] = argmax_all_probs[iter_] / args.num_samples all_decodings.append( torch.max(argmax_all_probs[iter_], dim=-1)[-1]) decoding = all_decodings[-1] used_t = time.time() - start_t curr_time += used_t if args.dataset != "mscoco": if args.remove_repeats: outputs_unidx = [ model.output_decoding(d) for d in [('src', sources), ( 'trg', targets), ('trg', remove_repeats_tensor(decoding))] ] else: outputs_unidx = [ model.output_decoding(d) for d in [('src', sources), ('trg', targets), ('trg', decoding)] ] else: # make sure that 5 captions per each example num_captions = len(all_captions[0]) for c in range(1, len(all_captions)): assert (num_captions == len(all_captions[c])) # untokenize reference captions for n_ref in range(len(all_captions)): n_caps = len(all_captions[0]) for c in range(n_caps): all_captions[n_ref][c] = all_captions[n_ref][c].replace( "@@ ", "") outputs_unidx = [list(map(list, zip(*all_captions)))] if args.remove_repeats: all_dec_outputs = [ model.output_decoding(d) for d in [('trg', remove_repeats_tensor(all_decodings[ii])) for ii in range(len(all_decodings))] ] else: all_dec_outputs = [ model.output_decoding(d) for d in [('trg', all_decodings[ii]) for ii in range(len(all_decodings))] ] corpus_size += batch_size if args.dataset != "mscoco": cum_tokens += sum([len(xx.split(" ")) for xx in outputs_unidx[0] ]) # NOTE source tokens, not target if args.dataset != "mscoco": src_outputs += outputs_unidx[0] trg_outputs += outputs_unidx[1] if args.remove_repeats: dec_outputs += remove_repeats(outputs_unidx[-1]) else: dec_outputs += outputs_unidx[-1] else: trg_outputs += outputs_unidx[0] for idx, each_output in enumerate(all_dec_outputs): if args.remove_repeats: all_decs[idx] += remove_repeats(each_output) else: all_decs[idx] += each_output #if True: if False and decoding_path is not None: for sent_i in range(len(outputs_unidx[0])): if args.dataset != "mscoco": print('SRC') print(outputs_unidx[0][sent_i]) for ii in range(len(all_decodings)): print('DEC iter {}'.format(ii)) print(all_dec_outputs[ii][sent_i]) print('TRG') print(outputs_unidx[1][sent_i]) else: print('TRG') trg = outputs_unidx[0] for subsent_i in range(len(trg[sent_i])): print('TRG {}'.format(subsent_i)) print(trg[sent_i][subsent_i]) for ii in range(len(all_decodings)): print('DEC iter {}'.format(ii)) print(all_dec_outputs[ii][sent_i]) print('---------------------------') timings += [used_t] if not args.debug: for s, t, d in zip(outputs_unidx[0], outputs_unidx[1], outputs_unidx[2]): s, t, d = s.replace('@@ ', ''), t.replace('@@ ', ''), d.replace('@@ ', '') print(s, file=handles[0], flush=True) print(t, file=handles[1], flush=True) print(d, file=handles[2], flush=True) if not args.no_tqdm: progressbar.update(iters) progressbar.set_description('finishing sentences={}/batches={}, \ length={}/average iter={}, speed={} sec/batch' .format(\ corpus_size, iters, src_len, np.mean(np.array(num_iters)), curr_time / (1 + iters))) if evaluate: for idx, each_dec in enumerate(all_decs): if len(all_decs[idx]) != len(trg_outputs): break if args.dataset != "mscoco": bleu_output = computeBLEU(each_dec, trg_outputs, corpus=True, tokenizer=tokenizer) else: bleu_output = computeBLEUMSCOCO(each_dec, trg_outputs, corpus=True, tokenizer=tokenizer) args.logger.info("iter {} | {}".format(idx + 1, print_bleu(bleu_output))) if args.adaptive_decoding != None: args.logger.info("----------------------------------------------") args.logger.info("Average # iters {}".format(np.mean(num_iters_total))) bleu_output = computeBLEU(dec_outputs, trg_outputs, corpus=True, tokenizer=tokenizer) args.logger.info("Adaptive BLEU | {}".format(print_bleu(bleu_output))) args.logger.info("----------------------------------------------") args.logger.info("Decoding speed analysis :") args.logger.info("{} sentences".format(cum_sentences)) if args.dataset != "mscoco": args.logger.info("{} tokens".format(cum_tokens)) args.logger.info("{:.3f} seconds".format(curr_time)) args.logger.info("{:.3f} ms / sentence".format( (curr_time / float(cum_sentences) * 1000))) if args.dataset != "mscoco": args.logger.info("{:.3f} ms / token".format( (curr_time / float(cum_tokens) * 1000))) args.logger.info("{:.3f} sentences / s".format( float(cum_sentences) / curr_time)) if args.dataset != "mscoco": args.logger.info("{:.3f} tokens / s".format( float(cum_tokens) / curr_time)) args.logger.info("----------------------------------------------") if args.decode_which > 0: args.logger.info("Writing to special file") parent = decoding_path / "speed" / "b_{}{}".format( args.beam_size if args.model is Transformer else args.valid_repeat_dec, "" if args.model is Transformer else "_{}".format(args.adaptive_decoding != None)) args.logger.info(str(parent)) parent.mkdir(parents=True, exist_ok=True) speed_handle = (parent / "results.{}".format(args.decode_which)).open('w') print("----------------------------------------------", file=speed_handle, flush=True) print("Decoding speed analysis :", file=speed_handle, flush=True) print("{} sentences".format(cum_sentences), file=speed_handle, flush=True) if args.dataset != "mscoco": print("{} tokens".format(cum_tokens), file=speed_handle, flush=True) print("{:.3f} seconds".format(curr_time), file=speed_handle, flush=True) print("{:.3f} ms / sentence".format( (curr_time / float(cum_sentences) * 1000)), file=speed_handle, flush=True) if args.dataset != "mscoco": print("{:.3f} ms / token".format( (curr_time / float(cum_tokens) * 1000)), file=speed_handle, flush=True) print("{:.3f} sentences / s".format(float(cum_sentences) / curr_time), file=speed_handle, flush=True) if args.dataset != "mscoco": print("{:.3f} tokens / s".format(float(cum_tokens) / curr_time), file=speed_handle, flush=True) print("----------------------------------------------", file=speed_handle, flush=True)
def decode_model(args, model, dev, teacher_model=None, evaluate=True, decoding_path=None, names=None, maxsteps=None): args.logger.info("decoding with {}, f_size={}, beam_size={}, alpha={}".format(args.decode_mode, args.f_size, args.beam_size, args.alpha)) dev.train = False # make iterator volatile=True if maxsteps is None: progressbar = tqdm(total=sum([1 for _ in dev]), desc='start decoding') else: progressbar = tqdm(total=maxsteps, desc='start decoding') model.eval() if teacher_model is not None: assert (args.f_size * args.beam_size > 1), 'multiple samples are essential.' teacher_model.eval() if decoding_path is not None: handles = [open(os.path.join(decoding_path, name), 'w') for name in names] corpus_size = 0 src_outputs, trg_outputs, dec_outputs, timings = [], [], [], [] decoded_words, target_words, decoded_info = 0, 0, 0 attentions = None pad_id = model.decoder.field.vocab.stoi['<pad>'] eos_id = model.decoder.field.vocab.stoi['<eos>'] curr_time = 0 for iters, dev_batch in enumerate(dev): if iters > maxsteps: args.logger.info('complete {} steps of decoding'.format(maxsteps)) break start_t = time.time() # encoding inputs, input_masks, targets, target_masks, sources, source_masks, encoding, batch_size = model.quick_prepare(dev_batch) if args.model is Transformer: # decoding from the Transformer decoder_inputs, decoder_masks = inputs, input_masks decoding = model(encoding, source_masks, decoder_inputs, decoder_masks, beam=args.beam_size, alpha=args.alpha, decoding=True, feedback=attentions) else: # decoding from the FastTransformer if teacher_model is not None: encoding_teacher = teacher_model.encoding(sources, source_masks) decoder_inputs, input_reorder, decoder_masks, _, fertility = \ model.prepare_initial(encoding, sources, source_masks, input_masks, None, mode=args.decode_mode, N=args.f_size) batch_size, src_len, hsize = encoding[0].size() trg_len = targets.size(1) if args.f_size > 1: source_masks = source_masks[:, None, :].expand(batch_size, args.f_size, src_len) source_masks = source_masks.contiguous().view(batch_size * args.f_size, src_len) for i in range(len(encoding)): encoding[i] = encoding[i][:, None, :].expand( batch_size, args.f_size, src_len, hsize).contiguous().view(batch_size * args.f_size, src_len, hsize) decoding = model(encoding, source_masks, decoder_inputs, decoder_masks, beam=args.beam_size, decoding=True, feedback=attentions) total_size = args.beam_size * args.f_size # print(fertility.data.sum() - decoder_masks.sum()) # print(fertility.data.sum() * args.beam_size - (decoding.data != 1).long().sum()) if total_size > 1: if args.beam_size > 1: source_masks = source_masks[:, None, :].expand(batch_size * args.f_size, args.beam_size, src_len).contiguous().view(batch_size * total_size, src_len) fertility = fertility[:, None, :].expand(batch_size * args.f_size, args.beam_size, src_len).contiguous().view(batch_size * total_size, src_len) # fertility = model.apply_mask(fertility, source_masks, -1) if teacher_model is not None: # use teacher model to re-rank the translation decoder_masks = teacher_model.prepare_masks(decoding) for i in range(len(encoding_teacher)): encoding_teacher[i] = encoding_teacher[i][:, None, :].expand( batch_size, total_size, src_len, hsize).contiguous().view( batch_size * total_size, src_len, hsize) student_inputs, _ = teacher_model.prepare_inputs( dev_batch, decoding, decoder_masks) student_targets, _ = teacher_model.prepare_targets(dev_batch, decoding, decoder_masks) out, probs = teacher_model(encoding_teacher, source_masks, student_inputs, decoder_masks, return_probs=True, decoding=False) _, teacher_loss = model.batched_cost(student_targets, decoder_masks, probs, batched=True) # student-loss (MLE) # reranking the translation teacher_loss = teacher_loss.view(batch_size, total_size) decoding = decoding.view(batch_size, total_size, -1) fertility = fertility.view(batch_size, total_size, -1) lp = decoder_masks.sum(1).view(batch_size, total_size) ** (1 - args.alpha) teacher_loss = teacher_loss * Variable(lp) # selected index selected_idx = (-teacher_loss).topk(1, 1)[1] # batch x 1 decoding = decoding.gather(1, selected_idx[:, :, None].expand(batch_size, 1, decoding.size(-1)))[:, 0, :] fertility = fertility.gather(1, selected_idx[:, :, None].expand(batch_size, 1, fertility.size(-1)))[:, 0, :] else: # (cheating, re-rank by sentence-BLEU score) # compute GLEU score to select the best translation trg_output = model.output_decoding(('trg', targets[:, None, :].expand(batch_size, total_size, trg_len).contiguous().view(batch_size * total_size, trg_len))) dec_output = model.output_decoding(('trg', decoding)) bleu_score = computeBLEU(dec_output, trg_output, corpus=False, tokenizer=tokenizer).contiguous().view(batch_size, total_size) bleu_score = bleu_score.cuda(args.gpu) selected_idx = bleu_score.max(1)[1] decoding = decoding.view(batch_size, total_size, -1) fertility = fertility.view(batch_size, total_size, -1) decoding = decoding.gather(1, selected_idx[:, None, None].expand(batch_size, 1, decoding.size(-1)))[:, 0, :] fertility = fertility.gather(1, selected_idx[:, None, None].expand(batch_size, 1, fertility.size(-1)))[:, 0, :] # print(fertility.data.sum() - (decoding.data != 1).long().sum()) assert (fertility.data.sum() - (decoding.data != 1).long().sum() == 0), 'fer match decode' used_t = time.time() - start_t curr_time += used_t real_mask = 1 - ((decoding.data == eos_id) + (decoding.data == pad_id)).float() outputs = [model.output_decoding(d) for d in [('src', sources), ('trg', targets), ('trg', decoding)]] corpus_size += batch_size src_outputs += outputs[0] trg_outputs += outputs[1] dec_outputs += outputs[2] timings += [used_t] if decoding_path is not None: for s, t, d in zip(outputs[0], outputs[1], outputs[2]): if args.no_bpe: s, t, d = s.replace('@@ ', ''), t.replace('@@ ', ''), d.replace('@@ ', '') print(s, file=handles[0], flush=True) print(t, file=handles[1], flush=True) print(d, file=handles[2], flush=True) if args.model is FastTransformer: with torch.cuda.device_of(fertility): fertility = fertility.data.tolist() for f in fertility: f = ' '.join([str(fi) for fi in cutoff(f, 0)]) print(f, file=handles[3], flush=True) progressbar.update(1) progressbar.set_description('finishing sentences={}/batches={}, speed={} sec/batch'.format(corpus_size, iters, curr_time / (1 + iters))) if evaluate: corpus_gleu = computeGLEU(dec_outputs, trg_outputs, corpus=True, tokenizer=tokenizer) corpus_bleu = computeBLEU(dec_outputs, trg_outputs, corpus=True, tokenizer=tokenizer) args.logger.info("The dev-set corpus GLEU = {}".format(corpus_gleu)) args.logger.info("The dev-set corpus BLEU = {}".format(corpus_bleu))
def valid_model(args, model, dev, dev_metrics=None, print_out=False, teacher_model=None): print_seqs = ['SRC ', 'REF '] + ['HYP{}'.format(ii+1) for ii in range(args.valid_repeat_dec)] trg_outputs = [] all_outputs = [ [] for ii in range(args.valid_repeat_dec)] outputs_data = {} model.eval() if teacher_model is not None: teacher_model.eval() for j, dev_batch in enumerate(dev): inputs, input_masks, \ targets, target_masks, \ sources, source_masks, \ encoding, batch_size = model.quick_prepare(dev_batch) if type(model) is Transformer: decoder_inputs, decoder_masks = inputs, input_masks elif type(model) is FastTransformer: decoder_inputs, _, decoder_masks = \ model.prepare_initial(encoding, sources, source_masks, input_masks) initial_inputs = decoder_inputs if type(model) is Transformer: decoding, out, probs = model(encoding, source_masks, decoder_inputs, decoder_masks, decoding=True, return_probs=True) elif type(model) is FastTransformer: losses, all_decodings = [], [] for iter_ in range(args.valid_repeat_dec): curr_iter = min(iter_, args.num_shared_dec-1) next_iter = min(curr_iter + 1, args.num_shared_dec-1) decoding, out, probs = model(encoding, source_masks, decoder_inputs, decoder_masks, decoding=True, return_probs=True, iter_=curr_iter) losses.append( model.cost(targets, target_masks, out=out, iter_=curr_iter) ) all_decodings.append( decoding ) logits = model.decoder[curr_iter].out(out) _, argmax = torch.max(logits, dim=-1) decoder_inputs = F.embedding(argmax, model.decoder[next_iter].out.weight * math.sqrt(args.d_model)) if args.sum_out_and_emb: decoder_inputs += out dev_outputs = [('src', sources), ('trg', targets)] if type(model) is Transformer: dev_outputs += [('trg', decoding)] elif type(model) is FastTransformer: dev_outputs += [('trg', xx) for xx in all_decodings] dev_outputs = [model.output_decoding(d) for d in dev_outputs] if print_out: for k, d in enumerate(dev_outputs): args.logger.info("{}: {}".format(print_seqs[k], d[0])) args.logger.info('------------------------------------------------------------------') trg_outputs += dev_outputs[1] for ii, d_outputs in enumerate(dev_outputs[2:]): all_outputs[ii] += d_outputs if dev_metrics is not None: dev_metrics.accumulate(batch_size, *losses) bleu = [100 * computeBLEU(ith_output, trg_outputs, corpus=True, tokenizer=tokenizer) for ith_output in all_outputs] outputs_data['bleu'] = bleu if dev_metrics is not None: args.logger.info(dev_metrics) args.logger.info("dev BLEU: {}".format(bleu)) return outputs_data
def decode_model(args, watcher, model, dev, evaluate=True, decoding_path=None, names=None, maxsteps=None): print_seqs = ['[sources]', '[targets]', '[decoded]'] args.logger.info("decoding beam-search: beam_size={}, alpha={}".format( args.beam_size, args.alpha)) dev.train = False # make iterator volatile=True if maxsteps is None: maxsteps = sum([1 for _ in dev]) progressbar = tqdm(total=maxsteps, desc='start decoding') model.eval() if decoding_path is not None: handles = [ open(os.path.join(decoding_path, name), 'w') for name in names ] corpus_size = 0 src_outputs, trg_outputs, dec_outputs, timings = [], [], [], [] decoded_words, target_words, decoded_info = 0, 0, 0 attentions = None pad_id = model.decoder.field.vocab.stoi['<pad>'] eos_id = model.decoder.field.vocab.stoi['<eos>'] curr_time = 0 for iters, dev_batch in enumerate(dev): if iters > maxsteps: args.logger.info('complete {} steps of decoding'.format(maxsteps)) break start_t = time.time() # prepare the data source_inputs, source_outputs, source_masks, \ target_inputs, target_outputs, target_masks = model.prepare_data(dev_batch) if not args.real_time: # encoding encoding_outputs = model.encoding(source_inputs, source_masks) # decoding decoding_outputs = model.decoding(encoding_outputs, source_masks, target_inputs, target_masks, beam=args.beam_size, alpha=args.alpha, decoding=True, return_probs=False) else: # currently only supports simultaneous greedy decoding decoding_outputs = model.simultaneous_decoding( source_inputs, source_masks) # reverse to string-sequence dev_outputs = [ model.io_enc.reverse(source_outputs), model.io_dec.reverse(target_outputs), model.io_dec.reverse(decoding_outputs) ] # for j in range(source_inputs.size(0)): # for k, d in enumerate(dev_outputs): # args.logger.info("{}: {}".format(print_seqs[k], d[j])) # args.logger.info("-----------------------------------") # 1/0 used_t = time.time() - start_t curr_time += used_t real_mask = 1 - ((decoding_outputs == eos_id) + (decoding_outputs == pad_id)).float() corpus_size += source_inputs.size(0) src_outputs += dev_outputs[0] trg_outputs += dev_outputs[1] dec_outputs += dev_outputs[2] timings += [used_t] if decoding_path is not None: for s, t, d in zip(dev_outputs[0], dev_outputs[1], dev_outputs[2]): if args.no_bpe: s, t, d = s.replace('@@ ', ''), t.replace('@@ ', ''), d.replace( '@@ ', '') print(s, file=handles[0], flush=True) print(t, file=handles[1], flush=True) print(d, file=handles[2], flush=True) progressbar.update(1) progressbar.set_description( 'finishing sentences={}/batches={}, speed={:.2f} sentences / sec'. format(corpus_size, iters, corpus_size / curr_time)) if evaluate: corpus_bleu = computeBLEU(dec_outputs, trg_outputs, corpus=True, tokenizer=debpe) args.logger.info("The dev-set corpus BLEU = {}".format(corpus_bleu))