Exemplo n.º 1
0
    def _call(self, input_signal, input_locs, output_locs, is_training):
        if not self.is_built:
            self.value_func = self.build_mlp(scope="value_func")
            self.after_func = self.build_mlp(scope="after")

            if self.do_object_wise:
                self.object_wise_func = self.build_object_wise(
                    scope="object_wise")

            self.is_built = True

        batch_size, n_inp, _ = tf_shape(input_signal)
        loc_dim = tf_shape(input_locs)[-1]
        n_outp = tf_shape(output_locs)[-2]
        input_locs = tf.broadcast_to(input_locs, (batch_size, n_inp, loc_dim))
        output_locs = tf.broadcast_to(output_locs,
                                      (batch_size, n_outp, loc_dim))

        dist = output_locs[:, :, None, :] - input_locs[:, None, :, :]
        proximity = tf.exp(-0.5 * tf.reduce_sum(
            (dist / self.kernel_std)**2, axis=3))
        proximity = proximity / (2 * np.pi)**(
            0.5 * loc_dim) / self.kernel_std**loc_dim

        V = apply_object_wise(
            self.value_func,
            input_signal,
            output_size=self.n_hidden,
            is_training=is_training)  # (batch_size, n_inp, value_dim)

        result = tf.matmul(proximity, V)  # (batch_size, n_outp, value_dim)

        # `after_func` is applied to the concatenation of the head outputs, and the result is added to the original
        # signal. Next, if `object_wise_func` is not None and `do_object_wise` is True, object_wise_func is
        # applied object wise and in a ResNet-style manner.

        output = apply_object_wise(self.after_func,
                                   result,
                                   output_size=self.n_hidden,
                                   is_training=is_training)
        output = tf.layers.dropout(output,
                                   self.p_dropout,
                                   training=is_training)
        signal = tf.contrib.layers.layer_norm(output)

        if self.do_object_wise:
            output = apply_object_wise(self.object_wise_func,
                                       signal,
                                       output_size=self.n_hidden,
                                       is_training=is_training)
            output = tf.layers.dropout(output,
                                       self.p_dropout,
                                       training=is_training)
            signal = tf.contrib.layers.layer_norm(signal + output)

        return signal
Exemplo n.º 2
0
    def _call(self, input_locs, input_features, reference_locs,
              reference_features, is_training):
        """ Assumes input_features and reference_features are identical. """
        assert self.do_object_wise

        if not self.is_built:
            self.object_wise_func = self.build_mlp(scope="object_wise_func")
            self.is_built = True

        object_wise = apply_object_wise(
            self.object_wise_func,
            reference_features,
            output_size=self.n_hidden,
            is_training=is_training)  # (B, n_ref, n_hidden)

        return object_wise
Exemplo n.º 3
0
    def _loop_body(self, f, conf, hidden_states, *tensor_arrays):
        batch_size = self.batch_size

        memory = self.encoded_frames[:, f]  # (batch_size, H*W, S)

        delta = 0.0001 * tf.range(self.n_trackers, dtype=tf.float32)[:, None,
                                                                     None]
        sort_criteria = tf.round(conf) - delta

        sorted_order = tf.contrib.framework.argsort(sort_criteria,
                                                    axis=0,
                                                    direction='DESCENDING')
        sorted_order = tf.reshape(sorted_order,
                                  (self.n_trackers, batch_size, 1))

        order = tf.cond(
            tf.logical_or(tf.equal(f, 0), not self.prioritize),
            lambda: tf.tile(
                tf.range(self.n_trackers)[:, None, None], (1, batch_size, 1)),
            lambda: sorted_order,
        )
        order = tf.reshape(order, (self.n_trackers, batch_size, 1))

        inverse_order = tf.contrib.framework.argsort(order,
                                                     axis=0,
                                                     direction='ASCENDING')

        tensors = defaultdict(list)

        for i in range(self.n_trackers):
            tensors["memory_activation"].append(
                tf.reduce_mean(tf.abs(memory), axis=2))

            # --- apply ordering if applicable ---

            indices = order[i]  # (batch_size, 1)
            indexor = tf.concat(
                [indices, tf.range(batch_size)[:, None]],
                axis=1)  # (batch_size, 2)
            _hidden_states = tf.gather_nd(hidden_states,
                                          indexor)  # (batch_size, n_hidden)

            # --- access the memory using spatial attention ---

            keys = self.key_network(_hidden_states, self.S,
                                    self.is_training)  # (batch_size, self.S)
            beta_logit = self.beta_network(_hidden_states, 1,
                                           self.is_training)  # (batch_size, 1)

            # beta = 1 + tf.math.softplus(beta_logit)

            beta_pos = tf.maximum(0.0, beta_logit)
            beta_neg = tf.minimum(0.0, beta_logit)
            beta = tf.log1p(tf.exp(beta_neg)) + beta_pos + tf.log1p(
                tf.exp(-beta_pos)) + (1 - np.log(2))

            _memory = tf.identity(memory)
            _memory = limit_grad_norm(_memory, 1.)

            key_activation = beta * tf_cosine_similarity(
                _memory, keys[:, None, :])  # (batch_size, H*W)
            attention_weights = tf.nn.softmax(
                key_activation, axis=1)[:, :, None]  # (batch_size, H*W, 1)

            _attention_weights = tf.identity(attention_weights)
            _attention_weights = limit_grad_norm(_attention_weights, 1.)

            attention_result = tf.reduce_sum(_attention_weights * memory,
                                             axis=1)  # (batch_size, S)

            # --- update tracker hidden state and output ---

            tracker_output, new_hidden = self.cell(attention_result,
                                                   _hidden_states)

            # --- update the memory for the next trackers ---

            write = self.write_network(tracker_output, self.S,
                                       self.is_training)
            erase = self.erase_network(tracker_output, self.S,
                                       self.is_training)
            erase = tf.nn.sigmoid(erase)

            memory = ((1 - attention_weights * erase[:, None, :]) * memory +
                      attention_weights * write[:, None, :])

            tensors["hidden_states"].append(new_hidden)
            tensors["tracker_output"].append(tracker_output)
            tensors["attention_result"].append(attention_result)
            tensors["attention_weights"].append(attention_weights[..., 0])

        tensors = {k: tf.stack(v, axis=0) for k, v in tensors.items()}

        # --- invert the ordering ---

        batch_indices = tf.tile(
            tf.range(batch_size)[None, :, None], (self.n_trackers, 1, 1))
        inverse_indexor = tf.concat([inverse_order, batch_indices],
                                    axis=2)  # (n_trackers, batch_size, 2)
        tensors = {
            k: tf.gather_nd(v, inverse_indexor)
            for k, v in tensors.items()
        }

        # --- compute the output values ---

        output = apply_object_wise(self.output_network,
                                   tensors["tracker_output"],
                                   output_size=self.output_size_per_object,
                                   is_training=self.is_training)

        conf, layer, pose, mask, appearance = tf.split(output, [
            1, self.n_layers, 4,
            np.prod(self.object_shape),
            self.image_depth * np.prod(self.object_shape)
        ],
                                                       axis=-1)

        conf = tf.abs(tf.nn.tanh(conf))

        conf = (self.float_is_training * conf +
                (1 - self.float_is_training) * tf.round(conf))

        layer = tf.nn.softmax(layer, axis=-1)
        layer = tf.transpose(layer, (1, 0, 2))
        layer = limit_grad_norm(layer, 10.)
        layer = tf.transpose(layer, (1, 0, 2))
        layer = tfp.distributions.RelaxedOneHotCategorical(
            self.layer_temperature, probs=layer).sample()

        pose = tf.nn.tanh(pose)

        mask = tfp.distributions.RelaxedBernoulli(self.mask_temperature,
                                                  logits=mask).sample()

        if self.fixed_mask:
            mask = tf.ones_like(mask)

        appearance = tf.nn.sigmoid(appearance)

        output = dict(conf=conf,
                      layer=layer,
                      pose=pose,
                      mask=mask,
                      appearance=appearance,
                      order=order,
                      **tensors)

        tensor_arrays = append_to_tensor_arrays(f, output, tensor_arrays)

        f += 1

        return [f, conf, tensors["hidden_states"], *tensor_arrays]
Exemplo n.º 4
0
    def _body(self, inp, features, objects, is_posterior):
        """
        Summary of how updates are done for the different variables:

        glimpse': glimpse_params = where + 0.1 * predicted_logit

        where_y/x: new_where_y/x = where_y/x + where_t_scale * tanh(predicted_logit)
        where_h/w: new_where_h/w_logit = where_h/w_logit + predicted_logit
        what: Hard to summarize here, taken from SQAIR. Kind of like an LSTM.
        depth: new_depth_logit = depth_logit + predicted_logit
        obj: new_obj = obj * sigmoid(predicted_logit)

        """
        batch_size, n_objects, _ = tf_shape(features)

        new_objects = AttrDict()

        is_posterior_tf = tf.ones_like(features[..., 0:2])
        if is_posterior:
            is_posterior_tf = is_posterior_tf * [1, 0]
        else:
            is_posterior_tf = is_posterior_tf * [0, 1]

        base_features = tf.concat([features, is_posterior_tf], axis=-1)

        cyt, cxt, ys, xs = tf.split(objects.normalized_box, 4, axis=-1)

        # Do this regardless of is_posterior, otherwise ScopedFunction gets messed up
        glimpse_dim = self.object_shape[0] * self.object_shape[1]
        glimpse_prime_params = apply_object_wise(self.glimpse_prime_network,
                                                 base_features,
                                                 output_size=4 +
                                                 2 * glimpse_dim,
                                                 is_training=self.is_training)

        glimpse_prime_params, glimpse_prime_mask_logit, glimpse_mask_logit = \
            tf.split(glimpse_prime_params, [4, glimpse_dim, glimpse_dim], axis=-1)

        if is_posterior:
            # --- obtain final parameters for glimpse prime by modifying current pose ---

            _yt, _xt, _ys, _xs = tf.split(glimpse_prime_params, 4, axis=-1)

            g_yt = cyt + 0.1 * _yt
            g_xt = cxt + 0.1 * _xt
            g_ys = ys + 0.1 * _ys
            g_xs = xs + 0.1 * _xs

            # --- extract glimpse prime ---

            _, image_height, image_width, _ = tf_shape(inp)
            g_yt, g_xt, g_ys, g_xs = coords_to_image_space(
                g_yt,
                g_xt,
                g_ys,
                g_xs, (image_height, image_width),
                self.anchor_box,
                top_left=False)
            glimpse_prime = extract_affine_glimpse(inp, self.object_shape,
                                                   g_yt, g_xt, g_ys, g_xs,
                                                   self.edge_resampler)
        else:
            g_yt = tf.zeros_like(cyt)
            g_xt = tf.zeros_like(cxt)
            g_ys = tf.zeros_like(ys)
            g_xs = tf.zeros_like(xs)
            glimpse_prime = tf.zeros(
                (batch_size, n_objects, *self.object_shape, self.image_depth))

        glimpse_prime_mask = tf.nn.sigmoid(glimpse_prime_mask_logit + 1.)
        leading_mask_shape = tf_shape(glimpse_prime)[:-1]
        glimpse_prime_mask = tf.reshape(glimpse_prime_mask,
                                        (*leading_mask_shape, 1))

        new_objects.update(
            glimpse_prime_box=tf.concat([g_yt, g_xt, g_ys, g_xs], axis=-1),
            glimpse_prime=glimpse_prime,
            glimpse_prime_mask=glimpse_prime_mask,
        )

        glimpse_prime *= glimpse_prime_mask

        # --- encode glimpse ---

        encoded_glimpse_prime = apply_object_wise(self.glimpse_prime_encoder,
                                                  glimpse_prime,
                                                  n_trailing_dims=3,
                                                  output_size=self.A,
                                                  is_training=self.is_training)

        if not is_posterior:
            encoded_glimpse_prime = tf.zeros((batch_size, n_objects, self.A),
                                             dtype=tf.float32)

        # --- position and scale ---

        # roughly:
        # base_features == temporal_state, encoded_glimpse_prime == hidden_output
        # hidden_output conditions on encoded_glimpse, and that's the only place encoded_glimpse_prime is used.

        # Here SQAIR conditions on the actual location values from the previous timestep, but we leave that out for now.
        d_box_inp = tf.concat([base_features, encoded_glimpse_prime], axis=-1)
        d_box_params = apply_object_wise(self.d_box_network,
                                         d_box_inp,
                                         output_size=8,
                                         is_training=self.is_training)

        d_box_mean, d_box_log_std = tf.split(d_box_params, 2, axis=-1)

        d_box_std = self.std_nonlinearity(d_box_log_std)

        d_box_mean = self.training_wheels * tf.stop_gradient(d_box_mean) + (
            1 - self.training_wheels) * d_box_mean
        d_box_std = self.training_wheels * tf.stop_gradient(d_box_std) + (
            1 - self.training_wheels) * d_box_std

        d_yt_mean, d_xt_mean, d_ys, d_xs = tf.split(d_box_mean, 4, axis=-1)
        d_yt_std, d_xt_std, ys_std, xs_std = tf.split(d_box_std, 4, axis=-1)

        # --- position ---

        # We predict position a bit differently from scale. For scale we want to put a prior on the actual value of
        # the scale, whereas for position we want to put a prior on the difference in position over timesteps.

        d_yt_logit = Normal(loc=d_yt_mean, scale=d_yt_std).sample()
        d_xt_logit = Normal(loc=d_xt_mean, scale=d_xt_std).sample()

        d_yt = self.where_t_scale * tf.nn.tanh(d_yt_logit)
        d_xt = self.where_t_scale * tf.nn.tanh(d_xt_logit)

        new_cyt = cyt + d_yt
        new_cxt = cxt + d_xt

        new_abs_posn = objects.abs_posn + tf.concat([d_yt, d_xt], axis=-1)

        # --- scale ---

        new_ys_mean = objects.ys_logit + d_ys
        new_xs_mean = objects.xs_logit + d_xs

        new_ys_logit = Normal(loc=new_ys_mean, scale=ys_std).sample()
        new_xs_logit = Normal(loc=new_xs_mean, scale=xs_std).sample()

        new_ys = float(self.max_hw - self.min_hw) * tf.nn.sigmoid(
            tf.clip_by_value(new_ys_logit, -10, 10)) + self.min_hw
        new_xs = float(self.max_hw - self.min_hw) * tf.nn.sigmoid(
            tf.clip_by_value(new_xs_logit, -10, 10)) + self.min_hw

        if self.use_abs_posn:
            box_params = tf.concat([
                new_abs_posn, d_yt_logit, d_xt_logit, new_ys_logit,
                new_xs_logit
            ],
                                   axis=-1)
        else:
            box_params = tf.concat(
                [d_yt_logit, d_xt_logit, new_ys_logit, new_xs_logit], axis=-1)

        new_objects.update(
            abs_posn=new_abs_posn,
            yt=new_cyt,
            xt=new_cxt,
            ys=new_ys,
            xs=new_xs,
            normalized_box=tf.concat([new_cyt, new_cxt, new_ys, new_xs],
                                     axis=-1),
            d_yt_logit=d_yt_logit,
            d_xt_logit=d_xt_logit,
            ys_logit=new_ys_logit,
            xs_logit=new_xs_logit,
            d_yt_logit_mean=d_yt_mean,
            d_xt_logit_mean=d_xt_mean,
            ys_logit_mean=new_ys_mean,
            xs_logit_mean=new_xs_mean,
            d_yt_logit_std=d_yt_std,
            d_xt_logit_std=d_xt_std,
            ys_logit_std=ys_std,
            xs_logit_std=xs_std,
        )

        # --- attributes ---

        # --- extract a glimpse using new box ---

        if is_posterior:
            _, image_height, image_width, _ = tf_shape(inp)
            _new_cyt, _new_cxt, _new_ys, _new_xs = coords_to_image_space(
                new_cyt,
                new_cxt,
                new_ys,
                new_xs, (image_height, image_width),
                self.anchor_box,
                top_left=False)

            glimpse = extract_affine_glimpse(inp, self.object_shape, _new_cyt,
                                             _new_cxt, _new_ys, _new_xs,
                                             self.edge_resampler)

        else:
            glimpse = tf.zeros(
                (batch_size, n_objects, *self.object_shape, self.image_depth))

        glimpse_mask = tf.nn.sigmoid(glimpse_mask_logit + 1.)
        leading_mask_shape = tf_shape(glimpse)[:-1]
        glimpse_mask = tf.reshape(glimpse_mask, (*leading_mask_shape, 1))

        glimpse *= glimpse_mask

        encoded_glimpse = apply_object_wise(self.glimpse_encoder,
                                            glimpse,
                                            n_trailing_dims=3,
                                            output_size=self.A,
                                            is_training=self.is_training)

        if not is_posterior:
            encoded_glimpse = tf.zeros((batch_size, n_objects, self.A),
                                       dtype=tf.float32)

        # --- predict change in attributes ---

        # so under sqair we mix between three different values for the attributes:
        # 1. value from previous timestep
        # 2. value predicted directly from glimpse
        # 3. value predicted based on update of temporal cell...this update conditions on hidden_output,
        # the prediction in #2., and the where values.

        # How to do this given that we are predicting the change in attr? We could just directly predict
        # the attr instead, but call it d_attr. After all, it is in this function that we control
        # whether d_attr is added to attr.

        # So, make a prediction based on just the input:

        attr_from_inp = apply_object_wise(self.predict_attr_inp,
                                          encoded_glimpse,
                                          output_size=2 * self.A,
                                          is_training=self.is_training)
        attr_from_inp_mean, attr_from_inp_log_std = tf.split(attr_from_inp,
                                                             [self.A, self.A],
                                                             axis=-1)

        attr_from_inp_std = self.std_nonlinearity(attr_from_inp_log_std)

        # And then a prediction which takes the past into account (predicting gate values at the same time):

        attr_from_temp_inp = tf.concat(
            [base_features, box_params, encoded_glimpse], axis=-1)
        attr_from_temp = apply_object_wise(self.predict_attr_temp,
                                           attr_from_temp_inp,
                                           output_size=5 * self.A,
                                           is_training=self.is_training)

        (attr_from_temp_mean, attr_from_temp_log_std, f_gate_logit,
         i_gate_logit, t_gate_logit) = tf.split(attr_from_temp, 5, axis=-1)

        attr_from_temp_std = self.std_nonlinearity(attr_from_temp_log_std)

        # bias the gates
        f_gate = tf.nn.sigmoid(f_gate_logit + 1) * .9999
        i_gate = tf.nn.sigmoid(i_gate_logit + 1) * .9999
        t_gate = tf.nn.sigmoid(t_gate_logit + 1) * .9999

        attr_mean = f_gate * objects.attr + (
            1 - i_gate) * attr_from_inp_mean + (1 -
                                                t_gate) * attr_from_temp_mean
        attr_std = (1 - i_gate) * attr_from_inp_std + (
            1 - t_gate) * attr_from_temp_std

        new_attr = Normal(loc=attr_mean, scale=attr_std).sample()

        # --- apply change in attributes ---

        new_objects.update(
            attr=new_attr,
            d_attr=new_attr - objects.attr,
            d_attr_mean=attr_mean - objects.attr,
            d_attr_std=attr_std,
            f_gate=f_gate,
            i_gate=i_gate,
            t_gate=t_gate,
            glimpse=glimpse,
            glimpse_mask=glimpse_mask,
        )

        # --- z ---

        d_z_inp = tf.concat(
            [base_features, box_params, new_attr, encoded_glimpse], axis=-1)
        d_z_params = apply_object_wise(self.d_z_network,
                                       d_z_inp,
                                       output_size=2,
                                       is_training=self.is_training)

        d_z_mean, d_z_log_std = tf.split(d_z_params, 2, axis=-1)
        d_z_std = self.std_nonlinearity(d_z_log_std)

        d_z_mean = self.training_wheels * tf.stop_gradient(d_z_mean) + (
            1 - self.training_wheels) * d_z_mean
        d_z_std = self.training_wheels * tf.stop_gradient(d_z_std) + (
            1 - self.training_wheels) * d_z_std

        d_z_logit = Normal(loc=d_z_mean, scale=d_z_std).sample()

        new_z_logit = objects.z_logit + d_z_logit
        new_z = self.z_nonlinearity(new_z_logit)

        new_objects.update(
            z=new_z,
            z_logit=new_z_logit,
            d_z_logit=d_z_logit,
            d_z_logit_mean=d_z_mean,
            d_z_logit_std=d_z_std,
        )

        # --- obj ---

        d_obj_inp = tf.concat(
            [base_features, box_params, new_attr, new_z, encoded_glimpse],
            axis=-1)
        d_obj_logit = apply_object_wise(self.d_obj_network,
                                        d_obj_inp,
                                        output_size=1,
                                        is_training=self.is_training)

        d_obj_logit = self.training_wheels * tf.stop_gradient(d_obj_logit) + (
            1 - self.training_wheels) * d_obj_logit
        d_obj_log_odds = tf.clip_by_value(d_obj_logit / self.obj_temp, -10.,
                                          10.)

        d_obj_pre_sigmoid = (self._noisy * concrete_binary_pre_sigmoid_sample(
            d_obj_log_odds, self.obj_concrete_temp) +
                             (1 - self._noisy) * d_obj_log_odds)

        d_obj = tf.nn.sigmoid(d_obj_pre_sigmoid)

        new_obj = objects.obj * d_obj

        new_objects.update(
            d_obj_log_odds=d_obj_log_odds,
            d_obj_prob=tf.nn.sigmoid(d_obj_log_odds),
            d_obj_pre_sigmoid=d_obj_pre_sigmoid,
            d_obj=d_obj,
            obj=new_obj,
        )

        # --- update each object's hidden state --

        cell_input = tf.concat([box_params, new_attr, new_z, new_obj], axis=-1)

        if is_posterior:
            _, new_objects.prop_state = apply_object_wise(
                self.cell, cell_input, objects.prop_state)
            new_objects.prior_prop_state = new_objects.prop_state
        else:
            _, new_objects.prior_prop_state = apply_object_wise(
                self.cell, cell_input, objects.prior_prop_state)
            new_objects.prop_state = new_objects.prior_prop_state

        return new_objects
Exemplo n.º 5
0
    def _body(self, inp, features, objects, is_posterior):
        batch_size, n_objects, _ = tf_shape(features)

        new_objects = AttrDict()

        is_posterior_tf = tf.ones_like(features[..., 0:2])
        if is_posterior:
            is_posterior_tf = is_posterior_tf * [1, 0]
        else:
            is_posterior_tf = is_posterior_tf * [0, 1]

        base_features = tf.concat([features, is_posterior_tf], axis=-1)

        cyt, cxt, ys, xs = tf.split(objects.normalized_box, 4, axis=-1)

        if self.learn_glimpse_prime:
            # Do this regardless of is_posterior, otherwise ScopedFunction gets messed up
            glimpse_prime_params = apply_object_wise(
                self.glimpse_prime_network,
                base_features,
                output_size=4,
                is_training=self.is_training)
        else:
            glimpse_prime_params = tf.zeros_like(base_features[..., :4])

        if is_posterior:

            if self.learn_glimpse_prime:
                # --- obtain final parameters for glimpse prime by modifying current pose ---
                _yt, _xt, _ys, _xs = tf.split(glimpse_prime_params, 4, axis=-1)

                # This is how it is done in SQAIR
                g_yt = cyt + 0.1 * _yt
                g_xt = cxt + 0.1 * _xt
                g_ys = ys + 0.1 * _ys
                g_xs = xs + 0.1 * _xs

                # g_yt = cyt + self.glimpse_prime_scale * tf.nn.tanh(_yt)
                # g_xt = cxt + self.glimpse_prime_scale * tf.nn.tanh(_xt)
                # g_ys = ys + self.glimpse_prime_scale * tf.nn.tanh(_ys)
                # g_xs = xs + self.glimpse_prime_scale * tf.nn.tanh(_xs)
            else:
                g_yt = cyt
                g_xt = cxt
                g_ys = self.glimpse_prime_scale * ys
                g_xs = self.glimpse_prime_scale * xs

            # --- extract glimpse prime ---

            _, image_height, image_width, _ = tf_shape(inp)
            g_yt, g_xt, g_ys, g_xs = coords_to_image_space(
                g_yt,
                g_xt,
                g_ys,
                g_xs, (image_height, image_width),
                self.anchor_box,
                top_left=False)
            glimpse_prime = extract_affine_glimpse(inp, self.object_shape,
                                                   g_yt, g_xt, g_ys, g_xs,
                                                   self.edge_resampler)
        else:
            g_yt = tf.zeros_like(cyt)
            g_xt = tf.zeros_like(cxt)
            g_ys = tf.zeros_like(ys)
            g_xs = tf.zeros_like(xs)
            glimpse_prime = tf.zeros(
                (batch_size, n_objects, *self.object_shape, self.image_depth))

        new_objects.update(glimpse_prime_box=tf.concat(
            [g_yt, g_xt, g_ys, g_xs], axis=-1), )

        # --- encode glimpse ---

        encoded_glimpse_prime = apply_object_wise(self.glimpse_prime_encoder,
                                                  glimpse_prime,
                                                  n_trailing_dims=3,
                                                  output_size=self.A,
                                                  is_training=self.is_training)

        if not is_posterior:
            encoded_glimpse_prime = tf.zeros((batch_size, n_objects, self.A),
                                             dtype=tf.float32)

        # --- position and scale ---

        d_box_inp = tf.concat([base_features, encoded_glimpse_prime], axis=-1)
        d_box_params = apply_object_wise(self.d_box_network,
                                         d_box_inp,
                                         output_size=8,
                                         is_training=self.is_training)

        d_box_mean, d_box_log_std = tf.split(d_box_params, 2, axis=-1)

        d_box_std = self.std_nonlinearity(d_box_log_std)

        d_box_mean = self.training_wheels * tf.stop_gradient(d_box_mean) + (
            1 - self.training_wheels) * d_box_mean
        d_box_std = self.training_wheels * tf.stop_gradient(d_box_std) + (
            1 - self.training_wheels) * d_box_std

        d_yt_mean, d_xt_mean, d_ys, d_xs = tf.split(d_box_mean, 4, axis=-1)
        d_yt_std, d_xt_std, ys_std, xs_std = tf.split(d_box_std, 4, axis=-1)

        # --- position ---

        # We predict position a bit differently from scale. For scale we want to put a prior on the actual value of
        # the scale, whereas for position we want to put a prior on the difference in position over timesteps.

        d_yt_logit = Normal(loc=d_yt_mean, scale=d_yt_std).sample()
        d_xt_logit = Normal(loc=d_xt_mean, scale=d_xt_std).sample()

        d_yt = self.where_t_scale * tf.nn.tanh(d_yt_logit)
        d_xt = self.where_t_scale * tf.nn.tanh(d_xt_logit)

        new_cyt = cyt + d_yt
        new_cxt = cxt + d_xt

        new_abs_posn = objects.abs_posn + tf.concat([d_yt, d_xt], axis=-1)

        # --- scale ---

        new_ys_mean = objects.ys_logit + d_ys
        new_xs_mean = objects.xs_logit + d_xs

        new_ys_logit = Normal(loc=new_ys_mean, scale=ys_std).sample()
        new_xs_logit = Normal(loc=new_xs_mean, scale=xs_std).sample()

        new_ys = float(self.max_hw - self.min_hw) * tf.nn.sigmoid(
            tf.clip_by_value(new_ys_logit, -10, 10)) + self.min_hw
        new_xs = float(self.max_hw - self.min_hw) * tf.nn.sigmoid(
            tf.clip_by_value(new_xs_logit, -10, 10)) + self.min_hw

        # Used for conditioning
        if self.use_abs_posn:
            box_params = tf.concat([
                new_abs_posn, d_yt_logit, d_xt_logit, new_ys_logit,
                new_xs_logit
            ],
                                   axis=-1)
        else:
            box_params = tf.concat(
                [d_yt_logit, d_xt_logit, new_ys_logit, new_xs_logit], axis=-1)

        new_objects.update(
            abs_posn=new_abs_posn,
            yt=new_cyt,
            xt=new_cxt,
            ys=new_ys,
            xs=new_xs,
            normalized_box=tf.concat([new_cyt, new_cxt, new_ys, new_xs],
                                     axis=-1),
            d_yt_logit=d_yt_logit,
            d_xt_logit=d_xt_logit,
            ys_logit=new_ys_logit,
            xs_logit=new_xs_logit,
            d_yt_logit_mean=d_yt_mean,
            d_xt_logit_mean=d_xt_mean,
            ys_logit_mean=new_ys_mean,
            xs_logit_mean=new_xs_mean,
            d_yt_logit_std=d_yt_std,
            d_xt_logit_std=d_xt_std,
            ys_logit_std=ys_std,
            xs_logit_std=xs_std,
            glimpse_prime=glimpse_prime,
        )

        # --- attributes ---

        # --- extract a glimpse using new box ---

        if is_posterior:
            _, image_height, image_width, _ = tf_shape(inp)
            _new_cyt, _new_cxt, _new_ys, _new_xs = coords_to_image_space(
                new_cyt,
                new_cxt,
                new_ys,
                new_xs, (image_height, image_width),
                self.anchor_box,
                top_left=False)

            glimpse = extract_affine_glimpse(inp, self.object_shape, _new_cyt,
                                             _new_cxt, _new_ys, _new_xs,
                                             self.edge_resampler)

        else:
            glimpse = tf.zeros(
                (batch_size, n_objects, *self.object_shape, self.image_depth))

        encoded_glimpse = apply_object_wise(self.glimpse_encoder,
                                            glimpse,
                                            n_trailing_dims=3,
                                            output_size=self.A,
                                            is_training=self.is_training)

        if not is_posterior:
            encoded_glimpse = tf.zeros((batch_size, n_objects, self.A),
                                       dtype=tf.float32)

        # --- predict change in attributes ---

        d_attr_inp = tf.concat([base_features, box_params, encoded_glimpse],
                               axis=-1)
        d_attr_params = apply_object_wise(self.d_attr_network,
                                          d_attr_inp,
                                          output_size=2 * self.A + 1,
                                          is_training=self.is_training)

        d_attr_mean, d_attr_log_std, gate_logit = tf.split(d_attr_params,
                                                           [self.A, self.A, 1],
                                                           axis=-1)
        d_attr_std = self.std_nonlinearity(d_attr_log_std)

        gate = tf.nn.sigmoid(gate_logit)

        if self.gate_d_attr:
            d_attr_mean *= gate

        d_attr = Normal(loc=d_attr_mean, scale=d_attr_std).sample()

        # --- apply change in attributes ---

        new_attr = objects.attr + d_attr

        new_objects.update(
            attr=new_attr,
            d_attr=d_attr,
            d_attr_mean=d_attr_mean,
            d_attr_std=d_attr_std,
            glimpse=glimpse,
            d_attr_gate=gate,
        )

        # --- z ---

        d_z_inp = tf.concat(
            [base_features, box_params, new_attr, encoded_glimpse], axis=-1)
        d_z_params = apply_object_wise(self.d_z_network,
                                       d_z_inp,
                                       output_size=2,
                                       is_training=self.is_training)

        d_z_mean, d_z_log_std = tf.split(d_z_params, 2, axis=-1)
        d_z_std = self.std_nonlinearity(d_z_log_std)

        d_z_mean = self.training_wheels * tf.stop_gradient(d_z_mean) + (
            1 - self.training_wheels) * d_z_mean
        d_z_std = self.training_wheels * tf.stop_gradient(d_z_std) + (
            1 - self.training_wheels) * d_z_std

        d_z_logit = Normal(loc=d_z_mean, scale=d_z_std).sample()

        new_z_logit = objects.z_logit + d_z_logit
        new_z = self.z_nonlinearity(new_z_logit)

        new_objects.update(
            z=new_z,
            z_logit=new_z_logit,
            d_z_logit=d_z_logit,
            d_z_logit_mean=d_z_mean,
            d_z_logit_std=d_z_std,
        )

        # --- obj ---

        d_obj_inp = tf.concat(
            [base_features, box_params, new_attr, new_z, encoded_glimpse],
            axis=-1)
        d_obj_logit = apply_object_wise(self.d_obj_network,
                                        d_obj_inp,
                                        output_size=1,
                                        is_training=self.is_training)

        d_obj_logit = self.training_wheels * tf.stop_gradient(d_obj_logit) + (
            1 - self.training_wheels) * d_obj_logit
        d_obj_log_odds = tf.clip_by_value(d_obj_logit / self.obj_temp, -10.,
                                          10.)

        d_obj_pre_sigmoid = (self._noisy * concrete_binary_pre_sigmoid_sample(
            d_obj_log_odds, self.obj_concrete_temp) +
                             (1 - self._noisy) * d_obj_log_odds)

        d_obj = tf.nn.sigmoid(d_obj_pre_sigmoid)

        new_obj = objects.obj * d_obj

        new_objects.update(
            d_obj_log_odds=d_obj_log_odds,
            d_obj_prob=tf.nn.sigmoid(d_obj_log_odds),
            d_obj_pre_sigmoid=d_obj_pre_sigmoid,
            d_obj=d_obj,
            obj=new_obj,
        )

        # --- update each object's hidden state --

        cell_input = tf.concat([box_params, new_attr, new_z, new_obj], axis=-1)

        if is_posterior:
            _, new_objects.prop_state = apply_object_wise(
                self.cell, cell_input, objects.prop_state)
            new_objects.prior_prop_state = new_objects.prop_state
        else:
            _, new_objects.prior_prop_state = apply_object_wise(
                self.cell, cell_input, objects.prior_prop_state)
            new_objects.prop_state = new_objects.prior_prop_state

        return new_objects
Exemplo n.º 6
0
    def _call(self, inp, inp_features, is_training, is_posterior=True, prop_state=None):
        print("\n" + "-" * 10 + " ConvGridObjectLayer(is_posterior={}) ".format(is_posterior) + "-" * 10)

        # --- set up sub networks and attributes ---

        self.maybe_build_subnet("box_network", builder=cfg.build_conv_lateral, key="box")
        self.maybe_build_subnet("attr_network", builder=cfg.build_conv_lateral, key="attr")
        self.maybe_build_subnet("z_network", builder=cfg.build_conv_lateral, key="z")
        self.maybe_build_subnet("obj_network", builder=cfg.build_conv_lateral, key="obj")

        self.maybe_build_subnet("object_encoder")

        _, H, W, _, n_channels = tf_shape(inp_features)

        if self.B != 1:
            raise Exception("NotImplemented")

        if not self.initialized:
            # Note this limits the re-usability of this module to images
            # with a fixed shape (the shape of the first image it is used on)
            self.batch_size, self.image_height, self.image_width, self.image_depth = tf_shape(inp)
            self.H = H
            self.W = W
            self.HWB = H*W
            self.batch_size = tf.shape(inp)[0]
            self.is_training = is_training
            self.float_is_training = tf.to_float(is_training)

        inp_features = tf.reshape(inp_features, (self.batch_size, H, W, n_channels))
        is_posterior_tf = tf.ones_like(inp_features[..., :2])
        if is_posterior:
            is_posterior_tf = is_posterior_tf * [1, 0]
        else:
            is_posterior_tf = is_posterior_tf * [0, 1]

        objects = AttrDict()

        base_features = tf.concat([inp_features, is_posterior_tf], axis=-1)

        # --- box ---

        layer_inp = base_features
        n_features = self.n_passthrough_features
        output_size = 8

        network_output = self.box_network(layer_inp, output_size + n_features, self.is_training)
        rep_input, features = tf.split(network_output, (output_size, n_features), axis=-1)

        _objects = self._build_box(rep_input, self.is_training)
        objects.update(_objects)

        # --- attr ---

        if is_posterior:
            # --- Get object attributes using object encoder ---

            yt, xt, ys, xs = tf.split(objects['normalized_box'], 4, axis=-1)

            yt, xt, ys, xs = coords_to_image_space(
                yt, xt, ys, xs, (self.image_height, self.image_width), self.anchor_box, top_left=False)

            transform_constraints = snt.AffineWarpConstraints.no_shear_2d()
            warper = snt.AffineGridWarper(
                (self.image_height, self.image_width), self.object_shape, transform_constraints)

            _boxes = tf.concat([xs, 2*xt - 1, ys, 2*yt - 1], axis=-1)
            _boxes = tf.reshape(_boxes, (self.batch_size*H*W, 4))
            grid_coords = warper(_boxes)
            grid_coords = tf.reshape(grid_coords, (self.batch_size, H, W, *self.object_shape, 2,))

            if self.edge_resampler:
                glimpse = resampler_edge.resampler_edge(inp, grid_coords)
            else:
                glimpse = tf.contrib.resampler.resampler(inp, grid_coords)
        else:
            glimpse = tf.zeros((self.batch_size, H, W, *self.object_shape, self.image_depth))

        # Create the object encoder network regardless of is_posterior, otherwise messes with ScopedFunction
        encoded_glimpse = apply_object_wise(
            self.object_encoder, glimpse, n_trailing_dims=3, output_size=self.A, is_training=self.is_training)

        if not is_posterior:
            encoded_glimpse = tf.zeros_like(encoded_glimpse)

        layer_inp = tf.concat([base_features, features, encoded_glimpse, objects['local_box']], axis=-1)
        network_output = self.attr_network(layer_inp, 2 * self.A + n_features, self.is_training)
        attr_mean, attr_log_std, features = tf.split(network_output, (self.A, self.A, n_features), axis=-1)

        attr_std = self.std_nonlinearity(attr_log_std)

        attr = Normal(loc=attr_mean, scale=attr_std).sample()

        objects.update(attr_mean=attr_mean, attr_std=attr_std, attr=attr, glimpse=glimpse)

        # --- z ---

        layer_inp = tf.concat([base_features, features, objects['local_box'], objects['attr']], axis=-1)
        n_features = self.n_passthrough_features

        network_output = self.z_network(layer_inp, 2 + n_features, self.is_training)
        z_mean, z_log_std, features = tf.split(network_output, (1, 1, n_features), axis=-1)
        z_std = self.std_nonlinearity(z_log_std)

        z_mean = self.training_wheels * tf.stop_gradient(z_mean) + (1-self.training_wheels) * z_mean
        z_std = self.training_wheels * tf.stop_gradient(z_std) + (1-self.training_wheels) * z_std
        z_logit = Normal(loc=z_mean, scale=z_std).sample()
        z = self.z_nonlinearity(z_logit)

        objects.update(z_logit_mean=z_mean, z_logit_std=z_std, z_logit=z_logit, z=z)

        # --- obj ---

        layer_inp = tf.concat([base_features, features, objects['local_box'], objects['attr'], objects['z']], axis=-1)
        rep_input = self.obj_network(layer_inp, 1, self.is_training)

        _objects = self._build_obj(rep_input, self.is_training)
        objects.update(_objects)

        # --- final ---

        _objects = AttrDict()
        for k, v in objects.items():
            _, _, _, *trailing_dims = tf_shape(v)
            _objects[k] = tf.reshape(v, (self.batch_size, self.HWB, *trailing_dims))
        objects = _objects

        if prop_state is not None:
            objects.prop_state = tf.tile(prop_state[0:1, None], (self.batch_size, self.HWB, 1))
            objects.prior_prop_state = tf.tile(prop_state[0:1, None], (self.batch_size, self.HWB, 1))

        # --- misc ---

        objects.n_objects = tf.fill((self.batch_size,), self.HWB)
        objects.pred_n_objects = tf.reduce_sum(objects.obj, axis=(1, 2))
        objects.pred_n_objects_hard = tf.reduce_sum(tf.round(objects.obj), axis=(1, 2))

        return objects
Exemplo n.º 7
0
    def _call(self, objects, background, is_training, appearance_only=False):
        if not self.initialized:
            self.image_depth = tf_shape(background)[-1]

        self.maybe_build_subnet("object_decoder")

        # --- compute sprite appearance from attr using object decoder ---

        appearance_logit = apply_object_wise(
            self.object_decoder, objects.attr,
            output_size=self.object_shape + (self.image_depth+1,),
            is_training=is_training)

        appearance_logit = appearance_logit * ([self.color_logit_scale] * 3 + [self.alpha_logit_scale])
        appearance_logit = appearance_logit + ([0.] * 3 + [self.alpha_logit_bias])

        appearance = tf.nn.sigmoid(tf.clip_by_value(appearance_logit, -10., 10.))

        if appearance_only:
            return dict(appearance=appearance)

        appearance_for_output = appearance

        batch_size, *obj_leading_shape, _, _, _ = tf_shape(appearance)
        n_objects = np.prod(obj_leading_shape)
        appearance = tf.reshape(
            appearance, (batch_size, n_objects, *self.object_shape, self.image_depth+1))

        obj_colors, obj_alpha = tf.split(appearance, [self.image_depth, 1], axis=-1)

        obj_alpha *= tf.reshape(objects.render_obj, (batch_size, n_objects, 1, 1, 1))

        z = tf.reshape(objects.z, (batch_size, n_objects, 1, 1, 1))
        obj_importance = tf.maximum(obj_alpha * z / self.importance_temp, 0.01)

        object_maps = tf.concat([obj_colors, obj_alpha, obj_importance], axis=-1)

        *_, image_height, image_width, _ = tf_shape(background)

        yt, xt, ys, xs = coords_to_image_space(
            objects.yt, objects.xt, objects.ys, objects.xs,
            (image_height, image_width), self.anchor_box, top_left=True)

        scales = tf.concat([ys, xs], axis=-1)
        scales = tf.reshape(scales, (batch_size, n_objects, 2))

        offsets = tf.concat([yt, xt], axis=-1)
        offsets = tf.reshape(offsets, (batch_size, n_objects, 2))

        # --- Compose images ---

        n_objects_per_image = tf.fill((batch_size,), int(n_objects))

        output = render_sprites.render_sprites(
            object_maps,
            n_objects_per_image,
            scales,
            offsets,
            background
        )

        return dict(
            appearance=appearance_for_output,
            output=output)
Exemplo n.º 8
0
    def _call(self, input_locs, input_features, reference_locs,
              reference_features, is_training):
        """
        input_features: (B, n_inp, n_hidden)
        input_locs: (B, n_inp, loc_dim)
        reference_locs: (B, n_ref, loc_dim)

        """
        assert (reference_features is not None) == self.do_object_wise

        if not self.is_built:
            self.relation_func = self.build_mlp(scope="relation_func")

            if self.do_object_wise:
                self.object_wise_func = self.build_mlp(
                    scope="object_wise_func")

            self.is_built = True

        loc_dim = tf_shape(input_locs)[-1]
        n_ref = tf_shape(reference_locs)[-2]
        batch_size, n_inp, _ = tf_shape(input_features)

        input_locs = tf.broadcast_to(input_locs, (batch_size, n_inp, loc_dim))
        reference_locs = tf.broadcast_to(reference_locs,
                                         (batch_size, n_ref, loc_dim))

        adjusted_locs = input_locs[:,
                                   None, :, :] - reference_locs[:, :,
                                                                None, :]  # (B, n_ref, n_inp, loc_dim)
        adjusted_features = tf.tile(
            input_features[:, None],
            (1, n_ref, 1, 1))  # (B, n_ref, n_inp, features_dim)
        relation_input = tf.concat([adjusted_features, adjusted_locs], axis=-1)

        if self.do_object_wise:
            object_wise = apply_object_wise(
                self.object_wise_func,
                reference_features,
                output_size=self.n_hidden,
                is_training=is_training)  # (B, n_ref, n_hidden)

            _object_wise = tf.tile(object_wise[:, :, None], (1, 1, n_inp, 1))
            relation_input = tf.concat([relation_input, _object_wise], axis=-1)
        else:
            object_wise = None

        V = apply_object_wise(
            self.relation_func,
            relation_input,
            output_size=self.n_hidden,
            is_training=is_training)  # (B, n_ref, n_inp, n_hidden)

        attention_weights = tf.exp(-0.5 * tf.reduce_sum(
            (adjusted_locs / self.kernel_std)**2, axis=3))
        attention_weights = (attention_weights / (2 * np.pi)**(loc_dim / 2) /
                             self.kernel_std**loc_dim)  # (B, n_ref, n_inp)

        result = tf.reduce_sum(V * attention_weights[..., None],
                               axis=2)  # (B, n_ref, n_hidden)

        if self.do_object_wise:
            result += object_wise

        # result = tf.contrib.layers.layer_norm(result)

        return result
Exemplo n.º 9
0
    def _call(self, signal, is_training, memory=None):
        if not self.is_built:
            self.query_funcs = [
                self.build_mlp(scope="query_head_{}".format(j))
                for j in range(self.n_heads)
            ]
            self.key_funcs = [
                self.build_mlp(scope="key_head_{}".format(j))
                for j in range(self.n_heads)
            ]
            self.value_funcs = [
                self.build_mlp(scope="value_head_{}".format(j))
                for j in range(self.n_heads)
            ]
            self.after_func = self.build_mlp(scope="after")

            if self.do_object_wise:
                self.object_wise_func = self.build_object_wise(
                    scope="object_wise")

            if self.memory is not None:
                self.K = [
                    apply_object_wise(self.key_funcs[j],
                                      memory,
                                      output_size=self.key_dim,
                                      is_training=is_training)
                    for j in range(self.n_heads)
                ]
                self.V = [
                    apply_object_wise(self.value_funcs[j],
                                      memory,
                                      output_size=self.value_dim,
                                      is_training=is_training)
                    for j in range(self.n_heads)
                ]

            self.is_built = True

        n_signal_dim = len(signal.shape)
        assert n_signal_dim in [2, 3]

        if isinstance(memory, tuple):
            # keys and values passed in directly
            K, V = memory
        elif memory is not None:
            # memory is a value that we apply key_funcs and value_funcs to to obtain keys and values
            K = [
                apply_object_wise(self.key_funcs[j],
                                  memory,
                                  output_size=self.key_dim,
                                  is_training=is_training)
                for j in range(self.n_heads)
            ]
            V = [
                apply_object_wise(self.value_funcs[j],
                                  memory,
                                  output_size=self.value_dim,
                                  is_training=is_training)
                for j in range(self.n_heads)
            ]
        elif self.K is not None:
            K = self.K
            V = self.V
        else:
            # self-attention - `signal` used for queries, keys and values.
            K = [
                apply_object_wise(self.key_funcs[j],
                                  signal,
                                  output_size=self.key_dim,
                                  is_training=is_training)
                for j in range(self.n_heads)
            ]
            V = [
                apply_object_wise(self.value_funcs[j],
                                  signal,
                                  output_size=self.value_dim,
                                  is_training=is_training)
                for j in range(self.n_heads)
            ]

        head_outputs = []
        for j in range(self.n_heads):
            Q = apply_object_wise(self.query_funcs[j],
                                  signal,
                                  output_size=self.key_dim,
                                  is_training=is_training)

            if n_signal_dim == 2:
                Q = Q[:, None, :]

            attention_logits = tf.matmul(Q, K[j], transpose_b=True) / tf.sqrt(
                tf.to_float(self.key_dim))
            attention = tf.nn.softmax(attention_logits)
            attended = tf.matmul(attention,
                                 V[j])  # (..., n_queries, value_dim)

            if n_signal_dim == 2:
                attended = attended[:, 0, :]

            head_outputs.append(attended)

        head_outputs = tf.concat(head_outputs, axis=-1)

        # `after_func` is applied to the concatenation of the head outputs, and the result is added to the original
        # signal. Next, if `object_wise_func` is not None and `do_object_wise` is True, object_wise_func is
        # applied object wise and in a ResNet-style manner.

        output = apply_object_wise(self.after_func,
                                   head_outputs,
                                   output_size=self.n_hidden,
                                   is_training=is_training)
        output = tf.layers.dropout(output,
                                   self.p_dropout,
                                   training=is_training)
        signal = tf.contrib.layers.layer_norm(signal + output)

        if self.do_object_wise:
            output = apply_object_wise(self.object_wise_func,
                                       signal,
                                       output_size=self.n_hidden,
                                       is_training=is_training)
            output = tf.layers.dropout(output,
                                       self.p_dropout,
                                       training=is_training)
            signal = tf.contrib.layers.layer_norm(signal + output)

        return signal
Exemplo n.º 10
0
    def _call(self, objects, background, is_training, appearance_only=False):
        if not self.initialized:
            self.image_depth = tf_shape(background)[-1]

        self.maybe_build_subnet("object_decoder")

        # --- compute sprite appearance from attr using object decoder ---

        appearance_logits = apply_object_wise(
            self.object_decoder, objects.attr,
            self.object_shape + (self.image_depth + 1, ), is_training)

        appearance_logits = appearance_logits * ([self.color_logit_scale] * 3 +
                                                 [self.alpha_logit_scale])
        appearance_logits = appearance_logits + ([0.] * 3 +
                                                 [self.alpha_logit_bias])

        appearance = tf.nn.sigmoid(
            tf.clip_by_value(appearance_logits, -10., 10.))

        if appearance_only:
            return dict(appearance=appearance)

        appearance_for_output = appearance

        batch_size, *obj_leading_shape, _, _, _ = tf_shape(appearance)
        n_objects = np.prod(obj_leading_shape)
        appearance = tf.reshape(
            appearance,
            (batch_size, n_objects, *self.object_shape, self.image_depth + 1))

        obj_colors, obj_alpha = tf.split(appearance, [self.image_depth, 1],
                                         axis=-1)

        if "alpha" in self.no_gradient:
            obj_alpha = tf.stop_gradient(obj_alpha)

        if "alpha" in self.fixed_values:
            obj_alpha = float(
                self.fixed_values["alpha"]) * tf.ones_like(obj_alpha)

        obj_alpha *= tf.reshape(objects.obj, (batch_size, n_objects, 1, 1, 1))

        z = tf.reshape(objects.z, (batch_size, n_objects, 1, 1, 1))
        obj_importance = tf.maximum(obj_alpha * z, 0.01)

        object_maps = tf.concat([obj_colors, obj_alpha, obj_importance],
                                axis=-1)

        ys, xs, yt, xt = objects.ys, objects.xs, objects.yt, objects.xt

        scales = tf.concat([ys, xs], axis=-1)
        scales = tf.reshape(scales, (batch_size, n_objects, 2))

        offsets = tf.concat([yt, xt], axis=-1)
        offsets = tf.reshape(offsets, (batch_size, n_objects, 2))

        # --- Compose images ---

        n_objects_per_image = tf.fill((batch_size, ), int(n_objects))

        output = render_sprites.render_sprites(object_maps,
                                               n_objects_per_image, scales,
                                               offsets, background)

        return dict(appearance=appearance_for_output, output=output)
Exemplo n.º 11
0
    def _call(self, objects, background, is_training, appearance_only=False, mask_only=False):
        """ If mask_only==True, then we ignore the provided background, using a black blackground instead,
            and also ignore the computed appearance, using all-white appearances instead.

        """
        if not self.initialized:
            self.image_depth = tf_shape(background)[-1]

        single = False
        if isinstance(objects, dict):
            single = True
            objects = [objects]

        _object_maps = []
        _scales = []
        _offsets = []
        _appearance = []

        for i, obj in enumerate(objects):
            anchor_box = self.anchor_boxes[i]
            object_shape = self.object_shapes[i]

            object_decoder = self.maybe_build_subnet(
                "object_decoder_for_flight_{}".format(i), builder_name='build_object_decoder')

            # --- compute sprite appearance from attr using object decoder ---

            appearance_logit = apply_object_wise(
                object_decoder, obj.attr,
                output_size=object_shape + (self.image_depth+1,),
                is_training=is_training)

            appearance_logit = appearance_logit * ([self.color_logit_scale] * self.image_depth + [self.alpha_logit_scale])
            appearance_logit = appearance_logit + ([0.] * self.image_depth + [self.alpha_logit_bias])

            appearance = tf.nn.sigmoid(tf.clip_by_value(appearance_logit, -10., 10.))
            _appearance.append(appearance)

            if appearance_only:
                continue

            batch_size, *obj_leading_shape, _, _, _ = tf_shape(appearance)
            n_objects = np.prod(obj_leading_shape)
            appearance = tf.reshape(
                appearance, (batch_size, n_objects, *object_shape, self.image_depth+1))

            obj_colors, obj_alpha = tf.split(appearance, [self.image_depth, 1], axis=-1)

            if mask_only:
                obj_colors = tf.ones_like(obj_colors)

            obj_alpha *= tf.reshape(obj.obj, (batch_size, n_objects, 1, 1, 1))

            z = tf.reshape(obj.z, (batch_size, n_objects, 1, 1, 1))
            obj_importance = tf.maximum(obj_alpha * z / self.importance_temp, 0.01)

            object_maps = tf.concat([obj_colors, obj_alpha, obj_importance], axis=-1)

            *_, image_height, image_width, _ = tf_shape(background)

            yt, xt, ys, xs = coords_to_image_space(
                obj.yt, obj.xt, obj.ys, obj.xs,
                (image_height, image_width), anchor_box, top_left=True)

            scales = tf.concat([ys, xs], axis=-1)
            scales = tf.reshape(scales, (batch_size, n_objects, 2))

            offsets = tf.concat([yt, xt], axis=-1)
            offsets = tf.reshape(offsets, (batch_size, n_objects, 2))

            _object_maps.append(object_maps)
            _scales.append(scales)
            _offsets.append(offsets)

        if single:
            _appearance = _appearance[0]

        if appearance_only:
            return dict(appearance=_appearance)

        if mask_only:
            background = tf.zeros_like(background)

        # --- Compose images ---

        output = render_sprites.render_sprites(
            _object_maps,
            _scales,
            _offsets,
            background
        )

        return dict(
            appearance=_appearance,
            output=output)