Ejemplo n.º 1
0
    def __call__(self, tensors):
        kl = concrete_binary_sample_kl(tensors["obj_pre_sigmoid"],
                                       tensors["obj_log_odds"],
                                       self.obj_concrete_temp,
                                       self.prior_log_odds,
                                       self.obj_concrete_temp)

        batch_size = tf_shape(tensors["obj_pre_sigmoid"])[0]
        return tf.reduce_sum(tf.reshape(kl, (batch_size, -1)), 1)
Ejemplo n.º 2
0
        def body(step, stopping_sum, prev_state, running_recon, kl_loss,
                 running_digits, scale_ta, scale_kl_ta, scale_std_ta, shift_ta,
                 shift_kl_ta, shift_std_ta, attr_ta, attr_kl_ta, attr_std_ta,
                 z_pres_ta, z_pres_probs_ta, z_pres_kl_ta, vae_input_ta,
                 vae_output_ta, scale, shift, attr, z_pres):

            if self.difference_air:
                inp = (self._tensors["inp"] -
                       tf.reshape(running_recon,
                                  (self.batch_size, *self.obs_shape)))
                encoded_inp = self.image_encoder(inp, 0, self.is_training)
                encoded_inp = tf.layers.flatten(encoded_inp)
            else:
                encoded_inp = self.encoded_inp

            if self.complete_rnn_input:
                rnn_input = tf.concat(
                    [encoded_inp, scale, shift, attr, z_pres], axis=1)
            else:
                rnn_input = encoded_inp

            hidden_rep, next_state = self.cell(rnn_input, prev_state)

            outputs = self.output_network(hidden_rep, 9, self.is_training)

            (scale_mean, scale_log_std, shift_mean, shift_log_std,
             z_pres_log_odds) = tf.split(outputs, [2, 2, 2, 2, 1], axis=1)

            # --- scale ---

            scale_std = tf.exp(scale_log_std)

            scale_mean = self.apply_fixed_value("scale_mean", scale_mean)
            scale_std = self.apply_fixed_value("scale_std", scale_std)

            scale_logits, scale_kl = normal_vae(scale_mean, scale_std,
                                                self.scale_prior_mean,
                                                self.scale_prior_std)
            scale_kl = tf.reduce_sum(scale_kl, axis=1, keepdims=True)
            scale = tf.nn.sigmoid(tf.clip_by_value(scale_logits, -10, 10))

            # --- shift ---

            shift_std = tf.exp(shift_log_std)

            shift_mean = self.apply_fixed_value("shift_mean", shift_mean)
            shift_std = self.apply_fixed_value("shift_std", shift_std)

            shift_logits, shift_kl = normal_vae(shift_mean, shift_std,
                                                self.shift_prior_mean,
                                                self.shift_prior_std)
            shift_kl = tf.reduce_sum(shift_kl, axis=1, keepdims=True)
            shift = tf.nn.tanh(tf.clip_by_value(shift_logits, -10, 10))

            # --- Extract windows from scene ---

            w, h = scale[:, 0:1], scale[:, 1:2]
            x, y = shift[:, 0:1], shift[:, 1:2]

            theta = tf.concat(
                [w, tf.zeros_like(w), x,
                 tf.zeros_like(h), h, y], axis=1)
            theta = tf.reshape(theta, (-1, 2, 3))

            vae_input = transformer(self._tensors["inp"], theta,
                                    self.object_shape)

            # This is a necessary reshape, as the output of transformer will have unknown dims
            vae_input = tf.reshape(
                vae_input,
                (self.batch_size, *self.object_shape, self.image_depth))

            # --- Apply Object-level VAE (object encoder/object decoder) to windows ---

            attr = self.object_encoder(vae_input, 2 * self.A, self.is_training)
            attr_mean, attr_log_std = tf.split(attr, 2, axis=1)
            attr_std = tf.exp(attr_log_std)
            attr, attr_kl = normal_vae(attr_mean, attr_std,
                                       self.attr_prior_mean,
                                       self.attr_prior_std)
            attr_kl = tf.reduce_sum(attr_kl, axis=1, keepdims=True)

            vae_output = self.object_decoder(
                attr,
                self.object_shape[0] * self.object_shape[1] * self.image_depth,
                self.is_training)
            vae_output = tf.nn.sigmoid(tf.clip_by_value(vae_output, -10, 10))

            # --- Place reconstructed objects in image ---

            theta_inverse = tf.concat([
                1. / w,
                tf.zeros_like(w), -x / w,
                tf.zeros_like(h), 1. / h, -y / h
            ],
                                      axis=1)
            theta_inverse = tf.reshape(theta_inverse, (-1, 2, 3))

            vae_output_transformed = transformer(
                tf.reshape(vae_output, (
                    self.batch_size,
                    *self.object_shape,
                    self.image_depth,
                )), theta_inverse, self.obs_shape[:2])
            vae_output_transformed = tf.reshape(vae_output_transformed, [
                self.batch_size,
                self.image_height * self.image_width * self.image_depth
            ])

            # --- z_pres ---

            if self.run_all_time_steps:
                z_pres = tf.ones_like(z_pres_log_odds)
                z_pres_prob = tf.ones_like(z_pres_log_odds)
                z_pres_kl = tf.zeros_like(z_pres_log_odds)
            else:
                z_pres_log_odds = tf.clip_by_value(z_pres_log_odds, -10, 10)

                z_pres_pre_sigmoid = concrete_binary_pre_sigmoid_sample(
                    z_pres_log_odds, self.z_pres_temperature)
                z_pres = tf.nn.sigmoid(z_pres_pre_sigmoid)
                z_pres = (self.float_is_training * z_pres +
                          (1 - self.float_is_training) * tf.round(z_pres))
                z_pres_prob = tf.nn.sigmoid(z_pres_log_odds)
                z_pres_kl = concrete_binary_sample_kl(
                    z_pres_pre_sigmoid,
                    z_pres_log_odds,
                    self.z_pres_temperature,
                    self.z_pres_prior_log_odds,
                    self.z_pres_temperature,
                )

            stopping_sum += (1.0 - z_pres)
            alive = tf.less(stopping_sum, self.stopping_threshold)
            running_digits += tf.to_int32(alive)

            # --- adjust reconstruction ---

            running_recon += tf.where(
                tf.tile(alive, (1, vae_output_transformed.shape[1])),
                z_pres * vae_output_transformed, tf.zeros_like(running_recon))

            # --- add kl to loss ---

            kl_loss += tf.where(alive, scale_kl, tf.zeros_like(kl_loss))
            kl_loss += tf.where(alive, shift_kl, tf.zeros_like(kl_loss))
            kl_loss += tf.where(alive, attr_kl, tf.zeros_like(kl_loss))
            kl_loss += tf.where(alive, z_pres_kl, tf.zeros_like(kl_loss))

            # --- record values ---

            scale_ta = scale_ta.write(scale_ta.size(), scale)
            scale_kl_ta = scale_kl_ta.write(scale_kl_ta.size(), scale_kl)
            scale_std_ta = scale_std_ta.write(scale_std_ta.size(), scale_std)

            shift_ta = shift_ta.write(shift_ta.size(), shift)
            shift_kl_ta = shift_kl_ta.write(shift_kl_ta.size(), shift_kl)
            shift_std_ta = shift_std_ta.write(shift_std_ta.size(), shift_std)

            attr_ta = attr_ta.write(attr_ta.size(), attr)
            attr_kl_ta = attr_kl_ta.write(attr_kl_ta.size(), attr_kl)
            attr_std_ta = attr_std_ta.write(attr_std_ta.size(), attr_std)

            vae_input_ta = vae_input_ta.write(vae_input_ta.size(),
                                              tf.layers.flatten(vae_input))
            vae_output_ta = vae_output_ta.write(vae_output_ta.size(),
                                                vae_output)

            z_pres_ta = z_pres_ta.write(z_pres_ta.size(), z_pres)
            z_pres_probs_ta = z_pres_probs_ta.write(z_pres_probs_ta.size(),
                                                    z_pres_prob)
            z_pres_kl_ta = z_pres_kl_ta.write(z_pres_kl_ta.size(), z_pres_kl)

            return (
                step + 1,
                stopping_sum,
                next_state,
                running_recon,
                kl_loss,
                running_digits,
                scale_ta,
                scale_kl_ta,
                scale_std_ta,
                shift_ta,
                shift_kl_ta,
                shift_std_ta,
                attr_ta,
                attr_kl_ta,
                attr_std_ta,
                z_pres_ta,
                z_pres_probs_ta,
                z_pres_kl_ta,
                vae_input_ta,
                vae_output_ta,
                scale,
                shift,
                attr,
                z_pres,
            )
Ejemplo n.º 3
0
    def _compute_obj_kl(self, tensors, existing_objects=None):
        # --- compute obj_kl ---

        obj_pre_sigmoid = tensors["obj_pre_sigmoid"]
        obj_log_odds = tensors["obj_log_odds"]
        obj_prob = tensors["obj_prob"]
        obj = tensors["obj"]
        batch_size, n_objects, _ = tf_shape(obj)

        max_n_objects = n_objects

        if existing_objects is not None:
            _, n_existing_objects, _ = tf_shape(existing_objects)
            existing_objects = tf.reshape(existing_objects, (batch_size, n_existing_objects))
            max_n_objects += n_existing_objects

        count_support = tf.range(max_n_objects+1, dtype=tf.float32)

        if self.count_prior_dist is not None:
            if self.count_prior_dist is not None:
                assert len(self.count_prior_dist) == (max_n_objects + 1)
            count_distribution = tf.constant(self.count_prior_dist, dtype=tf.float32)
        else:
            count_prior_prob = tf.nn.sigmoid(self.count_prior_log_odds)
            count_distribution = (1 - count_prior_prob) * (count_prior_prob ** count_support)

        normalizer = tf.reduce_sum(count_distribution)
        count_distribution = count_distribution / tf.maximum(normalizer, 1e-6)
        count_distribution = tf.tile(count_distribution[None, :], (batch_size, 1))

        if existing_objects is not None:
            count_so_far = tf.reduce_sum(tf.round(existing_objects), axis=1, keepdims=True)

            count_distribution = (
                count_distribution
                * tf_binomial_coefficient(count_support, count_so_far)
                * tf_binomial_coefficient(max_n_objects - count_support, n_existing_objects - count_so_far)
            )

            normalizer = tf.reduce_sum(count_distribution, axis=1, keepdims=True)
            count_distribution = count_distribution / tf.maximum(normalizer, 1e-6)
        else:
            count_so_far = tf.zeros((batch_size, 1), dtype=tf.float32)

        obj_kl = []
        for i in range(n_objects):
            p_z_given_Cz_raw = (count_support[None, :] - count_so_far) / (max_n_objects - i)
            p_z_given_Cz = tf.clip_by_value(p_z_given_Cz_raw, 0.0, 1.0)

            # Doing this instead of 1 - p_z_given_Cz seems to be more numerically stable.
            inv_p_z_given_Cz_raw = (max_n_objects - i - count_support[None, :] + count_so_far) / (max_n_objects - i)
            inv_p_z_given_Cz = tf.clip_by_value(inv_p_z_given_Cz_raw, 0.0, 1.0)

            p_z = tf.reduce_sum(count_distribution * p_z_given_Cz, axis=1, keepdims=True)

            if self.use_concrete_kl:
                prior_log_odds = tf_safe_log(p_z) - tf_safe_log(1-p_z)
                _obj_kl = concrete_binary_sample_kl(
                    obj_pre_sigmoid[:, i, :],
                    obj_log_odds[:, i, :], self.obj_concrete_temp,
                    prior_log_odds, self.obj_concrete_temp,
                )
            else:
                prob = obj_prob[:, i, :]

                _obj_kl = (
                    prob * (tf_safe_log(prob) - tf_safe_log(p_z))
                    + (1-prob) * (tf_safe_log(1-prob) - tf_safe_log(1-p_z))
                )

            obj_kl.append(_obj_kl)

            sample = tf.to_float(obj[:, i, :] > 0.5)
            mult = sample * p_z_given_Cz + (1-sample) * inv_p_z_given_Cz
            raw_count_distribution = mult * count_distribution
            normalizer = tf.reduce_sum(raw_count_distribution, axis=1, keepdims=True)
            normalizer = tf.maximum(normalizer, 1e-6)

            # invalid = tf.logical_and(p_z_given_Cz_raw > 1, count_distribution > 1e-8)
            # float_invalid = tf.cast(invalid, tf.float32)
            # diagnostic = tf.stack(
            #     [float_invalid, p_z_given_Cz, count_distribution, mult, raw_count_distribution], axis=-1)

            # assert_op = tf.Assert(
            #     tf.reduce_all(tf.logical_not(invalid)),
            #     [invalid, diagnostic, count_so_far, sample, tf.constant(i, dtype=tf.float32)],
            #     summarize=100000)

            count_distribution = raw_count_distribution / normalizer
            count_so_far += sample

            # this avoids buildup of inaccuracies that can cause problems in computing p_z_given_Cz_raw
            count_so_far = tf.round(count_so_far)

        obj_kl = tf.reshape(tf.concat(obj_kl, axis=1), (batch_size, n_objects, 1))

        return obj_kl
Ejemplo n.º 4
0
    def _compute_obj_kl(self, tensors):
        # --- compute obj_kl ---

        count_support = tf.range(self.HWB + 1, dtype=tf.float32)

        if self.count_prior_dist is not None:
            if self.count_prior_dist is not None:
                assert len(self.count_prior_dist) == (self.HWB + 1)
            count_distribution = tf.constant(self.count_prior_dist,
                                             dtype=tf.float32)
        else:
            count_prior_prob = tf.nn.sigmoid(self.count_prior_log_odds)
            count_distribution = (1 - count_prior_prob) * (count_prior_prob**
                                                           count_support)

        normalizer = tf.reduce_sum(count_distribution)
        count_distribution = count_distribution / normalizer
        count_distribution = tf.tile(count_distribution[None, :],
                                     (self.batch_size, 1))
        count_so_far = tf.zeros((self.batch_size, 1), dtype=tf.float32)

        i = 0

        obj_kl = []

        for h, w, b in itertools.product(range(self.H), range(self.W),
                                         range(self.B)):
            p_z_given_Cz = tf.maximum(count_support[None, :] - count_so_far,
                                      0) / (self.HWB - i)

            # Reshape for batch matmul
            _count_distribution = count_distribution[:, None, :]
            _p_z_given_Cz = p_z_given_Cz[:, :, None]

            p_z = tf.matmul(_count_distribution, _p_z_given_Cz)[:, :, 0]

            if self.use_concrete_kl:
                prior_log_odds = tf_safe_log(p_z) - tf_safe_log(1 - p_z)
                _obj_kl = concrete_binary_sample_kl(
                    tensors["obj_pre_sigmoid"][:, h, w, b, :],
                    tensors["obj_log_odds"][:, h, w, b, :],
                    self.obj_concrete_temp,
                    prior_log_odds,
                    self.obj_concrete_temp,
                )
            else:
                prob = tensors["obj_prob"][:, h, w, b, :]

                _obj_kl = (prob * (tf_safe_log(prob) - tf_safe_log(p_z)) +
                           (1 - prob) *
                           (tf_safe_log(1 - prob) - tf_safe_log(1 - p_z)))

            obj_kl.append(_obj_kl)

            sample = tf.to_float(tensors["obj"][:, h, w, b, :] > 0.5)
            mult = sample * p_z_given_Cz + (1 - sample) * (1 - p_z_given_Cz)
            count_distribution = mult * count_distribution
            normalizer = tf.reduce_sum(count_distribution,
                                       axis=1,
                                       keepdims=True)
            normalizer = tf.maximum(normalizer, 1e-6)
            count_distribution = count_distribution / normalizer

            count_so_far += sample

            i += 1

        obj_kl = tf.reshape(tf.concat(obj_kl, axis=1),
                            (self.batch_size, self.H, self.W, self.B, 1))

        return obj_kl
Ejemplo n.º 5
0
    def compute_kl(self, tensors, prior=None):
        simple_obj_kl = prior is not None

        if prior is None:
            prior = self._independent_prior()

        # --- box ---

        cell_y_kl = tensors["cell_y_logit_dist"].kl_divergence(
            prior["cell_y_logit_dist"])
        cell_x_kl = tensors["cell_x_logit_dist"].kl_divergence(
            prior["cell_x_logit_dist"])
        height_kl = tensors["height_logit_dist"].kl_divergence(
            prior["height_logit_dist"])
        width_kl = tensors["width_logit_dist"].kl_divergence(
            prior["width_logit_dist"])

        if "cell_y" in self.no_gradient:
            cell_y_kl = tf.stop_gradient(cell_y_kl)

        if "cell_x" in self.no_gradient:
            cell_x_kl = tf.stop_gradient(cell_x_kl)

        if "height" in self.no_gradient:
            height_kl = tf.stop_gradient(height_kl)

        if "width" in self.no_gradient:
            width_kl = tf.stop_gradient(width_kl)

        box_kl = tf.concat([cell_y_kl, cell_x_kl, height_kl, width_kl],
                           axis=-1)

        # --- attr ---

        attr_kl = tensors["attr_dist"].kl_divergence(prior["attr_dist"])

        if "attr" in self.no_gradient:
            attr_kl = tf.stop_gradient(attr_kl)

        # --- z ---

        z_kl = tensors["z_logit_dist"].kl_divergence(prior["z_logit_dist"])

        if "z" in self.no_gradient:
            z_kl = tf.stop_gradient(z_kl)

        if "z" in self.fixed_values:
            z_kl = tf.zeros_like(z_kl)

        # --- obj ---

        if simple_obj_kl:
            obj_kl = concrete_binary_sample_kl(
                tensors["obj_pre_sigmoid"],
                tensors["obj_log_odds"],
                self.obj_concrete_temp,
                prior["obj_log_odds"],
                self.obj_concrete_temp,
            )
        else:
            obj_kl = self._compute_obj_kl(tensors)

        if "obj" in self.no_gradient:
            obj_kl = tf.stop_gradient(obj_kl)

        return dict(
            cell_y_kl=cell_y_kl,
            cell_x_kl=cell_x_kl,
            height_kl=height_kl,
            width_kl=width_kl,
            box_kl=box_kl,
            z_kl=z_kl,
            attr_kl=attr_kl,
            obj_kl=obj_kl,
        )
Ejemplo n.º 6
0
    def _compute_obj_kl(self, tensors, existing_objects=None):
        # --- compute obj_kl ---

        obj_pre_sigmoid = tensors["obj_pre_sigmoid"]
        obj_log_odds = tensors["obj_log_odds"]
        obj_prob = tensors["obj_prob"]
        obj = tensors["obj"]
        batch_size, n_objects, _ = tf_shape(obj)

        max_n_objects = n_objects

        if existing_objects is not None:
            _, n_existing_objects, _ = tf_shape(existing_objects)
            existing_objects = tf.reshape(existing_objects,
                                          (batch_size, n_existing_objects))
            max_n_objects += n_existing_objects

        count_support = tf.range(max_n_objects + 1, dtype=tf.float32)

        if self.count_prior_dist is not None:
            if self.count_prior_dist is not None:
                assert len(self.count_prior_dist) == (max_n_objects + 1)
            count_distribution = tf.constant(self.count_prior_dist,
                                             dtype=tf.float32)
        else:
            count_prior_prob = tf.nn.sigmoid(self.count_prior_log_odds)
            count_distribution = (1 - count_prior_prob) * (count_prior_prob**
                                                           count_support)

        normalizer = tf.reduce_sum(count_distribution)
        count_distribution = count_distribution / tf.maximum(normalizer, 1e-6)
        count_distribution = tf.tile(count_distribution[None, :],
                                     (batch_size, 1))

        if existing_objects is not None:
            count_so_far = tf.reduce_sum(tf.round(existing_objects),
                                         axis=1,
                                         keepdims=True)

            count_distribution = (
                count_distribution *
                tf_binomial_coefficient(count_support, count_so_far) *
                tf_binomial_coefficient(max_n_objects - count_support,
                                        n_existing_objects - count_so_far))

            normalizer = tf.reduce_sum(count_distribution,
                                       axis=1,
                                       keepdims=True)
            count_distribution = count_distribution / tf.maximum(
                normalizer, 1e-6)
        else:
            count_so_far = tf.zeros((batch_size, 1), dtype=tf.float32)

        obj_kl = []
        for i in range(n_objects):
            p_z_given_Cz = tf.maximum(count_support[None, :] - count_so_far,
                                      0) / (max_n_objects - i)

            # Reshape for batch matmul
            _count_distribution = count_distribution[:, None, :]
            _p_z_given_Cz = p_z_given_Cz[:, :, None]

            p_z = tf.matmul(_count_distribution, _p_z_given_Cz)[:, :, 0]

            if self.use_concrete_kl:
                prior_log_odds = tf_safe_log(p_z) - tf_safe_log(1 - p_z)
                _obj_kl = concrete_binary_sample_kl(
                    obj_pre_sigmoid[:, i, :],
                    obj_log_odds[:, i, :],
                    self.obj_concrete_temp,
                    prior_log_odds,
                    self.obj_concrete_temp,
                )
            else:
                prob = obj_prob[:, i, :]

                _obj_kl = (prob * (tf_safe_log(prob) - tf_safe_log(p_z)) +
                           (1 - prob) *
                           (tf_safe_log(1 - prob) - tf_safe_log(1 - p_z)))

            obj_kl.append(_obj_kl)

            sample = tf.to_float(obj[:, i, :] > 0.5)
            mult = sample * p_z_given_Cz + (1 - sample) * (1 - p_z_given_Cz)
            count_distribution = mult * count_distribution
            normalizer = tf.reduce_sum(count_distribution,
                                       axis=1,
                                       keepdims=True)
            normalizer = tf.maximum(normalizer, 1e-6)
            count_distribution = count_distribution / normalizer

            count_so_far += sample

        obj_kl = tf.reshape(tf.concat(obj_kl, axis=1),
                            (batch_size, n_objects, 1))

        return obj_kl