Exemplo n.º 1
0
def cooked_collade_fn(cls, list_samples):
    list_i = []
    list_o = []
    list_p = []
    seq_input_ints = []
    seq_outputs_ints = []
    scatter_idx = []
    list_samples.sort(key=lambda x: -len(x[2]))
    list_states = []
    for idx, (i, o, p, int_i, int_o, states) in enumerate(list_samples):
        list_i.append(i)
        list_o.append(o)
        list_p.append(p)
        list_states.append(states)
        scatter_idx += [idx] * len(int_i)
        seq_input_ints += int_i
        seq_outputs_ints += int_o
    if cls.need_mask:
        list_states = [torch.LongTensor(x) for x in list_states]
        padded_states = pad_sequence(list_states, batch_first=True, padding_value=STATE_MAP['halt'])
    else:
        padded_states = None
    padded_i = cls.fn_pad_in(seq_input_ints)
    padded_o = cls.fn_pad_out(seq_outputs_ints)
    scatter_idx = torch.LongTensor(scatter_idx)

    return list_i, list_o, list_p, CookedData(padded_i, padded_o, scatter_idx, padded_states)
Exemplo n.º 2
0
    def forward(self, list_progs, context_embeds, ll=None, target_list=None, gen_method='sample', sizes=None, has_stopped=None):
        n_prog = len(list_progs)
        prog_int_seqs = [torch.LongTensor([self.vocab[c] for c in expr] + [self.tok_stop]).to(context_embeds.device) for expr in list_progs]
        lengths = [v.size(0) for v in prog_int_seqs]
        padded_int_seqs = pad_sequence(prog_int_seqs, batch_first=False, padding_value=self.tok_pad)

        packed_seq = pack_padded_sequence(padded_int_seqs, lengths=lengths, batch_first=False, enforce_sorted=False)
        tok_embed = self.tok_embed(packed_seq.data)
        packed_input = PackedSequence(data=tok_embed, batch_sizes=packed_seq.batch_sizes,
                        sorted_indices=packed_seq.sorted_indices, unsorted_indices=packed_seq.unsorted_indices)

        h = self.ctx2h(context_embeds).view(n_prog, 2 * self.rnn_layers, -1).transpose(0, 1)
        c = self.ctx2c(context_embeds).view(n_prog, 2 * self.rnn_layers, -1).transpose(0, 1)
        packed_out, _ = self.lstm(packed_input, (h, c))
        unpacked_out, _ = pad_packed_sequence(packed_out)

        # positions to mod/del
        expr_poses = (padded_int_seqs == self.tok_constexpr) | (padded_int_seqs == self.tok_subexpr)
        embed_expr = unpacked_out[expr_poses]
        if embed_expr.shape[0]:
            mod_scores = self.modify_score(embed_expr)
            del_scores = self.del_score(embed_expr)
        else:
            mod_scores = del_scores = None
        # positions to insert
        ins_poses = padded_int_seqs == self.tok_start
        insert_scores = self.insert_score(unpacked_out[ins_poses])

        # positions to stop
        stop_poses = padded_int_seqs == self.tok_stop
        stop_scores = self.stop_score(unpacked_out[stop_poses])
        logits = loc_score(mod_scores, del_scores, insert_scores, stop_scores, expr_poses, ins_poses, stop_poses, has_stopped)
        log_prob = F.log_softmax(logits, dim=0).t().contiguous()
        ll_target = None
        predecessors = None
        if target_list is None:
            if gen_method == 'sample':
                target = torch.multinomial(torch.exp(log_prob), 1)
            elif gen_method == 'argmax':
                target = torch.argmax(log_prob, dim=1)
            elif gen_method.startswith('beam'):
                beam_size = int(gen_method.split('-')[-1])
                raw_scores = log_prob + ll if ll is not None else log_prob
                predecessors, target, ll_target, sizes = beam_step(raw_scores, sizes, beam_size)
                update_embed = unpacked_out[target, predecessors]
            else:
                raise NotImplementedError
        else:
            target = torch.LongTensor(target_list).to(log_prob.device)
        target = target.view(-1)
        if predecessors is None:
            ll_step = log_prob[range(n_prog), target]
            ll_target = ll_step.view(ll.shape) + ll if ll is not None else ll_step
            update_embed = unpacked_out[target, range(n_prog)]
        return ll_target.view(-1, 1), target, update_embed, predecessors, sizes
Exemplo n.º 3
0
 def _padded_io(self, int_seqs, max_len):
     int_seqs = [torch.LongTensor(x) for x in int_seqs]
     padded = pad_sequence(int_seqs, max_len=max_len, batch_first=True)
     return padded
Exemplo n.º 4
0
 def _pad_io(self, int_seqs):
     int_seqs = [torch.LongTensor(x) for x in int_seqs]
     lengths = [v.size(0) for v in int_seqs]
     return pad_sequence(int_seqs), lengths