def discriminator_fn(targets, inputs=None, hparams=None): batch_size = targets.shape[1].value # sort of hack to ensure that the same t_sample is used for all the # discriminators that are given the same inputs if 't_sample' in inputs: t_sample = inputs['t_sample'] else: t_sample = tf.random_uniform([batch_size], minval=0, maxval=targets.shape[0].value, dtype=tf.int32) inputs['t_sample'] = t_sample image_sample = tf.gather_nd( targets, tf.stack([t_sample, tf.range(batch_size)], axis=1)) if 't_start' in inputs: t_start = inputs['t_start'] else: t_start = tf.random_uniform([batch_size], minval=0, maxval=targets.shape[0].value - hparams.clip_length + 1, dtype=tf.int32) inputs['t_start'] = t_start t_start_indices = tf.stack([t_start, tf.range(batch_size)], axis=1) t_offset_indices = tf.stack([ tf.range(hparams.clip_length), tf.zeros(hparams.clip_length, dtype=tf.int32) ], axis=1) indices = tf.expand_dims(t_start_indices, axis=0) + tf.expand_dims( t_offset_indices, axis=1) clip_sample = tf.reshape(tf.gather_nd(targets, flatten( indices, 0, 1)), [hparams.clip_length] + targets.shape.as_list()[1:]) outputs = {} if hparams.image_sn_gan_weight or hparams.image_sn_vae_gan_weight: image_features = create_image_sn_discriminator(image_sample, ndf=hparams.ndf) image_features, image_logits = image_features[:-1], image_features[-1] outputs['discrim_image_sn_logits'] = tf.expand_dims( image_logits, axis=0) # expand dims for the time dimension with tf.variable_scope(tf.get_variable_scope(), reuse=True): images_features = create_image_sn_discriminator(flatten( targets, 0, 1), ndf=hparams.ndf) images_features = images_features[:-1] for i, images_feature in enumerate(images_features): images_feature = tf.reshape( images_feature, targets.shape[:2].as_list() + images_feature.shape[1:].as_list()) outputs['discrim_image_sn_feature%d' % i] = images_feature if hparams.video_sn_gan_weight or hparams.video_sn_vae_gan_weight: video_features = create_video_sn_discriminator(clip_sample, ndf=hparams.ndf) video_features, video_logits = video_features[:-1], video_features[-1] outputs['discrim_video_sn_logits'] = video_logits for i, video_feature in enumerate(video_features): outputs['discrim_video_sn_feature%d' % i] = video_feature return None, outputs
def create_encoder(inputs, e_net='legacy', use_e_rnn=False, rnn='lstm', **kwargs): assert inputs.shape.ndims == 5 batch_shape = inputs.shape[:-3].as_list() inputs = flatten(inputs, 0, len(batch_shape) - 1) unflatten = lambda x: tf.reshape(x, batch_shape + x.shape.as_list()[1:]) if use_e_rnn: if e_net == 'legacy': kwargs.pop('n_layers', None) # unused h = create_legacy_encoder(inputs, include_top=False, **kwargs) with tf.variable_scope('h4'): h = dense(h, kwargs['nef'] * 4) elif e_net == 'n_layer': h = create_n_layer_encoder(inputs, include_top=False, **kwargs) with tf.variable_scope('layer_%d' % (kwargs['n_layers'] + 1)): h = dense(h, kwargs['nef'] * 4) else: raise ValueError('Invalid encoder net %s' % e_net) if rnn == 'lstm': RNNCell = tf.contrib.rnn.BasicLSTMCell elif rnn == 'gru': RNNCell = tf.contrib.rnn.GRUCell else: raise NotImplementedError h = nest.map_structure(unflatten, h) for i in range(2): with tf.variable_scope('%s_h%d' % (rnn, i)): rnn_cell = RNNCell(kwargs['nef'] * 4) h, _ = tf.nn.dynamic_rnn(rnn_cell, h, dtype=tf.float32, time_major=True) h = flatten(h, 0, len(batch_shape) - 1) with tf.variable_scope('z_mu'): z_mu = dense(h, kwargs['nz']) with tf.variable_scope('z_log_sigma_sq'): z_log_sigma_sq = dense(h, kwargs['nz']) z_log_sigma_sq = tf.clip_by_value(z_log_sigma_sq, -10, 10) outputs = {'enc_zs_mu': z_mu, 'enc_zs_log_sigma_sq': z_log_sigma_sq} else: if e_net == 'legacy': kwargs.pop('n_layers', None) # unused outputs = create_legacy_encoder(inputs, include_top=True, **kwargs) elif e_net == 'n_layer': outputs = create_n_layer_encoder(inputs, include_top=True, **kwargs) else: raise ValueError('Invalid encoder net %s' % e_net) outputs = nest.map_structure(unflatten, outputs) return outputs
def apply_dna_kernels(image, kernels, dilation_rate=(1, 1)): """ Args: image: A 4-D tensor of shape `[batch, in_height, in_width, in_channels]`. kernels: A 6-D of shape `[batch, in_height, in_width, kernel_size[0], kernel_size[1], num_transformed_images]`. Returns: A list of `num_transformed_images` 4-D tensors, each of shape `[batch, in_height, in_width, in_channels]`. """ dilation_rate = list(dilation_rate) if isinstance(dilation_rate, (tuple, list)) else [dilation_rate] * 2 batch_size, height, width, color_channels = image.get_shape().as_list() batch_size, height, width, kernel_height, kernel_width, num_transformed_images = kernels.get_shape().as_list() kernel_size = [kernel_height, kernel_width] # Flatten the spatial dimensions. kernels_reshaped = tf.reshape(kernels, [batch_size, height, width, kernel_size[0] * kernel_size[1], num_transformed_images]) image_padded = pad2d(image, kernel_size, rate=dilation_rate, padding='SAME', mode='SYMMETRIC') # Combine channel and batch dimensions into the first dimension. image_transposed = tf.transpose(image_padded, [3, 0, 1, 2]) image_reshaped = flatten(image_transposed, 0, 1)[..., None] patches_reshaped = tf.extract_image_patches(image_reshaped, ksizes=[1] + kernel_size + [1], strides=[1] * 4, rates=[1] + dilation_rate + [1], padding='VALID') # Separate channel and batch dimensions, and move channel dimension. patches_transposed = tf.reshape(patches_reshaped, [color_channels, batch_size, height, width, kernel_size[0] * kernel_size[1]]) patches = tf.transpose(patches_transposed, [1, 2, 3, 0, 4]) # Reduce along the spatial dimensions of the kernel. outputs = tf.matmul(patches, kernels_reshaped) outputs = tf.unstack(outputs, axis=-1) return outputs
def create_legacy_encoder(inputs, nz=8, nef=64, norm_layer='instance', include_top=True): norm_layer = ops.get_norm_layer(norm_layer) with tf.variable_scope('h1'): h1 = conv_pool2d(inputs, nef, kernel_size=5, strides=2) h1 = norm_layer(h1) h1 = tf.nn.relu(h1) with tf.variable_scope('h2'): h2 = conv_pool2d(h1, nef * 2, kernel_size=5, strides=2) h2 = norm_layer(h2) h2 = tf.nn.relu(h2) with tf.variable_scope('h3'): h3 = conv_pool2d(h2, nef * 4, kernel_size=5, strides=2) h3 = norm_layer(h3) h3 = tf.nn.relu(h3) h3_flatten = flatten(h3) if include_top: with tf.variable_scope('z_mu'): z_mu = dense(h3_flatten, nz) with tf.variable_scope('z_log_sigma_sq'): z_log_sigma_sq = dense(h3_flatten, nz) z_log_sigma_sq = tf.clip_by_value(z_log_sigma_sq, -10, 10) outputs = {'enc_zs_mu': z_mu, 'enc_zs_log_sigma_sq': z_log_sigma_sq} else: outputs = h3_flatten return outputs
def create_pspnet50_encoder(inputs): should_flatten = inputs.shape.ndims > 4 if should_flatten: batch_shape = inputs.shape[:-3].as_list() inputs = flatten(inputs, 0, len(batch_shape) - 1) outputs = pspnet_network.pspnet(inputs, resnet_layers=50) if should_flatten: outputs = tf.reshape(outputs, batch_shape + outputs.shape.as_list()[1:]) return outputs
def generator_fn(inputs, hparams=None): batch_size = inputs['images'].shape[1].value inputs = { name: tf_utils.maybe_pad_or_slice(input, hparams.sequence_length - 1) for name, input in inputs.items() } with tf.variable_scope('gru'): gru_cell = tf.nn.rnn_cell.GRUCell(hparams.dim_z_motion) if hparams.context_frames: with tf.variable_scope('content_encoder'): z_c = create_encoder( inputs['images'][0], # first context image for content encoder nef=hparams.nef, norm_layer=hparams.norm_layer, dim_z=hparams.dim_z_content) with tf.variable_scope('initial_motion_encoder'): h_0 = create_encoder( inputs['images'][hparams.context_frames - 1], # last context image for motion encoder nef=hparams.nef, norm_layer=hparams.norm_layer, dim_z=hparams.dim_z_motion) else: # unconditional case z_c = tf.random_normal([batch_size, hparams.dim_z_content]) h_0 = gru_cell.zero_state(batch_size, tf.float32) h_t = [h_0] for t in range(hparams.context_frames - 1, hparams.sequence_length - 1): with tf.variable_scope('gru', reuse=t > hparams.context_frames - 1): e_t = tf.random_normal([batch_size, hparams.dim_z_motion]) if 'actions' in inputs: e_t = tf.concat(inputs['actions'][t], axis=-1) h_t.append(gru_cell( e_t, h_t[-1])[1]) # the output and state is the same in GRUs z_m = tf.stack(h_t[1:], axis=0) z = tf.concat([ tf.tile(z_c[None, :, :], [hparams.sequence_length - hparams.context_frames, 1, 1]), z_m ], axis=-1) z_flatten = flatten(z[:, :, None, None, :], 0, 1) gen_images_flatten = create_generator(z_flatten, ngf=hparams.ngf, norm_layer=hparams.norm_layer) gen_images = tf.reshape(gen_images_flatten, [-1, batch_size] + gen_images_flatten.shape.as_list()[1:]) outputs = {'gen_images': gen_images} return gen_images, outputs
def create_discriminator(discrim_targets, discrim_inputs=None, d_net='legacy', **kwargs): should_flatten = discrim_targets.shape.ndims > 4 if should_flatten: ndims = discrim_targets.shape.ndims batch_shape = discrim_targets.shape[:-3].as_list() discrim_targets = flatten(discrim_targets, 0, len(batch_shape) - 1) if discrim_inputs is not None: assert discrim_inputs.shape.ndims == ndims assert discrim_inputs.shape[:-3].as_list() == batch_shape discrim_inputs = flatten(discrim_inputs, 0, len(batch_shape) - 1) if d_net == 'legacy': kwargs.pop('n_layers', None) # unused features = create_legacy_discriminator(discrim_targets, discrim_inputs, **kwargs) elif d_net == 'n_layer': kwargs.pop('downsample_layer', None) # unused n_layers = kwargs.pop('n_layers', None) if not n_layers: scale_size = min(*discrim_targets.shape.as_list()[1:3]) n_layers = int(np.log2(scale_size // 32)) features = create_n_layer_discriminator(discrim_targets, discrim_inputs, n_layers=n_layers, **kwargs) else: raise ValueError('Invalid discriminator net %s' % d_net) if should_flatten: features = nest.map_structure( lambda x: tf.reshape(x, batch_shape + x.shape.as_list()[1:]), features) return features
def generator_fn(inputs, outputs_enc=None, hparams=None): batch_size = inputs['images'].shape[1].value inputs = {name: tf_utils.maybe_pad_or_slice(input, hparams.sequence_length - 1) for name, input in inputs.items()} if hparams.nz: def sample_zs(): if outputs_enc is None: zs = tf.random_normal([hparams.sequence_length - 1, batch_size, hparams.nz], 0, 1) else: enc_zs_mu = outputs_enc['enc_zs_mu'] enc_zs_log_sigma_sq = outputs_enc['enc_zs_log_sigma_sq'] eps = tf.random_normal([hparams.sequence_length - 1, batch_size, hparams.nz], 0, 1) zs = enc_zs_mu + tf.sqrt(tf.exp(enc_zs_log_sigma_sq)) * eps return zs inputs['zs'] = sample_zs() else: if outputs_enc is not None: raise ValueError('outputs_enc has to be None when nz is 0.') cell = DNACell(inputs, hparams) outputs, _ = tf.nn.dynamic_rnn(cell, inputs, dtype=tf.float32, swap_memory=False, time_major=True) if hparams.nz: inputs_samples = {name: flatten(tf.tile(input[:, None], [1, hparams.num_samples] + [1] * (input.shape.ndims - 1)), 1, 2) for name, input in inputs.items() if name != 'zs'} inputs_samples['zs'] = tf.concat([sample_zs() for _ in range(hparams.num_samples)], axis=1) with tf.variable_scope(tf.get_variable_scope(), reuse=True): cell_samples = DNACell(inputs_samples, hparams) outputs_samples, _ = tf.nn.dynamic_rnn(cell_samples, inputs_samples, dtype=tf.float32, swap_memory=False, time_major=True) gen_images_samples = outputs_samples['gen_images'] gen_images_samples = tf.stack(tf.split(gen_images_samples, hparams.num_samples, axis=1), axis=-1) gen_images_samples_avg = tf.reduce_mean(gen_images_samples, axis=-1) outputs['gen_images_samples'] = gen_images_samples outputs['gen_images_samples_avg'] = gen_images_samples_avg # the RNN outputs generated images from time step 1 to sequence_length, # but generator_fn should only return images past context_frames outputs = {name: output[hparams.context_frames - 1:] for name, output in outputs.items()} gen_images = outputs['gen_images'] outputs['ground_truth_sampling_mean'] = tf.reduce_mean(tf.to_float(cell.ground_truth[hparams.context_frames:])) return gen_images, outputs
def generator_loss_fn(self, inputs, outputs, targets): hparams = self.hparams gen_losses = OrderedDict() if hparams.l1_weight or hparams.l2_weight or hparams.vgg_cdist_weight: gen_images = outputs.get('gen_images_enc', outputs['gen_images']) target_images = targets if hparams.l1_weight: gen_l1_loss = vp.losses.l1_loss(gen_images, target_images) gen_losses["gen_l1_loss"] = (gen_l1_loss, hparams.l1_weight) if hparams.l2_weight: gen_l2_loss = vp.losses.l2_loss(gen_images, target_images) gen_losses["gen_l2_loss"] = (gen_l2_loss, hparams.l2_weight) if hparams.vgg_cdist_weight: gen_vgg_cdist_loss = vp.metrics.vgg_cosine_distance( gen_images, target_images) gen_losses['gen_vgg_cdist_loss'] = (gen_vgg_cdist_loss, hparams.vgg_cdist_weight) if hparams.feature_l2_weight: gen_features = outputs.get('gen_features_enc', outputs['gen_features']) target_features = outputs['features'][hparams.context_frames:] gen_feature_l2_loss = vp.losses.l2_loss(gen_features, target_features) gen_losses["gen_feature_l2_loss"] = (gen_feature_l2_loss, hparams.feature_l2_weight) if hparams.ae_l2_weight: gen_images_dec = outputs.get( 'gen_images_dec_enc', outputs['gen_images_dec']) # they both should be the same target_images = inputs['images'] gen_ae_l2_loss = vp.losses.l2_loss(gen_images_dec, target_images) gen_losses["gen_ae_l2_loss"] = (gen_ae_l2_loss, hparams.ae_l2_weight) if hparams.state_weight: gen_states = outputs.get('gen_states_enc', outputs['gen_states']) target_states = inputs['states'][hparams.context_frames:] gen_state_loss = vp.losses.l2_loss(gen_states, target_states) gen_losses["gen_state_loss"] = (gen_state_loss, hparams.state_weight) if hparams.tv_weight: gen_flows = outputs.get('gen_flows_enc', outputs['gen_flows']) gen_flows_reshaped = flatten(flatten(gen_flows, 0, 1), -2) gen_tv_loss = tf.reduce_mean( tf.image.total_variation(gen_flows_reshaped)) gen_losses['gen_tv_loss'] = (gen_tv_loss, hparams.tv_weight) gan_weights = { '': hparams.gan_weight, '_tuple': hparams.tuple_gan_weight, '_image': hparams.image_gan_weight, '_video': hparams.video_gan_weight, '_acvideo': hparams.acvideo_gan_weight, '_image_sn': hparams.image_sn_gan_weight, '_images_sn': hparams.images_sn_gan_weight, '_video_sn': hparams.video_sn_gan_weight } for infix, gan_weight in gan_weights.items(): if gan_weight: gen_gan_loss = vp.losses.gan_loss( outputs['discrim%s_logits_fake' % infix], 1.0, hparams.gan_loss_type) gen_losses["gen%s_gan_loss" % infix] = (gen_gan_loss, gan_weight) if gan_weight and (hparams.gan_feature_l2_weight or hparams.gan_feature_cdist_weight): i_feature = 0 discrim_features_fake = [] discrim_features_real = [] while True: discrim_feature_fake = outputs.get( 'discrim%s_feature%d_fake' % (infix, i_feature)) discrim_feature_real = outputs.get( 'discrim%s_feature%d_real' % (infix, i_feature)) if discrim_feature_fake is None or discrim_feature_real is None: break discrim_features_fake.append(discrim_feature_fake) discrim_features_real.append(discrim_feature_real) i_feature += 1 if hparams.gan_feature_l2_weight: gen_gan_feature_l2_loss = sum([ vp.losses.l2_loss(discrim_feature_fake, discrim_feature_real) for discrim_feature_fake, discrim_feature_real in zip( discrim_features_fake, discrim_features_real) ]) gen_losses["gen%s_gan_feature_l2_loss" % infix] = (gen_gan_feature_l2_loss, hparams.gan_feature_l2_weight) if hparams.gan_feature_cdist_weight: gen_gan_feature_cdist_loss = sum([ vp.metrics.cosine_distance(discrim_feature_fake, discrim_feature_real) for discrim_feature_fake, discrim_feature_real in zip( discrim_features_fake, discrim_features_real) ]) gen_losses["gen%s_gan_feature_cdist_loss" % infix] = (gen_gan_feature_cdist_loss, hparams.gan_feature_cdist_weight) vae_gan_weights = { '': hparams.vae_gan_weight, '_tuple': hparams.tuple_vae_gan_weight, '_image': hparams.image_vae_gan_weight, '_video': hparams.video_vae_gan_weight, '_acvideo': hparams.acvideo_vae_gan_weight, '_image_sn': hparams.image_sn_vae_gan_weight, '_images_sn': hparams.images_sn_vae_gan_weight, '_video_sn': hparams.video_sn_vae_gan_weight } for infix, vae_gan_weight in vae_gan_weights.items(): if vae_gan_weight: gen_vae_gan_loss = vp.losses.gan_loss( outputs['discrim%s_logits_enc_fake' % infix], 1.0, hparams.gan_loss_type) gen_losses["gen%s_vae_gan_loss" % infix] = (gen_vae_gan_loss, vae_gan_weight) if vae_gan_weight and (hparams.gan_feature_l2_weight or hparams.gan_feature_cdist_weight): i_feature = 0 discrim_features_enc_fake = [] discrim_features_enc_real = [] while True: discrim_feature_enc_fake = outputs.get( 'discrim%s_feature%d_enc_fake' % (infix, i_feature)) discrim_feature_enc_real = outputs.get( 'discrim%s_feature%d_enc_real' % (infix, i_feature)) if discrim_feature_enc_fake is None or discrim_feature_enc_real is None: break discrim_features_enc_fake.append(discrim_feature_enc_fake) discrim_features_enc_real.append(discrim_feature_enc_real) i_feature += 1 if hparams.gan_feature_l2_weight: gen_vae_gan_feature_l2_loss = sum([ vp.losses.l2_loss(discrim_feature_enc_fake, discrim_feature_enc_real) for discrim_feature_enc_fake, discrim_feature_enc_real in zip(discrim_features_enc_fake, discrim_features_enc_real) ]) gen_losses["gen%s_vae_gan_feature_l2_loss" % infix] = (gen_vae_gan_feature_l2_loss, hparams.gan_feature_l2_weight) if hparams.gan_feature_cdist_weight: gen_vae_gan_feature_cdist_loss = sum([ vp.metrics.cosine_distance(discrim_feature_enc_fake, discrim_feature_enc_real) for discrim_feature_enc_fake, discrim_feature_enc_real in zip(discrim_features_enc_fake, discrim_features_enc_real) ]) gen_losses["gen%s_vae_gan_feature_cdist_loss" % infix] = (gen_vae_gan_feature_cdist_loss, hparams.gan_feature_cdist_weight) if hparams.kl_weight: gen_kl_loss = vp.losses.kl_loss(outputs['enc_zs_mu'], outputs['enc_zs_log_sigma_sq']) gen_losses["gen_kl_loss"] = (gen_kl_loss, self.kl_weight ) # possibly annealed kl_weight if hparams.z_l1_weight: gen_z_l1_loss = vp.losses.l1_loss(outputs['gen_enc_zs_mu'], outputs['gen_zs_random']) gen_losses["gen_z_l1_loss"] = (gen_z_l1_loss, hparams.z_l1_weight) return gen_losses
def call(self, inputs, states): norm_layer = ops.get_norm_layer(self.hparams.norm_layer) downsample_layer = ops.get_downsample_layer( self.hparams.downsample_layer) upsample_layer = ops.get_upsample_layer(self.hparams.upsample_layer) image_shape = inputs['images'].get_shape().as_list() batch_size, height, width, color_channels = image_shape time = states['time'] with tf.control_dependencies([tf.assert_equal(time[1:], time[0])]): t = tf.to_int32(tf.identity(time[0])) if 'states' in inputs: state = tf.where(self.ground_truth[t], inputs['states'], states['gen_state']) state_action = [] state_action_z = [] if 'actions' in inputs: state_action.append(inputs['actions']) state_action_z.append(inputs['actions']) if 'states' in inputs: state_action.append(state) # don't backpropagate the convnet through the state dynamics state_action_z.append(tf.stop_gradient(state)) if 'zs' in inputs: if self.hparams.use_rnn_z: with tf.variable_scope('%s_z' % self.hparams.rnn): rnn_z, rnn_z_state = self._rnn_func( inputs['zs'], states['rnn_z_state'], self.hparams.nz) state_action_z.append(rnn_z) else: state_action_z.append(inputs['zs']) def concat(tensors, axis): if len(tensors) == 0: return tf.zeros([batch_size, 0]) elif len(tensors) == 1: return tensors[0] else: return tf.concat(tensors, axis=axis) state_action = concat(state_action, axis=-1) state_action_z = concat(state_action_z, axis=-1) image_views = [] first_image_views = [] if 'pix_distribs' in inputs: pix_distrib_views = [] for i in range(self.hparams.num_views): suffix = '%d' % i if i > 0 else '' image_view = tf.where( self.ground_truth[t], inputs['images' + suffix], states['gen_image' + suffix]) # schedule sampling (if any) image_views.append(image_view) first_image_views.append(self.inputs['images' + suffix][0]) if 'pix_distribs' in inputs: pix_distrib_view = tf.where(self.ground_truth[t], inputs['pix_distribs' + suffix], states['gen_pix_distrib' + suffix]) pix_distrib_views.append(pix_distrib_view) outputs = {} new_states = {} all_layers = [] for i in range(self.hparams.num_views): suffix = '%d' % i if i > 0 else '' conv_rnn_states = states['conv_rnn_states' + suffix] layers = [] new_conv_rnn_states = [] for i, (out_channels, use_conv_rnn) in enumerate(self.encoder_layer_specs): with tf.variable_scope('h%d' % i + suffix): if i == 0: # all image views and the first image corresponding to this view only h = tf.concat(image_views + first_image_views, axis=-1) kernel_size = (5, 5) else: h = layers[-1][-1] kernel_size = (3, 3) if self.hparams.where_add == 'all' or ( self.hparams.where_add == 'input' and i == 0): h = tile_concat([h, state_action_z[:, None, None, :]], axis=-1) h = downsample_layer(h, out_channels, kernel_size=kernel_size, strides=(2, 2)) h = norm_layer(h) h = tf.nn.relu(h) if use_conv_rnn: conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)] with tf.variable_scope('%s_h%d' % (self.hparams.conv_rnn, i) + suffix): if self.hparams.where_add == 'all': conv_rnn_h = tile_concat( [h, state_action_z[:, None, None, :]], axis=-1) else: conv_rnn_h = h conv_rnn_h, conv_rnn_state = self._conv_rnn_func( conv_rnn_h, conv_rnn_state, out_channels) new_conv_rnn_states.append(conv_rnn_state) layers.append((h, conv_rnn_h) if use_conv_rnn else (h, )) num_encoder_layers = len(layers) for i, (out_channels, use_conv_rnn) in enumerate(self.decoder_layer_specs): with tf.variable_scope('h%d' % len(layers) + suffix): if i == 0: h = layers[-1][-1] else: h = tf.concat([ layers[-1][-1], layers[num_encoder_layers - i - 1][-1] ], axis=-1) if self.hparams.where_add == 'all' or ( self.hparams.where_add == 'middle' and i == 0): h = tile_concat([h, state_action_z[:, None, None, :]], axis=-1) h = upsample_layer(h, out_channels, kernel_size=(3, 3), strides=(2, 2)) h = norm_layer(h) h = tf.nn.relu(h) if use_conv_rnn: conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)] with tf.variable_scope( '%s_h%d' % (self.hparams.conv_rnn, len(layers)) + suffix): if self.hparams.where_add == 'all': conv_rnn_h = tile_concat( [h, state_action_z[:, None, None, :]], axis=-1) else: conv_rnn_h = h conv_rnn_h, conv_rnn_state = self._conv_rnn_func( conv_rnn_h, conv_rnn_state, out_channels) new_conv_rnn_states.append(conv_rnn_state) layers.append((h, conv_rnn_h) if use_conv_rnn else (h, )) assert len(new_conv_rnn_states) == len(conv_rnn_states) new_states['conv_rnn_states' + suffix] = new_conv_rnn_states all_layers.append(layers) if self.hparams.shared_views: break for i in range(self.hparams.num_views): suffix = '%d' % i if i > 0 else '' if self.hparams.shared_views: layers, = all_layers else: layers = all_layers[i] image = image_views[i] last_images = states['last_images' + suffix][1:] + [image] if 'pix_distribs' in inputs: pix_distrib = pix_distrib_views[i] last_pix_distribs = states['last_pix_distribs' + suffix][1:] + [pix_distrib] if self.hparams.last_frames and self.hparams.num_transformed_images: if self.hparams.transformation == 'flow': with tf.variable_scope('h%d_flow' % len(layers) + suffix): h_flow = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1)) h_flow = norm_layer(h_flow) h_flow = tf.nn.relu(h_flow) with tf.variable_scope('flows' + suffix): flows = conv2d(h_flow, 2 * self.hparams.last_frames * self.hparams.num_transformed_images, kernel_size=(3, 3), strides=(1, 1)) flows = tf.reshape(flows, [ batch_size, height, width, 2, self.hparams.last_frames * self.hparams.num_transformed_images ]) else: assert len(self.hparams.kernel_size) == 2 kernel_shape = list(self.hparams.kernel_size) + [ self.hparams.last_frames * self.hparams.num_transformed_images ] if self.hparams.transformation == 'dna': with tf.variable_scope('h%d_dna_kernel' % len(layers) + suffix): h_dna_kernel = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1)) h_dna_kernel = norm_layer(h_dna_kernel) h_dna_kernel = tf.nn.relu(h_dna_kernel) # Using largest hidden state for predicting untied conv kernels. with tf.variable_scope('dna_kernels' + suffix): kernels = conv2d(h_dna_kernel, np.prod(kernel_shape), kernel_size=(3, 3), strides=(1, 1)) kernels = tf.reshape(kernels, [batch_size, height, width] + kernel_shape) kernels = kernels + identity_kernel( self.hparams.kernel_size)[None, None, None, :, :, None] kernel_spatial_axes = [3, 4] elif self.hparams.transformation == 'cdna': with tf.variable_scope('cdna_kernels' + suffix): smallest_layer = layers[num_encoder_layers - 1][-1] kernels = dense(flatten(smallest_layer), np.prod(kernel_shape)) kernels = tf.reshape(kernels, [batch_size] + kernel_shape) kernels = kernels + identity_kernel( self.hparams.kernel_size)[None, :, :, None] kernel_spatial_axes = [1, 2] else: raise ValueError('Invalid transformation %s' % self.hparams.transformation) if self.hparams.transformation != 'flow': with tf.name_scope('kernel_normalization' + suffix): kernels = tf.nn.relu(kernels - RELU_SHIFT) + RELU_SHIFT kernels /= tf.reduce_sum(kernels, axis=kernel_spatial_axes, keepdims=True) if self.hparams.generate_scratch_image: with tf.variable_scope('h%d_scratch' % len(layers) + suffix): h_scratch = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1)) h_scratch = norm_layer(h_scratch) h_scratch = tf.nn.relu(h_scratch) # Using largest hidden state for predicting a new image layer. # This allows the network to also generate one image from scratch, # which is useful when regions of the image become unoccluded. with tf.variable_scope('scratch_image' + suffix): scratch_image = conv2d(h_scratch, color_channels, kernel_size=(3, 3), strides=(1, 1)) scratch_image = tf.nn.sigmoid(scratch_image) with tf.name_scope('transformed_images' + suffix): transformed_images = [] if self.hparams.last_frames and self.hparams.num_transformed_images: if self.hparams.transformation == 'flow': transformed_images.extend( apply_flows(last_images, flows)) else: transformed_images.extend( apply_kernels(last_images, kernels, self.hparams.dilation_rate)) if self.hparams.prev_image_background: transformed_images.append(image) if self.hparams.first_image_background and not self.hparams.context_images_background: transformed_images.append(self.inputs['images' + suffix][0]) if self.hparams.context_images_background: transformed_images.extend( tf.unstack( self.inputs['images' + suffix][:self.hparams.context_frames])) if self.hparams.generate_scratch_image: transformed_images.append(scratch_image) if 'pix_distribs' in inputs: with tf.name_scope('transformed_pix_distribs' + suffix): transformed_pix_distribs = [] if self.hparams.last_frames and self.hparams.num_transformed_images: if self.hparams.transformation == 'flow': transformed_pix_distribs.extend( apply_flows(last_pix_distribs, flows)) else: transformed_pix_distribs.extend( apply_kernels(last_pix_distribs, kernels, self.hparams.dilation_rate)) if self.hparams.prev_image_background: transformed_pix_distribs.append(pix_distrib) if self.hparams.first_image_background and not self.hparams.context_images_background: transformed_pix_distribs.append( self.inputs['pix_distribs' + suffix][0]) if self.hparams.context_images_background: transformed_pix_distribs.extend( tf.unstack(self.inputs['pix_distribs' + suffix] [:self.hparams.context_frames])) if self.hparams.generate_scratch_image: transformed_pix_distribs.append(pix_distrib) with tf.name_scope('masks' + suffix): if len(transformed_images) > 1: with tf.variable_scope('h%d_masks' % len(layers) + suffix): h_masks = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1)) h_masks = norm_layer(h_masks) h_masks = tf.nn.relu(h_masks) with tf.variable_scope('masks' + suffix): if self.hparams.dependent_mask: h_masks = tf.concat([h_masks] + transformed_images, axis=-1) masks = conv2d(h_masks, len(transformed_images), kernel_size=(3, 3), strides=(1, 1)) masks = tf.nn.softmax(masks) masks = tf.split(masks, len(transformed_images), axis=-1) elif len(transformed_images) == 1: masks = [tf.ones([batch_size, height, width, 1])] else: raise ValueError( "Either one of the following should be true: " "last_frames and num_transformed_images, first_image_background, " "prev_image_background, generate_scratch_image") with tf.name_scope('gen_images' + suffix): assert len(transformed_images) == len(masks) gen_image = tf.add_n([ transformed_image * mask for transformed_image, mask in zip(transformed_images, masks) ]) if 'pix_distribs' in inputs: with tf.name_scope('gen_pix_distribs' + suffix): assert len(transformed_pix_distribs) == len(masks) gen_pix_distrib = tf.add_n([ transformed_pix_distrib * mask for transformed_pix_distrib, mask in zip( transformed_pix_distribs, masks) ]) if self.hparams.renormalize_pixdistrib: gen_pix_distrib /= tf.reduce_sum(gen_pix_distrib, axis=(1, 2), keepdims=True) outputs['gen_images' + suffix] = gen_image outputs['transformed_images' + suffix] = tf.stack( transformed_images, axis=-1) outputs['masks' + suffix] = tf.stack(masks, axis=-1) if 'pix_distribs' in inputs: outputs['gen_pix_distribs' + suffix] = gen_pix_distrib outputs['transformed_pix_distribs' + suffix] = tf.stack( transformed_pix_distribs, axis=-1) if self.hparams.transformation == 'flow': outputs['gen_flows' + suffix] = flows flows_transposed = tf.transpose(flows, [0, 1, 2, 4, 3]) flows_rgb_transposed = tf_utils.flow_to_rgb(flows_transposed) flows_rgb = tf.transpose(flows_rgb_transposed, [0, 1, 2, 4, 3]) outputs['gen_flows_rgb' + suffix] = flows_rgb new_states['gen_image' + suffix] = gen_image new_states['last_images' + suffix] = last_images if 'pix_distribs' in inputs: new_states['gen_pix_distrib' + suffix] = gen_pix_distrib new_states['last_pix_distribs' + suffix] = last_pix_distribs if 'states' in inputs: with tf.name_scope('gen_states'): with tf.variable_scope('state_pred'): gen_state = dense(state_action, inputs['states'].shape[-1].value) if 'states' in inputs: outputs['gen_states'] = gen_state new_states['time'] = time + 1 if 'zs' in inputs and self.hparams.use_rnn_z: new_states['rnn_z_state'] = rnn_z_state if 'states' in inputs: new_states['gen_state'] = gen_state return outputs, new_states
def call(self, inputs, states): norm_layer = ops.get_norm_layer(self.hparams.norm_layer) feature_shape = inputs['features'].get_shape().as_list() batch_size, height, width, feature_channels = feature_shape conv_rnn_states = states['conv_rnn_states'] time = states['time'] with tf.control_dependencies([tf.assert_equal(time[1:], time[0])]): t = tf.to_int32(tf.identity(time[0])) feature = tf.where(self.ground_truth[t], inputs['features'], states['gen_feature']) # schedule sampling (if any) if 'states' in inputs: state = tf.where(self.ground_truth[t], inputs['states'], states['gen_state']) state_action = [] state_action_z = [] if 'actions' in inputs: state_action.append(inputs['actions']) state_action_z.append(inputs['actions']) if 'states' in inputs: state_action.append(state) # don't backpropagate the convnet through the state dynamics state_action_z.append(tf.stop_gradient(state)) if 'zs' in inputs: if self.hparams.use_rnn_z: with tf.variable_scope('%s_z' % self.hparams.rnn): rnn_z, rnn_z_state = self._rnn_func( inputs['zs'], states['rnn_z_state'], self.hparams.nz) state_action_z.append(rnn_z) else: state_action_z.append(inputs['zs']) def concat(tensors, axis): if len(tensors) == 0: return tf.zeros([batch_size, 0]) elif len(tensors) == 1: return tensors[0] else: return tf.concat(tensors, axis=axis) state_action = concat(state_action, axis=-1) state_action_z = concat(state_action_z, axis=-1) if 'actions' in inputs: gen_input = tile_concat( [feature, inputs['actions'][:, None, None, :]], axis=-1) else: gen_input = feature layers = [] new_conv_rnn_states = [] for i, (out_channels, use_conv_rnn) in enumerate(self.encoder_layer_specs): with tf.variable_scope('h%d' % i): if i == 0: # h = tf.concat([feature, self.inputs['features'][0]], axis=-1) # TODO: use first feature? h = feature else: h = layers[-1][-1] h = conv_pool2d(tile_concat( [h, state_action_z[:, None, None, :]], axis=-1), out_channels, kernel_size=(3, 3), strides=(2, 2)) h = norm_layer(h) h = tf.nn.relu(h) if use_conv_rnn: conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)] with tf.variable_scope('%s_h%d' % (self.hparams.conv_rnn, i)): conv_rnn_h, conv_rnn_state = self._conv_rnn_func( tile_concat([h, state_action_z[:, None, None, :]], axis=-1), conv_rnn_state, out_channels) new_conv_rnn_states.append(conv_rnn_state) layers.append((h, conv_rnn_h) if use_conv_rnn else (h, )) num_encoder_layers = len(layers) for i, (out_channels, use_conv_rnn) in enumerate(self.decoder_layer_specs): with tf.variable_scope('h%d' % len(layers)): if i == 0: h = layers[-1][-1] else: h = tf.concat([ layers[-1][-1], layers[num_encoder_layers - i - 1][-1] ], axis=-1) h = upsample_conv2d(tile_concat( [h, state_action_z[:, None, None, :]], axis=-1), out_channels, kernel_size=(3, 3), strides=(2, 2)) h = norm_layer(h) h = tf.nn.relu(h) if use_conv_rnn: conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)] with tf.variable_scope('%s_h%d' % (self.hparams.conv_rnn, len(layers))): conv_rnn_h, conv_rnn_state = self._conv_rnn_func( tile_concat([h, state_action_z[:, None, None, :]], axis=-1), conv_rnn_state, out_channels) new_conv_rnn_states.append(conv_rnn_state) layers.append((h, conv_rnn_h) if use_conv_rnn else (h, )) assert len(new_conv_rnn_states) == len(conv_rnn_states) if self.hparams.transformation == 'direct': with tf.variable_scope('h%d_direct' % len(layers)): h_direct = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1)) h_direct = norm_layer(h_direct) h_direct = tf.nn.relu(h_direct) with tf.variable_scope('direct'): gen_feature = conv2d(h_direct, feature_channels, kernel_size=(3, 3), strides=(1, 1)) else: if self.hparams.transformation == 'flow': with tf.variable_scope('h%d_flow' % len(layers)): h_flow = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1)) h_flow = norm_layer(h_flow) h_flow = tf.nn.relu(h_flow) with tf.variable_scope('flows'): flows = conv2d(h_flow, 2 * feature_channels, kernel_size=(3, 3), strides=(1, 1)) flows = tf.reshape( flows, [batch_size, height, width, 2, feature_channels]) transformations = flows else: assert len(self.hparams.kernel_size) == 2 kernel_shape = list( self.hparams.kernel_size) + [feature_channels] if self.hparams.transformation == 'local': with tf.variable_scope('h%d_local_kernel' % len(layers)): h_local_kernel = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1)) h_local_kernel = norm_layer(h_local_kernel) h_local_kernel = tf.nn.relu(h_local_kernel) # Using largest hidden state for predicting untied conv kernels. with tf.variable_scope('local_kernels'): kernels = conv2d(h_local_kernel, np.prod(kernel_shape), kernel_size=(3, 3), strides=(1, 1)) kernels = tf.reshape(kernels, [batch_size, height, width] + kernel_shape) kernels = kernels + identity_kernel( self.hparams.kernel_size)[None, None, None, :, :, None] elif self.hparams.transformation == 'conv': with tf.variable_scope('conv_kernels'): smallest_layer = layers[num_encoder_layers - 1][-1] kernels = dense(flatten(smallest_layer), np.prod(kernel_shape)) kernels = tf.reshape(kernels, [batch_size] + kernel_shape) kernels = kernels + identity_kernel( self.hparams.kernel_size)[None, :, :, None] else: raise ValueError('Invalid transformation %s' % self.hparams.transformation) transformations = kernels with tf.name_scope('gen_features'): if self.hparams.transformation == 'flow': def apply_transformation(feature_and_flow): feature, flow = feature_and_flow return flow_ops.image_warp(feature[..., None], flow) else: def apply_transformation(feature_and_kernel): feature, kernel = feature_and_kernel output, = apply_kernels(feature[..., None], kernel[..., None]) return tf.squeeze(output, axis=-1) gen_feature_transposed = tf.map_fn( apply_transformation, (tf.stack(tf.unstack(feature, axis=-1)), tf.stack(tf.unstack(transformations, axis=-1))), dtype=tf.float32) gen_feature = tf.stack(tf.unstack(gen_feature_transposed), axis=-1) # TODO: use norm and relu for generated features? gen_feature = norm_layer(gen_feature) gen_feature = tf.nn.relu(gen_feature) if 'states' in inputs: with tf.name_scope('gen_states'): with tf.variable_scope('state_pred'): gen_state = dense(state_action, inputs['states'].shape[-1].value) outputs = { 'gen_features': gen_feature, 'gen_inputs': gen_input, } if 'states' in inputs: outputs['gen_states'] = gen_state if self.hparams.transformation == 'flow': outputs['gen_flows'] = flows new_states = { 'time': time + 1, 'gen_feature': gen_feature, 'conv_rnn_states': new_conv_rnn_states, } if 'zs' in inputs and self.hparams.use_rnn_z: new_states['rnn_z_state'] = rnn_z_state if 'states' in inputs: new_states['gen_state'] = gen_state return outputs, new_states