def create_acvideo_discriminator(clips, actions, ndf=64, norm_layer='instance', use_noise=False, noise_sigma=None): norm_layer = ops.get_norm_layer(norm_layer) layers = [] paddings = [[0, 0], [0, 0], [1, 1], [1, 1], [0, 0]] clips = clips * 2 - 1 clip_pairs = tf.concat([clips[:-1], clips[1:]], axis=-1) clip_pairs = tile_concat([clip_pairs, actions[..., None, None, :]], axis=-1) clip_pairs = tf_utils.transpose_batch_time(clip_pairs) with tf.variable_scope("acvideo_layer_1"): h1 = noise(clip_pairs, use_noise, noise_sigma) h1 = conv3d(tf.pad(h1, paddings), ndf, kernel_size=(3, 4, 4), strides=(1, 2, 2), padding='VALID', use_bias=False) h1 = lrelu(h1, 0.2) layers.append(h1) with tf.variable_scope("acvideo_layer_2"): h2 = noise(h1, use_noise, noise_sigma) h2 = conv3d(tf.pad(h2, paddings), ndf * 2, kernel_size=(3, 4, 4), strides=(1, 2, 2), padding='VALID', use_bias=False) h2 = norm_layer(h2) h2 = lrelu(h2, 0.2) layers.append(h2) with tf.variable_scope("acvideo_layer_3"): h3 = noise(h2, use_noise, noise_sigma) h3 = conv3d(tf.pad(h3, paddings), ndf * 4, kernel_size=(3, 4, 4), strides=(1, 2, 2), padding='VALID', use_bias=False) h3 = norm_layer(h3) h3 = lrelu(h3, 0.2) layers.append(h3) with tf.variable_scope("acvideo_layer_4"): logits = conv3d(tf.pad(h3, paddings), 1, kernel_size=(3, 4, 4), strides=(1, 2, 2), padding='VALID', use_bias=False) layers.append(logits) return nest.map_structure(tf_utils.transpose_batch_time, layers)
def encoder_fn(inputs, hparams=None): images = inputs['images'] image_pairs = tf.concat([images[:hparams.sequence_length - 1], images[1:hparams.sequence_length]], axis=-1) if 'actions' in inputs: image_pairs = tile_concat([image_pairs, tf.expand_dims(tf.expand_dims(inputs['actions'], axis=-2), axis=-2)], axis=-1) outputs = create_encoder(image_pairs, e_net=hparams.e_net, use_e_rnn=hparams.use_e_rnn, rnn=hparams.rnn, nz=hparams.nz, nef=hparams.nef, n_layers=hparams.n_layers, norm_layer=hparams.norm_layer) return outputs
def discriminator_fn(targets, inputs=None, hparams=None): if inputs is None: targets_and_inputs = (targets, ) else: if hparams.d_conditional: if hparams.d_use_gt_inputs: image_inputs = inputs['images'][hparams.context_frames - 1:][:targets.shape[0].value] if 'actions' in inputs: action_inputs = inputs['actions'][hparams.context_frames - 1:][:targets.shape[0]. value] gen_inputs = ops.tile_concat( [image_inputs, action_inputs[:, :, None, None, :]], axis=-1) else: gen_inputs = image_inputs else: # exactly one of them should be true assert bool('gen_inputs' in inputs) != bool( 'gen_inputs_enc' in inputs) gen_inputs = inputs[ 'gen_inputs'] if 'gen_inputs' in inputs else inputs[ 'gen_inputs_enc'] gen_inputs = tf.stop_gradient(gen_inputs) else: gen_inputs = None targets_and_inputs = (targets, gen_inputs) features = create_discriminator( *targets_and_inputs, d_net=hparams.d_net, n_layers=hparams.n_layers, ndf=hparams.ndf, norm_layer=hparams.norm_layer, downsample_layer=hparams.d_downsample_layer) features, logits = features[:-1], features[-1] outputs = {'discrim_logits': logits} for i, feature in enumerate(features): outputs['discrim_feature%d' % i] = feature return logits, outputs
def encoder_fn(inputs, hparams=None): image_pairs = [] for i in range(hparams.num_views): suffix = '%d' % i if i > 0 else '' images = inputs['images' + suffix] image_pairs.append(images[:hparams.sequence_length - 1]) image_pairs.append(images[1:hparams.sequence_length]) image_pairs = tf.concat(image_pairs, axis=-1) if 'actions' in inputs: image_pairs = tile_concat([ image_pairs, tf.expand_dims(tf.expand_dims(inputs['actions'], axis=-2), axis=-2) ], axis=-1) outputs = create_encoder(image_pairs, e_net=hparams.e_net, use_e_rnn=hparams.use_e_rnn, rnn=hparams.rnn, nz=hparams.nz, nef=hparams.nef, n_layers=hparams.n_layers, norm_layer=hparams.norm_layer) return outputs
def call(self, inputs, states): image = inputs['images'] time, gen_image = states with tf.control_dependencies([tf.assert_equal(time[1:], time[0])]): t = tf.to_int32(tf.identity(time[0])) image = tf.where(self.ground_truth[t], image, gen_image) # schedule sampling (if any) if 'actions' in inputs: action = inputs['actions'] gen_input = ops.tile_concat([image, action[:, None, None, :]], axis=-1) else: gen_input = image gen_image = create_generator( gen_input, output_nc=self.hparams.output_nc, ngf=self.hparams.ngf, norm_layer=self.hparams.norm_layer, downsample_layer=self.hparams.downsample_layer, upsample_layer=self.hparams.upsample_layer) outputs = (gen_image, gen_input) new_states = (time + 1, gen_image) return outputs, new_states
def call(self, inputs, states): norm_layer = ops.get_norm_layer(self.hparams.norm_layer) downsample_layer = ops.get_downsample_layer( self.hparams.downsample_layer) upsample_layer = ops.get_upsample_layer(self.hparams.upsample_layer) image_shape = inputs['images'].get_shape().as_list() batch_size, height, width, color_channels = image_shape time = states['time'] with tf.control_dependencies([tf.assert_equal(time[1:], time[0])]): t = tf.to_int32(tf.identity(time[0])) if 'states' in inputs: state = tf.where(self.ground_truth[t], inputs['states'], states['gen_state']) state_action = [] state_action_z = [] if 'actions' in inputs: state_action.append(inputs['actions']) state_action_z.append(inputs['actions']) if 'states' in inputs: state_action.append(state) # don't backpropagate the convnet through the state dynamics state_action_z.append(tf.stop_gradient(state)) if 'zs' in inputs: if self.hparams.use_rnn_z: with tf.variable_scope('%s_z' % self.hparams.rnn): rnn_z, rnn_z_state = self._rnn_func( inputs['zs'], states['rnn_z_state'], self.hparams.nz) state_action_z.append(rnn_z) else: state_action_z.append(inputs['zs']) def concat(tensors, axis): if len(tensors) == 0: return tf.zeros([batch_size, 0]) elif len(tensors) == 1: return tensors[0] else: return tf.concat(tensors, axis=axis) state_action = concat(state_action, axis=-1) state_action_z = concat(state_action_z, axis=-1) image_views = [] first_image_views = [] if 'pix_distribs' in inputs: pix_distrib_views = [] for i in range(self.hparams.num_views): suffix = '%d' % i if i > 0 else '' image_view = tf.where( self.ground_truth[t], inputs['images' + suffix], states['gen_image' + suffix]) # schedule sampling (if any) image_views.append(image_view) first_image_views.append(self.inputs['images' + suffix][0]) if 'pix_distribs' in inputs: pix_distrib_view = tf.where(self.ground_truth[t], inputs['pix_distribs' + suffix], states['gen_pix_distrib' + suffix]) pix_distrib_views.append(pix_distrib_view) outputs = {} new_states = {} all_layers = [] for i in range(self.hparams.num_views): suffix = '%d' % i if i > 0 else '' conv_rnn_states = states['conv_rnn_states' + suffix] layers = [] new_conv_rnn_states = [] for i, (out_channels, use_conv_rnn) in enumerate(self.encoder_layer_specs): with tf.variable_scope('h%d' % i + suffix): if i == 0: # all image views and the first image corresponding to this view only h = tf.concat(image_views + first_image_views, axis=-1) kernel_size = (5, 5) else: h = layers[-1][-1] kernel_size = (3, 3) if self.hparams.where_add == 'all' or ( self.hparams.where_add == 'input' and i == 0): h = tile_concat([h, state_action_z[:, None, None, :]], axis=-1) h = downsample_layer(h, out_channels, kernel_size=kernel_size, strides=(2, 2)) h = norm_layer(h) h = tf.nn.relu(h) if use_conv_rnn: conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)] with tf.variable_scope('%s_h%d' % (self.hparams.conv_rnn, i) + suffix): if self.hparams.where_add == 'all': conv_rnn_h = tile_concat( [h, state_action_z[:, None, None, :]], axis=-1) else: conv_rnn_h = h conv_rnn_h, conv_rnn_state = self._conv_rnn_func( conv_rnn_h, conv_rnn_state, out_channels) new_conv_rnn_states.append(conv_rnn_state) layers.append((h, conv_rnn_h) if use_conv_rnn else (h, )) num_encoder_layers = len(layers) for i, (out_channels, use_conv_rnn) in enumerate(self.decoder_layer_specs): with tf.variable_scope('h%d' % len(layers) + suffix): if i == 0: h = layers[-1][-1] else: h = tf.concat([ layers[-1][-1], layers[num_encoder_layers - i - 1][-1] ], axis=-1) if self.hparams.where_add == 'all' or ( self.hparams.where_add == 'middle' and i == 0): h = tile_concat([h, state_action_z[:, None, None, :]], axis=-1) h = upsample_layer(h, out_channels, kernel_size=(3, 3), strides=(2, 2)) h = norm_layer(h) h = tf.nn.relu(h) if use_conv_rnn: conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)] with tf.variable_scope( '%s_h%d' % (self.hparams.conv_rnn, len(layers)) + suffix): if self.hparams.where_add == 'all': conv_rnn_h = tile_concat( [h, state_action_z[:, None, None, :]], axis=-1) else: conv_rnn_h = h conv_rnn_h, conv_rnn_state = self._conv_rnn_func( conv_rnn_h, conv_rnn_state, out_channels) new_conv_rnn_states.append(conv_rnn_state) layers.append((h, conv_rnn_h) if use_conv_rnn else (h, )) assert len(new_conv_rnn_states) == len(conv_rnn_states) new_states['conv_rnn_states' + suffix] = new_conv_rnn_states all_layers.append(layers) if self.hparams.shared_views: break for i in range(self.hparams.num_views): suffix = '%d' % i if i > 0 else '' if self.hparams.shared_views: layers, = all_layers else: layers = all_layers[i] image = image_views[i] last_images = states['last_images' + suffix][1:] + [image] if 'pix_distribs' in inputs: pix_distrib = pix_distrib_views[i] last_pix_distribs = states['last_pix_distribs' + suffix][1:] + [pix_distrib] if self.hparams.last_frames and self.hparams.num_transformed_images: if self.hparams.transformation == 'flow': with tf.variable_scope('h%d_flow' % len(layers) + suffix): h_flow = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1)) h_flow = norm_layer(h_flow) h_flow = tf.nn.relu(h_flow) with tf.variable_scope('flows' + suffix): flows = conv2d(h_flow, 2 * self.hparams.last_frames * self.hparams.num_transformed_images, kernel_size=(3, 3), strides=(1, 1)) flows = tf.reshape(flows, [ batch_size, height, width, 2, self.hparams.last_frames * self.hparams.num_transformed_images ]) else: assert len(self.hparams.kernel_size) == 2 kernel_shape = list(self.hparams.kernel_size) + [ self.hparams.last_frames * self.hparams.num_transformed_images ] if self.hparams.transformation == 'dna': with tf.variable_scope('h%d_dna_kernel' % len(layers) + suffix): h_dna_kernel = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1)) h_dna_kernel = norm_layer(h_dna_kernel) h_dna_kernel = tf.nn.relu(h_dna_kernel) # Using largest hidden state for predicting untied conv kernels. with tf.variable_scope('dna_kernels' + suffix): kernels = conv2d(h_dna_kernel, np.prod(kernel_shape), kernel_size=(3, 3), strides=(1, 1)) kernels = tf.reshape(kernels, [batch_size, height, width] + kernel_shape) kernels = kernels + identity_kernel( self.hparams.kernel_size)[None, None, None, :, :, None] kernel_spatial_axes = [3, 4] elif self.hparams.transformation == 'cdna': with tf.variable_scope('cdna_kernels' + suffix): smallest_layer = layers[num_encoder_layers - 1][-1] kernels = dense(flatten(smallest_layer), np.prod(kernel_shape)) kernels = tf.reshape(kernels, [batch_size] + kernel_shape) kernels = kernels + identity_kernel( self.hparams.kernel_size)[None, :, :, None] kernel_spatial_axes = [1, 2] else: raise ValueError('Invalid transformation %s' % self.hparams.transformation) if self.hparams.transformation != 'flow': with tf.name_scope('kernel_normalization' + suffix): kernels = tf.nn.relu(kernels - RELU_SHIFT) + RELU_SHIFT kernels /= tf.reduce_sum(kernels, axis=kernel_spatial_axes, keepdims=True) if self.hparams.generate_scratch_image: with tf.variable_scope('h%d_scratch' % len(layers) + suffix): h_scratch = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1)) h_scratch = norm_layer(h_scratch) h_scratch = tf.nn.relu(h_scratch) # Using largest hidden state for predicting a new image layer. # This allows the network to also generate one image from scratch, # which is useful when regions of the image become unoccluded. with tf.variable_scope('scratch_image' + suffix): scratch_image = conv2d(h_scratch, color_channels, kernel_size=(3, 3), strides=(1, 1)) scratch_image = tf.nn.sigmoid(scratch_image) with tf.name_scope('transformed_images' + suffix): transformed_images = [] if self.hparams.last_frames and self.hparams.num_transformed_images: if self.hparams.transformation == 'flow': transformed_images.extend( apply_flows(last_images, flows)) else: transformed_images.extend( apply_kernels(last_images, kernels, self.hparams.dilation_rate)) if self.hparams.prev_image_background: transformed_images.append(image) if self.hparams.first_image_background and not self.hparams.context_images_background: transformed_images.append(self.inputs['images' + suffix][0]) if self.hparams.context_images_background: transformed_images.extend( tf.unstack( self.inputs['images' + suffix][:self.hparams.context_frames])) if self.hparams.generate_scratch_image: transformed_images.append(scratch_image) if 'pix_distribs' in inputs: with tf.name_scope('transformed_pix_distribs' + suffix): transformed_pix_distribs = [] if self.hparams.last_frames and self.hparams.num_transformed_images: if self.hparams.transformation == 'flow': transformed_pix_distribs.extend( apply_flows(last_pix_distribs, flows)) else: transformed_pix_distribs.extend( apply_kernels(last_pix_distribs, kernels, self.hparams.dilation_rate)) if self.hparams.prev_image_background: transformed_pix_distribs.append(pix_distrib) if self.hparams.first_image_background and not self.hparams.context_images_background: transformed_pix_distribs.append( self.inputs['pix_distribs' + suffix][0]) if self.hparams.context_images_background: transformed_pix_distribs.extend( tf.unstack(self.inputs['pix_distribs' + suffix] [:self.hparams.context_frames])) if self.hparams.generate_scratch_image: transformed_pix_distribs.append(pix_distrib) with tf.name_scope('masks' + suffix): if len(transformed_images) > 1: with tf.variable_scope('h%d_masks' % len(layers) + suffix): h_masks = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1)) h_masks = norm_layer(h_masks) h_masks = tf.nn.relu(h_masks) with tf.variable_scope('masks' + suffix): if self.hparams.dependent_mask: h_masks = tf.concat([h_masks] + transformed_images, axis=-1) masks = conv2d(h_masks, len(transformed_images), kernel_size=(3, 3), strides=(1, 1)) masks = tf.nn.softmax(masks) masks = tf.split(masks, len(transformed_images), axis=-1) elif len(transformed_images) == 1: masks = [tf.ones([batch_size, height, width, 1])] else: raise ValueError( "Either one of the following should be true: " "last_frames and num_transformed_images, first_image_background, " "prev_image_background, generate_scratch_image") with tf.name_scope('gen_images' + suffix): assert len(transformed_images) == len(masks) gen_image = tf.add_n([ transformed_image * mask for transformed_image, mask in zip(transformed_images, masks) ]) if 'pix_distribs' in inputs: with tf.name_scope('gen_pix_distribs' + suffix): assert len(transformed_pix_distribs) == len(masks) gen_pix_distrib = tf.add_n([ transformed_pix_distrib * mask for transformed_pix_distrib, mask in zip( transformed_pix_distribs, masks) ]) if self.hparams.renormalize_pixdistrib: gen_pix_distrib /= tf.reduce_sum(gen_pix_distrib, axis=(1, 2), keepdims=True) outputs['gen_images' + suffix] = gen_image outputs['transformed_images' + suffix] = tf.stack( transformed_images, axis=-1) outputs['masks' + suffix] = tf.stack(masks, axis=-1) if 'pix_distribs' in inputs: outputs['gen_pix_distribs' + suffix] = gen_pix_distrib outputs['transformed_pix_distribs' + suffix] = tf.stack( transformed_pix_distribs, axis=-1) if self.hparams.transformation == 'flow': outputs['gen_flows' + suffix] = flows flows_transposed = tf.transpose(flows, [0, 1, 2, 4, 3]) flows_rgb_transposed = tf_utils.flow_to_rgb(flows_transposed) flows_rgb = tf.transpose(flows_rgb_transposed, [0, 1, 2, 4, 3]) outputs['gen_flows_rgb' + suffix] = flows_rgb new_states['gen_image' + suffix] = gen_image new_states['last_images' + suffix] = last_images if 'pix_distribs' in inputs: new_states['gen_pix_distrib' + suffix] = gen_pix_distrib new_states['last_pix_distribs' + suffix] = last_pix_distribs if 'states' in inputs: with tf.name_scope('gen_states'): with tf.variable_scope('state_pred'): gen_state = dense(state_action, inputs['states'].shape[-1].value) if 'states' in inputs: outputs['gen_states'] = gen_state new_states['time'] = time + 1 if 'zs' in inputs and self.hparams.use_rnn_z: new_states['rnn_z_state'] = rnn_z_state if 'states' in inputs: new_states['gen_state'] = gen_state return outputs, new_states
def call(self, inputs, states): norm_layer = ops.get_norm_layer(self.hparams.norm_layer) image_shape = inputs['images'].get_shape().as_list() batch_size, height, width, color_channels = image_shape conv_rnn_states = states['conv_rnn_states'] time = states['time'] with tf.control_dependencies([tf.assert_equal(time[1:], time[0])]): t = tf.to_int32(tf.identity(time[0])) last_gt_images = states['last_gt_images'][1:] + [inputs['images']] last_pred_flows = states['last_pred_flows'][1:] + [ tf.zeros_like(states['last_pred_flows'][-1]) ] image = tf.where(self.ground_truth[t], inputs['images'], states['gen_image']) last_images = states['last_images'][1:] + [image] last_base_images = [ tf.where(self.ground_truth[t], last_gt_image, last_base_image) for last_gt_image, last_base_image in zip( last_gt_images, states['last_base_images']) ] last_base_flows = [ tf.where(self.ground_truth[t], last_pred_flow, last_base_flow) for last_pred_flow, last_base_flow in zip( last_pred_flows, states['last_base_flows']) ] if 'pix_distribs' in inputs: last_gt_pix_distribs = states['last_gt_pix_distribs'][1:] + [ inputs['pix_distribs'] ] last_base_pix_distribs = [ tf.where(self.ground_truth[t], last_gt_pix_distrib, last_base_pix_distrib) for last_gt_pix_distrib, last_base_pix_distrib in zip( last_gt_pix_distribs, states['last_base_pix_distribs']) ] if 'states' in inputs: state = tf.where(self.ground_truth[t], inputs['states'], states['gen_state']) state_action = [] state_action_z = [] if 'actions' in inputs: state_action.append(inputs['actions']) state_action_z.append(inputs['actions']) if 'states' in inputs: state_action.append(state) # don't backpropagate the convnet through the state dynamics state_action_z.append(tf.stop_gradient(state)) if 'zs' in inputs: if self.hparams.use_rnn_z: with tf.variable_scope('%s_z' % self.hparams.rnn): rnn_z, rnn_z_state = self._rnn_func( inputs['zs'], states['rnn_z_state'], self.hparams.nz) state_action_z.append(rnn_z) else: state_action_z.append(inputs['zs']) def concat(tensors, axis): if len(tensors) == 0: return tf.zeros([batch_size, 0]) elif len(tensors) == 1: return tensors[0] else: return tf.concat(tensors, axis=axis) state_action = concat(state_action, axis=-1) state_action_z = concat(state_action_z, axis=-1) if 'actions' in inputs: gen_input = tile_concat(last_images + [inputs['actions'][:, None, None, :]], axis=-1) else: gen_input = tf.concat(last_images) layers = [] new_conv_rnn_states = [] for i, (out_channels, use_conv_rnn) in enumerate(self.encoder_layer_specs): with tf.variable_scope('h%d' % i): if i == 0: h = tf.concat(last_images, axis=-1) kernel_size = (5, 5) else: h = layers[-1][-1] kernel_size = (3, 3) h = conv_pool2d(tile_concat( [h, state_action_z[:, None, None, :]], axis=-1), out_channels, kernel_size=kernel_size, strides=(2, 2)) h = norm_layer(h) h = tf.nn.relu(h) if use_conv_rnn: conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)] with tf.variable_scope('%s_h%d' % (self.hparams.conv_rnn, i)): conv_rnn_h, conv_rnn_state = self._conv_rnn_func( tile_concat([h, state_action_z[:, None, None, :]], axis=-1), conv_rnn_state, out_channels) new_conv_rnn_states.append(conv_rnn_state) layers.append((h, conv_rnn_h) if use_conv_rnn else (h, )) num_encoder_layers = len(layers) for i, (out_channels, use_conv_rnn) in enumerate(self.decoder_layer_specs): with tf.variable_scope('h%d' % len(layers)): if i == 0: h = layers[-1][-1] else: h = tf.concat([ layers[-1][-1], layers[num_encoder_layers - i - 1][-1] ], axis=-1) h = upsample_conv2d(tile_concat( [h, state_action_z[:, None, None, :]], axis=-1), out_channels, kernel_size=(3, 3), strides=(2, 2)) h = norm_layer(h) h = tf.nn.relu(h) if use_conv_rnn: conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)] with tf.variable_scope('%s_h%d' % (self.hparams.conv_rnn, len(layers))): conv_rnn_h, conv_rnn_state = self._conv_rnn_func( tile_concat([h, state_action_z[:, None, None, :]], axis=-1), conv_rnn_state, out_channels) new_conv_rnn_states.append(conv_rnn_state) layers.append((h, conv_rnn_h) if use_conv_rnn else (h, )) assert len(new_conv_rnn_states) == len(conv_rnn_states) with tf.variable_scope('h%d_flow' % len(layers)): h_flow = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1)) h_flow = norm_layer(h_flow) h_flow = tf.nn.relu(h_flow) with tf.variable_scope('flows'): flows = conv2d(h_flow, 2 * self.hparams.last_frames, kernel_size=(3, 3), strides=(1, 1)) flows = tf.reshape( flows, [batch_size, height, width, 2, self.hparams.last_frames]) with tf.name_scope('transformed_images'): transformed_images = [] last_pred_flows = [ flow + flow_ops.image_warp(last_pred_flow, flow) for last_pred_flow, flow in zip(last_pred_flows, tf.unstack(flows, axis=-1)) ] last_base_flows = [ flow + flow_ops.image_warp(last_base_flow, flow) for last_base_flow, flow in zip(last_base_flows, tf.unstack(flows, axis=-1)) ] for last_base_image, last_base_flow in zip(last_base_images, last_base_flows): transformed_images.append( flow_ops.image_warp(last_base_image, last_base_flow)) if 'pix_distribs' in inputs: with tf.name_scope('transformed_pix_distribs'): transformed_pix_distribs = [] for last_base_pix_distrib, last_base_flow in zip( last_base_pix_distribs, last_base_flows): transformed_pix_distribs.append( flow_ops.image_warp(last_base_pix_distrib, last_base_flow)) with tf.name_scope('masks'): if len(transformed_images) > 1: with tf.variable_scope('h%d_masks' % len(layers)): h_masks = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1)) h_masks = norm_layer(h_masks) h_masks = tf.nn.relu(h_masks) with tf.variable_scope('masks'): if self.hparams.dependent_mask: h_masks = tf.concat([h_masks] + transformed_images, axis=-1) masks = conv2d(h_masks, len(transformed_images), kernel_size=(3, 3), strides=(1, 1)) masks = tf.nn.softmax(masks) masks = tf.split(masks, len(transformed_images), axis=-1) elif len(transformed_images) == 1: masks = [tf.ones([batch_size, height, width, 1])] else: raise ValueError( "Either one of the following should be true: " "last_frames and num_transformed_images, first_image_background, " "prev_image_background, generate_scratch_image") with tf.name_scope('gen_images'): assert len(transformed_images) == len(masks) gen_image = tf.add_n([ transformed_image * mask for transformed_image, mask in zip(transformed_images, masks) ]) if 'pix_distribs' in inputs: with tf.name_scope('gen_pix_distribs'): assert len(transformed_pix_distribs) == len(masks) gen_pix_distrib = tf.add_n([ transformed_pix_distrib * mask for transformed_pix_distrib, mask in zip( transformed_pix_distribs, masks) ]) # TODO: is this needed? # gen_pix_distrib /= tf.reduce_sum(gen_pix_distrib, axis=(1, 2), keepdims=True) if 'states' in inputs: with tf.name_scope('gen_states'): with tf.variable_scope('state_pred'): gen_state = dense(state_action, inputs['states'].shape[-1].value) outputs = { 'gen_images': gen_image, 'gen_inputs': gen_input, 'transformed_images': tf.stack(transformed_images, axis=-1), 'masks': tf.stack(masks, axis=-1), 'gen_flow': flows } if 'pix_distribs' in inputs: outputs['gen_pix_distribs'] = gen_pix_distrib outputs['transformed_pix_distribs'] = tf.stack( transformed_pix_distribs, axis=-1) if 'states' in inputs: outputs['gen_states'] = gen_state new_states = { 'time': time + 1, 'last_gt_images': last_gt_images, 'last_pred_flows': last_pred_flows, 'gen_image': gen_image, 'last_images': last_images, 'last_base_images': last_base_images, 'last_base_flows': last_base_flows, 'conv_rnn_states': new_conv_rnn_states } if 'zs' in inputs and self.hparams.use_rnn_z: new_states['rnn_z_state'] = rnn_z_state if 'pix_distribs' in inputs: new_states['last_gt_pix_distribs'] = last_gt_pix_distribs new_states['last_base_pix_distribs'] = last_base_pix_distribs if 'states' in inputs: new_states['gen_state'] = gen_state return outputs, new_states
def generator_fn(inputs, mode, hparams=None): inputs = {name: tf_utils.maybe_pad_or_slice(input, hparams.sequence_length - 1) for name, input in inputs.items()} images = inputs['images'] input_images = tf.concat(tf.unstack(images[:hparams.context_frames], axis=0), axis=-1) gen_states = [] for t in range(hparams.context_frames, hparams.sequence_length): state_action = [] if 'actions' in inputs: state_action.append(inputs['actions'][t - 1]) if 'states' in inputs: state_action.append(gen_states[-1] if gen_states else inputs['states'][t - 1]) state_action = tf.concat(state_action, axis=-1) with tf.name_scope('gen_states'): with tf.variable_scope('state_pred%d' % t): gen_state = dense(state_action, inputs['states'].shape[-1]) gen_states.append(gen_state) states_actions = [] if 'actions' in inputs: states_actions += tf.unstack(inputs['actions'][:hparams.sequence_length - 1], axis=0) if 'states' in inputs: states_actions += tf.unstack(inputs['states'][:hparams.context_frames], axis=0) states_actions += gen_states if states_actions: states_actions = tf.concat(states_actions, axis=-1) # don't backpropagate the convnet through the state dynamics states_actions = tf.stop_gradient(states_actions) else: states_actions = tf.zeros([images.shape[1], 0]) with slim.arg_scope([slim.conv2d], activation_fn=tf.nn.relu, weights_initializer=tf.truncated_normal_initializer(0.0, 0.01), weights_regularizer=slim.l2_regularizer(0.0001)): batch_norm_params = { 'decay': 0.9997, 'epsilon': 0.001, 'is_training': mode == 'train', } with slim.arg_scope([slim.batch_norm], is_training=mode == 'train', updates_collections=None): with slim.arg_scope([slim.conv2d], normalizer_fn=slim.batch_norm, normalizer_params=batch_norm_params): h0 = slim.conv2d(input_images, 64, [5, 5], stride=1, scope='conv1') size0 = tf.shape(input_images)[-3:-1] h1 = slim.max_pool2d(h0, [2, 2], scope='pool1') h1 = slim.conv2d(h1, 128, [5, 5], stride=1, scope='conv2') size1 = tf.shape(h1)[-3:-1] h2 = slim.max_pool2d(h1, [2, 2], scope='pool2') h2 = slim.conv2d(h2, 256, [3, 3], stride=1, scope='conv3') size2 = tf.shape(h2)[-3:-1] h3 = slim.max_pool2d(h2, [2, 2], scope='pool3') h3 = tile_concat([h3, states_actions[:, None, None, :]], axis=-1) h3 = slim.conv2d(h3, 256, [3, 3], stride=1, scope='conv4') h4 = tf.image.resize_bilinear(h3, size2) h4 = tf.concat([h4, h2], axis=-1) h4 = slim.conv2d(h4, 256, [3, 3], stride=1, scope='conv5') h5 = tf.image.resize_bilinear(h4, size1) h5 = tf.concat([h5, h1], axis=-1) h5 = slim.conv2d(h5, 128, [5, 5], stride=1, scope='conv6') h6 = tf.image.resize_bilinear(h5, size0) h6 = tf.concat([h6, h0], axis=-1) h6 = slim.conv2d(h6, 64, [5, 5], stride=1, scope='conv7') extrap_length = hparams.sequence_length - hparams.context_frames flows_masks = slim.conv2d(h6, 5 * extrap_length, [5, 5], stride=1, activation_fn=tf.tanh, normalizer_fn=None, scope='conv8') flows_masks = tf.split(flows_masks, extrap_length, axis=-1) gen_images = [] gen_flows_1 = [] gen_flows_2 = [] masks = [] for flows_mask in flows_masks: flow_1, flow_2, mask = tf.split(flows_mask, [2, 2, 1], axis=-1) gen_flows_1.append(flow_1) gen_flows_2.append(flow_2) mask = 0.5 * (1.0 + mask) masks.append(mask) linspace_x = tf.linspace(-1.0, 1.0, size0[1]) linspace_x.set_shape(input_images.shape[-2]) linspace_y = tf.linspace(-1.0, 1.0, size0[0]) linspace_y.set_shape(input_images.shape[-3]) grid_x, grid_y = tf.meshgrid(linspace_x, linspace_y) coor_x_1 = grid_x[None, :, :] + flow_1[:, :, :, 0] coor_y_1 = grid_y[None, :, :] + flow_1[:, :, :, 1] coor_x_2 = grid_x[None, :, :] + flow_2[:, :, :, 0] coor_y_2 = grid_y[None, :, :] + flow_2[:, :, :, 1] output_1 = bilinear_interp(images[0], coor_x_1, coor_y_1, 'interpolate') output_2 = bilinear_interp(images[1], coor_x_2, coor_y_2, 'interpolate') gen_image = mask * output_1 + (1.0 - mask) * output_2 gen_images.append(gen_image) gen_images = tf.stack(gen_images, axis=0) gen_flows_1 = tf.stack(gen_flows_1, axis=0) gen_flows_2 = tf.stack(gen_flows_2, axis=0) masks = tf.stack(masks, axis=0) outputs = { 'gen_images': gen_images, 'gen_flows_1': gen_flows_1, 'gen_flows_2': gen_flows_2, 'masks': masks, } if 'states' in inputs: gen_states = tf.stack(gen_states, axis=0) outputs['gen_states'] = gen_states return gen_images, outputs
def call(self, inputs, states): norm_layer = ops.get_norm_layer(self.hparams.norm_layer) feature_shape = inputs['features'].get_shape().as_list() batch_size, height, width, feature_channels = feature_shape conv_rnn_states = states['conv_rnn_states'] time = states['time'] with tf.control_dependencies([tf.assert_equal(time[1:], time[0])]): t = tf.to_int32(tf.identity(time[0])) feature = tf.where(self.ground_truth[t], inputs['features'], states['gen_feature']) # schedule sampling (if any) if 'states' in inputs: state = tf.where(self.ground_truth[t], inputs['states'], states['gen_state']) state_action = [] state_action_z = [] if 'actions' in inputs: state_action.append(inputs['actions']) state_action_z.append(inputs['actions']) if 'states' in inputs: state_action.append(state) # don't backpropagate the convnet through the state dynamics state_action_z.append(tf.stop_gradient(state)) if 'zs' in inputs: if self.hparams.use_rnn_z: with tf.variable_scope('%s_z' % self.hparams.rnn): rnn_z, rnn_z_state = self._rnn_func( inputs['zs'], states['rnn_z_state'], self.hparams.nz) state_action_z.append(rnn_z) else: state_action_z.append(inputs['zs']) def concat(tensors, axis): if len(tensors) == 0: return tf.zeros([batch_size, 0]) elif len(tensors) == 1: return tensors[0] else: return tf.concat(tensors, axis=axis) state_action = concat(state_action, axis=-1) state_action_z = concat(state_action_z, axis=-1) if 'actions' in inputs: gen_input = tile_concat( [feature, inputs['actions'][:, None, None, :]], axis=-1) else: gen_input = feature layers = [] new_conv_rnn_states = [] for i, (out_channels, use_conv_rnn) in enumerate(self.encoder_layer_specs): with tf.variable_scope('h%d' % i): if i == 0: # h = tf.concat([feature, self.inputs['features'][0]], axis=-1) # TODO: use first feature? h = feature else: h = layers[-1][-1] h = conv_pool2d(tile_concat( [h, state_action_z[:, None, None, :]], axis=-1), out_channels, kernel_size=(3, 3), strides=(2, 2)) h = norm_layer(h) h = tf.nn.relu(h) if use_conv_rnn: conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)] with tf.variable_scope('%s_h%d' % (self.hparams.conv_rnn, i)): conv_rnn_h, conv_rnn_state = self._conv_rnn_func( tile_concat([h, state_action_z[:, None, None, :]], axis=-1), conv_rnn_state, out_channels) new_conv_rnn_states.append(conv_rnn_state) layers.append((h, conv_rnn_h) if use_conv_rnn else (h, )) num_encoder_layers = len(layers) for i, (out_channels, use_conv_rnn) in enumerate(self.decoder_layer_specs): with tf.variable_scope('h%d' % len(layers)): if i == 0: h = layers[-1][-1] else: h = tf.concat([ layers[-1][-1], layers[num_encoder_layers - i - 1][-1] ], axis=-1) h = upsample_conv2d(tile_concat( [h, state_action_z[:, None, None, :]], axis=-1), out_channels, kernel_size=(3, 3), strides=(2, 2)) h = norm_layer(h) h = tf.nn.relu(h) if use_conv_rnn: conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)] with tf.variable_scope('%s_h%d' % (self.hparams.conv_rnn, len(layers))): conv_rnn_h, conv_rnn_state = self._conv_rnn_func( tile_concat([h, state_action_z[:, None, None, :]], axis=-1), conv_rnn_state, out_channels) new_conv_rnn_states.append(conv_rnn_state) layers.append((h, conv_rnn_h) if use_conv_rnn else (h, )) assert len(new_conv_rnn_states) == len(conv_rnn_states) if self.hparams.transformation == 'direct': with tf.variable_scope('h%d_direct' % len(layers)): h_direct = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1)) h_direct = norm_layer(h_direct) h_direct = tf.nn.relu(h_direct) with tf.variable_scope('direct'): gen_feature = conv2d(h_direct, feature_channels, kernel_size=(3, 3), strides=(1, 1)) else: if self.hparams.transformation == 'flow': with tf.variable_scope('h%d_flow' % len(layers)): h_flow = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1)) h_flow = norm_layer(h_flow) h_flow = tf.nn.relu(h_flow) with tf.variable_scope('flows'): flows = conv2d(h_flow, 2 * feature_channels, kernel_size=(3, 3), strides=(1, 1)) flows = tf.reshape( flows, [batch_size, height, width, 2, feature_channels]) transformations = flows else: assert len(self.hparams.kernel_size) == 2 kernel_shape = list( self.hparams.kernel_size) + [feature_channels] if self.hparams.transformation == 'local': with tf.variable_scope('h%d_local_kernel' % len(layers)): h_local_kernel = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1)) h_local_kernel = norm_layer(h_local_kernel) h_local_kernel = tf.nn.relu(h_local_kernel) # Using largest hidden state for predicting untied conv kernels. with tf.variable_scope('local_kernels'): kernels = conv2d(h_local_kernel, np.prod(kernel_shape), kernel_size=(3, 3), strides=(1, 1)) kernels = tf.reshape(kernels, [batch_size, height, width] + kernel_shape) kernels = kernels + identity_kernel( self.hparams.kernel_size)[None, None, None, :, :, None] elif self.hparams.transformation == 'conv': with tf.variable_scope('conv_kernels'): smallest_layer = layers[num_encoder_layers - 1][-1] kernels = dense(flatten(smallest_layer), np.prod(kernel_shape)) kernels = tf.reshape(kernels, [batch_size] + kernel_shape) kernels = kernels + identity_kernel( self.hparams.kernel_size)[None, :, :, None] else: raise ValueError('Invalid transformation %s' % self.hparams.transformation) transformations = kernels with tf.name_scope('gen_features'): if self.hparams.transformation == 'flow': def apply_transformation(feature_and_flow): feature, flow = feature_and_flow return flow_ops.image_warp(feature[..., None], flow) else: def apply_transformation(feature_and_kernel): feature, kernel = feature_and_kernel output, = apply_kernels(feature[..., None], kernel[..., None]) return tf.squeeze(output, axis=-1) gen_feature_transposed = tf.map_fn( apply_transformation, (tf.stack(tf.unstack(feature, axis=-1)), tf.stack(tf.unstack(transformations, axis=-1))), dtype=tf.float32) gen_feature = tf.stack(tf.unstack(gen_feature_transposed), axis=-1) # TODO: use norm and relu for generated features? gen_feature = norm_layer(gen_feature) gen_feature = tf.nn.relu(gen_feature) if 'states' in inputs: with tf.name_scope('gen_states'): with tf.variable_scope('state_pred'): gen_state = dense(state_action, inputs['states'].shape[-1].value) outputs = { 'gen_features': gen_feature, 'gen_inputs': gen_input, } if 'states' in inputs: outputs['gen_states'] = gen_state if self.hparams.transformation == 'flow': outputs['gen_flows'] = flows new_states = { 'time': time + 1, 'gen_feature': gen_feature, 'conv_rnn_states': new_conv_rnn_states, } if 'zs' in inputs and self.hparams.use_rnn_z: new_states['rnn_z_state'] = rnn_z_state if 'states' in inputs: new_states['gen_state'] = gen_state return outputs, new_states