def valid_model(args, model, dev, dev_metrics=None, dev_metrics_trg=None, dev_metrics_average=None, print_out=False, teacher_model=None, trg_len_dic=None): print_seq = (['REF '] if args.dataset == "mscoco" else [ 'SRC ', 'REF ' ]) + ['HYP{}'.format(ii + 1) for ii in range(args.valid_repeat_dec)] trg_outputs = [] real_all_outputs = [[] for ii in range(args.valid_repeat_dec)] short_all_outputs = [[] for ii in range(args.valid_repeat_dec)] outputs_data = {} model.eval() for j, dev_batch in enumerate(dev): if args.dataset == "mscoco": # only use first caption for calculating log likelihood all_captions = dev_batch[1] dev_batch[1] = dev_batch[1][0] decoder_inputs, decoder_masks,\ targets, target_masks,\ _, source_masks,\ encoding, batch_size, rest = model.quick_prepare_mscoco(dev_batch, all_captions=all_captions, fast=(type(model) is FastTransformer), inputs_dec=args.inputs_dec, trg_len_option=args.trg_len_option, max_len=args.max_offset, trg_len_dic=trg_len_dic, bp=args.bp) else: decoder_inputs, decoder_masks,\ targets, target_masks,\ sources, source_masks,\ encoding, batch_size, rest = model.quick_prepare(dev_batch, fast=(type(model) is FastTransformer), trg_len_option=args.trg_len_option, trg_len_ratio=args.trg_len_ratio, trg_len_dic=trg_len_dic, bp=args.bp) losses, all_decodings = [], [] if type(model) is Transformer: decoding, out, probs = model(encoding, source_masks, decoder_inputs, decoder_masks, beam=1, decoding=True, return_probs=True) loss = model.cost(targets, target_masks, out=out) losses.append(loss) all_decodings.append(decoding) elif type(model) is FastTransformer: for iter_ in range(args.valid_repeat_dec): curr_iter = min(iter_, args.num_decs - 1) next_iter = min(curr_iter + 1, args.num_decs - 1) decoding, out, probs = model(encoding, source_masks, decoder_inputs, decoder_masks, decoding=True, return_probs=True, iter_=curr_iter) loss = model.cost(targets, target_masks, out=out, iter_=curr_iter) losses.append(loss) all_decodings.append(decoding) decoder_inputs = 0 if args.next_dec_input in ["both", "emb"]: _, argmax = torch.max(probs, dim=-1) emb = F.embedding( argmax, model.decoder[next_iter].out.weight * math.sqrt(args.d_model)) decoder_inputs += emb if args.next_dec_input in ["both", "out"]: decoder_inputs += out if args.dataset == "mscoco": # make sure that 5 captions per each example num_captions = len(all_captions[0]) for c in range(1, len(all_captions)): assert (num_captions == len(all_captions[c])) # untokenize reference captions for n_ref in range(len(all_captions)): n_caps = len(all_captions[0]) for c in range(n_caps): all_captions[n_ref][c] = all_captions[n_ref][c].replace( "@@ ", "") src_ref = [list(map(list, zip(*all_captions)))] else: src_ref = [ model.output_decoding(d) for d in [('src', sources), ('trg', targets)] ] real_outputs = [ model.output_decoding(d) for d in [('trg', xx) for xx in all_decodings] ] if print_out: if args.dataset != "mscoco": for k, d in enumerate(src_ref + real_outputs): args.logger.info("{} ({}): {}".format( print_seq[k], len(d[0].split(" ")), d[0])) else: for k in range(len(all_captions[0])): for c in range(len(all_captions)): args.logger.info("REF ({}): {}".format( len(all_captions[c][k].split(" ")), all_captions[c][k])) for c in range(len(real_outputs)): args.logger.info("HYP {} ({}): {}".format( c + 1, len(real_outputs[c][k].split(" ")), real_outputs[c][k])) args.logger.info( '------------------------------------------------------------------' ) trg_outputs += src_ref[-1] for ii, d_outputs in enumerate(real_outputs): real_all_outputs[ii] += d_outputs if dev_metrics is not None: dev_metrics.accumulate(batch_size, *losses) if dev_metrics_trg is not None: dev_metrics_trg.accumulate(batch_size, *[rest[0], rest[1], rest[2]]) if dev_metrics_average is not None: dev_metrics_average.accumulate(batch_size, *[rest[3], rest[4]]) if args.dataset != "mscoco": real_bleu = [ computeBLEU(ith_output, trg_outputs, corpus=True, tokenizer=tokenizer) for ith_output in real_all_outputs ] else: real_bleu = [ computeBLEUMSCOCO(ith_output, trg_outputs, corpus=True, tokenizer=tokenizer) for ith_output in real_all_outputs ] outputs_data['real'] = real_bleu if "predict" in args.trg_len_option: outputs_data['pred_target_len_loss'] = getattr(dev_metrics_trg, 'pred_target_len_loss') outputs_data['pred_target_len_correct'] = getattr( dev_metrics_trg, 'pred_target_len_correct') outputs_data['pred_target_len_approx'] = getattr( dev_metrics_trg, 'pred_target_len_approx') outputs_data['average_target_len_correct'] = getattr( dev_metrics_average, 'average_target_len_correct') outputs_data['average_target_len_approx'] = getattr( dev_metrics_average, 'average_target_len_approx') if dev_metrics is not None: args.logger.info(dev_metrics) if dev_metrics_trg is not None: args.logger.info(dev_metrics_trg) if dev_metrics_average is not None: args.logger.info(dev_metrics_average) for idx in range(args.valid_repeat_dec): print_str = "iter {} | {}".format( idx + 1, print_bleu(real_bleu[idx], verbose=False)) args.logger.info(print_str) return outputs_data
def decode_model(args, model, dev, evaluate=True, trg_len_dic=None, decoding_path=None, names=None, maxsteps=None): args.logger.info("decoding, f_size={}, beam_size={}, alpha={}".format( args.f_size, args.beam_size, args.alpha)) dev.train = False # make iterator volatile=True if not args.no_tqdm: progressbar = tqdm(total=200, desc='start decoding') model.eval() if not args.debug: decoding_path.mkdir(parents=True, exist_ok=True) handles = [(decoding_path / name).open('w') for name in names] corpus_size = 0 src_outputs, trg_outputs, dec_outputs, timings = [], [], [], [] all_decs = [[] for idx in range(args.valid_repeat_dec)] decoded_words, target_words, decoded_info = 0, 0, 0 attentions = None decoder = model.decoder[ 0] if args.model is FastTransformer else model.decoder pad_id = decoder.field.vocab.stoi['<pad>'] eos_id = decoder.field.vocab.stoi['<eos>'] curr_time = 0 cum_sentences = 0 cum_tokens = 0 cum_images = 0 # used for mscoco num_iters_total = [] for iters, dev_batch in enumerate(dev): start_t = time.time() if args.dataset != "mscoco": decoder_inputs, decoder_masks,\ targets, target_masks,\ sources, source_masks,\ encoding, batch_size, rest = model.quick_prepare(dev_batch, fast=(type(model) is FastTransformer), trg_len_option=args.trg_len_option, trg_len_ratio=args.trg_len_ratio, trg_len_dic=trg_len_dic, bp=args.bp) else: # only use first caption for calculating log likelihood all_captions = dev_batch[1] dev_batch[1] = dev_batch[1][0] decoder_inputs, decoder_masks,\ targets, target_masks,\ _, source_masks,\ encoding, batch_size, rest = model.quick_prepare_mscoco(dev_batch, all_captions=all_captions, fast=(type(model) is FastTransformer), inputs_dec=args.inputs_dec, trg_len_option=args.trg_len_option, max_len=args.max_len, trg_len_dic=trg_len_dic, bp=args.bp, gpu=args.gpu>-1) sources = None cum_sentences += batch_size batch_size, src_len, hsize = encoding[0].size() # for now if type(model) is Transformer: all_decodings = [] decoding = model(encoding, source_masks, decoder_inputs, decoder_masks, beam=args.beam_size, alpha=args.alpha, \ decoding=True, feedback=attentions) all_decodings.append(decoding) num_iters = [0] elif type(model) is FastTransformer: decoding, all_decodings, num_iters, argmax_all_probs = run_fast_transformer(decoder_inputs, decoder_masks, \ sources, source_masks, targets, encoding, model, args, use_argmax=True) num_iters_total.extend(num_iters) if not args.use_argmax: for _ in range(args.num_samples): _, _, _, sampled_all_probs = run_fast_transformer(decoder_inputs, decoder_masks, \ sources, source_masks, encoding, model, args, use_argmax=False) for iter_ in range(args.valid_repeat_dec): argmax_all_probs[iter_] = argmax_all_probs[ iter_] + sampled_all_probs[iter_] all_decodings = [] for iter_ in range(args.valid_repeat_dec): argmax_all_probs[ iter_] = argmax_all_probs[iter_] / args.num_samples all_decodings.append( torch.max(argmax_all_probs[iter_], dim=-1)[-1]) decoding = all_decodings[-1] used_t = time.time() - start_t curr_time += used_t if args.dataset != "mscoco": if args.remove_repeats: outputs_unidx = [ model.output_decoding(d) for d in [('src', sources), ( 'trg', targets), ('trg', remove_repeats_tensor(decoding))] ] else: outputs_unidx = [ model.output_decoding(d) for d in [('src', sources), ('trg', targets), ('trg', decoding)] ] else: # make sure that 5 captions per each example num_captions = len(all_captions[0]) for c in range(1, len(all_captions)): assert (num_captions == len(all_captions[c])) # untokenize reference captions for n_ref in range(len(all_captions)): n_caps = len(all_captions[0]) for c in range(n_caps): all_captions[n_ref][c] = all_captions[n_ref][c].replace( "@@ ", "") outputs_unidx = [list(map(list, zip(*all_captions)))] if args.remove_repeats: all_dec_outputs = [ model.output_decoding(d) for d in [('trg', remove_repeats_tensor(all_decodings[ii])) for ii in range(len(all_decodings))] ] else: all_dec_outputs = [ model.output_decoding(d) for d in [('trg', all_decodings[ii]) for ii in range(len(all_decodings))] ] corpus_size += batch_size if args.dataset != "mscoco": cum_tokens += sum([len(xx.split(" ")) for xx in outputs_unidx[0] ]) # NOTE source tokens, not target if args.dataset != "mscoco": src_outputs += outputs_unidx[0] trg_outputs += outputs_unidx[1] if args.remove_repeats: dec_outputs += remove_repeats(outputs_unidx[-1]) else: dec_outputs += outputs_unidx[-1] else: trg_outputs += outputs_unidx[0] for idx, each_output in enumerate(all_dec_outputs): if args.remove_repeats: all_decs[idx] += remove_repeats(each_output) else: all_decs[idx] += each_output #if True: if False and decoding_path is not None: for sent_i in range(len(outputs_unidx[0])): if args.dataset != "mscoco": print('SRC') print(outputs_unidx[0][sent_i]) for ii in range(len(all_decodings)): print('DEC iter {}'.format(ii)) print(all_dec_outputs[ii][sent_i]) print('TRG') print(outputs_unidx[1][sent_i]) else: print('TRG') trg = outputs_unidx[0] for subsent_i in range(len(trg[sent_i])): print('TRG {}'.format(subsent_i)) print(trg[sent_i][subsent_i]) for ii in range(len(all_decodings)): print('DEC iter {}'.format(ii)) print(all_dec_outputs[ii][sent_i]) print('---------------------------') timings += [used_t] if not args.debug: for s, t, d in zip(outputs_unidx[0], outputs_unidx[1], outputs_unidx[2]): s, t, d = s.replace('@@ ', ''), t.replace('@@ ', ''), d.replace('@@ ', '') print(s, file=handles[0], flush=True) print(t, file=handles[1], flush=True) print(d, file=handles[2], flush=True) if not args.no_tqdm: progressbar.update(iters) progressbar.set_description('finishing sentences={}/batches={}, \ length={}/average iter={}, speed={} sec/batch' .format(\ corpus_size, iters, src_len, np.mean(np.array(num_iters)), curr_time / (1 + iters))) if evaluate: for idx, each_dec in enumerate(all_decs): if len(all_decs[idx]) != len(trg_outputs): break if args.dataset != "mscoco": bleu_output = computeBLEU(each_dec, trg_outputs, corpus=True, tokenizer=tokenizer) else: bleu_output = computeBLEUMSCOCO(each_dec, trg_outputs, corpus=True, tokenizer=tokenizer) args.logger.info("iter {} | {}".format(idx + 1, print_bleu(bleu_output))) if args.adaptive_decoding != None: args.logger.info("----------------------------------------------") args.logger.info("Average # iters {}".format(np.mean(num_iters_total))) bleu_output = computeBLEU(dec_outputs, trg_outputs, corpus=True, tokenizer=tokenizer) args.logger.info("Adaptive BLEU | {}".format(print_bleu(bleu_output))) args.logger.info("----------------------------------------------") args.logger.info("Decoding speed analysis :") args.logger.info("{} sentences".format(cum_sentences)) if args.dataset != "mscoco": args.logger.info("{} tokens".format(cum_tokens)) args.logger.info("{:.3f} seconds".format(curr_time)) args.logger.info("{:.3f} ms / sentence".format( (curr_time / float(cum_sentences) * 1000))) if args.dataset != "mscoco": args.logger.info("{:.3f} ms / token".format( (curr_time / float(cum_tokens) * 1000))) args.logger.info("{:.3f} sentences / s".format( float(cum_sentences) / curr_time)) if args.dataset != "mscoco": args.logger.info("{:.3f} tokens / s".format( float(cum_tokens) / curr_time)) args.logger.info("----------------------------------------------") if args.decode_which > 0: args.logger.info("Writing to special file") parent = decoding_path / "speed" / "b_{}{}".format( args.beam_size if args.model is Transformer else args.valid_repeat_dec, "" if args.model is Transformer else "_{}".format(args.adaptive_decoding != None)) args.logger.info(str(parent)) parent.mkdir(parents=True, exist_ok=True) speed_handle = (parent / "results.{}".format(args.decode_which)).open('w') print("----------------------------------------------", file=speed_handle, flush=True) print("Decoding speed analysis :", file=speed_handle, flush=True) print("{} sentences".format(cum_sentences), file=speed_handle, flush=True) if args.dataset != "mscoco": print("{} tokens".format(cum_tokens), file=speed_handle, flush=True) print("{:.3f} seconds".format(curr_time), file=speed_handle, flush=True) print("{:.3f} ms / sentence".format( (curr_time / float(cum_sentences) * 1000)), file=speed_handle, flush=True) if args.dataset != "mscoco": print("{:.3f} ms / token".format( (curr_time / float(cum_tokens) * 1000)), file=speed_handle, flush=True) print("{:.3f} sentences / s".format(float(cum_sentences) / curr_time), file=speed_handle, flush=True) if args.dataset != "mscoco": print("{:.3f} tokens / s".format(float(cum_tokens) / curr_time), file=speed_handle, flush=True) print("----------------------------------------------", file=speed_handle, flush=True)