Exemplo n.º 1
0
    def testShapeInferenceAndChecks(self):
        output_shape2d = (2, 3)
        source_shape2d = (6, 9)
        constraints = scale_2d(y=1) & translation_2d(x=-2, y=7)
        agw2d = snt.AffineGridWarper(source_shape=source_shape2d,
                                     output_shape=output_shape2d,
                                     constraints=constraints)

        input_params2d = tf.placeholder(tf.float32,
                                        [None, constraints.num_free_params])
        warped_grid2d = agw2d(input_params2d)
        self.assertEqual(warped_grid2d.get_shape().as_list()[1:], [2, 3, 2])

        output_shape2d = (2, 3)
        source_shape3d = (100, 200, 50)
        agw3d = snt.AffineGridWarper(source_shape=source_shape3d,
                                     output_shape=output_shape2d,
                                     constraints=[[None, 0, None, None],
                                                  [0, 1, 0, None],
                                                  [0, None, 0, None]])

        input_params3d = tf.placeholder(
            tf.float32, [None, agw3d.constraints.num_free_params])
        warped_grid3d = agw3d(input_params3d)
        self.assertEqual(warped_grid3d.get_shape().as_list()[1:], [2, 3, 3])

        output_shape3d = (2, 3, 4)
        source_shape3d = (100, 200, 50)
        agw3d = snt.AffineGridWarper(source_shape=source_shape3d,
                                     output_shape=output_shape3d,
                                     constraints=[[None, 0, None, None],
                                                  [0, 1, 0, None],
                                                  [0, None, 0, None]])

        input_params3d = tf.placeholder(
            tf.float32, [None, agw3d.constraints.num_free_params])
        warped_grid3d = agw3d(input_params3d)
        self.assertEqual(warped_grid3d.get_shape().as_list()[1:], [2, 3, 4, 3])

        with self.assertRaisesRegexp(
                snt.Error, "Incompatible set of constraints provided.*"):
            snt.AffineGridWarper(source_shape=source_shape3d,
                                 output_shape=output_shape3d,
                                 constraints=no_constraints(2))

        with self.assertRaisesRegexp(snt.Error,
                                     "Output domain dimensionality.*"):
            snt.AffineGridWarper(source_shape=source_shape2d,
                                 output_shape=output_shape3d,
                                 constraints=no_constraints(2))
Exemplo n.º 2
0
def extract_affine_glimpse(image, object_shape, cyt, cxt, ys, xs,
                           edge_resampler):
    """ (cyt, cxt) are rectangle center. (ys, xs) are rectangle height/width """
    _, *image_shape, image_depth = tf_shape(image)
    transform_constraints = snt.AffineWarpConstraints.no_shear_2d()
    warper = snt.AffineGridWarper(image_shape, object_shape,
                                  transform_constraints)

    # change coordinate system
    cyt = 2 * cyt - 1
    cxt = 2 * cxt - 1

    leading_shape = tf_shape(cyt)[:-1]

    _boxes = tf.concat([xs, cxt, ys, cyt], axis=-1)
    _boxes = tf.reshape(_boxes, (-1, 4))

    grid_coords = warper(_boxes)

    grid_coords = tf.reshape(grid_coords, (*leading_shape, *object_shape, 2))

    if edge_resampler:
        glimpses = resampler_edge.resampler_edge(image, grid_coords)
    else:
        glimpses = tf.contrib.resampler.resampler(image, grid_coords)

    glimpses = tf.reshape(glimpses,
                          (*leading_shape, *object_shape, image_depth))

    return glimpses
Exemplo n.º 3
0
    def __init__(self, img_size, crop_size, constraints=None, inverse=False):
        super(SpatialTransformer, self).__init__(self.__class__.__name__)

        with self._enter_variable_scope():
            self._warper = snt.AffineGridWarper(img_size, crop_size,
                                                constraints)
            if inverse:
                self._warper = self._warper.inverse()
Exemplo n.º 4
0
    def __init__(self, img_size, crop_size, inverse=False):
        super(SpatialTransformer, self).__init__()

        with self._enter_variable_scope():
            constraints = snt.AffineWarpConstraints.no_shear_2d()
            self._warper = snt.AffineGridWarper(img_size, crop_size,
                                                constraints)
            if inverse:
                self._warper = self._warper.inverse()
Exemplo n.º 5
0
    def testInvSameAsNumPyRef(self, output_shape, source_shape, constraints):
        def chain(x):
            return itertools.chain(*x)

        def predict(output_shape, source_shape, inputs):
            ranges = [
                np.linspace(-1, 1, x, dtype=np.float32)
                for x in reversed(source_shape)
            ]
            n = len(output_shape)
            grid = np.meshgrid(*ranges, indexing="xy")
            for _ in range(len(source_shape), len(output_shape)):
                grid.append(np.zeros_like(grid[0]))
            grid.append(np.ones_like(grid[0]))
            grid = np.array([x.reshape(1, -1) for x in grid]).squeeze()
            predicted_output = []
            for i in range(0, batch_size):
                affine_matrix = inputs[i, :].reshape(n, n + 1)
                inv_matrix = np.linalg.inv(affine_matrix[:2, :2])
                inv_transform = np.concatenate([
                    inv_matrix,
                    -np.dot(inv_matrix, affine_matrix[:, 2].reshape(2, 1))
                ], 1)
                x = np.dot(inv_transform, grid)
                for k, s in enumerate(reversed(output_shape)):
                    s = (s - 1) * 0.5
                    x[k, :] = x[k, :] * s + s
                x = np.concatenate([v.reshape(v.shape + (1, )) for v in x], -1)
                predicted_output.append(x.reshape(tuple(source_shape) + (n, )))
            return predicted_output

        batch_size = 20
        agw = snt.AffineGridWarper(source_shape=source_shape,
                                   output_shape=output_shape,
                                   constraints=constraints).inverse()
        inputs = tf.placeholder(tf.float32,
                                [None, constraints.num_free_params])
        warped_grid = agw(inputs)
        full_size = constraints.num_dim * (constraints.num_dim + 1)
        # Adding a bit of mass to the matrix to avoid singular matrices
        full_input_np = np.random.rand(batch_size, full_size) + 0.1

        con_i = [i for i, x in enumerate(chain(constraints.mask)) if not x]
        con_val = [x for x in chain(constraints.constraints) if x is not None]
        for i, v in zip(con_i, con_val):
            full_input_np[:, i] = v
        uncon_i = [i for i, x in enumerate(chain(constraints.mask)) if x]
        with self.test_session() as sess:
            output = sess.run(warped_grid,
                              feed_dict={inputs: full_input_np[:, uncon_i]})

        self.assertAllClose(output,
                            predict(output_shape, source_shape, full_input_np),
                            rtol=1e-05,
                            atol=1e-05)
def spatial_transformer(img_tensor, transform_params, crop_size):
    """
    :param img_tensor: tf.Tensor of size (batch_size, Height, Width, channels)
    :param transform_params: tf.Tensor of size (batch_size, 4), where params are  (scale_y, shift_y, scale_x, shift_x)
    :param crop_size): tuple of 2 ints, size of the resulting crop
    """
    constraints = snt.AffineWarpConstraints.no_shear_2d()
    img_size = img_tensor.shape.as_list()[1:]
    warper = snt.AffineGridWarper(img_size, crop_size, constraints)
    grid_coords = warper(transform_params)
    glimpse = snt.resampler(img_tensor[..., tf.newaxis], grid_coords)
    return glimpse
Exemplo n.º 7
0
    def __init__(self, img_size, crop_size, inverse=False):
        """Initialises the module.

        :param img_size: Tuple of ints, size of the input image.
        :param crop_size: Tuple of ints, size of the resampled image.
        :param inverse: Boolean; inverts the given transformation if True and then maps crops into full-sized images.
        """

        super(SpatialTransformer, self).__init__()

        with self._enter_variable_scope():
            constraints = snt.AffineWarpConstraints.no_shear_2d()
            self._warper = snt.AffineGridWarper(img_size[:2], crop_size, constraints)
            if inverse:
                self._warper = self._warper.inverse()
Exemplo n.º 8
0
  def testIdentity(self):
    constraints = snt.AffineWarpConstraints.no_constraints()
    warper = snt.AffineGridWarper([3, 3], [3, 3], constraints=constraints)
    p = tf.placeholder(tf.float64, (None, constraints.num_free_params))
    grid = warper(p)
    with self.test_session() as sess:
      warp_p = np.array([1, 0, 0,
                         0, 1, 0]).reshape([1, constraints.num_free_params])
      output = sess.run(grid, feed_dict={p: warp_p})

    # Check that output matches expected result for a known transformation.
    self.assertAllClose(output,
                        np.array([[[[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]],
                                   [[0.0, 1.0], [1.0, 1.0], [2.0, 1.0]],
                                   [[0.0, 2.0], [1.0, 2.0], [2.0, 2.0]]]]))
Exemplo n.º 9
0
    [corners, np.ones((1, corners.shape[1]))], axis=0)

left = image_corners[0, 0]
top = image_corners[1, 0]
right = image_corners[0, 1]
bottom = image_corners[1, 1]
width = right - left
height = bottom - top

boxes = boxes[:, [0, 2, 4, 5]]
boxes = np.tile(boxes, (n_examples, 1))
boxes = tf.constant(boxes, tf.float32)

transform_constraints = snt.AffineWarpConstraints.no_shear_2d()

warper = snt.AffineGridWarper(image_shape[:2], crop_shape[:2],
                              transform_constraints)
grid_coords = warper(boxes)
output = tf.contrib.resampler.resampler(images, grid_coords)

sess = tf.Session()
crops, _grid_coords = sess.run([output, grid_coords])
print(_grid_coords)

# import matplotlib
# matplotlib.use('pdf')
import matplotlib.pyplot as plt
import matplotlib.patches as patches

fig, axes = plt.subplots(n_examples, 2)

for image, crop, ax in zip(images, crops, axes):
Exemplo n.º 10
0
    def _build(self,
               pose,
               presence=None,
               template_feature=None,
               bg_image=None,
               img_embedding=None):
        """Builds the module.

    Args:
      pose: [B, n_templates, 6] tensor.
      presence: [B, n_templates] tensor.
      template_feature: [B, n_templates, n_features] tensor; these features are
        used to change templates based on the input, if present.
      bg_image: [B, *output_size] tensor representing the background.
      img_embedding: [B, d] tensor containing image embeddings.

    Returns:
      [B, n_templates, *output_size, n_channels] tensor.
    """
        batch_size, n_templates = pose.shape[:2].as_list()
        templates = self.make_templates(n_templates, template_feature)

        if templates.shape[0] == 1:
            templates = snt.TileByDim([0], [batch_size])(templates)

        # it's easier for me to think in inverse coordinates
        warper = snt.AffineGridWarper(self._output_size, self._template_size)
        warper = warper.inverse()

        grid_coords = snt.BatchApply(warper)(pose)
        resampler = snt.BatchApply(contrib_resampler.resampler)
        transformed_templates = resampler(templates, grid_coords)

        if bg_image is not None:
            bg_image = tf.expand_dims(bg_image, axis=1)
        else:
            bg_image = tf.nn.sigmoid(tf.get_variable('bg_value', shape=[1]))
            bg_image = tf.zeros_like(transformed_templates[:, :1]) + bg_image

        transformed_templates = tf.concat([transformed_templates, bg_image],
                                          axis=1)

        if presence is not None:
            presence = tf.concat([presence, tf.ones([batch_size, 1])], axis=1)

        if True:  # pylint: disable=using-constant-test

            if self._use_alpha_channel:
                template_mixing_logits = snt.TileByDim([0], [batch_size])(
                    self._templates_alpha)
                template_mixing_logits = resampler(template_mixing_logits,
                                                   grid_coords)

                bg_mixing_logit = tf.nn.softplus(
                    tf.get_variable('bg_mixing_logit', initializer=[0.]))

                bg_mixing_logit = (
                    tf.zeros_like(template_mixing_logits[:, :1]) +
                    bg_mixing_logit)

                template_mixing_logits = tf.concat(
                    [template_mixing_logits, bg_mixing_logit], 1)

            else:
                temperature_logit = tf.get_variable('temperature_logit',
                                                    shape=[1])
                temperature = tf.nn.softplus(temperature_logit + .5) + 1e-4
                template_mixing_logits = transformed_templates / temperature

        scale = 1.
        if self._learn_output_scale:
            scale = tf.get_variable('scale', shape=[1])
            scale = tf.nn.softplus(scale) + 1e-4

        if self._output_pdf_type == 'mixture':
            template_mixing_logits += make_brodcastable(
                math_ops.safe_log(presence), template_mixing_logits)

            rec_pdf = prob.MixtureDistribution(template_mixing_logits,
                                               [transformed_templates, scale],
                                               tfd.Normal)

        else:
            raise ValueError('Unknown pdf type: "{}".'.format(
                self._output_pdf_type))

        return AttrDict(raw_templates=tf.squeeze(self._templates, 0),
                        transformed_templates=transformed_templates[:, :-1],
                        mixing_logits=template_mixing_logits[:, :-1],
                        pdf=rec_pdf)
Exemplo n.º 11
0
    def _build_program_interpreter(self, tensors):
        # --- Get object attributes using object encoder ---

        max_objects = tensors["max_objects"]

        yt, xt, ys, xs = tf.split(tensors["normalized_box"], 4, axis=-1)

        transform_constraints = snt.AffineWarpConstraints.no_shear_2d()
        warper = snt.AffineGridWarper((self.image_height, self.image_width),
                                      self.object_shape, transform_constraints)

        _boxes = tf.concat(
            [xs, 2 * (xt + xs / 2) - 1, ys, 2 * (yt + ys / 2) - 1], axis=-1)
        _boxes = tf.reshape(_boxes, (self.batch_size * max_objects, 4))
        grid_coords = warper(_boxes)
        grid_coords = tf.reshape(grid_coords, (
            self.batch_size,
            max_objects,
            *self.object_shape,
            2,
        ))
        glimpse = tf.contrib.resampler.resampler(tensors["inp"], grid_coords)

        object_encoder_in = tf.reshape(glimpse,
                                       (self.batch_size * max_objects,
                                        *self.object_shape, self.image_depth))

        attr = self.object_encoder(object_encoder_in, (1, 1, 2 * self.A),
                                   self.is_training)
        attr = tf.reshape(attr, (self.batch_size, max_objects, 2 * self.A))
        attr_mean, attr_log_std = tf.split(attr, [self.A, self.A], axis=-1)
        attr_std = tf.exp(attr_log_std)

        if not self.noisy:
            attr_std = tf.zeros_like(attr_std)

        attr, attr_kl = normal_vae(attr_mean, attr_std, self.attr_prior_mean,
                                   self.attr_prior_std)

        object_decoder_in = tf.reshape(
            attr, (self.batch_size * max_objects, 1, 1, self.A))

        # --- Compute sprites from attr using object decoder ---

        object_logits = self.object_decoder(
            object_decoder_in, self.object_shape + (self.image_depth, ),
            self.is_training)

        objects = tf.nn.sigmoid(tf.clip_by_value(object_logits, -10., 10.))

        objects = tf.reshape(objects, (
            self.batch_size,
            max_objects,
            *self.object_shape,
            self.image_depth,
        ))
        alpha = tensors["obj"][:, :, :, None, None] * tf.ones_like(
            objects[:, :, :, :, :1])
        importance = tf.ones_like(objects[:, :, :, :, :1])
        objects = tf.concat([objects, alpha, importance], axis=-1)

        # -- Reconstruct image ---

        scales = tf.concat([ys, xs], axis=-1)
        scales = tf.reshape(scales, (self.batch_size, max_objects, 2))

        offsets = tf.concat([yt, xt], axis=-1)
        offsets = tf.reshape(offsets, (self.batch_size, max_objects, 2))

        output = render_sprites.render_sprites(objects, tensors["n_objects"],
                                               scales, offsets,
                                               tensors["background"])

        return dict(output=output,
                    glimpse=tf.reshape(glimpse,
                                       (self.batch_size, max_objects,
                                        *self.object_shape, self.image_depth)),
                    attr=tf.reshape(attr,
                                    (self.batch_size, max_objects, self.A)),
                    attr_kl=tf.reshape(attr_kl,
                                       (self.batch_size, max_objects, self.A)),
                    objects=tf.reshape(objects, (
                        self.batch_size,
                        max_objects,
                        *self.object_shape,
                        self.image_depth,
                    )))
Exemplo n.º 12
0
    def build_representation(self):
        # --- init modules ---

        if self.backbone is None:
            self.backbone = self.build_backbone(scope="backbone")

        if self.cell is None:
            self.cell = cfg.build_cell(self.n_hidden, name="cell")

            # self.cell must be a Sonnet RNNCore

            if self.learn_initial_state:
                self.initial_hidden_state = snt.trainable_initial_state(
                    1,
                    self.cell.state_size,
                    tf.float32,
                    name="initial_hidden_state")

        if self.key_network is None:
            self.key_network = cfg.build_key_network(scope="key_network")
        if self.beta_network is None:
            self.beta_network = cfg.build_beta_network(scope="beta_network")
        if self.write_network is None:
            self.write_network = cfg.build_write_network(scope="write_network")
        if self.erase_network is None:
            self.erase_network = cfg.build_erase_network(scope="erase_network")

        if self.output_network is None:
            self.output_network = cfg.build_output_network(
                scope="output_network")

        d_n_frames, n_trackers, batch_size = self.dynamic_n_frames, self.n_trackers, self.batch_size

        # --- encode ---

        video = tf.reshape(self.inp,
                           (batch_size * d_n_frames, *self.obs_shape[1:]))

        zero_to_one_Y = tf.linspace(0., 1., self.image_height)
        zero_to_one_X = tf.linspace(0., 1., self.image_width)
        X, Y = tf.meshgrid(zero_to_one_X, zero_to_one_Y)
        X = tf.tile(X[None, :, :, None], (batch_size * d_n_frames, 1, 1, 1))
        Y = tf.tile(Y[None, :, :, None], (batch_size * d_n_frames, 1, 1, 1))
        video = tf.concat([video, Y, X], axis=-1)

        encoded_frames, _, _ = self.backbone(video, self.S, self.is_training)

        _, H, W, _ = tf_shape(encoded_frames)
        self.H = H
        self.W = W
        encoded_frames = tf.reshape(encoded_frames,
                                    (batch_size, d_n_frames, H * W, self.S))
        self.encoded_frames = encoded_frames

        cts = tf.minimum(
            1., self.float_is_training + (1 - float(self.discrete_eval)))
        self.mask_temperature = cts * 1.0 + (1 - cts) * 1e-5
        self.layer_temperature = cts * 1.0 + (1 - cts) * 1e-5

        f = tf.constant(0, dtype=tf.int32)

        if self.learn_initial_state:
            hidden_state = self.initial_hidden_state[None, ...]
        else:
            hidden_state = self.cell.zero_state(1, tf.float32)[None, ...]

        hidden_states = tf.tile(hidden_state, (
            self.n_trackers,
            self.batch_size,
        ) + (1, ) * (len(hidden_state.shape) - 2))

        conf = tf.zeros((n_trackers, batch_size, 1))

        structure = dict(
            hidden_states=hidden_states,
            tracker_output=tf.zeros(
                (self.n_trackers, self.batch_size, self.n_hidden)),
            attention_result=tf.zeros(
                (self.n_trackers, self.batch_size, self.S)),
            attention_weights=tf.zeros(
                (self.n_trackers, self.batch_size, H * W)),
            memory_activation=tf.zeros(
                (self.n_trackers, self.batch_size, H * W)),
            conf=conf,
            layer=tf.zeros((self.n_trackers, self.batch_size, self.n_layers)),
            pose=tf.zeros((self.n_trackers, self.batch_size, 4)),
            mask=tf.zeros((self.n_trackers, self.batch_size,
                           np.prod(self.object_shape))),
            appearance=tf.zeros((self.n_trackers, self.batch_size,
                                 3 * np.prod(self.object_shape))),
            order=tf.zeros((self.n_trackers, self.batch_size, 1),
                           dtype=tf.int32),
        )
        tensor_arrays = make_tensor_arrays(structure, self.dynamic_n_frames)

        loop_vars = [f, conf, hidden_states, *tensor_arrays]

        result = tf.while_loop(self._loop_cond, self._loop_body, loop_vars)

        first_ta_idx = min(i for i, ta in enumerate(result)
                           if isinstance(ta, tf.TensorArray))
        tensor_arrays = result[first_ta_idx:]

        def finalize_ta(ta):
            t = ta.stack()
            # reshape from (n_frames, n_trackers, batch_size, *other) to (batch_size, n_frames, n_trackers, *other)
            return tf.transpose(t, (2, 0, 1, *range(3, len(t.shape))))

        tensors = map_structure(
            finalize_ta,
            tensor_arrays,
            is_leaf=lambda t: isinstance(t, tf.TensorArray))
        tensors = apply_keys(structure, tensors)

        self._tensors.update(tensors)

        pprint.pprint(self._tensors)

        # --- render/decode ---

        ys, xs, yt, xt = tf.split(tensors["pose"], 4, axis=-1)
        self.record_tensors(
            conf=tensors["conf"],
            ys=ys,
            xs=xs,
            yt=yt,
            xt=xt,
        )

        yt_normed = (yt + 1) / 2
        xt_normed = (xt + 1) / 2
        ys_normed = 1 + self.eta[0] * ys
        xs_normed = 1 + self.eta[1] * xs

        normalized_box = tf.concat(
            [yt_normed, xt_normed, ys_normed, xs_normed], axis=-1)

        # expose values for plotting

        self._tensors.update(
            normalized_box=normalized_box,
            mask=tf.reshape(
                tensors["mask"],
                (batch_size, d_n_frames, n_trackers, *self.object_shape)),
            appearance=tf.reshape(tensors["appearance"],
                                  (batch_size, d_n_frames, n_trackers,
                                   *self.object_shape, self.image_depth)),
            conf=tensors["conf"],
            layer=tensors["layer"],
            order=tensors["order"],
        )

        # --- reshape values ---

        N = batch_size * d_n_frames * n_trackers
        ys_normed = tf.reshape(ys_normed, (N, 1))
        xs_normed = tf.reshape(xs_normed, (N, 1))
        yt_normed = tf.reshape(yt_normed, (N, 1))
        xt_normed = tf.reshape(xt_normed, (N, 1))

        _yt, _xt, _ys, _xs = tba_coords_to_image_space(
            yt_normed,
            xt_normed,
            ys_normed,
            xs_normed, (self.image_height, self.image_width),
            self.anchor_box,
            top_left=False)

        mask = tf.reshape(tensors["mask"], (N, *self.object_shape, 1))
        appearance = tf.reshape(tensors["appearance"],
                                (N, *self.object_shape, self.image_depth))

        transform_constraints = snt.AffineWarpConstraints.no_shear_2d()
        warper = snt.AffineGridWarper((self.image_height, self.image_width),
                                      self.object_shape, transform_constraints)
        inverse_warper = warper.inverse()
        transforms = tf.concat([_xs, _xt, _ys, _yt], axis=-1)
        grid_coords = inverse_warper(transforms)

        transformed_masks = tf.contrib.resampler.resampler(mask, grid_coords)
        transformed_masks = tf.reshape(
            transformed_masks, (batch_size, d_n_frames, n_trackers,
                                self.image_height, self.image_width, 1))

        transformed_appearances = tf.contrib.resampler.resampler(
            appearance, grid_coords)
        transformed_appearances = tf.reshape(
            transformed_appearances,
            (batch_size, d_n_frames, n_trackers, self.image_height,
             self.image_width, self.image_depth))

        layer_masks = []
        layer_appearances = []

        conf = tensors["conf"][:, :, :, :, None, None]

        # TODO: currently assuming a black background

        final_frames = tf.zeros((batch_size, d_n_frames, self.image_height,
                                 self.image_width, self.image_depth))

        # For each layer, create a mask image and an appearance image
        for layer_idx in range(self.n_layers):
            layer_weight = tensors["layer"][:, :, :, layer_idx, None, None,
                                            None]

            # (batch_size, n_frames, self.image_height, self.image_width, 1)
            layer_mask = tf.reduce_sum(conf * layer_weight * transformed_masks,
                                       axis=2)
            layer_mask = tf.minimum(1.0, layer_mask)

            # (batch_size, n_frames, self.image_height, self.image_width, 3)
            layer_appearance = tf.reduce_sum(conf * layer_weight *
                                             transformed_masks *
                                             transformed_appearances,
                                             axis=2)

            if self.clamp_appearance:
                layer_appearance = tf.minimum(1.0, layer_appearance)

            final_frames = (1 - layer_mask) * final_frames + layer_appearance

            layer_masks.append(layer_mask)
            layer_appearances.append(layer_appearance)

        self._tensors["output"] = final_frames

        # --- losses ---

        self._tensors['per_pixel_reconstruction_loss'] = (self.inp -
                                                          final_frames)**2

        self.losses['reconstruction'] = (
            tf.reduce_sum(self._tensors["per_pixel_reconstruction_loss"]) /
            tf.cast(d_n_frames * self.batch_size, tf.float32))
        # self.losses['reconstruction'] = tf.reduce_mean(self._tensors["per_pixel_reconstruction_loss"])

        self.losses['area'] = self.lmbda * tf.reduce_mean(
            ys_normed * xs_normed)
Exemplo n.º 13
0
    def _call(self, inp, inp_features, is_training, is_posterior=True, prop_state=None):
        print("\n" + "-" * 10 + " ConvGridObjectLayer(is_posterior={}) ".format(is_posterior) + "-" * 10)

        # --- set up sub networks and attributes ---

        self.maybe_build_subnet("box_network", builder=cfg.build_conv_lateral, key="box")
        self.maybe_build_subnet("attr_network", builder=cfg.build_conv_lateral, key="attr")
        self.maybe_build_subnet("z_network", builder=cfg.build_conv_lateral, key="z")
        self.maybe_build_subnet("obj_network", builder=cfg.build_conv_lateral, key="obj")

        self.maybe_build_subnet("object_encoder")

        _, H, W, _, n_channels = tf_shape(inp_features)

        if self.B != 1:
            raise Exception("NotImplemented")

        if not self.initialized:
            # Note this limits the re-usability of this module to images
            # with a fixed shape (the shape of the first image it is used on)
            self.batch_size, self.image_height, self.image_width, self.image_depth = tf_shape(inp)
            self.H = H
            self.W = W
            self.HWB = H*W
            self.batch_size = tf.shape(inp)[0]
            self.is_training = is_training
            self.float_is_training = tf.to_float(is_training)

        inp_features = tf.reshape(inp_features, (self.batch_size, H, W, n_channels))
        is_posterior_tf = tf.ones_like(inp_features[..., :2])
        if is_posterior:
            is_posterior_tf = is_posterior_tf * [1, 0]
        else:
            is_posterior_tf = is_posterior_tf * [0, 1]

        objects = AttrDict()

        base_features = tf.concat([inp_features, is_posterior_tf], axis=-1)

        # --- box ---

        layer_inp = base_features
        n_features = self.n_passthrough_features
        output_size = 8

        network_output = self.box_network(layer_inp, output_size + n_features, self.is_training)
        rep_input, features = tf.split(network_output, (output_size, n_features), axis=-1)

        _objects = self._build_box(rep_input, self.is_training)
        objects.update(_objects)

        # --- attr ---

        if is_posterior:
            # --- Get object attributes using object encoder ---

            yt, xt, ys, xs = tf.split(objects['normalized_box'], 4, axis=-1)

            yt, xt, ys, xs = coords_to_image_space(
                yt, xt, ys, xs, (self.image_height, self.image_width), self.anchor_box, top_left=False)

            transform_constraints = snt.AffineWarpConstraints.no_shear_2d()
            warper = snt.AffineGridWarper(
                (self.image_height, self.image_width), self.object_shape, transform_constraints)

            _boxes = tf.concat([xs, 2*xt - 1, ys, 2*yt - 1], axis=-1)
            _boxes = tf.reshape(_boxes, (self.batch_size*H*W, 4))
            grid_coords = warper(_boxes)
            grid_coords = tf.reshape(grid_coords, (self.batch_size, H, W, *self.object_shape, 2,))

            if self.edge_resampler:
                glimpse = resampler_edge.resampler_edge(inp, grid_coords)
            else:
                glimpse = tf.contrib.resampler.resampler(inp, grid_coords)
        else:
            glimpse = tf.zeros((self.batch_size, H, W, *self.object_shape, self.image_depth))

        # Create the object encoder network regardless of is_posterior, otherwise messes with ScopedFunction
        encoded_glimpse = apply_object_wise(
            self.object_encoder, glimpse, n_trailing_dims=3, output_size=self.A, is_training=self.is_training)

        if not is_posterior:
            encoded_glimpse = tf.zeros_like(encoded_glimpse)

        layer_inp = tf.concat([base_features, features, encoded_glimpse, objects['local_box']], axis=-1)
        network_output = self.attr_network(layer_inp, 2 * self.A + n_features, self.is_training)
        attr_mean, attr_log_std, features = tf.split(network_output, (self.A, self.A, n_features), axis=-1)

        attr_std = self.std_nonlinearity(attr_log_std)

        attr = Normal(loc=attr_mean, scale=attr_std).sample()

        objects.update(attr_mean=attr_mean, attr_std=attr_std, attr=attr, glimpse=glimpse)

        # --- z ---

        layer_inp = tf.concat([base_features, features, objects['local_box'], objects['attr']], axis=-1)
        n_features = self.n_passthrough_features

        network_output = self.z_network(layer_inp, 2 + n_features, self.is_training)
        z_mean, z_log_std, features = tf.split(network_output, (1, 1, n_features), axis=-1)
        z_std = self.std_nonlinearity(z_log_std)

        z_mean = self.training_wheels * tf.stop_gradient(z_mean) + (1-self.training_wheels) * z_mean
        z_std = self.training_wheels * tf.stop_gradient(z_std) + (1-self.training_wheels) * z_std
        z_logit = Normal(loc=z_mean, scale=z_std).sample()
        z = self.z_nonlinearity(z_logit)

        objects.update(z_logit_mean=z_mean, z_logit_std=z_std, z_logit=z_logit, z=z)

        # --- obj ---

        layer_inp = tf.concat([base_features, features, objects['local_box'], objects['attr'], objects['z']], axis=-1)
        rep_input = self.obj_network(layer_inp, 1, self.is_training)

        _objects = self._build_obj(rep_input, self.is_training)
        objects.update(_objects)

        # --- final ---

        _objects = AttrDict()
        for k, v in objects.items():
            _, _, _, *trailing_dims = tf_shape(v)
            _objects[k] = tf.reshape(v, (self.batch_size, self.HWB, *trailing_dims))
        objects = _objects

        if prop_state is not None:
            objects.prop_state = tf.tile(prop_state[0:1, None], (self.batch_size, self.HWB, 1))
            objects.prior_prop_state = tf.tile(prop_state[0:1, None], (self.batch_size, self.HWB, 1))

        # --- misc ---

        objects.n_objects = tf.fill((self.batch_size,), self.HWB)
        objects.pred_n_objects = tf.reduce_sum(objects.obj, axis=(1, 2))
        objects.pred_n_objects_hard = tf.reduce_sum(tf.round(objects.obj), axis=(1, 2))

        return objects
Exemplo n.º 14
0
    def _call(self, inp, inp_features, is_training, is_posterior=True, prop_state=None):
        print("\n" + "-" * 10 + " GridObjectLayer(is_posterior={}) ".format(is_posterior) + "-" * 10)

        # --- set up sub networks and attributes ---

        self.maybe_build_subnet("box_network", builder=cfg.build_lateral, key="box")
        self.maybe_build_subnet("attr_network", builder=cfg.build_lateral, key="attr")
        self.maybe_build_subnet("z_network", builder=cfg.build_lateral, key="z")
        self.maybe_build_subnet("obj_network", builder=cfg.build_lateral, key="obj")

        self.maybe_build_subnet("object_encoder")

        _, H, W, _, _ = tf_shape(inp_features)
        H = int(H)
        W = int(W)

        if not self.initialized:
            # Note this limits the re-usability of this module to images
            # with a fixed shape (the shape of the first image it is used on)
            self.batch_size, self.image_height, self.image_width, self.image_depth = tf_shape(inp)
            self.H = H
            self.W = W
            self.HWB = H*W*self.B
            self.is_training = is_training
            self.float_is_training = tf.to_float(is_training)

        # --- set up the edge element ---

        sizes = [4, self.A, 1, 1]
        sigmoids = [True, False, False, True]
        total_sample_size = sum(sizes)

        if self.edge_weights is None:
            self.edge_weights = tf.get_variable("edge_weights", shape=total_sample_size, dtype=tf.float32)
            if "edge" in self.fixed_weights:
                tf.add_to_collection(FIXED_COLLECTION, self.edge_weights)

        _edge_weights = tf.split(self.edge_weights, sizes, axis=0)
        _edge_weights = [
            (tf.nn.sigmoid(ew) if sigmoid else ew)
            for ew, sigmoid in zip(_edge_weights, sigmoids)]
        edge_element = tf.concat(_edge_weights, axis=0)
        edge_element = tf.tile(edge_element[None, :], (self.batch_size, 1))

        # --- containers for storing built program ---

        program = np.empty((H, W, self.B), dtype=np.object)

        # --- build the program ---

        is_posterior_tf = tf.ones((self.batch_size, 2))
        if is_posterior:
            is_posterior_tf = is_posterior_tf * [1, 0]
        else:
            is_posterior_tf = is_posterior_tf * [0, 1]

        results = []
        for h, w, b in itertools.product(range(H), range(W), range(self.B)):
            built = dict()

            partial_program, features = None, None
            context = self._get_sequential_context(program, h, w, b, edge_element)
            base_features = tf.concat([inp_features[:, h, w, b, :], context, is_posterior_tf], axis=1)

            # --- box ---

            layer_inp = base_features
            n_features = self.n_passthrough_features
            output_size = 8

            network_output = self.box_network(layer_inp, output_size + n_features, self. is_training)
            rep_input, features = tf.split(network_output, (output_size, n_features), axis=1)

            _built = self._build_box(rep_input, self.is_training, hw=(h, w))
            built.update(_built)
            partial_program = built['local_box']

            # --- attr ---

            if is_posterior:
                # --- Get object attributes using object encoder ---

                yt, xt, ys, xs = tf.split(built['normalized_box'], 4, axis=-1)

                yt, xt, ys, xs = coords_to_image_space(
                    yt, xt, ys, xs, (self.image_height, self.image_width), self.anchor_box, top_left=False)

                transform_constraints = snt.AffineWarpConstraints.no_shear_2d()
                warper = snt.AffineGridWarper(
                    (self.image_height, self.image_width), self.object_shape, transform_constraints)

                _boxes = tf.concat([xs, 2*xt - 1, ys, 2*yt - 1], axis=-1)

                grid_coords = warper(_boxes)
                grid_coords = tf.reshape(grid_coords, (self.batch_size, 1, *self.object_shape, 2,))
                if self.edge_resampler:
                    glimpse = resampler_edge.resampler_edge(inp, grid_coords)
                else:
                    glimpse = tf.contrib.resampler.resampler(inp, grid_coords)
                glimpse = tf.reshape(glimpse, (self.batch_size, *self.object_shape, self.image_depth))
            else:
                glimpse = tf.zeros((self.batch_size, *self.object_shape, self.image_depth))

            # Create the object encoder network regardless of is_posterior, otherwise messes with ScopedFunction
            encoded_glimpse = self.object_encoder(glimpse, (1, 1, self.A), self.is_training)
            encoded_glimpse = tf.reshape(encoded_glimpse, (self.batch_size, self.A))

            if not is_posterior:
                encoded_glimpse = tf.zeros_like(encoded_glimpse)

            layer_inp = tf.concat(
                [base_features, features, encoded_glimpse, partial_program], axis=1)
            network_output = self.attr_network(layer_inp, 2 * self.A + n_features, self. is_training)
            attr_mean, attr_log_std, features = tf.split(network_output, (self.A, self.A, n_features), axis=1)

            attr_std = self.std_nonlinearity(attr_log_std)

            attr = Normal(loc=attr_mean, scale=attr_std).sample()

            built.update(attr_mean=attr_mean, attr_std=attr_std, attr=attr, glimpse=glimpse)
            partial_program = tf.concat([partial_program, built['attr']], axis=1)

            # --- z ---

            layer_inp = tf.concat([base_features, features, partial_program], axis=1)
            n_features = self.n_passthrough_features

            network_output = self.z_network(layer_inp, 2 + n_features, self.is_training)
            z_mean, z_log_std, features = tf.split(network_output, (1, 1, n_features), axis=1)
            z_std = self.std_nonlinearity(z_log_std)

            z_mean = self.training_wheels * tf.stop_gradient(z_mean) + (1-self.training_wheels) * z_mean
            z_std = self.training_wheels * tf.stop_gradient(z_std) + (1-self.training_wheels) * z_std
            z_logit = Normal(loc=z_mean, scale=z_std).sample()
            z = self.z_nonlinearity(z_logit)

            built.update(z_logit_mean=z_mean, z_logit_std=z_std, z_logit=z_logit, z=z)
            partial_program = tf.concat([partial_program, built['z']], axis=1)

            # --- obj ---

            layer_inp = tf.concat([base_features, features, partial_program], axis=1)
            rep_input = self.obj_network(layer_inp, 1, self.is_training)

            _built = self._build_obj(rep_input, self.is_training)
            built.update(_built)

            partial_program = tf.concat([partial_program, built['obj']], axis=1)

            # --- final ---

            results.append(built)

            program[h, w, b] = partial_program
            assert program[h, w, b].shape[1] == total_sample_size

        objects = AttrDict()
        for k in results[0]:
            objects[k] = tf.stack([r[k] for r in results], axis=1)

        if prop_state is not None:
            objects.prop_state = tf.tile(prop_state[0:1, None], (self.batch_size, self.HWB, 1))
            objects.prior_prop_state = tf.tile(prop_state[0:1, None], (self.batch_size, self.HWB, 1))

        # --- misc ---

        objects.n_objects = tf.fill((self.batch_size,), self.HWB)
        objects.pred_n_objects = tf.reduce_sum(objects.obj, axis=(1, 2))
        objects.pred_n_objects_hard = tf.reduce_sum(tf.round(objects.obj), axis=(1, 2))

        return objects
Exemplo n.º 15
0
    def _call(self, inp, inp_features, is_training, is_posterior=True):
        print("\n" + "-" * 10 +
              " GridObjectLayer(is_posterior={}) ".format(is_posterior) +
              "-" * 10)

        # --- set up sub networks and attributes ---

        self.maybe_build_subnet("box_network",
                                builder=cfg.build_lateral,
                                key="box")
        self.maybe_build_subnet("attr_network",
                                builder=cfg.build_lateral,
                                key="attr")
        self.maybe_build_subnet("z_network",
                                builder=cfg.build_lateral,
                                key="z")
        self.maybe_build_subnet("obj_network",
                                builder=cfg.build_lateral,
                                key="obj")

        self.maybe_build_subnet("object_encoder")

        _, H, W, _, _ = inp_features.shape
        H = int(H)
        W = int(W)

        if not self.initialized:
            # Note this limits the re-usability of this module to images
            # with a fixed shape (the shape of the first image it is used on)
            self.image_height = int(inp.shape[-3])
            self.image_width = int(inp.shape[-2])
            self.image_depth = int(inp.shape[-1])
            self.H = H
            self.W = W
            self.HWB = H * W * self.B
            self.batch_size = tf.shape(inp)[0]
            self.is_training = is_training
            self.float_is_training = tf.to_float(is_training)

        # --- set up the edge element ---

        sizes = [4, self.A, 1, 1]
        sigmoids = [True, False, False, True]
        total_sample_size = sum(sizes)

        if self.edge_weights is None:
            self.edge_weights = tf.get_variable("edge_weights",
                                                shape=total_sample_size,
                                                dtype=tf.float32)
            if "edge" in self.fixed_weights:
                tf.add_to_collection(FIXED_COLLECTION, self.edge_weights)

        _edge_weights = tf.split(self.edge_weights, sizes, axis=0)
        _edge_weights = [(tf.nn.sigmoid(ew) if sigmoid else ew)
                         for ew, sigmoid in zip(_edge_weights, sigmoids)]
        edge_element = tf.concat(_edge_weights, axis=0)
        edge_element = tf.tile(edge_element[None, :], (self.batch_size, 1))

        # --- containers for storing built program ---

        _tensors = collections.defaultdict(self._make_empty)
        program = np.empty((H, W, self.B), dtype=np.object)

        # --- build the program ---

        is_posterior_tf = tf.ones((self.batch_size, 2))
        if is_posterior:
            is_posterior_tf = is_posterior_tf * [1, 0]
        else:
            is_posterior_tf = is_posterior_tf * [0, 1]

        for h, w, b in itertools.product(range(H), range(W), range(self.B)):
            partial_program, features = None, None
            context = self._get_sequential_context(program, h, w, b,
                                                   edge_element)
            base_features = tf.concat(
                [inp_features[:, h, w, b, :], context, is_posterior_tf],
                axis=1)

            # --- box ---

            layer_inp = base_features
            n_features = self.n_passthrough_features
            output_size = 8

            network_output = self.box_network(layer_inp,
                                              output_size + n_features,
                                              self.is_training)
            rep_input, features = tf.split(network_output,
                                           (output_size, n_features),
                                           axis=1)

            built = self._build_box(rep_input, h, w, b, self.is_training)

            for key, value in built.items():
                _tensors[key][h, w, b] = value
            partial_program = built['box']

            # --- attr ---

            if is_posterior:
                # --- Get object attributes using object encoder ---

                yt, xt, ys, xs = tf.split(built['normalized_box'], 4, axis=-1)

                # yt/xt give top/left but here we need center
                yt += ys / 2
                xt += xs / 2

                transform_constraints = snt.AffineWarpConstraints.no_shear_2d()
                warper = snt.AffineGridWarper(
                    (self.image_height, self.image_width), self.object_shape,
                    transform_constraints)

                _boxes = tf.concat([xs, 2 * xt - 1, ys, 2 * yt - 1], axis=-1)

                grid_coords = warper(_boxes)
                grid_coords = tf.reshape(grid_coords, (
                    self.batch_size,
                    1,
                    *self.object_shape,
                    2,
                ))
                glimpse = resampler_edge.resampler_edge(inp, grid_coords)
                glimpse = tf.reshape(
                    glimpse,
                    (self.batch_size, *self.object_shape, self.image_depth))
            else:
                glimpse = tf.zeros(
                    (self.batch_size, *self.object_shape, self.image_depth))

            # Create the object encoder network regardless of is_posterior, otherwise messes with ScopedFunction
            encoded_glimpse = self.object_encoder(glimpse, (1, 1, self.A),
                                                  self.is_training)
            encoded_glimpse = tf.reshape(encoded_glimpse,
                                         (self.batch_size, self.A))

            if not is_posterior:
                encoded_glimpse = tf.zeros_like(encoded_glimpse)

            layer_inp = tf.concat(
                [base_features, features, encoded_glimpse, partial_program],
                axis=1)
            network_output = self.attr_network(layer_inp,
                                               2 * self.A + n_features,
                                               self.is_training)
            attr_mean, attr_log_std, features = tf.split(
                network_output, (self.A, self.A, n_features), axis=1)

            attr_std = self.std_nonlinearity(attr_log_std)

            attr_dist = Normal(loc=attr_mean, scale=attr_std)
            attr = attr_dist.sample()

            if "attr" in self.no_gradient:
                attr = tf.stop_gradient(attr)

            built = dict(attr_dist=attr_dist, attr=attr, glimpse=glimpse)

            for key, value in built.items():
                _tensors[key][h, w, b] = value
            partial_program = tf.concat([partial_program, built['attr']],
                                        axis=1)

            # --- z ---

            layer_inp = tf.concat([base_features, features, partial_program],
                                  axis=1)
            n_features = self.n_passthrough_features

            network_output = self.z_network(layer_inp, 2 + n_features,
                                            self.is_training)
            z_mean, z_log_std, features = tf.split(network_output,
                                                   (1, 1, n_features),
                                                   axis=1)
            z_std = self.std_nonlinearity(z_log_std)

            z_mean = self.training_wheels * tf.stop_gradient(z_mean) + (
                1 - self.training_wheels) * z_mean
            z_std = self.training_wheels * tf.stop_gradient(z_std) + (
                1 - self.training_wheels) * z_std
            z_logit_dist = Normal(loc=z_mean, scale=z_std)
            z_logits = z_logit_dist.sample()
            z = self.z_nonlinearity(z_logits)

            if "z" in self.no_gradient:
                z = tf.stop_gradient(z)

            if "z" in self.fixed_values:
                z = self.fixed_values['z'] * tf.ones_like(z)

            built = dict(z_logit_dist=z_logit_dist, z=z)

            for key, value in built.items():
                _tensors[key][h, w, b] = value
            partial_program = tf.concat([partial_program, built['z']], axis=1)

            # --- obj ---

            layer_inp = tf.concat([base_features, features, partial_program],
                                  axis=1)
            rep_input = self.obj_network(layer_inp, 1, self.is_training)

            built = self._build_obj(rep_input, self.is_training)

            for key, value in built.items():
                _tensors[key][h, w, b] = value
            partial_program = tf.concat([partial_program, built['obj']],
                                        axis=1)

            # --- final ---

            program[h, w, b] = partial_program
            assert program[h, w, b].shape[1] == total_sample_size

        # --- merge tensors from different grid cells ---

        objects = AttrDict()
        for k, v in _tensors.items():
            if k.endswith('_dist'):
                dist = v[0, 0, 0]
                dist_class = type(dist)
                params = dist.parameters.copy()
                tensor_keys = sorted(key for key, t in params.items()
                                     if isinstance(t, tf.Tensor))
                tensor_params = {}

                for key in tensor_keys:
                    t1 = []
                    for h in range(H):
                        t2 = []
                        for w in range(W):
                            t2.append(
                                tf.stack([
                                    v[h, w, b].parameters[key]
                                    for b in range(self.B)
                                ],
                                         axis=1))
                        t1.append(tf.stack(t2, axis=1))
                    tensor_params[key] = tf.stack(t1, axis=1)

                params.update(tensor_params)
                objects[k] = dist_class(**params)
            else:
                t1 = []
                for h in range(H):
                    t2 = []
                    for w in range(W):
                        t2.append(
                            tf.stack([v[h, w, b] for b in range(self.B)],
                                     axis=1))
                    t1.append(tf.stack(t2, axis=1))
                objects[k] = tf.stack(t1, axis=1)

        objects.all = tf.concat(
            [objects.box, objects.attr, objects.z, objects.obj], axis=-1)

        # --- misc ---

        objects.n_objects = tf.fill((self.batch_size, ), self.HWB)
        objects.pred_n_objects = tf.reduce_sum(objects.obj, axis=(1, 2, 3, 4))
        objects.pred_n_objects_hard = tf.reduce_sum(tf.round(objects.obj),
                                                    axis=(1, 2, 3, 4))

        return objects
Exemplo n.º 16
0
    def build_background(self):
        if cfg.background_cfg.mode == "colour":
            rgb = np.array(to_rgb(cfg.background_cfg.colour))[None, None, None, :]
            background = rgb * tf.ones_like(self.inp)

        elif cfg.background_cfg.mode == "learn_solid":
            # Learn a solid colour for the background
            self.solid_background_logits = tf.get_variable("solid_background", initializer=[0.0, 0.0, 0.0])
            if "background" in self.fixed_weights:
                tf.add_to_collection(FIXED_COLLECTION, self.solid_background_logits)
            solid_background = tf.nn.sigmoid(10 * self.solid_background_logits)
            background = solid_background[None, None, None, :] * tf.ones_like(self.inp)

        elif cfg.background_cfg.mode == "scalor":
            pass

        elif cfg.background_cfg.mode == "learn":
            self.maybe_build_subnet("background_encoder")
            self.maybe_build_subnet("background_decoder")

            # Here I'm just encoding the first frame...
            bg_attr = self.background_encoder(self.inp[:, 0], 2 * cfg.background_cfg.A, self.is_training)
            bg_attr_mean, bg_attr_log_std = tf.split(bg_attr, 2, axis=-1)
            bg_attr_std = tf.exp(bg_attr_log_std)
            if not self.noisy:
                bg_attr_std = tf.zeros_like(bg_attr_std)

            bg_attr, bg_attr_kl = normal_vae(bg_attr_mean, bg_attr_std, self.attr_prior_mean, self.attr_prior_std)

            self._tensors.update(
                bg_attr_mean=bg_attr_mean,
                bg_attr_std=bg_attr_std,
                bg_attr_kl=bg_attr_kl,
                bg_attr=bg_attr)

            # --- decode ---

            _, T, H, W, _ = tf_shape(self.inp)

            background = self.background_decoder(bg_attr, 3, self.is_training)
            assert len(background.shape) == 2 and background.shape[1] == 3
            background = tf.nn.sigmoid(tf.clip_by_value(background, -10, 10))
            background = tf.tile(background[:, None, None, None, :], (1, T, H, W, 1))

        elif cfg.background_cfg.mode == "learn_and_transform":
            self.maybe_build_subnet("background_encoder")
            self.maybe_build_subnet("background_decoder")

            # --- encode ---

            n_transform_latents = 4
            n_latents = (2 * cfg.background_cfg.A, 2 * n_transform_latents)

            bg_attr, bg_transform_params = self.background_encoder(self.inp, n_latents, self.is_training)

            # --- bg attributes ---

            bg_attr_mean, bg_attr_log_std = tf.split(bg_attr, 2, axis=-1)
            bg_attr_std = self.std_nonlinearity(bg_attr_log_std)

            bg_attr, bg_attr_kl = normal_vae(bg_attr_mean, bg_attr_std, self.attr_prior_mean, self.attr_prior_std)

            # --- bg location ---

            bg_transform_params = tf.reshape(
                bg_transform_params,
                (self.batch_size, self.dynamic_n_frames, 2*n_transform_latents))

            mean, log_std = tf.split(bg_transform_params, 2, axis=2)
            std = self.std_nonlinearity(log_std)

            logits, kl = normal_vae(mean, std, 0.0, 1.0)

            # integrate across timesteps
            logits = tf.cumsum(logits, axis=1)
            logits = tf.reshape(logits, (self.batch_size*self.dynamic_n_frames, n_transform_latents))

            y, x, h, w = tf.split(logits, n_transform_latents, axis=1)
            h = (0.9 - 0.5) * tf.nn.sigmoid(h) + 0.5
            w = (0.9 - 0.5) * tf.nn.sigmoid(w) + 0.5
            y = (1 - h) * tf.nn.tanh(y)
            x = (1 - w) * tf.nn.tanh(x)

            # --- decode ---

            background = self.background_decoder(bg_attr, self.image_depth, self.is_training)
            bg_shape = cfg.background_cfg.bg_shape
            background = background[:, :bg_shape[0], :bg_shape[1], :]
            assert background.shape[1:3] == bg_shape
            background_raw = tf.nn.sigmoid(tf.clip_by_value(background, -10, 10))

            transform_constraints = snt.AffineWarpConstraints.no_shear_2d()

            warper = snt.AffineGridWarper(
                bg_shape, (self.image_height, self.image_width), transform_constraints)

            transforms = tf.concat([w, x, h, y], axis=-1)
            grid_coords = warper(transforms)

            grid_coords = tf.reshape(
                grid_coords,
                (self.batch_size, self.dynamic_n_frames, *tf_shape(grid_coords)[1:]))

            background = tf.contrib.resampler.resampler(background_raw, grid_coords)

            self._tensors.update(
                bg_attr_mean=bg_attr_mean,
                bg_attr_std=bg_attr_std,
                bg_attr_kl=bg_attr_kl,
                bg_attr=bg_attr,
                bg_y=tf.reshape(y, (self.batch_size, self.dynamic_n_frames, 1)),
                bg_x=tf.reshape(x, (self.batch_size, self.dynamic_n_frames, 1)),
                bg_h=tf.reshape(h, (self.batch_size, self.dynamic_n_frames, 1)),
                bg_w=tf.reshape(w, (self.batch_size, self.dynamic_n_frames, 1)),
                bg_transform_kl=kl,
                bg_raw=background_raw,
            )

        elif cfg.background_cfg.mode == "data":
            background = self._tensors["ground_truth_background"]

        else:
            raise Exception("Unrecognized background mode: {}.".format(cfg.background_cfg.mode))

        self._tensors["background"] = background[:, :self.dynamic_n_frames]