Example #1
0
    def generate_trajs(self, data, boxes):
        with torch.no_grad():
            num_objs = boxes.shape[2]
            g_idx = np.array([[i, j, 1] for i in range(num_objs)
                              for j in range(num_objs) if j != i])
            g_idx = torch.from_numpy(g_idx[None].repeat(data.shape[0], 0))
            rois = xyxy_to_rois(boxes,
                                batch=data.shape[0],
                                time_step=data.shape[1],
                                num_devices=self.num_gpus)
            outputs = self.model(data,
                                 rois,
                                 num_rollouts=self.pred_rollout,
                                 g_idx=g_idx)
            outputs = {
                'boxes': outputs['boxes'].cpu().numpy(),
                'masks': outputs['masks'].cpu().numpy(),
                'score': outputs['score'].cpu().numpy(),
            }
            outputs['boxes'][..., 0::2] *= self.input_width
            outputs['boxes'][..., 1::2] *= self.input_height
            outputs['boxes'] = xywh2xyxy(outputs['boxes'].reshape(
                -1, 4)).reshape((data.shape[0], -1, num_objs, 4))

        return outputs['boxes'], outputs['masks'], outputs['score']
Example #2
0
    def train_epoch(self):
        for batch_idx, (data, data_t, rois, gt_boxes, gt_masks, valid,
                        g_idx) in enumerate(self.train_loader):
            self._adjust_learning_rate()
            data, data_t = data.to(self.device), data_t.to(self.device)
            rois = xyxy_to_rois(rois,
                                batch=data.shape[0],
                                time_step=data.shape[1],
                                num_devices=self.num_gpus)
            self.optim.zero_grad()

            outputs = self.model(data,
                                 rois,
                                 num_rollouts=self.ptrain_size,
                                 g_idx=g_idx,
                                 x_t=data_t,
                                 phase='train')
            labels = {
                'boxes': gt_boxes.to(self.device),
                'masks': gt_masks.to(self.device),
                'valid': valid.to(self.device),
            }
            loss = self.loss(outputs, labels, 'train')
            loss.backward()
            self.optim.step()
            # this is an approximation for printing; the dataset size may not divide the batch size
            self.iterations += self.batch_size

            print_msg = ""
            print_msg += f"{self.epochs:03}/{self.iterations // 1000:04}k"
            print_msg += f" | "
            mean_loss = np.mean(
                np.array(self.box_p_step_losses[:self.ptrain_size]) /
                self.loss_cnt) * 1e3
            print_msg += f"{mean_loss:.3f} | "
            print_msg += f" | ".join([
                "{:.3f}".format(self.losses[name] * 1e3 / self.loss_cnt)
                for name in self.loss_name
            ])
            speed = self.loss_cnt / (timer() - self.time)
            eta = (self.max_iters - self.iterations) / speed / 3600
            print_msg += f" | speed: {speed:.1f} | eta: {eta:.2f} h"
            print_msg += (
                " " * (os.get_terminal_size().columns - len(print_msg) - 10))
            tprint(print_msg)

            if self.iterations % self.val_interval == 0:
                self.snapshot()
                self.val()
                self._init_loss()
                self.model.train()

            if self.iterations >= self.max_iters:
                print('\r', end='')
                print(f'{self.best_mean:.3f}')
                break
Example #3
0
    def val(self):
        self.model.eval()
        self._init_loss()

        if C.RPIN.VAE:
            losses = dict.fromkeys(self.loss_name, 0.0)
            box_p_step_losses = [0.0 for _ in range(self.ptest_size)]
            masks_step_losses = [0.0 for _ in range(self.ptest_size)]

        for batch_idx, (data, _, rois, gt_boxes, gt_masks, valid, g_idx,
                        seq_l) in enumerate(self.val_loader):
            tprint(f'eval: {batch_idx}/{len(self.val_loader)}')
            with torch.no_grad():

                if C.RPIN.ROI_MASKING or C.RPIN.ROI_CROPPING:
                    # data should be (b x t x o x c x h x w)
                    data = data.permute(
                        (0, 2, 1, 3, 4, 5))  # (b, o, t, c, h, w)
                    data = data.reshape((data.shape[0] * data.shape[1], ) +
                                        data.shape[2:])  # (b*o, t, c, h, w)

                data = data.to(self.device)
                rois = xyxy_to_rois(rois,
                                    batch=data.shape[0],
                                    time_step=data.shape[1],
                                    num_devices=self.num_gpus)
                labels = {
                    'boxes': gt_boxes.to(self.device),
                    'masks': gt_masks.to(self.device),
                    'valid': valid.to(self.device),
                    'seq_l': seq_l.to(self.device),
                }

                outputs = self.model(data,
                                     rois,
                                     num_rollouts=self.ptest_size,
                                     g_idx=g_idx,
                                     phase='test')
                self.loss(outputs, labels, 'test')
                # VAE multiple runs
                if C.RPIN.VAE:
                    vae_best_mean = np.mean(
                        np.array(self.box_p_step_losses[:self.ptest_size]) /
                        self.loss_cnt) * 1e3
                    losses_t = self.losses.copy()
                    box_p_step_losses_t = self.box_p_step_losses.copy()
                    masks_step_losses_t = self.masks_step_losses.copy()
                    for i in range(9):
                        outputs = self.model(data,
                                             rois,
                                             num_rollouts=self.ptest_size,
                                             g_idx=g_idx,
                                             phase='test')
                        self.loss(outputs, labels, 'test')
                        mean_loss = np.mean(
                            np.array(self.box_p_step_losses[:self.ptest_size])
                            / self.loss_cnt) * 1e3
                        if mean_loss < vae_best_mean:
                            losses_t = self.losses.copy()
                            box_p_step_losses_t = self.box_p_step_losses.copy()
                            masks_step_losses_t = self.masks_step_losses.copy()
                            vae_best_mean = mean_loss
                        self._init_loss()

                    for k, v in losses.items():
                        losses[k] += losses_t[k]
                    for i in range(len(box_p_step_losses)):
                        box_p_step_losses[i] += box_p_step_losses_t[i]
                        masks_step_losses[i] += masks_step_losses_t[i]

        if C.RPIN.VAE:
            self.losses = losses.copy()
            self.box_p_step_losses = box_p_step_losses.copy()
            self.loss_cnt = len(self.val_loader)

        print('\r', end='')
        print_msg = ""
        print_msg += f"{self.epochs:03}/{self.iterations // 1000:04}k"
        print_msg += f" | "
        mean_loss = np.mean(
            np.array(self.box_p_step_losses[:self.ptest_size]) /
            self.loss_cnt) * 1e3
        print_msg += f"{mean_loss:.3f} | "

        if mean_loss < self.best_mean:
            self.snapshot('ckpt_best.path.tar')
            self.best_mean = mean_loss

        print_msg += f" | ".join([
            "{:.3f}".format(self.losses[name] * 1e3 / self.loss_cnt)
            for name in self.loss_name
        ])
        if C.RPIN.SEQ_CLS_LOSS_WEIGHT:
            print_msg += f" | {self.fg_correct / (self.fg_num + 1e-9):.3f} | {self.bg_correct / (self.bg_num + 1e-9):.3f}"
        # print_msg += (" " * (os.get_terminal_size().columns - len(print_msg) - 10))
        self.logger.info(print_msg)
Example #4
0
    def test(self):
        self.model.eval()

        if C.RPIN.VAE:
            losses = dict.fromkeys(self.loss_name, 0.0)
            box_p_step_losses = [0.0 for _ in range(self.ptest_size)]
            masks_step_losses = [0.0 for _ in range(self.ptest_size)]

        for batch_idx, (data, _, rois, gt_boxes, gt_masks, valid,
                        g_idx) in enumerate(self.val_loader):
            with torch.no_grad():
                data = data.to(self.device)
                rois = xyxy_to_rois(rois,
                                    batch=data.shape[0],
                                    time_step=data.shape[1],
                                    num_devices=self.num_gpus)
                labels = {
                    'boxes': gt_boxes.to(self.device),
                    'masks': gt_masks.to(self.device),
                    'valid': valid.to(self.device),
                }
                outputs = self.model(data,
                                     rois,
                                     num_rollouts=self.ptest_size,
                                     g_idx=g_idx,
                                     phase='test')
                self.loss(outputs, labels, 'test')
                # VAE multiple runs
                if C.RPIN.VAE:
                    vae_best_mean = np.mean(
                        np.array(self.box_p_step_losses[:self.ptest_size]) /
                        self.loss_cnt) * 1e3
                    losses_t = self.losses.copy()
                    box_p_step_losses_t = self.box_p_step_losses.copy()
                    masks_step_losses_t = self.masks_step_losses.copy()
                    for i in range(99):
                        outputs = self.model(data,
                                             rois,
                                             num_rollouts=self.ptest_size,
                                             g_idx=g_idx,
                                             phase='test')
                        self.loss(outputs, labels, 'test')
                        mean_loss = np.mean(
                            np.array(self.box_p_step_losses[:self.ptest_size])
                            / self.loss_cnt) * 1e3
                        if mean_loss < vae_best_mean:
                            losses_t = self.losses.copy()
                            box_p_step_losses_t = self.box_p_step_losses.copy()
                            masks_step_losses_t = self.masks_step_losses.copy()
                            vae_best_mean = mean_loss
                        self._init_loss()

                    for k, v in losses.items():
                        losses[k] += losses_t[k]
                    for i in range(len(box_p_step_losses)):
                        box_p_step_losses[i] += box_p_step_losses_t[i]
                        masks_step_losses[i] += masks_step_losses_t[i]

                tprint(f'eval: {batch_idx}/{len(self.val_loader)}:' + ' ' * 20)

            if self.plot_image > 0:
                outputs = {
                    'boxes':
                    outputs['boxes'].cpu().numpy(),
                    'masks':
                    outputs['masks'].cpu().numpy()
                    if C.RPIN.MASK_LOSS_WEIGHT else None,
                }
                outputs['boxes'][..., 0::2] *= self.input_width
                outputs['boxes'][..., 1::2] *= self.input_height
                outputs['boxes'] = xywh2xyxy(outputs['boxes'].reshape(
                    -1, 4)).reshape(
                        (data.shape[0], -1, C.RPIN.MAX_NUM_OBJS, 4))

                labels = {
                    'boxes': labels['boxes'].cpu().numpy(),
                    'masks': labels['masks'].cpu().numpy(),
                }
                labels['boxes'][..., 0::2] *= self.input_width
                labels['boxes'][..., 1::2] *= self.input_height
                labels['boxes'] = xywh2xyxy(labels['boxes'].reshape(
                    -1, 4)).reshape(
                        (data.shape[0], -1, C.RPIN.MAX_NUM_OBJS, 4))

                for i in range(rois.shape[0]):
                    batch_size = C.SOLVER.BATCH_SIZE if not C.RPIN.VAE else 1
                    plot_image_idx = batch_size * batch_idx + i
                    if plot_image_idx < self.plot_image:
                        tprint(f'plotting: {plot_image_idx}' + ' ' * 20)
                        video_idx, img_idx = self.val_loader.dataset.video_info[
                            plot_image_idx]
                        video_name = self.val_loader.dataset.video_list[
                            video_idx]

                        v = valid[i].numpy().astype(np.bool)
                        pred_boxes_i = outputs['boxes'][i][:, v]
                        gt_boxes_i = labels['boxes'][i][:, v]

                        if 'PHYRE' in C.DATA_ROOT:
                            im_data = phyre.observations_to_float_rgb(
                                np.load(video_name).astype(
                                    np.uint8))[..., ::-1]
                            a, b, c = video_name.split('/')[5:8]
                            output_name = f'{a}_{b}_{c.replace(".npy", "")}'

                            bg_image = np.load(video_name).astype(np.uint8)
                            for fg_id in [1, 2, 3, 5]:
                                bg_image[bg_image == fg_id] = 0
                            bg_image = phyre.observations_to_float_rgb(
                                bg_image)
                        else:
                            bg_image = None
                            image_list = sorted(
                                glob(
                                    f'{video_name}/*{self.val_loader.dataset.image_ext}'
                                ))
                            im_name = image_list[img_idx]
                            output_name = '_'.join(
                                im_name.split('.')[0].split('/')[-2:])
                            # deal with image data here
                            gt_boxes_i = labels['boxes'][i][:, v]
                            im_data = get_im_data(im_name, gt_boxes_i[None,
                                                                      0:1],
                                                  C.DATA_ROOT,
                                                  self.high_resolution_plot)

                        if self.high_resolution_plot:
                            scale_w = im_data.shape[1] / self.input_width
                            scale_h = im_data.shape[0] / self.input_height
                            pred_boxes_i[..., [0, 2]] *= scale_w
                            pred_boxes_i[..., [1, 3]] *= scale_h
                            gt_boxes_i[..., [0, 2]] *= scale_w
                            gt_boxes_i[..., [1, 3]] *= scale_h

                        pred_masks_i = None
                        if C.RPIN.MASK_LOSS_WEIGHT:
                            pred_masks_i = outputs['masks'][i][:, v]

                        plot_rollouts(im_data,
                                      pred_boxes_i,
                                      gt_boxes_i,
                                      pred_masks_i,
                                      labels['masks'][i][:, v],
                                      output_dir=self.output_dir,
                                      output_name=output_name,
                                      bg_image=bg_image)

        if C.RPIN.VAE:
            self.losses = losses.copy()
            self.box_p_step_losses = box_p_step_losses.copy()
            self.loss_cnt = len(self.val_loader)

        print('\r', end='')
        print_msg = ""
        mean_loss = np.mean(
            np.array(self.box_p_step_losses[:self.ptest_size]) /
            self.loss_cnt) * 1e3
        print_msg += f"{mean_loss:.3f} | "
        print_msg += f" | ".join([
            "{:.3f}".format(self.losses[name] * 1e3 / self.loss_cnt)
            for name in self.loss_name
        ])
        pprint(print_msg)