Exemplo n.º 1
0
def main():
    # ---------------------------------------------------------------------------- #
    # Setup the experiment
    # ---------------------------------------------------------------------------- #
    args = parse_args()

    # load the configuration
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    purge_cfg(cfg)
    cfg.freeze()

    # run name
    timestamp = time.strftime('%m-%d_%H-%M-%S')
    hostname = socket.gethostname()
    run_name = '{:s}.{:s}'.format(timestamp, hostname)

    output_dir = cfg.OUTPUT_DIR
    # replace '@' with config path
    if output_dir:
        config_path = osp.splitext(args.config_file)[0]
        output_dir = output_dir.replace(
            '@', config_path.replace('configs', 'outputs'))
        if args.dev:
            output_dir = osp.join(output_dir, run_name)
            warnings.warn('Dev mode enabled.')
        if osp.isdir(output_dir):
            warnings.warn('Output directory exists.')
        os.makedirs(output_dir, exist_ok=True)

    logger = setup_logger('train',
                          output_dir,
                          filename='log.train.{:s}.txt'.format(run_name))
    logger.info('{:d} GPUs available'.format(torch.cuda.device_count()))
    logger.info(args)

    from common.utils.collect_env import collect_env_info
    logger.info('Collecting env info (might take some time)\n' +
                collect_env_info())

    logger.info('Loaded configuration file {:s}'.format(args.config_file))
    logger.info('Running with config:\n{}'.format(cfg))

    # ---------------------------------------------------------------------------- #
    # Build models, optimizer, scheduler, checkpointer, etc.
    # ---------------------------------------------------------------------------- #
    # build model
    set_random_seed(cfg.RNG_SEED)
    model = build_model(cfg)
    logger.info('Build model:\n{}'.format(str(model)))

    # Currently only support single-gpu mode
    model = model.cuda()

    # build optimizer
    optimizer = build_optimizer(cfg, model)

    # build lr scheduler
    lr_scheduler = build_lr_scheduler(cfg, optimizer)

    # build checkpointer
    # Note that checkpointer will load state_dict of model, optimizer and scheduler.
    checkpointer = CheckpointerV2(model,
                                  optimizer=optimizer,
                                  scheduler=lr_scheduler,
                                  save_dir=output_dir,
                                  logger=logger,
                                  max_to_keep=cfg.TRAIN.MAX_TO_KEEP)
    checkpoint_data = checkpointer.load(cfg.RESUME_PATH,
                                        resume=cfg.AUTO_RESUME,
                                        resume_states=cfg.RESUME_STATES,
                                        strict=cfg.RESUME_STRICT)
    ckpt_period = cfg.TRAIN.CHECKPOINT_PERIOD
    start_iter = checkpoint_data.get('iteration', 0)

    # build data loader
    # Reset the random seed again in case the initialization of models changes the random state.
    set_random_seed(cfg.RNG_SEED)
    train_dataloader = build_gnn_dataloader(cfg, True, start_iter)
    logger.info(train_dataloader.dataset)

    # build metrics
    train_meters = MetricLogger(delimiter='  ')

    def setup_train():
        model.train()
        train_meters.reset()

    # Build tensorboard logger
    summary_writer = None
    if output_dir:
        tb_dir = output_dir
        summary_writer = SummaryWriter(tb_dir, max_queue=64, flush_secs=30)

    # ---------------------------------------------------------------------------- #
    # Setup validation
    # ---------------------------------------------------------------------------- #
    val_period = cfg.VAL.PERIOD
    do_validation = val_period > 0
    if do_validation:
        val_dataloader = build_gnn_dataloader(cfg, training=False)
        logger.info(val_dataloader.dataset)
        val_meters = MetricLogger(delimiter='  ')

        best_metric_name = 'best_{}'.format(cfg.VAL.METRIC)
        best_metric = checkpoint_data.get(best_metric_name, None)

        def setup_validate():
            model.eval()
            val_meters.reset()

    # ---------------------------------------------------------------------------- #
    # Training begins.
    # ---------------------------------------------------------------------------- #
    setup_train()
    max_iter = cfg.TRAIN.MAX_ITER
    logger.info('Start training from iteration {}'.format(start_iter))
    tic = time.time()

    for iteration, data_batch in enumerate(train_dataloader, start_iter):
        cur_iter = iteration + 1
        data_time = time.time() - tic

        # copy data from cpu to gpu
        data_batch = data_batch.to('cuda')

        # forward
        pd_dict = model(data_batch)

        # update losses
        loss_dict = model.compute_losses(
            pd_dict,
            data_batch,
        )
        total_loss = sum(loss_dict.values())

        # It is slightly faster to update metrics and meters before backward
        with torch.no_grad():
            train_meters.update(total_loss=total_loss, **loss_dict)
            model.update_metrics(pd_dict, data_batch, train_meters.metrics)

        # backward
        optimizer.zero_grad()
        total_loss.backward()
        if cfg.OPTIMIZER.MAX_GRAD_NORM > 0:
            # CAUTION: built-in clip_grad_norm_ clips the total norm.
            total_norm = clip_grad_norm_(model.parameters(),
                                         max_norm=cfg.OPTIMIZER.MAX_GRAD_NORM)
        else:
            total_norm = None
        optimizer.step()

        batch_time = time.time() - tic
        train_meters.update(time=batch_time, data=data_time)

        # log
        log_period = cfg.TRAIN.LOG_PERIOD
        if log_period > 0 and (cur_iter % log_period == 0 or cur_iter == 1):
            logger.info(
                train_meters.delimiter.join([
                    'iter: {iter:4d}',
                    '{meters}',
                    'lr: {lr:.2e}',
                    'max mem: {memory:.0f}',
                ]).format(
                    iter=cur_iter,
                    meters=str(train_meters),
                    lr=optimizer.param_groups[0]['lr'],
                    memory=torch.cuda.max_memory_allocated() / (1024.0**2),
                ))

        # summary
        summary_period = cfg.TRAIN.SUMMARY_PERIOD
        if summary_writer is not None and (summary_period > 0
                                           and cur_iter % summary_period == 0):
            keywords = (
                'loss',
                'acc',
            )
            for name, metric in train_meters.metrics.items():
                if all(k not in name for k in keywords):
                    continue
                summary_writer.add_scalar('train/' + name,
                                          metric.result,
                                          global_step=cur_iter)

            # summarize gradient norm
            if total_norm is not None:
                summary_writer.add_scalar('grad_norm',
                                          total_norm,
                                          global_step=cur_iter)

        # ---------------------------------------------------------------------------- #
        # validate for one epoch
        # ---------------------------------------------------------------------------- #
        if do_validation and (cur_iter % val_period == 0
                              or cur_iter == max_iter):
            setup_validate()
            logger.info('Validation begins at iteration {}.'.format(cur_iter))

            start_time_val = time.time()
            tic = time.time()
            for iteration_val, data_batch in enumerate(val_dataloader):
                data_time = time.time() - tic

                # copy data from cpu to gpu
                data_batch = data_batch.to('cuda')

                # forward
                with torch.no_grad():
                    pd_dict = model(data_batch)

                # update losses and metrics
                loss_dict = model.compute_losses(pd_dict, data_batch)
                total_loss = sum(loss_dict.values())

                # update metrics and meters
                val_meters.update(loss=total_loss, **loss_dict)
                model.update_metrics(pd_dict, data_batch, val_meters.metrics)

                batch_time = time.time() - tic
                val_meters.update(time=batch_time, data=data_time)
                tic = time.time()

                if cfg.VAL.LOG_PERIOD > 0 and iteration_val % cfg.VAL.LOG_PERIOD == 0:
                    logger.info(
                        val_meters.delimiter.join([
                            'iter: {iter:4d}',
                            '{meters}',
                            'max mem: {memory:.0f}',
                        ]).format(
                            iter=iteration,
                            meters=str(val_meters),
                            memory=torch.cuda.max_memory_allocated() /
                            (1024.0**2),
                        ))

            # END: validation loop
            epoch_time_val = time.time() - start_time_val
            logger.info('Iteration[{}]-Val {}  total_time: {:.2f}s'.format(
                cur_iter, val_meters.summary_str, epoch_time_val))

            # summary
            if summary_writer is not None:
                keywords = ('loss', 'acc', 'ap', 'recall')
                for name, metric in val_meters.metrics.items():
                    if all(k not in name for k in keywords):
                        continue
                    summary_writer.add_scalar('val/' + name,
                                              metric.result,
                                              global_step=cur_iter)

            # best validation
            if cfg.VAL.METRIC in val_meters.metrics:
                cur_metric = val_meters.metrics[cfg.VAL.METRIC].result
                if best_metric is None \
                        or (cfg.VAL.METRIC_ASCEND and cur_metric > best_metric) \
                        or (not cfg.VAL.METRIC_ASCEND and cur_metric < best_metric):
                    best_metric = cur_metric
                    checkpoint_data['iteration'] = cur_iter
                    checkpoint_data[best_metric_name] = best_metric
                    checkpointer.save('model_best',
                                      tag=False,
                                      **checkpoint_data)

            # restore training
            setup_train()

        # ---------------------------------------------------------------------------- #
        # After validation
        # ---------------------------------------------------------------------------- #
        # checkpoint
        if (ckpt_period > 0
                and cur_iter % ckpt_period == 0) or cur_iter == max_iter:
            checkpoint_data['iteration'] = cur_iter
            if do_validation and best_metric is not None:
                checkpoint_data[best_metric_name] = best_metric
            checkpointer.save('model_{:06d}'.format(cur_iter),
                              **checkpoint_data)

        # ---------------------------------------------------------------------------- #
        # Finalize one step
        # ---------------------------------------------------------------------------- #
        # since pytorch v1.1.0, lr_scheduler is called after optimization.
        if lr_scheduler is not None:
            lr_scheduler.step()
        tic = time.time()

    # END: training loop
    if do_validation and cfg.VAL.METRIC:
        logger.info('Best val-{} = {}'.format(cfg.VAL.METRIC, best_metric))
Exemplo n.º 2
0
def test(cfg, args, output_dir='', run_name=''):
    logger = logging.getLogger('mvpnet.test')

    # build mvpnet model
    model = build_model_mvpnet_3d(cfg)[0]
    model = model.cuda()

    # build checkpointer
    checkpointer = CheckpointerV2(model, save_dir=output_dir, logger=logger)
    if args.ckpt_path:
        # load weight if specified
        weight_path = args.ckpt_path.replace('@', output_dir)
        checkpointer.load(weight_path, resume=False)
    else:
        # load last checkpoint
        checkpointer.load(None, resume=True)

    # build dataset
    k = args.k or cfg.DATASET.ScanNet2D3DChunks.k
    num_views = args.num_views or cfg.DATASET.ScanNet2D3DChunks.num_rgbd_frames
    test_dataset = ScanNet2D3DChunksTest(cache_dir=args.cache_dir,
                                         image_dir=args.image_dir,
                                         split=args.split,
                                         chunk_size=(args.chunk_size, args.chunk_size),
                                         chunk_stride=args.chunk_stride,
                                         chunk_thresh=args.chunk_thresh,
                                         num_rgbd_frames=num_views,
                                         resize=cfg.DATASET.ScanNet2D3DChunks.resize,
                                         image_normalizer=cfg.DATASET.ScanNet2D3DChunks.image_normalizer,
                                         k=k,
                                         to_tensor=True,
                                         )
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=4,
                                 collate_fn=lambda x: x[0])

    # evaluator
    class_names = test_dataset.class_names
    evaluator = Evaluator(class_names) #加这个参数会报错吗?labels = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39]
    num_classes = len(class_names)
    submit_dir = None
    if args.save:
        submit_dir = osp.join(output_dir, 'submit', run_name)
        os.makedirs(submit_dir)
        logits_dir = osp.join(submit_dir, 'logits')
        os.makedirs(logits_dir)

    # ---------------------------------------------------------------------------- #
    # Test
    # ---------------------------------------------------------------------------- #
    model.eval()
    set_random_seed(cfg.RNG_SEED)
    test_meters = MetricLogger(delimiter='  ')

    with torch.no_grad():
        start_time = time.time()
        start_time_scan = time.time()
        for scan_idx, data_dict_list in enumerate(test_dataloader):
            # fetch data
            scan_id = test_dataset.scan_ids[scan_idx]
            points = test_dataset.data[scan_idx]['points'].astype(np.float32)
            seg_label = None
            if args.split != 'test':
                seg_label = test_dataset.data[scan_idx]['seg_label']
                seg_label = test_dataset.nyu40_to_scannet[seg_label] #id? value?
            data_time = time.time() - start_time_scan

            # prepare outputs
            num_points = len(points)
            pred_logit_whole_scene = np.zeros([num_points, num_classes], dtype=np.float32)
            num_pred_per_point = np.zeros(num_points, dtype=np.uint8)

            # iterate over chunks
            tic = time.time()
            for data_dict in data_dict_list:
                chunk_ind = data_dict.pop('chunk_ind')

                # padding for chunks with points less than min-nb-pts
                # It is required since farthest point sampling requires more points than centroids.
                chunk_points = data_dict['points']
                chunk_nb_pts = chunk_points.shape[1]  # note that already transposed
                if chunk_nb_pts < args.min_nb_pts:
                    print('Too sparse chunk in {} with {} points.'.format(scan_id, chunk_nb_pts))
                    pad = np.random.randint(chunk_nb_pts, size=args.min_nb_pts - chunk_nb_pts)
                    choice = np.hstack([np.arange(chunk_nb_pts), pad])
                    data_dict['points'] = data_dict['points'][:, choice]
                    data_dict['knn_indices'] = data_dict['knn_indices'][choice]

                data_batch = {k: torch.tensor([v]) for k, v in data_dict.items()}
                data_batch = {k: v.cuda(non_blocking=True) for k, v in data_batch.items()}
                # forward
                preds = model(data_batch)
                seg_logit = preds['seg_logit'].squeeze(0).cpu().numpy().T
                seg_logit = seg_logit[:len(chunk_ind)]
                # update
                pred_logit_whole_scene[chunk_ind] += seg_logit
                num_pred_per_point[chunk_ind] += 1
            forward_time = time.time() - tic

            pred_logit_whole_scene = pred_logit_whole_scene / np.maximum(num_pred_per_point[:, np.newaxis], 1)
            pred_label_whole_scene = np.argmax(pred_logit_whole_scene, axis=1)

            no_pred_mask = num_pred_per_point == 0
            no_pred_indices = np.nonzero(no_pred_mask)[0]
            if no_pred_indices.size > 0:
                logger.warning('{:s}: There are {:d} points without prediction.'.format(scan_id, no_pred_mask.sum()))
                pred_label_whole_scene[no_pred_indices] = num_classes

            if _DEBUG:
                # DEBUG: visualize scene
                from mvpnet.utils.visualize import visualize_labels
                visualize_labels(points, pred_label_whole_scene)
                visualize_labels(points, seg_label)

            # evaluate
            tic = time.time()
            if seg_label is not None:
                evaluator.update(pred_label_whole_scene, seg_label)
            metric_time = time.time() - tic

            batch_time = time.time() - start_time_scan
            test_meters.update(time=batch_time)
            test_meters.update(data=data_time,
                               forward_time=forward_time,
                               metric_time=metric_time,
                               )

            # save prediction
            if submit_dir:
                remapped_pred_labels = test_dataset.scannet_to_nyu40[pred_label_whole_scene]
                np.savetxt(osp.join(submit_dir, scan_id + '.txt'), remapped_pred_labels, '%d')
                np.save(osp.join(logits_dir, scan_id + '.npy'), pred_logit_whole_scene, '%d')

            logger.info(
                test_meters.delimiter.join(
                    [
                        '{:d}/{:d}({:s})',
                        'acc: {acc:.2f}',
                        'IoU: {iou:.2f}',
                        '{meters}',
                    ]
                ).format(
                    scan_idx, len(test_dataset), scan_id,
                    acc=evaluator.overall_acc * 100.0,
                    iou=evaluator.overall_iou * 100.0,
                    meters=str(test_meters),
                )
            )
            start_time_scan = time.time()

        test_time = time.time() - start_time
        logger.info('Test {}  test time: {:.2f}s'.format(test_meters.summary_str, test_time))

    # evaluate
    logger.info('overall accuracy={:.2f}%'.format(100.0 * evaluator.overall_acc))
    logger.info('overall IOU={:.2f}'.format(100.0 * evaluator.overall_iou))
    logger.info('class-wise accuracy and IoU.\n{}'.format(evaluator.print_table()))
    evaluator.save_table(osp.join(output_dir, 'eval.{}.tsv'.format(run_name)))
Exemplo n.º 3
0
def train(cfg, output_dir='', run_name=''):
    # ---------------------------------------------------------------------------- #
    # Build models, optimizer, scheduler, checkpointer, etc.
    # It is recommended not to modify this section.
    # ---------------------------------------------------------------------------- #
    logger = logging.getLogger('mvpnet.train')

    # build model
    set_random_seed(cfg.RNG_SEED)
    model, loss_fn, train_metric, val_metric = build_model_sem_seg_2d(cfg)
    logger.info('Build model:\n{}'.format(str(model)))
    num_params = sum(param.numel() for param in model.parameters())
    print('#Parameters: {:.2e}'.format(num_params))

    num_gpus = torch.cuda.device_count()
    if num_gpus > 1:
        model = nn.DataParallel(model).cuda()
    elif num_gpus == 1:
        model = model.cuda()
    else:
        raise NotImplementedError('Not support cpu training now.')

    # build optimizer
    # model_cfg = cfg.MODEL[cfg.MODEL.TYPE]
    optimizer = build_optimizer(cfg, model)

    # build lr scheduler
    scheduler = build_scheduler(cfg, optimizer)

    # build checkpointer
    # Note that checkpointer will load state_dict of model, optimizer and scheduler.
    checkpointer = CheckpointerV2(model,
                                  optimizer=optimizer,
                                  scheduler=scheduler,
                                  save_dir=output_dir,
                                  logger=logger,
                                  max_to_keep=cfg.TRAIN.MAX_TO_KEEP)
    checkpoint_data = checkpointer.load(cfg.RESUME_PATH, resume=cfg.AUTO_RESUME, resume_states=cfg.RESUME_STATES)
    ckpt_period = cfg.TRAIN.CHECKPOINT_PERIOD

    # build freezer
    if cfg.TRAIN.FROZEN_PATTERNS:
        freezer = Freezer(model, cfg.TRAIN.FROZEN_PATTERNS)
        freezer.freeze(verbose=True)  # sanity check
    else:
        freezer = None

    # build data loader
    # Reset the random seed again in case the initialization of models changes the random state.
    set_random_seed(cfg.RNG_SEED)
    train_dataloader = build_dataloader(cfg, mode='train')
    val_period = cfg.VAL.PERIOD
    val_dataloader = build_dataloader(cfg, mode='val') if val_period > 0 else None

    # build tensorboard logger (optionally by comment)
    if output_dir:
        tb_dir = osp.join(output_dir, 'tb.{:s}'.format(run_name))
        summary_writier = SummaryWriter(tb_dir)
    else:
        summary_writier = None

    # ---------------------------------------------------------------------------- #
    # Train
    # Customization begins here.
    # ---------------------------------------------------------------------------- #
    max_iteration = cfg.SCHEDULER.MAX_ITERATION
    start_iteration = checkpoint_data.get('iteration', 0)
    best_metric_name = 'best_{}'.format(cfg.VAL.METRIC)
    best_metric = checkpoint_data.get(best_metric_name, None)
    logger.info('Start training from iteration {}'.format(start_iteration))

    # add metrics
    if not isinstance(train_metric, (list, tuple)):
        train_metric = [train_metric]
    if not isinstance(val_metric, (list, tuple)):
        val_metric = [val_metric]
    train_metric_logger = MetricLogger(delimiter='  ')
    train_metric_logger.add_meters(train_metric)
    val_metric_logger = MetricLogger(delimiter='  ')
    val_metric_logger.add_meters(val_metric)

    # wrap the dataloader
    batch_sampler = train_dataloader.batch_sampler
    train_dataloader.batch_sampler = IterationBasedBatchSampler(batch_sampler, max_iteration, start_iteration)

    def setup_train():
        # set training mode
        model.train()
        loss_fn.train()
        # freeze parameters/modules optionally
        if freezer is not None:
            freezer.freeze()
        # reset metric
        train_metric_logger.reset()

    def setup_validate():
        # set evaluate mode
        model.eval()
        loss_fn.eval()
        # reset metric
        val_metric_logger.reset()

    setup_train()
    end = time.time()
    for iteration, data_batch in enumerate(train_dataloader, start_iteration):
        data_time = time.time() - end
        # copy data from cpu to gpu
        data_batch = {k: v.cuda(non_blocking=True) for k, v in data_batch.items()}
        # forward
        preds = model(data_batch)
        # update losses
        optimizer.zero_grad()
        loss_dict = loss_fn(preds, data_batch)
        total_loss = sum(loss_dict.values())

        # It is slightly faster to update metrics and meters before backward
        with torch.no_grad():
            train_metric_logger.update(loss=total_loss, **loss_dict)
            for metric in train_metric:
                metric.update_dict(preds, data_batch)

        # backward
        total_loss.backward()
        if cfg.OPTIMIZER.MAX_GRAD_NORM > 0:
            # CAUTION: built-in clip_grad_norm_ clips the total norm.
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=cfg.OPTIMIZER.MAX_GRAD_NORM)
        optimizer.step()

        batch_time = time.time() - end
        train_metric_logger.update(time=batch_time, data=data_time)
        cur_iter = iteration + 1

        # log
        if cur_iter == 1 or (cfg.TRAIN.LOG_PERIOD > 0 and cur_iter % cfg.TRAIN.LOG_PERIOD) == 0:
            logger.info(
                train_metric_logger.delimiter.join(
                    [
                        'iter: {iter:4d}',
                        '{meters}',
                        'lr: {lr:.2e}',
                        'max mem: {memory:.0f}',
                    ]
                ).format(
                    iter=cur_iter,
                    meters=str(train_metric_logger),
                    lr=optimizer.param_groups[0]['lr'],
                    memory=torch.cuda.max_memory_allocated() / (1024.0 ** 2),
                )
            )

        # summary
        if summary_writier is not None and cfg.TRAIN.SUMMARY_PERIOD > 0 and cur_iter % cfg.TRAIN.SUMMARY_PERIOD == 0:
            keywords = ('loss', 'acc', 'iou')
            for name, meter in train_metric_logger.meters.items():
                if all(k not in name for k in keywords):
                    continue
                summary_writier.add_scalar('train/' + name, meter.global_avg, global_step=cur_iter)
            with torch.no_grad():
                # summary images
                image = data_batch['image'][0].cpu().numpy()  # (3, h, w)
                image = np.transpose(image, [1, 2, 0])
                image = image * (0.229, 0.224, 0.225) + (0.485, 0.456, 0.406)
                seg_label = data_batch['seg_label'][0].cpu().numpy()  # (h, w)
                seg_label[seg_label < 0] = len(color_palette) - 1
                seg_logit = preds['seg_logit'][0].cpu().numpy()  # (c, h, w)
                pred_label = seg_logit.argmax(0)
                pred_image = color_palette[pred_label, :]  # (h, w, 3)
                gt_image = color_palette[seg_label, :]  # (h, w, 3)
                summary_writier.add_image('train/' + 'input', image, global_step=cur_iter, dataformats='HWC')
                summary_writier.add_image('train/' + 'pred', pred_image, global_step=cur_iter, dataformats='HWC')
                summary_writier.add_image('train/' + 'gt', gt_image, global_step=cur_iter, dataformats='HWC')

        # checkpoint
        if (ckpt_period > 0 and cur_iter % ckpt_period == 0) or cur_iter == max_iteration:
            checkpoint_data['iteration'] = cur_iter
            checkpoint_data[best_metric_name] = best_metric
            checkpointer.save('model_{:06d}'.format(cur_iter), **checkpoint_data)

        # ---------------------------------------------------------------------------- #
        # validate for one epoch
        # ---------------------------------------------------------------------------- #
        if val_period > 0 and (cur_iter % val_period == 0 or cur_iter == max_iteration):
            start_time_val = time.time()
            setup_validate()

            end = time.time()
            with torch.no_grad():
                for iteration_val, data_batch in enumerate(val_dataloader):
                    data_time = time.time() - end
                    # copy data from cpu to gpu
                    data_batch = {k: v.cuda(non_blocking=True) for k, v in data_batch.items()}
                    # forward
                    preds = model(data_batch)
                    # update losses and metrics
                    loss_dict = loss_fn(preds, data_batch)
                    total_loss = sum(loss_dict.values())
                    # update metrics and meters
                    val_metric_logger.update(loss=total_loss, **loss_dict)
                    for metric in val_metric:
                        metric.update_dict(preds, data_batch)

                    batch_time = time.time() - end
                    val_metric_logger.update(time=batch_time, data=data_time)
                    end = time.time()

                    if cfg.VAL.LOG_PERIOD > 0 and iteration_val % cfg.VAL.LOG_PERIOD == 0:
                        logger.info(
                            val_metric_logger.delimiter.join(
                                [
                                    'iter: {iter:4d}',
                                    '{meters}',
                                    'max mem: {memory:.0f}',
                                ]
                            ).format(
                                iter=iteration,
                                meters=str(val_metric_logger),
                                memory=torch.cuda.max_memory_allocated() / (1024.0 ** 2),
                            )
                        )

            epoch_time_val = time.time() - start_time_val
            logger.info('Iteration[{}]-Val {}  total_time: {:.2f}s'.format(
                cur_iter, val_metric_logger.summary_str, epoch_time_val))

            # summary
            if summary_writier is not None:
                keywords = ('loss', 'acc', 'iou')
                for name, meter in val_metric_logger.meters.items():
                    if all(k not in name for k in keywords):
                        continue
                    summary_writier.add_scalar('val/' + name, meter.global_avg, global_step=cur_iter)
                with torch.no_grad():
                    # summary images
                    image = data_batch['image'][0].cpu().numpy()  # (3, h, w)
                    image = np.transpose(image, [1, 2, 0])
                    image = image * (0.229, 0.224, 0.225) + (0.485, 0.456, 0.406)
                    seg_label = data_batch['seg_label'][0].cpu().numpy()  # (h, w)
                    seg_label[seg_label < 0] = len(color_palette) - 1
                    seg_logit = preds['seg_logit'][0].cpu().numpy()  # (c, h, w)
                    pred_label = seg_logit.argmax(0)
                    pred_image = color_palette[pred_label, :]  # (h, w, 3)
                    gt_image = color_palette[seg_label, :]  # (h, w, 3)
                    summary_writier.add_image('val/' + 'input', image, global_step=cur_iter, dataformats='HWC')
                    summary_writier.add_image('val/' + 'pred', pred_image, global_step=cur_iter, dataformats='HWC')
                    summary_writier.add_image('val/' + 'gt', gt_image, global_step=cur_iter, dataformats='HWC')

            # best validation
            if cfg.VAL.METRIC in val_metric_logger.meters:
                cur_metric = val_metric_logger.meters[cfg.VAL.METRIC].global_avg
                if best_metric is None \
                        or ('loss' not in cfg.VAL.METRIC and cur_metric > best_metric) \
                        or ('loss' in cfg.VAL.METRIC and cur_metric < best_metric):
                    best_metric = cur_metric
                    checkpoint_data['iteration'] = cur_iter
                    checkpoint_data[best_metric_name] = best_metric
                    checkpointer.save('model_best', tag=False, **checkpoint_data)

            # restore training
            setup_train()

        # since pytorch v1.1.0, lr_scheduler is called after optimization.
        if scheduler is not None:
            scheduler.step()
        end = time.time()

    logger.info('Best val-{} = {}'.format(cfg.VAL.METRIC, best_metric))
    return model
Exemplo n.º 4
0
def test(cfg, args, output_dir='', run_name=''):
    logger = logging.getLogger('mvpnet.test')

    # build 2d model
    net_2d = build_model_sem_seg_2d(cfg)[0]

    # build checkpointer
    checkpointer = CheckpointerV2(net_2d, save_dir=output_dir, logger=logger)
    if args.ckpt_path:
        # load weight if specified
        weight_path = args.ckpt_path.replace('@', output_dir)
        checkpointer.load(weight_path, resume=False)
    else:
        # load last checkpoint
        checkpointer.load(None, resume=True)

    # wrapper for 2d model
    model = MVPNet2D(net_2d)
    model = model.cuda()

    # build dataset
    test_dataset = ScanNet2D3DChunksTest(
        cache_dir=args.cache_dir,
        image_dir=args.image_dir,
        split=args.split,
        chunk_size=(args.chunk_size, args.chunk_size),
        chunk_stride=args.chunk_stride,
        chunk_thresh=args.chunk_thresh,
        num_rgbd_frames=args.num_views,
        resize=cfg.DATASET.ScanNet2D.resize,
        image_normalizer=cfg.DATASET.ScanNet2D.normalizer,
        k=args.k,
        to_tensor=True,
    )
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=4,
                                 collate_fn=lambda x: x[0])

    # evaluator
    class_names = test_dataset.class_names
    evaluator = Evaluator(class_names)
    num_classes = len(class_names)
    submit_dir = None
    if args.save:
        submit_dir = osp.join(output_dir, 'submit', run_name)

    # ---------------------------------------------------------------------------- #
    # Test
    # ---------------------------------------------------------------------------- #
    model.eval()
    set_random_seed(cfg.RNG_SEED)
    test_meters = MetricLogger(delimiter='  ')

    with torch.no_grad():
        start_time = time.time()
        start_time_scan = time.time()
        for scan_idx, data_dict_list in enumerate(test_dataloader):
            # fetch data
            scan_id = test_dataset.scan_ids[scan_idx]
            points = test_dataset.data[scan_idx]['points'].astype(np.float32)
            seg_label = test_dataset.data[scan_idx]['seg_label']
            seg_label = test_dataset.nyu40_to_scannet[seg_label]
            data_time = time.time() - start_time_scan

            # prepare outputs
            num_points = len(points)
            pred_logit_whole_scene = np.zeros([num_points, num_classes],
                                              dtype=np.float32)
            num_pred_per_point = np.zeros(num_points, dtype=np.uint8)

            # iterate over chunks
            tic = time.time()
            for data_dict in data_dict_list:
                chunk_ind = data_dict.pop('chunk_ind')
                data_batch = {
                    k: torch.tensor([v])
                    for k, v in data_dict.items()
                }
                data_batch = {
                    k: v.cuda(non_blocking=True)
                    for k, v in data_batch.items()
                }
                # forward
                preds = model(data_batch)
                seg_logit = preds['seg_logit'].squeeze(0).cpu().numpy().T
                seg_logit = seg_logit[:len(chunk_ind)]
                # update
                pred_logit_whole_scene[chunk_ind] += seg_logit
                num_pred_per_point[chunk_ind] += 1
            forward_time = time.time() - tic

            pred_logit_whole_scene = pred_logit_whole_scene / np.maximum(
                num_pred_per_point[:, np.newaxis], 1)
            pred_label_whole_scene = np.argmax(pred_logit_whole_scene, axis=1)

            no_pred_mask = num_pred_per_point == 0
            no_pred_indices = np.nonzero(no_pred_mask)[0]
            if no_pred_indices.size > 0:
                logger.warning(
                    '{:s}: There are {:d} points without prediction.'.format(
                        scan_id, no_pred_mask.sum()))
                pred_label_whole_scene[no_pred_indices] = num_classes

            if _DEBUG:
                # DEBUG: visualize scene
                from mvpnet.utils.visualize import visualize_labels
                visualize_labels(points, pred_label_whole_scene)
                visualize_labels(points, seg_label)

            # evaluate
            tic = time.time()
            if seg_label is not None:
                evaluator.update(pred_label_whole_scene, seg_label)
            metric_time = time.time() - tic

            batch_time = time.time() - start_time_scan
            test_meters.update(time=batch_time)
            test_meters.update(
                data=data_time,
                forward_time=forward_time,
                metric_time=metric_time,
            )

            # save prediction
            if submit_dir:
                remapped_pred_labels = test_dataset.scannet_to_nyu40[
                    pred_label_whole_scene]
                np.savetxt(osp.join(submit_dir, scan_id + '.txt'),
                           remapped_pred_labels, '%d')

            logger.info(
                test_meters.delimiter.join([
                    '{:d}/{:d}({:s})',
                    'acc: {acc:.2f}',
                    'IoU: {iou:.2f}',
                    '{meters}',
                ]).format(
                    scan_idx,
                    len(test_dataset),
                    scan_id,
                    acc=evaluator.overall_acc * 100.0,
                    iou=evaluator.overall_iou * 100.0,
                    meters=str(test_meters),
                ))
            start_time_scan = time.time()

        test_time = time.time() - start_time
        logger.info('Test {}  test time: {:.2f}s'.format(
            test_meters.summary_str, test_time))

    # evaluate
    logger.info('overall accuracy={:.2f}%'.format(100.0 *
                                                  evaluator.overall_acc))
    logger.info('overall IOU={:.2f}'.format(100.0 * evaluator.overall_iou))
    logger.info('class-wise accuracy and IoU.\n{}'.format(
        evaluator.print_table()))
    evaluator.save_table(osp.join(output_dir, 'eval.{}.tsv'.format(run_name)))
Exemplo n.º 5
0
def test(cfg, args, output_dir='', run_name=''):
    logger = logging.getLogger('mvpnet.test')

    # build model
    model = build_model_sem_seg_2d(cfg)[0]
    model = model.cuda()

    # build checkpointer
    checkpointer = CheckpointerV2(model, save_dir=output_dir, logger=logger)

    if args.ckpt_path:
        # load weight if specified
        weight_path = args.ckpt_path.replace('@', output_dir)
        checkpointer.load(weight_path, resume=False)
    else:
        # load last checkpoint
        checkpointer.load(None, resume=True)

    # build dataset
    test_dataset = ScanNet2D(
        cfg.DATASET.ROOT_DIR,
        split=args.split,
        subsample=None,
        to_tensor=True,
        resize=cfg.DATASET.ScanNet2D.resize,
        normalizer=cfg.DATASET.ScanNet2D.normalizer,
    )
    batch_size = args.batch_size or cfg.VAL.BATCH_SIZE
    num_workers = args.num_workers or cfg.DATALOADER.NUM_WORKERS
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=batch_size,
                                 shuffle=False,
                                 num_workers=num_workers,
                                 drop_last=False)

    # evaluator
    class_names = test_dataset.class_names
    evaluator = Evaluator(class_names)
    num_classes = len(class_names)
    submit_dir = None
    if args.save:
        submit_dir = osp.join(output_dir, 'submit', run_name)

    # ---------------------------------------------------------------------------- #
    # Test
    # ---------------------------------------------------------------------------- #
    model.eval()
    set_random_seed(cfg.RNG_SEED)
    test_meters = MetricLogger(delimiter='  ')

    with torch.no_grad():
        start_time = time.time()
        for iteration, data_batch in enumerate(test_dataloader):
            gt_label = data_batch.get('seg_label', None)
            data_batch = {
                k: v.cuda(non_blocking=True)
                for k, v in data_batch.items()
            }
            # forward
            preds = model(data_batch)
            pred_label = preds['seg_logit'].argmax(
                1).cpu().numpy()  # (b, h, w)
            # evaluate
            if gt_label is not None:
                gt_label = gt_label.cpu().numpy()
                evaluator.batch_update(pred_label, gt_label)
            # logging
            if args.log_period and iteration % args.log_period == 0:
                logger.info(
                    test_meters.delimiter.join([
                        '{:d}/{:d}',
                        'acc: {acc:.2f}',
                        'IoU: {iou:.2f}',
                        # '{meters}',
                    ]).format(
                        iteration,
                        len(test_dataloader),
                        acc=evaluator.overall_acc * 100.0,
                        iou=evaluator.overall_iou * 100.0,
                        # meters=str(test_meters),
                    ))
        test_time = time.time() - start_time
        # logger.info('Test {}  test time: {:.2f}s'.format(test_meters.summary_str, test_time))

    # evaluate
    logger.info('overall accuracy={:.2f}%'.format(100.0 *
                                                  evaluator.overall_acc))
    logger.info('overall IOU={:.2f}'.format(100.0 * evaluator.overall_iou))
    logger.info('class-wise accuracy and IoU.\n{}'.format(
        evaluator.print_table()))
    evaluator.save_table(osp.join(output_dir, 'eval.{}.tsv'.format(run_name)))
Exemplo n.º 6
0
def test(cfg, args, output_dir='', run_name=''):
    logger = logging.getLogger('mvpnet.test')

    # build model
    model = build_model_sem_seg_3d(cfg)[0]
    model = model.cuda()

    # build checkpointer
    checkpointer = CheckpointerV2(model, save_dir=output_dir, logger=logger)

    if args.ckpt_path:
        # load weight if specified
        weight_path = args.ckpt_path.replace('@', output_dir)
        checkpointer.load(weight_path, resume=False)
    else:
        # load last checkpoint
        checkpointer.load(None, resume=True)

    # build dataset
    test_dataset = ScanNet3D(cfg.DATASET.ROOT_DIR, split=args.split)
    test_dataset.set_mapping('scannet')

    # evaluator
    class_names = test_dataset.class_names
    evaluator = Evaluator(class_names)
    num_classes = len(class_names)
    submit_dir = None
    if args.save:
        submit_dir = osp.join(output_dir, 'submit', run_name)

    # others
    aug_list = []
    if not args.no_rot:
        aug_list.append(T.RandomRotateZ())
    transform = T.Compose([T.ToTensor(), T.Transpose()])
    use_color = args.use_color or (model.in_channels == 3)
    num_votes = args.num_votes

    # ---------------------------------------------------------------------------- #
    # Test
    # ---------------------------------------------------------------------------- #
    model.eval()
    set_random_seed(args.seed)
    test_meters = MetricLogger(delimiter='  ')

    with torch.no_grad():
        start_time = time.time()
        for scan_idx in range(len(test_dataset)):
            start_time_scan = time.time()
            # fetch data
            tic = time.time()
            data_dict = test_dataset[scan_idx]
            scan_id = data_dict['scan_id']
            points = data_dict['points']  # (n, 3)
            colors = data_dict['colors']  # (n, 3)
            seg_label = data_dict.get('seg_label', None)  # (n,)
            data_time = time.time() - tic

            # prepare inputs
            tic = time.time()
            data_list = []
            points_list = []
            colors_list = []
            ind_list = []
            for vote_ind in range(num_votes):
                if len(points) >= args.nb_pts:
                    ind = np.random.choice(len(points),
                                           size=args.nb_pts,
                                           replace=False)
                else:
                    ind = np.hstack([
                        np.arange(len(points)),
                        np.zeros(args.nb_pts - len(points))
                    ])
                points_list.append(points[ind])
                colors_list.append(colors[ind])
                ind_list.append(ind)
            for vote_ind in range(num_votes):
                data_single = {
                    'points': points_list[vote_ind],
                }
                if use_color:
                    data_single['feature'] = colors_list[vote_ind]
                data_list.append(transform(**data_single))
            data_batch = {
                k: torch.stack([x[k] for x in data_list])
                for k in data_single
            }
            data_batch = {
                k: v.cuda(non_blocking=True)
                for k, v in data_batch.items()
            }
            preprocess_time = time.time() - tic

            # forward
            tic = time.time()
            preds = model(data_batch)
            seg_logit_batch = preds['seg_logit'].cpu().numpy()
            forward_time = time.time() - tic

            # propagate predictions and ensemble
            tic = time.time()
            pred_logit_whole_scene = np.zeros([len(points), num_classes],
                                              dtype=np.float32)
            for vote_ind in range(num_votes):
                points_per_vote = points_list[vote_ind]
                seg_logit_per_vote = seg_logit_batch[vote_ind].T
                # Propagate to nearest neighbours
                nbrs = NearestNeighbors(
                    n_neighbors=1, algorithm='ball_tree').fit(points_per_vote)
                _, nn_indices = nbrs.kneighbors(points[:, 0:3])
                pred_logit_whole_scene += seg_logit_per_vote[nn_indices[:, 0]]
            # if we use softmax, it is necessary to normalize logits
            pred_logit_whole_scene = pred_logit_whole_scene / num_votes
            pred_label_whole_scene = np.argmax(pred_logit_whole_scene, axis=1)
            postprocess_time = time.time() - tic

            if _DEBUG:
                # DEBUG: visualize scene
                from mvpnet.utils.visualize import visualize_labels
                visualize_labels(points, pred_label_whole_scene, colors=colors)

            # evaluate
            tic = time.time()
            if seg_label is not None:
                evaluator.update(pred_label_whole_scene, seg_label)
            metric_time = time.time() - tic

            batch_time = time.time() - start_time_scan
            test_meters.update(time=batch_time)
            test_meters.update(data=data_time,
                               preprocess_time=preprocess_time,
                               forward_time=forward_time,
                               postprocess_time=postprocess_time,
                               metric_time=metric_time)

            # save prediction
            if submit_dir:
                remapped_pred_labels = test_dataset.scannet_to_raw[
                    pred_label_whole_scene]
                np.savetxt(osp.join(submit_dir, scan_id + '.txt'),
                           remapped_pred_labels, '%d')

            logger.info(
                test_meters.delimiter.join([
                    '{:d}/{:d}({:s})',
                    'acc: {acc:.2f}',
                    'IoU: {iou:.2f}',
                    '{meters}',
                ]).format(
                    scan_idx,
                    len(test_dataset),
                    scan_id,
                    acc=evaluator.overall_acc * 100.0,
                    iou=evaluator.overall_iou * 100.0,
                    meters=str(test_meters),
                ))
        test_time = time.time() - start_time
        logger.info('Test {}  test time: {:.2f}s'.format(
            test_meters.summary_str, test_time))

    # evaluate
    logger.info('overall accuracy={:.2f}%'.format(100.0 *
                                                  evaluator.overall_acc))
    logger.info('overall IOU={:.2f}'.format(100.0 * evaluator.overall_iou))
    logger.info('class-wise accuracy and IoU.\n{}'.format(
        evaluator.print_table()))
    evaluator.save_table(osp.join(output_dir, 'eval.{}.tsv'.format(run_name)))
Exemplo n.º 7
0
def test(cfg, args, output_dir='', run_name=''):
    logger = logging.getLogger('mvpnet.test')

    # build model
    model = build_model_sem_seg_3d(cfg)[0]
    model = model.cuda()

    # build checkpointer
    checkpointer = CheckpointerV2(model, save_dir=output_dir, logger=logger)

    if args.ckpt_path:
        # load weight if specified
        weight_path = args.ckpt_path.replace('@', output_dir)
        checkpointer.load(weight_path, resume=False)
    else:
        # load last checkpoint
        checkpointer.load(None, resume=True)

    # build dataset
    test_dataset = ScanNet3D(cfg.DATASET.ROOT_DIR, split=args.split)
    test_dataset.set_mapping('scannet')

    # evaluator
    class_names = test_dataset.class_names
    evaluator = Evaluator(class_names)
    num_classes = len(class_names)
    submit_dir = None
    if args.save:
        submit_dir = osp.join(output_dir, 'submit', run_name)

    # others
    transform = T.Compose(
        [T.ToTensor(), T.Pad(args.min_nb_pts),
         T.Transpose()])
    use_color = args.use_color or (model.in_channels == 3)

    # ---------------------------------------------------------------------------- #
    # Test
    # ---------------------------------------------------------------------------- #
    model.eval()
    set_random_seed(cfg.RNG_SEED)
    test_meters = MetricLogger(delimiter='  ')

    with torch.no_grad():
        start_time = time.time()
        for scan_idx in range(len(test_dataset)):
            start_time_scan = time.time()
            # fetch data
            tic = time.time()
            data_dict = test_dataset[scan_idx]
            scan_id = data_dict['scan_id']
            points = data_dict['points']  # (n, 3)
            colors = data_dict['colors']  # (n, 3)
            seg_label = data_dict.get('seg_label', None)  # (n,)
            data_time = time.time() - tic

            # generate chunks
            tic = time.time()
            chunk_indices = scene2chunks_legacy(points,
                                                chunk_size=(args.chunk_size,
                                                            args.chunk_size),
                                                stride=args.chunk_stride,
                                                thresh=args.chunk_thresh)
            # num_chunks = len(chunk_indices)
            preprocess_time = time.time() - tic

            # prepare outputs
            num_points = len(points)
            pred_logit_whole_scene = np.zeros([num_points, num_classes],
                                              dtype=np.float32)
            num_pred_per_point = np.zeros(num_points, dtype=np.uint8)

            # iterate over chunks
            tic = time.time()
            for indices in chunk_indices:
                chunk_points = points[indices]
                chunk_feature = colors[indices]
                chunk_num_points = len(chunk_points)
                if chunk_num_points < args.min_nb_pts:
                    print('Too few points({}) in a chunk of {}'.format(
                        chunk_num_points, scan_id))
                # if _DEBUG:
                #     # DEBUG: visualize chunk
                #     from mvpnet.utils.o3d_util import visualize_point_cloud
                #     visualize_point_cloud(chunk_points, colors=chunk_feature)
                # prepare inputs
                data_batch = {'points': chunk_points}
                if use_color:
                    data_batch['feature'] = chunk_feature
                # preprocess
                data_batch = transform(**data_batch)
                data_batch = {
                    k: torch.stack([v])
                    for k, v in data_batch.items()
                }
                data_batch = {
                    k: v.cuda(non_blocking=True)
                    for k, v in data_batch.items()
                }
                # forward
                preds = model(data_batch)
                seg_logit = preds['seg_logit'].squeeze(0).cpu().numpy().T
                seg_logit = seg_logit[:chunk_num_points]
                # update
                pred_logit_whole_scene[indices] += seg_logit
                num_pred_per_point[indices] += 1
            forward_time = time.time() - tic

            pred_logit_whole_scene = pred_logit_whole_scene / np.maximum(
                num_pred_per_point[:, np.newaxis], 1)
            pred_label_whole_scene = np.argmax(pred_logit_whole_scene, axis=1)

            no_pred_mask = num_pred_per_point == 0
            no_pred_indices = np.nonzero(no_pred_mask)[0]
            if no_pred_indices.size > 0:
                logger.warning(
                    '{:s}: There are {:d} points without prediction.'.format(
                        scan_id, no_pred_mask.sum()))
                pred_label_whole_scene[no_pred_indices] = num_classes

            if _DEBUG:
                # DEBUG: visualize scene
                from mvpnet.utils.visualize import visualize_labels
                visualize_labels(points, pred_label_whole_scene, colors=colors)

            # evaluate
            tic = time.time()
            if seg_label is not None:
                evaluator.update(pred_label_whole_scene, seg_label)
            metric_time = time.time() - tic

            batch_time = time.time() - start_time_scan
            test_meters.update(time=batch_time)
            test_meters.update(data=data_time,
                               preprocess_time=preprocess_time,
                               forward_time=forward_time,
                               metric_time=metric_time)

            # save prediction
            if submit_dir:
                remapped_pred_labels = test_dataset.scannet_to_raw[
                    pred_label_whole_scene]
                np.savetxt(osp.join(submit_dir, scan_id + '.txt'),
                           remapped_pred_labels, '%d')

            logger.info(
                test_meters.delimiter.join([
                    '{:d}/{:d}({:s})',
                    'acc: {acc:.2f}',
                    'IoU: {iou:.2f}',
                    '{meters}',
                ]).format(
                    scan_idx,
                    len(test_dataset),
                    scan_id,
                    acc=evaluator.overall_acc * 100.0,
                    iou=evaluator.overall_iou * 100.0,
                    meters=str(test_meters),
                ))
        test_time = time.time() - start_time
        logger.info('Test {}  test time: {:.2f}s'.format(
            test_meters.summary_str, test_time))

    # evaluate
    logger.info('overall accuracy={:.2f}%'.format(100.0 *
                                                  evaluator.overall_acc))
    logger.info('overall IOU={:.2f}'.format(100.0 * evaluator.overall_iou))
    logger.info('class-wise accuracy and IoU.\n{}'.format(
        evaluator.print_table()))
    evaluator.save_table(osp.join(output_dir, 'eval.{}.tsv'.format(run_name)))