Example #1
0
    def _build_obj(self, obj_logits, is_training, **kwargs):
        obj_logits = self.training_wheels * tf.stop_gradient(obj_logits) + (
            1 - self.training_wheels) * obj_logits
        obj_logits = obj_logits / self.obj_temp

        obj_log_odds = tf.clip_by_value(obj_logits, -10., 10.)

        obj_pre_sigmoid = concrete_binary_pre_sigmoid_sample(
            obj_log_odds, self.obj_concrete_temp)
        raw_obj = tf.nn.sigmoid(obj_pre_sigmoid)

        if self.noisy:
            obj = (self.float_is_training * raw_obj +
                   (1 - self.float_is_training) * tf.round(raw_obj))
        else:
            obj = tf.round(raw_obj)

        if "obj" in self.no_gradient:
            obj = tf.stop_gradient(obj)

        if "obj" in self.fixed_values:
            obj = self.fixed_values['obj'] * tf.ones_like(obj)

        return dict(
            obj=obj,
            raw_obj=raw_obj,
            obj_pre_sigmoid=obj_pre_sigmoid,
            obj_log_odds=obj_log_odds,
            obj_prob=tf.nn.sigmoid(obj_log_odds),
        )
Example #2
0
    def _build_obj(self, obj_logit, is_training, **kwargs):
        obj_logit = self.training_wheels * tf.stop_gradient(obj_logit) + (1-self.training_wheels) * obj_logit
        obj_log_odds = tf.clip_by_value(obj_logit / self.obj_temp, -10., 10.)

        obj_pre_sigmoid = (
            self._noisy * concrete_binary_pre_sigmoid_sample(obj_log_odds, self.obj_concrete_temp)
            + (1 - self._noisy) * obj_log_odds
        )

        obj = tf.nn.sigmoid(obj_pre_sigmoid)

        return dict(
            obj_log_odds=obj_log_odds,
            obj_prob=tf.nn.sigmoid(obj_log_odds),
            obj_pre_sigmoid=obj_pre_sigmoid,
            obj=obj,
        )
Example #3
0
    def _build_obj(self, obj_logit, is_training, **kwargs):
        obj_logit = self.training_wheels * tf.stop_gradient(obj_logit) + (1-self.training_wheels) * obj_logit
        obj_log_odds = tf.clip_by_value(obj_logit / self.obj_temp, -10., 10.)

        if self.noisy:
            obj_pre_sigmoid = concrete_binary_pre_sigmoid_sample(obj_log_odds, self.obj_concrete_temp)
        else:
            obj_pre_sigmoid = obj_log_odds

        obj = tf.nn.sigmoid(obj_pre_sigmoid)

        render_obj = (
            self.float_is_training * obj
            + (1 - self.float_is_training) * tf.round(obj)
        )

        return dict(
            obj_log_odds=obj_log_odds,
            obj_prob=tf.nn.sigmoid(obj_log_odds),
            obj_pre_sigmoid=obj_pre_sigmoid,
            obj=obj,
            render_obj=render_obj,
        )
Example #4
0
    def _body(self, inp, features, objects, is_posterior):
        """
        Summary of how updates are done for the different variables:

        glimpse': glimpse_params = where + 0.1 * predicted_logit

        where_y/x: new_where_y/x = where_y/x + where_t_scale * tanh(predicted_logit)
        where_h/w: new_where_h/w_logit = where_h/w_logit + predicted_logit
        what: Hard to summarize here, taken from SQAIR. Kind of like an LSTM.
        depth: new_depth_logit = depth_logit + predicted_logit
        obj: new_obj = obj * sigmoid(predicted_logit)

        """
        batch_size, n_objects, _ = tf_shape(features)

        new_objects = AttrDict()

        is_posterior_tf = tf.ones_like(features[..., 0:2])
        if is_posterior:
            is_posterior_tf = is_posterior_tf * [1, 0]
        else:
            is_posterior_tf = is_posterior_tf * [0, 1]

        base_features = tf.concat([features, is_posterior_tf], axis=-1)

        cyt, cxt, ys, xs = tf.split(objects.normalized_box, 4, axis=-1)

        # Do this regardless of is_posterior, otherwise ScopedFunction gets messed up
        glimpse_dim = self.object_shape[0] * self.object_shape[1]
        glimpse_prime_params = apply_object_wise(self.glimpse_prime_network,
                                                 base_features,
                                                 output_size=4 +
                                                 2 * glimpse_dim,
                                                 is_training=self.is_training)

        glimpse_prime_params, glimpse_prime_mask_logit, glimpse_mask_logit = \
            tf.split(glimpse_prime_params, [4, glimpse_dim, glimpse_dim], axis=-1)

        if is_posterior:
            # --- obtain final parameters for glimpse prime by modifying current pose ---

            _yt, _xt, _ys, _xs = tf.split(glimpse_prime_params, 4, axis=-1)

            g_yt = cyt + 0.1 * _yt
            g_xt = cxt + 0.1 * _xt
            g_ys = ys + 0.1 * _ys
            g_xs = xs + 0.1 * _xs

            # --- extract glimpse prime ---

            _, image_height, image_width, _ = tf_shape(inp)
            g_yt, g_xt, g_ys, g_xs = coords_to_image_space(
                g_yt,
                g_xt,
                g_ys,
                g_xs, (image_height, image_width),
                self.anchor_box,
                top_left=False)
            glimpse_prime = extract_affine_glimpse(inp, self.object_shape,
                                                   g_yt, g_xt, g_ys, g_xs,
                                                   self.edge_resampler)
        else:
            g_yt = tf.zeros_like(cyt)
            g_xt = tf.zeros_like(cxt)
            g_ys = tf.zeros_like(ys)
            g_xs = tf.zeros_like(xs)
            glimpse_prime = tf.zeros(
                (batch_size, n_objects, *self.object_shape, self.image_depth))

        glimpse_prime_mask = tf.nn.sigmoid(glimpse_prime_mask_logit + 1.)
        leading_mask_shape = tf_shape(glimpse_prime)[:-1]
        glimpse_prime_mask = tf.reshape(glimpse_prime_mask,
                                        (*leading_mask_shape, 1))

        new_objects.update(
            glimpse_prime_box=tf.concat([g_yt, g_xt, g_ys, g_xs], axis=-1),
            glimpse_prime=glimpse_prime,
            glimpse_prime_mask=glimpse_prime_mask,
        )

        glimpse_prime *= glimpse_prime_mask

        # --- encode glimpse ---

        encoded_glimpse_prime = apply_object_wise(self.glimpse_prime_encoder,
                                                  glimpse_prime,
                                                  n_trailing_dims=3,
                                                  output_size=self.A,
                                                  is_training=self.is_training)

        if not is_posterior:
            encoded_glimpse_prime = tf.zeros((batch_size, n_objects, self.A),
                                             dtype=tf.float32)

        # --- position and scale ---

        # roughly:
        # base_features == temporal_state, encoded_glimpse_prime == hidden_output
        # hidden_output conditions on encoded_glimpse, and that's the only place encoded_glimpse_prime is used.

        # Here SQAIR conditions on the actual location values from the previous timestep, but we leave that out for now.
        d_box_inp = tf.concat([base_features, encoded_glimpse_prime], axis=-1)
        d_box_params = apply_object_wise(self.d_box_network,
                                         d_box_inp,
                                         output_size=8,
                                         is_training=self.is_training)

        d_box_mean, d_box_log_std = tf.split(d_box_params, 2, axis=-1)

        d_box_std = self.std_nonlinearity(d_box_log_std)

        d_box_mean = self.training_wheels * tf.stop_gradient(d_box_mean) + (
            1 - self.training_wheels) * d_box_mean
        d_box_std = self.training_wheels * tf.stop_gradient(d_box_std) + (
            1 - self.training_wheels) * d_box_std

        d_yt_mean, d_xt_mean, d_ys, d_xs = tf.split(d_box_mean, 4, axis=-1)
        d_yt_std, d_xt_std, ys_std, xs_std = tf.split(d_box_std, 4, axis=-1)

        # --- position ---

        # We predict position a bit differently from scale. For scale we want to put a prior on the actual value of
        # the scale, whereas for position we want to put a prior on the difference in position over timesteps.

        d_yt_logit = Normal(loc=d_yt_mean, scale=d_yt_std).sample()
        d_xt_logit = Normal(loc=d_xt_mean, scale=d_xt_std).sample()

        d_yt = self.where_t_scale * tf.nn.tanh(d_yt_logit)
        d_xt = self.where_t_scale * tf.nn.tanh(d_xt_logit)

        new_cyt = cyt + d_yt
        new_cxt = cxt + d_xt

        new_abs_posn = objects.abs_posn + tf.concat([d_yt, d_xt], axis=-1)

        # --- scale ---

        new_ys_mean = objects.ys_logit + d_ys
        new_xs_mean = objects.xs_logit + d_xs

        new_ys_logit = Normal(loc=new_ys_mean, scale=ys_std).sample()
        new_xs_logit = Normal(loc=new_xs_mean, scale=xs_std).sample()

        new_ys = float(self.max_hw - self.min_hw) * tf.nn.sigmoid(
            tf.clip_by_value(new_ys_logit, -10, 10)) + self.min_hw
        new_xs = float(self.max_hw - self.min_hw) * tf.nn.sigmoid(
            tf.clip_by_value(new_xs_logit, -10, 10)) + self.min_hw

        if self.use_abs_posn:
            box_params = tf.concat([
                new_abs_posn, d_yt_logit, d_xt_logit, new_ys_logit,
                new_xs_logit
            ],
                                   axis=-1)
        else:
            box_params = tf.concat(
                [d_yt_logit, d_xt_logit, new_ys_logit, new_xs_logit], axis=-1)

        new_objects.update(
            abs_posn=new_abs_posn,
            yt=new_cyt,
            xt=new_cxt,
            ys=new_ys,
            xs=new_xs,
            normalized_box=tf.concat([new_cyt, new_cxt, new_ys, new_xs],
                                     axis=-1),
            d_yt_logit=d_yt_logit,
            d_xt_logit=d_xt_logit,
            ys_logit=new_ys_logit,
            xs_logit=new_xs_logit,
            d_yt_logit_mean=d_yt_mean,
            d_xt_logit_mean=d_xt_mean,
            ys_logit_mean=new_ys_mean,
            xs_logit_mean=new_xs_mean,
            d_yt_logit_std=d_yt_std,
            d_xt_logit_std=d_xt_std,
            ys_logit_std=ys_std,
            xs_logit_std=xs_std,
        )

        # --- attributes ---

        # --- extract a glimpse using new box ---

        if is_posterior:
            _, image_height, image_width, _ = tf_shape(inp)
            _new_cyt, _new_cxt, _new_ys, _new_xs = coords_to_image_space(
                new_cyt,
                new_cxt,
                new_ys,
                new_xs, (image_height, image_width),
                self.anchor_box,
                top_left=False)

            glimpse = extract_affine_glimpse(inp, self.object_shape, _new_cyt,
                                             _new_cxt, _new_ys, _new_xs,
                                             self.edge_resampler)

        else:
            glimpse = tf.zeros(
                (batch_size, n_objects, *self.object_shape, self.image_depth))

        glimpse_mask = tf.nn.sigmoid(glimpse_mask_logit + 1.)
        leading_mask_shape = tf_shape(glimpse)[:-1]
        glimpse_mask = tf.reshape(glimpse_mask, (*leading_mask_shape, 1))

        glimpse *= glimpse_mask

        encoded_glimpse = apply_object_wise(self.glimpse_encoder,
                                            glimpse,
                                            n_trailing_dims=3,
                                            output_size=self.A,
                                            is_training=self.is_training)

        if not is_posterior:
            encoded_glimpse = tf.zeros((batch_size, n_objects, self.A),
                                       dtype=tf.float32)

        # --- predict change in attributes ---

        # so under sqair we mix between three different values for the attributes:
        # 1. value from previous timestep
        # 2. value predicted directly from glimpse
        # 3. value predicted based on update of temporal cell...this update conditions on hidden_output,
        # the prediction in #2., and the where values.

        # How to do this given that we are predicting the change in attr? We could just directly predict
        # the attr instead, but call it d_attr. After all, it is in this function that we control
        # whether d_attr is added to attr.

        # So, make a prediction based on just the input:

        attr_from_inp = apply_object_wise(self.predict_attr_inp,
                                          encoded_glimpse,
                                          output_size=2 * self.A,
                                          is_training=self.is_training)
        attr_from_inp_mean, attr_from_inp_log_std = tf.split(attr_from_inp,
                                                             [self.A, self.A],
                                                             axis=-1)

        attr_from_inp_std = self.std_nonlinearity(attr_from_inp_log_std)

        # And then a prediction which takes the past into account (predicting gate values at the same time):

        attr_from_temp_inp = tf.concat(
            [base_features, box_params, encoded_glimpse], axis=-1)
        attr_from_temp = apply_object_wise(self.predict_attr_temp,
                                           attr_from_temp_inp,
                                           output_size=5 * self.A,
                                           is_training=self.is_training)

        (attr_from_temp_mean, attr_from_temp_log_std, f_gate_logit,
         i_gate_logit, t_gate_logit) = tf.split(attr_from_temp, 5, axis=-1)

        attr_from_temp_std = self.std_nonlinearity(attr_from_temp_log_std)

        # bias the gates
        f_gate = tf.nn.sigmoid(f_gate_logit + 1) * .9999
        i_gate = tf.nn.sigmoid(i_gate_logit + 1) * .9999
        t_gate = tf.nn.sigmoid(t_gate_logit + 1) * .9999

        attr_mean = f_gate * objects.attr + (
            1 - i_gate) * attr_from_inp_mean + (1 -
                                                t_gate) * attr_from_temp_mean
        attr_std = (1 - i_gate) * attr_from_inp_std + (
            1 - t_gate) * attr_from_temp_std

        new_attr = Normal(loc=attr_mean, scale=attr_std).sample()

        # --- apply change in attributes ---

        new_objects.update(
            attr=new_attr,
            d_attr=new_attr - objects.attr,
            d_attr_mean=attr_mean - objects.attr,
            d_attr_std=attr_std,
            f_gate=f_gate,
            i_gate=i_gate,
            t_gate=t_gate,
            glimpse=glimpse,
            glimpse_mask=glimpse_mask,
        )

        # --- z ---

        d_z_inp = tf.concat(
            [base_features, box_params, new_attr, encoded_glimpse], axis=-1)
        d_z_params = apply_object_wise(self.d_z_network,
                                       d_z_inp,
                                       output_size=2,
                                       is_training=self.is_training)

        d_z_mean, d_z_log_std = tf.split(d_z_params, 2, axis=-1)
        d_z_std = self.std_nonlinearity(d_z_log_std)

        d_z_mean = self.training_wheels * tf.stop_gradient(d_z_mean) + (
            1 - self.training_wheels) * d_z_mean
        d_z_std = self.training_wheels * tf.stop_gradient(d_z_std) + (
            1 - self.training_wheels) * d_z_std

        d_z_logit = Normal(loc=d_z_mean, scale=d_z_std).sample()

        new_z_logit = objects.z_logit + d_z_logit
        new_z = self.z_nonlinearity(new_z_logit)

        new_objects.update(
            z=new_z,
            z_logit=new_z_logit,
            d_z_logit=d_z_logit,
            d_z_logit_mean=d_z_mean,
            d_z_logit_std=d_z_std,
        )

        # --- obj ---

        d_obj_inp = tf.concat(
            [base_features, box_params, new_attr, new_z, encoded_glimpse],
            axis=-1)
        d_obj_logit = apply_object_wise(self.d_obj_network,
                                        d_obj_inp,
                                        output_size=1,
                                        is_training=self.is_training)

        d_obj_logit = self.training_wheels * tf.stop_gradient(d_obj_logit) + (
            1 - self.training_wheels) * d_obj_logit
        d_obj_log_odds = tf.clip_by_value(d_obj_logit / self.obj_temp, -10.,
                                          10.)

        d_obj_pre_sigmoid = (self._noisy * concrete_binary_pre_sigmoid_sample(
            d_obj_log_odds, self.obj_concrete_temp) +
                             (1 - self._noisy) * d_obj_log_odds)

        d_obj = tf.nn.sigmoid(d_obj_pre_sigmoid)

        new_obj = objects.obj * d_obj

        new_objects.update(
            d_obj_log_odds=d_obj_log_odds,
            d_obj_prob=tf.nn.sigmoid(d_obj_log_odds),
            d_obj_pre_sigmoid=d_obj_pre_sigmoid,
            d_obj=d_obj,
            obj=new_obj,
        )

        # --- update each object's hidden state --

        cell_input = tf.concat([box_params, new_attr, new_z, new_obj], axis=-1)

        if is_posterior:
            _, new_objects.prop_state = apply_object_wise(
                self.cell, cell_input, objects.prop_state)
            new_objects.prior_prop_state = new_objects.prop_state
        else:
            _, new_objects.prior_prop_state = apply_object_wise(
                self.cell, cell_input, objects.prior_prop_state)
            new_objects.prop_state = new_objects.prior_prop_state

        return new_objects
Example #5
0
    def _body(self, inp, features, objects, is_posterior):
        batch_size, n_objects, _ = tf_shape(features)

        new_objects = AttrDict()

        is_posterior_tf = tf.ones_like(features[..., 0:2])
        if is_posterior:
            is_posterior_tf = is_posterior_tf * [1, 0]
        else:
            is_posterior_tf = is_posterior_tf * [0, 1]

        base_features = tf.concat([features, is_posterior_tf], axis=-1)

        cyt, cxt, ys, xs = tf.split(objects.normalized_box, 4, axis=-1)

        if self.learn_glimpse_prime:
            # Do this regardless of is_posterior, otherwise ScopedFunction gets messed up
            glimpse_prime_params = apply_object_wise(
                self.glimpse_prime_network,
                base_features,
                output_size=4,
                is_training=self.is_training)
        else:
            glimpse_prime_params = tf.zeros_like(base_features[..., :4])

        if is_posterior:

            if self.learn_glimpse_prime:
                # --- obtain final parameters for glimpse prime by modifying current pose ---
                _yt, _xt, _ys, _xs = tf.split(glimpse_prime_params, 4, axis=-1)

                # This is how it is done in SQAIR
                g_yt = cyt + 0.1 * _yt
                g_xt = cxt + 0.1 * _xt
                g_ys = ys + 0.1 * _ys
                g_xs = xs + 0.1 * _xs

                # g_yt = cyt + self.glimpse_prime_scale * tf.nn.tanh(_yt)
                # g_xt = cxt + self.glimpse_prime_scale * tf.nn.tanh(_xt)
                # g_ys = ys + self.glimpse_prime_scale * tf.nn.tanh(_ys)
                # g_xs = xs + self.glimpse_prime_scale * tf.nn.tanh(_xs)
            else:
                g_yt = cyt
                g_xt = cxt
                g_ys = self.glimpse_prime_scale * ys
                g_xs = self.glimpse_prime_scale * xs

            # --- extract glimpse prime ---

            _, image_height, image_width, _ = tf_shape(inp)
            g_yt, g_xt, g_ys, g_xs = coords_to_image_space(
                g_yt,
                g_xt,
                g_ys,
                g_xs, (image_height, image_width),
                self.anchor_box,
                top_left=False)
            glimpse_prime = extract_affine_glimpse(inp, self.object_shape,
                                                   g_yt, g_xt, g_ys, g_xs,
                                                   self.edge_resampler)
        else:
            g_yt = tf.zeros_like(cyt)
            g_xt = tf.zeros_like(cxt)
            g_ys = tf.zeros_like(ys)
            g_xs = tf.zeros_like(xs)
            glimpse_prime = tf.zeros(
                (batch_size, n_objects, *self.object_shape, self.image_depth))

        new_objects.update(glimpse_prime_box=tf.concat(
            [g_yt, g_xt, g_ys, g_xs], axis=-1), )

        # --- encode glimpse ---

        encoded_glimpse_prime = apply_object_wise(self.glimpse_prime_encoder,
                                                  glimpse_prime,
                                                  n_trailing_dims=3,
                                                  output_size=self.A,
                                                  is_training=self.is_training)

        if not is_posterior:
            encoded_glimpse_prime = tf.zeros((batch_size, n_objects, self.A),
                                             dtype=tf.float32)

        # --- position and scale ---

        d_box_inp = tf.concat([base_features, encoded_glimpse_prime], axis=-1)
        d_box_params = apply_object_wise(self.d_box_network,
                                         d_box_inp,
                                         output_size=8,
                                         is_training=self.is_training)

        d_box_mean, d_box_log_std = tf.split(d_box_params, 2, axis=-1)

        d_box_std = self.std_nonlinearity(d_box_log_std)

        d_box_mean = self.training_wheels * tf.stop_gradient(d_box_mean) + (
            1 - self.training_wheels) * d_box_mean
        d_box_std = self.training_wheels * tf.stop_gradient(d_box_std) + (
            1 - self.training_wheels) * d_box_std

        d_yt_mean, d_xt_mean, d_ys, d_xs = tf.split(d_box_mean, 4, axis=-1)
        d_yt_std, d_xt_std, ys_std, xs_std = tf.split(d_box_std, 4, axis=-1)

        # --- position ---

        # We predict position a bit differently from scale. For scale we want to put a prior on the actual value of
        # the scale, whereas for position we want to put a prior on the difference in position over timesteps.

        d_yt_logit = Normal(loc=d_yt_mean, scale=d_yt_std).sample()
        d_xt_logit = Normal(loc=d_xt_mean, scale=d_xt_std).sample()

        d_yt = self.where_t_scale * tf.nn.tanh(d_yt_logit)
        d_xt = self.where_t_scale * tf.nn.tanh(d_xt_logit)

        new_cyt = cyt + d_yt
        new_cxt = cxt + d_xt

        new_abs_posn = objects.abs_posn + tf.concat([d_yt, d_xt], axis=-1)

        # --- scale ---

        new_ys_mean = objects.ys_logit + d_ys
        new_xs_mean = objects.xs_logit + d_xs

        new_ys_logit = Normal(loc=new_ys_mean, scale=ys_std).sample()
        new_xs_logit = Normal(loc=new_xs_mean, scale=xs_std).sample()

        new_ys = float(self.max_hw - self.min_hw) * tf.nn.sigmoid(
            tf.clip_by_value(new_ys_logit, -10, 10)) + self.min_hw
        new_xs = float(self.max_hw - self.min_hw) * tf.nn.sigmoid(
            tf.clip_by_value(new_xs_logit, -10, 10)) + self.min_hw

        # Used for conditioning
        if self.use_abs_posn:
            box_params = tf.concat([
                new_abs_posn, d_yt_logit, d_xt_logit, new_ys_logit,
                new_xs_logit
            ],
                                   axis=-1)
        else:
            box_params = tf.concat(
                [d_yt_logit, d_xt_logit, new_ys_logit, new_xs_logit], axis=-1)

        new_objects.update(
            abs_posn=new_abs_posn,
            yt=new_cyt,
            xt=new_cxt,
            ys=new_ys,
            xs=new_xs,
            normalized_box=tf.concat([new_cyt, new_cxt, new_ys, new_xs],
                                     axis=-1),
            d_yt_logit=d_yt_logit,
            d_xt_logit=d_xt_logit,
            ys_logit=new_ys_logit,
            xs_logit=new_xs_logit,
            d_yt_logit_mean=d_yt_mean,
            d_xt_logit_mean=d_xt_mean,
            ys_logit_mean=new_ys_mean,
            xs_logit_mean=new_xs_mean,
            d_yt_logit_std=d_yt_std,
            d_xt_logit_std=d_xt_std,
            ys_logit_std=ys_std,
            xs_logit_std=xs_std,
            glimpse_prime=glimpse_prime,
        )

        # --- attributes ---

        # --- extract a glimpse using new box ---

        if is_posterior:
            _, image_height, image_width, _ = tf_shape(inp)
            _new_cyt, _new_cxt, _new_ys, _new_xs = coords_to_image_space(
                new_cyt,
                new_cxt,
                new_ys,
                new_xs, (image_height, image_width),
                self.anchor_box,
                top_left=False)

            glimpse = extract_affine_glimpse(inp, self.object_shape, _new_cyt,
                                             _new_cxt, _new_ys, _new_xs,
                                             self.edge_resampler)

        else:
            glimpse = tf.zeros(
                (batch_size, n_objects, *self.object_shape, self.image_depth))

        encoded_glimpse = apply_object_wise(self.glimpse_encoder,
                                            glimpse,
                                            n_trailing_dims=3,
                                            output_size=self.A,
                                            is_training=self.is_training)

        if not is_posterior:
            encoded_glimpse = tf.zeros((batch_size, n_objects, self.A),
                                       dtype=tf.float32)

        # --- predict change in attributes ---

        d_attr_inp = tf.concat([base_features, box_params, encoded_glimpse],
                               axis=-1)
        d_attr_params = apply_object_wise(self.d_attr_network,
                                          d_attr_inp,
                                          output_size=2 * self.A + 1,
                                          is_training=self.is_training)

        d_attr_mean, d_attr_log_std, gate_logit = tf.split(d_attr_params,
                                                           [self.A, self.A, 1],
                                                           axis=-1)
        d_attr_std = self.std_nonlinearity(d_attr_log_std)

        gate = tf.nn.sigmoid(gate_logit)

        if self.gate_d_attr:
            d_attr_mean *= gate

        d_attr = Normal(loc=d_attr_mean, scale=d_attr_std).sample()

        # --- apply change in attributes ---

        new_attr = objects.attr + d_attr

        new_objects.update(
            attr=new_attr,
            d_attr=d_attr,
            d_attr_mean=d_attr_mean,
            d_attr_std=d_attr_std,
            glimpse=glimpse,
            d_attr_gate=gate,
        )

        # --- z ---

        d_z_inp = tf.concat(
            [base_features, box_params, new_attr, encoded_glimpse], axis=-1)
        d_z_params = apply_object_wise(self.d_z_network,
                                       d_z_inp,
                                       output_size=2,
                                       is_training=self.is_training)

        d_z_mean, d_z_log_std = tf.split(d_z_params, 2, axis=-1)
        d_z_std = self.std_nonlinearity(d_z_log_std)

        d_z_mean = self.training_wheels * tf.stop_gradient(d_z_mean) + (
            1 - self.training_wheels) * d_z_mean
        d_z_std = self.training_wheels * tf.stop_gradient(d_z_std) + (
            1 - self.training_wheels) * d_z_std

        d_z_logit = Normal(loc=d_z_mean, scale=d_z_std).sample()

        new_z_logit = objects.z_logit + d_z_logit
        new_z = self.z_nonlinearity(new_z_logit)

        new_objects.update(
            z=new_z,
            z_logit=new_z_logit,
            d_z_logit=d_z_logit,
            d_z_logit_mean=d_z_mean,
            d_z_logit_std=d_z_std,
        )

        # --- obj ---

        d_obj_inp = tf.concat(
            [base_features, box_params, new_attr, new_z, encoded_glimpse],
            axis=-1)
        d_obj_logit = apply_object_wise(self.d_obj_network,
                                        d_obj_inp,
                                        output_size=1,
                                        is_training=self.is_training)

        d_obj_logit = self.training_wheels * tf.stop_gradient(d_obj_logit) + (
            1 - self.training_wheels) * d_obj_logit
        d_obj_log_odds = tf.clip_by_value(d_obj_logit / self.obj_temp, -10.,
                                          10.)

        d_obj_pre_sigmoid = (self._noisy * concrete_binary_pre_sigmoid_sample(
            d_obj_log_odds, self.obj_concrete_temp) +
                             (1 - self._noisy) * d_obj_log_odds)

        d_obj = tf.nn.sigmoid(d_obj_pre_sigmoid)

        new_obj = objects.obj * d_obj

        new_objects.update(
            d_obj_log_odds=d_obj_log_odds,
            d_obj_prob=tf.nn.sigmoid(d_obj_log_odds),
            d_obj_pre_sigmoid=d_obj_pre_sigmoid,
            d_obj=d_obj,
            obj=new_obj,
        )

        # --- update each object's hidden state --

        cell_input = tf.concat([box_params, new_attr, new_z, new_obj], axis=-1)

        if is_posterior:
            _, new_objects.prop_state = apply_object_wise(
                self.cell, cell_input, objects.prop_state)
            new_objects.prior_prop_state = new_objects.prop_state
        else:
            _, new_objects.prior_prop_state = apply_object_wise(
                self.cell, cell_input, objects.prior_prop_state)
            new_objects.prop_state = new_objects.prior_prop_state

        return new_objects
Example #6
0
        def body(step, stopping_sum, prev_state, running_recon, kl_loss,
                 running_digits, scale_ta, scale_kl_ta, scale_std_ta, shift_ta,
                 shift_kl_ta, shift_std_ta, attr_ta, attr_kl_ta, attr_std_ta,
                 z_pres_ta, z_pres_probs_ta, z_pres_kl_ta, vae_input_ta,
                 vae_output_ta, scale, shift, attr, z_pres):

            if self.difference_air:
                inp = (self._tensors["inp"] -
                       tf.reshape(running_recon,
                                  (self.batch_size, *self.obs_shape)))
                encoded_inp = self.image_encoder(inp, 0, self.is_training)
                encoded_inp = tf.layers.flatten(encoded_inp)
            else:
                encoded_inp = self.encoded_inp

            if self.complete_rnn_input:
                rnn_input = tf.concat(
                    [encoded_inp, scale, shift, attr, z_pres], axis=1)
            else:
                rnn_input = encoded_inp

            hidden_rep, next_state = self.cell(rnn_input, prev_state)

            outputs = self.output_network(hidden_rep, 9, self.is_training)

            (scale_mean, scale_log_std, shift_mean, shift_log_std,
             z_pres_log_odds) = tf.split(outputs, [2, 2, 2, 2, 1], axis=1)

            # --- scale ---

            scale_std = tf.exp(scale_log_std)

            scale_mean = self.apply_fixed_value("scale_mean", scale_mean)
            scale_std = self.apply_fixed_value("scale_std", scale_std)

            scale_logits, scale_kl = normal_vae(scale_mean, scale_std,
                                                self.scale_prior_mean,
                                                self.scale_prior_std)
            scale_kl = tf.reduce_sum(scale_kl, axis=1, keepdims=True)
            scale = tf.nn.sigmoid(tf.clip_by_value(scale_logits, -10, 10))

            # --- shift ---

            shift_std = tf.exp(shift_log_std)

            shift_mean = self.apply_fixed_value("shift_mean", shift_mean)
            shift_std = self.apply_fixed_value("shift_std", shift_std)

            shift_logits, shift_kl = normal_vae(shift_mean, shift_std,
                                                self.shift_prior_mean,
                                                self.shift_prior_std)
            shift_kl = tf.reduce_sum(shift_kl, axis=1, keepdims=True)
            shift = tf.nn.tanh(tf.clip_by_value(shift_logits, -10, 10))

            # --- Extract windows from scene ---

            w, h = scale[:, 0:1], scale[:, 1:2]
            x, y = shift[:, 0:1], shift[:, 1:2]

            theta = tf.concat(
                [w, tf.zeros_like(w), x,
                 tf.zeros_like(h), h, y], axis=1)
            theta = tf.reshape(theta, (-1, 2, 3))

            vae_input = transformer(self._tensors["inp"], theta,
                                    self.object_shape)

            # This is a necessary reshape, as the output of transformer will have unknown dims
            vae_input = tf.reshape(
                vae_input,
                (self.batch_size, *self.object_shape, self.image_depth))

            # --- Apply Object-level VAE (object encoder/object decoder) to windows ---

            attr = self.object_encoder(vae_input, 2 * self.A, self.is_training)
            attr_mean, attr_log_std = tf.split(attr, 2, axis=1)
            attr_std = tf.exp(attr_log_std)
            attr, attr_kl = normal_vae(attr_mean, attr_std,
                                       self.attr_prior_mean,
                                       self.attr_prior_std)
            attr_kl = tf.reduce_sum(attr_kl, axis=1, keepdims=True)

            vae_output = self.object_decoder(
                attr,
                self.object_shape[0] * self.object_shape[1] * self.image_depth,
                self.is_training)
            vae_output = tf.nn.sigmoid(tf.clip_by_value(vae_output, -10, 10))

            # --- Place reconstructed objects in image ---

            theta_inverse = tf.concat([
                1. / w,
                tf.zeros_like(w), -x / w,
                tf.zeros_like(h), 1. / h, -y / h
            ],
                                      axis=1)
            theta_inverse = tf.reshape(theta_inverse, (-1, 2, 3))

            vae_output_transformed = transformer(
                tf.reshape(vae_output, (
                    self.batch_size,
                    *self.object_shape,
                    self.image_depth,
                )), theta_inverse, self.obs_shape[:2])
            vae_output_transformed = tf.reshape(vae_output_transformed, [
                self.batch_size,
                self.image_height * self.image_width * self.image_depth
            ])

            # --- z_pres ---

            if self.run_all_time_steps:
                z_pres = tf.ones_like(z_pres_log_odds)
                z_pres_prob = tf.ones_like(z_pres_log_odds)
                z_pres_kl = tf.zeros_like(z_pres_log_odds)
            else:
                z_pres_log_odds = tf.clip_by_value(z_pres_log_odds, -10, 10)

                z_pres_pre_sigmoid = concrete_binary_pre_sigmoid_sample(
                    z_pres_log_odds, self.z_pres_temperature)
                z_pres = tf.nn.sigmoid(z_pres_pre_sigmoid)
                z_pres = (self.float_is_training * z_pres +
                          (1 - self.float_is_training) * tf.round(z_pres))
                z_pres_prob = tf.nn.sigmoid(z_pres_log_odds)
                z_pres_kl = concrete_binary_sample_kl(
                    z_pres_pre_sigmoid,
                    z_pres_log_odds,
                    self.z_pres_temperature,
                    self.z_pres_prior_log_odds,
                    self.z_pres_temperature,
                )

            stopping_sum += (1.0 - z_pres)
            alive = tf.less(stopping_sum, self.stopping_threshold)
            running_digits += tf.to_int32(alive)

            # --- adjust reconstruction ---

            running_recon += tf.where(
                tf.tile(alive, (1, vae_output_transformed.shape[1])),
                z_pres * vae_output_transformed, tf.zeros_like(running_recon))

            # --- add kl to loss ---

            kl_loss += tf.where(alive, scale_kl, tf.zeros_like(kl_loss))
            kl_loss += tf.where(alive, shift_kl, tf.zeros_like(kl_loss))
            kl_loss += tf.where(alive, attr_kl, tf.zeros_like(kl_loss))
            kl_loss += tf.where(alive, z_pres_kl, tf.zeros_like(kl_loss))

            # --- record values ---

            scale_ta = scale_ta.write(scale_ta.size(), scale)
            scale_kl_ta = scale_kl_ta.write(scale_kl_ta.size(), scale_kl)
            scale_std_ta = scale_std_ta.write(scale_std_ta.size(), scale_std)

            shift_ta = shift_ta.write(shift_ta.size(), shift)
            shift_kl_ta = shift_kl_ta.write(shift_kl_ta.size(), shift_kl)
            shift_std_ta = shift_std_ta.write(shift_std_ta.size(), shift_std)

            attr_ta = attr_ta.write(attr_ta.size(), attr)
            attr_kl_ta = attr_kl_ta.write(attr_kl_ta.size(), attr_kl)
            attr_std_ta = attr_std_ta.write(attr_std_ta.size(), attr_std)

            vae_input_ta = vae_input_ta.write(vae_input_ta.size(),
                                              tf.layers.flatten(vae_input))
            vae_output_ta = vae_output_ta.write(vae_output_ta.size(),
                                                vae_output)

            z_pres_ta = z_pres_ta.write(z_pres_ta.size(), z_pres)
            z_pres_probs_ta = z_pres_probs_ta.write(z_pres_probs_ta.size(),
                                                    z_pres_prob)
            z_pres_kl_ta = z_pres_kl_ta.write(z_pres_kl_ta.size(), z_pres_kl)

            return (
                step + 1,
                stopping_sum,
                next_state,
                running_recon,
                kl_loss,
                running_digits,
                scale_ta,
                scale_kl_ta,
                scale_std_ta,
                shift_ta,
                shift_kl_ta,
                shift_std_ta,
                attr_ta,
                attr_kl_ta,
                attr_std_ta,
                z_pres_ta,
                z_pres_probs_ta,
                z_pres_kl_ta,
                vae_input_ta,
                vae_output_ta,
                scale,
                shift,
                attr,
                z_pres,
            )