Ejemplo n.º 1
0
    def __call__(self, updater):
        self.fetches = "inp output"

        fetched = self._fetch(updater)
        fetched = Config(fetched)

        inp = fetched['inp']
        output = fetched['output']
        T = inp.shape[1]
        mean_image = np.tile(inp.mean(axis=1, keepdims=True), (1, T, 1, 1, 1))

        B = inp.shape[0]

        fig_unit_size = 3

        fig_height = B * fig_unit_size
        fig_width = 7 * fig_unit_size

        diff = self.normalize_images(
            np.abs(inp - output).sum(axis=-1, keepdims=True))
        xent = self.normalize_images(
            xent_loss(pred=output, label=inp, tf=False).sum(axis=-1,
                                                            keepdims=True))

        diff_mean = self.normalize_images(
            np.abs(mean_image - output).sum(axis=-1, keepdims=True))
        xent_mean = self.normalize_images(
            xent_loss(pred=mean_image, label=inp, tf=False).sum(axis=-1,
                                                                keepdims=True))

        path = self.path_for("animation", updater, ext=None)

        fig, axes, anim, path = animate(inp,
                                        output,
                                        diff.astype('f'),
                                        xent.astype('f'),
                                        mean_image,
                                        diff_mean.astype('f'),
                                        xent_mean.astype('f'),
                                        figsize=(fig_width, fig_height),
                                        path=path,
                                        square_grid=False)
        plt.close()
Ejemplo n.º 2
0
    def _prepare_fetched(self, fetched):
        inp = fetched['inp']
        output = fetched['output']
        prediction = fetched.get("prediction", None)
        targets = fetched.get("targets", None)

        N, T, image_height, image_width, _ = inp.shape

        flat_obj = fetched['obj'].reshape(N, T, -1)
        background = fetched['background']

        box = (
            fetched['normalized_box']
            * [image_height, image_width, image_height, image_width]
        )
        flat_box = box.reshape(N, T, -1, 4)

        n_annotations = fetched.get("n_annotations", np.zeros(N, dtype='i'))
        annotations = fetched.get("annotations", None)
        # actions = fetched.get("actions", None)

        diff = self.normalize_images(np.abs(inp - output).mean(axis=-1, keepdims=True))
        xent = self.normalize_images(
            xent_loss(pred=output, label=inp, tf=False).mean(axis=-1, keepdims=True))

        learned_bg = "bg_y" in fetched
        bg_y = fetched.get("bg_y", None)
        bg_x = fetched.get("bg_x", None)
        bg_h = fetched.get("bg_h", None)
        bg_w = fetched.get("bg_w", None)
        bg_raw = fetched.get("bg_raw", None)

        fetched.update(
            prediction=prediction,
            targets=targets,
            flat_obj=flat_obj,
            background=background,
            box=box,
            flat_box=flat_box,
            n_annotations=n_annotations,
            annotations=annotations,
            diff=diff,
            xent=xent,
            learned_bg=learned_bg,
            bg_y=bg_y,
            bg_x=bg_x,
            bg_h=bg_h,
            bg_w=bg_w,
            bg_raw=bg_raw,
        )
Ejemplo n.º 3
0
    def build_representation(self):

        # --- build graph ---

        self._build_program_generator()

        if self.object_encoder is None:
            self.object_encoder = cfg.build_object_encoder(
                scope="object_encoder")
            if "object_encoder" in self.fixed_weights:
                self.object_encoder.fix_variables()

        if self.object_decoder is None:
            self.object_decoder = cfg.build_object_decoder(
                scope="object_decoder")
            if "object_decoder" in self.fixed_weights:
                self.object_decoder.fix_variables()

        self._build_program_interpreter()

        # --- specify values to record ---

        self.record_tensors(n_objects=self._tensors["n_objects"],
                            attr=self._tensors["attr"])

        # --- losses ---

        if self.train_reconstruction:
            output = self._tensors['output']
            inp = self._tensors['inp']
            self._tensors['per_pixel_reconstruction_loss'] = xent_loss(
                pred=output, label=inp)
            self.losses['reconstruction'] = (
                self.reconstruction_weight *
                tf_mean_sum(self._tensors['per_pixel_reconstruction_loss']))

        if self.train_kl:
            obj = self._tensors["obj"]
            self.losses['attr_kl'] = self.kl_weight * tf_mean_sum(
                obj * self._tensors["attr_kl"])

        # --- other evaluation metrics

        if "n_annotations" in self._tensors:
            count_1norm = tf.to_float(
                tf.abs(
                    tf.to_int32(self._tensors["n_objects"]) -
                    self._tensors["n_annotations"]))
            self.record_tensors(count_1norm=count_1norm,
                                count_error=count_1norm > 0.5)
Ejemplo n.º 4
0
    def _plot_reconstruction(self, updater, fetched):
        inp = fetched['inp']
        output = fetched['output']

        fig_height = 20
        fig_width = 4.5 * fig_height

        diff = self.normalize_images(np.abs(inp - output).sum(axis=-1, keepdims=True) / output.shape[-1])
        xent = self.normalize_images(xent_loss(pred=output, label=inp, tf=False).sum(axis=-1, keepdims=True))

        path = self.path_for("animation", updater, ext=None)
        fig, axes, anim, path = animate(
            inp, output, diff.astype('f'), xent.astype('f'),
            figsize=(fig_width, fig_height), path=path)
        plt.close()
Ejemplo n.º 5
0
    def build_representation(self):
        # --- init modules ---

        if self.encoder is None:
            self.encoder = cfg.build_encoder(scope="encoder")
            if "encoder" in self.fixed_weights:
                self.encoder.fix_variables()

        if self.decoder is None:
            self.decoder = cfg.build_decoder(scope="decoder")
            if "decoder" in self.fixed_weights:
                self.decoder.fix_variables()

        # --- encode ---

        attr = self.encoder(self.inp, 2 * self.A, self.is_training)
        attr_mean, attr_log_std = tf.split(attr, 2, axis=-1)
        attr_std = tf.exp(attr_log_std)

        if not self.noisy:
            attr_std = tf.zeros_like(attr_std)

        attr, attr_kl = normal_vae(attr_mean, attr_std, self.attr_prior_mean, self.attr_prior_std)

        obj_shape = tf.concat([tf.shape(attr)[:-1], [1]], axis=0)
        self._tensors["obj"] = tf.ones(obj_shape)

        self._tensors.update(attr_mean=attr_mean, attr_std=attr_std, attr_kl=attr_kl, attr=attr)

        # --- decode ---

        reconstruction = self.decoder(attr, 3, self.is_training)
        reconstruction = reconstruction[:, :self.inp.shape[1], :self.inp.shape[2], :]

        reconstruction = tf.nn.sigmoid(tf.clip_by_value(reconstruction, -10, 10))
        self._tensors["output"] = reconstruction

        # --- losses ---

        if self.train_kl:
            self.losses['attr_kl'] = tf_mean_sum(self._tensors["attr_kl"])

        if self.train_reconstruction:
            self._tensors['per_pixel_reconstruction_loss'] = xent_loss(pred=reconstruction, label=self.inp)
            self.losses['reconstruction'] = tf_mean_sum(self._tensors['per_pixel_reconstruction_loss'])
Ejemplo n.º 6
0
    def build_representation(self):
        self._tensors["output"] = reconstruction = self._tensors["background"]

        # --- losses ---

        if self.train_reconstruction:
            self._tensors['per_pixel_reconstruction_loss'] = xent_loss(
                pred=reconstruction, label=self.inp)
            self.losses['reconstruction'] = tf_mean_sum(
                self._tensors['per_pixel_reconstruction_loss'])

        if "bg_attr_kl" in self._tensors:
            self.losses.update(
                bg_attr_kl=self.kl_weight *
                tf_mean_sum(self._tensors["bg_attr_kl"]),
                bg_transform_kl=self.kl_weight *
                tf_mean_sum(self._tensors["bg_transform_kl"]),
            )
Ejemplo n.º 7
0
    def build_representation(self):

        self.maybe_build_subnet("object_encoder")
        self.maybe_build_subnet("object_decoder")

        program_tensors = self._build_program_generator(self._tensors)
        self._tensors.update(program_tensors)

        interpreter_tensors = self._build_program_interpreter(self._tensors)
        self._tensors.update(interpreter_tensors)

        # --- specify values to record ---

        self.record_tensors(n_objects=self._tensors["n_objects"],
                            attr=self._tensors["attr"])

        # --- losses ---

        if self.train_reconstruction:
            output = self._tensors['output']
            inp = self._tensors['inp']
            self._tensors['per_pixel_reconstruction_loss'] = xent_loss(
                pred=output, label=inp)
            self.losses['reconstruction'] = (
                self.reconstruction_weight *
                tf_mean_sum(self._tensors['per_pixel_reconstruction_loss']))

        if self.train_kl:
            obj = self._tensors["obj"]
            self.losses['attr_kl'] = self.kl_weight * tf_mean_sum(
                obj * self._tensors["attr_kl"])

        # --- other evaluation metrics

        if "n_annotations" in self._tensors:
            count_1norm = tf.to_float(
                tf.abs(
                    tf.to_int32(self._tensors["n_objects"]) -
                    self._tensors["n_annotations"]))
            self.record_tensors(count_1norm=count_1norm,
                                count_error=count_1norm > 0.5)
Ejemplo n.º 8
0
    def build_representation(self):
        # --- init modules ---
        self.B = len(self.anchor_boxes)

        if self.backbone is None:
            self.backbone = self.build_backbone(scope="backbone")
            if "backbone" in self.fixed_weights:
                self.backbone.fix_variables()

        if self.feature_fuser is None:
            self.feature_fuser = self.build_feature_fuser(scope="feature_fuser")
            if "feature_fuser" in self.fixed_weights:
                self.feature_fuser.fix_variables()

        if self.obj_feature_extractor is None and self.build_obj_feature_extractor is not None:
            self.obj_feature_extractor = self.build_obj_feature_extractor(scope="obj_feature_extractor")
            if "obj_feature_extractor" in self.fixed_weights:
                self.obj_feature_extractor.fix_variables()

        backbone_output, n_grid_cells, grid_cell_size = self.backbone(
            self.inp, self.B*self.n_backbone_features, self.is_training)

        self.H, self.W = [int(i) for i in n_grid_cells]
        self.HWB = self.H * self.W * self.B
        self.pixels_per_cell = tuple(int(i) for i in grid_cell_size)
        H, W, B = self.H, self.W, self.B

        if self.object_layer is None:
            self.object_layer = ObjectLayer(self.pixels_per_cell, scope="objects")

        self.object_rep_tensors = []
        object_rep_tensors = None
        _tensors = defaultdict(list)

        for f in range(self.n_frames):
            print("Bulding network for frame {}".format(f))
            early_frame_features = backbone_output[:, f]

            if f > 0 and self.obj_feature_extractor is not None:
                object_features = object_rep_tensors["all"]
                object_features = tf.reshape(
                    object_features, (self.batch_size, H, W, B*tf_shape(object_features)[-1]))
                early_frame_features += self.obj_feature_extractor(
                    object_features, B*self.n_backbone_features, self.is_training)

            frame_features = self.feature_fuser(
                early_frame_features, B*self.n_backbone_features, self.is_training)

            frame_features = tf.reshape(
                frame_features, (self.batch_size, H, W, B, self.n_backbone_features))

            object_rep_tensors = self.object_layer(
                self.inp[:, f], frame_features, self._tensors["background"][:, f], self.is_training)

            self.object_rep_tensors.append(object_rep_tensors)

            for k, v in object_rep_tensors.items():
                _tensors[k].append(v)

        self._tensors.update(**{k: tf.stack(v, axis=1) for k, v in _tensors.items()})

        # --- specify values to record ---

        obj = self._tensors["obj"]
        pred_n_objects = self._tensors["pred_n_objects"]

        self.record_tensors(
            batch_size=self.batch_size,
            float_is_training=self.float_is_training,

            cell_y=self._tensors["cell_y"],
            cell_x=self._tensors["cell_x"],
            h=self._tensors["h"],
            w=self._tensors["w"],
            z=self._tensors["z"],
            area=self._tensors["area"],

            cell_y_std=self._tensors["cell_y_std"],
            cell_x_std=self._tensors["cell_x_std"],
            h_std=self._tensors["h_std"],
            w_std=self._tensors["w_std"],
            z_std=self._tensors["z_std"],

            n_objects=pred_n_objects,
            obj=obj,

            latent_area=self._tensors["latent_area"],
            latent_hw=self._tensors["latent_hw"],

            attr=self._tensors["attr"],
        )

        # --- losses ---

        if self.train_reconstruction:
            output = self._tensors['output']
            inp = self._tensors['inp']
            self._tensors['per_pixel_reconstruction_loss'] = xent_loss(pred=output, label=inp)
            self.losses['reconstruction'] = (
                self.reconstruction_weight * tf_mean_sum(self._tensors['per_pixel_reconstruction_loss'])
            )

        if self.train_kl:
            self.losses.update(
                obj_kl=self.kl_weight * tf_mean_sum(self._tensors["obj_kl"]),
                cell_y_kl=self.kl_weight * tf_mean_sum(obj * self._tensors["cell_y_kl"]),
                cell_x_kl=self.kl_weight * tf_mean_sum(obj * self._tensors["cell_x_kl"]),
                h_kl=self.kl_weight * tf_mean_sum(obj * self._tensors["h_kl"]),
                w_kl=self.kl_weight * tf_mean_sum(obj * self._tensors["w_kl"]),
                z_kl=self.kl_weight * tf_mean_sum(obj * self._tensors["z_kl"]),
                attr_kl=self.kl_weight * tf_mean_sum(obj * self._tensors["attr_kl"]),
            )

            if cfg.background_cfg.mode == "learn_and_transform":
                self.losses.update(
                    bg_attr_kl=self.kl_weight * tf_mean_sum(self._tensors["bg_attr_kl"]),
                    bg_transform_kl=self.kl_weight * tf_mean_sum(self._tensors["bg_transform_kl"]),
                )

        # --- other evaluation metrics ---

        if "n_annotations" in self._tensors:
            count_1norm = tf.to_float(
                tf.abs(tf.to_int32(self._tensors["pred_n_objects_hard"]) - self._tensors["n_valid_annotations"]))

            self.record_tensors(
                count_1norm=count_1norm,
                count_error=count_1norm > 0.5,
            )
Ejemplo n.º 9
0
    def build_representation(self):
        # --- build graph ---

        self.maybe_build_subnet("backbone")
        assert isinstance(self.backbone, GridConvNet)

        inp = self._tensors["inp"]
        backbone_output = self.backbone(inp, self.n_backbone_features,
                                        self.is_training)
        n_grid_cells = self.backbone.layer_info[-1]['n_grid_cells']
        grid_cell_size = self.backbone.layer_info[-1]['grid_cell_size']

        self.H, self.W = [int(i) for i in n_grid_cells]
        self.HWB = self.H * self.W * self.B
        self.pixels_per_cell = tuple(int(i) for i in grid_cell_size)

        if self.object_layer is None:
            if self.conv_object_layer:
                self.object_layer = ConvGridObjectLayer(
                    pixels_per_cell=self.pixels_per_cell, scope="objects")
            else:
                self.object_layer = GridObjectLayer(
                    pixels_per_cell=self.pixels_per_cell, scope="objects")

        if self.object_renderer is None:
            self.object_renderer = ObjectRenderer(self.anchor_box,
                                                  self.object_shape,
                                                  scope="renderer")

        objects = self.object_layer(self.inp, backbone_output,
                                    self.is_training)
        self._tensors.update(objects)

        kl_tensors = self.object_layer.compute_kl(objects)
        self._tensors.update(kl_tensors)

        if self.obj_kl is None:
            self.obj_kl = self.build_obj_kl()

        self._tensors['obj_kl'] = self.obj_kl(self._tensors)

        render_tensors = self.object_renderer(objects,
                                              self._tensors["background"],
                                              self.is_training)
        self._tensors.update(render_tensors)

        # --- specify values to record ---

        obj = self._tensors["obj"]

        self.record_tensors(
            batch_size=self.batch_size,
            float_is_training=self.float_is_training,
            cell_y=self._tensors["cell_y"],
            cell_x=self._tensors["cell_x"],
            height=self._tensors["height"],
            width=self._tensors["width"],
            z=self._tensors["z"],
            cell_y_std=self._tensors["cell_y_logit_std"],
            cell_x_std=self._tensors["cell_x_logit_std"],
            height_std=self._tensors["height_logit_std"],
            width_std=self._tensors["width_logit_std"],
            z_std=self._tensors["z_logit_std"],
            obj=obj,
            attr=self._tensors["attr"],
            pred_n_objects=self._tensors["pred_n_objects"],
        )

        # --- losses ---

        if self.train_reconstruction:
            output = self._tensors['output']
            inp = self._tensors['inp']
            self._tensors['per_pixel_reconstruction_loss'] = xent_loss(
                pred=output, label=inp)
            self.losses['reconstruction'] = (
                self.reconstruction_weight *
                tf_mean_sum(self._tensors['per_pixel_reconstruction_loss']))

        if self.train_kl:
            self.losses.update(
                obj_kl=self.kl_weight * tf_mean_sum(self._tensors["obj_kl"]),
                cell_y_kl=self.kl_weight *
                tf_mean_sum(obj * self._tensors["cell_y_kl"]),
                cell_x_kl=self.kl_weight *
                tf_mean_sum(obj * self._tensors["cell_x_kl"]),
                height_kl=self.kl_weight *
                tf_mean_sum(obj * self._tensors["height_kl"]),
                width_kl=self.kl_weight *
                tf_mean_sum(obj * self._tensors["width_kl"]),
                z_kl=self.kl_weight * tf_mean_sum(obj * self._tensors["z_kl"]),
                attr_kl=self.kl_weight *
                tf_mean_sum(obj * self._tensors["attr_kl"]),
            )

        # --- other evaluation metrics ---

        if "n_annotations" in self._tensors:
            count_1norm = tf.to_float(
                tf.abs(
                    tf.to_int32(self._tensors["pred_n_objects_hard"]) -
                    self._tensors["n_valid_annotations"]))

            count_1norm_relative = (count_1norm / tf.maximum(
                tf.cast(self._tensors["n_valid_annotations"], tf.float32),
                1e-6))

            self.record_tensors(
                count_1norm_relative=count_1norm_relative,
                count_1norm=count_1norm,
                count_error=count_1norm > 0.5,
            )
Ejemplo n.º 10
0
    def _plot_reconstruction(self, updater, fetched):
        inp = fetched['inp']
        output = fetched['output']

        nb = np.split(fetched['normalized_box'], 4, axis=-1)
        pixel_space_box = tba_coords_to_pixel_space(
            *nb, (updater.image_height, updater.image_width),
            updater.network.anchor_box,
            top_left=True)
        pixel_space_box = np.concatenate(pixel_space_box, axis=-1)

        conf = fetched['conf']
        layer = fetched['layer']
        order = fetched['order']

        H, W = updater.network.H, updater.network.W
        attention_weights = fetched['attention_weights']
        attention_weights = attention_weights.reshape(
            *attention_weights.shape[:-1], H, W)
        memory_activation = fetched['memory_activation']
        memory_activation = memory_activation.reshape(
            *memory_activation.shape[:-1], H, W)

        mask = fetched['mask']
        appearance = fetched['appearance']

        annotations = fetched.get('annotations', None)
        n_annotations = fetched.get('n_annotations',
                                    np.zeros(inp.shape[0], dtype='i'))

        diff = self.normalize_images(
            np.abs(inp - output).sum(axis=-1, keepdims=True) /
            output.shape[-1])
        xent = self.normalize_images(
            xent_loss(pred=output, label=inp, tf=False).sum(axis=-1,
                                                            keepdims=True))

        B, T = inp.shape[:2]
        print("Plotting for {} data points...".format(B))
        n_base_images = 8
        n_images_per_obj = 4
        n_images = n_base_images + n_images_per_obj * updater.network.n_trackers

        fig_unit_size = 4
        fig_height = T * fig_unit_size
        fig_width = n_images * fig_unit_size

        colours = plt.get_cmap('Set3').colors[:updater.network.n_trackers]
        rgb_colours = np.array([to_rgb(c) for c in colours])

        for b in range(B):
            fig, axes = plt.subplots(T,
                                     n_images,
                                     figsize=(fig_width, fig_height))
            for ax in axes.flatten():
                ax.set_axis_off()

            for t in range(T):
                ax = axes[t, 0]
                self.imshow(ax, inp[b, t])
                if t == 0:
                    ax.set_title('input')

                ax = axes[t, 1]
                self.imshow(ax, output[b, t])
                if t == 0:
                    ax.set_title('reconstruction')

                ax = axes[t, 2]
                self.imshow(ax, diff[b, t])
                if t == 0:
                    ax.set_title('abs error')

                ax = axes[t, 3]
                self.imshow(ax, xent[b, t])
                if t == 0:
                    ax.set_title('xent')

                ax = axes[t, 4]
                array = np.concatenate([
                    rgb_colours[:, None, :], conf[b, t, :, :, None] *
                    (1., 1., 1.), layer[b, t, :, :, None] * (1., 1., 1.)
                ],
                                       axis=1)
                self.imshow(ax, array)
                if t == 0:
                    ax.set_title('conf & layer')

                ax = axes[t, 5]
                array = rgb_colours[order[b, t, :, 0]][None, :, :]
                self.imshow(ax, array)
                if t == 0:
                    ax.set_title('order')

                gt_ax = axes[t, 6]
                self.imshow(gt_ax, inp[b, t])
                if t == 0:
                    gt_ax.set_title('gt with boxes')

                rec_ax = axes[t, 7]
                self.imshow(rec_ax, output[b, t])
                if t == 0:
                    rec_ax.set_title('rec with boxes')

                # Plot true bounding boxes
                for k in range(n_annotations[b]):
                    valid, _, _, top, bottom, left, right = annotations[b, t,
                                                                        k]

                    if not valid:
                        continue

                    height = bottom - top
                    width = right - left

                    # make a striped rectangle by superimposing two rectangles with different linestyles

                    rect = patches.Rectangle((left, top),
                                             width,
                                             height,
                                             linewidth=self.linewidth,
                                             edgecolor=self.gt_color,
                                             facecolor='none',
                                             linestyle="-")
                    gt_ax.add_patch(rect)
                    rect = patches.Rectangle((left, top),
                                             width,
                                             height,
                                             linewidth=self.linewidth,
                                             edgecolor=self.gt_color2,
                                             facecolor='none',
                                             linestyle=":")
                    gt_ax.add_patch(rect)

                    rect = patches.Rectangle((left, top),
                                             width,
                                             height,
                                             linewidth=self.linewidth,
                                             edgecolor=self.gt_color,
                                             facecolor='none',
                                             linestyle="-")
                    rec_ax.add_patch(rect)
                    rect = patches.Rectangle((left, top),
                                             width,
                                             height,
                                             linewidth=self.linewidth,
                                             edgecolor=self.gt_color2,
                                             facecolor='none',
                                             linestyle=":")
                    rec_ax.add_patch(rect)

                for i in range(updater.network.n_trackers):
                    top, left, height, width = pixel_space_box[b, t, i]
                    rect = patches.Rectangle((left, top),
                                             width,
                                             height,
                                             linewidth=6,
                                             edgecolor=colours[i],
                                             facecolor='none')
                    gt_ax.add_patch(rect)
                    rect = patches.Rectangle((left, top),
                                             width,
                                             height,
                                             linewidth=6,
                                             edgecolor=colours[i],
                                             facecolor='none')
                    rec_ax.add_patch(rect)

                    ax_appearance = axes[t,
                                         n_base_images + n_images_per_obj * i]
                    self.imshow(ax_appearance, appearance[b, t, i])
                    rect = patches.Rectangle((-0.05, -0.05),
                                             1.1,
                                             1.1,
                                             clip_on=False,
                                             linewidth=20,
                                             transform=ax_appearance.transAxes,
                                             edgecolor=colours[i],
                                             facecolor='none')
                    ax_appearance.add_patch(rect)
                    if t == 0:
                        ax_appearance.set_title('appearance {}'.format(i))

                    ax_mask = axes[t, n_base_images + n_images_per_obj * i + 1]
                    self.imshow(ax_mask, mask[b, t, i])
                    rect = patches.Rectangle((-0.05, -0.05),
                                             1.1,
                                             1.1,
                                             clip_on=False,
                                             linewidth=20,
                                             transform=ax_mask.transAxes,
                                             edgecolor=colours[i],
                                             facecolor='none')
                    ax_mask.add_patch(rect)
                    if t == 0:
                        ax_mask.set_title('mask {}'.format(i))

                    ax_mem = axes[t, n_base_images + n_images_per_obj * i + 2]
                    self.imshow(ax_mem, memory_activation[b, t, i])
                    rect = patches.Rectangle((-0.05, -0.05),
                                             1.1,
                                             1.1,
                                             clip_on=False,
                                             linewidth=20,
                                             transform=ax_mem.transAxes,
                                             edgecolor=colours[i],
                                             facecolor='none')
                    ax_mem.add_patch(rect)
                    if t == 0:
                        ax_mem.set_title('memory_activation {}'.format(i))

                    ax_att = axes[t, n_base_images + n_images_per_obj * i + 3]
                    self.imshow(ax_att, attention_weights[b, t, i])
                    rect = patches.Rectangle((-0.05, -0.05),
                                             1.1,
                                             1.1,
                                             clip_on=False,
                                             linewidth=20,
                                             transform=ax_att.transAxes,
                                             edgecolor=colours[i],
                                             facecolor='none')
                    ax_att.add_patch(rect)
                    if t == 0:
                        ax_att.set_title('attention_weights {}'.format(i))

            plt.subplots_adjust(left=0.02,
                                right=.98,
                                top=.98,
                                bottom=0.02,
                                wspace=0.1,
                                hspace=0.1)
            self.savefig("tba/" + str(b), fig, updater)
Ejemplo n.º 11
0
    def build_representation(self):
        # --- build graph ---

        self.maybe_build_subnet("backbone")
        assert isinstance(self.backbone, GridConvNet)

        inp = self._tensors["inp"]
        backbone_output, n_grid_cells, grid_cell_size = self.backbone(
            inp, self.B * self.n_backbone_features, self.is_training)

        self.H, self.W = [int(i) for i in n_grid_cells]
        self.HWB = self.H * self.W * self.B
        self.pixels_per_cell = tuple(int(i) for i in grid_cell_size)

        backbone_output = tf.reshape(
            backbone_output,
            (-1, self.H, self.W, self.B, self.n_backbone_features))

        if self.object_layer is None:
            self.object_layer = GridObjectLayer(self.pixels_per_cell,
                                                scope="objects")

        if self.object_renderer is None:
            self.object_renderer = ObjectRenderer(scope="renderer")

        objects = self.object_layer(self.inp, backbone_output,
                                    self.is_training)
        self._tensors.update(objects)

        kl_tensors = self.object_layer.compute_kl(objects)
        self._tensors.update(kl_tensors)

        render_tensors = self.object_renderer(objects,
                                              self._tensors["background"],
                                              self.is_training)
        self._tensors.update(render_tensors)

        # --- specify values to record ---

        obj = self._tensors["obj"]
        pred_n_objects = self._tensors["pred_n_objects"]

        self.record_tensors(
            batch_size=self.batch_size,
            float_is_training=self.float_is_training,
            cell_y=self._tensors["cell_y"],
            cell_x=self._tensors["cell_x"],
            height=self._tensors["height"],
            width=self._tensors["width"],
            z=self._tensors["z"],
            cell_y_std=self._tensors["cell_y_logit_dist"].scale,
            cell_x_std=self._tensors["cell_x_logit_dist"].scale,
            height_std=self._tensors["height_logit_dist"].scale,
            width_std=self._tensors["width_logit_dist"].scale,
            z_std=self._tensors["z_logit_dist"].scale,
            n_objects=pred_n_objects,
            obj=obj,
            on_cell_y_avg=tf.reduce_sum(self._tensors["cell_y"] * obj,
                                        axis=(1, 2, 3, 4)) / pred_n_objects,
            on_cell_x_avg=tf.reduce_sum(self._tensors["cell_x"] * obj,
                                        axis=(1, 2, 3, 4)) / pred_n_objects,
            on_height_avg=tf.reduce_sum(self._tensors["height"] * obj,
                                        axis=(1, 2, 3, 4)) / pred_n_objects,
            on_width_avg=tf.reduce_sum(self._tensors["width"] * obj,
                                       axis=(1, 2, 3, 4)) / pred_n_objects,
            on_z_avg=tf.reduce_sum(self._tensors["z"] * obj, axis=(1, 2, 3, 4))
            / pred_n_objects,
            attr=self._tensors["attr"],
        )

        # --- losses ---

        if self.train_reconstruction:
            output = self._tensors['output']
            inp = self._tensors['inp']
            self._tensors['per_pixel_reconstruction_loss'] = xent_loss(
                pred=output, label=inp)
            self.losses['reconstruction'] = (
                self.reconstruction_weight *
                tf_mean_sum(self._tensors['per_pixel_reconstruction_loss']))

        if self.train_kl:
            self.losses.update(
                obj_kl=self.kl_weight * tf_mean_sum(self._tensors["obj_kl"]),
                cell_y_kl=self.kl_weight *
                tf_mean_sum(obj * self._tensors["cell_y_kl"]),
                cell_x_kl=self.kl_weight *
                tf_mean_sum(obj * self._tensors["cell_x_kl"]),
                height_kl=self.kl_weight *
                tf_mean_sum(obj * self._tensors["height_kl"]),
                width_kl=self.kl_weight *
                tf_mean_sum(obj * self._tensors["width_kl"]),
                z_kl=self.kl_weight * tf_mean_sum(obj * self._tensors["z_kl"]),
                attr_kl=self.kl_weight *
                tf_mean_sum(obj * self._tensors["attr_kl"]),
            )

        # --- other evaluation metrics ---

        if "n_annotations" in self._tensors:
            count_1norm = tf.to_float(
                tf.abs(
                    tf.to_int32(self._tensors["pred_n_objects_hard"]) -
                    self._tensors["n_valid_annotations"]))

            self.record_tensors(
                count_1norm=count_1norm,
                count_error=count_1norm > 0.5,
            )
Ejemplo n.º 12
0
    def build_representation(self):
        # --- process input ---

        if self.image_encoder is None:
            self.image_encoder = cfg.build_image_encoder(scope="image_encoder")
            if "image_encoder" in self.fixed_weights:
                self.image_encoder.fix_variables()

        if self.cell is None:
            self.cell = cfg.build_cell(scope="cell")
            if "cell" in self.fixed_weights:
                self.cell.fix_variables()

        if self.output_network is None:
            self.output_network = cfg.build_output_network(
                scope="output_network")
            if "output" in self.fixed_weights:
                self.output_network.fix_variables()

        if self.object_encoder is None:
            self.object_encoder = cfg.build_object_encoder(
                scope="object_encoder")
            if "object_encoder" in self.fixed_weights:
                self.object_encoder.fix_variables()

        if self.object_decoder is None:
            self.object_decoder = cfg.build_object_decoder(
                scope="object_decoder")
            if "object_decoder" in self.fixed_weights:
                self.object_decoder.fix_variables()

        self.target_n_digits = self._tensors["n_valid_annotations"]

        if not self.difference_air:
            encoded_inp = self.image_encoder(self._tensors["inp"], 0,
                                             self.is_training)
            self.encoded_inp = tf.layers.flatten(encoded_inp)

        # --- condition of while-loop ---

        def cond(step, stopping_sum, *_):
            return tf.logical_and(
                tf.less(step, self.max_time_steps),
                tf.reduce_any(tf.less(stopping_sum, self.stopping_threshold)))

        # --- body of while-loop ---

        def body(step, stopping_sum, prev_state, running_recon, kl_loss,
                 running_digits, scale_ta, scale_kl_ta, scale_std_ta, shift_ta,
                 shift_kl_ta, shift_std_ta, attr_ta, attr_kl_ta, attr_std_ta,
                 z_pres_ta, z_pres_probs_ta, z_pres_kl_ta, vae_input_ta,
                 vae_output_ta, scale, shift, attr, z_pres):

            if self.difference_air:
                inp = (self._tensors["inp"] -
                       tf.reshape(running_recon,
                                  (self.batch_size, *self.obs_shape)))
                encoded_inp = self.image_encoder(inp, 0, self.is_training)
                encoded_inp = tf.layers.flatten(encoded_inp)
            else:
                encoded_inp = self.encoded_inp

            if self.complete_rnn_input:
                rnn_input = tf.concat(
                    [encoded_inp, scale, shift, attr, z_pres], axis=1)
            else:
                rnn_input = encoded_inp

            hidden_rep, next_state = self.cell(rnn_input, prev_state)

            outputs = self.output_network(hidden_rep, 9, self.is_training)

            (scale_mean, scale_log_std, shift_mean, shift_log_std,
             z_pres_log_odds) = tf.split(outputs, [2, 2, 2, 2, 1], axis=1)

            # --- scale ---

            scale_std = tf.exp(scale_log_std)

            scale_mean = self.apply_fixed_value("scale_mean", scale_mean)
            scale_std = self.apply_fixed_value("scale_std", scale_std)

            scale_logits, scale_kl = normal_vae(scale_mean, scale_std,
                                                self.scale_prior_mean,
                                                self.scale_prior_std)
            scale_kl = tf.reduce_sum(scale_kl, axis=1, keepdims=True)
            scale = tf.nn.sigmoid(tf.clip_by_value(scale_logits, -10, 10))

            # --- shift ---

            shift_std = tf.exp(shift_log_std)

            shift_mean = self.apply_fixed_value("shift_mean", shift_mean)
            shift_std = self.apply_fixed_value("shift_std", shift_std)

            shift_logits, shift_kl = normal_vae(shift_mean, shift_std,
                                                self.shift_prior_mean,
                                                self.shift_prior_std)
            shift_kl = tf.reduce_sum(shift_kl, axis=1, keepdims=True)
            shift = tf.nn.tanh(tf.clip_by_value(shift_logits, -10, 10))

            # --- Extract windows from scene ---

            w, h = scale[:, 0:1], scale[:, 1:2]
            x, y = shift[:, 0:1], shift[:, 1:2]

            theta = tf.concat(
                [w, tf.zeros_like(w), x,
                 tf.zeros_like(h), h, y], axis=1)
            theta = tf.reshape(theta, (-1, 2, 3))

            vae_input = transformer(self._tensors["inp"], theta,
                                    self.object_shape)

            # This is a necessary reshape, as the output of transformer will have unknown dims
            vae_input = tf.reshape(
                vae_input,
                (self.batch_size, *self.object_shape, self.image_depth))

            # --- Apply Object-level VAE (object encoder/object decoder) to windows ---

            attr = self.object_encoder(vae_input, 2 * self.A, self.is_training)
            attr_mean, attr_log_std = tf.split(attr, 2, axis=1)
            attr_std = tf.exp(attr_log_std)
            attr, attr_kl = normal_vae(attr_mean, attr_std,
                                       self.attr_prior_mean,
                                       self.attr_prior_std)
            attr_kl = tf.reduce_sum(attr_kl, axis=1, keepdims=True)

            vae_output = self.object_decoder(
                attr,
                self.object_shape[0] * self.object_shape[1] * self.image_depth,
                self.is_training)
            vae_output = tf.nn.sigmoid(tf.clip_by_value(vae_output, -10, 10))

            # --- Place reconstructed objects in image ---

            theta_inverse = tf.concat([
                1. / w,
                tf.zeros_like(w), -x / w,
                tf.zeros_like(h), 1. / h, -y / h
            ],
                                      axis=1)
            theta_inverse = tf.reshape(theta_inverse, (-1, 2, 3))

            vae_output_transformed = transformer(
                tf.reshape(vae_output, (
                    self.batch_size,
                    *self.object_shape,
                    self.image_depth,
                )), theta_inverse, self.obs_shape[:2])
            vae_output_transformed = tf.reshape(vae_output_transformed, [
                self.batch_size,
                self.image_height * self.image_width * self.image_depth
            ])

            # --- z_pres ---

            if self.run_all_time_steps:
                z_pres = tf.ones_like(z_pres_log_odds)
                z_pres_prob = tf.ones_like(z_pres_log_odds)
                z_pres_kl = tf.zeros_like(z_pres_log_odds)
            else:
                z_pres_log_odds = tf.clip_by_value(z_pres_log_odds, -10, 10)

                z_pres_pre_sigmoid = concrete_binary_pre_sigmoid_sample(
                    z_pres_log_odds, self.z_pres_temperature)
                z_pres = tf.nn.sigmoid(z_pres_pre_sigmoid)
                z_pres = (self.float_is_training * z_pres +
                          (1 - self.float_is_training) * tf.round(z_pres))
                z_pres_prob = tf.nn.sigmoid(z_pres_log_odds)
                z_pres_kl = concrete_binary_sample_kl(
                    z_pres_pre_sigmoid,
                    z_pres_log_odds,
                    self.z_pres_temperature,
                    self.z_pres_prior_log_odds,
                    self.z_pres_temperature,
                )

            stopping_sum += (1.0 - z_pres)
            alive = tf.less(stopping_sum, self.stopping_threshold)
            running_digits += tf.to_int32(alive)

            # --- adjust reconstruction ---

            running_recon += tf.where(
                tf.tile(alive, (1, vae_output_transformed.shape[1])),
                z_pres * vae_output_transformed, tf.zeros_like(running_recon))

            # --- add kl to loss ---

            kl_loss += tf.where(alive, scale_kl, tf.zeros_like(kl_loss))
            kl_loss += tf.where(alive, shift_kl, tf.zeros_like(kl_loss))
            kl_loss += tf.where(alive, attr_kl, tf.zeros_like(kl_loss))
            kl_loss += tf.where(alive, z_pres_kl, tf.zeros_like(kl_loss))

            # --- record values ---

            scale_ta = scale_ta.write(scale_ta.size(), scale)
            scale_kl_ta = scale_kl_ta.write(scale_kl_ta.size(), scale_kl)
            scale_std_ta = scale_std_ta.write(scale_std_ta.size(), scale_std)

            shift_ta = shift_ta.write(shift_ta.size(), shift)
            shift_kl_ta = shift_kl_ta.write(shift_kl_ta.size(), shift_kl)
            shift_std_ta = shift_std_ta.write(shift_std_ta.size(), shift_std)

            attr_ta = attr_ta.write(attr_ta.size(), attr)
            attr_kl_ta = attr_kl_ta.write(attr_kl_ta.size(), attr_kl)
            attr_std_ta = attr_std_ta.write(attr_std_ta.size(), attr_std)

            vae_input_ta = vae_input_ta.write(vae_input_ta.size(),
                                              tf.layers.flatten(vae_input))
            vae_output_ta = vae_output_ta.write(vae_output_ta.size(),
                                                vae_output)

            z_pres_ta = z_pres_ta.write(z_pres_ta.size(), z_pres)
            z_pres_probs_ta = z_pres_probs_ta.write(z_pres_probs_ta.size(),
                                                    z_pres_prob)
            z_pres_kl_ta = z_pres_kl_ta.write(z_pres_kl_ta.size(), z_pres_kl)

            return (
                step + 1,
                stopping_sum,
                next_state,
                running_recon,
                kl_loss,
                running_digits,
                scale_ta,
                scale_kl_ta,
                scale_std_ta,
                shift_ta,
                shift_kl_ta,
                shift_std_ta,
                attr_ta,
                attr_kl_ta,
                attr_std_ta,
                z_pres_ta,
                z_pres_probs_ta,
                z_pres_kl_ta,
                vae_input_ta,
                vae_output_ta,
                scale,
                shift,
                attr,
                z_pres,
            )

        # --- end of while-loop body ---

        rnn_init_state = self.cell.zero_state(self.batch_size, tf.float32)

        (_, _, _, reconstruction, kl_loss, self.predicted_n_digits, scale,
         scale_kl, scale_std, shift, shift_kl, shift_std, attr, attr_kl,
         attr_std, z_pres, z_pres_probs, z_pres_kl, vae_input, vae_output, _,
         _, _, _) = tf.while_loop(
             cond,
             body,
             [
                 tf.constant(0),  # RNN time step, initially zero
                 tf.zeros(
                     (self.batch_size, 1)),  # running sum of z_pres samples
                 rnn_init_state,  # initial RNN state
                 tf.zeros((self.batch_size, np.product(self.obs_shape)
                           )),  # reconstruction canvas, initially empty
                 tf.zeros((self.batch_size,
                           1)),  # running value of the loss function
                 tf.zeros((self.batch_size, 1),
                          dtype=tf.int32),  # running inferred number of digits
                 tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True),
                 tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True),
                 tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True),
                 tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True),
                 tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True),
                 tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True),
                 tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True),
                 tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True),
                 tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True),
                 tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True),
                 tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True),
                 tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True),
                 tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True),
                 tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True),
                 tf.zeros((self.batch_size, 2)),  # scale
                 tf.zeros((self.batch_size, 2)),  # shift
                 tf.zeros((self.batch_size, self.A)),  # attr
                 tf.zeros((self.batch_size, 1)),  # z_pres
             ])

        def process_tensor_array(tensor_array, name, shape=None):
            tensor = tf.transpose(tensor_array.stack(), (1, 0, 2))

            time_pad = self.max_time_steps - tf.shape(tensor)[1]
            padding = [[0, 0], [0, time_pad]]
            padding += [[0, 0]] * (len(tensor.shape) - 2)

            tensor = tf.pad(tensor, padding, name=name)

            if shape is not None:
                tensor = tf.reshape(tensor, shape)

            return tensor

        self.predicted_n_digits = self.predicted_n_digits[:, 0]
        self._tensors["predicted_n_digits"] = self.predicted_n_digits

        self._tensors['scale'] = process_tensor_array(scale, 'scale')
        self._tensors['scale_kl'] = process_tensor_array(scale_kl, 'scale_kl')
        self._tensors['scale_std'] = process_tensor_array(
            scale_std, 'scale_std')

        self._tensors['shift'] = process_tensor_array(shift, 'shift')
        self._tensors['shift_kl'] = process_tensor_array(shift_kl, 'shift_kl')
        self._tensors['shift_std'] = process_tensor_array(
            shift_std, 'shift_std')

        self._tensors['attr'] = process_tensor_array(
            attr, 'attr', (self.batch_size, self.max_time_steps, self.A))
        self._tensors['attr_kl'] = process_tensor_array(attr_kl, 'attr_kl')
        self._tensors['attr_std'] = process_tensor_array(attr_std, 'attr_std')

        self._tensors['z_pres'] = process_tensor_array(
            z_pres, 'z_pres', (self.batch_size, self.max_time_steps, 1))
        self._tensors['obj'] = tf.round(
            self._tensors['z_pres'])  # for `build_math_representation`
        self._tensors['z_pres_probs'] = process_tensor_array(
            z_pres_probs, 'z_pres_probs')
        self._tensors['z_pres_kl'] = process_tensor_array(
            z_pres_kl, 'z_pres_kl')

        self._tensors['vae_input'] = process_tensor_array(
            vae_input, 'vae_input')
        self._tensors['vae_output'] = process_tensor_array(
            vae_output, 'vae_output')

        reconstruction = tf.clip_by_value(reconstruction, 0.0, 1.0)

        flat_inp = tf.layers.flatten(self._tensors["inp"])

        self._tensors['per_pixel_reconstruction_loss'] = xent_loss(
            pred=reconstruction, label=flat_inp)
        self.losses.update(
            reconstruction=tf_mean_sum(
                self._tensors['per_pixel_reconstruction_loss']),
            running=self.kl_weight * tf.reduce_mean(kl_loss),
        )

        self._tensors['output'] = tf.reshape(
            reconstruction, (self.batch_size, ) + self.obs_shape)

        count_error = 1 - tf.to_float(
            tf.equal(self.target_n_digits, self.predicted_n_digits))
        count_1norm = tf.abs(self.target_n_digits - self.predicted_n_digits)

        self.record_tensors(
            predicted_n_digits=self.predicted_n_digits,
            count_error=count_error,
            count_1norm=count_1norm,
            scale=self._tensors["scale"],
            x=self._tensors["shift"][:, :, 0],
            y=self._tensors["shift"][:, :, 1],
            z_pres_prob=self._tensors["z_pres_probs"],
            z_pres_kl=self._tensors["z_pres_kl"],
            scale_kl=self._tensors["scale_kl"],
            shift_kl=self._tensors["shift_kl"],
            attr_kl=self._tensors["attr_kl"],
            scale_std=self._tensors["scale_std"],
            shift_std=self._tensors["shift_std"],
            attr_std=self._tensors["attr_std"],
        )
Ejemplo n.º 13
0
    def build_representation(self):
        # --- init modules ---

        if self.encoder is None:
            self.encoder = self.build_encoder(scope="encoder")
            if "encoder" in self.fixed_weights:
                self.encoder.fix_variables()

        if self.cell is None and self.build_cell is not None:
            self.cell = cfg.build_cell(scope="cell")
            if "cell" in self.fixed_weights:
                self.cell.fix_variables()

        if self.decoder is None:
            self.decoder = cfg.build_decoder(scope="decoder")
            if "decoder" in self.fixed_weights:
                self.decoder.fix_variables()

        # --- encode ---

        inp_trailing_shape = tf_shape(self.inp)[2:]
        video = tf.reshape(self.inp, (self.batch_size * self.dynamic_n_frames, *inp_trailing_shape))
        encoder_output = self.encoder(video, 2 * self.A, self.is_training)

        eo_trailing_shape = tf_shape(encoder_output)[1:]
        encoder_output = tf.reshape(
            encoder_output, (self.batch_size, self.dynamic_n_frames, *eo_trailing_shape))

        if self.cell is None:
            attr = encoder_output
        else:

            if self.flat_latent:
                n_trailing_dims = int(np.prod(eo_trailing_shape))
                encoder_output = tf.reshape(
                    encoder_output, (self.batch_size, self.dynamic_n_frames, n_trailing_dims))
            else:
                raise Exception("NotImplemented")

                n_objects = int(np.prod(eo_trailing_shape[:-1]))
                D = eo_trailing_shape[-1]
                encoder_output = tf.reshape(
                    encoder_output, (self.batch_size, self.dynamic_n_frames, n_objects, D))

            encoder_output = tf.layers.flatten(encoder_output)

            attr, final_state = dynamic_rnn(
                self.cell, encoder_output, initial_state=self.cell.zero_state(self.batch_size, tf.float32),
                parallel_iterations=1, swap_memory=False, time_major=False)

        attr_mean, attr_log_std = tf.split(attr, 2, axis=-1)
        attr_std = tf.math.softplus(attr_log_std)

        if not self.noisy:
            attr_std = tf.zeros_like(attr_std)

        attr, attr_kl = normal_vae(attr_mean, attr_std, self.attr_prior_mean, self.attr_prior_std)

        self._tensors.update(attr_mean=attr_mean, attr_std=attr_std, attr_kl=attr_kl, attr=attr)

        # --- decode ---

        decoder_input = tf.reshape(attr, (self.batch_size*self.dynamic_n_frames, *tf_shape(attr)[2:]))

        reconstruction = self.decoder(decoder_input, tf_shape(self.inp)[2:], self.is_training)
        reconstruction = reconstruction[:, :self.obs_shape[1], :self.obs_shape[2], :]
        reconstruction = tf.reshape(reconstruction, (self.batch_size, self.dynamic_n_frames, *self.obs_shape[1:]))

        reconstruction = tf.nn.sigmoid(tf.clip_by_value(reconstruction, -10, 10))
        self._tensors["output"] = reconstruction

        # --- losses ---

        if self.train_kl:
            self.losses['attr_kl'] = tf_mean_sum(self._tensors["attr_kl"])

        if self.train_reconstruction:
            self._tensors['per_pixel_reconstruction_loss'] = xent_loss(pred=reconstruction, label=self.inp)
            self.losses['reconstruction'] = tf_mean_sum(self._tensors['per_pixel_reconstruction_loss'])