def cooked_collade_fn(cls, list_samples): list_i = [] list_o = [] list_p = [] seq_input_ints = [] seq_outputs_ints = [] scatter_idx = [] list_samples.sort(key=lambda x: -len(x[2])) list_states = [] for idx, (i, o, p, int_i, int_o, states) in enumerate(list_samples): list_i.append(i) list_o.append(o) list_p.append(p) list_states.append(states) scatter_idx += [idx] * len(int_i) seq_input_ints += int_i seq_outputs_ints += int_o if cls.need_mask: list_states = [torch.LongTensor(x) for x in list_states] padded_states = pad_sequence(list_states, batch_first=True, padding_value=STATE_MAP['halt']) else: padded_states = None padded_i = cls.fn_pad_in(seq_input_ints) padded_o = cls.fn_pad_out(seq_outputs_ints) scatter_idx = torch.LongTensor(scatter_idx) return list_i, list_o, list_p, CookedData(padded_i, padded_o, scatter_idx, padded_states)
def forward(self, list_progs, context_embeds, ll=None, target_list=None, gen_method='sample', sizes=None, has_stopped=None): n_prog = len(list_progs) prog_int_seqs = [torch.LongTensor([self.vocab[c] for c in expr] + [self.tok_stop]).to(context_embeds.device) for expr in list_progs] lengths = [v.size(0) for v in prog_int_seqs] padded_int_seqs = pad_sequence(prog_int_seqs, batch_first=False, padding_value=self.tok_pad) packed_seq = pack_padded_sequence(padded_int_seqs, lengths=lengths, batch_first=False, enforce_sorted=False) tok_embed = self.tok_embed(packed_seq.data) packed_input = PackedSequence(data=tok_embed, batch_sizes=packed_seq.batch_sizes, sorted_indices=packed_seq.sorted_indices, unsorted_indices=packed_seq.unsorted_indices) h = self.ctx2h(context_embeds).view(n_prog, 2 * self.rnn_layers, -1).transpose(0, 1) c = self.ctx2c(context_embeds).view(n_prog, 2 * self.rnn_layers, -1).transpose(0, 1) packed_out, _ = self.lstm(packed_input, (h, c)) unpacked_out, _ = pad_packed_sequence(packed_out) # positions to mod/del expr_poses = (padded_int_seqs == self.tok_constexpr) | (padded_int_seqs == self.tok_subexpr) embed_expr = unpacked_out[expr_poses] if embed_expr.shape[0]: mod_scores = self.modify_score(embed_expr) del_scores = self.del_score(embed_expr) else: mod_scores = del_scores = None # positions to insert ins_poses = padded_int_seqs == self.tok_start insert_scores = self.insert_score(unpacked_out[ins_poses]) # positions to stop stop_poses = padded_int_seqs == self.tok_stop stop_scores = self.stop_score(unpacked_out[stop_poses]) logits = loc_score(mod_scores, del_scores, insert_scores, stop_scores, expr_poses, ins_poses, stop_poses, has_stopped) log_prob = F.log_softmax(logits, dim=0).t().contiguous() ll_target = None predecessors = None if target_list is None: if gen_method == 'sample': target = torch.multinomial(torch.exp(log_prob), 1) elif gen_method == 'argmax': target = torch.argmax(log_prob, dim=1) elif gen_method.startswith('beam'): beam_size = int(gen_method.split('-')[-1]) raw_scores = log_prob + ll if ll is not None else log_prob predecessors, target, ll_target, sizes = beam_step(raw_scores, sizes, beam_size) update_embed = unpacked_out[target, predecessors] else: raise NotImplementedError else: target = torch.LongTensor(target_list).to(log_prob.device) target = target.view(-1) if predecessors is None: ll_step = log_prob[range(n_prog), target] ll_target = ll_step.view(ll.shape) + ll if ll is not None else ll_step update_embed = unpacked_out[target, range(n_prog)] return ll_target.view(-1, 1), target, update_embed, predecessors, sizes
def _padded_io(self, int_seqs, max_len): int_seqs = [torch.LongTensor(x) for x in int_seqs] padded = pad_sequence(int_seqs, max_len=max_len, batch_first=True) return padded
def _pad_io(self, int_seqs): int_seqs = [torch.LongTensor(x) for x in int_seqs] lengths = [v.size(0) for v in int_seqs] return pad_sequence(int_seqs), lengths