def testShapeInferenceAndChecks(self): output_shape2d = (2, 3) source_shape2d = (6, 9) constraints = scale_2d(y=1) & translation_2d(x=-2, y=7) agw2d = snt.AffineGridWarper(source_shape=source_shape2d, output_shape=output_shape2d, constraints=constraints) input_params2d = tf.placeholder(tf.float32, [None, constraints.num_free_params]) warped_grid2d = agw2d(input_params2d) self.assertEqual(warped_grid2d.get_shape().as_list()[1:], [2, 3, 2]) output_shape2d = (2, 3) source_shape3d = (100, 200, 50) agw3d = snt.AffineGridWarper(source_shape=source_shape3d, output_shape=output_shape2d, constraints=[[None, 0, None, None], [0, 1, 0, None], [0, None, 0, None]]) input_params3d = tf.placeholder( tf.float32, [None, agw3d.constraints.num_free_params]) warped_grid3d = agw3d(input_params3d) self.assertEqual(warped_grid3d.get_shape().as_list()[1:], [2, 3, 3]) output_shape3d = (2, 3, 4) source_shape3d = (100, 200, 50) agw3d = snt.AffineGridWarper(source_shape=source_shape3d, output_shape=output_shape3d, constraints=[[None, 0, None, None], [0, 1, 0, None], [0, None, 0, None]]) input_params3d = tf.placeholder( tf.float32, [None, agw3d.constraints.num_free_params]) warped_grid3d = agw3d(input_params3d) self.assertEqual(warped_grid3d.get_shape().as_list()[1:], [2, 3, 4, 3]) with self.assertRaisesRegexp( snt.Error, "Incompatible set of constraints provided.*"): snt.AffineGridWarper(source_shape=source_shape3d, output_shape=output_shape3d, constraints=no_constraints(2)) with self.assertRaisesRegexp(snt.Error, "Output domain dimensionality.*"): snt.AffineGridWarper(source_shape=source_shape2d, output_shape=output_shape3d, constraints=no_constraints(2))
def extract_affine_glimpse(image, object_shape, cyt, cxt, ys, xs, edge_resampler): """ (cyt, cxt) are rectangle center. (ys, xs) are rectangle height/width """ _, *image_shape, image_depth = tf_shape(image) transform_constraints = snt.AffineWarpConstraints.no_shear_2d() warper = snt.AffineGridWarper(image_shape, object_shape, transform_constraints) # change coordinate system cyt = 2 * cyt - 1 cxt = 2 * cxt - 1 leading_shape = tf_shape(cyt)[:-1] _boxes = tf.concat([xs, cxt, ys, cyt], axis=-1) _boxes = tf.reshape(_boxes, (-1, 4)) grid_coords = warper(_boxes) grid_coords = tf.reshape(grid_coords, (*leading_shape, *object_shape, 2)) if edge_resampler: glimpses = resampler_edge.resampler_edge(image, grid_coords) else: glimpses = tf.contrib.resampler.resampler(image, grid_coords) glimpses = tf.reshape(glimpses, (*leading_shape, *object_shape, image_depth)) return glimpses
def __init__(self, img_size, crop_size, constraints=None, inverse=False): super(SpatialTransformer, self).__init__(self.__class__.__name__) with self._enter_variable_scope(): self._warper = snt.AffineGridWarper(img_size, crop_size, constraints) if inverse: self._warper = self._warper.inverse()
def __init__(self, img_size, crop_size, inverse=False): super(SpatialTransformer, self).__init__() with self._enter_variable_scope(): constraints = snt.AffineWarpConstraints.no_shear_2d() self._warper = snt.AffineGridWarper(img_size, crop_size, constraints) if inverse: self._warper = self._warper.inverse()
def testInvSameAsNumPyRef(self, output_shape, source_shape, constraints): def chain(x): return itertools.chain(*x) def predict(output_shape, source_shape, inputs): ranges = [ np.linspace(-1, 1, x, dtype=np.float32) for x in reversed(source_shape) ] n = len(output_shape) grid = np.meshgrid(*ranges, indexing="xy") for _ in range(len(source_shape), len(output_shape)): grid.append(np.zeros_like(grid[0])) grid.append(np.ones_like(grid[0])) grid = np.array([x.reshape(1, -1) for x in grid]).squeeze() predicted_output = [] for i in range(0, batch_size): affine_matrix = inputs[i, :].reshape(n, n + 1) inv_matrix = np.linalg.inv(affine_matrix[:2, :2]) inv_transform = np.concatenate([ inv_matrix, -np.dot(inv_matrix, affine_matrix[:, 2].reshape(2, 1)) ], 1) x = np.dot(inv_transform, grid) for k, s in enumerate(reversed(output_shape)): s = (s - 1) * 0.5 x[k, :] = x[k, :] * s + s x = np.concatenate([v.reshape(v.shape + (1, )) for v in x], -1) predicted_output.append(x.reshape(tuple(source_shape) + (n, ))) return predicted_output batch_size = 20 agw = snt.AffineGridWarper(source_shape=source_shape, output_shape=output_shape, constraints=constraints).inverse() inputs = tf.placeholder(tf.float32, [None, constraints.num_free_params]) warped_grid = agw(inputs) full_size = constraints.num_dim * (constraints.num_dim + 1) # Adding a bit of mass to the matrix to avoid singular matrices full_input_np = np.random.rand(batch_size, full_size) + 0.1 con_i = [i for i, x in enumerate(chain(constraints.mask)) if not x] con_val = [x for x in chain(constraints.constraints) if x is not None] for i, v in zip(con_i, con_val): full_input_np[:, i] = v uncon_i = [i for i, x in enumerate(chain(constraints.mask)) if x] with self.test_session() as sess: output = sess.run(warped_grid, feed_dict={inputs: full_input_np[:, uncon_i]}) self.assertAllClose(output, predict(output_shape, source_shape, full_input_np), rtol=1e-05, atol=1e-05)
def spatial_transformer(img_tensor, transform_params, crop_size): """ :param img_tensor: tf.Tensor of size (batch_size, Height, Width, channels) :param transform_params: tf.Tensor of size (batch_size, 4), where params are (scale_y, shift_y, scale_x, shift_x) :param crop_size): tuple of 2 ints, size of the resulting crop """ constraints = snt.AffineWarpConstraints.no_shear_2d() img_size = img_tensor.shape.as_list()[1:] warper = snt.AffineGridWarper(img_size, crop_size, constraints) grid_coords = warper(transform_params) glimpse = snt.resampler(img_tensor[..., tf.newaxis], grid_coords) return glimpse
def __init__(self, img_size, crop_size, inverse=False): """Initialises the module. :param img_size: Tuple of ints, size of the input image. :param crop_size: Tuple of ints, size of the resampled image. :param inverse: Boolean; inverts the given transformation if True and then maps crops into full-sized images. """ super(SpatialTransformer, self).__init__() with self._enter_variable_scope(): constraints = snt.AffineWarpConstraints.no_shear_2d() self._warper = snt.AffineGridWarper(img_size[:2], crop_size, constraints) if inverse: self._warper = self._warper.inverse()
def testIdentity(self): constraints = snt.AffineWarpConstraints.no_constraints() warper = snt.AffineGridWarper([3, 3], [3, 3], constraints=constraints) p = tf.placeholder(tf.float64, (None, constraints.num_free_params)) grid = warper(p) with self.test_session() as sess: warp_p = np.array([1, 0, 0, 0, 1, 0]).reshape([1, constraints.num_free_params]) output = sess.run(grid, feed_dict={p: warp_p}) # Check that output matches expected result for a known transformation. self.assertAllClose(output, np.array([[[[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]], [[0.0, 1.0], [1.0, 1.0], [2.0, 1.0]], [[0.0, 2.0], [1.0, 2.0], [2.0, 2.0]]]]))
[corners, np.ones((1, corners.shape[1]))], axis=0) left = image_corners[0, 0] top = image_corners[1, 0] right = image_corners[0, 1] bottom = image_corners[1, 1] width = right - left height = bottom - top boxes = boxes[:, [0, 2, 4, 5]] boxes = np.tile(boxes, (n_examples, 1)) boxes = tf.constant(boxes, tf.float32) transform_constraints = snt.AffineWarpConstraints.no_shear_2d() warper = snt.AffineGridWarper(image_shape[:2], crop_shape[:2], transform_constraints) grid_coords = warper(boxes) output = tf.contrib.resampler.resampler(images, grid_coords) sess = tf.Session() crops, _grid_coords = sess.run([output, grid_coords]) print(_grid_coords) # import matplotlib # matplotlib.use('pdf') import matplotlib.pyplot as plt import matplotlib.patches as patches fig, axes = plt.subplots(n_examples, 2) for image, crop, ax in zip(images, crops, axes):
def _build(self, pose, presence=None, template_feature=None, bg_image=None, img_embedding=None): """Builds the module. Args: pose: [B, n_templates, 6] tensor. presence: [B, n_templates] tensor. template_feature: [B, n_templates, n_features] tensor; these features are used to change templates based on the input, if present. bg_image: [B, *output_size] tensor representing the background. img_embedding: [B, d] tensor containing image embeddings. Returns: [B, n_templates, *output_size, n_channels] tensor. """ batch_size, n_templates = pose.shape[:2].as_list() templates = self.make_templates(n_templates, template_feature) if templates.shape[0] == 1: templates = snt.TileByDim([0], [batch_size])(templates) # it's easier for me to think in inverse coordinates warper = snt.AffineGridWarper(self._output_size, self._template_size) warper = warper.inverse() grid_coords = snt.BatchApply(warper)(pose) resampler = snt.BatchApply(contrib_resampler.resampler) transformed_templates = resampler(templates, grid_coords) if bg_image is not None: bg_image = tf.expand_dims(bg_image, axis=1) else: bg_image = tf.nn.sigmoid(tf.get_variable('bg_value', shape=[1])) bg_image = tf.zeros_like(transformed_templates[:, :1]) + bg_image transformed_templates = tf.concat([transformed_templates, bg_image], axis=1) if presence is not None: presence = tf.concat([presence, tf.ones([batch_size, 1])], axis=1) if True: # pylint: disable=using-constant-test if self._use_alpha_channel: template_mixing_logits = snt.TileByDim([0], [batch_size])( self._templates_alpha) template_mixing_logits = resampler(template_mixing_logits, grid_coords) bg_mixing_logit = tf.nn.softplus( tf.get_variable('bg_mixing_logit', initializer=[0.])) bg_mixing_logit = ( tf.zeros_like(template_mixing_logits[:, :1]) + bg_mixing_logit) template_mixing_logits = tf.concat( [template_mixing_logits, bg_mixing_logit], 1) else: temperature_logit = tf.get_variable('temperature_logit', shape=[1]) temperature = tf.nn.softplus(temperature_logit + .5) + 1e-4 template_mixing_logits = transformed_templates / temperature scale = 1. if self._learn_output_scale: scale = tf.get_variable('scale', shape=[1]) scale = tf.nn.softplus(scale) + 1e-4 if self._output_pdf_type == 'mixture': template_mixing_logits += make_brodcastable( math_ops.safe_log(presence), template_mixing_logits) rec_pdf = prob.MixtureDistribution(template_mixing_logits, [transformed_templates, scale], tfd.Normal) else: raise ValueError('Unknown pdf type: "{}".'.format( self._output_pdf_type)) return AttrDict(raw_templates=tf.squeeze(self._templates, 0), transformed_templates=transformed_templates[:, :-1], mixing_logits=template_mixing_logits[:, :-1], pdf=rec_pdf)
def _build_program_interpreter(self, tensors): # --- Get object attributes using object encoder --- max_objects = tensors["max_objects"] yt, xt, ys, xs = tf.split(tensors["normalized_box"], 4, axis=-1) transform_constraints = snt.AffineWarpConstraints.no_shear_2d() warper = snt.AffineGridWarper((self.image_height, self.image_width), self.object_shape, transform_constraints) _boxes = tf.concat( [xs, 2 * (xt + xs / 2) - 1, ys, 2 * (yt + ys / 2) - 1], axis=-1) _boxes = tf.reshape(_boxes, (self.batch_size * max_objects, 4)) grid_coords = warper(_boxes) grid_coords = tf.reshape(grid_coords, ( self.batch_size, max_objects, *self.object_shape, 2, )) glimpse = tf.contrib.resampler.resampler(tensors["inp"], grid_coords) object_encoder_in = tf.reshape(glimpse, (self.batch_size * max_objects, *self.object_shape, self.image_depth)) attr = self.object_encoder(object_encoder_in, (1, 1, 2 * self.A), self.is_training) attr = tf.reshape(attr, (self.batch_size, max_objects, 2 * self.A)) attr_mean, attr_log_std = tf.split(attr, [self.A, self.A], axis=-1) attr_std = tf.exp(attr_log_std) if not self.noisy: attr_std = tf.zeros_like(attr_std) attr, attr_kl = normal_vae(attr_mean, attr_std, self.attr_prior_mean, self.attr_prior_std) object_decoder_in = tf.reshape( attr, (self.batch_size * max_objects, 1, 1, self.A)) # --- Compute sprites from attr using object decoder --- object_logits = self.object_decoder( object_decoder_in, self.object_shape + (self.image_depth, ), self.is_training) objects = tf.nn.sigmoid(tf.clip_by_value(object_logits, -10., 10.)) objects = tf.reshape(objects, ( self.batch_size, max_objects, *self.object_shape, self.image_depth, )) alpha = tensors["obj"][:, :, :, None, None] * tf.ones_like( objects[:, :, :, :, :1]) importance = tf.ones_like(objects[:, :, :, :, :1]) objects = tf.concat([objects, alpha, importance], axis=-1) # -- Reconstruct image --- scales = tf.concat([ys, xs], axis=-1) scales = tf.reshape(scales, (self.batch_size, max_objects, 2)) offsets = tf.concat([yt, xt], axis=-1) offsets = tf.reshape(offsets, (self.batch_size, max_objects, 2)) output = render_sprites.render_sprites(objects, tensors["n_objects"], scales, offsets, tensors["background"]) return dict(output=output, glimpse=tf.reshape(glimpse, (self.batch_size, max_objects, *self.object_shape, self.image_depth)), attr=tf.reshape(attr, (self.batch_size, max_objects, self.A)), attr_kl=tf.reshape(attr_kl, (self.batch_size, max_objects, self.A)), objects=tf.reshape(objects, ( self.batch_size, max_objects, *self.object_shape, self.image_depth, )))
def build_representation(self): # --- init modules --- if self.backbone is None: self.backbone = self.build_backbone(scope="backbone") if self.cell is None: self.cell = cfg.build_cell(self.n_hidden, name="cell") # self.cell must be a Sonnet RNNCore if self.learn_initial_state: self.initial_hidden_state = snt.trainable_initial_state( 1, self.cell.state_size, tf.float32, name="initial_hidden_state") if self.key_network is None: self.key_network = cfg.build_key_network(scope="key_network") if self.beta_network is None: self.beta_network = cfg.build_beta_network(scope="beta_network") if self.write_network is None: self.write_network = cfg.build_write_network(scope="write_network") if self.erase_network is None: self.erase_network = cfg.build_erase_network(scope="erase_network") if self.output_network is None: self.output_network = cfg.build_output_network( scope="output_network") d_n_frames, n_trackers, batch_size = self.dynamic_n_frames, self.n_trackers, self.batch_size # --- encode --- video = tf.reshape(self.inp, (batch_size * d_n_frames, *self.obs_shape[1:])) zero_to_one_Y = tf.linspace(0., 1., self.image_height) zero_to_one_X = tf.linspace(0., 1., self.image_width) X, Y = tf.meshgrid(zero_to_one_X, zero_to_one_Y) X = tf.tile(X[None, :, :, None], (batch_size * d_n_frames, 1, 1, 1)) Y = tf.tile(Y[None, :, :, None], (batch_size * d_n_frames, 1, 1, 1)) video = tf.concat([video, Y, X], axis=-1) encoded_frames, _, _ = self.backbone(video, self.S, self.is_training) _, H, W, _ = tf_shape(encoded_frames) self.H = H self.W = W encoded_frames = tf.reshape(encoded_frames, (batch_size, d_n_frames, H * W, self.S)) self.encoded_frames = encoded_frames cts = tf.minimum( 1., self.float_is_training + (1 - float(self.discrete_eval))) self.mask_temperature = cts * 1.0 + (1 - cts) * 1e-5 self.layer_temperature = cts * 1.0 + (1 - cts) * 1e-5 f = tf.constant(0, dtype=tf.int32) if self.learn_initial_state: hidden_state = self.initial_hidden_state[None, ...] else: hidden_state = self.cell.zero_state(1, tf.float32)[None, ...] hidden_states = tf.tile(hidden_state, ( self.n_trackers, self.batch_size, ) + (1, ) * (len(hidden_state.shape) - 2)) conf = tf.zeros((n_trackers, batch_size, 1)) structure = dict( hidden_states=hidden_states, tracker_output=tf.zeros( (self.n_trackers, self.batch_size, self.n_hidden)), attention_result=tf.zeros( (self.n_trackers, self.batch_size, self.S)), attention_weights=tf.zeros( (self.n_trackers, self.batch_size, H * W)), memory_activation=tf.zeros( (self.n_trackers, self.batch_size, H * W)), conf=conf, layer=tf.zeros((self.n_trackers, self.batch_size, self.n_layers)), pose=tf.zeros((self.n_trackers, self.batch_size, 4)), mask=tf.zeros((self.n_trackers, self.batch_size, np.prod(self.object_shape))), appearance=tf.zeros((self.n_trackers, self.batch_size, 3 * np.prod(self.object_shape))), order=tf.zeros((self.n_trackers, self.batch_size, 1), dtype=tf.int32), ) tensor_arrays = make_tensor_arrays(structure, self.dynamic_n_frames) loop_vars = [f, conf, hidden_states, *tensor_arrays] result = tf.while_loop(self._loop_cond, self._loop_body, loop_vars) first_ta_idx = min(i for i, ta in enumerate(result) if isinstance(ta, tf.TensorArray)) tensor_arrays = result[first_ta_idx:] def finalize_ta(ta): t = ta.stack() # reshape from (n_frames, n_trackers, batch_size, *other) to (batch_size, n_frames, n_trackers, *other) return tf.transpose(t, (2, 0, 1, *range(3, len(t.shape)))) tensors = map_structure( finalize_ta, tensor_arrays, is_leaf=lambda t: isinstance(t, tf.TensorArray)) tensors = apply_keys(structure, tensors) self._tensors.update(tensors) pprint.pprint(self._tensors) # --- render/decode --- ys, xs, yt, xt = tf.split(tensors["pose"], 4, axis=-1) self.record_tensors( conf=tensors["conf"], ys=ys, xs=xs, yt=yt, xt=xt, ) yt_normed = (yt + 1) / 2 xt_normed = (xt + 1) / 2 ys_normed = 1 + self.eta[0] * ys xs_normed = 1 + self.eta[1] * xs normalized_box = tf.concat( [yt_normed, xt_normed, ys_normed, xs_normed], axis=-1) # expose values for plotting self._tensors.update( normalized_box=normalized_box, mask=tf.reshape( tensors["mask"], (batch_size, d_n_frames, n_trackers, *self.object_shape)), appearance=tf.reshape(tensors["appearance"], (batch_size, d_n_frames, n_trackers, *self.object_shape, self.image_depth)), conf=tensors["conf"], layer=tensors["layer"], order=tensors["order"], ) # --- reshape values --- N = batch_size * d_n_frames * n_trackers ys_normed = tf.reshape(ys_normed, (N, 1)) xs_normed = tf.reshape(xs_normed, (N, 1)) yt_normed = tf.reshape(yt_normed, (N, 1)) xt_normed = tf.reshape(xt_normed, (N, 1)) _yt, _xt, _ys, _xs = tba_coords_to_image_space( yt_normed, xt_normed, ys_normed, xs_normed, (self.image_height, self.image_width), self.anchor_box, top_left=False) mask = tf.reshape(tensors["mask"], (N, *self.object_shape, 1)) appearance = tf.reshape(tensors["appearance"], (N, *self.object_shape, self.image_depth)) transform_constraints = snt.AffineWarpConstraints.no_shear_2d() warper = snt.AffineGridWarper((self.image_height, self.image_width), self.object_shape, transform_constraints) inverse_warper = warper.inverse() transforms = tf.concat([_xs, _xt, _ys, _yt], axis=-1) grid_coords = inverse_warper(transforms) transformed_masks = tf.contrib.resampler.resampler(mask, grid_coords) transformed_masks = tf.reshape( transformed_masks, (batch_size, d_n_frames, n_trackers, self.image_height, self.image_width, 1)) transformed_appearances = tf.contrib.resampler.resampler( appearance, grid_coords) transformed_appearances = tf.reshape( transformed_appearances, (batch_size, d_n_frames, n_trackers, self.image_height, self.image_width, self.image_depth)) layer_masks = [] layer_appearances = [] conf = tensors["conf"][:, :, :, :, None, None] # TODO: currently assuming a black background final_frames = tf.zeros((batch_size, d_n_frames, self.image_height, self.image_width, self.image_depth)) # For each layer, create a mask image and an appearance image for layer_idx in range(self.n_layers): layer_weight = tensors["layer"][:, :, :, layer_idx, None, None, None] # (batch_size, n_frames, self.image_height, self.image_width, 1) layer_mask = tf.reduce_sum(conf * layer_weight * transformed_masks, axis=2) layer_mask = tf.minimum(1.0, layer_mask) # (batch_size, n_frames, self.image_height, self.image_width, 3) layer_appearance = tf.reduce_sum(conf * layer_weight * transformed_masks * transformed_appearances, axis=2) if self.clamp_appearance: layer_appearance = tf.minimum(1.0, layer_appearance) final_frames = (1 - layer_mask) * final_frames + layer_appearance layer_masks.append(layer_mask) layer_appearances.append(layer_appearance) self._tensors["output"] = final_frames # --- losses --- self._tensors['per_pixel_reconstruction_loss'] = (self.inp - final_frames)**2 self.losses['reconstruction'] = ( tf.reduce_sum(self._tensors["per_pixel_reconstruction_loss"]) / tf.cast(d_n_frames * self.batch_size, tf.float32)) # self.losses['reconstruction'] = tf.reduce_mean(self._tensors["per_pixel_reconstruction_loss"]) self.losses['area'] = self.lmbda * tf.reduce_mean( ys_normed * xs_normed)
def _call(self, inp, inp_features, is_training, is_posterior=True, prop_state=None): print("\n" + "-" * 10 + " ConvGridObjectLayer(is_posterior={}) ".format(is_posterior) + "-" * 10) # --- set up sub networks and attributes --- self.maybe_build_subnet("box_network", builder=cfg.build_conv_lateral, key="box") self.maybe_build_subnet("attr_network", builder=cfg.build_conv_lateral, key="attr") self.maybe_build_subnet("z_network", builder=cfg.build_conv_lateral, key="z") self.maybe_build_subnet("obj_network", builder=cfg.build_conv_lateral, key="obj") self.maybe_build_subnet("object_encoder") _, H, W, _, n_channels = tf_shape(inp_features) if self.B != 1: raise Exception("NotImplemented") if not self.initialized: # Note this limits the re-usability of this module to images # with a fixed shape (the shape of the first image it is used on) self.batch_size, self.image_height, self.image_width, self.image_depth = tf_shape(inp) self.H = H self.W = W self.HWB = H*W self.batch_size = tf.shape(inp)[0] self.is_training = is_training self.float_is_training = tf.to_float(is_training) inp_features = tf.reshape(inp_features, (self.batch_size, H, W, n_channels)) is_posterior_tf = tf.ones_like(inp_features[..., :2]) if is_posterior: is_posterior_tf = is_posterior_tf * [1, 0] else: is_posterior_tf = is_posterior_tf * [0, 1] objects = AttrDict() base_features = tf.concat([inp_features, is_posterior_tf], axis=-1) # --- box --- layer_inp = base_features n_features = self.n_passthrough_features output_size = 8 network_output = self.box_network(layer_inp, output_size + n_features, self.is_training) rep_input, features = tf.split(network_output, (output_size, n_features), axis=-1) _objects = self._build_box(rep_input, self.is_training) objects.update(_objects) # --- attr --- if is_posterior: # --- Get object attributes using object encoder --- yt, xt, ys, xs = tf.split(objects['normalized_box'], 4, axis=-1) yt, xt, ys, xs = coords_to_image_space( yt, xt, ys, xs, (self.image_height, self.image_width), self.anchor_box, top_left=False) transform_constraints = snt.AffineWarpConstraints.no_shear_2d() warper = snt.AffineGridWarper( (self.image_height, self.image_width), self.object_shape, transform_constraints) _boxes = tf.concat([xs, 2*xt - 1, ys, 2*yt - 1], axis=-1) _boxes = tf.reshape(_boxes, (self.batch_size*H*W, 4)) grid_coords = warper(_boxes) grid_coords = tf.reshape(grid_coords, (self.batch_size, H, W, *self.object_shape, 2,)) if self.edge_resampler: glimpse = resampler_edge.resampler_edge(inp, grid_coords) else: glimpse = tf.contrib.resampler.resampler(inp, grid_coords) else: glimpse = tf.zeros((self.batch_size, H, W, *self.object_shape, self.image_depth)) # Create the object encoder network regardless of is_posterior, otherwise messes with ScopedFunction encoded_glimpse = apply_object_wise( self.object_encoder, glimpse, n_trailing_dims=3, output_size=self.A, is_training=self.is_training) if not is_posterior: encoded_glimpse = tf.zeros_like(encoded_glimpse) layer_inp = tf.concat([base_features, features, encoded_glimpse, objects['local_box']], axis=-1) network_output = self.attr_network(layer_inp, 2 * self.A + n_features, self.is_training) attr_mean, attr_log_std, features = tf.split(network_output, (self.A, self.A, n_features), axis=-1) attr_std = self.std_nonlinearity(attr_log_std) attr = Normal(loc=attr_mean, scale=attr_std).sample() objects.update(attr_mean=attr_mean, attr_std=attr_std, attr=attr, glimpse=glimpse) # --- z --- layer_inp = tf.concat([base_features, features, objects['local_box'], objects['attr']], axis=-1) n_features = self.n_passthrough_features network_output = self.z_network(layer_inp, 2 + n_features, self.is_training) z_mean, z_log_std, features = tf.split(network_output, (1, 1, n_features), axis=-1) z_std = self.std_nonlinearity(z_log_std) z_mean = self.training_wheels * tf.stop_gradient(z_mean) + (1-self.training_wheels) * z_mean z_std = self.training_wheels * tf.stop_gradient(z_std) + (1-self.training_wheels) * z_std z_logit = Normal(loc=z_mean, scale=z_std).sample() z = self.z_nonlinearity(z_logit) objects.update(z_logit_mean=z_mean, z_logit_std=z_std, z_logit=z_logit, z=z) # --- obj --- layer_inp = tf.concat([base_features, features, objects['local_box'], objects['attr'], objects['z']], axis=-1) rep_input = self.obj_network(layer_inp, 1, self.is_training) _objects = self._build_obj(rep_input, self.is_training) objects.update(_objects) # --- final --- _objects = AttrDict() for k, v in objects.items(): _, _, _, *trailing_dims = tf_shape(v) _objects[k] = tf.reshape(v, (self.batch_size, self.HWB, *trailing_dims)) objects = _objects if prop_state is not None: objects.prop_state = tf.tile(prop_state[0:1, None], (self.batch_size, self.HWB, 1)) objects.prior_prop_state = tf.tile(prop_state[0:1, None], (self.batch_size, self.HWB, 1)) # --- misc --- objects.n_objects = tf.fill((self.batch_size,), self.HWB) objects.pred_n_objects = tf.reduce_sum(objects.obj, axis=(1, 2)) objects.pred_n_objects_hard = tf.reduce_sum(tf.round(objects.obj), axis=(1, 2)) return objects
def _call(self, inp, inp_features, is_training, is_posterior=True, prop_state=None): print("\n" + "-" * 10 + " GridObjectLayer(is_posterior={}) ".format(is_posterior) + "-" * 10) # --- set up sub networks and attributes --- self.maybe_build_subnet("box_network", builder=cfg.build_lateral, key="box") self.maybe_build_subnet("attr_network", builder=cfg.build_lateral, key="attr") self.maybe_build_subnet("z_network", builder=cfg.build_lateral, key="z") self.maybe_build_subnet("obj_network", builder=cfg.build_lateral, key="obj") self.maybe_build_subnet("object_encoder") _, H, W, _, _ = tf_shape(inp_features) H = int(H) W = int(W) if not self.initialized: # Note this limits the re-usability of this module to images # with a fixed shape (the shape of the first image it is used on) self.batch_size, self.image_height, self.image_width, self.image_depth = tf_shape(inp) self.H = H self.W = W self.HWB = H*W*self.B self.is_training = is_training self.float_is_training = tf.to_float(is_training) # --- set up the edge element --- sizes = [4, self.A, 1, 1] sigmoids = [True, False, False, True] total_sample_size = sum(sizes) if self.edge_weights is None: self.edge_weights = tf.get_variable("edge_weights", shape=total_sample_size, dtype=tf.float32) if "edge" in self.fixed_weights: tf.add_to_collection(FIXED_COLLECTION, self.edge_weights) _edge_weights = tf.split(self.edge_weights, sizes, axis=0) _edge_weights = [ (tf.nn.sigmoid(ew) if sigmoid else ew) for ew, sigmoid in zip(_edge_weights, sigmoids)] edge_element = tf.concat(_edge_weights, axis=0) edge_element = tf.tile(edge_element[None, :], (self.batch_size, 1)) # --- containers for storing built program --- program = np.empty((H, W, self.B), dtype=np.object) # --- build the program --- is_posterior_tf = tf.ones((self.batch_size, 2)) if is_posterior: is_posterior_tf = is_posterior_tf * [1, 0] else: is_posterior_tf = is_posterior_tf * [0, 1] results = [] for h, w, b in itertools.product(range(H), range(W), range(self.B)): built = dict() partial_program, features = None, None context = self._get_sequential_context(program, h, w, b, edge_element) base_features = tf.concat([inp_features[:, h, w, b, :], context, is_posterior_tf], axis=1) # --- box --- layer_inp = base_features n_features = self.n_passthrough_features output_size = 8 network_output = self.box_network(layer_inp, output_size + n_features, self. is_training) rep_input, features = tf.split(network_output, (output_size, n_features), axis=1) _built = self._build_box(rep_input, self.is_training, hw=(h, w)) built.update(_built) partial_program = built['local_box'] # --- attr --- if is_posterior: # --- Get object attributes using object encoder --- yt, xt, ys, xs = tf.split(built['normalized_box'], 4, axis=-1) yt, xt, ys, xs = coords_to_image_space( yt, xt, ys, xs, (self.image_height, self.image_width), self.anchor_box, top_left=False) transform_constraints = snt.AffineWarpConstraints.no_shear_2d() warper = snt.AffineGridWarper( (self.image_height, self.image_width), self.object_shape, transform_constraints) _boxes = tf.concat([xs, 2*xt - 1, ys, 2*yt - 1], axis=-1) grid_coords = warper(_boxes) grid_coords = tf.reshape(grid_coords, (self.batch_size, 1, *self.object_shape, 2,)) if self.edge_resampler: glimpse = resampler_edge.resampler_edge(inp, grid_coords) else: glimpse = tf.contrib.resampler.resampler(inp, grid_coords) glimpse = tf.reshape(glimpse, (self.batch_size, *self.object_shape, self.image_depth)) else: glimpse = tf.zeros((self.batch_size, *self.object_shape, self.image_depth)) # Create the object encoder network regardless of is_posterior, otherwise messes with ScopedFunction encoded_glimpse = self.object_encoder(glimpse, (1, 1, self.A), self.is_training) encoded_glimpse = tf.reshape(encoded_glimpse, (self.batch_size, self.A)) if not is_posterior: encoded_glimpse = tf.zeros_like(encoded_glimpse) layer_inp = tf.concat( [base_features, features, encoded_glimpse, partial_program], axis=1) network_output = self.attr_network(layer_inp, 2 * self.A + n_features, self. is_training) attr_mean, attr_log_std, features = tf.split(network_output, (self.A, self.A, n_features), axis=1) attr_std = self.std_nonlinearity(attr_log_std) attr = Normal(loc=attr_mean, scale=attr_std).sample() built.update(attr_mean=attr_mean, attr_std=attr_std, attr=attr, glimpse=glimpse) partial_program = tf.concat([partial_program, built['attr']], axis=1) # --- z --- layer_inp = tf.concat([base_features, features, partial_program], axis=1) n_features = self.n_passthrough_features network_output = self.z_network(layer_inp, 2 + n_features, self.is_training) z_mean, z_log_std, features = tf.split(network_output, (1, 1, n_features), axis=1) z_std = self.std_nonlinearity(z_log_std) z_mean = self.training_wheels * tf.stop_gradient(z_mean) + (1-self.training_wheels) * z_mean z_std = self.training_wheels * tf.stop_gradient(z_std) + (1-self.training_wheels) * z_std z_logit = Normal(loc=z_mean, scale=z_std).sample() z = self.z_nonlinearity(z_logit) built.update(z_logit_mean=z_mean, z_logit_std=z_std, z_logit=z_logit, z=z) partial_program = tf.concat([partial_program, built['z']], axis=1) # --- obj --- layer_inp = tf.concat([base_features, features, partial_program], axis=1) rep_input = self.obj_network(layer_inp, 1, self.is_training) _built = self._build_obj(rep_input, self.is_training) built.update(_built) partial_program = tf.concat([partial_program, built['obj']], axis=1) # --- final --- results.append(built) program[h, w, b] = partial_program assert program[h, w, b].shape[1] == total_sample_size objects = AttrDict() for k in results[0]: objects[k] = tf.stack([r[k] for r in results], axis=1) if prop_state is not None: objects.prop_state = tf.tile(prop_state[0:1, None], (self.batch_size, self.HWB, 1)) objects.prior_prop_state = tf.tile(prop_state[0:1, None], (self.batch_size, self.HWB, 1)) # --- misc --- objects.n_objects = tf.fill((self.batch_size,), self.HWB) objects.pred_n_objects = tf.reduce_sum(objects.obj, axis=(1, 2)) objects.pred_n_objects_hard = tf.reduce_sum(tf.round(objects.obj), axis=(1, 2)) return objects
def _call(self, inp, inp_features, is_training, is_posterior=True): print("\n" + "-" * 10 + " GridObjectLayer(is_posterior={}) ".format(is_posterior) + "-" * 10) # --- set up sub networks and attributes --- self.maybe_build_subnet("box_network", builder=cfg.build_lateral, key="box") self.maybe_build_subnet("attr_network", builder=cfg.build_lateral, key="attr") self.maybe_build_subnet("z_network", builder=cfg.build_lateral, key="z") self.maybe_build_subnet("obj_network", builder=cfg.build_lateral, key="obj") self.maybe_build_subnet("object_encoder") _, H, W, _, _ = inp_features.shape H = int(H) W = int(W) if not self.initialized: # Note this limits the re-usability of this module to images # with a fixed shape (the shape of the first image it is used on) self.image_height = int(inp.shape[-3]) self.image_width = int(inp.shape[-2]) self.image_depth = int(inp.shape[-1]) self.H = H self.W = W self.HWB = H * W * self.B self.batch_size = tf.shape(inp)[0] self.is_training = is_training self.float_is_training = tf.to_float(is_training) # --- set up the edge element --- sizes = [4, self.A, 1, 1] sigmoids = [True, False, False, True] total_sample_size = sum(sizes) if self.edge_weights is None: self.edge_weights = tf.get_variable("edge_weights", shape=total_sample_size, dtype=tf.float32) if "edge" in self.fixed_weights: tf.add_to_collection(FIXED_COLLECTION, self.edge_weights) _edge_weights = tf.split(self.edge_weights, sizes, axis=0) _edge_weights = [(tf.nn.sigmoid(ew) if sigmoid else ew) for ew, sigmoid in zip(_edge_weights, sigmoids)] edge_element = tf.concat(_edge_weights, axis=0) edge_element = tf.tile(edge_element[None, :], (self.batch_size, 1)) # --- containers for storing built program --- _tensors = collections.defaultdict(self._make_empty) program = np.empty((H, W, self.B), dtype=np.object) # --- build the program --- is_posterior_tf = tf.ones((self.batch_size, 2)) if is_posterior: is_posterior_tf = is_posterior_tf * [1, 0] else: is_posterior_tf = is_posterior_tf * [0, 1] for h, w, b in itertools.product(range(H), range(W), range(self.B)): partial_program, features = None, None context = self._get_sequential_context(program, h, w, b, edge_element) base_features = tf.concat( [inp_features[:, h, w, b, :], context, is_posterior_tf], axis=1) # --- box --- layer_inp = base_features n_features = self.n_passthrough_features output_size = 8 network_output = self.box_network(layer_inp, output_size + n_features, self.is_training) rep_input, features = tf.split(network_output, (output_size, n_features), axis=1) built = self._build_box(rep_input, h, w, b, self.is_training) for key, value in built.items(): _tensors[key][h, w, b] = value partial_program = built['box'] # --- attr --- if is_posterior: # --- Get object attributes using object encoder --- yt, xt, ys, xs = tf.split(built['normalized_box'], 4, axis=-1) # yt/xt give top/left but here we need center yt += ys / 2 xt += xs / 2 transform_constraints = snt.AffineWarpConstraints.no_shear_2d() warper = snt.AffineGridWarper( (self.image_height, self.image_width), self.object_shape, transform_constraints) _boxes = tf.concat([xs, 2 * xt - 1, ys, 2 * yt - 1], axis=-1) grid_coords = warper(_boxes) grid_coords = tf.reshape(grid_coords, ( self.batch_size, 1, *self.object_shape, 2, )) glimpse = resampler_edge.resampler_edge(inp, grid_coords) glimpse = tf.reshape( glimpse, (self.batch_size, *self.object_shape, self.image_depth)) else: glimpse = tf.zeros( (self.batch_size, *self.object_shape, self.image_depth)) # Create the object encoder network regardless of is_posterior, otherwise messes with ScopedFunction encoded_glimpse = self.object_encoder(glimpse, (1, 1, self.A), self.is_training) encoded_glimpse = tf.reshape(encoded_glimpse, (self.batch_size, self.A)) if not is_posterior: encoded_glimpse = tf.zeros_like(encoded_glimpse) layer_inp = tf.concat( [base_features, features, encoded_glimpse, partial_program], axis=1) network_output = self.attr_network(layer_inp, 2 * self.A + n_features, self.is_training) attr_mean, attr_log_std, features = tf.split( network_output, (self.A, self.A, n_features), axis=1) attr_std = self.std_nonlinearity(attr_log_std) attr_dist = Normal(loc=attr_mean, scale=attr_std) attr = attr_dist.sample() if "attr" in self.no_gradient: attr = tf.stop_gradient(attr) built = dict(attr_dist=attr_dist, attr=attr, glimpse=glimpse) for key, value in built.items(): _tensors[key][h, w, b] = value partial_program = tf.concat([partial_program, built['attr']], axis=1) # --- z --- layer_inp = tf.concat([base_features, features, partial_program], axis=1) n_features = self.n_passthrough_features network_output = self.z_network(layer_inp, 2 + n_features, self.is_training) z_mean, z_log_std, features = tf.split(network_output, (1, 1, n_features), axis=1) z_std = self.std_nonlinearity(z_log_std) z_mean = self.training_wheels * tf.stop_gradient(z_mean) + ( 1 - self.training_wheels) * z_mean z_std = self.training_wheels * tf.stop_gradient(z_std) + ( 1 - self.training_wheels) * z_std z_logit_dist = Normal(loc=z_mean, scale=z_std) z_logits = z_logit_dist.sample() z = self.z_nonlinearity(z_logits) if "z" in self.no_gradient: z = tf.stop_gradient(z) if "z" in self.fixed_values: z = self.fixed_values['z'] * tf.ones_like(z) built = dict(z_logit_dist=z_logit_dist, z=z) for key, value in built.items(): _tensors[key][h, w, b] = value partial_program = tf.concat([partial_program, built['z']], axis=1) # --- obj --- layer_inp = tf.concat([base_features, features, partial_program], axis=1) rep_input = self.obj_network(layer_inp, 1, self.is_training) built = self._build_obj(rep_input, self.is_training) for key, value in built.items(): _tensors[key][h, w, b] = value partial_program = tf.concat([partial_program, built['obj']], axis=1) # --- final --- program[h, w, b] = partial_program assert program[h, w, b].shape[1] == total_sample_size # --- merge tensors from different grid cells --- objects = AttrDict() for k, v in _tensors.items(): if k.endswith('_dist'): dist = v[0, 0, 0] dist_class = type(dist) params = dist.parameters.copy() tensor_keys = sorted(key for key, t in params.items() if isinstance(t, tf.Tensor)) tensor_params = {} for key in tensor_keys: t1 = [] for h in range(H): t2 = [] for w in range(W): t2.append( tf.stack([ v[h, w, b].parameters[key] for b in range(self.B) ], axis=1)) t1.append(tf.stack(t2, axis=1)) tensor_params[key] = tf.stack(t1, axis=1) params.update(tensor_params) objects[k] = dist_class(**params) else: t1 = [] for h in range(H): t2 = [] for w in range(W): t2.append( tf.stack([v[h, w, b] for b in range(self.B)], axis=1)) t1.append(tf.stack(t2, axis=1)) objects[k] = tf.stack(t1, axis=1) objects.all = tf.concat( [objects.box, objects.attr, objects.z, objects.obj], axis=-1) # --- misc --- objects.n_objects = tf.fill((self.batch_size, ), self.HWB) objects.pred_n_objects = tf.reduce_sum(objects.obj, axis=(1, 2, 3, 4)) objects.pred_n_objects_hard = tf.reduce_sum(tf.round(objects.obj), axis=(1, 2, 3, 4)) return objects
def build_background(self): if cfg.background_cfg.mode == "colour": rgb = np.array(to_rgb(cfg.background_cfg.colour))[None, None, None, :] background = rgb * tf.ones_like(self.inp) elif cfg.background_cfg.mode == "learn_solid": # Learn a solid colour for the background self.solid_background_logits = tf.get_variable("solid_background", initializer=[0.0, 0.0, 0.0]) if "background" in self.fixed_weights: tf.add_to_collection(FIXED_COLLECTION, self.solid_background_logits) solid_background = tf.nn.sigmoid(10 * self.solid_background_logits) background = solid_background[None, None, None, :] * tf.ones_like(self.inp) elif cfg.background_cfg.mode == "scalor": pass elif cfg.background_cfg.mode == "learn": self.maybe_build_subnet("background_encoder") self.maybe_build_subnet("background_decoder") # Here I'm just encoding the first frame... bg_attr = self.background_encoder(self.inp[:, 0], 2 * cfg.background_cfg.A, self.is_training) bg_attr_mean, bg_attr_log_std = tf.split(bg_attr, 2, axis=-1) bg_attr_std = tf.exp(bg_attr_log_std) if not self.noisy: bg_attr_std = tf.zeros_like(bg_attr_std) bg_attr, bg_attr_kl = normal_vae(bg_attr_mean, bg_attr_std, self.attr_prior_mean, self.attr_prior_std) self._tensors.update( bg_attr_mean=bg_attr_mean, bg_attr_std=bg_attr_std, bg_attr_kl=bg_attr_kl, bg_attr=bg_attr) # --- decode --- _, T, H, W, _ = tf_shape(self.inp) background = self.background_decoder(bg_attr, 3, self.is_training) assert len(background.shape) == 2 and background.shape[1] == 3 background = tf.nn.sigmoid(tf.clip_by_value(background, -10, 10)) background = tf.tile(background[:, None, None, None, :], (1, T, H, W, 1)) elif cfg.background_cfg.mode == "learn_and_transform": self.maybe_build_subnet("background_encoder") self.maybe_build_subnet("background_decoder") # --- encode --- n_transform_latents = 4 n_latents = (2 * cfg.background_cfg.A, 2 * n_transform_latents) bg_attr, bg_transform_params = self.background_encoder(self.inp, n_latents, self.is_training) # --- bg attributes --- bg_attr_mean, bg_attr_log_std = tf.split(bg_attr, 2, axis=-1) bg_attr_std = self.std_nonlinearity(bg_attr_log_std) bg_attr, bg_attr_kl = normal_vae(bg_attr_mean, bg_attr_std, self.attr_prior_mean, self.attr_prior_std) # --- bg location --- bg_transform_params = tf.reshape( bg_transform_params, (self.batch_size, self.dynamic_n_frames, 2*n_transform_latents)) mean, log_std = tf.split(bg_transform_params, 2, axis=2) std = self.std_nonlinearity(log_std) logits, kl = normal_vae(mean, std, 0.0, 1.0) # integrate across timesteps logits = tf.cumsum(logits, axis=1) logits = tf.reshape(logits, (self.batch_size*self.dynamic_n_frames, n_transform_latents)) y, x, h, w = tf.split(logits, n_transform_latents, axis=1) h = (0.9 - 0.5) * tf.nn.sigmoid(h) + 0.5 w = (0.9 - 0.5) * tf.nn.sigmoid(w) + 0.5 y = (1 - h) * tf.nn.tanh(y) x = (1 - w) * tf.nn.tanh(x) # --- decode --- background = self.background_decoder(bg_attr, self.image_depth, self.is_training) bg_shape = cfg.background_cfg.bg_shape background = background[:, :bg_shape[0], :bg_shape[1], :] assert background.shape[1:3] == bg_shape background_raw = tf.nn.sigmoid(tf.clip_by_value(background, -10, 10)) transform_constraints = snt.AffineWarpConstraints.no_shear_2d() warper = snt.AffineGridWarper( bg_shape, (self.image_height, self.image_width), transform_constraints) transforms = tf.concat([w, x, h, y], axis=-1) grid_coords = warper(transforms) grid_coords = tf.reshape( grid_coords, (self.batch_size, self.dynamic_n_frames, *tf_shape(grid_coords)[1:])) background = tf.contrib.resampler.resampler(background_raw, grid_coords) self._tensors.update( bg_attr_mean=bg_attr_mean, bg_attr_std=bg_attr_std, bg_attr_kl=bg_attr_kl, bg_attr=bg_attr, bg_y=tf.reshape(y, (self.batch_size, self.dynamic_n_frames, 1)), bg_x=tf.reshape(x, (self.batch_size, self.dynamic_n_frames, 1)), bg_h=tf.reshape(h, (self.batch_size, self.dynamic_n_frames, 1)), bg_w=tf.reshape(w, (self.batch_size, self.dynamic_n_frames, 1)), bg_transform_kl=kl, bg_raw=background_raw, ) elif cfg.background_cfg.mode == "data": background = self._tensors["ground_truth_background"] else: raise Exception("Unrecognized background mode: {}.".format(cfg.background_cfg.mode)) self._tensors["background"] = background[:, :self.dynamic_n_frames]