def mask_sample_probs_with_length(action_probs, sample_length):
    """

    :param action_probs: probs with special shape like [batch, seq]
    :param sample_length:
    :return:
    """
    p1_log_probs, p2_log_probs, is_copy_log_probs, copy_output_log_probs, sample_output_log_probs, \
    sample_output_ids_log_probs = action_probs
    # record_is_nan(is_copy_log_probs, 'is_copy_log_probs in mask_sample')
    # record_is_nan(copy_output_log_probs, 'copy_output_log_probs in mask_sample')
    # record_is_nan(sample_output_log_probs, 'sample_output_log_probs in mask_sample')
    if not isinstance(sample_length, torch.Tensor):
        sample_length_tensor = torch.LongTensor(sample_length).to(
            is_copy_log_probs.device)
    else:
        sample_length_tensor = sample_length
    length_mask_float = create_sequence_length_mask(
        sample_length_tensor, max_len=is_copy_log_probs.shape[1]).float()

    is_copy_log_probs = is_copy_log_probs * length_mask_float
    sample_output_ids_log_probs = sample_output_ids_log_probs * length_mask_float

    sample_total_log_probs = torch.sum(
        is_copy_log_probs + sample_output_ids_log_probs,
        dim=-1) / sample_length_tensor.float()

    # record_is_nan(sample_total_log_probs, 'sample_total_log_probs in mask_sample')
    final_probs = p1_log_probs + p2_log_probs + sample_total_log_probs
    # record_is_nan(final_probs, 'final_probs in mask_sample')
    return final_probs
예제 #2
0
 def _forward(self, input_seq, input_length, decoder_input, grammar_index,
              grammar_index_length, target_index):
     encoder_hidden, encoder_mask, encoder_output, is_copy = self._encoder_and_calculate_is_copy(
         input_length, input_seq)
     decoder_output, _, _ = self.decoder(
         inputs=self.embedding(decoder_input),
         encoder_hidden=encoder_hidden,
         encoder_outputs=encoder_output,
         encoder_mask=~encoder_mask,
         teacher_forcing_ratio=1)
     decoder_output = torch.stack(decoder_output, dim=1)
     max_length = decoder_output.shape[1]
     decoder_output = decoder_output.view(-1, decoder_output.shape[-1])
     to_select_index = []
     for i, index_list in enumerate(target_index):
         for t in index_list:
             to_select_index.append(max_length * i + t)
     decoder_output = torch.index_select(
         decoder_output, 0,
         torch.LongTensor(to_select_index).to(input_seq.device))
     grammar_mask = create_sequence_length_mask(
         grammar_index_length, max_len=grammar_index.shape[1])
     decoder_output = self.grammar_mask_output(decoder_output,
                                               grammar_index, grammar_mask)
     return is_copy.squeeze(-1), decoder_output
예제 #3
0
 def forward(self, adjacent_matrix, inp_seq, inp_seq_len):
     _, _, encoder_logit = self.graph_encoder(adjacent_matrix, inp_seq, inp_seq_len)
     if self.check_error_task:
         output_logit = self.output(encoder_logit).squeeze(-1)
     else:
         mask = create_sequence_length_mask(inp_seq_len).unsqueeze(dim=-2)
         encoder_logit, _ = self.attention(encoder_logit, encoder_logit, encoder_logit, mask)
         output_logit = self.output(encoder_logit)
     return [output_logit]
예제 #4
0
 def _encoder_and_calculate_is_copy(self, input_length, input_seq):
     input_seq = self.embedding(input_seq)
     encoder_output, encoder_hidden_state = self.encoder(input_seq, )
     is_copy = self.is_copy_output(encoder_output)
     encoder_mask = create_sequence_length_mask(input_length, )
     encoder_hidden = [
         hid.view(self.num_layers, hid.shape[1], -1)
         for hid in encoder_hidden_state
     ]
     return encoder_hidden, encoder_mask, encoder_output, is_copy
 def do_input_position(self, inputs, input_lengths):
     # input embedding
     if self.position_encode_fn is not None:
         position_input = self.dropout(self.position_encode_fn(inputs))
     else:
         position_input = inputs
     input_mask = create_sequence_length_mask(
         input_lengths,
         torch.max(input_lengths).item(),
         gpu_index=gpu_index)
     return input_mask, position_input
 def forward(self,
             inputs,
             input_length,
             targets,
             target_length,
             do_sample=False):
     teacher_forcing_ratio = 0 if do_sample else 1
     result = self.seq_model(input_variable=inputs,
                             input_lengths=input_length,
                             target_variable=targets,
                             teacher_forcing_ratio=teacher_forcing_ratio)
     outputs = torch.stack(result[0], dim=1)
     decoder_mask = create_sequence_length_mask(target_length - 1)
     # outputs = outputs.data.masked_fill_(~decoder_mask.unsqueeze(dim=2), 0)
     return [outputs]
예제 #7
0
 def forward(self,
             input_variable,
             input_lengths=None,
             target_variable=None,
             teacher_forcing_ratio=0):
     encoder_outputs, encoder_hidden = self.encoder(input_variable,
                                                    input_lengths)
     encoder_mask = create_sequence_length_mask(input_lengths)
     result = self.decoder(inputs=target_variable,
                           encoder_hidden=encoder_hidden,
                           encoder_outputs=encoder_outputs,
                           function=self.decode_function,
                           teacher_forcing_ratio=teacher_forcing_ratio,
                           encoder_mask=encoder_mask)
     return result
    def forward(self, inputs_embed, input_lengths, outputs, output_lengths):
        input_mask, position_input_embed = self.do_input_position(
            inputs_embed, input_lengths)

        encode_value = self.do_encode(input_mask, position_input_embed)

        # output embedding
        outputs_embed = self.dropout(self.output_embedding(outputs))
        position_output = self.dropout(self.position_encode_fn(outputs_embed))
        output_mask = create_sequence_length_mask(
            output_lengths,
            torch.max(output_lengths).item(),
            gpu_index=gpu_index)

        decoder_value = self.do_decode(position_output, output_mask,
                                       encode_value, input_mask)
        output_value = self.output(decoder_value)
        return output_value
예제 #9
0
def parse_graph_output_from_mask_lm_output(input_seq,
                                           ac_seq,
                                           ac_seq_len,
                                           ignore_id=-1,
                                           check_error_task=False,
                                           is_effect=[]):
    is_effect_mask = torch.ByteTensor(is_effect).unsqueeze(-1).to(
        ac_seq_len.device)
    mask = create_sequence_length_mask(ac_seq_len)
    mask = mask * is_effect_mask
    if check_error_task:
        target = torch.eq(input_seq[:, 1:1 + ac_seq.size(1)], ac_seq).float()
        target = torch.where(mask, target,
                             torch.FloatTensor([ignore_id]).to(target.device))
    else:
        target = torch.where(mask, ac_seq,
                             torch.LongTensor([ignore_id]).to(ac_seq.device))
    return target
예제 #10
0
    def merge_sequence_accroding_chunk_list(self,
                                            line_tensor: torch.Tensor,
                                            chunk_list,
                                            dim=0):
        '''
        merge multiply chunk to one sequence accroding chunk list
        :param line_tensor: [..., len(chunk_list), max(chunk_list), ...]
        :param chunk_list: a list of each chunk size
        :return: [..., sequence, ...]
        '''
        line_shape = list(line_tensor.shape)
        new_shape = line_shape[:dim] + [-1] + line_shape[dim + 2:]
        line_tensor = line_tensor.contiguous().view(*new_shape)

        chunk_index_mask = create_sequence_length_mask(chunk_list).view(-1)
        index_shape = [-1 for _ in new_shape]
        index_shape[dim] = chunk_index_mask.shape[0]
        chunk_index_mask = chunk_index_mask.view(*index_shape)
        line_tensor = torch.masked_select(
            line_tensor, mask=chunk_index_mask).view(*new_shape)
        return line_tensor
예제 #11
0
 def forward(self, x, adj, copy_length):
     if self.mask_ast_node_in_rnn:
         copy_length_mask = create_sequence_length_mask(
             copy_length, x.shape[1]).unsqueeze(-1)
         zero_fill = torch.zeros_like(x)
         for i in range(self.graph_itr):
             tx = torch.where(copy_length_mask, x, zero_fill)
             tx = tx + self.rnn[i](tx, adj, copy_length)
             x = torch.where(copy_length_mask, tx, x)
             x = self.dropout(x)
             # for _ in range(self.inner_graph_itr):
             x = x + self.graph(x, adj)
             if i < self.graph_itr - 1:
                 # pass
                 x = self.dropout(x)
     else:
         for i in range(self.graph_itr):
             x = x + self.rnn[i](x, adj, copy_length)
             x = self.dropout(x)
             x = x + self.graph(x, adj)
             if i < self.graph_itr - 1:
                 x = self.dropout(x)
     return x
    def add_result(self,
                   output_ids,
                   model_output,
                   model_target,
                   model_input,
                   ignore_token=None,
                   batch_data=None):
        """

        :param log_probs: [batch, seq, vocab_size]
        :param target: [batch, seq]
        :param ignore_token:
        :param gpu_index:
        :param batch_data:
        :return:
        """
        input_seq_len = model_input[1]
        seq_mask = create_sequence_length_mask(input_seq_len)
        output_ids = output_ids.tolist()

        outputs = [
            self.convert_one_token_ids_to_code(o, self.vocab.id_to_word)
            for o in output_ids
        ]
예제 #13
0
def combine_train(p_model,
                  s_model,
                  seq_model,
                  dataset,
                  batch_size,
                  loss_fn,
                  p_optimizer,
                  s_optimizer,
                  delay_reward_fn,
                  baseline_fn,
                  delay_loss_fn,
                  vocab,
                  train_type=None,
                  predict_type='first',
                  include_error_reward=-10000,
                  pretrain=False,
                  random_action=None):
    if train_type == 'p_model':
        change_model_state([p_model], [s_model, seq_model])
        policy_train = True
    elif train_type == 's_model':
        change_model_state([s_model, seq_model], [p_model])
        policy_train = False
    else:
        change_model_state([], [p_model, s_model, seq_model])
        policy_train = False

    begin_tensor = s_model.begin_token
    end_tensor = s_model.end_token
    gap_tensor = s_model.gap_token

    begin_len = 1
    begin_token = vocab.word_to_id(vocab.begin_tokens[0])
    end_token = vocab.word_to_id(vocab.end_tokens[0])
    gap_token = vocab.word_to_id(vocab.addition_tokens[0])
    step = 0
    select_count = torch.LongTensor([0])
    seq_count = torch.LongTensor([0])
    decoder_input_count = torch.LongTensor([0])
    total_seq_loss = torch.Tensor([0])
    total_p_loss = torch.Tensor([0])
    total_s_accuracy_top_k = {}
    for data in data_loader(dataset,
                            batch_size=batch_size,
                            is_shuffle=True,
                            drop_last=True):
        p_model.zero_grad()
        s_model.zero_grad()
        seq_model.zero_grad()

        error_tokens = transform_to_cuda(
            torch.LongTensor(PaddedList(data['error_tokens'])))
        error_length = transform_to_cuda(torch.LongTensor(
            data['error_length']))
        error_action_masks = transform_to_cuda(
            torch.ByteTensor(PaddedList(data['error_mask'], fill_value=0)))

        max_len = torch.max(error_length)
        error_token_masks = create_sequence_length_mask(
            error_length, max_len=max_len.data.item(), gpu_index=gpu_index)

        # add full code context information to each position word using BiRNN.
        context_input, context_hidden = s_model.do_context_rnn(error_tokens)
        # sample the action by interaction between policy model(p_model) and structed model(s_model)
        if not pretrain:
            action_probs_records_list, action_records_list, output_records_list, hidden = create_policy_action_batch(
                p_model, s_model, context_input, policy_train=policy_train)
        else:
            action_probs_records_list, action_records_list, output_records_list, hidden = create_policy_action_batch(
                p_model,
                s_model,
                context_input,
                policy_train=True,
                random_action=[0.8, 0.2])
        action_probs_records = torch.stack(action_probs_records_list, dim=1)
        action_records = torch.stack(action_records_list, dim=1)
        output_records = torch.cat(output_records_list, dim=1)
        masked_action_records = action_records.data.masked_fill_(
            ~error_token_masks, 0)
        if pretrain:
            masked_action_records = error_action_masks.byte(
            ) | masked_action_records.byte()

        include_all_error = check_action_include_all_error(
            masked_action_records, error_action_masks)
        contain_all_error_count = torch.sum(include_all_error)

        tokens_tensor, token_length, part_ac_tokens_list, ac_token_length = combine_spilt_tokens_batch_with_tensor(
            output_records,
            data['ac_tokens'],
            masked_action_records,
            data['token_map'],
            gap_tensor,
            begin_tensor,
            end_tensor,
            gap_token,
            begin_token,
            end_token,
            gpu_index=gpu_index)

        if predict_type == 'start':
            decoder_input = [tokens[:-1] for tokens in part_ac_tokens_list]
            decoder_length = [len(inp) for inp in decoder_input]
            target_output = [tokens[1:] for tokens in part_ac_tokens_list]
        elif predict_type == 'first':
            decoder_input = [
                tokens[begin_len:-1] for tokens in part_ac_tokens_list
            ]
            decoder_length = [len(inp) for inp in decoder_input]
            target_output = [
                tokens[begin_len + 1:] for tokens in part_ac_tokens_list
            ]

        token_length_tensor = transform_to_cuda(torch.LongTensor(token_length))
        ac_token_tensor = transform_to_cuda(
            torch.LongTensor(PaddedList(decoder_input, fill_value=0)))
        ac_token_length_tensor = transform_to_cuda(
            torch.LongTensor(decoder_length))
        log_probs = seq_model.forward(tokens_tensor, token_length_tensor,
                                      ac_token_tensor, ac_token_length_tensor)

        target_output_tensor = transform_to_cuda(
            torch.LongTensor(
                PaddedList(target_output, fill_value=TARGET_PAD_TOKEN)))
        s_loss = loss_fn(log_probs.view(-1, vocab.vocabulary_size),
                         target_output_tensor.view(-1))

        remain_batch = torch.sum(masked_action_records, dim=1)
        add_batch = torch.eq(remain_batch, 0).long()
        remain_batch = remain_batch + add_batch
        total_batch = torch.sum(error_token_masks, dim=1)
        force_error_rewards = (
            ~include_all_error).float() * include_error_reward
        delay_reward = delay_reward_fn(log_probs, target_output_tensor,
                                       total_batch, remain_batch,
                                       force_error_rewards)
        delay_reward = torch.unsqueeze(delay_reward, dim=1).expand(-1, max_len)
        delay_reward = delay_reward * error_token_masks.float()

        if baseline_fn is not None:
            baseline_reward = baseline_fn(delay_reward, error_token_masks)
            total_reward = delay_reward - baseline_reward
        else:
            total_reward = delay_reward

        # force_error_rewards = torch.unsqueeze(~include_all_error, dim=1).float() * error_token_masks.float() * include_error_reward
        force_error_rewards = torch.unsqueeze(
            ~include_all_error, dim=1).float() * error_token_masks.float() * 0
        p_loss = delay_loss_fn(action_probs_records, total_reward,
                               error_token_masks, force_error_rewards)

        if math.isnan(p_loss):
            print('p_loss is nan')
            continue
        # iterate record variable
        step += 1
        one_decoder_input_count = torch.sum(ac_token_length_tensor)
        decoder_input_count += one_decoder_input_count.data.cpu()
        total_seq_loss += s_loss.cpu().data.item(
        ) * one_decoder_input_count.float().cpu()

        one_seq_count = torch.sum(error_length)
        seq_count += one_seq_count.cpu()
        total_p_loss += p_loss.cpu().data.item() * one_seq_count.float().cpu()

        s_accuracy_top_k = calculate_accuracy_of_code_completion(
            log_probs,
            target_output_tensor,
            ignore_token=TARGET_PAD_TOKEN,
            topk_range=(1, 5),
            gpu_index=gpu_index)
        for key, value in s_accuracy_top_k.items():
            total_s_accuracy_top_k[key] = s_accuracy_top_k.get(key, 0) + value

        select_count_each_batch = torch.sum(masked_action_records, dim=1)
        select_count = select_count + torch.sum(
            select_count_each_batch).data.cpu()

        print(
            'train_type: {} step {} sequence model loss: {}, policy model loss: {}, contain all error count: {}, select of each batch: {}, total of each batch: {}, total decoder_input_cout: {}, topk: {}, '
            .format(train_type, step, s_loss, p_loss, contain_all_error_count,
                    select_count_each_batch.data.tolist(),
                    error_length.data.tolist(),
                    one_decoder_input_count.data.item(), s_accuracy_top_k))
        sys.stdout.flush()
        sys.stderr.flush()

        if train_type != 'p_model':
            p_model.zero_grad()
        if train_type != 's_model':
            s_model.zero_grad()
            seq_model.zero_grad()

        if train_type == 'p_model':
            torch.nn.utils.clip_grad_norm_(p_model.parameters(), 0.5)
            p_loss.backward()
            p_optimizer.step()
        elif train_type == 's_model':
            torch.nn.utils.clip_grad_norm_(s_model.parameters(), 8)
            torch.nn.utils.clip_grad_norm_(seq_model.parameters(), 8)
            s_loss.backward()
            s_optimizer.step()

    for key, value in total_s_accuracy_top_k.items():
        total_s_accuracy_top_k[key] = total_s_accuracy_top_k.get(
            key, 0) / decoder_input_count.data.item()

    return (total_seq_loss / decoder_input_count.float()).data.item(), (
        total_p_loss / seq_count.float()).data.item(), (
            select_count.float() /
            seq_count.float()).data.item(), total_s_accuracy_top_k
예제 #14
0
    def forward(self,
                input_seq,
                input_line_length: torch.Tensor,
                input_line_token_length: torch.Tensor,
                input_length: torch.Tensor,
                adj_matrix,
                target_error_position,
                target_seq,
                target_length,
                do_sample=False):
        '''

        :param input_seq: input sequence, torch.Tensor, full with token id in dictionary
        :param input_line_length: the number of lines per input sequence, [batch]
        :param input_line_token_length: the length of each lines, [batch, line]
        :param input_length: input token length include ast node, [batch]
        :param adj_matrix:
        :param target_seq:
        :param target_length:
        :param do_sample:
        :return:
        '''
        teacher_forcing_ratio = 0 if do_sample else 1
        embedded = self.embedding(input_seq)
        embedded = self.input_dropout(embedded)

        if self.graph_embedding is not None:
            copy_length = torch.sum(input_line_token_length, dim=-1)
            graph_embedded = self.graph_encoder.forward(
                adjacent_matrix=adj_matrix,
                copy_length=copy_length,
                input_seq=embedded)
        else:
            graph_embedded = embedded

        batch_token_sequence, batch_line_sequence, batch_line_hidden = self.line_encoder.forward(
            graph_embedded, input_line_token_length)
        # code_state: [num_layers* bi_num, batch_size, hidden_size]
        line_output_state, code_state = self.code_encoder(
            batch_line_sequence, input_line_length)

        line_mask = create_sequence_length_mask(input_line_length)
        error_position = self.position_pointer(line_output_state, line_mask)
        if do_sample:
            pos = torch.max(error_position, dim=-1)[1]
        else:
            pos = target_error_position

        error_line_hidden_list = [
            batch_line_hidden[:, i, p] for i, p in enumerate(pos)
        ]
        error_line_hidden = torch.stack(tuple(error_line_hidden_list), dim=1)
        # error_line_hidden = batch_line_hidden[:, :, pos]
        combine_line_state = self.encoder_linear(
            torch.cat((error_line_hidden, code_state), dim=-1))

        encoder_mask = create_sequence_length_mask(
            input_length, max_len=graph_embedded.shape[1])
        # combine_line_state: [num_layers, batch, hidden_size]
        # batch_token_sequence: [batch, tokens, hidden_size]
        decoder_outputs, decoder_hidden, _ = self.decoder(
            inputs=target_seq,
            encoder_hidden=combine_line_state,
            encoder_outputs=graph_embedded,
            function=F.log_softmax,
            teacher_forcing_ratio=teacher_forcing_ratio,
            encoder_mask=~encoder_mask)
        decoder_outputs = torch.stack(decoder_outputs, dim=1)
        return error_position, decoder_outputs