def discriminator_fn(targets, inputs=None, hparams=None):
    batch_size = targets.shape[1].value
    # sort of hack to ensure that the same t_sample is used for all the
    # discriminators that are given the same inputs
    if 't_sample' in inputs:
        t_sample = inputs['t_sample']
    else:
        t_sample = tf.random_uniform([batch_size],
                                     minval=0,
                                     maxval=targets.shape[0].value,
                                     dtype=tf.int32)
        inputs['t_sample'] = t_sample
    image_sample = tf.gather_nd(
        targets, tf.stack([t_sample, tf.range(batch_size)], axis=1))

    if 't_start' in inputs:
        t_start = inputs['t_start']
    else:
        t_start = tf.random_uniform([batch_size],
                                    minval=0,
                                    maxval=targets.shape[0].value -
                                    hparams.clip_length + 1,
                                    dtype=tf.int32)
        inputs['t_start'] = t_start
    t_start_indices = tf.stack([t_start, tf.range(batch_size)], axis=1)
    t_offset_indices = tf.stack([
        tf.range(hparams.clip_length),
        tf.zeros(hparams.clip_length, dtype=tf.int32)
    ],
                                axis=1)
    indices = tf.expand_dims(t_start_indices, axis=0) + tf.expand_dims(
        t_offset_indices, axis=1)
    clip_sample = tf.reshape(tf.gather_nd(targets, flatten(
        indices, 0, 1)), [hparams.clip_length] + targets.shape.as_list()[1:])

    outputs = {}
    if hparams.image_sn_gan_weight or hparams.image_sn_vae_gan_weight:
        image_features = create_image_sn_discriminator(image_sample,
                                                       ndf=hparams.ndf)
        image_features, image_logits = image_features[:-1], image_features[-1]
        outputs['discrim_image_sn_logits'] = tf.expand_dims(
            image_logits, axis=0)  # expand dims for the time dimension
        with tf.variable_scope(tf.get_variable_scope(), reuse=True):
            images_features = create_image_sn_discriminator(flatten(
                targets, 0, 1),
                                                            ndf=hparams.ndf)
        images_features = images_features[:-1]
        for i, images_feature in enumerate(images_features):
            images_feature = tf.reshape(
                images_feature, targets.shape[:2].as_list() +
                images_feature.shape[1:].as_list())
            outputs['discrim_image_sn_feature%d' % i] = images_feature
    if hparams.video_sn_gan_weight or hparams.video_sn_vae_gan_weight:
        video_features = create_video_sn_discriminator(clip_sample,
                                                       ndf=hparams.ndf)
        video_features, video_logits = video_features[:-1], video_features[-1]
        outputs['discrim_video_sn_logits'] = video_logits
        for i, video_feature in enumerate(video_features):
            outputs['discrim_video_sn_feature%d' % i] = video_feature
    return None, outputs
Esempio n. 2
0
def create_encoder(inputs,
                   e_net='legacy',
                   use_e_rnn=False,
                   rnn='lstm',
                   **kwargs):
    assert inputs.shape.ndims == 5
    batch_shape = inputs.shape[:-3].as_list()
    inputs = flatten(inputs, 0, len(batch_shape) - 1)
    unflatten = lambda x: tf.reshape(x, batch_shape + x.shape.as_list()[1:])

    if use_e_rnn:
        if e_net == 'legacy':
            kwargs.pop('n_layers', None)  # unused
            h = create_legacy_encoder(inputs, include_top=False, **kwargs)
            with tf.variable_scope('h4'):
                h = dense(h, kwargs['nef'] * 4)
        elif e_net == 'n_layer':
            h = create_n_layer_encoder(inputs, include_top=False, **kwargs)
            with tf.variable_scope('layer_%d' % (kwargs['n_layers'] + 1)):
                h = dense(h, kwargs['nef'] * 4)
        else:
            raise ValueError('Invalid encoder net %s' % e_net)

        if rnn == 'lstm':
            RNNCell = tf.contrib.rnn.BasicLSTMCell
        elif rnn == 'gru':
            RNNCell = tf.contrib.rnn.GRUCell
        else:
            raise NotImplementedError

        h = nest.map_structure(unflatten, h)
        for i in range(2):
            with tf.variable_scope('%s_h%d' % (rnn, i)):
                rnn_cell = RNNCell(kwargs['nef'] * 4)
                h, _ = tf.nn.dynamic_rnn(rnn_cell,
                                         h,
                                         dtype=tf.float32,
                                         time_major=True)
        h = flatten(h, 0, len(batch_shape) - 1)

        with tf.variable_scope('z_mu'):
            z_mu = dense(h, kwargs['nz'])
        with tf.variable_scope('z_log_sigma_sq'):
            z_log_sigma_sq = dense(h, kwargs['nz'])
            z_log_sigma_sq = tf.clip_by_value(z_log_sigma_sq, -10, 10)
        outputs = {'enc_zs_mu': z_mu, 'enc_zs_log_sigma_sq': z_log_sigma_sq}
    else:
        if e_net == 'legacy':
            kwargs.pop('n_layers', None)  # unused
            outputs = create_legacy_encoder(inputs, include_top=True, **kwargs)
        elif e_net == 'n_layer':
            outputs = create_n_layer_encoder(inputs,
                                             include_top=True,
                                             **kwargs)
        else:
            raise ValueError('Invalid encoder net %s' % e_net)

    outputs = nest.map_structure(unflatten, outputs)
    return outputs
def apply_dna_kernels(image, kernels, dilation_rate=(1, 1)):
    """
    Args:
        image: A 4-D tensor of shape
            `[batch, in_height, in_width, in_channels]`.
        kernels: A 6-D of shape
            `[batch, in_height, in_width, kernel_size[0], kernel_size[1], num_transformed_images]`.

    Returns:
        A list of `num_transformed_images` 4-D tensors, each of shape
            `[batch, in_height, in_width, in_channels]`.
    """
    dilation_rate = list(dilation_rate) if isinstance(dilation_rate, (tuple, list)) else [dilation_rate] * 2
    batch_size, height, width, color_channels = image.get_shape().as_list()
    batch_size, height, width, kernel_height, kernel_width, num_transformed_images = kernels.get_shape().as_list()
    kernel_size = [kernel_height, kernel_width]

    # Flatten the spatial dimensions.
    kernels_reshaped = tf.reshape(kernels, [batch_size, height, width,
                                            kernel_size[0] * kernel_size[1], num_transformed_images])
    image_padded = pad2d(image, kernel_size, rate=dilation_rate, padding='SAME', mode='SYMMETRIC')
    # Combine channel and batch dimensions into the first dimension.
    image_transposed = tf.transpose(image_padded, [3, 0, 1, 2])
    image_reshaped = flatten(image_transposed, 0, 1)[..., None]
    patches_reshaped = tf.extract_image_patches(image_reshaped, ksizes=[1] + kernel_size + [1],
                                                strides=[1] * 4, rates=[1] + dilation_rate + [1], padding='VALID')
    # Separate channel and batch dimensions, and move channel dimension.
    patches_transposed = tf.reshape(patches_reshaped, [color_channels, batch_size, height, width, kernel_size[0] * kernel_size[1]])
    patches = tf.transpose(patches_transposed, [1, 2, 3, 0, 4])
    # Reduce along the spatial dimensions of the kernel.
    outputs = tf.matmul(patches, kernels_reshaped)
    outputs = tf.unstack(outputs, axis=-1)
    return outputs
def create_legacy_encoder(inputs,
                          nz=8,
                          nef=64,
                          norm_layer='instance',
                          include_top=True):
    norm_layer = ops.get_norm_layer(norm_layer)

    with tf.variable_scope('h1'):
        h1 = conv_pool2d(inputs, nef, kernel_size=5, strides=2)
        h1 = norm_layer(h1)
        h1 = tf.nn.relu(h1)

    with tf.variable_scope('h2'):
        h2 = conv_pool2d(h1, nef * 2, kernel_size=5, strides=2)
        h2 = norm_layer(h2)
        h2 = tf.nn.relu(h2)

    with tf.variable_scope('h3'):
        h3 = conv_pool2d(h2, nef * 4, kernel_size=5, strides=2)
        h3 = norm_layer(h3)
        h3 = tf.nn.relu(h3)
        h3_flatten = flatten(h3)

    if include_top:
        with tf.variable_scope('z_mu'):
            z_mu = dense(h3_flatten, nz)
        with tf.variable_scope('z_log_sigma_sq'):
            z_log_sigma_sq = dense(h3_flatten, nz)
            z_log_sigma_sq = tf.clip_by_value(z_log_sigma_sq, -10, 10)
        outputs = {'enc_zs_mu': z_mu, 'enc_zs_log_sigma_sq': z_log_sigma_sq}
    else:
        outputs = h3_flatten
    return outputs
Esempio n. 5
0
def create_pspnet50_encoder(inputs):
    should_flatten = inputs.shape.ndims > 4
    if should_flatten:
        batch_shape = inputs.shape[:-3].as_list()
        inputs = flatten(inputs, 0, len(batch_shape) - 1)

    outputs = pspnet_network.pspnet(inputs, resnet_layers=50)

    if should_flatten:
        outputs = tf.reshape(outputs,
                             batch_shape + outputs.shape.as_list()[1:])
    return outputs
Esempio n. 6
0
def generator_fn(inputs, hparams=None):
    batch_size = inputs['images'].shape[1].value
    inputs = {
        name: tf_utils.maybe_pad_or_slice(input, hparams.sequence_length - 1)
        for name, input in inputs.items()
    }

    with tf.variable_scope('gru'):
        gru_cell = tf.nn.rnn_cell.GRUCell(hparams.dim_z_motion)

    if hparams.context_frames:
        with tf.variable_scope('content_encoder'):
            z_c = create_encoder(
                inputs['images'][0],  # first context image for content encoder
                nef=hparams.nef,
                norm_layer=hparams.norm_layer,
                dim_z=hparams.dim_z_content)
        with tf.variable_scope('initial_motion_encoder'):
            h_0 = create_encoder(
                inputs['images'][hparams.context_frames -
                                 1],  # last context image for motion encoder
                nef=hparams.nef,
                norm_layer=hparams.norm_layer,
                dim_z=hparams.dim_z_motion)
    else:  # unconditional case
        z_c = tf.random_normal([batch_size, hparams.dim_z_content])
        h_0 = gru_cell.zero_state(batch_size, tf.float32)

    h_t = [h_0]
    for t in range(hparams.context_frames - 1, hparams.sequence_length - 1):
        with tf.variable_scope('gru', reuse=t > hparams.context_frames - 1):
            e_t = tf.random_normal([batch_size, hparams.dim_z_motion])
            if 'actions' in inputs:
                e_t = tf.concat(inputs['actions'][t], axis=-1)
            h_t.append(gru_cell(
                e_t, h_t[-1])[1])  # the output and state is the same in GRUs
    z_m = tf.stack(h_t[1:], axis=0)
    z = tf.concat([
        tf.tile(z_c[None, :, :],
                [hparams.sequence_length - hparams.context_frames, 1, 1]), z_m
    ],
                  axis=-1)

    z_flatten = flatten(z[:, :, None, None, :], 0, 1)
    gen_images_flatten = create_generator(z_flatten,
                                          ngf=hparams.ngf,
                                          norm_layer=hparams.norm_layer)
    gen_images = tf.reshape(gen_images_flatten, [-1, batch_size] +
                            gen_images_flatten.shape.as_list()[1:])

    outputs = {'gen_images': gen_images}
    return gen_images, outputs
Esempio n. 7
0
def create_discriminator(discrim_targets,
                         discrim_inputs=None,
                         d_net='legacy',
                         **kwargs):
    should_flatten = discrim_targets.shape.ndims > 4
    if should_flatten:
        ndims = discrim_targets.shape.ndims
        batch_shape = discrim_targets.shape[:-3].as_list()
        discrim_targets = flatten(discrim_targets, 0, len(batch_shape) - 1)
        if discrim_inputs is not None:
            assert discrim_inputs.shape.ndims == ndims
            assert discrim_inputs.shape[:-3].as_list() == batch_shape
            discrim_inputs = flatten(discrim_inputs, 0, len(batch_shape) - 1)

    if d_net == 'legacy':
        kwargs.pop('n_layers', None)  # unused
        features = create_legacy_discriminator(discrim_targets, discrim_inputs,
                                               **kwargs)
    elif d_net == 'n_layer':
        kwargs.pop('downsample_layer', None)  # unused
        n_layers = kwargs.pop('n_layers', None)
        if not n_layers:
            scale_size = min(*discrim_targets.shape.as_list()[1:3])
            n_layers = int(np.log2(scale_size // 32))
        features = create_n_layer_discriminator(discrim_targets,
                                                discrim_inputs,
                                                n_layers=n_layers,
                                                **kwargs)
    else:
        raise ValueError('Invalid discriminator net %s' % d_net)

    if should_flatten:
        features = nest.map_structure(
            lambda x: tf.reshape(x, batch_shape + x.shape.as_list()[1:]),
            features)
    return features
def generator_fn(inputs, outputs_enc=None, hparams=None):
    batch_size = inputs['images'].shape[1].value
    inputs = {name: tf_utils.maybe_pad_or_slice(input, hparams.sequence_length - 1)
              for name, input in inputs.items()}
    if hparams.nz:
        def sample_zs():
            if outputs_enc is None:
                zs = tf.random_normal([hparams.sequence_length - 1, batch_size, hparams.nz], 0, 1)
            else:
                enc_zs_mu = outputs_enc['enc_zs_mu']
                enc_zs_log_sigma_sq = outputs_enc['enc_zs_log_sigma_sq']
                eps = tf.random_normal([hparams.sequence_length - 1, batch_size, hparams.nz], 0, 1)
                zs = enc_zs_mu + tf.sqrt(tf.exp(enc_zs_log_sigma_sq)) * eps
            return zs
        inputs['zs'] = sample_zs()
    else:
        if outputs_enc is not None:
            raise ValueError('outputs_enc has to be None when nz is 0.')
    cell = DNACell(inputs, hparams)
    outputs, _ = tf.nn.dynamic_rnn(cell, inputs, dtype=tf.float32,
                                   swap_memory=False, time_major=True)
    if hparams.nz:
        inputs_samples = {name: flatten(tf.tile(input[:, None], [1, hparams.num_samples] + [1] * (input.shape.ndims - 1)), 1, 2)
                          for name, input in inputs.items() if name != 'zs'}
        inputs_samples['zs'] = tf.concat([sample_zs() for _ in range(hparams.num_samples)], axis=1)
        with tf.variable_scope(tf.get_variable_scope(), reuse=True):
            cell_samples = DNACell(inputs_samples, hparams)
            outputs_samples, _ = tf.nn.dynamic_rnn(cell_samples, inputs_samples, dtype=tf.float32,
                                                   swap_memory=False, time_major=True)
        gen_images_samples = outputs_samples['gen_images']
        gen_images_samples = tf.stack(tf.split(gen_images_samples, hparams.num_samples, axis=1), axis=-1)
        gen_images_samples_avg = tf.reduce_mean(gen_images_samples, axis=-1)
        outputs['gen_images_samples'] = gen_images_samples
        outputs['gen_images_samples_avg'] = gen_images_samples_avg
    # the RNN outputs generated images from time step 1 to sequence_length,
    # but generator_fn should only return images past context_frames
    outputs = {name: output[hparams.context_frames - 1:] for name, output in outputs.items()}
    gen_images = outputs['gen_images']
    outputs['ground_truth_sampling_mean'] = tf.reduce_mean(tf.to_float(cell.ground_truth[hparams.context_frames:]))
    return gen_images, outputs
Esempio n. 9
0
 def generator_loss_fn(self, inputs, outputs, targets):
     hparams = self.hparams
     gen_losses = OrderedDict()
     if hparams.l1_weight or hparams.l2_weight or hparams.vgg_cdist_weight:
         gen_images = outputs.get('gen_images_enc', outputs['gen_images'])
         target_images = targets
     if hparams.l1_weight:
         gen_l1_loss = vp.losses.l1_loss(gen_images, target_images)
         gen_losses["gen_l1_loss"] = (gen_l1_loss, hparams.l1_weight)
     if hparams.l2_weight:
         gen_l2_loss = vp.losses.l2_loss(gen_images, target_images)
         gen_losses["gen_l2_loss"] = (gen_l2_loss, hparams.l2_weight)
     if hparams.vgg_cdist_weight:
         gen_vgg_cdist_loss = vp.metrics.vgg_cosine_distance(
             gen_images, target_images)
         gen_losses['gen_vgg_cdist_loss'] = (gen_vgg_cdist_loss,
                                             hparams.vgg_cdist_weight)
     if hparams.feature_l2_weight:
         gen_features = outputs.get('gen_features_enc',
                                    outputs['gen_features'])
         target_features = outputs['features'][hparams.context_frames:]
         gen_feature_l2_loss = vp.losses.l2_loss(gen_features,
                                                 target_features)
         gen_losses["gen_feature_l2_loss"] = (gen_feature_l2_loss,
                                              hparams.feature_l2_weight)
     if hparams.ae_l2_weight:
         gen_images_dec = outputs.get(
             'gen_images_dec_enc',
             outputs['gen_images_dec'])  # they both should be the same
         target_images = inputs['images']
         gen_ae_l2_loss = vp.losses.l2_loss(gen_images_dec, target_images)
         gen_losses["gen_ae_l2_loss"] = (gen_ae_l2_loss,
                                         hparams.ae_l2_weight)
     if hparams.state_weight:
         gen_states = outputs.get('gen_states_enc', outputs['gen_states'])
         target_states = inputs['states'][hparams.context_frames:]
         gen_state_loss = vp.losses.l2_loss(gen_states, target_states)
         gen_losses["gen_state_loss"] = (gen_state_loss,
                                         hparams.state_weight)
     if hparams.tv_weight:
         gen_flows = outputs.get('gen_flows_enc', outputs['gen_flows'])
         gen_flows_reshaped = flatten(flatten(gen_flows, 0, 1), -2)
         gen_tv_loss = tf.reduce_mean(
             tf.image.total_variation(gen_flows_reshaped))
         gen_losses['gen_tv_loss'] = (gen_tv_loss, hparams.tv_weight)
     gan_weights = {
         '': hparams.gan_weight,
         '_tuple': hparams.tuple_gan_weight,
         '_image': hparams.image_gan_weight,
         '_video': hparams.video_gan_weight,
         '_acvideo': hparams.acvideo_gan_weight,
         '_image_sn': hparams.image_sn_gan_weight,
         '_images_sn': hparams.images_sn_gan_weight,
         '_video_sn': hparams.video_sn_gan_weight
     }
     for infix, gan_weight in gan_weights.items():
         if gan_weight:
             gen_gan_loss = vp.losses.gan_loss(
                 outputs['discrim%s_logits_fake' % infix], 1.0,
                 hparams.gan_loss_type)
             gen_losses["gen%s_gan_loss" % infix] = (gen_gan_loss,
                                                     gan_weight)
         if gan_weight and (hparams.gan_feature_l2_weight
                            or hparams.gan_feature_cdist_weight):
             i_feature = 0
             discrim_features_fake = []
             discrim_features_real = []
             while True:
                 discrim_feature_fake = outputs.get(
                     'discrim%s_feature%d_fake' % (infix, i_feature))
                 discrim_feature_real = outputs.get(
                     'discrim%s_feature%d_real' % (infix, i_feature))
                 if discrim_feature_fake is None or discrim_feature_real is None:
                     break
                 discrim_features_fake.append(discrim_feature_fake)
                 discrim_features_real.append(discrim_feature_real)
                 i_feature += 1
             if hparams.gan_feature_l2_weight:
                 gen_gan_feature_l2_loss = sum([
                     vp.losses.l2_loss(discrim_feature_fake,
                                       discrim_feature_real)
                     for discrim_feature_fake, discrim_feature_real in zip(
                         discrim_features_fake, discrim_features_real)
                 ])
                 gen_losses["gen%s_gan_feature_l2_loss" %
                            infix] = (gen_gan_feature_l2_loss,
                                      hparams.gan_feature_l2_weight)
             if hparams.gan_feature_cdist_weight:
                 gen_gan_feature_cdist_loss = sum([
                     vp.metrics.cosine_distance(discrim_feature_fake,
                                                discrim_feature_real)
                     for discrim_feature_fake, discrim_feature_real in zip(
                         discrim_features_fake, discrim_features_real)
                 ])
                 gen_losses["gen%s_gan_feature_cdist_loss" %
                            infix] = (gen_gan_feature_cdist_loss,
                                      hparams.gan_feature_cdist_weight)
     vae_gan_weights = {
         '': hparams.vae_gan_weight,
         '_tuple': hparams.tuple_vae_gan_weight,
         '_image': hparams.image_vae_gan_weight,
         '_video': hparams.video_vae_gan_weight,
         '_acvideo': hparams.acvideo_vae_gan_weight,
         '_image_sn': hparams.image_sn_vae_gan_weight,
         '_images_sn': hparams.images_sn_vae_gan_weight,
         '_video_sn': hparams.video_sn_vae_gan_weight
     }
     for infix, vae_gan_weight in vae_gan_weights.items():
         if vae_gan_weight:
             gen_vae_gan_loss = vp.losses.gan_loss(
                 outputs['discrim%s_logits_enc_fake' % infix], 1.0,
                 hparams.gan_loss_type)
             gen_losses["gen%s_vae_gan_loss" % infix] = (gen_vae_gan_loss,
                                                         vae_gan_weight)
         if vae_gan_weight and (hparams.gan_feature_l2_weight
                                or hparams.gan_feature_cdist_weight):
             i_feature = 0
             discrim_features_enc_fake = []
             discrim_features_enc_real = []
             while True:
                 discrim_feature_enc_fake = outputs.get(
                     'discrim%s_feature%d_enc_fake' % (infix, i_feature))
                 discrim_feature_enc_real = outputs.get(
                     'discrim%s_feature%d_enc_real' % (infix, i_feature))
                 if discrim_feature_enc_fake is None or discrim_feature_enc_real is None:
                     break
                 discrim_features_enc_fake.append(discrim_feature_enc_fake)
                 discrim_features_enc_real.append(discrim_feature_enc_real)
                 i_feature += 1
             if hparams.gan_feature_l2_weight:
                 gen_vae_gan_feature_l2_loss = sum([
                     vp.losses.l2_loss(discrim_feature_enc_fake,
                                       discrim_feature_enc_real)
                     for discrim_feature_enc_fake, discrim_feature_enc_real
                     in zip(discrim_features_enc_fake,
                            discrim_features_enc_real)
                 ])
                 gen_losses["gen%s_vae_gan_feature_l2_loss" %
                            infix] = (gen_vae_gan_feature_l2_loss,
                                      hparams.gan_feature_l2_weight)
             if hparams.gan_feature_cdist_weight:
                 gen_vae_gan_feature_cdist_loss = sum([
                     vp.metrics.cosine_distance(discrim_feature_enc_fake,
                                                discrim_feature_enc_real)
                     for discrim_feature_enc_fake, discrim_feature_enc_real
                     in zip(discrim_features_enc_fake,
                            discrim_features_enc_real)
                 ])
                 gen_losses["gen%s_vae_gan_feature_cdist_loss" %
                            infix] = (gen_vae_gan_feature_cdist_loss,
                                      hparams.gan_feature_cdist_weight)
     if hparams.kl_weight:
         gen_kl_loss = vp.losses.kl_loss(outputs['enc_zs_mu'],
                                         outputs['enc_zs_log_sigma_sq'])
         gen_losses["gen_kl_loss"] = (gen_kl_loss, self.kl_weight
                                      )  # possibly annealed kl_weight
     if hparams.z_l1_weight:
         gen_z_l1_loss = vp.losses.l1_loss(outputs['gen_enc_zs_mu'],
                                           outputs['gen_zs_random'])
         gen_losses["gen_z_l1_loss"] = (gen_z_l1_loss, hparams.z_l1_weight)
     return gen_losses
    def call(self, inputs, states):
        norm_layer = ops.get_norm_layer(self.hparams.norm_layer)
        downsample_layer = ops.get_downsample_layer(
            self.hparams.downsample_layer)
        upsample_layer = ops.get_upsample_layer(self.hparams.upsample_layer)
        image_shape = inputs['images'].get_shape().as_list()
        batch_size, height, width, color_channels = image_shape

        time = states['time']
        with tf.control_dependencies([tf.assert_equal(time[1:], time[0])]):
            t = tf.to_int32(tf.identity(time[0]))

        if 'states' in inputs:
            state = tf.where(self.ground_truth[t], inputs['states'],
                             states['gen_state'])

        state_action = []
        state_action_z = []
        if 'actions' in inputs:
            state_action.append(inputs['actions'])
            state_action_z.append(inputs['actions'])
        if 'states' in inputs:
            state_action.append(state)
            # don't backpropagate the convnet through the state dynamics
            state_action_z.append(tf.stop_gradient(state))

        if 'zs' in inputs:
            if self.hparams.use_rnn_z:
                with tf.variable_scope('%s_z' % self.hparams.rnn):
                    rnn_z, rnn_z_state = self._rnn_func(
                        inputs['zs'], states['rnn_z_state'], self.hparams.nz)
                state_action_z.append(rnn_z)
            else:
                state_action_z.append(inputs['zs'])

        def concat(tensors, axis):
            if len(tensors) == 0:
                return tf.zeros([batch_size, 0])
            elif len(tensors) == 1:
                return tensors[0]
            else:
                return tf.concat(tensors, axis=axis)

        state_action = concat(state_action, axis=-1)
        state_action_z = concat(state_action_z, axis=-1)

        image_views = []
        first_image_views = []
        if 'pix_distribs' in inputs:
            pix_distrib_views = []
        for i in range(self.hparams.num_views):
            suffix = '%d' % i if i > 0 else ''
            image_view = tf.where(
                self.ground_truth[t], inputs['images' + suffix],
                states['gen_image' + suffix])  # schedule sampling (if any)
            image_views.append(image_view)
            first_image_views.append(self.inputs['images' + suffix][0])
            if 'pix_distribs' in inputs:
                pix_distrib_view = tf.where(self.ground_truth[t],
                                            inputs['pix_distribs' + suffix],
                                            states['gen_pix_distrib' + suffix])
                pix_distrib_views.append(pix_distrib_view)

        outputs = {}
        new_states = {}
        all_layers = []
        for i in range(self.hparams.num_views):
            suffix = '%d' % i if i > 0 else ''
            conv_rnn_states = states['conv_rnn_states' + suffix]
            layers = []
            new_conv_rnn_states = []
            for i, (out_channels,
                    use_conv_rnn) in enumerate(self.encoder_layer_specs):
                with tf.variable_scope('h%d' % i + suffix):
                    if i == 0:
                        # all image views and the first image corresponding to this view only
                        h = tf.concat(image_views + first_image_views, axis=-1)
                        kernel_size = (5, 5)
                    else:
                        h = layers[-1][-1]
                        kernel_size = (3, 3)
                    if self.hparams.where_add == 'all' or (
                            self.hparams.where_add == 'input' and i == 0):
                        h = tile_concat([h, state_action_z[:, None, None, :]],
                                        axis=-1)
                    h = downsample_layer(h,
                                         out_channels,
                                         kernel_size=kernel_size,
                                         strides=(2, 2))
                    h = norm_layer(h)
                    h = tf.nn.relu(h)
                if use_conv_rnn:
                    conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)]
                    with tf.variable_scope('%s_h%d' %
                                           (self.hparams.conv_rnn, i) +
                                           suffix):
                        if self.hparams.where_add == 'all':
                            conv_rnn_h = tile_concat(
                                [h, state_action_z[:, None, None, :]], axis=-1)
                        else:
                            conv_rnn_h = h
                        conv_rnn_h, conv_rnn_state = self._conv_rnn_func(
                            conv_rnn_h, conv_rnn_state, out_channels)
                    new_conv_rnn_states.append(conv_rnn_state)
                layers.append((h, conv_rnn_h) if use_conv_rnn else (h, ))

            num_encoder_layers = len(layers)
            for i, (out_channels,
                    use_conv_rnn) in enumerate(self.decoder_layer_specs):
                with tf.variable_scope('h%d' % len(layers) + suffix):
                    if i == 0:
                        h = layers[-1][-1]
                    else:
                        h = tf.concat([
                            layers[-1][-1],
                            layers[num_encoder_layers - i - 1][-1]
                        ],
                                      axis=-1)
                    if self.hparams.where_add == 'all' or (
                            self.hparams.where_add == 'middle' and i == 0):
                        h = tile_concat([h, state_action_z[:, None, None, :]],
                                        axis=-1)
                    h = upsample_layer(h,
                                       out_channels,
                                       kernel_size=(3, 3),
                                       strides=(2, 2))
                    h = norm_layer(h)
                    h = tf.nn.relu(h)
                if use_conv_rnn:
                    conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)]
                    with tf.variable_scope(
                            '%s_h%d' % (self.hparams.conv_rnn, len(layers)) +
                            suffix):
                        if self.hparams.where_add == 'all':
                            conv_rnn_h = tile_concat(
                                [h, state_action_z[:, None, None, :]], axis=-1)
                        else:
                            conv_rnn_h = h
                        conv_rnn_h, conv_rnn_state = self._conv_rnn_func(
                            conv_rnn_h, conv_rnn_state, out_channels)
                    new_conv_rnn_states.append(conv_rnn_state)
                layers.append((h, conv_rnn_h) if use_conv_rnn else (h, ))
            assert len(new_conv_rnn_states) == len(conv_rnn_states)

            new_states['conv_rnn_states' + suffix] = new_conv_rnn_states

            all_layers.append(layers)
            if self.hparams.shared_views:
                break

        for i in range(self.hparams.num_views):
            suffix = '%d' % i if i > 0 else ''
            if self.hparams.shared_views:
                layers, = all_layers
            else:
                layers = all_layers[i]

            image = image_views[i]
            last_images = states['last_images' + suffix][1:] + [image]
            if 'pix_distribs' in inputs:
                pix_distrib = pix_distrib_views[i]
                last_pix_distribs = states['last_pix_distribs' +
                                           suffix][1:] + [pix_distrib]

            if self.hparams.last_frames and self.hparams.num_transformed_images:
                if self.hparams.transformation == 'flow':
                    with tf.variable_scope('h%d_flow' % len(layers) + suffix):
                        h_flow = conv2d(layers[-1][-1],
                                        self.hparams.ngf,
                                        kernel_size=(3, 3),
                                        strides=(1, 1))
                        h_flow = norm_layer(h_flow)
                        h_flow = tf.nn.relu(h_flow)

                    with tf.variable_scope('flows' + suffix):
                        flows = conv2d(h_flow,
                                       2 * self.hparams.last_frames *
                                       self.hparams.num_transformed_images,
                                       kernel_size=(3, 3),
                                       strides=(1, 1))
                        flows = tf.reshape(flows, [
                            batch_size, height, width, 2,
                            self.hparams.last_frames *
                            self.hparams.num_transformed_images
                        ])
                else:
                    assert len(self.hparams.kernel_size) == 2
                    kernel_shape = list(self.hparams.kernel_size) + [
                        self.hparams.last_frames *
                        self.hparams.num_transformed_images
                    ]
                    if self.hparams.transformation == 'dna':
                        with tf.variable_scope('h%d_dna_kernel' % len(layers) +
                                               suffix):
                            h_dna_kernel = conv2d(layers[-1][-1],
                                                  self.hparams.ngf,
                                                  kernel_size=(3, 3),
                                                  strides=(1, 1))
                            h_dna_kernel = norm_layer(h_dna_kernel)
                            h_dna_kernel = tf.nn.relu(h_dna_kernel)

                        # Using largest hidden state for predicting untied conv kernels.
                        with tf.variable_scope('dna_kernels' + suffix):
                            kernels = conv2d(h_dna_kernel,
                                             np.prod(kernel_shape),
                                             kernel_size=(3, 3),
                                             strides=(1, 1))
                            kernels = tf.reshape(kernels,
                                                 [batch_size, height, width] +
                                                 kernel_shape)
                            kernels = kernels + identity_kernel(
                                self.hparams.kernel_size)[None, None,
                                                          None, :, :, None]
                        kernel_spatial_axes = [3, 4]
                    elif self.hparams.transformation == 'cdna':
                        with tf.variable_scope('cdna_kernels' + suffix):
                            smallest_layer = layers[num_encoder_layers - 1][-1]
                            kernels = dense(flatten(smallest_layer),
                                            np.prod(kernel_shape))
                            kernels = tf.reshape(kernels,
                                                 [batch_size] + kernel_shape)
                            kernels = kernels + identity_kernel(
                                self.hparams.kernel_size)[None, :, :, None]
                        kernel_spatial_axes = [1, 2]
                    else:
                        raise ValueError('Invalid transformation %s' %
                                         self.hparams.transformation)

                if self.hparams.transformation != 'flow':
                    with tf.name_scope('kernel_normalization' + suffix):
                        kernels = tf.nn.relu(kernels - RELU_SHIFT) + RELU_SHIFT
                        kernels /= tf.reduce_sum(kernels,
                                                 axis=kernel_spatial_axes,
                                                 keepdims=True)

            if self.hparams.generate_scratch_image:
                with tf.variable_scope('h%d_scratch' % len(layers) + suffix):
                    h_scratch = conv2d(layers[-1][-1],
                                       self.hparams.ngf,
                                       kernel_size=(3, 3),
                                       strides=(1, 1))
                    h_scratch = norm_layer(h_scratch)
                    h_scratch = tf.nn.relu(h_scratch)

                # Using largest hidden state for predicting a new image layer.
                # This allows the network to also generate one image from scratch,
                # which is useful when regions of the image become unoccluded.
                with tf.variable_scope('scratch_image' + suffix):
                    scratch_image = conv2d(h_scratch,
                                           color_channels,
                                           kernel_size=(3, 3),
                                           strides=(1, 1))
                    scratch_image = tf.nn.sigmoid(scratch_image)

            with tf.name_scope('transformed_images' + suffix):
                transformed_images = []
                if self.hparams.last_frames and self.hparams.num_transformed_images:
                    if self.hparams.transformation == 'flow':
                        transformed_images.extend(
                            apply_flows(last_images, flows))
                    else:
                        transformed_images.extend(
                            apply_kernels(last_images, kernels,
                                          self.hparams.dilation_rate))
                if self.hparams.prev_image_background:
                    transformed_images.append(image)
                if self.hparams.first_image_background and not self.hparams.context_images_background:
                    transformed_images.append(self.inputs['images' +
                                                          suffix][0])
                if self.hparams.context_images_background:
                    transformed_images.extend(
                        tf.unstack(
                            self.inputs['images' +
                                        suffix][:self.hparams.context_frames]))
                if self.hparams.generate_scratch_image:
                    transformed_images.append(scratch_image)

            if 'pix_distribs' in inputs:
                with tf.name_scope('transformed_pix_distribs' + suffix):
                    transformed_pix_distribs = []
                    if self.hparams.last_frames and self.hparams.num_transformed_images:
                        if self.hparams.transformation == 'flow':
                            transformed_pix_distribs.extend(
                                apply_flows(last_pix_distribs, flows))
                        else:
                            transformed_pix_distribs.extend(
                                apply_kernels(last_pix_distribs, kernels,
                                              self.hparams.dilation_rate))
                    if self.hparams.prev_image_background:
                        transformed_pix_distribs.append(pix_distrib)
                    if self.hparams.first_image_background and not self.hparams.context_images_background:
                        transformed_pix_distribs.append(
                            self.inputs['pix_distribs' + suffix][0])
                    if self.hparams.context_images_background:
                        transformed_pix_distribs.extend(
                            tf.unstack(self.inputs['pix_distribs' + suffix]
                                       [:self.hparams.context_frames]))
                    if self.hparams.generate_scratch_image:
                        transformed_pix_distribs.append(pix_distrib)

            with tf.name_scope('masks' + suffix):
                if len(transformed_images) > 1:
                    with tf.variable_scope('h%d_masks' % len(layers) + suffix):
                        h_masks = conv2d(layers[-1][-1],
                                         self.hparams.ngf,
                                         kernel_size=(3, 3),
                                         strides=(1, 1))
                        h_masks = norm_layer(h_masks)
                        h_masks = tf.nn.relu(h_masks)

                    with tf.variable_scope('masks' + suffix):
                        if self.hparams.dependent_mask:
                            h_masks = tf.concat([h_masks] + transformed_images,
                                                axis=-1)
                        masks = conv2d(h_masks,
                                       len(transformed_images),
                                       kernel_size=(3, 3),
                                       strides=(1, 1))
                        masks = tf.nn.softmax(masks)
                        masks = tf.split(masks,
                                         len(transformed_images),
                                         axis=-1)
                elif len(transformed_images) == 1:
                    masks = [tf.ones([batch_size, height, width, 1])]
                else:
                    raise ValueError(
                        "Either one of the following should be true: "
                        "last_frames and num_transformed_images, first_image_background, "
                        "prev_image_background, generate_scratch_image")

            with tf.name_scope('gen_images' + suffix):
                assert len(transformed_images) == len(masks)
                gen_image = tf.add_n([
                    transformed_image * mask for transformed_image, mask in
                    zip(transformed_images, masks)
                ])

            if 'pix_distribs' in inputs:
                with tf.name_scope('gen_pix_distribs' + suffix):
                    assert len(transformed_pix_distribs) == len(masks)
                    gen_pix_distrib = tf.add_n([
                        transformed_pix_distrib * mask
                        for transformed_pix_distrib, mask in zip(
                            transformed_pix_distribs, masks)
                    ])

                    if self.hparams.renormalize_pixdistrib:
                        gen_pix_distrib /= tf.reduce_sum(gen_pix_distrib,
                                                         axis=(1, 2),
                                                         keepdims=True)

            outputs['gen_images' + suffix] = gen_image
            outputs['transformed_images' + suffix] = tf.stack(
                transformed_images, axis=-1)
            outputs['masks' + suffix] = tf.stack(masks, axis=-1)
            if 'pix_distribs' in inputs:
                outputs['gen_pix_distribs' + suffix] = gen_pix_distrib
                outputs['transformed_pix_distribs' + suffix] = tf.stack(
                    transformed_pix_distribs, axis=-1)
            if self.hparams.transformation == 'flow':
                outputs['gen_flows' + suffix] = flows
                flows_transposed = tf.transpose(flows, [0, 1, 2, 4, 3])
                flows_rgb_transposed = tf_utils.flow_to_rgb(flows_transposed)
                flows_rgb = tf.transpose(flows_rgb_transposed, [0, 1, 2, 4, 3])
                outputs['gen_flows_rgb' + suffix] = flows_rgb

            new_states['gen_image' + suffix] = gen_image
            new_states['last_images' + suffix] = last_images
            if 'pix_distribs' in inputs:
                new_states['gen_pix_distrib' + suffix] = gen_pix_distrib
                new_states['last_pix_distribs' + suffix] = last_pix_distribs

        if 'states' in inputs:
            with tf.name_scope('gen_states'):
                with tf.variable_scope('state_pred'):
                    gen_state = dense(state_action,
                                      inputs['states'].shape[-1].value)

        if 'states' in inputs:
            outputs['gen_states'] = gen_state

        new_states['time'] = time + 1
        if 'zs' in inputs and self.hparams.use_rnn_z:
            new_states['rnn_z_state'] = rnn_z_state
        if 'states' in inputs:
            new_states['gen_state'] = gen_state
        return outputs, new_states
Esempio n. 11
0
    def call(self, inputs, states):
        norm_layer = ops.get_norm_layer(self.hparams.norm_layer)
        feature_shape = inputs['features'].get_shape().as_list()
        batch_size, height, width, feature_channels = feature_shape
        conv_rnn_states = states['conv_rnn_states']

        time = states['time']
        with tf.control_dependencies([tf.assert_equal(time[1:], time[0])]):
            t = tf.to_int32(tf.identity(time[0]))

        feature = tf.where(self.ground_truth[t], inputs['features'],
                           states['gen_feature'])  # schedule sampling (if any)
        if 'states' in inputs:
            state = tf.where(self.ground_truth[t], inputs['states'],
                             states['gen_state'])

        state_action = []
        state_action_z = []
        if 'actions' in inputs:
            state_action.append(inputs['actions'])
            state_action_z.append(inputs['actions'])
        if 'states' in inputs:
            state_action.append(state)
            # don't backpropagate the convnet through the state dynamics
            state_action_z.append(tf.stop_gradient(state))

        if 'zs' in inputs:
            if self.hparams.use_rnn_z:
                with tf.variable_scope('%s_z' % self.hparams.rnn):
                    rnn_z, rnn_z_state = self._rnn_func(
                        inputs['zs'], states['rnn_z_state'], self.hparams.nz)
                state_action_z.append(rnn_z)
            else:
                state_action_z.append(inputs['zs'])

        def concat(tensors, axis):
            if len(tensors) == 0:
                return tf.zeros([batch_size, 0])
            elif len(tensors) == 1:
                return tensors[0]
            else:
                return tf.concat(tensors, axis=axis)

        state_action = concat(state_action, axis=-1)
        state_action_z = concat(state_action_z, axis=-1)
        if 'actions' in inputs:
            gen_input = tile_concat(
                [feature, inputs['actions'][:, None, None, :]], axis=-1)
        else:
            gen_input = feature

        layers = []
        new_conv_rnn_states = []
        for i, (out_channels,
                use_conv_rnn) in enumerate(self.encoder_layer_specs):
            with tf.variable_scope('h%d' % i):
                if i == 0:
                    # h = tf.concat([feature, self.inputs['features'][0]], axis=-1)  # TODO: use first feature?
                    h = feature
                else:
                    h = layers[-1][-1]
                h = conv_pool2d(tile_concat(
                    [h, state_action_z[:, None, None, :]], axis=-1),
                                out_channels,
                                kernel_size=(3, 3),
                                strides=(2, 2))
                h = norm_layer(h)
                h = tf.nn.relu(h)
            if use_conv_rnn:
                conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)]
                with tf.variable_scope('%s_h%d' % (self.hparams.conv_rnn, i)):
                    conv_rnn_h, conv_rnn_state = self._conv_rnn_func(
                        tile_concat([h, state_action_z[:, None, None, :]],
                                    axis=-1), conv_rnn_state, out_channels)
                new_conv_rnn_states.append(conv_rnn_state)
            layers.append((h, conv_rnn_h) if use_conv_rnn else (h, ))

        num_encoder_layers = len(layers)
        for i, (out_channels,
                use_conv_rnn) in enumerate(self.decoder_layer_specs):
            with tf.variable_scope('h%d' % len(layers)):
                if i == 0:
                    h = layers[-1][-1]
                else:
                    h = tf.concat([
                        layers[-1][-1], layers[num_encoder_layers - i - 1][-1]
                    ],
                                  axis=-1)
                h = upsample_conv2d(tile_concat(
                    [h, state_action_z[:, None, None, :]], axis=-1),
                                    out_channels,
                                    kernel_size=(3, 3),
                                    strides=(2, 2))
                h = norm_layer(h)
                h = tf.nn.relu(h)
            if use_conv_rnn:
                conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)]
                with tf.variable_scope('%s_h%d' %
                                       (self.hparams.conv_rnn, len(layers))):
                    conv_rnn_h, conv_rnn_state = self._conv_rnn_func(
                        tile_concat([h, state_action_z[:, None, None, :]],
                                    axis=-1), conv_rnn_state, out_channels)
                new_conv_rnn_states.append(conv_rnn_state)
            layers.append((h, conv_rnn_h) if use_conv_rnn else (h, ))
        assert len(new_conv_rnn_states) == len(conv_rnn_states)

        if self.hparams.transformation == 'direct':
            with tf.variable_scope('h%d_direct' % len(layers)):
                h_direct = conv2d(layers[-1][-1],
                                  self.hparams.ngf,
                                  kernel_size=(3, 3),
                                  strides=(1, 1))
                h_direct = norm_layer(h_direct)
                h_direct = tf.nn.relu(h_direct)

            with tf.variable_scope('direct'):
                gen_feature = conv2d(h_direct,
                                     feature_channels,
                                     kernel_size=(3, 3),
                                     strides=(1, 1))
        else:
            if self.hparams.transformation == 'flow':
                with tf.variable_scope('h%d_flow' % len(layers)):
                    h_flow = conv2d(layers[-1][-1],
                                    self.hparams.ngf,
                                    kernel_size=(3, 3),
                                    strides=(1, 1))
                    h_flow = norm_layer(h_flow)
                    h_flow = tf.nn.relu(h_flow)

                with tf.variable_scope('flows'):
                    flows = conv2d(h_flow,
                                   2 * feature_channels,
                                   kernel_size=(3, 3),
                                   strides=(1, 1))
                    flows = tf.reshape(
                        flows,
                        [batch_size, height, width, 2, feature_channels])
                transformations = flows
            else:
                assert len(self.hparams.kernel_size) == 2
                kernel_shape = list(
                    self.hparams.kernel_size) + [feature_channels]
                if self.hparams.transformation == 'local':
                    with tf.variable_scope('h%d_local_kernel' % len(layers)):
                        h_local_kernel = conv2d(layers[-1][-1],
                                                self.hparams.ngf,
                                                kernel_size=(3, 3),
                                                strides=(1, 1))
                        h_local_kernel = norm_layer(h_local_kernel)
                        h_local_kernel = tf.nn.relu(h_local_kernel)

                    # Using largest hidden state for predicting untied conv kernels.
                    with tf.variable_scope('local_kernels'):
                        kernels = conv2d(h_local_kernel,
                                         np.prod(kernel_shape),
                                         kernel_size=(3, 3),
                                         strides=(1, 1))
                        kernels = tf.reshape(kernels,
                                             [batch_size, height, width] +
                                             kernel_shape)
                        kernels = kernels + identity_kernel(
                            self.hparams.kernel_size)[None, None, None, :, :,
                                                      None]
                elif self.hparams.transformation == 'conv':
                    with tf.variable_scope('conv_kernels'):
                        smallest_layer = layers[num_encoder_layers - 1][-1]
                        kernels = dense(flatten(smallest_layer),
                                        np.prod(kernel_shape))
                        kernels = tf.reshape(kernels,
                                             [batch_size] + kernel_shape)
                        kernels = kernels + identity_kernel(
                            self.hparams.kernel_size)[None, :, :, None]
                else:
                    raise ValueError('Invalid transformation %s' %
                                     self.hparams.transformation)
                transformations = kernels

            with tf.name_scope('gen_features'):
                if self.hparams.transformation == 'flow':

                    def apply_transformation(feature_and_flow):
                        feature, flow = feature_and_flow
                        return flow_ops.image_warp(feature[..., None], flow)
                else:

                    def apply_transformation(feature_and_kernel):
                        feature, kernel = feature_and_kernel
                        output, = apply_kernels(feature[..., None],
                                                kernel[..., None])
                        return tf.squeeze(output, axis=-1)

                gen_feature_transposed = tf.map_fn(
                    apply_transformation,
                    (tf.stack(tf.unstack(feature, axis=-1)),
                     tf.stack(tf.unstack(transformations, axis=-1))),
                    dtype=tf.float32)
                gen_feature = tf.stack(tf.unstack(gen_feature_transposed),
                                       axis=-1)

        # TODO: use norm and relu for generated features?
        gen_feature = norm_layer(gen_feature)
        gen_feature = tf.nn.relu(gen_feature)

        if 'states' in inputs:
            with tf.name_scope('gen_states'):
                with tf.variable_scope('state_pred'):
                    gen_state = dense(state_action,
                                      inputs['states'].shape[-1].value)

        outputs = {
            'gen_features': gen_feature,
            'gen_inputs': gen_input,
        }
        if 'states' in inputs:
            outputs['gen_states'] = gen_state
        if self.hparams.transformation == 'flow':
            outputs['gen_flows'] = flows

        new_states = {
            'time': time + 1,
            'gen_feature': gen_feature,
            'conv_rnn_states': new_conv_rnn_states,
        }
        if 'zs' in inputs and self.hparams.use_rnn_z:
            new_states['rnn_z_state'] = rnn_z_state
        if 'states' in inputs:
            new_states['gen_state'] = gen_state
        return outputs, new_states