Exemple #1
0
    def prepare_decoder(self, targets):
        """Prepares targets for transformer decoder."""
        shape = utils.shape_list(targets)
        # image should be [batch, height, width, channels]
        assert len(shape) == 4, "Image tensors should be 4-dimensional"

        # Shift positions
        targets = tf.reshape(targets,
                             [-1] + self.get_shape_for_decoder() + [1])
        targets = utils.right_shift_blockwise_nd(targets,
                                                 self.hparams.query_shape)

        # Add channel embeddings
        targets = tf.reshape(
            targets,
            [-1, self.frame_height, self.frame_width, self.num_channels])
        targets = utils.get_channel_embeddings(io_depth=self.num_channels,
                                               targets=targets,
                                               hidden_size=self.hidden_size)

        # add positional embeddings if needed
        if self.add_positional_emb:
            targets = utils.add_positional_embedding_nd(
                targets,
                max_length=max(self.frame_height, self.frame_width,
                               self.num_channels),
                name="pos_emb")
        targets = tf.reshape(targets, [-1] + self.get_shape_for_decoder() +
                             [self.hidden_size])
        return targets
    def prepare_decoder(self, targets):
        """Prepares targets for transformer decoder."""
        shape = utils.shape_list(targets)
        # sequence should be [batch, seq_length]
        assert len(shape) == 2, "Sequence tensors should be 2-dimensional"
        assert len(self.hparams.query_shape
                   ) == 1, "query shape should be 1-dimensional"

        # Mask random positions
        if self.hparams.target_dropout:
            targets = tf.where(
                tf.random.uniform(shape) < self.hparams.target_dropout,
                tf.zeros_like(targets), targets)
        # Shift positions
        targets = tf.expand_dims(targets, axis=-1)
        targets = utils.right_shift_blockwise_nd(targets,
                                                 self.hparams.query_shape)
        targets = tf.squeeze(targets, axis=-1)
        # Add token embeddings
        targets = utils.get_embeddings(targets=targets,
                                       hidden_size=self.hparams.embedding_dims,
                                       vocab_size=self.vocab_size)
        if self.hparams.dropout:
            targets = tf.nn.dropout(targets, 1 - self.hparams.dropout)
        targets = tf.layers.dense(targets,
                                  self.hidden_size,
                                  activation=None,
                                  name="emb_dense")
        if self.hparams.add_timing_signal:
            targets += utils.get_timing_signal_1d(
                self.hparams.max_target_length, self.hidden_size)
        return targets
 def multinomial_squeeze(self, logits, temperature=1.0):
   """multinomial sampling from logits."""
   logits_shape = utils.shape_list(logits)
   reshaped_logits = (tf.reshape(logits, [-1, logits_shape[-1]]) / temperature)
   choices = tf.multinomial(reshaped_logits, 1)
   choices = tf.reshape(choices, logits_shape[:-1])
   return tf.to_int32(choices)
 def infer_step(i, recent_output, recent_logits, cache, decoding_stats):
     """Inference step."""
     features_copy = features.copy()
     features_copy["targets"] = recent_output
     cur_sample, cur_logit = self.sample(features_copy,
                                         decode_step=i,
                                         cache=cache,
                                         decoding_stats=decoding_stats)
     pos = i
     samples = recent_output + tf.scatter_nd(
         indices=[[b, pos, 0, 0] for b in range(self.batch_size)],
         updates=cur_sample,
         shape=utils.shape_list(recent_output))
     logits = recent_logits + tf.scatter_nd(
         indices=[[b, pos] for b in range(self.batch_size)],
         updates=cur_logit,
         shape=utils.shape_list(recent_logits))
     return i + 1, samples, logits, cache, decoding_stats
  def process_partial_targets_decoding(self, targets):
    """Processes partially generated targets in decoding mode."""
    targets_shape = utils.shape_list(targets)
    seq_length = targets_shape[1]
    blocks_per_dim = [
        s // q for s, q in zip([seq_length], self.hparams.query_shape)
    ]
    targets = tf.reshape(
        targets, [targets_shape[0], -1,
                  np.prod(self.hparams.query_shape), 1])

    targets = utils.unflatten_blocks_nd(targets, blocks_per_dim)
    targets = utils.put_back_blocks_nd(targets, self.hparams.query_shape)
    targets = tf.reshape(targets, [-1, seq_length])
    return targets
  def process_partial_targets_decoding(self, targets):
    """Processes partially generated targets in decoding mode."""
    original_shape = self.get_shape_for_decoder()
    blocks_per_dim = [
        s // q for s, q in zip(original_shape, self.hparams.query_shape)
    ]
    targets_shape = utils.shape_list(targets)
    targets = tf.reshape(
        targets, [targets_shape[0], -1,
                  np.prod(self.hparams.query_shape), 1])

    targets = utils.unflatten_blocks_nd(targets, blocks_per_dim)
    targets = utils.put_back_blocks_nd(targets, self.hparams.query_shape)
    targets = tf.reshape(
        targets, [-1, self.frame_height, self.frame_width, self.num_channels])
    return targets
    def infer(self, features, **kwargs):
        with tf.variable_scope("sparse_transformer", reuse=tf.AUTO_REUSE):
            features = self.bottom(features)
        decode_length = self.hparams.max_target_length
        cache = {}
        decoding_stats = {}
        targets_old = features.get("targets")
        start_step = 0
        initial_output = tf.zeros((self.batch_size, decode_length, 1, 1),
                                  dtype=tf.int32)
        initial_logits = tf.zeros(
            (self.batch_size, decode_length, self.vocab_size))

        # call body once to initialize cache with representations of input frames.
        features["targets"] = initial_output
        # Set shape of inputs
        if "inputs" in features:
            features["inputs"].set_shape([
                self.batch_size, self.hparams.max_length, 1,
                self.hparams.hidden_size
            ])
        with tf.variable_scope("sparse_transformer/body", reuse=tf.AUTO_REUSE):
            self.body(features,
                      decode_step=None,
                      cache=cache,
                      decoding_stats=decoding_stats)

        def infer_step(i, recent_output, recent_logits, cache, decoding_stats):
            """Inference step."""
            features_copy = features.copy()
            features_copy["targets"] = recent_output
            cur_sample, cur_logit = self.sample(features_copy,
                                                decode_step=i,
                                                cache=cache,
                                                decoding_stats=decoding_stats)
            pos = i
            samples = recent_output + tf.scatter_nd(
                indices=[[b, pos, 0, 0] for b in range(self.batch_size)],
                updates=cur_sample,
                shape=utils.shape_list(recent_output))
            logits = recent_logits + tf.scatter_nd(
                indices=[[b, pos] for b in range(self.batch_size)],
                updates=cur_logit,
                shape=utils.shape_list(recent_logits))
            return i + 1, samples, logits, cache, decoding_stats

        def while_exit_cond(i, result, logits, cache, decoding_stats):  # pylint: disable=unused-argument
            """Exit the loop if it reaches decode_length."""
            not_overflow = i < decode_length
            return not_overflow

        _, final_result, final_logits, _, decoding_stats = tf.while_loop(
            while_exit_cond,
            infer_step, [
                start_step, initial_output, initial_logits, cache,
                decoding_stats
            ],
            back_prop=False,
            parallel_iterations=1)

        original_shape = [decode_length]

        blocks_per_dim = [
            s // q for s, q in zip(original_shape, self.hparams.query_shape)
        ]
        final_result_shape = utils.shape_list(final_result)
        final_result = tf.reshape(
            final_result,
            [final_result_shape[0], -1,
             np.prod(self.hparams.query_shape), 1])
        final_logits_shape = utils.shape_list(final_logits)
        final_logits = tf.reshape(final_logits, [
            final_logits_shape[0], -1,
            np.prod(self.hparams.query_shape), final_logits_shape[-1]
        ])
        final_result = utils.unflatten_blocks_nd(final_result, blocks_per_dim)
        final_result = utils.put_back_blocks_nd(final_result,
                                                self.hparams.query_shape)
        final_logits = utils.unflatten_blocks_nd(final_logits, blocks_per_dim)
        final_logits = utils.put_back_blocks_nd(final_logits,
                                                self.hparams.query_shape)

        for name, value in decoding_stats.items():
            tf.summary.scalar("decodes/%s" % name, value / decode_length)

        # Reassign targets back to the previous value.
        if targets_old is not None:
            features["targets"] = targets_old

        return {
            "outputs": final_result,
            "scores": None,
            "logits": final_logits,
            "losses": None,
        }
    def body(self,
             features,
             decode_step=None,
             cache=None,
             decoding_stats=None,
             add_summary=True):
        encoder_output = None
        extra_losses = []
        padding_bias = None
        if not self.hparams.fast_decode:
            decode_step = None
        if "inputs" in features:
            inputs = features["inputs"]
            # remove the last two dimensions that are always 1.
            inputs = tf.reshape(
                inputs,
                utils.shape_list(inputs)[:2] + [self.hidden_size])
            # Padding bias only used for seq2seq models.
            padding_bias = utils.embedding_to_padding(inputs)
            # Mask random positions
            shape = utils.shape_list(inputs)
            if self.hparams.input_dropout:
                inputs = tf.where(
                    tf.random.uniform(shape) < self.hparams.input_dropout,
                    tf.zeros_like(inputs), inputs)
            if self.hparams.add_timing_signal:
                inputs += utils.get_timing_signal_1d(self.hparams.max_length,
                                                     self.hidden_size)
            if cache is not None and -1 in cache:
                encoder_output = cache[-1]
            else:
                encoder_output = utils.transformer_encoder_layers(
                    inputs=inputs,
                    num_layers=self.num_encoder_layers,
                    hparams=self.hparams,
                    losses=extra_losses,
                    name="encoder",
                    token_bias=features.get("token_bias_inputs"),
                    padding_bias=padding_bias)
            if cache is not None and -1 not in cache:
                cache[-1] = encoder_output
        targets = tf.to_int32(features["targets"])
        # remove the last two dimensions that are always 1.
        targets = tf.reshape(targets, utils.shape_list(targets)[:2])
        # Clamp targets to max_target_length
        targets = targets[:, :self.hparams.max_target_length]
        if self.is_decode:
            targets = self.process_partial_targets_decoding(targets)
        decoder_input = self.prepare_decoder(targets)

        decoder_output = utils.transformer_decoder_layers(
            inputs=decoder_input,
            num_layers=self.num_decoder_layers,
            hparams=self.hparams,
            encoder_output=encoder_output,
            decode_step=decode_step,
            losses=extra_losses,
            cache=cache,
            name="decoder",
            decoding_stats=decoding_stats,
            token_bias_inputs=features.get("token_bias_inputs"),
            token_bias_targets=features.get("token_bias_targets"),
            padding_bias=padding_bias)
        logits = self.produce_output(decoder_output)

        # Return logits as-is in decoding mode
        if self.is_decode:
            return logits

        # Add cross entropy loss
        one_hot_targets = tf.one_hot(tf.cast(targets, dtype=tf.int32),
                                     self.vocab_size)
        x_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=one_hot_targets, logits=logits)
        weights = tf.to_float(tf.not_equal(targets, 0))
        loss = tf.reduce_sum(x_entropy * weights) / tf.reduce_sum(weights)
        if add_summary:
            tf.summary.scalar("losses/weight", tf.reduce_sum(weights))
            tf.summary.scalar("losses/x_entropy",
                              tf.reduce_sum(x_entropy * weights))

        loss_dict = {"training": loss}
        if extra_losses:
            loss_dict["extra_loss"] = tf.add_n(extra_losses)
        # hack for T2T metrics
        logits = tf.reshape(
            logits,
            utils.shape_list(logits)[:2] + [1, 1] +
            utils.shape_list(logits)[-1:])
        return logits, loss_dict
Exemple #9
0
    def infer(self, features, **kwargs):
        decode_length = (self.frame_height * self.frame_width *
                         self.num_channels)
        cache = {}
        decoding_stats = {}
        targets_old = features.get("targets", None)
        initial_output = tf.zeros((self.batch_size, decode_length),
                                  dtype=tf.int32)
        initial_logits = tf.zeros(
            (self.batch_size, decode_length, self.targets_vocab_size))
        # call body once to initialize cache with representations of input frames.
        features["targets"] = initial_output
        with tf.variable_scope("sparse_imagetransformer/body",
                               reuse=tf.AUTO_REUSE,
                               use_resource=True):
            self.body(features,
                      decode_step=None,
                      cache=cache,
                      decoding_stats=decoding_stats)

        def infer_step(i, recent_output, recent_logits, cache, decoding_stats):
            """Inference step."""
            features_copy = features.copy()
            features_copy["targets"] = recent_output
            cur_sample, cur_logit = self.sample(features_copy,
                                                decode_step=i,
                                                cache=cache,
                                                decoding_stats=decoding_stats)
            pos = i
            samples = recent_output + tf.scatter_nd(
                indices=[[b, pos] for b in range(self.batch_size)],
                updates=cur_sample,
                shape=utils.shape_list(recent_output))
            logits = recent_logits + tf.scatter_nd(
                indices=[[b, pos] for b in range(self.batch_size)],
                updates=cur_logit,
                shape=utils.shape_list(recent_logits))
            return i + 1, samples, logits, cache, decoding_stats

        def while_exit_cond(i, result, logits, cache, decoding_stats):  # pylint: disable=unused-argument
            """Exit the loop if it reaches decode_length."""
            not_overflow = i < decode_length
            return not_overflow

        _, final_result, final_logits, _, decoding_stats = tf.while_loop(
            while_exit_cond,
            infer_step, [
                tf.constant(0), initial_output, initial_logits, cache,
                decoding_stats
            ],
            back_prop=False,
            parallel_iterations=1)

        original_shape = self.get_shape_for_decoder()

        blocks_per_dim = [
            s // q for s, q in zip(original_shape, self.hparams.query_shape)
        ]
        final_result_shape = utils.shape_list(final_result)
        final_result = tf.reshape(
            final_result,
            [final_result_shape[0], -1,
             np.prod(self.hparams.query_shape), 1])
        final_logits_shape = utils.shape_list(final_logits)
        final_logits = tf.reshape(final_logits, [
            final_logits_shape[0], -1,
            np.prod(self.hparams.query_shape), final_logits_shape[-1]
        ])
        final_result = utils.unflatten_blocks_nd(final_result, blocks_per_dim)
        final_result = utils.put_back_blocks_nd(final_result,
                                                self.hparams.query_shape)
        final_logits = utils.unflatten_blocks_nd(final_logits, blocks_per_dim)
        final_logits = utils.put_back_blocks_nd(final_logits,
                                                self.hparams.query_shape)

        final_result = tf.reshape(
            final_result,
            [-1, self.frame_height, self.frame_width, self.num_channels])
        final_logits = tf.reshape(final_logits, [
            -1, self.frame_height, self.frame_width, self.num_channels,
            self.targets_vocab_size
        ])

        if utils.is_xla_compiled():
            _IMGS["decodes"] = final_result

        for name, value in decoding_stats.items():
            tf.summary.scalar("decodes/%s" % name, value / decode_length)

        # Reassign targets back to the previous value.
        if targets_old is not None:
            features["targets"] = targets_old

        return {
            "outputs": final_result,
            "scores": None,
            "logits": final_logits,
            "losses": None,
        }