def construct_predictive_tower( self, input_image, input_reward, action, lstm_state, latent, concat_latent=False): # Main tower lstm_func = common_video.conv_lstm_2d frame_shape = common_layers.shape_list(input_image) batch_size, img_height, img_width, color_channels = frame_shape # the number of different pixel motion predictions # and the number of masks for each of those predictions num_masks = self.hparams.num_masks upsample_method = self.hparams.upsample_method tile_and_concat = common_video.tile_and_concat lstm_size = self.tinyify([32, 32, 64, 64, 128, 64, 32]) conv_size = self.tinyify([32]) with tf.variable_scope("main", reuse=tf.AUTO_REUSE): hidden5, skips, layer_id = self.bottom_part_tower( input_image, input_reward, action, latent, lstm_state, lstm_size, conv_size, concat_latent=concat_latent) enc0, enc1 = skips with tf.variable_scope("upsample1", reuse=tf.AUTO_REUSE): enc4 = common_layers.cyclegan_upsample( hidden5, num_outputs=hidden5.shape.as_list()[-1], stride=[2, 2], method=upsample_method) enc1_shape = common_layers.shape_list(enc1) enc4 = enc4[:, :enc1_shape[1], :enc1_shape[2], :] # Cut to shape. enc4 = tile_and_concat(enc4, latent, concat_latent=concat_latent) hidden6, lstm_state[layer_id] = lstm_func( enc4, lstm_state[layer_id], lstm_size[5], name="state6", spatial_dims=enc1_shape[1:-1]) # 16x16 hidden6 = tile_and_concat(hidden6, latent, concat_latent=concat_latent) hidden6 = tfcl.layer_norm(hidden6, scope="layer_norm7") # Skip connection. hidden6 = tf.concat(axis=3, values=[hidden6, enc1]) # both 16x16 layer_id += 1 with tf.variable_scope("upsample2", reuse=tf.AUTO_REUSE): enc5 = common_layers.cyclegan_upsample( hidden6, num_outputs=hidden6.shape.as_list()[-1], stride=[2, 2], method=upsample_method) enc0_shape = common_layers.shape_list(enc0) enc5 = enc5[:, :enc0_shape[1], :enc0_shape[2], :] # Cut to shape. enc5 = tile_and_concat(enc5, latent, concat_latent=concat_latent) hidden7, lstm_state[layer_id] = lstm_func( enc5, lstm_state[layer_id], lstm_size[6], name="state7", spatial_dims=enc0_shape[1:-1]) # 32x32 hidden7 = tfcl.layer_norm(hidden7, scope="layer_norm8") layer_id += 1 # Skip connection. hidden7 = tf.concat(axis=3, values=[hidden7, enc0]) # both 32x32 with tf.variable_scope("upsample3", reuse=tf.AUTO_REUSE): enc6 = common_layers.cyclegan_upsample( hidden7, num_outputs=hidden7.shape.as_list()[-1], stride=[2, 2], method=upsample_method) enc6 = tfcl.layer_norm(enc6, scope="layer_norm9") enc6 = tile_and_concat(enc6, latent, concat_latent=concat_latent) if self.hparams.model_options == "DNA": # Using largest hidden state for predicting untied conv kernels. enc7 = tfl.conv2d_transpose( enc6, self.hparams.dna_kernel_size**2, [1, 1], strides=(1, 1), padding="SAME", name="convt4", activation=None) else: # Using largest hidden state for predicting a new image layer. enc7 = tfl.conv2d_transpose( enc6, color_channels, [1, 1], strides=(1, 1), padding="SAME", name="convt4", activation=None) # This allows the network to also generate one image from scratch, # which is useful when regions of the image become unoccluded. transformed = [tf.nn.sigmoid(enc7)] if self.hparams.model_options == "CDNA": # cdna_input = tf.reshape(hidden5, [int(batch_size), -1]) cdna_input = tfcl.flatten(hidden5) transformed += common_video.cdna_transformation( input_image, cdna_input, num_masks, int(color_channels), self.hparams.dna_kernel_size, self.hparams.relu_shift) elif self.hparams.model_options == "DNA": # Only one mask is supported (more should be unnecessary). if num_masks != 1: raise ValueError("Only one mask is supported for DNA model.") transformed = [ common_video.dna_transformation( input_image, enc7, self.hparams.dna_kernel_size, self.hparams.relu_shift)] masks = tfl.conv2d( enc6, filters=num_masks + 1, kernel_size=[1, 1], strides=(1, 1), name="convt7", padding="SAME") masks = masks[:, :img_height, :img_width, ...] masks = tf.reshape( tf.nn.softmax(tf.reshape(masks, [-1, num_masks + 1])), [batch_size, int(img_height), int(img_width), num_masks + 1]) mask_list = tf.split( axis=3, num_or_size_splits=num_masks + 1, value=masks) output = mask_list[0] * input_image for layer, mask in zip(transformed, mask_list[1:]): # TODO(mbz): take another look at this logic and verify. output = output[:, :img_height, :img_width, :] layer = layer[:, :img_height, :img_width, :] output += layer * mask # Map to softmax digits if self.is_per_pixel_softmax: output = tf.layers.dense( output, self.hparams.problem.num_channels * 256, name="logits") mid_outputs = [enc0, enc1, enc4, enc5, enc6] return output, lstm_state, mid_outputs
def construct_predictive_tower(self, input_image, input_reward, action, lstm_state, latent, concat_latent=False): # Main tower lstm_func = common_video.conv_lstm_2d frame_shape = common_layers.shape_list(input_image) batch_size, img_height, img_width, color_channels = frame_shape # the number of different pixel motion predictions # and the number of masks for each of those predictions num_masks = self.hparams.num_masks upsample_method = self.hparams.upsample_method tile_and_concat = common_video.tile_and_concat lstm_size = self.tinyify([32, 32, 64, 64, 128, 64, 32]) conv_size = self.tinyify([32]) with tf.variable_scope("main", reuse=tf.AUTO_REUSE): hidden5, skips = self.bottom_part_tower( input_image, input_reward, action, latent, lstm_state, lstm_size, conv_size, concat_latent=concat_latent) enc0, enc1 = skips with tf.variable_scope("upsample1", reuse=tf.AUTO_REUSE): enc4 = common_layers.cyclegan_upsample( hidden5, num_outputs=hidden5.shape.as_list()[-1], stride=[2, 2], method=upsample_method) enc1_shape = common_layers.shape_list(enc1) enc4 = enc4[:, :enc1_shape[1], :enc1_shape[2], :] # Cut to shape. enc4 = tile_and_concat(enc4, latent, concat_latent=concat_latent) hidden6, lstm_state[5] = lstm_func( enc4, lstm_state[5], lstm_size[5], name="state6", spatial_dims=enc1_shape[1:-1]) # 16x16 hidden6 = tile_and_concat(hidden6, latent, concat_latent=concat_latent) hidden6 = tfcl.layer_norm(hidden6, scope="layer_norm7") # Skip connection. hidden6 = tf.concat(axis=3, values=[hidden6, enc1]) # both 16x16 with tf.variable_scope("upsample2", reuse=tf.AUTO_REUSE): enc5 = common_layers.cyclegan_upsample( hidden6, num_outputs=hidden6.shape.as_list()[-1], stride=[2, 2], method=upsample_method) enc0_shape = common_layers.shape_list(enc0) enc5 = enc5[:, :enc0_shape[1], :enc0_shape[2], :] # Cut to shape. enc5 = tile_and_concat(enc5, latent, concat_latent=concat_latent) hidden7, lstm_state[6] = lstm_func( enc5, lstm_state[6], lstm_size[6], name="state7", spatial_dims=enc0_shape[1:-1]) # 32x32 hidden7 = tfcl.layer_norm(hidden7, scope="layer_norm8") # Skip connection. hidden7 = tf.concat(axis=3, values=[hidden7, enc0]) # both 32x32 with tf.variable_scope("upsample3", reuse=tf.AUTO_REUSE): enc6 = common_layers.cyclegan_upsample( hidden7, num_outputs=hidden7.shape.as_list()[-1], stride=[2, 2], method=upsample_method) enc6 = tfcl.layer_norm(enc6, scope="layer_norm9") enc6 = tile_and_concat(enc6, latent, concat_latent=concat_latent) if self.hparams.model_options == "DNA": # Using largest hidden state for predicting untied conv kernels. enc7 = tfl.conv2d_transpose(enc6, self.hparams.dna_kernel_size**2, [1, 1], strides=(1, 1), padding="SAME", name="convt4", activation=None) else: # Using largest hidden state for predicting a new image layer. enc7 = tfl.conv2d_transpose(enc6, color_channels, [1, 1], strides=(1, 1), padding="SAME", name="convt4", activation=None) # This allows the network to also generate one image from scratch, # which is useful when regions of the image become unoccluded. transformed = [tf.nn.sigmoid(enc7)] if self.hparams.model_options == "CDNA": # cdna_input = tf.reshape(hidden5, [int(batch_size), -1]) cdna_input = tfcl.flatten(hidden5) transformed += common_video.cdna_transformation( input_image, cdna_input, num_masks, int(color_channels), self.hparams.dna_kernel_size, self.hparams.relu_shift) elif self.hparams.model_options == "DNA": # Only one mask is supported (more should be unnecessary). if num_masks != 1: raise ValueError( "Only one mask is supported for DNA model.") transformed = [ common_video.dna_transformation( input_image, enc7, self.hparams.dna_kernel_size, self.hparams.relu_shift) ] masks = tfl.conv2d(enc6, filters=num_masks + 1, kernel_size=[1, 1], strides=(1, 1), name="convt7", padding="SAME") masks = tf.reshape( tf.nn.softmax(tf.reshape(masks, [-1, num_masks + 1])), [batch_size, int(img_height), int(img_width), num_masks + 1]) mask_list = tf.split(axis=3, num_or_size_splits=num_masks + 1, value=masks) output = mask_list[0] * input_image for layer, mask in zip(transformed, mask_list[1:]): output += layer * mask return output, lstm_state