def body(self, features): """Body of the model. Args: features: a dictionary with the tensors. Returns: A pair (predictions, losses) where predictions is the generated image and losses is a dictionary of losses (that get added for the final loss). """ features["targets"] = features["inputs"] is_training = self.hparams.mode == tf.estimator.ModeKeys.TRAIN # Input images. inputs = tf.to_float(features["targets_raw"]) # Noise vector. z = tf.random_uniform([self.hparams.batch_size, self.hparams.bottleneck_bits], minval=-1, maxval=1, name="z") # Generator output: fake images. out_shape = common_layers.shape_list(inputs)[1:4] g = self.generator(z, is_training, out_shape) losses = self.losses(inputs, g) # pylint: disable=not-callable summary_g_image = tf.reshape( g[0, :], [1] + common_layers.shape_list(inputs)[1:]) tf.summary.image("generated", summary_g_image, max_outputs=1) if is_training: # Returns an dummy output and the losses dictionary. return tf.zeros_like(inputs), losses return tf.reshape(g, tf.shape(inputs)), losses
def infer(self, features, *args, **kwargs): """Produce predictions from the model by running it.""" del args, kwargs if "targets" not in features: if "infer_targets" in features: targets_shape = common_layers.shape_list(features["infer_targets"]) elif "inputs" in features: targets_shape = common_layers.shape_list(features["inputs"]) targets_shape[1] = self.hparams.video_num_target_frames else: raise ValueError("no inputs are given.") features["targets"] = tf.zeros(targets_shape, dtype=tf.float32) output, _ = self(features) # pylint: disable=not-callable if not isinstance(output, dict): output = {"targets": output} x = output["targets"] if self.is_per_pixel_softmax: x_shape = common_layers.shape_list(x) x = tf.reshape(x, [-1, x_shape[-1]]) x = tf.argmax(x, axis=-1) x = tf.reshape(x, x_shape[:-1]) else: x = tf.squeeze(x, axis=-1) x = tf.to_int64(tf.round(x)) output["targets"] = x if self.hparams.reward_prediction: output["target_reward"] = tf.argmax(output["target_reward"], axis=-1) # only required for decoding. output["outputs"] = output["targets"] output["scores"] = output["targets"] return output
def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1, alpha=0.0, use_tpu=False): """Produce predictions from the model.""" if not features: features = {} inputs_old = None if "inputs" in features and len(features["inputs"].shape) < 4: inputs_old = features["inputs"] features["inputs"] = tf.expand_dims(features["inputs"], 2) # Create an initial targets tensor. if "partial_targets" in features: initial_output = tf.convert_to_tensor(features["partial_targets"]) else: batch_size = common_layers.shape_list(features["inputs"])[0] length = common_layers.shape_list(features["inputs"])[1] target_length = tf.to_int32(2.0 * tf.to_float(length)) initial_output = tf.zeros((batch_size, target_length, 1, 1), dtype=tf.int64) features["targets"] = initial_output logits, _ = self(features) # pylint: disable=not-callable samples = tf.argmax(logits, axis=-1) if inputs_old is not None: # Restore to not confuse Estimator. features["inputs"] = inputs_old return samples
def bottleneck(self, x): hparams = self.hparams b, _ = super(AutoencoderDualDiscrete, self).bottleneck(x) if hparams.mode == tf.estimator.ModeKeys.EVAL: return b, 0.0 bt, bi = tf.split(b, 2, axis=0) if self.hparams.mode != tf.estimator.ModeKeys.TRAIN: return tf.concat([bi, bi], axis=0), 0.0 # Share the first hparams.bottleneck_shared_bits. shared = (bt + bi) / 2 # -1 if both -1, 1 if both were 1, 0 if disagree. rand = tf.random_uniform(common_layers.shape_list(bt)) br = tf.where(rand < 0.5, bt, bi) # Break ties at random. bs = tf.where(shared == 0, br, shared) bs = tf.concat([bs, bs], axis=0) n = hparams.bottleneck_shared_bits step = tf.train.get_global_step() zero = tf.constant(0, dtype=tf.int64) if step is None: step = zero step = tf.maximum(zero, step - hparams.bottleneck_shared_bits_start_warmup) f = common_layers.inverse_lin_decay( hparams.bottleneck_shared_bits_stop_warmup, min_value=0.1, step=step) n = tf.where(step > 1, n * f, n) n = tf.cast(n, tf.int64) b_shape = common_layers.shape_list(b) b = tf.concat([bs[..., :n], b[..., n:]], axis=-1) b = tf.reshape(b, b_shape) return b, 0.0
def padded_sequence_accuracy(predictions, labels, weights_fn=common_layers.weights_nonzero): """Percentage of times that predictions matches labels everywhere (non-0).""" # If the last dimension is 1 then we're using L1/L2 loss. if common_layers.shape_list(predictions)[-1] == 1: return rounding_sequence_accuracy( predictions, labels, weights_fn=weights_fn) with tf.variable_scope( "padded_sequence_accuracy", values=[predictions, labels]): padded_predictions, padded_labels = common_layers.pad_with_zeros( predictions, labels) weights = weights_fn(padded_labels) # Flatten, keeping batch dim (and num_classes dim for predictions) # TPU argmax can only deal with a limited number of dimensions predictions_shape = common_layers.shape_list(padded_predictions) batch_size = predictions_shape[0] num_classes = predictions_shape[-1] flat_size = common_layers.list_product( common_layers.shape_list(padded_labels)[1:]) padded_predictions = tf.reshape( padded_predictions, [batch_size, common_layers.list_product(predictions_shape[1:-1]), num_classes]) padded_labels = tf.reshape(padded_labels, [batch_size, flat_size]) weights = tf.reshape(weights, [batch_size, flat_size]) outputs = tf.to_int32(tf.argmax(padded_predictions, axis=-1)) padded_labels = tf.to_int32(padded_labels) not_correct = tf.to_float(tf.not_equal(outputs, padded_labels)) * weights axis = list(range(1, len(outputs.get_shape()))) correct_seq = 1.0 - tf.minimum(1.0, tf.reduce_sum(not_correct, axis=axis)) return correct_seq, tf.constant(1.0)
def dae(x, hparams, name): with tf.variable_scope(name): m = tf.layers.dense(x, hparams.v_size, name="mask") if hparams.softmax_k > 0: m, kl = top_k_softmax(m, hparams.softmax_k) return m, m, 1.0 - tf.reduce_mean(kl) logsm = tf.nn.log_softmax(m) # Gumbel-softmax sample. gumbel_samples = gumbel_sample(common_layers.shape_list(m)) steps = hparams.kl_warmup_steps gumbel_samples *= common_layers.inverse_exp_decay(steps // 5) * 0.5 temperature = 1.2 - common_layers.inverse_lin_decay(steps) # 10% of the time keep reasonably high temperature to keep learning. temperature = tf.cond(tf.less(tf.random_uniform([]), 0.9), lambda: temperature, lambda: tf.random_uniform([], minval=0.5, maxval=1.0)) s = tf.nn.softmax((logsm + gumbel_samples) / temperature) m = tf.nn.softmax(m) kl = - tf.reduce_max(logsm, axis=-1) if _DO_SUMMARIES: tf.summary.histogram("max-log", tf.reshape(kl, [-1])) # Calculate the argmax and construct hot vectors. maxvec = tf.reshape(tf.argmax(m, axis=-1), [-1]) maxvhot = tf.stop_gradient(tf.one_hot(maxvec, hparams.v_size)) # Add losses that prevent too few being used. distrib = tf.reshape(logsm, [-1, hparams.v_size]) * maxvhot d_mean = tf.reduce_mean(distrib, axis=[0], keep_dims=True) d_variance = tf.reduce_mean(tf.square(distrib - d_mean), axis=[0]) d_dev = - tf.reduce_mean(d_variance) ret = s if hparams.mode != tf.contrib.learn.ModeKeys.TRAIN: ret = tf.reshape(maxvhot, common_layers.shape_list(s)) # Just hot @eval. return m, ret, d_dev * 5.0 + tf.reduce_mean(kl) * 0.002
def embed(self, x): """Embedding function that takes discrete latent and returns embedding. Args: x: Input to the discretization bottleneck. Returns: Continuous embedding to be passed on to the decoder. Raises: ValueError: For unknown or missing arguments. """ shape_x = common_layers.shape_list(x) x_flat = tf.reshape(x, [-1, 1]) c = self.int_to_bit(x_flat, num_bits=self.hparams.z_size, base=2) shape = common_layers.shape_list(c) new_shape = shape new_shape.append(self.hparams.num_blocks) new_shape.append(int(self.hparams.z_size / self.hparams.num_blocks)) c = tf.to_int32(tf.reshape(c, shape=new_shape)) h1_shape = shape_x h1_shape.append(self.hparams.hidden_size) h1 = tf.zeros(dtype=tf.float32, shape=h1_shape) c_int = self.bit_to_int( c, num_bits=int(self.hparams.z_size / self.hparams.num_blocks), base=2) c_hot = tf.one_hot(c_int, depth=self.hparams.block_v_size, axis=-1) c_hot_flat = tf.reshape( c_hot, shape=[-1, self.hparams.num_blocks, self.hparams.block_v_size]) h1 = tf.matmul(tf.transpose(c_hot_flat, perm=[1, 0, 2]), self.means) h1 = tf.transpose(h1, perm=[1, 0, 2]) h1 = tf.reshape(h1, shape=h1_shape) h1_shape[0] = self.hparams.batch_size h2 = tf.layers.dense(tf.nn.relu(h1), self.hparams.filter_size, name="vch2") res = tf.layers.dense( tf.nn.relu(h2), self.hparams.hidden_size, name="vcfin") return res
def symbols_to_logits_fn(ids): """Go from ids to logits.""" ids = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) ids = tf.pad(ids[:, 1:], [[0, 0], [0, 1], [0, 0], [0, 0]]) if "partial_targets" in features: pt = features["partial_targets"] pt_length = common_layers.shape_list(pt)[1] pt = tf.tile(pt, [1, beam_size]) pt = tf.reshape(pt, [batch_size * beam_size, pt_length, 1, 1]) ids = tf.concat([pt, ids], axis=1) features["targets"] = ids self._coverage = None logits, _ = self(features) # pylint: disable=not-callable # now self._coverage is a coverage tensor for the first datashard. # it has shape [batch_size] and contains floats between 0 and # source_length. if self._problem_hparams: modality = self._problem_hparams.target_modality if modality.top_is_pointwise: return tf.squeeze(logits, axis=[1, 2, 3]) # -1 due to the pad above. current_output_position = common_layers.shape_list(ids)[1] - 1 logits = logits[:, current_output_position, :, :] return tf.squeeze(logits, axis=[1, 2])
def postprocess_image(x, rows, cols, hparams): """Postprocessing after decoding.""" batch = common_layers.shape_list(x)[0] channels = 256 x = tf.reshape(x, [batch, rows, cols, hparams.hidden_size]) # targets = common_layers.conv(x, 256, (1, 1), name="output_conv") targets = tf.layers.dense(x, 256, use_bias=True, activation=None, name="output_conv") if hparams.mode == tf.contrib.learn.ModeKeys.INFER: y = targets y = tf.reshape(y, [batch, -1, hparams.img_len*3, channels]) yshape = common_layers.shape_list(y) block_length = hparams.query_shape[0] block_width = hparams.query_shape[1] # Break into block row wise. y = tf.reshape(y, [batch, yshape[1] // block_length, block_length, yshape[2], channels]) yshape = common_layers.shape_list(y) # Break into blocks width wise. y_blocks = tf.reshape(y, [batch, yshape[1], yshape[2], yshape[3] // block_width, block_width, channels]) # Reshape targets as [batch_size, num_blocks_rows, num_block_cols, # block_length, block_width, channels] targets = tf.transpose(y_blocks, [0, 1, 3, 2, 4, 5]) return targets
def transformer_prepare_decoder(targets, hparams, features=None): """Prepare one shard of the model for the decoder. Args: targets: a Tensor. hparams: run hyperparameters features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. Returns: decoder_input: a Tensor, bottom of decoder stack decoder_self_attention_bias: a bias tensor for use in encoder self-attention """ decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle( common_layers.shape_list(targets)[1])) if features and "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = features["targets_segmentation"] targets_position = features["targets_position"] decoder_self_attention_bias += common_attention.attention_bias_same_segment( targets_segmentation, targets_segmentation) else: targets_position = None if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( common_layers.shape_list(targets)[1]) decoder_input = common_layers.shift_right_3d(targets) if hparams.pos == "timing": if targets_position is not None: decoder_input = common_attention.add_timing_signal_1d_given_position( decoder_input, targets_position) else: decoder_input = common_attention.add_timing_signal_1d(decoder_input) return (decoder_input, decoder_self_attention_bias)
def create_output(decoder_output, rows, cols, targets, hparams): """Creates output from decoder output and vars. Args: decoder_output: Tensor of shape [batch, ...], where ... can be any rank such that the number of elements is batch * rows * cols * hparams.hidden_size. rows: Integer representing number of rows in a 2-D data point. cols: Integer representing number of columns in a 2-D data point. targets: Tensor of shape [batch, hparams.img_len, hparams.img_len, hparams.num_channels]. hparams: tf.contrib.training.HParams set. Returns: Tensor of shape [batch, hparams.img_len, hparams.img_len, hparams.num_mixtures * 10] if hparams.likelihood is DMOL, otherwise [batch, hparams.img_len, hparams.img_len, hparams.num_channels, 256]. In the special case of predict mode, it is a Tensor of rank 5. """ decoded_image = postprocess_image(decoder_output, rows, cols, hparams) depth = common_layers.shape_list(decoded_image)[-1] batch, height, width, channels = common_layers.shape_list(targets) likelihood = getattr(hparams, "likelihood", DistributionType.CAT) if hparams.mode == tf.estimator.ModeKeys.PREDICT: y = tf.reshape(decoded_image, [batch, -1, 1, 1, depth]) output = y[:, :height, :, :, :] elif likelihood == DistributionType.CAT: # Unpack the cols dimension of the Categorical. output = tf.reshape(decoded_image, [batch, height, width, channels, depth]) else: output = decoded_image return output
def vq_discrete_unbottleneck(x, hidden_size): """Simple undiscretization from vector quantized representation.""" x_shape = common_layers.shape_list(x) x = tf.to_float(x) bottleneck_size = common_layers.shape_list(x)[-1] means, _, _ = get_vq_bottleneck(bottleneck_size, hidden_size) result = tf.matmul(tf.reshape(x, [-1, x_shape[-1]]), means) return tf.reshape(result, x_shape[:-1] + [hidden_size])
def prepare_decoder(targets, hparams): """Prepare decoder for images.""" targets_shape = common_layers.shape_list(targets) channels = hparams.num_channels curr_infer_length = None # during training, images are [batch, IMG_LEN, IMG_LEN, 3]. # At inference, they are [batch, curr_infer_length, 1, 1] if hparams.mode == tf.contrib.learn.ModeKeys.INFER: curr_infer_length = targets_shape[1] if hparams.block_raster_scan: assert hparams.img_len*channels % hparams.query_shape[1] == 0 assert hparams.img_len % hparams.query_shape[0] == 0 total_block_width = hparams.img_len*channels # Decoding is in block raster scan order. We divide the image into # hparams.query_shape blocks and then decode each block in raster scan. # To make that compatible with our inference pipeline, pad the target so # that rows is a multiple of query_shape and columns is a multiple of # hparams.img_len*channels curr_infer_length = targets_shape[1] block_padding_factor = total_block_width * hparams.query_shape[0] targets = tf.pad(targets, [ [0, 0], [0, -curr_infer_length % block_padding_factor], [0, 0], [0, 0]]) num_blocks = total_block_width // hparams.query_shape[1] # Reshape the image to represent blocks target_blocks = tf.reshape( targets, [targets_shape[0], -1, num_blocks, hparams.query_shape[0], hparams.query_shape[1]]) # Transpose to read the image in 2D fashion. targets = tf.transpose(target_blocks, [0, 1, 3, 2, 4]) else: # add padding to make sure the size of targets is a multiple of img_height # times number of channels. This is needed for positional encodings and # for doing the RGB lookup. padding_factor = channels * hparams.img_len targets = tf.pad(targets, [ [0, 0], [0, -curr_infer_length % padding_factor], [0, 0], [0, 0]]) targets = tf.reshape(targets, [targets_shape[0], -1, hparams.img_len, channels]) # Preprocess image x = prepare_image(targets, hparams, name="dec_channels") x_shape = common_layers.shape_list(x) if (hparams.dec_attention_type == AttentionType.LOCAL_2D or hparams.dec_attention_type == AttentionType.LOCAL_BLOCK): x = common_attention.right_shift_blockwise(x, hparams.query_shape) x = add_pos_signals(x, hparams, "dec_pos") else: # Add position signals x = tf.reshape(x, [targets_shape[0], x_shape[1]*x_shape[2], hparams.hidden_size]) x = common_layers.shift_right_3d(x) x = tf.reshape(x, [targets_shape[0], x_shape[1], x_shape[2], hparams.hidden_size]) x = add_pos_signals(x, hparams, "dec_pos") x = common_layers.cast_like(x, targets) return x, x_shape[1], x_shape[2]
def logits_to_samples(logits): """Get samples from logits.""" # If the last dimension is 1 then we're using L1/L2 loss. if common_layers.shape_list(logits)[-1] == 1: return tf.to_int32(tf.squeeze(logits, axis=-1)) # Argmax in TF doesn't handle more than 5 dimensions yet. logits_shape = common_layers.shape_list(logits) argmax = tf.argmax(tf.reshape(logits, [-1, logits_shape[-1]]), axis=-1) return tf.reshape(argmax, logits_shape[:-1])
def loss(self, top_out, targets): predictions = top_out if (len(common_layers.shape_list(top_out)) != len( common_layers.shape_list(targets))): predictions = tf.squeeze(top_out, axis=[-1]) with tf.name_scope("l2"): weights = self.targets_weights_fn(targets) l2 = tf.pow(predictions - targets, 2) return tf.reduce_sum(l2 * weights), tf.reduce_sum(weights)
def loss(self, top_out, targets): predictions = top_out if (len(common_layers.shape_list(top_out)) != len( common_layers.shape_list(targets))): predictions = tf.squeeze(top_out, axis=[-1]) with tf.name_scope("log_possion"): weights = self.targets_weights_fn(targets) lp_loss = tf.nn.log_poisson_loss(targets, predictions) return tf.reduce_sum(lp_loss * weights), tf.reduce_sum(weights)
def construct_model(self, images, actions, rewards): images = tf.unstack(images, axis=0) actions = tf.unstack(actions, axis=0) rewards = tf.unstack(rewards, axis=0) batch_size = common_layers.shape_list(images[0])[0] context_frames = self.hparams.video_num_input_frames # Predicted images and rewards. gen_rewards, gen_images, latent_means, latent_stds = [], [], [], [] # LSTM states. lstm_state = [None] * 7 # Create scheduled sampling function ss_func = self.get_scheduled_sample_func(batch_size) pred_image = tf.zeros_like(images[0]) pred_reward = tf.zeros_like(rewards[0]) latent = None for timestep, image, action, reward in zip( range(len(images)-1), images[:-1], actions[:-1], rewards[:-1]): # Scheduled Sampling done_warm_start = timestep > context_frames - 1 groundtruth_items = [image, reward] generated_items = [pred_image, pred_reward] input_image, input_reward = self.get_scheduled_sample_inputs( done_warm_start, groundtruth_items, generated_items, ss_func) # Latent # TODO(mbz): should we use input_image iunstead of image? latent_images = tf.stack([image, images[timestep+1]], axis=0) latent_mean, latent_std = self.construct_latent_tower( latent_images, time_axis=0) latent = common_video.get_gaussian_tensor(latent_mean, latent_std) latent_means.append(latent_mean) latent_stds.append(latent_std) # Prediction pred_image, lstm_state, _ = self.construct_predictive_tower( input_image, input_reward, action, lstm_state, latent) if self.hparams.reward_prediction: pred_reward = self.reward_prediction( pred_image, input_reward, action, latent) pred_reward = common_video.decode_to_shape( pred_reward, common_layers.shape_list(input_reward), "reward_dec") else: pred_reward = input_reward gen_images.append(pred_image) gen_rewards.append(pred_reward) gen_images = tf.stack(gen_images, axis=0) gen_rewards = tf.stack(gen_rewards, axis=0) return gen_images, gen_rewards, latent_means, latent_stds
def postprocess_image(x, rows, cols, hparams): """Postprocessing after decoding. Args: x: Tensor of shape [batch, ...], where ... can be any rank such that the number of elements in x is batch * rows * cols * hparams.hidden_size. rows: Integer representing number of rows in a 2-D data point. cols: Integer representing number of columns in a 2-D data point. hparams: tf.contrib.training.HParams set. Returns: Tensor of shape [batch, rows, cols, depth], where depth is hparams.num_mixtures * 10 if hparams.likelihood is DMOL, otherwise 256. In the special case of inference and block raster scan order, it is a Tensor of shape [batch, num_blocks_rows, num_block_cols, block_length, block_width, depth]. """ batch = common_layers.shape_list(x)[0] x = tf.reshape(x, [batch, rows, cols, hparams.hidden_size]) likelihood = getattr(hparams, "likelihood", DistributionType.CAT) if likelihood == DistributionType.DMOL: depth = hparams.num_mixtures * 10 targets = tf.layers.dense(x, depth, use_bias=False, activation=None, name="output_conv") else: depth = 256 targets = tf.layers.dense(x, depth, use_bias=True, activation=None, name="output_conv") if (hparams.mode == tf.contrib.learn.ModeKeys.INFER and hparams.block_raster_scan): y = targets yshape = common_layers.shape_list(y) block_length = hparams.query_shape[0] block_width = hparams.query_shape[1] # Break into block row wise. y = tf.reshape(y, [batch, yshape[1] // block_length, block_length, yshape[2], depth]) yshape = common_layers.shape_list(y) # Break into blocks width wise. y_blocks = tf.reshape(y, [batch, yshape[1], yshape[2], yshape[3] // block_width, block_width, depth]) # Reshape targets as [batch, num_blocks_rows, num_block_cols, block_length, # block_width, depth]. targets = tf.transpose(y_blocks, [0, 1, 3, 2, 4, 5]) return targets
def top_k_experts(x, k, hparams): x_shape = common_layers.shape_list(x) x_flat = tf.reshape(x, [-1, common_layers.shape_list(x)[-1]]) is_training = hparams.mode == tf.contrib.learn.ModeKeys.TRAIN gates, load = expert_utils.noisy_top_k_gating( x_flat, 2 ** hparams.z_size, is_training, k) gates_shape = [x_shape[0], x_shape[1], x_shape[2], 2 ** hparams.z_size] gates = tf.reshape(gates, gates_shape) load_loss = expert_utils.cv_squared(load) return gates, load_loss
def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1, alpha=0.0, use_tpu=False): """Produce predictions from the model.""" if not self._hparams.do_mask: infer_out = super(TransformerAE, self).infer( features, decode_length, beam_size, top_beams, alpha, use_tpu=use_tpu) return infer_out["outputs"] if not features: features = {} inputs_old = None if "inputs" in features and len(features["inputs"].shape) < 4: inputs_old = features["inputs"] features["inputs"] = tf.expand_dims(features["inputs"], 2) # Create an initial targets tensor. if "partial_targets" in features: initial_output = tf.convert_to_tensor(features["partial_targets"]) else: # inputs might not be present in features (e.g.: language modeling), # in which case we fallback to 'infer_targets' for calculating initial # input shape, type, etc. inputs_or_targets = features.get("inputs", features.get("infer_targets")) batch_size = common_layers.shape_list(inputs_or_targets)[0] length = common_layers.shape_list(inputs_or_targets)[1] hidden_dim = common_layers.shape_list(inputs_or_targets)[-1] target_length = tf.to_int32(2.0 * tf.to_float(length)) initial_output = tf.zeros((batch_size, target_length, 1, hidden_dim), dtype=inputs_or_targets.dtype) features["targets"] = initial_output logits, _ = self(features) # pylint: disable=not-callable # this should only happen if we're doing target_modality not real if inputs_or_targets.dtype == tf.float32: samples = logits else: samples = tf.argmax(logits, axis=-1) # More steps. self.predict_mask = 0.0 # Use the provided targets this time. how_many_more_steps = 0 # Set to 1 or more for Gibbs-like sampling. for _ in range(how_many_more_steps): with tf.variable_scope(tf.get_variable_scope(), reuse=True): features["targets"] = samples logits, _ = self(features) # pylint: disable=not-callable if inputs_or_targets.dtype == tf.float32: # When target_modality is real, the last axis does not represent # classes, so it should not be argmax'ed samples = logits else: samples = tf.argmax(logits, axis=-1) self.predict_mask = 1.0 if inputs_old is not None: # Restore to not confuse Estimator. features["inputs"] = inputs_old return samples
def decode_transformer(encoder_output, encoder_decoder_attention_bias, targets, hparams, name, task=None): """Original Transformer decoder.""" with tf.variable_scope(name): if task is None: task = hparams.task if task == "translate": targets = common_layers.flatten4d3d(targets) decoder_input, decoder_self_bias = ( transformer.transformer_prepare_decoder(targets, hparams)) decoder_input = tf.nn.dropout(decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) decoder_output = transformer.transformer_decoder( decoder_input, encoder_output, decoder_self_bias, encoder_decoder_attention_bias, hparams) decoder_output = tf.expand_dims(decoder_output, axis=2) else: assert task == "image" inputs = None # have to reshape targets as b, 32, 32, 3 * hidden size] beacuse otherwise # prepare_image will choke targets = tf.reshape(targets, [tf.shape(targets)[0], hparams.img_len, hparams.img_len, hparams.num_channels*hparams.hidden_size]) # Prepare decoder inputs and bias. decoder_input, _, _, bias = cia.prepare_decoder(targets, hparams) # Add class label to decoder input. if not hparams.drop_inputs: decoder_input += tf.reshape( inputs, [common_layers.shape_list(targets)[0], 1, 1, hparams.hidden_size]) decoder_output = cia.transformer_decoder_layers( decoder_input, None, bias, hparams.num_decoder_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.dec_attention_type, name="decoder") decoder_output_shape = common_layers.shape_list(decoder_output) decoder_output = tf.reshape(decoder_output, [decoder_output_shape[0], -1, 1, hparams.hidden_size]) # Expand since t2t expects 4d tensors. return decoder_output
def loss(self, top_out, targets): """Compute loss numerator and denominator for one shard of output.""" logits = top_out logits = tf.reshape(logits, [-1] + common_layers.shape_list(logits)[2:]) targets = tf.reshape(targets, [-1] + common_layers.shape_list(targets)[2:]) cutoff = getattr(self._model_hparams, "video_modality_loss_cutoff", 0.01) return common_layers.padded_cross_entropy( logits, targets, self._model_hparams.label_smoothing, cutoff=cutoff, weights_fn=self.targets_weights_fn)
def reduce_dimensions(predictions, labels): """Reduce dimensions for high-dimensional predictions and labels.""" # We will treat first dimensions as batch. One example are video frames. if len(predictions.get_shape()) > 5: predictions_shape = common_layers.shape_list(predictions) predictions = tf.reshape( predictions, [predictions_shape[0], predictions_shape[1], -1, predictions_shape[-1]]) labels_shape = common_layers.shape_list(labels) labels = tf.reshape( labels, [labels_shape[0], labels_shape[1], -1]) return predictions, labels
def logits_to_samples(logits, key): """Get samples from logits.""" # If the last dimension is 1 then we're using L1/L2 loss. if common_layers.shape_list(logits)[-1] == 1: return tf.to_int32(tf.squeeze(logits, axis=-1)) if key == "targets": return pixels_from_softmax( logits, gumbel_noise_factor=0.0, temperature=hparams.pixel_sampling_temperature) # Argmax in TF doesn't handle more than 5 dimensions yet. logits_shape = common_layers.shape_list(logits) argmax = tf.argmax(tf.reshape(logits, [-1, logits_shape[-1]]), axis=-1) return tf.reshape(argmax, logits_shape[:-1])
def loss(self, top_out, targets): """Compute loss numerator and denominator for one shard of output.""" logits = top_out logits = tf.reshape(logits, [-1] + common_layers.shape_list(logits)[2:-1]) targets = tf.reshape(targets, [-1] + common_layers.shape_list(targets)[2:]) weights = self.targets_weights_fn(targets) # Shift targets by 0.5 so later just casting to int gives the prediction. # So for int targets, say 0 and 7, we actually train to predict 0.5 and 7.5. # Later (in merics or infer) this is cast to int anyway. Also, we have no # loss beyond self.cutoff = 0.2 as these are already correct predictions. targets = tf.to_float(targets) + 0.5 loss = self.internal_loss(logits, targets) return tf.reduce_sum(loss * weights), tf.reduce_sum(weights)
def discriminator(self, x, is_training): """Discriminator architecture based on InfoGAN. Args: x: input images, shape [bs, h, w, channels] is_training: boolean, are we in train or eval model. Returns: out_logit: the output logits (before sigmoid). """ hparams = self.hparams with tf.variable_scope( "discriminator", initializer=tf.random_normal_initializer(stddev=0.02)): batch_size, height, width = common_layers.shape_list(x)[:3] # Mapping x from [bs, h, w, c] to [bs, 1] net = tf.layers.conv2d( x, 64, (4, 4), strides=(2, 2), padding="SAME", name="d_conv1") # [bs, h/2, w/2, 64] net = lrelu(net) net = tf.layers.conv2d( net, 128, (4, 4), strides=(2, 2), padding="SAME", name="d_conv2") # [bs, h/4, w/4, 128] if hparams.discriminator_batchnorm: net = tf.layers.batch_normalization( net, training=is_training, momentum=0.999, name="d_bn2") net = lrelu(net) size = height * width net = tf.reshape(net, [batch_size, size * 8]) # [bs, h * w * 8] net = tf.layers.dense(net, 1024, name="d_fc3") # [bs, 1024] if hparams.discriminator_batchnorm: net = tf.layers.batch_normalization( net, training=is_training, momentum=0.999, name="d_bn3") net = lrelu(net) return net
def process_single_frame(prev_outputs, inputs): """Process a single frame of the video.""" cur_image, input_reward, action = inputs time_step, prev_image, prev_reward, frame_buf, lstm_states = prev_outputs # sample from softmax (by argmax). this is noop for non-softmax loss. prev_image = self.get_sampled_frame(prev_image) generated_items = [prev_image] groundtruth_items = [cur_image] done_warm_start = tf.greater(time_step, context_frames - 1) input_image, = self.get_scheduled_sample_inputs( done_warm_start, groundtruth_items, generated_items, ss_func) # Prediction pred_image, lstm_states, _ = self.construct_predictive_tower( input_image, None, action, lstm_states, latent) if self.hparams.reward_prediction: reward_input_image = self.get_sampled_frame(pred_image) if self.hparams.reward_prediction_stop_gradient: reward_input_image = tf.stop_gradient(reward_input_image) with tf.control_dependencies([time_step]): frame_buf = [reward_input_image] + frame_buf[:-1] pred_reward = self.reward_prediction(frame_buf, None, action, latent) pred_reward = common_video.decode_to_shape( pred_reward, common_layers.shape_list(input_reward), "reward_dec") else: pred_reward = prev_reward time_step += 1 outputs = (time_step, pred_image, pred_reward, frame_buf, lstm_states) return outputs
def bottom_compress(self, inputs, name="bottom"): """Transform input from data space to model space. Perform conversion of RGB pixel values to a real number and combine values for each pixel to form representation of image_length x image_length dims. Args: inputs: A Tensor with shape [batch, ...] name: string, scope. Returns: body_input: A Tensor with shape [batch, ?, ?, body_input_depth]. """ with tf.variable_scope(name): inputs = common_layers.convert_rgb_to_real(inputs) ishape = common_layers.shape_list(inputs) inputs = tf.reshape(inputs, [-1, ishape[1], ishape[2] * ishape[3], 1]) inputs.set_shape([None, None, None, 1]) # We compress RGB intensities for each pixel using a conv. x = common_layers.conv_block( inputs, self._body_input_depth, [((1, 1), (1, 3))], first_relu=False, padding="VALID", strides=(1, 3), force2d=True, name="conv_input") return x
def lstm_cell(inputs, state, num_units, use_peepholes=False, cell_clip=0.0, initializer=None, num_proj=None, num_unit_shards=None, num_proj_shards=None, reuse=None, name=None): """Full LSTM cell.""" input_shape = common_layers.shape_list(inputs) cell = tf.contrib.rnn.LSTMCell(num_units, use_peepholes=use_peepholes, cell_clip=cell_clip, initializer=initializer, num_proj=num_proj, num_unit_shards=num_unit_shards, num_proj_shards=num_proj_shards, reuse=reuse, name=name, state_is_tuple=False) if state is None: state = cell.zero_state(input_shape[0], tf.float32) outputs, new_state = cell(inputs, state) return outputs, new_state
def scheduled_sample_count(ground_truth_x, generated_x, batch_size, scheduled_sample_var): """Sample batch with specified mix of groundtruth and generated data points. Args: ground_truth_x: tensor of ground-truth data points. generated_x: tensor of generated data points. batch_size: batch size scheduled_sample_var: number of ground-truth examples to include in batch. Returns: New batch with num_ground_truth sampled from ground_truth_x and the rest from generated_x. """ num_ground_truth = scheduled_sample_var idx = tf.random_shuffle(tf.range(batch_size)) ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth)) generated_idx = tf.gather(idx, tf.range(num_ground_truth, batch_size)) ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx) generated_examps = tf.gather(generated_x, generated_idx) output = tf.dynamic_stitch([ground_truth_idx, generated_idx], [ground_truth_examps, generated_examps]) # if batch size is known set it. if isinstance(batch_size, int): output.set_shape([batch_size] + common_layers.shape_list(output)[1:]) return output
def body_single(self, features): hparams = self.hparams filters = hparams.hidden_size kernel1, kernel2 = (3, 3), (4, 4) # Embed the inputs. inputs_shape = common_layers.shape_list(features["inputs"]) # Using non-zero bias initializer below for edge cases of uniform inputs. x = tf.layers.dense( features["inputs"], filters, name="inputs_embed", bias_initializer=tf.random_normal_initializer(stddev=0.01)) x = common_attention.add_timing_signal_nd(x) # Down-stride. layer_inputs = [x] for i in range(hparams.num_compress_steps): with tf.variable_scope("downstride%d" % i): layer_inputs.append(x) x = common_layers.make_even_size(x) if i < hparams.filter_double_steps: filters *= 2 x = common_attention.add_timing_signal_nd(x) x = tf.layers.conv2d(x, filters, kernel2, activation=common_layers.belu, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) # Add embedded action if present. if "input_action" in features: action = features["input_action"][:, -1, :] x = self.inject_additional_input(x, action, "action_enc", hparams.action_injection) x, extra_loss = self.inject_latent(x, features, filters) # Run a stack of convolutions. for i in range(hparams.num_hidden_layers): with tf.variable_scope("layer%d" % i): y = tf.nn.dropout(x, 1.0 - hparams.dropout) y = tf.layers.conv2d(y, filters, kernel1, activation=common_layers.belu, strides=(1, 1), padding="SAME") if i == 0: x = y else: x = common_layers.layer_norm(x + y) # Up-convolve. layer_inputs = list(reversed(layer_inputs)) for i in range(hparams.num_compress_steps): with tf.variable_scope("upstride%d" % i): if "input_action" in features: x = self.inject_additional_input(x, action, "action_enc", hparams.action_injection) if i >= hparams.num_compress_steps - hparams.filter_double_steps: filters //= 2 x = tf.layers.conv2d_transpose(x, filters, kernel2, activation=common_layers.belu, strides=(2, 2), padding="SAME") y = layer_inputs[i] shape = common_layers.shape_list(y) x = x[:, :shape[1], :shape[2], :] x = common_layers.layer_norm(x + y) x = common_attention.add_timing_signal_nd(x) # Cut down to original size. x = x[:, :inputs_shape[1], :inputs_shape[2], :] if self.is_per_pixel_softmax: x = tf.layers.dense(x, hparams.problem.num_channels * 256, name="logits") else: x = tf.layers.dense(x, hparams.problem.num_channels, name="logits") # Reward prediction if needed. if "target_reward" not in features: return x reward_pred = tf.expand_dims( # Add a fake channels dim. tf.reduce_mean(x, axis=[1, 2], keepdims=True), axis=3) return {"targets": x, "target_reward": reward_pred}, extra_loss
def body(self, features): hparams = self.hparams input_shape = common_layers.shape_list(features['inputs']) batch_size, _, frame_width, frame_height, frame_channels = input_shape # pylint: disable=unused-variable # Swap time and batch axes. input_frames = common_video.swap_time_and_batch_axes( tf.to_float(features['inputs'])) target_frames = common_video.swap_time_and_batch_axes(features['targets']) # Get actions if exist otherwise use zeros input_actions = self.get_input_if_exists( features, 'input_action', batch_size, hparams.video_num_input_frames) target_actions = self.get_input_if_exists( features, 'target_action', batch_size, hparams.video_num_target_frames) # Get rewards if exist otherwise use zeros # TODO(blazej) enable rewards. # input_rewards = self.get_input_if_exists( # features, 'input_reward', batch_size, hparams.video_num_input_frames) # target_rewards = self.get_input_if_exists( # features, 'target_reward', batch_size,hparams.video_num_target_frames) # all_rewards = tf.concat([input_rewards, target_rewards], axis=0) all_actions = tf.concat([input_actions, target_actions], axis=0) all_frames = tf.concat([input_frames, target_frames], axis=0) all_frames = tf.unstack(all_frames, axis=0) all_actions = tf.unstack(all_actions, axis=0) all_actions = [tf.squeeze(a, 1) for a in all_actions] # TODO(blazej) - most likely this downsize is too strong. all_frames = [ tf.image.resize_images( image, (IMG_HEIGHT, IMG_WIDTH), method=tf.image.ResizeMethod.BICUBIC) for image in all_frames ] enc_out_all, pred_out_all, _, van_on_enc_all = construct_model( all_frames, all_actions, context_frames=hparams.context_frames, hparams=hparams, is_training=self.is_training) enc_pred_loss, _ = calc_loss_psnr( enc_out_all[1:], pred_out_all, 'enc_pred_loss', hparams=hparams, use_l1_loss=hparams.enc_pred_use_l1_loss) van_on_enc_loss, _ = calc_loss_psnr( van_on_enc_all, all_frames[1:], 'van_on_enc_loss', hparams=hparams) enc_pred_loss_scale_delay = max(hparams.enc_pred_loss_scale_delay, 1) enc_pred_loss_scale = tf.nn.sigmoid( (tf.to_float(tf.train.get_or_create_global_step() ) - enc_pred_loss_scale_delay) / (enc_pred_loss_scale_delay * .1)) * hparams.enc_pred_loss_scale tf.summary.scalar('enc_pred_loss_scale', enc_pred_loss_scale) epva_loss = enc_pred_loss * enc_pred_loss_scale + van_on_enc_loss tf.summary.scalar('epva_loss', epva_loss) predictions = tf.stack(van_on_enc_all) # TODO(mbz): clean this up! def fix_video_dims_and_concat_on_x_axis(x): x = tf.transpose(x, [1, 3, 4, 0, 2]) x = tf.reshape(x, [batch_size, frame_height, frame_channels, -1]) x = tf.transpose(x, [0, 3, 1, 2]) return x frames_gd = fix_video_dims_and_concat_on_x_axis(target_frames) frames_pd = fix_video_dims_and_concat_on_x_axis(predictions) side_by_side_video = tf.concat([frames_gd, frames_pd], axis=1) tf.summary.image('full_video', side_by_side_video) predictions = common_video.swap_time_and_batch_axes(predictions) predictions = tf.slice(predictions, [0, hparams.video_num_input_frames-1, 0, 0, 0], [-1]*5) return predictions, {'extra': epva_loss}
def decode_transformer(encoder_output, encoder_decoder_attention_bias, targets, hparams, name, task=None, causal=True): """Original Transformer decoder.""" orig_hparams = hparams with tf.variable_scope(name): if task is None: task = hparams.task if task == "translate": targets = common_layers.flatten4d3d(targets) decoder_input, decoder_self_bias = ( transformer.transformer_prepare_decoder(targets, hparams)) decoder_input = tf.nn.dropout(decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) if not causal: decoder_self_bias *= 0. decoder_output = transformer.transformer_decoder( decoder_input, encoder_output, decoder_self_bias, encoder_decoder_attention_bias, hparams) decoder_output = tf.expand_dims(decoder_output, axis=2) else: assert task == "image" inputs = None # have to reshape targets as b, 32, 32, 3 * hidden size] beacuse otherwise # prepare_image will choke targets = tf.reshape(targets, [tf.shape(targets)[0], hparams.img_len, hparams.img_len, hparams.num_channels*hparams.hidden_size]) # Prepare decoder inputs and bias. decoder_input, _, _, bias = cia.prepare_decoder(targets, hparams) # Add class label to decoder input. if not hparams.drop_inputs: decoder_input += tf.reshape( inputs, [common_layers.shape_list(targets)[0], 1, 1, hparams.hidden_size]) decoder_output = cia.transformer_decoder_layers( decoder_input, None, bias, hparams.num_decoder_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.dec_attention_type, name="decoder") decoder_output_shape = common_layers.shape_list(decoder_output) decoder_output = tf.reshape(decoder_output, [decoder_output_shape[0], -1, 1, hparams.hidden_size]) # Expand since t2t expects 4d tensors. hparams = orig_hparams return decoder_output
def body(self, features): hparams = self.hparams # Run the basic autoencoder part first. basic_result, losses = super(AutoencoderAutoregressive, self).body(features) if hparams.autoregressive_mode == "none": assert not hparams.autoregressive_forget_base return basic_result, losses if "training" in losses: plain_training_loss = losses.pop("training") losses["plain"] = plain_training_loss res_shape = common_layers.shape_list(basic_result) vocab_size = self._problem_hparams.target_modality.top_dimensionality targets = tf.one_hot(features["targets_raw"], vocab_size) # Prepare inputs for autoregressive modes. if common_layers.shape_list(features["targets"])[1] == 1: # This happens on the first step of predicitions. assert hparams.mode == tf.estimator.ModeKeys.PREDICT targets = tf.zeros_like(basic_result) targets = self.embed(targets) if hparams.autoregressive_gumbel_sample: basic_hot = self.gumbel_sample(basic_result) else: basic_hot = basic_result basic_result = self.embed(basic_hot) shape = common_layers.shape_list(basic_result) basic1d = tf.reshape(basic_result, [shape[0], -1, shape[-1]]) targets = tf.reshape(targets, common_layers.shape_list(basic_result)) # During autoregressive inference, don't resample. if hparams.mode == tf.estimator.ModeKeys.PREDICT: if hasattr(hparams, "sampled_basic1d_tensor"): basic1d = hparams.sampled_basic1d_tensor else: hparams.sampled_basic1d_tensor = basic1d # Sometimes it's useful to look at non-autoregressive evals. targets_dropout = targets if (hparams.mode == tf.estimator.ModeKeys.EVAL and hparams.autoregressive_eval_pure_autoencoder): targets_dropout = tf.zeros_like(basic_result) # Now combine the basic reconstruction with shifted targets. targets1d = tf.reshape(targets_dropout, [shape[0], -1, shape[-1]]) targets_shifted = common_layers.shift_right_3d(targets1d) concat1d = tf.concat([basic1d, targets_shifted], axis=-1) # The forget_base hparam sets purely-autoregressive mode, no autoencoder. if hparams.autoregressive_forget_base: concat1d = tf.reshape(targets, [shape[0], -1, shape[-1]]) concat1d = common_layers.shift_right_3d(concat1d) # The autoregressive part depends on the mode. if hparams.autoregressive_mode == "conv3": res = common_layers.conv1d( concat1d, hparams.hidden_size, 3, padding="LEFT", activation=common_layers.belu, name="autoregressive_conv3") res = tf.layers.dense(res, vocab_size, name="autoregressive_final") return tf.reshape(res, res_shape), losses if hparams.autoregressive_mode == "conv5": res = common_layers.conv1d( concat1d, hparams.hidden_size, 5, padding="LEFT", activation=common_layers.belu, name="autoregressive_conv5") res = tf.layers.dense(res, vocab_size, name="autoregressive_final") return tf.reshape(res, res_shape), losses if hparams.autoregressive_mode == "sru": res = common_layers.conv1d( concat1d, hparams.hidden_size, 3, padding="LEFT", activation=common_layers.belu, name="autoregressive_sru_conv3") res = common_layers.sru(res) res = tf.layers.dense(res, vocab_size, name="autoregressive_final") return tf.reshape(res, res_shape), losses raise ValueError( "Unsupported autoregressive mode: %s" % hparams.autoregressive_mode)
def level_cond_prior(prior_dist, z, latent, hparams, state): """Returns a conditional prior for each level. Args: prior_dist: Distribution conditioned on the previous levels. z: Tensor, output of the previous levels. latent: Tensor or a list of tensors to condition the latent_distribution. hparams: next_frame_glow hparams. state: Current LSTM state. Used only if hparams.latent_dist_encoder is a lstm. Raises: ValueError: If hparams.latent_dist_encoder is "pointwise" and if the shape of latent is different from z. """ latent_dist_encoder = hparams.get("latent_dist_encoder", None) latent_skip = hparams.get("latent_skip", False) if latent_dist_encoder == "pointwise": last_latent = latent merge_std = hparams.level_scale latent_shape = common_layers.shape_list(latent) z_shape = common_layers.shape_list(z) if latent_shape != z_shape: raise ValueError("Expected latent_shape to be %s, got %s" % (latent_shape, z_shape)) latent_dist = scale_gaussian_prior( "latent_prior", latent, logscale_factor=3.0) cond_dist = merge_level_and_latent_dist(prior_dist, latent_dist, merge_std=merge_std) elif latent_dist_encoder == "conv_net": output_channels = common_layers.shape_list(z)[-1] last_latent = latent[-1] latent_stack = tf.concat([prior_dist.loc] + latent, axis=-1) latent_stack = noise_op(latent_stack, hparams) cond_dist = latent_to_dist( "latent_stack", latent_stack, hparams=hparams, output_channels=output_channels) elif latent_dist_encoder == "conv3d_net": last_latent = latent[-1] output_channels = common_layers.shape_list(last_latent)[-1] num_steps = len(latent) # Stack across time. cond_latents = tf.stack(latent, axis=1) # Concat latents from previous levels across channels. prev_latents = tf.tile(tf.expand_dims(prior_dist.loc, axis=1), [1, num_steps, 1, 1, 1]) cond_latents = tf.concat((cond_latents, prev_latents), axis=-1) cond_latents = noise_op(cond_latents, hparams) cond_dist = temporal_latent_to_dist( "latent_stack", cond_latents, hparams, output_channels=output_channels) elif latent_dist_encoder == "conv_lstm": last_latent = latent output_channels = common_layers.shape_list(z)[-1] latent_stack = tf.concat((prior_dist.loc, latent), axis=-1) latent_stack = noise_op(latent_stack, hparams) _, state = common_video.conv_lstm_2d( latent_stack, state, hparams.latent_encoder_width, kernel_size=3, name="conv_lstm") cond_dist = single_conv_dist( "state_to_dist", state.h, output_channels=output_channels) if latent_skip: new_mean = cond_dist.loc + last_latent cond_dist = tfp.distributions.Normal(new_mean, cond_dist.scale) return cond_dist.loc, cond_dist.scale, state
def construct_predictive_tower( self, input_image, input_reward, action, lstm_state, latent, concat_latent=False): # Main tower lstm_func = common_video.conv_lstm_2d frame_shape = common_layers.shape_list(input_image) batch_size, img_height, img_width, color_channels = frame_shape # the number of different pixel motion predictions # and the number of masks for each of those predictions num_masks = self.hparams.num_masks upsample_method = self.hparams.upsample_method tile_and_concat = common_video.tile_and_concat lstm_size = self.tinyify([32, 32, 64, 64, 128, 64, 32]) conv_size = self.tinyify([32]) with tf.variable_scope("main", reuse=tf.AUTO_REUSE): hidden5, skips, layer_id = self.bottom_part_tower( input_image, input_reward, action, latent, lstm_state, lstm_size, conv_size, concat_latent=concat_latent) enc0, enc1 = skips with tf.variable_scope("upsample1", reuse=tf.AUTO_REUSE): enc4 = common_layers.cyclegan_upsample( hidden5, num_outputs=hidden5.shape.as_list()[-1], stride=[2, 2], method=upsample_method) enc1_shape = common_layers.shape_list(enc1) enc4 = enc4[:, :enc1_shape[1], :enc1_shape[2], :] # Cut to shape. enc4 = tile_and_concat(enc4, latent, concat_latent=concat_latent) hidden6, lstm_state[layer_id] = lstm_func( enc4, lstm_state[layer_id], lstm_size[5], name="state6", spatial_dims=enc1_shape[1:-1]) # 16x16 hidden6 = tile_and_concat(hidden6, latent, concat_latent=concat_latent) hidden6 = tfcl.layer_norm(hidden6, scope="layer_norm7") # Skip connection. hidden6 = tf.concat(axis=3, values=[hidden6, enc1]) # both 16x16 layer_id += 1 with tf.variable_scope("upsample2", reuse=tf.AUTO_REUSE): enc5 = common_layers.cyclegan_upsample( hidden6, num_outputs=hidden6.shape.as_list()[-1], stride=[2, 2], method=upsample_method) enc0_shape = common_layers.shape_list(enc0) enc5 = enc5[:, :enc0_shape[1], :enc0_shape[2], :] # Cut to shape. enc5 = tile_and_concat(enc5, latent, concat_latent=concat_latent) hidden7, lstm_state[layer_id] = lstm_func( enc5, lstm_state[layer_id], lstm_size[6], name="state7", spatial_dims=enc0_shape[1:-1]) # 32x32 hidden7 = tfcl.layer_norm(hidden7, scope="layer_norm8") layer_id += 1 # Skip connection. hidden7 = tf.concat(axis=3, values=[hidden7, enc0]) # both 32x32 with tf.variable_scope("upsample3", reuse=tf.AUTO_REUSE): enc6 = common_layers.cyclegan_upsample( hidden7, num_outputs=hidden7.shape.as_list()[-1], stride=[2, 2], method=upsample_method) enc6 = tfcl.layer_norm(enc6, scope="layer_norm9") enc6 = tile_and_concat(enc6, latent, concat_latent=concat_latent) if self.hparams.model_options == "DNA": # Using largest hidden state for predicting untied conv kernels. enc7 = tfl.conv2d_transpose( enc6, self.hparams.dna_kernel_size**2, [1, 1], strides=(1, 1), padding="SAME", name="convt4", activation=None) else: # Using largest hidden state for predicting a new image layer. enc7 = tfl.conv2d_transpose( enc6, color_channels, [1, 1], strides=(1, 1), padding="SAME", name="convt4", activation=None) # This allows the network to also generate one image from scratch, # which is useful when regions of the image become unoccluded. transformed = [tf.nn.sigmoid(enc7)] if self.hparams.model_options == "CDNA": # cdna_input = tf.reshape(hidden5, [int(batch_size), -1]) cdna_input = tfl.flatten(hidden5) transformed += common_video.cdna_transformation( input_image, cdna_input, num_masks, int(color_channels), self.hparams.dna_kernel_size, self.hparams.relu_shift) elif self.hparams.model_options == "DNA": # Only one mask is supported (more should be unnecessary). if num_masks != 1: raise ValueError("Only one mask is supported for DNA model.") transformed = [ common_video.dna_transformation( input_image, enc7, self.hparams.dna_kernel_size, self.hparams.relu_shift)] masks = tfl.conv2d( enc6, filters=num_masks + 1, kernel_size=[1, 1], strides=(1, 1), name="convt7", padding="SAME") masks = masks[:, :img_height, :img_width, ...] masks = tf.reshape( tf.nn.softmax(tf.reshape(masks, [-1, num_masks + 1])), [batch_size, int(img_height), int(img_width), num_masks + 1]) mask_list = tf.split( axis=3, num_or_size_splits=num_masks + 1, value=masks) output = mask_list[0] * input_image for layer, mask in zip(transformed, mask_list[1:]): # TODO(mbz): take another look at this logic and verify. output = output[:, :img_height, :img_width, :] layer = layer[:, :img_height, :img_width, :] output += layer * mask # Map to softmax digits if self.is_per_pixel_softmax: output = tf.layers.dense( output, self.hparams.problem.num_channels * 256, name="logits") mid_outputs = [enc0, enc1, enc4, enc5, enc6] return output, lstm_state, mid_outputs
def conv2d_fixed_padding(inputs, filters, kernel_size, strides, data_format="channels_first", use_td=False, targeting_rate=None, keep_prob=None, is_training=None): """Strided 2-D convolution with explicit padding. The padding is consistent and is based only on `kernel_size`, not on the dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone). Args: inputs: `Tensor` of size `[batch, channels, height_in, width_in]`. filters: `int` number of filters in the convolution. kernel_size: `int` size of the kernel to be used in the convolution. strides: `int` strides of the convolution. data_format: `str` either "channels_first" for `[batch, channels, height, width]` or "channels_last for `[batch, height, width, channels]`. use_td: `str` one of "weight" or "unit". Set to False or "" to disable targeted dropout. targeting_rate: `float` proportion of weights to target with targeted dropout. keep_prob: `float` keep probability for targeted dropout. is_training: `bool` for whether the model is in training. Returns: A `Tensor` of shape `[batch, filters, height_out, width_out]`. Raises: Exception: if use_td is not valid. """ if strides > 1: inputs = fixed_padding(inputs, kernel_size, data_format=data_format) if use_td: inputs_shape = common_layers.shape_list(inputs) if use_td == "weight": if data_format == "channels_last": size = kernel_size * kernel_size * inputs_shape[-1] else: size = kernel_size * kernel_size * inputs_shape[1] targeting_count = targeting_rate * tf.to_float(size) targeting_fn = common_layers.weight_targeting elif use_td == "unit": targeting_count = targeting_rate * filters targeting_fn = common_layers.unit_targeting else: raise Exception("Unrecognized targeted dropout type: %s" % use_td) y = common_layers.td_conv( inputs, filters, kernel_size, targeting_count, targeting_fn, keep_prob, is_training, do_prune=True, strides=strides, padding=("SAME" if strides == 1 else "VALID"), data_format=data_format, use_bias=False, kernel_initializer=tf.variance_scaling_initializer()) else: y = tf.layers.conv2d( inputs=inputs, filters=filters, kernel_size=kernel_size, strides=strides, padding=("SAME" if strides == 1 else "VALID"), use_bias=False, kernel_initializer=tf.variance_scaling_initializer(), data_format=data_format) return y
def targets_bottom(self, x): with tf.variable_scope(self.name): return tf.zeros([ common_layers.shape_list(x)[0], 1, 1, self._model_hparams.hidden_size ])
def bottom(self, x): """Use batchnorm instead of CMVN and shorten the stft with strided convs. Args: x: float32 tensor with shape [batch_size, len, 1, freqs * channels] Returns: float32 tensor with shape [batch_size, shorter_len, 1, hidden_size] """ inputs = x p = self._model_hparams num_mel_bins = p.audio_num_mel_bins num_channels = 3 if p.audio_add_delta_deltas else 1 with tf.variable_scope(self.name): if p.audio_preproc_in_bottom: # Compute filterbanks with tf.variable_scope("fbanks"): waveforms = tf.squeeze(inputs, [2, 3]) mel_fbanks = common_audio.compute_mel_filterbank_features( waveforms, sample_rate=p.audio_sample_rate, dither=p.audio_dither, preemphasis=p.audio_preemphasis, frame_length=p.audio_frame_length, frame_step=p.audio_frame_step, lower_edge_hertz=p.audio_lower_edge_hertz, upper_edge_hertz=p.audio_upper_edge_hertz, num_mel_bins=p.audio_num_mel_bins, apply_mask=True) if p.audio_add_delta_deltas: mel_fbanks = common_audio.add_delta_deltas(mel_fbanks) x = tf.reshape( mel_fbanks, common_layers.shape_list(mel_fbanks)[:2] + [num_mel_bins, num_channels]) nonpadding_mask = 1. - common_attention.embedding_to_padding( x) num_of_nonpadding_elements = tf.reduce_sum( nonpadding_mask) * num_mel_bins * num_channels # This replaces CMVN estimation on data var_epsilon = 1e-09 mean = tf.reduce_sum(x, axis=[ 1 ], keepdims=True) / num_of_nonpadding_elements variance = ( num_of_nonpadding_elements * mean**2. - 2. * mean * tf.reduce_sum(x, axis=[1], keepdims=True) + tf.reduce_sum(x**2, axis=[1], keepdims=True) ) / num_of_nonpadding_elements x = (x - mean) * tf.rsqrt(variance + var_epsilon) * tf.expand_dims( nonpadding_mask, -1) else: x = inputs # The convention is that the models are flattened along the spatial, # dimensions, thus the speech preprocessor treats frequencies and # channels as image colors (last axis) x.set_shape([None, None, num_mel_bins, num_channels]) # TODO(chorowski): how to specify bottom's hparams and avoid hardcoding? x = tf.pad(x, [[0, 0], [0, 8], [0, 0], [0, 0]]) for _ in range(2): x = tf.layers.conv2d(x, 128, (3, 3), (2, 2), use_bias=False) x = common_layers.layer_norm(x) x = tf.nn.relu(x) xshape = common_layers.shape_list(x) # apply a conv that will remove all frequencies and at the same time # project the output into desired hidden_size x = tf.pad(x, [[0, 0], [0, 2], [0, 0], [0, 0]]) x = tf.layers.conv2d(x, p.hidden_size, (3, xshape[2]), use_bias=False) assert common_layers.shape_list(x)[2] == 1 x = common_layers.layer_norm(x) x = tf.nn.relu(x) return x
def __process(self, all_frames, all_actions, all_rewards, all_raw_frames): """Main video processing function.""" hparams = self.hparams all_frames_copy = [tf.identity(frame) for frame in all_frames] orig_frame_shape = common_layers.shape_list(all_frames[0]) batch_size = orig_frame_shape[0] ss_func = self.get_scheduled_sample_func(batch_size) target_frames = [] extra_loss = 0.0 # Any extra info required by the model goes into here. video_features = self.video_features(all_frames, all_actions, all_rewards, all_raw_frames) num_frames = len(all_frames) if self.is_recurrent_model: input_index_range = range(num_frames - 1) else: input_index_range = range(hparams.video_num_target_frames) # Setup the internal states as well as an auxiliary tf op # to enforce syncronization between prediction steps. if self.internal_states is None: internal_states = None sync_op = tf.no_op() else: internal_states = self.load_internal_states_ops() with tf.control_dependencies(flat_lists(internal_states)): sync_op = tf.no_op() res_frames, sampled_frames, res_rewards = [], [], [] for i in input_index_range: with tf.control_dependencies([sync_op]): frames, actions, rewards, target_index = self.__get_next_inputs( i, all_frames, all_actions, all_rewards) target_frame = all_frames[target_index] target_frames.append(tf.identity(target_frame)) with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): func_in = (frames, actions, rewards, target_frame, internal_states, video_features) func_out = self.next_frame(*func_in) res_frame, res_reward, res_extra_loss, internal_states = func_out res_frames.append(res_frame) res_rewards.append(res_reward) extra_loss += res_extra_loss / float( len(input_index_range)) # Syncronizing the internals states # Some Tensflow Magic to make sure everything happens as it should. with tf.control_dependencies([res_frame]): sync_op = tf.no_op() if self.is_predicting and self.is_recurrent_model and i == 0: # The internal state save happens at the end of the 1st iteration # which essentially allows recurrent models to continue # running after one prediction. # Necessary for planning/rl applications. save_ops = self.save_internal_states_ops( internal_states) with tf.control_dependencies(flat_lists(save_ops)): sync_op = tf.no_op() # Only for Softmax loss: sample frame so we can keep iterating. sampled_frame = self.get_sampled_frame(res_frame) sampled_frames.append(sampled_frame) # Check whether we are done with context frames or not if self.is_recurrent_model: done_warm_start = (i >= hparams.video_num_input_frames - 1) else: done_warm_start = True # Always true for non-reccurent networks. if self.is_predicting and done_warm_start: all_frames[target_index] = sampled_frame # Scheduled sampling during training. if self.is_training: groundtruth_items = [target_frame] generated_items = [sampled_frame] ss_frame, = self.get_scheduled_sample_inputs( done_warm_start, groundtruth_items, generated_items, ss_func) all_frames[target_index] = ss_frame video_extra_loss = self.video_extra_loss(sampled_frames, target_frames, internal_states, video_features) tf.summary.scalar("video_extra_loss", video_extra_loss) extra_loss += video_extra_loss if self.is_recurrent_model: has_input_predictions = hparams.video_num_input_frames > 1 if self.is_training and hparams.internal_loss and has_input_predictions: # add the loss for input frames as well. extra_gts = all_frames_copy[1:hparams.video_num_input_frames] extra_raw_gts = all_raw_frames[1:hparams. video_num_input_frames] extra_pds = res_frames[:hparams.video_num_input_frames - 1] recon_loss = self.get_extra_internal_loss( extra_raw_gts, extra_gts, extra_pds) extra_loss += recon_loss # Cut the predicted input frames. res_frames = res_frames[hparams.video_num_input_frames - 1:] res_rewards = res_rewards[hparams.video_num_input_frames - 1:] sampled_frames = sampled_frames[hparams.video_num_input_frames - 1:] target_frames = target_frames[hparams.video_num_input_frames - 1:] self.visualize_predictions(sampled_frames, target_frames) output_frames = tf.stack(res_frames, axis=1) targets = output_frames if self.has_rewards: output_rewards = tf.stack(res_rewards, axis=1) targets = { "targets": output_frames, "target_reward": output_rewards } return targets, extra_loss
def infer(self, features, *args, **kwargs): # pylint: disable=arguments-differ """Produce predictions from the model by running it.""" del args, kwargs # Inputs and features preparation needed to handle edge cases. if not features: features = {} hparams = self.hparams inputs_old = None if "inputs" in features and len(features["inputs"].shape) < 4: inputs_old = features["inputs"] features["inputs"] = tf.expand_dims(features["inputs"], 2) def logits_to_samples(logits): """Get samples from logits.""" # If the last dimension is 1 then we're using L1/L2 loss. if common_layers.shape_list(logits)[-1] == 1: return tf.to_int32(tf.squeeze(logits, axis=-1)) # Argmax in TF doesn't handle more than 5 dimensions yet. logits_shape = common_layers.shape_list(logits) argmax = tf.argmax(tf.reshape(logits, [-1, logits_shape[-1]]), axis=-1) return tf.reshape(argmax, logits_shape[:-1]) # Get predictions. try: num_channels = hparams.problem.num_channels except AttributeError: num_channels = 1 if "inputs" in features: inputs_shape = common_layers.shape_list(features["inputs"]) targets_shape = [ inputs_shape[0], hparams.video_num_target_frames, inputs_shape[2], inputs_shape[3], num_channels ] else: tf.logging.warn("Guessing targets shape as no inputs are given.") targets_shape = [ hparams.batch_size, hparams.video_num_target_frames, 1, 1, num_channels ] features["targets"] = tf.zeros(targets_shape, dtype=tf.int32) reward_in_mod = "target_reward" in hparams.problem_hparams.modality action_in_mod = "target_action" in hparams.problem_hparams.modality if reward_in_mod: # TODO(lukaszkaiser): this is a hack. get the actual reward history. if "input_reward" not in features: features["input_reward"] = tf.zeros( [inputs_shape[0], inputs_shape[1], 1], dtype=tf.int32) features["target_reward"] = tf.zeros( [targets_shape[0], targets_shape[1], 1], dtype=tf.int32) if action_in_mod and "target_action" not in features: features["target_action"] = tf.zeros( [targets_shape[0], targets_shape[1], 1], dtype=tf.int32) logits, _ = self(features) # pylint: disable=not-callable if isinstance(logits, dict): results = {} for k, v in six.iteritems(logits): results[k] = logits_to_samples(v) results["%s_logits" % k] = v # HACK: bypassing decoding issues. results["outputs"] = results["targets"] results["scores"] = results["targets"] else: results = logits_to_samples(logits) # Restore inputs to not confuse Estimator in edge cases. if inputs_old is not None: features["inputs"] = inputs_old # Return results. return results
def targets_bottom(self, x): with tf.variable_scope(self.name): return tf.zeros( [common_layers.shape_list(x)[0], 1, 1, self._body_input_depth])
def construct_model(self, images, actions, rewards): """Builds the stochastic model. The model first encodes all the images (x_t) in the sequence using the encoder. Let"s call the output e_t. Then it predicts the latent state of the next frame using a recurrent posterior network z ~ q(z|e_{0:t}) = N(mu(e_{0:t}), sigma(e_{0:t})). Another recurrent network predicts the embedding of the next frame using the approximated posterior e_{t+1} = p(e_{t+1}|e_{0:t}, z) Finally, the decoder decodes e_{t+1} into x_{t+1}. Skip connections from encoder to decoder help with reconstruction. Args: images: tensor of ground truth image sequences actions: NOT used list of action tensors rewards: NOT used list of reward tensors Returns: gen_images: generated images fakr_rewards: input rewards as reward prediction! pred_mu: predited means of posterior pred_logvar: predicted log(var) of posterior """ # model does not support action conditioned and reward prediction fake_reward_prediction = rewards del actions, rewards z_dim = self.hparams.z_dim g_dim = self.hparams.g_dim rnn_size = self.hparams.rnn_size posterior_rnn_layers = self.hparams.posterior_rnn_layers predictor_rnn_layers = self.hparams.predictor_rnn_layers context_frames = self.hparams.video_num_input_frames seq_len, batch_size, _, _, color_channels = common_layers.shape_list( images) # LSTM initial sizesstates. predictor_states = [None] * predictor_rnn_layers posterior_states = [None] * posterior_rnn_layers tf.logging.info(">>>> Encoding") # Encoding: enc_images, enc_skips = [], [] images = tf.unstack(images, axis=0) for i, image in enumerate(images): with tf.variable_scope("encoder", reuse=tf.AUTO_REUSE): enc, skips = self.encoder(image, rnn_size) enc = tfcl.flatten(enc) enc_images.append(enc) enc_skips.append(skips) tf.logging.info(">>>> Prediction") # Prediction pred_enc, pred_mu, pred_logvar = [], [], [] for i in range(1, seq_len): with tf.variable_scope("prediction", reuse=tf.AUTO_REUSE): # current encoding h_current = enc_images[i - 1] # target encoding h_target = enc_images[i] z = tf.random_normal([batch_size, z_dim], 0, 1, dtype=tf.float32) mu, logvar = tf.zeros_like(z), tf.zeros_like(z) # Only use Posterior if it's training time if self.hparams.mode == tf.estimator.ModeKeys.TRAIN: mu, logvar, posterior_states = self.lstm_gaussian( h_target, posterior_states, rnn_size, z_dim, posterior_rnn_layers) # The original implementation has a multiplier of 0.5 # Removed here for simplicity i.e. replacing var with std z = z * tf.exp(logvar) + mu # Predict output encoding h_pred, predictor_states = self.stacked_lstm( tf.concat([h_current, z], axis=1), predictor_states, rnn_size, g_dim, predictor_rnn_layers) pred_enc.append(h_pred) pred_mu.append(mu) pred_logvar.append(logvar) tf.logging.info(">>>> Decoding") # Decoding gen_images = [] for i in range(seq_len - 1): with tf.variable_scope("decoding", reuse=tf.AUTO_REUSE): # use skip values of last available frame skip_index = min(context_frames - 1, i) h_pred = tf.reshape(pred_enc[i], [batch_size, 1, 1, g_dim]) x_pred = self.decoder(h_pred, enc_skips[skip_index], color_channels) gen_images.append(x_pred) tf.logging.info(">>>> Done") gen_images = tf.stack(gen_images, axis=0) return gen_images, fake_reward_prediction, pred_mu, pred_logvar
def body(self, features): hparams = self.hparams batch_size = common_layers.shape_list(features["inputs"])[0] # Swap time and batch axes. input_frames = common_video.swap_time_and_batch_axes(features["inputs"]) target_frames = common_video.swap_time_and_batch_axes(features["targets"]) # Get actions if exist otherwise use zeros input_actions = self.get_input_if_exists( features, "input_action", batch_size, hparams.video_num_input_frames) target_actions = self.get_input_if_exists( features, "target_action", batch_size, hparams.video_num_target_frames) # Get rewards if exist otherwise use zeros input_rewards = self.get_input_if_exists( features, "input_reward", batch_size, hparams.video_num_input_frames) target_rewards = self.get_input_if_exists( features, "target_reward", batch_size, hparams.video_num_target_frames) all_actions = tf.concat([input_actions, target_actions], axis=0) all_rewards = tf.concat([input_rewards, target_rewards], axis=0) all_frames = tf.concat([input_frames, target_frames], axis=0) # Each image is being used twice, in latent tower and main tower. # This is to make sure we are using the *same* image for both, ... # ... given how TF queues work. # NOT sure if this is required at all. Doesn"t hurt though! :) all_frames = tf.identity(all_frames) gen_images, gen_rewards, latent_means, latent_stds = self.construct_model( images=all_frames, actions=all_actions, rewards=all_rewards, ) extra_loss = self.get_extra_loss( latent_means=latent_means, latent_stds=latent_stds, true_frames=all_frames, gen_frames=gen_images) # Visualize predictions in Tensorboard if self.is_training: self.visualize_predictions(all_frames[1:], gen_images) # Ignore the predictions from the input frames. # This is NOT the same as original paper/implementation. predictions = gen_images[hparams.video_num_input_frames-1:] reward_pred = gen_rewards[hparams.video_num_input_frames-1:] reward_pred = tf.squeeze(reward_pred, axis=2) # Remove extra dimension. # Swap back time and batch axes. predictions = common_video.swap_time_and_batch_axes(predictions) reward_pred = common_video.swap_time_and_batch_axes(reward_pred) if self.is_training and hparams.internal_loss: # add the loss for input frames as well. extra_gts = all_frames[1:hparams.video_num_input_frames] extra_gts = common_video.swap_time_and_batch_axes(extra_gts) extra_pds = gen_images[:hparams.video_num_input_frames-1] extra_pds = common_video.swap_time_and_batch_axes(extra_pds) extra_raw_gts = features["inputs_raw"][:, 1:] recon_loss = self.get_extra_internal_loss( extra_raw_gts, extra_gts, extra_pds) extra_loss += recon_loss return_targets = predictions if hparams.reward_prediction: return_targets = {"targets": predictions, "target_reward": reward_pred} return return_targets, extra_loss
def construct_model(self, images, actions, rewards): """Build convolutional lstm video predictor using CDNA, or DNA. Args: images: list of tensors of ground truth image sequences there should be a 4D image ?xWxHxC for each timestep actions: list of action tensors each action should be in the shape ?x1xZ rewards: list of reward tensors each reward should be in the shape ?x1xZ Returns: gen_images: predicted future image frames gen_rewards: predicted future rewards latent_mean: mean of approximated posterior latent_std: std of approximated posterior Raises: ValueError: if more than 1 mask specified for DNA model. """ context_frames = self.hparams.video_num_input_frames buffer_size = self.hparams.reward_prediction_buffer_size if buffer_size == 0: buffer_size = context_frames if buffer_size > context_frames: raise ValueError("Buffer size is bigger than context frames %d %d." % (buffer_size, context_frames)) batch_size = common_layers.shape_list(images[0])[0] ss_func = self.get_scheduled_sample_func(batch_size) def process_single_frame(prev_outputs, inputs): """Process a single frame of the video.""" cur_image, input_reward, action = inputs time_step, prev_image, prev_reward, frame_buf, lstm_states = prev_outputs # sample from softmax (by argmax). this is noop for non-softmax loss. prev_image = self.get_sampled_frame(prev_image) generated_items = [prev_image] groundtruth_items = [cur_image] done_warm_start = tf.greater(time_step, context_frames - 1) input_image, = self.get_scheduled_sample_inputs( done_warm_start, groundtruth_items, generated_items, ss_func) # Prediction pred_image, lstm_states, _ = self.construct_predictive_tower( input_image, None, action, lstm_states, latent) if self.hparams.reward_prediction: reward_input_image = self.get_sampled_frame(pred_image) if self.hparams.reward_prediction_stop_gradient: reward_input_image = tf.stop_gradient(reward_input_image) with tf.control_dependencies([time_step]): frame_buf = [reward_input_image] + frame_buf[:-1] pred_reward = self.reward_prediction(frame_buf, None, action, latent) pred_reward = common_video.decode_to_shape( pred_reward, common_layers.shape_list(input_reward), "reward_dec") else: pred_reward = prev_reward time_step += 1 outputs = (time_step, pred_image, pred_reward, frame_buf, lstm_states) return outputs # Latent tower latent = None if self.hparams.stochastic_model: latent_mean, latent_std = self.construct_latent_tower(images, time_axis=0) latent = common_video.get_gaussian_tensor(latent_mean, latent_std) # HACK: Do first step outside to initialize all the variables lstm_states = [None] * (5 if self.hparams.small_mode else 7) frame_buffer = [tf.zeros_like(images[0])] * buffer_size inputs = images[0], rewards[0], actions[0] init_image_shape = common_layers.shape_list(images[0]) if self.is_per_pixel_softmax: init_image_shape[-1] *= 256 init_image = tf.zeros(init_image_shape, dtype=images.dtype) prev_outputs = (tf.constant(0), init_image, tf.zeros_like(rewards[0]), frame_buffer, lstm_states) initializers = process_single_frame(prev_outputs, inputs) first_gen_images = tf.expand_dims(initializers[1], axis=0) first_gen_rewards = tf.expand_dims(initializers[2], axis=0) inputs = (images[1:-1], rewards[1:-1], actions[1:-1]) outputs = tf.scan(process_single_frame, inputs, initializers) gen_images, gen_rewards = outputs[1:3] gen_images = tf.concat((first_gen_images, gen_images), axis=0) gen_rewards = tf.concat((first_gen_rewards, gen_rewards), axis=0) if self.hparams.stochastic_model: return gen_images, gen_rewards, [latent_mean], [latent_std] else: return gen_images, gen_rewards, None, None
def next_frame(self, frames, actions, rewards, target_frame, internal_states, video_extra): del rewards, video_extra hparams = self.hparams filters = hparams.hidden_size kernel2 = (4, 4) action = actions[-1] # Stack the inputs. if internal_states is not None and hparams.concat_internal_states: # Use the first part of the first internal state if asked to concatenate. batch_size = common_layers.shape_list(frames[0])[0] internal_state = internal_states[0][0][:batch_size, :, :, :] stacked_frames = tf.concat(frames + [internal_state], axis=-1) else: stacked_frames = tf.concat(frames, axis=-1) inputs_shape = common_layers.shape_list(stacked_frames) # Update internal states early if requested. if hparams.concat_internal_states: internal_states = self.update_internal_states_early( internal_states, frames) # Using non-zero bias initializer below for edge cases of uniform inputs. x = tf.layers.dense( stacked_frames, filters, name="inputs_embed", bias_initializer=tf.random_normal_initializer(stddev=0.01)) x = common_attention.add_timing_signal_nd(x) # Down-stride. layer_inputs = [x] for i in range(hparams.num_compress_steps): with tf.variable_scope("downstride%d" % i): layer_inputs.append(x) x = tf.nn.dropout(x, 1.0 - self.hparams.dropout) x = common_layers.make_even_size(x) if i < hparams.filter_double_steps: filters *= 2 x = common_attention.add_timing_signal_nd(x) x = tf.layers.conv2d(x, filters, kernel2, activation=common_layers.belu, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) if self.has_actions: with tf.variable_scope("policy"): x_flat = tf.layers.flatten(x) policy_pred = tf.layers.dense(x_flat, self.hparams.problem.num_actions) value_pred = tf.layers.dense(x_flat, 1) value_pred = tf.squeeze(value_pred, axis=-1) else: policy_pred, value_pred = None, None # Add embedded action if present. if self.has_actions: x = common_video.inject_additional_input( x, action, "action_enc", hparams.action_injection) # Inject latent if present. Only for stochastic models. x, extra_loss = self.inject_latent(x, frames, target_frame, action) x_mid = tf.reduce_mean(x, axis=[1, 2], keepdims=True) x, internal_states = self.middle_network(x, internal_states) # Up-convolve. layer_inputs = list(reversed(layer_inputs)) for i in range(hparams.num_compress_steps): with tf.variable_scope("upstride%d" % i): x = tf.nn.dropout(x, 1.0 - self.hparams.dropout) if self.has_actions: x = common_video.inject_additional_input( x, action, "action_enc", hparams.action_injection) if i >= hparams.num_compress_steps - hparams.filter_double_steps: filters //= 2 x = tf.layers.conv2d_transpose( x, filters, kernel2, activation=common_layers.belu, strides=(2, 2), padding="SAME") y = layer_inputs[i] shape = common_layers.shape_list(y) x = x[:, :shape[1], :shape[2], :] x = common_layers.layer_norm(x + y) x = common_attention.add_timing_signal_nd(x) # Cut down to original size. x = x[:, :inputs_shape[1], :inputs_shape[2], :] x_fin = tf.reduce_mean(x, axis=[1, 2], keepdims=True) if self.is_per_pixel_softmax: x = tf.layers.dense(x, hparams.problem.num_channels * 256, name="logits") else: x = tf.layers.dense(x, hparams.problem.num_channels, name="logits") reward_pred = None if self.has_rewards: # Reward prediction based on middle and final logits. reward_pred = tf.concat([x_mid, x_fin], axis=-1) reward_pred = tf.nn.relu(tf.layers.dense( reward_pred, 128, name="reward_pred")) reward_pred = tf.squeeze(reward_pred, axis=1) # Remove extra dims reward_pred = tf.squeeze(reward_pred, axis=1) # Remove extra dims return x, reward_pred, policy_pred, value_pred, extra_loss, internal_states
def lstm_attention_search_based_decoder(inputs, hparams, train, name, initial_state, encoder_outputs, build_storage, storage, n): """Run LSTM cell with attention on inputs of shape [batch x time x size].""" def dropout_lstm_cell(): return tf.contrib.rnn.DropoutWrapper( LSTMShallowFusionCell(hparams.hidden_size, build_storage, storage), input_keep_prob=1.0 - hparams.dropout * tf.to_float(train)) layers = [dropout_lstm_cell() for _ in range(hparams.num_hidden_layers)] if hparams.attention_mechanism == "luong": attention_mechanism_class = tf.contrib.seq2seq.LuongAttention elif hparams.attention_mechanism == "bahdanau": attention_mechanism_class = tf.contrib.seq2seq.BahdanauAttention else: raise ValueError("Unknown hparams.attention_mechanism = %s, must be " "luong or bahdanau." % hparams.attention_mechanism) attention_mechanism = attention_mechanism_class(hparams.hidden_size, encoder_outputs) if not build_storage: p_copy = [ tf.TensorArray(tf.float32, size=tf.shape(inputs)[1], dynamic_size=True, name='dzeta_dot_q'), tf.TensorArray(tf.float32, size=tf.shape(inputs)[1], dynamic_size=True, name='1_dzeta') ] else: p_copy = None # TODO: add fusion_type in hparams cell = AttentionWrapperSearchBased( tf.nn.rnn_cell.MultiRNNCell(layers), [attention_mechanism] * hparams.num_heads, storage=storage, build_storage=build_storage, p_copy=p_copy, start_index=n, attention_layer_size=[hparams.attention_layer_size] * hparams.num_heads, output_attention=(hparams.output_attention == 1)) batch_size = common_layers.shape_list(inputs)[0] initial_state = cell.zero_state(batch_size, tf.float32).clone(cell_state=initial_state) with tf.variable_scope(name): output, state = tf.nn.dynamic_rnn(cell, inputs, initial_state=initial_state, dtype=tf.float32, time_major=False) # For multi-head attention project output back to hidden size if hparams.output_attention == 1 and hparams.num_heads > 1: output = tf.layers.dense(output, hparams.hidden_size) return output, p_copy
def body(self, features): hp = self.hparams # pylint: disable=eval-used if hp.image_input_type == "image": image_feat = vqa_layers.image_embedding( features["inputs"], model_fn=eval(hp.image_model_fn), trainable=hp.train_resnet, is_training=hp.mode == tf.estimator.ModeKeys.TRAIN) else: image_feat = features["inputs"] image_feat = common_layers.flatten4d3d(image_feat) image_feat = common_layers.dense(image_feat, hp.hidden_size) utils.collect_named_outputs("norms", "image_feat_after_proj", tf.norm(image_feat, axis=-1)) question = common_layers.flatten4d3d(features["question"]) utils.collect_named_outputs("norms", "question_embedding", tf.norm(question, axis=-1)) (encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias) = prepare_image_question_encoder( image_feat, question, hp) encoder_input = tf.nn.dropout(encoder_input, keep_prob=1. - hp.layer_prepostprocess_dropout) encoder_output, _ = recurrent_transformer_decoder( encoder_input, None, encoder_self_attention_bias, None, hp, name="encoder") utils.collect_named_outputs("norms", "encoder_output", tf.norm(encoder_output, axis=-1)) # scale query by sqrt(hidden_size) query = tf.get_variable("query", [hp.hidden_size]) * hp.hidden_size**0.5 query = tf.expand_dims(tf.expand_dims(query, axis=0), axis=0) batch_size = common_layers.shape_list(encoder_input)[0] query = tf.tile(query, [batch_size, 1, 1]) query = tf.nn.dropout(query, keep_prob=1. - hp.layer_prepostprocess_dropout) decoder_output, _ = recurrent_transformer_decoder( query, encoder_output, None, encoder_decoder_attention_bias, hp, name="decoder") utils.collect_named_outputs("norms", "decoder_output", tf.norm(decoder_output, axis=-1)) norm_tensors = utils.convert_collection_to_dict("norms") vqa_layers.summarize_tensors(norm_tensors, tag="norms/") # Expand dimension 1 and 2 return tf.expand_dims(decoder_output, axis=1)
def conv(name, x, output_channels, filter_size=None, stride=None, logscale_factor=3.0, apply_actnorm=True, conv_init="default", dilations=None): """Convolutional layer with edge bias padding and optional actnorm. If x is 5-dimensional, actnorm is applied independently across every time-step. Args: name: variable scope. x: 4-D Tensor or 5-D Tensor of shape NHWC or NTHWC output_channels: Number of output channels. filter_size: list of ints, if None [3, 3] and [2, 3, 3] are defaults for 4-D and 5-D input tensors respectively. stride: list of ints, default stride: 1 logscale_factor: see actnorm for parameter meaning. apply_actnorm: if apply_actnorm the activations of the first minibatch have zero mean and unit variance. Else, there is no scaling applied. conv_init: default or zeros. default is a normal distribution with 0.05 std. dilations: List of integers, apply dilations. Returns: x: actnorm(conv2d(x)) Raises: ValueError: if init is set to "zeros" and apply_actnorm is set to True. """ if conv_init == "zeros" and apply_actnorm: raise ValueError("apply_actnorm is unstable when init is set to zeros.") x_shape = common_layers.shape_list(x) is_2d = len(x_shape) == 4 num_steps = x_shape[1] # set filter_size, stride and in_channels if is_2d: if filter_size is None: filter_size = [3, 3] if stride is None: stride = [1, 1] if dilations is None: dilations = [1, 1, 1, 1] actnorm_func = actnorm x = add_edge_bias(x, filter_size=filter_size) conv_filter = tf.nn.conv2d else: if filter_size is None: if num_steps == 1: filter_size = [1, 3, 3] else: filter_size = [2, 3, 3] if stride is None: stride = [1, 1, 1] if dilations is None: dilations = [1, 1, 1, 1, 1] actnorm_func = actnorm_3d x = time_pad(x, filter_size=filter_size, dilations=dilations) conv_filter = tf.nn.conv3d in_channels = common_layers.shape_list(x)[-1] filter_shape = filter_size + [in_channels, output_channels] stride_shape = [1] + stride + [1] with tf.variable_scope(name, reuse=tf.AUTO_REUSE): if conv_init == "default": initializer = default_initializer() elif conv_init == "zeros": initializer = tf.zeros_initializer() w = tf.get_variable("W", filter_shape, tf.float32, initializer=initializer) x = conv_filter(x, w, stride_shape, padding="VALID", dilations=dilations) if apply_actnorm: x, _ = actnorm_func("actnorm", x, logscale_factor=logscale_factor) else: x += tf.get_variable("b", [1, 1, 1, output_channels], initializer=tf.zeros_initializer()) logs = tf.get_variable("logs", [1, output_channels], initializer=tf.zeros_initializer()) x *= tf.exp(logs * logscale_factor) return x
def ae_transformer_internal(inputs, targets, target_space, hparams, cache=None, predict_mask=1.0): """AE Transformer, main step used for training.""" # Summaries break with the do_refine cond, turn them off in that case. global _DO_SUMMARIES if hparams.do_refine: _DO_SUMMARIES = False # Prepare. if inputs is not None: batch_size = common_layers.shape_list(inputs)[0] else: batch_size = common_layers.shape_list(targets)[0] targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size]) # Encoder. if inputs is not None: inputs = common_layers.flatten4d3d(inputs) inputs, ed = encode(inputs, target_space, hparams, "input_enc") inputs_ex, ed_ex = inputs, ed else: ed, inputs_ex, ed_ex = None, None, None # Autoencoding. losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)} if hparams.do_ae: # flatten here original_targets_shape = tf.shape(targets) if hparams.task == "image": cia.maybe_reshape_4d_to_3d(targets) if hparams.task == "translate": max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1) else: assert hparams.task == "image" max_targets_len_from_inputs = targets targets, _ = common_layers.pad_to_same_length( targets, max_targets_len_from_inputs, final_length_divisible_by=2**hparams.num_compress_steps) targets_c = compress(targets, inputs, False, hparams, "compress") if hparams.mode != tf.estimator.ModeKeys.PREDICT: # Compress and bottleneck. latents_dense, latents_discrete, extra_loss, embed = hparams.bottleneck( x=targets_c, filter_size=hparams.compress_filter_size, name="vc", mode=hparams.mode) if _DO_SUMMARIES: tf.summary.histogram("b0", tf.reshape(latents_discrete[:, 0, :], [-1])) pc = common_layers.inverse_exp_decay(hparams.startup_steps) pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 cond = tf.less(tf.random_uniform([batch_size]), pc) latents_dense = tf.where(cond, latents_dense, targets_c) # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean. losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond)) # Extra loss predicting latent code from input. Discrete only. if hparams.bottleneck_kind not in ["dense", "vae"]: latents_pred = decode_transformer( inputs_ex, ed_ex, embed(latents_discrete), hparams, "extra", task="translate") _, latent_pred_loss = ae_latent_softmax( latents_pred, tf.stop_gradient(latents_discrete), hparams) losses["latent_pred"] = tf.reduce_mean( latent_pred_loss * tf.to_float(cond)) else: inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c") losses["latent_pred"] = tf.reduce_mean((inputs_c - targets_c)**2) * 20 def bn_inputs(): with tf.variable_scope(tf.get_variable_scope(), reuse=True): bn, _, _, _ = hparams.bottleneck( x=inputs_c, filter_size=hparams.compress_filter_size, name="vc", mode=hparams.mode) return bn inputs_c = bn_inputs ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5 ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 latents_dense = tf.where(tf.less(tf.random_uniform([batch_size]), ptc), latents_dense, inputs_c) else: if hparams.bottleneck_kind in ["dense", "vae"]: inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c") latents_dense, _, _, _ = hparams.bottleneck( x=inputs_c, filter_size=hparams.compress_filter_size, name="vc", mode=hparams.mode) else: latent_len = common_layers.shape_list(targets_c)[1] _, _, _, embed = hparams.bottleneck( x=targets_c, filter_size=hparams.compress_filter_size, name="vc") latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :]) if cache is None: cache = ae_latent_sample( latents_dense, inputs_ex, ed_ex, embed, 16, hparams) latents_dense = embed(cache) # Postprocess. d = latents_dense pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size]) pos = pos[:, :common_layers.shape_list(latents_dense)[1] + 1, :, :] latents_dense = tf.pad(latents_dense, [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos # decompressing the dense latents for i in range(hparams.num_compress_steps): j = hparams.num_compress_steps - i - 1 d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j) if hparams.do_attend_decompress: d = attend(d, inputs, hparams, "decompress_attend_%d" % j) d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j) # Masking. if hparams.do_mask: masking = common_layers.inverse_lin_decay(hparams.mask_startup_steps) masking *= common_layers.inverse_exp_decay( hparams.mask_startup_steps // 4) # Not much at start. if not hparams.do_refine: masking -= tf.random_uniform([]) * hparams.unmasked_percentage masking = tf.minimum(tf.maximum(masking, 0.0), 1.0) if hparams.use_predict_mask: masking = predict_mask if hparams.mode == tf.estimator.ModeKeys.PREDICT: masking = predict_mask mask = tf.less(masking, tf.random_uniform( common_layers.shape_list(targets)[:-1])) mask = tf.expand_dims(tf.to_float(mask), 3) # targets is always [batch, length, 1, depth] targets = mask * targets + (1.0 - mask) * d # reshape back to 4d here if hparams.task == "image": targets = tf.reshape(targets, original_targets_shape) if hparams.task == "translate": targets = tf.concat([tf.reverse(latents_dense, [1]), targets], axis=1) res = decode_transformer(inputs, ed, targets, hparams, "decoder", causal=hparams.causal) if hparams.do_ae: if hparams.task == "translate": res = res[:, common_layers.shape_list(latents_dense)[1]:, :, :] if hparams.do_mask and hparams.do_refine: def refine_res(): # return residual_conv(res, 1, (5, 1), hparams, "refine") r, _ = encode(tf.squeeze(res, axis=[2]), target_space, hparams, "refine_enc") return tf.expand_dims(r, axis=2) masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3]) all_masked = tf.less(masked_batches, 0.1) res = tf.where(all_masked, refine_res(), res) # We'll start training the extra model of latents after mask_startup_steps. nonlatent_steps = hparams.mask_startup_steps latent_time = tf.less(nonlatent_steps, tf.to_int32(tf.train.get_global_step())) losses["latent_pred"] *= tf.to_float(latent_time) return res, losses, cache
def render2cmd_v3_internal(self, features, hparams, train): # inputs and targets are both sequences with # shape = [batch, seq_len, 1, hparams.problem.feature_dim] targets = features['targets'] source = features['source'] losses = {} sampled_bottleneck = self.pretrained_visual_encoder(features, hparams) if hparams.sg_bottleneck: sampled_bottleneck = tf.stop_gradient(sampled_bottleneck) with tf.variable_scope('render2cmd_v3_internal'): # override bottleneck, or return it, if requested if 'bottleneck' in features: if common_layers.shape_list(features['bottleneck'])[0] == 0: # return sampled_bottleneck, # set losses['training'] = 0 so self.top() doesn't get called on it return sampled_bottleneck, {'training': 0.0} else: # we want to use the given bottleneck sampled_bottleneck = features['bottleneck'] # finalize bottleneck unbottleneck_dim = hparams.hidden_size * 2 # twice because using LSTM if hparams.twice_decoder: unbottleneck_dim = unbottleneck_dim * 2 dec_initial_state = [] # LSTM encoder _, encoder_output_states = self.lstm_encoder( common_layers.flatten4d3d(source), hparams) print(features['targets'].shape) print('run stacking...') print(sampled_bottleneck.shape) print(source.shape) # input() for hi in range(hparams.num_hidden_layers): unbottleneck = self.unbottleneck(sampled_bottleneck, unbottleneck_dim, name_append='_{}'.format(hi)) c, h = encoder_output_states[hi] # print(unbottleneck.shape) #print(c.shape, h.shape) first_dim = common_layers.shape_list(unbottleneck)[0] # print(first_dim) #c = tf.tile(c,[first_dim,1]) #h = tf.tile(h,[first_dim,1]) # input() dec_initial_state.append( tf.nn.rnn_cell.LSTMStateTuple( c=tf.concat( [unbottleneck[:, :unbottleneck_dim // 2], c], 1), h=tf.concat([unbottleneck[:, unbottleneck_dim // 2:], h], 1))) dec_initial_state = tuple(dec_initial_state) #print('checkshape dec_initial_state') # print(dec_initial_state) # input() shifted_targets = common_layers.shift_right(targets) # Add 1 to account for the padding added to the left from shift_right targets_length = common_layers.length_from_embedding( shifted_targets) + 1 # LSTM decoder hparams_decoder = copy.copy(hparams) if hparams.twice_decoder: hparams_decoder.hidden_size = 2 * hparams.hidden_size if hparams.mode == tf.estimator.ModeKeys.PREDICT: decoder_outputs, _ = self.lstm_decoder_infer( common_layers.flatten4d3d(shifted_targets), targets_length, hparams_decoder, features['targets_cls'], train, initial_state=dec_initial_state, bottleneck=sampled_bottleneck) else: decoder_outputs, _ = self.lstm_decoder( common_layers.flatten4d3d(shifted_targets), targets_length, hparams_decoder, features['targets_cls'], train, initial_state=dec_initial_state, bottleneck=sampled_bottleneck) ret = tf.expand_dims(decoder_outputs, axis=2) return ret, losses
def body(self, features): hparams = self.hparams is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN vocab_size = self._problem_hparams.target_modality.top_dimensionality encoder_layers = None self.is1d = hparams.sample_width == 1 if hparams.mode != tf.estimator.ModeKeys.PREDICT: labels = features["targets_raw"] labels_shape = common_layers.shape_list(labels) # handle videos if len(labels.shape) == 5: labels = common_layers.time_to_channels(labels) shape = common_layers.shape_list(labels) x = tf.one_hot(labels, vocab_size) x = self.embed(x) target_codes = x if shape[2] == 1: self.is1d = True # Run encoder. x, encoder_layers = self.encoder(x) # Bottleneck. b, b_loss = self.bottleneck(x) xb_loss = 0.0 b_shape = common_layers.shape_list(b) self._cur_bottleneck_tensor = b b = self.unbottleneck(b, common_layers.shape_list(x)[-1]) if not is_training: x = b else: l = 2**hparams.num_hidden_layers warm_step = int(hparams.bottleneck_warmup_steps * 0.25 * l) nomix_p = common_layers.inverse_lin_decay(warm_step) + 0.01 if common_layers.should_generate_summaries(): tf.summary.scalar("nomix_p_bottleneck", nomix_p) rand = tf.random_uniform(common_layers.shape_list(x)) # This is the distance between b and x. Having this as loss helps learn # the bottleneck function, but if we back-propagated to x it would be # minimized by just setting x=0 and b=0 -- so we don't want too much # of the influence of this, and we stop-gradient to not zero-out x. x_stop = tf.stop_gradient(x) xb_loss = tf.reduce_mean(tf.reduce_sum(tf.square(x_stop - b), axis=-1)) # To prevent this loss from exploding we clip at 1, but anneal clipping. clip_max = 1.0 / common_layers.inverse_exp_decay( warm_step, min_value=0.001) xb_clip = tf.maximum(tf.stop_gradient(xb_loss), clip_max) xb_loss *= clip_max / xb_clip x = tf.where(tf.less(rand, nomix_p), b, x) if hparams.gan_loss_factor != 0.0: # Add a purely sampled batch on which we'll compute the GAN loss. g = self.unbottleneck( self.sample(shape=b_shape), common_layers.shape_list(x)[-1], reuse=True) x = tf.concat([g, x], axis=0) encoder_layers = [tf.concat([l, l], axis=0) for l in encoder_layers] else: if self._cur_bottleneck_tensor is None: b = self.sample() else: b = self._cur_bottleneck_tensor self._cur_bottleneck_tensor = b res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers res_size = min(res_size, hparams.max_hidden_size) x = self.unbottleneck(b, res_size) # Run decoder. x = self.decoder(x, encoder_layers) # Cut to the right size and mix before returning. res = x if hparams.mode != tf.estimator.ModeKeys.PREDICT: res = x[:, :shape[1], :shape[2], :] # Final dense layer. res = tf.layers.dense( res, self.num_channels * hparams.hidden_size, name="res_dense") output_shape = common_layers.shape_list(res)[:-1] + [ self.num_channels, self.hparams.hidden_size ] res = tf.reshape(res, output_shape) if hparams.mode == tf.estimator.ModeKeys.PREDICT: if hparams.use_vq_loss: (reconstr, _, _, _, _) = discretization.vq_loss(res, labels, vocab_size) else: reconstr = tf.layers.dense(res, vocab_size, name="autoencoder_final") return reconstr, {"bottleneck_loss": 0.0} if hparams.gan_loss_factor != 0.0: res_gan, res = tf.split(res, 2, axis=0) # Losses. losses = { "bottleneck_extra": b_loss, "bottleneck_l2": hparams.bottleneck_l2_factor * xb_loss } if hparams.use_vq_loss: vq_temperature = hparams.vq_temperature / common_layers.inverse_exp_decay( hparams.gan_codes_warmup_steps * 1.2, min_value=hparams.vq_temperature * 2) if hparams.mode != tf.estimator.ModeKeys.TRAIN: vq_temperature = None with tf.variable_scope("vq_loss"): (reconstr, _, target_codes, code_loss, targets_loss) = discretization.vq_loss( res, labels, vocab_size, temperature=vq_temperature) losses["code_loss"] = code_loss * hparams.code_loss_factor losses["training"] = targets_loss else: reconstr = tf.layers.dense(res, vocab_size, name="autoencoder_final") targets_loss = tf.losses.sparse_softmax_cross_entropy( logits=tf.reshape(reconstr, labels_shape + [vocab_size]), labels=tf.reshape(labels, labels_shape)) losses["training"] = targets_loss # GAN losses. if hparams.gan_loss_factor != 0.0: update_means_factor = common_layers.inverse_exp_decay( hparams.gan_codes_warmup_steps, min_value=0.0001) if hparams.use_vq_loss: with tf.variable_scope("vq_loss", reuse=True): update_means = tf.less(tf.random_uniform([]), update_means_factor) reconstr_gan, gan_codes, _, code_loss_gan, _ = discretization.vq_loss( res_gan, labels, vocab_size, do_update=update_means, temperature=vq_temperature) reconstr_gan_nonoise = reconstr_gan code_loss_gan *= hparams.code_loss_factor * update_means_factor losses["code_loss_gan"] = code_loss_gan else: reconstr_gan = tf.layers.dense( res_gan, vocab_size, name="autoencoder_final", reuse=True) reconstr_gan_nonoise = reconstr_gan reconstr_gan = self.gumbel_sample(reconstr_gan) # Embed to codes. gan_codes = self.embed(reconstr_gan) # Add GAN loss if requested. gan_loss = 0.0 if hparams.gan_loss_factor != 0.0: self.image_summary("gan", reconstr_gan_nonoise) def discriminate(x): """Run a dioscriminator depending on the hparams.""" if hparams.discriminator == "default": return common_layers.deep_discriminator( x, hparams.discriminator_batchnorm, is_training) elif hparams.discriminator == "patched": return common_layers.patch_discriminator(x) elif hparams.discriminator == "single": return common_layers.single_discriminator( x, hparams.discriminator_size, hparams.discriminator_kernel_size, hparams.discriminator_strides, pure_mean=hparams.discriminator_pure_mean) elif hparams.discriminator == "double": return common_layers.double_discriminator( x, hparams.discriminator_size, hparams.discriminator_kernel_size, hparams.discriminator_strides, pure_mean=hparams.discriminator_pure_mean) else: raise Exception("Unknown discriminator %s" % hparams.discriminator) tc_shape = common_layers.shape_list(target_codes) if len(tc_shape) > 4: target_codes = tf.reshape(target_codes, tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]]) gan_codes = tf.reshape(gan_codes, tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]]) gan_lr = common_layers.inverse_exp_decay( hparams.gan_codes_warmup_steps * 1.5) rev_grad_gan_codes = reverse_gradient(gan_codes, lr=gan_lr) gan_loss = common_layers.sliced_gan_loss( target_codes, rev_grad_gan_codes, discriminate, self.hparams.num_sliced_vecs, do_tanh=hparams.sliced_do_tanh) gan_loss *= hparams.gan_loss_factor * update_means_factor losses["gan_loss"] = -gan_loss self.image_summary("ae", reconstr) logits = tf.reshape(reconstr, labels_shape + [vocab_size]) return logits, losses
def invertible_1x1_conv(name, x, reverse=False): """1X1 convolution on x. The 1X1 convolution is parametrized as P*L*(U + sign(s)*exp(log(s))) where 1. P is a permutation matrix. 2. L is a lower triangular matrix with diagonal entries unity. 3. U is a upper triangular matrix where the diagonal entries zero. 4. s is a vector. sign(s) and P are fixed and the remaining are optimized. P, L, U and s are initialized by the PLU decomposition of a random rotation matrix. Args: name: scope x: Input Tensor. reverse: whether the pass is from z -> x or x -> z. Returns: x_conv: x after a 1X1 convolution is applied on x. objective: sum(log(s)) """ _, height, width, channels = common_layers.shape_list(x) w_shape = [channels, channels] # Random rotation-matrix Q random_matrix = np.random.rand(channels, channels) np_w = scipy.linalg.qr(random_matrix)[0].astype("float32") # Initialize P,L,U and s from the LU decomposition of a random rotation matrix np_p, np_l, np_u = scipy.linalg.lu(np_w) np_s = np.diag(np_u) np_sign_s = np.sign(np_s) np_log_s = np.log(np.abs(np_s)) np_u = np.triu(np_u, k=1) with tf.variable_scope(name, reuse=tf.AUTO_REUSE): p = tf.get_variable("P", initializer=np_p, trainable=False) l = tf.get_variable("L", initializer=np_l) sign_s = tf.get_variable( "sign_S", initializer=np_sign_s, trainable=False) log_s = tf.get_variable("log_S", initializer=np_log_s) u = tf.get_variable("U", initializer=np_u) # W = P * L * (U + sign_s * exp(log_s)) l_mask = np.tril(np.ones([channels, channels], dtype=np.float32), -1) l = l * l_mask + tf.eye(channels, channels) u = u * np.transpose(l_mask) + tf.diag(sign_s * tf.exp(log_s)) w = tf.matmul(p, tf.matmul(l, u)) # If height or width cannot be statically determined then they end up as # tf.int32 tensors, which cannot be directly multiplied with a floating # point tensor without a cast. objective = tf.reduce_sum(log_s) * tf.cast(height * width, log_s.dtype) if not reverse: w = tf.reshape(w, [1, 1] + w_shape) x = tf.nn.conv2d(x, w, [1, 1, 1, 1], "SAME", data_format="NHWC") else: # TODO(b/111271662): Remove when supported. def tpu_inv(m): """tf.linalg.inv workaround until it is supported on TPU.""" q, r = tf.linalg.qr(m) return tf.linalg.triangular_solve(r, tf.transpose(q), lower=False) w_inv = tf.reshape(tpu_inv(w), [1, 1]+w_shape) x = tf.nn.conv2d( x, w_inv, [1, 1, 1, 1], "SAME", data_format="NHWC") objective *= -1 return x, objective
def while_exit_cond(logits_so_far, unused_current_hidden): length = common_layers.shape_list(logits_so_far)[1] return length < max_decode_length
def lstm_decoder_infer(self, inputs, sequence_length, hparams, clss, train, initial_state=None, bottleneck=None): # IN PREDICT MODE, RUN tf.while RNN max_decode_length = 51 batch_size = common_layers.shape_list(inputs)[0] zero_pad, logits_so_far = self.create_initial_input_for_decode( batch_size) layers = contrib_rnn.MultiRNNCell([ self.lstm_cell(hparams, train) for _ in range(hparams.num_hidden_layers) ]) if initial_state is None: raise Exception('initial state should be init from bottleneck!') # append one-hot class to bottleneck, which will be given per step clss = tf.reshape(clss, [-1]) if not hparams.use_cls: clss = tf.zeros_like(clss) if hparams.condition_on_sln: sln = tf.reshape(sequence_length, [-1]) bottleneck = tf.concat((bottleneck, tf.one_hot(clss, hparams.num_categories), tf.one_hot(sln, max_decode_length)), -1) else: bottleneck = tf.concat((bottleneck, tf.one_hot(clss, hparams.num_categories)), -1) def infer_step(logits_so_far, current_hidden): """Inference step of LSTM while loop.""" # unflatten hidden: current_hidden = tuple(tf.nn.rnn_cell.LSTMStateTuple(c=s[0], h=s[1]) for s in current_hidden) # put logits_so_far through top tm = self._problem_hparams.modality['targets'] # need to reuse top params reset_scope = tf.variable_scope(tf.VariableScope(tf.AUTO_REUSE, ''), reuse=tf.AUTO_REUSE, auxiliary_name_scope=False) top_scope = tf.variable_scope('svg_decoder/{}_modality'.format(tm), reuse=tf.AUTO_REUSE) with reset_scope, top_scope: samples_so_far = self.hparams.top['targets']( logits_so_far, None, self.hparams, self.problem_hparams.vocab_size) # append a zero pad to the samples. this effectively shifts the samples # right, but, unlike shift_right, by not removing the last element, we # allow an empty samples_so_far to not be empty after padding samples_so_far = tf.concat([zero_pad, samples_so_far], axis=1) shifted_targets = common_layers.flatten4d3d(samples_so_far) # now take the very last one here, will be the actual input to the rnn shifted_targets = shifted_targets[:, -1:, :] # tile and append the bottleneck to inputs sln_offset = 0 if hparams.condition_on_sln: sln_offset = 51 pre_tile_y = tf.reshape( bottleneck, [common_layers.shape_list(bottleneck)[0], 1, hparams.bottleneck_bits + hparams.num_categories + sln_offset]) overlay_x = tf.tile(pre_tile_y, [1, common_layers.shape_list(shifted_targets)[1], 1]) inputs = tf.concat([shifted_targets, overlay_x], -1) seq_len_batch = tf.ones([common_layers.shape_list(inputs)[0]]) # RUN PRE-LSTM LAYER with tf.variable_scope('pre_decoder', reuse=tf.AUTO_REUSE): inputs = tf.layers.dense( inputs, hparams.hidden_size, name='bottom') inputs = tf.nn.tanh(inputs) # RUN LSTM with tf.variable_scope('lstm_decoder', reuse=tf.AUTO_REUSE): next_step, next_state = tf.nn.dynamic_rnn( layers, inputs, seq_len_batch, initial_state=current_hidden, dtype=tf.float32, time_major=False) next_step = tf.expand_dims(next_step, [1]) logits_so_far = tf.concat([logits_so_far, next_step], 1) #print('concat success') # input() # flatten state next_state = tuple((s.c, s.h) for s in next_state) return logits_so_far, next_state def while_exit_cond(logits_so_far, unused_current_hidden): length = common_layers.shape_list(logits_so_far)[1] return length < max_decode_length # passing state must be flattened: initial_state = tuple([(s.c, s.h) for s in initial_state]) # actually run tf.while: logits, final_state = tf.while_loop( while_exit_cond, infer_step, [logits_so_far, initial_state], shape_invariants=[ tf.TensorShape([None, None, 1, hparams.hidden_size]), tuple([(s[0].get_shape(), s[1].get_shape()) for s in initial_state]), ], back_prop=False, parallel_iterations=1 ) # logits should be returned in 3d mode: logits = common_layers.flatten4d3d(logits) return logits, final_state
def transformer_prepare_encoder(inputs, target_space, hparams, features=None): """Prepare one shard of the model for the encoder. Args: inputs: a Tensor. target_space: a Tensor. hparams: run hyperparameters features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. Returns: encoder_input: a Tensor, bottom of encoder stack encoder_self_attention_bias: a bias tensor for use in encoder self-attention encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder attention """ ishape_static = inputs.shape.as_list() encoder_input = inputs if features and "inputs_segmentation" in features: # Packed dataset. Keep the examples from seeing each other. inputs_segmentation = features["inputs_segmentation"] inputs_position = features["inputs_position"] targets_segmentation = features["targets_segmentation"] encoder_self_attention_bias = common_attention.attention_bias_same_segment( inputs_segmentation, inputs_segmentation) encoder_decoder_attention_bias = ( common_attention.attention_bias_same_segment(targets_segmentation, inputs_segmentation)) else: # Usual case - not a packed dataset. encoder_padding = common_attention.embedding_to_padding(encoder_input) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) encoder_self_attention_bias = ignore_padding encoder_decoder_attention_bias = ignore_padding inputs_position = None if hparams.proximity_bias: encoder_self_attention_bias += common_attention.attention_bias_proximal( common_layers.shape_list(inputs)[1]) if hparams.get("use_target_space_embedding", True): # Append target_space_id embedding to inputs. emb_target_space = common_layers.embedding( target_space, 32, ishape_static[-1], name="target_space_embedding", dtype=tf.bfloat16 if hparams.activation_dtype == "bfloat16" else tf.float32) emb_target_space = tf.reshape(emb_target_space, [1, 1, -1]) encoder_input += emb_target_space if hparams.pos == "timing": if inputs_position is not None: encoder_input = common_attention.add_timing_signal_1d_given_position( encoder_input, inputs_position) else: encoder_input = common_attention.add_timing_signal_1d(encoder_input) elif hparams.pos == "emb": encoder_input = common_attention.add_positional_embedding( encoder_input, hparams.max_length, "inputs_positional_embedding", inputs_position) if hparams.activation_dtype == "bfloat16": encoder_self_attention_bias = tf.cast(encoder_self_attention_bias, tf.bfloat16) encoder_decoder_attention_bias = tf.cast(encoder_decoder_attention_bias, tf.bfloat16) return (encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias)
def body(self, features): hparams = self.hparams batch_size = common_layers.shape_list(features["inputs"])[0] # Swap time and batch axes. input_frames = common_video.swap_time_and_batch_axes( features["inputs"]) target_frames = common_video.swap_time_and_batch_axes( features["targets"]) # Get actions if exist otherwise use zeros input_actions = self.get_input_if_exists( features, "input_action", batch_size, hparams.video_num_input_frames) target_actions = self.get_input_if_exists( features, "target_action", batch_size, hparams.video_num_target_frames) # Get rewards if exist otherwise use zeros input_rewards = self.get_input_if_exists( features, "input_reward", batch_size, hparams.video_num_input_frames) target_rewards = self.get_input_if_exists( features, "target_reward", batch_size, hparams.video_num_target_frames) all_actions = tf.concat([input_actions, target_actions], axis=0) all_rewards = tf.concat([input_rewards, target_rewards], axis=0) all_frames = tf.concat([input_frames, target_frames], axis=0) # Each image is being used twice, in latent tower and main tower. # This is to make sure we are using the *same* image for both, ... # ... given how TF queues work. # NOT sure if this is required at all. Doesn"t hurt though! :) all_frames = tf.identity(all_frames) gen_images, gen_rewards, latent_means, latent_stds = self.construct_model( images=all_frames, actions=all_actions, rewards=all_rewards, ) extra_loss = self.get_extra_loss(latent_means=latent_means, latent_stds=latent_stds, true_frames=all_frames, gen_frames=gen_images) # Visualize predictions in Tensorboard if self.is_training and not self.is_per_pixel_softmax: self.visualize_predictions(all_frames[1:], gen_images) # Ignore the predictions from the input frames. # This is NOT the same as original paper/implementation. predictions = gen_images[hparams.video_num_input_frames - 1:] reward_pred = gen_rewards[hparams.video_num_input_frames - 1:] reward_pred = tf.squeeze(reward_pred, axis=2) # Remove extra dimension. # Swap back time and batch axes. predictions = common_video.swap_time_and_batch_axes(predictions) reward_pred = common_video.swap_time_and_batch_axes(reward_pred) if hparams.internal_loss: # add the MSE loss for input frames as well. # we are assuming the modality is L2. otherwise the loss would be # incosistent across the frames. if self._target_modality != "VideoModalityL2Raw": raise ValueError("internal loss only works with L2.") recon_loss = tf.losses.mean_squared_error( all_frames[1:hparams.video_num_input_frames + 1], gen_images[:hparams.video_num_input_frames]) tf.summary.scalar("mse_extra", recon_loss) extra_loss += recon_loss return_targets = predictions if hparams.reward_prediction: return_targets = { "targets": predictions, "target_reward": reward_pred } return return_targets, extra_loss
def transformer_ffn_layer(x, hparams, pad_remover=None, conv_padding="LEFT", nonpadding_mask=None, losses=None, cache=None, decode_loop_step=None, readout_filter_size=0): """Feed-forward layer in the transformer. Args: x: a Tensor of shape [batch_size, length, hparams.hidden_size] hparams: hyperparameters for model pad_remover: an expert_utils.PadRemover object tracking the padding positions. If provided, when using convolutional settings, the padding is removed before applying the convolution, and restored afterward. This can give a significant speedup. conv_padding: a string - either "LEFT" or "SAME". nonpadding_mask: an optional Tensor with shape [batch_size, length]. needed for convolutional layers with "SAME" padding. Contains 1.0 in positions corresponding to nonpadding. losses: optional list onto which to append extra training losses cache: dict, containing tensors which are the results of previous attentions, used for fast decoding. decode_loop_step: An integer, step number of the decoding loop. Only used for inference on TPU. readout_filter_size: if it's greater than 0, then it will be used instead of filter_size Returns: a Tensor of shape [batch_size, length, hparams.hidden_size] Raises: ValueError: If losses arg is None, but layer generates extra losses. """ ffn_layer = hparams.ffn_layer relu_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "relu_dropout_broadcast_dims", ""))) if ffn_layer == "conv_hidden_relu": # Backwards compatibility ffn_layer = "dense_relu_dense" if ffn_layer == "dense_relu_dense": # In simple convolution mode, use `pad_remover` to speed up processing. mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_FFN_FILTER_DENSE, value={ "filter_size": hparams.filter_size, "use_bias": "True", "activation": mlperf_log.RELU }) mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_FFN_OUTPUT_DENSE, value={ "hidden_size": hparams.hidden_size, "use_bias": "True", }) mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_RELU_DROPOUT, value=hparams.relu_dropout) if pad_remover: original_shape = common_layers.shape_list(x) # Collapse `x` across examples, and remove padding positions. x = tf.reshape(x, tf.concat([[-1], original_shape[2:]], axis=0)) x = tf.expand_dims(pad_remover.remove(x), axis=0) conv_output = quaternion_dense_relu_dense( x, hparams.filter_size, hparams.hidden_size, dropout=hparams.relu_dropout, dropout_broadcast_dims=relu_dropout_broadcast_dims) if pad_remover: # Restore `conv_output` to the original shape of `x`, including padding. conv_output = tf.reshape( pad_remover.restore(tf.squeeze(conv_output, axis=0)), original_shape) return conv_output elif ffn_layer == "raw_dense_relu_dense": # In simple convolution mode, use `pad_remover` to speed up processing. mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_FFN_FILTER_DENSE, value={ "filter_size": hparams.filter_size, "use_bias": "True", "activation": mlperf_log.RELU }) mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_FFN_OUTPUT_DENSE, value={ "hidden_size": hparams.hidden_size, "use_bias": "True", }) mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_RELU_DROPOUT, value=hparams.relu_dropout) if pad_remover: original_shape = common_layers.shape_list(x) # Collapse `x` across examples, and remove padding positions. x = tf.reshape(x, tf.concat([[-1], original_shape[2:]], axis=0)) x = tf.expand_dims(pad_remover.remove(x), axis=0) conv_output = common_layers.dense_relu_dense( x, hparams.filter_size, hparams.hidden_size, dropout=hparams.relu_dropout, dropout_broadcast_dims=relu_dropout_broadcast_dims) if pad_remover: # Restore `conv_output` to the original shape of `x`, including padding. conv_output = tf.reshape( pad_remover.restore(tf.squeeze(conv_output, axis=0)), original_shape) return conv_output elif ffn_layer == "conv_relu_conv": return common_layers.conv_relu_conv( x, readout_filter_size or hparams.filter_size, hparams.hidden_size, first_kernel_size=hparams.conv_first_kernel, second_kernel_size=1, padding=conv_padding, nonpadding_mask=nonpadding_mask, dropout=hparams.relu_dropout, cache=cache, decode_loop_step=decode_loop_step) elif ffn_layer == "parameter_attention": return common_attention.parameter_attention( x, hparams.parameter_attention_key_channels or hparams.hidden_size, hparams.parameter_attention_value_channels or hparams.hidden_size, hparams.hidden_size, readout_filter_size or hparams.filter_size, hparams.num_heads, hparams.attention_dropout) elif ffn_layer == "conv_hidden_relu_with_sepconv": return common_layers.conv_hidden_relu( x, readout_filter_size or hparams.filter_size, hparams.hidden_size, kernel_size=(3, 1), second_kernel_size=(31, 1), padding="LEFT", dropout=hparams.relu_dropout) elif ffn_layer == "sru": return common_layers.sru(x) elif ffn_layer == "local_moe_tpu": overhead = ( hparams.moe_overhead_train if hparams.mode == tf.estimator.ModeKeys.TRAIN else hparams.moe_overhead_eval) ret, loss = expert_utils.local_moe_tpu( x, hparams.filter_size // 2, hparams.hidden_size, hparams.moe_num_experts, overhead=overhead, loss_coef=hparams.moe_loss_coef) elif ffn_layer == "local_moe": overhead = ( hparams.moe_overhead_train if hparams.mode == tf.estimator.ModeKeys.TRAIN else hparams.moe_overhead_eval) ret, loss = expert_utils.local_moe( x, True, expert_utils.ffn_expert_fn(hparams.hidden_size, [hparams.filter_size], hparams.hidden_size), hparams.moe_num_experts, k=hparams.moe_k, hparams=hparams) losses.append(loss) return ret else: assert ffn_layer == "none" return x
def variance_loss(self, b): part = tf.random_uniform(common_layers.shape_list(b)) selection = tf.to_float(tf.less(part, tf.random_uniform([]))) selection_size = tf.reduce_sum(selection) part_avg = tf.abs(tf.reduce_sum(b * selection)) / (selection_size + 1) return part_avg
def body(self, features): hparams = self.hparams is_predicting = hparams.mode == tf.estimator.ModeKeys.PREDICT # TODO(lukaszkaiser): the split axes and the argmax below heavily depend on # using the default (a bit strange) video modality - we should change that. # Split inputs and targets into lists. input_frames = tf.unstack(features["inputs"], axis=1) target_frames = tf.unstack(features["targets"], axis=1) all_frames = input_frames + target_frames if "input_action" in features: input_actions = list( tf.split(features["input_action"], hparams.video_num_input_frames, axis=1)) target_actions = list( tf.split(features["target_action"], hparams.video_num_target_frames, axis=1)) all_actions = input_actions + target_actions orig_frame_shape = common_layers.shape_list(all_frames[0]) # Run a number of steps. res_frames, sampled_frames, sampled_frames_raw = [], [], [] if "target_reward" in features: res_rewards, extra_loss = [], 0.0 sample_prob = common_layers.inverse_exp_decay( hparams.scheduled_sampling_warmup_steps) sample_prob *= hparams.scheduled_sampling_prob for i in range(hparams.video_num_target_frames): cur_frames = all_frames[i:i + hparams.video_num_input_frames] features["inputs"] = tf.concat(cur_frames, axis=-1) features["cur_target_frame"] = all_frames[ i + hparams.video_num_input_frames] if "input_action" in features: cur_actions = all_actions[i:i + hparams.video_num_input_frames] features["input_action"] = tf.concat(cur_actions, axis=1) # Run model. with tf.variable_scope(tf.get_variable_scope(), reuse=i > 0): if "target_reward" not in features: res_frame = self.body_single(features) else: res_dict, res_extra_loss = self.body_single(features) extra_loss += res_extra_loss res_frame = res_dict["targets"] res_reward = res_dict["target_reward"] res_rewards.append(res_reward) res_frames.append(res_frame) # Only for Softmax loss: sample frame so we can keep iterating. sampled_frame_raw = self.get_sampled_frame(res_frame) sampled_frames_raw.append(sampled_frame_raw) # TODO(lukaszkaiser): this should be consistent with modality.bottom() sampled_frame = common_layers.standardize_images(sampled_frame_raw) sampled_frames.append(sampled_frame) if is_predicting: all_frames[i + hparams.video_num_input_frames] = sampled_frame # Scheduled sampling during training. if (hparams.scheduled_sampling_prob > 0.0 and self.is_training): do_sample = tf.less(tf.random_uniform([orig_frame_shape[0]]), sample_prob) orig_frame = all_frames[i + hparams.video_num_input_frames] sampled_frame = tf.where(do_sample, sampled_frame, orig_frame) all_frames[i + hparams.video_num_input_frames] = sampled_frame # Concatenate results and return them. frames = tf.stack(res_frames, axis=1) if "target_reward" not in features: return frames rewards = tf.concat(res_rewards, axis=1) return {"targets": frames, "target_reward": rewards}, extra_loss