def bottom(self, inputs): with tf.variable_scope(self.name): inputs_shape = common_layers.shape_list(inputs) if len(inputs_shape) != 5: raise ValueError( "Assuming videos given as tensors in the format " "[batch, time, height, width, channels] but got one " "of shape: %s" % str(inputs_shape)) if not tf.contrib.eager.in_eager_mode(): if inputs.get_shape().as_list()[1] is None: tf.summary.image("inputs_last_frame", tf.cast(inputs[:, -1, :, :, :], tf.uint8), max_outputs=1) else: for k in range(inputs_shape[1]): tf.summary.image("inputs_frame_%d" % k, tf.cast(inputs[:, k, :, :, :], tf.uint8), max_outputs=1) # Standardize frames. inputs = tf.reshape(inputs, [-1] + inputs_shape[2:]) inputs = common_layers.standardize_images(inputs) inputs = tf.reshape(inputs, inputs_shape) # Concatenate the time dimension on channels for image models to work. transposed = tf.transpose(inputs, [0, 2, 3, 1, 4]) return tf.reshape(transposed, [ inputs_shape[0], inputs_shape[2], inputs_shape[3], inputs_shape[1] * inputs_shape[4] ])
def get_sampled_frame(self, pred_frame): """Samples the frame based on modality. if the modality is L2/L1 then the next predicted frame is the next frame and there is no sampling but in case of Softmax loss the next actual frame should be sampled from predicted frame. This enables multi-frame target prediction with Softmax loss. Args: pred_frame: predicted frame. Returns: sampled frame. """ # TODO(lukaszkaiser): the logic below heavily depend on the current # (a bit strange) video modalities - we should change that. if self.is_per_pixel_softmax: frame_shape = common_layers.shape_list(pred_frame) target_shape = frame_shape[:-1] + [self.hparams.problem.num_channels] sampled_frame = tf.reshape(pred_frame, target_shape + [256]) sampled_frame = pixels_from_softmax( sampled_frame, temperature=self.hparams.pixel_sampling_temperature) # TODO(lukaszkaiser): this should be consistent with modality.bottom() sampled_frame = common_layers.standardize_images(sampled_frame) else: x = common_layers.convert_real_to_rgb(pred_frame) x = x - tf.stop_gradient(x + tf.round(x)) x = common_layers.convert_rgb_to_real(x) return x return sampled_frame
def get_sampled_frame(self, pred_frame): """Samples the frame based on modality. if the modality is L2/L1 then the next predicted frame is the next frame and there is no sampling but in case of Softmax loss the next actual frame should be sampled from predicted frame. This enables multi-frame target prediction with Softmax loss. Args: pred_frame: predicted frame. Returns: sampled frame. """ # TODO(lukaszkaiser): the logic below heavily depend on the current # (a bit strange) video modalities - we should change that. if self.is_per_pixel_softmax: frame_shape = common_layers.shape_list(pred_frame) target_shape = frame_shape[:-1] + [self.hparams.problem.num_channels] sampled_frame = tf.reshape(pred_frame, target_shape + [256]) sampled_frame = pixels_from_softmax( sampled_frame, temperature=self.hparams.pixel_sampling_temperature) # TODO(lukaszkaiser): this should be consistent with modality.bottom() sampled_frame = common_layers.standardize_images(sampled_frame) else: x = common_layers.convert_real_to_rgb(pred_frame) x = x - tf.stop_gradient(x + tf.round(x)) x = common_layers.convert_rgb_to_real(x) return x return sampled_frame
def get_sampled_frame(self, pred_frame): """Samples the frame based on modality. if the modality is L2/L1 then the next predicted frame is the next frame and there is no sampling but in case of Softmax loss the next actual frame should be sampled from predicted frame. This enables multi-frame target prediction with Softmax loss. Args: pred_frame: predicted frame. Returns: sampled frame. """ if not self.is_per_pixel_softmax: return pred_frame frame_shape = common_layers.shape_list(pred_frame) target_shape = frame_shape[:-1] + [self.hparams.problem.num_channels] sampled_frame = tf.reshape(pred_frame, target_shape + [256]) # TODO(lukaszkaiser): should this be argmax or real sampling. sampled_frame = tf.argmax(sampled_frame, axis=-1) sampled_frame = tf.to_float(sampled_frame) # TODO(lukaszkaiser): this should be consistent with modality.bottom() sampled_frame = common_layers.standardize_images(sampled_frame) return sampled_frame
def bottom(self, inputs): """Transform input from data space to model space. Perform the Xception "Entry flow", which consists of two convolutional filter upscalings followed by three residually connected separable convolution blocks. Args: inputs: A Tensor with shape [batch, ...] Returns: body_input: A Tensor with shape [batch, ?, ?, body_input_depth]. """ with tf.variable_scope(self.name): def xnet_resblock(x, filters, res_relu, name): with tf.variable_scope(name): y = common_layers.separable_conv_block( x, filters, [((1, 1), (3, 3)), ((1, 1), (3, 3))], first_relu=True, padding="SAME", force2d=True, name="sep_conv_block") y = common_layers.pool(y, (3, 3), "MAX", "SAME", strides=(2, 2)) return y + common_layers.conv_block(x, filters, [((1, 1), (1, 1))], padding="SAME", strides=(2, 2), first_relu=res_relu, force2d=True, name="res_conv0") inputs = common_layers.standardize_images(inputs) # TODO(lukaszkaiser): summaries here don't work in multi-problem case yet. # tf.summary.image("inputs", inputs, max_outputs=2) x = common_layers.conv_block(inputs, 32, [((1, 1), (3, 3))], first_relu=False, padding="SAME", strides=(2, 2), force2d=True, name="conv0") x = common_layers.conv_block(x, 64, [((1, 1), (3, 3))], padding="SAME", force2d=True, name="conv1") x = xnet_resblock(x, min(128, self._body_input_depth), True, "block0") x = xnet_resblock(x, min(256, self._body_input_depth), False, "block1") return xnet_resblock(x, self._body_input_depth, False, "block2")
def body(self, features): def deconv2d(cur, i, kernel_size, output_filters, activation=tf.nn.relu): thicker = common_layers.conv( cur, output_filters * 4, kernel_size, padding="SAME", activation=activation, name="deconv2d" + str(i)) return tf.depth_to_space(thicker, 2) cur_frame = common_layers.standardize_images(features["inputs_0"]) prev_frame = common_layers.standardize_images(features["inputs_1"]) frames = tf.concat([cur_frame, prev_frame], axis=3) frames = tf.reshape(frames, [-1, 210, 160, 6]) h1 = tf.layers.conv2d(frames, filters=64, strides=2, kernel_size=(8, 8), padding="SAME", activation=tf.nn.relu) h2 = tf.layers.conv2d(h1, filters=128, strides=2, kernel_size=(6, 6), padding="SAME", activation=tf.nn.relu) h3 = tf.layers.conv2d(h2, filters=128, strides=2, kernel_size=(6, 6), padding="SAME", activation=tf.nn.relu) h4 = tf.layers.conv2d(h3, filters=128, strides=2, kernel_size=(4, 4), padding="SAME", activation=tf.nn.relu) h45 = tf.reshape(h4, [-1, 14 * 10 * 128]) h5 = tf.layers.dense(h45, 2048, activation=tf.nn.relu) h6 = tf.layers.dense(h5, 2048, activation=tf.nn.relu) h7 = tf.layers.dense(h6, 14 * 10 * 128, activation=tf.nn.relu) h8 = tf.reshape(h7, [-1, 14, 10, 128]) h9 = deconv2d(h8, 1, (4, 4), 128, activation=tf.nn.relu) h9 = h9[:, :27, :, :] h10 = deconv2d(h9, 2, (6, 6), 128, activation=tf.nn.relu) h10 = h10[:, :53, :, :] h11 = deconv2d(h10, 3, (6, 6), 128, activation=tf.nn.relu) h11 = h11[:, :105, :, :] h12 = deconv2d(h11, 4, (8, 8), 3 * 256, activation=tf.identity) reward = tf.layers.flatten(h12) return {"targets": h12, "reward": reward}
def bottom(self, inputs): with tf.variable_scope(self.name): inputs = common_layers.standardize_images(inputs) tf.summary.image("inputs", inputs, max_outputs=2) return common_layers.conv_block( inputs, self._body_input_depth, [((1, 1), (3, 3))], first_relu=False, padding="SAME", force2d=True, name="small_image_conv")
def bottom(self, inputs): with tf.variable_scope(self.name): common_layers.summarize_video(inputs, "inputs") inputs_shape = common_layers.shape_list(inputs) # Standardize frames. inputs = tf.reshape(inputs, [-1] + inputs_shape[2:]) inputs = common_layers.standardize_images(inputs) inputs = tf.reshape(inputs, inputs_shape) # Concatenate the time dimension on channels for image models to work. transposed = tf.transpose(inputs, [0, 2, 3, 1, 4]) return tf.reshape(transposed, [ inputs_shape[0], inputs_shape[2], inputs_shape[3], inputs_shape[1] * inputs_shape[4] ])
def bottom(self, x): inputs = x with tf.variable_scope(self.name): common_layers.summarize_video(inputs, "inputs") inputs_shape = common_layers.shape_list(inputs) # Standardize frames. inputs = tf.reshape(inputs, [-1] + inputs_shape[2:]) inputs = common_layers.standardize_images(inputs) inputs = tf.reshape(inputs, inputs_shape) # Concatenate the time dimension on channels for image models to work. transposed = tf.transpose(inputs, [0, 2, 3, 1, 4]) return tf.reshape(transposed, [ inputs_shape[0], inputs_shape[2], inputs_shape[3], inputs_shape[1] * inputs_shape[4] ])
def bottom(self, inputs): with tf.variable_scope(self.name): inputs = common_layers.standardize_images(inputs) # TODO(lukaszkaiser): summaries here don't work in multi-problem case yet. # tf.summary.image("inputs", inputs, max_outputs=2) if self._model_hparams.compress_steps > 0: strides = (2, 2) else: strides = (1, 1) return common_layers.conv_block(inputs, self._body_input_depth, [((1, 1), (3, 3))], first_relu=False, strides=strides, padding="SAME", force2d=True, name="small_image_conv")
def xception_entry(inputs, hidden_dim): with tf.variable_scope("xception_entry"): def xnet_resblock(x, filters, res_relu, name): with tf.variable_scope(name): y = common_layers.separable_conv_block(x, filters, [((1, 1), (3, 3)), ((1, 1), (3, 3))], first_relu=True, padding="SAME", force2d=True, name="sep_conv_block") y = common_layers.pool(y, (3, 3), "MAX", "SAME", strides=(2, 2)) return y + common_layers.conv_block(x, filters, [((1, 1), (1, 1))], padding="SAME", strides=(2, 2), first_relu=res_relu, force2d=True, name="res_conv0") inputs = common_layers.standardize_images(inputs) # TODO(lukaszkaiser): summaries here don't work in multi-problem case yet. # tf.summary.image("inputs", inputs, max_outputs=2) x = common_layers.conv_block(inputs, 32, [((1, 1), (3, 3))], first_relu=False, padding="SAME", strides=(2, 2), force2d=True, name="conv0") x = common_layers.conv_block(x, 64, [((1, 1), (3, 3))], padding="SAME", force2d=True, name="conv1") x = xnet_resblock(x, min(128, hidden_dim), True, "block0") x = xnet_resblock(x, min(256, hidden_dim), False, "block1") return xnet_resblock(x, hidden_dim, False, "block2")
def bottom(self, inputs): with tf.variable_scope(self.name): inputs_shape = common_layers.shape_list(inputs) if len(inputs_shape) != 5: raise ValueError( "Assuming videos given as tensors in the format " "[batch, time, height, width, channels].") if not context.in_eager_mode(): tf.summary.image("inputs", tf.cast(inputs[:, -1, :, :, :], tf.uint8), max_outputs=1) # Standardize frames. inputs = tf.reshape(inputs, [-1] + inputs_shape[2:]) inputs = common_layers.standardize_images(inputs) inputs = tf.reshape(inputs, inputs_shape) # Concatenate the time dimension on channels for image models to work. transposed = tf.transpose(inputs, [0, 2, 3, 1, 4]) return tf.reshape(transposed, [ inputs_shape[0], inputs_shape[2], inputs_shape[3], inputs_shape[1] * inputs_shape[4] ])
def xception_entry(inputs, hidden_dim): with tf.variable_scope("xception_entry"): def xnet_resblock(x, filters, res_relu, name): with tf.variable_scope(name): y = common_layers.separable_conv_block( x, filters, [((1, 1), (3, 3)), ((1, 1), (3, 3))], first_relu=True, padding="SAME", force2d=True, name="sep_conv_block") y = common_layers.pool(y, (3, 3), "MAX", "SAME", strides=(2, 2)) return y + common_layers.conv_block( x, filters, [((1, 1), (1, 1))], padding="SAME", strides=(2, 2), first_relu=res_relu, force2d=True, name="res_conv0") inputs = common_layers.standardize_images(inputs) # TODO(lukaszkaiser): summaries here don't work in multi-problem case yet. # tf.summary.image("inputs", inputs, max_outputs=2) x = common_layers.conv_block( inputs, 32, [((1, 1), (3, 3))], first_relu=False, padding="SAME", strides=(2, 2), force2d=True, name="conv0") x = common_layers.conv_block( x, 64, [((1, 1), (3, 3))], padding="SAME", force2d=True, name="conv1") x = xnet_resblock(x, min(128, hidden_dim), True, "block0") x = xnet_resblock(x, min(256, hidden_dim), False, "block1") return xnet_resblock(x, hidden_dim, False, "block2")
def testStandardizeImages(self): x = np.random.rand(5, 7, 7, 3) with self.test_session() as session: y = common_layers.standardize_images(tf.constant(x)) res = session.run(y) self.assertEqual(res.shape, (5, 7, 7, 3))
def body(self, features): hparams = self.hparams is_predicting = hparams.mode == tf.estimator.ModeKeys.PREDICT # TODO(lukaszkaiser): the split axes and the argmax below heavily depend on # using the default (a bit strange) video modality - we should change that. # Split inputs and targets into lists. input_frames = tf.unstack(features["inputs"], axis=1) target_frames = tf.unstack(features["targets"], axis=1) all_frames = input_frames + target_frames if "input_action" in features: input_actions = list( tf.split(features["input_action"], hparams.video_num_input_frames, axis=1)) target_actions = list( tf.split(features["target_action"], hparams.video_num_target_frames, axis=1)) all_actions = input_actions + target_actions orig_frame_shape = common_layers.shape_list(all_frames[0]) # Run a number of steps. res_frames, sampled_frames, sampled_frames_raw = [], [], [] if "target_reward" in features: res_rewards, extra_loss = [], 0.0 sample_prob = common_layers.inverse_exp_decay( hparams.scheduled_sampling_warmup_steps) sample_prob *= hparams.scheduled_sampling_prob for i in range(hparams.video_num_target_frames): cur_frames = all_frames[i:i + hparams.video_num_input_frames] features["inputs"] = tf.concat(cur_frames, axis=-1) features["cur_target_frame"] = all_frames[ i + hparams.video_num_input_frames] if "input_action" in features: cur_actions = all_actions[i:i + hparams.video_num_input_frames] features["input_action"] = tf.concat(cur_actions, axis=1) # Run model. with tf.variable_scope(tf.get_variable_scope(), reuse=i > 0): if "target_reward" not in features: res_frame = self.body_single(features) else: res_dict, res_extra_loss = self.body_single(features) extra_loss += res_extra_loss res_frame = res_dict["targets"] res_reward = res_dict["target_reward"] res_rewards.append(res_reward) res_frames.append(res_frame) # Only for Softmax loss: sample frame so we can keep iterating. sampled_frame_raw = self.get_sampled_frame(res_frame) sampled_frames_raw.append(sampled_frame_raw) # TODO(lukaszkaiser): this should be consistent with modality.bottom() sampled_frame = common_layers.standardize_images(sampled_frame_raw) sampled_frames.append(sampled_frame) if is_predicting: all_frames[i + hparams.video_num_input_frames] = sampled_frame # Scheduled sampling during training. if (hparams.scheduled_sampling_prob > 0.0 and self.is_training): do_sample = tf.less(tf.random_uniform([orig_frame_shape[0]]), sample_prob) orig_frame = all_frames[i + hparams.video_num_input_frames] sampled_frame = tf.where(do_sample, sampled_frame, orig_frame) all_frames[i + hparams.video_num_input_frames] = sampled_frame # Concatenate results and return them. frames = tf.stack(res_frames, axis=1) if "target_reward" not in features: return frames rewards = tf.concat(res_rewards, axis=1) return {"targets": frames, "target_reward": rewards}, extra_loss
def targets_bottom(self, x): common_video.gif_summary("targets", x, max_outputs=1) x = common_layers.standardize_images(x) return x
def bottom(self, inputs): with tf.variable_scope(self.name): inputs = common_layers.standardize_images(inputs) if not context.in_eager_mode(): tf.summary.image("inputs", inputs, max_outputs=2) return tf.to_float(inputs)
def next_frame(self, frames, actions, rewards, target_frame, internal_states, video_extra): del rewards, video_extra hparams = self.hparams filters = hparams.hidden_size kernel2 = (4, 4) action = actions[-1] activation_fn = common_layers.belu if self.hparams.activation_fn == "relu": activation_fn = tf.nn.relu # Normalize frames. frames = [common_layers.standardize_images(f) for f in frames] # Stack the inputs. if internal_states is not None and hparams.concat_internal_states: # Use the first part of the first internal state if asked to concatenate. batch_size = common_layers.shape_list(frames[0])[0] internal_state = internal_states[0][0][:batch_size, :, :, :] stacked_frames = tf.concat(frames + [internal_state], axis=-1) else: stacked_frames = tf.concat(frames, axis=-1) inputs_shape = common_layers.shape_list(stacked_frames) # Update internal states early if requested. if hparams.concat_internal_states: internal_states = self.update_internal_states_early( internal_states, frames) # Using non-zero bias initializer below for edge cases of uniform inputs. x = tf.layers.dense( stacked_frames, filters, name="inputs_embed", bias_initializer=tf.random_normal_initializer(stddev=0.01)) x = common_attention.add_timing_signal_nd(x) # Down-stride. layer_inputs = [x] for i in range(hparams.num_compress_steps): with tf.variable_scope("downstride%d" % i): layer_inputs.append(x) x = tf.nn.dropout(x, 1.0 - self.hparams.dropout) x = common_layers.make_even_size(x) if i < hparams.filter_double_steps: filters *= 2 x = common_attention.add_timing_signal_nd(x) x = tf.layers.conv2d(x, filters, kernel2, activation=activation_fn, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) if self.has_actions: with tf.variable_scope("policy"): x_flat = tf.layers.flatten(x) policy_pred = tf.layers.dense(x_flat, self.hparams.problem.num_actions) value_pred = tf.layers.dense(x_flat, 1) value_pred = tf.squeeze(value_pred, axis=-1) else: policy_pred, value_pred = None, None # Add embedded action if present. if self.has_actions: x = common_video.inject_additional_input(x, action, "action_enc", hparams.action_injection) # Inject latent if present. Only for stochastic models. norm_target_frame = common_layers.standardize_images(target_frame) x, extra_loss = self.inject_latent(x, frames, norm_target_frame, action) x_mid = tf.reduce_mean(x, axis=[1, 2], keepdims=True) x, internal_states = self.middle_network(x, internal_states) # Up-convolve. layer_inputs = list(reversed(layer_inputs)) for i in range(hparams.num_compress_steps): with tf.variable_scope("upstride%d" % i): x = tf.nn.dropout(x, 1.0 - self.hparams.dropout) if self.has_actions: x = common_video.inject_additional_input( x, action, "action_enc", hparams.action_injection) if i >= hparams.num_compress_steps - hparams.filter_double_steps: filters //= 2 x = tf.layers.conv2d_transpose(x, filters, kernel2, activation=activation_fn, strides=(2, 2), padding="SAME") y = layer_inputs[i] shape = common_layers.shape_list(y) x = x[:, :shape[1], :shape[2], :] x = common_layers.layer_norm(x + y) x = common_attention.add_timing_signal_nd(x) # Cut down to original size. x = x[:, :inputs_shape[1], :inputs_shape[2], :] x_fin = tf.reduce_mean(x, axis=[1, 2], keepdims=True) if hparams.do_autoregressive_rnn: # If enabled, we predict the target frame autoregregressively using rnns. # To this end, the current prediciton is flattened into one long sequence # of sub-pixels, and so is the target frame. Each sub-pixel (RGB value, # from 0 to 255) is predicted with an RNN. To avoid doing as many steps # as width * height * channels, we only use a number of pixels back, # as many as hparams.autoregressive_rnn_lookback. with tf.variable_scope("autoregressive_rnn"): batch_size = common_layers.shape_list(frames[0])[0] # Height, width, channels and lookback are the constants we need. h, w = inputs_shape[1], inputs_shape[ 2] # 105, 80 on Atari games c = hparams.problem.num_channels lookback = hparams.autoregressive_rnn_lookback assert ( h * w ) % lookback == 0, "Number of pixels must divide lookback." m = (h * w) // lookback # Batch size multiplier for the RNN. # These are logits that will be used as inputs to the RNN. rnn_inputs = tf.layers.dense(x, c * 64, name="rnn_inputs") # They are of shape [batch_size, h, w, c, 64], reshaping now. rnn_inputs = tf.reshape(rnn_inputs, [batch_size * m, lookback * c, 64]) # Same for the target frame. rnn_target = tf.reshape(target_frame, [batch_size * m, lookback * c]) # Construct rnn starting state: flatten rnn_inputs, apply a relu layer. rnn_start_state = tf.nn.relu( tf.layers.dense(tf.nn.relu(tf.layers.flatten(rnn_inputs)), 256, name="rnn_start_state")) # Our RNN function API is on bits, each subpixel has 8 bits. total_num_bits = lookback * c * 8 # We need to provide RNN targets as bits (due to the API). rnn_target_bits = discretization.int_to_bit(rnn_target, 8) rnn_target_bits = tf.reshape(rnn_target_bits, [batch_size * m, total_num_bits]) if self.is_training: # Run the RNN in training mode, add it's loss to the losses. rnn_predict, rnn_loss = discretization.predict_bits_with_lstm( rnn_start_state, 128, total_num_bits, target_bits=rnn_target_bits, extra_inputs=rnn_inputs) extra_loss += rnn_loss # We still use non-RNN predictions too in order to guide the network. x = tf.layers.dense(x, c * 256, name="logits") x = tf.reshape(x, [batch_size, h, w, c, 256]) rnn_predict = tf.reshape(rnn_predict, [batch_size, h, w, c, 256]) # Mix non-RNN and RNN predictions so that after warmup the RNN is 90%. x = tf.reshape(tf.nn.log_softmax(x), [batch_size, h, w, c * 256]) rnn_predict = tf.nn.log_softmax(rnn_predict) rnn_predict = tf.reshape(rnn_predict, [batch_size, h, w, c * 256]) alpha = 0.9 * common_layers.inverse_lin_decay( hparams.autoregressive_rnn_warmup_steps) x = alpha * rnn_predict + (1.0 - alpha) * x else: # In prediction mode, run the RNN without any targets. bits, _ = discretization.predict_bits_with_lstm( rnn_start_state, 128, total_num_bits, extra_inputs=rnn_inputs, temperature=0.0 ) # No sampling from this RNN, just greedy. # The output is in bits, get back the predicted pixels. bits = tf.reshape(bits, [batch_size * m, lookback * c, 8]) ints = discretization.bit_to_int(tf.maximum(bits, 0), 8) ints = tf.reshape(ints, [batch_size, h, w, c]) x = tf.reshape(tf.one_hot(ints, 256), [batch_size, h, w, c * 256]) elif self.is_per_pixel_softmax: x = tf.layers.dense(x, hparams.problem.num_channels * 256, name="logits") else: x = tf.layers.dense(x, hparams.problem.num_channels, name="logits") reward_pred = None if self.has_rewards: # Reward prediction based on middle and final logits. reward_pred = tf.concat([x_mid, x_fin], axis=-1) reward_pred = tf.nn.relu( tf.layers.dense(reward_pred, 128, name="reward_pred")) reward_pred = tf.squeeze(reward_pred, axis=1) # Remove extra dims reward_pred = tf.squeeze(reward_pred, axis=1) # Remove extra dims return x, reward_pred, policy_pred, value_pred, extra_loss, internal_states
def body(self, features): hparams = self.hparams is_predicting = hparams.mode == tf.estimator.ModeKeys.PREDICT # TODO(lukaszkaiser): the split axes and the argmax below heavily depend on # using the default (a bit strange) video modality - we should change that. # Split inputs and targets into lists. input_frames = tf.unstack(features["inputs"], axis=1) target_frames = tf.unstack(features["targets"], axis=1) all_frames = input_frames + target_frames if "input_action" in features: input_actions = list( tf.split(features["input_action"], hparams.video_num_input_frames, axis=1)) target_actions = list( tf.split(features["target_action"], hparams.video_num_target_frames, axis=1)) all_actions = input_actions + target_actions orig_frame_shape = common_layers.shape_list(all_frames[0]) batch_size = orig_frame_shape[0] ss_func = self.get_scheduled_sample_func(batch_size) # Run a number of steps. res_frames, sampled_frames, sampled_frames_raw = [], [], [] extra_loss = 0.0 if "target_reward" in features: res_rewards = [] for i in range(hparams.video_num_target_frames): cur_frames = all_frames[i:i + hparams.video_num_input_frames] features["inputs"] = tf.concat(cur_frames, axis=-1) features["cur_target_frame"] = all_frames[ i + hparams.video_num_input_frames] if "input_action" in features: cur_actions = all_actions[i:i + hparams.video_num_input_frames] features["input_action"] = tf.concat(cur_actions, axis=1) # Run model. with tf.variable_scope(tf.get_variable_scope(), reuse=i > 0): if "target_reward" not in features: res_frame, res_extra_loss = self.body_single(features) else: res_dict, res_extra_loss = self.body_single(features) res_frame = res_dict["targets"] res_reward = res_dict["target_reward"] res_rewards.append(res_reward) extra_loss += res_extra_loss / float( hparams.video_num_target_frames) res_frames.append(res_frame) # Only for Softmax loss: sample frame so we can keep iterating. sampled_frame_raw = self.get_sampled_frame(res_frame) sampled_frames_raw.append(sampled_frame_raw) # TODO(lukaszkaiser): this should be consistent with modality.bottom() sampled_frame = common_layers.standardize_images(sampled_frame_raw) sampled_frames.append(sampled_frame) if is_predicting: all_frames[i + hparams.video_num_input_frames] = sampled_frame # Scheduled sampling during training. if self.is_training: done_warm_start = True # Always true for non-reccurent networks. groundtruth_items = [ all_frames[i + hparams.video_num_input_frames] ] generated_items = [sampled_frame] ss_frame, = self.get_scheduled_sample_inputs( done_warm_start, groundtruth_items, generated_items, ss_func) all_frames[i + hparams.video_num_input_frames] = ss_frame # Concatenate results and return them. frames = tf.stack(res_frames, axis=1) if "target_reward" not in features: return frames rewards = tf.concat(res_rewards, axis=1) return {"targets": frames, "target_reward": rewards}, extra_loss
def network(self): def middle_network(layer): # Run a stack of convolutions. x = layer kernel1 = (3, 3) filters = common_layers.shape_list(x)[-1] for i in range(2): with tf.variable_scope("layer%d" % i): y = tf.nn.dropout(x, 1.0 - 0.5) y = tf.layers.conv2d(y, filters, kernel1, activation=self.activation_fn, strides=(1, 1), padding="SAME") if i == 0: x = y else: x = common_layers.layer_norm(x + y) return x batch_size = tf.shape(self.states_ph)[0] filters = self.hidden_size kernel2 = (4, 4) action = self.actions_oph #[0] NOTE - might remove this # Normalize states if (self.n_envs > 1): states = [ common_layers.standardize_images(self.states_ph[i, :, :, :]) for i in range(self.n_envs) ] stacked_states = tf.stack(states) else: stacked_states = common_layers.standardize_images(self.states_ph) inputs_shape = common_layers.shape_list(stacked_states) # Using non-zero bias initializer below for edge cases of uniform inputs. x = tf.layers.dense( stacked_states, filters, name="inputs_embed", bias_initializer=tf.random_normal_initializer(stddev=0.01)) x = common_attention.add_timing_signal_nd(x) # Down-stride. layer_inputs = [x] for i in range(self.layers): with tf.variable_scope("downstride%d" % i): layer_inputs.append(x) x = tf.nn.dropout(x, 1.0 - self.dropout_p) x = common_layers.make_even_size(x) if i < 2: filters *= 2 x = common_attention.add_timing_signal_nd(x) x = tf.layers.conv2d(x, filters, kernel2, activation=self.activation_fn, strides=(2, 2), padding="SAME") x = common_layers.layer_norm(x) if self.is_policy: with tf.variable_scope("policy"): x_flat = tf.layers.flatten(x) policy_pred = tf.layers.dense(x_flat, self.action_dim) value_pred = tf.layers.dense(x_flat, 1) value_pred = tf.squeeze(value_pred, axis=-1) else: policy_pred, value_pred = None, None #if self.has_actions: x = inject_additional_input(x, action, "action_enc", "multi_additive") # Inject latent if present. Only for stochastic models. target_states = common_layers.standardize_images(self.target_states) x_mid = tf.reduce_mean(x, axis=[1, 2], keepdims=True) x = middle_network(x) # Up-convolve. layer_inputs = list(reversed(layer_inputs)) for i in range(self.layers): with tf.variable_scope("upstride%d" % i): x = tf.nn.dropout(x, 1.0 - 0.1) if i >= self.layers - 2: filters //= 2 x = tf.layers.conv2d_transpose(x, filters, kernel2, activation=self.activation_fn, strides=(2, 2), padding="SAME") y = layer_inputs[i] shape = common_layers.shape_list(y) x = x[:, :shape[1], :shape[2], :] x = common_layers.layer_norm(x + y) x = common_attention.add_timing_signal_nd(x) # Cut down to original size. x = x[:, :inputs_shape[1], :inputs_shape[2], :] x_fin = tf.reduce_mean(x, axis=[1, 2], keepdims=True) x = tf.layers.dense(x, self.depth, name="logits") reward_pred = None if self.has_rewards: # Reward prediction based on middle and final logits. reward_pred = tf.concat([x_mid, x_fin], axis=-1) reward_pred = tf.nn.relu( tf.layers.dense(reward_pred, 128, name="reward_pred")) reward_pred = tf.squeeze(reward_pred, axis=1) # Remove extra dims reward_pred = tf.squeeze(reward_pred, axis=1) # Remove extra dims return x, reward_pred, policy_pred, value_pred
def bottom(self, x): inputs = x with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE): common_layers.summarize_video(inputs, "inputs") inputs = common_layers.standardize_images(inputs) return common_layers.time_to_channels(inputs)
def body(self, features): hparams = self.hparams self.has_action = "input_action" in features self.has_reward = "target_reward" in features # dirty hack to enable the latent tower self.features = features # Split inputs and targets into lists. input_frames = tf.unstack(features["inputs"], axis=1) target_frames = tf.unstack(features["targets"], axis=1) all_frames = input_frames + target_frames if self.has_action: input_actions = tf.unstack(features["input_action"], axis=1) target_actions = tf.unstack(features["target_action"], axis=1) all_actions = input_actions + target_actions res_frames, sampled_frames, sampled_frames_raw, res_rewards = [], [], [], [] lstm_states = [None] * hparams.num_lstm_layers extra_loss = 0.0 num_frames = len(all_frames) for i in range(num_frames - 1): frame = all_frames[i] action = all_actions[i] if self.has_action else None # more hack to enable latent_tower # TODO(mbz): clean this up. self.features["inputs"] = all_frames[i] self.features["cur_target_frame"] = all_frames[i+1] # Run model. with tf.variable_scope("recurrent_model", reuse=tf.AUTO_REUSE): func_out = self.predict_next_frame(frame, action, lstm_states) res_frame, res_reward, res_extra_loss, lstm_states = func_out res_frames.append(res_frame) res_rewards.append(res_reward) extra_loss += res_extra_loss sampled_frame_raw = self.get_sampled_frame(res_frame) sampled_frames_raw.append(sampled_frame_raw) # TODO(lukaszkaiser): this should be consistent with modality.bottom() sampled_frame = common_layers.standardize_images(sampled_frame_raw) sampled_frames.append(sampled_frame) # Only for Softmax loss: sample next frame so we can keep iterating. if self.is_predicting and i >= hparams.video_num_input_frames: all_frames[i+1] = sampled_frame # Concatenate results and return them. output_frames = res_frames[hparams.video_num_input_frames-1:] frames = tf.stack(output_frames, axis=1) has_input_predictions = hparams.video_num_input_frames > 1 if self.is_training and hparams.internal_loss and has_input_predictions: # add the loss for input frames as well. extra_gts = input_frames[1:] extra_pds = res_frames[:hparams.video_num_input_frames-1] extra_raw_gts = features["inputs_raw"][:, 1:] recon_loss = self.get_extra_internal_loss( extra_raw_gts, extra_gts, extra_pds) extra_loss += recon_loss if not self.has_reward: return frames, extra_loss rewards = tf.concat(res_rewards[hparams.video_num_input_frames-1:], axis=1) return {"targets": frames, "target_reward": rewards}, extra_loss
def bottom(self, inputs): with tf.variable_scope(self.name): inputs = common_layers.standardize_images(inputs) if not context.in_eager_mode(): tf.summary.image("inputs", inputs, max_outputs=2) return tf.to_float(inputs)
def body(self, features): hparams = self.hparams is_predicting = hparams.mode == tf.estimator.ModeKeys.PREDICT if hparams.video_num_target_frames < 2: res = self.body_single(features) return res # TODO(lukaszkaiser): the split axes and the argmax below heavily depend on # using the default (a bit strange) video modality - we should change that. # Split inputs and targets into lists. input_frames = list( tf.split(features["inputs"], hparams.video_num_input_frames, axis=-1)) target_frames = list( tf.split(features["targets"], hparams.video_num_target_frames, axis=-1)) all_frames = input_frames + target_frames if "input_action" in features: input_actions = list( tf.split(features["input_action"], hparams.video_num_input_frames, axis=1)) target_actions = list( tf.split(features["target_action"], hparams.video_num_target_frames, axis=1)) all_actions = input_actions + target_actions # Run a number of steps. res_frames = [] if "target_reward" in features: res_rewards, extra_loss = [], 0.0 sample_prob = common_layers.inverse_exp_decay( hparams.scheduled_sampling_warmup_steps) sample_prob *= hparams.scheduled_sampling_prob for i in range(hparams.video_num_target_frames): cur_frames = all_frames[i:i + hparams.video_num_input_frames] features["inputs"] = tf.concat(cur_frames, axis=-1) if "input_action" in features: cur_actions = all_actions[i:i + hparams.video_num_input_frames] features["input_action"] = tf.concat(cur_actions, axis=1) # Run model. with tf.variable_scope(tf.get_variable_scope(), reuse=i > 0): if "target_reward" not in features: res_frames.append(self.body_single(features)) else: res_dict, res_extra_loss = self.body_single(features) extra_loss += res_extra_loss res_frames.append(res_dict["targets"]) res_rewards.append(res_dict["target_reward"]) # When predicting, use the generated frame. orig_frame = all_frames[i + hparams.video_num_input_frames] shape = common_layers.shape_list(orig_frame) sampled_frame = tf.reshape( res_frames[-1], shape[:-1] + [hparams.problem.num_channels, 256]) sampled_frame = tf.to_float(tf.argmax(sampled_frame, axis=-1)) sampled_frame = common_layers.standardize_images(sampled_frame) if is_predicting: all_frames[i + hparams.video_num_input_frames] = sampled_frame # Scheduled sampling during training. if (hparams.scheduled_sampling_prob > 0.0 and self.is_training): do_sample = tf.less(tf.random_uniform([shape[0]]), sample_prob) sampled_frame = tf.where(do_sample, sampled_frame, orig_frame) all_frames[i + hparams.video_num_input_frames] = sampled_frame # Concatenate results and return them. frames = tf.concat(res_frames, axis=-1) if "target_reward" not in features: return frames rewards = tf.concat(res_rewards, axis=1) return {"targets": frames, "target_reward": rewards}, extra_loss