def apply_flows(image, flows):
    if isinstance(image, list):
        image_list = image
        flows_list = tf.split(flows, len(image_list), axis=-1)
        outputs = []
        for image, flows in zip(image_list, flows_list):
            outputs.extend(apply_flows(image, flows))
    else:
        flows = tf.unstack(flows, axis=-1)
        outputs = [flow_ops.image_warp(image, flow) for flow in flows]
    return outputs
    def call(self, inputs, states):
        norm_layer = ops.get_norm_layer(self.hparams.norm_layer)
        image_shape = inputs['images'].get_shape().as_list()
        batch_size, height, width, color_channels = image_shape
        conv_rnn_states = states['conv_rnn_states']

        time = states['time']
        with tf.control_dependencies([tf.assert_equal(time[1:], time[0])]):
            t = tf.to_int32(tf.identity(time[0]))

        last_gt_images = states['last_gt_images'][1:] + [inputs['images']]
        last_pred_flows = states['last_pred_flows'][1:] + [
            tf.zeros_like(states['last_pred_flows'][-1])
        ]
        image = tf.where(self.ground_truth[t], inputs['images'],
                         states['gen_image'])
        last_images = states['last_images'][1:] + [image]
        last_base_images = [
            tf.where(self.ground_truth[t], last_gt_image, last_base_image)
            for last_gt_image, last_base_image in zip(
                last_gt_images, states['last_base_images'])
        ]
        last_base_flows = [
            tf.where(self.ground_truth[t], last_pred_flow, last_base_flow)
            for last_pred_flow, last_base_flow in zip(
                last_pred_flows, states['last_base_flows'])
        ]
        if 'pix_distribs' in inputs:
            last_gt_pix_distribs = states['last_gt_pix_distribs'][1:] + [
                inputs['pix_distribs']
            ]
            last_base_pix_distribs = [
                tf.where(self.ground_truth[t], last_gt_pix_distrib,
                         last_base_pix_distrib)
                for last_gt_pix_distrib, last_base_pix_distrib in zip(
                    last_gt_pix_distribs, states['last_base_pix_distribs'])
            ]
        if 'states' in inputs:
            state = tf.where(self.ground_truth[t], inputs['states'],
                             states['gen_state'])

        state_action = []
        state_action_z = []
        if 'actions' in inputs:
            state_action.append(inputs['actions'])
            state_action_z.append(inputs['actions'])
        if 'states' in inputs:
            state_action.append(state)
            # don't backpropagate the convnet through the state dynamics
            state_action_z.append(tf.stop_gradient(state))

        if 'zs' in inputs:
            if self.hparams.use_rnn_z:
                with tf.variable_scope('%s_z' % self.hparams.rnn):
                    rnn_z, rnn_z_state = self._rnn_func(
                        inputs['zs'], states['rnn_z_state'], self.hparams.nz)
                state_action_z.append(rnn_z)
            else:
                state_action_z.append(inputs['zs'])

        def concat(tensors, axis):
            if len(tensors) == 0:
                return tf.zeros([batch_size, 0])
            elif len(tensors) == 1:
                return tensors[0]
            else:
                return tf.concat(tensors, axis=axis)

        state_action = concat(state_action, axis=-1)
        state_action_z = concat(state_action_z, axis=-1)
        if 'actions' in inputs:
            gen_input = tile_concat(last_images +
                                    [inputs['actions'][:, None, None, :]],
                                    axis=-1)
        else:
            gen_input = tf.concat(last_images)

        layers = []
        new_conv_rnn_states = []
        for i, (out_channels,
                use_conv_rnn) in enumerate(self.encoder_layer_specs):
            with tf.variable_scope('h%d' % i):
                if i == 0:
                    h = tf.concat(last_images, axis=-1)
                    kernel_size = (5, 5)
                else:
                    h = layers[-1][-1]
                    kernel_size = (3, 3)
                h = conv_pool2d(tile_concat(
                    [h, state_action_z[:, None, None, :]], axis=-1),
                                out_channels,
                                kernel_size=kernel_size,
                                strides=(2, 2))
                h = norm_layer(h)
                h = tf.nn.relu(h)
            if use_conv_rnn:
                conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)]
                with tf.variable_scope('%s_h%d' % (self.hparams.conv_rnn, i)):
                    conv_rnn_h, conv_rnn_state = self._conv_rnn_func(
                        tile_concat([h, state_action_z[:, None, None, :]],
                                    axis=-1), conv_rnn_state, out_channels)
                new_conv_rnn_states.append(conv_rnn_state)
            layers.append((h, conv_rnn_h) if use_conv_rnn else (h, ))

        num_encoder_layers = len(layers)
        for i, (out_channels,
                use_conv_rnn) in enumerate(self.decoder_layer_specs):
            with tf.variable_scope('h%d' % len(layers)):
                if i == 0:
                    h = layers[-1][-1]
                else:
                    h = tf.concat([
                        layers[-1][-1], layers[num_encoder_layers - i - 1][-1]
                    ],
                                  axis=-1)
                h = upsample_conv2d(tile_concat(
                    [h, state_action_z[:, None, None, :]], axis=-1),
                                    out_channels,
                                    kernel_size=(3, 3),
                                    strides=(2, 2))
                h = norm_layer(h)
                h = tf.nn.relu(h)
            if use_conv_rnn:
                conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)]
                with tf.variable_scope('%s_h%d' %
                                       (self.hparams.conv_rnn, len(layers))):
                    conv_rnn_h, conv_rnn_state = self._conv_rnn_func(
                        tile_concat([h, state_action_z[:, None, None, :]],
                                    axis=-1), conv_rnn_state, out_channels)
                new_conv_rnn_states.append(conv_rnn_state)
            layers.append((h, conv_rnn_h) if use_conv_rnn else (h, ))
        assert len(new_conv_rnn_states) == len(conv_rnn_states)

        with tf.variable_scope('h%d_flow' % len(layers)):
            h_flow = conv2d(layers[-1][-1],
                            self.hparams.ngf,
                            kernel_size=(3, 3),
                            strides=(1, 1))
            h_flow = norm_layer(h_flow)
            h_flow = tf.nn.relu(h_flow)

        with tf.variable_scope('flows'):
            flows = conv2d(h_flow,
                           2 * self.hparams.last_frames,
                           kernel_size=(3, 3),
                           strides=(1, 1))
            flows = tf.reshape(
                flows,
                [batch_size, height, width, 2, self.hparams.last_frames])

        with tf.name_scope('transformed_images'):
            transformed_images = []
            last_pred_flows = [
                flow + flow_ops.image_warp(last_pred_flow, flow)
                for last_pred_flow, flow in zip(last_pred_flows,
                                                tf.unstack(flows, axis=-1))
            ]
            last_base_flows = [
                flow + flow_ops.image_warp(last_base_flow, flow)
                for last_base_flow, flow in zip(last_base_flows,
                                                tf.unstack(flows, axis=-1))
            ]
            for last_base_image, last_base_flow in zip(last_base_images,
                                                       last_base_flows):
                transformed_images.append(
                    flow_ops.image_warp(last_base_image, last_base_flow))

        if 'pix_distribs' in inputs:
            with tf.name_scope('transformed_pix_distribs'):
                transformed_pix_distribs = []
                for last_base_pix_distrib, last_base_flow in zip(
                        last_base_pix_distribs, last_base_flows):
                    transformed_pix_distribs.append(
                        flow_ops.image_warp(last_base_pix_distrib,
                                            last_base_flow))

        with tf.name_scope('masks'):
            if len(transformed_images) > 1:
                with tf.variable_scope('h%d_masks' % len(layers)):
                    h_masks = conv2d(layers[-1][-1],
                                     self.hparams.ngf,
                                     kernel_size=(3, 3),
                                     strides=(1, 1))
                    h_masks = norm_layer(h_masks)
                    h_masks = tf.nn.relu(h_masks)

                with tf.variable_scope('masks'):
                    if self.hparams.dependent_mask:
                        h_masks = tf.concat([h_masks] + transformed_images,
                                            axis=-1)
                    masks = conv2d(h_masks,
                                   len(transformed_images),
                                   kernel_size=(3, 3),
                                   strides=(1, 1))
                    masks = tf.nn.softmax(masks)
                    masks = tf.split(masks, len(transformed_images), axis=-1)
            elif len(transformed_images) == 1:
                masks = [tf.ones([batch_size, height, width, 1])]
            else:
                raise ValueError(
                    "Either one of the following should be true: "
                    "last_frames and num_transformed_images, first_image_background, "
                    "prev_image_background, generate_scratch_image")

        with tf.name_scope('gen_images'):
            assert len(transformed_images) == len(masks)
            gen_image = tf.add_n([
                transformed_image * mask
                for transformed_image, mask in zip(transformed_images, masks)
            ])

        if 'pix_distribs' in inputs:
            with tf.name_scope('gen_pix_distribs'):
                assert len(transformed_pix_distribs) == len(masks)
                gen_pix_distrib = tf.add_n([
                    transformed_pix_distrib * mask
                    for transformed_pix_distrib, mask in zip(
                        transformed_pix_distribs, masks)
                ])
                # TODO: is this needed?
                # gen_pix_distrib /= tf.reduce_sum(gen_pix_distrib, axis=(1, 2), keepdims=True)

        if 'states' in inputs:
            with tf.name_scope('gen_states'):
                with tf.variable_scope('state_pred'):
                    gen_state = dense(state_action,
                                      inputs['states'].shape[-1].value)

        outputs = {
            'gen_images': gen_image,
            'gen_inputs': gen_input,
            'transformed_images': tf.stack(transformed_images, axis=-1),
            'masks': tf.stack(masks, axis=-1),
            'gen_flow': flows
        }
        if 'pix_distribs' in inputs:
            outputs['gen_pix_distribs'] = gen_pix_distrib
            outputs['transformed_pix_distribs'] = tf.stack(
                transformed_pix_distribs, axis=-1)
        if 'states' in inputs:
            outputs['gen_states'] = gen_state

        new_states = {
            'time': time + 1,
            'last_gt_images': last_gt_images,
            'last_pred_flows': last_pred_flows,
            'gen_image': gen_image,
            'last_images': last_images,
            'last_base_images': last_base_images,
            'last_base_flows': last_base_flows,
            'conv_rnn_states': new_conv_rnn_states
        }
        if 'zs' in inputs and self.hparams.use_rnn_z:
            new_states['rnn_z_state'] = rnn_z_state
        if 'pix_distribs' in inputs:
            new_states['last_gt_pix_distribs'] = last_gt_pix_distribs
            new_states['last_base_pix_distribs'] = last_base_pix_distribs
        if 'states' in inputs:
            new_states['gen_state'] = gen_state
        return outputs, new_states
示例#3
0
 def apply_transformation(feature_and_flow):
     feature, flow = feature_and_flow
     return flow_ops.image_warp(feature[..., None], flow)