def train_step(self,
                   interaction,
                   max_generation_length,
                   snippet_alignment_probability=1.,
                   db2id=None,
                   id2db=None,
                   step=None):
        """ Trains the interaction-level model on a single interaction.

        Args:
            interaction (Interaction): The interaction to train on.
            learning_rate (float): Learning rate to use.
            snippet_keep_age (int): Age of oldest snippets to use.
            snippet_alignment_probability (float): The probability that a snippet will
                be used in constructing the gold sequence.
        """
        # assert self.params.discourse_level_lstm

        losses = []
        total_gold_tokens = 0

        input_hidden_states = []
        input_sequences = []

        final_utterance_states_c = []
        final_utterance_states_h = []

        previous_query_states = []
        previous_queries = []

        decoder_states = []

        discourse_state = None
        if self.params.discourse_level_lstm:
            discourse_state, discourse_lstm_states = self._initialize_discourse_states(
            )
        discourse_states = []

        # Schema and schema embeddings
        input_schema = interaction.get_schema()
        schema_states = []

        if input_schema and not self.params.use_bert:
            schema_states = self.encode_schema_bow_simple(input_schema)

        # Get the intra-turn graph and cross-turn graph
        inner = []
        for i, ele in enumerate(
                interaction.interaction.schema.column_names_surface_form):
            for j in range(
                    i + 1,
                    len(interaction.interaction.schema.
                        column_names_surface_form)):
                if ele.split(
                        '.'
                )[0] == interaction.interaction.schema.column_names_surface_form[
                        j].split('.')[0]:
                    inner.append([i, j])
        adjacent_matrix = self.get_adj_matrix(
            inner, input_schema.table_schema['foreign_keys'],
            input_schema.num_col)
        adjacent_matrix_cross = self.get_adj_utterance_matrix(
            inner, input_schema.table_schema['foreign_keys'],
            input_schema.num_col)
        adjacent_matrix = paddle.to_tensor(adjacent_matrix)
        adjacent_matrix_cross = paddle.to_tensor(adjacent_matrix_cross)

        previous_schema_states = paddle.zeros(
            [input_schema.num_col, self.params.encoder_state_size])

        for utterance_index, utterance in enumerate(
                interaction.gold_utterances()):

            if interaction.identifier in LIMITED_INTERACTIONS and utterance_index > LIMITED_INTERACTIONS[
                    interaction.identifier]:
                break

            input_sequence = utterance.input_sequence()

            available_snippets = utterance.snippets()
            previous_query = utterance.previous_query()

            # Get the gold query: reconstruct if the alignment probability is less than one
            if snippet_alignment_probability < 1.:
                gold_query = sql_util.add_snippets_to_query(
                    available_snippets,
                    utterance.contained_entities(),
                    utterance.anonymized_gold_query(),
                    prob_align=snippet_alignment_probability) + [
                        vocab.EOS_TOK
                    ]
            else:
                gold_query = utterance.gold_query()

            final_utterance_state, utterance_states, schema_states = self.get_bert_encoding(
                input_sequence, input_schema, discourse_state, dropout=True)

            # temp1=final_utterance_state

            schema_states = paddle.stack(schema_states, axis=0)
            for i in range(self.params.gnn_layer_number):
                schema_states = self.gnn_history[2 * i](schema_states,
                                                        adjacent_matrix_cross,
                                                        previous_schema_states)
                schema_states = self.gnn_history[2 * i + 1](
                    schema_states, adjacent_matrix_cross,
                    previous_schema_states)
                schema_states = self.gnn[i](schema_states, adjacent_matrix)
            previous_schema_states = schema_states
            schema_states_ls = paddle.split(schema_states,
                                            schema_states.shape[0],
                                            axis=0)
            schema_states = [ele.squeeze(0) for ele in schema_states_ls]

            input_hidden_states.extend(utterance_states)
            input_sequences.append(input_sequence)

            num_utterances_to_keep = min(self.params.maximum_utterances,
                                         len(input_sequences))

            if self.params.discourse_level_lstm:
                discourse_state, discourse_lstm_states = self.discourse_lstms(
                    final_utterance_state[0].unsqueeze(0),
                    discourse_lstm_states)
                discourse_state = discourse_state.squeeze()

            if self.params.use_utterance_attention:

                final_utterance_states_c, final_utterance_states_h, final_utterance_state = self.get_utterance_attention(
                    final_utterance_states_c, final_utterance_states_h,
                    final_utterance_state, num_utterances_to_keep)

            if self.params.state_positional_embeddings:
                utterance_states, flat_sequence = self._add_positional_embeddings(
                    input_hidden_states, input_sequences)

            snippets = None

            if self.params.use_previous_query:
                if len(previous_query) > 0:
                    previous_queries, previous_query_states = self.get_previous_queries(
                        previous_queries, previous_query_states,
                        previous_query, input_schema)

            if len(gold_query) <= max_generation_length and len(
                    previous_query) <= max_generation_length:
                prediction = self.predict_turn(
                    final_utterance_state,
                    utterance_states,
                    schema_states,
                    max_generation_length,
                    gold_query=gold_query,
                    snippets=snippets,
                    input_sequence=flat_sequence,
                    previous_queries=previous_queries,
                    previous_query_states=previous_query_states,
                    input_schema=input_schema,
                    feed_gold_tokens=True,
                    training=True)
                loss = prediction[1]
                decoder_states = prediction[3]
                total_gold_tokens += len(gold_query)
                losses.append(loss)
            else:
                # Break if previous decoder snippet encoding -- because the previous
                # sequence was too long to run the decoder.
                if self.params.previous_decoder_snippet_encoding:
                    break
                continue

        if losses:
            average_loss = paddle.sum(paddle.stack(losses)) / total_gold_tokens
            print(f"total_gold_tokens:{total_gold_tokens}, step:{step}")
            print(f"LOSS:{float(average_loss.numpy())}")
            if paddle.sum(paddle.cast(paddle.isinf(average_loss),
                                      'int32')) == paddle.ones([1]):
                self.save("./inf_checkpoint")

            # Renormalize so the effect is normalized by the batch size.
            normalized_loss = average_loss
            if self.params.reweight_batch:
                normalized_loss = len(losses) * average_loss / float(
                    self.params.batch_size)

            normalized_loss.backward()

            if step <= self.params.warmup_step:
                self.set_learning_rate(step / self.params.warmup_step *
                                       self.params.initial_learning_rate)
            step += 1

            self.trainer.step()
            if self.params.fine_tune_bert:
                self.bert_trainer.step()
                self.bert_trainer.clear_grad()
            self.trainer.clear_grad()

            loss_scalar = float(normalized_loss.numpy())
            isNan = sum(
                paddle.cast(paddle.isnan(normalized_loss),
                            'float32').numpy().tolist()) == 0
            if paddle.isnan(normalized_loss):
                print("nan error but keep running")
            assert isNan

        else:
            loss_scalar = 0.

        return loss_scalar, step
Exemple #2
0
    def train_step(self, interaction, max_generation_length, snippet_alignment_probability=1.):
        """ Trains the interaction-level model on a single interaction.

        Inputs:
            interaction (Interaction): The interaction to train on.
            learning_rate (float): Learning rate to use.
            snippet_keep_age (int): Age of oldest snippets to use.
            snippet_alignment_probability (float): The probability that a snippet will
                be used in constructing the gold sequence.
        """
        # assert self.params.discourse_level_lstm

        losses = []
        total_gold_tokens = 0

        input_hidden_states = []
        input_sequences = []

        final_utterance_states_c = []
        final_utterance_states_h = []

        previous_query_states = []
        previous_queries = []

        decoder_states = []

        discourse_state = None
        if self.params.discourse_level_lstm:
            discourse_state, discourse_lstm_states = self._initialize_discourse_states()
        discourse_states = []

        # Schema and schema embeddings
        input_schema = interaction.get_schema()
        schema_states = []

        if input_schema and not self.params.use_bert:
            schema_states = self.encode_schema_bow_simple(input_schema)

        for utterance_index, utterance in enumerate(interaction.gold_utterances()):
            if interaction.identifier in LIMITED_INTERACTIONS and utterance_index > LIMITED_INTERACTIONS[interaction.identifier]:
                break

            input_sequence = utterance.input_sequence()

            available_snippets = utterance.snippets()
            previous_query = utterance.previous_query()

            # Get the gold query: reconstruct if the alignment probability is less than one
            if snippet_alignment_probability < 1.:
                gold_query = sql_util.add_snippets_to_query(
                    available_snippets,
                    utterance.contained_entities(),
                    utterance.anonymized_gold_query(),
                    prob_align=snippet_alignment_probability) + [vocab.EOS_TOK]
            else:
                gold_query = utterance.gold_query()

            # Encode the utterance, and update the discourse-level states
            if not self.params.use_bert:
                if self.params.discourse_level_lstm:
                    utterance_token_embedder = lambda token: torch.cat([self.input_embedder(token), discourse_state], dim=0)
                else:
                    utterance_token_embedder = self.input_embedder
                final_utterance_state, utterance_states = self.utterance_encoder(
                    input_sequence,
                    utterance_token_embedder,
                    dropout_amount=self.dropout)
            else:
                final_utterance_state, utterance_states, schema_states = self.get_bert_encoding(input_sequence, input_schema, discourse_state, dropout=True)

            input_hidden_states.extend(utterance_states)
            input_sequences.append(input_sequence)

            num_utterances_to_keep = min(self.params.maximum_utterances, len(input_sequences))

            # final_utterance_state[1][0] is the first layer's hidden states at the last time step (concat forward lstm and backward lstm)
            if self.params.discourse_level_lstm:
                _, discourse_state, discourse_lstm_states = torch_utils.forward_one_multilayer(self.discourse_lstms, final_utterance_state[1][0], discourse_lstm_states, self.dropout)

            if self.params.use_utterance_attention:
                final_utterance_states_c, final_utterance_states_h, final_utterance_state = self.get_utterance_attention(final_utterance_states_c, final_utterance_states_h, final_utterance_state, num_utterances_to_keep)

            if self.params.state_positional_embeddings:
                utterance_states, flat_sequence = self._add_positional_embeddings(input_hidden_states, input_sequences)
            else:
                flat_sequence = []
                for utt in input_sequences[-num_utterances_to_keep:]:
                    flat_sequence.extend(utt)

            snippets = None
            if self.params.use_snippets:
                if self.params.previous_decoder_snippet_encoding:
                    snippets = encode_snippets_with_states(available_snippets, decoder_states)
                else:
                    snippets = self._encode_snippets(previous_query, available_snippets, input_schema)

            if self.params.use_previous_query and len(previous_query) > 0:
                previous_queries, previous_query_states = self.get_previous_queries(previous_queries, previous_query_states, previous_query, input_schema)

            if len(gold_query) <= max_generation_length and len(previous_query) <= max_generation_length:
                prediction = self.predict_turn(final_utterance_state,
                                               utterance_states,
                                               schema_states,
                                               max_generation_length,
                                               gold_query=gold_query,
                                               snippets=snippets,
                                               input_sequence=flat_sequence,
                                               previous_queries=previous_queries,
                                               previous_query_states=previous_query_states,
                                               input_schema=input_schema,
                                               feed_gold_tokens=True,
                                               training=True)
                loss = prediction[1]
                decoder_states = prediction[3]
                total_gold_tokens += len(gold_query)
                losses.append(loss)
            else:
                # Break if previous decoder snippet encoding -- because the previous
                # sequence was too long to run the decoder.
                if self.params.previous_decoder_snippet_encoding:
                    break
                continue

            torch.cuda.empty_cache()

        if losses:
            average_loss = torch.sum(torch.stack(losses)) / total_gold_tokens

            # Renormalize so the effect is normalized by the batch size.
            normalized_loss = average_loss
            if self.params.reweight_batch:
                normalized_loss = len(losses) * average_loss / float(self.params.batch_size)

            normalized_loss.backward()
            self.trainer.step()
            if self.params.fine_tune_bert:
                self.bert_trainer.step()
            self.zero_grad()

            loss_scalar = normalized_loss.item()
        else:
            loss_scalar = 0.

        return loss_scalar
    def train_step(self,
                   interaction,
                   max_generation_length,
                   snippet_alignment_probability=1.,
                   db2id=None,
                   id2db=None,
                   step=None):
        """ Trains the interaction-level model on a single interaction.

        Inputs:
            interaction (Interaction): The interaction to train on.
            learning_rate (float): Learning rate to use.
            snippet_keep_age (int): Age of oldest snippets to use.
            snippet_alignment_probability (float): The probability that a snippet will
                be used in constructing the gold sequence.
        """
        # assert self.params.discourse_level_lstm

        losses = []
        total_gold_tokens = 0

        input_hidden_states = []
        input_sequences = []

        final_utterance_states_c = []
        final_utterance_states_h = []

        previous_query_states = []
        previous_queries = []

        decoder_states = []

        discourse_state = None
        if self.params.discourse_level_lstm:
            discourse_state, discourse_lstm_states = self._initialize_discourse_states(
            )
        discourse_states = []

        # Schema and schema embeddings
        input_schema = interaction.get_schema()
        schema_states = []

        if input_schema and not self.params.use_bert:
            schema_states = self.encode_schema_bow_simple(input_schema)

        # Get the intra-turn graph and cross-turn graph
        inner = []
        for i, ele in enumerate(
                interaction.interaction.schema.column_names_surface_form):
            for j in range(
                    i + 1,
                    len(interaction.interaction.schema.
                        column_names_surface_form)):
                if ele.split(
                        '.'
                )[0] == interaction.interaction.schema.column_names_surface_form[
                        j].split('.')[0]:
                    inner.append([i, j])
        adjacent_matrix = self.get_adj_matrix(
            inner, input_schema.table_schema['foreign_keys'],
            input_schema.num_col)
        adjacent_matrix_cross = self.get_adj_utterance_matrix(
            inner, input_schema.table_schema['foreign_keys'],
            input_schema.num_col)
        adjacent_matrix = torch.Tensor(adjacent_matrix).cuda()
        adjacent_matrix_cross = torch.Tensor(adjacent_matrix_cross).cuda()

        previous_schema_states = torch.zeros(
            input_schema.num_col, self.params.encoder_state_size).cuda()

        for utterance_index, utterance in enumerate(
                interaction.gold_utterances()):
            if interaction.identifier in LIMITED_INTERACTIONS and utterance_index > LIMITED_INTERACTIONS[
                    interaction.identifier]:
                break

            input_sequence = utterance.input_sequence()

            available_snippets = utterance.snippets()
            previous_query = utterance.previous_query()

            # Get the gold query: reconstruct if the alignment probability is less than one
            if snippet_alignment_probability < 1.:
                gold_query = sql_util.add_snippets_to_query(
                    available_snippets,
                    utterance.contained_entities(),
                    utterance.anonymized_gold_query(),
                    prob_align=snippet_alignment_probability) + [
                        vocab.EOS_TOK
                    ]
            else:
                gold_query = utterance.gold_query()

            # Encode the utterance, and update the discourse-level states
            if not self.params.use_bert:
                if self.params.discourse_level_lstm:
                    utterance_token_embedder = lambda token: torch.cat(
                        [self.input_embedder(token), discourse_state], dim=0)
                else:
                    utterance_token_embedder = self.input_embedder
                final_utterance_state, utterance_states = self.utterance_encoder(
                    input_sequence,
                    utterance_token_embedder,
                    dropout_amount=self.dropout)
            else:
                final_utterance_state, utterance_states, schema_states = self.get_bert_encoding(
                    input_sequence,
                    input_schema,
                    discourse_state,
                    dropout=True)

            schema_states = torch.stack(schema_states, dim=0)
            for i in range(self.params.gnn_layer_number):
                schema_states = self.gnn_history[2 * i](schema_states,
                                                        adjacent_matrix_cross,
                                                        previous_schema_states)
                schema_states = self.gnn_history[2 * i + 1](
                    schema_states, adjacent_matrix_cross,
                    previous_schema_states)
                schema_states = self.gnn[i](schema_states, adjacent_matrix)
            previous_schema_states = schema_states
            #schema_states = schema_states_ori + schema_states
            schema_states_ls = torch.split(schema_states, 1, dim=0)
            schema_states = [ele.squeeze(0) for ele in schema_states_ls]

            input_hidden_states.extend(utterance_states)
            input_sequences.append(input_sequence)

            num_utterances_to_keep = min(self.params.maximum_utterances,
                                         len(input_sequences))

            # final_utterance_state[1][0] is the first layer's hidden states at the last time step (concat forward lstm and backward lstm)
            if self.params.discourse_level_lstm:
                _, discourse_state, discourse_lstm_states = torch_utils.forward_one_multilayer(
                    self.discourse_lstms, final_utterance_state[1][0],
                    discourse_lstm_states, self.dropout)

            if self.params.use_utterance_attention:
                final_utterance_states_c, final_utterance_states_h, final_utterance_state = self.get_utterance_attention(
                    final_utterance_states_c, final_utterance_states_h,
                    final_utterance_state, num_utterances_to_keep)

            if self.params.state_positional_embeddings:
                utterance_states, flat_sequence = self._add_positional_embeddings(
                    input_hidden_states, input_sequences)
            else:
                flat_sequence = []
                for utt in input_sequences[-num_utterances_to_keep:]:
                    flat_sequence.extend(utt)

            snippets = None
            if self.params.use_snippets:
                if self.params.previous_decoder_snippet_encoding:
                    snippets = encode_snippets_with_states(
                        available_snippets, decoder_states)
                else:
                    snippets = self._encode_snippets(previous_query,
                                                     available_snippets,
                                                     input_schema)

            if self.params.use_previous_query and len(previous_query) > 0:
                previous_queries, previous_query_states = self.get_previous_queries(
                    previous_queries, previous_query_states, previous_query,
                    input_schema)

            if len(gold_query) <= max_generation_length and len(
                    previous_query) <= max_generation_length:
                prediction = self.predict_turn(
                    final_utterance_state,
                    utterance_states,
                    schema_states,
                    max_generation_length,
                    gold_query=gold_query,
                    snippets=snippets,
                    input_sequence=flat_sequence,
                    previous_queries=previous_queries,
                    previous_query_states=previous_query_states,
                    input_schema=input_schema,
                    feed_gold_tokens=True,
                    training=True)
                loss = prediction[1]
                decoder_states = prediction[3]
                total_gold_tokens += len(gold_query)
                losses.append(loss)
            else:
                # Break if previous decoder snippet encoding -- because the previous
                # sequence was too long to run the decoder.
                if self.params.previous_decoder_snippet_encoding:
                    break
                continue

            #torch.cuda.empty_cache()

        if losses:
            average_loss = torch.sum(torch.stack(losses)) / total_gold_tokens
            #average_loss = torch.sum(torch.stack(losses))
            print(average_loss.item(), total_gold_tokens, step)
            if torch.sum(torch.isinf(average_loss)).item() == 1:
                self.save("./inf_checkpoint")

            # Renormalize so the effect is normalized by the batch size.
            normalized_loss = average_loss
            if self.params.reweight_batch:
                normalized_loss = len(losses) * average_loss / float(
                    self.params.batch_size)
                #normalized_loss = average_loss / (len(losses) * float(self.params.batch_size))

            normalized_loss.backward()

            torch.nn.utils.clip_grad_norm_(self.parameters(), self.params.clip)

            if step <= self.params.warmup_step:
                self.set_learning_rate(step / self.params.warmup_step *
                                       self.params.initial_learning_rate)
            step += 1

            self.trainer.step()
            if self.params.fine_tune_bert:
                self.bert_trainer.step()
            self.zero_grad()

            loss_scalar = normalized_loss.item()
            #assert torch.isnan(normalized_loss).item() == 0
            if torch.isnan(normalized_loss).item() != 0:
                print("nan error but keep running")
        else:
            loss_scalar = 0.

        return loss_scalar, step