def create_legacy_encoder(inputs,
                          nz=8,
                          nef=64,
                          norm_layer='instance',
                          include_top=True):
    norm_layer = ops.get_norm_layer(norm_layer)

    with tf.variable_scope('h1'):
        h1 = conv_pool2d(inputs, nef, kernel_size=5, strides=2)
        h1 = norm_layer(h1)
        h1 = tf.nn.relu(h1)

    with tf.variable_scope('h2'):
        h2 = conv_pool2d(h1, nef * 2, kernel_size=5, strides=2)
        h2 = norm_layer(h2)
        h2 = tf.nn.relu(h2)

    with tf.variable_scope('h3'):
        h3 = conv_pool2d(h2, nef * 4, kernel_size=5, strides=2)
        h3 = norm_layer(h3)
        h3 = tf.nn.relu(h3)
        h3_flatten = flatten(h3)

    if include_top:
        with tf.variable_scope('z_mu'):
            z_mu = dense(h3_flatten, nz)
        with tf.variable_scope('z_log_sigma_sq'):
            z_log_sigma_sq = dense(h3_flatten, nz)
            z_log_sigma_sq = tf.clip_by_value(z_log_sigma_sq, -10, 10)
        outputs = {'enc_zs_mu': z_mu, 'enc_zs_log_sigma_sq': z_log_sigma_sq}
    else:
        outputs = h3_flatten
    return outputs
 def _conv_rnn_func(self, inputs, state, filters):
     inputs_shape = inputs.get_shape().as_list()
     input_shape = inputs_shape[1:]
     if self.hparams.norm_layer == 'none':
         normalizer_fn = None
     else:
         normalizer_fn = ops.get_norm_layer(self.hparams.norm_layer)
     if self.hparams.conv_rnn == 'lstm':
         Conv2DRNNCell = BasicConv2DLSTMCell
     elif self.hparams.conv_rnn == 'gru':
         Conv2DRNNCell = Conv2DGRUCell
     else:
         raise NotImplementedError
     if self.hparams.ablation_conv_rnn_norm:
         conv_rnn_cell = Conv2DRNNCell(input_shape,
                                       filters,
                                       kernel_size=(5, 5),
                                       reuse=tf.get_variable_scope().reuse)
         h, state = conv_rnn_cell(inputs, state)
         outputs = (normalizer_fn(h), state)
     else:
         conv_rnn_cell = Conv2DRNNCell(
             input_shape,
             filters,
             kernel_size=(5, 5),
             normalizer_fn=normalizer_fn,
             separate_norms=self.hparams.norm_layer == 'layer',
             reuse=tf.get_variable_scope().reuse)
         outputs = conv_rnn_cell(inputs, state)
     return outputs
예제 #3
0
def encoder(inputs, nef=64, n_layers=3, norm_layer='instance'):
    norm_layer = ops.get_norm_layer(norm_layer)
    layers = []
    paddings = [[0, 0], [1, 1], [1, 1], [0, 0]]

    with tf.variable_scope("layer_1"):
        convolved = conv2d(tf.pad(inputs, paddings),
                           nef,
                           kernel_size=4,
                           strides=2,
                           padding='VALID')
        rectified = lrelu(convolved, 0.2)
        layers.append(rectified)

    for i in range(1, n_layers):
        with tf.variable_scope("layer_%d" % (len(layers) + 1)):
            out_channels = nef * min(2**i, 4)
            convolved = conv2d(tf.pad(layers[-1], paddings),
                               out_channels,
                               kernel_size=4,
                               strides=2,
                               padding='VALID')
            normalized = norm_layer(convolved)
            rectified = lrelu(normalized, 0.2)
            layers.append(rectified)

    pooled = pool2d(rectified,
                    rectified.shape.as_list()[1:3],
                    padding='VALID',
                    pool_mode='avg')
    squeezed = tf.squeeze(pooled, [1, 2])
    return squeezed
예제 #4
0
def create_generator(z, ngf=64, norm_layer='instance', n_channels=3):
    norm_layer = ops.get_norm_layer(norm_layer)
    layers = []

    with tf.variable_scope("layer_1"):
        h0 = deconv2d(z, ngf * 8, kernel_size=4, strides=1, padding='VALID')
        h0 = norm_layer(h0)
        h0 = tf.nn.relu(h0)
        layers.append(h0)

    with tf.variable_scope("layer_2"):
        h1 = deconv2d(h0, ngf * 4, kernel_size=4, strides=2)
        h1 = norm_layer(h1)
        h1 = tf.nn.relu(h1)
        layers.append(h1)

    with tf.variable_scope("layer_3"):
        h2 = deconv2d(h1, ngf * 2, kernel_size=4, strides=2)
        h2 = norm_layer(h2)
        h2 = tf.nn.relu(h2)
        layers.append(h2)

    with tf.variable_scope("layer_4"):
        h3 = deconv2d(h2, ngf, kernel_size=4, strides=2)
        h3 = norm_layer(h3)
        h3 = tf.nn.relu(h3)
        layers.append(h3)

    with tf.variable_scope("layer_5"):
        h4 = deconv2d(h3, n_channels, kernel_size=4, strides=2)
        h4 = tf.nn.tanh(h4)
        layers.append(h4)
    return h4
예제 #5
0
def create_acvideo_discriminator(clips,
                                 actions,
                                 ndf=64,
                                 norm_layer='instance',
                                 use_noise=False,
                                 noise_sigma=None):
    norm_layer = ops.get_norm_layer(norm_layer)
    layers = []
    paddings = [[0, 0], [0, 0], [1, 1], [1, 1], [0, 0]]

    clips = clips * 2 - 1
    clip_pairs = tf.concat([clips[:-1], clips[1:]], axis=-1)
    clip_pairs = tile_concat([clip_pairs, actions[..., None, None, :]],
                             axis=-1)
    clip_pairs = tf_utils.transpose_batch_time(clip_pairs)

    with tf.variable_scope("acvideo_layer_1"):
        h1 = noise(clip_pairs, use_noise, noise_sigma)
        h1 = conv3d(tf.pad(h1, paddings),
                    ndf,
                    kernel_size=(3, 4, 4),
                    strides=(1, 2, 2),
                    padding='VALID',
                    use_bias=False)
        h1 = lrelu(h1, 0.2)
        layers.append(h1)

    with tf.variable_scope("acvideo_layer_2"):
        h2 = noise(h1, use_noise, noise_sigma)
        h2 = conv3d(tf.pad(h2, paddings),
                    ndf * 2,
                    kernel_size=(3, 4, 4),
                    strides=(1, 2, 2),
                    padding='VALID',
                    use_bias=False)
        h2 = norm_layer(h2)
        h2 = lrelu(h2, 0.2)
        layers.append(h2)

    with tf.variable_scope("acvideo_layer_3"):
        h3 = noise(h2, use_noise, noise_sigma)
        h3 = conv3d(tf.pad(h3, paddings),
                    ndf * 4,
                    kernel_size=(3, 4, 4),
                    strides=(1, 2, 2),
                    padding='VALID',
                    use_bias=False)
        h3 = norm_layer(h3)
        h3 = lrelu(h3, 0.2)
        layers.append(h3)

    with tf.variable_scope("acvideo_layer_4"):
        logits = conv3d(tf.pad(h3, paddings),
                        1,
                        kernel_size=(3, 4, 4),
                        strides=(1, 2, 2),
                        padding='VALID',
                        use_bias=False)
        layers.append(logits)
    return nest.map_structure(tf_utils.transpose_batch_time, layers)
예제 #6
0
def create_encoder(image, nef=64, norm_layer='instance', dim_z=10):
    norm_layer = ops.get_norm_layer(norm_layer)
    layers = []
    paddings = [[0, 0], [1, 1], [1, 1], [0, 0]]

    with tf.variable_scope("layer_1"):
        h0 = conv2d(tf.pad(image, paddings),
                    nef,
                    kernel_size=4,
                    strides=2,
                    padding='VALID')
        h0 = norm_layer(h0)
        h0 = lrelu(h0, 0.2)
        layers.append(h0)

    with tf.variable_scope("layer_2"):
        h1 = conv2d(tf.pad(h0, paddings),
                    nef * 2,
                    kernel_size=4,
                    strides=2,
                    padding='VALID')
        h1 = norm_layer(h1)
        h1 = lrelu(h1, 0.2)
        layers.append(h1)

    with tf.variable_scope("layer_3"):
        h2 = conv2d(tf.pad(h1, paddings),
                    nef * 4,
                    kernel_size=4,
                    strides=2,
                    padding='VALID')
        h2 = norm_layer(h2)
        h2 = lrelu(h2, 0.2)
        layers.append(h2)

    with tf.variable_scope("layer_4"):
        h3 = conv2d(tf.pad(h2, paddings),
                    nef * 8,
                    kernel_size=4,
                    strides=2,
                    padding='VALID')
        h3 = norm_layer(h3)
        h3 = lrelu(h3, 0.2)
        layers.append(h3)

    with tf.variable_scope("layer_5"):
        h4 = conv2d(tf.pad(h3, paddings),
                    dim_z,
                    kernel_size=4,
                    strides=2,
                    padding='VALID')
        layers.append(h4)

    pooled = pool2d(h4,
                    h4.shape[1:3].as_list(),
                    padding='VALID',
                    pool_mode='avg')
    squeezed = tf.squeeze(pooled, [1, 2])
    return squeezed
예제 #7
0
def create_image_discriminator(images,
                               ndf=64,
                               norm_layer='instance',
                               use_noise=False,
                               noise_sigma=None):
    norm_layer = ops.get_norm_layer(norm_layer)
    layers = []
    paddings = [[0, 0], [1, 1], [1, 1], [0, 0]]

    images = images * 2 - 1

    with tf.variable_scope("image_layer_1"):
        h1 = noise(images, use_noise, noise_sigma)
        h1 = conv2d(tf.pad(h1, paddings),
                    ndf,
                    kernel_size=4,
                    strides=2,
                    padding='VALID',
                    use_bias=False)
        h1 = lrelu(h1, 0.2)
        layers.append(h1)

    with tf.variable_scope("image_layer_2"):
        h2 = noise(h1, use_noise, noise_sigma)
        h2 = conv2d(tf.pad(h2, paddings),
                    ndf * 2,
                    kernel_size=4,
                    strides=2,
                    padding='VALID',
                    use_bias=False)
        h2 = norm_layer(h2)
        h2 = lrelu(h2, 0.2)
        layers.append(h2)

    with tf.variable_scope("image_layer_3"):
        h3 = noise(h2, use_noise, noise_sigma)
        h3 = conv2d(tf.pad(h3, paddings),
                    ndf * 4,
                    kernel_size=4,
                    strides=2,
                    padding='VALID',
                    use_bias=False)
        h3 = norm_layer(h3)
        h3 = lrelu(h3, 0.2)
        layers.append(h3)

    with tf.variable_scope("image_layer_4"):
        h4 = noise(h3, use_noise, noise_sigma)
        logits = conv2d(tf.pad(h4, paddings),
                        1,
                        kernel_size=4,
                        strides=2,
                        padding='VALID',
                        use_bias=False)
        layers.append(logits)
    return layers
예제 #8
0
def create_n_layer_discriminator(discrim_targets,
                                 discrim_inputs=None,
                                 ndf=64,
                                 n_layers=3,
                                 norm_layer='instance'):
    norm_layer = ops.get_norm_layer(norm_layer)

    layers = []
    inputs = [discrim_targets]
    if discrim_inputs is not None:
        inputs.append(discrim_inputs)
    inputs = tf.concat(inputs, axis=-1)

    paddings = [[0, 0], [1, 1], [1, 1], [0, 0]]

    # layer_1: [batch, 256, 256, in_channels * 2] => [batch, 128, 128, ndf]
    with tf.variable_scope("layer_1"):
        convolved = conv2d(tf.pad(inputs, paddings),
                           ndf,
                           kernel_size=4,
                           strides=2,
                           padding='VALID')
        rectified = lrelu(convolved, 0.2)
        layers.append(rectified)

    # layer_2: [batch, 128, 128, ndf] => [batch, 64, 64, ndf * 2]
    # layer_3: [batch, 64, 64, ndf * 2] => [batch, 32, 32, ndf * 4]
    # layer_4: [batch, 32, 32, ndf * 4] => [batch, 31, 31, ndf * 8]
    for i in range(n_layers):
        with tf.variable_scope("layer_%d" % (len(layers) + 1)):
            out_channels = ndf * min(2**(i + 1), 8)
            stride = 1 if i == n_layers - 1 else 2  # last layer here has stride 1
            convolved = conv2d(tf.pad(layers[-1], paddings),
                               out_channels,
                               kernel_size=4,
                               strides=stride,
                               padding='VALID')
            normalized = norm_layer(convolved)
            rectified = lrelu(normalized, 0.2)
            layers.append(rectified)

    # layer_5: [batch, 31, 31, ndf * 8] => [batch, 30, 30, 1]
    with tf.variable_scope("layer_%d" % (len(layers) + 1)):
        logits = conv2d(tf.pad(rectified, paddings),
                        1,
                        kernel_size=4,
                        strides=1,
                        padding='VALID')
        layers.append(
            logits
        )  # don't apply sigmoid to the logits in case we want to use LSGAN
    return layers
예제 #9
0
def create_video_discriminator(clips, ndf=64, norm_layer='instance'):
    norm_layer = ops.get_norm_layer(norm_layer)
    layers = []
    paddings = [[0, 0], [0, 0], [1, 1], [1, 1], [0, 0]]

    clips = tf_utils.transpose_batch_time(clips)

    with tf.variable_scope("video_layer_1"):
        h1 = conv3d(tf.pad(clips, paddings),
                    ndf,
                    kernel_size=4,
                    strides=(1, 2, 2),
                    padding='VALID')
        h1 = lrelu(h1, 0.2)
        layers.append(h1)

    with tf.variable_scope("video_layer_2"):
        h2 = conv3d(tf.pad(h1, paddings),
                    ndf * 2,
                    kernel_size=4,
                    strides=(1, 2, 2),
                    padding='VALID')
        h2 = norm_layer(h2)
        h2 = lrelu(h2, 0.2)
        layers.append(h2)

    with tf.variable_scope("video_layer_3"):
        h3 = conv3d(tf.pad(h2, paddings),
                    ndf * 4,
                    kernel_size=4,
                    strides=(1, 2, 2),
                    padding='VALID')
        h3 = norm_layer(h3)
        h3 = lrelu(h3, 0.2)
        layers.append(h3)

    with tf.variable_scope("video_layer_4"):
        if h3.shape[1].value < 4:
            kernel_size = (h3.shape[1].value, 4, 4)
        else:
            kernel_size = 4
        logits = conv3d(h3,
                        1,
                        kernel_size=kernel_size,
                        strides=1,
                        padding='VALID')
        layers.append(logits)
    return nest.map_structure(tf_utils.transpose_batch_time, layers)
예제 #10
0
def create_n_layer_encoder(inputs,
                           nz=8,
                           nef=64,
                           n_layers=3,
                           norm_layer='instance',
                           include_top=True):
    norm_layer = ops.get_norm_layer(norm_layer)
    layers = []
    paddings = [[0, 0], [1, 1], [1, 1], [0, 0]]

    with tf.variable_scope("layer_1"):
        convolved = conv2d(tf.pad(inputs, paddings),
                           nef,
                           kernel_size=4,
                           strides=2,
                           padding='VALID')
        rectified = lrelu(convolved, 0.2)
        layers.append(rectified)

    for i in range(1, n_layers):
        with tf.variable_scope("layer_%d" % (len(layers) + 1)):
            out_channels = nef * min(2**i, 4)
            convolved = conv2d(tf.pad(layers[-1], paddings),
                               out_channels,
                               kernel_size=4,
                               strides=2,
                               padding='VALID')
            normalized = norm_layer(convolved)
            rectified = lrelu(normalized, 0.2)
            layers.append(rectified)

    pooled = pool2d(rectified,
                    rectified.shape[1:3].as_list(),
                    padding='VALID',
                    pool_mode='avg')
    squeezed = tf.squeeze(pooled, [1, 2])

    if include_top:
        with tf.variable_scope('z_mu'):
            z_mu = dense(squeezed, nz)
        with tf.variable_scope('z_log_sigma_sq'):
            z_log_sigma_sq = dense(squeezed, nz)
            z_log_sigma_sq = tf.clip_by_value(z_log_sigma_sq, -10, 10)
        outputs = {'enc_zs_mu': z_mu, 'enc_zs_log_sigma_sq': z_log_sigma_sq}
    else:
        outputs = squeezed
    return outputs
예제 #11
0
def create_image_discriminator(images, ndf=64, norm_layer='instance'):
    norm_layer = ops.get_norm_layer(norm_layer)
    layers = []
    paddings = [[0, 0], [1, 1], [1, 1], [0, 0]]

    with tf.variable_scope("image_layer_1"):
        h1 = conv2d(tf.pad(images, paddings),
                    ndf,
                    kernel_size=4,
                    strides=2,
                    padding='VALID')
        h1 = lrelu(h1, 0.2)
        layers.append(h1)

    with tf.variable_scope("image_layer_2"):
        h2 = conv2d(tf.pad(h1, paddings),
                    ndf * 2,
                    kernel_size=4,
                    strides=2,
                    padding='VALID')
        h2 = norm_layer(h2)
        h2 = lrelu(h2, 0.2)
        layers.append(h2)

    with tf.variable_scope("image_layer_3"):
        h3 = conv2d(tf.pad(h2, paddings),
                    ndf * 4,
                    kernel_size=4,
                    strides=2,
                    padding='VALID')
        h3 = norm_layer(h3)
        h3 = lrelu(h3, 0.2)
        layers.append(h3)

    with tf.variable_scope("image_layer_4"):
        logits = conv2d(h3, 1, kernel_size=4, strides=1, padding='VALID')
        layers.append(logits)
    return layers
    def call(self, inputs, states):
        norm_layer = ops.get_norm_layer(self.hparams.norm_layer)
        downsample_layer = ops.get_downsample_layer(
            self.hparams.downsample_layer)
        upsample_layer = ops.get_upsample_layer(self.hparams.upsample_layer)
        image_shape = inputs['images'].get_shape().as_list()
        batch_size, height, width, color_channels = image_shape

        time = states['time']
        with tf.control_dependencies([tf.assert_equal(time[1:], time[0])]):
            t = tf.to_int32(tf.identity(time[0]))

        if 'states' in inputs:
            state = tf.where(self.ground_truth[t], inputs['states'],
                             states['gen_state'])

        state_action = []
        state_action_z = []
        if 'actions' in inputs:
            state_action.append(inputs['actions'])
            state_action_z.append(inputs['actions'])
        if 'states' in inputs:
            state_action.append(state)
            # don't backpropagate the convnet through the state dynamics
            state_action_z.append(tf.stop_gradient(state))

        if 'zs' in inputs:
            if self.hparams.use_rnn_z:
                with tf.variable_scope('%s_z' % self.hparams.rnn):
                    rnn_z, rnn_z_state = self._rnn_func(
                        inputs['zs'], states['rnn_z_state'], self.hparams.nz)
                state_action_z.append(rnn_z)
            else:
                state_action_z.append(inputs['zs'])

        def concat(tensors, axis):
            if len(tensors) == 0:
                return tf.zeros([batch_size, 0])
            elif len(tensors) == 1:
                return tensors[0]
            else:
                return tf.concat(tensors, axis=axis)

        state_action = concat(state_action, axis=-1)
        state_action_z = concat(state_action_z, axis=-1)

        image_views = []
        first_image_views = []
        if 'pix_distribs' in inputs:
            pix_distrib_views = []
        for i in range(self.hparams.num_views):
            suffix = '%d' % i if i > 0 else ''
            image_view = tf.where(
                self.ground_truth[t], inputs['images' + suffix],
                states['gen_image' + suffix])  # schedule sampling (if any)
            image_views.append(image_view)
            first_image_views.append(self.inputs['images' + suffix][0])
            if 'pix_distribs' in inputs:
                pix_distrib_view = tf.where(self.ground_truth[t],
                                            inputs['pix_distribs' + suffix],
                                            states['gen_pix_distrib' + suffix])
                pix_distrib_views.append(pix_distrib_view)

        outputs = {}
        new_states = {}
        all_layers = []
        for i in range(self.hparams.num_views):
            suffix = '%d' % i if i > 0 else ''
            conv_rnn_states = states['conv_rnn_states' + suffix]
            layers = []
            new_conv_rnn_states = []
            for i, (out_channels,
                    use_conv_rnn) in enumerate(self.encoder_layer_specs):
                with tf.variable_scope('h%d' % i + suffix):
                    if i == 0:
                        # all image views and the first image corresponding to this view only
                        h = tf.concat(image_views + first_image_views, axis=-1)
                        kernel_size = (5, 5)
                    else:
                        h = layers[-1][-1]
                        kernel_size = (3, 3)
                    if self.hparams.where_add == 'all' or (
                            self.hparams.where_add == 'input' and i == 0):
                        h = tile_concat([h, state_action_z[:, None, None, :]],
                                        axis=-1)
                    h = downsample_layer(h,
                                         out_channels,
                                         kernel_size=kernel_size,
                                         strides=(2, 2))
                    h = norm_layer(h)
                    h = tf.nn.relu(h)
                if use_conv_rnn:
                    conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)]
                    with tf.variable_scope('%s_h%d' %
                                           (self.hparams.conv_rnn, i) +
                                           suffix):
                        if self.hparams.where_add == 'all':
                            conv_rnn_h = tile_concat(
                                [h, state_action_z[:, None, None, :]], axis=-1)
                        else:
                            conv_rnn_h = h
                        conv_rnn_h, conv_rnn_state = self._conv_rnn_func(
                            conv_rnn_h, conv_rnn_state, out_channels)
                    new_conv_rnn_states.append(conv_rnn_state)
                layers.append((h, conv_rnn_h) if use_conv_rnn else (h, ))

            num_encoder_layers = len(layers)
            for i, (out_channels,
                    use_conv_rnn) in enumerate(self.decoder_layer_specs):
                with tf.variable_scope('h%d' % len(layers) + suffix):
                    if i == 0:
                        h = layers[-1][-1]
                    else:
                        h = tf.concat([
                            layers[-1][-1],
                            layers[num_encoder_layers - i - 1][-1]
                        ],
                                      axis=-1)
                    if self.hparams.where_add == 'all' or (
                            self.hparams.where_add == 'middle' and i == 0):
                        h = tile_concat([h, state_action_z[:, None, None, :]],
                                        axis=-1)
                    h = upsample_layer(h,
                                       out_channels,
                                       kernel_size=(3, 3),
                                       strides=(2, 2))
                    h = norm_layer(h)
                    h = tf.nn.relu(h)
                if use_conv_rnn:
                    conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)]
                    with tf.variable_scope(
                            '%s_h%d' % (self.hparams.conv_rnn, len(layers)) +
                            suffix):
                        if self.hparams.where_add == 'all':
                            conv_rnn_h = tile_concat(
                                [h, state_action_z[:, None, None, :]], axis=-1)
                        else:
                            conv_rnn_h = h
                        conv_rnn_h, conv_rnn_state = self._conv_rnn_func(
                            conv_rnn_h, conv_rnn_state, out_channels)
                    new_conv_rnn_states.append(conv_rnn_state)
                layers.append((h, conv_rnn_h) if use_conv_rnn else (h, ))
            assert len(new_conv_rnn_states) == len(conv_rnn_states)

            new_states['conv_rnn_states' + suffix] = new_conv_rnn_states

            all_layers.append(layers)
            if self.hparams.shared_views:
                break

        for i in range(self.hparams.num_views):
            suffix = '%d' % i if i > 0 else ''
            if self.hparams.shared_views:
                layers, = all_layers
            else:
                layers = all_layers[i]

            image = image_views[i]
            last_images = states['last_images' + suffix][1:] + [image]
            if 'pix_distribs' in inputs:
                pix_distrib = pix_distrib_views[i]
                last_pix_distribs = states['last_pix_distribs' +
                                           suffix][1:] + [pix_distrib]

            if self.hparams.last_frames and self.hparams.num_transformed_images:
                if self.hparams.transformation == 'flow':
                    with tf.variable_scope('h%d_flow' % len(layers) + suffix):
                        h_flow = conv2d(layers[-1][-1],
                                        self.hparams.ngf,
                                        kernel_size=(3, 3),
                                        strides=(1, 1))
                        h_flow = norm_layer(h_flow)
                        h_flow = tf.nn.relu(h_flow)

                    with tf.variable_scope('flows' + suffix):
                        flows = conv2d(h_flow,
                                       2 * self.hparams.last_frames *
                                       self.hparams.num_transformed_images,
                                       kernel_size=(3, 3),
                                       strides=(1, 1))
                        flows = tf.reshape(flows, [
                            batch_size, height, width, 2,
                            self.hparams.last_frames *
                            self.hparams.num_transformed_images
                        ])
                else:
                    assert len(self.hparams.kernel_size) == 2
                    kernel_shape = list(self.hparams.kernel_size) + [
                        self.hparams.last_frames *
                        self.hparams.num_transformed_images
                    ]
                    if self.hparams.transformation == 'dna':
                        with tf.variable_scope('h%d_dna_kernel' % len(layers) +
                                               suffix):
                            h_dna_kernel = conv2d(layers[-1][-1],
                                                  self.hparams.ngf,
                                                  kernel_size=(3, 3),
                                                  strides=(1, 1))
                            h_dna_kernel = norm_layer(h_dna_kernel)
                            h_dna_kernel = tf.nn.relu(h_dna_kernel)

                        # Using largest hidden state for predicting untied conv kernels.
                        with tf.variable_scope('dna_kernels' + suffix):
                            kernels = conv2d(h_dna_kernel,
                                             np.prod(kernel_shape),
                                             kernel_size=(3, 3),
                                             strides=(1, 1))
                            kernels = tf.reshape(kernels,
                                                 [batch_size, height, width] +
                                                 kernel_shape)
                            kernels = kernels + identity_kernel(
                                self.hparams.kernel_size)[None, None,
                                                          None, :, :, None]
                        kernel_spatial_axes = [3, 4]
                    elif self.hparams.transformation == 'cdna':
                        with tf.variable_scope('cdna_kernels' + suffix):
                            smallest_layer = layers[num_encoder_layers - 1][-1]
                            kernels = dense(flatten(smallest_layer),
                                            np.prod(kernel_shape))
                            kernels = tf.reshape(kernels,
                                                 [batch_size] + kernel_shape)
                            kernels = kernels + identity_kernel(
                                self.hparams.kernel_size)[None, :, :, None]
                        kernel_spatial_axes = [1, 2]
                    else:
                        raise ValueError('Invalid transformation %s' %
                                         self.hparams.transformation)

                if self.hparams.transformation != 'flow':
                    with tf.name_scope('kernel_normalization' + suffix):
                        kernels = tf.nn.relu(kernels - RELU_SHIFT) + RELU_SHIFT
                        kernels /= tf.reduce_sum(kernels,
                                                 axis=kernel_spatial_axes,
                                                 keepdims=True)

            if self.hparams.generate_scratch_image:
                with tf.variable_scope('h%d_scratch' % len(layers) + suffix):
                    h_scratch = conv2d(layers[-1][-1],
                                       self.hparams.ngf,
                                       kernel_size=(3, 3),
                                       strides=(1, 1))
                    h_scratch = norm_layer(h_scratch)
                    h_scratch = tf.nn.relu(h_scratch)

                # Using largest hidden state for predicting a new image layer.
                # This allows the network to also generate one image from scratch,
                # which is useful when regions of the image become unoccluded.
                with tf.variable_scope('scratch_image' + suffix):
                    scratch_image = conv2d(h_scratch,
                                           color_channels,
                                           kernel_size=(3, 3),
                                           strides=(1, 1))
                    scratch_image = tf.nn.sigmoid(scratch_image)

            with tf.name_scope('transformed_images' + suffix):
                transformed_images = []
                if self.hparams.last_frames and self.hparams.num_transformed_images:
                    if self.hparams.transformation == 'flow':
                        transformed_images.extend(
                            apply_flows(last_images, flows))
                    else:
                        transformed_images.extend(
                            apply_kernels(last_images, kernels,
                                          self.hparams.dilation_rate))
                if self.hparams.prev_image_background:
                    transformed_images.append(image)
                if self.hparams.first_image_background and not self.hparams.context_images_background:
                    transformed_images.append(self.inputs['images' +
                                                          suffix][0])
                if self.hparams.context_images_background:
                    transformed_images.extend(
                        tf.unstack(
                            self.inputs['images' +
                                        suffix][:self.hparams.context_frames]))
                if self.hparams.generate_scratch_image:
                    transformed_images.append(scratch_image)

            if 'pix_distribs' in inputs:
                with tf.name_scope('transformed_pix_distribs' + suffix):
                    transformed_pix_distribs = []
                    if self.hparams.last_frames and self.hparams.num_transformed_images:
                        if self.hparams.transformation == 'flow':
                            transformed_pix_distribs.extend(
                                apply_flows(last_pix_distribs, flows))
                        else:
                            transformed_pix_distribs.extend(
                                apply_kernels(last_pix_distribs, kernels,
                                              self.hparams.dilation_rate))
                    if self.hparams.prev_image_background:
                        transformed_pix_distribs.append(pix_distrib)
                    if self.hparams.first_image_background and not self.hparams.context_images_background:
                        transformed_pix_distribs.append(
                            self.inputs['pix_distribs' + suffix][0])
                    if self.hparams.context_images_background:
                        transformed_pix_distribs.extend(
                            tf.unstack(self.inputs['pix_distribs' + suffix]
                                       [:self.hparams.context_frames]))
                    if self.hparams.generate_scratch_image:
                        transformed_pix_distribs.append(pix_distrib)

            with tf.name_scope('masks' + suffix):
                if len(transformed_images) > 1:
                    with tf.variable_scope('h%d_masks' % len(layers) + suffix):
                        h_masks = conv2d(layers[-1][-1],
                                         self.hparams.ngf,
                                         kernel_size=(3, 3),
                                         strides=(1, 1))
                        h_masks = norm_layer(h_masks)
                        h_masks = tf.nn.relu(h_masks)

                    with tf.variable_scope('masks' + suffix):
                        if self.hparams.dependent_mask:
                            h_masks = tf.concat([h_masks] + transformed_images,
                                                axis=-1)
                        masks = conv2d(h_masks,
                                       len(transformed_images),
                                       kernel_size=(3, 3),
                                       strides=(1, 1))
                        masks = tf.nn.softmax(masks)
                        masks = tf.split(masks,
                                         len(transformed_images),
                                         axis=-1)
                elif len(transformed_images) == 1:
                    masks = [tf.ones([batch_size, height, width, 1])]
                else:
                    raise ValueError(
                        "Either one of the following should be true: "
                        "last_frames and num_transformed_images, first_image_background, "
                        "prev_image_background, generate_scratch_image")

            with tf.name_scope('gen_images' + suffix):
                assert len(transformed_images) == len(masks)
                gen_image = tf.add_n([
                    transformed_image * mask for transformed_image, mask in
                    zip(transformed_images, masks)
                ])

            if 'pix_distribs' in inputs:
                with tf.name_scope('gen_pix_distribs' + suffix):
                    assert len(transformed_pix_distribs) == len(masks)
                    gen_pix_distrib = tf.add_n([
                        transformed_pix_distrib * mask
                        for transformed_pix_distrib, mask in zip(
                            transformed_pix_distribs, masks)
                    ])

                    if self.hparams.renormalize_pixdistrib:
                        gen_pix_distrib /= tf.reduce_sum(gen_pix_distrib,
                                                         axis=(1, 2),
                                                         keepdims=True)

            outputs['gen_images' + suffix] = gen_image
            outputs['transformed_images' + suffix] = tf.stack(
                transformed_images, axis=-1)
            outputs['masks' + suffix] = tf.stack(masks, axis=-1)
            if 'pix_distribs' in inputs:
                outputs['gen_pix_distribs' + suffix] = gen_pix_distrib
                outputs['transformed_pix_distribs' + suffix] = tf.stack(
                    transformed_pix_distribs, axis=-1)
            if self.hparams.transformation == 'flow':
                outputs['gen_flows' + suffix] = flows
                flows_transposed = tf.transpose(flows, [0, 1, 2, 4, 3])
                flows_rgb_transposed = tf_utils.flow_to_rgb(flows_transposed)
                flows_rgb = tf.transpose(flows_rgb_transposed, [0, 1, 2, 4, 3])
                outputs['gen_flows_rgb' + suffix] = flows_rgb

            new_states['gen_image' + suffix] = gen_image
            new_states['last_images' + suffix] = last_images
            if 'pix_distribs' in inputs:
                new_states['gen_pix_distrib' + suffix] = gen_pix_distrib
                new_states['last_pix_distribs' + suffix] = last_pix_distribs

        if 'states' in inputs:
            with tf.name_scope('gen_states'):
                with tf.variable_scope('state_pred'):
                    gen_state = dense(state_action,
                                      inputs['states'].shape[-1].value)

        if 'states' in inputs:
            outputs['gen_states'] = gen_state

        new_states['time'] = time + 1
        if 'zs' in inputs and self.hparams.use_rnn_z:
            new_states['rnn_z_state'] = rnn_z_state
        if 'states' in inputs:
            new_states['gen_state'] = gen_state
        return outputs, new_states
예제 #13
0
    def call(self, inputs, states):
        norm_layer = ops.get_norm_layer(self.hparams.norm_layer)
        image_shape = inputs['images'].get_shape().as_list()
        batch_size, height, width, color_channels = image_shape
        conv_rnn_states = states['conv_rnn_states']

        time = states['time']
        with tf.control_dependencies([tf.assert_equal(time[1:], time[0])]):
            t = tf.to_int32(tf.identity(time[0]))

        last_gt_images = states['last_gt_images'][1:] + [inputs['images']]
        last_pred_flows = states['last_pred_flows'][1:] + [
            tf.zeros_like(states['last_pred_flows'][-1])
        ]
        image = tf.where(self.ground_truth[t], inputs['images'],
                         states['gen_image'])
        last_images = states['last_images'][1:] + [image]
        last_base_images = [
            tf.where(self.ground_truth[t], last_gt_image, last_base_image)
            for last_gt_image, last_base_image in zip(
                last_gt_images, states['last_base_images'])
        ]
        last_base_flows = [
            tf.where(self.ground_truth[t], last_pred_flow, last_base_flow)
            for last_pred_flow, last_base_flow in zip(
                last_pred_flows, states['last_base_flows'])
        ]
        if 'pix_distribs' in inputs:
            last_gt_pix_distribs = states['last_gt_pix_distribs'][1:] + [
                inputs['pix_distribs']
            ]
            last_base_pix_distribs = [
                tf.where(self.ground_truth[t], last_gt_pix_distrib,
                         last_base_pix_distrib)
                for last_gt_pix_distrib, last_base_pix_distrib in zip(
                    last_gt_pix_distribs, states['last_base_pix_distribs'])
            ]
        if 'states' in inputs:
            state = tf.where(self.ground_truth[t], inputs['states'],
                             states['gen_state'])

        state_action = []
        state_action_z = []
        if 'actions' in inputs:
            state_action.append(inputs['actions'])
            state_action_z.append(inputs['actions'])
        if 'states' in inputs:
            state_action.append(state)
            # don't backpropagate the convnet through the state dynamics
            state_action_z.append(tf.stop_gradient(state))

        if 'zs' in inputs:
            if self.hparams.use_rnn_z:
                with tf.variable_scope('%s_z' % self.hparams.rnn):
                    rnn_z, rnn_z_state = self._rnn_func(
                        inputs['zs'], states['rnn_z_state'], self.hparams.nz)
                state_action_z.append(rnn_z)
            else:
                state_action_z.append(inputs['zs'])

        def concat(tensors, axis):
            if len(tensors) == 0:
                return tf.zeros([batch_size, 0])
            elif len(tensors) == 1:
                return tensors[0]
            else:
                return tf.concat(tensors, axis=axis)

        state_action = concat(state_action, axis=-1)
        state_action_z = concat(state_action_z, axis=-1)
        if 'actions' in inputs:
            gen_input = tile_concat(last_images +
                                    [inputs['actions'][:, None, None, :]],
                                    axis=-1)
        else:
            gen_input = tf.concat(last_images)

        layers = []
        new_conv_rnn_states = []
        for i, (out_channels,
                use_conv_rnn) in enumerate(self.encoder_layer_specs):
            with tf.variable_scope('h%d' % i):
                if i == 0:
                    h = tf.concat(last_images, axis=-1)
                    kernel_size = (5, 5)
                else:
                    h = layers[-1][-1]
                    kernel_size = (3, 3)
                h = conv_pool2d(tile_concat(
                    [h, state_action_z[:, None, None, :]], axis=-1),
                                out_channels,
                                kernel_size=kernel_size,
                                strides=(2, 2))
                h = norm_layer(h)
                h = tf.nn.relu(h)
            if use_conv_rnn:
                conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)]
                with tf.variable_scope('%s_h%d' % (self.hparams.conv_rnn, i)):
                    conv_rnn_h, conv_rnn_state = self._conv_rnn_func(
                        tile_concat([h, state_action_z[:, None, None, :]],
                                    axis=-1), conv_rnn_state, out_channels)
                new_conv_rnn_states.append(conv_rnn_state)
            layers.append((h, conv_rnn_h) if use_conv_rnn else (h, ))

        num_encoder_layers = len(layers)
        for i, (out_channels,
                use_conv_rnn) in enumerate(self.decoder_layer_specs):
            with tf.variable_scope('h%d' % len(layers)):
                if i == 0:
                    h = layers[-1][-1]
                else:
                    h = tf.concat([
                        layers[-1][-1], layers[num_encoder_layers - i - 1][-1]
                    ],
                                  axis=-1)
                h = upsample_conv2d(tile_concat(
                    [h, state_action_z[:, None, None, :]], axis=-1),
                                    out_channels,
                                    kernel_size=(3, 3),
                                    strides=(2, 2))
                h = norm_layer(h)
                h = tf.nn.relu(h)
            if use_conv_rnn:
                conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)]
                with tf.variable_scope('%s_h%d' %
                                       (self.hparams.conv_rnn, len(layers))):
                    conv_rnn_h, conv_rnn_state = self._conv_rnn_func(
                        tile_concat([h, state_action_z[:, None, None, :]],
                                    axis=-1), conv_rnn_state, out_channels)
                new_conv_rnn_states.append(conv_rnn_state)
            layers.append((h, conv_rnn_h) if use_conv_rnn else (h, ))
        assert len(new_conv_rnn_states) == len(conv_rnn_states)

        with tf.variable_scope('h%d_flow' % len(layers)):
            h_flow = conv2d(layers[-1][-1],
                            self.hparams.ngf,
                            kernel_size=(3, 3),
                            strides=(1, 1))
            h_flow = norm_layer(h_flow)
            h_flow = tf.nn.relu(h_flow)

        with tf.variable_scope('flows'):
            flows = conv2d(h_flow,
                           2 * self.hparams.last_frames,
                           kernel_size=(3, 3),
                           strides=(1, 1))
            flows = tf.reshape(
                flows,
                [batch_size, height, width, 2, self.hparams.last_frames])

        with tf.name_scope('transformed_images'):
            transformed_images = []
            last_pred_flows = [
                flow + flow_ops.image_warp(last_pred_flow, flow)
                for last_pred_flow, flow in zip(last_pred_flows,
                                                tf.unstack(flows, axis=-1))
            ]
            last_base_flows = [
                flow + flow_ops.image_warp(last_base_flow, flow)
                for last_base_flow, flow in zip(last_base_flows,
                                                tf.unstack(flows, axis=-1))
            ]
            for last_base_image, last_base_flow in zip(last_base_images,
                                                       last_base_flows):
                transformed_images.append(
                    flow_ops.image_warp(last_base_image, last_base_flow))

        if 'pix_distribs' in inputs:
            with tf.name_scope('transformed_pix_distribs'):
                transformed_pix_distribs = []
                for last_base_pix_distrib, last_base_flow in zip(
                        last_base_pix_distribs, last_base_flows):
                    transformed_pix_distribs.append(
                        flow_ops.image_warp(last_base_pix_distrib,
                                            last_base_flow))

        with tf.name_scope('masks'):
            if len(transformed_images) > 1:
                with tf.variable_scope('h%d_masks' % len(layers)):
                    h_masks = conv2d(layers[-1][-1],
                                     self.hparams.ngf,
                                     kernel_size=(3, 3),
                                     strides=(1, 1))
                    h_masks = norm_layer(h_masks)
                    h_masks = tf.nn.relu(h_masks)

                with tf.variable_scope('masks'):
                    if self.hparams.dependent_mask:
                        h_masks = tf.concat([h_masks] + transformed_images,
                                            axis=-1)
                    masks = conv2d(h_masks,
                                   len(transformed_images),
                                   kernel_size=(3, 3),
                                   strides=(1, 1))
                    masks = tf.nn.softmax(masks)
                    masks = tf.split(masks, len(transformed_images), axis=-1)
            elif len(transformed_images) == 1:
                masks = [tf.ones([batch_size, height, width, 1])]
            else:
                raise ValueError(
                    "Either one of the following should be true: "
                    "last_frames and num_transformed_images, first_image_background, "
                    "prev_image_background, generate_scratch_image")

        with tf.name_scope('gen_images'):
            assert len(transformed_images) == len(masks)
            gen_image = tf.add_n([
                transformed_image * mask
                for transformed_image, mask in zip(transformed_images, masks)
            ])

        if 'pix_distribs' in inputs:
            with tf.name_scope('gen_pix_distribs'):
                assert len(transformed_pix_distribs) == len(masks)
                gen_pix_distrib = tf.add_n([
                    transformed_pix_distrib * mask
                    for transformed_pix_distrib, mask in zip(
                        transformed_pix_distribs, masks)
                ])
                # TODO: is this needed?
                # gen_pix_distrib /= tf.reduce_sum(gen_pix_distrib, axis=(1, 2), keepdims=True)

        if 'states' in inputs:
            with tf.name_scope('gen_states'):
                with tf.variable_scope('state_pred'):
                    gen_state = dense(state_action,
                                      inputs['states'].shape[-1].value)

        outputs = {
            'gen_images': gen_image,
            'gen_inputs': gen_input,
            'transformed_images': tf.stack(transformed_images, axis=-1),
            'masks': tf.stack(masks, axis=-1),
            'gen_flow': flows
        }
        if 'pix_distribs' in inputs:
            outputs['gen_pix_distribs'] = gen_pix_distrib
            outputs['transformed_pix_distribs'] = tf.stack(
                transformed_pix_distribs, axis=-1)
        if 'states' in inputs:
            outputs['gen_states'] = gen_state

        new_states = {
            'time': time + 1,
            'last_gt_images': last_gt_images,
            'last_pred_flows': last_pred_flows,
            'gen_image': gen_image,
            'last_images': last_images,
            'last_base_images': last_base_images,
            'last_base_flows': last_base_flows,
            'conv_rnn_states': new_conv_rnn_states
        }
        if 'zs' in inputs and self.hparams.use_rnn_z:
            new_states['rnn_z_state'] = rnn_z_state
        if 'pix_distribs' in inputs:
            new_states['last_gt_pix_distribs'] = last_gt_pix_distribs
            new_states['last_base_pix_distribs'] = last_base_pix_distribs
        if 'states' in inputs:
            new_states['gen_state'] = gen_state
        return outputs, new_states
예제 #14
0
    def call(self, inputs, states):
        norm_layer = ops.get_norm_layer(self.hparams.norm_layer)
        feature_shape = inputs['features'].get_shape().as_list()
        batch_size, height, width, feature_channels = feature_shape
        conv_rnn_states = states['conv_rnn_states']

        time = states['time']
        with tf.control_dependencies([tf.assert_equal(time[1:], time[0])]):
            t = tf.to_int32(tf.identity(time[0]))

        feature = tf.where(self.ground_truth[t], inputs['features'],
                           states['gen_feature'])  # schedule sampling (if any)
        if 'states' in inputs:
            state = tf.where(self.ground_truth[t], inputs['states'],
                             states['gen_state'])

        state_action = []
        state_action_z = []
        if 'actions' in inputs:
            state_action.append(inputs['actions'])
            state_action_z.append(inputs['actions'])
        if 'states' in inputs:
            state_action.append(state)
            # don't backpropagate the convnet through the state dynamics
            state_action_z.append(tf.stop_gradient(state))

        if 'zs' in inputs:
            if self.hparams.use_rnn_z:
                with tf.variable_scope('%s_z' % self.hparams.rnn):
                    rnn_z, rnn_z_state = self._rnn_func(
                        inputs['zs'], states['rnn_z_state'], self.hparams.nz)
                state_action_z.append(rnn_z)
            else:
                state_action_z.append(inputs['zs'])

        def concat(tensors, axis):
            if len(tensors) == 0:
                return tf.zeros([batch_size, 0])
            elif len(tensors) == 1:
                return tensors[0]
            else:
                return tf.concat(tensors, axis=axis)

        state_action = concat(state_action, axis=-1)
        state_action_z = concat(state_action_z, axis=-1)
        if 'actions' in inputs:
            gen_input = tile_concat(
                [feature, inputs['actions'][:, None, None, :]], axis=-1)
        else:
            gen_input = feature

        layers = []
        new_conv_rnn_states = []
        for i, (out_channels,
                use_conv_rnn) in enumerate(self.encoder_layer_specs):
            with tf.variable_scope('h%d' % i):
                if i == 0:
                    # h = tf.concat([feature, self.inputs['features'][0]], axis=-1)  # TODO: use first feature?
                    h = feature
                else:
                    h = layers[-1][-1]
                h = conv_pool2d(tile_concat(
                    [h, state_action_z[:, None, None, :]], axis=-1),
                                out_channels,
                                kernel_size=(3, 3),
                                strides=(2, 2))
                h = norm_layer(h)
                h = tf.nn.relu(h)
            if use_conv_rnn:
                conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)]
                with tf.variable_scope('%s_h%d' % (self.hparams.conv_rnn, i)):
                    conv_rnn_h, conv_rnn_state = self._conv_rnn_func(
                        tile_concat([h, state_action_z[:, None, None, :]],
                                    axis=-1), conv_rnn_state, out_channels)
                new_conv_rnn_states.append(conv_rnn_state)
            layers.append((h, conv_rnn_h) if use_conv_rnn else (h, ))

        num_encoder_layers = len(layers)
        for i, (out_channels,
                use_conv_rnn) in enumerate(self.decoder_layer_specs):
            with tf.variable_scope('h%d' % len(layers)):
                if i == 0:
                    h = layers[-1][-1]
                else:
                    h = tf.concat([
                        layers[-1][-1], layers[num_encoder_layers - i - 1][-1]
                    ],
                                  axis=-1)
                h = upsample_conv2d(tile_concat(
                    [h, state_action_z[:, None, None, :]], axis=-1),
                                    out_channels,
                                    kernel_size=(3, 3),
                                    strides=(2, 2))
                h = norm_layer(h)
                h = tf.nn.relu(h)
            if use_conv_rnn:
                conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)]
                with tf.variable_scope('%s_h%d' %
                                       (self.hparams.conv_rnn, len(layers))):
                    conv_rnn_h, conv_rnn_state = self._conv_rnn_func(
                        tile_concat([h, state_action_z[:, None, None, :]],
                                    axis=-1), conv_rnn_state, out_channels)
                new_conv_rnn_states.append(conv_rnn_state)
            layers.append((h, conv_rnn_h) if use_conv_rnn else (h, ))
        assert len(new_conv_rnn_states) == len(conv_rnn_states)

        if self.hparams.transformation == 'direct':
            with tf.variable_scope('h%d_direct' % len(layers)):
                h_direct = conv2d(layers[-1][-1],
                                  self.hparams.ngf,
                                  kernel_size=(3, 3),
                                  strides=(1, 1))
                h_direct = norm_layer(h_direct)
                h_direct = tf.nn.relu(h_direct)

            with tf.variable_scope('direct'):
                gen_feature = conv2d(h_direct,
                                     feature_channels,
                                     kernel_size=(3, 3),
                                     strides=(1, 1))
        else:
            if self.hparams.transformation == 'flow':
                with tf.variable_scope('h%d_flow' % len(layers)):
                    h_flow = conv2d(layers[-1][-1],
                                    self.hparams.ngf,
                                    kernel_size=(3, 3),
                                    strides=(1, 1))
                    h_flow = norm_layer(h_flow)
                    h_flow = tf.nn.relu(h_flow)

                with tf.variable_scope('flows'):
                    flows = conv2d(h_flow,
                                   2 * feature_channels,
                                   kernel_size=(3, 3),
                                   strides=(1, 1))
                    flows = tf.reshape(
                        flows,
                        [batch_size, height, width, 2, feature_channels])
                transformations = flows
            else:
                assert len(self.hparams.kernel_size) == 2
                kernel_shape = list(
                    self.hparams.kernel_size) + [feature_channels]
                if self.hparams.transformation == 'local':
                    with tf.variable_scope('h%d_local_kernel' % len(layers)):
                        h_local_kernel = conv2d(layers[-1][-1],
                                                self.hparams.ngf,
                                                kernel_size=(3, 3),
                                                strides=(1, 1))
                        h_local_kernel = norm_layer(h_local_kernel)
                        h_local_kernel = tf.nn.relu(h_local_kernel)

                    # Using largest hidden state for predicting untied conv kernels.
                    with tf.variable_scope('local_kernels'):
                        kernels = conv2d(h_local_kernel,
                                         np.prod(kernel_shape),
                                         kernel_size=(3, 3),
                                         strides=(1, 1))
                        kernels = tf.reshape(kernels,
                                             [batch_size, height, width] +
                                             kernel_shape)
                        kernels = kernels + identity_kernel(
                            self.hparams.kernel_size)[None, None, None, :, :,
                                                      None]
                elif self.hparams.transformation == 'conv':
                    with tf.variable_scope('conv_kernels'):
                        smallest_layer = layers[num_encoder_layers - 1][-1]
                        kernels = dense(flatten(smallest_layer),
                                        np.prod(kernel_shape))
                        kernels = tf.reshape(kernels,
                                             [batch_size] + kernel_shape)
                        kernels = kernels + identity_kernel(
                            self.hparams.kernel_size)[None, :, :, None]
                else:
                    raise ValueError('Invalid transformation %s' %
                                     self.hparams.transformation)
                transformations = kernels

            with tf.name_scope('gen_features'):
                if self.hparams.transformation == 'flow':

                    def apply_transformation(feature_and_flow):
                        feature, flow = feature_and_flow
                        return flow_ops.image_warp(feature[..., None], flow)
                else:

                    def apply_transformation(feature_and_kernel):
                        feature, kernel = feature_and_kernel
                        output, = apply_kernels(feature[..., None],
                                                kernel[..., None])
                        return tf.squeeze(output, axis=-1)

                gen_feature_transposed = tf.map_fn(
                    apply_transformation,
                    (tf.stack(tf.unstack(feature, axis=-1)),
                     tf.stack(tf.unstack(transformations, axis=-1))),
                    dtype=tf.float32)
                gen_feature = tf.stack(tf.unstack(gen_feature_transposed),
                                       axis=-1)

        # TODO: use norm and relu for generated features?
        gen_feature = norm_layer(gen_feature)
        gen_feature = tf.nn.relu(gen_feature)

        if 'states' in inputs:
            with tf.name_scope('gen_states'):
                with tf.variable_scope('state_pred'):
                    gen_state = dense(state_action,
                                      inputs['states'].shape[-1].value)

        outputs = {
            'gen_features': gen_feature,
            'gen_inputs': gen_input,
        }
        if 'states' in inputs:
            outputs['gen_states'] = gen_state
        if self.hparams.transformation == 'flow':
            outputs['gen_flows'] = flows

        new_states = {
            'time': time + 1,
            'gen_feature': gen_feature,
            'conv_rnn_states': new_conv_rnn_states,
        }
        if 'zs' in inputs and self.hparams.use_rnn_z:
            new_states['rnn_z_state'] = rnn_z_state
        if 'states' in inputs:
            new_states['gen_state'] = gen_state
        return outputs, new_states
예제 #15
0
def create_legacy_discriminator(discrim_targets,
                                discrim_inputs=None,
                                ndf=64,
                                norm_layer='instance',
                                downsample_layer='conv_pool2d'):
    norm_layer = ops.get_norm_layer(norm_layer)
    downsample_layer = ops.get_downsample_layer(downsample_layer)

    layers = []
    inputs = [discrim_targets]
    if discrim_inputs is not None:
        inputs.append(discrim_inputs)
    inputs = tf.concat(inputs, axis=-1)

    scale_size = min(*inputs.shape.as_list()[1:3])
    if scale_size == 256:
        layer_specs = [
            (
                ndf, 2
            ),  # layer_1: [batch, 256, 256, in_channels * 2] => [batch, 128, 128, ndf]
            (ndf * 2,
             2),  # layer_2: [batch, 128, 128, ndf] => [batch, 64, 64, ndf * 2]
            (
                ndf * 4, 2
            ),  # layer_3: [batch, 64, 64, ndf * 2] => [batch, 32, 32, ndf * 4]
            (
                ndf * 8, 1
            ),  # layer_4: [batch, 32, 32, ndf * 4] => [batch, 32, 32, ndf * 8]
            (1, 1),  # layer_5: [batch, 32, 32, ndf * 8] => [batch, 32, 32, 1]
        ]
    elif scale_size == 128:
        layer_specs = [
            (ndf, 2),
            (ndf * 2, 2),
            (ndf * 4, 1),
            (ndf * 8, 1),
            (1, 1),
        ]
    elif scale_size == 64:
        layer_specs = [
            (ndf, 2),
            (ndf * 2, 1),
            (ndf * 4, 1),
            (ndf * 8, 1),
            (1, 1),
        ]
    else:
        raise NotImplementedError

    with tf.variable_scope("layer_1"):
        out_channels, strides = layer_specs[0]
        convolved = downsample_layer(inputs,
                                     out_channels,
                                     kernel_size=4,
                                     strides=strides)
        rectified = lrelu(convolved, 0.2)
        layers.append(rectified)

    for out_channels, strides in layer_specs[1:-1]:
        with tf.variable_scope("layer_%d" % (len(layers) + 1)):
            if strides == 1:
                convolved = conv2d(layers[-1], out_channels, kernel_size=4)
            else:
                convolved = downsample_layer(layers[-1],
                                             out_channels,
                                             kernel_size=4,
                                             strides=strides)
            normalized = norm_layer(convolved)
            rectified = lrelu(normalized, 0.2)
            layers.append(rectified)

    with tf.variable_scope("layer_%d" % (len(layers) + 1)):
        out_channels, strides = layer_specs[-1]
        if strides == 1:
            logits = conv2d(rectified, out_channels, kernel_size=4)
        else:
            logits = downsample_layer(rectified,
                                      out_channels,
                                      kernel_size=4,
                                      strides=strides)
        layers.append(
            logits
        )  # don't apply sigmoid to the logits in case we want to use LSGAN

    return layers
예제 #16
0
def create_generator(generator_inputs,
                     output_nc=3,
                     ngf=64,
                     norm_layer='instance',
                     downsample_layer='conv_pool2d',
                     upsample_layer='upsample_conv2d'):
    norm_layer = ops.get_norm_layer(norm_layer)
    downsample_layer = ops.get_downsample_layer(downsample_layer)
    upsample_layer = ops.get_upsample_layer(upsample_layer)

    layers = []
    inputs = generator_inputs

    scale_size = min(*inputs.shape.as_list()[1:3])
    if scale_size == 256:
        layer_specs = [
            (
                ngf, 2
            ),  # encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf]
            (
                ngf * 2, 2
            ),  # encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 2]
            (
                ngf * 4, 2
            ),  # encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 4]
            (
                ngf * 8, 2
            ),  # encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8]
            (
                ngf * 8, 2
            ),  # encoder_5: [batch, 16, 16, ngf * 8] => [batch, 8, 8, ngf * 8]
            (ngf * 8,
             2),  # encoder_6: [batch, 8, 8, ngf * 8] => [batch, 4, 4, ngf * 8]
            (ngf * 8,
             2),  # encoder_7: [batch, 4, 4, ngf * 8] => [batch, 2, 2, ngf * 8]
            (ngf * 8,
             2),  # encoder_8: [batch, 2, 2, ngf * 8] => [batch, 1, 1, ngf * 8]
        ]
    elif scale_size == 128:
        layer_specs = [
            (ngf, 2),
            (ngf * 2, 2),
            (ngf * 4, 2),
            (ngf * 8, 2),
            (ngf * 8, 2),
            (ngf * 8, 2),
            (ngf * 8, 2),
        ]
    elif scale_size == 64:
        layer_specs = [
            (ngf, 2),
            (ngf * 2, 2),
            (ngf * 4, 2),
            (ngf * 8, 2),
            (ngf * 8, 2),
            (ngf * 8, 2),
        ]
    else:
        raise NotImplementedError

    with tf.variable_scope("encoder_1"):
        out_channels, strides = layer_specs[0]
        if strides == 1:
            output = conv2d(inputs, out_channels, kernel_size=4)
        else:
            output = downsample_layer(inputs,
                                      out_channels,
                                      kernel_size=4,
                                      strides=strides)
        layers.append(output)

    for out_channels, strides in layer_specs[1:]:
        with tf.variable_scope("encoder_%d" % (len(layers) + 1)):
            rectified = lrelu(layers[-1], 0.2)
            # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels]
            if strides == 1:
                convolved = conv2d(rectified, out_channels, kernel_size=4)
            else:
                convolved = downsample_layer(rectified,
                                             out_channels,
                                             kernel_size=4,
                                             strides=strides)
            output = norm_layer(convolved)
            layers.append(output)

    if scale_size == 256:
        layer_specs = [
            (
                ngf * 8, 2, 0.5
            ),  # decoder_8: [batch, 1, 1, ngf * 8] => [batch, 2, 2, ngf * 8 * 2]
            (
                ngf * 8, 2, 0.5
            ),  # decoder_7: [batch, 2, 2, ngf * 8 * 2] => [batch, 4, 4, ngf * 8 * 2]
            (
                ngf * 8, 2, 0.5
            ),  # decoder_6: [batch, 4, 4, ngf * 8 * 2] => [batch, 8, 8, ngf * 8 * 2]
            (
                ngf * 8, 2, 0.0
            ),  # decoder_5: [batch, 8, 8, ngf * 8 * 2] => [batch, 16, 16, ngf * 8 * 2]
            (
                ngf * 4, 2, 0.0
            ),  # decoder_4: [batch, 16, 16, ngf * 8 * 2] => [batch, 32, 32, ngf * 4 * 2]
            (
                ngf * 2, 2, 0.0
            ),  # decoder_3: [batch, 32, 32, ngf * 4 * 2] => [batch, 64, 64, ngf * 2 * 2]
            (
                ngf, 2, 0.0
            ),  # decoder_2: [batch, 64, 64, ngf * 2 * 2] => [batch, 128, 128, ngf * 2]
            (
                output_nc, 2, 0.0
            ),  # decoder_1: [batch, 128, 128, ngf * 2] => [batch, 256, 256, generator_outputs_channels]
        ]
    elif scale_size == 128:
        layer_specs = [
            (ngf * 8, 2, 0.5),
            (ngf * 8, 2, 0.5),
            (ngf * 8, 2, 0.5),
            (ngf * 4, 2, 0.0),
            (ngf * 2, 2, 0.0),
            (ngf, 2, 0.0),
            (output_nc, 2, 0.0),
        ]
    elif scale_size == 64:
        layer_specs = [
            (ngf * 8, 2, 0.5),
            (ngf * 8, 2, 0.5),
            (ngf * 4, 2, 0.0),
            (ngf * 2, 2, 0.0),
            (ngf, 2, 0.0),
            (output_nc, 2, 0.0),
        ]
    else:
        raise NotImplementedError

    num_encoder_layers = len(layers)
    for decoder_layer, (out_channels, stride,
                        dropout) in enumerate(layer_specs[:-1]):
        skip_layer = num_encoder_layers - decoder_layer - 1
        with tf.variable_scope("decoder_%d" % (skip_layer + 1)):
            if decoder_layer == 0:
                # first decoder layer doesn't have skip connections
                # since it is directly connected to the skip_layer
                input = layers[-1]
            else:
                input = tf.concat([layers[-1], layers[skip_layer]], axis=3)

            rectified = tf.nn.relu(input)
            # [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels]
            if stride == 1:
                output = conv2d(rectified, out_channels, kernel_size=4)
            else:
                output = upsample_layer(rectified,
                                        out_channels,
                                        kernel_size=4,
                                        strides=strides)
            output = norm_layer(output)

            if dropout > 0.0:
                output = tf.nn.dropout(output, keep_prob=1 - dropout)

            layers.append(output)

    with tf.variable_scope("decoder_1"):
        out_channels, stride, dropout = layer_specs[-1]
        assert dropout == 0.0  # no dropout at the last layer
        input = tf.concat([layers[-1], layers[0]], axis=3)
        rectified = tf.nn.relu(input)
        if stride == 1:
            output = conv2d(rectified, out_channels, kernel_size=4)
        else:
            output = upsample_layer(rectified,
                                    out_channels,
                                    kernel_size=4,
                                    strides=strides)
        output = tf.tanh(output)
        output = (output + 1) / 2
        layers.append(output)

    return layers[-1]