示例#1
0
    def forward_hook(self, embeds, batch_size, seq_length, h):
        if self.rl_baseline == "value" and self.training:
            # Break the computational graph.
            x = Variable(embeds.data, volatile=not self.training).view(
                batch_size, seq_length, -1)
            h0 = Variable(to_gpu(torch.zeros(1, batch_size, self.v_rnn_dim)),
                          volatile=not self.training)
            c0 = Variable(to_gpu(torch.zeros(1, batch_size, self.v_rnn_dim)),
                          volatile=not self.training)
            output, (hn, _) = self.v_rnn(x, (h0, c0))
            if self.use_sentence_pair:
                hn = hn.squeeze()
                h1, h2 = hn[:batch_size // 2], hn[batch_size // 2:]
                hn_both = torch.cat([h1, h2], 1)
                self.baseline_outp = self.v_mlp(hn_both.squeeze())
            else:
                self.baseline_outp = self.v_mlp(hn.squeeze())

        elif self.rl_baseline == "shared" and self.training:
            # Break the computational graph.
            hn = h[0]  # model_dim//2, batch_size
            if self.use_sentence_pair:
                # To-do: Not currently supported!!
                hn = hn.squeeze()
                h1, h2 = hn[:batch_size // 2], hn[batch_size // 2:]
                hn_both = torch.cat([h1, h2], 1)
                self.baseline_outp = self.v_mlp(hn_both.squeeze())
            else:
                self.baseline_outp = self.v_mlp(hn.squeeze())
示例#2
0
文件: plain_rnn.py 项目: tsvm/spinn
    def run_rnn(self, x):
        batch_size, seq_len, model_dim = x.data.size()

        num_layers = 1
        bidirectional=self.is_bidirectional
        bi = 2 if bidirectional else 1
        h0 = Variable(
            to_gpu(
                torch.zeros(
                    num_layers * bi,
                    batch_size,
                    self.model_dim)),
            volatile=not self.training)
        c0 = Variable(
            to_gpu(
                torch.zeros(
                    num_layers * bi,
                    batch_size,
                    self.model_dim)),
            volatile=not self.training)

        # Expects (input, h_0):
        #   input => batch_size x seq_len x model_dim
        #   h_0   => (num_layers x num_directions[1,2]) x batch_size x model_dim
        # c_0   => (num_layers x num_directions[1,2]) x batch_size x model_dim
        output, (hn, cn) = self.rnn(x, (h0, c0))
        if self.data_type=="mt":
            return hn, cn, output
        return hn
示例#3
0
    def unwrap_tree(self, lefts, rights, writes):
        max_len = lefts.shape[1]
        left_prem = lefts[:, :, 0]
        left_hyp = lefts[:, :, 1]
        left = np.concatenate([left_prem, left_hyp], axis=0)
        right_prem = rights[:, :, 0]
        right_hyp = rights[:, :, 1]
        right = np.concatenate([right_prem, right_hyp], axis=0)
        write_prem = writes[:, :, 0]
        write_hyp = writes[:, :, 1]
        write = np.concatenate([write_prem, write_hyp], axis=0)
        #print("left")
        #print(left)
        #print("write")
        #print(write)

        l = to_gpu(Variable(torch.from_numpy(left), volatile=not self.training))
        r = to_gpu(Variable(torch.from_numpy(right), volatile=not self.training))
        w = to_gpu(Variable(torch.from_numpy(write), volatile=not self.training))

        l = l - (l.ge(200).int() * (200 - max_len))
        #print("left new")
        #print(l)
        r = r - (r.ge(200).int() * (200 - max_len))
        w = w - (w.ge(201).int() * (201 - max_len))
        w_mask = w.ge(0).long()
        w = w + (w.le(0).int() * (2 * max_len))
        #print("write new")
        #print(w)
        #print("write mask")
        #print(w_mask)
        return l.long(), r.long(), w.long(), w_mask
示例#4
0
文件: rl_spinn.py 项目: seyiqi/spinn
 def forward_hook(self, embeds, batch_size, seq_length):
     if self.rl_baseline == "value" and self.training:
         # Break the computational graph.
         x = Variable(embeds.data, volatile=not self.training).view(
             batch_size, seq_length, -1)
         h0 = Variable(to_gpu(torch.zeros(1, batch_size, self.v_dim)),
                       volatile=not self.training)
         c0 = Variable(to_gpu(torch.zeros(1, batch_size, self.v_dim)),
                       volatile=not self.training)
         output, (hn, cn) = self.v_rnn(x, (h0, c0))
         self.baseline_outp = self.v_mlp(hn.squeeze())
示例#5
0
文件: mt_model.py 项目: tsvm/spinn
    def build_reward(self, output, target, mask, rl_reward="mean"):
        if rl_reward == "xent":
            batch_size = target.size(0)
            seq_length = target.size(1)
            _target = target.permute(1, 0).long()
            output = output[:-1, :, :]  # drop <end> token
            probs = F.softmax(output, dim=2).data.cpu()
            log_inv_prob = torch.log(1 - probs)

            # Looping over seq_length to get a sum of rewards across the full sequence
            # Element-wise mean not supported yet.
            rewards = torch.zeros(batch_size)
            for i in range(seq_length):
                rewards += -1 * torch.gather(
                    log_inv_prob[i], 1, _target[i].unsqueeze(1)).squeeze()
        else:
            output = output.permute(1, 0, 2)
            target = to_gpu(Variable(target))
            if rl_reward == "mean":
                criterion = nn.NLLLoss(reduction="elementwise_mean")
            elif rl_reward == "sum":
                criterion = nn.NLLLoss(reduction="sum")
            batch_size = output.shape[0]
            rewards = [0.0] * batch_size

            # Note that we're putting NLLLoss to an unusual use below
            # Instead of passing a full batch of single token, we're passing a single full example of some sequence length
            # If summing, we're summing over all prediction, similarly for elementwise-mean
            for i in range(batch_size):
                rewards[i] = criterion(output[i][:-1, :], target[i].long())
            rewards = torch.tensor([float(x) for x in rewards])

        return rewards
示例#6
0
    def run_spinn(self,
                  example,
                  embeds,
                  use_internal_parser,
                  validate_transitions=True):
        self.spinn.reset_state()
        h_list, transition_acc, transition_loss, attended = self.spinn(
            example,
            use_internal_parser=use_internal_parser,
            validate_transitions=validate_transitions)

        ## Not using during attention debugging.
        maxlen_attended = max([len(x) for x in attended])
        memory_lengths = None  #to_gpu(Variable(torch.Tensor([len(x) for x in attended])))
        attended = [
            x + (maxlen_attended - len(x)) *
            [to_gpu(Variable(torch.zeros(1, self.model_dim)))]
            for x in attended
        ]

        attended = [torch.cat(x) for x in attended]
        attended = torch.cat([x.unsqueeze(1) for x in attended], 1)

        if self.data_type == "mt":
            h = torch.cat(h_list).unsqueeze(0)
        else:
            h = self.wrap(h_list)

        return h, h_list, transition_acc, transition_loss, attended, memory_lengths
示例#7
0
    def unwrap_sentence_pair(self, sentences, transitions):
        x_prem = sentences[:, :, 0]
        x_hyp = sentences[:, :, 1]
        x = np.concatenate([x_prem, x_hyp], axis=0)

        return to_gpu(Variable(torch.from_numpy(x),
                               volatile=not self.training))
示例#8
0
    def predict_actions(self, transition_output):
        transition_output_t = transition_output / max(self.temperature, TINY)
        transition_dist = F.softmax(transition_output_t, dim=1)

        if self.catalan:
            # Use the catalan distribution as a prior.
            p_shift_catalan = [
                self.shift_probabilities.prob(n_red, n_step, n_tok)
                for n_red, n_step, n_tok in zip(self.n_reduces, self.n_steps,
                                                self.n_tokens)
            ]
            p_shift_catalan = torch.FloatTensor(p_shift_catalan).view(-1, 1)
            p_catalan = torch.cat([p_shift_catalan, 1. - p_shift_catalan], 1)
            p_catalan = to_gpu(Variable(p_catalan))

            _p_new = transition_dist * p_catalan
            p_new = _p_new / (_p_new.sum(1) + TINY)  # normalize
            transition_dist = p_new

        if self.catalan and self.catalan_backprop:
            transition_logdist = torch.log(transition_dist + TINY)
        else:
            transition_logdist = F.log_softmax(transition_output_t, dim=1)
        shift_probs = transition_dist.data[:, 0]

        if self.training:
            np_shift_probs = shift_probs.cpu().numpy()
            transition_preds = (np.random.rand(*np_shift_probs.shape) >
                                np_shift_probs).astype('int32')
        else:
            # Greedy prediction
            transition_preds = torch.round(
                1 - shift_probs).cpu().numpy().astype('int32')
        return transition_logdist, transition_preds
示例#9
0
    def forward(self, example, use_internal_parser=False, validate_transitions=True):
        self.buffers_n = (example.tokens.data != 0).long().sum(1).view(-1).tolist()

        if self.debug:
            seq_length = example.tokens.size(1)
            assert all(buf_n <= (seq_length + 1) // 2 for buf_n in self.buffers_n), \
                "All sentences (including cropped) must be the appropriate length."

        self.bufs = example.bufs

        # Notes on adding zeros to bufs/stacks.
        # - After the buffer is consumed, we need one zero on the buffer
        #   used as input to the tracker.
        # - For the first two steps, the stack would be empty, but we add
        #   zeros so that the tracker still gets input.
        zeros = to_gpu(Variable(torch.from_numpy(
            np.zeros(self.bufs[0][0].size(), dtype=np.float32)),
            volatile=self.bufs[0][0].volatile))

        # Trim unused tokens.
        self.bufs = [[zeros] + b[-b_n:] for b, b_n in zip(self.bufs, self.buffers_n)]

        self.stacks = [[zeros, zeros] for buf in self.bufs]

        if hasattr(self, 'tracker'):
            self.tracker.reset_state()
        if not hasattr(example, 'transitions'):
            # TODO: Support no transitions. In the meantime, must at least pass dummy transitions.
            raise ValueError('Transitions must be included.')
        self.forward_hook()
        return self.run(example.transitions,
                        run_internal_parser=True,
                        use_internal_parser=use_internal_parser,
                        validate_transitions=validate_transitions)
示例#10
0
文件: r_spinn.py 项目: anhad13/spinn
    def forward(self, top_buf, top_stack_1, top_stack_2):
        if self.tracking_ln:
            top_buf = self.buf_ln(top_buf)
            top_stack_1 = self.stack1_ln(top_stack_1)
            top_stack_2 = self.stack2_ln(top_stack_2)

        if self.lateral_tracking:
            tracker_inp = self.buf(top_buf)
            tracker_inp += self.stack1(top_stack_1)
            tracker_inp += self.stack2(top_stack_2)

            batch_size = tracker_inp.size(0)

            if self.h is not None:
                tracker_inp += self.lateral(self.h)
            if self.c is None:
                self.c = to_gpu(
                    Variable(torch.from_numpy(
                        np.zeros((batch_size, self.state_size),
                                 dtype=np.float32)),
                             volatile=tracker_inp.volatile))

            # Run tracking lstm.
            self.c, self.h = lstm(self.c, tracker_inp)

            return self.h, self.c
        else:
            return torch.cat([top_buf, top_stack_1, top_stack_2], 1), None
示例#11
0
    def forward(self,
                sentences,
                _,
                __=None,
                example_lengths=None,
                store_parse_masks=False,
                pyramid_temperature_multiplier=1.0,
                **kwargs):
        # Useful when investigating dynamic batching:
        # self.seq_lengths = sentences.shape[1] - (sentences == 0).sum(1)

        x, example_lengths = self.unwrap(sentences, example_lengths)

        emb = self.run_embed(x)

        batch_size, seq_len, model_dim = emb.data.size()
        example_lengths_var = to_gpu(
            Variable(torch.from_numpy(example_lengths))).long()

        hh, _, masks, temperature = self.binary_tree_lstm(
            emb,
            example_lengths_var,
            temperature_multiplier=pyramid_temperature_multiplier)

        if self.training:
            self.temperature_to_display = temperature

        if store_parse_masks:
            self.mask_memory = [mask.data.cpu().numpy() for mask in masks]

        h = self.wrap(hh)
        output = self.mlp(self.build_features(h))

        return output
示例#12
0
    def forward(self,
                sentences,
                _,
                __,
                dist=None,
                pyramid_temperature_multiplier=1.0,
                example_lengths=None,
                store_parse_masks=False,
                **kwargs):

        # before: sentences and dist: <batch x maxlen x 2> (2 = |{prm, hyp}|)

        # Useful when investigating dynamic batching:
        # self.seq_lengths = sentences.shape[1] - (sentences == 0).sum(1)

        orig_example_lengths = example_lengths
        # <maxlen x 2>

        x, example_lengths = self.unwrap(sentences, orig_example_lengths)
        if dist is not None:
            dist, _ = self.unwrap(dist, orig_example_lengths)  # gone to gpu
        # after: x and dist: < numSent x maxlen >, numSent = batch x 2

        emb = self.run_embed(x)
        # <numSent, maxlen, dim>

        batch_size, seq_len, model_dim = emb.data.size()
        example_lengths_var = to_gpu(
            Variable(torch.from_numpy(example_lengths))).long()
        # <numSent>

        # self.binary_tree_lstm.sbs_loss = 0.0
        # self.binary_tree_lstm.sbs_acc = 0
        hh, _, masks, temperature = self.binary_tree_lstm(
            emb,
            example_lengths_var,
            dist=dist,
            temperature_multiplier=pyramid_temperature_multiplier)

        if self.training:
            self.temperature_to_display = temperature

            # if self.binary_tree_lstm.sbs_acc.cpu().data.numpy() > self.binary_tree_lstm.n_total.cpu().data.numpy():
            #     print 'acc', self.binary_tree_lstm.sbs_acc.data
            #     print 'total', self.binary_tree_lstm.n_total.data
            #     sys.exit(0)
            self.sbs_loss = self.binary_tree_lstm.sbs_loss / self.binary_tree_lstm.n_total.float(
            )
            self.sbs_acc = self.binary_tree_lstm.sbs_acc / self.binary_tree_lstm.n_total.float(
            )
            #TODO: sbs_acc may not divided by num at this moment

        if store_parse_masks:
            self.mask_memory = [mask.data.cpu().numpy() for mask in masks]

        h = self.wrap(hh)
        output = self.mlp(self.build_features(h))

        return output
示例#13
0
文件: rl_spinn.py 项目: seyiqi/spinn
    def build_baseline(self,
                       rewards,
                       sentences,
                       transitions,
                       y_batch=None,
                       embeds=None):
        if self.rl_baseline == "ema":
            mu = self.rl_mu
            baseline = self.baseline[0]
            self.baseline[0] = self.baseline[0] * (1 -
                                                   mu) + rewards.mean() * mu
        elif self.rl_baseline == "pass":
            baseline = 0.
        elif self.rl_baseline == "greedy":
            # Pass inputs to Greedy Max
            output = self.run_greedy(sentences, transitions)

            # Estimate Reward
            probs = F.softmax(output).data.cpu()
            target = torch.from_numpy(y_batch).long()
            approx_rewards = self.build_reward(probs,
                                               target,
                                               rl_reward=self.rl_reward)

            baseline = approx_rewards
        elif self.rl_baseline == "value":
            output = self.baseline_outp

            if self.rl_reward == "standard":
                baseline = F.sigmoid(output)
                self.value_loss = nn.BCELoss()(
                    baseline,
                    to_gpu(Variable(rewards, volatile=not self.training)))
            elif self.rl_reward == "xent":
                baseline = output
                self.value_loss = nn.MSELoss()(
                    baseline,
                    to_gpu(Variable(rewards, volatile=not self.training)))
            else:
                raise NotImplementedError

            baseline = baseline.data.cpu()
        else:
            raise NotImplementedError

        return baseline
示例#14
0
文件: r_spinn.py 项目: anhad13/spinn
    def mc_reinforce(self, rewards, baseline):
        t_preds = np.concatenate(
            [m['t_preds'] for m in self.spinn.memories if 't_preds' in m])
        t_mask = np.concatenate(
            [m['t_mask'] for m in self.spinn.memories if 't_mask' in m])
        t_valid_mask = np.concatenate(
            [m['t_valid_mask'] for m in self.spinn.memories if 't_mask' in m])
        t_logprobs = torch.cat([
            m['t_logprobs'] for m in self.spinn.memories if 't_logprobs' in m
        ], 0)

        if self.use_sentence_pair:
            # Handles the case of SNLI where each reward is used for two
            # sentences.
            rewards = torch.cat([rewards, rewards], 0)
            baseline = torch.cat([baseline, baseline], 0)

#t_logprobs=t_logprobs.view(1,-1)

#p_actions=t_logprobs[:,0].long()
        advantage = -1 * (rewards - baseline)
        batch_size = advantage.size(0)
        seq_length = t_preds.shape[0] / batch_size
        a_index = np.arange(batch_size)
        a_index = a_index.reshape(1, -1).repeat(seq_length, axis=0).flatten()
        a_index = torch.from_numpy(a_index[t_mask]).long()

        t_index = to_gpu(
            Variable(torch.from_numpy(np.arange(
                t_mask.shape[0])[t_mask])).long())
        t_logprobs = torch.index_select(t_logprobs, 0, t_index)
        #p_actions = torch.index_select(p_actions, 0, a_index)
        actions = to_gpu(
            Variable(torch.from_numpy(t_preds[t_mask]).long().view(-1, 1),
                     volatile=not self.training))
        log_p_action = torch.gather(t_logprobs, 1, actions)
        advantage = torch.index_select(advantage, 0, a_index)
        policy_loss = to_gpu(Variable(advantage.long().view(
            1, -1))) * log_p_action.view(-1).long()
        print(
            torch.max(advantage.long().view(1, -1) *
                      log_p_action.view(-1).long()))
        policy_loss = torch.sum(policy_loss.float()) / log_p_action.size(0)
        #print(policy_loss)
        return policy_loss * 0.000121392198451
示例#15
0
文件: cbow.py 项目: xiaonanzzz/spinn
    def build_example(self, sentences, transitions):
        batch_size = sentences.shape[0]

        # Build Tokens
        x_prem = sentences[:, :, 0]
        x_hyp = sentences[:, :, 1]
        x = np.concatenate([x_prem, x_hyp], axis=0)

        return to_gpu(Variable(torch.from_numpy(x),
                               volatile=not self.training))
示例#16
0
    def unwrap_sentence_pair(self, sentences, lengths=None):
        x_prem = sentences[:, :, 0]
        x_hyp = sentences[:, :, 1]
        x = np.concatenate([x_prem, x_hyp], axis=0)

        if lengths is not None:
            len_prem = lengths[:, 0]
            len_hyp = lengths[:, 1]
            lengths = np.concatenate([len_prem, len_hyp], axis=0)

        return to_gpu(Variable(torch.from_numpy(x),
                               volatile=not self.training)), lengths
示例#17
0
    def unwrap_sentence(self, sentences, transitions):
        # Build Tokens
        x = sentences

        # Build Transitions
        t = transitions

        example = Example()
        example.tokens = to_gpu(Variable(torch.from_numpy(x), volatile=not self.training))
        example.transitions = t

        return example
示例#18
0
    def run_rnn(self, x):
        batch_size, seq_len, _ = x.data.size()

        num_layers = 1
        bidirectional = self.bidirectional
        bi = 2 if bidirectional else 1
        h0 = Variable(to_gpu(
            torch.zeros(num_layers * bi, batch_size, self.model_dim / bi)),
                      volatile=not self.training)
        c0 = Variable(to_gpu(
            torch.zeros(num_layers * bi, batch_size, self.model_dim / bi)),
                      volatile=not self.training)

        # Expects (input, h_0):
        #   input => batch_size x seq_len x model_dim
        #   h_0   => (num_layers x num_directions[1,2]) x batch_size x model_dim
        #   c_0   => (num_layers x num_directions[1,2]) x batch_size x model_dim
        output, (hn, cn) = self.rnn(x, (h0, c0))

        hn = hn.transpose(0, 1).contiguous().view(batch_size, -1)

        return hn
示例#19
0
文件: loss.py 项目: TaoMiner/eesc
def auxiliary_loss(model):

    has_spinn = hasattr(model, 'spinn')
    has_policy = has_spinn and hasattr(model, 'policy_loss')
    has_value = has_spinn and hasattr(model, 'value_loss')

    total_loss = to_gpu(Variable(torch.Tensor([0.0])))
    if has_policy:
        total_loss += model.policy_loss
    if has_value:
        total_loss += model.value_loss

    return total_loss
示例#20
0
    def build_example(self, sentences, transitions):
        batch_size = sentences.shape[0]

        # Build Tokens
        x = sentences

        # Build Transitions
        t = transitions

        example = Example()
        example.tokens = to_gpu(Variable(torch.from_numpy(x), volatile=not self.training))
        example.transitions = t

        return example
示例#21
0
    def unwrap_sentence_pair(self, sentences, transitions):
        # Build Tokens
        x_prem = sentences[:, :, 0]
        x_hyp = sentences[:, :, 1]
        x = np.concatenate([x_prem, x_hyp], axis=0)

        # Build Transitions
        t_prem = transitions[:, :, 0]
        t_hyp = transitions[:, :, 1]
        t = np.concatenate([t_prem, t_hyp], axis=0)

        example = Example()
        example.tokens = to_gpu(Variable(torch.from_numpy(x), volatile=not self.training))
        example.transitions = t

        return example
示例#22
0
    def reset_decoder(self, example):
        """Run decoder on input to initialize rnn states."""
        batch_size = len(example.bufs)

        # TODO: Would prefer to run decoder forwards or backwards?
        batch = torch.cat([torch.cat(b, 0).unsqueeze(0) for b in example.bufs],
                          0)

        init = to_gpu(
            Variable(torch.zeros(1, batch_size, self.decoder_dim),
                     volatile=not self.training))
        self.dec_h = list(torch.chunk(init, batch_size, 1))
        self.dec_c = list(torch.chunk(init, batch_size, 1))

        # TODO: Right now the decoder runs over the entire sentence, which is a bit like cheating!
        self.run_decoder_rnn(range(batch_size), batch)
示例#23
0
    def build_example(self, sentences, transitions):
        batch_size = sentences.shape[0]
        # sentences: (#batches, #feature, #2)
        # Build Tokens
        x_prem = sentences[:,:,0]
        x_hyp = sentences[:,:,1]
        x = np.concatenate([x_prem, x_hyp], axis=0)

        # Build Transitions
        t_prem = transitions[:,:,0]
        t_hyp = transitions[:,:,1]
        t = np.concatenate([t_prem, t_hyp], axis=0)

        example = Example()
        example.tokens = to_gpu(Variable(torch.from_numpy(x), volatile=not self.training))
        example.transitions = t

        return example
示例#24
0
    def t_reduce(self, buf, stack, tracking, lefts, rights, trackings):
        """REDUCE: Should compose top two items of the stack into new item."""

        # The right-most input will be popped first.
        for reduce_inp in [rights, lefts]:
            if len(stack) > 0:
                reduce_inp.append(stack.pop())
            else:
                if self.debug:
                    raise IndexError
                # If we try to Reduce, but there are less than 2 items on the stack,
                # then treat any available item as the right input, and use zeros
                # for any other inputs.
                # NOTE: Only happens on cropped data.
                zeros = to_gpu(Variable(
                    torch.from_numpy(np.zeros(buf[0].size(), dtype=np.float32)),
                    volatile=buf[0].volatile))
                reduce_inp.append(zeros)

        trackings.append(tracking)
示例#25
0
    def build_baseline(self,
                       output,
                       rewards,
                       sentences,
                       transitions,
                       y_batch=None):
        if self.rl_baseline == "ema":
            mu = self.rl_mu
            self.baseline[0] = self.baseline[0] * (1 -
                                                   mu) + rewards.mean() * mu
            baseline = self.baseline[0]
        elif self.rl_baseline == "policy":
            # Pass inputs to Policy Net
            policy_outp = self.policy(sentences, transitions)

            # Estimate Reward
            policy_prob = policy_outp

            # Save MSE Loss using Reward as target
            self.policy_loss = nn.MSELoss()(
                policy_prob,
                to_gpu(Variable(rewards, volatile=not self.training)))

            baseline = policy_prob.data.cpu()
        elif self.rl_baseline == "greedy":
            # Pass inputs to Greedy Max
            greedy_outp = self.run_greedy(sentences, transitions)

            # Estimate Reward
            logits = F.softmax(output).data.cpu()
            target = torch.from_numpy(y_batch).long()
            greedy_rewards = self.build_reward(logits, target)

            baseline = greedy_rewards
        else:
            raise NotImplementedError

        return baseline
示例#26
0
    def reinforce(self, rewards):
        t_preds, t_logits, t_given, t_mask = self.spinn.get_statistics()

        # TODO: Many of these ops are on the cpu. Might be worth shifting to GPU.
        if self.use_sentence_pair:
            # Handles the case of SNLI where each reward is used for two sentences.
            rewards = torch.cat([rewards, rewards], 0)

        # Expand rewards.
        if not self.spinn.use_skips:
            rewards = rewards.index_select(0, torch.from_numpy(t_mask).long())
        else:
            raise NotImplementedError

        log_p_action = torch.cat(
            [t_logits[i, p] for i, p in enumerate(t_preds)], 0)

        rl_loss = -1. * torch.sum(log_p_action * to_gpu(
            Variable(rewards, volatile=log_p_action.volatile)))
        rl_loss /= log_p_action.size(0)
        rl_loss *= self.rl_weight

        return rl_loss
示例#27
0
    def forward(self, top_buf, top_stack_1, top_stack_2):
        tracker_inp = self.buf(top_buf.h)
        tracker_inp += self.stack1(top_stack_1.h)
        tracker_inp += self.stack2(top_stack_2.h)

        batch_size = tracker_inp.size(0)

        if self.lateral_tracking:
            if self.h is not None:
                tracker_inp += self.lateral(self.h)
            if self.c is None:
                self.c = to_gpu(Variable(torch.from_numpy(
                    np.zeros((batch_size, self.state_size),
                                  dtype=np.float32)),
                    volatile=tracker_inp.volatile))

            # Run tracking lstm.
            self.c, self.h = lstm(self.c, tracker_inp)

            return self.h, self.c
        else:
            outp = self.transform(tracker_inp)
            return outp, None
示例#28
0
def train_loop(FLAGS, model, trainer, training_data_iter, eval_iterators,
               logger):
    # Accumulate useful statistics.
    A = Accumulator(maxlen=FLAGS.deque_length)

    # Train.
    logger.Log("Training.")

    # New Training Loop
    progress_bar = SimpleProgressBar(msg="Training",
                                     bar_length=60,
                                     enabled=FLAGS.show_progress_bar)
    progress_bar.step(i=0, total=FLAGS.statistics_interval_steps)

    log_entry = pb.SpinnEntry()
    for _ in range(trainer.step, FLAGS.training_steps):
        if (trainer.step -
                trainer.best_dev_step) > FLAGS.early_stopping_steps_to_wait:
            logger.Log('No improvement after ' +
                       str(FLAGS.early_stopping_steps_to_wait) +
                       ' steps. Stopping training.')
            break

        model.train()
        log_entry.Clear()
        log_entry.step = trainer.step
        should_log = False

        start = time.time()

        batch = get_batch(next(training_data_iter))
        X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = batch

        total_tokens = sum([(nt + 1) / 2
                            for nt in num_transitions_batch.reshape(-1)])

        # Reset cached gradients.
        trainer.optimizer_zero_grad()

        temperature = math.sin(
            math.pi / 2 +
            trainer.step / float(FLAGS.rl_confidence_interval) * 2 * math.pi)
        temperature = (temperature + 1) / 2

        # Confidence Penalty for Transition Predictions.
        if FLAGS.rl_confidence_penalty:
            epsilon = FLAGS.rl_epsilon * \
                math.exp(-trainer.step / float(FLAGS.rl_epsilon_decay))
            temp = 1 + \
                (temperature - .5) * FLAGS.rl_confidence_penalty * epsilon
            model.spinn.temperature = max(1e-3, temp)

        # Soft Wake/Sleep based on temperature.
        if FLAGS.rl_wake_sleep:
            model.rl_weight = temperature * FLAGS.rl_weight

        # Run model.
        output = model(X_batch,
                       transitions_batch,
                       y_batch,
                       use_internal_parser=FLAGS.use_internal_parser,
                       validate_transitions=FLAGS.validate_transitions)

        # Calculate class accuracy.
        target = torch.from_numpy(y_batch).long()

        # get the index of the max log-probability
        pred = output.data.max(1, keepdim=False)[1].cpu()

        class_acc = pred.eq(target).sum() / float(target.size(0))

        # Calculate class loss.
        xent_loss = nn.CrossEntropyLoss()(output,
                                          to_gpu(
                                              Variable(target,
                                                       volatile=False)))

        # Optionally calculate transition loss.
        transition_loss = model.transition_loss if hasattr(
            model, 'transition_loss') else None

        # Accumulate Total Loss Variable
        total_loss = 0.0
        total_loss += xent_loss
        if transition_loss is not None and model.optimize_transition_loss:
            total_loss += transition_loss
        aux_loss = auxiliary_loss(model)
        total_loss += aux_loss

        # Backward pass.
        total_loss.backward()

        # Hard Gradient Clipping
        nn.utils.clip_grad_norm([
            param for name, param in model.named_parameters()
            if name not in ["embed.embed.weight"]
        ], FLAGS.clipping_max_value)

        # Gradient descent step.
        trainer.optimizer_step()

        end = time.time()

        total_time = end - start

        train_accumulate(model, A, batch)
        A.add('class_acc', class_acc)
        A.add('total_tokens', total_tokens)
        A.add('total_time', total_time)

        train_rl_accumulate(model, A, batch)

        if trainer.step % FLAGS.statistics_interval_steps == 0:
            progress_bar.step(i=FLAGS.statistics_interval_steps,
                              total=FLAGS.statistics_interval_steps)
            progress_bar.finish()

            A.add('xent_cost', xent_loss.data[0])
            stats(model, trainer, A, log_entry)
            should_log = True

        if trainer.step % FLAGS.sample_interval_steps == 0 and FLAGS.num_samples > 0:
            should_log = True
            model.train()
            model(X_batch,
                  transitions_batch,
                  y_batch,
                  use_internal_parser=FLAGS.use_internal_parser,
                  validate_transitions=FLAGS.validate_transitions)
            tr_transitions_per_example, tr_strength = model.spinn.get_transitions_per_example(
            )

            model.eval()
            model(X_batch,
                  transitions_batch,
                  y_batch,
                  use_internal_parser=FLAGS.use_internal_parser,
                  validate_transitions=FLAGS.validate_transitions)
            ev_transitions_per_example, ev_strength = model.spinn.get_transitions_per_example(
            )

            if model.use_sentence_pair and len(transitions_batch.shape) == 3:
                transitions_batch = np.concatenate(
                    [transitions_batch[:, :, 0], transitions_batch[:, :, 1]],
                    axis=0)

            # This could be done prior to running the batch for a tiny speed
            # boost.
            t_idxs = list(range(FLAGS.num_samples))
            random.shuffle(t_idxs)
            t_idxs = sorted(t_idxs[:FLAGS.num_samples])
            for t_idx in t_idxs:
                log = log_entry.rl_sampling.add()
                gold = transitions_batch[t_idx]
                pred_tr = tr_transitions_per_example[t_idx]
                pred_ev = ev_transitions_per_example[t_idx]
                strength_tr = sparks([1] + tr_strength[t_idx].tolist(),
                                     dec_str)
                strength_ev = sparks([1] + ev_strength[t_idx].tolist(),
                                     dec_str)
                _, crossing = evalb.crossing(gold, pred)
                log.t_idx = t_idx
                log.crossing = crossing
                log.gold_lb = "".join(map(str, gold))
                log.pred_tr = "".join(map(str, pred_tr))
                log.pred_ev = "".join(map(str, pred_ev))
                log.strg_tr = strength_tr[1:]
                log.strg_ev = strength_ev[1:]

        if trainer.step > 0 and trainer.step % FLAGS.eval_interval_steps == 0:
            should_log = True
            for index, eval_set in enumerate(eval_iterators):
                acc, _ = evaluate(FLAGS,
                                  model,
                                  eval_set,
                                  log_entry,
                                  logger,
                                  trainer,
                                  eval_index=index)
                if index == 0:
                    trainer.new_dev_accuracy(acc)

            progress_bar.reset()

        if trainer.step > FLAGS.ckpt_step and trainer.step % FLAGS.ckpt_interval_steps == 0:
            should_log = True
            trainer.checkpoint()

        if should_log:
            logger.LogEntry(log_entry)

        progress_bar.step(i=(trainer.step % FLAGS.statistics_interval_steps) +
                          1,
                          total=FLAGS.statistics_interval_steps)
示例#29
0
 def unwrap_sentence(self, sentences, lengths=None):
     return to_gpu(
         Variable(torch.from_numpy(sentences),
                  volatile=not self.training)), lengths
示例#30
0
def train_loop(FLAGS, data_manager, model, optimizer, trainer, training_data_iter, eval_iterators, logger, step, best_dev_error):
    # Accumulate useful statistics.
    A = Accumulator(maxlen=FLAGS.deque_length)
    M = MetricsWriter(os.path.join(FLAGS.metrics_path, FLAGS.experiment_name))

    # Checkpoint paths.
    standard_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name)
    best_checkpoint_path = get_checkpoint_path(FLAGS.ckpt_path, FLAGS.experiment_name, best=True)

    # Build log format strings.
    model.train()
    X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = get_batch(training_data_iter.next())
    model(X_batch, transitions_batch, y_batch,
            use_internal_parser=FLAGS.use_internal_parser,
            validate_transitions=FLAGS.validate_transitions
            )

    logger.Log("")
    logger.Log("# ----- BEGIN: Log Configuration ----- #")

    # Preview train string template.
    train_str = train_format(model)
    logger.Log("Train-Format: {}".format(train_str))
    train_extra_str = train_extra_format(model)
    logger.Log("Train-Extra-Format: {}".format(train_extra_str))

    # Preview eval string template.
    eval_str = eval_format(model)
    logger.Log("Eval-Format: {}".format(eval_str))
    eval_extra_str = eval_extra_format(model)
    logger.Log("Eval-Extra-Format: {}".format(eval_extra_str))

    logger.Log("# ----- END: Log Configuration ----- #")
    logger.Log("")

    # Train.
    logger.Log("Training.")

    # New Training Loop
    progress_bar = SimpleProgressBar(msg="Training", bar_length=60, enabled=FLAGS.show_progress_bar)
    progress_bar.step(i=0, total=FLAGS.statistics_interval_steps)

    for step in range(step, FLAGS.training_steps):
        model.train()

        start = time.time()

        batch = get_batch(training_data_iter.next())
        X_batch, transitions_batch, y_batch, num_transitions_batch, train_ids = batch

        total_tokens = sum([(nt+1)/2 for nt in num_transitions_batch.reshape(-1)])

        # Reset cached gradients.
        optimizer.zero_grad()

        # Run model.
        output = model(X_batch, transitions_batch, y_batch,
            use_internal_parser=FLAGS.use_internal_parser,
            validate_transitions=FLAGS.validate_transitions
            )

        # Normalize output.
        logits = F.log_softmax(output)

        # Calculate class accuracy.
        target = torch.from_numpy(y_batch).long()
        pred = logits.data.max(1)[1].cpu() # get the index of the max log-probability
        class_acc = pred.eq(target).sum() / float(target.size(0))

        # Calculate class loss.
        xent_loss = nn.NLLLoss()(logits, to_gpu(Variable(target, volatile=False)))

        # Optionally calculate transition loss.
        transition_loss = model.transition_loss if hasattr(model, 'transition_loss') else None

        # Extract L2 Cost
        l2_loss = l2_cost(model, FLAGS.l2_lambda) if FLAGS.use_l2_cost else None

        # Accumulate Total Loss Variable
        total_loss = 0.0
        total_loss += xent_loss
        if l2_loss is not None:
            total_loss += l2_loss
        if transition_loss is not None and model.optimize_transition_loss:
            total_loss += transition_loss
        total_loss += auxiliary_loss(model)

        # Backward pass.
        total_loss.backward()

        # Hard Gradient Clipping
        clip = FLAGS.clipping_max_value
        for p in model.parameters():
            if p.requires_grad:
                p.grad.data.clamp_(min=-clip, max=clip)

        # Learning Rate Decay
        if FLAGS.actively_decay_learning_rate:
            optimizer.lr = FLAGS.learning_rate * (FLAGS.learning_rate_decay_per_10k_steps ** (step / 10000.0))

        # Gradient descent step.
        optimizer.step()

        end = time.time()

        total_time = end - start

        train_accumulate(model, data_manager, A, batch)
        A.add('class_acc', class_acc)
        A.add('total_tokens', total_tokens)
        A.add('total_time', total_time)

        if step % FLAGS.statistics_interval_steps == 0:
            progress_bar.step(i=FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps)
            progress_bar.finish()

            A.add('xent_cost', xent_loss.data[0])
            A.add('l2_cost', l2_loss.data[0])
            stats_args = train_stats(model, optimizer, A, step)

            train_metrics(M, stats_args, step)

            logger.Log(train_str.format(**stats_args))
            logger.Log(train_extra_str.format(**stats_args))

        if step % FLAGS.sample_interval_steps == 0 and FLAGS.num_samples > 0:
            model.train()
            model(X_batch, transitions_batch, y_batch,
                use_internal_parser=FLAGS.use_internal_parser,
                validate_transitions=FLAGS.validate_transitions
                )
            tr_transitions_per_example, tr_strength = model.spinn.get_transitions_per_example()

            model.eval()
            model(X_batch, transitions_batch, y_batch,
                use_internal_parser=FLAGS.use_internal_parser,
                validate_transitions=FLAGS.validate_transitions
                )
            ev_transitions_per_example, ev_strength = model.spinn.get_transitions_per_example()

            transition_str = "Samples:"
            if model.use_sentence_pair and len(transitions_batch.shape) == 3:
                transitions_batch = np.concatenate([
                    transitions_batch[:,:,0], transitions_batch[:,:,1]], axis=0)

            # This could be done prior to running the batch for a tiny speed boost.
            t_idxs = range(FLAGS.num_samples)
            random.shuffle(t_idxs)
            t_idxs = sorted(t_idxs[:FLAGS.num_samples])
            for t_idx in t_idxs:
                gold = transitions_batch[t_idx]
                pred_tr = tr_transitions_per_example[t_idx]
                pred_ev = ev_transitions_per_example[t_idx]
                stength_tr = sparks([1] + tr_strength[t_idx].tolist())
                stength_ev = sparks([1] + ev_strength[t_idx].tolist())
                _, crossing = evalb.crossing(gold, pred)
                transition_str += "\n{}. crossing={}".format(t_idx, crossing)
                transition_str += "\n     g{}".format("".join(map(str, gold)))
                transition_str += "\n      {}".format(stength_tr[1:].encode('utf-8'))
                transition_str += "\n    pt{}".format("".join(map(str, pred_tr)))
                transition_str += "\n      {}".format(stength_ev[1:].encode('utf-8'))
                transition_str += "\n    pe{}".format("".join(map(str, pred_ev)))
            logger.Log(transition_str)

        if step > 0 and step % FLAGS.eval_interval_steps == 0:
            for index, eval_set in enumerate(eval_iterators):
                acc, tacc = evaluate(FLAGS, model, data_manager, eval_set, index, logger, step)
                if FLAGS.ckpt_on_best_dev_error and index == 0 and (1 - acc) < 0.99 * best_dev_error and step > FLAGS.ckpt_step:
                    best_dev_error = 1 - acc
                    logger.Log("Checkpointing with new best dev accuracy of %f" % acc)
                    trainer.save(best_checkpoint_path, step, best_dev_error)
            progress_bar.reset()

        if step > FLAGS.ckpt_step and step % FLAGS.ckpt_interval_steps == 0:
            logger.Log("Checkpointing.")
            trainer.save(standard_checkpoint_path, step, best_dev_error)

        progress_bar.step(i=step % FLAGS.statistics_interval_steps, total=FLAGS.statistics_interval_steps)