Beispiel #1
0
 def make_batch(self, seq, gpu):
     tokens = [self.vocab.start_token] + seq + [self.vocab.stop_token]
     idxs = torch.LongTensor([[self.vocab.index(t)] for t in tokens])
     lengths = torch.LongTensor([len(tokens)])
     encoder_inputs = Variable(idxs,
                               lengths,
                               batch_dim=1,
                               length_dim=0,
                               pad_value=0)
     return {'encoder_input': encoder_inputs.cuda(gpu)}
Beispiel #2
0
    def init_state(self, batch_size, encoder_state):
        output = Variable(
            torch.LongTensor([self.vocab.start_index] * batch_size)\
                .view(1, -1),
            lengths=torch.LongTensor([1] * batch_size),
            length_dim=0, batch_dim=1, pad_value=self.vocab.pad_index)

        if str(encoder_state.device) != "cpu":
            output = output.cuda(encoder_state.device)

        return {"output": output, "decoder_state": encoder_state}
Beispiel #3
0
 def make_generator_inputs(self, data):
     tokens = [self.source_vocab.start_token] + data["source"] \
         + [self.source_vocab.stop_token]
     inputs = Variable(torch.LongTensor(
         [[self.source_vocab[t] for t in tokens]]).t(),
                       lengths=torch.LongTensor([len(tokens)]),
                       length_dim=0,
                       batch_dim=1,
                       pad_value=self.source_vocab.pad_index)
     if self._gpu > -1:
         inputs = inputs.cuda(self._gpu)
     return {"source_inputs": inputs}
Beispiel #4
0
 def _batch_labels(self, labels_batch):
     input_tokens = torch.LongTensor(
         [[self.source_vocab[tok] for tok in self._labels2input(labels)]
          for labels in labels_batch])
     length = input_tokens.size(1)
     inputs = Variable(input_tokens.t(),
                       lengths=torch.LongTensor([length] *
                                                len(labels_batch)),
                       length_dim=0,
                       batch_dim=1,
                       pad_value=-1)
     if self._gpu > -1:
         inputs = inputs.cuda(self._gpu)
     return {"source_inputs": inputs}
Beispiel #5
0
    def next_state(self, decoder, prev_state, context, active_items, controls):

        # Get next state from the decoder.
        next_state = decoder.next_state(prev_state, context, controls=controls)

        log_probs = next_state["log_probs"].tensor
        samples = torch.distributions.Categorical(
            logits=next_state["target_logits"].tensor).sample()

        sample_output = Variable(samples,
                                 lengths=samples.new(samples.size(1)).fill_(1),
                                 length_dim=0,
                                 batch_dim=1,
                                 pad_value=self.vocab.pad_index)

        output_log_probs = log_probs.gather(2, samples.unsqueeze(-1))\
            .squeeze(2)

        next_state["output"] = sample_output
        next_state["output_log_probs"] = output_log_probs

        # Mask outputs if we have already completed that batch item.
        next_state["output"].data.view(-1).masked_fill_(
            ~active_items, self.vocab.pad_index)

        return next_state
Beispiel #6
0
    def _variable_forward(self, query, key, value):
        key = self.key_net(key)
        query = self.query_net(query)

        key = key.permute_as_sequence_batch_features()
        query = query.permute_as_sequence_batch_features()

        assert key.dim() == query.dim() == 3
        # TODO use named tensors to allow aribitrary seq dims

        with torch.no_grad():
            mask = ~torch.einsum("qbh,kbh->qkb", [(~query.mask).float(),
                                                  (~key.mask).float()]).byte()

        query_uns = query.data.unsqueeze(query.length_dim + 1)
        key_uns = key.data.unsqueeze(query.length_dim)

        hidden = torch.tanh(key_uns + query_uns)
        scores = hidden.matmul(self.weight)

        scores = scores.masked_fill(mask, float("-inf"))

        attention = torch.softmax(scores.transpose(1, 2), dim=2)
        attention = attention.masked_fill(attention != attention, 0.)

        comp = torch.einsum("ijk,kjh->ijh",
                            [attention, self.value_net(value).data])
        comp = Variable(comp, lengths=query.lengths, length_dim=0, batch_dim=1)

        return {"attention": attention, "output": comp}
Beispiel #7
0
    def init_state(self, batch_size, encoder_state):

        beam_state = {}

        n, bs, ds = encoder_state.size()
        assert batch_size == bs

        beam_state["decoder_state"] = encoder_state.unsqueeze(2)\
            .repeat(1, 1, self.beam_size, 1)\
            .view(n, batch_size * self.beam_size, ds)

        beam_state["output"] = Variable(
            torch.LongTensor(
                [self.vocab.start_index] * batch_size * self.beam_size)\
                .view(1, -1),
            lengths=torch.LongTensor([1] * batch_size * self.beam_size),
            length_dim=0, batch_dim=1, pad_value=self.vocab.pad_index)

        if str(encoder_state.device) != "cpu":
            beam_state["output"] = beam_state["output"].cuda(
                encoder_state.device)

        # Start the first beam of each batch with 0 log prob, and all others
        # with -inf.
        beam_state["accum_log_prob"] = (self._init_accum_log_probs(
            batch_size, encoder_state.device))

        # At the first time step no sequences have been terminated so this mask
        # is all 0s.
        beam_state["terminal_mask"] = (encoder_state.new().byte().new(
            1, batch_size * self.beam_size, 1).fill_(0))

        return beam_state
Beispiel #8
0
    def output(self, as_indices=False, n_best=-1):
        if n_best < 1:
            o = self._output[:,0]
            if as_indices:
                olen = (self._lengths[:,0])
                return Variable(o, lengths=olen, batch_dim=0, length_dim=1,
                                pad_value=int(self.vocab.pad_index))
            tokens = []
            for row in o:
                tokens.append([self.vocab.token(t) for t in row
                               if t != self.vocab.pad_index])
            return tokens

        elif n_best < self.beam_size:
            o = self._output[:,:n_best]
        else:
            o = self._output

        if as_indices:
            print(o)
            print(self._lengths)
            return o

        beams = []
        for beam in o:
            tokens = []
            for row in beam:
                tokens.append([self.vocab.token(t) for t in row
                               if t != self.vocab.pad_index])
            beams.append(tokens)
        return beams
Beispiel #9
0
    def init_state(self, batch_size, device, source_indices, 
                   encoder_state=None):

        output = Variable(
            torch.LongTensor([self.vocab.start_index] * batch_size)\
                .view(1, -1),
            lengths=torch.LongTensor([1] * batch_size),
            length_dim=0, batch_dim=1, pad_value=self.vocab.pad_index)
        

        if str(device) != "cpu":
            output = output.cuda(device)
            
        mask = source_indices.mask.t()
        
        return {"target": output, "decoder_state": encoder_state,
                "inputs": output, 'pointer_mask': mask}
Beispiel #10
0
def seq2tsr(seq, vocab):
    length = torch.LongTensor([len(seq)])
    tensor = torch.LongTensor([[vocab[x] for x in seq]]).t()
    return Variable(tensor,
                    lengths=length,
                    length_dim=0,
                    batch_dim=1,
                    pad_value=vocab.pad_index)
def search2inputs(search_outputs, vocab, gpu):
    input_lens = [len(x) + 1 for x in search_outputs]
    max_len = max(input_lens)
    inputs = []
    for out in search_outputs:
        out = [vocab.start_token] + out
        if len(out) < max_len:
            out = out + [vocab.pad_token] * (max_len - len(out))
        inputs.append([vocab[t] for t in out])
    inputs = torch.LongTensor(inputs)
    input_lens = torch.LongTensor(input_lens)
    inputs = Variable(
        inputs.t(),
        lengths=input_lens,
        batch_dim=1, length_dim=0, pad_value=vocab.pad_index)
    if gpu > -1:
        inputs = inputs.cuda(gpu)
    return {"inputs": inputs}
Beispiel #12
0
 def make_batch(self, input_tokens, gpu):
     input_indices = [self.src_vocab.index(t) for t in input_tokens]
     input_indices = [self.src_vocab.start_index] \
         + input_indices + [self.src_vocab.stop_index]
     input_indices = torch.LongTensor([input_indices])
     lengths = torch.LongTensor([len(input_tokens) + 2])
     variable = Variable(input_indices.t(), lengths=lengths,
                         length_dim=0, batch_dim=1,
                         pad_value=self.src_vocab.pad_index).cuda(gpu)
     return {'encoder_input': variable}
Beispiel #13
0
    def _next_candidates(self, batch_size, log_probs, candidates):
        # TODO seq_lps should really be called cumulative log probs.

        # flat_beam_lps (batch size x (beam size ** 2))
        flat_beam_lps = log_probs.view(batch_size, -1)

        flat_beam_scores = flat_beam_lps / (self.steps + 1)

        # beam_seq_scores (batch size x beam size)
        # relative_indices (batch_size x beam size)
        beam_seq_scores, relative_indices = torch.topk(
            flat_beam_scores, k=self.beam_size, dim=1)
        relative_indices_mask = (
            (relative_indices >= flat_beam_scores.size(1))
            |
            (relative_indices < 0)
        ) 
        relative_indices = relative_indices.masked_fill(
            relative_indices_mask, 0)

        # beam_seq_lps (batch size x beam size)
        beam_seq_lps = flat_beam_lps.gather(1, relative_indices)

        # TODO make these ahead of time. 
        offset1 = (
            torch.arange(batch_size, device=beam_seq_lps.device) 
                * self.beam_size
        ).view(batch_size, 1)
        
        offset2 = offset1 * self.beam_size
       
        beam_indexing = ((relative_indices // self.beam_size) + offset1) \
            .view(-1)

        # beam_seq_lps (1 x (batch_size * beam_size) x 1)
        beam_seq_lps = beam_seq_lps \
            .view(1, batch_size * self.beam_size, 1)
        
        # beam_seq_scores (1 x (batch_size * beam_size) x 1)
        beam_seq_scores = beam_seq_scores \
            .view(1, batch_size * self.beam_size, 1)

        # next_output (1 x (batch size * beam size))
        next_candidate_indices = (relative_indices + offset2).view(-1)
        next_output = Variable(
            candidates.view(-1)[next_candidate_indices].view(1, -1),
            lengths=candidates.new().long().new(batch_size * self.beam_size)\
                .fill_(1),
            length_dim=0, batch_dim=1, pad_value=self.vocab.pad_index)


        return beam_seq_lps, beam_seq_scores, next_output, beam_indexing
Beispiel #14
0
    def init_state_context(self, encoder_state):
        batch_size = encoder_state["state"].size(1) * self.samples

        output = Variable(torch.LongTensor([self.vocab.start_index] *
                                           batch_size).view(1, -1),
                          lengths=torch.LongTensor([1] * batch_size),
                          length_dim=0,
                          batch_dim=1,
                          pad_value=self.vocab.pad_index)

        if str(encoder_state["state"].device) != "cpu":
            output = output.cuda(encoder_state["state"].device)

        layers = encoder_state["state"].size(0)
        decoder_state = encoder_state["state"].unsqueeze(2)\
            .repeat(1, 1, self.samples, 1).view(layers, batch_size, -1)
        search_state = {"output": output, "decoder_state": decoder_state}

        encoder_output = encoder_state["output"].repeat_batch_dim(self.samples)
        context = {"encoder_output": encoder_output}

        return search_state, context
Beispiel #15
0
    def init_state(self, batch_size, encoder_state, encoder_input, device):

        beam_state = {}

        if encoder_state:
            n, bs, ds = encoder_state.size()
            assert batch_size == bs
            beam_state["decoder_state"] = encoder_state.unsqueeze(2)\
                .repeat(1, 1, self.beam_size, 1)\
                .view(n, batch_size * self.beam_size, ds)
        ei = encoder_input.permute_as_batch_sequence_features().data
        cons = ei.new().bool().new(batch_size, len(self.vocab)).fill_(1)
        cons = cons.scatter_(1, ei, 0)
        cons = cons.unsqueeze(1).repeat(1, self.beam_size,
                                        1).view(batch_size * self.beam_size,
                                                -1)
        cons[:, self.vocab.stop_index] = 1
        for v in [
                "inform", "give_opinion", "verify_attribute", "request",
                "recommend", "request_explanation", "confirm", "suggest",
                "request_attribute", "<pad>", "<sos>", 'rating=N/A',
                'rating=average', 'rating=excellent', 'rating=good',
                'rating=poor'
        ]:
            if v in self.vocab:
                cons[:, self.vocab.index(v)] = 1
        beam_state['cons'] = cons

        beam_state["target"] = Variable(
            torch.LongTensor(
                [self.vocab.start_index] * batch_size * self.beam_size)\
                .view(1, -1),
            lengths=torch.LongTensor([1] * batch_size * self.beam_size),
            length_dim=0, batch_dim=1, pad_value=self.vocab.pad_index)

        beam_state["target"] = beam_state["target"].to(device)
        beam_state['inputs'] = beam_state['target']

        # Start the first beam of each batch with 0 log prob, and all others
        # with -inf.
        beam_state["accum_log_prob"] = (self._init_accum_log_probs(
            batch_size, device))

        # At the first time step no sequences have been terminated so this mask
        # is all 0s.
        beam_state["terminal_mask"] = (torch.BoolTensor(
            1, batch_size * self.beam_size, 1).fill_(0).to(device))

        return beam_state
Beispiel #16
0
    def _make_classifier_inputs(self, batch):
        lens = []
        clf_inputs = []
        for out in batch:
            clf_inputs.append([self.input_vocab.start_index] +
                              [self.input_vocab[t] for t in out[:-1]] +
                              [self.input_vocab.stop_index])
            lens.append(len(out) + 1)
        lens = torch.LongTensor(lens)
        max_len = lens.max().item()
        clf_inputs = torch.LongTensor([
            inp + [self.input_vocab.pad_index] * (max_len - len(inp))
            for inp in clf_inputs
        ]).t()

        clf_inputs = Variable(clf_inputs,
                              lengths=lens,
                              batch_dim=1,
                              length_dim=0,
                              pad_value=self.input_vocab.pad_index)
        if self.gpu > -1:
            return clf_inputs.cuda(self.gpu)
        else:
            return clf_inputs
Beispiel #17
0
    def output(self, as_indices=False):
        
        if as_indices:
            lengths = (~self._mask_T).sum(0)
            return Variable(self._outputs, lengths=lengths, length_dim=0,
                            batch_dim=1, pad_value=self.vocab.pad_index)\
                            .apply_sequence_mask()

        tokens = []
        for output in self._outputs.t():
            tokens_i = []
            for index in output:
                if index in [self.vocab.pad_index, self.vocab.stop_index]:
                    break
                tokens_i.append(self.vocab.token(index))
            tokens.append(tokens_i)

        return tokens
Beispiel #18
0
    def init_state(self, batch_size, device, source_indices, encoder_state):

        beam_state = {}

        if encoder_state:
            n, bs, ds = encoder_state.size()
            assert batch_size == bs
            beam_state["decoder_state"] = encoder_state.unsqueeze(2)\
                .repeat(1, 1, self.beam_size, 1)\
                .view(n, batch_size * self.beam_size, ds)

        beam_state["target"] = Variable(
            torch.LongTensor(
                [self.vocab.start_index] * batch_size * self.beam_size)\
                .view(1, -1),
            lengths=torch.LongTensor([1] * batch_size * self.beam_size),
            length_dim=0, batch_dim=1, pad_value=self.vocab.pad_index)

        beam_state["target"] = beam_state["target"].to(device)
        beam_state['inputs'] = beam_state['target']

        source_indices = source_indices.permute_as_batch_sequence_features()
        mask = source_indices.mask.unsqueeze(1).repeat(1, self.beam_size, 1)\
            .contiguous().view(-1, source_indices.size(1))
        beam_state['pointer_mask'] = mask

        # Start the first beam of each batch with 0 log prob, and all others 
        # with -inf.
        beam_state["accum_log_prob"] = (
            self._init_accum_log_probs(batch_size, device)
        )

        # At the first time step no sequences have been terminated so this mask
        # is all 0s. 
        beam_state["terminal_mask"] = (
            torch.BoolTensor(
                1, batch_size * self.beam_size, 1).fill_(0).to(device)
        )   

        return beam_state
Beispiel #19
0
    def next_state(self, decoder, prev_state, context, active_items,
                   controls):

        # Get next state from the decoder.
        next_state = decoder.next_state(prev_state, context, controls=controls)

        next_pointer = next_state['log_probs']\
            .permute_as_sequence_batch_features().max(2)[1]\

        target = context['source_indices'].data.t().gather(
            1, next_pointer.data.t())
        target = prev_state['target'].new_with_meta(target.t())
        next_state['target'] = target

        pm = prev_state['pointer_mask']
        pm.scatter_(1, next_pointer.data.t(), 1)
        next_state['pointer_mask'] = pm

        # Mask outputs if we have already completed that batch item. 
        next_state["target"].data.view(-1).masked_fill_(
            ~active_items, int(self.vocab.pad_index))

        next_inputs = torch.cat(
            [
                prev_state['inputs'].data, 
                next_state['target'].data,
            ],
            0)
        next_inputs = Variable(
            next_inputs,
            prev_state['inputs'].lengths + 1,
            length_dim=0, batch_dim=1,
            pad_value=prev_state['inputs'].pad_value)
        next_state['inputs'] = next_inputs
#
        return next_state
Beispiel #20
0
    def __call__(self, decoder, encoder_state, controls=None):
        self.reset()
        batch_size = encoder_state["batch_size"]
        device = encoder_state["device"]
        
        search_state = self.init_state(batch_size, device,
                                       encoder_state['source_indices'],
                                       encoder_state.get("state", None))
        
        context = {
            "encoder_output": encoder_state["output"],
            "source_indices": encoder_state['source_indices'],
        }
        active_items = torch.BoolTensor(batch_size).fill_(1)
        output_length = torch.LongTensor(batch_size).fill_(0)
        if str(encoder_state["device"]) != "cpu":
            active_items = active_items.cuda(encoder_state["device"])
            output_length = output_length.cuda(encoder_state["device"])

        prev_decoder_inputs = search_state["target"] 
        step_masks = []
        # Perform search until we either trigger a termination condition for
        # each batch item or we reach the maximum number of search steps.
        while self.steps < self.max_steps and not self.is_finished:
            
            inactive_items = ~active_items

            # Mask any inputs that are finished, so that greedy would 
            # be identitcal to forward passes. 
            search_state["target"].data.view(-1).masked_fill_(
                inactive_items, int(self.vocab.pad_index))

            step_masks.append(inactive_items)
            self.steps += 1
            search_state = self.next_state(
                decoder, search_state, context, active_items, controls)        
            
            self._states.append(search_state)
            self._outputs.append(search_state["target"].clone()\
                .permute_as_sequence_batch_features())
            output_length = output_length + active_items.long()
            temp_outputs = torch.cat(
                    [prev_decoder_inputs.tensor] + \
                            [x.tensor for x in self._outputs], 0)
            temp_outputs = Variable(temp_outputs, lengths=output_length + 1,
                                    batch_dim=1, length_dim=0,
                                    pad_value=self.vocab.pad_index)
            search_state["prev_output"] = temp_outputs
            

            active_items = self.check_termination(search_state, active_items)
            self.is_finished = torch.all(~active_items)

        # Finish the search by collecting final sequences, and other 
        # stats. 
        self._collect_search_states(active_items)
        self._incomplete_items = active_items
        self._is_finished = True

        self._mask_T = torch.stack(step_masks)
        self._mask = self._mask_T.t().contiguous()
        return self
Beispiel #21
0
    def next_state(self, decoder, batch_size, prev_state, context, 
                   active_items, controls):

        # Get next state from the decoder.
        next_state = decoder.next_state(prev_state, context, controls=controls)

        # Compute the top beam_size next outputs for each beam item.
        # topk_lps (1 x batch size x beam size x beam size)
        # candidate_outputs (1 x batch size x beam size x beam size)
        topk_lps, candidate_outputs = torch.topk(
            next_state["log_probs"].data \
                .view(1, batch_size, self.beam_size, -1),
            k=self.beam_size, dim=3)

        # If any sequence was completed last step, we should mask it's log
        # prob so that we don't generate from the terminal token.
        # slp (1 x batch_size x beam size x 1) 
        slp = prev_state["accum_log_prob"] \
            .masked_fill(prev_state["terminal_mask"], float("-inf")) \
            .view(1, batch_size, self.beam_size, 1)

        # Combine next step log probs with the previous sequences cumulative
        # log probs, i.e.
        #     log P(y_t) = log P(y_<t) + log P(y_t)
        # candidate_log_probs (1 x batch size x beam size x beam size)
        candidate_log_probs = slp + topk_lps

        # Rerank and select the beam_size best options from the available 
        # beam_size ** 2 candidates.
        # b_seq_lps (1 x (batch size * beam size) x 1)
        # b_scores (1 x (batch size * beam size) x 1)
        # b_pointer (1 x (batch size * beam size))
        # b_indices ((batch size * beam size))
        b_seq_lps, b_scores, b_pointer, b_indices = self._next_candidates(
            batch_size, candidate_log_probs, candidate_outputs)

        b_pointer_mask = (
            (b_pointer.data >= context['source_indices'].size(1))
            |
            (b_pointer.data < 0)
        )

        b_pointer = b_pointer.masked_fill(b_pointer_mask, 0)
#        print(context['source_indices'])
#        print(b_pointer)
        b_output = context['source_indices'].gather(1, b_pointer.data.t())
        b_output = b_pointer.new_with_meta(b_output.t())
#        print(b_output)

#        print(b_indices)
        pm = prev_state['pointer_mask']
#        print(pm.size())
#        print(pm.long())
        pm = pm.index_select(0, b_indices)
#        print(pm.long())
        pm.scatter_(1, b_pointer.data.t(), 1)

#        print(pm.long())
#        print(pm.size())

        # TODO re-implement this behavior
        #next_state.stage_indexing("batch", b_indices)

#        print("BEAM INDICES")
#        print(b_indices)
#        print(prev_state['inputs'])
#        print(prev_state['inputs'].index_select(1, b_indices))
#        print(type(prev_state['inputs'].index_select(1, b_indices)))
#        print(b_output)
#        print(type(b_output))
        next_inputs = torch.cat(
            [
                prev_state['inputs'].index_select(1, b_indices).data, 
                b_output.data,
            ],
            0)
        next_inputs = Variable(
            next_inputs,
            prev_state['inputs'].lengths + 1,
            length_dim=0, batch_dim=1,
            pad_value=prev_state['inputs'].pad_value)
#        print(next_inputs)

        next_state = {
            "decoder_state": next_state["decoder_state"]\
                .index_select(1, b_indices),
            "target": b_output,
            "accum_log_prob": b_seq_lps,
            "beam_score": b_scores,
            "beam_indices": b_indices,
            "inputs": next_inputs,
            'pointer_mask': pm,
        }
        return next_state
        #exit()
        #next_state = {"decoder_state": next_state["decoder_state"]
        #print(next_state.keys())


        next_state["output"] = (b_output, ("batch", "sequence"))
        next_state["cumulative_log_probability"] = (
            b_seq_lps, ("sequence", "batch", "placeholder")
        )
        next_state["beam_score"] = (
            b_scores, ("sequence", "batch", "placeholder")
        )
        next_state["beam_indices"] = (b_indices, ("batch"))

        return next_state