Beispiel #1
0
def main():
    parser = argparse.ArgumentParser(
        description=
        'Large-scale Point Cloud Semantic Segmentation with Superpoint Graphs')

    # Optimization arguments
    parser.add_argument('--wd', default=0, type=float, help='Weight decay')
    parser.add_argument('--lr',
                        default=1e-2,
                        type=float,
                        help='Initial learning rate')
    parser.add_argument(
        '--lr_decay',
        default=0.7,
        type=float,
        help='Multiplicative factor used on learning rate at `lr_steps`')
    parser.add_argument(
        '--lr_steps',
        default='[]',
        help='List of epochs where the learning rate is decreased by `lr_decay`'
    )
    parser.add_argument('--momentum', default=0.9, type=float, help='Momentum')
    parser.add_argument(
        '--epochs',
        default=10,
        type=int,
        help='Number of epochs to train. If <=0, only testing will be done.')
    parser.add_argument('--batch_size', default=2, type=int, help='Batch size')
    parser.add_argument('--optim', default='adam', help='Optimizer: sgd|adam')
    parser.add_argument(
        '--grad_clip',
        default=1,
        type=float,
        help='Element-wise clipping of gradient. If 0, does not clip')

    # Learning process arguments
    parser.add_argument('--cuda', default=1, type=int, help='Bool, use cuda')
    parser.add_argument(
        '--nworkers',
        default=0,
        type=int,
        help=
        'Num subprocesses to use for data loading. 0 means that the data will be loaded in the main process'
    )
    parser.add_argument('--test_nth_epoch',
                        default=1,
                        type=int,
                        help='Test each n-th epoch during training')
    parser.add_argument('--save_nth_epoch',
                        default=1,
                        type=int,
                        help='Save model each n-th epoch during training')
    parser.add_argument(
        '--test_multisamp_n',
        default=10,
        type=int,
        help='Average logits obtained over runs with different seeds')

    # Dataset
    parser.add_argument('--dataset',
                        default='sema3d',
                        help='Dataset name: sema3d|s3dis')
    parser.add_argument(
        '--cvfold',
        default=0,
        type=int,
        help='Fold left-out for testing in leave-one-out setting (S3DIS)')
    parser.add_argument('--odir',
                        default='results',
                        help='Directory to store results')
    parser.add_argument('--resume',
                        default='',
                        help='Loads a previously saved model.')
    parser.add_argument('--db_train_name', default='train')
    parser.add_argument('--db_test_name', default='val')
    parser.add_argument('--SEMA3D_PATH', default='datasets/semantic3d')
    parser.add_argument('--S3DIS_PATH', default='datasets/s3dis')
    parser.add_argument('--CUSTOM_SET_PATH', default='datasets/custom_set')

    # Model
    parser.add_argument(
        '--model_config',
        default='gru_10,f_8',
        help=
        'Defines the model as a sequence of layers, see graphnet.py for definitions of respective layers and acceptable arguments. In short: rectype_repeats_mv_layernorm_ingate_concat, with rectype the type of recurrent unit [gru/crf/lstm], repeats the number of message passing iterations, mv (default True) the use of matrix-vector (mv) instead vector-vector (vv) edge filters, layernorm (default True) the use of layernorms in the recurrent units, ingate (default True) the use of input gating, concat (default True) the use of state concatenation'
    )
    parser.add_argument('--seed',
                        default=1,
                        type=int,
                        help='Seed for random initialisation')
    parser.add_argument(
        '--edge_attribs',
        default=
        'delta_avg,delta_std,nlength/ld,surface/ld,volume/ld,size/ld,xyz/d',
        help=
        'Edge attribute definition, see spg_edge_features() in spg.py for definitions.'
    )

    # Point cloud processing
    parser.add_argument(
        '--pc_attribs',
        default='',
        help='Point attributes fed to PointNets, if empty then all possible.')
    parser.add_argument(
        '--pc_augm_scale',
        default=0,
        type=float,
        help=
        'Training augmentation: Uniformly random scaling in [1/scale, scale]')
    parser.add_argument(
        '--pc_augm_rot',
        default=1,
        type=int,
        help='Training augmentation: Bool, random rotation around z-axis')
    parser.add_argument(
        '--pc_augm_mirror_prob',
        default=0,
        type=float,
        help='Training augmentation: Probability of mirroring about x or y axes'
    )
    parser.add_argument(
        '--pc_augm_jitter',
        default=1,
        type=int,
        help='Training augmentation: Bool, Gaussian jittering of all attributes'
    )
    parser.add_argument(
        '--pc_xyznormalize',
        default=1,
        type=int,
        help='Bool, normalize xyz into unit ball, i.e. in [-0.5,0.5]')

    # Filter generating network
    parser.add_argument(
        '--fnet_widths',
        default='[32,128,64]',
        help=
        'List of width of hidden filter gen net layers (excluding the input and output ones, they are automatic)'
    )
    parser.add_argument(
        '--fnet_llbias',
        default=0,
        type=int,
        help='Bool, use bias in the last layer in filter gen net')
    parser.add_argument(
        '--fnet_orthoinit',
        default=1,
        type=int,
        help='Bool, use orthogonal weight initialization for filter gen net.')
    parser.add_argument(
        '--fnet_bnidx',
        default=2,
        type=int,
        help='Layer index to insert batchnorm to. -1=do not insert.')
    parser.add_argument(
        '--edge_mem_limit',
        default=30000,
        type=int,
        help=
        'Number of edges to process in parallel during computation, a low number can reduce memory peaks.'
    )

    # Superpoint graph
    parser.add_argument(
        '--spg_attribs01',
        default=1,
        type=int,
        help='Bool, normalize edge features to 0 mean 1 deviation')
    parser.add_argument('--spg_augm_nneigh',
                        default=100,
                        type=int,
                        help='Number of neighborhoods to sample in SPG')
    parser.add_argument('--spg_augm_order',
                        default=3,
                        type=int,
                        help='Order of neighborhoods to sample in SPG')
    parser.add_argument(
        '--spg_augm_hardcutoff',
        default=512,
        type=int,
        help=
        'Maximum number of superpoints larger than args.ptn_minpts to sample in SPG'
    )
    parser.add_argument(
        '--spg_superedge_cutoff',
        default=-1,
        type=float,
        help=
        'Artificially constrained maximum length of superedge, -1=do not constrain'
    )

    # Point net
    parser.add_argument(
        '--ptn_minpts',
        default=40,
        type=int,
        help=
        'Minimum number of points in a superpoint for computing its embedding.'
    )
    parser.add_argument('--ptn_npts',
                        default=128,
                        type=int,
                        help='Number of input points for PointNet.')
    parser.add_argument('--ptn_widths',
                        default='[[64,64,128,128,256], [256,64,32]]',
                        help='PointNet widths')
    parser.add_argument('--ptn_widths_stn',
                        default='[[64,64,128], [128,64]]',
                        help='PointNet\'s Transformer widths')
    parser.add_argument(
        '--ptn_nfeat_stn',
        default=11,
        type=int,
        help='PointNet\'s Transformer number of input features')
    parser.add_argument('--ptn_prelast_do', default=0, type=float)
    parser.add_argument(
        '--ptn_mem_monger',
        default=1,
        type=int,
        help=
        'Bool, save GPU memory by recomputing PointNets in back propagation.')

    args = parser.parse_args()
    args.start_epoch = 0
    args.lr_steps = ast.literal_eval(args.lr_steps)
    args.fnet_widths = ast.literal_eval(args.fnet_widths)
    args.ptn_widths = ast.literal_eval(args.ptn_widths)
    args.ptn_widths_stn = ast.literal_eval(args.ptn_widths_stn)

    print('Will save to ' + args.odir)
    if not os.path.exists(args.odir):
        os.makedirs(args.odir)
    with open(os.path.join(args.odir, 'cmdline.txt'), 'w') as f:
        f.write(" ".join([
            "'" + a + "'" if (len(a) == 0 or a[0] != '-') else a
            for a in sys.argv
        ]))

    set_seed(args.seed, args.cuda)
    logging.getLogger().setLevel(
        logging.INFO)  #set to logging.DEBUG to allow for more prints
    if (args.dataset == 'sema3d' and args.db_test_name.startswith('test')) or (
            args.dataset.startswith('s3dis_02') and args.cvfold == 2):
        # needed in pytorch 0.2 for super-large graphs with batchnorm in fnet  (https://github.com/pytorch/pytorch/pull/2919)
        torch.backends.cudnn.enabled = False

    # Decide on the dataset
    if args.dataset == 'sema3d':
        import sema3d_dataset
        dbinfo = sema3d_dataset.get_info(args)
        create_dataset = sema3d_dataset.get_datasets
    elif args.dataset == 's3dis':
        import s3dis_dataset
        dbinfo = s3dis_dataset.get_info(args)
        create_dataset = s3dis_dataset.get_datasets
    elif args.dataset == 'custom_dataset':
        import custom_dataset  #<- to write!
        dbinfo = custom_dataset.get_info(args)
        create_dataset = custom_dataset.get_datasets
    else:
        raise NotImplementedError('Unknown dataset ' + args.dataset)

    # Create model and optimizer
    if args.resume != '':
        if args.resume == 'RESUME': args.resume = args.odir + '/model.pth.tar'
        model, optimizer, stats = resume(args, dbinfo)
    else:
        model = create_model(args, dbinfo)
        optimizer = create_optimizer(args, model)
        stats = []

    train_dataset, test_dataset = create_dataset(args)
    ptnCloudEmbedder = pointnet.CloudEmbedder(args)
    scheduler = MultiStepLR(optimizer,
                            milestones=args.lr_steps,
                            gamma=args.lr_decay,
                            last_epoch=args.start_epoch - 1)

    ############
    def train():
        """ Trains for one epoch """
        model.train()

        loader = torch.utils.data.DataLoader(train_dataset,
                                             batch_size=args.batch_size,
                                             collate_fn=spg.eccpc_collate,
                                             num_workers=args.nworkers,
                                             shuffle=True,
                                             drop_last=True)
        if logging.getLogger().getEffectiveLevel() > logging.DEBUG:
            loader = tqdm(loader, ncols=100)

        loss_meter = tnt.meter.AverageValueMeter()
        acc_meter = tnt.meter.ClassErrorMeter(accuracy=True)
        confusion_matrix = metrics.ConfusionMatrix(dbinfo['classes'])
        t0 = time.time()

        # iterate over dataset in batches
        for bidx, (targets, GIs, clouds_data) in enumerate(loader):
            t_loader = 1000 * (time.time() - t0)

            model.ecc.set_info(GIs, args.cuda)
            label_mode_cpu, label_vec_cpu, segm_size_cpu = targets[:,
                                                                   0], targets[:,
                                                                               2:], targets[:, 1:].sum(
                                                                                   1
                                                                               )
            if args.cuda:
                label_mode, label_vec, segm_size = label_mode_cpu.cuda(
                ), label_vec_cpu.float().cuda(), segm_size_cpu.float().cuda()

            else:
                label_mode, label_vec, segm_size = label_mode_cpu, label_vec_cpu.float(
                ), segm_size_cpu.float()

            optimizer.zero_grad()
            t0 = time.time()

            embeddings = ptnCloudEmbedder.run(model, *clouds_data)
            outputs = model.ecc(embeddings)

            loss = nn.functional.cross_entropy(outputs, Variable(label_mode))
            loss.backward()
            ptnCloudEmbedder.bw_hook()

            if args.grad_clip > 0:
                for p in model.parameters():
                    p.grad.data.clamp_(-args.grad_clip, args.grad_clip)
            optimizer.step()

            t_trainer = 1000 * (time.time() - t0)
            loss_meter.add(loss.data[0])

            o_cpu, t_cpu, tvec_cpu = filter_valid(outputs.data.cpu().numpy(),
                                                  label_mode_cpu.numpy(),
                                                  label_vec_cpu.numpy())
            acc_meter.add(o_cpu, t_cpu)
            confusion_matrix.count_predicted_batch(tvec_cpu,
                                                   np.argmax(o_cpu, 1))

            logging.debug(
                'Batch loss %f, Loader time %f ms, Trainer time %f ms.',
                loss.data[0], t_loader, t_trainer)
            t0 = time.time()

        return acc_meter.value()[0], loss_meter.value(
        )[0], confusion_matrix.get_overall_accuracy(
        ), confusion_matrix.get_average_intersection_union()

    ############
    def eval():
        """ Evaluated model on test set """
        model.eval()

        loader = torch.utils.data.DataLoader(test_dataset,
                                             batch_size=1,
                                             collate_fn=spg.eccpc_collate,
                                             num_workers=args.nworkers)
        if logging.getLogger().getEffectiveLevel() > logging.DEBUG:
            loader = tqdm(loader, ncols=100)

        acc_meter = tnt.meter.ClassErrorMeter(accuracy=True)
        confusion_matrix = metrics.ConfusionMatrix(dbinfo['classes'])

        # iterate over dataset in batches
        for bidx, (targets, GIs, clouds_data) in enumerate(loader):
            model.ecc.set_info(GIs, args.cuda)
            label_mode_cpu, label_vec_cpu, segm_size_cpu = targets[:,
                                                                   0], targets[:,
                                                                               2:], targets[:, 1:].sum(
                                                                                   1
                                                                               ).float(
                                                                               )

            embeddings = ptnCloudEmbedder.run(model, *clouds_data)
            outputs = model.ecc(embeddings)

            o_cpu, t_cpu, tvec_cpu = filter_valid(outputs.data.cpu().numpy(),
                                                  label_mode_cpu.numpy(),
                                                  label_vec_cpu.numpy())
            if t_cpu.size > 0:
                acc_meter.add(o_cpu, t_cpu)
                confusion_matrix.count_predicted_batch(tvec_cpu,
                                                       np.argmax(o_cpu, 1))

        return meter_value(acc_meter), confusion_matrix.get_overall_accuracy(
        ), confusion_matrix.get_average_intersection_union(
        ), confusion_matrix.get_mean_class_accuracy()

    ############
    def eval_final():
        """ Evaluated model on test set in an extended way: computes estimates over multiple samples of point clouds and stores predictions """
        model.eval()

        acc_meter = tnt.meter.ClassErrorMeter(accuracy=True)
        confusion_matrix = metrics.ConfusionMatrix(dbinfo['classes'])
        collected, predictions = defaultdict(list), {}

        # collect predictions over multiple sampling seeds
        for ss in range(args.test_multisamp_n):
            test_dataset_ss = create_dataset(args, ss)[1]
            loader = torch.utils.data.DataLoader(test_dataset_ss,
                                                 batch_size=1,
                                                 collate_fn=spg.eccpc_collate,
                                                 num_workers=args.nworkers)
            if logging.getLogger().getEffectiveLevel() > logging.DEBUG:
                loader = tqdm(loader, ncols=100)

            # iterate over dataset in batches
            for bidx, (targets, GIs, clouds_data) in enumerate(loader):
                model.ecc.set_info(GIs, args.cuda)
                label_mode_cpu, label_vec_cpu, segm_size_cpu = targets[:,
                                                                       0], targets[:, 2:], targets[:, 1:].sum(
                                                                           1
                                                                       ).float(
                                                                       )

                embeddings = ptnCloudEmbedder.run(model, *clouds_data)
                outputs = model.ecc(embeddings)

                fname = clouds_data[0][0][:clouds_data[0][0].rfind('.')]
                collected[fname].append(
                    (outputs.data.cpu().numpy(), label_mode_cpu.numpy(),
                     label_vec_cpu.numpy()))

        # aggregate predictions (mean)
        for fname, lst in collected.items():
            o_cpu, t_cpu, tvec_cpu = list(zip(*lst))
            if args.test_multisamp_n > 1:
                o_cpu = np.mean(np.stack(o_cpu, 0), 0)
            else:
                o_cpu = o_cpu[0]
            t_cpu, tvec_cpu = t_cpu[0], tvec_cpu[0]
            predictions[fname] = np.argmax(o_cpu, 1)
            o_cpu, t_cpu, tvec_cpu = filter_valid(o_cpu, t_cpu, tvec_cpu)
            if t_cpu.size > 0:
                acc_meter.add(o_cpu, t_cpu)
                confusion_matrix.count_predicted_batch(tvec_cpu,
                                                       np.argmax(o_cpu, 1))

        per_class_iou = {}
        perclsiou = confusion_matrix.get_intersection_union_per_class()
        for c, name in dbinfo['inv_class_map'].items():
            per_class_iou[name] = perclsiou[c]

        return meter_value(acc_meter), confusion_matrix.get_overall_accuracy(
        ), confusion_matrix.get_average_intersection_union(
        ), per_class_iou, predictions, confusion_matrix.get_mean_class_accuracy(
        ), confusion_matrix.confusion_matrix

    ############
    # Training loop
    for epoch in range(args.start_epoch, args.epochs):
        print('Epoch {}/{} ({}):'.format(epoch, args.epochs, args.odir))
        scheduler.step()

        acc, loss, oacc, avg_iou = train()

        if (epoch + 1) % args.test_nth_epoch == 0 or epoch + 1 == args.epochs:
            acc_test, oacc_test, avg_iou_test, avg_acc_test = eval()
            print(
                '-> Train accuracy: {}, \tLoss: {}, \tTest accuracy: {}, \tTest oAcc: {}, \tTest avgIoU: {}'
                .format(acc, loss, acc_test, oacc_test, avg_iou_test))
        else:
            acc_test, oacc_test, avg_iou_test, avg_acc_test = 0, 0, 0, 0
            print('-> Train accuracy: {}, \tLoss: {}'.format(acc, loss))

        stats.append({
            'epoch': epoch,
            'acc': acc,
            'loss': loss,
            'oacc': oacc,
            'avg_iou': avg_iou,
            'acc_test': acc_test,
            'oacc_test': oacc_test,
            'avg_iou_test': avg_iou_test,
            'avg_acc_test': avg_acc_test
        })

        if epoch % args.save_nth_epoch == 0 or epoch == args.epochs - 1:
            with open(os.path.join(args.odir, 'trainlog.txt'), 'w') as outfile:
                json.dump(stats, outfile)
            torch.save(
                {
                    'epoch': epoch + 1,
                    'args': args,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                }, os.path.join(args.odir, 'model.pth.tar'))

        if math.isnan(loss): break

    if len(stats) > 0:
        with open(os.path.join(args.odir, 'trainlog.txt'), 'w') as outfile:
            json.dump(stats, outfile)

    # Final evaluation
    if args.test_multisamp_n > 0:
        acc_test, oacc_test, avg_iou_test, per_class_iou_test, predictions_test, avg_acc_test, confusion_matrix = eval_final(
        )
        print(
            '-> Multisample {}: Test accuracy: {}, \tTest oAcc: {}, \tTest avgIoU: {}, \tTest mAcc: {}'
            .format(args.test_multisamp_n, acc_test, oacc_test, avg_iou_test,
                    avg_acc_test))
        with h5py.File(
                os.path.join(args.odir,
                             'predictions_' + args.db_test_name + '.h5'),
                'w') as hf:
            for fname, o_cpu in predictions_test.items():
                hf.create_dataset(name=fname, data=o_cpu)  #(0-based classes)
        with open(
                os.path.join(args.odir,
                             'scores_' + args.db_test_name + '.txt'),
                'w') as outfile:
            json.dump([{
                'epoch': args.start_epoch,
                'acc_test': acc_test,
                'oacc_test': oacc_test,
                'avg_iou_test': avg_iou_test,
                'per_class_iou_test': per_class_iou_test,
                'avg_acc_test': avg_acc_test
            }], outfile)
        np.save(os.path.join(args.odir, 'pointwise_cm.npy'), confusion_matrix)
Beispiel #2
0
def main():
    args = parse_args()
    print('Will save to ' + args.odir)
    if not os.path.exists(args.odir):
        os.makedirs(args.odir)
    with open(os.path.join(args.odir, 'cmdline.txt'), 'w') as f:
        f.write(" ".join([
            "'" + a + "'" if (len(a) == 0 or a[0] != '-') else a
            for a in sys.argv
        ]))

    set_seed(args.seed, args.cuda)
    logging.getLogger().setLevel(
        logging.INFO)  # set to logging.DEBUG to allow for more prints
    if (args.dataset == 'sema3d' and args.db_test_name.startswith('test')) or (
            args.dataset.startswith('s3dis_02') and args.cvfold == 2):
        # needed in pytorch 0.2 for super-large graphs with batchnorm in fnet  (https://github.com/pytorch/pytorch/pull/2919)
        torch.backends.cudnn.enabled = False

    if args.use_pyg:
        torch.backends.cudnn.enabled = False

    # Decide on the dataset
    if args.dataset == 'sema3d':
        import sema3d_dataset
        dbinfo = sema3d_dataset.get_info(args)
        create_dataset = sema3d_dataset.get_datasets
    elif args.dataset == 's3dis':
        import s3dis_dataset
        dbinfo = s3dis_dataset.get_info(args)
        create_dataset = s3dis_dataset.get_datasets
    elif args.dataset == 'vkitti':
        import vkitti_dataset
        dbinfo = vkitti_dataset.get_info(args)
        create_dataset = vkitti_dataset.get_datasets
    elif args.dataset == 'custom_dataset':
        import custom_dataset  # <- to write!
        dbinfo = custom_dataset.get_info(args)
        create_dataset = custom_dataset.get_datasets
    else:
        raise NotImplementedError('Unknown dataset ' + args.dataset)

    # Create model and optimizer
    if args.resume != '':
        if args.resume == 'RESUME':
            args.resume = args.odir + '/model.pth.tar'
        model, optimizer, stats = resume(args, dbinfo)
    else:
        model = create_model(args, dbinfo)
        optimizer = create_optimizer(args, model)
        stats = []

    train_dataset, test_dataset, valid_dataset, scaler = create_dataset(args)

    print(
        f"Train dataset: {len(train_dataset)} elements - Test dataset: {len(test_dataset)} elements - "
        f"Validation dataset: {len(valid_dataset)} elements")
    ptnCloudEmbedder = pointnet.CloudEmbedder(args)
    scheduler = MultiStepLR(optimizer,
                            milestones=args.lr_steps,
                            gamma=args.lr_decay,
                            last_epoch=args.start_epoch - 1)

    def train():
        """ Trains for one epoch """
        model.train()

        loader = torch.utils.data.DataLoader(train_dataset,
                                             batch_size=args.batch_size,
                                             collate_fn=spg.eccpc_collate,
                                             num_workers=args.nworkers,
                                             shuffle=True,
                                             drop_last=True)
        if logging.getLogger().getEffectiveLevel() > logging.DEBUG:
            loader = tqdm(loader, ncols=65)

        loss_meter = tnt.meter.AverageValueMeter()
        acc_meter = tnt.meter.ClassErrorMeter(accuracy=True)
        confusion_matrix = metrics.ConfusionMatrix(dbinfo['classes'])
        t0 = time.time()

        # iterate over dataset in batches
        for bidx, (targets, GIs, clouds_data) in enumerate(loader):
            t_loader = 1000 * (time.time() - t0)

            model.ecc.set_info(GIs, args.cuda)
            label_mode_cpu, label_vec_cpu, segm_size_cpu = targets[:,
                                                                   0], targets[:,
                                                                               2:], targets[:, 1:].sum(
                                                                                   1
                                                                               )
            if args.cuda:
                label_mode, label_vec, segm_size = label_mode_cpu.cuda(
                ), label_vec_cpu.float().cuda(), segm_size_cpu.float().cuda()
            else:
                label_mode, label_vec, segm_size = label_mode_cpu, label_vec_cpu.float(
                ), segm_size_cpu.float()

            optimizer.zero_grad()
            t0 = time.time()

            embeddings = ptnCloudEmbedder.run(model, *clouds_data)
            outputs = model.ecc(embeddings)

            loss = nn.functional.cross_entropy(outputs,
                                               Variable(label_mode),
                                               weight=dbinfo["class_weights"])

            loss.backward()
            ptnCloudEmbedder.bw_hook()

            if args.grad_clip > 0:
                for p in model.parameters():
                    p.grad.data.clamp_(-args.grad_clip, args.grad_clip)
            optimizer.step()

            t_trainer = 1000 * (time.time() - t0)
            loss_meter.add(loss.item())  # pytorch 0.4

            o_cpu, t_cpu, tvec_cpu = filter_valid(outputs.data.cpu().numpy(),
                                                  label_mode_cpu.numpy(),
                                                  label_vec_cpu.numpy())
            acc_meter.add(o_cpu, t_cpu)
            confusion_matrix.count_predicted_batch(tvec_cpu,
                                                   np.argmax(o_cpu, 1))

            logging.debug(
                'Batch loss %f, Loader time %f ms, Trainer time %f ms.',
                loss.data.item(), t_loader, t_trainer)
            t0 = time.time()

        metrics_trn = {
            'acc': acc_meter.value()[0],
            'loss': loss_meter.value()[0],
            'oacc': confusion_matrix.get_overall_accuracy(),
            'avg_iou': confusion_matrix.get_average_intersection_union()
        }
        return metrics_trn

    def eval(is_valid=False):
        """ Evaluated model on test set """
        model.eval()

        if is_valid:  # validation
            loader = torch.utils.data.DataLoader(valid_dataset,
                                                 batch_size=1,
                                                 collate_fn=spg.eccpc_collate,
                                                 num_workers=args.nworkers)
        else:  # evaluation
            loader = torch.utils.data.DataLoader(test_dataset,
                                                 batch_size=1,
                                                 collate_fn=spg.eccpc_collate,
                                                 num_workers=args.nworkers)

        if logging.getLogger().getEffectiveLevel() > logging.DEBUG:
            loader = tqdm(loader, ncols=65)

        acc_meter = tnt.meter.ClassErrorMeter(accuracy=True)
        loss_meter = tnt.meter.AverageValueMeter()
        confusion_matrix = metrics.ConfusionMatrix(dbinfo['classes'])

        # iterate over dataset in batches
        for bidx, (targets, GIs, clouds_data) in enumerate(loader):
            model.ecc.set_info(GIs, args.cuda)
            label_mode_cpu, label_vec_cpu, segm_size_cpu = targets[:,
                                                                   0], targets[:,
                                                                               2:], targets[:, 1:].sum(
                                                                                   1
                                                                               ).float(
                                                                               )
            if args.cuda:
                label_mode, label_vec, segm_size = label_mode_cpu.cuda(
                ), label_vec_cpu.float().cuda(), segm_size_cpu.float().cuda()
            else:
                label_mode, label_vec, segm_size = label_mode_cpu, label_vec_cpu.float(
                ), segm_size_cpu.float()

            embeddings = ptnCloudEmbedder.run(model, *clouds_data)
            outputs = model.ecc(embeddings)

            loss = nn.functional.cross_entropy(outputs,
                                               Variable(label_mode),
                                               weight=dbinfo["class_weights"])
            loss_meter.add(loss.item())

            o_cpu, t_cpu, tvec_cpu = filter_valid(outputs.data.cpu().numpy(),
                                                  label_mode_cpu.numpy(),
                                                  label_vec_cpu.numpy())
            if t_cpu.size > 0:
                acc_meter.add(o_cpu, t_cpu)
                confusion_matrix.count_predicted_batch(tvec_cpu,
                                                       np.argmax(o_cpu, 1))

        metrics_eval = {
            'acc': meter_value(acc_meter),
            'loss': loss_meter.value()[0],
            'oacc': confusion_matrix.get_overall_accuracy(),
            'avg_iou': confusion_matrix.get_average_intersection_union(),
            'avg_acc': confusion_matrix.get_mean_class_accuracy()
        }
        return metrics_eval

    def eval_final():
        """ Evaluated model on test set in an extended way: computes estimates over multiple samples of point clouds and stores predictions """
        model.eval()

        acc_meter = tnt.meter.ClassErrorMeter(accuracy=True)
        confusion_matrix = metrics.ConfusionMatrix(dbinfo['classes'])
        collected, predictions = defaultdict(list), {}

        # collect predictions over multiple sampling seeds
        for ss in range(args.test_multisamp_n):
            test_dataset_ss = create_dataset(args, ss)[1]
            loader = torch.utils.data.DataLoader(test_dataset_ss,
                                                 batch_size=1,
                                                 collate_fn=spg.eccpc_collate,
                                                 num_workers=args.nworkers)
            if logging.getLogger().getEffectiveLevel() > logging.DEBUG:
                loader = tqdm(loader, ncols=65)

            # iterate over dataset in batches
            for bidx, (targets, GIs, clouds_data) in enumerate(loader):
                model.ecc.set_info(GIs, args.cuda)
                label_mode_cpu, label_vec_cpu, segm_size_cpu = targets[:,
                                                                       0], targets[:, 2:], targets[:, 1:].sum(
                                                                           1
                                                                       ).float(
                                                                       )

                embeddings = ptnCloudEmbedder.run(model, *clouds_data)
                outputs = model.ecc(embeddings)

                fname = clouds_data[0][0][:clouds_data[0][0].rfind('.')]
                collected[fname].append(
                    (outputs.data.cpu().numpy(), label_mode_cpu.numpy(),
                     label_vec_cpu.numpy()))

        # aggregate predictions (mean)
        for fname, lst in collected.items():
            o_cpu, t_cpu, tvec_cpu = list(zip(*lst))
            if args.test_multisamp_n > 1:
                o_cpu = np.mean(np.stack(o_cpu, 0), 0)
            else:
                o_cpu = o_cpu[0]
            t_cpu, tvec_cpu = t_cpu[0], tvec_cpu[0]
            predictions[fname] = np.argmax(o_cpu, 1)
            o_cpu, t_cpu, tvec_cpu = filter_valid(o_cpu, t_cpu, tvec_cpu)
            if t_cpu.size > 0:
                acc_meter.add(o_cpu, t_cpu)
                confusion_matrix.count_predicted_batch(tvec_cpu,
                                                       np.argmax(o_cpu, 1))

        per_class_iou = {}
        perclsiou = confusion_matrix.get_intersection_union_per_class()
        for c, name in dbinfo['inv_class_map'].items():
            per_class_iou[name] = perclsiou[c]

        metrics_final = {
            'acc': meter_value(acc_meter),
            'oacc': confusion_matrix.get_overall_accuracy(),
            'avg_iou': confusion_matrix.get_average_intersection_union(),
            'per_class_iou': per_class_iou,
            'predictions': predictions,
            'avg_acc': confusion_matrix.get_mean_class_accuracy(),
            'confusion_matrix': confusion_matrix.confusion_matrix
        }
        return metrics_final

    # Training loop
    try:
        best_iou = stats[-1]['best_iou']
    except:
        best_iou = 0
    TRAIN_COLOR = '\033[0m'
    VAL_COLOR = '\033[0;94m'
    TEST_COLOR = '\033[0;93m'
    BEST_COLOR = '\033[0;92m'
    epoch = args.start_epoch

    for epoch in range(args.start_epoch, args.epochs):
        print(f"Epoch {epoch}/{args.epochs} ({args.odir})")
        scheduler.step()

        metrics_trn = train()

        print(
            TRAIN_COLOR +
            f"-> Train Loss: {metrics_trn['loss']:.4f}  Train accuracy: {metrics_trn['acc']:.2f}%"
        )

        new_best_model = False
        if args.use_val_set:
            metrics_eval = eval(True)
            print(
                VAL_COLOR +
                f"-> Val Loss: {metrics_eval['loss']:.4f}  Val accuracy: {metrics_eval['acc']:.2f}%  "
                f"Val oAcc: {100*metrics_eval['oacc']:.2f}%  Val IoU: {100*metrics_eval['avg_iou']:.2f}%  "
                f"best ioU: {100*max(best_iou,metrics_eval['avg_iou']):.2f}%" +
                TRAIN_COLOR)
            if metrics_eval[
                    'avg_iou'] > best_iou:  # best score yet on the validation set
                print(BEST_COLOR + '-> New best model achieved!' + TRAIN_COLOR)
                best_iou = metrics_eval['avg_iou']
                new_best_model = True
                torch.save(
                    {
                        'epoch': epoch + 1,
                        'args': args,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict()
                    }, os.path.join(args.odir, 'model.pth.tar'))
        elif epoch % args.save_nth_epoch == 0 or epoch == args.epochs - 1:
            torch.save(
                {
                    'epoch': epoch + 1,
                    'args': args,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                }, os.path.join(args.odir, 'model.pth.tar'))
        # test every test_nth_epochs
        # or test after each new model (but skip the first 5 for efficiency)
        if (not args.use_val_set and (epoch + 1) % args.test_nth_epoch
                == 0) or (args.use_val_set and new_best_model and epoch > 5):
            metrics_test = eval(False)
            print(
                TEST_COLOR +
                f"-> Test Loss: {metrics_test['loss']:.4f}  Test accuracy: {metrics_test['acc']:.2f}%  "
                f"Test oAcc: {100*metrics_test['oacc']:.2f}%  Test avgIoU: {100*metrics_test['avg_iou']:.2f}%"
                + TRAIN_COLOR)
        else:
            metrics_test = {
                'acc': 0,
                'loss': 0,
                'oacc': 0,
                'avg_iou': 0,
                'avg_acc': 0
            }

        stats.append({
            'epoch': epoch,
            'acc_trn': metrics_trn['acc'],
            'loss_trn': metrics_trn['loss'],
            'oacc_trn': metrics_trn['oacc'],
            'avg_iou_trn': metrics_trn['avg_iou'],
            'acc_test': metrics_test['acc'],
            'oacc_test': metrics_test['oacc'],
            'avg_iou_test': metrics_test['avg_iou'],
            'avg_acc_test': metrics_test['avg_acc'],
            'best_iou': best_iou
        })

        if math.isnan(metrics_trn['loss']):
            break

    if len(stats) > 0:
        with open(os.path.join(args.odir, 'trainlog.json'), 'w') as outfile:
            json.dump(stats, outfile, indent=4)

    if args.use_val_set:
        args.resume = args.odir + '/model.pth.tar'
        model, optimizer, stats = resume(args, dbinfo)
        torch.save(
            {
                'epoch': epoch + 1,
                'args': args,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            }, os.path.join(args.odir, 'model.pth.tar'))

    # Final evaluation
    if args.test_multisamp_n > 0 and 'test' in args.db_test_name:
        metrics_test = eval_final()
        print(
            f"-> Multisample {args.test_multisamp_n}: Test accuracy: {metrics_test['acc']:.2f}, \tTest oAcc: {metrics_test['oacc']:.2f}, "
            f"\tTest avgIoU: {metrics_test['avg_iou']:.2f}, \tTest mAcc: {metrics_test['avg_acc']:.2f}'"
        )
        with h5py.File(
                os.path.join(args.odir,
                             'predictions_' + args.db_test_name + '.h5'),
                'w') as hf:
            for fname, o_cpu in metrics_test['predictions'].items():
                hf.create_dataset(name=fname, data=o_cpu)  # (0-based classes)
        with open(
                os.path.join(args.odir,
                             'scores_' + args.db_test_name + '.json'),
                'w') as outfile:
            json.dump([{
                'epoch': args.start_epoch,
                'acc_test': metrics_test['acc'],
                'oacc_test': metrics_test['oacc'],
                'avg_iou_test': metrics_test['avg_iou'],
                'per_class_iou_test': metrics_test['per_class_iou'],
                'avg_acc_test': metrics_test['avg_acc']
            }], outfile)
        np.save(os.path.join(args.odir, 'pointwise_cm.npy'),
                metrics_test['confusion_matrix'])