def _call(self, inp, inp_features, is_training, is_posterior=True, prop_state=None):
        print("\n" + "-" * 10 + " ConvGridObjectLayer(is_posterior={}) ".format(is_posterior) + "-" * 10)

        # --- set up sub networks and attributes ---

        self.maybe_build_subnet("box_network", builder=cfg.build_conv_lateral, key="box")
        self.maybe_build_subnet("attr_network", builder=cfg.build_conv_lateral, key="attr")
        self.maybe_build_subnet("z_network", builder=cfg.build_conv_lateral, key="z")
        self.maybe_build_subnet("obj_network", builder=cfg.build_conv_lateral, key="obj")

        self.maybe_build_subnet("object_encoder")

        _, H, W, _, n_channels = tf_shape(inp_features)

        if self.B != 1:
            raise Exception("NotImplemented")

        if not self.initialized:
            # Note this limits the re-usability of this module to images
            # with a fixed shape (the shape of the first image it is used on)
            self.batch_size, self.image_height, self.image_width, self.image_depth = tf_shape(inp)
            self.H = H
            self.W = W
            self.HWB = H*W
            self.batch_size = tf.shape(inp)[0]
            self.is_training = is_training
            self.float_is_training = tf.to_float(is_training)

        inp_features = tf.reshape(inp_features, (self.batch_size, H, W, n_channels))
        is_posterior_tf = tf.ones_like(inp_features[..., :2])
        if is_posterior:
            is_posterior_tf = is_posterior_tf * [1, 0]
        else:
            is_posterior_tf = is_posterior_tf * [0, 1]

        objects = AttrDict()

        base_features = tf.concat([inp_features, is_posterior_tf], axis=-1)

        # --- box ---

        layer_inp = base_features
        n_features = self.n_passthrough_features
        output_size = 8

        network_output = self.box_network(layer_inp, output_size + n_features, self.is_training)
        rep_input, features = tf.split(network_output, (output_size, n_features), axis=-1)

        _objects = self._build_box(rep_input, self.is_training)
        objects.update(_objects)

        # --- attr ---

        if is_posterior:
            # --- Get object attributes using object encoder ---

            yt, xt, ys, xs = tf.split(objects['normalized_box'], 4, axis=-1)

            yt, xt, ys, xs = coords_to_image_space(
                yt, xt, ys, xs, (self.image_height, self.image_width), self.anchor_box, top_left=False)

            transform_constraints = snt.AffineWarpConstraints.no_shear_2d()
            warper = snt.AffineGridWarper(
                (self.image_height, self.image_width), self.object_shape, transform_constraints)

            _boxes = tf.concat([xs, 2*xt - 1, ys, 2*yt - 1], axis=-1)
            _boxes = tf.reshape(_boxes, (self.batch_size*H*W, 4))
            grid_coords = warper(_boxes)
            grid_coords = tf.reshape(grid_coords, (self.batch_size, H, W, *self.object_shape, 2,))

            if self.edge_resampler:
                glimpse = resampler_edge.resampler_edge(inp, grid_coords)
            else:
                glimpse = tf.contrib.resampler.resampler(inp, grid_coords)
        else:
            glimpse = tf.zeros((self.batch_size, H, W, *self.object_shape, self.image_depth))

        # Create the object encoder network regardless of is_posterior, otherwise messes with ScopedFunction
        encoded_glimpse = apply_object_wise(
            self.object_encoder, glimpse, n_trailing_dims=3, output_size=self.A, is_training=self.is_training)

        if not is_posterior:
            encoded_glimpse = tf.zeros_like(encoded_glimpse)

        layer_inp = tf.concat([base_features, features, encoded_glimpse, objects['local_box']], axis=-1)
        network_output = self.attr_network(layer_inp, 2 * self.A + n_features, self.is_training)
        attr_mean, attr_log_std, features = tf.split(network_output, (self.A, self.A, n_features), axis=-1)

        attr_std = self.std_nonlinearity(attr_log_std)

        attr = Normal(loc=attr_mean, scale=attr_std).sample()

        objects.update(attr_mean=attr_mean, attr_std=attr_std, attr=attr, glimpse=glimpse)

        # --- z ---

        layer_inp = tf.concat([base_features, features, objects['local_box'], objects['attr']], axis=-1)
        n_features = self.n_passthrough_features

        network_output = self.z_network(layer_inp, 2 + n_features, self.is_training)
        z_mean, z_log_std, features = tf.split(network_output, (1, 1, n_features), axis=-1)
        z_std = self.std_nonlinearity(z_log_std)

        z_mean = self.training_wheels * tf.stop_gradient(z_mean) + (1-self.training_wheels) * z_mean
        z_std = self.training_wheels * tf.stop_gradient(z_std) + (1-self.training_wheels) * z_std
        z_logit = Normal(loc=z_mean, scale=z_std).sample()
        z = self.z_nonlinearity(z_logit)

        objects.update(z_logit_mean=z_mean, z_logit_std=z_std, z_logit=z_logit, z=z)

        # --- obj ---

        layer_inp = tf.concat([base_features, features, objects['local_box'], objects['attr'], objects['z']], axis=-1)
        rep_input = self.obj_network(layer_inp, 1, self.is_training)

        _objects = self._build_obj(rep_input, self.is_training)
        objects.update(_objects)

        # --- final ---

        _objects = AttrDict()
        for k, v in objects.items():
            _, _, _, *trailing_dims = tf_shape(v)
            _objects[k] = tf.reshape(v, (self.batch_size, self.HWB, *trailing_dims))
        objects = _objects

        if prop_state is not None:
            objects.prop_state = tf.tile(prop_state[0:1, None], (self.batch_size, self.HWB, 1))
            objects.prior_prop_state = tf.tile(prop_state[0:1, None], (self.batch_size, self.HWB, 1))

        # --- misc ---

        objects.n_objects = tf.fill((self.batch_size,), self.HWB)
        objects.pred_n_objects = tf.reduce_sum(objects.obj, axis=(1, 2))
        objects.pred_n_objects_hard = tf.reduce_sum(tf.round(objects.obj), axis=(1, 2))

        return objects
    def _call(self, inp, inp_features, is_training, is_posterior=True, prop_state=None):
        print("\n" + "-" * 10 + " GridObjectLayer(is_posterior={}) ".format(is_posterior) + "-" * 10)

        # --- set up sub networks and attributes ---

        self.maybe_build_subnet("box_network", builder=cfg.build_lateral, key="box")
        self.maybe_build_subnet("attr_network", builder=cfg.build_lateral, key="attr")
        self.maybe_build_subnet("z_network", builder=cfg.build_lateral, key="z")
        self.maybe_build_subnet("obj_network", builder=cfg.build_lateral, key="obj")

        self.maybe_build_subnet("object_encoder")

        _, H, W, _, _ = tf_shape(inp_features)
        H = int(H)
        W = int(W)

        if not self.initialized:
            # Note this limits the re-usability of this module to images
            # with a fixed shape (the shape of the first image it is used on)
            self.batch_size, self.image_height, self.image_width, self.image_depth = tf_shape(inp)
            self.H = H
            self.W = W
            self.HWB = H*W*self.B
            self.is_training = is_training
            self.float_is_training = tf.to_float(is_training)

        # --- set up the edge element ---

        sizes = [4, self.A, 1, 1]
        sigmoids = [True, False, False, True]
        total_sample_size = sum(sizes)

        if self.edge_weights is None:
            self.edge_weights = tf.get_variable("edge_weights", shape=total_sample_size, dtype=tf.float32)
            if "edge" in self.fixed_weights:
                tf.add_to_collection(FIXED_COLLECTION, self.edge_weights)

        _edge_weights = tf.split(self.edge_weights, sizes, axis=0)
        _edge_weights = [
            (tf.nn.sigmoid(ew) if sigmoid else ew)
            for ew, sigmoid in zip(_edge_weights, sigmoids)]
        edge_element = tf.concat(_edge_weights, axis=0)
        edge_element = tf.tile(edge_element[None, :], (self.batch_size, 1))

        # --- containers for storing built program ---

        program = np.empty((H, W, self.B), dtype=np.object)

        # --- build the program ---

        is_posterior_tf = tf.ones((self.batch_size, 2))
        if is_posterior:
            is_posterior_tf = is_posterior_tf * [1, 0]
        else:
            is_posterior_tf = is_posterior_tf * [0, 1]

        results = []
        for h, w, b in itertools.product(range(H), range(W), range(self.B)):
            built = dict()

            partial_program, features = None, None
            context = self._get_sequential_context(program, h, w, b, edge_element)
            base_features = tf.concat([inp_features[:, h, w, b, :], context, is_posterior_tf], axis=1)

            # --- box ---

            layer_inp = base_features
            n_features = self.n_passthrough_features
            output_size = 8

            network_output = self.box_network(layer_inp, output_size + n_features, self. is_training)
            rep_input, features = tf.split(network_output, (output_size, n_features), axis=1)

            _built = self._build_box(rep_input, self.is_training, hw=(h, w))
            built.update(_built)
            partial_program = built['local_box']

            # --- attr ---

            if is_posterior:
                # --- Get object attributes using object encoder ---

                yt, xt, ys, xs = tf.split(built['normalized_box'], 4, axis=-1)

                yt, xt, ys, xs = coords_to_image_space(
                    yt, xt, ys, xs, (self.image_height, self.image_width), self.anchor_box, top_left=False)

                transform_constraints = snt.AffineWarpConstraints.no_shear_2d()
                warper = snt.AffineGridWarper(
                    (self.image_height, self.image_width), self.object_shape, transform_constraints)

                _boxes = tf.concat([xs, 2*xt - 1, ys, 2*yt - 1], axis=-1)

                grid_coords = warper(_boxes)
                grid_coords = tf.reshape(grid_coords, (self.batch_size, 1, *self.object_shape, 2,))
                if self.edge_resampler:
                    glimpse = resampler_edge.resampler_edge(inp, grid_coords)
                else:
                    glimpse = tf.contrib.resampler.resampler(inp, grid_coords)
                glimpse = tf.reshape(glimpse, (self.batch_size, *self.object_shape, self.image_depth))
            else:
                glimpse = tf.zeros((self.batch_size, *self.object_shape, self.image_depth))

            # Create the object encoder network regardless of is_posterior, otherwise messes with ScopedFunction
            encoded_glimpse = self.object_encoder(glimpse, (1, 1, self.A), self.is_training)
            encoded_glimpse = tf.reshape(encoded_glimpse, (self.batch_size, self.A))

            if not is_posterior:
                encoded_glimpse = tf.zeros_like(encoded_glimpse)

            layer_inp = tf.concat(
                [base_features, features, encoded_glimpse, partial_program], axis=1)
            network_output = self.attr_network(layer_inp, 2 * self.A + n_features, self. is_training)
            attr_mean, attr_log_std, features = tf.split(network_output, (self.A, self.A, n_features), axis=1)

            attr_std = self.std_nonlinearity(attr_log_std)

            attr = Normal(loc=attr_mean, scale=attr_std).sample()

            built.update(attr_mean=attr_mean, attr_std=attr_std, attr=attr, glimpse=glimpse)
            partial_program = tf.concat([partial_program, built['attr']], axis=1)

            # --- z ---

            layer_inp = tf.concat([base_features, features, partial_program], axis=1)
            n_features = self.n_passthrough_features

            network_output = self.z_network(layer_inp, 2 + n_features, self.is_training)
            z_mean, z_log_std, features = tf.split(network_output, (1, 1, n_features), axis=1)
            z_std = self.std_nonlinearity(z_log_std)

            z_mean = self.training_wheels * tf.stop_gradient(z_mean) + (1-self.training_wheels) * z_mean
            z_std = self.training_wheels * tf.stop_gradient(z_std) + (1-self.training_wheels) * z_std
            z_logit = Normal(loc=z_mean, scale=z_std).sample()
            z = self.z_nonlinearity(z_logit)

            built.update(z_logit_mean=z_mean, z_logit_std=z_std, z_logit=z_logit, z=z)
            partial_program = tf.concat([partial_program, built['z']], axis=1)

            # --- obj ---

            layer_inp = tf.concat([base_features, features, partial_program], axis=1)
            rep_input = self.obj_network(layer_inp, 1, self.is_training)

            _built = self._build_obj(rep_input, self.is_training)
            built.update(_built)

            partial_program = tf.concat([partial_program, built['obj']], axis=1)

            # --- final ---

            results.append(built)

            program[h, w, b] = partial_program
            assert program[h, w, b].shape[1] == total_sample_size

        objects = AttrDict()
        for k in results[0]:
            objects[k] = tf.stack([r[k] for r in results], axis=1)

        if prop_state is not None:
            objects.prop_state = tf.tile(prop_state[0:1, None], (self.batch_size, self.HWB, 1))
            objects.prior_prop_state = tf.tile(prop_state[0:1, None], (self.batch_size, self.HWB, 1))

        # --- misc ---

        objects.n_objects = tf.fill((self.batch_size,), self.HWB)
        objects.pred_n_objects = tf.reduce_sum(objects.obj, axis=(1, 2))
        objects.pred_n_objects_hard = tf.reduce_sum(tf.round(objects.obj), axis=(1, 2))

        return objects