コード例 #1
0
def generator_fn(inputs, hparams=None):
    batch_size = inputs['images'].shape[1].value
    inputs = {
        name: tf_utils.maybe_pad_or_slice(input, hparams.sequence_length - 1)
        for name, input in inputs.items()
    }

    with tf.variable_scope('gru'):
        gru_cell = tf.nn.rnn_cell.GRUCell(hparams.dim_z_motion)

    if hparams.context_frames:
        with tf.variable_scope('content_encoder'):
            z_c = create_encoder(
                inputs['images'][0],  # first context image for content encoder
                nef=hparams.nef,
                norm_layer=hparams.norm_layer,
                dim_z=hparams.dim_z_content)
        with tf.variable_scope('initial_motion_encoder'):
            h_0 = create_encoder(
                inputs['images'][hparams.context_frames -
                                 1],  # last context image for motion encoder
                nef=hparams.nef,
                norm_layer=hparams.norm_layer,
                dim_z=hparams.dim_z_motion)
    else:  # unconditional case
        z_c = tf.random_normal([batch_size, hparams.dim_z_content])
        h_0 = gru_cell.zero_state(batch_size, tf.float32)

    h_t = [h_0]
    for t in range(hparams.context_frames - 1, hparams.sequence_length - 1):
        with tf.variable_scope('gru', reuse=t > hparams.context_frames - 1):
            e_t = tf.random_normal([batch_size, hparams.dim_z_motion])
            if 'actions' in inputs:
                e_t = tf.concat(inputs['actions'][t], axis=-1)
            h_t.append(gru_cell(
                e_t, h_t[-1])[1])  # the output and state is the same in GRUs
    z_m = tf.stack(h_t[1:], axis=0)
    z = tf.concat([
        tf.tile(z_c[None, :, :],
                [hparams.sequence_length - hparams.context_frames, 1, 1]), z_m
    ],
                  axis=-1)

    z_flatten = flatten(z[:, :, None, None, :], 0, 1)
    gen_images_flatten = create_generator(z_flatten,
                                          ngf=hparams.ngf,
                                          norm_layer=hparams.norm_layer)
    gen_images = tf.reshape(gen_images_flatten, [-1, batch_size] +
                            gen_images_flatten.shape.as_list()[1:])

    outputs = {'gen_images': gen_images}
    return gen_images, outputs
コード例 #2
0
def generator_fn(inputs, outputs_enc=None, hparams=None):
    batch_size = inputs['images'].shape[1].value
    inputs = {name: tf_utils.maybe_pad_or_slice(input, hparams.sequence_length - 1)
              for name, input in inputs.items()}
    if hparams.nz:
        def sample_zs():
            if outputs_enc is None:
                zs = tf.random_normal([hparams.sequence_length - 1, batch_size, hparams.nz], 0, 1)
            else:
                enc_zs_mu = outputs_enc['enc_zs_mu']
                enc_zs_log_sigma_sq = outputs_enc['enc_zs_log_sigma_sq']
                eps = tf.random_normal([hparams.sequence_length - 1, batch_size, hparams.nz], 0, 1)
                zs = enc_zs_mu + tf.sqrt(tf.exp(enc_zs_log_sigma_sq)) * eps
            return zs
        inputs['zs'] = sample_zs()
    else:
        if outputs_enc is not None:
            raise ValueError('outputs_enc has to be None when nz is 0.')
    cell = DNACell(inputs, hparams)
    outputs, _ = tf.nn.dynamic_rnn(cell, inputs, dtype=tf.float32,
                                   swap_memory=False, time_major=True)
    if hparams.nz:
        inputs_samples = {name: flatten(tf.tile(input[:, None], [1, hparams.num_samples] + [1] * (input.shape.ndims - 1)), 1, 2)
                          for name, input in inputs.items() if name != 'zs'}
        inputs_samples['zs'] = tf.concat([sample_zs() for _ in range(hparams.num_samples)], axis=1)
        with tf.variable_scope(tf.get_variable_scope(), reuse=True):
            cell_samples = DNACell(inputs_samples, hparams)
            outputs_samples, _ = tf.nn.dynamic_rnn(cell_samples, inputs_samples, dtype=tf.float32,
                                                   swap_memory=False, time_major=True)
        gen_images_samples = outputs_samples['gen_images']
        gen_images_samples = tf.stack(tf.split(gen_images_samples, hparams.num_samples, axis=1), axis=-1)
        gen_images_samples_avg = tf.reduce_mean(gen_images_samples, axis=-1)
        outputs['gen_images_samples'] = gen_images_samples
        outputs['gen_images_samples_avg'] = gen_images_samples_avg
    # the RNN outputs generated images from time step 1 to sequence_length,
    # but generator_fn should only return images past context_frames
    outputs = {name: output[hparams.context_frames - 1:] for name, output in outputs.items()}
    gen_images = outputs['gen_images']
    outputs['ground_truth_sampling_mean'] = tf.reduce_mean(tf.to_float(cell.ground_truth[hparams.context_frames:]))
    return gen_images, outputs
コード例 #3
0
def generator_fn(inputs, hparams=None):
    images = inputs['images']
    with tf.variable_scope('encoder'):
        features = tf.map_fn(create_pspnet50_encoder, images)
    features = tf.stop_gradient(features)

    inputs = dict(inputs)
    inputs['features'] = features
    inputs = {
        name: tf_utils.maybe_pad_or_slice(input, hparams.sequence_length - 1)
        for name, input in inputs.items()
    }

    cell = DynamicsCell(inputs, hparams)
    outputs, _ = tf.nn.dynamic_rnn(cell,
                                   inputs,
                                   dtype=tf.float32,
                                   swap_memory=False,
                                   time_major=True)
    # the RNN outputs generated images from time step 1 to sequence_length,
    # but generator_fn should only return images past context_frames
    outputs = {
        name: output[hparams.context_frames - 1:]
        for name, output in outputs.items()
    }
    outputs['ground_truth_sampling_mean'] = tf.reduce_mean(
        tf.to_float(cell.ground_truth[hparams.context_frames:]))

    gen_features = outputs['gen_features']
    with tf.variable_scope('decoder') as decoder_scope:
        gen_images = tf.map_fn(
            create_decoder,
            tf.stop_gradient(gen_features))  # TODO: stop gradient for decoder?
    with tf.variable_scope(decoder_scope, reuse=True):
        gen_images_dec = tf.map_fn(create_decoder, features)
    outputs['gen_images'] = gen_images
    outputs['gen_images_dec'] = gen_images_dec
    outputs['features'] = features
    return gen_images, outputs
コード例 #4
0
def generator_fn(inputs, hparams=None):
    batch_size = inputs['images'].shape[1].value
    cell = Pix2PixCell(inputs, hparams)
    inputs = OrderedDict([
        (name, tf_utils.maybe_pad_or_slice(input, hparams.sequence_length - 1))
        for name, input in inputs.items() if name in ('images', 'actions')
    ])
    (gen_images, gen_inputs), _ = \
        tf.nn.dynamic_rnn(cell, inputs, sequence_length=[hparams.sequence_length - 1] * batch_size,
                          dtype=tf.float32, swap_memory=False, time_major=True)
    # the RNN outputs generated images from time step 1 to sequence_length,
    # but generator_fn should only return images past context_frames
    outputs = {
        'gen_images':
        gen_images[hparams.context_frames - 1:],
        'gen_inputs':
        gen_inputs[hparams.context_frames - 1:],
        'ground_truth_sampling_mean':
        tf.reduce_mean(tf.to_float(cell.ground_truth[hparams.context_frames:]))
    }
    gen_images = outputs['gen_images']
    return gen_images, outputs
コード例 #5
0
def generator_fn(inputs, mode, hparams=None):
    inputs = {name: tf_utils.maybe_pad_or_slice(input, hparams.sequence_length - 1)
              for name, input in inputs.items()}
    images = inputs['images']
    input_images = tf.concat(tf.unstack(images[:hparams.context_frames], axis=0), axis=-1)

    gen_states = []
    for t in range(hparams.context_frames, hparams.sequence_length):
        state_action = []
        if 'actions' in inputs:
            state_action.append(inputs['actions'][t - 1])
        if 'states' in inputs:
            state_action.append(gen_states[-1] if gen_states else inputs['states'][t - 1])
            state_action = tf.concat(state_action, axis=-1)
            with tf.name_scope('gen_states'):
                with tf.variable_scope('state_pred%d' % t):
                    gen_state = dense(state_action, inputs['states'].shape[-1])
            gen_states.append(gen_state)

    states_actions = []
    if 'actions' in inputs:
        states_actions += tf.unstack(inputs['actions'][:hparams.sequence_length - 1], axis=0)
    if 'states' in inputs:
        states_actions += tf.unstack(inputs['states'][:hparams.context_frames], axis=0)
        states_actions += gen_states
    if states_actions:
        states_actions = tf.concat(states_actions, axis=-1)
        # don't backpropagate the convnet through the state dynamics
        states_actions = tf.stop_gradient(states_actions)
    else:
        states_actions = tf.zeros([images.shape[1], 0])

    with slim.arg_scope([slim.conv2d],
                        activation_fn=tf.nn.relu,
                        weights_initializer=tf.truncated_normal_initializer(0.0, 0.01),
                        weights_regularizer=slim.l2_regularizer(0.0001)):
        batch_norm_params = {
            'decay': 0.9997,
            'epsilon': 0.001,
            'is_training': mode == 'train',
        }
        with slim.arg_scope([slim.batch_norm], is_training=mode == 'train', updates_collections=None):
            with slim.arg_scope([slim.conv2d], normalizer_fn=slim.batch_norm,
                                normalizer_params=batch_norm_params):
                h0 = slim.conv2d(input_images, 64, [5, 5], stride=1, scope='conv1')
                size0 = tf.shape(input_images)[-3:-1]

                h1 = slim.max_pool2d(h0, [2, 2], scope='pool1')
                h1 = slim.conv2d(h1, 128, [5, 5], stride=1, scope='conv2')
                size1 = tf.shape(h1)[-3:-1]

                h2 = slim.max_pool2d(h1, [2, 2], scope='pool2')
                h2 = slim.conv2d(h2, 256, [3, 3], stride=1, scope='conv3')
                size2 = tf.shape(h2)[-3:-1]

                h3 = slim.max_pool2d(h2, [2, 2], scope='pool3')
                h3 = tile_concat([h3, states_actions[:, None, None, :]], axis=-1)
                h3 = slim.conv2d(h3, 256, [3, 3], stride=1, scope='conv4')

                h4 = tf.image.resize_bilinear(h3, size2)
                h4 = tf.concat([h4, h2], axis=-1)
                h4 = slim.conv2d(h4, 256, [3, 3], stride=1, scope='conv5')

                h5 = tf.image.resize_bilinear(h4, size1)
                h5 = tf.concat([h5, h1], axis=-1)
                h5 = slim.conv2d(h5, 128, [5, 5], stride=1, scope='conv6')

                h6 = tf.image.resize_bilinear(h5, size0)
                h6 = tf.concat([h6, h0], axis=-1)
                h6 = slim.conv2d(h6, 64, [5, 5], stride=1, scope='conv7')

        extrap_length = hparams.sequence_length - hparams.context_frames
        flows_masks = slim.conv2d(h6, 5 * extrap_length, [5, 5], stride=1, activation_fn=tf.tanh,
                                  normalizer_fn=None, scope='conv8')

    flows_masks = tf.split(flows_masks, extrap_length, axis=-1)

    gen_images = []
    gen_flows_1 = []
    gen_flows_2 = []
    masks = []
    for flows_mask in flows_masks:
        flow_1, flow_2, mask = tf.split(flows_mask, [2, 2, 1], axis=-1)
        gen_flows_1.append(flow_1)
        gen_flows_2.append(flow_2)
        mask = 0.5 * (1.0 + mask)
        masks.append(mask)

        linspace_x = tf.linspace(-1.0, 1.0, size0[1])
        linspace_x.set_shape(input_images.shape[-2])
        linspace_y = tf.linspace(-1.0, 1.0, size0[0])
        linspace_y.set_shape(input_images.shape[-3])
        grid_x, grid_y = tf.meshgrid(linspace_x, linspace_y)

        coor_x_1 = grid_x[None, :, :] + flow_1[:, :, :, 0]
        coor_y_1 = grid_y[None, :, :] + flow_1[:, :, :, 1]

        coor_x_2 = grid_x[None, :, :] + flow_2[:, :, :, 0]
        coor_y_2 = grid_y[None, :, :] + flow_2[:, :, :, 1]

        output_1 = bilinear_interp(images[0], coor_x_1, coor_y_1, 'interpolate')
        output_2 = bilinear_interp(images[1], coor_x_2, coor_y_2, 'interpolate')

        gen_image = mask * output_1 + (1.0 - mask) * output_2
        gen_images.append(gen_image)
    gen_images = tf.stack(gen_images, axis=0)
    gen_flows_1 = tf.stack(gen_flows_1, axis=0)
    gen_flows_2 = tf.stack(gen_flows_2, axis=0)
    masks = tf.stack(masks, axis=0)

    outputs = {
        'gen_images': gen_images,
        'gen_flows_1': gen_flows_1,
        'gen_flows_2': gen_flows_2,
        'masks': masks,
    }
    if 'states' in inputs:
        gen_states = tf.stack(gen_states, axis=0)
        outputs['gen_states'] = gen_states
    return gen_images, outputs