def apply_flows(image, flows): if isinstance(image, list): image_list = image flows_list = tf.split(flows, len(image_list), axis=-1) outputs = [] for image, flows in zip(image_list, flows_list): outputs.extend(apply_flows(image, flows)) else: flows = tf.unstack(flows, axis=-1) outputs = [flow_ops.image_warp(image, flow) for flow in flows] return outputs
def call(self, inputs, states): norm_layer = ops.get_norm_layer(self.hparams.norm_layer) image_shape = inputs['images'].get_shape().as_list() batch_size, height, width, color_channels = image_shape conv_rnn_states = states['conv_rnn_states'] time = states['time'] with tf.control_dependencies([tf.assert_equal(time[1:], time[0])]): t = tf.to_int32(tf.identity(time[0])) last_gt_images = states['last_gt_images'][1:] + [inputs['images']] last_pred_flows = states['last_pred_flows'][1:] + [ tf.zeros_like(states['last_pred_flows'][-1]) ] image = tf.where(self.ground_truth[t], inputs['images'], states['gen_image']) last_images = states['last_images'][1:] + [image] last_base_images = [ tf.where(self.ground_truth[t], last_gt_image, last_base_image) for last_gt_image, last_base_image in zip( last_gt_images, states['last_base_images']) ] last_base_flows = [ tf.where(self.ground_truth[t], last_pred_flow, last_base_flow) for last_pred_flow, last_base_flow in zip( last_pred_flows, states['last_base_flows']) ] if 'pix_distribs' in inputs: last_gt_pix_distribs = states['last_gt_pix_distribs'][1:] + [ inputs['pix_distribs'] ] last_base_pix_distribs = [ tf.where(self.ground_truth[t], last_gt_pix_distrib, last_base_pix_distrib) for last_gt_pix_distrib, last_base_pix_distrib in zip( last_gt_pix_distribs, states['last_base_pix_distribs']) ] if 'states' in inputs: state = tf.where(self.ground_truth[t], inputs['states'], states['gen_state']) state_action = [] state_action_z = [] if 'actions' in inputs: state_action.append(inputs['actions']) state_action_z.append(inputs['actions']) if 'states' in inputs: state_action.append(state) # don't backpropagate the convnet through the state dynamics state_action_z.append(tf.stop_gradient(state)) if 'zs' in inputs: if self.hparams.use_rnn_z: with tf.variable_scope('%s_z' % self.hparams.rnn): rnn_z, rnn_z_state = self._rnn_func( inputs['zs'], states['rnn_z_state'], self.hparams.nz) state_action_z.append(rnn_z) else: state_action_z.append(inputs['zs']) def concat(tensors, axis): if len(tensors) == 0: return tf.zeros([batch_size, 0]) elif len(tensors) == 1: return tensors[0] else: return tf.concat(tensors, axis=axis) state_action = concat(state_action, axis=-1) state_action_z = concat(state_action_z, axis=-1) if 'actions' in inputs: gen_input = tile_concat(last_images + [inputs['actions'][:, None, None, :]], axis=-1) else: gen_input = tf.concat(last_images) layers = [] new_conv_rnn_states = [] for i, (out_channels, use_conv_rnn) in enumerate(self.encoder_layer_specs): with tf.variable_scope('h%d' % i): if i == 0: h = tf.concat(last_images, axis=-1) kernel_size = (5, 5) else: h = layers[-1][-1] kernel_size = (3, 3) h = conv_pool2d(tile_concat( [h, state_action_z[:, None, None, :]], axis=-1), out_channels, kernel_size=kernel_size, strides=(2, 2)) h = norm_layer(h) h = tf.nn.relu(h) if use_conv_rnn: conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)] with tf.variable_scope('%s_h%d' % (self.hparams.conv_rnn, i)): conv_rnn_h, conv_rnn_state = self._conv_rnn_func( tile_concat([h, state_action_z[:, None, None, :]], axis=-1), conv_rnn_state, out_channels) new_conv_rnn_states.append(conv_rnn_state) layers.append((h, conv_rnn_h) if use_conv_rnn else (h, )) num_encoder_layers = len(layers) for i, (out_channels, use_conv_rnn) in enumerate(self.decoder_layer_specs): with tf.variable_scope('h%d' % len(layers)): if i == 0: h = layers[-1][-1] else: h = tf.concat([ layers[-1][-1], layers[num_encoder_layers - i - 1][-1] ], axis=-1) h = upsample_conv2d(tile_concat( [h, state_action_z[:, None, None, :]], axis=-1), out_channels, kernel_size=(3, 3), strides=(2, 2)) h = norm_layer(h) h = tf.nn.relu(h) if use_conv_rnn: conv_rnn_state = conv_rnn_states[len(new_conv_rnn_states)] with tf.variable_scope('%s_h%d' % (self.hparams.conv_rnn, len(layers))): conv_rnn_h, conv_rnn_state = self._conv_rnn_func( tile_concat([h, state_action_z[:, None, None, :]], axis=-1), conv_rnn_state, out_channels) new_conv_rnn_states.append(conv_rnn_state) layers.append((h, conv_rnn_h) if use_conv_rnn else (h, )) assert len(new_conv_rnn_states) == len(conv_rnn_states) with tf.variable_scope('h%d_flow' % len(layers)): h_flow = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1)) h_flow = norm_layer(h_flow) h_flow = tf.nn.relu(h_flow) with tf.variable_scope('flows'): flows = conv2d(h_flow, 2 * self.hparams.last_frames, kernel_size=(3, 3), strides=(1, 1)) flows = tf.reshape( flows, [batch_size, height, width, 2, self.hparams.last_frames]) with tf.name_scope('transformed_images'): transformed_images = [] last_pred_flows = [ flow + flow_ops.image_warp(last_pred_flow, flow) for last_pred_flow, flow in zip(last_pred_flows, tf.unstack(flows, axis=-1)) ] last_base_flows = [ flow + flow_ops.image_warp(last_base_flow, flow) for last_base_flow, flow in zip(last_base_flows, tf.unstack(flows, axis=-1)) ] for last_base_image, last_base_flow in zip(last_base_images, last_base_flows): transformed_images.append( flow_ops.image_warp(last_base_image, last_base_flow)) if 'pix_distribs' in inputs: with tf.name_scope('transformed_pix_distribs'): transformed_pix_distribs = [] for last_base_pix_distrib, last_base_flow in zip( last_base_pix_distribs, last_base_flows): transformed_pix_distribs.append( flow_ops.image_warp(last_base_pix_distrib, last_base_flow)) with tf.name_scope('masks'): if len(transformed_images) > 1: with tf.variable_scope('h%d_masks' % len(layers)): h_masks = conv2d(layers[-1][-1], self.hparams.ngf, kernel_size=(3, 3), strides=(1, 1)) h_masks = norm_layer(h_masks) h_masks = tf.nn.relu(h_masks) with tf.variable_scope('masks'): if self.hparams.dependent_mask: h_masks = tf.concat([h_masks] + transformed_images, axis=-1) masks = conv2d(h_masks, len(transformed_images), kernel_size=(3, 3), strides=(1, 1)) masks = tf.nn.softmax(masks) masks = tf.split(masks, len(transformed_images), axis=-1) elif len(transformed_images) == 1: masks = [tf.ones([batch_size, height, width, 1])] else: raise ValueError( "Either one of the following should be true: " "last_frames and num_transformed_images, first_image_background, " "prev_image_background, generate_scratch_image") with tf.name_scope('gen_images'): assert len(transformed_images) == len(masks) gen_image = tf.add_n([ transformed_image * mask for transformed_image, mask in zip(transformed_images, masks) ]) if 'pix_distribs' in inputs: with tf.name_scope('gen_pix_distribs'): assert len(transformed_pix_distribs) == len(masks) gen_pix_distrib = tf.add_n([ transformed_pix_distrib * mask for transformed_pix_distrib, mask in zip( transformed_pix_distribs, masks) ]) # TODO: is this needed? # gen_pix_distrib /= tf.reduce_sum(gen_pix_distrib, axis=(1, 2), keepdims=True) if 'states' in inputs: with tf.name_scope('gen_states'): with tf.variable_scope('state_pred'): gen_state = dense(state_action, inputs['states'].shape[-1].value) outputs = { 'gen_images': gen_image, 'gen_inputs': gen_input, 'transformed_images': tf.stack(transformed_images, axis=-1), 'masks': tf.stack(masks, axis=-1), 'gen_flow': flows } if 'pix_distribs' in inputs: outputs['gen_pix_distribs'] = gen_pix_distrib outputs['transformed_pix_distribs'] = tf.stack( transformed_pix_distribs, axis=-1) if 'states' in inputs: outputs['gen_states'] = gen_state new_states = { 'time': time + 1, 'last_gt_images': last_gt_images, 'last_pred_flows': last_pred_flows, 'gen_image': gen_image, 'last_images': last_images, 'last_base_images': last_base_images, 'last_base_flows': last_base_flows, 'conv_rnn_states': new_conv_rnn_states } if 'zs' in inputs and self.hparams.use_rnn_z: new_states['rnn_z_state'] = rnn_z_state if 'pix_distribs' in inputs: new_states['last_gt_pix_distribs'] = last_gt_pix_distribs new_states['last_base_pix_distribs'] = last_base_pix_distribs if 'states' in inputs: new_states['gen_state'] = gen_state return outputs, new_states
def apply_transformation(feature_and_flow): feature, flow = feature_and_flow return flow_ops.image_warp(feature[..., None], flow)