def generator_fn(inputs, hparams=None): batch_size = inputs['images'].shape[1].value inputs = { name: tf_utils.maybe_pad_or_slice(input, hparams.sequence_length - 1) for name, input in inputs.items() } with tf.variable_scope('gru'): gru_cell = tf.nn.rnn_cell.GRUCell(hparams.dim_z_motion) if hparams.context_frames: with tf.variable_scope('content_encoder'): z_c = create_encoder( inputs['images'][0], # first context image for content encoder nef=hparams.nef, norm_layer=hparams.norm_layer, dim_z=hparams.dim_z_content) with tf.variable_scope('initial_motion_encoder'): h_0 = create_encoder( inputs['images'][hparams.context_frames - 1], # last context image for motion encoder nef=hparams.nef, norm_layer=hparams.norm_layer, dim_z=hparams.dim_z_motion) else: # unconditional case z_c = tf.random_normal([batch_size, hparams.dim_z_content]) h_0 = gru_cell.zero_state(batch_size, tf.float32) h_t = [h_0] for t in range(hparams.context_frames - 1, hparams.sequence_length - 1): with tf.variable_scope('gru', reuse=t > hparams.context_frames - 1): e_t = tf.random_normal([batch_size, hparams.dim_z_motion]) if 'actions' in inputs: e_t = tf.concat(inputs['actions'][t], axis=-1) h_t.append(gru_cell( e_t, h_t[-1])[1]) # the output and state is the same in GRUs z_m = tf.stack(h_t[1:], axis=0) z = tf.concat([ tf.tile(z_c[None, :, :], [hparams.sequence_length - hparams.context_frames, 1, 1]), z_m ], axis=-1) z_flatten = flatten(z[:, :, None, None, :], 0, 1) gen_images_flatten = create_generator(z_flatten, ngf=hparams.ngf, norm_layer=hparams.norm_layer) gen_images = tf.reshape(gen_images_flatten, [-1, batch_size] + gen_images_flatten.shape.as_list()[1:]) outputs = {'gen_images': gen_images} return gen_images, outputs
def generator_fn(inputs, outputs_enc=None, hparams=None): batch_size = inputs['images'].shape[1].value inputs = {name: tf_utils.maybe_pad_or_slice(input, hparams.sequence_length - 1) for name, input in inputs.items()} if hparams.nz: def sample_zs(): if outputs_enc is None: zs = tf.random_normal([hparams.sequence_length - 1, batch_size, hparams.nz], 0, 1) else: enc_zs_mu = outputs_enc['enc_zs_mu'] enc_zs_log_sigma_sq = outputs_enc['enc_zs_log_sigma_sq'] eps = tf.random_normal([hparams.sequence_length - 1, batch_size, hparams.nz], 0, 1) zs = enc_zs_mu + tf.sqrt(tf.exp(enc_zs_log_sigma_sq)) * eps return zs inputs['zs'] = sample_zs() else: if outputs_enc is not None: raise ValueError('outputs_enc has to be None when nz is 0.') cell = DNACell(inputs, hparams) outputs, _ = tf.nn.dynamic_rnn(cell, inputs, dtype=tf.float32, swap_memory=False, time_major=True) if hparams.nz: inputs_samples = {name: flatten(tf.tile(input[:, None], [1, hparams.num_samples] + [1] * (input.shape.ndims - 1)), 1, 2) for name, input in inputs.items() if name != 'zs'} inputs_samples['zs'] = tf.concat([sample_zs() for _ in range(hparams.num_samples)], axis=1) with tf.variable_scope(tf.get_variable_scope(), reuse=True): cell_samples = DNACell(inputs_samples, hparams) outputs_samples, _ = tf.nn.dynamic_rnn(cell_samples, inputs_samples, dtype=tf.float32, swap_memory=False, time_major=True) gen_images_samples = outputs_samples['gen_images'] gen_images_samples = tf.stack(tf.split(gen_images_samples, hparams.num_samples, axis=1), axis=-1) gen_images_samples_avg = tf.reduce_mean(gen_images_samples, axis=-1) outputs['gen_images_samples'] = gen_images_samples outputs['gen_images_samples_avg'] = gen_images_samples_avg # the RNN outputs generated images from time step 1 to sequence_length, # but generator_fn should only return images past context_frames outputs = {name: output[hparams.context_frames - 1:] for name, output in outputs.items()} gen_images = outputs['gen_images'] outputs['ground_truth_sampling_mean'] = tf.reduce_mean(tf.to_float(cell.ground_truth[hparams.context_frames:])) return gen_images, outputs
def generator_fn(inputs, hparams=None): images = inputs['images'] with tf.variable_scope('encoder'): features = tf.map_fn(create_pspnet50_encoder, images) features = tf.stop_gradient(features) inputs = dict(inputs) inputs['features'] = features inputs = { name: tf_utils.maybe_pad_or_slice(input, hparams.sequence_length - 1) for name, input in inputs.items() } cell = DynamicsCell(inputs, hparams) outputs, _ = tf.nn.dynamic_rnn(cell, inputs, dtype=tf.float32, swap_memory=False, time_major=True) # the RNN outputs generated images from time step 1 to sequence_length, # but generator_fn should only return images past context_frames outputs = { name: output[hparams.context_frames - 1:] for name, output in outputs.items() } outputs['ground_truth_sampling_mean'] = tf.reduce_mean( tf.to_float(cell.ground_truth[hparams.context_frames:])) gen_features = outputs['gen_features'] with tf.variable_scope('decoder') as decoder_scope: gen_images = tf.map_fn( create_decoder, tf.stop_gradient(gen_features)) # TODO: stop gradient for decoder? with tf.variable_scope(decoder_scope, reuse=True): gen_images_dec = tf.map_fn(create_decoder, features) outputs['gen_images'] = gen_images outputs['gen_images_dec'] = gen_images_dec outputs['features'] = features return gen_images, outputs
def generator_fn(inputs, hparams=None): batch_size = inputs['images'].shape[1].value cell = Pix2PixCell(inputs, hparams) inputs = OrderedDict([ (name, tf_utils.maybe_pad_or_slice(input, hparams.sequence_length - 1)) for name, input in inputs.items() if name in ('images', 'actions') ]) (gen_images, gen_inputs), _ = \ tf.nn.dynamic_rnn(cell, inputs, sequence_length=[hparams.sequence_length - 1] * batch_size, dtype=tf.float32, swap_memory=False, time_major=True) # the RNN outputs generated images from time step 1 to sequence_length, # but generator_fn should only return images past context_frames outputs = { 'gen_images': gen_images[hparams.context_frames - 1:], 'gen_inputs': gen_inputs[hparams.context_frames - 1:], 'ground_truth_sampling_mean': tf.reduce_mean(tf.to_float(cell.ground_truth[hparams.context_frames:])) } gen_images = outputs['gen_images'] return gen_images, outputs
def generator_fn(inputs, mode, hparams=None): inputs = {name: tf_utils.maybe_pad_or_slice(input, hparams.sequence_length - 1) for name, input in inputs.items()} images = inputs['images'] input_images = tf.concat(tf.unstack(images[:hparams.context_frames], axis=0), axis=-1) gen_states = [] for t in range(hparams.context_frames, hparams.sequence_length): state_action = [] if 'actions' in inputs: state_action.append(inputs['actions'][t - 1]) if 'states' in inputs: state_action.append(gen_states[-1] if gen_states else inputs['states'][t - 1]) state_action = tf.concat(state_action, axis=-1) with tf.name_scope('gen_states'): with tf.variable_scope('state_pred%d' % t): gen_state = dense(state_action, inputs['states'].shape[-1]) gen_states.append(gen_state) states_actions = [] if 'actions' in inputs: states_actions += tf.unstack(inputs['actions'][:hparams.sequence_length - 1], axis=0) if 'states' in inputs: states_actions += tf.unstack(inputs['states'][:hparams.context_frames], axis=0) states_actions += gen_states if states_actions: states_actions = tf.concat(states_actions, axis=-1) # don't backpropagate the convnet through the state dynamics states_actions = tf.stop_gradient(states_actions) else: states_actions = tf.zeros([images.shape[1], 0]) with slim.arg_scope([slim.conv2d], activation_fn=tf.nn.relu, weights_initializer=tf.truncated_normal_initializer(0.0, 0.01), weights_regularizer=slim.l2_regularizer(0.0001)): batch_norm_params = { 'decay': 0.9997, 'epsilon': 0.001, 'is_training': mode == 'train', } with slim.arg_scope([slim.batch_norm], is_training=mode == 'train', updates_collections=None): with slim.arg_scope([slim.conv2d], normalizer_fn=slim.batch_norm, normalizer_params=batch_norm_params): h0 = slim.conv2d(input_images, 64, [5, 5], stride=1, scope='conv1') size0 = tf.shape(input_images)[-3:-1] h1 = slim.max_pool2d(h0, [2, 2], scope='pool1') h1 = slim.conv2d(h1, 128, [5, 5], stride=1, scope='conv2') size1 = tf.shape(h1)[-3:-1] h2 = slim.max_pool2d(h1, [2, 2], scope='pool2') h2 = slim.conv2d(h2, 256, [3, 3], stride=1, scope='conv3') size2 = tf.shape(h2)[-3:-1] h3 = slim.max_pool2d(h2, [2, 2], scope='pool3') h3 = tile_concat([h3, states_actions[:, None, None, :]], axis=-1) h3 = slim.conv2d(h3, 256, [3, 3], stride=1, scope='conv4') h4 = tf.image.resize_bilinear(h3, size2) h4 = tf.concat([h4, h2], axis=-1) h4 = slim.conv2d(h4, 256, [3, 3], stride=1, scope='conv5') h5 = tf.image.resize_bilinear(h4, size1) h5 = tf.concat([h5, h1], axis=-1) h5 = slim.conv2d(h5, 128, [5, 5], stride=1, scope='conv6') h6 = tf.image.resize_bilinear(h5, size0) h6 = tf.concat([h6, h0], axis=-1) h6 = slim.conv2d(h6, 64, [5, 5], stride=1, scope='conv7') extrap_length = hparams.sequence_length - hparams.context_frames flows_masks = slim.conv2d(h6, 5 * extrap_length, [5, 5], stride=1, activation_fn=tf.tanh, normalizer_fn=None, scope='conv8') flows_masks = tf.split(flows_masks, extrap_length, axis=-1) gen_images = [] gen_flows_1 = [] gen_flows_2 = [] masks = [] for flows_mask in flows_masks: flow_1, flow_2, mask = tf.split(flows_mask, [2, 2, 1], axis=-1) gen_flows_1.append(flow_1) gen_flows_2.append(flow_2) mask = 0.5 * (1.0 + mask) masks.append(mask) linspace_x = tf.linspace(-1.0, 1.0, size0[1]) linspace_x.set_shape(input_images.shape[-2]) linspace_y = tf.linspace(-1.0, 1.0, size0[0]) linspace_y.set_shape(input_images.shape[-3]) grid_x, grid_y = tf.meshgrid(linspace_x, linspace_y) coor_x_1 = grid_x[None, :, :] + flow_1[:, :, :, 0] coor_y_1 = grid_y[None, :, :] + flow_1[:, :, :, 1] coor_x_2 = grid_x[None, :, :] + flow_2[:, :, :, 0] coor_y_2 = grid_y[None, :, :] + flow_2[:, :, :, 1] output_1 = bilinear_interp(images[0], coor_x_1, coor_y_1, 'interpolate') output_2 = bilinear_interp(images[1], coor_x_2, coor_y_2, 'interpolate') gen_image = mask * output_1 + (1.0 - mask) * output_2 gen_images.append(gen_image) gen_images = tf.stack(gen_images, axis=0) gen_flows_1 = tf.stack(gen_flows_1, axis=0) gen_flows_2 = tf.stack(gen_flows_2, axis=0) masks = tf.stack(masks, axis=0) outputs = { 'gen_images': gen_images, 'gen_flows_1': gen_flows_1, 'gen_flows_2': gen_flows_2, 'masks': masks, } if 'states' in inputs: gen_states = tf.stack(gen_states, axis=0) outputs['gen_states'] = gen_states return gen_images, outputs