def mask_sample_probs_with_length(action_probs, sample_length): """ :param action_probs: probs with special shape like [batch, seq] :param sample_length: :return: """ p1_log_probs, p2_log_probs, is_copy_log_probs, copy_output_log_probs, sample_output_log_probs, \ sample_output_ids_log_probs = action_probs # record_is_nan(is_copy_log_probs, 'is_copy_log_probs in mask_sample') # record_is_nan(copy_output_log_probs, 'copy_output_log_probs in mask_sample') # record_is_nan(sample_output_log_probs, 'sample_output_log_probs in mask_sample') if not isinstance(sample_length, torch.Tensor): sample_length_tensor = torch.LongTensor(sample_length).to( is_copy_log_probs.device) else: sample_length_tensor = sample_length length_mask_float = create_sequence_length_mask( sample_length_tensor, max_len=is_copy_log_probs.shape[1]).float() is_copy_log_probs = is_copy_log_probs * length_mask_float sample_output_ids_log_probs = sample_output_ids_log_probs * length_mask_float sample_total_log_probs = torch.sum( is_copy_log_probs + sample_output_ids_log_probs, dim=-1) / sample_length_tensor.float() # record_is_nan(sample_total_log_probs, 'sample_total_log_probs in mask_sample') final_probs = p1_log_probs + p2_log_probs + sample_total_log_probs # record_is_nan(final_probs, 'final_probs in mask_sample') return final_probs
def _forward(self, input_seq, input_length, decoder_input, grammar_index, grammar_index_length, target_index): encoder_hidden, encoder_mask, encoder_output, is_copy = self._encoder_and_calculate_is_copy( input_length, input_seq) decoder_output, _, _ = self.decoder( inputs=self.embedding(decoder_input), encoder_hidden=encoder_hidden, encoder_outputs=encoder_output, encoder_mask=~encoder_mask, teacher_forcing_ratio=1) decoder_output = torch.stack(decoder_output, dim=1) max_length = decoder_output.shape[1] decoder_output = decoder_output.view(-1, decoder_output.shape[-1]) to_select_index = [] for i, index_list in enumerate(target_index): for t in index_list: to_select_index.append(max_length * i + t) decoder_output = torch.index_select( decoder_output, 0, torch.LongTensor(to_select_index).to(input_seq.device)) grammar_mask = create_sequence_length_mask( grammar_index_length, max_len=grammar_index.shape[1]) decoder_output = self.grammar_mask_output(decoder_output, grammar_index, grammar_mask) return is_copy.squeeze(-1), decoder_output
def forward(self, adjacent_matrix, inp_seq, inp_seq_len): _, _, encoder_logit = self.graph_encoder(adjacent_matrix, inp_seq, inp_seq_len) if self.check_error_task: output_logit = self.output(encoder_logit).squeeze(-1) else: mask = create_sequence_length_mask(inp_seq_len).unsqueeze(dim=-2) encoder_logit, _ = self.attention(encoder_logit, encoder_logit, encoder_logit, mask) output_logit = self.output(encoder_logit) return [output_logit]
def _encoder_and_calculate_is_copy(self, input_length, input_seq): input_seq = self.embedding(input_seq) encoder_output, encoder_hidden_state = self.encoder(input_seq, ) is_copy = self.is_copy_output(encoder_output) encoder_mask = create_sequence_length_mask(input_length, ) encoder_hidden = [ hid.view(self.num_layers, hid.shape[1], -1) for hid in encoder_hidden_state ] return encoder_hidden, encoder_mask, encoder_output, is_copy
def do_input_position(self, inputs, input_lengths): # input embedding if self.position_encode_fn is not None: position_input = self.dropout(self.position_encode_fn(inputs)) else: position_input = inputs input_mask = create_sequence_length_mask( input_lengths, torch.max(input_lengths).item(), gpu_index=gpu_index) return input_mask, position_input
def forward(self, inputs, input_length, targets, target_length, do_sample=False): teacher_forcing_ratio = 0 if do_sample else 1 result = self.seq_model(input_variable=inputs, input_lengths=input_length, target_variable=targets, teacher_forcing_ratio=teacher_forcing_ratio) outputs = torch.stack(result[0], dim=1) decoder_mask = create_sequence_length_mask(target_length - 1) # outputs = outputs.data.masked_fill_(~decoder_mask.unsqueeze(dim=2), 0) return [outputs]
def forward(self, input_variable, input_lengths=None, target_variable=None, teacher_forcing_ratio=0): encoder_outputs, encoder_hidden = self.encoder(input_variable, input_lengths) encoder_mask = create_sequence_length_mask(input_lengths) result = self.decoder(inputs=target_variable, encoder_hidden=encoder_hidden, encoder_outputs=encoder_outputs, function=self.decode_function, teacher_forcing_ratio=teacher_forcing_ratio, encoder_mask=encoder_mask) return result
def forward(self, inputs_embed, input_lengths, outputs, output_lengths): input_mask, position_input_embed = self.do_input_position( inputs_embed, input_lengths) encode_value = self.do_encode(input_mask, position_input_embed) # output embedding outputs_embed = self.dropout(self.output_embedding(outputs)) position_output = self.dropout(self.position_encode_fn(outputs_embed)) output_mask = create_sequence_length_mask( output_lengths, torch.max(output_lengths).item(), gpu_index=gpu_index) decoder_value = self.do_decode(position_output, output_mask, encode_value, input_mask) output_value = self.output(decoder_value) return output_value
def parse_graph_output_from_mask_lm_output(input_seq, ac_seq, ac_seq_len, ignore_id=-1, check_error_task=False, is_effect=[]): is_effect_mask = torch.ByteTensor(is_effect).unsqueeze(-1).to( ac_seq_len.device) mask = create_sequence_length_mask(ac_seq_len) mask = mask * is_effect_mask if check_error_task: target = torch.eq(input_seq[:, 1:1 + ac_seq.size(1)], ac_seq).float() target = torch.where(mask, target, torch.FloatTensor([ignore_id]).to(target.device)) else: target = torch.where(mask, ac_seq, torch.LongTensor([ignore_id]).to(ac_seq.device)) return target
def merge_sequence_accroding_chunk_list(self, line_tensor: torch.Tensor, chunk_list, dim=0): ''' merge multiply chunk to one sequence accroding chunk list :param line_tensor: [..., len(chunk_list), max(chunk_list), ...] :param chunk_list: a list of each chunk size :return: [..., sequence, ...] ''' line_shape = list(line_tensor.shape) new_shape = line_shape[:dim] + [-1] + line_shape[dim + 2:] line_tensor = line_tensor.contiguous().view(*new_shape) chunk_index_mask = create_sequence_length_mask(chunk_list).view(-1) index_shape = [-1 for _ in new_shape] index_shape[dim] = chunk_index_mask.shape[0] chunk_index_mask = chunk_index_mask.view(*index_shape) line_tensor = torch.masked_select( line_tensor, mask=chunk_index_mask).view(*new_shape) return line_tensor
def forward(self, x, adj, copy_length): if self.mask_ast_node_in_rnn: copy_length_mask = create_sequence_length_mask( copy_length, x.shape[1]).unsqueeze(-1) zero_fill = torch.zeros_like(x) for i in range(self.graph_itr): tx = torch.where(copy_length_mask, x, zero_fill) tx = tx + self.rnn[i](tx, adj, copy_length) x = torch.where(copy_length_mask, tx, x) x = self.dropout(x) # for _ in range(self.inner_graph_itr): x = x + self.graph(x, adj) if i < self.graph_itr - 1: # pass x = self.dropout(x) else: for i in range(self.graph_itr): x = x + self.rnn[i](x, adj, copy_length) x = self.dropout(x) x = x + self.graph(x, adj) if i < self.graph_itr - 1: x = self.dropout(x) return x
def add_result(self, output_ids, model_output, model_target, model_input, ignore_token=None, batch_data=None): """ :param log_probs: [batch, seq, vocab_size] :param target: [batch, seq] :param ignore_token: :param gpu_index: :param batch_data: :return: """ input_seq_len = model_input[1] seq_mask = create_sequence_length_mask(input_seq_len) output_ids = output_ids.tolist() outputs = [ self.convert_one_token_ids_to_code(o, self.vocab.id_to_word) for o in output_ids ]
def combine_train(p_model, s_model, seq_model, dataset, batch_size, loss_fn, p_optimizer, s_optimizer, delay_reward_fn, baseline_fn, delay_loss_fn, vocab, train_type=None, predict_type='first', include_error_reward=-10000, pretrain=False, random_action=None): if train_type == 'p_model': change_model_state([p_model], [s_model, seq_model]) policy_train = True elif train_type == 's_model': change_model_state([s_model, seq_model], [p_model]) policy_train = False else: change_model_state([], [p_model, s_model, seq_model]) policy_train = False begin_tensor = s_model.begin_token end_tensor = s_model.end_token gap_tensor = s_model.gap_token begin_len = 1 begin_token = vocab.word_to_id(vocab.begin_tokens[0]) end_token = vocab.word_to_id(vocab.end_tokens[0]) gap_token = vocab.word_to_id(vocab.addition_tokens[0]) step = 0 select_count = torch.LongTensor([0]) seq_count = torch.LongTensor([0]) decoder_input_count = torch.LongTensor([0]) total_seq_loss = torch.Tensor([0]) total_p_loss = torch.Tensor([0]) total_s_accuracy_top_k = {} for data in data_loader(dataset, batch_size=batch_size, is_shuffle=True, drop_last=True): p_model.zero_grad() s_model.zero_grad() seq_model.zero_grad() error_tokens = transform_to_cuda( torch.LongTensor(PaddedList(data['error_tokens']))) error_length = transform_to_cuda(torch.LongTensor( data['error_length'])) error_action_masks = transform_to_cuda( torch.ByteTensor(PaddedList(data['error_mask'], fill_value=0))) max_len = torch.max(error_length) error_token_masks = create_sequence_length_mask( error_length, max_len=max_len.data.item(), gpu_index=gpu_index) # add full code context information to each position word using BiRNN. context_input, context_hidden = s_model.do_context_rnn(error_tokens) # sample the action by interaction between policy model(p_model) and structed model(s_model) if not pretrain: action_probs_records_list, action_records_list, output_records_list, hidden = create_policy_action_batch( p_model, s_model, context_input, policy_train=policy_train) else: action_probs_records_list, action_records_list, output_records_list, hidden = create_policy_action_batch( p_model, s_model, context_input, policy_train=True, random_action=[0.8, 0.2]) action_probs_records = torch.stack(action_probs_records_list, dim=1) action_records = torch.stack(action_records_list, dim=1) output_records = torch.cat(output_records_list, dim=1) masked_action_records = action_records.data.masked_fill_( ~error_token_masks, 0) if pretrain: masked_action_records = error_action_masks.byte( ) | masked_action_records.byte() include_all_error = check_action_include_all_error( masked_action_records, error_action_masks) contain_all_error_count = torch.sum(include_all_error) tokens_tensor, token_length, part_ac_tokens_list, ac_token_length = combine_spilt_tokens_batch_with_tensor( output_records, data['ac_tokens'], masked_action_records, data['token_map'], gap_tensor, begin_tensor, end_tensor, gap_token, begin_token, end_token, gpu_index=gpu_index) if predict_type == 'start': decoder_input = [tokens[:-1] for tokens in part_ac_tokens_list] decoder_length = [len(inp) for inp in decoder_input] target_output = [tokens[1:] for tokens in part_ac_tokens_list] elif predict_type == 'first': decoder_input = [ tokens[begin_len:-1] for tokens in part_ac_tokens_list ] decoder_length = [len(inp) for inp in decoder_input] target_output = [ tokens[begin_len + 1:] for tokens in part_ac_tokens_list ] token_length_tensor = transform_to_cuda(torch.LongTensor(token_length)) ac_token_tensor = transform_to_cuda( torch.LongTensor(PaddedList(decoder_input, fill_value=0))) ac_token_length_tensor = transform_to_cuda( torch.LongTensor(decoder_length)) log_probs = seq_model.forward(tokens_tensor, token_length_tensor, ac_token_tensor, ac_token_length_tensor) target_output_tensor = transform_to_cuda( torch.LongTensor( PaddedList(target_output, fill_value=TARGET_PAD_TOKEN))) s_loss = loss_fn(log_probs.view(-1, vocab.vocabulary_size), target_output_tensor.view(-1)) remain_batch = torch.sum(masked_action_records, dim=1) add_batch = torch.eq(remain_batch, 0).long() remain_batch = remain_batch + add_batch total_batch = torch.sum(error_token_masks, dim=1) force_error_rewards = ( ~include_all_error).float() * include_error_reward delay_reward = delay_reward_fn(log_probs, target_output_tensor, total_batch, remain_batch, force_error_rewards) delay_reward = torch.unsqueeze(delay_reward, dim=1).expand(-1, max_len) delay_reward = delay_reward * error_token_masks.float() if baseline_fn is not None: baseline_reward = baseline_fn(delay_reward, error_token_masks) total_reward = delay_reward - baseline_reward else: total_reward = delay_reward # force_error_rewards = torch.unsqueeze(~include_all_error, dim=1).float() * error_token_masks.float() * include_error_reward force_error_rewards = torch.unsqueeze( ~include_all_error, dim=1).float() * error_token_masks.float() * 0 p_loss = delay_loss_fn(action_probs_records, total_reward, error_token_masks, force_error_rewards) if math.isnan(p_loss): print('p_loss is nan') continue # iterate record variable step += 1 one_decoder_input_count = torch.sum(ac_token_length_tensor) decoder_input_count += one_decoder_input_count.data.cpu() total_seq_loss += s_loss.cpu().data.item( ) * one_decoder_input_count.float().cpu() one_seq_count = torch.sum(error_length) seq_count += one_seq_count.cpu() total_p_loss += p_loss.cpu().data.item() * one_seq_count.float().cpu() s_accuracy_top_k = calculate_accuracy_of_code_completion( log_probs, target_output_tensor, ignore_token=TARGET_PAD_TOKEN, topk_range=(1, 5), gpu_index=gpu_index) for key, value in s_accuracy_top_k.items(): total_s_accuracy_top_k[key] = s_accuracy_top_k.get(key, 0) + value select_count_each_batch = torch.sum(masked_action_records, dim=1) select_count = select_count + torch.sum( select_count_each_batch).data.cpu() print( 'train_type: {} step {} sequence model loss: {}, policy model loss: {}, contain all error count: {}, select of each batch: {}, total of each batch: {}, total decoder_input_cout: {}, topk: {}, ' .format(train_type, step, s_loss, p_loss, contain_all_error_count, select_count_each_batch.data.tolist(), error_length.data.tolist(), one_decoder_input_count.data.item(), s_accuracy_top_k)) sys.stdout.flush() sys.stderr.flush() if train_type != 'p_model': p_model.zero_grad() if train_type != 's_model': s_model.zero_grad() seq_model.zero_grad() if train_type == 'p_model': torch.nn.utils.clip_grad_norm_(p_model.parameters(), 0.5) p_loss.backward() p_optimizer.step() elif train_type == 's_model': torch.nn.utils.clip_grad_norm_(s_model.parameters(), 8) torch.nn.utils.clip_grad_norm_(seq_model.parameters(), 8) s_loss.backward() s_optimizer.step() for key, value in total_s_accuracy_top_k.items(): total_s_accuracy_top_k[key] = total_s_accuracy_top_k.get( key, 0) / decoder_input_count.data.item() return (total_seq_loss / decoder_input_count.float()).data.item(), ( total_p_loss / seq_count.float()).data.item(), ( select_count.float() / seq_count.float()).data.item(), total_s_accuracy_top_k
def forward(self, input_seq, input_line_length: torch.Tensor, input_line_token_length: torch.Tensor, input_length: torch.Tensor, adj_matrix, target_error_position, target_seq, target_length, do_sample=False): ''' :param input_seq: input sequence, torch.Tensor, full with token id in dictionary :param input_line_length: the number of lines per input sequence, [batch] :param input_line_token_length: the length of each lines, [batch, line] :param input_length: input token length include ast node, [batch] :param adj_matrix: :param target_seq: :param target_length: :param do_sample: :return: ''' teacher_forcing_ratio = 0 if do_sample else 1 embedded = self.embedding(input_seq) embedded = self.input_dropout(embedded) if self.graph_embedding is not None: copy_length = torch.sum(input_line_token_length, dim=-1) graph_embedded = self.graph_encoder.forward( adjacent_matrix=adj_matrix, copy_length=copy_length, input_seq=embedded) else: graph_embedded = embedded batch_token_sequence, batch_line_sequence, batch_line_hidden = self.line_encoder.forward( graph_embedded, input_line_token_length) # code_state: [num_layers* bi_num, batch_size, hidden_size] line_output_state, code_state = self.code_encoder( batch_line_sequence, input_line_length) line_mask = create_sequence_length_mask(input_line_length) error_position = self.position_pointer(line_output_state, line_mask) if do_sample: pos = torch.max(error_position, dim=-1)[1] else: pos = target_error_position error_line_hidden_list = [ batch_line_hidden[:, i, p] for i, p in enumerate(pos) ] error_line_hidden = torch.stack(tuple(error_line_hidden_list), dim=1) # error_line_hidden = batch_line_hidden[:, :, pos] combine_line_state = self.encoder_linear( torch.cat((error_line_hidden, code_state), dim=-1)) encoder_mask = create_sequence_length_mask( input_length, max_len=graph_embedded.shape[1]) # combine_line_state: [num_layers, batch, hidden_size] # batch_token_sequence: [batch, tokens, hidden_size] decoder_outputs, decoder_hidden, _ = self.decoder( inputs=target_seq, encoder_hidden=combine_line_state, encoder_outputs=graph_embedded, function=F.log_softmax, teacher_forcing_ratio=teacher_forcing_ratio, encoder_mask=~encoder_mask) decoder_outputs = torch.stack(decoder_outputs, dim=1) return error_position, decoder_outputs