示例#1
0
def extract_affine_glimpse(image, object_shape, cyt, cxt, ys, xs,
                           edge_resampler):
    """ (cyt, cxt) are rectangle center. (ys, xs) are rectangle height/width """
    _, *image_shape, image_depth = tf_shape(image)
    transform_constraints = snt.AffineWarpConstraints.no_shear_2d()
    warper = snt.AffineGridWarper(image_shape, object_shape,
                                  transform_constraints)

    # change coordinate system
    cyt = 2 * cyt - 1
    cxt = 2 * cxt - 1

    leading_shape = tf_shape(cyt)[:-1]

    _boxes = tf.concat([xs, cxt, ys, cyt], axis=-1)
    _boxes = tf.reshape(_boxes, (-1, 4))

    grid_coords = warper(_boxes)

    grid_coords = tf.reshape(grid_coords, (*leading_shape, *object_shape, 2))

    if edge_resampler:
        glimpses = resampler_edge.resampler_edge(image, grid_coords)
    else:
        glimpses = tf.contrib.resampler.resampler(image, grid_coords)

    glimpses = tf.reshape(glimpses,
                          (*leading_shape, *object_shape, image_depth))

    return glimpses
示例#2
0
    def _call(self, input_signal, input_locs, output_locs, is_training):
        if not self.is_built:
            self.value_func = self.build_mlp(scope="value_func")
            self.after_func = self.build_mlp(scope="after")

            if self.do_object_wise:
                self.object_wise_func = self.build_object_wise(
                    scope="object_wise")

            self.is_built = True

        batch_size, n_inp, _ = tf_shape(input_signal)
        loc_dim = tf_shape(input_locs)[-1]
        n_outp = tf_shape(output_locs)[-2]
        input_locs = tf.broadcast_to(input_locs, (batch_size, n_inp, loc_dim))
        output_locs = tf.broadcast_to(output_locs,
                                      (batch_size, n_outp, loc_dim))

        dist = output_locs[:, :, None, :] - input_locs[:, None, :, :]
        proximity = tf.exp(-0.5 * tf.reduce_sum(
            (dist / self.kernel_std)**2, axis=3))
        proximity = proximity / (2 * np.pi)**(
            0.5 * loc_dim) / self.kernel_std**loc_dim

        V = apply_object_wise(
            self.value_func,
            input_signal,
            output_size=self.n_hidden,
            is_training=is_training)  # (batch_size, n_inp, value_dim)

        result = tf.matmul(proximity, V)  # (batch_size, n_outp, value_dim)

        # `after_func` is applied to the concatenation of the head outputs, and the result is added to the original
        # signal. Next, if `object_wise_func` is not None and `do_object_wise` is True, object_wise_func is
        # applied object wise and in a ResNet-style manner.

        output = apply_object_wise(self.after_func,
                                   result,
                                   output_size=self.n_hidden,
                                   is_training=is_training)
        output = tf.layers.dropout(output,
                                   self.p_dropout,
                                   training=is_training)
        signal = tf.contrib.layers.layer_norm(output)

        if self.do_object_wise:
            output = apply_object_wise(self.object_wise_func,
                                       signal,
                                       output_size=self.n_hidden,
                                       is_training=is_training)
            output = tf.layers.dropout(output,
                                       self.p_dropout,
                                       training=is_training)
            signal = tf.contrib.layers.layer_norm(signal + output)

        return signal
示例#3
0
    def _call(self, inp, output_size, is_training):
        if self.bg_head is None:
            self.bg_head = ConvNet(
                layers=[
                    dict(filters=None, kernel_size=1, strides=1, padding="SAME"),
                    dict(filters=None, kernel_size=1, strides=1, padding="SAME"),
                ],
                scope="bg_head"
            )

        if self.transform_head is None:
            self.transform_head = MLP(n_units=[64, 64], scope="transform_head")

        n_attr_channels, n_transform_values = output_size
        processed = super()._call(inp, n_attr_channels, is_training)
        B, F, H, W, C = tf_shape(processed)

        # Map processed to shapes (B, H, W, C) and (B, F, 2)

        bg_attrs = self.bg_head(tf.reduce_mean(processed, axis=1), None, is_training)

        transform_values = self.transform_head(
            tf.reshape(processed, (B*F, H*W*C)),
            n_transform_values, is_training)

        transform_values = tf.reshape(transform_values, (B, F, n_transform_values))

        return bg_attrs, transform_values
示例#4
0
    def __call__(self, tensors):
        batch_size = tf_shape(tensors["obj"])[0]

        exp_rate = self.exp_rate
        assert_exp_rate_gt_zero = tf.Assert(exp_rate >= 0, [exp_rate],
                                            name='assert_exp_rate_gt_zero')

        with tf.control_dependencies([assert_exp_rate_gt_zero]):
            posterior_log_pdf = logistic_log_pdf(tensors["obj_log_odds"],
                                                 tensors["obj_pre_sigmoid"],
                                                 self.obj_concrete_temp)
            posterior_log_pdf = tf.reduce_sum(tf.reshape(
                posterior_log_pdf, (batch_size, -1)),
                                              axis=1)

        # This is different from the true log prior pdf by a constant factor,
        # namely the log of the normalization constant for the prior.
        concrete_sum = tf.reduce_sum(tf.reshape(tensors["obj"],
                                                (batch_size, -1)),
                                     axis=1)

        # prior_pdf = exp_rate * tf.exp(-exp_rate * concrete_sum)
        prior_log_pdf = -exp_rate * concrete_sum

        return posterior_log_pdf - prior_log_pdf
示例#5
0
    def __call__(self, tensors):
        kl = concrete_binary_sample_kl(tensors["obj_pre_sigmoid"],
                                       tensors["obj_log_odds"],
                                       self.obj_concrete_temp,
                                       self.prior_log_odds,
                                       self.obj_concrete_temp)

        batch_size = tf_shape(tensors["obj_pre_sigmoid"])[0]
        return tf.reduce_sum(tf.reshape(kl, (batch_size, -1)), 1)
示例#6
0
    def build_representation(self):
        assert cfg.background_cfg.mode == 'colour'
        self.build_background()

        # dummy variable to satisfy dps
        tf.get_variable("dummy", shape=(1, ), dtype=tf.float32)

        B, T, *rest = tf_shape(self._tensors["background"])

        inp = tf.reshape(self._tensors["inp"], (T * B, *rest))
        bg = tf.reshape(self._tensors["background"], (T * B, *rest))

        program_tensors = tf_find_connected_components(inp, bg,
                                                       self.cc_threshold,
                                                       self.colours,
                                                       self.cosine_threshold)

        self._tensors.update({
            k: tf.reshape(v, (B, T, *tf_shape(v)[1:]))
            for k, v in program_tensors.items() if k != 'max_objects'
        })

        if "n_annotations" in self._tensors:
            count_1norm = tf.to_float(
                tf.abs(
                    tf.to_int32(self._tensors["n_objects"]) -
                    self._tensors["n_valid_annotations"]))
            count_1norm_relative = (count_1norm / tf.maximum(
                tf.cast(self._tensors["n_valid_annotations"], tf.float32),
                1e-6))
            self.record_tensors(
                count_1norm_relative=count_1norm_relative,
                count_1norm=count_1norm,
                count_error=count_1norm > 0.5,
                n_objects_per_frame=self._tensors["n_objects"],
            )
示例#7
0
def apply_object_wise(func, signal, output_size, is_training, restore_shape=True, n_trailing_dims=1):
    """ Treat `signal` as a batch of objects. Apply function `func` separately to each object.
        The final `n_trailing_dims`-many dimensions are treated as "within-object" dimensions.
        By default, objects are assumed to be vectors, but this can be changed by increasing
        `n_trailing_dims`. e.g. n_trailing_dims==2 means each object is a matrix, i.e. the
        last 2 dimensions  of signal are dimensions of the object.

    """
    shape = tf_shape(signal)
    leading_dim = tf.reduce_prod(shape[:-n_trailing_dims])
    signal = tf.reshape(signal, (leading_dim, *shape[-n_trailing_dims:]))
    output = func(signal, output_size, is_training)

    if restore_shape:
        if not isinstance(output_size, tuple):
            output_size = [output_size]
        output = tf.reshape(output, (*shape[:-n_trailing_dims], *output_size))

    return output
示例#8
0
def tile_input_for_iwae(tensor, iw_samples, with_time=False):
    """Tiles tensor `tensor` in such a way that tiled samples are contiguous in memory;
    i.e. it tiles along the axis after the batch axis and reshapes to have the same rank as
    the original tensor

    :param tensor: tf.Tensor to be tiled
    :param iw_samples: int, number of importance-weighted samples
    :param with_time: boolean, if true than an additional axis at the beginning is assumed
    :return:
    """
    shape = list(tf_shape(tensor))
    shape[with_time] *= iw_samples

    tiles = [1, iw_samples] + [1] * (tensor.shape.ndims - (1 + with_time))
    if with_time:
        tiles = [1] + tiles

    tensor = tf.expand_dims(tensor, 1 + with_time)
    tensor = tf.tile(tensor, tiles)
    tensor = tf.reshape(tensor, shape)
    return tensor
示例#9
0
    def null_object_set(self, batch_size):
        n_prop_objects = self.n_prop_objects

        new_objects = AttrDict(
            normalized_box=tf.zeros((batch_size, n_prop_objects, 4)),
            attr=tf.zeros((batch_size, n_prop_objects, self.A)),
            z=tf.zeros((batch_size, n_prop_objects, 1)),
            obj=tf.zeros((batch_size, n_prop_objects, 1)),
        )

        yt, xt, ys, xs = tf.split(new_objects.normalized_box, 4, axis=-1)

        new_objects.update(
            abs_posn=new_objects.normalized_box[..., :2] + 0.0,
            yt=yt,
            xt=xt,
            ys=ys,
            xs=xs,
            ys_logit=ys + 0.0,
            xs_logit=xs + 0.0,

            # d_yt=yt + 0.0,
            # d_xt=xt + 0.0,
            # d_ys=ys + 0.0,
            # d_xs=xs + 0.0,
            # d_attr=new_objects.attr + 0.0,
            # d_z=new_objects.z + 0.0,
            z_logit=new_objects.z + 0.0,
        )

        prop_state = self.cell.initial_state(batch_size * n_prop_objects,
                                             tf.float32)
        trailing_shape = tf_shape(prop_state)[1:]
        new_objects.prop_state = tf.reshape(
            prop_state, (batch_size, n_prop_objects, *trailing_shape))
        new_objects.prior_prop_state = new_objects.prop_state

        return new_objects
示例#10
0
    def build_representation(self):
        # --- init modules ---

        if self.backbone is None:
            self.backbone = self.build_backbone(scope="backbone")

        if self.cell is None:
            self.cell = cfg.build_cell(self.n_hidden, name="cell")

            # self.cell must be a Sonnet RNNCore

            if self.learn_initial_state:
                self.initial_hidden_state = snt.trainable_initial_state(
                    1,
                    self.cell.state_size,
                    tf.float32,
                    name="initial_hidden_state")

        if self.key_network is None:
            self.key_network = cfg.build_key_network(scope="key_network")
        if self.beta_network is None:
            self.beta_network = cfg.build_beta_network(scope="beta_network")
        if self.write_network is None:
            self.write_network = cfg.build_write_network(scope="write_network")
        if self.erase_network is None:
            self.erase_network = cfg.build_erase_network(scope="erase_network")

        if self.output_network is None:
            self.output_network = cfg.build_output_network(
                scope="output_network")

        d_n_frames, n_trackers, batch_size = self.dynamic_n_frames, self.n_trackers, self.batch_size

        # --- encode ---

        video = tf.reshape(self.inp,
                           (batch_size * d_n_frames, *self.obs_shape[1:]))

        zero_to_one_Y = tf.linspace(0., 1., self.image_height)
        zero_to_one_X = tf.linspace(0., 1., self.image_width)
        X, Y = tf.meshgrid(zero_to_one_X, zero_to_one_Y)
        X = tf.tile(X[None, :, :, None], (batch_size * d_n_frames, 1, 1, 1))
        Y = tf.tile(Y[None, :, :, None], (batch_size * d_n_frames, 1, 1, 1))
        video = tf.concat([video, Y, X], axis=-1)

        encoded_frames, _, _ = self.backbone(video, self.S, self.is_training)

        _, H, W, _ = tf_shape(encoded_frames)
        self.H = H
        self.W = W
        encoded_frames = tf.reshape(encoded_frames,
                                    (batch_size, d_n_frames, H * W, self.S))
        self.encoded_frames = encoded_frames

        cts = tf.minimum(
            1., self.float_is_training + (1 - float(self.discrete_eval)))
        self.mask_temperature = cts * 1.0 + (1 - cts) * 1e-5
        self.layer_temperature = cts * 1.0 + (1 - cts) * 1e-5

        f = tf.constant(0, dtype=tf.int32)

        if self.learn_initial_state:
            hidden_state = self.initial_hidden_state[None, ...]
        else:
            hidden_state = self.cell.zero_state(1, tf.float32)[None, ...]

        hidden_states = tf.tile(hidden_state, (
            self.n_trackers,
            self.batch_size,
        ) + (1, ) * (len(hidden_state.shape) - 2))

        conf = tf.zeros((n_trackers, batch_size, 1))

        structure = dict(
            hidden_states=hidden_states,
            tracker_output=tf.zeros(
                (self.n_trackers, self.batch_size, self.n_hidden)),
            attention_result=tf.zeros(
                (self.n_trackers, self.batch_size, self.S)),
            attention_weights=tf.zeros(
                (self.n_trackers, self.batch_size, H * W)),
            memory_activation=tf.zeros(
                (self.n_trackers, self.batch_size, H * W)),
            conf=conf,
            layer=tf.zeros((self.n_trackers, self.batch_size, self.n_layers)),
            pose=tf.zeros((self.n_trackers, self.batch_size, 4)),
            mask=tf.zeros((self.n_trackers, self.batch_size,
                           np.prod(self.object_shape))),
            appearance=tf.zeros((self.n_trackers, self.batch_size,
                                 3 * np.prod(self.object_shape))),
            order=tf.zeros((self.n_trackers, self.batch_size, 1),
                           dtype=tf.int32),
        )
        tensor_arrays = make_tensor_arrays(structure, self.dynamic_n_frames)

        loop_vars = [f, conf, hidden_states, *tensor_arrays]

        result = tf.while_loop(self._loop_cond, self._loop_body, loop_vars)

        first_ta_idx = min(i for i, ta in enumerate(result)
                           if isinstance(ta, tf.TensorArray))
        tensor_arrays = result[first_ta_idx:]

        def finalize_ta(ta):
            t = ta.stack()
            # reshape from (n_frames, n_trackers, batch_size, *other) to (batch_size, n_frames, n_trackers, *other)
            return tf.transpose(t, (2, 0, 1, *range(3, len(t.shape))))

        tensors = map_structure(
            finalize_ta,
            tensor_arrays,
            is_leaf=lambda t: isinstance(t, tf.TensorArray))
        tensors = apply_keys(structure, tensors)

        self._tensors.update(tensors)

        pprint.pprint(self._tensors)

        # --- render/decode ---

        ys, xs, yt, xt = tf.split(tensors["pose"], 4, axis=-1)
        self.record_tensors(
            conf=tensors["conf"],
            ys=ys,
            xs=xs,
            yt=yt,
            xt=xt,
        )

        yt_normed = (yt + 1) / 2
        xt_normed = (xt + 1) / 2
        ys_normed = 1 + self.eta[0] * ys
        xs_normed = 1 + self.eta[1] * xs

        normalized_box = tf.concat(
            [yt_normed, xt_normed, ys_normed, xs_normed], axis=-1)

        # expose values for plotting

        self._tensors.update(
            normalized_box=normalized_box,
            mask=tf.reshape(
                tensors["mask"],
                (batch_size, d_n_frames, n_trackers, *self.object_shape)),
            appearance=tf.reshape(tensors["appearance"],
                                  (batch_size, d_n_frames, n_trackers,
                                   *self.object_shape, self.image_depth)),
            conf=tensors["conf"],
            layer=tensors["layer"],
            order=tensors["order"],
        )

        # --- reshape values ---

        N = batch_size * d_n_frames * n_trackers
        ys_normed = tf.reshape(ys_normed, (N, 1))
        xs_normed = tf.reshape(xs_normed, (N, 1))
        yt_normed = tf.reshape(yt_normed, (N, 1))
        xt_normed = tf.reshape(xt_normed, (N, 1))

        _yt, _xt, _ys, _xs = tba_coords_to_image_space(
            yt_normed,
            xt_normed,
            ys_normed,
            xs_normed, (self.image_height, self.image_width),
            self.anchor_box,
            top_left=False)

        mask = tf.reshape(tensors["mask"], (N, *self.object_shape, 1))
        appearance = tf.reshape(tensors["appearance"],
                                (N, *self.object_shape, self.image_depth))

        transform_constraints = snt.AffineWarpConstraints.no_shear_2d()
        warper = snt.AffineGridWarper((self.image_height, self.image_width),
                                      self.object_shape, transform_constraints)
        inverse_warper = warper.inverse()
        transforms = tf.concat([_xs, _xt, _ys, _yt], axis=-1)
        grid_coords = inverse_warper(transforms)

        transformed_masks = tf.contrib.resampler.resampler(mask, grid_coords)
        transformed_masks = tf.reshape(
            transformed_masks, (batch_size, d_n_frames, n_trackers,
                                self.image_height, self.image_width, 1))

        transformed_appearances = tf.contrib.resampler.resampler(
            appearance, grid_coords)
        transformed_appearances = tf.reshape(
            transformed_appearances,
            (batch_size, d_n_frames, n_trackers, self.image_height,
             self.image_width, self.image_depth))

        layer_masks = []
        layer_appearances = []

        conf = tensors["conf"][:, :, :, :, None, None]

        # TODO: currently assuming a black background

        final_frames = tf.zeros((batch_size, d_n_frames, self.image_height,
                                 self.image_width, self.image_depth))

        # For each layer, create a mask image and an appearance image
        for layer_idx in range(self.n_layers):
            layer_weight = tensors["layer"][:, :, :, layer_idx, None, None,
                                            None]

            # (batch_size, n_frames, self.image_height, self.image_width, 1)
            layer_mask = tf.reduce_sum(conf * layer_weight * transformed_masks,
                                       axis=2)
            layer_mask = tf.minimum(1.0, layer_mask)

            # (batch_size, n_frames, self.image_height, self.image_width, 3)
            layer_appearance = tf.reduce_sum(conf * layer_weight *
                                             transformed_masks *
                                             transformed_appearances,
                                             axis=2)

            if self.clamp_appearance:
                layer_appearance = tf.minimum(1.0, layer_appearance)

            final_frames = (1 - layer_mask) * final_frames + layer_appearance

            layer_masks.append(layer_mask)
            layer_appearances.append(layer_appearance)

        self._tensors["output"] = final_frames

        # --- losses ---

        self._tensors['per_pixel_reconstruction_loss'] = (self.inp -
                                                          final_frames)**2

        self.losses['reconstruction'] = (
            tf.reduce_sum(self._tensors["per_pixel_reconstruction_loss"]) /
            tf.cast(d_n_frames * self.batch_size, tf.float32))
        # self.losses['reconstruction'] = tf.reduce_mean(self._tensors["per_pixel_reconstruction_loss"])

        self.losses['area'] = self.lmbda * tf.reduce_mean(
            ys_normed * xs_normed)
示例#11
0
    def _body(self, inp, features, objects, is_posterior):
        """
        Summary of how updates are done for the different variables:

        glimpse': glimpse_params = where + 0.1 * predicted_logit

        where_y/x: new_where_y/x = where_y/x + where_t_scale * tanh(predicted_logit)
        where_h/w: new_where_h/w_logit = where_h/w_logit + predicted_logit
        what: Hard to summarize here, taken from SQAIR. Kind of like an LSTM.
        depth: new_depth_logit = depth_logit + predicted_logit
        obj: new_obj = obj * sigmoid(predicted_logit)

        """
        batch_size, n_objects, _ = tf_shape(features)

        new_objects = AttrDict()

        is_posterior_tf = tf.ones_like(features[..., 0:2])
        if is_posterior:
            is_posterior_tf = is_posterior_tf * [1, 0]
        else:
            is_posterior_tf = is_posterior_tf * [0, 1]

        base_features = tf.concat([features, is_posterior_tf], axis=-1)

        cyt, cxt, ys, xs = tf.split(objects.normalized_box, 4, axis=-1)

        # Do this regardless of is_posterior, otherwise ScopedFunction gets messed up
        glimpse_dim = self.object_shape[0] * self.object_shape[1]
        glimpse_prime_params = apply_object_wise(self.glimpse_prime_network,
                                                 base_features,
                                                 output_size=4 +
                                                 2 * glimpse_dim,
                                                 is_training=self.is_training)

        glimpse_prime_params, glimpse_prime_mask_logit, glimpse_mask_logit = \
            tf.split(glimpse_prime_params, [4, glimpse_dim, glimpse_dim], axis=-1)

        if is_posterior:
            # --- obtain final parameters for glimpse prime by modifying current pose ---

            _yt, _xt, _ys, _xs = tf.split(glimpse_prime_params, 4, axis=-1)

            g_yt = cyt + 0.1 * _yt
            g_xt = cxt + 0.1 * _xt
            g_ys = ys + 0.1 * _ys
            g_xs = xs + 0.1 * _xs

            # --- extract glimpse prime ---

            _, image_height, image_width, _ = tf_shape(inp)
            g_yt, g_xt, g_ys, g_xs = coords_to_image_space(
                g_yt,
                g_xt,
                g_ys,
                g_xs, (image_height, image_width),
                self.anchor_box,
                top_left=False)
            glimpse_prime = extract_affine_glimpse(inp, self.object_shape,
                                                   g_yt, g_xt, g_ys, g_xs,
                                                   self.edge_resampler)
        else:
            g_yt = tf.zeros_like(cyt)
            g_xt = tf.zeros_like(cxt)
            g_ys = tf.zeros_like(ys)
            g_xs = tf.zeros_like(xs)
            glimpse_prime = tf.zeros(
                (batch_size, n_objects, *self.object_shape, self.image_depth))

        glimpse_prime_mask = tf.nn.sigmoid(glimpse_prime_mask_logit + 1.)
        leading_mask_shape = tf_shape(glimpse_prime)[:-1]
        glimpse_prime_mask = tf.reshape(glimpse_prime_mask,
                                        (*leading_mask_shape, 1))

        new_objects.update(
            glimpse_prime_box=tf.concat([g_yt, g_xt, g_ys, g_xs], axis=-1),
            glimpse_prime=glimpse_prime,
            glimpse_prime_mask=glimpse_prime_mask,
        )

        glimpse_prime *= glimpse_prime_mask

        # --- encode glimpse ---

        encoded_glimpse_prime = apply_object_wise(self.glimpse_prime_encoder,
                                                  glimpse_prime,
                                                  n_trailing_dims=3,
                                                  output_size=self.A,
                                                  is_training=self.is_training)

        if not is_posterior:
            encoded_glimpse_prime = tf.zeros((batch_size, n_objects, self.A),
                                             dtype=tf.float32)

        # --- position and scale ---

        # roughly:
        # base_features == temporal_state, encoded_glimpse_prime == hidden_output
        # hidden_output conditions on encoded_glimpse, and that's the only place encoded_glimpse_prime is used.

        # Here SQAIR conditions on the actual location values from the previous timestep, but we leave that out for now.
        d_box_inp = tf.concat([base_features, encoded_glimpse_prime], axis=-1)
        d_box_params = apply_object_wise(self.d_box_network,
                                         d_box_inp,
                                         output_size=8,
                                         is_training=self.is_training)

        d_box_mean, d_box_log_std = tf.split(d_box_params, 2, axis=-1)

        d_box_std = self.std_nonlinearity(d_box_log_std)

        d_box_mean = self.training_wheels * tf.stop_gradient(d_box_mean) + (
            1 - self.training_wheels) * d_box_mean
        d_box_std = self.training_wheels * tf.stop_gradient(d_box_std) + (
            1 - self.training_wheels) * d_box_std

        d_yt_mean, d_xt_mean, d_ys, d_xs = tf.split(d_box_mean, 4, axis=-1)
        d_yt_std, d_xt_std, ys_std, xs_std = tf.split(d_box_std, 4, axis=-1)

        # --- position ---

        # We predict position a bit differently from scale. For scale we want to put a prior on the actual value of
        # the scale, whereas for position we want to put a prior on the difference in position over timesteps.

        d_yt_logit = Normal(loc=d_yt_mean, scale=d_yt_std).sample()
        d_xt_logit = Normal(loc=d_xt_mean, scale=d_xt_std).sample()

        d_yt = self.where_t_scale * tf.nn.tanh(d_yt_logit)
        d_xt = self.where_t_scale * tf.nn.tanh(d_xt_logit)

        new_cyt = cyt + d_yt
        new_cxt = cxt + d_xt

        new_abs_posn = objects.abs_posn + tf.concat([d_yt, d_xt], axis=-1)

        # --- scale ---

        new_ys_mean = objects.ys_logit + d_ys
        new_xs_mean = objects.xs_logit + d_xs

        new_ys_logit = Normal(loc=new_ys_mean, scale=ys_std).sample()
        new_xs_logit = Normal(loc=new_xs_mean, scale=xs_std).sample()

        new_ys = float(self.max_hw - self.min_hw) * tf.nn.sigmoid(
            tf.clip_by_value(new_ys_logit, -10, 10)) + self.min_hw
        new_xs = float(self.max_hw - self.min_hw) * tf.nn.sigmoid(
            tf.clip_by_value(new_xs_logit, -10, 10)) + self.min_hw

        if self.use_abs_posn:
            box_params = tf.concat([
                new_abs_posn, d_yt_logit, d_xt_logit, new_ys_logit,
                new_xs_logit
            ],
                                   axis=-1)
        else:
            box_params = tf.concat(
                [d_yt_logit, d_xt_logit, new_ys_logit, new_xs_logit], axis=-1)

        new_objects.update(
            abs_posn=new_abs_posn,
            yt=new_cyt,
            xt=new_cxt,
            ys=new_ys,
            xs=new_xs,
            normalized_box=tf.concat([new_cyt, new_cxt, new_ys, new_xs],
                                     axis=-1),
            d_yt_logit=d_yt_logit,
            d_xt_logit=d_xt_logit,
            ys_logit=new_ys_logit,
            xs_logit=new_xs_logit,
            d_yt_logit_mean=d_yt_mean,
            d_xt_logit_mean=d_xt_mean,
            ys_logit_mean=new_ys_mean,
            xs_logit_mean=new_xs_mean,
            d_yt_logit_std=d_yt_std,
            d_xt_logit_std=d_xt_std,
            ys_logit_std=ys_std,
            xs_logit_std=xs_std,
        )

        # --- attributes ---

        # --- extract a glimpse using new box ---

        if is_posterior:
            _, image_height, image_width, _ = tf_shape(inp)
            _new_cyt, _new_cxt, _new_ys, _new_xs = coords_to_image_space(
                new_cyt,
                new_cxt,
                new_ys,
                new_xs, (image_height, image_width),
                self.anchor_box,
                top_left=False)

            glimpse = extract_affine_glimpse(inp, self.object_shape, _new_cyt,
                                             _new_cxt, _new_ys, _new_xs,
                                             self.edge_resampler)

        else:
            glimpse = tf.zeros(
                (batch_size, n_objects, *self.object_shape, self.image_depth))

        glimpse_mask = tf.nn.sigmoid(glimpse_mask_logit + 1.)
        leading_mask_shape = tf_shape(glimpse)[:-1]
        glimpse_mask = tf.reshape(glimpse_mask, (*leading_mask_shape, 1))

        glimpse *= glimpse_mask

        encoded_glimpse = apply_object_wise(self.glimpse_encoder,
                                            glimpse,
                                            n_trailing_dims=3,
                                            output_size=self.A,
                                            is_training=self.is_training)

        if not is_posterior:
            encoded_glimpse = tf.zeros((batch_size, n_objects, self.A),
                                       dtype=tf.float32)

        # --- predict change in attributes ---

        # so under sqair we mix between three different values for the attributes:
        # 1. value from previous timestep
        # 2. value predicted directly from glimpse
        # 3. value predicted based on update of temporal cell...this update conditions on hidden_output,
        # the prediction in #2., and the where values.

        # How to do this given that we are predicting the change in attr? We could just directly predict
        # the attr instead, but call it d_attr. After all, it is in this function that we control
        # whether d_attr is added to attr.

        # So, make a prediction based on just the input:

        attr_from_inp = apply_object_wise(self.predict_attr_inp,
                                          encoded_glimpse,
                                          output_size=2 * self.A,
                                          is_training=self.is_training)
        attr_from_inp_mean, attr_from_inp_log_std = tf.split(attr_from_inp,
                                                             [self.A, self.A],
                                                             axis=-1)

        attr_from_inp_std = self.std_nonlinearity(attr_from_inp_log_std)

        # And then a prediction which takes the past into account (predicting gate values at the same time):

        attr_from_temp_inp = tf.concat(
            [base_features, box_params, encoded_glimpse], axis=-1)
        attr_from_temp = apply_object_wise(self.predict_attr_temp,
                                           attr_from_temp_inp,
                                           output_size=5 * self.A,
                                           is_training=self.is_training)

        (attr_from_temp_mean, attr_from_temp_log_std, f_gate_logit,
         i_gate_logit, t_gate_logit) = tf.split(attr_from_temp, 5, axis=-1)

        attr_from_temp_std = self.std_nonlinearity(attr_from_temp_log_std)

        # bias the gates
        f_gate = tf.nn.sigmoid(f_gate_logit + 1) * .9999
        i_gate = tf.nn.sigmoid(i_gate_logit + 1) * .9999
        t_gate = tf.nn.sigmoid(t_gate_logit + 1) * .9999

        attr_mean = f_gate * objects.attr + (
            1 - i_gate) * attr_from_inp_mean + (1 -
                                                t_gate) * attr_from_temp_mean
        attr_std = (1 - i_gate) * attr_from_inp_std + (
            1 - t_gate) * attr_from_temp_std

        new_attr = Normal(loc=attr_mean, scale=attr_std).sample()

        # --- apply change in attributes ---

        new_objects.update(
            attr=new_attr,
            d_attr=new_attr - objects.attr,
            d_attr_mean=attr_mean - objects.attr,
            d_attr_std=attr_std,
            f_gate=f_gate,
            i_gate=i_gate,
            t_gate=t_gate,
            glimpse=glimpse,
            glimpse_mask=glimpse_mask,
        )

        # --- z ---

        d_z_inp = tf.concat(
            [base_features, box_params, new_attr, encoded_glimpse], axis=-1)
        d_z_params = apply_object_wise(self.d_z_network,
                                       d_z_inp,
                                       output_size=2,
                                       is_training=self.is_training)

        d_z_mean, d_z_log_std = tf.split(d_z_params, 2, axis=-1)
        d_z_std = self.std_nonlinearity(d_z_log_std)

        d_z_mean = self.training_wheels * tf.stop_gradient(d_z_mean) + (
            1 - self.training_wheels) * d_z_mean
        d_z_std = self.training_wheels * tf.stop_gradient(d_z_std) + (
            1 - self.training_wheels) * d_z_std

        d_z_logit = Normal(loc=d_z_mean, scale=d_z_std).sample()

        new_z_logit = objects.z_logit + d_z_logit
        new_z = self.z_nonlinearity(new_z_logit)

        new_objects.update(
            z=new_z,
            z_logit=new_z_logit,
            d_z_logit=d_z_logit,
            d_z_logit_mean=d_z_mean,
            d_z_logit_std=d_z_std,
        )

        # --- obj ---

        d_obj_inp = tf.concat(
            [base_features, box_params, new_attr, new_z, encoded_glimpse],
            axis=-1)
        d_obj_logit = apply_object_wise(self.d_obj_network,
                                        d_obj_inp,
                                        output_size=1,
                                        is_training=self.is_training)

        d_obj_logit = self.training_wheels * tf.stop_gradient(d_obj_logit) + (
            1 - self.training_wheels) * d_obj_logit
        d_obj_log_odds = tf.clip_by_value(d_obj_logit / self.obj_temp, -10.,
                                          10.)

        d_obj_pre_sigmoid = (self._noisy * concrete_binary_pre_sigmoid_sample(
            d_obj_log_odds, self.obj_concrete_temp) +
                             (1 - self._noisy) * d_obj_log_odds)

        d_obj = tf.nn.sigmoid(d_obj_pre_sigmoid)

        new_obj = objects.obj * d_obj

        new_objects.update(
            d_obj_log_odds=d_obj_log_odds,
            d_obj_prob=tf.nn.sigmoid(d_obj_log_odds),
            d_obj_pre_sigmoid=d_obj_pre_sigmoid,
            d_obj=d_obj,
            obj=new_obj,
        )

        # --- update each object's hidden state --

        cell_input = tf.concat([box_params, new_attr, new_z, new_obj], axis=-1)

        if is_posterior:
            _, new_objects.prop_state = apply_object_wise(
                self.cell, cell_input, objects.prop_state)
            new_objects.prior_prop_state = new_objects.prop_state
        else:
            _, new_objects.prior_prop_state = apply_object_wise(
                self.cell, cell_input, objects.prior_prop_state)
            new_objects.prop_state = new_objects.prior_prop_state

        return new_objects
示例#12
0
    def _call(self, inp, features, objects, is_training, is_posterior):
        print("\n" + "-" * 10 +
              " PropagationLayer(is_posterior={}) ".format(is_posterior) +
              "-" * 10)

        self._build_networks()

        if not self.initialized:
            # Note this limits the re-usability of this module to images
            # with a fixed shape (the shape of the first image it is used on)
            self.image_height = int(inp.shape[-3])
            self.image_width = int(inp.shape[-2])
            self.image_depth = int(inp.shape[-1])
            self.batch_size = tf.shape(inp)[0]
            self.is_training = is_training
            self.float_is_training = tf.to_float(is_training)

        if self.do_lateral:
            # hasn't been updated to make use of abs_posn
            raise Exception("NotImplemented.")

            batch_size, n_objects, _ = tf_shape(features)

            new_objects = []

            for i in range(n_objects):
                # apply lateral to running objects with the feature vector for
                # the current object

                _features = features[:, i:i + 1, :]

                if i > 0:
                    normalized_box = tf.concat(
                        [o.normalized_box for o in new_objects], axis=1)
                    attr = tf.concat([o.attr for o in new_objects], axis=1)
                    z = tf.concat([o.z for o in new_objects], axis=1)
                    obj = tf.concat([o.obj for o in new_objects], axis=1)
                    completed_features = tf.concat(
                        [normalized_box[:, :, 2:], attr, z, obj], axis=2)
                    completed_locs = normalized_box[:, :, :2]

                    current_features = tf.concat([
                        objects.normalized_box[:, i:i + 1,
                                               2:], objects.attr[:, i:i + 1],
                        objects.z[:, i:i + 1], objects.obj[:, i:i + 1]
                    ],
                                                 axis=2)
                    current_locs = objects.normalized_box[:, i:i + 1, :2]

                    # if i > max_completed_objects:
                    #     # top_k_indices
                    #     # squared_distances = tf.reduce_sum((completed_locs - current_locs)**2, axis=2)
                    #     # _, top_k_indices = tf.nn.top_k(squared_distances, k=max_completed_objects, sorted=False)

                    _features = self.lateral_network(completed_locs,
                                                     completed_features,
                                                     current_locs,
                                                     current_features,
                                                     is_training)

                _objects = AttrDict(
                    normalized_box=objects.normalized_box[:, i:i + 1],
                    attr=objects.attr[:, i:i + 1],
                    z=objects.z[:, i:i + 1],
                    obj=objects.obj[:, i:i + 1],
                )

                _new_objects = self._body(inp, _features, _objects,
                                          is_posterior)
                new_objects.append(_new_objects)

            _new_objects = AttrDict()
            for k in new_objects[0]:
                _new_objects[k] = tf.concat([no[k] for no in new_objects],
                                            axis=1)
            return _new_objects

        else:
            return self._body(inp, features, objects, is_posterior)
示例#13
0
    def _body(self, inp, features, objects, is_posterior):
        batch_size, n_objects, _ = tf_shape(features)

        new_objects = AttrDict()

        is_posterior_tf = tf.ones_like(features[..., 0:2])
        if is_posterior:
            is_posterior_tf = is_posterior_tf * [1, 0]
        else:
            is_posterior_tf = is_posterior_tf * [0, 1]

        base_features = tf.concat([features, is_posterior_tf], axis=-1)

        cyt, cxt, ys, xs = tf.split(objects.normalized_box, 4, axis=-1)

        if self.learn_glimpse_prime:
            # Do this regardless of is_posterior, otherwise ScopedFunction gets messed up
            glimpse_prime_params = apply_object_wise(
                self.glimpse_prime_network,
                base_features,
                output_size=4,
                is_training=self.is_training)
        else:
            glimpse_prime_params = tf.zeros_like(base_features[..., :4])

        if is_posterior:

            if self.learn_glimpse_prime:
                # --- obtain final parameters for glimpse prime by modifying current pose ---
                _yt, _xt, _ys, _xs = tf.split(glimpse_prime_params, 4, axis=-1)

                # This is how it is done in SQAIR
                g_yt = cyt + 0.1 * _yt
                g_xt = cxt + 0.1 * _xt
                g_ys = ys + 0.1 * _ys
                g_xs = xs + 0.1 * _xs

                # g_yt = cyt + self.glimpse_prime_scale * tf.nn.tanh(_yt)
                # g_xt = cxt + self.glimpse_prime_scale * tf.nn.tanh(_xt)
                # g_ys = ys + self.glimpse_prime_scale * tf.nn.tanh(_ys)
                # g_xs = xs + self.glimpse_prime_scale * tf.nn.tanh(_xs)
            else:
                g_yt = cyt
                g_xt = cxt
                g_ys = self.glimpse_prime_scale * ys
                g_xs = self.glimpse_prime_scale * xs

            # --- extract glimpse prime ---

            _, image_height, image_width, _ = tf_shape(inp)
            g_yt, g_xt, g_ys, g_xs = coords_to_image_space(
                g_yt,
                g_xt,
                g_ys,
                g_xs, (image_height, image_width),
                self.anchor_box,
                top_left=False)
            glimpse_prime = extract_affine_glimpse(inp, self.object_shape,
                                                   g_yt, g_xt, g_ys, g_xs,
                                                   self.edge_resampler)
        else:
            g_yt = tf.zeros_like(cyt)
            g_xt = tf.zeros_like(cxt)
            g_ys = tf.zeros_like(ys)
            g_xs = tf.zeros_like(xs)
            glimpse_prime = tf.zeros(
                (batch_size, n_objects, *self.object_shape, self.image_depth))

        new_objects.update(glimpse_prime_box=tf.concat(
            [g_yt, g_xt, g_ys, g_xs], axis=-1), )

        # --- encode glimpse ---

        encoded_glimpse_prime = apply_object_wise(self.glimpse_prime_encoder,
                                                  glimpse_prime,
                                                  n_trailing_dims=3,
                                                  output_size=self.A,
                                                  is_training=self.is_training)

        if not is_posterior:
            encoded_glimpse_prime = tf.zeros((batch_size, n_objects, self.A),
                                             dtype=tf.float32)

        # --- position and scale ---

        d_box_inp = tf.concat([base_features, encoded_glimpse_prime], axis=-1)
        d_box_params = apply_object_wise(self.d_box_network,
                                         d_box_inp,
                                         output_size=8,
                                         is_training=self.is_training)

        d_box_mean, d_box_log_std = tf.split(d_box_params, 2, axis=-1)

        d_box_std = self.std_nonlinearity(d_box_log_std)

        d_box_mean = self.training_wheels * tf.stop_gradient(d_box_mean) + (
            1 - self.training_wheels) * d_box_mean
        d_box_std = self.training_wheels * tf.stop_gradient(d_box_std) + (
            1 - self.training_wheels) * d_box_std

        d_yt_mean, d_xt_mean, d_ys, d_xs = tf.split(d_box_mean, 4, axis=-1)
        d_yt_std, d_xt_std, ys_std, xs_std = tf.split(d_box_std, 4, axis=-1)

        # --- position ---

        # We predict position a bit differently from scale. For scale we want to put a prior on the actual value of
        # the scale, whereas for position we want to put a prior on the difference in position over timesteps.

        d_yt_logit = Normal(loc=d_yt_mean, scale=d_yt_std).sample()
        d_xt_logit = Normal(loc=d_xt_mean, scale=d_xt_std).sample()

        d_yt = self.where_t_scale * tf.nn.tanh(d_yt_logit)
        d_xt = self.where_t_scale * tf.nn.tanh(d_xt_logit)

        new_cyt = cyt + d_yt
        new_cxt = cxt + d_xt

        new_abs_posn = objects.abs_posn + tf.concat([d_yt, d_xt], axis=-1)

        # --- scale ---

        new_ys_mean = objects.ys_logit + d_ys
        new_xs_mean = objects.xs_logit + d_xs

        new_ys_logit = Normal(loc=new_ys_mean, scale=ys_std).sample()
        new_xs_logit = Normal(loc=new_xs_mean, scale=xs_std).sample()

        new_ys = float(self.max_hw - self.min_hw) * tf.nn.sigmoid(
            tf.clip_by_value(new_ys_logit, -10, 10)) + self.min_hw
        new_xs = float(self.max_hw - self.min_hw) * tf.nn.sigmoid(
            tf.clip_by_value(new_xs_logit, -10, 10)) + self.min_hw

        # Used for conditioning
        if self.use_abs_posn:
            box_params = tf.concat([
                new_abs_posn, d_yt_logit, d_xt_logit, new_ys_logit,
                new_xs_logit
            ],
                                   axis=-1)
        else:
            box_params = tf.concat(
                [d_yt_logit, d_xt_logit, new_ys_logit, new_xs_logit], axis=-1)

        new_objects.update(
            abs_posn=new_abs_posn,
            yt=new_cyt,
            xt=new_cxt,
            ys=new_ys,
            xs=new_xs,
            normalized_box=tf.concat([new_cyt, new_cxt, new_ys, new_xs],
                                     axis=-1),
            d_yt_logit=d_yt_logit,
            d_xt_logit=d_xt_logit,
            ys_logit=new_ys_logit,
            xs_logit=new_xs_logit,
            d_yt_logit_mean=d_yt_mean,
            d_xt_logit_mean=d_xt_mean,
            ys_logit_mean=new_ys_mean,
            xs_logit_mean=new_xs_mean,
            d_yt_logit_std=d_yt_std,
            d_xt_logit_std=d_xt_std,
            ys_logit_std=ys_std,
            xs_logit_std=xs_std,
            glimpse_prime=glimpse_prime,
        )

        # --- attributes ---

        # --- extract a glimpse using new box ---

        if is_posterior:
            _, image_height, image_width, _ = tf_shape(inp)
            _new_cyt, _new_cxt, _new_ys, _new_xs = coords_to_image_space(
                new_cyt,
                new_cxt,
                new_ys,
                new_xs, (image_height, image_width),
                self.anchor_box,
                top_left=False)

            glimpse = extract_affine_glimpse(inp, self.object_shape, _new_cyt,
                                             _new_cxt, _new_ys, _new_xs,
                                             self.edge_resampler)

        else:
            glimpse = tf.zeros(
                (batch_size, n_objects, *self.object_shape, self.image_depth))

        encoded_glimpse = apply_object_wise(self.glimpse_encoder,
                                            glimpse,
                                            n_trailing_dims=3,
                                            output_size=self.A,
                                            is_training=self.is_training)

        if not is_posterior:
            encoded_glimpse = tf.zeros((batch_size, n_objects, self.A),
                                       dtype=tf.float32)

        # --- predict change in attributes ---

        d_attr_inp = tf.concat([base_features, box_params, encoded_glimpse],
                               axis=-1)
        d_attr_params = apply_object_wise(self.d_attr_network,
                                          d_attr_inp,
                                          output_size=2 * self.A + 1,
                                          is_training=self.is_training)

        d_attr_mean, d_attr_log_std, gate_logit = tf.split(d_attr_params,
                                                           [self.A, self.A, 1],
                                                           axis=-1)
        d_attr_std = self.std_nonlinearity(d_attr_log_std)

        gate = tf.nn.sigmoid(gate_logit)

        if self.gate_d_attr:
            d_attr_mean *= gate

        d_attr = Normal(loc=d_attr_mean, scale=d_attr_std).sample()

        # --- apply change in attributes ---

        new_attr = objects.attr + d_attr

        new_objects.update(
            attr=new_attr,
            d_attr=d_attr,
            d_attr_mean=d_attr_mean,
            d_attr_std=d_attr_std,
            glimpse=glimpse,
            d_attr_gate=gate,
        )

        # --- z ---

        d_z_inp = tf.concat(
            [base_features, box_params, new_attr, encoded_glimpse], axis=-1)
        d_z_params = apply_object_wise(self.d_z_network,
                                       d_z_inp,
                                       output_size=2,
                                       is_training=self.is_training)

        d_z_mean, d_z_log_std = tf.split(d_z_params, 2, axis=-1)
        d_z_std = self.std_nonlinearity(d_z_log_std)

        d_z_mean = self.training_wheels * tf.stop_gradient(d_z_mean) + (
            1 - self.training_wheels) * d_z_mean
        d_z_std = self.training_wheels * tf.stop_gradient(d_z_std) + (
            1 - self.training_wheels) * d_z_std

        d_z_logit = Normal(loc=d_z_mean, scale=d_z_std).sample()

        new_z_logit = objects.z_logit + d_z_logit
        new_z = self.z_nonlinearity(new_z_logit)

        new_objects.update(
            z=new_z,
            z_logit=new_z_logit,
            d_z_logit=d_z_logit,
            d_z_logit_mean=d_z_mean,
            d_z_logit_std=d_z_std,
        )

        # --- obj ---

        d_obj_inp = tf.concat(
            [base_features, box_params, new_attr, new_z, encoded_glimpse],
            axis=-1)
        d_obj_logit = apply_object_wise(self.d_obj_network,
                                        d_obj_inp,
                                        output_size=1,
                                        is_training=self.is_training)

        d_obj_logit = self.training_wheels * tf.stop_gradient(d_obj_logit) + (
            1 - self.training_wheels) * d_obj_logit
        d_obj_log_odds = tf.clip_by_value(d_obj_logit / self.obj_temp, -10.,
                                          10.)

        d_obj_pre_sigmoid = (self._noisy * concrete_binary_pre_sigmoid_sample(
            d_obj_log_odds, self.obj_concrete_temp) +
                             (1 - self._noisy) * d_obj_log_odds)

        d_obj = tf.nn.sigmoid(d_obj_pre_sigmoid)

        new_obj = objects.obj * d_obj

        new_objects.update(
            d_obj_log_odds=d_obj_log_odds,
            d_obj_prob=tf.nn.sigmoid(d_obj_log_odds),
            d_obj_pre_sigmoid=d_obj_pre_sigmoid,
            d_obj=d_obj,
            obj=new_obj,
        )

        # --- update each object's hidden state --

        cell_input = tf.concat([box_params, new_attr, new_z, new_obj], axis=-1)

        if is_posterior:
            _, new_objects.prop_state = apply_object_wise(
                self.cell, cell_input, objects.prop_state)
            new_objects.prior_prop_state = new_objects.prop_state
        else:
            _, new_objects.prior_prop_state = apply_object_wise(
                self.cell, cell_input, objects.prior_prop_state)
            new_objects.prop_state = new_objects.prior_prop_state

        return new_objects
示例#14
0
    def _call(self, inp, inp_features, is_training, is_posterior=True, prop_state=None):
        print("\n" + "-" * 10 + " ConvGridObjectLayer(is_posterior={}) ".format(is_posterior) + "-" * 10)

        # --- set up sub networks and attributes ---

        self.maybe_build_subnet("box_network", builder=cfg.build_conv_lateral, key="box")
        self.maybe_build_subnet("attr_network", builder=cfg.build_conv_lateral, key="attr")
        self.maybe_build_subnet("z_network", builder=cfg.build_conv_lateral, key="z")
        self.maybe_build_subnet("obj_network", builder=cfg.build_conv_lateral, key="obj")

        self.maybe_build_subnet("object_encoder")

        _, H, W, _, n_channels = tf_shape(inp_features)

        if self.B != 1:
            raise Exception("NotImplemented")

        if not self.initialized:
            # Note this limits the re-usability of this module to images
            # with a fixed shape (the shape of the first image it is used on)
            self.batch_size, self.image_height, self.image_width, self.image_depth = tf_shape(inp)
            self.H = H
            self.W = W
            self.HWB = H*W
            self.batch_size = tf.shape(inp)[0]
            self.is_training = is_training
            self.float_is_training = tf.to_float(is_training)

        inp_features = tf.reshape(inp_features, (self.batch_size, H, W, n_channels))
        is_posterior_tf = tf.ones_like(inp_features[..., :2])
        if is_posterior:
            is_posterior_tf = is_posterior_tf * [1, 0]
        else:
            is_posterior_tf = is_posterior_tf * [0, 1]

        objects = AttrDict()

        base_features = tf.concat([inp_features, is_posterior_tf], axis=-1)

        # --- box ---

        layer_inp = base_features
        n_features = self.n_passthrough_features
        output_size = 8

        network_output = self.box_network(layer_inp, output_size + n_features, self.is_training)
        rep_input, features = tf.split(network_output, (output_size, n_features), axis=-1)

        _objects = self._build_box(rep_input, self.is_training)
        objects.update(_objects)

        # --- attr ---

        if is_posterior:
            # --- Get object attributes using object encoder ---

            yt, xt, ys, xs = tf.split(objects['normalized_box'], 4, axis=-1)

            yt, xt, ys, xs = coords_to_image_space(
                yt, xt, ys, xs, (self.image_height, self.image_width), self.anchor_box, top_left=False)

            transform_constraints = snt.AffineWarpConstraints.no_shear_2d()
            warper = snt.AffineGridWarper(
                (self.image_height, self.image_width), self.object_shape, transform_constraints)

            _boxes = tf.concat([xs, 2*xt - 1, ys, 2*yt - 1], axis=-1)
            _boxes = tf.reshape(_boxes, (self.batch_size*H*W, 4))
            grid_coords = warper(_boxes)
            grid_coords = tf.reshape(grid_coords, (self.batch_size, H, W, *self.object_shape, 2,))

            if self.edge_resampler:
                glimpse = resampler_edge.resampler_edge(inp, grid_coords)
            else:
                glimpse = tf.contrib.resampler.resampler(inp, grid_coords)
        else:
            glimpse = tf.zeros((self.batch_size, H, W, *self.object_shape, self.image_depth))

        # Create the object encoder network regardless of is_posterior, otherwise messes with ScopedFunction
        encoded_glimpse = apply_object_wise(
            self.object_encoder, glimpse, n_trailing_dims=3, output_size=self.A, is_training=self.is_training)

        if not is_posterior:
            encoded_glimpse = tf.zeros_like(encoded_glimpse)

        layer_inp = tf.concat([base_features, features, encoded_glimpse, objects['local_box']], axis=-1)
        network_output = self.attr_network(layer_inp, 2 * self.A + n_features, self.is_training)
        attr_mean, attr_log_std, features = tf.split(network_output, (self.A, self.A, n_features), axis=-1)

        attr_std = self.std_nonlinearity(attr_log_std)

        attr = Normal(loc=attr_mean, scale=attr_std).sample()

        objects.update(attr_mean=attr_mean, attr_std=attr_std, attr=attr, glimpse=glimpse)

        # --- z ---

        layer_inp = tf.concat([base_features, features, objects['local_box'], objects['attr']], axis=-1)
        n_features = self.n_passthrough_features

        network_output = self.z_network(layer_inp, 2 + n_features, self.is_training)
        z_mean, z_log_std, features = tf.split(network_output, (1, 1, n_features), axis=-1)
        z_std = self.std_nonlinearity(z_log_std)

        z_mean = self.training_wheels * tf.stop_gradient(z_mean) + (1-self.training_wheels) * z_mean
        z_std = self.training_wheels * tf.stop_gradient(z_std) + (1-self.training_wheels) * z_std
        z_logit = Normal(loc=z_mean, scale=z_std).sample()
        z = self.z_nonlinearity(z_logit)

        objects.update(z_logit_mean=z_mean, z_logit_std=z_std, z_logit=z_logit, z=z)

        # --- obj ---

        layer_inp = tf.concat([base_features, features, objects['local_box'], objects['attr'], objects['z']], axis=-1)
        rep_input = self.obj_network(layer_inp, 1, self.is_training)

        _objects = self._build_obj(rep_input, self.is_training)
        objects.update(_objects)

        # --- final ---

        _objects = AttrDict()
        for k, v in objects.items():
            _, _, _, *trailing_dims = tf_shape(v)
            _objects[k] = tf.reshape(v, (self.batch_size, self.HWB, *trailing_dims))
        objects = _objects

        if prop_state is not None:
            objects.prop_state = tf.tile(prop_state[0:1, None], (self.batch_size, self.HWB, 1))
            objects.prior_prop_state = tf.tile(prop_state[0:1, None], (self.batch_size, self.HWB, 1))

        # --- misc ---

        objects.n_objects = tf.fill((self.batch_size,), self.HWB)
        objects.pred_n_objects = tf.reduce_sum(objects.obj, axis=(1, 2))
        objects.pred_n_objects_hard = tf.reduce_sum(tf.round(objects.obj), axis=(1, 2))

        return objects
示例#15
0
    def _call(self, data, is_training):
        self.data = data

        inp = data["image"]
        self._tensors = AttrDict(
            inp=inp,
            is_training=is_training,
            float_is_training=tf.to_float(is_training),
            batch_size=tf.shape(inp)[0],
        )

        if "annotations" in data:
            self._tensors.update(
                annotations=data["annotations"]["data"],
                n_annotations=data["annotations"]["shapes"][:, 1],
                n_valid_annotations=tf.to_int32(
                    tf.reduce_sum(
                        data["annotations"]["data"][:, :, :, 0]
                        * tf.to_float(data["annotations"]["mask"][:, :, :, 0]),
                        axis=2
                    )
                )
            )

        if "label" in data:
            self._tensors.update(
                targets=data["label"],
            )

        if "background" in data:
            self._tensors.update(
                ground_truth_background=data["background"],
            )

        if "offset" in data:
            self._tensors.update(
                offset=data["offset"],
            )

        max_n_frames = tf_shape(inp)[1]

        if self.stage_steps is None:
            self.current_stage = tf.constant(0, tf.int32)
            dynamic_n_frames = max_n_frames
        else:
            self.current_stage = tf.cast(tf.train.get_or_create_global_step(), tf.int32) // self.stage_steps
            dynamic_n_frames = tf.minimum(
                self.initial_n_frames + self.n_frames_scale * self.current_stage, max_n_frames)

        dynamic_n_frames = tf.cast(dynamic_n_frames, tf.float32)
        dynamic_n_frames = (
            self.float_is_training * tf.cast(dynamic_n_frames, tf.float32)
            + (1-self.float_is_training) * tf.cast(max_n_frames, tf.float32)
        )
        self.dynamic_n_frames = tf.cast(dynamic_n_frames, tf.int32)

        self._tensors.current_stage = self.current_stage
        self._tensors.dynamic_n_frames = self.dynamic_n_frames

        self._tensors.inp = self._tensors.inp[:, :self.dynamic_n_frames]

        if 'annotations' in self._tensors:
            self._tensors.annotations = self._tensors.annotations[:, :self.dynamic_n_frames]
            # self._tensors.n_annotations = self._tensors.n_annotations[:, :self.dynamic_n_frames]
            self._tensors.n_valid_annotations = self._tensors.n_valid_annotations[:, :self.dynamic_n_frames]

        self.record_tensors(
            batch_size=tf.to_float(self.batch_size),
            float_is_training=self.float_is_training,
            current_stage=self.current_stage,
            dynamic_n_frames=self.dynamic_n_frames,
        )

        self.losses = dict()

        with tf.variable_scope("representation", reuse=self.initialized):
            if self.needs_background:
                self.build_background()

            self.build_representation()

        return dict(
            tensors=self._tensors,
            recorded_tensors=self.recorded_tensors,
            losses=self.losses,
        )
示例#16
0
    def _call(self, inp, inp_features, is_training, is_posterior=True, prop_state=None):
        print("\n" + "-" * 10 + " GridObjectLayer(is_posterior={}) ".format(is_posterior) + "-" * 10)

        # --- set up sub networks and attributes ---

        self.maybe_build_subnet("box_network", builder=cfg.build_lateral, key="box")
        self.maybe_build_subnet("attr_network", builder=cfg.build_lateral, key="attr")
        self.maybe_build_subnet("z_network", builder=cfg.build_lateral, key="z")
        self.maybe_build_subnet("obj_network", builder=cfg.build_lateral, key="obj")

        self.maybe_build_subnet("object_encoder")

        _, H, W, _, _ = tf_shape(inp_features)
        H = int(H)
        W = int(W)

        if not self.initialized:
            # Note this limits the re-usability of this module to images
            # with a fixed shape (the shape of the first image it is used on)
            self.batch_size, self.image_height, self.image_width, self.image_depth = tf_shape(inp)
            self.H = H
            self.W = W
            self.HWB = H*W*self.B
            self.is_training = is_training
            self.float_is_training = tf.to_float(is_training)

        # --- set up the edge element ---

        sizes = [4, self.A, 1, 1]
        sigmoids = [True, False, False, True]
        total_sample_size = sum(sizes)

        if self.edge_weights is None:
            self.edge_weights = tf.get_variable("edge_weights", shape=total_sample_size, dtype=tf.float32)
            if "edge" in self.fixed_weights:
                tf.add_to_collection(FIXED_COLLECTION, self.edge_weights)

        _edge_weights = tf.split(self.edge_weights, sizes, axis=0)
        _edge_weights = [
            (tf.nn.sigmoid(ew) if sigmoid else ew)
            for ew, sigmoid in zip(_edge_weights, sigmoids)]
        edge_element = tf.concat(_edge_weights, axis=0)
        edge_element = tf.tile(edge_element[None, :], (self.batch_size, 1))

        # --- containers for storing built program ---

        program = np.empty((H, W, self.B), dtype=np.object)

        # --- build the program ---

        is_posterior_tf = tf.ones((self.batch_size, 2))
        if is_posterior:
            is_posterior_tf = is_posterior_tf * [1, 0]
        else:
            is_posterior_tf = is_posterior_tf * [0, 1]

        results = []
        for h, w, b in itertools.product(range(H), range(W), range(self.B)):
            built = dict()

            partial_program, features = None, None
            context = self._get_sequential_context(program, h, w, b, edge_element)
            base_features = tf.concat([inp_features[:, h, w, b, :], context, is_posterior_tf], axis=1)

            # --- box ---

            layer_inp = base_features
            n_features = self.n_passthrough_features
            output_size = 8

            network_output = self.box_network(layer_inp, output_size + n_features, self. is_training)
            rep_input, features = tf.split(network_output, (output_size, n_features), axis=1)

            _built = self._build_box(rep_input, self.is_training, hw=(h, w))
            built.update(_built)
            partial_program = built['local_box']

            # --- attr ---

            if is_posterior:
                # --- Get object attributes using object encoder ---

                yt, xt, ys, xs = tf.split(built['normalized_box'], 4, axis=-1)

                yt, xt, ys, xs = coords_to_image_space(
                    yt, xt, ys, xs, (self.image_height, self.image_width), self.anchor_box, top_left=False)

                transform_constraints = snt.AffineWarpConstraints.no_shear_2d()
                warper = snt.AffineGridWarper(
                    (self.image_height, self.image_width), self.object_shape, transform_constraints)

                _boxes = tf.concat([xs, 2*xt - 1, ys, 2*yt - 1], axis=-1)

                grid_coords = warper(_boxes)
                grid_coords = tf.reshape(grid_coords, (self.batch_size, 1, *self.object_shape, 2,))
                if self.edge_resampler:
                    glimpse = resampler_edge.resampler_edge(inp, grid_coords)
                else:
                    glimpse = tf.contrib.resampler.resampler(inp, grid_coords)
                glimpse = tf.reshape(glimpse, (self.batch_size, *self.object_shape, self.image_depth))
            else:
                glimpse = tf.zeros((self.batch_size, *self.object_shape, self.image_depth))

            # Create the object encoder network regardless of is_posterior, otherwise messes with ScopedFunction
            encoded_glimpse = self.object_encoder(glimpse, (1, 1, self.A), self.is_training)
            encoded_glimpse = tf.reshape(encoded_glimpse, (self.batch_size, self.A))

            if not is_posterior:
                encoded_glimpse = tf.zeros_like(encoded_glimpse)

            layer_inp = tf.concat(
                [base_features, features, encoded_glimpse, partial_program], axis=1)
            network_output = self.attr_network(layer_inp, 2 * self.A + n_features, self. is_training)
            attr_mean, attr_log_std, features = tf.split(network_output, (self.A, self.A, n_features), axis=1)

            attr_std = self.std_nonlinearity(attr_log_std)

            attr = Normal(loc=attr_mean, scale=attr_std).sample()

            built.update(attr_mean=attr_mean, attr_std=attr_std, attr=attr, glimpse=glimpse)
            partial_program = tf.concat([partial_program, built['attr']], axis=1)

            # --- z ---

            layer_inp = tf.concat([base_features, features, partial_program], axis=1)
            n_features = self.n_passthrough_features

            network_output = self.z_network(layer_inp, 2 + n_features, self.is_training)
            z_mean, z_log_std, features = tf.split(network_output, (1, 1, n_features), axis=1)
            z_std = self.std_nonlinearity(z_log_std)

            z_mean = self.training_wheels * tf.stop_gradient(z_mean) + (1-self.training_wheels) * z_mean
            z_std = self.training_wheels * tf.stop_gradient(z_std) + (1-self.training_wheels) * z_std
            z_logit = Normal(loc=z_mean, scale=z_std).sample()
            z = self.z_nonlinearity(z_logit)

            built.update(z_logit_mean=z_mean, z_logit_std=z_std, z_logit=z_logit, z=z)
            partial_program = tf.concat([partial_program, built['z']], axis=1)

            # --- obj ---

            layer_inp = tf.concat([base_features, features, partial_program], axis=1)
            rep_input = self.obj_network(layer_inp, 1, self.is_training)

            _built = self._build_obj(rep_input, self.is_training)
            built.update(_built)

            partial_program = tf.concat([partial_program, built['obj']], axis=1)

            # --- final ---

            results.append(built)

            program[h, w, b] = partial_program
            assert program[h, w, b].shape[1] == total_sample_size

        objects = AttrDict()
        for k in results[0]:
            objects[k] = tf.stack([r[k] for r in results], axis=1)

        if prop_state is not None:
            objects.prop_state = tf.tile(prop_state[0:1, None], (self.batch_size, self.HWB, 1))
            objects.prior_prop_state = tf.tile(prop_state[0:1, None], (self.batch_size, self.HWB, 1))

        # --- misc ---

        objects.n_objects = tf.fill((self.batch_size,), self.HWB)
        objects.pred_n_objects = tf.reduce_sum(objects.obj, axis=(1, 2))
        objects.pred_n_objects_hard = tf.reduce_sum(tf.round(objects.obj), axis=(1, 2))

        return objects
示例#17
0
    def _compute_obj_kl(self, tensors, existing_objects=None):
        # --- compute obj_kl ---

        obj_pre_sigmoid = tensors["obj_pre_sigmoid"]
        obj_log_odds = tensors["obj_log_odds"]
        obj_prob = tensors["obj_prob"]
        obj = tensors["obj"]
        batch_size, n_objects, _ = tf_shape(obj)

        max_n_objects = n_objects

        if existing_objects is not None:
            _, n_existing_objects, _ = tf_shape(existing_objects)
            existing_objects = tf.reshape(existing_objects, (batch_size, n_existing_objects))
            max_n_objects += n_existing_objects

        count_support = tf.range(max_n_objects+1, dtype=tf.float32)

        if self.count_prior_dist is not None:
            if self.count_prior_dist is not None:
                assert len(self.count_prior_dist) == (max_n_objects + 1)
            count_distribution = tf.constant(self.count_prior_dist, dtype=tf.float32)
        else:
            count_prior_prob = tf.nn.sigmoid(self.count_prior_log_odds)
            count_distribution = (1 - count_prior_prob) * (count_prior_prob ** count_support)

        normalizer = tf.reduce_sum(count_distribution)
        count_distribution = count_distribution / tf.maximum(normalizer, 1e-6)
        count_distribution = tf.tile(count_distribution[None, :], (batch_size, 1))

        if existing_objects is not None:
            count_so_far = tf.reduce_sum(tf.round(existing_objects), axis=1, keepdims=True)

            count_distribution = (
                count_distribution
                * tf_binomial_coefficient(count_support, count_so_far)
                * tf_binomial_coefficient(max_n_objects - count_support, n_existing_objects - count_so_far)
            )

            normalizer = tf.reduce_sum(count_distribution, axis=1, keepdims=True)
            count_distribution = count_distribution / tf.maximum(normalizer, 1e-6)
        else:
            count_so_far = tf.zeros((batch_size, 1), dtype=tf.float32)

        obj_kl = []
        for i in range(n_objects):
            p_z_given_Cz_raw = (count_support[None, :] - count_so_far) / (max_n_objects - i)
            p_z_given_Cz = tf.clip_by_value(p_z_given_Cz_raw, 0.0, 1.0)

            # Doing this instead of 1 - p_z_given_Cz seems to be more numerically stable.
            inv_p_z_given_Cz_raw = (max_n_objects - i - count_support[None, :] + count_so_far) / (max_n_objects - i)
            inv_p_z_given_Cz = tf.clip_by_value(inv_p_z_given_Cz_raw, 0.0, 1.0)

            p_z = tf.reduce_sum(count_distribution * p_z_given_Cz, axis=1, keepdims=True)

            if self.use_concrete_kl:
                prior_log_odds = tf_safe_log(p_z) - tf_safe_log(1-p_z)
                _obj_kl = concrete_binary_sample_kl(
                    obj_pre_sigmoid[:, i, :],
                    obj_log_odds[:, i, :], self.obj_concrete_temp,
                    prior_log_odds, self.obj_concrete_temp,
                )
            else:
                prob = obj_prob[:, i, :]

                _obj_kl = (
                    prob * (tf_safe_log(prob) - tf_safe_log(p_z))
                    + (1-prob) * (tf_safe_log(1-prob) - tf_safe_log(1-p_z))
                )

            obj_kl.append(_obj_kl)

            sample = tf.to_float(obj[:, i, :] > 0.5)
            mult = sample * p_z_given_Cz + (1-sample) * inv_p_z_given_Cz
            raw_count_distribution = mult * count_distribution
            normalizer = tf.reduce_sum(raw_count_distribution, axis=1, keepdims=True)
            normalizer = tf.maximum(normalizer, 1e-6)

            # invalid = tf.logical_and(p_z_given_Cz_raw > 1, count_distribution > 1e-8)
            # float_invalid = tf.cast(invalid, tf.float32)
            # diagnostic = tf.stack(
            #     [float_invalid, p_z_given_Cz, count_distribution, mult, raw_count_distribution], axis=-1)

            # assert_op = tf.Assert(
            #     tf.reduce_all(tf.logical_not(invalid)),
            #     [invalid, diagnostic, count_so_far, sample, tf.constant(i, dtype=tf.float32)],
            #     summarize=100000)

            count_distribution = raw_count_distribution / normalizer
            count_so_far += sample

            # this avoids buildup of inaccuracies that can cause problems in computing p_z_given_Cz_raw
            count_so_far = tf.round(count_so_far)

        obj_kl = tf.reshape(tf.concat(obj_kl, axis=1), (batch_size, n_objects, 1))

        return obj_kl
示例#18
0
    def build_representation(self):
        processed_image = index.tile_input_for_iwae(
            tf.transpose(self.inp, (1, 0, 2, 3, 4)), self.k_particles, with_time=True)
        shape = list(tf_shape(processed_image))
        shape[1] = cfg.batch_size * self.k_particles
        processed_image = tf.reshape(processed_image, shape)

        self._tensors.update(
            processed_image=processed_image,
            mean_img=self.data['mean_img'],
        )

        _, _, *img_size = processed_image.shape.as_list()

        layers = [self.n_hidden] * self.n_layers

        def glimpse_encoder():
            return AIREncoder(img_size, self.object_shape, self.n_what, Encoder(layers),
                              masked_glimpse=self.masked_glimpse, debug=self.debug)

        steps_pred_hidden = self.n_hidden / 2

        transform_estimator = partial(StochasticTransformParam, layers, self.transform_var_bias)

        if self.fixed_presence:
            disc_steps_predictor = partial(FixedStepsPredictor, discovery=True)
        else:
            disc_steps_predictor = partial(StepsPredictor, steps_pred_hidden, self.disc_step_bias)

        if cfg.build_input_encoder is None:
            input_encoder = partial(Encoder, layers)
        else:
            input_encoder = cfg.build_input_encoder

        _input_encoder = input_encoder()

        T, B, *rest = tf_shape(processed_image)
        images = tf.reshape(processed_image, (T*B, *rest))
        encoded_input = _input_encoder(images)
        encoded_input = tf.reshape(encoded_input, (T, B, *tf_shape(encoded_input)[1:]))

        with tf.variable_scope('discovery'):

            discover_cell = DiscoveryCore(
                processed_image, encoded_input, self.object_shape, self.n_what, self.rnn_class(self.n_hidden),
                glimpse_encoder, transform_estimator, disc_steps_predictor, debug=self.debug)

            if self.fast_discovery:
                object_state_predictor = MLP([256, 256, self.n_hidden])
                discover = FastDiscover(
                    object_state_predictor, self.n_objects, discover_cell,
                    step_success_prob=self.step_success_prob, where_mean=[*self.scale_prior, 0, 0],
                    disc_prior_type=self.disc_prior_type, rec_where_prior=self.rec_where_prior)

            else:
                discover = Discover(
                    self.n_objects, discover_cell, step_success_prob=self.step_success_prob,
                    where_mean=[*self.scale_prior, 0, 0], disc_prior_type=self.disc_prior_type,
                    rec_where_prior=self.rec_where_prior)

        with tf.variable_scope('propagation'):
            # Prop cell should have a different rnn cell but should share all other estimators
            glimpse_encoder = lambda: discover_cell._glimpse_encoder
            transform_estimator = partial(StochasticTransformParam, layers, self.transform_var_bias)

            if self.fixed_presence:
                prop_steps_predictor = partial(FixedStepsPredictor, discovery=False)
            else:
                prop_steps_predictor = partial(StepsPredictor, steps_pred_hidden, self.prop_step_bias)

            prior_rnn = self.prior_rnn_class(self.n_hidden)
            propagation_prior = make_prior(self.prop_prior_type, self.n_what, prior_rnn, self.prop_prior_step_bias)

            propagate_rnn_cell = self.rnn_class(self.n_hidden)
            temporal_rnn_cell = self.time_rnn_class(self.n_hidden)

            propagation_cell = PropagationCore(processed_image, encoded_input, self.object_shape, self.n_what,
                                               propagate_rnn_cell, glimpse_encoder, transform_estimator,
                                               prop_steps_predictor, temporal_rnn_cell, debug=self.debug)

            if self.fast_propagation:
                propagate = FastPropagate(propagation_cell, propagation_prior)
            else:
                propagate = Propagate(propagation_cell, propagation_prior)

        with tf.variable_scope('decoder'):
            glimpse_decoder = partial(Decoder, layers, output_scale=self.output_scale)
            decoder = AIRDecoder(img_size, self.object_shape, glimpse_decoder,
                                 batch_dims=2,
                                 mean_img=self._tensors.mean_img,
                                 output_std=self.output_std,
                                 scale_bounds=self.scale_bounds)

        with tf.variable_scope('sequence'):
            time_cell = self.time_rnn_class(self.n_hidden)

            sequence_apdr = SequentialAIR(
                self.n_objects, self.object_shape, discover, propagate,
                time_cell, decoder, prior_start_step=self._prior_start_step)

        outputs = sequence_apdr(processed_image)
        outputs['where_coords'] = decoder._transformer.to_coords(outputs['where'])

        self._tensors.update(outputs)
示例#19
0
    def build_background(self):
        if cfg.background_cfg.mode == "colour":
            rgb = np.array(to_rgb(cfg.background_cfg.colour))[None, None, None, :]
            background = rgb * tf.ones_like(self.inp)

        elif cfg.background_cfg.mode == "learn_solid":
            # Learn a solid colour for the background
            self.solid_background_logits = tf.get_variable("solid_background", initializer=[0.0, 0.0, 0.0])
            if "background" in self.fixed_weights:
                tf.add_to_collection(FIXED_COLLECTION, self.solid_background_logits)
            solid_background = tf.nn.sigmoid(10 * self.solid_background_logits)
            background = solid_background[None, None, None, :] * tf.ones_like(self.inp)

        elif cfg.background_cfg.mode == "scalor":
            pass

        elif cfg.background_cfg.mode == "learn":
            self.maybe_build_subnet("background_encoder")
            self.maybe_build_subnet("background_decoder")

            # Here I'm just encoding the first frame...
            bg_attr = self.background_encoder(self.inp[:, 0], 2 * cfg.background_cfg.A, self.is_training)
            bg_attr_mean, bg_attr_log_std = tf.split(bg_attr, 2, axis=-1)
            bg_attr_std = tf.exp(bg_attr_log_std)
            if not self.noisy:
                bg_attr_std = tf.zeros_like(bg_attr_std)

            bg_attr, bg_attr_kl = normal_vae(bg_attr_mean, bg_attr_std, self.attr_prior_mean, self.attr_prior_std)

            self._tensors.update(
                bg_attr_mean=bg_attr_mean,
                bg_attr_std=bg_attr_std,
                bg_attr_kl=bg_attr_kl,
                bg_attr=bg_attr)

            # --- decode ---

            _, T, H, W, _ = tf_shape(self.inp)

            background = self.background_decoder(bg_attr, 3, self.is_training)
            assert len(background.shape) == 2 and background.shape[1] == 3
            background = tf.nn.sigmoid(tf.clip_by_value(background, -10, 10))
            background = tf.tile(background[:, None, None, None, :], (1, T, H, W, 1))

        elif cfg.background_cfg.mode == "learn_and_transform":
            self.maybe_build_subnet("background_encoder")
            self.maybe_build_subnet("background_decoder")

            # --- encode ---

            n_transform_latents = 4
            n_latents = (2 * cfg.background_cfg.A, 2 * n_transform_latents)

            bg_attr, bg_transform_params = self.background_encoder(self.inp, n_latents, self.is_training)

            # --- bg attributes ---

            bg_attr_mean, bg_attr_log_std = tf.split(bg_attr, 2, axis=-1)
            bg_attr_std = self.std_nonlinearity(bg_attr_log_std)

            bg_attr, bg_attr_kl = normal_vae(bg_attr_mean, bg_attr_std, self.attr_prior_mean, self.attr_prior_std)

            # --- bg location ---

            bg_transform_params = tf.reshape(
                bg_transform_params,
                (self.batch_size, self.dynamic_n_frames, 2*n_transform_latents))

            mean, log_std = tf.split(bg_transform_params, 2, axis=2)
            std = self.std_nonlinearity(log_std)

            logits, kl = normal_vae(mean, std, 0.0, 1.0)

            # integrate across timesteps
            logits = tf.cumsum(logits, axis=1)
            logits = tf.reshape(logits, (self.batch_size*self.dynamic_n_frames, n_transform_latents))

            y, x, h, w = tf.split(logits, n_transform_latents, axis=1)
            h = (0.9 - 0.5) * tf.nn.sigmoid(h) + 0.5
            w = (0.9 - 0.5) * tf.nn.sigmoid(w) + 0.5
            y = (1 - h) * tf.nn.tanh(y)
            x = (1 - w) * tf.nn.tanh(x)

            # --- decode ---

            background = self.background_decoder(bg_attr, self.image_depth, self.is_training)
            bg_shape = cfg.background_cfg.bg_shape
            background = background[:, :bg_shape[0], :bg_shape[1], :]
            assert background.shape[1:3] == bg_shape
            background_raw = tf.nn.sigmoid(tf.clip_by_value(background, -10, 10))

            transform_constraints = snt.AffineWarpConstraints.no_shear_2d()

            warper = snt.AffineGridWarper(
                bg_shape, (self.image_height, self.image_width), transform_constraints)

            transforms = tf.concat([w, x, h, y], axis=-1)
            grid_coords = warper(transforms)

            grid_coords = tf.reshape(
                grid_coords,
                (self.batch_size, self.dynamic_n_frames, *tf_shape(grid_coords)[1:]))

            background = tf.contrib.resampler.resampler(background_raw, grid_coords)

            self._tensors.update(
                bg_attr_mean=bg_attr_mean,
                bg_attr_std=bg_attr_std,
                bg_attr_kl=bg_attr_kl,
                bg_attr=bg_attr,
                bg_y=tf.reshape(y, (self.batch_size, self.dynamic_n_frames, 1)),
                bg_x=tf.reshape(x, (self.batch_size, self.dynamic_n_frames, 1)),
                bg_h=tf.reshape(h, (self.batch_size, self.dynamic_n_frames, 1)),
                bg_w=tf.reshape(w, (self.batch_size, self.dynamic_n_frames, 1)),
                bg_transform_kl=kl,
                bg_raw=background_raw,
            )

        elif cfg.background_cfg.mode == "data":
            background = self._tensors["ground_truth_background"]

        else:
            raise Exception("Unrecognized background mode: {}.".format(cfg.background_cfg.mode))

        self._tensors["background"] = background[:, :self.dynamic_n_frames]
示例#20
0
def _find_connected_componenents_body(mask):
    components = tf.contrib.image.connected_components(mask)

    total_n_objects = tf.to_int32(tf.reduce_max(components))
    indices = tf.range(1, total_n_objects + 1)

    maxs = tf.reduce_max(components, axis=(1, 2))

    # So that we don't pick up zeros.
    for_mins = tf.where(mask, components,
                        (total_n_objects + 1) * tf.ones_like(components))
    mins = tf.reduce_min(for_mins, axis=(1, 2))

    n_objects = tf.to_int32(tf.maximum((maxs - mins) + 1, 0))

    under = indices[None, :] <= maxs[:, None]
    over = indices[None, :] >= mins[:, None]

    both = tf.to_int32(tf.logical_and(under, over))
    batch_indices_for_objects = tf.argmax(both, axis=0)

    assert_valid_batch_indices = tf.Assert(tf.reduce_all(
        tf.equal(tf.reduce_sum(both, axis=0), 1)), [both],
                                           name="assert_valid_batch_indices")

    with tf.control_dependencies([assert_valid_batch_indices]):
        batch_indices_for_objects = tf.identity(batch_indices_for_objects)

    _, image_height, image_width, *_ = tf_shape(mask)
    cell = BboxCell(components, batch_indices_for_objects, image_height,
                    image_width)

    # For each object, get its bounding box by using `indices` to figure out which element of
    # `components` the object appears in, and then check that element
    object_bboxes, _ = dynamic_rnn(cell,
                                   indices[:, None, None],
                                   initial_state=cell.zero_state(
                                       1, tf.float32),
                                   parallel_iterations=10,
                                   swap_memory=False,
                                   time_major=True)

    # Couldn't I have just iterated through all object indices and used tf.where on `components` to simultaneously
    # get both the bounding box and the batch index? Yes, but I think I thought that would be expensive
    # (have to look through the entirety of `components` once for each object).

    # Get rid of dummy batch dim created for dynamic_rnn
    object_bboxes = object_bboxes[:, 0, :]

    obj = tf.sequence_mask(n_objects)
    routing = tf.reshape(tf.to_int32(obj), (-1, ))
    routing = tf.cumsum(routing, exclusive=True)
    routing = tf.reshape(routing, tf.shape(obj))
    obj = tf.to_float(obj[:, :, None])

    return dict(
        normalized_box=tf.gather(object_bboxes, routing, axis=0),
        obj=obj,
        n_objects=n_objects,
        max_objects=tf.reduce_max(n_objects),
    )
示例#21
0
    def _call(self, objects, background, is_training, appearance_only=False):
        if not self.initialized:
            self.image_depth = tf_shape(background)[-1]

        self.maybe_build_subnet("object_decoder")

        # --- compute sprite appearance from attr using object decoder ---

        appearance_logits = apply_object_wise(
            self.object_decoder, objects.attr,
            self.object_shape + (self.image_depth + 1, ), is_training)

        appearance_logits = appearance_logits * ([self.color_logit_scale] * 3 +
                                                 [self.alpha_logit_scale])
        appearance_logits = appearance_logits + ([0.] * 3 +
                                                 [self.alpha_logit_bias])

        appearance = tf.nn.sigmoid(
            tf.clip_by_value(appearance_logits, -10., 10.))

        if appearance_only:
            return dict(appearance=appearance)

        appearance_for_output = appearance

        batch_size, *obj_leading_shape, _, _, _ = tf_shape(appearance)
        n_objects = np.prod(obj_leading_shape)
        appearance = tf.reshape(
            appearance,
            (batch_size, n_objects, *self.object_shape, self.image_depth + 1))

        obj_colors, obj_alpha = tf.split(appearance, [self.image_depth, 1],
                                         axis=-1)

        if "alpha" in self.no_gradient:
            obj_alpha = tf.stop_gradient(obj_alpha)

        if "alpha" in self.fixed_values:
            obj_alpha = float(
                self.fixed_values["alpha"]) * tf.ones_like(obj_alpha)

        obj_alpha *= tf.reshape(objects.obj, (batch_size, n_objects, 1, 1, 1))

        z = tf.reshape(objects.z, (batch_size, n_objects, 1, 1, 1))
        obj_importance = tf.maximum(obj_alpha * z, 0.01)

        object_maps = tf.concat([obj_colors, obj_alpha, obj_importance],
                                axis=-1)

        ys, xs, yt, xt = objects.ys, objects.xs, objects.yt, objects.xt

        scales = tf.concat([ys, xs], axis=-1)
        scales = tf.reshape(scales, (batch_size, n_objects, 2))

        offsets = tf.concat([yt, xt], axis=-1)
        offsets = tf.reshape(offsets, (batch_size, n_objects, 2))

        # --- Compose images ---

        n_objects_per_image = tf.fill((batch_size, ), int(n_objects))

        output = render_sprites.render_sprites(object_maps,
                                               n_objects_per_image, scales,
                                               offsets, background)

        return dict(appearance=appearance_for_output, output=output)
示例#22
0
    def build_representation(self):
        # --- init modules ---

        if self.encoder is None:
            self.encoder = self.build_encoder(scope="encoder")
            if "encoder" in self.fixed_weights:
                self.encoder.fix_variables()

        if self.cell is None and self.build_cell is not None:
            self.cell = cfg.build_cell(scope="cell")
            if "cell" in self.fixed_weights:
                self.cell.fix_variables()

        if self.decoder is None:
            self.decoder = cfg.build_decoder(scope="decoder")
            if "decoder" in self.fixed_weights:
                self.decoder.fix_variables()

        # --- encode ---

        inp_trailing_shape = tf_shape(self.inp)[2:]
        video = tf.reshape(self.inp, (self.batch_size * self.dynamic_n_frames, *inp_trailing_shape))
        encoder_output = self.encoder(video, 2 * self.A, self.is_training)

        eo_trailing_shape = tf_shape(encoder_output)[1:]
        encoder_output = tf.reshape(
            encoder_output, (self.batch_size, self.dynamic_n_frames, *eo_trailing_shape))

        if self.cell is None:
            attr = encoder_output
        else:

            if self.flat_latent:
                n_trailing_dims = int(np.prod(eo_trailing_shape))
                encoder_output = tf.reshape(
                    encoder_output, (self.batch_size, self.dynamic_n_frames, n_trailing_dims))
            else:
                raise Exception("NotImplemented")

                n_objects = int(np.prod(eo_trailing_shape[:-1]))
                D = eo_trailing_shape[-1]
                encoder_output = tf.reshape(
                    encoder_output, (self.batch_size, self.dynamic_n_frames, n_objects, D))

            encoder_output = tf.layers.flatten(encoder_output)

            attr, final_state = dynamic_rnn(
                self.cell, encoder_output, initial_state=self.cell.zero_state(self.batch_size, tf.float32),
                parallel_iterations=1, swap_memory=False, time_major=False)

        attr_mean, attr_log_std = tf.split(attr, 2, axis=-1)
        attr_std = tf.math.softplus(attr_log_std)

        if not self.noisy:
            attr_std = tf.zeros_like(attr_std)

        attr, attr_kl = normal_vae(attr_mean, attr_std, self.attr_prior_mean, self.attr_prior_std)

        self._tensors.update(attr_mean=attr_mean, attr_std=attr_std, attr_kl=attr_kl, attr=attr)

        # --- decode ---

        decoder_input = tf.reshape(attr, (self.batch_size*self.dynamic_n_frames, *tf_shape(attr)[2:]))

        reconstruction = self.decoder(decoder_input, tf_shape(self.inp)[2:], self.is_training)
        reconstruction = reconstruction[:, :self.obs_shape[1], :self.obs_shape[2], :]
        reconstruction = tf.reshape(reconstruction, (self.batch_size, self.dynamic_n_frames, *self.obs_shape[1:]))

        reconstruction = tf.nn.sigmoid(tf.clip_by_value(reconstruction, -10, 10))
        self._tensors["output"] = reconstruction

        # --- losses ---

        if self.train_kl:
            self.losses['attr_kl'] = tf_mean_sum(self._tensors["attr_kl"])

        if self.train_reconstruction:
            self._tensors['per_pixel_reconstruction_loss'] = xent_loss(pred=reconstruction, label=self.inp)
            self.losses['reconstruction'] = tf_mean_sum(self._tensors['per_pixel_reconstruction_loss'])
示例#23
0
    def _call(self, objects, background, is_training, appearance_only=False, mask_only=False):
        """ If mask_only==True, then we ignore the provided background, using a black blackground instead,
            and also ignore the computed appearance, using all-white appearances instead.

        """
        if not self.initialized:
            self.image_depth = tf_shape(background)[-1]

        single = False
        if isinstance(objects, dict):
            single = True
            objects = [objects]

        _object_maps = []
        _scales = []
        _offsets = []
        _appearance = []

        for i, obj in enumerate(objects):
            anchor_box = self.anchor_boxes[i]
            object_shape = self.object_shapes[i]

            object_decoder = self.maybe_build_subnet(
                "object_decoder_for_flight_{}".format(i), builder_name='build_object_decoder')

            # --- compute sprite appearance from attr using object decoder ---

            appearance_logit = apply_object_wise(
                object_decoder, obj.attr,
                output_size=object_shape + (self.image_depth+1,),
                is_training=is_training)

            appearance_logit = appearance_logit * ([self.color_logit_scale] * self.image_depth + [self.alpha_logit_scale])
            appearance_logit = appearance_logit + ([0.] * self.image_depth + [self.alpha_logit_bias])

            appearance = tf.nn.sigmoid(tf.clip_by_value(appearance_logit, -10., 10.))
            _appearance.append(appearance)

            if appearance_only:
                continue

            batch_size, *obj_leading_shape, _, _, _ = tf_shape(appearance)
            n_objects = np.prod(obj_leading_shape)
            appearance = tf.reshape(
                appearance, (batch_size, n_objects, *object_shape, self.image_depth+1))

            obj_colors, obj_alpha = tf.split(appearance, [self.image_depth, 1], axis=-1)

            if mask_only:
                obj_colors = tf.ones_like(obj_colors)

            obj_alpha *= tf.reshape(obj.obj, (batch_size, n_objects, 1, 1, 1))

            z = tf.reshape(obj.z, (batch_size, n_objects, 1, 1, 1))
            obj_importance = tf.maximum(obj_alpha * z / self.importance_temp, 0.01)

            object_maps = tf.concat([obj_colors, obj_alpha, obj_importance], axis=-1)

            *_, image_height, image_width, _ = tf_shape(background)

            yt, xt, ys, xs = coords_to_image_space(
                obj.yt, obj.xt, obj.ys, obj.xs,
                (image_height, image_width), anchor_box, top_left=True)

            scales = tf.concat([ys, xs], axis=-1)
            scales = tf.reshape(scales, (batch_size, n_objects, 2))

            offsets = tf.concat([yt, xt], axis=-1)
            offsets = tf.reshape(offsets, (batch_size, n_objects, 2))

            _object_maps.append(object_maps)
            _scales.append(scales)
            _offsets.append(offsets)

        if single:
            _appearance = _appearance[0]

        if appearance_only:
            return dict(appearance=_appearance)

        if mask_only:
            background = tf.zeros_like(background)

        # --- Compose images ---

        output = render_sprites.render_sprites(
            _object_maps,
            _scales,
            _offsets,
            background
        )

        return dict(
            appearance=_appearance,
            output=output)
示例#24
0
    def _compute_obj_kl(self, tensors, existing_objects=None):
        # --- compute obj_kl ---

        obj_pre_sigmoid = tensors["obj_pre_sigmoid"]
        obj_log_odds = tensors["obj_log_odds"]
        obj_prob = tensors["obj_prob"]
        obj = tensors["obj"]
        batch_size, n_objects, _ = tf_shape(obj)

        max_n_objects = n_objects

        if existing_objects is not None:
            _, n_existing_objects, _ = tf_shape(existing_objects)
            existing_objects = tf.reshape(existing_objects,
                                          (batch_size, n_existing_objects))
            max_n_objects += n_existing_objects

        count_support = tf.range(max_n_objects + 1, dtype=tf.float32)

        if self.count_prior_dist is not None:
            if self.count_prior_dist is not None:
                assert len(self.count_prior_dist) == (max_n_objects + 1)
            count_distribution = tf.constant(self.count_prior_dist,
                                             dtype=tf.float32)
        else:
            count_prior_prob = tf.nn.sigmoid(self.count_prior_log_odds)
            count_distribution = (1 - count_prior_prob) * (count_prior_prob**
                                                           count_support)

        normalizer = tf.reduce_sum(count_distribution)
        count_distribution = count_distribution / tf.maximum(normalizer, 1e-6)
        count_distribution = tf.tile(count_distribution[None, :],
                                     (batch_size, 1))

        if existing_objects is not None:
            count_so_far = tf.reduce_sum(tf.round(existing_objects),
                                         axis=1,
                                         keepdims=True)

            count_distribution = (
                count_distribution *
                tf_binomial_coefficient(count_support, count_so_far) *
                tf_binomial_coefficient(max_n_objects - count_support,
                                        n_existing_objects - count_so_far))

            normalizer = tf.reduce_sum(count_distribution,
                                       axis=1,
                                       keepdims=True)
            count_distribution = count_distribution / tf.maximum(
                normalizer, 1e-6)
        else:
            count_so_far = tf.zeros((batch_size, 1), dtype=tf.float32)

        obj_kl = []
        for i in range(n_objects):
            p_z_given_Cz = tf.maximum(count_support[None, :] - count_so_far,
                                      0) / (max_n_objects - i)

            # Reshape for batch matmul
            _count_distribution = count_distribution[:, None, :]
            _p_z_given_Cz = p_z_given_Cz[:, :, None]

            p_z = tf.matmul(_count_distribution, _p_z_given_Cz)[:, :, 0]

            if self.use_concrete_kl:
                prior_log_odds = tf_safe_log(p_z) - tf_safe_log(1 - p_z)
                _obj_kl = concrete_binary_sample_kl(
                    obj_pre_sigmoid[:, i, :],
                    obj_log_odds[:, i, :],
                    self.obj_concrete_temp,
                    prior_log_odds,
                    self.obj_concrete_temp,
                )
            else:
                prob = obj_prob[:, i, :]

                _obj_kl = (prob * (tf_safe_log(prob) - tf_safe_log(p_z)) +
                           (1 - prob) *
                           (tf_safe_log(1 - prob) - tf_safe_log(1 - p_z)))

            obj_kl.append(_obj_kl)

            sample = tf.to_float(obj[:, i, :] > 0.5)
            mult = sample * p_z_given_Cz + (1 - sample) * (1 - p_z_given_Cz)
            count_distribution = mult * count_distribution
            normalizer = tf.reduce_sum(count_distribution,
                                       axis=1,
                                       keepdims=True)
            normalizer = tf.maximum(normalizer, 1e-6)
            count_distribution = count_distribution / normalizer

            count_so_far += sample

        obj_kl = tf.reshape(tf.concat(obj_kl, axis=1),
                            (batch_size, n_objects, 1))

        return obj_kl
示例#25
0
    def _call(self, input_locs, input_features, reference_locs,
              reference_features, is_training):
        """
        input_features: (B, n_inp, n_hidden)
        input_locs: (B, n_inp, loc_dim)
        reference_locs: (B, n_ref, loc_dim)

        """
        assert (reference_features is not None) == self.do_object_wise

        if not self.is_built:
            self.relation_func = self.build_mlp(scope="relation_func")

            if self.do_object_wise:
                self.object_wise_func = self.build_mlp(
                    scope="object_wise_func")

            self.is_built = True

        loc_dim = tf_shape(input_locs)[-1]
        n_ref = tf_shape(reference_locs)[-2]
        batch_size, n_inp, _ = tf_shape(input_features)

        input_locs = tf.broadcast_to(input_locs, (batch_size, n_inp, loc_dim))
        reference_locs = tf.broadcast_to(reference_locs,
                                         (batch_size, n_ref, loc_dim))

        adjusted_locs = input_locs[:,
                                   None, :, :] - reference_locs[:, :,
                                                                None, :]  # (B, n_ref, n_inp, loc_dim)
        adjusted_features = tf.tile(
            input_features[:, None],
            (1, n_ref, 1, 1))  # (B, n_ref, n_inp, features_dim)
        relation_input = tf.concat([adjusted_features, adjusted_locs], axis=-1)

        if self.do_object_wise:
            object_wise = apply_object_wise(
                self.object_wise_func,
                reference_features,
                output_size=self.n_hidden,
                is_training=is_training)  # (B, n_ref, n_hidden)

            _object_wise = tf.tile(object_wise[:, :, None], (1, 1, n_inp, 1))
            relation_input = tf.concat([relation_input, _object_wise], axis=-1)
        else:
            object_wise = None

        V = apply_object_wise(
            self.relation_func,
            relation_input,
            output_size=self.n_hidden,
            is_training=is_training)  # (B, n_ref, n_inp, n_hidden)

        attention_weights = tf.exp(-0.5 * tf.reduce_sum(
            (adjusted_locs / self.kernel_std)**2, axis=3))
        attention_weights = (attention_weights / (2 * np.pi)**(loc_dim / 2) /
                             self.kernel_std**loc_dim)  # (B, n_ref, n_inp)

        result = tf.reduce_sum(V * attention_weights[..., None],
                               axis=2)  # (B, n_ref, n_hidden)

        if self.do_object_wise:
            result += object_wise

        # result = tf.contrib.layers.layer_norm(result)

        return result
示例#26
0
    def _call(self, objects, background, is_training, appearance_only=False):
        if not self.initialized:
            self.image_depth = tf_shape(background)[-1]

        self.maybe_build_subnet("object_decoder")

        # --- compute sprite appearance from attr using object decoder ---

        appearance_logit = apply_object_wise(
            self.object_decoder, objects.attr,
            output_size=self.object_shape + (self.image_depth+1,),
            is_training=is_training)

        appearance_logit = appearance_logit * ([self.color_logit_scale] * 3 + [self.alpha_logit_scale])
        appearance_logit = appearance_logit + ([0.] * 3 + [self.alpha_logit_bias])

        appearance = tf.nn.sigmoid(tf.clip_by_value(appearance_logit, -10., 10.))

        if appearance_only:
            return dict(appearance=appearance)

        appearance_for_output = appearance

        batch_size, *obj_leading_shape, _, _, _ = tf_shape(appearance)
        n_objects = np.prod(obj_leading_shape)
        appearance = tf.reshape(
            appearance, (batch_size, n_objects, *self.object_shape, self.image_depth+1))

        obj_colors, obj_alpha = tf.split(appearance, [self.image_depth, 1], axis=-1)

        obj_alpha *= tf.reshape(objects.render_obj, (batch_size, n_objects, 1, 1, 1))

        z = tf.reshape(objects.z, (batch_size, n_objects, 1, 1, 1))
        obj_importance = tf.maximum(obj_alpha * z / self.importance_temp, 0.01)

        object_maps = tf.concat([obj_colors, obj_alpha, obj_importance], axis=-1)

        *_, image_height, image_width, _ = tf_shape(background)

        yt, xt, ys, xs = coords_to_image_space(
            objects.yt, objects.xt, objects.ys, objects.xs,
            (image_height, image_width), self.anchor_box, top_left=True)

        scales = tf.concat([ys, xs], axis=-1)
        scales = tf.reshape(scales, (batch_size, n_objects, 2))

        offsets = tf.concat([yt, xt], axis=-1)
        offsets = tf.reshape(offsets, (batch_size, n_objects, 2))

        # --- Compose images ---

        n_objects_per_image = tf.fill((batch_size,), int(n_objects))

        output = render_sprites.render_sprites(
            object_maps,
            n_objects_per_image,
            scales,
            offsets,
            background
        )

        return dict(
            appearance=appearance_for_output,
            output=output)
示例#27
0
    def build_representation(self):
        # --- init modules ---
        self.B = len(self.anchor_boxes)

        if self.backbone is None:
            self.backbone = self.build_backbone(scope="backbone")
            if "backbone" in self.fixed_weights:
                self.backbone.fix_variables()

        if self.feature_fuser is None:
            self.feature_fuser = self.build_feature_fuser(scope="feature_fuser")
            if "feature_fuser" in self.fixed_weights:
                self.feature_fuser.fix_variables()

        if self.obj_feature_extractor is None and self.build_obj_feature_extractor is not None:
            self.obj_feature_extractor = self.build_obj_feature_extractor(scope="obj_feature_extractor")
            if "obj_feature_extractor" in self.fixed_weights:
                self.obj_feature_extractor.fix_variables()

        backbone_output, n_grid_cells, grid_cell_size = self.backbone(
            self.inp, self.B*self.n_backbone_features, self.is_training)

        self.H, self.W = [int(i) for i in n_grid_cells]
        self.HWB = self.H * self.W * self.B
        self.pixels_per_cell = tuple(int(i) for i in grid_cell_size)
        H, W, B = self.H, self.W, self.B

        if self.object_layer is None:
            self.object_layer = ObjectLayer(self.pixels_per_cell, scope="objects")

        self.object_rep_tensors = []
        object_rep_tensors = None
        _tensors = defaultdict(list)

        for f in range(self.n_frames):
            print("Bulding network for frame {}".format(f))
            early_frame_features = backbone_output[:, f]

            if f > 0 and self.obj_feature_extractor is not None:
                object_features = object_rep_tensors["all"]
                object_features = tf.reshape(
                    object_features, (self.batch_size, H, W, B*tf_shape(object_features)[-1]))
                early_frame_features += self.obj_feature_extractor(
                    object_features, B*self.n_backbone_features, self.is_training)

            frame_features = self.feature_fuser(
                early_frame_features, B*self.n_backbone_features, self.is_training)

            frame_features = tf.reshape(
                frame_features, (self.batch_size, H, W, B, self.n_backbone_features))

            object_rep_tensors = self.object_layer(
                self.inp[:, f], frame_features, self._tensors["background"][:, f], self.is_training)

            self.object_rep_tensors.append(object_rep_tensors)

            for k, v in object_rep_tensors.items():
                _tensors[k].append(v)

        self._tensors.update(**{k: tf.stack(v, axis=1) for k, v in _tensors.items()})

        # --- specify values to record ---

        obj = self._tensors["obj"]
        pred_n_objects = self._tensors["pred_n_objects"]

        self.record_tensors(
            batch_size=self.batch_size,
            float_is_training=self.float_is_training,

            cell_y=self._tensors["cell_y"],
            cell_x=self._tensors["cell_x"],
            h=self._tensors["h"],
            w=self._tensors["w"],
            z=self._tensors["z"],
            area=self._tensors["area"],

            cell_y_std=self._tensors["cell_y_std"],
            cell_x_std=self._tensors["cell_x_std"],
            h_std=self._tensors["h_std"],
            w_std=self._tensors["w_std"],
            z_std=self._tensors["z_std"],

            n_objects=pred_n_objects,
            obj=obj,

            latent_area=self._tensors["latent_area"],
            latent_hw=self._tensors["latent_hw"],

            attr=self._tensors["attr"],
        )

        # --- losses ---

        if self.train_reconstruction:
            output = self._tensors['output']
            inp = self._tensors['inp']
            self._tensors['per_pixel_reconstruction_loss'] = xent_loss(pred=output, label=inp)
            self.losses['reconstruction'] = (
                self.reconstruction_weight * tf_mean_sum(self._tensors['per_pixel_reconstruction_loss'])
            )

        if self.train_kl:
            self.losses.update(
                obj_kl=self.kl_weight * tf_mean_sum(self._tensors["obj_kl"]),
                cell_y_kl=self.kl_weight * tf_mean_sum(obj * self._tensors["cell_y_kl"]),
                cell_x_kl=self.kl_weight * tf_mean_sum(obj * self._tensors["cell_x_kl"]),
                h_kl=self.kl_weight * tf_mean_sum(obj * self._tensors["h_kl"]),
                w_kl=self.kl_weight * tf_mean_sum(obj * self._tensors["w_kl"]),
                z_kl=self.kl_weight * tf_mean_sum(obj * self._tensors["z_kl"]),
                attr_kl=self.kl_weight * tf_mean_sum(obj * self._tensors["attr_kl"]),
            )

            if cfg.background_cfg.mode == "learn_and_transform":
                self.losses.update(
                    bg_attr_kl=self.kl_weight * tf_mean_sum(self._tensors["bg_attr_kl"]),
                    bg_transform_kl=self.kl_weight * tf_mean_sum(self._tensors["bg_transform_kl"]),
                )

        # --- other evaluation metrics ---

        if "n_annotations" in self._tensors:
            count_1norm = tf.to_float(
                tf.abs(tf.to_int32(self._tensors["pred_n_objects_hard"]) - self._tensors["n_valid_annotations"]))

            self.record_tensors(
                count_1norm=count_1norm,
                count_error=count_1norm > 0.5,
            )
示例#28
0
 def __call__(self, inp):
     output = self.wrapped(inp, None, True)[0]
     batch_size = tf_shape(output)[0]
     n_trailing = np.prod(tf_shape(output)[1:])
     return tf.reshape(output, (batch_size, n_trailing))