Ejemplo n.º 1
0
    def build_representation(self):

        # --- build graph ---

        self._build_program_generator()

        if self.object_encoder is None:
            self.object_encoder = cfg.build_object_encoder(
                scope="object_encoder")
            if "object_encoder" in self.fixed_weights:
                self.object_encoder.fix_variables()

        if self.object_decoder is None:
            self.object_decoder = cfg.build_object_decoder(
                scope="object_decoder")
            if "object_decoder" in self.fixed_weights:
                self.object_decoder.fix_variables()

        self._build_program_interpreter()

        # --- specify values to record ---

        self.record_tensors(n_objects=self._tensors["n_objects"],
                            attr=self._tensors["attr"])

        # --- losses ---

        if self.train_reconstruction:
            output = self._tensors['output']
            inp = self._tensors['inp']
            self._tensors['per_pixel_reconstruction_loss'] = xent_loss(
                pred=output, label=inp)
            self.losses['reconstruction'] = (
                self.reconstruction_weight *
                tf_mean_sum(self._tensors['per_pixel_reconstruction_loss']))

        if self.train_kl:
            obj = self._tensors["obj"]
            self.losses['attr_kl'] = self.kl_weight * tf_mean_sum(
                obj * self._tensors["attr_kl"])

        # --- other evaluation metrics

        if "n_annotations" in self._tensors:
            count_1norm = tf.to_float(
                tf.abs(
                    tf.to_int32(self._tensors["n_objects"]) -
                    self._tensors["n_annotations"]))
            self.record_tensors(count_1norm=count_1norm,
                                count_error=count_1norm > 0.5)
Ejemplo n.º 2
0
    def build_representation(self):
        # --- init modules ---

        if self.encoder is None:
            self.encoder = cfg.build_encoder(scope="encoder")
            if "encoder" in self.fixed_weights:
                self.encoder.fix_variables()

        if self.decoder is None:
            self.decoder = cfg.build_decoder(scope="decoder")
            if "decoder" in self.fixed_weights:
                self.decoder.fix_variables()

        # --- encode ---

        attr = self.encoder(self.inp, 2 * self.A, self.is_training)
        attr_mean, attr_log_std = tf.split(attr, 2, axis=-1)
        attr_std = tf.exp(attr_log_std)

        if not self.noisy:
            attr_std = tf.zeros_like(attr_std)

        attr, attr_kl = normal_vae(attr_mean, attr_std, self.attr_prior_mean, self.attr_prior_std)

        obj_shape = tf.concat([tf.shape(attr)[:-1], [1]], axis=0)
        self._tensors["obj"] = tf.ones(obj_shape)

        self._tensors.update(attr_mean=attr_mean, attr_std=attr_std, attr_kl=attr_kl, attr=attr)

        # --- decode ---

        reconstruction = self.decoder(attr, 3, self.is_training)
        reconstruction = reconstruction[:, :self.inp.shape[1], :self.inp.shape[2], :]

        reconstruction = tf.nn.sigmoid(tf.clip_by_value(reconstruction, -10, 10))
        self._tensors["output"] = reconstruction

        # --- losses ---

        if self.train_kl:
            self.losses['attr_kl'] = tf_mean_sum(self._tensors["attr_kl"])

        if self.train_reconstruction:
            self._tensors['per_pixel_reconstruction_loss'] = xent_loss(pred=reconstruction, label=self.inp)
            self.losses['reconstruction'] = tf_mean_sum(self._tensors['per_pixel_reconstruction_loss'])
Ejemplo n.º 3
0
    def build_representation(self):
        self._tensors["output"] = reconstruction = self._tensors["background"]

        # --- losses ---

        if self.train_reconstruction:
            self._tensors['per_pixel_reconstruction_loss'] = xent_loss(
                pred=reconstruction, label=self.inp)
            self.losses['reconstruction'] = tf_mean_sum(
                self._tensors['per_pixel_reconstruction_loss'])

        if "bg_attr_kl" in self._tensors:
            self.losses.update(
                bg_attr_kl=self.kl_weight *
                tf_mean_sum(self._tensors["bg_attr_kl"]),
                bg_transform_kl=self.kl_weight *
                tf_mean_sum(self._tensors["bg_transform_kl"]),
            )
Ejemplo n.º 4
0
    def build_representation(self):

        self.maybe_build_subnet("object_encoder")
        self.maybe_build_subnet("object_decoder")

        program_tensors = self._build_program_generator(self._tensors)
        self._tensors.update(program_tensors)

        interpreter_tensors = self._build_program_interpreter(self._tensors)
        self._tensors.update(interpreter_tensors)

        # --- specify values to record ---

        self.record_tensors(n_objects=self._tensors["n_objects"],
                            attr=self._tensors["attr"])

        # --- losses ---

        if self.train_reconstruction:
            output = self._tensors['output']
            inp = self._tensors['inp']
            self._tensors['per_pixel_reconstruction_loss'] = xent_loss(
                pred=output, label=inp)
            self.losses['reconstruction'] = (
                self.reconstruction_weight *
                tf_mean_sum(self._tensors['per_pixel_reconstruction_loss']))

        if self.train_kl:
            obj = self._tensors["obj"]
            self.losses['attr_kl'] = self.kl_weight * tf_mean_sum(
                obj * self._tensors["attr_kl"])

        # --- other evaluation metrics

        if "n_annotations" in self._tensors:
            count_1norm = tf.to_float(
                tf.abs(
                    tf.to_int32(self._tensors["n_objects"]) -
                    self._tensors["n_annotations"]))
            self.record_tensors(count_1norm=count_1norm,
                                count_error=count_1norm > 0.5)
Ejemplo n.º 5
0
    def build_representation(self):
        # --- init modules ---
        self.B = len(self.anchor_boxes)

        if self.backbone is None:
            self.backbone = self.build_backbone(scope="backbone")
            if "backbone" in self.fixed_weights:
                self.backbone.fix_variables()

        if self.feature_fuser is None:
            self.feature_fuser = self.build_feature_fuser(scope="feature_fuser")
            if "feature_fuser" in self.fixed_weights:
                self.feature_fuser.fix_variables()

        if self.obj_feature_extractor is None and self.build_obj_feature_extractor is not None:
            self.obj_feature_extractor = self.build_obj_feature_extractor(scope="obj_feature_extractor")
            if "obj_feature_extractor" in self.fixed_weights:
                self.obj_feature_extractor.fix_variables()

        backbone_output, n_grid_cells, grid_cell_size = self.backbone(
            self.inp, self.B*self.n_backbone_features, self.is_training)

        self.H, self.W = [int(i) for i in n_grid_cells]
        self.HWB = self.H * self.W * self.B
        self.pixels_per_cell = tuple(int(i) for i in grid_cell_size)
        H, W, B = self.H, self.W, self.B

        if self.object_layer is None:
            self.object_layer = ObjectLayer(self.pixels_per_cell, scope="objects")

        self.object_rep_tensors = []
        object_rep_tensors = None
        _tensors = defaultdict(list)

        for f in range(self.n_frames):
            print("Bulding network for frame {}".format(f))
            early_frame_features = backbone_output[:, f]

            if f > 0 and self.obj_feature_extractor is not None:
                object_features = object_rep_tensors["all"]
                object_features = tf.reshape(
                    object_features, (self.batch_size, H, W, B*tf_shape(object_features)[-1]))
                early_frame_features += self.obj_feature_extractor(
                    object_features, B*self.n_backbone_features, self.is_training)

            frame_features = self.feature_fuser(
                early_frame_features, B*self.n_backbone_features, self.is_training)

            frame_features = tf.reshape(
                frame_features, (self.batch_size, H, W, B, self.n_backbone_features))

            object_rep_tensors = self.object_layer(
                self.inp[:, f], frame_features, self._tensors["background"][:, f], self.is_training)

            self.object_rep_tensors.append(object_rep_tensors)

            for k, v in object_rep_tensors.items():
                _tensors[k].append(v)

        self._tensors.update(**{k: tf.stack(v, axis=1) for k, v in _tensors.items()})

        # --- specify values to record ---

        obj = self._tensors["obj"]
        pred_n_objects = self._tensors["pred_n_objects"]

        self.record_tensors(
            batch_size=self.batch_size,
            float_is_training=self.float_is_training,

            cell_y=self._tensors["cell_y"],
            cell_x=self._tensors["cell_x"],
            h=self._tensors["h"],
            w=self._tensors["w"],
            z=self._tensors["z"],
            area=self._tensors["area"],

            cell_y_std=self._tensors["cell_y_std"],
            cell_x_std=self._tensors["cell_x_std"],
            h_std=self._tensors["h_std"],
            w_std=self._tensors["w_std"],
            z_std=self._tensors["z_std"],

            n_objects=pred_n_objects,
            obj=obj,

            latent_area=self._tensors["latent_area"],
            latent_hw=self._tensors["latent_hw"],

            attr=self._tensors["attr"],
        )

        # --- losses ---

        if self.train_reconstruction:
            output = self._tensors['output']
            inp = self._tensors['inp']
            self._tensors['per_pixel_reconstruction_loss'] = xent_loss(pred=output, label=inp)
            self.losses['reconstruction'] = (
                self.reconstruction_weight * tf_mean_sum(self._tensors['per_pixel_reconstruction_loss'])
            )

        if self.train_kl:
            self.losses.update(
                obj_kl=self.kl_weight * tf_mean_sum(self._tensors["obj_kl"]),
                cell_y_kl=self.kl_weight * tf_mean_sum(obj * self._tensors["cell_y_kl"]),
                cell_x_kl=self.kl_weight * tf_mean_sum(obj * self._tensors["cell_x_kl"]),
                h_kl=self.kl_weight * tf_mean_sum(obj * self._tensors["h_kl"]),
                w_kl=self.kl_weight * tf_mean_sum(obj * self._tensors["w_kl"]),
                z_kl=self.kl_weight * tf_mean_sum(obj * self._tensors["z_kl"]),
                attr_kl=self.kl_weight * tf_mean_sum(obj * self._tensors["attr_kl"]),
            )

            if cfg.background_cfg.mode == "learn_and_transform":
                self.losses.update(
                    bg_attr_kl=self.kl_weight * tf_mean_sum(self._tensors["bg_attr_kl"]),
                    bg_transform_kl=self.kl_weight * tf_mean_sum(self._tensors["bg_transform_kl"]),
                )

        # --- other evaluation metrics ---

        if "n_annotations" in self._tensors:
            count_1norm = tf.to_float(
                tf.abs(tf.to_int32(self._tensors["pred_n_objects_hard"]) - self._tensors["n_valid_annotations"]))

            self.record_tensors(
                count_1norm=count_1norm,
                count_error=count_1norm > 0.5,
            )
Ejemplo n.º 6
0
    def build_representation(self):
        # --- build graph ---

        self.maybe_build_subnet("backbone")
        assert isinstance(self.backbone, GridConvNet)

        inp = self._tensors["inp"]
        backbone_output = self.backbone(inp, self.n_backbone_features,
                                        self.is_training)
        n_grid_cells = self.backbone.layer_info[-1]['n_grid_cells']
        grid_cell_size = self.backbone.layer_info[-1]['grid_cell_size']

        self.H, self.W = [int(i) for i in n_grid_cells]
        self.HWB = self.H * self.W * self.B
        self.pixels_per_cell = tuple(int(i) for i in grid_cell_size)

        if self.object_layer is None:
            if self.conv_object_layer:
                self.object_layer = ConvGridObjectLayer(
                    pixels_per_cell=self.pixels_per_cell, scope="objects")
            else:
                self.object_layer = GridObjectLayer(
                    pixels_per_cell=self.pixels_per_cell, scope="objects")

        if self.object_renderer is None:
            self.object_renderer = ObjectRenderer(self.anchor_box,
                                                  self.object_shape,
                                                  scope="renderer")

        objects = self.object_layer(self.inp, backbone_output,
                                    self.is_training)
        self._tensors.update(objects)

        kl_tensors = self.object_layer.compute_kl(objects)
        self._tensors.update(kl_tensors)

        if self.obj_kl is None:
            self.obj_kl = self.build_obj_kl()

        self._tensors['obj_kl'] = self.obj_kl(self._tensors)

        render_tensors = self.object_renderer(objects,
                                              self._tensors["background"],
                                              self.is_training)
        self._tensors.update(render_tensors)

        # --- specify values to record ---

        obj = self._tensors["obj"]

        self.record_tensors(
            batch_size=self.batch_size,
            float_is_training=self.float_is_training,
            cell_y=self._tensors["cell_y"],
            cell_x=self._tensors["cell_x"],
            height=self._tensors["height"],
            width=self._tensors["width"],
            z=self._tensors["z"],
            cell_y_std=self._tensors["cell_y_logit_std"],
            cell_x_std=self._tensors["cell_x_logit_std"],
            height_std=self._tensors["height_logit_std"],
            width_std=self._tensors["width_logit_std"],
            z_std=self._tensors["z_logit_std"],
            obj=obj,
            attr=self._tensors["attr"],
            pred_n_objects=self._tensors["pred_n_objects"],
        )

        # --- losses ---

        if self.train_reconstruction:
            output = self._tensors['output']
            inp = self._tensors['inp']
            self._tensors['per_pixel_reconstruction_loss'] = xent_loss(
                pred=output, label=inp)
            self.losses['reconstruction'] = (
                self.reconstruction_weight *
                tf_mean_sum(self._tensors['per_pixel_reconstruction_loss']))

        if self.train_kl:
            self.losses.update(
                obj_kl=self.kl_weight * tf_mean_sum(self._tensors["obj_kl"]),
                cell_y_kl=self.kl_weight *
                tf_mean_sum(obj * self._tensors["cell_y_kl"]),
                cell_x_kl=self.kl_weight *
                tf_mean_sum(obj * self._tensors["cell_x_kl"]),
                height_kl=self.kl_weight *
                tf_mean_sum(obj * self._tensors["height_kl"]),
                width_kl=self.kl_weight *
                tf_mean_sum(obj * self._tensors["width_kl"]),
                z_kl=self.kl_weight * tf_mean_sum(obj * self._tensors["z_kl"]),
                attr_kl=self.kl_weight *
                tf_mean_sum(obj * self._tensors["attr_kl"]),
            )

        # --- other evaluation metrics ---

        if "n_annotations" in self._tensors:
            count_1norm = tf.to_float(
                tf.abs(
                    tf.to_int32(self._tensors["pred_n_objects_hard"]) -
                    self._tensors["n_valid_annotations"]))

            count_1norm_relative = (count_1norm / tf.maximum(
                tf.cast(self._tensors["n_valid_annotations"], tf.float32),
                1e-6))

            self.record_tensors(
                count_1norm_relative=count_1norm_relative,
                count_1norm=count_1norm,
                count_error=count_1norm > 0.5,
            )
Ejemplo n.º 7
0
    def build_representation(self):
        # --- build graph ---

        self.maybe_build_subnet("backbone")
        assert isinstance(self.backbone, GridConvNet)

        inp = self._tensors["inp"]
        backbone_output, n_grid_cells, grid_cell_size = self.backbone(
            inp, self.B * self.n_backbone_features, self.is_training)

        self.H, self.W = [int(i) for i in n_grid_cells]
        self.HWB = self.H * self.W * self.B
        self.pixels_per_cell = tuple(int(i) for i in grid_cell_size)

        backbone_output = tf.reshape(
            backbone_output,
            (-1, self.H, self.W, self.B, self.n_backbone_features))

        if self.object_layer is None:
            self.object_layer = GridObjectLayer(self.pixels_per_cell,
                                                scope="objects")

        if self.object_renderer is None:
            self.object_renderer = ObjectRenderer(scope="renderer")

        objects = self.object_layer(self.inp, backbone_output,
                                    self.is_training)
        self._tensors.update(objects)

        kl_tensors = self.object_layer.compute_kl(objects)
        self._tensors.update(kl_tensors)

        render_tensors = self.object_renderer(objects,
                                              self._tensors["background"],
                                              self.is_training)
        self._tensors.update(render_tensors)

        # --- specify values to record ---

        obj = self._tensors["obj"]
        pred_n_objects = self._tensors["pred_n_objects"]

        self.record_tensors(
            batch_size=self.batch_size,
            float_is_training=self.float_is_training,
            cell_y=self._tensors["cell_y"],
            cell_x=self._tensors["cell_x"],
            height=self._tensors["height"],
            width=self._tensors["width"],
            z=self._tensors["z"],
            cell_y_std=self._tensors["cell_y_logit_dist"].scale,
            cell_x_std=self._tensors["cell_x_logit_dist"].scale,
            height_std=self._tensors["height_logit_dist"].scale,
            width_std=self._tensors["width_logit_dist"].scale,
            z_std=self._tensors["z_logit_dist"].scale,
            n_objects=pred_n_objects,
            obj=obj,
            on_cell_y_avg=tf.reduce_sum(self._tensors["cell_y"] * obj,
                                        axis=(1, 2, 3, 4)) / pred_n_objects,
            on_cell_x_avg=tf.reduce_sum(self._tensors["cell_x"] * obj,
                                        axis=(1, 2, 3, 4)) / pred_n_objects,
            on_height_avg=tf.reduce_sum(self._tensors["height"] * obj,
                                        axis=(1, 2, 3, 4)) / pred_n_objects,
            on_width_avg=tf.reduce_sum(self._tensors["width"] * obj,
                                       axis=(1, 2, 3, 4)) / pred_n_objects,
            on_z_avg=tf.reduce_sum(self._tensors["z"] * obj, axis=(1, 2, 3, 4))
            / pred_n_objects,
            attr=self._tensors["attr"],
        )

        # --- losses ---

        if self.train_reconstruction:
            output = self._tensors['output']
            inp = self._tensors['inp']
            self._tensors['per_pixel_reconstruction_loss'] = xent_loss(
                pred=output, label=inp)
            self.losses['reconstruction'] = (
                self.reconstruction_weight *
                tf_mean_sum(self._tensors['per_pixel_reconstruction_loss']))

        if self.train_kl:
            self.losses.update(
                obj_kl=self.kl_weight * tf_mean_sum(self._tensors["obj_kl"]),
                cell_y_kl=self.kl_weight *
                tf_mean_sum(obj * self._tensors["cell_y_kl"]),
                cell_x_kl=self.kl_weight *
                tf_mean_sum(obj * self._tensors["cell_x_kl"]),
                height_kl=self.kl_weight *
                tf_mean_sum(obj * self._tensors["height_kl"]),
                width_kl=self.kl_weight *
                tf_mean_sum(obj * self._tensors["width_kl"]),
                z_kl=self.kl_weight * tf_mean_sum(obj * self._tensors["z_kl"]),
                attr_kl=self.kl_weight *
                tf_mean_sum(obj * self._tensors["attr_kl"]),
            )

        # --- other evaluation metrics ---

        if "n_annotations" in self._tensors:
            count_1norm = tf.to_float(
                tf.abs(
                    tf.to_int32(self._tensors["pred_n_objects_hard"]) -
                    self._tensors["n_valid_annotations"]))

            self.record_tensors(
                count_1norm=count_1norm,
                count_error=count_1norm > 0.5,
            )
Ejemplo n.º 8
0
    def build_representation(self):
        # --- process input ---

        if self.image_encoder is None:
            self.image_encoder = cfg.build_image_encoder(scope="image_encoder")
            if "image_encoder" in self.fixed_weights:
                self.image_encoder.fix_variables()

        if self.cell is None:
            self.cell = cfg.build_cell(scope="cell")
            if "cell" in self.fixed_weights:
                self.cell.fix_variables()

        if self.output_network is None:
            self.output_network = cfg.build_output_network(
                scope="output_network")
            if "output" in self.fixed_weights:
                self.output_network.fix_variables()

        if self.object_encoder is None:
            self.object_encoder = cfg.build_object_encoder(
                scope="object_encoder")
            if "object_encoder" in self.fixed_weights:
                self.object_encoder.fix_variables()

        if self.object_decoder is None:
            self.object_decoder = cfg.build_object_decoder(
                scope="object_decoder")
            if "object_decoder" in self.fixed_weights:
                self.object_decoder.fix_variables()

        self.target_n_digits = self._tensors["n_valid_annotations"]

        if not self.difference_air:
            encoded_inp = self.image_encoder(self._tensors["inp"], 0,
                                             self.is_training)
            self.encoded_inp = tf.layers.flatten(encoded_inp)

        # --- condition of while-loop ---

        def cond(step, stopping_sum, *_):
            return tf.logical_and(
                tf.less(step, self.max_time_steps),
                tf.reduce_any(tf.less(stopping_sum, self.stopping_threshold)))

        # --- body of while-loop ---

        def body(step, stopping_sum, prev_state, running_recon, kl_loss,
                 running_digits, scale_ta, scale_kl_ta, scale_std_ta, shift_ta,
                 shift_kl_ta, shift_std_ta, attr_ta, attr_kl_ta, attr_std_ta,
                 z_pres_ta, z_pres_probs_ta, z_pres_kl_ta, vae_input_ta,
                 vae_output_ta, scale, shift, attr, z_pres):

            if self.difference_air:
                inp = (self._tensors["inp"] -
                       tf.reshape(running_recon,
                                  (self.batch_size, *self.obs_shape)))
                encoded_inp = self.image_encoder(inp, 0, self.is_training)
                encoded_inp = tf.layers.flatten(encoded_inp)
            else:
                encoded_inp = self.encoded_inp

            if self.complete_rnn_input:
                rnn_input = tf.concat(
                    [encoded_inp, scale, shift, attr, z_pres], axis=1)
            else:
                rnn_input = encoded_inp

            hidden_rep, next_state = self.cell(rnn_input, prev_state)

            outputs = self.output_network(hidden_rep, 9, self.is_training)

            (scale_mean, scale_log_std, shift_mean, shift_log_std,
             z_pres_log_odds) = tf.split(outputs, [2, 2, 2, 2, 1], axis=1)

            # --- scale ---

            scale_std = tf.exp(scale_log_std)

            scale_mean = self.apply_fixed_value("scale_mean", scale_mean)
            scale_std = self.apply_fixed_value("scale_std", scale_std)

            scale_logits, scale_kl = normal_vae(scale_mean, scale_std,
                                                self.scale_prior_mean,
                                                self.scale_prior_std)
            scale_kl = tf.reduce_sum(scale_kl, axis=1, keepdims=True)
            scale = tf.nn.sigmoid(tf.clip_by_value(scale_logits, -10, 10))

            # --- shift ---

            shift_std = tf.exp(shift_log_std)

            shift_mean = self.apply_fixed_value("shift_mean", shift_mean)
            shift_std = self.apply_fixed_value("shift_std", shift_std)

            shift_logits, shift_kl = normal_vae(shift_mean, shift_std,
                                                self.shift_prior_mean,
                                                self.shift_prior_std)
            shift_kl = tf.reduce_sum(shift_kl, axis=1, keepdims=True)
            shift = tf.nn.tanh(tf.clip_by_value(shift_logits, -10, 10))

            # --- Extract windows from scene ---

            w, h = scale[:, 0:1], scale[:, 1:2]
            x, y = shift[:, 0:1], shift[:, 1:2]

            theta = tf.concat(
                [w, tf.zeros_like(w), x,
                 tf.zeros_like(h), h, y], axis=1)
            theta = tf.reshape(theta, (-1, 2, 3))

            vae_input = transformer(self._tensors["inp"], theta,
                                    self.object_shape)

            # This is a necessary reshape, as the output of transformer will have unknown dims
            vae_input = tf.reshape(
                vae_input,
                (self.batch_size, *self.object_shape, self.image_depth))

            # --- Apply Object-level VAE (object encoder/object decoder) to windows ---

            attr = self.object_encoder(vae_input, 2 * self.A, self.is_training)
            attr_mean, attr_log_std = tf.split(attr, 2, axis=1)
            attr_std = tf.exp(attr_log_std)
            attr, attr_kl = normal_vae(attr_mean, attr_std,
                                       self.attr_prior_mean,
                                       self.attr_prior_std)
            attr_kl = tf.reduce_sum(attr_kl, axis=1, keepdims=True)

            vae_output = self.object_decoder(
                attr,
                self.object_shape[0] * self.object_shape[1] * self.image_depth,
                self.is_training)
            vae_output = tf.nn.sigmoid(tf.clip_by_value(vae_output, -10, 10))

            # --- Place reconstructed objects in image ---

            theta_inverse = tf.concat([
                1. / w,
                tf.zeros_like(w), -x / w,
                tf.zeros_like(h), 1. / h, -y / h
            ],
                                      axis=1)
            theta_inverse = tf.reshape(theta_inverse, (-1, 2, 3))

            vae_output_transformed = transformer(
                tf.reshape(vae_output, (
                    self.batch_size,
                    *self.object_shape,
                    self.image_depth,
                )), theta_inverse, self.obs_shape[:2])
            vae_output_transformed = tf.reshape(vae_output_transformed, [
                self.batch_size,
                self.image_height * self.image_width * self.image_depth
            ])

            # --- z_pres ---

            if self.run_all_time_steps:
                z_pres = tf.ones_like(z_pres_log_odds)
                z_pres_prob = tf.ones_like(z_pres_log_odds)
                z_pres_kl = tf.zeros_like(z_pres_log_odds)
            else:
                z_pres_log_odds = tf.clip_by_value(z_pres_log_odds, -10, 10)

                z_pres_pre_sigmoid = concrete_binary_pre_sigmoid_sample(
                    z_pres_log_odds, self.z_pres_temperature)
                z_pres = tf.nn.sigmoid(z_pres_pre_sigmoid)
                z_pres = (self.float_is_training * z_pres +
                          (1 - self.float_is_training) * tf.round(z_pres))
                z_pres_prob = tf.nn.sigmoid(z_pres_log_odds)
                z_pres_kl = concrete_binary_sample_kl(
                    z_pres_pre_sigmoid,
                    z_pres_log_odds,
                    self.z_pres_temperature,
                    self.z_pres_prior_log_odds,
                    self.z_pres_temperature,
                )

            stopping_sum += (1.0 - z_pres)
            alive = tf.less(stopping_sum, self.stopping_threshold)
            running_digits += tf.to_int32(alive)

            # --- adjust reconstruction ---

            running_recon += tf.where(
                tf.tile(alive, (1, vae_output_transformed.shape[1])),
                z_pres * vae_output_transformed, tf.zeros_like(running_recon))

            # --- add kl to loss ---

            kl_loss += tf.where(alive, scale_kl, tf.zeros_like(kl_loss))
            kl_loss += tf.where(alive, shift_kl, tf.zeros_like(kl_loss))
            kl_loss += tf.where(alive, attr_kl, tf.zeros_like(kl_loss))
            kl_loss += tf.where(alive, z_pres_kl, tf.zeros_like(kl_loss))

            # --- record values ---

            scale_ta = scale_ta.write(scale_ta.size(), scale)
            scale_kl_ta = scale_kl_ta.write(scale_kl_ta.size(), scale_kl)
            scale_std_ta = scale_std_ta.write(scale_std_ta.size(), scale_std)

            shift_ta = shift_ta.write(shift_ta.size(), shift)
            shift_kl_ta = shift_kl_ta.write(shift_kl_ta.size(), shift_kl)
            shift_std_ta = shift_std_ta.write(shift_std_ta.size(), shift_std)

            attr_ta = attr_ta.write(attr_ta.size(), attr)
            attr_kl_ta = attr_kl_ta.write(attr_kl_ta.size(), attr_kl)
            attr_std_ta = attr_std_ta.write(attr_std_ta.size(), attr_std)

            vae_input_ta = vae_input_ta.write(vae_input_ta.size(),
                                              tf.layers.flatten(vae_input))
            vae_output_ta = vae_output_ta.write(vae_output_ta.size(),
                                                vae_output)

            z_pres_ta = z_pres_ta.write(z_pres_ta.size(), z_pres)
            z_pres_probs_ta = z_pres_probs_ta.write(z_pres_probs_ta.size(),
                                                    z_pres_prob)
            z_pres_kl_ta = z_pres_kl_ta.write(z_pres_kl_ta.size(), z_pres_kl)

            return (
                step + 1,
                stopping_sum,
                next_state,
                running_recon,
                kl_loss,
                running_digits,
                scale_ta,
                scale_kl_ta,
                scale_std_ta,
                shift_ta,
                shift_kl_ta,
                shift_std_ta,
                attr_ta,
                attr_kl_ta,
                attr_std_ta,
                z_pres_ta,
                z_pres_probs_ta,
                z_pres_kl_ta,
                vae_input_ta,
                vae_output_ta,
                scale,
                shift,
                attr,
                z_pres,
            )

        # --- end of while-loop body ---

        rnn_init_state = self.cell.zero_state(self.batch_size, tf.float32)

        (_, _, _, reconstruction, kl_loss, self.predicted_n_digits, scale,
         scale_kl, scale_std, shift, shift_kl, shift_std, attr, attr_kl,
         attr_std, z_pres, z_pres_probs, z_pres_kl, vae_input, vae_output, _,
         _, _, _) = tf.while_loop(
             cond,
             body,
             [
                 tf.constant(0),  # RNN time step, initially zero
                 tf.zeros(
                     (self.batch_size, 1)),  # running sum of z_pres samples
                 rnn_init_state,  # initial RNN state
                 tf.zeros((self.batch_size, np.product(self.obs_shape)
                           )),  # reconstruction canvas, initially empty
                 tf.zeros((self.batch_size,
                           1)),  # running value of the loss function
                 tf.zeros((self.batch_size, 1),
                          dtype=tf.int32),  # running inferred number of digits
                 tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True),
                 tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True),
                 tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True),
                 tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True),
                 tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True),
                 tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True),
                 tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True),
                 tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True),
                 tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True),
                 tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True),
                 tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True),
                 tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True),
                 tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True),
                 tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True),
                 tf.zeros((self.batch_size, 2)),  # scale
                 tf.zeros((self.batch_size, 2)),  # shift
                 tf.zeros((self.batch_size, self.A)),  # attr
                 tf.zeros((self.batch_size, 1)),  # z_pres
             ])

        def process_tensor_array(tensor_array, name, shape=None):
            tensor = tf.transpose(tensor_array.stack(), (1, 0, 2))

            time_pad = self.max_time_steps - tf.shape(tensor)[1]
            padding = [[0, 0], [0, time_pad]]
            padding += [[0, 0]] * (len(tensor.shape) - 2)

            tensor = tf.pad(tensor, padding, name=name)

            if shape is not None:
                tensor = tf.reshape(tensor, shape)

            return tensor

        self.predicted_n_digits = self.predicted_n_digits[:, 0]
        self._tensors["predicted_n_digits"] = self.predicted_n_digits

        self._tensors['scale'] = process_tensor_array(scale, 'scale')
        self._tensors['scale_kl'] = process_tensor_array(scale_kl, 'scale_kl')
        self._tensors['scale_std'] = process_tensor_array(
            scale_std, 'scale_std')

        self._tensors['shift'] = process_tensor_array(shift, 'shift')
        self._tensors['shift_kl'] = process_tensor_array(shift_kl, 'shift_kl')
        self._tensors['shift_std'] = process_tensor_array(
            shift_std, 'shift_std')

        self._tensors['attr'] = process_tensor_array(
            attr, 'attr', (self.batch_size, self.max_time_steps, self.A))
        self._tensors['attr_kl'] = process_tensor_array(attr_kl, 'attr_kl')
        self._tensors['attr_std'] = process_tensor_array(attr_std, 'attr_std')

        self._tensors['z_pres'] = process_tensor_array(
            z_pres, 'z_pres', (self.batch_size, self.max_time_steps, 1))
        self._tensors['obj'] = tf.round(
            self._tensors['z_pres'])  # for `build_math_representation`
        self._tensors['z_pres_probs'] = process_tensor_array(
            z_pres_probs, 'z_pres_probs')
        self._tensors['z_pres_kl'] = process_tensor_array(
            z_pres_kl, 'z_pres_kl')

        self._tensors['vae_input'] = process_tensor_array(
            vae_input, 'vae_input')
        self._tensors['vae_output'] = process_tensor_array(
            vae_output, 'vae_output')

        reconstruction = tf.clip_by_value(reconstruction, 0.0, 1.0)

        flat_inp = tf.layers.flatten(self._tensors["inp"])

        self._tensors['per_pixel_reconstruction_loss'] = xent_loss(
            pred=reconstruction, label=flat_inp)
        self.losses.update(
            reconstruction=tf_mean_sum(
                self._tensors['per_pixel_reconstruction_loss']),
            running=self.kl_weight * tf.reduce_mean(kl_loss),
        )

        self._tensors['output'] = tf.reshape(
            reconstruction, (self.batch_size, ) + self.obs_shape)

        count_error = 1 - tf.to_float(
            tf.equal(self.target_n_digits, self.predicted_n_digits))
        count_1norm = tf.abs(self.target_n_digits - self.predicted_n_digits)

        self.record_tensors(
            predicted_n_digits=self.predicted_n_digits,
            count_error=count_error,
            count_1norm=count_1norm,
            scale=self._tensors["scale"],
            x=self._tensors["shift"][:, :, 0],
            y=self._tensors["shift"][:, :, 1],
            z_pres_prob=self._tensors["z_pres_probs"],
            z_pres_kl=self._tensors["z_pres_kl"],
            scale_kl=self._tensors["scale_kl"],
            shift_kl=self._tensors["shift_kl"],
            attr_kl=self._tensors["attr_kl"],
            scale_std=self._tensors["scale_std"],
            shift_std=self._tensors["shift_std"],
            attr_std=self._tensors["attr_std"],
        )
Ejemplo n.º 9
0
    def build_representation(self):
        # --- init modules ---

        if self.encoder is None:
            self.encoder = self.build_encoder(scope="encoder")
            if "encoder" in self.fixed_weights:
                self.encoder.fix_variables()

        if self.cell is None and self.build_cell is not None:
            self.cell = cfg.build_cell(scope="cell")
            if "cell" in self.fixed_weights:
                self.cell.fix_variables()

        if self.decoder is None:
            self.decoder = cfg.build_decoder(scope="decoder")
            if "decoder" in self.fixed_weights:
                self.decoder.fix_variables()

        # --- encode ---

        inp_trailing_shape = tf_shape(self.inp)[2:]
        video = tf.reshape(self.inp, (self.batch_size * self.dynamic_n_frames, *inp_trailing_shape))
        encoder_output = self.encoder(video, 2 * self.A, self.is_training)

        eo_trailing_shape = tf_shape(encoder_output)[1:]
        encoder_output = tf.reshape(
            encoder_output, (self.batch_size, self.dynamic_n_frames, *eo_trailing_shape))

        if self.cell is None:
            attr = encoder_output
        else:

            if self.flat_latent:
                n_trailing_dims = int(np.prod(eo_trailing_shape))
                encoder_output = tf.reshape(
                    encoder_output, (self.batch_size, self.dynamic_n_frames, n_trailing_dims))
            else:
                raise Exception("NotImplemented")

                n_objects = int(np.prod(eo_trailing_shape[:-1]))
                D = eo_trailing_shape[-1]
                encoder_output = tf.reshape(
                    encoder_output, (self.batch_size, self.dynamic_n_frames, n_objects, D))

            encoder_output = tf.layers.flatten(encoder_output)

            attr, final_state = dynamic_rnn(
                self.cell, encoder_output, initial_state=self.cell.zero_state(self.batch_size, tf.float32),
                parallel_iterations=1, swap_memory=False, time_major=False)

        attr_mean, attr_log_std = tf.split(attr, 2, axis=-1)
        attr_std = tf.math.softplus(attr_log_std)

        if not self.noisy:
            attr_std = tf.zeros_like(attr_std)

        attr, attr_kl = normal_vae(attr_mean, attr_std, self.attr_prior_mean, self.attr_prior_std)

        self._tensors.update(attr_mean=attr_mean, attr_std=attr_std, attr_kl=attr_kl, attr=attr)

        # --- decode ---

        decoder_input = tf.reshape(attr, (self.batch_size*self.dynamic_n_frames, *tf_shape(attr)[2:]))

        reconstruction = self.decoder(decoder_input, tf_shape(self.inp)[2:], self.is_training)
        reconstruction = reconstruction[:, :self.obs_shape[1], :self.obs_shape[2], :]
        reconstruction = tf.reshape(reconstruction, (self.batch_size, self.dynamic_n_frames, *self.obs_shape[1:]))

        reconstruction = tf.nn.sigmoid(tf.clip_by_value(reconstruction, -10, 10))
        self._tensors["output"] = reconstruction

        # --- losses ---

        if self.train_kl:
            self.losses['attr_kl'] = tf_mean_sum(self._tensors["attr_kl"])

        if self.train_reconstruction:
            self._tensors['per_pixel_reconstruction_loss'] = xent_loss(pred=reconstruction, label=self.inp)
            self.losses['reconstruction'] = tf_mean_sum(self._tensors['per_pixel_reconstruction_loss'])