def get_loss(self, model, batch_data, pred_dict, train=True, flag = 0): n_support_train = self.args.n_shot_train n_support_test = self.args.n_shot_test n_query = self.args.n_query if not train: losses_adapt = self.criterion(pred_dict['s_logits'].reshape((2*n_support_test*n_query,2)), paddle.expand(batch_data['s_label'],[n_query,n_support_test*2]).reshape((1,2*n_support_test*n_query)).squeeze(0)) else: if flag: losses_adapt = self.criterion(pred_dict['s_logits'].reshape((2*n_support_train*n_query,2)), paddle.expand(batch_data['s_label'],[n_query,n_support_train*2]).reshape((1,2*n_support_train*n_query)).squeeze(0)) else: losses_adapt = self.criterion(pred_dict['q_logits'], batch_data['q_label']) if paddle.isnan(losses_adapt).any() or paddle.isinf(losses_adapt).any(): print('!!!!!!!!!!!!!!!!!!! Nan value for supervised CE loss', losses_adapt) print(pred_dict['s_logits']) losses_adapt = paddle.zeros_like(losses_adapt) if self.args.reg_adj > 0: n_support = batch_data['s_label'].shape[0] adj = pred_dict['adj'][-1] if train: if flag: s_label = paddle.expand(batch_data['s_label'], [n_query,batch_data['s_label'].shape[0]]) n_d = n_query * n_support label_edge = model.layers.label2edge(s_label).reshape((n_d, -1)) pred_edge = adj[:,:,:-1,:-1].reshape((n_d, -1)) else: s_label = paddle.expand(batch_data['s_label'], [n_query,batch_data['s_label'].shape[0]]) q_label = batch_data['q_label'].unsqueeze(1) total_label = paddle.concat([s_label, q_label], 1) label_edge = model.layers.label2edge(total_label)[:,:,-1,:-1] pred_edge = adj[:,:,-1,:-1] else: s_label = batch_data['s_label'].unsqueeze(0) n_d = n_support * self.args.rel_edge label_edge = model.layers.label2edge(s_label).reshape((n_d, -1)) pred_edge = adj[:, :, :n_support, :n_support].mean(0).reshape((n_d, -1)) adj_loss_val = F.mse_loss(pred_edge, label_edge) if paddle.isnan(adj_loss_val).any() or paddle.isinf(adj_loss_val).any(): print('!!!!!!!!!!!!!!!!!!! Nan value for adjacency loss', adj_loss_val) adj_loss_val = paddle.zeros_like(adj_loss_val) losses_adapt += self.args.reg_adj * adj_loss_val return losses_adapt
def get_loss(self, model, batch_data, pred_dict, train=True): if not train and self.update_s_q: losses_adapt = self.criterion(pred_dict['s_logits'], batch_data['s_label']) else: losses_adapt = self.criterion(pred_dict['logits'], batch_data['label']) if paddle.isnan(losses_adapt).any() or paddle.isinf( losses_adapt).any(): print('!!!!!!!!!!!!!!!!!!! Nan value for supervised CE loss', losses_adapt) print(pred_dict['s_logits']) losses_adapt = paddle.zeros_like(losses_adapt) if self.args.reg_adj > 0: n_support = batch_data['s_label'].shape[0] adj = pred_dict['adj'][-1] if train: n_query = batch_data['q_label'].shape[0] s_label = paddle.expand( batch_data['s_label'], [n_query, batch_data['s_label'].shape[0]]) q_label = batch_data['q_label'].unsqueeze(1) total_label = paddle.concat([s_label, q_label], 1) n_d = n_query * self.args.rel_edge * (n_support + 1) label_edge = model.layers.label2edge(total_label).reshape( (n_d, -1)) pred_edge = adj.reshape((n_d, -1)) else: s_label = batch_data['s_label'].unsqueeze(0) n_d = n_support * self.args.rel_edge label_edge = model.layers.label2edge(s_label).reshape( (n_d, -1)) pred_edge = adj[:, :, :n_support, :n_support].mean(0).reshape( (n_d, -1)) adj_loss_val = F.mse_loss(pred_edge, label_edge) if paddle.isnan(adj_loss_val).any() or paddle.isinf( adj_loss_val).any(): print('!!!!!!!!!!!!!!!!!!! Nan value for adjacency loss', adj_loss_val) adj_loss_val = paddle.zeros_like(adj_loss_val) losses_adapt += self.args.reg_adj * adj_loss_val return losses_adapt
def set_grad(params, params_with_grad, scale=1.0): for param, param_w_grad in zip(params, params_with_grad): if param.grad is None: param.grad = paddle.ParamAttr( param.data.new().resize_(*param.data.shape())) grad = param_w_grad.grad.data if scale is not None: grad /= scale if paddle.isnan(grad).any() or paddle.isinf(grad).any(): return True # invalid grad param.grad.data.copy_(grad) return False
def body(i, ls_func_calls, a1, a2, phi_1, derf_1, done): phi_2, derf_2, derphi_2 = phi_and_derphi(a2) paddle.assign(ls_func_calls + 1, ls_func_calls) paddle.assign(done | paddle.any(paddle.isinf(phi_2)), done) def true_fn1(): j = zoom(a1, phi_1, derphi_1, derf_1, a2, phi_2, derphi_2, phi_0, derphi_0) paddle.assign(a1, a_star) paddle.assign(phi_1, phi_star) paddle.assign(derf_1, derf_star) paddle.assign(ls_func_calls + j, ls_func_calls) pred1 = ~done & ((phi_2 > phi_0 + c1 * a2 * derphi_0) | ((phi_2 >= phi_0) & (i > 1))) paddle.assign(done | pred1, done) paddle.static.nn.cond(pred1, true_fn1, None) def true_fn2(): paddle.assign(a2, a_star) paddle.assign(phi_2, phi_star) paddle.assign(derf_2, derf_star) pred2 = ~done & (paddle.abs(derphi_2) <= -c2 * derphi_0) paddle.assign(done | pred2, done) paddle.static.nn.cond(pred2, true_fn2, None) def true_fn3(): j = zoom(a2, phi_2, derphi_2, derf_2, a1, phi_1, derphi_1, phi_0, derphi_0) paddle.assign(a2, a_star) paddle.assign(phi_2, phi_star) paddle.assign(derf_2, derf_star) paddle.assign(ls_func_calls + j, ls_func_calls) pred3 = ~done & (derphi_2 >= 0) paddle.assign(done | pred3, done) paddle.static.nn.cond(pred3, true_fn3, None) def false_fn(): paddle.assign(a2, a1) paddle.assign(phi_2, phi_1) paddle.assign(derf_2, derf_1) paddle.assign(paddle.minimum(2 * a2, alpha_max), a2) paddle.assign(i + 1, i) paddle.static.nn.cond(done, None, false_fn) return [i, ls_func_calls, a1, a2, phi_1, derf_1, done]
def log_prob(self, value): """probability mass function evaluated at value Args: value (Tensor): value to be evaluated. Returns: Tensor: probability of value. """ if paddle.is_integer(value): value = paddle.cast(value, self.probs.dtype) logits, value = paddle.broadcast_tensors( [paddle.log(self.probs), value]) logits[(value == 0) & (paddle.isinf(logits))] = 0 return (paddle.lgamma(value.sum(-1) + 1) - paddle.lgamma(value + 1).sum(-1) + (value * logits).sum(-1))
def forward(self, inputs): return paddle.cast(paddle.isinf(inputs), "int32")
def forward(self, inputs): """ forward """ x = paddle.isinf(inputs) return x.astype('float32')
def train_step(self, interaction, max_generation_length, snippet_alignment_probability=1., db2id=None, id2db=None, step=None): """ Trains the interaction-level model on a single interaction. Args: interaction (Interaction): The interaction to train on. learning_rate (float): Learning rate to use. snippet_keep_age (int): Age of oldest snippets to use. snippet_alignment_probability (float): The probability that a snippet will be used in constructing the gold sequence. """ # assert self.params.discourse_level_lstm losses = [] total_gold_tokens = 0 input_hidden_states = [] input_sequences = [] final_utterance_states_c = [] final_utterance_states_h = [] previous_query_states = [] previous_queries = [] decoder_states = [] discourse_state = None if self.params.discourse_level_lstm: discourse_state, discourse_lstm_states = self._initialize_discourse_states( ) discourse_states = [] # Schema and schema embeddings input_schema = interaction.get_schema() schema_states = [] if input_schema and not self.params.use_bert: schema_states = self.encode_schema_bow_simple(input_schema) # Get the intra-turn graph and cross-turn graph inner = [] for i, ele in enumerate( interaction.interaction.schema.column_names_surface_form): for j in range( i + 1, len(interaction.interaction.schema. column_names_surface_form)): if ele.split( '.' )[0] == interaction.interaction.schema.column_names_surface_form[ j].split('.')[0]: inner.append([i, j]) adjacent_matrix = self.get_adj_matrix( inner, input_schema.table_schema['foreign_keys'], input_schema.num_col) adjacent_matrix_cross = self.get_adj_utterance_matrix( inner, input_schema.table_schema['foreign_keys'], input_schema.num_col) adjacent_matrix = paddle.to_tensor(adjacent_matrix) adjacent_matrix_cross = paddle.to_tensor(adjacent_matrix_cross) previous_schema_states = paddle.zeros( [input_schema.num_col, self.params.encoder_state_size]) for utterance_index, utterance in enumerate( interaction.gold_utterances()): if interaction.identifier in LIMITED_INTERACTIONS and utterance_index > LIMITED_INTERACTIONS[ interaction.identifier]: break input_sequence = utterance.input_sequence() available_snippets = utterance.snippets() previous_query = utterance.previous_query() # Get the gold query: reconstruct if the alignment probability is less than one if snippet_alignment_probability < 1.: gold_query = sql_util.add_snippets_to_query( available_snippets, utterance.contained_entities(), utterance.anonymized_gold_query(), prob_align=snippet_alignment_probability) + [ vocab.EOS_TOK ] else: gold_query = utterance.gold_query() final_utterance_state, utterance_states, schema_states = self.get_bert_encoding( input_sequence, input_schema, discourse_state, dropout=True) # temp1=final_utterance_state schema_states = paddle.stack(schema_states, axis=0) for i in range(self.params.gnn_layer_number): schema_states = self.gnn_history[2 * i](schema_states, adjacent_matrix_cross, previous_schema_states) schema_states = self.gnn_history[2 * i + 1]( schema_states, adjacent_matrix_cross, previous_schema_states) schema_states = self.gnn[i](schema_states, adjacent_matrix) previous_schema_states = schema_states schema_states_ls = paddle.split(schema_states, schema_states.shape[0], axis=0) schema_states = [ele.squeeze(0) for ele in schema_states_ls] input_hidden_states.extend(utterance_states) input_sequences.append(input_sequence) num_utterances_to_keep = min(self.params.maximum_utterances, len(input_sequences)) if self.params.discourse_level_lstm: discourse_state, discourse_lstm_states = self.discourse_lstms( final_utterance_state[0].unsqueeze(0), discourse_lstm_states) discourse_state = discourse_state.squeeze() if self.params.use_utterance_attention: final_utterance_states_c, final_utterance_states_h, final_utterance_state = self.get_utterance_attention( final_utterance_states_c, final_utterance_states_h, final_utterance_state, num_utterances_to_keep) if self.params.state_positional_embeddings: utterance_states, flat_sequence = self._add_positional_embeddings( input_hidden_states, input_sequences) snippets = None if self.params.use_previous_query: if len(previous_query) > 0: previous_queries, previous_query_states = self.get_previous_queries( previous_queries, previous_query_states, previous_query, input_schema) if len(gold_query) <= max_generation_length and len( previous_query) <= max_generation_length: prediction = self.predict_turn( final_utterance_state, utterance_states, schema_states, max_generation_length, gold_query=gold_query, snippets=snippets, input_sequence=flat_sequence, previous_queries=previous_queries, previous_query_states=previous_query_states, input_schema=input_schema, feed_gold_tokens=True, training=True) loss = prediction[1] decoder_states = prediction[3] total_gold_tokens += len(gold_query) losses.append(loss) else: # Break if previous decoder snippet encoding -- because the previous # sequence was too long to run the decoder. if self.params.previous_decoder_snippet_encoding: break continue if losses: average_loss = paddle.sum(paddle.stack(losses)) / total_gold_tokens print(f"total_gold_tokens:{total_gold_tokens}, step:{step}") print(f"LOSS:{float(average_loss.numpy())}") if paddle.sum(paddle.cast(paddle.isinf(average_loss), 'int32')) == paddle.ones([1]): self.save("./inf_checkpoint") # Renormalize so the effect is normalized by the batch size. normalized_loss = average_loss if self.params.reweight_batch: normalized_loss = len(losses) * average_loss / float( self.params.batch_size) normalized_loss.backward() if step <= self.params.warmup_step: self.set_learning_rate(step / self.params.warmup_step * self.params.initial_learning_rate) step += 1 self.trainer.step() if self.params.fine_tune_bert: self.bert_trainer.step() self.bert_trainer.clear_grad() self.trainer.clear_grad() loss_scalar = float(normalized_loss.numpy()) isNan = sum( paddle.cast(paddle.isnan(normalized_loss), 'float32').numpy().tolist()) == 0 if paddle.isnan(normalized_loss): print("nan error but keep running") assert isNan else: loss_scalar = 0. return loss_scalar, step