Exemple #1
0
    def _call(self, inp, mask, is_training):
        self.maybe_build_subnet('background_encoder')
        self.maybe_build_subnet('background_decoder')

        combined = tf.concat([inp, mask], axis=-1)
        latent = self.background_encoder(combined,
                                         2 * self.n_latents_per_channel,
                                         is_training)
        mean, std = tf.split(latent, 2, axis=-1)
        sample, kl = normal_vae(mean, std, 0, 1)
        background = self.background_decoder(sample, None, is_training)
        return background, kl
Exemple #2
0
    def build_representation(self):
        # --- init modules ---

        if self.encoder is None:
            self.encoder = cfg.build_encoder(scope="encoder")
            if "encoder" in self.fixed_weights:
                self.encoder.fix_variables()

        if self.decoder is None:
            self.decoder = cfg.build_decoder(scope="decoder")
            if "decoder" in self.fixed_weights:
                self.decoder.fix_variables()

        # --- encode ---

        attr = self.encoder(self.inp, 2 * self.A, self.is_training)
        attr_mean, attr_log_std = tf.split(attr, 2, axis=-1)
        attr_std = tf.exp(attr_log_std)

        if not self.noisy:
            attr_std = tf.zeros_like(attr_std)

        attr, attr_kl = normal_vae(attr_mean, attr_std, self.attr_prior_mean, self.attr_prior_std)

        obj_shape = tf.concat([tf.shape(attr)[:-1], [1]], axis=0)
        self._tensors["obj"] = tf.ones(obj_shape)

        self._tensors.update(attr_mean=attr_mean, attr_std=attr_std, attr_kl=attr_kl, attr=attr)

        # --- decode ---

        reconstruction = self.decoder(attr, 3, self.is_training)
        reconstruction = reconstruction[:, :self.inp.shape[1], :self.inp.shape[2], :]

        reconstruction = tf.nn.sigmoid(tf.clip_by_value(reconstruction, -10, 10))
        self._tensors["output"] = reconstruction

        # --- losses ---

        if self.train_kl:
            self.losses['attr_kl'] = tf_mean_sum(self._tensors["attr_kl"])

        if self.train_reconstruction:
            self._tensors['per_pixel_reconstruction_loss'] = xent_loss(pred=reconstruction, label=self.inp)
            self.losses['reconstruction'] = tf_mean_sum(self._tensors['per_pixel_reconstruction_loss'])
Exemple #3
0
    def _build_program_interpreter(self, tensors):
        # --- Get object attributes using object encoder ---

        max_objects = tensors["max_objects"]

        yt, xt, ys, xs = tf.split(tensors["normalized_box"], 4, axis=-1)

        transform_constraints = snt.AffineWarpConstraints.no_shear_2d()
        warper = snt.AffineGridWarper((self.image_height, self.image_width),
                                      self.object_shape, transform_constraints)

        _boxes = tf.concat(
            [xs, 2 * (xt + xs / 2) - 1, ys, 2 * (yt + ys / 2) - 1], axis=-1)
        _boxes = tf.reshape(_boxes, (self.batch_size * max_objects, 4))
        grid_coords = warper(_boxes)
        grid_coords = tf.reshape(grid_coords, (
            self.batch_size,
            max_objects,
            *self.object_shape,
            2,
        ))
        glimpse = tf.contrib.resampler.resampler(tensors["inp"], grid_coords)

        object_encoder_in = tf.reshape(glimpse,
                                       (self.batch_size * max_objects,
                                        *self.object_shape, self.image_depth))

        attr = self.object_encoder(object_encoder_in, (1, 1, 2 * self.A),
                                   self.is_training)
        attr = tf.reshape(attr, (self.batch_size, max_objects, 2 * self.A))
        attr_mean, attr_log_std = tf.split(attr, [self.A, self.A], axis=-1)
        attr_std = tf.exp(attr_log_std)

        if not self.noisy:
            attr_std = tf.zeros_like(attr_std)

        attr, attr_kl = normal_vae(attr_mean, attr_std, self.attr_prior_mean,
                                   self.attr_prior_std)

        object_decoder_in = tf.reshape(
            attr, (self.batch_size * max_objects, 1, 1, self.A))

        # --- Compute sprites from attr using object decoder ---

        object_logits = self.object_decoder(
            object_decoder_in, self.object_shape + (self.image_depth, ),
            self.is_training)

        objects = tf.nn.sigmoid(tf.clip_by_value(object_logits, -10., 10.))

        objects = tf.reshape(objects, (
            self.batch_size,
            max_objects,
            *self.object_shape,
            self.image_depth,
        ))
        alpha = tensors["obj"][:, :, :, None, None] * tf.ones_like(
            objects[:, :, :, :, :1])
        importance = tf.ones_like(objects[:, :, :, :, :1])
        objects = tf.concat([objects, alpha, importance], axis=-1)

        # -- Reconstruct image ---

        scales = tf.concat([ys, xs], axis=-1)
        scales = tf.reshape(scales, (self.batch_size, max_objects, 2))

        offsets = tf.concat([yt, xt], axis=-1)
        offsets = tf.reshape(offsets, (self.batch_size, max_objects, 2))

        output = render_sprites.render_sprites(objects, tensors["n_objects"],
                                               scales, offsets,
                                               tensors["background"])

        return dict(output=output,
                    glimpse=tf.reshape(glimpse,
                                       (self.batch_size, max_objects,
                                        *self.object_shape, self.image_depth)),
                    attr=tf.reshape(attr,
                                    (self.batch_size, max_objects, self.A)),
                    attr_kl=tf.reshape(attr_kl,
                                       (self.batch_size, max_objects, self.A)),
                    objects=tf.reshape(objects, (
                        self.batch_size,
                        max_objects,
                        *self.object_shape,
                        self.image_depth,
                    )))
Exemple #4
0
        def body(step, stopping_sum, prev_state, running_recon, kl_loss,
                 running_digits, scale_ta, scale_kl_ta, scale_std_ta, shift_ta,
                 shift_kl_ta, shift_std_ta, attr_ta, attr_kl_ta, attr_std_ta,
                 z_pres_ta, z_pres_probs_ta, z_pres_kl_ta, vae_input_ta,
                 vae_output_ta, scale, shift, attr, z_pres):

            if self.difference_air:
                inp = (self._tensors["inp"] -
                       tf.reshape(running_recon,
                                  (self.batch_size, *self.obs_shape)))
                encoded_inp = self.image_encoder(inp, 0, self.is_training)
                encoded_inp = tf.layers.flatten(encoded_inp)
            else:
                encoded_inp = self.encoded_inp

            if self.complete_rnn_input:
                rnn_input = tf.concat(
                    [encoded_inp, scale, shift, attr, z_pres], axis=1)
            else:
                rnn_input = encoded_inp

            hidden_rep, next_state = self.cell(rnn_input, prev_state)

            outputs = self.output_network(hidden_rep, 9, self.is_training)

            (scale_mean, scale_log_std, shift_mean, shift_log_std,
             z_pres_log_odds) = tf.split(outputs, [2, 2, 2, 2, 1], axis=1)

            # --- scale ---

            scale_std = tf.exp(scale_log_std)

            scale_mean = self.apply_fixed_value("scale_mean", scale_mean)
            scale_std = self.apply_fixed_value("scale_std", scale_std)

            scale_logits, scale_kl = normal_vae(scale_mean, scale_std,
                                                self.scale_prior_mean,
                                                self.scale_prior_std)
            scale_kl = tf.reduce_sum(scale_kl, axis=1, keepdims=True)
            scale = tf.nn.sigmoid(tf.clip_by_value(scale_logits, -10, 10))

            # --- shift ---

            shift_std = tf.exp(shift_log_std)

            shift_mean = self.apply_fixed_value("shift_mean", shift_mean)
            shift_std = self.apply_fixed_value("shift_std", shift_std)

            shift_logits, shift_kl = normal_vae(shift_mean, shift_std,
                                                self.shift_prior_mean,
                                                self.shift_prior_std)
            shift_kl = tf.reduce_sum(shift_kl, axis=1, keepdims=True)
            shift = tf.nn.tanh(tf.clip_by_value(shift_logits, -10, 10))

            # --- Extract windows from scene ---

            w, h = scale[:, 0:1], scale[:, 1:2]
            x, y = shift[:, 0:1], shift[:, 1:2]

            theta = tf.concat(
                [w, tf.zeros_like(w), x,
                 tf.zeros_like(h), h, y], axis=1)
            theta = tf.reshape(theta, (-1, 2, 3))

            vae_input = transformer(self._tensors["inp"], theta,
                                    self.object_shape)

            # This is a necessary reshape, as the output of transformer will have unknown dims
            vae_input = tf.reshape(
                vae_input,
                (self.batch_size, *self.object_shape, self.image_depth))

            # --- Apply Object-level VAE (object encoder/object decoder) to windows ---

            attr = self.object_encoder(vae_input, 2 * self.A, self.is_training)
            attr_mean, attr_log_std = tf.split(attr, 2, axis=1)
            attr_std = tf.exp(attr_log_std)
            attr, attr_kl = normal_vae(attr_mean, attr_std,
                                       self.attr_prior_mean,
                                       self.attr_prior_std)
            attr_kl = tf.reduce_sum(attr_kl, axis=1, keepdims=True)

            vae_output = self.object_decoder(
                attr,
                self.object_shape[0] * self.object_shape[1] * self.image_depth,
                self.is_training)
            vae_output = tf.nn.sigmoid(tf.clip_by_value(vae_output, -10, 10))

            # --- Place reconstructed objects in image ---

            theta_inverse = tf.concat([
                1. / w,
                tf.zeros_like(w), -x / w,
                tf.zeros_like(h), 1. / h, -y / h
            ],
                                      axis=1)
            theta_inverse = tf.reshape(theta_inverse, (-1, 2, 3))

            vae_output_transformed = transformer(
                tf.reshape(vae_output, (
                    self.batch_size,
                    *self.object_shape,
                    self.image_depth,
                )), theta_inverse, self.obs_shape[:2])
            vae_output_transformed = tf.reshape(vae_output_transformed, [
                self.batch_size,
                self.image_height * self.image_width * self.image_depth
            ])

            # --- z_pres ---

            if self.run_all_time_steps:
                z_pres = tf.ones_like(z_pres_log_odds)
                z_pres_prob = tf.ones_like(z_pres_log_odds)
                z_pres_kl = tf.zeros_like(z_pres_log_odds)
            else:
                z_pres_log_odds = tf.clip_by_value(z_pres_log_odds, -10, 10)

                z_pres_pre_sigmoid = concrete_binary_pre_sigmoid_sample(
                    z_pres_log_odds, self.z_pres_temperature)
                z_pres = tf.nn.sigmoid(z_pres_pre_sigmoid)
                z_pres = (self.float_is_training * z_pres +
                          (1 - self.float_is_training) * tf.round(z_pres))
                z_pres_prob = tf.nn.sigmoid(z_pres_log_odds)
                z_pres_kl = concrete_binary_sample_kl(
                    z_pres_pre_sigmoid,
                    z_pres_log_odds,
                    self.z_pres_temperature,
                    self.z_pres_prior_log_odds,
                    self.z_pres_temperature,
                )

            stopping_sum += (1.0 - z_pres)
            alive = tf.less(stopping_sum, self.stopping_threshold)
            running_digits += tf.to_int32(alive)

            # --- adjust reconstruction ---

            running_recon += tf.where(
                tf.tile(alive, (1, vae_output_transformed.shape[1])),
                z_pres * vae_output_transformed, tf.zeros_like(running_recon))

            # --- add kl to loss ---

            kl_loss += tf.where(alive, scale_kl, tf.zeros_like(kl_loss))
            kl_loss += tf.where(alive, shift_kl, tf.zeros_like(kl_loss))
            kl_loss += tf.where(alive, attr_kl, tf.zeros_like(kl_loss))
            kl_loss += tf.where(alive, z_pres_kl, tf.zeros_like(kl_loss))

            # --- record values ---

            scale_ta = scale_ta.write(scale_ta.size(), scale)
            scale_kl_ta = scale_kl_ta.write(scale_kl_ta.size(), scale_kl)
            scale_std_ta = scale_std_ta.write(scale_std_ta.size(), scale_std)

            shift_ta = shift_ta.write(shift_ta.size(), shift)
            shift_kl_ta = shift_kl_ta.write(shift_kl_ta.size(), shift_kl)
            shift_std_ta = shift_std_ta.write(shift_std_ta.size(), shift_std)

            attr_ta = attr_ta.write(attr_ta.size(), attr)
            attr_kl_ta = attr_kl_ta.write(attr_kl_ta.size(), attr_kl)
            attr_std_ta = attr_std_ta.write(attr_std_ta.size(), attr_std)

            vae_input_ta = vae_input_ta.write(vae_input_ta.size(),
                                              tf.layers.flatten(vae_input))
            vae_output_ta = vae_output_ta.write(vae_output_ta.size(),
                                                vae_output)

            z_pres_ta = z_pres_ta.write(z_pres_ta.size(), z_pres)
            z_pres_probs_ta = z_pres_probs_ta.write(z_pres_probs_ta.size(),
                                                    z_pres_prob)
            z_pres_kl_ta = z_pres_kl_ta.write(z_pres_kl_ta.size(), z_pres_kl)

            return (
                step + 1,
                stopping_sum,
                next_state,
                running_recon,
                kl_loss,
                running_digits,
                scale_ta,
                scale_kl_ta,
                scale_std_ta,
                shift_ta,
                shift_kl_ta,
                shift_std_ta,
                attr_ta,
                attr_kl_ta,
                attr_std_ta,
                z_pres_ta,
                z_pres_probs_ta,
                z_pres_kl_ta,
                vae_input_ta,
                vae_output_ta,
                scale,
                shift,
                attr,
                z_pres,
            )
Exemple #5
0
    def build_representation(self):
        # --- init modules ---

        if self.encoder is None:
            self.encoder = self.build_encoder(scope="encoder")
            if "encoder" in self.fixed_weights:
                self.encoder.fix_variables()

        if self.cell is None and self.build_cell is not None:
            self.cell = cfg.build_cell(scope="cell")
            if "cell" in self.fixed_weights:
                self.cell.fix_variables()

        if self.decoder is None:
            self.decoder = cfg.build_decoder(scope="decoder")
            if "decoder" in self.fixed_weights:
                self.decoder.fix_variables()

        # --- encode ---

        inp_trailing_shape = tf_shape(self.inp)[2:]
        video = tf.reshape(self.inp, (self.batch_size * self.dynamic_n_frames, *inp_trailing_shape))
        encoder_output = self.encoder(video, 2 * self.A, self.is_training)

        eo_trailing_shape = tf_shape(encoder_output)[1:]
        encoder_output = tf.reshape(
            encoder_output, (self.batch_size, self.dynamic_n_frames, *eo_trailing_shape))

        if self.cell is None:
            attr = encoder_output
        else:

            if self.flat_latent:
                n_trailing_dims = int(np.prod(eo_trailing_shape))
                encoder_output = tf.reshape(
                    encoder_output, (self.batch_size, self.dynamic_n_frames, n_trailing_dims))
            else:
                raise Exception("NotImplemented")

                n_objects = int(np.prod(eo_trailing_shape[:-1]))
                D = eo_trailing_shape[-1]
                encoder_output = tf.reshape(
                    encoder_output, (self.batch_size, self.dynamic_n_frames, n_objects, D))

            encoder_output = tf.layers.flatten(encoder_output)

            attr, final_state = dynamic_rnn(
                self.cell, encoder_output, initial_state=self.cell.zero_state(self.batch_size, tf.float32),
                parallel_iterations=1, swap_memory=False, time_major=False)

        attr_mean, attr_log_std = tf.split(attr, 2, axis=-1)
        attr_std = tf.math.softplus(attr_log_std)

        if not self.noisy:
            attr_std = tf.zeros_like(attr_std)

        attr, attr_kl = normal_vae(attr_mean, attr_std, self.attr_prior_mean, self.attr_prior_std)

        self._tensors.update(attr_mean=attr_mean, attr_std=attr_std, attr_kl=attr_kl, attr=attr)

        # --- decode ---

        decoder_input = tf.reshape(attr, (self.batch_size*self.dynamic_n_frames, *tf_shape(attr)[2:]))

        reconstruction = self.decoder(decoder_input, tf_shape(self.inp)[2:], self.is_training)
        reconstruction = reconstruction[:, :self.obs_shape[1], :self.obs_shape[2], :]
        reconstruction = tf.reshape(reconstruction, (self.batch_size, self.dynamic_n_frames, *self.obs_shape[1:]))

        reconstruction = tf.nn.sigmoid(tf.clip_by_value(reconstruction, -10, 10))
        self._tensors["output"] = reconstruction

        # --- losses ---

        if self.train_kl:
            self.losses['attr_kl'] = tf_mean_sum(self._tensors["attr_kl"])

        if self.train_reconstruction:
            self._tensors['per_pixel_reconstruction_loss'] = xent_loss(pred=reconstruction, label=self.inp)
            self.losses['reconstruction'] = tf_mean_sum(self._tensors['per_pixel_reconstruction_loss'])
Exemple #6
0
    def build_background(self):
        if cfg.background_cfg.mode == "colour":
            rgb = np.array(to_rgb(cfg.background_cfg.colour))[None, None, None, :]
            background = rgb * tf.ones_like(self.inp)

        elif cfg.background_cfg.mode == "learn_solid":
            # Learn a solid colour for the background
            self.solid_background_logits = tf.get_variable("solid_background", initializer=[0.0, 0.0, 0.0])
            if "background" in self.fixed_weights:
                tf.add_to_collection(FIXED_COLLECTION, self.solid_background_logits)
            solid_background = tf.nn.sigmoid(10 * self.solid_background_logits)
            background = solid_background[None, None, None, :] * tf.ones_like(self.inp)

        elif cfg.background_cfg.mode == "scalor":
            pass

        elif cfg.background_cfg.mode == "learn":
            self.maybe_build_subnet("background_encoder")
            self.maybe_build_subnet("background_decoder")

            # Here I'm just encoding the first frame...
            bg_attr = self.background_encoder(self.inp[:, 0], 2 * cfg.background_cfg.A, self.is_training)
            bg_attr_mean, bg_attr_log_std = tf.split(bg_attr, 2, axis=-1)
            bg_attr_std = tf.exp(bg_attr_log_std)
            if not self.noisy:
                bg_attr_std = tf.zeros_like(bg_attr_std)

            bg_attr, bg_attr_kl = normal_vae(bg_attr_mean, bg_attr_std, self.attr_prior_mean, self.attr_prior_std)

            self._tensors.update(
                bg_attr_mean=bg_attr_mean,
                bg_attr_std=bg_attr_std,
                bg_attr_kl=bg_attr_kl,
                bg_attr=bg_attr)

            # --- decode ---

            _, T, H, W, _ = tf_shape(self.inp)

            background = self.background_decoder(bg_attr, 3, self.is_training)
            assert len(background.shape) == 2 and background.shape[1] == 3
            background = tf.nn.sigmoid(tf.clip_by_value(background, -10, 10))
            background = tf.tile(background[:, None, None, None, :], (1, T, H, W, 1))

        elif cfg.background_cfg.mode == "learn_and_transform":
            self.maybe_build_subnet("background_encoder")
            self.maybe_build_subnet("background_decoder")

            # --- encode ---

            n_transform_latents = 4
            n_latents = (2 * cfg.background_cfg.A, 2 * n_transform_latents)

            bg_attr, bg_transform_params = self.background_encoder(self.inp, n_latents, self.is_training)

            # --- bg attributes ---

            bg_attr_mean, bg_attr_log_std = tf.split(bg_attr, 2, axis=-1)
            bg_attr_std = self.std_nonlinearity(bg_attr_log_std)

            bg_attr, bg_attr_kl = normal_vae(bg_attr_mean, bg_attr_std, self.attr_prior_mean, self.attr_prior_std)

            # --- bg location ---

            bg_transform_params = tf.reshape(
                bg_transform_params,
                (self.batch_size, self.dynamic_n_frames, 2*n_transform_latents))

            mean, log_std = tf.split(bg_transform_params, 2, axis=2)
            std = self.std_nonlinearity(log_std)

            logits, kl = normal_vae(mean, std, 0.0, 1.0)

            # integrate across timesteps
            logits = tf.cumsum(logits, axis=1)
            logits = tf.reshape(logits, (self.batch_size*self.dynamic_n_frames, n_transform_latents))

            y, x, h, w = tf.split(logits, n_transform_latents, axis=1)
            h = (0.9 - 0.5) * tf.nn.sigmoid(h) + 0.5
            w = (0.9 - 0.5) * tf.nn.sigmoid(w) + 0.5
            y = (1 - h) * tf.nn.tanh(y)
            x = (1 - w) * tf.nn.tanh(x)

            # --- decode ---

            background = self.background_decoder(bg_attr, self.image_depth, self.is_training)
            bg_shape = cfg.background_cfg.bg_shape
            background = background[:, :bg_shape[0], :bg_shape[1], :]
            assert background.shape[1:3] == bg_shape
            background_raw = tf.nn.sigmoid(tf.clip_by_value(background, -10, 10))

            transform_constraints = snt.AffineWarpConstraints.no_shear_2d()

            warper = snt.AffineGridWarper(
                bg_shape, (self.image_height, self.image_width), transform_constraints)

            transforms = tf.concat([w, x, h, y], axis=-1)
            grid_coords = warper(transforms)

            grid_coords = tf.reshape(
                grid_coords,
                (self.batch_size, self.dynamic_n_frames, *tf_shape(grid_coords)[1:]))

            background = tf.contrib.resampler.resampler(background_raw, grid_coords)

            self._tensors.update(
                bg_attr_mean=bg_attr_mean,
                bg_attr_std=bg_attr_std,
                bg_attr_kl=bg_attr_kl,
                bg_attr=bg_attr,
                bg_y=tf.reshape(y, (self.batch_size, self.dynamic_n_frames, 1)),
                bg_x=tf.reshape(x, (self.batch_size, self.dynamic_n_frames, 1)),
                bg_h=tf.reshape(h, (self.batch_size, self.dynamic_n_frames, 1)),
                bg_w=tf.reshape(w, (self.batch_size, self.dynamic_n_frames, 1)),
                bg_transform_kl=kl,
                bg_raw=background_raw,
            )

        elif cfg.background_cfg.mode == "data":
            background = self._tensors["ground_truth_background"]

        else:
            raise Exception("Unrecognized background mode: {}.".format(cfg.background_cfg.mode))

        self._tensors["background"] = background[:, :self.dynamic_n_frames]