Example #1
0
def valid_model(args, model, dev, dev_metrics=None, distillation=False, print_out=False, U=None, beam=1, alpha=0.6):
    print_seqs = ['[sources]', '[targets]', '[decoded]', '[fertili]', '[origind]']
    src_outputs, trg_outputs, dec_outputs = [], [], []
    outputs = {}

    model.eval()
    progressbar = tqdm(total=len([1 for _ in dev]), desc='start decoding for validation...')

    for j, dev_batch in enumerate(dev):
        inputs, input_masks, \
        targets, target_masks, \
        sources, source_masks, \
        encoding, batch_size = model.quick_prepare(dev_batch, distillation, U=U)

        decoder_inputs, input_reorder, fertility_cost = inputs, None, None
        if type(model) is FastTransformer:
            decoder_inputs, input_reorder, decoder_masks, fertility_cost, pred_fertility = \
                model.prepare_initial(encoding, sources, source_masks, input_masks, None, mode='argmax')
        else:
            decoder_masks = input_masks

        decoding, out, probs = model(encoding, source_masks, decoder_inputs, decoder_masks, 
                                     decoding=True, return_probs=True, beam=beam, alpha=alpha)
        dev_outputs = [('src', sources), ('trg', targets), ('trg', decoding)]
        dev_outputs = [model.output_decoding(d) for d in dev_outputs]
    
        if (print_out and (j < 5)):
            for k, d in enumerate(dev_outputs):
                args.logger.info("{}: {}".format(print_seqs[k], d[0]))
            args.logger.info('------------------------------------------------------------------')

        src_outputs += dev_outputs[0]
        trg_outputs += dev_outputs[1]
        dec_outputs += dev_outputs[2]

        if dev_metrics is not None:
            values = [0, 0]
            dev_metrics.accumulate(batch_size, *values)

        info = 'Validation: decoding step={}'.format(j + 1)
        progressbar.update(1)
        progressbar.set_description(info)
    
    progressbar.close()

    corpus_bleu = computeBLEU(dec_outputs, trg_outputs, corpus=True, tokenizer=tokenizer)
    outputs['corpus_bleu'] = corpus_bleu
    outputs['dev_output']  = tuple(src_outputs, trg_outputs, dec_outputs)
    if dev_metrics is not None: 
        args.logger.info(dev_metrics)

    args.logger.info("The dev-set corpus BLEU = {}".format(corpus_bleu))
    return outputs
Example #2
0
    def validation_epoch_end(self, val_step_outputs):
        # global myGlobal
        # avg_val_loss = torch.tensor([x['loss'] for x in val_step_outputs]).mean()
        # avg_val_acc = torch.tensor([x["progress_bar"]["val_acc"] for x in val_step_outputs]).mean()
        #
        # pbar = {'avg_val_acc': avg_val_acc}
        print("Translation Sample =================")

        #"An old man trying to get up from a broken chair
        #A man wearing red shirt sitting under a tree

        for sentence in config.sentences:
            if config.USE_BPE == False:
                # if self.nepochs == config.MAX_EPOCHS:
                #     myGlobal.change(True)
                # myGlobal = True
                translated_sentence = translate_sentence(self,
                                                         sentence,
                                                         self.german_vocab,
                                                         self.english_vocab,
                                                         self.deviceLegacy,
                                                         max_length=50)
                # print("Output", translated_sentence)
                # print(sentence)
                # global myGlobal
                # myGlobal = False
                # exit()
                # if self.nepochs == config.MAX_EPOCHS:
                #     myGlobal.change(False)
                #     print("Input", sentence)
                #     print("Output", translated_sentence)
                #     exit()
            else:
                translated_sentence = translate_sentence_bpe(
                    self,
                    sentence,
                    self.german_vocab,
                    self.english_vocab,
                    self.deviceLegacy,
                    max_length=50)

            print("Output", translated_sentence)

        # if config.COMPUTE_BLEU == True and self.nepochs == config.MAX_EPOCHS:
        if config.COMPUTE_BLEU == True and self.nepochs > 0:
            bleu_score = computeBLEU(self.test_data, self, self.german_vocab,
                                     self.english_vocab, self.deviceLegacy)
            self.bleu_scores.append(bleu_score)
            print("BLEU score: ", bleu_score)
            if self.nepochs % 1 == 0:
                writeArrToCSV(self.bleu_scores)
        return
Example #3
0
def valid_model(args,
                model,
                dev,
                dev_metrics=None,
                distillation=False,
                print_out=False,
                teacher_model=None):
    print_seqs = [
        '[sources]', '[targets]', '[decoded]', '[fertili]', '[origind]'
    ]
    trg_outputs, dec_outputs = [], []
    outputs = {}

    model.eval()
    if teacher_model is not None:
        teacher_model.eval()

    for j, dev_batch in enumerate(dev):
        inputs, input_masks, \
        targets, target_masks, \
        sources, source_masks, \
        encoding, batch_size = model.quick_prepare(dev_batch, distillation)

        decoder_inputs, input_reorder, fertility_cost = inputs, None, None
        if type(model) is FastTransformer:
            decoder_inputs, input_reorder, decoder_masks, fertility_cost, pred_fertility = \
                model.prepare_initial(encoding, sources, source_masks, input_masks, None, mode='argmax')
        else:
            decoder_masks = input_masks

        decoding, out, probs = model(encoding,
                                     source_masks,
                                     decoder_inputs,
                                     decoder_masks,
                                     decoding=True,
                                     return_probs=True)
        dev_outputs = [('src', sources), ('trg', targets), ('trg', decoding)]
        if type(model) is FastTransformer:
            dev_outputs += [('src', input_reorder)]
        dev_outputs = [model.output_decoding(d) for d in dev_outputs]
        gleu = computeGLEU(dev_outputs[2],
                           dev_outputs[1],
                           corpus=False,
                           tokenizer=tokenizer)

        if print_out:
            for k, d in enumerate(dev_outputs):
                args.logger.info("{}: {}".format(print_seqs[k], d[0]))
            args.logger.info(
                '------------------------------------------------------------------'
            )

        if teacher_model is not None:  # teacher is Transformer, student is FastTransformer
            inputs_student, _, targets_student, _, _, _, encoding_teacher, _ \
                                = teacher_model.quick_prepare(dev_batch, False, decoding, decoding, input_masks, target_masks, source_masks)
            teacher_real_loss = teacher_model.cost(
                targets,
                target_masks,
                out=teacher_model(encoding_teacher, source_masks, inputs,
                                  input_masks))
            teacher_fake_out = teacher_model(encoding_teacher, source_masks,
                                             inputs_student, input_masks)
            teacher_fake_loss = teacher_model.cost(targets_student,
                                                   target_masks,
                                                   out=teacher_fake_out)
            teacher_alter_loss = teacher_model.cost(targets,
                                                    target_masks,
                                                    out=teacher_fake_out)

        trg_outputs += dev_outputs[1]
        dec_outputs += dev_outputs[2]

        if dev_metrics is not None:
            values = [0, gleu]
            if teacher_model is not None:
                values += [
                    teacher_real_loss, teacher_fake_loss,
                    teacher_real_loss - teacher_fake_loss, teacher_alter_loss,
                    teacher_alter_loss - teacher_fake_loss
                ]
            if fertility_cost is not None:
                values += [fertility_cost]

            dev_metrics.accumulate(batch_size, *values)

    corpus_gleu = computeGLEU(dec_outputs,
                              trg_outputs,
                              corpus=True,
                              tokenizer=tokenizer)
    corpus_bleu = computeBLEU(dec_outputs,
                              trg_outputs,
                              corpus=True,
                              tokenizer=tokenizer)
    outputs['corpus_gleu'] = corpus_gleu
    outputs['corpus_bleu'] = corpus_bleu
    if dev_metrics is not None:
        args.logger.info(dev_metrics)

    args.logger.info("The dev-set corpus GLEU = {}".format(corpus_gleu))
    args.logger.info("The dev-set corpus BLEU = {}".format(corpus_bleu))
    return outputs
Example #4
0
def valid_model(args,
                model,
                dev,
                dev_metrics=None,
                dev_metrics_trg=None,
                dev_metrics_average=None,
                print_out=False,
                teacher_model=None,
                trg_len_dic=None):
    print_seq = (['REF '] if args.dataset == "mscoco" else [
        'SRC ', 'REF '
    ]) + ['HYP{}'.format(ii + 1) for ii in range(args.valid_repeat_dec)]

    trg_outputs = []
    real_all_outputs = [[] for ii in range(args.valid_repeat_dec)]
    short_all_outputs = [[] for ii in range(args.valid_repeat_dec)]
    outputs_data = {}

    model.eval()
    for j, dev_batch in enumerate(dev):
        if args.dataset == "mscoco":
            # only use first caption for calculating log likelihood
            all_captions = dev_batch[1]
            dev_batch[1] = dev_batch[1][0]
            decoder_inputs, decoder_masks,\
            targets, target_masks,\
            _, source_masks,\
            encoding, batch_size, rest = model.quick_prepare_mscoco(dev_batch, all_captions=all_captions, fast=(type(model) is FastTransformer), inputs_dec=args.inputs_dec, trg_len_option=args.trg_len_option, max_len=args.max_offset, trg_len_dic=trg_len_dic, bp=args.bp)

        else:
            decoder_inputs, decoder_masks,\
            targets, target_masks,\
            sources, source_masks,\
            encoding, batch_size, rest = model.quick_prepare(dev_batch, fast=(type(model) is FastTransformer), trg_len_option=args.trg_len_option, trg_len_ratio=args.trg_len_ratio, trg_len_dic=trg_len_dic, bp=args.bp)

        losses, all_decodings = [], []
        if type(model) is Transformer:
            decoding, out, probs = model(encoding,
                                         source_masks,
                                         decoder_inputs,
                                         decoder_masks,
                                         beam=1,
                                         decoding=True,
                                         return_probs=True)
            loss = model.cost(targets, target_masks, out=out)
            losses.append(loss)
            all_decodings.append(decoding)

        elif type(model) is FastTransformer:
            for iter_ in range(args.valid_repeat_dec):
                curr_iter = min(iter_, args.num_decs - 1)
                next_iter = min(curr_iter + 1, args.num_decs - 1)

                decoding, out, probs = model(encoding,
                                             source_masks,
                                             decoder_inputs,
                                             decoder_masks,
                                             decoding=True,
                                             return_probs=True,
                                             iter_=curr_iter)

                loss = model.cost(targets,
                                  target_masks,
                                  out=out,
                                  iter_=curr_iter)
                losses.append(loss)
                all_decodings.append(decoding)

                decoder_inputs = 0
                if args.next_dec_input in ["both", "emb"]:
                    _, argmax = torch.max(probs, dim=-1)
                    emb = F.embedding(
                        argmax, model.decoder[next_iter].out.weight *
                        math.sqrt(args.d_model))
                    decoder_inputs += emb

                if args.next_dec_input in ["both", "out"]:
                    decoder_inputs += out

        if args.dataset == "mscoco":
            # make sure that 5 captions per each example
            num_captions = len(all_captions[0])
            for c in range(1, len(all_captions)):
                assert (num_captions == len(all_captions[c]))

            # untokenize reference captions
            for n_ref in range(len(all_captions)):
                n_caps = len(all_captions[0])
                for c in range(n_caps):
                    all_captions[n_ref][c] = all_captions[n_ref][c].replace(
                        "@@ ", "")

            src_ref = [list(map(list, zip(*all_captions)))]
        else:
            src_ref = [
                model.output_decoding(d)
                for d in [('src', sources), ('trg', targets)]
            ]

        real_outputs = [
            model.output_decoding(d)
            for d in [('trg', xx) for xx in all_decodings]
        ]

        if print_out:
            if args.dataset != "mscoco":
                for k, d in enumerate(src_ref + real_outputs):
                    args.logger.info("{} ({}): {}".format(
                        print_seq[k], len(d[0].split(" ")), d[0]))
            else:
                for k in range(len(all_captions[0])):
                    for c in range(len(all_captions)):
                        args.logger.info("REF ({}): {}".format(
                            len(all_captions[c][k].split(" ")),
                            all_captions[c][k]))

                    for c in range(len(real_outputs)):
                        args.logger.info("HYP {} ({}): {}".format(
                            c + 1, len(real_outputs[c][k].split(" ")),
                            real_outputs[c][k]))
            args.logger.info(
                '------------------------------------------------------------------'
            )

        trg_outputs += src_ref[-1]
        for ii, d_outputs in enumerate(real_outputs):
            real_all_outputs[ii] += d_outputs

        if dev_metrics is not None:
            dev_metrics.accumulate(batch_size, *losses)
        if dev_metrics_trg is not None:
            dev_metrics_trg.accumulate(batch_size,
                                       *[rest[0], rest[1], rest[2]])
        if dev_metrics_average is not None:
            dev_metrics_average.accumulate(batch_size, *[rest[3], rest[4]])

    if args.dataset != "mscoco":
        real_bleu = [
            computeBLEU(ith_output,
                        trg_outputs,
                        corpus=True,
                        tokenizer=tokenizer) for ith_output in real_all_outputs
        ]
    else:
        real_bleu = [
            computeBLEUMSCOCO(ith_output,
                              trg_outputs,
                              corpus=True,
                              tokenizer=tokenizer)
            for ith_output in real_all_outputs
        ]

    outputs_data['real'] = real_bleu

    if "predict" in args.trg_len_option:
        outputs_data['pred_target_len_loss'] = getattr(dev_metrics_trg,
                                                       'pred_target_len_loss')
        outputs_data['pred_target_len_correct'] = getattr(
            dev_metrics_trg, 'pred_target_len_correct')
        outputs_data['pred_target_len_approx'] = getattr(
            dev_metrics_trg, 'pred_target_len_approx')
        outputs_data['average_target_len_correct'] = getattr(
            dev_metrics_average, 'average_target_len_correct')
        outputs_data['average_target_len_approx'] = getattr(
            dev_metrics_average, 'average_target_len_approx')

    if dev_metrics is not None:
        args.logger.info(dev_metrics)
    if dev_metrics_trg is not None:
        args.logger.info(dev_metrics_trg)
    if dev_metrics_average is not None:
        args.logger.info(dev_metrics_average)

    for idx in range(args.valid_repeat_dec):
        print_str = "iter {} | {}".format(
            idx + 1, print_bleu(real_bleu[idx], verbose=False))
        args.logger.info(print_str)

    return outputs_data
Example #5
0
def decode_model(args,
                 model,
                 dev,
                 evaluate=True,
                 decoding_path=None,
                 names=None,
                 maxsteps=None):

    args.logger.info("decoding, f_size={}, beam_size={}, alpha={}".format(
        args.f_size, args.beam_size, args.alpha))
    dev.train = False  # make iterator volatile=True

    if maxsteps is None:
        progressbar = tqdm(total=sum([1 for _ in dev]), desc='start decoding')
    else:
        progressbar = tqdm(total=maxsteps, desc='start decoding')

    model.eval()
    if decoding_path is not None:
        handles = [
            open(os.path.join(decoding_path, name), 'w') for name in names
        ]

    corpus_size = 0
    src_outputs, trg_outputs, dec_outputs, timings = [], [], [], []
    decoded_words, target_words, decoded_info = 0, 0, 0

    attentions = None
    pad_id = model.decoder[0].field.vocab.stoi['<pad>']
    eos_id = model.decoder[0].field.vocab.stoi['<eos>']

    curr_time = 0
    cum_bs = 0
    for iters, dev_batch in enumerate(dev):
        if iters > maxsteps:
            args.logger.info('complete {} steps of decoding'.format(maxsteps))
            break

        start_t = time.time()

        # encoding
        inputs, input_masks, \
        targets, target_masks, \
        sources, source_masks, \
        encoding, batch_size = model.quick_prepare(dev_batch)
        cum_bs += batch_size
        # for now

        if type(model) is Transformer:
            all_decodings = []
            decoder_inputs, decoder_masks = inputs, input_masks
            decoding = model(encoding, source_masks, decoder_inputs, decoder_masks,
                            beam=args.beam_size, alpha=args.alpha, \
                             decoding=True, feedback=attentions)
            all_decodings.append(decoding)

        elif type(model) is FastTransformer:
            decoder_inputs, _, decoder_masks = \
                    model.prepare_initial(encoding, sources, source_masks, input_masks,\
                                          N=args.f_size)
            batch_size, src_len, hsize = encoding[0].size()
            all_decodings = []
            prev_dec_output = None
            iter_ = 0

            while True:
                iter_num = min(iter_, args.num_shared_dec - 1)
                next_iter = min(iter_ + 1, args.num_shared_dec - 1)

                decoding, out, probs = model(encoding,
                                             source_masks,
                                             decoder_inputs,
                                             decoder_masks,
                                             decoding=True,
                                             return_probs=True,
                                             iter_=iter_num)

                all_decodings.append(decoding)

                thedecoder = model.decoder[iter_num]

                logits = thedecoder.out(out)
                _, argmax = torch.max(logits, dim=-1)

                decoder_inputs = F.embedding(
                    argmax, model.decoder[next_iter].out.weight *
                    math.sqrt(args.d_model))
                if args.sum_out_and_emb:
                    decoder_inputs += out

                iter_ += 1
                if iter_ == args.valid_repeat_dec:
                    break

        used_t = time.time() - start_t
        curr_time += used_t

        real_mask = 1 - ((decoding.data == eos_id) +
                         (decoding.data == pad_id)).float()
        outputs = [
            model.output_decoding(d)
            for d in [('src', sources), ('trg', targets), ('trg', decoding)]
        ]
        all_dec_outputs = [
            model.output_decoding(d)
            for d in [('trg', all_decodings[ii])
                      for ii in range(len(all_decodings))]
        ]

        corpus_size += batch_size
        src_outputs += outputs[0]
        trg_outputs += outputs[1]
        dec_outputs += outputs[-1]
        """
        for sent_i in range(len(outputs[0])):
            print ('SRC')
            print (outputs[0][sent_i])
            print ('TRG')
            print (outputs[1][sent_i])
            for ii in range(len(all_decodings)):
                print ('DEC iter {}'.format(ii))
                print (all_dec_outputs[ii][sent_i])
            print ('---------------------------')
        """

        timings += [used_t]

        if decoding_path is not None:
            for s, t, d in zip(outputs[0], outputs[1], outputs[2]):
                s, t, d = s.replace('@@ ',
                                    ''), t.replace('@@ ',
                                                   ''), d.replace('@@ ', '')
                print(s, file=handles[0], flush=True)
                print(t, file=handles[1], flush=True)
                print(d, file=handles[2], flush=True)

    print(curr_time / float(cum_bs) * 1000)
    #progressbar.update(1)
    #progressbar.set_description('finishing sentences={}/batches={}, speed={} sec/batch'.format(corpus_size, iters, curr_time / (1 + iters)))

    if evaluate:
        corpus_bleu = computeBLEU(dec_outputs,
                                  trg_outputs,
                                  corpus=True,
                                  tokenizer=tokenizer)
        #args.logger.info("The dev-set corpus BLEU = {}".format(corpus_bleu))
        print("The dev-set corpus BLEU = {}".format(corpus_bleu))
Example #6
0
def run_fast_transformer(decoder_inputs, decoder_masks,\
                        sources, source_masks,\
                        targets,\
                        encoding,\
                        model, args, use_argmax=True):

    trg_unidx = model.output_decoding(('trg', targets))

    batch_size, src_len, hsize = encoding[0].size()

    all_decodings = []
    all_probs = []
    iter_ = 0
    bleu_hist = [[] for xx in range(batch_size)]
    output_hist = [[] for xx in range(batch_size)]
    multiset_hist = [[] for xx in range(batch_size)]
    num_iters = [0 for xx in range(batch_size)]
    done_ = [False for xx in range(batch_size)]
    final_decoding = [None for xx in range(batch_size)]

    while True:
        curr_iter = min(iter_, args.num_decs - 1)
        next_iter = min(iter_ + 1, args.num_decs - 1)

        decoding, out, probs = model(encoding,
                                     source_masks,
                                     decoder_inputs,
                                     decoder_masks,
                                     decoding=True,
                                     return_probs=True,
                                     iter_=curr_iter)

        dec_output = decoding.data.cpu().numpy().tolist()
        """
        if args.trg_len_option != "reference":
            decoder_masks = 0. * decoder_masks
            for bidx in range(batch_size):
                try:
                    decoder_masks[bidx,:(dec_output[bidx].index(3))+1] = 1.
                except:
                    decoder_masks[bidx,:] = 1.
        """

        if args.adaptive_decoding == "oracle":
            out_unidx = model.output_decoding(('trg', decoding))
            sentence_bleus = computeBLEU(out_unidx,
                                         trg_unidx,
                                         corpus=False,
                                         tokenizer=tokenizer)

            for bidx in range(batch_size):
                output_hist[bidx].append(dec_output[bidx])
                bleu_hist[bidx].append(sentence_bleus[bidx])

            converged = oracle_converged(bleu_hist,
                                         num_items=args.adaptive_window)
            for bidx in range(batch_size):
                if not done_[bidx] and converged[bidx] and num_iters[bidx] == 0:
                    num_iters[bidx] = iter_ + 1 - (args.adaptive_window - 1)
                    done_[bidx] = True
                    final_decoding[bidx] = output_hist[bidx][-args.
                                                             adaptive_window]

        elif args.adaptive_decoding == "equality":
            for bidx in range(batch_size):
                #if 3 in dec_output[bidx]:
                #    dec_output[bidx] = dec_output[bidx][:dec_output[bidx].index(3)]
                output_hist[bidx].append(dec_output[bidx])

            converged = equality_converged(output_hist,
                                           num_items=args.adaptive_window)

            for bidx in range(batch_size):
                if not done_[bidx] and converged[bidx] and num_iters[bidx] == 0:
                    num_iters[bidx] = iter_ + 1
                    done_[bidx] = True
                    final_decoding[bidx] = output_hist[bidx][-1]

        elif args.adaptive_decoding == "jaccard":
            for bidx in range(batch_size):
                #if 3 in dec_output[bidx]:
                #    dec_output[bidx] = dec_output[bidx][:dec_output[bidx].index(3)]
                output_hist[bidx].append(dec_output[bidx])
                multiset_hist[bidx].append(Multiset(dec_output[bidx]))

            converged = jaccard_converged(multiset_hist,
                                          num_items=args.adaptive_window)

            for bidx in range(batch_size):
                if not done_[bidx] and converged[bidx] and num_iters[bidx] == 0:
                    num_iters[bidx] = iter_ + 1
                    done_[bidx] = True
                    final_decoding[bidx] = output_hist[bidx][-1]

        all_decodings.append(decoding)
        all_probs.append(probs)

        decoder_inputs = 0
        if args.next_dec_input in ["both", "emb"]:
            if use_argmax:
                _, argmax = torch.max(probs, dim=-1)
            else:
                probs_sz = probs.size()
                probs_ = Variable(probs.data, requires_grad=False)
                argmax = torch.multinomial(
                    probs_.contiguous().view(-1, probs_sz[-1]),
                    1).view(*probs_sz[:-1])
            emb = F.embedding(
                argmax,
                model.decoder[next_iter].out.weight * math.sqrt(args.d_model))
            decoder_inputs += emb

        if args.next_dec_input in ["both", "out"]:
            decoder_inputs += out

        iter_ += 1
        if iter_ == args.valid_repeat_dec or (False not in done_):
            break

    if args.adaptive_decoding != None:
        for bidx in range(batch_size):
            if num_iters[bidx] == 0:
                num_iters[bidx] = 20
            if final_decoding[bidx] == None:
                if args.adaptive_decoding == "oracle":
                    final_decoding[bidx] = output_hist[bidx][np.argmax(
                        bleu_hist[bidx])]
                else:
                    final_decoding[bidx] = output_hist[bidx][-1]

        decoding = Variable(torch.LongTensor(np.array(final_decoding)))
        if decoder_masks.is_cuda:
            decoding = decoding.cuda()

    return decoding, all_decodings, num_iters, all_probs
Example #7
0
def decode_model(args,
                 model,
                 dev,
                 evaluate=True,
                 trg_len_dic=None,
                 decoding_path=None,
                 names=None,
                 maxsteps=None):

    args.logger.info("decoding, f_size={}, beam_size={}, alpha={}".format(
        args.f_size, args.beam_size, args.alpha))
    dev.train = False  # make iterator volatile=True

    if not args.no_tqdm:
        progressbar = tqdm(total=200, desc='start decoding')

    model.eval()
    if not args.debug:
        decoding_path.mkdir(parents=True, exist_ok=True)
        handles = [(decoding_path / name).open('w') for name in names]

    corpus_size = 0
    src_outputs, trg_outputs, dec_outputs, timings = [], [], [], []
    all_decs = [[] for idx in range(args.valid_repeat_dec)]
    decoded_words, target_words, decoded_info = 0, 0, 0

    attentions = None
    decoder = model.decoder[
        0] if args.model is FastTransformer else model.decoder
    pad_id = decoder.field.vocab.stoi['<pad>']
    eos_id = decoder.field.vocab.stoi['<eos>']

    curr_time = 0
    cum_sentences = 0
    cum_tokens = 0
    cum_images = 0  # used for mscoco
    num_iters_total = []

    for iters, dev_batch in enumerate(dev):
        start_t = time.time()

        if args.dataset != "mscoco":
            decoder_inputs, decoder_masks,\
            targets, target_masks,\
            sources, source_masks,\
            encoding, batch_size, rest = model.quick_prepare(dev_batch, fast=(type(model) is FastTransformer), trg_len_option=args.trg_len_option, trg_len_ratio=args.trg_len_ratio, trg_len_dic=trg_len_dic, bp=args.bp)
        else:
            # only use first caption for calculating log likelihood
            all_captions = dev_batch[1]
            dev_batch[1] = dev_batch[1][0]
            decoder_inputs, decoder_masks,\
            targets, target_masks,\
            _, source_masks,\
            encoding, batch_size, rest = model.quick_prepare_mscoco(dev_batch, all_captions=all_captions, fast=(type(model) is FastTransformer), inputs_dec=args.inputs_dec, trg_len_option=args.trg_len_option, max_len=args.max_len, trg_len_dic=trg_len_dic, bp=args.bp, gpu=args.gpu>-1)
            sources = None

        cum_sentences += batch_size

        batch_size, src_len, hsize = encoding[0].size()

        # for now
        if type(model) is Transformer:
            all_decodings = []
            decoding = model(encoding, source_masks, decoder_inputs, decoder_masks,
                            beam=args.beam_size, alpha=args.alpha, \
                             decoding=True, feedback=attentions)
            all_decodings.append(decoding)
            num_iters = [0]

        elif type(model) is FastTransformer:
            decoding, all_decodings, num_iters, argmax_all_probs = run_fast_transformer(decoder_inputs, decoder_masks, \
                                        sources, source_masks, targets, encoding, model, args, use_argmax=True)
            num_iters_total.extend(num_iters)

            if not args.use_argmax:
                for _ in range(args.num_samples):
                    _, _, _, sampled_all_probs = run_fast_transformer(decoder_inputs, decoder_masks, \
                                                sources, source_masks, encoding, model, args, use_argmax=False)
                    for iter_ in range(args.valid_repeat_dec):
                        argmax_all_probs[iter_] = argmax_all_probs[
                            iter_] + sampled_all_probs[iter_]

                all_decodings = []
                for iter_ in range(args.valid_repeat_dec):
                    argmax_all_probs[
                        iter_] = argmax_all_probs[iter_] / args.num_samples
                    all_decodings.append(
                        torch.max(argmax_all_probs[iter_], dim=-1)[-1])
                decoding = all_decodings[-1]

        used_t = time.time() - start_t
        curr_time += used_t

        if args.dataset != "mscoco":
            if args.remove_repeats:
                outputs_unidx = [
                    model.output_decoding(d)
                    for d in [('src', sources), (
                        'trg',
                        targets), ('trg', remove_repeats_tensor(decoding))]
                ]
            else:
                outputs_unidx = [
                    model.output_decoding(d)
                    for d in [('src', sources), ('trg',
                                                 targets), ('trg', decoding)]
                ]

        else:
            # make sure that 5 captions per each example
            num_captions = len(all_captions[0])
            for c in range(1, len(all_captions)):
                assert (num_captions == len(all_captions[c]))

            # untokenize reference captions
            for n_ref in range(len(all_captions)):
                n_caps = len(all_captions[0])
                for c in range(n_caps):
                    all_captions[n_ref][c] = all_captions[n_ref][c].replace(
                        "@@ ", "")

            outputs_unidx = [list(map(list, zip(*all_captions)))]

        if args.remove_repeats:
            all_dec_outputs = [
                model.output_decoding(d)
                for d in [('trg', remove_repeats_tensor(all_decodings[ii]))
                          for ii in range(len(all_decodings))]
            ]
        else:
            all_dec_outputs = [
                model.output_decoding(d)
                for d in [('trg', all_decodings[ii])
                          for ii in range(len(all_decodings))]
            ]

        corpus_size += batch_size
        if args.dataset != "mscoco":
            cum_tokens += sum([len(xx.split(" ")) for xx in outputs_unidx[0]
                               ])  # NOTE source tokens, not target

        if args.dataset != "mscoco":
            src_outputs += outputs_unidx[0]
            trg_outputs += outputs_unidx[1]
            if args.remove_repeats:
                dec_outputs += remove_repeats(outputs_unidx[-1])
            else:
                dec_outputs += outputs_unidx[-1]

        else:
            trg_outputs += outputs_unidx[0]

        for idx, each_output in enumerate(all_dec_outputs):
            if args.remove_repeats:
                all_decs[idx] += remove_repeats(each_output)
            else:
                all_decs[idx] += each_output

        #if True:
        if False and decoding_path is not None:
            for sent_i in range(len(outputs_unidx[0])):
                if args.dataset != "mscoco":
                    print('SRC')
                    print(outputs_unidx[0][sent_i])
                    for ii in range(len(all_decodings)):
                        print('DEC iter {}'.format(ii))
                        print(all_dec_outputs[ii][sent_i])
                    print('TRG')
                    print(outputs_unidx[1][sent_i])
                else:
                    print('TRG')
                    trg = outputs_unidx[0]
                    for subsent_i in range(len(trg[sent_i])):
                        print('TRG {}'.format(subsent_i))
                        print(trg[sent_i][subsent_i])
                    for ii in range(len(all_decodings)):
                        print('DEC iter {}'.format(ii))
                        print(all_dec_outputs[ii][sent_i])
                print('---------------------------')

        timings += [used_t]

        if not args.debug:
            for s, t, d in zip(outputs_unidx[0], outputs_unidx[1],
                               outputs_unidx[2]):
                s, t, d = s.replace('@@ ',
                                    ''), t.replace('@@ ',
                                                   ''), d.replace('@@ ', '')
                print(s, file=handles[0], flush=True)
                print(t, file=handles[1], flush=True)
                print(d, file=handles[2], flush=True)

        if not args.no_tqdm:
            progressbar.update(iters)
            progressbar.set_description('finishing sentences={}/batches={}, \
                length={}/average iter={}, speed={} sec/batch'                                                              .format(\
                corpus_size, iters, src_len, np.mean(np.array(num_iters)), curr_time / (1 + iters)))

    if evaluate:
        for idx, each_dec in enumerate(all_decs):
            if len(all_decs[idx]) != len(trg_outputs):
                break
            if args.dataset != "mscoco":
                bleu_output = computeBLEU(each_dec,
                                          trg_outputs,
                                          corpus=True,
                                          tokenizer=tokenizer)
            else:
                bleu_output = computeBLEUMSCOCO(each_dec,
                                                trg_outputs,
                                                corpus=True,
                                                tokenizer=tokenizer)
            args.logger.info("iter {} | {}".format(idx + 1,
                                                   print_bleu(bleu_output)))

    if args.adaptive_decoding != None:
        args.logger.info("----------------------------------------------")
        args.logger.info("Average # iters {}".format(np.mean(num_iters_total)))
        bleu_output = computeBLEU(dec_outputs,
                                  trg_outputs,
                                  corpus=True,
                                  tokenizer=tokenizer)
        args.logger.info("Adaptive BLEU | {}".format(print_bleu(bleu_output)))

    args.logger.info("----------------------------------------------")
    args.logger.info("Decoding speed analysis :")
    args.logger.info("{} sentences".format(cum_sentences))
    if args.dataset != "mscoco":
        args.logger.info("{} tokens".format(cum_tokens))
    args.logger.info("{:.3f} seconds".format(curr_time))

    args.logger.info("{:.3f} ms / sentence".format(
        (curr_time / float(cum_sentences) * 1000)))
    if args.dataset != "mscoco":
        args.logger.info("{:.3f} ms / token".format(
            (curr_time / float(cum_tokens) * 1000)))

    args.logger.info("{:.3f} sentences / s".format(
        float(cum_sentences) / curr_time))
    if args.dataset != "mscoco":
        args.logger.info("{:.3f} tokens / s".format(
            float(cum_tokens) / curr_time))
    args.logger.info("----------------------------------------------")

    if args.decode_which > 0:
        args.logger.info("Writing to special file")
        parent = decoding_path / "speed" / "b_{}{}".format(
            args.beam_size if args.model is Transformer else
            args.valid_repeat_dec, "" if args.model is Transformer else
            "_{}".format(args.adaptive_decoding != None))
        args.logger.info(str(parent))
        parent.mkdir(parents=True, exist_ok=True)
        speed_handle = (parent /
                        "results.{}".format(args.decode_which)).open('w')

        print("----------------------------------------------",
              file=speed_handle,
              flush=True)
        print("Decoding speed analysis :", file=speed_handle, flush=True)
        print("{} sentences".format(cum_sentences),
              file=speed_handle,
              flush=True)
        if args.dataset != "mscoco":
            print("{} tokens".format(cum_tokens),
                  file=speed_handle,
                  flush=True)
        print("{:.3f} seconds".format(curr_time),
              file=speed_handle,
              flush=True)

        print("{:.3f} ms / sentence".format(
            (curr_time / float(cum_sentences) * 1000)),
              file=speed_handle,
              flush=True)
        if args.dataset != "mscoco":
            print("{:.3f} ms / token".format(
                (curr_time / float(cum_tokens) * 1000)),
                  file=speed_handle,
                  flush=True)

        print("{:.3f} sentences / s".format(float(cum_sentences) / curr_time),
              file=speed_handle,
              flush=True)
        if args.dataset != "mscoco":
            print("{:.3f} tokens / s".format(float(cum_tokens) / curr_time),
                  file=speed_handle,
                  flush=True)
        print("----------------------------------------------",
              file=speed_handle,
              flush=True)
Example #8
0
def decode_model(args, model, dev, teacher_model=None, evaluate=True,
                decoding_path=None, names=None, maxsteps=None):

    args.logger.info("decoding with {}, f_size={}, beam_size={}, alpha={}".format(args.decode_mode, args.f_size, args.beam_size, args.alpha))
    dev.train = False  # make iterator volatile=True

    if maxsteps is None:
        progressbar = tqdm(total=sum([1 for _ in dev]), desc='start decoding')
    else:
        progressbar = tqdm(total=maxsteps, desc='start decoding')

    model.eval()
    if teacher_model is not None:
        assert (args.f_size * args.beam_size > 1), 'multiple samples are essential.'
        teacher_model.eval()

    if decoding_path is not None:
        handles = [open(os.path.join(decoding_path, name), 'w') for name in names]

    corpus_size = 0
    src_outputs, trg_outputs, dec_outputs, timings = [], [], [], []
    decoded_words, target_words, decoded_info = 0, 0, 0

    attentions = None
    pad_id = model.decoder.field.vocab.stoi['<pad>']
    eos_id = model.decoder.field.vocab.stoi['<eos>']

    curr_time = 0
    for iters, dev_batch in enumerate(dev):

        if iters > maxsteps:
            args.logger.info('complete {} steps of decoding'.format(maxsteps))
            break

        start_t = time.time()
        # encoding
        inputs, input_masks, targets, target_masks, sources, source_masks, encoding, batch_size = model.quick_prepare(dev_batch)

        if args.model is Transformer:
            # decoding from the Transformer

            decoder_inputs, decoder_masks = inputs, input_masks
            decoding = model(encoding, source_masks, decoder_inputs, decoder_masks,
                            beam=args.beam_size, alpha=args.alpha, decoding=True, feedback=attentions)
        else:
            # decoding from the FastTransformer

            if teacher_model is not None:
                encoding_teacher = teacher_model.encoding(sources, source_masks)

            decoder_inputs, input_reorder, decoder_masks, _, fertility = \
                    model.prepare_initial(encoding, sources, source_masks, input_masks, None, mode=args.decode_mode, N=args.f_size)
            batch_size, src_len, hsize = encoding[0].size()
            trg_len = targets.size(1)

            if args.f_size > 1:
                source_masks = source_masks[:, None, :].expand(batch_size, args.f_size, src_len)
                source_masks = source_masks.contiguous().view(batch_size * args.f_size, src_len)
                for i in range(len(encoding)):
                    encoding[i] = encoding[i][:, None, :].expand(
                    batch_size, args.f_size, src_len, hsize).contiguous().view(batch_size * args.f_size, src_len, hsize)
            decoding = model(encoding, source_masks, decoder_inputs, decoder_masks, beam=args.beam_size, decoding=True, feedback=attentions)
            total_size = args.beam_size * args.f_size

            # print(fertility.data.sum() - decoder_masks.sum())
            # print(fertility.data.sum() * args.beam_size - (decoding.data != 1).long().sum())
            if total_size > 1:
                if args.beam_size > 1:
                    source_masks = source_masks[:, None, :].expand(batch_size * args.f_size,
                        args.beam_size, src_len).contiguous().view(batch_size * total_size, src_len)
                    fertility = fertility[:, None, :].expand(batch_size * args.f_size,
                        args.beam_size, src_len).contiguous().view(batch_size * total_size, src_len)
                    # fertility = model.apply_mask(fertility, source_masks, -1)

                if teacher_model is not None:  # use teacher model to re-rank the translation
                    decoder_masks = teacher_model.prepare_masks(decoding)

                    for i in range(len(encoding_teacher)):
                        encoding_teacher[i] = encoding_teacher[i][:, None, :].expand(
                            batch_size,  total_size, src_len, hsize).contiguous().view(
                            batch_size * total_size, src_len, hsize)

                    student_inputs,  _ = teacher_model.prepare_inputs( dev_batch, decoding, decoder_masks)
                    student_targets, _ = teacher_model.prepare_targets(dev_batch, decoding, decoder_masks)
                    out, probs = teacher_model(encoding_teacher, source_masks, student_inputs, decoder_masks, return_probs=True, decoding=False)
                    _, teacher_loss = model.batched_cost(student_targets, decoder_masks, probs, batched=True)  # student-loss (MLE)

                    # reranking the translation
                    teacher_loss = teacher_loss.view(batch_size, total_size)
                    decoding = decoding.view(batch_size, total_size, -1)
                    fertility = fertility.view(batch_size, total_size, -1)
                    lp = decoder_masks.sum(1).view(batch_size, total_size) ** (1 - args.alpha)
                    teacher_loss = teacher_loss * Variable(lp)

                    # selected index
                    selected_idx = (-teacher_loss).topk(1, 1)[1]   # batch x 1
                    decoding = decoding.gather(1, selected_idx[:, :, None].expand(batch_size, 1, decoding.size(-1)))[:, 0, :]
                    fertility = fertility.gather(1, selected_idx[:, :, None].expand(batch_size, 1, fertility.size(-1)))[:, 0, :]

                else:   # (cheating, re-rank by sentence-BLEU score)

                    # compute GLEU score to select the best translation
                    trg_output = model.output_decoding(('trg', targets[:, None, :].expand(batch_size,
                                                        total_size, trg_len).contiguous().view(batch_size * total_size, trg_len)))
                    dec_output = model.output_decoding(('trg', decoding))
                    bleu_score = computeBLEU(dec_output, trg_output, corpus=False, tokenizer=tokenizer).contiguous().view(batch_size, total_size)
                    bleu_score = bleu_score.cuda(args.gpu)
                    selected_idx = bleu_score.max(1)[1]

                    decoding = decoding.view(batch_size, total_size, -1)
                    fertility = fertility.view(batch_size, total_size, -1)
                    decoding = decoding.gather(1, selected_idx[:, None, None].expand(batch_size, 1, decoding.size(-1)))[:, 0, :]
                    fertility = fertility.gather(1, selected_idx[:, None, None].expand(batch_size, 1, fertility.size(-1)))[:, 0, :]

                    # print(fertility.data.sum() - (decoding.data != 1).long().sum())
                    assert (fertility.data.sum() - (decoding.data != 1).long().sum() == 0), 'fer match decode'


        used_t = time.time() - start_t
        curr_time += used_t

        real_mask = 1 - ((decoding.data == eos_id) + (decoding.data == pad_id)).float()
        outputs = [model.output_decoding(d) for d in [('src', sources), ('trg', targets), ('trg', decoding)]]

        corpus_size += batch_size
        src_outputs += outputs[0]
        trg_outputs += outputs[1]
        dec_outputs += outputs[2]
        timings += [used_t]

        if decoding_path is not None:
            for s, t, d in zip(outputs[0], outputs[1], outputs[2]):
                if args.no_bpe:
                    s, t, d = s.replace('@@ ', ''), t.replace('@@ ', ''), d.replace('@@ ', '')
                print(s, file=handles[0], flush=True)
                print(t, file=handles[1], flush=True)
                print(d, file=handles[2], flush=True)

            if args.model is FastTransformer:
                with torch.cuda.device_of(fertility):
                    fertility = fertility.data.tolist()
                    for f in fertility:
                        f = ' '.join([str(fi) for fi in cutoff(f, 0)])
                        print(f, file=handles[3], flush=True)

        progressbar.update(1)
        progressbar.set_description('finishing sentences={}/batches={}, speed={} sec/batch'.format(corpus_size, iters, curr_time / (1 + iters)))

    if evaluate:
        corpus_gleu = computeGLEU(dec_outputs, trg_outputs, corpus=True, tokenizer=tokenizer)
        corpus_bleu = computeBLEU(dec_outputs, trg_outputs, corpus=True, tokenizer=tokenizer)
        args.logger.info("The dev-set corpus GLEU = {}".format(corpus_gleu))
        args.logger.info("The dev-set corpus BLEU = {}".format(corpus_bleu))
Example #9
0
def valid_model(args, model, dev, dev_metrics=None,
                print_out=False, teacher_model=None):
    print_seqs = ['SRC ', 'REF '] + ['HYP{}'.format(ii+1) for ii in range(args.valid_repeat_dec)]
    trg_outputs = []
    all_outputs = [ [] for ii in range(args.valid_repeat_dec)]
    outputs_data = {}

    model.eval()
    if teacher_model is not None:
        teacher_model.eval()

    for j, dev_batch in enumerate(dev):
        inputs, input_masks, \
        targets, target_masks, \
        sources, source_masks, \
        encoding, batch_size = model.quick_prepare(dev_batch)

        if type(model) is Transformer:
            decoder_inputs, decoder_masks = inputs, input_masks
        elif type(model) is FastTransformer:
            decoder_inputs, _, decoder_masks = \
                model.prepare_initial(encoding, sources, source_masks, input_masks)
            initial_inputs = decoder_inputs

        if type(model) is Transformer:
            decoding, out, probs = model(encoding, source_masks, decoder_inputs, decoder_masks,
                                         decoding=True, return_probs=True)
        elif type(model) is FastTransformer:
            losses, all_decodings = [], []
            for iter_ in range(args.valid_repeat_dec):
                curr_iter = min(iter_, args.num_shared_dec-1)
                next_iter = min(curr_iter + 1, args.num_shared_dec-1)

                decoding, out, probs = model(encoding, source_masks, decoder_inputs, decoder_masks,
                                             decoding=True, return_probs=True, iter_=curr_iter)
                losses.append( model.cost(targets, target_masks, out=out, iter_=curr_iter) )
                all_decodings.append( decoding )

                logits = model.decoder[curr_iter].out(out)
                _, argmax = torch.max(logits, dim=-1)

                decoder_inputs = F.embedding(argmax, model.decoder[next_iter].out.weight *
                                                     math.sqrt(args.d_model))
                if args.sum_out_and_emb:
                    decoder_inputs += out

        dev_outputs = [('src', sources), ('trg', targets)]
        if type(model) is Transformer:
            dev_outputs += [('trg', decoding)]
        elif type(model) is FastTransformer:
            dev_outputs += [('trg', xx) for xx in all_decodings]

        dev_outputs = [model.output_decoding(d) for d in dev_outputs]

        if print_out:
            for k, d in enumerate(dev_outputs):
                args.logger.info("{}: {}".format(print_seqs[k], d[0]))
            args.logger.info('------------------------------------------------------------------')

        trg_outputs += dev_outputs[1]
        for ii, d_outputs in enumerate(dev_outputs[2:]):
            all_outputs[ii] += d_outputs

        if dev_metrics is not None:
            dev_metrics.accumulate(batch_size, *losses)

    bleu = [100 * computeBLEU(ith_output, trg_outputs, corpus=True, tokenizer=tokenizer) for ith_output in all_outputs]

    outputs_data['bleu'] = bleu
    if dev_metrics is not None:
        args.logger.info(dev_metrics)

    args.logger.info("dev BLEU: {}".format(bleu))
    return outputs_data
Example #10
0
def decode_model(args,
                 watcher,
                 model,
                 dev,
                 evaluate=True,
                 decoding_path=None,
                 names=None,
                 maxsteps=None):

    print_seqs = ['[sources]', '[targets]', '[decoded]']

    args.logger.info("decoding beam-search: beam_size={}, alpha={}".format(
        args.beam_size, args.alpha))
    dev.train = False  # make iterator volatile=True

    if maxsteps is None:
        maxsteps = sum([1 for _ in dev])
    progressbar = tqdm(total=maxsteps, desc='start decoding')

    model.eval()
    if decoding_path is not None:
        handles = [
            open(os.path.join(decoding_path, name), 'w') for name in names
        ]

    corpus_size = 0
    src_outputs, trg_outputs, dec_outputs, timings = [], [], [], []
    decoded_words, target_words, decoded_info = 0, 0, 0

    attentions = None
    pad_id = model.decoder.field.vocab.stoi['<pad>']
    eos_id = model.decoder.field.vocab.stoi['<eos>']

    curr_time = 0
    for iters, dev_batch in enumerate(dev):

        if iters > maxsteps:
            args.logger.info('complete {} steps of decoding'.format(maxsteps))
            break

        start_t = time.time()

        # prepare the data
        source_inputs, source_outputs, source_masks, \
        target_inputs, target_outputs, target_masks = model.prepare_data(dev_batch)

        if not args.real_time:
            # encoding
            encoding_outputs = model.encoding(source_inputs, source_masks)

            # decoding
            decoding_outputs = model.decoding(encoding_outputs,
                                              source_masks,
                                              target_inputs,
                                              target_masks,
                                              beam=args.beam_size,
                                              alpha=args.alpha,
                                              decoding=True,
                                              return_probs=False)
        else:
            # currently only supports simultaneous greedy decoding
            decoding_outputs = model.simultaneous_decoding(
                source_inputs, source_masks)

        # reverse to string-sequence
        dev_outputs = [
            model.io_enc.reverse(source_outputs),
            model.io_dec.reverse(target_outputs),
            model.io_dec.reverse(decoding_outputs)
        ]

        # for j in range(source_inputs.size(0)):
        #     for k, d in enumerate(dev_outputs):
        #         args.logger.info("{}: {}".format(print_seqs[k], d[j]))
        #     args.logger.info("-----------------------------------")
        #     1/0

        used_t = time.time() - start_t
        curr_time += used_t

        real_mask = 1 - ((decoding_outputs == eos_id) +
                         (decoding_outputs == pad_id)).float()

        corpus_size += source_inputs.size(0)
        src_outputs += dev_outputs[0]
        trg_outputs += dev_outputs[1]
        dec_outputs += dev_outputs[2]
        timings += [used_t]

        if decoding_path is not None:
            for s, t, d in zip(dev_outputs[0], dev_outputs[1], dev_outputs[2]):
                if args.no_bpe:
                    s, t, d = s.replace('@@ ', ''), t.replace('@@ ',
                                                              ''), d.replace(
                                                                  '@@ ', '')
                print(s, file=handles[0], flush=True)
                print(t, file=handles[1], flush=True)
                print(d, file=handles[2], flush=True)

        progressbar.update(1)
        progressbar.set_description(
            'finishing sentences={}/batches={}, speed={:.2f} sentences / sec'.
            format(corpus_size, iters, corpus_size / curr_time))

    if evaluate:
        corpus_bleu = computeBLEU(dec_outputs,
                                  trg_outputs,
                                  corpus=True,
                                  tokenizer=debpe)
        args.logger.info("The dev-set corpus BLEU = {}".format(corpus_bleu))