def create_legacy_discriminator(discrim_targets, discrim_inputs=None, ndf=64, norm_layer='instance', downsample_layer='conv_pool2d'): norm_layer = ops.get_norm_layer(norm_layer) downsample_layer = ops.get_downsample_layer(downsample_layer) layers = [] inputs = [discrim_targets] if discrim_inputs is not None: inputs.append(discrim_inputs) inputs = tf.concat(inputs, axis=-1) scale_size = min(*inputs.shape.as_list()[1:3]) if scale_size == 256: layer_specs = [ ( ndf, 2 ), # layer_1: [batch, 256, 256, in_channels * 2] => [batch, 128, 128, ndf] (ndf * 2, 2), # layer_2: [batch, 128, 128, ndf] => [batch, 64, 64, ndf * 2] ( ndf * 4, 2 ), # layer_3: [batch, 64, 64, ndf * 2] => [batch, 32, 32, ndf * 4] ( ndf * 8, 1 ), # layer_4: [batch, 32, 32, ndf * 4] => [batch, 32, 32, ndf * 8] (1, 1), # layer_5: [batch, 32, 32, ndf * 8] => [batch, 32, 32, 1] ] elif scale_size == 128: layer_specs = [ (ndf, 2), (ndf * 2, 2), (ndf * 4, 1), (ndf * 8, 1), (1, 1), ] elif scale_size == 64: layer_specs = [ (ndf, 2), (ndf * 2, 1), (ndf * 4, 1), (ndf * 8, 1), (1, 1), ] else: raise NotImplementedError with tf.variable_scope("layer_1"): out_channels, strides = layer_specs[0] convolved = downsample_layer(inputs, out_channels, kernel_size=4, strides=strides) rectified = lrelu(convolved, 0.2) layers.append(rectified) for out_channels, strides in layer_specs[1:-1]: with tf.variable_scope("layer_%d" % (len(layers) + 1)): if strides == 1: convolved = conv2d(layers[-1], out_channels, kernel_size=4) else: convolved = downsample_layer(layers[-1], out_channels, kernel_size=4, strides=strides) normalized = norm_layer(convolved) rectified = lrelu(normalized, 0.2) layers.append(rectified) with tf.variable_scope("layer_%d" % (len(layers) + 1)): out_channels, strides = layer_specs[-1] if strides == 1: logits = conv2d(rectified, out_channels, kernel_size=4) else: logits = downsample_layer(rectified, out_channels, kernel_size=4, strides=strides) layers.append( logits ) # don't apply sigmoid to the logits in case we want to use LSGAN return layers
def call(self, inputs, states): norm_layer = ops.get_norm_layer(self.hparams.norm_layer) downsample_layer = ops.get_downsample_layer( self.hparams.downsample_layer) upsample_layer = ops.get_upsample_layer(self.hparams.upsample_layer) image_shape = inputs['images'].get_shape().as_list() batch_size, height, width, color_channels = image_shape time = states['time'] with tf.control_dependencies([tf.assert_equal(time[1:], time[0])]): t = tf.to_int32(tf.identity(time[0])) if 'states' in inputs: state = tf.where(self.ground_truth[t], inputs['states'], states['gen_state']) state_action = [] state_action_z = [] if 'actions' in inputs: state_action.append(inputs['actions']) state_action_z.append(inputs['actions']) if 'states' in inputs: state_action.append(state) # don't backpropagate the convnet through the state dynamics state_action_z.append(tf.stop_gradient(state)) if 'zs' in inputs: if self.hparams.use_rnn_z: with tf.variable_scope('%s_z' % self.hparams.rnn): rnn_z, rnn_z_state = self._rnn_func( inputs['zs'], states['rnn_z_state'], self.hparams.nz) state_action_z.append(rnn_z) else: state_action_z.append(inputs['zs']) def concat(tensors, axis): if len(tensors) == 0: return tf.zeros([batch_size, 0]) elif len(tensors) == 1: return tensors[0] else: return tf.concat(tensors, axis=axis) state_action = concat(state_action, axis=-1) state_action_z = concat(state_action_z, axis=-1) image_views = [] first_image_views = [] if 'pix_distribs' in inputs: pix_distrib_views = [] for i in range(self.hparams.num_views): suffix = '%d' % i if i > 0 else '' image_view = tf.where( self.ground_truth[t], inputs['images' + suffix], states['gen_image' + suffix]) # schedule sampling (if any) image_views.append(image_view) first_image_views.append(self.inputs['images' + suffix][0]) if 'pix_distribs' in inputs: pix_distrib_view = tf.where(self.ground_truth[t], inputs['pix_distribs' + suffix], states['gen_pix_distrib' + suffix]) pix_distrib_views.append(pix_distrib_view) outputs = {} new_states = {} all_layers = [] for i in range(self.hparams.num_views): suffix = '%d' % i if i > 0 else '' conv_rnn_states = states['conv_rnn_states' + suffix] layers = [] new_conv_rnn_states = [] for i, (out_channels, use_conv_rnn) in enumerate(self.encoder_layer_specs): with tf.variable_scope('h%d' % i + suffix): if i == 0: # all image views and the first image corresponding to this view only h = tf.concat(image_views + first_image_views, axis=-1) kernel_size = (5, 5) else: h = layers[-1][-1] kernel_size = (3, 3) if self.hparams.where_add == 'all' or ( self.hparams.where_add == 'input' and i == 0): h = tile_concat([h, state_action_z[:, None, None, :]], axis=-1) h = downsample_layer(h, out_channels, kernel_size=kernel_size, strides=(2, 2)) h = norm_layer(h) h = tf.nn.relu(h) if use_conv_rnn: conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)] with tf.variable_scope('%s_h%d' % (self.hparams.conv_rnn, i) + suffix): if self.hparams.where_add == 'all': conv_rnn_h = tile_concat( [h, state_action_z[:, None, None, :]], axis=-1) else: conv_rnn_h = h conv_rnn_h, conv_rnn_state = self._conv_rnn_func( conv_rnn_h, conv_rnn_state, out_channels) new_conv_rnn_states.append(conv_rnn_state) layers.append((h, conv_rnn_h) if use_conv_rnn else (h, )) num_encoder_layers = len(layers) for i, (out_channels, use_conv_rnn) in enumerate(self.decoder_layer_specs): with tf.variable_scope('h%d' % len(layers) + suffix): if i == 0: h = layers[-1][-1] else: h = tf.concat([ layers[-1][-1], layers[num_encoder_layers - i - 1][-1] ], axis=-1) if self.hparams.where_add == 'all' or ( self.hparams.where_add == 'middle' and i == 0): h = tile_concat([h, state_action_z[:, None, None, :]], axis=-1) h = upsample_layer(h, out_channels, kernel_size=(3, 3), strides=(2, 2)) h = norm_layer(h) h = tf.nn.relu(h) if use_conv_rnn: conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)] with tf.variable_scope( '%s_h%d' % (self.hparams.conv_rnn, len(layers)) + suffix): if self.hparams.where_add == 'all': conv_rnn_h = tile_concat( [h, state_action_z[:, None, None, :]], axis=-1) else: conv_rnn_h = h conv_rnn_h, conv_rnn_state = self._conv_rnn_func( conv_rnn_h, conv_rnn_state, out_channels) new_conv_rnn_states.append(conv_rnn_state) layers.append((h, conv_rnn_h) if use_conv_rnn else (h, )) assert len(new_conv_rnn_states) == len(conv_rnn_states) new_states['conv_rnn_states' + suffix] = new_conv_rnn_states all_layers.append(layers) if self.hparams.shared_views: break for i in range(self.hparams.num_views): suffix = '%d' % i if i > 0 else '' if self.hparams.shared_views: layers, = all_layers else: layers = all_layers[i] image = image_views[i] last_images = states['last_images' + suffix][1:] + [image] if 'pix_distribs' in inputs: pix_distrib = pix_distrib_views[i] last_pix_distribs = states['last_pix_distribs' + suffix][1:] + [pix_distrib] if self.hparams.last_frames and self.hparams.num_transformed_images: if self.hparams.transformation == 'flow': with tf.variable_scope('h%d_flow' % len(layers) + suffix): h_flow = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1)) h_flow = norm_layer(h_flow) h_flow = tf.nn.relu(h_flow) with tf.variable_scope('flows' + suffix): flows = conv2d(h_flow, 2 * self.hparams.last_frames * self.hparams.num_transformed_images, kernel_size=(3, 3), strides=(1, 1)) flows = tf.reshape(flows, [ batch_size, height, width, 2, self.hparams.last_frames * self.hparams.num_transformed_images ]) else: assert len(self.hparams.kernel_size) == 2 kernel_shape = list(self.hparams.kernel_size) + [ self.hparams.last_frames * self.hparams.num_transformed_images ] if self.hparams.transformation == 'dna': with tf.variable_scope('h%d_dna_kernel' % len(layers) + suffix): h_dna_kernel = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1)) h_dna_kernel = norm_layer(h_dna_kernel) h_dna_kernel = tf.nn.relu(h_dna_kernel) # Using largest hidden state for predicting untied conv kernels. with tf.variable_scope('dna_kernels' + suffix): kernels = conv2d(h_dna_kernel, np.prod(kernel_shape), kernel_size=(3, 3), strides=(1, 1)) kernels = tf.reshape(kernels, [batch_size, height, width] + kernel_shape) kernels = kernels + identity_kernel( self.hparams.kernel_size)[None, None, None, :, :, None] kernel_spatial_axes = [3, 4] elif self.hparams.transformation == 'cdna': with tf.variable_scope('cdna_kernels' + suffix): smallest_layer = layers[num_encoder_layers - 1][-1] kernels = dense(flatten(smallest_layer), np.prod(kernel_shape)) kernels = tf.reshape(kernels, [batch_size] + kernel_shape) kernels = kernels + identity_kernel( self.hparams.kernel_size)[None, :, :, None] kernel_spatial_axes = [1, 2] else: raise ValueError('Invalid transformation %s' % self.hparams.transformation) if self.hparams.transformation != 'flow': with tf.name_scope('kernel_normalization' + suffix): kernels = tf.nn.relu(kernels - RELU_SHIFT) + RELU_SHIFT kernels /= tf.reduce_sum(kernels, axis=kernel_spatial_axes, keepdims=True) if self.hparams.generate_scratch_image: with tf.variable_scope('h%d_scratch' % len(layers) + suffix): h_scratch = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1)) h_scratch = norm_layer(h_scratch) h_scratch = tf.nn.relu(h_scratch) # Using largest hidden state for predicting a new image layer. # This allows the network to also generate one image from scratch, # which is useful when regions of the image become unoccluded. with tf.variable_scope('scratch_image' + suffix): scratch_image = conv2d(h_scratch, color_channels, kernel_size=(3, 3), strides=(1, 1)) scratch_image = tf.nn.sigmoid(scratch_image) with tf.name_scope('transformed_images' + suffix): transformed_images = [] if self.hparams.last_frames and self.hparams.num_transformed_images: if self.hparams.transformation == 'flow': transformed_images.extend( apply_flows(last_images, flows)) else: transformed_images.extend( apply_kernels(last_images, kernels, self.hparams.dilation_rate)) if self.hparams.prev_image_background: transformed_images.append(image) if self.hparams.first_image_background and not self.hparams.context_images_background: transformed_images.append(self.inputs['images' + suffix][0]) if self.hparams.context_images_background: transformed_images.extend( tf.unstack( self.inputs['images' + suffix][:self.hparams.context_frames])) if self.hparams.generate_scratch_image: transformed_images.append(scratch_image) if 'pix_distribs' in inputs: with tf.name_scope('transformed_pix_distribs' + suffix): transformed_pix_distribs = [] if self.hparams.last_frames and self.hparams.num_transformed_images: if self.hparams.transformation == 'flow': transformed_pix_distribs.extend( apply_flows(last_pix_distribs, flows)) else: transformed_pix_distribs.extend( apply_kernels(last_pix_distribs, kernels, self.hparams.dilation_rate)) if self.hparams.prev_image_background: transformed_pix_distribs.append(pix_distrib) if self.hparams.first_image_background and not self.hparams.context_images_background: transformed_pix_distribs.append( self.inputs['pix_distribs' + suffix][0]) if self.hparams.context_images_background: transformed_pix_distribs.extend( tf.unstack(self.inputs['pix_distribs' + suffix] [:self.hparams.context_frames])) if self.hparams.generate_scratch_image: transformed_pix_distribs.append(pix_distrib) with tf.name_scope('masks' + suffix): if len(transformed_images) > 1: with tf.variable_scope('h%d_masks' % len(layers) + suffix): h_masks = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1)) h_masks = norm_layer(h_masks) h_masks = tf.nn.relu(h_masks) with tf.variable_scope('masks' + suffix): if self.hparams.dependent_mask: h_masks = tf.concat([h_masks] + transformed_images, axis=-1) masks = conv2d(h_masks, len(transformed_images), kernel_size=(3, 3), strides=(1, 1)) masks = tf.nn.softmax(masks) masks = tf.split(masks, len(transformed_images), axis=-1) elif len(transformed_images) == 1: masks = [tf.ones([batch_size, height, width, 1])] else: raise ValueError( "Either one of the following should be true: " "last_frames and num_transformed_images, first_image_background, " "prev_image_background, generate_scratch_image") with tf.name_scope('gen_images' + suffix): assert len(transformed_images) == len(masks) gen_image = tf.add_n([ transformed_image * mask for transformed_image, mask in zip(transformed_images, masks) ]) if 'pix_distribs' in inputs: with tf.name_scope('gen_pix_distribs' + suffix): assert len(transformed_pix_distribs) == len(masks) gen_pix_distrib = tf.add_n([ transformed_pix_distrib * mask for transformed_pix_distrib, mask in zip( transformed_pix_distribs, masks) ]) if self.hparams.renormalize_pixdistrib: gen_pix_distrib /= tf.reduce_sum(gen_pix_distrib, axis=(1, 2), keepdims=True) outputs['gen_images' + suffix] = gen_image outputs['transformed_images' + suffix] = tf.stack( transformed_images, axis=-1) outputs['masks' + suffix] = tf.stack(masks, axis=-1) if 'pix_distribs' in inputs: outputs['gen_pix_distribs' + suffix] = gen_pix_distrib outputs['transformed_pix_distribs' + suffix] = tf.stack( transformed_pix_distribs, axis=-1) if self.hparams.transformation == 'flow': outputs['gen_flows' + suffix] = flows flows_transposed = tf.transpose(flows, [0, 1, 2, 4, 3]) flows_rgb_transposed = tf_utils.flow_to_rgb(flows_transposed) flows_rgb = tf.transpose(flows_rgb_transposed, [0, 1, 2, 4, 3]) outputs['gen_flows_rgb' + suffix] = flows_rgb new_states['gen_image' + suffix] = gen_image new_states['last_images' + suffix] = last_images if 'pix_distribs' in inputs: new_states['gen_pix_distrib' + suffix] = gen_pix_distrib new_states['last_pix_distribs' + suffix] = last_pix_distribs if 'states' in inputs: with tf.name_scope('gen_states'): with tf.variable_scope('state_pred'): gen_state = dense(state_action, inputs['states'].shape[-1].value) if 'states' in inputs: outputs['gen_states'] = gen_state new_states['time'] = time + 1 if 'zs' in inputs and self.hparams.use_rnn_z: new_states['rnn_z_state'] = rnn_z_state if 'states' in inputs: new_states['gen_state'] = gen_state return outputs, new_states
def create_generator(generator_inputs, output_nc=3, ngf=64, norm_layer='instance', downsample_layer='conv_pool2d', upsample_layer='upsample_conv2d'): norm_layer = ops.get_norm_layer(norm_layer) downsample_layer = ops.get_downsample_layer(downsample_layer) upsample_layer = ops.get_upsample_layer(upsample_layer) layers = [] inputs = generator_inputs scale_size = min(*inputs.shape.as_list()[1:3]) if scale_size == 256: layer_specs = [ ( ngf, 2 ), # encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf] ( ngf * 2, 2 ), # encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 2] ( ngf * 4, 2 ), # encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 4] ( ngf * 8, 2 ), # encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8] ( ngf * 8, 2 ), # encoder_5: [batch, 16, 16, ngf * 8] => [batch, 8, 8, ngf * 8] (ngf * 8, 2), # encoder_6: [batch, 8, 8, ngf * 8] => [batch, 4, 4, ngf * 8] (ngf * 8, 2), # encoder_7: [batch, 4, 4, ngf * 8] => [batch, 2, 2, ngf * 8] (ngf * 8, 2), # encoder_8: [batch, 2, 2, ngf * 8] => [batch, 1, 1, ngf * 8] ] elif scale_size == 128: layer_specs = [ (ngf, 2), (ngf * 2, 2), (ngf * 4, 2), (ngf * 8, 2), (ngf * 8, 2), (ngf * 8, 2), (ngf * 8, 2), ] elif scale_size == 64: layer_specs = [ (ngf, 2), (ngf * 2, 2), (ngf * 4, 2), (ngf * 8, 2), (ngf * 8, 2), (ngf * 8, 2), ] else: raise NotImplementedError with tf.variable_scope("encoder_1"): out_channels, strides = layer_specs[0] if strides == 1: output = conv2d(inputs, out_channels, kernel_size=4) else: output = downsample_layer(inputs, out_channels, kernel_size=4, strides=strides) layers.append(output) for out_channels, strides in layer_specs[1:]: with tf.variable_scope("encoder_%d" % (len(layers) + 1)): rectified = lrelu(layers[-1], 0.2) # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels] if strides == 1: convolved = conv2d(rectified, out_channels, kernel_size=4) else: convolved = downsample_layer(rectified, out_channels, kernel_size=4, strides=strides) output = norm_layer(convolved) layers.append(output) if scale_size == 256: layer_specs = [ ( ngf * 8, 2, 0.5 ), # decoder_8: [batch, 1, 1, ngf * 8] => [batch, 2, 2, ngf * 8 * 2] ( ngf * 8, 2, 0.5 ), # decoder_7: [batch, 2, 2, ngf * 8 * 2] => [batch, 4, 4, ngf * 8 * 2] ( ngf * 8, 2, 0.5 ), # decoder_6: [batch, 4, 4, ngf * 8 * 2] => [batch, 8, 8, ngf * 8 * 2] ( ngf * 8, 2, 0.0 ), # decoder_5: [batch, 8, 8, ngf * 8 * 2] => [batch, 16, 16, ngf * 8 * 2] ( ngf * 4, 2, 0.0 ), # decoder_4: [batch, 16, 16, ngf * 8 * 2] => [batch, 32, 32, ngf * 4 * 2] ( ngf * 2, 2, 0.0 ), # decoder_3: [batch, 32, 32, ngf * 4 * 2] => [batch, 64, 64, ngf * 2 * 2] ( ngf, 2, 0.0 ), # decoder_2: [batch, 64, 64, ngf * 2 * 2] => [batch, 128, 128, ngf * 2] ( output_nc, 2, 0.0 ), # decoder_1: [batch, 128, 128, ngf * 2] => [batch, 256, 256, generator_outputs_channels] ] elif scale_size == 128: layer_specs = [ (ngf * 8, 2, 0.5), (ngf * 8, 2, 0.5), (ngf * 8, 2, 0.5), (ngf * 4, 2, 0.0), (ngf * 2, 2, 0.0), (ngf, 2, 0.0), (output_nc, 2, 0.0), ] elif scale_size == 64: layer_specs = [ (ngf * 8, 2, 0.5), (ngf * 8, 2, 0.5), (ngf * 4, 2, 0.0), (ngf * 2, 2, 0.0), (ngf, 2, 0.0), (output_nc, 2, 0.0), ] else: raise NotImplementedError num_encoder_layers = len(layers) for decoder_layer, (out_channels, stride, dropout) in enumerate(layer_specs[:-1]): skip_layer = num_encoder_layers - decoder_layer - 1 with tf.variable_scope("decoder_%d" % (skip_layer + 1)): if decoder_layer == 0: # first decoder layer doesn't have skip connections # since it is directly connected to the skip_layer input = layers[-1] else: input = tf.concat([layers[-1], layers[skip_layer]], axis=3) rectified = tf.nn.relu(input) # [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels] if stride == 1: output = conv2d(rectified, out_channels, kernel_size=4) else: output = upsample_layer(rectified, out_channels, kernel_size=4, strides=strides) output = norm_layer(output) if dropout > 0.0: output = tf.nn.dropout(output, keep_prob=1 - dropout) layers.append(output) with tf.variable_scope("decoder_1"): out_channels, stride, dropout = layer_specs[-1] assert dropout == 0.0 # no dropout at the last layer input = tf.concat([layers[-1], layers[0]], axis=3) rectified = tf.nn.relu(input) if stride == 1: output = conv2d(rectified, out_channels, kernel_size=4) else: output = upsample_layer(rectified, out_channels, kernel_size=4, strides=strides) output = tf.tanh(output) output = (output + 1) / 2 layers.append(output) return layers[-1]