Beispiel #1
0
def build_rnn_model(vocab_size, embed_dim, rnn_size, num_layers, dropout_p, bidirectional):
    # Build encoder
    src_embeddings = Embeddings(
        word_vec_size=embed_dim,
        word_vocab_size=vocab_size,
        word_padding_idx=0
    )
    encoder = RNNEncoder("GRU", bidirectional, num_layers, rnn_size, dropout=dropout_p, embeddings=src_embeddings)
    
    tgt_embeddings0 = Embeddings(
        word_vec_size=embed_dim,
        word_vocab_size=vocab_size,
        word_padding_idx=0
    )
    decoder0 = StdRNNDecoder("GRU", bidirectional, num_layers, rnn_size, dropout=dropout_p, embeddings=tgt_embeddings0)
    tgt_embeddings1 = Embeddings(
        word_vec_size=embed_dim,
        word_vocab_size=vocab_size,
        word_padding_idx=0
    )
    tgt_embeddings1 = Embeddings(
        word_vec_size=embed_dim,
        word_vocab_size=vocab_size,
        word_padding_idx=0
    )
    decoder1 = StdRNNDecoder("GRU", bidirectional, num_layers, rnn_size, dropout=dropout_p, embeddings=tgt_embeddings1)
    
    return encoder, decoder0, decoder1
Beispiel #2
0
def build_decoder(opt, embeddings):
    """
    Various decoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this decoder.
    """
    if opt.decoder_type == "transformer":
        return TransformerDecoder(opt.dec_layers, opt.dec_rnn_size, opt.heads,
                                  opt.transformer_ff, opt.global_attention,
                                  opt.copy_attn, opt.self_attn_type,
                                  opt.dropout, embeddings)
    elif opt.decoder_type == "cnn":
        return CNNDecoder(opt.dec_layers, opt.dec_rnn_size,
                          opt.global_attention, opt.copy_attn,
                          opt.cnn_kernel_width, opt.dropout, embeddings)
    elif opt.input_feed:
        return InputFeedRNNDecoder(opt.rnn_type, opt.brnn, opt.dec_layers,
                                   opt.dec_rnn_size, opt.global_attention,
                                   opt.global_attention_function,
                                   opt.coverage_attn, opt.context_gate,
                                   opt.copy_attn, opt.dropout, embeddings,
                                   opt.reuse_copy_attn)
    else:
        return StdRNNDecoder(opt.rnn_type, opt.brnn, opt.dec_layers,
                             opt.dec_rnn_size, opt.global_attention,
                             opt.global_attention_function, opt.coverage_attn,
                             opt.context_gate, opt.copy_attn, opt.dropout,
                             embeddings, opt.reuse_copy_attn)
Beispiel #3
0
def build_decoder(opt, embeddings):
    """
    Various decoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this decoder.
    """
    if opt.decoder_type == "transformer":
        return TransformerDecoder(opt.dec_layers, opt.rnn_size,
                                  opt.heads, opt.transformer_ff,
                                  opt.global_attention, opt.copy_attn,
                                  opt.self_attn_type,
                                  opt.dropout, embeddings)
    elif opt.decoder_type == "cnn":
        return CNNDecoder(opt.dec_layers, opt.rnn_size,
                          opt.global_attention, opt.copy_attn,
                          opt.cnn_kernel_width, opt.dropout,
                          embeddings)
    elif opt.input_feed:
        assert opt.key_model in ["key_generator", "key_end2end"]
        if opt.key_model == "key_generator":
            return InputFeedRNNDecoder(opt.rnn_type, opt.brnn,
                                       opt.dec_layers, opt.rnn_size,
                                       opt.global_attention,
                                       opt.coverage_attn,
                                       opt.context_gate,
                                       opt.copy_attn,
                                       opt.dropout,
                                       embeddings,
                                       opt.reuse_copy_attn,
                                       no_sftmax_bf_rescale=opt.no_sftmax_bf_rescale,
                                       use_retrieved_keys=opt.use_retrieved_keys)
        else:
            return MyInputFeedRNNDecoder(opt.rnn_type, opt.brnn,
                                         opt.dec_layers, opt.rnn_size,
                                         opt.global_attention,
                                         opt.coverage_attn,
                                         opt.context_gate,
                                         opt.copy_attn,
                                         opt.dropout,
                                         embeddings,
                                         opt.reuse_copy_attn,
                                         not_use_sel_probs=opt.not_use_sel_probs,
                                         no_sftmax_bf_rescale=opt.no_sftmax_bf_rescale,
                                         use_retrieved_keys=opt.use_retrieved_keys,
                                         only_rescale_copy=opt.only_rescale_copy)
    else:
        return StdRNNDecoder(opt.rnn_type, opt.brnn,
                             opt.dec_layers, opt.rnn_size,
                             opt.global_attention,
                             opt.coverage_attn,
                             opt.context_gate,
                             opt.copy_attn,
                             opt.dropout,
                             embeddings,
                             opt.reuse_copy_attn)
Beispiel #4
0
def build_decoder(opt, embeddings):
    """
    Various decoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this decoder.
    """
    if opt.decoder_type == "rnn":

        if opt.input_feed:
            return InputFeedRNNDecoder(opt.rnn_type, opt.brnn, opt.dec_layers,
                                       opt.rnn_size, opt.global_attention,
                                       opt.coverage_attn, opt.context_gate,
                                       opt.copy_attn, opt.dropout, embeddings,
                                       opt.reuse_copy_attn)
        else:
            return StdRNNDecoder(opt.rnn_type, opt.brnn, opt.dec_layers,
                                 opt.rnn_size, opt.global_attention,
                                 opt.coverage_attn, opt.context_gate,
                                 opt.copy_attn, opt.dropout, embeddings,
                                 opt.reuse_copy_attn)
    else:
        raise ModuleNotFoundError("Decoder type not found")
Beispiel #5
0
    def __init__(self,
                 model_dim=None,
                 model_type=None,
                 word_embedding_dim=None,
                 vocab_size=None,
                 initial_embeddings=None,
                 fine_tune_loaded_embeddings=None,
                 num_classes=None,
                 embedding_keep_rate=None,
                 tracking_lstm_hidden_dim=4,
                 transition_weight=None,
                 encode_reverse=None,
                 encode_bidirectional=None,
                 encode_num_layers=None,
                 lateral_tracking=None,
                 tracking_ln=None,
                 use_tracking_in_composition=None,
                 predict_use_cell=None,
                 use_sentence_pair=False,
                 use_difference_feature=False,
                 use_product_feature=False,
                 mlp_dim=None,
                 num_mlp_layers=None,
                 mlp_ln=None,
                 classifier_keep_rate=None,
                 context_args=None,
                 composition_args=None,
                 with_attention=False,
                 data_type=None,
                 target_vocabulary=None,
                 onmt_module=None,
                 FLAGS=None,
                 data_manager=None,
                 **kwargs):
        super(NMTModel, self).__init__()

        assert not (
            use_tracking_in_composition and not lateral_tracking
        ), "Lateral tracking must be on to use tracking in composition."

        self.kwargs = kwargs

        self.model_dim = model_dim
        self.model_type = model_type
        self.data_type = data_type
        self.target_vocabulary = target_vocabulary

        if self.model_type == "SPINN":
            encoder_builder = spinn_builder
        elif self.model_type == "RLSPINN":
            encoder_builder = rl_builder
        elif self.model_type == "LMS":
            encoder_builder = lms_builder
        elif self.model_type == "RNN":
            encoder_builder = rnn_builder

        if self.model_type == "SPINN" or "RNN" or "LMS":
            self.encoder = encoder_builder(
                model_dim=model_dim,
                word_embedding_dim=word_embedding_dim,
                vocab_size=vocab_size,
                initial_embeddings=initial_embeddings,
                fine_tune_loaded_embeddings=fine_tune_loaded_embeddings,
                num_classes=num_classes,
                embedding_keep_rate=embedding_keep_rate,
                tracking_lstm_hidden_dim=tracking_lstm_hidden_dim,
                transition_weight=transition_weight,
                use_sentence_pair=use_sentence_pair,
                lateral_tracking=lateral_tracking,
                tracking_ln=tracking_ln,
                use_tracking_in_composition=use_tracking_in_composition,
                predict_use_cell=predict_use_cell,
                use_difference_feature=use_difference_feature,
                use_product_feature=use_product_feature,
                classifier_keep_rate=classifier_keep_rate,
                mlp_dim=mlp_dim,
                num_mlp_layers=num_mlp_layers,
                mlp_ln=mlp_ln,
                context_args=context_args,
                composition_args=composition_args,
                with_attention=with_attention,
                data_type=data_type,
                onmt_module=onmt_module,
                FLAGS=FLAGS,
                data_manager=data_manager)
        else:
            self.encoder = rl_builder(data_manager=data_manager,
                                      initial_embeddings=initial_embeddings,
                                      vocab_size=vocab_size,
                                      num_classes=num_classes,
                                      FLAGS=FLAGS,
                                      context_args=context_args,
                                      composition_args=composition_args)
        if self.model_type == "LMS":
            self.model_dim **= 2
        # To-do: move this head of script. onmt_module path needs to be imported to do so.
        sys.path.append(onmt_module)
        from onmt.decoders.decoder import InputFeedRNNDecoder, StdRNNDecoder, RNNDecoderBase
        from onmt.encoders.rnn_encoder import RNNEncoder
        from onmt.modules import Embeddings

        self.output_embeddings = Embeddings(self.model_dim,
                                            len(target_vocabulary) + 1, 0)

        # Below, model_dim is multiplied by 2 so that the output dimension is the same as the
        # input word embedding dimension, and not half.
        # Look at TreeRNN for details (there is a down projection).
        if self.model_type == "RNN":
            self.is_bidirectional = True
            self.down_project = Linear()(2 * self.model_dim,
                                         self.model_dim,
                                         bias=True)
            self.down_project_context = Linear()(2 * self.model_dim,
                                                 self.model_dim,
                                                 bias=True)
        else:
            if self.model_type == "LMS":
                self.spinn = self.encoder.lms
            else:
                self.spinn = self.encoder.spinn
            self.is_bidirectional = False

        self.decoder = StdRNNDecoder("LSTM",
                                     self.is_bidirectional,
                                     1,
                                     self.model_dim,
                                     embeddings=self.output_embeddings)
        self.generator = nn.Sequential(
            nn.Linear(self.model_dim,
                      len(self.target_vocabulary) + 1), nn.LogSoftmax())
Beispiel #6
0
class NMTModel(nn.Module):
    def __init__(self,
                 model_dim=None,
                 model_type=None,
                 word_embedding_dim=None,
                 vocab_size=None,
                 initial_embeddings=None,
                 fine_tune_loaded_embeddings=None,
                 num_classes=None,
                 embedding_keep_rate=None,
                 tracking_lstm_hidden_dim=4,
                 transition_weight=None,
                 encode_reverse=None,
                 encode_bidirectional=None,
                 encode_num_layers=None,
                 lateral_tracking=None,
                 tracking_ln=None,
                 use_tracking_in_composition=None,
                 predict_use_cell=None,
                 use_sentence_pair=False,
                 use_difference_feature=False,
                 use_product_feature=False,
                 mlp_dim=None,
                 num_mlp_layers=None,
                 mlp_ln=None,
                 classifier_keep_rate=None,
                 context_args=None,
                 composition_args=None,
                 with_attention=False,
                 data_type=None,
                 target_vocabulary=None,
                 onmt_module=None,
                 FLAGS=None,
                 data_manager=None,
                 **kwargs):
        super(NMTModel, self).__init__()

        assert not (
            use_tracking_in_composition and not lateral_tracking
        ), "Lateral tracking must be on to use tracking in composition."

        self.kwargs = kwargs

        self.model_dim = model_dim
        self.model_type = model_type
        self.data_type = data_type
        self.target_vocabulary = target_vocabulary

        if self.model_type == "SPINN":
            encoder_builder = spinn_builder
        elif self.model_type == "RLSPINN":
            encoder_builder = rl_builder
        elif self.model_type == "LMS":
            encoder_builder = lms_builder
        elif self.model_type == "RNN":
            encoder_builder = rnn_builder

        if self.model_type == "SPINN" or "RNN" or "LMS":
            self.encoder = encoder_builder(
                model_dim=model_dim,
                word_embedding_dim=word_embedding_dim,
                vocab_size=vocab_size,
                initial_embeddings=initial_embeddings,
                fine_tune_loaded_embeddings=fine_tune_loaded_embeddings,
                num_classes=num_classes,
                embedding_keep_rate=embedding_keep_rate,
                tracking_lstm_hidden_dim=tracking_lstm_hidden_dim,
                transition_weight=transition_weight,
                use_sentence_pair=use_sentence_pair,
                lateral_tracking=lateral_tracking,
                tracking_ln=tracking_ln,
                use_tracking_in_composition=use_tracking_in_composition,
                predict_use_cell=predict_use_cell,
                use_difference_feature=use_difference_feature,
                use_product_feature=use_product_feature,
                classifier_keep_rate=classifier_keep_rate,
                mlp_dim=mlp_dim,
                num_mlp_layers=num_mlp_layers,
                mlp_ln=mlp_ln,
                context_args=context_args,
                composition_args=composition_args,
                with_attention=with_attention,
                data_type=data_type,
                onmt_module=onmt_module,
                FLAGS=FLAGS,
                data_manager=data_manager)
        else:
            self.encoder = rl_builder(data_manager=data_manager,
                                      initial_embeddings=initial_embeddings,
                                      vocab_size=vocab_size,
                                      num_classes=num_classes,
                                      FLAGS=FLAGS,
                                      context_args=context_args,
                                      composition_args=composition_args)
        if self.model_type == "LMS":
            self.model_dim **= 2
        # To-do: move this head of script. onmt_module path needs to be imported to do so.
        sys.path.append(onmt_module)
        from onmt.decoders.decoder import InputFeedRNNDecoder, StdRNNDecoder, RNNDecoderBase
        from onmt.encoders.rnn_encoder import RNNEncoder
        from onmt.modules import Embeddings

        self.output_embeddings = Embeddings(self.model_dim,
                                            len(target_vocabulary) + 1, 0)

        # Below, model_dim is multiplied by 2 so that the output dimension is the same as the
        # input word embedding dimension, and not half.
        # Look at TreeRNN for details (there is a down projection).
        if self.model_type == "RNN":
            self.is_bidirectional = True
            self.down_project = Linear()(2 * self.model_dim,
                                         self.model_dim,
                                         bias=True)
            self.down_project_context = Linear()(2 * self.model_dim,
                                                 self.model_dim,
                                                 bias=True)
        else:
            if self.model_type == "LMS":
                self.spinn = self.encoder.lms
            else:
                self.spinn = self.encoder.spinn
            self.is_bidirectional = False

        self.decoder = StdRNNDecoder("LSTM",
                                     self.is_bidirectional,
                                     1,
                                     self.model_dim,
                                     embeddings=self.output_embeddings)
        self.generator = nn.Sequential(
            nn.Linear(self.model_dim,
                      len(self.target_vocabulary) + 1), nn.LogSoftmax())

    def forward(self,
                sentences,
                transitions,
                y_batch=None,
                use_internal_parser=False,
                validate_transitions=True,
                **kwargs):

        example, spinn_outp, attended, transition_loss, transitions_acc, memory_lengths = self.encoder(
            sentences,
            transitions,
            y_batch,
            use_internal_parser=use_internal_parser,
            validate_transitions=validate_transitions)
        self.sentences = sentences
        self.transitions = transitions
        self.y_batch = y_batch

        nfeat = 1  # 5984 #self.output_embeddings.embedding_size
        target_maxlen = max([len(x) for x in y_batch])
        self.target_maxlen = target_maxlen
        maxlen = example.tokens.size()[1]

        # To-do: Replace below with nn.utils.rnn.pad_sequence
        tmp_target = []
        t_mask = []
        for x in y_batch:
            arr = np.array(list(x) + [1] * (target_maxlen - len(x)))
            t_mask.append([1] * (len(x) + 1) + [0] * (target_maxlen - len(x)))
            # arr = x + [1]*(target_maxlen - len(x))
            tmp = []
            for y in arr:
                la = y
                tmp.append(la)
            tmp_target.append(tmp)

        target = []
        batch_size = example.tokens.size()[0]
        self.t_tmask_target = []
        for i in range(target_maxlen):
            tmp = []
            tmp_mask = []
            for j in range(batch_size):
                tmp.append(tmp_target[j][i])
                tmp_mask.append(t_mask[j][i])
            target.append(tmp)
            self.t_tmask_target.append(tmp_mask)

        padded_enc_output = spinn_outp

        target = torch.tensor(np.array(target)).view(
            (target_maxlen, batch_size, nfeat)).long()
        target = to_gpu(Variable(target, requires_grad=False))

        if self.model_type == "SPINN" or \
           self.model_type == "RLSPINN" or \
           self.model_type == "LMS":
            src = torch.cat([
                torch.cat(x[::-1]).unsqueeze(0) for x in example.bufs
            ]).transpose(0, 1)
        else:
            src = example.bufs
            attended = attended.transpose(0, 1)
            padded_enc_output = padded_enc_output.view(1, 2 * batch_size,
                                                       self.model_dim)
            attended = self.down_project(attended)
            padded_enc_output = self.down_project_context(padded_enc_output)
        enc_state = self.decoder.init_decoder_state(
            src, attended, (padded_enc_output, padded_enc_output))

        teacher_force = False
        padded_enc_output = None
        enc_output = None
        t_mask = None
        tmp_target = None

        if self.training:
            if teacher_force:
                decoder = self.decoder(target, attended, enc_state)
                output = self.generator(decoder[0])
            else:
                start_token = to_gpu(
                    Variable(torch.zeros((1, batch_size, 1)),
                             requires_grad=False)).long()
                inp = start_token
                dec_state = enc_state
                output = []
                # Looping through max+1 to also predict end_token
                for i in range(target_maxlen + 1):
                    if i == 0:
                        inp = start_token
                    else:
                        inp = target[i - 1].unsqueeze(0)
                    dec_out, dec_state, attn = self.decoder(
                        inp,
                        attended,
                        dec_state,
                        memory_lengths=memory_lengths,
                        step=i)
                    output.append(
                        self.generator(dec_out.squeeze(0)).unsqueeze(0))
                output = torch.cat(output)

            if self.model_type == "RLSPINN":
                self.encoder.transition_loss = None
                self.output_hook(output, sentences, transitions, y_batch,
                                 self.t_tmask_target)

        # Now just predict during inference mode.
        else:
            start_token = to_gpu(
                Variable(torch.zeros((1, batch_size, 1)),
                         requires_grad=False)).long()
            inp = start_token
            maxpossible = 100
            dec_state = enc_state
            predicted = []

            # TODO: replace with k-beam search
            # inp = target[0].unsqueeze(0)
            debug = False
            score_matrix = []
            for i in range(100):
                dec_out, dec_state, attn = self.decoder(inp,
                                                        attended,
                                                        dec_state,
                                                        step=i)
                out = self.generator(dec_out.squeeze(0))
                argmaxed = torch.max(out, 1)[1]
                inp = argmaxed.unsqueeze(1).unsqueeze(0)
                predicted.append(argmaxed)
                if debug:
                    score_matrix.append(attn['std'].cpu().detach().numpy())
            if debug:
                filename = "attn__" + str(int(time.time()))
                pickle.dump(score_matrix, open(filename, "wb"))
            return predicted
        return output, target, None, torch.tensor(self.t_tmask_target)

    def build_reward(self, output, target, mask, rl_reward="mean"):
        if rl_reward == "xent":
            batch_size = target.size(0)
            seq_length = target.size(1)
            _target = target.permute(1, 0).long()
            output = output[:-1, :, :]  # drop <end> token
            probs = F.softmax(output, dim=2).data.cpu()
            log_inv_prob = torch.log(1 - probs)

            # Looping over seq_length to get a sum of rewards across the full sequence
            # Element-wise mean not supported yet.
            rewards = torch.zeros(batch_size)
            for i in range(seq_length):
                rewards += -1 * torch.gather(
                    log_inv_prob[i], 1, _target[i].unsqueeze(1)).squeeze()
        else:
            output = output.permute(1, 0, 2)
            target = to_gpu(Variable(target))
            if rl_reward == "mean":
                criterion = nn.NLLLoss(reduction="elementwise_mean")
            elif rl_reward == "sum":
                criterion = nn.NLLLoss(reduction="sum")
            batch_size = output.shape[0]
            rewards = [0.0] * batch_size

            # Note that we're putting NLLLoss to an unusual use below
            # Instead of passing a full batch of single token, we're passing a single full example of some sequence length
            # If summing, we're summing over all prediction, similarly for elementwise-mean
            for i in range(batch_size):
                rewards[i] = criterion(output[i][:-1, :], target[i].long())
            rewards = torch.tensor([float(x) for x in rewards])

        return rewards

    def build_baseline(self, rewards, sentences, transitions, y_batch=None):
        if self.encoder.rl_baseline == "ema":
            mu = self.encoder.rl_mu
            baseline = self.baseline[0]
            self.baseline[0] = self.baseline[0] * \
                (1 - mu) + rewards.mean() * mu

        elif self.encoder.rl_baseline == "pass":
            baseline = 0.

        elif self.encoder.rl_baseline == "greedy":
            # Pass inputs to Greedy Max
            output = self.run_greedy(sentences, transitions)

            # Estimate Reward
            probs = F.softmax(output, dim=1).data.cpu()
            target = torch.from_numpy(y_batch).long()
            approx_rewards = self.build_reward(
                probs, target, rl_reward=self.encoder.rl_reward)

            baseline = approx_rewards.view(-1)

        elif self.encoder.rl_baseline == "value":
            output = self.encoder.baseline_outp
            if self.encoder.rl_value_reward == "bce":
                baseline = torch.sigmoid(output).view(-1)
                self.value_loss = nn.BCELoss()(
                    s, to_gpu(Variable(rewards, volatile=not self.training)))
            elif self.encoder.rl_value_reward == "mse":
                baseline = output.view(-1)
                value_loss = nn.MSELoss()(
                    baseline,
                    to_gpu(Variable(rewards, volatile=not self.training)))
                self.value_loss = value_loss.mean()

            else:
                raise NotImplementedError

            baseline = baseline.data.cpu()

        else:
            raise NotImplementedError

        return baseline

    def reinforce(self, advantage):
        """
        t_preds  = 200...111 (flattened predictions from sub_batches 1...N)
        t_mask   = 011...111 (binary mask, selecting non-skips only)
        t_logprobs = (B*N)xC (tensor of sub_batch_size * sub_num_batches x transition classes)
        a_index  = 011...(N-1)(N-1)(N-1) (masked sub_batch_indices for each transition)
        t_index  = 013...(B*N-3)(B*N-2)(B*N-1) (masked indices across all sub_batches)
        """

        # TODO: Many of these ops are on the cpu. Might be worth shifting to
        # GPU.
        t_preds = np.concatenate(
            [m['t_preds'] for m in self.spinn.memories if 't_preds' in m])
        t_mask = np.concatenate(
            [m['t_mask'] for m in self.spinn.memories if 't_mask' in m])
        t_valid_mask = np.concatenate(
            [m['t_valid_mask'] for m in self.spinn.memories if 't_mask' in m])
        t_logprobs = torch.cat([
            m['t_logprobs'] for m in self.spinn.memories if 't_logprobs' in m
        ], 0)

        if self.encoder.rl_valid:
            t_mask = np.logical_and(t_mask, t_valid_mask)

        batch_size = advantage.size(0)

        seq_length = t_preds.shape[0] // batch_size
        a_index = np.arange(batch_size)
        a_index = a_index.reshape(1, -1).repeat(seq_length, axis=0).flatten()

        # Patch to handle no valid generated parses
        try:
            a_index = torch.from_numpy(a_index[t_mask]).long()
            t_index = to_gpu(
                Variable(torch.from_numpy(np.arange(
                    t_mask.shape[0])[t_mask])).long())

            self.stats = dict(mean=advantage.mean(),
                              mean_magnitude=advantage.abs().mean(),
                              var=advantage.var(),
                              var_magnitude=advantage.abs().var())

            # Expand advantage.
            advantage = torch.index_select(advantage, 0, a_index)

            # Filter logits.
            t_logprobs = torch.index_select(t_logprobs, 0, t_index)

            actions = to_gpu(
                Variable(torch.from_numpy(t_preds[t_mask]).long().view(-1, 1)))

            log_p_action = torch.gather(t_logprobs, 1, actions)

            # NOTE: Not sure I understand why entropy is inside this
            # multiplication. Investigate?
            policy_losses = log_p_action.view(-1) * \
                to_gpu(Variable(advantage))
            policy_loss = -1. * torch.sum(policy_losses)
            policy_loss /= log_p_action.size(0)
            policy_loss *= self.encoder.rl_weight
        except:
            print("No valid parses. Policy loss of -1 passed.")
            policy_loss = to_gpu(Variable(torch.ones(1) * -1))

        return policy_loss

    def output_hook(self,
                    output,
                    sentences,
                    transitions,
                    y_batch=None,
                    t_tmask_target=None):
        if not self.training:
            return

        # Todo:
        # Pad y_batch to creat a single tensor (also convert from np array of lists to a tensor)
        # ha
        #target = torch.from_numpy(y_batch).long()
        tmp = [torch.Tensor(y_batch[i]) for i in range(len(y_batch))]
        target = nn.utils.rnn.pad_sequence(tmp,
                                           self.target_maxlen,
                                           padding_value=1)

        # Get Reward.
        if self.encoder.rl_transition_acc_as_reward:
            ground = np.transpose(transitions)
            pred = np.array([
                m['t_preds'] for m in self.encoder.spinn.memories
                if 't_preds' in m
            ])
            correct = (ground == pred).astype(np.float32)
            trans_acc = np.sum(correct, axis=0) / correct.shape[0]
            rewards = torch.from_numpy(trans_acc)
        else:
            rewards = self.build_reward(output,
                                        target,
                                        t_tmask_target,
                                        rl_reward=self.encoder.rl_reward)

        # Get Baseline.
        baseline = self.build_baseline(rewards, sentences, transitions,
                                       y_batch)

        # Calculate advantage.
        advantage = rewards - baseline

        # Whiten advantage. This is also called Variance Normalization.
        if self.encoder.rl_whiten:
            advantage = (advantage - advantage.mean()) / \
                (advantage.std() + 1e-8)

        # Assign REINFORCE output.
        self.policy_loss = self.reinforce(advantage)
Beispiel #7
0
def train(opt, logger=None):
    "training process"

    # Create dataset iterator
    SRC, TGT, train_iter, test_iter, val_iter = create_soccer_dataset(opt)
    device = torch.device(opt.device)

    encoder = RNNEncoder(rnn_type=opt.rnn_type,
                         bidirectional=opt.bidirectional,
                         num_layers=opt.num_layers,
                         vocab_size=len(SRC.vocab.itos),
                         word_dim=opt.src_wd_dim,
                         hidden_size=opt.hidden_size).to(device)

    decoder_emb = nn.Embedding(len(TGT.vocab.itos), opt.src_wd_dim)
    decoder = StdRNNDecoder(rnn_type=opt.rnn_type,
                            bidirectional_encoder=opt.bidirectional,
                            num_layers=opt.num_layers,
                            hidden_size=opt.hidden_size,
                            embeddings=decoder_emb).to(device)

    model = NMTModel(encoder=encoder, decoder=decoder,
                     multigpu=False).to(device)

    optimizer = optim.SGD(model.parameters(), lr=float(opt.lr))

    def evaluation(data_iter):
        """do evaluation on data_iter
        return: average_word_loss"""
        model.eval()
        with torch.no_grad():
            eval_total_loss = 0
            for batch_count, batch in enumerate(data_iter, 1):
                src, src_lengths = batch.src[0], batch.src[1]
                tgt, tgt_lengths = batch.tgt[0], batch.tgt[1]
                src = src.to(device)
                tgt = tgt.to(device)
                decoder_outputs, attns, dec_state = \
                    model(src, tgt, src_lengths)
                loss = masked_cross_entropy(decoder_outputs, tgt[1:],
                                            tgt_lengths)
                eval_total_loss += loss.item()
            return eval_total_loss / batch_count

    # Start training
    for epoch in range(1, int(opt.epoch) + 1):
        start_time = time.time()
        # Turn on training mode which enables dropout.
        model.train()
        total_loss = 0
        for batch_count, batch in enumerate(train_iter, 1):
            optimizer.zero_grad()

            src, src_lengths = batch.src[0], batch.src[1]
            tgt, tgt_lengths = batch.tgt[0], batch.tgt[1]
            src = src.to(device)
            tgt = tgt.to(device)
            src_lengths = src_lengths.to(device)
            tgt_lengths = tgt_lengths.to(device)
            decoder_outputs, attns, dec_state = \
                model(src, tgt, src_lengths)
            # Note tgt[1:] excludes the start token
            # and shif one position for input
            loss = masked_cross_entropy(decoder_outputs, tgt[1:], tgt_lengths)
            loss.backward()
            total_loss += loss.item()
            optimizer.step()

        # All xx_loss means loss per word on xx dataset
        train_loss = total_loss / batch_count
        # Doing validation
        val_loss = evaluation(val_iter)

        elapsed = time.time() - start_time
        start_time = time.time()

        if logger:
            logger.info('| epoch {:3d} | train_loss {:5.2f} '
                        '| val_loss {:8.2f} | time {:5.1f}s'.format(
                            epoch, train_loss, val_loss, elapsed))

        # Saving model
        if epoch % opt.every_n_epoch_save == 0:
            if logger:
                logger.info("start to save model on {}".format(opt.save))
            with open(opt.save, 'wb') as save_fh:
                torch.save(model, save_fh)
Beispiel #8
0
def build_decoder(opt, embeddings):
    """
    Various decoder dispatcher function.
    Args:
        opt: the option in current environment.
        embeddings (Embeddings): vocab embeddings for this decoder.
    """

    if opt.decoder_type == "transformer":
        logger.info('TransformerDecoder: layers %d, input dim %d, '
                    'fat relu hidden dim %d, num heads %d, %s global attn, '
                    'copy attn %d, self attn type %s, dropout %.2f' %
                    (opt.dec_layers, opt.dec_rnn_size, opt.transformer_ff,
                     opt.heads, opt.global_attention, opt.copy_attn,
                     opt.self_attn_type, opt.dropout))

        # dec_rnn_size   = dimension of keys/values/queries (input to FF)
        # transformer_ff = dimension of fat relu
        return TransformerDecoder(opt.dec_layers, opt.dec_rnn_size,
                                  opt.heads, opt.transformer_ff,
                                  opt.global_attention, opt.copy_attn,
                                  opt.self_attn_type,
                                  opt.dropout, embeddings)
    elif opt.decoder_type == "cnn":
        return CNNDecoder(opt.dec_layers, opt.dec_rnn_size,
                          opt.global_attention, opt.copy_attn,
                          opt.cnn_kernel_width, opt.dropout,
                          embeddings)
    elif opt.input_feed:
        logger.info('InputFeedRNNDecoder: type %s, bidir %d, layers %d, '
                    'hidden size %d, %s global attn (%s), '
                    'coverage attn %d, copy attn %d, dropout %.2f' %
                    (opt.rnn_type, opt.brnn, opt.dec_layers,
                     opt.dec_rnn_size, opt.global_attention,
                     opt.global_attention_function, opt.coverage_attn,
                     opt.copy_attn, opt.dropout))
        return InputFeedRNNDecoder(opt.rnn_type, opt.brnn,
                                   opt.dec_layers, opt.dec_rnn_size,
                                   opt.global_attention,
                                   opt.global_attention_function,
                                   opt.coverage_attn,
                                   opt.context_gate,
                                   opt.copy_attn,
                                   opt.dropout,
                                   embeddings,
                                   opt.reuse_copy_attn)
    else:
        logger.info('StdRNNDecoder: type %s, bidir %d, layers %d, '
                    'hidden size %d, %s global attn (%s), '
                    'coverage attn %d, copy attn %d, dropout %.2f' %
                    (opt.rnn_type, opt.brnn, opt.dec_layers,
                     opt.dec_rnn_size, opt.global_attention,
                     opt.global_attention_function, opt.coverage_attn,
                     opt.copy_attn, opt.dropout))
        return StdRNNDecoder(opt.rnn_type, opt.brnn,
                             opt.dec_layers, opt.dec_rnn_size,
                             opt.global_attention,
                             opt.global_attention_function,
                             opt.coverage_attn,
                             opt.context_gate,
                             opt.copy_attn,
                             opt.dropout,
                             embeddings,
                             opt.reuse_copy_attn)