def process_partial_targets_decoding(self, targets): """Processes partially generated targets in decoding mode.""" targets_shape = utils.shape_list(targets) seq_length = targets_shape[1] blocks_per_dim = [ s // q for s, q in zip([seq_length], self.hparams.query_shape) ] targets = tf.reshape( targets, [targets_shape[0], -1, np.prod(self.hparams.query_shape), 1]) targets = utils.unflatten_blocks_nd(targets, blocks_per_dim) targets = utils.put_back_blocks_nd(targets, self.hparams.query_shape) targets = tf.reshape(targets, [-1, seq_length]) return targets
def process_partial_targets_decoding(self, targets): """Processes partially generated targets in decoding mode.""" original_shape = self.get_shape_for_decoder() blocks_per_dim = [ s // q for s, q in zip(original_shape, self.hparams.query_shape) ] targets_shape = utils.shape_list(targets) targets = tf.reshape( targets, [targets_shape[0], -1, np.prod(self.hparams.query_shape), 1]) targets = utils.unflatten_blocks_nd(targets, blocks_per_dim) targets = utils.put_back_blocks_nd(targets, self.hparams.query_shape) targets = tf.reshape( targets, [-1, self.frame_height, self.frame_width, self.num_channels]) return targets
def infer(self, features, **kwargs): with tf.variable_scope("sparse_transformer", reuse=tf.AUTO_REUSE): features = self.bottom(features) decode_length = self.hparams.max_target_length cache = {} decoding_stats = {} targets_old = features.get("targets") start_step = 0 initial_output = tf.zeros((self.batch_size, decode_length, 1, 1), dtype=tf.int32) initial_logits = tf.zeros( (self.batch_size, decode_length, self.vocab_size)) # call body once to initialize cache with representations of input frames. features["targets"] = initial_output # Set shape of inputs if "inputs" in features: features["inputs"].set_shape([ self.batch_size, self.hparams.max_length, 1, self.hparams.hidden_size ]) with tf.variable_scope("sparse_transformer/body", reuse=tf.AUTO_REUSE): self.body(features, decode_step=None, cache=cache, decoding_stats=decoding_stats) def infer_step(i, recent_output, recent_logits, cache, decoding_stats): """Inference step.""" features_copy = features.copy() features_copy["targets"] = recent_output cur_sample, cur_logit = self.sample(features_copy, decode_step=i, cache=cache, decoding_stats=decoding_stats) pos = i samples = recent_output + tf.scatter_nd( indices=[[b, pos, 0, 0] for b in range(self.batch_size)], updates=cur_sample, shape=utils.shape_list(recent_output)) logits = recent_logits + tf.scatter_nd( indices=[[b, pos] for b in range(self.batch_size)], updates=cur_logit, shape=utils.shape_list(recent_logits)) return i + 1, samples, logits, cache, decoding_stats def while_exit_cond(i, result, logits, cache, decoding_stats): # pylint: disable=unused-argument """Exit the loop if it reaches decode_length.""" not_overflow = i < decode_length return not_overflow _, final_result, final_logits, _, decoding_stats = tf.while_loop( while_exit_cond, infer_step, [ start_step, initial_output, initial_logits, cache, decoding_stats ], back_prop=False, parallel_iterations=1) original_shape = [decode_length] blocks_per_dim = [ s // q for s, q in zip(original_shape, self.hparams.query_shape) ] final_result_shape = utils.shape_list(final_result) final_result = tf.reshape( final_result, [final_result_shape[0], -1, np.prod(self.hparams.query_shape), 1]) final_logits_shape = utils.shape_list(final_logits) final_logits = tf.reshape(final_logits, [ final_logits_shape[0], -1, np.prod(self.hparams.query_shape), final_logits_shape[-1] ]) final_result = utils.unflatten_blocks_nd(final_result, blocks_per_dim) final_result = utils.put_back_blocks_nd(final_result, self.hparams.query_shape) final_logits = utils.unflatten_blocks_nd(final_logits, blocks_per_dim) final_logits = utils.put_back_blocks_nd(final_logits, self.hparams.query_shape) for name, value in decoding_stats.items(): tf.summary.scalar("decodes/%s" % name, value / decode_length) # Reassign targets back to the previous value. if targets_old is not None: features["targets"] = targets_old return { "outputs": final_result, "scores": None, "logits": final_logits, "losses": None, }
def infer(self, features, **kwargs): decode_length = (self.frame_height * self.frame_width * self.num_channels) cache = {} decoding_stats = {} targets_old = features.get("targets", None) initial_output = tf.zeros((self.batch_size, decode_length), dtype=tf.int32) initial_logits = tf.zeros( (self.batch_size, decode_length, self.targets_vocab_size)) # call body once to initialize cache with representations of input frames. features["targets"] = initial_output with tf.variable_scope("sparse_imagetransformer/body", reuse=tf.AUTO_REUSE, use_resource=True): self.body(features, decode_step=None, cache=cache, decoding_stats=decoding_stats) def infer_step(i, recent_output, recent_logits, cache, decoding_stats): """Inference step.""" features_copy = features.copy() features_copy["targets"] = recent_output cur_sample, cur_logit = self.sample(features_copy, decode_step=i, cache=cache, decoding_stats=decoding_stats) pos = i samples = recent_output + tf.scatter_nd( indices=[[b, pos] for b in range(self.batch_size)], updates=cur_sample, shape=utils.shape_list(recent_output)) logits = recent_logits + tf.scatter_nd( indices=[[b, pos] for b in range(self.batch_size)], updates=cur_logit, shape=utils.shape_list(recent_logits)) return i + 1, samples, logits, cache, decoding_stats def while_exit_cond(i, result, logits, cache, decoding_stats): # pylint: disable=unused-argument """Exit the loop if it reaches decode_length.""" not_overflow = i < decode_length return not_overflow _, final_result, final_logits, _, decoding_stats = tf.while_loop( while_exit_cond, infer_step, [ tf.constant(0), initial_output, initial_logits, cache, decoding_stats ], back_prop=False, parallel_iterations=1) original_shape = self.get_shape_for_decoder() blocks_per_dim = [ s // q for s, q in zip(original_shape, self.hparams.query_shape) ] final_result_shape = utils.shape_list(final_result) final_result = tf.reshape( final_result, [final_result_shape[0], -1, np.prod(self.hparams.query_shape), 1]) final_logits_shape = utils.shape_list(final_logits) final_logits = tf.reshape(final_logits, [ final_logits_shape[0], -1, np.prod(self.hparams.query_shape), final_logits_shape[-1] ]) final_result = utils.unflatten_blocks_nd(final_result, blocks_per_dim) final_result = utils.put_back_blocks_nd(final_result, self.hparams.query_shape) final_logits = utils.unflatten_blocks_nd(final_logits, blocks_per_dim) final_logits = utils.put_back_blocks_nd(final_logits, self.hparams.query_shape) final_result = tf.reshape( final_result, [-1, self.frame_height, self.frame_width, self.num_channels]) final_logits = tf.reshape(final_logits, [ -1, self.frame_height, self.frame_width, self.num_channels, self.targets_vocab_size ]) if utils.is_xla_compiled(): _IMGS["decodes"] = final_result for name, value in decoding_stats.items(): tf.summary.scalar("decodes/%s" % name, value / decode_length) # Reassign targets back to the previous value. if targets_old is not None: features["targets"] = targets_old return { "outputs": final_result, "scores": None, "logits": final_logits, "losses": None, }