Ejemplo n.º 1
0
def call_onmt(words, model_name, n_best=10, return_scores=False):
    #Adapted code from OpenNMT translate.py

    stream = fake_stream()
    fields, model, model_opt = _give_model(model_name)
    opt = opennmt_opts("", **_default_kwargs(words, n_best))
    scorer = GNMTGlobalScorer.from_opt(opt)
    t = Translator.from_opt(model,
                            fields,
                            opt,
                            model_opt,
                            global_scorer=scorer,
                            out_file=stream,
                            report_score=False)
    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
     if opt.tgt is not None else repeat(None)
    shard_pairs = zip(src_shards, tgt_shards)
    for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
        translated = t.translate(
            src=src_shard,
            tgt=tgt_shard,
            #src_dir=opt.batch_size,
            batch_size=opt.batch_size,
            attn_debug=opt.attn_debug)
        stream.load_scores(translated)
    chunks = _parse_fake_stream(stream, n_best)
    scores = stream.get_scores()
    if return_scores is True:
        return [list(zip(c, s)) for c, s in zip(chunks, scores)]
    else:
        return chunks
Ejemplo n.º 2
0
def evaluation(model, ds: Dataset, vocab):
    _readers, _data = inputters.Dataset.config([
        ("src", {
            "reader": inputters.str2reader["text"](),
            "data": ds.val.source
        }),
        ("tgt", {
            "reader": inputters.str2reader["text"](),
            "data": ds.val.target
        })
    ])

    dataset = inputters.Dataset(vocab,
                                _readers,
                                _data,
                                sort_key=inputters.str2sortkey["text"])
    data_iter = inputters.OrderedIterator(dataset=dataset,
                                          batch_size=10,
                                          train=False,
                                          sort=False,
                                          sort_within_batch=True,
                                          shuffle=False)

    scorer = GNMTGlobalScorer(alpha=0.7,
                              beta=0.,
                              length_penalty="avg",
                              coverage_penalty="none")
    builder = TranslationBuilder(data=dataset, fields=vocab)

    src_reader = inputters.str2reader["text"]
    tgt_reader = inputters.str2reader["text"]
    gpu = 0 if cuda.is_available() else -1

    translator = Translator(model=model,
                            fields=vocab,
                            src_reader=src_reader,
                            tgt_reader=tgt_reader,
                            global_scorer=scorer,
                            gpu=gpu)

    for batch in data_iter:
        trans_batch = translator.translate_batch(
            batch, [vocab["src"].base_field.vocab], attn_debug=False)
        translations = builder.from_batch(trans_batch)

        for trans in translations:
            print(trans.log(0))
        break

    return
Ejemplo n.º 3
0
    def __init__(self,
                 config,
                 num_layers = 1,
                 bidirectional=True):
        super(DQN, self).__init__()
        
        self.config = c = config
        self.encoder_embeddings = onmt.modules.Embeddings(c.emb_size, c.src_vocab_size, word_padding_idx=c.src_padding, dropout=0)
        self.encoder = onmt.encoders.RNNEncoder(
            hidden_size=c.rnn_size,
            num_layers=num_layers,
            rnn_type=c.rnn_type,
            bidirectional=bidirectional,
            embeddings=self.encoder_embeddings,
            dropout=0.0,
        )
        
        self.decoder_embeddings = onmt.modules.Embeddings(c.emb_size, c.tgt_vocab_size, word_padding_idx=c.tgt_padding, dropout=0)
        self.decoder = onmt.decoders.decoder.InputFeedRNNDecoder(
            hidden_size=c.rnn_size,
            num_layers=num_layers,
            bidirectional_encoder=bidirectional, 
            rnn_type=c.rnn_type,
            embeddings=self.decoder_embeddings,
            dropout=0.0,
        )
        
        self.generator = Generator(c)
        
        if config.DISTRIBUTIONAL:
            self.quantile_weight = 1.0 / config.QUANTILES

        # Supervised Learning
        self.pretrain_generator = nn.Sequential(
            nn.Linear(self.config.rnn_size, self.config.tgt_vocab_size),
            nn.LogSoftmax(dim=-1)
        )

        # for supervised inference / beam search
        self.scorer = GNMTGlobalScorer(alpha=0.7, 
                                    beta=0., 
                                    length_penalty="avg", 
                                    coverage_penalty="none")
Ejemplo n.º 4
0
def run_example():
    BEAM_SIZE = 2
    N_BEST = 1
    BATCH_SZ = 1
    SEQ_LEN = 3

    initial = [0.35, 0.25, 0.4]
    transition_matrix = [[0.3, 0.6, 0.1], [0.4, 0.2, 0.4], [0.3, 0.4, 0.4]]

    beam = BeamSearch(BEAM_SIZE, BATCH_SZ, 0, 1, 2, N_BEST,
                      GNMTGlobalScorer(0.7, 0., "avg", "none"), 0, 30, False,
                      0, set(), False, 0.)
    device_init = torch.zeros(1, 1)
    beam.initialize(device_init, torch.randint(0, 30, (BATCH_SZ, )))

    def printBestNPaths(beam: BeamSearch, step: int):
        print(f'\nstep {step} beam results:')
        for k in range(BEAM_SIZE):
            best_path = beam.alive_seq[k].squeeze().tolist()[1:]
            prob = exp(beam.topk_log_probs[0][k])
            print(f'prob {prob:.3f} with path {best_path}')

    init_scores = torch.log(torch.tensor([initial], dtype=torch.float))
    init_scores = deepcopy(init_scores.repeat(BATCH_SZ * BEAM_SIZE, 1))
    beam.advance(init_scores, None)
    printBestNPaths(beam, 0)

    for step in range(SEQ_LEN - 1):
        idx_list = beam.topk_ids.squeeze().tolist()
        beam_transition = []
        for idx in idx_list:
            beam_transition.append(transition_matrix[idx])
        beam_transition_tensor = torch.log(torch.tensor(beam_transition))

        beam.advance(beam_transition_tensor, None)
        beam.update_finished()

        printBestNPaths(beam, step + 1)
                         ratio=-0.0,
                         coverage_penalty='none',
                         alpha=0.0,
                         beta=-0.0,
                         block_ngram_repeat=0,
                         ignore_when_blocking=[],
                         replace_unk=False,
                         phrase_table='',
                         verbose=True,
                         dump_beam='',
                         n_best=1,
                         batch_type='sents',
                         gpu=0)

fields, model, model_opt = load_test_model(opt, args)
scorer = GNMTGlobalScorer.from_opt(opt)
out_file = codecs.open(opt.output, 'w+', 'utf-8')
translator = Translator.from_opt(model,
                                 fields,
                                 opt,
                                 model_opt,
                                 args,
                                 global_scorer=scorer,
                                 out_file=out_file,
                                 report_align=opt.report_align,
                                 report_score=False,
                                 logger=None)

res = []
n = 1
with open(args.input_file, 'r') as f:
Ejemplo n.º 6
0
    def __init__(
        self,
        encoder,
        decoder_hidden,
        embeddings,
        max_layer=12,
        src_pad_idx=0,
        encoder_hidden=None,
        latent_size=None,
        scalar_mix=False,
        aggregator="mean",
        teacher_forcing_p=0.3,
        classification=None,
        attentional=False,
        definition_encoder=None,
        word_dropout_p=None,
        decoder_num_layers=2,
    ):
        super(DefinitionProbingLSTM, self).__init__()

        self.embeddings = embeddings
        self.encoder_hidden = encoder_hidden
        self.decoder_hidden = decoder_hidden
        self.decoder_num_layers = decoder_num_layers
        self.encoder = encoder
        self.latent_size = latent_size
        self.src_pad_idx = src_pad_idx
        self.aggregator = aggregator
        self.context_feed_forward = nn.Linear(self.encoder_hidden, self.encoder_hidden)
        self.scalar_mix = None
        if scalar_mix:
            self.scalar_mix = ScalarMix(self.max_layer + 1)
        self.global_scorer = GNMTGlobalScorer(
            alpha=2, beta=None, length_penalty="avg", coverage_penalty=None
        )
        self.definition_encoder = LSTM_Encoder(
            self.embeddings._def,
            self.encoder_hidden,
            self.encoder_num_layers,
            self.dropout_dict.src,
            self.dropout_dict.src,
        )
        self.context_encoder = LSTM_Encoder(
            self.embeddings.src,
            self.encoder_hidden,
            self.encoder_num_layers,
            self.dropout_dict.src,
            self.dropout_dict.src,
        )

        self.decoder = LSTM_Decoder(
            embeddings.tgt,
            hidden=self.decoder_hidden,
            encoder_hidden=self.encoder_hidden,
            num_layers=self.decoder_num_layers,
            word_dropout=word_dropout_p,
            teacher_forcing_p=teacher_forcing_p,
            attention="general" if attentional else None,
            dropout=DotMap({"input": 0.5, "output": 0.5}),
            decoder="VDM" if self.variational else "LSTM",
            variational=self.variational,
            latent_size=self.latent_size,
        )

        self.target_kl = 1.0
        self.definition_feed_forward = nn.Linear(
            self.encoder_hidden, self.encoder_hidden
        )
        self.mean_layer = nn.Linear(self.latent_size, self.latent_size)
        self.logvar_layer = nn.Linear(self.latent_size, self.latent_size)
        self.w_z_post = nn.Sequential(
            nn.Linear(self.encoder_hidden * 2, self.latent_size), nn.Tanh()
        )
        self.mean_prime_layer = nn.Linear(self.latent_size, self.latent_size)
        self.logvar_prime_layer = nn.Linear(self.latent_size, self.latent_size)
        self.w_z_prior = nn.Sequential(
            nn.Linear(self.encoder_hidden, self.latent_size), nn.Tanh()
        )
        self.z_project = nn.Sequential(
            nn.Linear(self.latent_size, self.decoder_hidden), nn.Tanh()
        )
Ejemplo n.º 7
0
    def __init__(
        self,
        encoder,
        encoder_pretrained,
        encoder_frozen,
        decoder_hidden,
        embeddings,
        max_layer=12,
        src_pad_idx=0,
        encoder_hidden=None,
        variational=None,
        latent_size=None,
        scalar_mix=False,
        aggregator="mean",
        teacher_forcing_p=0.3,
        classification=None,
        attentional=False,
        definition_encoder=None,
        word_dropout_p=None,
        decoder_num_layers=None,
    ):
        super(DefinitionProbing, self).__init__()

        self.embeddings = embeddings
        self.variational = variational
        self.encoder_hidden = encoder_hidden
        self.decoder_hidden = decoder_hidden
        self.decoder_num_layers = decoder_num_layers
        self.encoder = encoder
        self.latent_size = latent_size
        self.src_pad_idx = src_pad_idx
        if encoder_pretrained:
            self.encoder_hidden = self.encoder.config.hidden_size
        if encoder_frozen:
            for param in self.encoder.parameters():
                param.requires_grad = False
        self.max_layer = max_layer
        self.aggregator = aggregator
        if self.aggregator == "span":
            self.span_extractor = SelfAttentiveSpanExtractor(self.encoder_hidden)
        self.context_feed_forward = nn.Linear(self.encoder_hidden, self.encoder_hidden)
        self.scalar_mix = None
        if scalar_mix:
            self.scalar_mix = ScalarMix(self.max_layer + 1)
        self.global_scorer = GNMTGlobalScorer(
            alpha=2, beta=None, length_penalty="avg", coverage_penalty=None
        )

        self.decoder = LSTM_Decoder(
            embeddings.tgt,
            hidden=self.decoder_hidden,
            encoder_hidden=self.encoder_hidden,
            num_layers=self.decoder_num_layers,
            word_dropout=word_dropout_p,
            teacher_forcing_p=teacher_forcing_p,
            attention="general" if attentional else None,
            dropout=DotMap({"input": 0.5, "output": 0.5}),
            decoder="VDM" if self.variational else "LSTM",
            variational=self.variational,
            latent_size=self.latent_size,
        )

        self.target_kl = 1.0
        if self.variational:
            self.definition_encoder = definition_encoder
            self.definition_feed_forward = nn.Linear(
                self.encoder_hidden, self.encoder_hidden
            )
            self.mean_layer = nn.Linear(self.latent_size, self.latent_size)
            self.logvar_layer = nn.Linear(self.latent_size, self.latent_size)
            self.w_z_post = nn.Sequential(
                nn.Linear(self.encoder_hidden * 2, self.latent_size), nn.Tanh()
            )
            self.mean_prime_layer = nn.Linear(self.latent_size, self.latent_size)
            self.logvar_prime_layer = nn.Linear(self.latent_size, self.latent_size)
            self.w_z_prior = nn.Sequential(
                nn.Linear(self.encoder_hidden, self.latent_size), nn.Tanh()
            )
            self.z_project = nn.Sequential(
                nn.Linear(self.latent_size, self.decoder_hidden), nn.Tanh()
            )
Ejemplo n.º 8
0
def batch_beam_search_trs(model, inputs, batch_size, device="cpu", max_len=128, beam_size=20, n_best=1, alpha=1.):
    """ beam search with batch input for Transformer model

    Arguments:
        beam {onmt.BeamSearch} -- opennmt BeamSearch class
        model {torch.nn.Module} -- subclass of torch.nn.Module, required to implement .encode() and .decode() method
        inputs {list} -- list of torch.Tensor for input of encode()

    Keyword Arguments:
        device {str} -- device to eval model (default: {"cpu"})

    Returns:
        result -- 2D list (B, N-best), each element is an (seq, score) pair
    """
    beam = BeamSearch(beam_size, batch_size,
                    pad=C.PAD, bos=C.BOS, eos=C.EOS,
                    n_best=n_best,
                    mb_device=device,
                    global_scorer=GNMTGlobalScorer(alpha, 0.1, "avg", "none"),
                    min_length=0,
                    max_length=max_len,
                    ratio=0.0,
                    memory_lengths=None,
                    block_ngram_repeat=False,
                    exclusion_tokens=None,
                    stepwise_penalty=True,
                    return_attention=False,
                    )
    model.eval()
    is_finished = [False] * beam.batch_size
    with torch.no_grad():
        src_ids, _, src_pos, _, _, src_key_padding_mask, _, original_memory_key_padding_mask = list(
            map(lambda x: x.to(device), inputs))
        
        if src_ids.shape[1]!=batch_size:
            diff = batch_size - src_ids.shape[1]
            src_ids = torch.cat([src_ids] + [src_ids[:,:1]] * diff, dim=1)
            src_pos = torch.cat([src_pos] + [src_pos[:,:1]] * diff, dim=1)
            src_key_padding_mask = torch.cat([src_key_padding_mask]+[src_key_padding_mask[:1]]* diff, dim=0)
            original_memory_key_padding_mask = torch.cat([original_memory_key_padding_mask] +[original_memory_key_padding_mask[:1]]*diff, dim=0)


        model.to(device)
        original_memory = model.encode(src_ids, src_pos, src_key_padding_mask=src_key_padding_mask)

        memory = original_memory
        memory_key_padding_mask = original_memory_key_padding_mask
        while not beam.done:
            len_decoder_inputs = beam.alive_seq.shape[1]
            dec_pos = torch.arange(1, len_decoder_inputs+1).repeat(beam.alive_seq.shape[0], 1).permute(1, 0).to(device)

            # unsqueeze the memory and memory_key_padding_mask in B dim to match the size (BM*BS)
            repeated_memory = memory.repeat(1, 1, beam.beam_size).reshape(
                memory.shape[0], -1, memory.shape[-1])
            repeated_memory_key_padding_mask = memory_key_padding_mask.repeat(
                1, beam.beam_size).reshape(-1, memory_key_padding_mask.shape[1])

            decoder_outputs = model.decode(beam.alive_seq.permute(1, 0), dec_pos, _, repeated_memory, memory_key_padding_mask=repeated_memory_key_padding_mask)[-1]
            if hasattr(model, "proj"):
                logits = model.proj(decoder_outputs)
            elif hasattr(model, "gen"):
                logits = model.gen(decoder_outputs)
            else:
                raise ValueError("Unknown generator!")

            log_probs = torch.nn.functional.log_softmax(logits, dim=1)
            beam.advance(log_probs, None)
            if beam.is_finished.any():
                beam.update_finished()

                # select data for the still-alive index
                for i, n_best in enumerate(beam.predictions):
                    if is_finished[i] == False and len(n_best) == beam.n_best:
                        is_finished[i] = True

                alive_example_idx = [i for i in range(
                    len(is_finished)) if not is_finished[i]]
                if alive_example_idx:
                    memory = original_memory[:, alive_example_idx, :]
                    memory_key_padding_mask = original_memory_key_padding_mask[alive_example_idx]

    # packing data for easy accessing
    results = []
    for batch_preds, batch_scores in zip(beam.predictions, beam.scores):
        n_best_result = []
        for n_best_pred, n_best_score in zip(batch_preds, batch_scores):
            assert isinstance(n_best_pred, torch.Tensor)
            assert isinstance(n_best_score, torch.Tensor)
            n_best_result.append(
                (n_best_pred.tolist(), n_best_score.item())
            )
        results.append(n_best_result)

    return results
Ejemplo n.º 9
0
    def __init__(self, model_dir):

        # Model dir
        self._model_dir = os.path.abspath(model_dir)
        if not os.path.isdir(self._model_dir):
            msg = f"{model_dir} doesn't exists'"
            raise ValueError(msg)

        # Extended model
        self._extended_model = ExtendedModel(model_dir)

        # Config
        self._config = self._extended_model.config

        # Options
        self._opts = self._config.opts

        # Get the model options
        model_path = self._opts.models[0]
        checkpoint = torch.load(
            model_path, map_location=lambda storage, loc: storage
        )
        self._model_opts = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
        ArgumentParser.update_model_opts(self._model_opts)
        ArgumentParser.validate_model_opts(self._model_opts)

        # Extract vocabulary
        vocab = checkpoint["vocab"]
        if inputters.old_style_vocab(vocab):
            self._fields = inputters.load_old_vocab(
                vocab, "text", dynamic_dict=False
            )
        else:
            self._fields = vocab

        # Train_steps
        self._train_steps = self._model_opts.train_steps

        # Build openmmt model
        self._opennmt_model = build_base_model(
            self._model_opts,
            self._fields,
            use_gpu(self._opts),
            checkpoint,
            self._opts.gpu,
        )

        # Translator
        try:
            min_length = self._opts.min_length
        except:
            min_length = 0

        try:
            max_length = self._opts.max_length
        except:
            max_length = 100

        try:
            beam_size = self._opts.beam_size
        except:
            beam_size = 5

        try:
            replace_unk = self._opts.replace_unk
        except:
            replace_unk = 0

        self._translator = Translator(
            self._opennmt_model,
            self._fields,
            TextDataReader(),
            TextDataReader(),
            gpu=self._opts.gpu,
            min_length=min_length,
            max_length=max_length,
            beam_size=beam_size,
            replace_unk=replace_unk,
            copy_attn=self._model_opts.copy_attn,
            global_scorer=GNMTGlobalScorer(0.0, -0.0, "none", "none"),
            seed=self.SEED,
        )

        online_learning = self._config.online_learning
        if online_learning:
            # Optim
            optimizer_opt = type("", (), {})()
            optimizer_opt.optim = "sgd"
            optimizer_opt.learning_rate = self._opts.learning_rate
            optimizer_opt.train_from = ""
            optimizer_opt.adam_beta1 = 0
            optimizer_opt.adam_beta2 = 0
            optimizer_opt.model_dtype = "fp32"
            optimizer_opt.decay_method = "none"
            optimizer_opt.start_decay_steps = 100000
            optimizer_opt.learning_rate_decay = 1.0
            optimizer_opt.decay_steps = 100000
            optimizer_opt.max_grad_norm = 5
            self._optim = Optimizer.from_opt(
                self._opennmt_model, optimizer_opt, checkpoint=None
            )

            trainer_opt = type("", (), {})()
            trainer_opt.lambda_coverage = 0.0
            trainer_opt.copy_attn = False
            trainer_opt.label_smoothing = 0.0
            trainer_opt.truncated_decoder = 0
            trainer_opt.model_dtype = "fp32"
            trainer_opt.max_generator_batches = 32
            trainer_opt.normalization = "sents"
            trainer_opt.accum_count = [1]
            trainer_opt.accum_steps = [0]
            trainer_opt.world_size = 1
            trainer_opt.average_decay = 0
            trainer_opt.average_every = 1
            trainer_opt.dropout = 0
            trainer_opt.dropout_steps = (0,)
            trainer_opt.gpu_verbose_level = 0
            trainer_opt.early_stopping = 0
            trainer_opt.early_stopping_criteria = (None,)
            trainer_opt.tensorboard = False
            trainer_opt.report_every = 50
            trainer_opt.gpu_ranks = []
            if self._opts.gpu != -1:
                trainer_opt.gpu_ranks = [self._opts.gpu]

            self._trainer = build_trainer(
                trainer_opt,
                self._opts.gpu,
                self._opennmt_model,
                self._fields,
                self._optim,
            )
        else:
            self._trainer = None
Ejemplo n.º 10
0
    def __init__(self, model_dir):

        # Model dir
        self._model_dir = os.path.abspath(model_dir)
        if not os.path.isdir(self._model_dir):
            msg = f"{model_dir} doesn't exists'"
            raise ValueError(msg)

        # Extended model
        self._extended_model = ExtendedModel(model_dir)

        # Config
        self._config = self._extended_model.config

        # Options
        self._opts = self._config.opts

        # Get the model options
        model_path = self._opts.models[0]
        checkpoint = torch.load(model_path,
                                map_location=lambda storage, loc: storage)
        self._model_opts = ArgumentParser.ckpt_model_opts(checkpoint['opt'])
        ArgumentParser.update_model_opts(self._model_opts)
        ArgumentParser.validate_model_opts(self._model_opts)

        # Train_steps
        self._train_steps = self._model_opts.train_steps

        # Extract vocabulary
        vocab = checkpoint['vocab']
        if inputters.old_style_vocab(vocab):
            self._fields = inputters.load_old_vocab(
                vocab,
                self._opts.data_type,
                dynamic_dict=self._model_opts.copy_attn)
        else:
            self._fields = vocab

        # Build model
        self._model = build_base_model(self._model_opts, self._fields,
                                       use_gpu(self._opts), checkpoint,
                                       self._opts.gpu)

        if self._opts.fp32:
            self._model.float()

        #Translator
        scorer = GNMTGlobalScorer.from_opt(self._opts)

        self.translator = OnmtxTranslator.from_opt(
            self._model,
            self._fields,
            self._opts,
            self._model_opts,
            global_scorer=scorer,
            out_file=None,
            report_score=False,
            logger=None,
        )

        # Create trainer
        self._optim = Optimizer.from_opt(self._model,
                                         self._opts,
                                         checkpoint=checkpoint)

        device_id = -1  # TODO Handle GPU
        self.trainer = build_trainer(self._opts, device_id, self._model,
                                     self._fields, self._optim)
Ejemplo n.º 11
0
    def generate(self, src, lengths, dec_idx, max_length=20, beam_size=5, n_best=1):
        assert dec_idx == 0 or dec_idx == 1
        batch_size = src.size(1)
        
        def var(a):
            return torch.tensor(a, requires_grad=False)
        
        def rvar(a):
            return var(a.repeat(1, beam_size, 1))
        
        def bottle(m):
            return m.view(batch_size * beam_size, -1)

        def unbottle(m):
            return m.view(beam_size, batch_size, -1)
        
        def from_beam(beam):
            ret = {"predictions": [],
                   "scores": [],
                   "attention": []}
            for b in beam:
                scores, ks = b.sort_finished(minimum=n_best)
                hyps, attn = [], []
                for i, (times, k) in enumerate(ks[:n_best]):
                    hyp, att = b.get_hyp(times, k)
                    hyps.append(hyp)
                    attn.append(att)
                ret["predictions"].append(hyps)
                ret["scores"].append(scores)
                ret["attention"].append(attn)
            return ret
        
        
        scorer = GNMTGlobalScorer(0, 0, "none", "none")
        
        beam = [Beam(beam_size, n_best=n_best,
                     cuda=self.cuda(),
                     global_scorer=scorer,
                     pad=PAD_IDX,
                     eos=SOS_IDX,
                     bos=EOS_IDX,
                     min_length=0,
                     stepwise_penalty=False,
                     block_ngram_repeat=0)
                for __ in range(batch_size)]
        
        enc_final, memory_bank = self.encoder(src, lengths)
        
        token = torch.full((1, batch_size, 1), SOS_IDX, dtype=torch.long, device=next(self.parameters()).device)
        dec_state = enc_final
        dec_state = self.choose_decoder(dec_idx).init_decoder_state(src, memory_bank, dec_state)
               
        memory_bank = rvar(memory_bank.data)
        memory_lengths = lengths.repeat(beam_size)
        dec_state.repeat_beam_size_times(beam_size)
        
        # unroll
        all_indices = []
        for i in range(max_length):
            if all((b.done() for b in beam)):
                break
                
            inp = var(torch.stack([b.get_current_state() for b in beam]).t().contiguous().view(1, -1))
            inp = inp.unsqueeze(2)
                
            decoder_output, dec_state, attn = self.choose_decoder(dec_idx)(inp, memory_bank, dec_state, memory_lengths=memory_lengths, step=i)
            
            decoder_output = decoder_output.squeeze(0)
            
            out = self.generator(decoder_output).data
            out = unbottle(out)
            
            # beam x tgt_vocab
            beam_attn = unbottle(attn["std"])
            
            for j, b in enumerate(beam):
                b.advance(out[:, j], beam_attn.data[:, j, :memory_lengths[j]])
                dec_state.beam_update(j, b.get_current_origin(), beam_size)
    
        ret = from_beam(beam)
#        ret["src"] = src.transpose(1, 0)

        return ret