Exemplo n.º 1
0
    def dynamic_decode_and_search(self,
                                  embedding,
                                  start_tokens,
                                  end_token,
                                  vocab_size=None,
                                  initial_state=None,
                                  beam_width=5,
                                  length_penalty=0.0,
                                  maximum_iterations=250,
                                  memory=None,
                                  memory_sequence_len=None,
                                  memory_out_ids=None,
                                  mode=tf.estimator.ModeKeys.PREDICT):

        batch_size = tf.shape(start_tokens)[0]

        if initial_state is not None:
            initial_state = tf.contrib.seq2seq.tile_batch(
                initial_state, multiplier=beam_width)
        if memory is not None:
            memory = tf.contrib.seq2seq.tile_batch(
                memory, multiplier=beam_width)
        if memory_sequence_len is not None:
            memory_sequence_len = tf.contrib.seq2seq.tile_batch(
                memory_sequence_len, multiplier=beam_width)

        cell, initial_state = build_cell(
            self.num_units, self.num_layers,
            initial_state=initial_state,
            copy_state=self.copy_state,
            cell_fn=self.cell_fn,
            batch_size=batch_size * beam_width,
            output_dropout_rate=self.output_dropout_rate,
            attention_mechanism_fn=self.attention_mechanism_fn,
            memory=memory,
            memory_sequence_len=memory_sequence_len,
            mode=mode)

        if vocab_size is not None:
            projection_layer = tf.layers.Dense(vocab_size, use_bias=False)
        else:
            projection_layer = None

        decoder = tf.contrib.seq2seq.BeamSearchDecoder(
            cell,
            embedding,
            start_tokens,
            end_token,
            initial_state,
            beam_width,
            output_layer=projection_layer,
            length_penalty_weight=length_penalty)

        outputs, beam_state, length = tf.contrib.seq2seq.dynamic_decode(
            decoder, maximum_iterations=maximum_iterations)

        predicted_ids = outputs.predicted_ids[:, :, 0]
        log_probs = beam_state.log_probs[:, 0]

        return predicted_ids, log_probs
Exemplo n.º 2
0
    def dynamic_decode(self,
                       embedding,
                       start_tokens,
                       end_token,
                       vocab_size=None,
                       initial_state=None,
                       output_layer=None,
                       maximum_iterations=250,
                       memory=None,
                       memory_sequence_len=None,
                       memory_out_ids=None,
                       mode=tf.estimator.ModeKeys.PREDICT):

        cell, initial_state = build_cell(
            self.num_units,
            self.num_layers,
            initial_state=initial_state,
            copy_state=self.copy_state,
            cell_fn=self.cell_fn,
            output_dropout_rate=self.output_dropout_rate,
            attention_mechanism_fn=self.attention_mechanism_fn,
            memory=memory,
            memory_sequence_len=memory_sequence_len,
            mode=mode)

        if vocab_size is not None:
            projection_layer = tf.layers.Dense(vocab_size, use_bias=False)
        else:
            projection_layer = None
            vocab_size = self.num_units

        # helper and decoder
        helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
            embedding, start_tokens, end_token)

        extended_vocab_size = tf.maximum(
            tf.reduce_max(memory_out_ids) + 1, vocab_size)

        copying_mechanism = unnormalized_luong_attention(
            self.num_units, memory, memory_sequence_len)

        cell = CopyingWrapper(cell=cell,
                              copying_mechanism=copying_mechanism,
                              memory_out_ids=memory_out_ids,
                              extended_vocab_size=extended_vocab_size,
                              output_layer=projection_layer)

        initial_state = cell.zero_state(
            tf.shape(memory)[0], tf.float32).clone(cell_state=initial_state)

        decoder = tf.contrib.seq2seq.BasicDecoder(cell, helper, initial_state)

        # decode and extract logits and predictions
        outputs, state, _ = tf.contrib.seq2seq.dynamic_decode(
            decoder, maximum_iterations=maximum_iterations, swap_memory=True)
        logits = outputs.rnn_output
        predicted_ids = outputs.sample_id

        return predicted_ids, logits
Exemplo n.º 3
0
    def build(self, node_features_size: int, num_edge_types: int,
              mode=tf.estimator.ModeKeys.TRAIN) -> None:
        if self.encoder_type == "bidirectional_rnn":
            self.fwd_cell = build_cell(
                self.num_units/2, self.num_layers,
                cell_fn=self.cell_fn,
                output_dropout_rate=self.dropout_rate,
                # input_shape=tf.TensorShape([None, node_features_size]),
                mode=mode,
                name="fwd_cell")
            self.bwd_cell = build_cell(
                self.num_units/2, self.num_layers,
                cell_fn=self.cell_fn,
                output_dropout_rate=self.dropout_rate,
                # input_shape=tf.TensorShape([None, node_features_size]),
                mode=mode,
                name="bwd_cell")
        elif self.encoder_type == "rnn":
            self.rnn_cell = build_cell(
                self.num_units, self.num_layers,
                cell_fn=self.cell_fn,
                # input_shape=tf.TensorShape([None, node_features_size]),
                output_dropout_rate=self.dropout_rate,
                mode=mode)

        self.merge_layer = tf.layers.Dense(
            self.gnn_input_size if self.gnn_input_size is not None else self.num_units,
            use_bias=False)
        self.merge_layer.build(
            (None, None, node_features_size + self.num_units))
        self.base_graph_encoder.build(
            self.gnn_input_size if self.gnn_input_size is not None else self.num_units,
            num_edge_types)

        self.output_map = tf.layers.Dense(
            name="output_map",
            units=self.num_units,
            use_bias=False,
            kernel_initializer=eye_glorot)

        # same for state map
        # TODO: This is much easier to just do in the actual call, this will be cropped there
        # in the future needs to be put back in the build method

        self.built = True
Exemplo n.º 4
0
    def build(
            self,
            initial_node_features_size: int,
            num_edge_types: int,
            mode: tf.estimator.ModeKeys = tf.estimator.ModeKeys.TRAIN) -> None:
        """
        Args:
            initial_node_features_size: Dimensionality of initial node features;
              will be padded to size of node features used in GNN.
            num_edge_types: Number of edge types.
        """
        super().build(self.node_features_size, num_edge_types)
        self.initial_node_features_size = initial_node_features_size

        if self.initial_node_features_size > self.node_features_size:
            raise ValueError(
                "GGNN currently only support initial features size smaller "
                "than state size")

        self._message_weights = []  # type: List[tf.Variable]
        self._update_grus = []  # type: List[tf.nn.rnn_cell.GRUCell]
        self._edge_bias = []  # type: List[tf.Variable]
        for layer_id, _ in enumerate(self.timesteps_per_layer):
            # weights for the message_fun
            self._message_weights.append(
                tf.get_variable(name="message_weights_%d" % layer_id,
                                shape=(self.num_edge_types,
                                       self.node_features_size,
                                       self.node_features_size),
                                initializer=tf.glorot_normal_initializer()))

            # bias for the message_fun
            if self.use_edge_bias:
                self._edge_bias.append(
                    tf.get_variable(
                        "edge_bias_%d" % layer_id,
                        (self.num_edge_types, self.node_features_size)))

            # gru for the update_fun
            cell = build_cell(num_units=self.node_features_size,
                              num_layers=1,
                              cell_fn=tf.nn.rnn_cell.GRUCell,
                              output_dropout_rate=self.gru_dropout_rate,
                              input_shape=tf.TensorShape(
                                  [None, self.node_features_size]),
                              name="update_gru_%d" % layer_id,
                              mode=mode)
            self._update_grus.append(cell)

        # final linear layer
        self.final_layer = tf.layers.Dense(self.node_features_size,
                                           use_bias=False)
        self.final_layer.build((None, self.node_features_size))
Exemplo n.º 5
0
    def dynamic_decode(self,
                       embedding,
                       start_tokens,
                       end_token,
                       vocab_size=None,
                       initial_state=None,
                       maximum_iterations=250,
                       memory=None,
                       memory_sequence_len=None,
                       memory_out_ids=None,
                       mode=tf.estimator.ModeKeys.PREDICT):

        cell, initial_state = build_cell(
            self.num_units, self.num_layers,
            initial_state=initial_state,
            cell_fn=self.cell_fn,
            copy_state=self.copy_state,
            output_dropout_rate=self.output_dropout_rate,
            attention_mechanism_fn=self.attention_mechanism_fn,
            memory=memory,
            memory_sequence_len=memory_sequence_len,
            mode=mode)

        if vocab_size is not None:
            projection_layer = tf.layers.Dense(vocab_size, use_bias=False)
        else:
            projection_layer = None

        # helper and decoder
        helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
            embedding, start_tokens, end_token)
        decoder = tf.contrib.seq2seq.BasicDecoder(
            cell, helper, initial_state,
            output_layer=projection_layer)

        # decode and extract logits and predictions
        outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
            decoder,
            maximum_iterations=maximum_iterations,
            swap_memory=True)
        logits = outputs.rnn_output
        ids = outputs.sample_id

        return ids, logits
Exemplo n.º 6
0
    def decode(self,
               inputs: tf.Tensor,
               sequence_length: tf.Tensor,
               vocab_size: int = None,
               initial_state: tf.Tensor = None,
               sampling_probability=None,
               embedding=None,
               memory=None,
               memory_sequence_len=None,
               memory_out_ids=None,
               mode=tf.estimator.ModeKeys.TRAIN):
        print("decode input:", inputs)
        if (sampling_probability is not None and
                (isinstance(sampling_probability, tf.Tensor) or sampling_probability > 0.0)):
            if embedding is None:
                raise ValueError(
                    "embedding argument must be set when using scheduled sampling")

            tf.summary.scalar("sampling_probability", sampling_probability)
            helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper(
                inputs,
                sequence_length,
                embedding,
                sampling_probability)
        else:
            helper = tf.contrib.seq2seq.TrainingHelper(inputs, sequence_length)

        cell, initial_state = build_cell(
            self.num_units, self.num_layers,
            initial_state=initial_state,
            copy_state=self.copy_state,
            cell_fn=self.cell_fn,
            output_dropout_rate=self.output_dropout_rate,
            attention_mechanism_fn=self.attention_mechanism_fn,
            memory=memory,
            memory_sequence_len=memory_sequence_len,
            mode=mode,
            alignment_history=self.coverage_loss_lambda > 0)

        if vocab_size is not None:
            projection_layer = tf.layers.Dense(vocab_size, use_bias=False)
        else:
            projection_layer = None

        decoder = tf.contrib.seq2seq.BasicDecoder(
            cell, helper, initial_state,
            output_layer=projection_layer)

        # decode and extract logits and predictions
        outputs, state, _ = tf.contrib.seq2seq.dynamic_decode(
            decoder,
            swap_memory=True)
        logits = outputs.rnn_output
        ids = outputs.sample_id

        if hasattr(state, 'alignment_history') and \
           not isinstance(state.alignment_history, tuple):
            attention = tf.transpose(
                state.alignment_history.stack(), (1, 0, 2))
            decoder_loss = self.coverage_loss(attention, memory_sequence_len)
        else:
            decoder_loss = None

        return ids, logits, decoder_loss
Exemplo n.º 7
0
    def decode(self,
               inputs: tf.Tensor,
               sequence_length: tf.Tensor,
               vocab_size: int = None,
               initial_state: tf.Tensor = None,
               sampling_probability=None,
               embedding=None,
               memory=None,
               memory_sequence_len=None,
               memory_out_ids=None,
               mode=tf.estimator.ModeKeys.TRAIN):
        if (sampling_probability is not None
                and (isinstance(sampling_probability, tf.Tensor)
                     or sampling_probability > 0.0)):
            if embedding is None:
                raise ValueError(
                    "embedding argument must be set when using scheduled sampling"
                )

            tf.summary.scalar("sampling_probability", sampling_probability)
            helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper(
                inputs, sequence_length, embedding, sampling_probability)
        else:
            helper = tf.contrib.seq2seq.TrainingHelper(inputs, sequence_length)

        cell, initial_state = build_cell(
            self.num_units,
            self.num_layers,
            initial_state=initial_state,
            copy_state=self.copy_state,
            cell_fn=self.cell_fn,
            output_dropout_rate=self.output_dropout_rate,
            attention_mechanism_fn=self.attention_mechanism_fn,
            memory=memory,
            memory_sequence_len=memory_sequence_len,
            mode=mode,
            alignment_history=self.coverage_loss_lambda > 0)

        if vocab_size is not None:
            projection_layer = tf.layers.Dense(vocab_size, use_bias=False)
        else:
            projection_layer = None
            vocab_size = self.num_units

        # helper and decode
        helper = tf.contrib.seq2seq.TrainingHelper(inputs,
                                                   sequence_length,
                                                   time_major=False)

        extended_vocab_size = tf.maximum(
            tf.reduce_max(memory_out_ids) + 1, tf.cast(vocab_size, tf.int64))

        copying_mechanism = unnormalized_luong_attention(
            self.num_units, memory, memory_sequence_len)

        cell = CopyingWrapper(cell=cell,
                              copying_mechanism=copying_mechanism,
                              memory_out_ids=memory_out_ids,
                              extended_vocab_size=extended_vocab_size,
                              output_layer=projection_layer)

        initial_state = cell.zero_state(
            tf.shape(memory)[0], tf.float32).clone(cell_state=initial_state)

        decoder = tf.contrib.seq2seq.BasicDecoder(cell, helper, initial_state)

        outputs, state, _ = tf.contrib.seq2seq.dynamic_decode(decoder,
                                                              swap_memory=True)

        print(state.cell_state.alignment_history)
        if hasattr(state.cell_state, 'alignment_history') and \
                not isinstance(state.cell_state.alignment_history, tuple):
            attention = tf.transpose(
                state.cell_state.alignment_history.stack(), (1, 0, 2))
            decoder_loss = self.coverage_loss(attention, memory_sequence_len)
        else:
            decoder_loss = None

        logits = outputs.rnn_output
        ids = outputs.sample_id

        return ids, logits, decoder_loss