Esempio n. 1
0
def apply_dna_kernels_non_dilated(image, kernels):
    batch_size, height, width, color_channels = image.get_shape().as_list()
    batch_size, height, width, kernel_size, num_transformed_images = kernels.get_shape(
    ).as_list()
    # Flatten the spatial dimensions.
    kernels_reshaped = tf.reshape(kernels, [
        batch_size, height, width, kernel_size[0] * kernel_size[1],
        num_transformed_images
    ])
    image_padded = pad2d(image, kernel_size, padding='SAME', mode='SYMMETRIC')
    # Combine channel and batch dimensions into the first dimension.
    image_transposed = tf.transpose(image_padded, [3, 0, 1, 2])
    image_reshaped = flatten(image_transposed, 0, 1)[..., None]
    patches_reshaped = tf.extract_image_patches(image_reshaped,
                                                ksizes=[1] + kernel_size + [1],
                                                strides=[1] * 4,
                                                rates=[1] * 4,
                                                padding='VALID')
    # Separate channel and batch dimensions.
    patches = tf.reshape(patches_reshaped, [
        color_channels, batch_size, height, width,
        kernel_size[0] * kernel_size[1]
    ])
    # Reduce along the spatial dimensions of the kernel.
    outputs = tf.reduce_sum(patches[..., None] * kernels_reshaped[None, ...],
                            axis=-2)
    # Swap channel and transformation dimensions.
    outputs = tf.transpose(outputs, [4, 1, 2, 3, 0])
    outputs = tf.unstack(outputs, axis=0)
    return outputs
def apply_dna_kernels_dilated(image, kernels, dilation_rate=(1, 1)):
    dilation_rate = list(dilation_rate) if isinstance(dilation_rate, (tuple, list)) else [dilation_rate] * 2
    batch_size, height, width, color_channels = image.get_shape().as_list()
    batch_size, kernel_height, kernel_width, kernel_size, num_transformed_images = kernels.get_shape().as_list()
    # Flatten the spatial dimensions.
    kernels_reshaped = tf.reshape(kernels, [batch_size, kernel_height, kernel_width,
                                            kernel_size[0] * kernel_size[1], num_transformed_images])
    image_padded = pad2d(image, kernel_size, rate=dilation_rate, padding='SAME', mode='SYMMETRIC')
    # for dilation = [2, 2], this is equivalent to this:
    # small_images = [image[:, 0::2, 0::2, :], image[:, 0::2, 1::2, :], image[:, 1::2, 0::2, :], image[:, 1::2, 1::2, :]]
    small_images = tf.space_to_batch_nd(image_padded, dilation_rate, paddings=[[0, 0]] * 2)
    small_images = tf.reshape(small_images, [dilation_rate[0] * dilation_rate[1], batch_size,
                                             image_padded.get_shape().as_list()[1] // dilation_rate[0],
                                             image_padded.get_shape().as_list()[2] // dilation_rate[1],
                                             color_channels])
    small_images = tf.unstack(small_images, axis=0)
    small_outputs = []
    for small_image in small_images:
        # Combine channel and batch dimensions into the first dimension.
        image_transposed = tf.transpose(small_image, [3, 0, 1, 2])
        image_reshaped = flatten(image_transposed, 0, 1)[..., None]
        patches_reshaped = tf.extract_image_patches(image_reshaped, ksizes=[1] + kernel_size + [1],
                                                    strides=[1] * 4, rates=[1] * 4, padding='VALID')
        # Separate channel and batch dimensions.
        patches = tf.reshape(patches_reshaped, [color_channels, batch_size,
                                                height // dilation_rate[0], width // dilation_rate[1],
                                                kernel_size[0] * kernel_size[1]])
        # Reduce along the spatial dimensions of the kernel.
        outputs = tf.reduce_sum(patches[..., None] * kernels_reshaped[None, ...], axis=-2)
        # Swap channel and transformation dimensions.
        outputs = tf.transpose(outputs, [4, 1, 2, 3, 0])
        outputs = tf.unstack(outputs, axis=0)
        small_outputs.append(outputs)
    small_outputs = list(zip(*small_outputs))
    small_outputs = [tf.reshape(small_output, [dilation_rate[0] * dilation_rate[1] * batch_size,
                                               height // dilation_rate[0], width // dilation_rate[1], color_channels])
                     for small_output in small_outputs]
    outputs = [tf.batch_to_space_nd(small_output, dilation_rate, crops=[[0, 0]] * 2) for small_output in small_outputs]
    return outputs
    def call(self, inputs, states):

        # inputs
        (image, action, state), other_inputs = inputs[:3], inputs[3:]
        if other_inputs:
            pix_distrib = other_inputs.pop(0)

        # states
        (lstm_states, time, gen_image, gen_state), other_states = states[:4], states[4:]
        lstm_state0, lstm_state1, lstm_state2, lstm_state3, lstm_state4 = lstm_states
        if other_states:
            other_states = list(other_states)
            if self.first_pix_distrib is not None:
                gen_pix_distrib = other_states.pop(0)

            if 'appflow_chained' == self.conf['model']:
                prev_flow_t_0 = other_states.pop(0)

        image_shape = image.get_shape().as_list()
        batch_size, height, width, color_channels = image_shape

        _, state_dim = state.get_shape().as_list()
        kernel_size = self.kernel_size
        dilation_rate = self.dilation_rate
        num_transformed_images = self.num_transformed_images
        vgf_dim = self.vgf_dim

        done_warm_start = time > self.context_frames - 1
        if self.feedself:
            image = tf.cond(tf.reduce_all(done_warm_start),
                            lambda: gen_image,  # feed in generated image
                            lambda: image)  # feed in ground_truth
            if self.first_pix_distrib is not None:
                pix_distrib = tf.cond(tf.reduce_all(done_warm_start),
                                      lambda: gen_pix_distrib,  # feed in generated pixel distribution
                                      lambda: pix_distrib)  # feed in ground_truth
        else:
            image = tf.cond(tf.reduce_all(done_warm_start),
                            lambda: scheduled_sample(image, gen_image, batch_size, self.num_ground_truth),
                            # schedule sampling
                            lambda: image)  # feed in ground_truth
            if self.first_pix_distrib is not None:
                raise NotImplementedError
        state = tf.cond(tf.reduce_all(tf.equal(time, 0)),
                        lambda: state,  # feed in ground_truth state only for first time step
                        lambda: gen_state)  # feed in predicted state

        if 'ignore_state' in self.conf:
            state_action = tf.concat([action], axis=-1)
        else:
            state_action = tf.concat([action, state], axis=-1)

        with tf.variable_scope('h0'):
            h0 = conv_pool2d(image, vgf_dim, kernel_size=(5, 5), strides=(2, 2))
            h0 = self.normalizer_fn(h0)
            h0 = tf.nn.relu(h0)

        with tf.variable_scope('lstm_h0'):
            lstm_h0, lstm_state0 = self._lstm_func(h0, lstm_state0, vgf_dim)

        with tf.variable_scope('h1'):
            h1 = conv_pool2d(lstm_h0, vgf_dim * 2, kernel_size=(3, 3), strides=(2, 2))
            h1 = self.normalizer_fn(h1)
            h1 = tf.nn.relu(h1)

        with tf.variable_scope('lstm_h1'):
            lstm_h1, lstm_state1 = self._lstm_func(h1, lstm_state1, vgf_dim * 2)

        with tf.variable_scope('h2'):
            h2 = conv_pool2d(lstm_h1, vgf_dim * 4, kernel_size=(3, 3), strides=(2, 2))
            h2 = self.normalizer_fn(h2)
            h2 = tf.nn.relu(h2)

        # Pass in state and action.
        if self.use_state:
            with tf.variable_scope('state_action_h2'):
                state_action_smear = tf.tile(state_action[:, None, None, :],
                                             [1, h2.get_shape().as_list()[1], h2.get_shape().as_list()[2], 1])
                state_action_h2 = tf.concat([h2, state_action_smear], axis=-1)
                state_action_h2 = conv2d(state_action_h2, vgf_dim * 4, kernel_size=(1, 1), strides=(1, 1))
                # TODO: consider adding normalizer and relu here
        else:
            state_action_h2 = h2

        with tf.variable_scope('lstm_h2'):
            lstm_h2, lstm_state2 = self._lstm_func(state_action_h2, lstm_state2, vgf_dim * 4)

        with tf.variable_scope('h3'):
            h3 = upsample_conv2d(lstm_h2, vgf_dim * 2, kernel_size=(3, 3), strides=(2, 2))
            h3 = self.normalizer_fn(h3)
            h3 = tf.nn.relu(h3)

        with tf.variable_scope('lstm_h3'):
            lstm_h3, lstm_state3 = self._lstm_func(h3, lstm_state3, vgf_dim * 2)

        with tf.variable_scope('h4'):
            h4 = upsample_conv2d(tf.concat([lstm_h3, h1], axis=-1), vgf_dim, kernel_size=(3, 3), strides=(2, 2))
            h4 = self.normalizer_fn(h4)
            h4 = tf.nn.relu(h4)

        with tf.variable_scope('lstm_h4'):
            lstm_h4, lstm_state4 = self._lstm_func(h4, lstm_state4, vgf_dim)

        with tf.variable_scope('h5'):
            h5 = upsample_conv2d(tf.concat([lstm_h4, h0], axis=-1), vgf_dim, kernel_size=(3, 3), strides=(2, 2))
            h5 = self.normalizer_fn(h5)
            h5 = tf.nn.relu(h5)

        if self.model == 'dna':
            with tf.variable_scope('h6_dna_kernel'):
                h6_dna_kernel = conv2d(h5, vgf_dim, kernel_size=(3, 3), strides=(1, 1))
                h6_dna_kernel = self.normalizer_fn(h6_dna_kernel)
                h6_dna_kernel = tf.nn.relu(h6_dna_kernel)

        if self.model == 'appflow' or self.model  == 'appflow_chained':
            with tf.variable_scope('pre_appflow'):
                h6_appflow = conv2d(h5, vgf_dim, kernel_size=(3, 3), strides=(1, 1))
                h6_appflow = self.normalizer_fn(h6_appflow)
                h6_appflow = tf.nn.relu(h6_appflow)
            with tf.variable_scope('appflow'):
                flowvecs_tp1_t = conv2d(h6_appflow, 2, kernel_size=(3, 3), strides=(1, 1))

        if self.generate_scratch_image:
            with tf.variable_scope('h6_scratch'):
                h6_scratch = conv2d(h5, vgf_dim, kernel_size=(3, 3), strides=(1, 1))
                h6_scratch = self.normalizer_fn(h6_scratch)
                h6_scratch = tf.nn.relu(h6_scratch)

        with tf.variable_scope('h6_masks'):
            h6_masks = conv2d(h5, vgf_dim, kernel_size=(3, 3), strides=(1, 1))
            h6_masks = self.normalizer_fn(h6_masks)
            h6_masks = tf.nn.relu(h6_masks)

        if self.model == 'dna':
            # Using largest hidden state for predicting untied conv kernels.
            with tf.variable_scope('dna_kernels'):
                kernels = conv2d(h6_dna_kernel, kernel_size[0] * kernel_size[1] * num_transformed_images,
                                 kernel_size=(3, 3), strides=(1, 1))
                kernels = tf.reshape(kernels, [batch_size, height, width,
                                               kernel_size[0], kernel_size[1], num_transformed_images])
            kernel_spatial_axes = [3, 4]
        elif self.model == 'cdna':
            with tf.variable_scope('cdna_kernels'):
                kernels = dense(flatten(lstm_h2), kernel_size[0] * kernel_size[1] * num_transformed_images)
                kernels = tf.reshape(kernels, [batch_size, kernel_size[0], kernel_size[1], num_transformed_images])
            kernel_spatial_axes = [1, 2]
        elif self.model == 'appflow' or self.model == 'appflow_chained':
            pass # appearance flow doesn't have kernels
        else:
            raise ValueError('Invalid model %s' % self.model)

        transformed_images = []
        with tf.name_scope('transformed_images'):
            if self.model == 'appflow':
                transformed_images += [apply_warp(image, flowvecs_tp1_t)]
            elif self.model == 'appflow_chained':
                flow_tp1_0 = flowvecs_tp1_t + apply_warp(prev_flow_t_0, flowvecs_tp1_t)
                transformed_images += [apply_warp(self.first_image, flow_tp1_0)]
            else:
                with tf.name_scope('kernel_normalization'):
                    kernels = tf.nn.relu(kernels - RELU_SHIFT) + RELU_SHIFT
                    kernels /= tf.reduce_sum(kernels, axis=kernel_spatial_axes, keep_dims=True)
                transformed_images += apply_kernels(image, kernels, dilation_rate=dilation_rate)

        if self.first_image_background:
            transformed_images.append(self.first_image)
        if self.prev_image_background:
            transformed_images.append(image)
        if self.generate_scratch_image:
            # Using largest hidden state for predicting a new image layer.
            # This allows the network to also generate one image from scratch,
            # which is useful when regions of the image become unoccluded.
            with tf.variable_scope('scratch_image'):
                scratch_image = conv2d(h6_scratch, image_shape[-1], kernel_size=(3, 3), strides=(1, 1))
                scratch_image = tf.nn.sigmoid(scratch_image)
                transformed_images.append(scratch_image)

        if self.first_pix_distrib is not None:
            transformed_pix_distribs = []
            with tf.name_scope('transformed_pix_distrib'):

                if self.model == 'appflow' or self.model == 'appflow_chained':
                    if self.model == 'appflow_chained':
                        trafoflow = flow_tp1_0
                    else:
                        trafoflow = flowvecs_tp1_t
                    transf_pix = [[apply_warp(pix_distrib[:, p], trafoflow)] for p in range(self.ndesig)]
                else:
                    transf_pix = [apply_kernels(pix_distrib[:, p], kernels, dilation_rate=dilation_rate) for p in
                                  range(self.ndesig)]
                transf_pix_l = []
                for n in range(self.num_transformed_images):
                    transf_pix_n = tf.stack([transf_pix[p][n] for p in range(self.ndesig)], axis=1)
                    transf_pix_l.append(transf_pix_n)
                transformed_pix_distribs += transf_pix_l
            if self.first_image_background:
                transformed_pix_distribs.append(self.first_pix_distrib)
            if self.prev_image_background:
                transformed_pix_distribs.append(pix_distrib)
            if self.generate_scratch_image:
                transformed_pix_distribs.append(tf.zeros_like(pix_distrib))

        with tf.variable_scope('masks'):
            if self.dependent_mask:
                mask_inputs = tf.concat([h6_masks] + transformed_images, axis=-1)
            else:
                mask_inputs = h6_masks
            masks = conv2d(mask_inputs, len(transformed_images), kernel_size=(3, 3), strides=(1, 1))
            masks = tf.nn.softmax(masks)
            masks = tf.split(masks, len(transformed_images), axis=-1)

        with tf.name_scope('gen_image'):
            assert len(transformed_images) == len(masks)
            gen_image = tf.add_n([transformed_image * mask
                                      for transformed_image, mask in zip(transformed_images, masks)])

        with tf.name_scope('flow_map'):
            if self.model != 'appflow' and self.model != 'appflow_chained':
                flow_map = compute_flow_map(kernels, masks[:num_transformed_images])

        if self.first_pix_distrib is not None:
            with tf.name_scope('gen_pix_distrib'):
                assert len(transformed_pix_distribs) <= len(masks) <= len(
                    transformed_pix_distribs) + 1  # there might be an extra mask because of the scratch image
                gen_pix_distrib = []
                for p in range(self.ndesig):
                    transformed_pix_distribs_p = [transformed_pix_distribs[n][:, p] for n in
                                                  range(len(transformed_pix_distribs))]
                    gen_pix_distrib.append(tf.add_n([transformed_pix_distrib * mask
                                                     for transformed_pix_distrib, mask in
                                                     zip(transformed_pix_distribs_p, masks)]))
                gen_pix_distrib = tf.stack(gen_pix_distrib, axis=1)

        with tf.variable_scope('state_pred'):
            gen_state = dense(state_action, state_dim)

        # outputs
        outputs = [gen_image, gen_state, masks, transformed_images]
        if 'compute_flow_map' in self.conf:
            outputs.append(flow_map)

        if self.first_pix_distrib is not None:
            outputs.append(gen_pix_distrib)
            outputs.append(transformed_pix_distribs)

        outputs = tuple(outputs)
        # states
        new_lstm_states = lstm_state0, lstm_state1, lstm_state2, lstm_state3, lstm_state4
        new_states = [new_lstm_states, time + 1, gen_image, gen_state, ]
        if self.first_pix_distrib is not None:
            new_states.append(gen_pix_distrib)
        if self.model == 'appflow_chained':
            new_states.append(flow_tp1_0)

        new_states = tuple(new_states)
        return outputs, new_states