def process_partial_targets_decoding(self, targets):
    """Processes partially generated targets in decoding mode."""
    targets_shape = utils.shape_list(targets)
    seq_length = targets_shape[1]
    blocks_per_dim = [
        s // q for s, q in zip([seq_length], self.hparams.query_shape)
    ]
    targets = tf.reshape(
        targets, [targets_shape[0], -1,
                  np.prod(self.hparams.query_shape), 1])

    targets = utils.unflatten_blocks_nd(targets, blocks_per_dim)
    targets = utils.put_back_blocks_nd(targets, self.hparams.query_shape)
    targets = tf.reshape(targets, [-1, seq_length])
    return targets
  def process_partial_targets_decoding(self, targets):
    """Processes partially generated targets in decoding mode."""
    original_shape = self.get_shape_for_decoder()
    blocks_per_dim = [
        s // q for s, q in zip(original_shape, self.hparams.query_shape)
    ]
    targets_shape = utils.shape_list(targets)
    targets = tf.reshape(
        targets, [targets_shape[0], -1,
                  np.prod(self.hparams.query_shape), 1])

    targets = utils.unflatten_blocks_nd(targets, blocks_per_dim)
    targets = utils.put_back_blocks_nd(targets, self.hparams.query_shape)
    targets = tf.reshape(
        targets, [-1, self.frame_height, self.frame_width, self.num_channels])
    return targets
    def infer(self, features, **kwargs):
        with tf.variable_scope("sparse_transformer", reuse=tf.AUTO_REUSE):
            features = self.bottom(features)
        decode_length = self.hparams.max_target_length
        cache = {}
        decoding_stats = {}
        targets_old = features.get("targets")
        start_step = 0
        initial_output = tf.zeros((self.batch_size, decode_length, 1, 1),
                                  dtype=tf.int32)
        initial_logits = tf.zeros(
            (self.batch_size, decode_length, self.vocab_size))

        # call body once to initialize cache with representations of input frames.
        features["targets"] = initial_output
        # Set shape of inputs
        if "inputs" in features:
            features["inputs"].set_shape([
                self.batch_size, self.hparams.max_length, 1,
                self.hparams.hidden_size
            ])
        with tf.variable_scope("sparse_transformer/body", reuse=tf.AUTO_REUSE):
            self.body(features,
                      decode_step=None,
                      cache=cache,
                      decoding_stats=decoding_stats)

        def infer_step(i, recent_output, recent_logits, cache, decoding_stats):
            """Inference step."""
            features_copy = features.copy()
            features_copy["targets"] = recent_output
            cur_sample, cur_logit = self.sample(features_copy,
                                                decode_step=i,
                                                cache=cache,
                                                decoding_stats=decoding_stats)
            pos = i
            samples = recent_output + tf.scatter_nd(
                indices=[[b, pos, 0, 0] for b in range(self.batch_size)],
                updates=cur_sample,
                shape=utils.shape_list(recent_output))
            logits = recent_logits + tf.scatter_nd(
                indices=[[b, pos] for b in range(self.batch_size)],
                updates=cur_logit,
                shape=utils.shape_list(recent_logits))
            return i + 1, samples, logits, cache, decoding_stats

        def while_exit_cond(i, result, logits, cache, decoding_stats):  # pylint: disable=unused-argument
            """Exit the loop if it reaches decode_length."""
            not_overflow = i < decode_length
            return not_overflow

        _, final_result, final_logits, _, decoding_stats = tf.while_loop(
            while_exit_cond,
            infer_step, [
                start_step, initial_output, initial_logits, cache,
                decoding_stats
            ],
            back_prop=False,
            parallel_iterations=1)

        original_shape = [decode_length]

        blocks_per_dim = [
            s // q for s, q in zip(original_shape, self.hparams.query_shape)
        ]
        final_result_shape = utils.shape_list(final_result)
        final_result = tf.reshape(
            final_result,
            [final_result_shape[0], -1,
             np.prod(self.hparams.query_shape), 1])
        final_logits_shape = utils.shape_list(final_logits)
        final_logits = tf.reshape(final_logits, [
            final_logits_shape[0], -1,
            np.prod(self.hparams.query_shape), final_logits_shape[-1]
        ])
        final_result = utils.unflatten_blocks_nd(final_result, blocks_per_dim)
        final_result = utils.put_back_blocks_nd(final_result,
                                                self.hparams.query_shape)
        final_logits = utils.unflatten_blocks_nd(final_logits, blocks_per_dim)
        final_logits = utils.put_back_blocks_nd(final_logits,
                                                self.hparams.query_shape)

        for name, value in decoding_stats.items():
            tf.summary.scalar("decodes/%s" % name, value / decode_length)

        # Reassign targets back to the previous value.
        if targets_old is not None:
            features["targets"] = targets_old

        return {
            "outputs": final_result,
            "scores": None,
            "logits": final_logits,
            "losses": None,
        }
Beispiel #4
0
    def infer(self, features, **kwargs):
        decode_length = (self.frame_height * self.frame_width *
                         self.num_channels)
        cache = {}
        decoding_stats = {}
        targets_old = features.get("targets", None)
        initial_output = tf.zeros((self.batch_size, decode_length),
                                  dtype=tf.int32)
        initial_logits = tf.zeros(
            (self.batch_size, decode_length, self.targets_vocab_size))
        # call body once to initialize cache with representations of input frames.
        features["targets"] = initial_output
        with tf.variable_scope("sparse_imagetransformer/body",
                               reuse=tf.AUTO_REUSE,
                               use_resource=True):
            self.body(features,
                      decode_step=None,
                      cache=cache,
                      decoding_stats=decoding_stats)

        def infer_step(i, recent_output, recent_logits, cache, decoding_stats):
            """Inference step."""
            features_copy = features.copy()
            features_copy["targets"] = recent_output
            cur_sample, cur_logit = self.sample(features_copy,
                                                decode_step=i,
                                                cache=cache,
                                                decoding_stats=decoding_stats)
            pos = i
            samples = recent_output + tf.scatter_nd(
                indices=[[b, pos] for b in range(self.batch_size)],
                updates=cur_sample,
                shape=utils.shape_list(recent_output))
            logits = recent_logits + tf.scatter_nd(
                indices=[[b, pos] for b in range(self.batch_size)],
                updates=cur_logit,
                shape=utils.shape_list(recent_logits))
            return i + 1, samples, logits, cache, decoding_stats

        def while_exit_cond(i, result, logits, cache, decoding_stats):  # pylint: disable=unused-argument
            """Exit the loop if it reaches decode_length."""
            not_overflow = i < decode_length
            return not_overflow

        _, final_result, final_logits, _, decoding_stats = tf.while_loop(
            while_exit_cond,
            infer_step, [
                tf.constant(0), initial_output, initial_logits, cache,
                decoding_stats
            ],
            back_prop=False,
            parallel_iterations=1)

        original_shape = self.get_shape_for_decoder()

        blocks_per_dim = [
            s // q for s, q in zip(original_shape, self.hparams.query_shape)
        ]
        final_result_shape = utils.shape_list(final_result)
        final_result = tf.reshape(
            final_result,
            [final_result_shape[0], -1,
             np.prod(self.hparams.query_shape), 1])
        final_logits_shape = utils.shape_list(final_logits)
        final_logits = tf.reshape(final_logits, [
            final_logits_shape[0], -1,
            np.prod(self.hparams.query_shape), final_logits_shape[-1]
        ])
        final_result = utils.unflatten_blocks_nd(final_result, blocks_per_dim)
        final_result = utils.put_back_blocks_nd(final_result,
                                                self.hparams.query_shape)
        final_logits = utils.unflatten_blocks_nd(final_logits, blocks_per_dim)
        final_logits = utils.put_back_blocks_nd(final_logits,
                                                self.hparams.query_shape)

        final_result = tf.reshape(
            final_result,
            [-1, self.frame_height, self.frame_width, self.num_channels])
        final_logits = tf.reshape(final_logits, [
            -1, self.frame_height, self.frame_width, self.num_channels,
            self.targets_vocab_size
        ])

        if utils.is_xla_compiled():
            _IMGS["decodes"] = final_result

        for name, value in decoding_stats.items():
            tf.summary.scalar("decodes/%s" % name, value / decode_length)

        # Reassign targets back to the previous value.
        if targets_old is not None:
            features["targets"] = targets_old

        return {
            "outputs": final_result,
            "scores": None,
            "logits": final_logits,
            "losses": None,
        }