Ejemplo n.º 1
0
    def __init__(self, data_root, split, image_ext='.jpg'):
        super().__init__(data_root, split, image_ext)

        self.video_list = sorted(glob(f'{self.data_root}/{self.split}/*/'))
        self.anno_list = [v[:-1] + '.pkl' for v in self.video_list]

        if C.INPUT.PRELOAD_TO_MEMORY:
            print('loading data from hickle file...')
            data = hickle.load(f'{self.data_root}/{self.split}.hkl')
            self.total_img = np.transpose(data['X'], (0, 1, 4, 2, 3))
            self.total_box = np.zeros((data['y'].shape[:3] + (5, )))
            for anno_idx, anno_name in enumerate(self.anno_list):
                tprint(f'loading progress: {anno_idx}/{len(self.anno_list)}')
                with open(anno_name, 'rb') as f:
                    boxes = pickle.load(f)
                self.total_box[anno_idx] = boxes

        self.video_info = np.zeros((0, 2), dtype=np.int32)
        for idx, video_name in enumerate(
                self.video_list if not C.INPUT.PRELOAD_TO_MEMORY else self.
                total_box):
            tprint(f'loading progress: {idx}/{len(self.video_list)}')
            if C.INPUT.PRELOAD_TO_MEMORY:
                num_sw = self.total_box[idx].shape[0] - self.seq_size + 1
            else:
                num_im = len(glob(f'{video_name}/*{image_ext}'))
                num_sw = num_im - self.seq_size + 1  # number of sliding windows

            if num_sw <= 0:
                continue
            video_info_t = np.zeros((num_sw, 2), dtype=np.int32)
            video_info_t[:, 0] = idx  # video index
            video_info_t[:, 1] = np.arange(num_sw)  # sliding window index
            self.video_info = np.vstack((self.video_info, video_info_t))
Ejemplo n.º 2
0
    def train_epoch(self):
        for batch_idx, (data, data_t, rois, gt_boxes, gt_masks, valid,
                        g_idx) in enumerate(self.train_loader):
            self._adjust_learning_rate()
            data, data_t = data.to(self.device), data_t.to(self.device)
            rois = xyxy_to_rois(rois,
                                batch=data.shape[0],
                                time_step=data.shape[1],
                                num_devices=self.num_gpus)
            self.optim.zero_grad()

            outputs = self.model(data,
                                 rois,
                                 num_rollouts=self.ptrain_size,
                                 g_idx=g_idx,
                                 x_t=data_t,
                                 phase='train')
            labels = {
                'boxes': gt_boxes.to(self.device),
                'masks': gt_masks.to(self.device),
                'valid': valid.to(self.device),
            }
            loss = self.loss(outputs, labels, 'train')
            loss.backward()
            self.optim.step()
            # this is an approximation for printing; the dataset size may not divide the batch size
            self.iterations += self.batch_size

            print_msg = ""
            print_msg += f"{self.epochs:03}/{self.iterations // 1000:04}k"
            print_msg += f" | "
            mean_loss = np.mean(
                np.array(self.box_p_step_losses[:self.ptrain_size]) /
                self.loss_cnt) * 1e3
            print_msg += f"{mean_loss:.3f} | "
            print_msg += f" | ".join([
                "{:.3f}".format(self.losses[name] * 1e3 / self.loss_cnt)
                for name in self.loss_name
            ])
            speed = self.loss_cnt / (timer() - self.time)
            eta = (self.max_iters - self.iterations) / speed / 3600
            print_msg += f" | speed: {speed:.1f} | eta: {eta:.2f} h"
            print_msg += (
                " " * (os.get_terminal_size().columns - len(print_msg) - 10))
            tprint(print_msg)

            if self.iterations % self.val_interval == 0:
                self.snapshot()
                self.val()
                self._init_loss()
                self.model.train()

            if self.iterations >= self.max_iters:
                print('\r', end='')
                print(f'{self.best_mean:.3f}')
                break
Ejemplo n.º 3
0
    def __init__(self, data_root, split, image_ext='.jpg'):
        super().__init__(data_root, split, image_ext)

        protocal = C.PHYRE_PROTOCAL
        fold = C.PHYRE_FOLD
        env_list = open(
            f'{data_root}/splits/{protocal}_{split}_fold_{fold}.txt',
            'r').read().split('\n')
        self.video_list = sum([
            sorted(glob(f'{data_root}/images/{env.replace(":", "/")}/*.npy'))
            for env in env_list
        ], [])
        self.anno_list = [(v[:-4] + '_boxes.hkl').replace('images', 'labels')
                          for v in self.video_list]

        # just for plot images
        if plot:
            self.video_list = [
                k for k in self.video_list
                if int(k.split('/')[-1].split('.')[0]) < 40
            ]
            self.anno_list = [
                k for k in self.anno_list
                if int(k.split('/')[-1].split('_')[0]) < 40
            ]
            assert len(self.video_list) == len(self.anno_list)
            self.video_list = self.video_list[::80]
            self.anno_list = self.anno_list[::80]

        # video_info_name = f'for_plot.npy'
        video_info_name = f'{data_root}/{protocal}_{split}_{self.input_size}_{self.pred_size}_fold_{fold}_info.npy'
        if os.path.exists(video_info_name):
            print(f'loading info from: {video_info_name}')
            self.video_info = np.load(video_info_name)
        else:
            self.video_info = np.zeros((0, 2), dtype=np.int32)
            for idx, video_name in enumerate(self.video_list):
                tprint(f'loading progress: {idx}/{len(self.video_list)}')
                num_im = hickle.load(
                    video_name.replace('images', 'labels').replace(
                        '.npy', '_boxes.hkl')).shape[0]
                if plot:
                    # we will pad sequence so no check
                    num_sw = 1
                else:
                    assert self.input_size == 1
                    num_sw = min(1, num_im - self.seq_size + 1)

                if num_sw <= 0:
                    continue
                video_info_t = np.zeros((num_sw, 2), dtype=np.int32)
                video_info_t[:, 0] = idx  # video index
                video_info_t[:, 1] = np.arange(num_sw)  # sliding window index
                self.video_info = np.vstack((self.video_info, video_info_t))

            np.save(video_info_name, self.video_info)
Ejemplo n.º 4
0
Archivo: ss.py Proyecto: dingmyu/RPIN
    def __init__(self, data_root, split, image_ext='.jpg'):
        super().__init__(data_root, split, image_ext)

        self.video_list = sorted(glob(f'{self.data_root}/{self.split}/*/'))
        self.anno_list = [v[:-1] + '_boxes.pkl' for v in self.video_list]

        self.video_info = np.zeros((0, 2), dtype=np.int32)
        for idx, video_name in enumerate(self.video_list):
            tprint(f'loading progress: {idx}/{len(self.video_list)}')
            num_im = len(glob(f'{video_name}/*{image_ext}'))
            # In ShapeStack, we only use the sequence starting from the first frame
            num_sw = min(1, num_im - self.seq_size + 1)
            if num_sw <= 0:
                continue
            video_info_t = np.zeros((num_sw, 2), dtype=np.int32)
            video_info_t[:, 0] = idx  # video index
            video_info_t[:, 1] = np.arange(num_sw)  # sliding window index
            self.video_info = np.vstack((self.video_info, video_info_t))
Ejemplo n.º 5
0
    def val(self):
        self.model.eval()
        self._init_loss()

        if C.RPIN.VAE:
            losses = dict.fromkeys(self.loss_name, 0.0)
            box_p_step_losses = [0.0 for _ in range(self.ptest_size)]
            masks_step_losses = [0.0 for _ in range(self.ptest_size)]

        for batch_idx, (data, _, rois, gt_boxes, gt_masks, valid, g_idx,
                        seq_l) in enumerate(self.val_loader):
            tprint(f'eval: {batch_idx}/{len(self.val_loader)}')
            with torch.no_grad():

                if C.RPIN.ROI_MASKING or C.RPIN.ROI_CROPPING:
                    # data should be (b x t x o x c x h x w)
                    data = data.permute(
                        (0, 2, 1, 3, 4, 5))  # (b, o, t, c, h, w)
                    data = data.reshape((data.shape[0] * data.shape[1], ) +
                                        data.shape[2:])  # (b*o, t, c, h, w)

                data = data.to(self.device)
                rois = xyxy_to_rois(rois,
                                    batch=data.shape[0],
                                    time_step=data.shape[1],
                                    num_devices=self.num_gpus)
                labels = {
                    'boxes': gt_boxes.to(self.device),
                    'masks': gt_masks.to(self.device),
                    'valid': valid.to(self.device),
                    'seq_l': seq_l.to(self.device),
                }

                outputs = self.model(data,
                                     rois,
                                     num_rollouts=self.ptest_size,
                                     g_idx=g_idx,
                                     phase='test')
                self.loss(outputs, labels, 'test')
                # VAE multiple runs
                if C.RPIN.VAE:
                    vae_best_mean = np.mean(
                        np.array(self.box_p_step_losses[:self.ptest_size]) /
                        self.loss_cnt) * 1e3
                    losses_t = self.losses.copy()
                    box_p_step_losses_t = self.box_p_step_losses.copy()
                    masks_step_losses_t = self.masks_step_losses.copy()
                    for i in range(9):
                        outputs = self.model(data,
                                             rois,
                                             num_rollouts=self.ptest_size,
                                             g_idx=g_idx,
                                             phase='test')
                        self.loss(outputs, labels, 'test')
                        mean_loss = np.mean(
                            np.array(self.box_p_step_losses[:self.ptest_size])
                            / self.loss_cnt) * 1e3
                        if mean_loss < vae_best_mean:
                            losses_t = self.losses.copy()
                            box_p_step_losses_t = self.box_p_step_losses.copy()
                            masks_step_losses_t = self.masks_step_losses.copy()
                            vae_best_mean = mean_loss
                        self._init_loss()

                    for k, v in losses.items():
                        losses[k] += losses_t[k]
                    for i in range(len(box_p_step_losses)):
                        box_p_step_losses[i] += box_p_step_losses_t[i]
                        masks_step_losses[i] += masks_step_losses_t[i]

        if C.RPIN.VAE:
            self.losses = losses.copy()
            self.box_p_step_losses = box_p_step_losses.copy()
            self.loss_cnt = len(self.val_loader)

        print('\r', end='')
        print_msg = ""
        print_msg += f"{self.epochs:03}/{self.iterations // 1000:04}k"
        print_msg += f" | "
        mean_loss = np.mean(
            np.array(self.box_p_step_losses[:self.ptest_size]) /
            self.loss_cnt) * 1e3
        print_msg += f"{mean_loss:.3f} | "

        if mean_loss < self.best_mean:
            self.snapshot('ckpt_best.path.tar')
            self.best_mean = mean_loss

        print_msg += f" | ".join([
            "{:.3f}".format(self.losses[name] * 1e3 / self.loss_cnt)
            for name in self.loss_name
        ])
        if C.RPIN.SEQ_CLS_LOSS_WEIGHT:
            print_msg += f" | {self.fg_correct / (self.fg_num + 1e-9):.3f} | {self.bg_correct / (self.bg_num + 1e-9):.3f}"
        # print_msg += (" " * (os.get_terminal_size().columns - len(print_msg) - 10))
        self.logger.info(print_msg)
Ejemplo n.º 6
0
    def test(self):
        self.model.eval()

        if C.RPIN.VAE:
            losses = dict.fromkeys(self.loss_name, 0.0)
            box_p_step_losses = [0.0 for _ in range(self.ptest_size)]
            masks_step_losses = [0.0 for _ in range(self.ptest_size)]

        for batch_idx, (data, _, rois, gt_boxes, gt_masks, valid,
                        g_idx) in enumerate(self.val_loader):
            with torch.no_grad():
                data = data.to(self.device)
                rois = xyxy_to_rois(rois,
                                    batch=data.shape[0],
                                    time_step=data.shape[1],
                                    num_devices=self.num_gpus)
                labels = {
                    'boxes': gt_boxes.to(self.device),
                    'masks': gt_masks.to(self.device),
                    'valid': valid.to(self.device),
                }
                outputs = self.model(data,
                                     rois,
                                     num_rollouts=self.ptest_size,
                                     g_idx=g_idx,
                                     phase='test')
                self.loss(outputs, labels, 'test')
                # VAE multiple runs
                if C.RPIN.VAE:
                    vae_best_mean = np.mean(
                        np.array(self.box_p_step_losses[:self.ptest_size]) /
                        self.loss_cnt) * 1e3
                    losses_t = self.losses.copy()
                    box_p_step_losses_t = self.box_p_step_losses.copy()
                    masks_step_losses_t = self.masks_step_losses.copy()
                    for i in range(99):
                        outputs = self.model(data,
                                             rois,
                                             num_rollouts=self.ptest_size,
                                             g_idx=g_idx,
                                             phase='test')
                        self.loss(outputs, labels, 'test')
                        mean_loss = np.mean(
                            np.array(self.box_p_step_losses[:self.ptest_size])
                            / self.loss_cnt) * 1e3
                        if mean_loss < vae_best_mean:
                            losses_t = self.losses.copy()
                            box_p_step_losses_t = self.box_p_step_losses.copy()
                            masks_step_losses_t = self.masks_step_losses.copy()
                            vae_best_mean = mean_loss
                        self._init_loss()

                    for k, v in losses.items():
                        losses[k] += losses_t[k]
                    for i in range(len(box_p_step_losses)):
                        box_p_step_losses[i] += box_p_step_losses_t[i]
                        masks_step_losses[i] += masks_step_losses_t[i]

                tprint(f'eval: {batch_idx}/{len(self.val_loader)}:' + ' ' * 20)

            if self.plot_image > 0:
                outputs = {
                    'boxes':
                    outputs['boxes'].cpu().numpy(),
                    'masks':
                    outputs['masks'].cpu().numpy()
                    if C.RPIN.MASK_LOSS_WEIGHT else None,
                }
                outputs['boxes'][..., 0::2] *= self.input_width
                outputs['boxes'][..., 1::2] *= self.input_height
                outputs['boxes'] = xywh2xyxy(outputs['boxes'].reshape(
                    -1, 4)).reshape(
                        (data.shape[0], -1, C.RPIN.MAX_NUM_OBJS, 4))

                labels = {
                    'boxes': labels['boxes'].cpu().numpy(),
                    'masks': labels['masks'].cpu().numpy(),
                }
                labels['boxes'][..., 0::2] *= self.input_width
                labels['boxes'][..., 1::2] *= self.input_height
                labels['boxes'] = xywh2xyxy(labels['boxes'].reshape(
                    -1, 4)).reshape(
                        (data.shape[0], -1, C.RPIN.MAX_NUM_OBJS, 4))

                for i in range(rois.shape[0]):
                    batch_size = C.SOLVER.BATCH_SIZE if not C.RPIN.VAE else 1
                    plot_image_idx = batch_size * batch_idx + i
                    if plot_image_idx < self.plot_image:
                        tprint(f'plotting: {plot_image_idx}' + ' ' * 20)
                        video_idx, img_idx = self.val_loader.dataset.video_info[
                            plot_image_idx]
                        video_name = self.val_loader.dataset.video_list[
                            video_idx]

                        v = valid[i].numpy().astype(np.bool)
                        pred_boxes_i = outputs['boxes'][i][:, v]
                        gt_boxes_i = labels['boxes'][i][:, v]

                        if 'PHYRE' in C.DATA_ROOT:
                            im_data = phyre.observations_to_float_rgb(
                                np.load(video_name).astype(
                                    np.uint8))[..., ::-1]
                            a, b, c = video_name.split('/')[5:8]
                            output_name = f'{a}_{b}_{c.replace(".npy", "")}'

                            bg_image = np.load(video_name).astype(np.uint8)
                            for fg_id in [1, 2, 3, 5]:
                                bg_image[bg_image == fg_id] = 0
                            bg_image = phyre.observations_to_float_rgb(
                                bg_image)
                        else:
                            bg_image = None
                            image_list = sorted(
                                glob(
                                    f'{video_name}/*{self.val_loader.dataset.image_ext}'
                                ))
                            im_name = image_list[img_idx]
                            output_name = '_'.join(
                                im_name.split('.')[0].split('/')[-2:])
                            # deal with image data here
                            gt_boxes_i = labels['boxes'][i][:, v]
                            im_data = get_im_data(im_name, gt_boxes_i[None,
                                                                      0:1],
                                                  C.DATA_ROOT,
                                                  self.high_resolution_plot)

                        if self.high_resolution_plot:
                            scale_w = im_data.shape[1] / self.input_width
                            scale_h = im_data.shape[0] / self.input_height
                            pred_boxes_i[..., [0, 2]] *= scale_w
                            pred_boxes_i[..., [1, 3]] *= scale_h
                            gt_boxes_i[..., [0, 2]] *= scale_w
                            gt_boxes_i[..., [1, 3]] *= scale_h

                        pred_masks_i = None
                        if C.RPIN.MASK_LOSS_WEIGHT:
                            pred_masks_i = outputs['masks'][i][:, v]

                        plot_rollouts(im_data,
                                      pred_boxes_i,
                                      gt_boxes_i,
                                      pred_masks_i,
                                      labels['masks'][i][:, v],
                                      output_dir=self.output_dir,
                                      output_name=output_name,
                                      bg_image=bg_image)

        if C.RPIN.VAE:
            self.losses = losses.copy()
            self.box_p_step_losses = box_p_step_losses.copy()
            self.loss_cnt = len(self.val_loader)

        print('\r', end='')
        print_msg = ""
        mean_loss = np.mean(
            np.array(self.box_p_step_losses[:self.ptest_size]) /
            self.loss_cnt) * 1e3
        print_msg += f"{mean_loss:.3f} | "
        print_msg += f" | ".join([
            "{:.3f}".format(self.losses[name] * 1e3 / self.loss_cnt)
            for name in self.loss_name
        ])
        pprint(print_msg)
Ejemplo n.º 7
0
def train(train_loader, test_loader, model, optim, scheduler, logger, output_dir):
    max_iters = C.SOLVER.MAX_ITERS
    model.train()

    losses = []
    acc = [0, 0, 0, 0]
    test_accs = []
    last_time = time.time()
    cur_update = 0
    while True and cur_update < max_iters:
        for batch_idx, data_tuple in enumerate(train_loader):
            if cur_update >= max_iters:
                break
            model.train()

            p_gt, p_pred, n_gt, n_pred = data_tuple
            labels = torch.cat([torch.ones(p_gt.shape[0]), torch.zeros(n_gt.shape[0])]).to('cuda')
            p_data = []
            n_data = []
            for i, idx in enumerate(take_idx):
                p_data.append(p_gt[:, [idx]] if is_gt[i] else p_pred[:, [idx]])
                n_data.append(n_gt[:, [idx]] if is_gt[i] else n_pred[:, [idx]])
            p_data = torch.cat(p_data, dim=1)
            n_data = torch.cat(n_data, dim=1)
            data = torch.cat([p_data, n_data])

            data = data.long().to('cuda')
            optim.zero_grad()

            pred = model(data)
            loss = model.ce_loss(pred, labels)

            pred = pred.sigmoid() >= 0.5
            acc[0] += ((pred == labels)[labels == 1]).sum().item()
            acc[1] += ((pred == labels)[labels == 0]).sum().item()
            acc[2] += (labels == 1).sum().item()
            acc[3] += (labels == 0).sum().item()

            loss.backward()
            optim.step()
            scheduler.step()
            losses.append(loss.mean().item())

            cur_update += 1
            speed = (time.time() - last_time) / cur_update
            eta = (max_iters - cur_update) * speed / 3600
            info = f'Iter: {cur_update} / {max_iters}, eta: {eta:.2f}h ' \
                   f'p acc: {acc[0] / acc[2]:.4f} n acc: {acc[1] / acc[3]:.4f}'
            tprint(info)

            if (cur_update + 1) % C.SOLVER.VAL_INTERVAL == 0:
                pprint(info)
                fpath = os.path.join(output_dir, 'last.ckpt')
                torch.save(
                    dict(
                        model=model.state_dict(), optim=optim.state_dict(), done_batches=cur_update + 1,
                        scheduler=scheduler and scheduler.state_dict(),
                    ), fpath
                )

                p_acc, n_acc = test(test_loader, model)
                test_accs.append([p_acc, n_acc])
                model.train()
                acc = [0, 0, 0, 0]
                for k in range(2):
                    info = ''
                    for test_acc in test_accs:
                        info += f'{test_acc[k] * 100:.1f} / '
                    logger.info(info)