Beispiel #1
0
    # Copy the loss layer
    criterion = nn.CrossEntropyLoss()
    criterions = parallel.replicate(criterion, devices)
    min_time = np.inf

    for iteration in range(10):
        optimizer.zero_grad()

        # Get new data
        inputs, all_labels = [], []
        for i in range(num_devices):
            coordinates, features, labels = generate_input(config.file_name,
                                                           voxel_size=0.05)
            with torch.cuda.device(devices[i]):
                inputs.append(
                    ME.SparseTensor(features - 0.5,
                                    coords=coordinates).to(devices[i]))
            all_labels.append(labels.long().to(devices[i]))

        # The raw version of the parallel_apply
        st = time()
        replicas = parallel.replicate(net, devices)
        outputs = parallel.parallel_apply(replicas, inputs, devices=devices)

        # Extract features from the sparse tensors to use a pytorch criterion
        out_features = [output.F for output in outputs]
        losses = parallel.parallel_apply(criterions,
                                         tuple(zip(out_features, all_labels)),
                                         devices=devices)
        loss = parallel.gather(losses, target_device, dim=0).mean()
        t = time() - st
        min_time = min(t, min_time)
Beispiel #2
0
def train(net, loaders, device, logger, config):
    ######################
    # optimizer
    ######################
    if (config['optimizer'] == 'SGD'):
        optimizer = optim.SGD(net.parameters(),
                              lr=config['sgd_lr'],
                              momentum=config['sgd_momentum'],
                              dampening=config['sgd_dampening'],
                              weight_decay=config['weight_decay'])
    elif config['optimizer'] == 'Adam':
        optimizer = optim.Adam(net.parameters(),
                               lr=config['adam_lr'],
                               betas=(config['adam_beta1'],
                                      config['adam_beta2']),
                               weight_decay=config['weight_decay'])

    ######################
    # loss
    ######################
    sem_criterion = torch.nn.CrossEntropyLoss(
        ignore_index=config['sem_num_labels'])
    clf_criterion = torch.nn.CrossEntropyLoss()

    ######################
    # restore model
    ######################
    writer = SummaryWriter(log_dir=config['dump_dir'])
    restore_iter = 1
    curr_best_metric = 0
    path_checkpoint = config['pretrained_weights']
    path_new_model = os.path.join(config['dump_dir'], config['checkpoint'])
    if os.path.exists(path_checkpoint) and config['restore']:
        checkpoint = torch.load(path_checkpoint)  # checkpoint
        pretrained_dict = checkpoint['state_dict']

        # update pretrained layers
        model_dict = net.state_dict()
        model_dict.update(pretrained_dict)
        net.load_state_dict(model_dict)

        # freeze layers
        for (name, layer) in net._modules.items():
            if (not name.startswith('clf')):
                c_module = eval('net.%s' % name)
                for param in c_module.parameters():
                    param.requires_grad = False

    #######################
    # train and val loader
    #######################
    train_loader = iter(loaders['train'])
    val_loader = loaders['val']

    ###########################
    ## training
    ###########################
    net.train()
    start = time.time()
    stats_train = init_stats()
    for curr_iter in range(restore_iter, config['max_iter']):
        # train on one batch and optimize
        data_dict = train_loader.next()
        optimizer.zero_grad()
        sin = ME.SparseTensor(data_dict['feats'],
                              data_dict['coords'].int()).to(device)
        _, clf_out = net(sin)
        writer.add_scalar('lr', get_lr(optimizer), curr_iter)

        ###########################
        ## scene classification part
        ###########################
        loss = clf_criterion(clf_out.F, data_dict['labels'].to(device))
        loss.backward()
        optimizer.step()

        writer.add_scalar('train/iter_loss', loss.item(), curr_iter)

        stats_train['clf_loss'] += loss.item()
        is_correct = data_dict['labels'] == torch.argmax(clf_out.F, 1).cpu()
        stats_train['clf_correct'] += is_correct.sum().item()
        stats_train['num_samples'] += data_dict['labels'].size()[0]

        ###########################
        ## validation
        ###########################
        if curr_iter % config['val_freq'] == 0:
            end = time.time()
            ### evaluate
            net.eval()
            stats_val = init_stats()
            n_iters = 0
            with torch.no_grad():  # avoid out of memory problem
                for data_dict in val_loader:
                    sin = ME.SparseTensor(data_dict['feats'],
                                          data_dict['coords'].int()).to(device)
                    _, clf_out = net(sin)

                    ###########################
                    ## scene classification part
                    ###########################
                    loss = clf_criterion(clf_out.F,
                                         data_dict['labels'].to(device))
                    stats_val['clf_loss'] += loss.item()

                    is_correct = data_dict['labels'] == torch.argmax(
                        clf_out.F, 1).cpu()
                    stats_val['clf_correct'] += is_correct.sum().item()
                    stats_val['num_samples'] += data_dict['labels'].size()[0]

                    n_iters += 1

            ###########################
            ## scene stats
            ###########################
            writer.add_scalar('train/clf_loss',
                              stats_train['clf_loss'] / config['val_freq'],
                              curr_iter)
            writer.add_scalar(
                'train/clf_acc',
                stats_train['clf_correct'] / stats_train['num_samples'],
                curr_iter)
            writer.add_scalar('validate/clf_loss',
                              stats_val['clf_loss'] / n_iters, curr_iter)
            writer.add_scalar(
                'validate/clf_acc',
                stats_val['clf_correct'] / stats_val['num_samples'], curr_iter)

            logger.info('Iter: %d, time: %d s' % (curr_iter, end - start))

            logger.info(
                'Train: clf_acc: %.3f, clf_loss: %.3f' %
                (stats_train['clf_correct'] / stats_train['num_samples'],
                 stats_train['clf_loss'] / config['val_freq']))

            logger.info('Val  : clf_acc: %.3f, clf_loss: %.3f' %
                        (stats_val['clf_correct'] / stats_val['num_samples'],
                         stats_val['clf_loss'] / n_iters))

            ### update checkpoint
            val_acc = stats_val['clf_correct'] / stats_val['num_samples']
            c_metric = val_acc
            if (c_metric > curr_best_metric):
                curr_best_metric = c_metric
                torch.save({'state_dict': net.state_dict()}, path_new_model)
                logger.info(
                    '---------- model updated, best metric: %.3f ----------' %
                    curr_best_metric)

            stats_train = init_stats()
            start = time.time()
            net.train()
Beispiel #3
0
def main(args):
    def log_string(str):
        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    experiment_dir = Path('./log/')
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath('part_seg')
    experiment_dir.mkdir(exist_ok=True)
    if args.log_dir is None:
        experiment_dir = experiment_dir.joinpath(timestr)
    else:
        experiment_dir = experiment_dir.joinpath(args.log_dir)
    experiment_dir.mkdir(exist_ok=True)
    checkpoints_dir = experiment_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = experiment_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG'''
    args = parse_args()
    logger = logging.getLogger("Model")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    log_string('PARAMETER ...')
    log_string(args)

    root = '/media/feihu/Storage/kitti_point_cloud/semantic_kitti/'
    file_list = '/media/feihu/Storage/kitti_point_cloud/semantic_kitti/train2.list'
    val_list = '/media/feihu/Storage/kitti_point_cloud/semantic_kitti/val2.list'
    TRAIN_DATASET = KittiDataset(root=root,
                                 file_list=file_list,
                                 npoints=args.npoint,
                                 training=True,
                                 augment=True)
    trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET,
                                                  batch_size=args.batch_size,
                                                  shuffle=True,
                                                  drop_last=True,
                                                  num_workers=2)
    TEST_DATASET = KittiDataset(root=root,
                                file_list=val_list,
                                npoints=args.npoint,
                                training=False,
                                augment=False)
    testDataLoader = torch.utils.data.DataLoader(TEST_DATASET,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 drop_last=True,
                                                 num_workers=2)
    log_string("The number of training data is: %d" % len(TRAIN_DATASET))
    log_string("The number of test data is: %d" % len(TEST_DATASET))
    #    num_classes = 16
    '''MODEL LOADING'''

    shutil.copy('models/%s.py' % args.model, str(experiment_dir))
    shutil.copy('models/pointnet_util.py', str(experiment_dir))

    num_devices = args.num_gpus  #torch.cuda.device_count()
    #    assert num_devices > 1, "Cannot detect more than 1 GPU."
    #    print(num_devices)
    devices = list(range(num_devices))
    target_device = devices[0]

    #    MODEL = importlib.import_module(args.model)

    net = FusionNet(args.npoint, 4, 20, nPlanes)

    #    net = MODEL.get_model(num_classes, normal_channel=args.normal)
    net = net.to(target_device)

    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv2d') != -1:
            if m.weight is not None:
                torch.nn.init.xavier_normal_(m.weight.data)
            if m.bias is not None:
                torch.nn.init.constant_(m.bias.data, 0.0)
        elif classname.find('Linear') != -1:
            if m.weight is not None:
                torch.nn.init.xavier_normal_(m.weight.data)
            if m.bias is not None:
                torch.nn.init.constant_(m.bias.data, 0.0)

    try:
        checkpoint = torch.load(
            str(experiment_dir) + '/checkpoints/best_model.pth')
        start_epoch = checkpoint['epoch']
        net.load_state_dict(checkpoint['model_state_dict'])
        log_string('Use pretrain model')
    except:
        log_string('No existing model, starting training from scratch...')
        start_epoch = 0
        net = net.apply(weights_init)

    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(net.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=args.decay_rate)
    else:
        optimizer = torch.optim.SGD(net.parameters(),
                                    lr=1e-1,
                                    momentum=0.9,
                                    weight_decay=1e-4,
                                    nesterov=True)
#        optimizer = torch.optim.SGD(net.parameters(), lr=args.learning_rate, momentum=0.9)

    def bn_momentum_adjust(m, momentum):
        if isinstance(m, torch.nn.BatchNorm2d) or isinstance(
                m, torch.nn.BatchNorm1d):
            m.momentum = momentum

    LEARNING_RATE_CLIP = 1e-5
    MOMENTUM_ORIGINAL = 0.1
    MOMENTUM_DECCAY = 0.5
    MOMENTUM_DECCAY_STEP = 20 / 2  # args.step_size

    best_acc = 0
    global_epoch = 0
    best_class_avg_iou = 0
    best_inctance_avg_iou = 0

    #    criterion = MODEL.get_loss()
    criterion = nn.CrossEntropyLoss()
    criterions = parallel.replicate(criterion, devices)

    # The raw version of the parallel_apply
    #    replicas = parallel.replicate(net, devices)
    #    input_coding = scn.InputLayer(dimension, torch.LongTensor(spatialSize), mode=4)

    for epoch in range(start_epoch, args.epoch):
        log_string('Epoch %d (%d/%s):' %
                   (global_epoch + 1, epoch + 1, args.epoch))
        '''Adjust learning rate and BN momentum'''

        #        lr = max(args.learning_rate * (args.lr_decay ** (epoch // args.step_size)), LEARNING_RATE_CLIP)
        #        lr = args.learning_rate * \
        #            math.exp((1 - epoch) * args.lr_decay)

        #        log_string('Learning rate:%f' % lr)

        #        for param_group in optimizer.param_groups:
        #            param_group['lr'] = lr
        #        for param_group in optimizer.param_groups:
        #            param_group['lr'] = lr

        mean_correct = []
        if 1:
            momentum = MOMENTUM_ORIGINAL * (MOMENTUM_DECCAY
                                            **(epoch // MOMENTUM_DECCAY_STEP))
            if momentum < 0.01:
                momentum = 0.01
            print('BN momentum updated to: %f' % momentum)
            net = net.apply(lambda x: bn_momentum_adjust(x, momentum))
        '''learning one epoch'''
        net.train()

        #        for iteration, data in tqdm(enumerate(trainDataLoader), total=len(trainDataLoader), smoothing=0.9):
        for iteration, data in enumerate(trainDataLoader):
            #adjust learing rate.
            if (iteration) % 320 == 0:
                lr_count = epoch * 6 + (iteration) / 320
                lr = args.learning_rate * math.exp(
                    (1 - lr_count) * args.lr_decay)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr

                log_string('Learning rate:%f' % lr)

            optimizer.zero_grad()
            if iteration > 1920:
                break
            points, target, ins, mask = data
            #            print(torch.max(points[:, :, :3], 1)[0])
            #            print(torch.min(points[:, :, :3], 1)[0])

            valid = mask > 0
            total_points = valid.sum()
            orgs = points
            points = points.data.numpy()
            #            print(total_points)
            inputs, targets, masks = [], [], []
            coords = []
            for i in range(num_devices):
                start = int(i * (args.batch_size / num_devices))
                end = int((i + 1) * (args.batch_size / num_devices))
                batch = provider.transform_for_sparse(
                    points[start:end, :, :3], points[start:end, :, 3:],
                    target[start:end, :].data.numpy(),
                    mask[start:end, :].data.numpy(), scale, spatialSize)
                batch['x'][1] = batch['x'][1].type(torch.FloatTensor)
                batch['x'][0] = batch['x'][0].type(torch.IntTensor)
                batch['y'] = batch['y'].type(torch.LongTensor)

                org_xyz = orgs[start:end, :, :3].transpose(1, 2).contiguous()
                org_feas = orgs[start:end, :, 3:].transpose(1, 2).contiguous()

                label = Variable(batch['y'], requires_grad=False)
                maski = batch['mask'].type(torch.IntTensor)
                #                print(torch.max(batch['x'][0], 0)[0])
                #                print(torch.min(batch['x'][0], 0)[0])
                #                locs, feas = input_layer(batch['x'][0].to(devices[i]), batch['x'][1].to(devices[i]))
                locs, feas = input_layer(batch['x'][0].cuda(),
                                         batch['x'][1].cuda())
                #                print(locs.size(), feas.size(), batch['x'][0].size())

                #               print(inputi.size(), batch['x'][1].size())

                with torch.cuda.device(devices[i]):
                    org_coords = batch['x'][0].to(devices[i])
                    inputi = ME.SparseTensor(feas.cpu(), locs).to(
                        devices[i])  #input_coding(batch['x'])
                    org_xyz = org_xyz.to(devices[i])
                    org_feas = org_feas.to(devices[i])
                    maski = maski.to(devices[i])
                    inputs.append(
                        [inputi, org_coords, org_xyz, org_feas, maski])
                    targets.append(label.to(devices[i]))
#                    masks.append(maski.contiguous().to(devices[i]))

            replicas = parallel.replicate(net, devices)
            predictions = parallel.parallel_apply(replicas,
                                                  inputs,
                                                  devices=devices)

            count = 0
            #            print("end ...")
            results = []
            labels = []
            match = 0

            for i in range(num_devices):
                #               temp = predictions[i]['output1'].F#.view(-1, num_classes)
                temp = predictions[i]
                #                temp = output_layer(locs, predictions[i]['output1'].F, coords[i])
                temp = temp[targets[i] > 0, :]
                results.append(temp)

                temp = targets[i]
                temp = temp[targets[i] > 0]
                labels.append(temp)
                #               print(prediction2[i].size(), prediction1[i].size(), targets[i].size())
                outputi = results[
                    i]  #prediction2[i].contiguous().view(-1, num_classes)
                num_points = labels[i].size(0)
                count += num_points

                _, pred_choice = outputi.data.max(1)  #[1]
                #                print(pred_choice)
                correct = pred_choice.eq(labels[i].data).cpu().sum()
                match += correct.item()
                mean_correct.append(correct.item() / num_points)
#            print(prediction2, labels)
            losses = parallel.parallel_apply(criterions,
                                             tuple(zip(results, labels)),
                                             devices=devices)
            loss = parallel.gather(losses, target_device, dim=0).mean()
            loss.backward()
            optimizer.step()
            #            assert(count1 == count2 and total_points == count1)
            log_string(
                "===> Epoch[{}]({}/{}) Valid points:{}/{} Loss: {:.4f} Accuracy: {:.4f}"
                .format(epoch, iteration, len(trainDataLoader), count,
                        total_points, loss.item(), match / count))
#            sys.stdout.flush()
        train_instance_acc = np.mean(mean_correct)
        log_string('Train accuracy is: %.5f' % train_instance_acc)

        #        continue

        with torch.no_grad():
            net.eval()
            evaluator = iouEval(num_classes, ignore)

            evaluator.reset()
            for iteration, (points, target, ins,
                            mask) in tqdm(enumerate(testDataLoader),
                                          total=len(testDataLoader),
                                          smoothing=0.9):
                cur_batch_size, NUM_POINT, _ = points.size()
                #                points, label, target, mask = points.float().cuda(), label.long().cuda(), target.long().cuda(), mask.float().cuda()
                if iteration > 192:
                    break
                if 0:
                    points = points.data.numpy()
                    points[:, :, 0:3], norm = provider.pc_normalize(
                        points[:, :, :3], mask.data.numpy())
                    points = torch.Tensor(points)
                orgs = points
                points = points.data.numpy()
                inputs, targets, masks = [], [], []
                coords = []
                for i in range(num_devices):
                    start = int(i * (cur_batch_size / num_devices))
                    end = int((i + 1) * (cur_batch_size / num_devices))
                    batch = provider.transform_for_test(
                        points[start:end, :, :3], points[start:end, :, 3:],
                        target[start:end, :].data.numpy(),
                        mask[start:end, :].data.numpy(), scale, spatialSize)
                    batch['x'][1] = batch['x'][1].type(torch.FloatTensor)
                    batch['x'][0] = batch['x'][0].type(torch.IntTensor)
                    batch['y'] = batch['y'].type(torch.LongTensor)

                    org_xyz = orgs[start:end, :, :3].transpose(1,
                                                               2).contiguous()
                    org_feas = orgs[start:end, :,
                                    3:].transpose(1, 2).contiguous()

                    label = Variable(batch['y'], requires_grad=False)
                    maski = batch['mask'].type(torch.IntTensor)
                    locs, feas = input_layer(batch['x'][0].cuda(),
                                             batch['x'][1].cuda())
                    #                print(locs.size(), feas.size(), batch['x'][0].size())

                    #               print(inputi.size(), batch['x'][1].size())
                    with torch.cuda.device(devices[i]):
                        org_coords = batch['x'][0].to(devices[i])
                        inputi = ME.SparseTensor(feas.cpu(), locs).to(
                            devices[i])  #input_coding(batch['x'])
                        org_xyz = org_xyz.to(devices[i])
                        org_feas = org_feas.to(devices[i])
                        maski = maski.to(devices[i])
                        inputs.append(
                            [inputi, org_coords, org_xyz, org_feas, maski])
                        targets.append(label.to(devices[i]))
#                        masks.append(maski.contiguous().to(devices[i]))

                replicas = parallel.replicate(net, devices)
                outputs = parallel.parallel_apply(replicas,
                                                  inputs,
                                                  devices=devices)

                #                net = net.eval()
                #                seg_pred = classifier(points, to_categorical(label, num_classes))
                seg_pred = outputs[0].cpu()
                #                mask = masks[0].cpu()
                target = targets[0].cpu()
                loc = locs[0].cpu()
                for i in range(1, num_devices):
                    seg_pred = torch.cat((seg_pred, outputs[i].cpu()), 0)
                    #                    mask = torch.cat((mask, masks[i].cpu()), 0)
                    target = torch.cat((target, targets[i].cpu()), 0)

                seg_pred = seg_pred[target > 0, :]
                target = target[target > 0]
                _, seg_pred = seg_pred.data.max(1)  #[1]

                target = target.data.numpy()

                evaluator.addBatch(seg_pred, target)

# when I am done, print the evaluation
            m_accuracy = evaluator.getacc()
            m_jaccard, class_jaccard = evaluator.getIoU()

            log_string('Validation set:\n'
                       'Acc avg {m_accuracy:.3f}\n'
                       'IoU avg {m_jaccard:.3f}'.format(m_accuracy=m_accuracy,
                                                        m_jaccard=m_jaccard))
            # print also classwise
            for i, jacc in enumerate(class_jaccard):
                if i not in ignore:
                    log_string(
                        'IoU class {i:} [{class_str:}] = {jacc:.3f}'.format(
                            i=i,
                            class_str=class_strings[class_inv_remap[i]],
                            jacc=jacc))

        log_string('Epoch %d test Accuracy: %f  mean avg mIOU: %f' %
                   (epoch + 1, m_accuracy, m_jaccard))
        if (m_jaccard >= best_class_avg_iou):
            #            logger.info('Save model...')
            log_string('Saveing model...')
            savepath = str(checkpoints_dir) + '/best_model.pth'
            log_string('Saving at %s' % savepath)
            state = {
                'epoch': epoch,
                'train_acc': train_instance_acc,
                'test_acc': m_accuracy,
                'class_avg_iou': m_jaccard,
                'model_state_dict': net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }
            torch.save(state, savepath)


#            log_string('Saving model....')

        if m_accuracy > best_acc:
            best_acc = m_accuracy
        if m_jaccard > best_class_avg_iou:
            best_class_avg_iou = m_jaccard

        log_string('Best accuracy is: %.5f' % best_acc)
        log_string('Best class avg mIOU is: %.5f' % best_class_avg_iou)

        global_epoch += 1
Beispiel #4
0
def check_data(model, train_data_loader, val_data_loader, config):
    data_iter = train_data_loader.__iter__()
    import ipdb
    ipdb.set_trace()

    sample_size = 1
    strides = [2, 4, 8, 16]
    # strides = [2]

    for stride in strides:

        coordinate_map_key = [stride] * 3
        pool0 = ME.MinkowskiSumPooling(kernel_size=stride,
                                       stride=stride,
                                       dilation=1,
                                       dimension=3)
        all_neis_scenes = []
        for _ in range(sample_size):
            coords, input, target, _, _ = data_iter.next()

            assert coords[:, 0].max() == 0  # assert bs=1
            x = ME.SparseTensor(input, coords, device='cuda')
            x = pool0(x)

            d = {}
            neis_d = x.coordinate_manager.get_kernel_map(
                x.coordinate_map_key,
                x.coordinate_map_key,
                kernel_size=3,
                stride=1,
                dilation=1,
            )
            # d['all_c'] = x.C[:,1:]
            N = x.C.shape[0]
            k = 27
            all_neis = []
            for k_ in range(k):

                if not k_ in neis_d.keys():
                    continue

                neis_ = torch.gather(x.C[:, 1:].float(),
                                     dim=0,
                                     index=neis_d[k_][0].reshape(-1, 1).repeat(
                                         1, 3).long())
                neis = torch.zeros(N, 3, device=x.F.device)
                neis.scatter_(dim=0,
                              index=neis_d[k_][1].reshape(-1,
                                                          1).repeat(1,
                                                                    3).long(),
                              src=neis_)
                neis = (neis.sum(-1) > 0).int()
                all_neis.append(neis)
            all_neis = torch.stack(all_neis)
            all_neis_scenes.append(all_neis)

        # d['neis'] = all_neis
        d['sparse_mask'] = torch.cat(all_neis_scenes, dim=-1)
        print('Stride:{} Shape:'.format(stride), d['sparse_mask'].shape)
        name = "sparse_mask_s{}_nuscenes.pth".format(stride)
        # name = 'test-kitti'
        torch.save(
            d,
            '/home/zhaotianchen/project/point-transformer/SpatioTemporalSegmentation-ScanNet/plot/final/{}'
            .format(name))
        # import ipdb; ipdb.set_trace()

    import ipdb
    ipdb.set_trace()
Beispiel #5
0
def train(net, device, config):
    optimizer = optim.SGD(
        net.parameters(),
        lr=config.lr,
        momentum=config.momentum,
        weight_decay=config.weight_decay,
    )
    scheduler = optim.lr_scheduler.ExponentialLR(
        optimizer,
        0.999,
    )

    crit = torch.nn.CrossEntropyLoss()

    train_dataloader = make_data_loader(
        "train",
        augment_data=True,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        repeat=True,
        config=config,
    )
    val_dataloader = make_data_loader(
        "val",
        augment_data=False,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        repeat=True,
        config=config,
    )

    curr_iter = 0
    if os.path.exists(config.weights):
        checkpoint = torch.load(config.weights)
        net.load_state_dict(checkpoint["state_dict"])
        if config.load_optimizer.lower() == "true":
            curr_iter = checkpoint["curr_iter"] + 1
            optimizer.load_state_dict(checkpoint["optimizer"])
            scheduler.load_state_dict(checkpoint["scheduler"])

    net.train()
    train_iter = iter(train_dataloader)
    val_iter = iter(val_dataloader)
    logging.info(f"LR: {scheduler.get_lr()}")
    for i in range(curr_iter, config.max_iter):

        s = time()
        data_dict = train_iter.next()
        d = time() - s

        optimizer.zero_grad()
        sin = ME.TensorField(data_dict["feats"],
                             data_dict["coords"],
                             device=device)
        sout = net(sin)
        loss = crit(sout.F, data_dict["labels"].to(device))
        loss.backward()
        optimizer.step()
        t = time() - s

        if i % config.empty_freq == 0:
            torch.cuda.empty_cache()

        if i % config.stat_freq == 0:
            logging.info(
                f"Iter: {i}, Loss: {loss.item():.3e}, Data Loading Time: {d:.3e}, Tot Time: {t:.3e}"
            )

        if i % config.val_freq == 0 and i > 0:
            torch.save(
                {
                    "state_dict": net.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "scheduler": scheduler.state_dict(),
                    "curr_iter": i,
                },
                config.weights,
            )

            # Validation
            logging.info("Validation")
            test(net, val_iter, config, "val")

            logging.info(f"LR: {scheduler.get_lr()}")

            net.train()

            # one epoch
            scheduler.step()
  def _train_epoch(self, epoch):
    gc.collect()
    self.model.train()
    # Epoch starts from 1
    total_loss = 0
    total_num = 0.0

    data_loader = self.data_loader
    data_loader_iter = self.data_loader.__iter__()

    iter_size = self.iter_size
    start_iter = (epoch - 1) * (len(data_loader) // iter_size)

    data_meter, data_timer, total_timer = AverageMeter(), Timer(), Timer()

    # Main training
    for curr_iter in range(len(data_loader) // iter_size):
      self.optimizer.zero_grad()
      batch_pos_loss, batch_neg_loss, batch_loss = 0, 0, 0

      data_time = 0
      total_timer.tic()
      for iter_idx in range(iter_size):
        # Caffe iter size
        data_timer.tic()
        input_dict = data_loader_iter.next()
        data_time += data_timer.toc(average=False)

        # pairs consist of (xyz1 index, xyz0 index)
        sinput0 = ME.SparseTensor(
            input_dict['sinput0_F'], coords=input_dict['sinput0_C']).to(self.device)
        F0 = self.model(sinput0).F

        sinput1 = ME.SparseTensor(
            input_dict['sinput1_F'], coords=input_dict['sinput1_C']).to(self.device)
        F1 = self.model(sinput1).F

        N0, N1 = len(sinput0), len(sinput1)

        pos_pairs = input_dict['correspondences']
        neg_pairs = self.generate_rand_negative_pairs(pos_pairs, max(N0, N1), N0, N1)
        pos_pairs = pos_pairs.long().to(self.device)
        neg_pairs = torch.from_numpy(neg_pairs).long().to(self.device)

        neg0 = F0.index_select(0, neg_pairs[:, 0])
        neg1 = F1.index_select(0, neg_pairs[:, 1])
        pos0 = F0.index_select(0, pos_pairs[:, 0])
        pos1 = F1.index_select(0, pos_pairs[:, 1])

        # Positive loss
        pos_loss = (pos0 - pos1).pow(2).sum(1)

        # Negative loss
        neg_loss = F.relu(self.neg_thresh -
                          ((neg0 - neg1).pow(2).sum(1) + 1e-4).sqrt()).pow(2)

        pos_loss_mean = pos_loss.mean() / iter_size
        neg_loss_mean = neg_loss.mean() / iter_size

        # Weighted loss
        loss = pos_loss_mean + self.neg_weight * neg_loss_mean
        loss.backward(
        )  # To accumulate gradient, zero gradients only at the begining of iter_size
        batch_loss += loss.item()
        batch_pos_loss += pos_loss_mean.item()
        batch_neg_loss += neg_loss_mean.item()

      self.optimizer.step()

      torch.cuda.empty_cache()

      total_loss += batch_loss
      total_num += 1.0
      total_timer.toc()
      data_meter.update(data_time)

      # Print logs
      if curr_iter % self.config.stat_freq == 0:
        self.writer.add_scalar('train/loss', batch_loss, start_iter + curr_iter)
        self.writer.add_scalar('train/pos_loss', batch_pos_loss, start_iter + curr_iter)
        self.writer.add_scalar('train/neg_loss', batch_neg_loss, start_iter + curr_iter)
        logging.info(
            "Train Epoch: {} [{}/{}], Current Loss: {:.3e} Pos: {:.3f} Neg: {:.3f}"
            .format(epoch, curr_iter,
                    len(self.data_loader) //
                    iter_size, batch_loss, batch_pos_loss, batch_neg_loss) +
            "\tData time: {:.4f}, Train time: {:.4f}, Iter time: {:.4f}".format(
                data_meter.avg, total_timer.avg - data_meter.avg, total_timer.avg))
        data_meter.reset()
        total_timer.reset()
  def _train_epoch(self, epoch):
    gc.collect()
    self.model.train()
    # Epoch starts from 1
    total_loss = 0
    total_num = 0.0
    data_loader = self.data_loader
    data_loader_iter = self.data_loader.__iter__()
    iter_size = self.iter_size
    data_meter, data_timer, total_timer = AverageMeter(), Timer(), Timer()
    start_iter = (epoch - 1) * (len(data_loader) // iter_size)
    for curr_iter in range(len(data_loader) // iter_size):
      self.optimizer.zero_grad()
      batch_pos_loss, batch_neg_loss, batch_loss = 0, 0, 0

      data_time = 0
      total_timer.tic()
      for iter_idx in range(iter_size):
        data_timer.tic()
        input_dict = data_loader_iter.next()
        data_time += data_timer.toc(average=False)

        sinput0 = ME.SparseTensor(
            input_dict['sinput0_F'], coords=input_dict['sinput0_C']).to(self.device)
        F0 = self.model(sinput0).F

        sinput1 = ME.SparseTensor(
            input_dict['sinput1_F'], coords=input_dict['sinput1_C']).to(self.device)

        F1 = self.model(sinput1).F

        pos_pairs = input_dict['correspondences']
        pos_loss, neg_loss = self.contrastive_hardest_negative_loss(
            F0,
            F1,
            pos_pairs,
            num_pos=self.config.num_pos_per_batch * self.config.batch_size,
            num_hn_samples=self.config.num_hn_samples_per_batch *
            self.config.batch_size)

        pos_loss /= iter_size
        neg_loss /= iter_size
        loss = pos_loss + self.neg_weight * neg_loss
        loss.backward()

        batch_loss += loss.item()
        batch_pos_loss += pos_loss.item()
        batch_neg_loss += neg_loss.item()

      self.optimizer.step()
      gc.collect()

      torch.cuda.empty_cache()

      total_loss += batch_loss
      total_num += 1.0
      total_timer.toc()
      data_meter.update(data_time)

      if curr_iter % self.config.stat_freq == 0:
        self.writer.add_scalar('train/loss', batch_loss, start_iter + curr_iter)
        self.writer.add_scalar('train/pos_loss', batch_pos_loss, start_iter + curr_iter)
        self.writer.add_scalar('train/neg_loss', batch_neg_loss, start_iter + curr_iter)
        logging.info(
            "Train Epoch: {} [{}/{}], Current Loss: {:.3e} Pos: {:.3f} Neg: {:.3f}"
            .format(epoch, curr_iter,
                    len(self.data_loader) //
                    iter_size, batch_loss, batch_pos_loss, batch_neg_loss) +
            "\tData time: {:.4f}, Train time: {:.4f}, Iter time: {:.4f}".format(
                data_meter.avg, total_timer.avg - data_meter.avg, total_timer.avg))
        data_meter.reset()
        total_timer.reset()
 def validation_step(self, batch, batch_idx):
     stensor = ME.SparseTensor(coordinates=batch["coordinates"],
                               features=batch["features"])
     return self.criterion(self(stensor).F, batch["labels"].long())
Beispiel #9
0
    def __init__(self):
        nn.Module.__init__(self)

        # Input sparse tensor must have tensor stride 128.
        ch = self.CHANNELS

        # Block 1
        self.block1 = nn.Sequential(
            ME.MinkowskiConvolution(1, ch[0], kernel_size=3, stride=2, dimension=3),
            ME.MinkowskiBatchNorm(ch[0]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[0], ch[0], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[0]),
            ME.MinkowskiELU(),
        )

        self.block2 = nn.Sequential(
            ME.MinkowskiConvolution(ch[0], ch[1], kernel_size=3, stride=2, dimension=3),
            ME.MinkowskiBatchNorm(ch[1]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[1], ch[1], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[1]),
            ME.MinkowskiELU(),
        )

        self.block3 = nn.Sequential(
            ME.MinkowskiConvolution(ch[1], ch[2], kernel_size=3, stride=2, dimension=3),
            ME.MinkowskiBatchNorm(ch[2]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[2], ch[2], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[2]),
            ME.MinkowskiELU(),
        )

        self.block4 = nn.Sequential(
            ME.MinkowskiConvolution(ch[2], ch[3], kernel_size=3, stride=2, dimension=3),
            ME.MinkowskiBatchNorm(ch[3]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[3], ch[3], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[3]),
            ME.MinkowskiELU(),
        )

        self.block5 = nn.Sequential(
            ME.MinkowskiConvolution(ch[3], ch[4], kernel_size=3, stride=2, dimension=3),
            ME.MinkowskiBatchNorm(ch[4]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[4], ch[4], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[4]),
            ME.MinkowskiELU(),
        )

        self.block6 = nn.Sequential(
            ME.MinkowskiConvolution(ch[4], ch[5], kernel_size=3, stride=2, dimension=3),
            ME.MinkowskiBatchNorm(ch[5]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[5], ch[5], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[5]),
            ME.MinkowskiELU(),
        )

        self.block7 = nn.Sequential(
            ME.MinkowskiConvolution(ch[5], ch[6], kernel_size=3, stride=2, dimension=3),
            ME.MinkowskiBatchNorm(ch[6]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[6], ch[6], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[6]),
            ME.MinkowskiELU(),
        )

        self.global_pool = ME.MinkowskiGlobalPooling()

        self.linear_mean = ME.MinkowskiLinear(ch[6], ch[6], bias=True)
        self.linear_log_var = ME.MinkowskiLinear(ch[6], ch[6], bias=True)
        self.weight_initialization()
Beispiel #10
0
  def __init__(self,
               in_channels=3,
               out_channels=32,
               bn_momentum=0.1,
               conv1_kernel_size=3,
               normalize_feature=False,
               D=3):
    ME.MinkowskiNetwork.__init__(self, D)
    NORM_TYPE = self.NORM_TYPE
    BLOCK_NORM_TYPE = self.BLOCK_NORM_TYPE
    CHANNELS = self.CHANNELS
    TR_CHANNELS = self.TR_CHANNELS
    REGION_TYPE = self.REGION_TYPE
    self.normalize_feature = normalize_feature
    self.conv1 = conv(
        in_channels=in_channels,
        out_channels=CHANNELS[1],
        kernel_size=conv1_kernel_size,
        stride=1,
        dilation=1,
        has_bias=False,
        region_type=ME.RegionType.HYPERCUBE,
        dimension=D)
    self.norm1 = get_norm(NORM_TYPE, CHANNELS[1], bn_momentum=bn_momentum, dimension=D)

    self.block1 = get_block(
        BLOCK_NORM_TYPE,
        CHANNELS[1],
        CHANNELS[1],
        bn_momentum=bn_momentum,
        region_type=REGION_TYPE,
        dimension=D)

    self.pool2 = ME.MinkowskiSumPooling(kernel_size=2, stride=2, dimension=D)
    self.conv2 = conv(
        in_channels=CHANNELS[1],
        out_channels=CHANNELS[2],
        kernel_size=3,
        stride=1,
        dilation=1,
        has_bias=False,
        region_type=REGION_TYPE,
        dimension=D)
    self.norm2 = get_norm(NORM_TYPE, CHANNELS[2], bn_momentum=bn_momentum, dimension=D)

    self.block2 = get_block(
        BLOCK_NORM_TYPE,
        CHANNELS[2],
        CHANNELS[2],
        bn_momentum=bn_momentum,
        region_type=REGION_TYPE,
        dimension=D)

    self.pool3 = ME.MinkowskiSumPooling(kernel_size=2, stride=2, dimension=D)
    self.conv3 = conv(
        in_channels=CHANNELS[2],
        out_channels=CHANNELS[3],
        kernel_size=3,
        stride=1,
        dilation=1,
        has_bias=False,
        region_type=REGION_TYPE,
        dimension=D)
    self.norm3 = get_norm(NORM_TYPE, CHANNELS[3], bn_momentum=bn_momentum, dimension=D)

    self.block3 = get_block(
        BLOCK_NORM_TYPE,
        CHANNELS[3],
        CHANNELS[3],
        bn_momentum=bn_momentum,
        region_type=REGION_TYPE,
        dimension=D)

    self.pool4 = ME.MinkowskiSumPooling(kernel_size=2, stride=2, dimension=D)
    self.conv4 = conv(
        in_channels=CHANNELS[3],
        out_channels=CHANNELS[4],
        kernel_size=3,
        stride=1,
        dilation=1,
        has_bias=False,
        region_type=ME.RegionType.HYPERCUBE,
        dimension=D)
    self.norm4 = get_norm(NORM_TYPE, CHANNELS[4], bn_momentum=bn_momentum, dimension=D)

    self.block4 = get_block(
        BLOCK_NORM_TYPE,
        CHANNELS[4],
        CHANNELS[4],
        bn_momentum=bn_momentum,
        region_type=ME.RegionType.HYPERCUBE,
        dimension=D)

    self.conv4_tr = conv_tr(
        in_channels=CHANNELS[4],
        out_channels=TR_CHANNELS[4],
        kernel_size=3,
        stride=2,
        dilation=1,
        has_bias=False,
        region_type=ME.RegionType.HYPERCUBE,
        dimension=D)
    self.norm4_tr = get_norm(
        NORM_TYPE, TR_CHANNELS[4], bn_momentum=bn_momentum, dimension=D)

    self.block4_tr = get_block(
        BLOCK_NORM_TYPE,
        TR_CHANNELS[4],
        TR_CHANNELS[4],
        bn_momentum=bn_momentum,
        region_type=REGION_TYPE,
        dimension=D)

    self.conv3_tr = conv_tr(
        in_channels=CHANNELS[3] + TR_CHANNELS[4],
        out_channels=TR_CHANNELS[3],
        kernel_size=3,
        stride=2,
        dilation=1,
        has_bias=False,
        region_type=REGION_TYPE,
        dimension=D)
    self.norm3_tr = get_norm(
        NORM_TYPE, TR_CHANNELS[3], bn_momentum=bn_momentum, dimension=D)

    self.block3_tr = get_block(
        BLOCK_NORM_TYPE,
        TR_CHANNELS[3],
        TR_CHANNELS[3],
        bn_momentum=bn_momentum,
        region_type=REGION_TYPE,
        dimension=D)

    self.conv2_tr = conv_tr(
        in_channels=CHANNELS[2] + TR_CHANNELS[3],
        out_channels=TR_CHANNELS[2],
        kernel_size=3,
        stride=2,
        dilation=1,
        has_bias=False,
        region_type=REGION_TYPE,
        dimension=D)
    self.norm2_tr = get_norm(
        NORM_TYPE, TR_CHANNELS[2], bn_momentum=bn_momentum, dimension=D)

    self.block2_tr = get_block(
        BLOCK_NORM_TYPE,
        TR_CHANNELS[2],
        TR_CHANNELS[2],
        bn_momentum=bn_momentum,
        region_type=REGION_TYPE,
        dimension=D)

    self.conv1_tr = conv(
        in_channels=CHANNELS[1] + TR_CHANNELS[2],
        out_channels=TR_CHANNELS[1],
        kernel_size=1,
        stride=1,
        dilation=1,
        has_bias=False,
        dimension=D)

    # self.block1_tr = BasicBlockBN(TR_CHANNELS[1], TR_CHANNELS[1], bn_momentum=bn_momentum, D=D)

    self.final = ME.MinkowskiConvolution(
        in_channels=TR_CHANNELS[1],
        out_channels=out_channels,
        kernel_size=1,
        stride=1,
        dilation=1,
        has_bias=True,
        dimension=D)
 def get_mlp_block(self, in_channel, out_channel):
     return nn.Sequential(
         ME.MinkowskiLinear(in_channel, out_channel, bias=False),
         ME.MinkowskiBatchNorm(out_channel),
         ME.MinkowskiReLU(),
     )
Beispiel #12
0
  def __init__(self,
               in_channels=3,
               out_channels=32,
               bn_momentum=0.1,
               conv1_kernel_size=3,
               normalize_feature=False,
               D=3):
    ME.MinkowskiNetwork.__init__(self, D)
    NORM_TYPE = self.NORM_TYPE
    BLOCK_NORM_TYPE = self.BLOCK_NORM_TYPE
    CHANNELS = self.CHANNELS
    TR_CHANNELS = self.TR_CHANNELS
    DEPTHS = self.DEPTHS
    REGION_TYPE = self.REGION_TYPE
    self.normalize_feature = normalize_feature
    self.conv1 = conv(
        in_channels=in_channels,
        out_channels=CHANNELS[1],
        kernel_size=conv1_kernel_size,
        stride=1,
        dilation=1,
        has_bias=False,
        region_type=REGION_TYPE,
        dimension=D)
    self.norm1 = get_norm(NORM_TYPE, CHANNELS[1], bn_momentum=bn_momentum, dimension=D)

    self.block1 = nn.Sequential(*[
        get_block(
            BLOCK_NORM_TYPE,
            CHANNELS[1],
            CHANNELS[1],
            bn_momentum=bn_momentum,
            region_type=REGION_TYPE,
            dimension=D) for d in range(DEPTHS[1])
    ])

    self.pool2 = ME.MinkowskiSumPooling(kernel_size=2, stride=2, dimension=D)
    self.conv2 = conv(
        in_channels=CHANNELS[1],
        out_channels=CHANNELS[2],
        kernel_size=1,
        stride=1,
        dilation=1,
        has_bias=False,
        dimension=D)
    self.norm2 = get_norm(NORM_TYPE, CHANNELS[2], bn_momentum=bn_momentum, dimension=D)

    self.block2 = nn.Sequential(*[
        get_block(
            BLOCK_NORM_TYPE,
            CHANNELS[2],
            CHANNELS[2],
            bn_momentum=bn_momentum,
            region_type=REGION_TYPE,
            dimension=D) for d in range(DEPTHS[2])
    ])

    self.pool3 = ME.MinkowskiSumPooling(kernel_size=2, stride=2, dimension=D)
    self.conv3 = conv(
        in_channels=CHANNELS[2],
        out_channels=CHANNELS[3],
        kernel_size=1,
        stride=1,
        dilation=1,
        has_bias=False,
        dimension=D)
    self.norm3 = get_norm(NORM_TYPE, CHANNELS[3], bn_momentum=bn_momentum, dimension=D)

    self.block3 = nn.Sequential(*[
        get_block(
            BLOCK_NORM_TYPE,
            CHANNELS[3],
            CHANNELS[3],
            bn_momentum=bn_momentum,
            region_type=REGION_TYPE,
            dimension=D) for d in range(DEPTHS[3])
    ])

    self.pool3_tr = ME.MinkowskiPoolingTranspose(kernel_size=2, stride=2, dimension=D)
    self.conv3_tr = conv_tr(
        in_channels=CHANNELS[3],
        out_channels=TR_CHANNELS[3],
        kernel_size=1,
        stride=1,
        dilation=1,
        has_bias=False,
        dimension=D)
    self.norm3_tr = get_norm(
        NORM_TYPE, TR_CHANNELS[3], bn_momentum=bn_momentum, dimension=D)

    self.block3_tr = nn.Sequential(*[
        get_block(
            BLOCK_NORM_TYPE,
            TR_CHANNELS[3],
            TR_CHANNELS[3],
            bn_momentum=bn_momentum,
            region_type=REGION_TYPE,
            dimension=D) for d in range(DEPTHS[-3])
    ])

    self.pool2_tr = ME.MinkowskiPoolingTranspose(kernel_size=2, stride=2, dimension=D)
    self.conv2_tr = conv_tr(
        in_channels=CHANNELS[2] + TR_CHANNELS[3],
        out_channels=TR_CHANNELS[2],
        kernel_size=1,
        stride=1,
        dilation=1,
        has_bias=False,
        region_type=REGION_TYPE,
        dimension=D)
    self.norm2_tr = get_norm(
        NORM_TYPE, TR_CHANNELS[2], bn_momentum=bn_momentum, dimension=D)

    self.block2_tr = nn.Sequential(*[
        get_block(
            BLOCK_NORM_TYPE,
            TR_CHANNELS[2],
            TR_CHANNELS[2],
            bn_momentum=bn_momentum,
            region_type=REGION_TYPE,
            dimension=D) for d in range(DEPTHS[-2])
    ])

    self.conv1_tr = conv_tr(
        in_channels=CHANNELS[1] + TR_CHANNELS[2],
        out_channels=TR_CHANNELS[1],
        kernel_size=1,
        stride=1,
        dilation=1,
        has_bias=False,
        region_type=REGION_TYPE,
        dimension=D)

    # self.block1_tr = BasicBlockBN(TR_CHANNELS[1], TR_CHANNELS[1], bn_momentum=bn_momentum, dimension=D)

    self.final = conv(
        in_channels=TR_CHANNELS[1],
        out_channels=out_channels,
        kernel_size=1,
        stride=1,
        dilation=1,
        has_bias=True,
        dimension=D)
Beispiel #13
0
def sparse_corr_6d(scalespace,
                   feature_A,
                   feature_B,
                   k=10,
                   coords_A=None,
                   coords_B=None,
                   reverse=False,
                   ratio=False,
                   sparse_type='torch',
                   return_indx=False,
                   fsize=None,
                   bidx=None):

    b, ch = feature_B.shape[:2]

    if fsize is None:
        hA, wA = feature_A.shape[2:]
        hB, wB = feature_B.shape[2:]
    else:
        hA, wA = fsize
        hB, wB = fsize

    feature_A = feature_A.view(b, ch, -1)
    feature_B = feature_B.view(b, ch, -1)

    nA = feature_A.shape[2]
    nB = feature_B.shape[2]

    with torch.no_grad():
        dist_squared, indx = knn_faiss(feature_B, feature_A, k)

    if bidx is None: bidx = torch.arange(b).view(b, 1, 1)
    bidx = bidx.expand_as(indx).contiguous()

    sidx1 = torch.empty(indx.shape).fill_(scalespace[0]).view(-1, 1)
    sidx2 = torch.empty(indx.shape).fill_(scalespace[1]).view(-1, 1)

    if feature_A.requires_grad:
        corr = (feature_A.permute(1,0,2).unsqueeze(2) * \
                feature_B.permute(1,0,2)[:,bidx.view(-1),indx.view(-1)].view(ch,b,k,nA)).sum(dim=0).contiguous()
    else:
        corr = 1 - dist_squared / 2  # [b,k,nA]

    if ratio:
        corr_ratio = corr / corr[:, :1, :]

    if coords_A is None:
        YA, XA = torch.meshgrid(torch.arange(hA), torch.arange(wA))
        YA = YA.contiguous()
        XA = XA.contiguous()
        yA = YA.view(-1).unsqueeze(0).unsqueeze(0).expand(
            b, k, nA).contiguous().view(-1, 1)
        xA = XA.view(-1).unsqueeze(0).unsqueeze(0).expand(
            b, k, nA).contiguous().view(-1, 1)
    else:
        yA, xA = coords_A
        yA = yA.view(-1).unsqueeze(0).unsqueeze(0).expand(
            b, k, nA).contiguous().view(-1, 1)
        xA = xA.view(-1).unsqueeze(0).unsqueeze(0).expand(
            b, k, nA).contiguous().view(-1, 1)

    if coords_B is None:
        YB, XB = torch.meshgrid(torch.arange(hB), torch.arange(wB))
        YB = YB.contiguous()
        XB = XB.contiguous()
        yB = YB.view(-1)[indx.view(-1).cpu()].view(-1, 1)
        xB = XB.view(-1)[indx.view(-1).cpu()].view(-1, 1)
    else:
        yB, xB = coords_B
        yB = yB.view(-1)[indx.view(-1).cpu()].view(-1, 1)
        xB = xB.view(-1)[indx.view(-1).cpu()].view(-1, 1)

    bidx = bidx.view(-1, 1)
    corr = corr.view(-1, 1)
    if ratio: corr_ratio = corr_ratio.view(-1, 1)

    if reverse:
        yA, xA, yB, xB = yB, xB, yA, xA
        hA, wA, hB, wB = hB, wB, hA, wA

    if sparse_type == 'me':
        coords = torch.cat((bidx, sidx1, sidx2, yA, xA, yB, xB), dim=1).int()
        scorr = ME.SparseTensor(corr, coords)

        if ratio: scorr_ratio = ME.SparseTensor(corr_ratio, coords)

    elif sparse_type == 'torch':
        coords = torch.cat((bidx, sidx1, sidx2, yA, xA, yB, xB),
                           dim=1).long().to(corr.device).t()
        scorr = torch.sparse.FloatTensor(coords, corr,
                                         torch.Size([b, hA, wA, hB, wB, 1]))

        if ratio:
            scorr_ratio = torch.sparse.FloatTensor(
                coords, corr_ratio, torch.Size([b, hA, wA, hB, wB, 1]))

    elif sparse_type == 'raw':
        coords = torch.cat((bidx, sidx1, sidx2, yA, xA, yB, xB), dim=1).int()
        scorr = (corr, coords)

        if ratio: scorr_ratio = (corr_ratio, coords)

    else:
        raise ValueError('sparse type {} not recognized'.format(sparse_type))

    if ratio: return scorr, scorr_ratio
    if return_indx: return corr, indx
    return scorr
Beispiel #14
0
def transpose_me_6d(sten):
    return ME.SparseTensor(sten.features.clone(),
                           sten.coordinates[:, [0, 1, 2, 5, 6, 3, 4]].clone())
    def forward(self,
                in_field: ME.TensorField,
                aux=None,
                save_anchor=False,
                iter_=None):

        x = in_field
        if save_anchor:
            self.anchors = []

        if aux is not None:
            aux = ME.SparseTensor(coordinates=x.C,
                                  features=aux.float().reshape([-1, 1]).cuda())

            x0 = self.stem1(x)
            x = self.stem2(x0)

            aux1 = subsample_aux(x0, x, aux)
            x1 = self.PTBlock1(x, aux=aux1)

            x = self.TDLayer1(x1)
            aux2 = subsample_aux(x1, x, aux1)
            x2 = self.PTBlock2(x, aux=aux2)

            x = self.TDLayer2(x2)
            aux3 = subsample_aux(x2, x, aux2)
            x3 = self.PTBlock3(x, aux=aux3)

            x = self.TDLayer3(x3)
            aux4 = subsample_aux(x3, x, aux3)
            x = self.PTBlock4(x, aux=aux4)

            x = self.TULayer5(x, x3)
            x = self.PTBlock5(x, aux=aux3)

            x = self.TULayer6(x, x2)
            x = self.PTBlock6(x, aux=aux2)

            x = self.TULayer7(x, x1)
            x = self.PTBlock7(x, aux=aux1)

        else:

            # alpha from 0 ~ 1.0
            # gradually from pointnet(skip_attn: TR) to tr-block
            if self.ALPHA_BLENDING_MID_TR or self.ALPHA_BLENDING_FIRST_TR:
                assert iter_ is not None
                # iter_ is normalized by config.iter_size
                alpha = 1 - 1.01**(-24000 * iter_)
                if iter_ < 0.1:
                    alpha = 0

            x0 = self.stem1(x)
            x = self.stem2(x0)

            # x1 = self.PTBlock1(x)
            if self.ALPHA_BLENDING_FIRST_TR:
                x1 = self.PTBlock1(x, iter_=iter_)
                x1_ = self.PTBlock1_branch(x, iter_=iter_)
                new_x1 = alpha * x1.F + (1 - alpha) * x1_.F
                x1 = ME.SparseTensor(features=new_x1,
                                     coordinate_map_key=x.coordinate_map_key,
                                     coordinate_manager=x.coordinate_manager)
            else:
                x1 = self.PTBlock1(x, iter_=iter_)
            if save_anchor:
                self.anchors.append(x1)

            x = self.TDLayer1(x1)
            # x2 = self.PTBlock2(x)
            x2 = self.PTBlock2(x, iter_=iter_)
            if save_anchor:
                self.anchors.append(x2)

            x = self.TDLayer2(x2)
            x3 = self.PTBlock3(x, iter_=iter_)
            # if save_anchor:
            # self.anchors.append(x3)

            x = self.TDLayer3(x3)
            if self.ALPHA_BLENDING_MID_TR:
                x_ = self.PTBlock4_branch(x, iter_=iter_)
            x = self.PTBlock4(x, iter_=iter_)
            if self.ALPHA_BLENDING_MID_TR:
                new_x = alpha * x.F + (1 - alpha) * x_.F
                x = ME.SparseTensor(features=new_x,
                                    coordinate_map_key=x.coordinate_map_key,
                                    coordinate_manager=x.coordinate_manager)

            # if save_anchor:
            # self.anchors.append(x)

            x = self.TULayer5(x, x3)
            if self.ALPHA_BLENDING_MID_TR:
                x_ = self.PTBlock5_branch(x)
            x = self.PTBlock5(x, iter_=iter_)
            if self.ALPHA_BLENDING_MID_TR:
                new_x = alpha * x.F + (1 - alpha) * x_.F
                x = ME.SparseTensor(features=new_x,
                                    coordinate_map_key=x.coordinate_map_key,
                                    coordinate_manager=x.coordinate_manager)
            # if save_anchor:
            # self.anchors.append(x)

            x = self.TULayer6(x, x2)
            # x = self.PTBlock6(x)
            x = self.PTBlock6(x, iter_=iter_)
            if save_anchor:
                self.anchors.append(x)

            x = self.TULayer7(x, x1)
            x = self.PTBlock7(x)
            # x = self.PTBlock7(x, iter_=iter_)
            if save_anchor:
                self.anchors.append(x)

            x = self.TULayer8(x, x0)
            x = self.fc(x)

            # x = self.final_conv(x)
            # x = self.fc(me.cat(x0,x))

        if save_anchor:
            return x, self.anchors
        else:
            return x
Beispiel #16
0
    def __init__(self):
        nn.Module.__init__(self)

        # Input sparse tensor must have tensor stride 128.
        ch = self.CHANNELS

        # Block 1
        self.block1 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                ch[0], ch[0], kernel_size=2, stride=2, dimension=3
            ),
            ME.MinkowskiBatchNorm(ch[0]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[0], ch[0], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[0]),
            ME.MinkowskiELU(),
            ME.MinkowskiGenerativeConvolutionTranspose(
                ch[0], ch[1], kernel_size=2, stride=2, dimension=3
            ),
            ME.MinkowskiBatchNorm(ch[1]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[1], ch[1], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[1]),
            ME.MinkowskiELU(),
        )

        self.block1_cls = ME.MinkowskiConvolution(
            ch[1], 1, kernel_size=1, bias=True, dimension=3
        )

        # Block 2
        self.block2 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                ch[1], ch[2], kernel_size=2, stride=2, dimension=3
            ),
            ME.MinkowskiBatchNorm(ch[2]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[2], ch[2], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[2]),
            ME.MinkowskiELU(),
        )

        self.block2_cls = ME.MinkowskiConvolution(
            ch[2], 1, kernel_size=1, bias=True, dimension=3
        )

        # Block 3
        self.block3 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                ch[2], ch[3], kernel_size=2, stride=2, dimension=3
            ),
            ME.MinkowskiBatchNorm(ch[3]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[3], ch[3], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[3]),
            ME.MinkowskiELU(),
        )

        self.block3_cls = ME.MinkowskiConvolution(
            ch[3], 1, kernel_size=1, bias=True, dimension=3
        )

        # Block 4
        self.block4 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                ch[3], ch[4], kernel_size=2, stride=2, dimension=3
            ),
            ME.MinkowskiBatchNorm(ch[4]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[4], ch[4], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[4]),
            ME.MinkowskiELU(),
        )

        self.block4_cls = ME.MinkowskiConvolution(
            ch[4], 1, kernel_size=1, bias=True, dimension=3
        )

        # Block 5
        self.block5 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                ch[4], ch[5], kernel_size=2, stride=2, dimension=3
            ),
            ME.MinkowskiBatchNorm(ch[5]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[5], ch[5], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[5]),
            ME.MinkowskiELU(),
        )

        self.block5_cls = ME.MinkowskiConvolution(
            ch[5], 1, kernel_size=1, bias=True, dimension=3
        )

        # Block 6
        self.block6 = nn.Sequential(
            ME.MinkowskiGenerativeConvolutionTranspose(
                ch[5], ch[6], kernel_size=2, stride=2, dimension=3
            ),
            ME.MinkowskiBatchNorm(ch[6]),
            ME.MinkowskiELU(),
            ME.MinkowskiConvolution(ch[6], ch[6], kernel_size=3, dimension=3),
            ME.MinkowskiBatchNorm(ch[6]),
            ME.MinkowskiELU(),
        )

        self.block6_cls = ME.MinkowskiConvolution(
            ch[6], 1, kernel_size=1, bias=True, dimension=3
        )

        # pruning
        self.pruning = ME.MinkowskiPruning()
Beispiel #17
0
    # Define a model and load the weights
    model = MinkUNet34C(3, 20).to(device)
    model_dict = torch.load(config.weights)
    model.load_state_dict(model_dict)
    model.eval()

    coords, colors, pcd = load_file(config.file_name)
    # Measure time
    with torch.no_grad():
        voxel_size = 0.02

        # Feed-forward pass and get the prediction
        bp()
        sinput = ME.SparseTensor(
            feats=torch.from_numpy(colors).float(),
            coords=ME.utils.batched_coordinates([coords / voxel_size]),
            quantization_mode=ME.SparseTensorQuantizationMode.
            UNWEIGHTED_AVERAGE).to(device)
        logits = model(sinput).slice(sinput)

    _, pred = logits.max(1)
    pred = pred.cpu().numpy()

    # Create a point cloud file
    pred_pcd = o3d.geometry.PointCloud()
    # Map color
    colors = np.array([SCANNET_COLOR_MAP[VALID_CLASS_IDS[l]] for l in pred])
    pred_pcd.points = o3d.utility.Vector3dVector(coords)
    pred_pcd.colors = o3d.utility.Vector3dVector(colors / 255)

    # Move the original point cloud
Beispiel #18
0
    def forward(self, z_glob, target_key):
        out_cls, targets = [], []

        z = ME.SparseTensor(
            features=z_glob.F,
            coordinates=z_glob.C,
            tensor_stride=self.resolution,
            coordinate_manager=z_glob.coordinate_manager,
        )

        # Block1
        out1 = self.block1(z)
        out1_cls = self.block1_cls(out1)
        target = self.get_target(out1, target_key)
        targets.append(target)
        out_cls.append(out1_cls)
        keep1 = (out1_cls.F > 0).squeeze()

        # If training, force target shape generation, use net.eval() to disable
        if self.training:
            keep1 += target

        # Remove voxels 32
        out1 = self.pruning(out1, keep1)

        # Block 2
        out2 = self.block2(out1)
        out2_cls = self.block2_cls(out2)
        target = self.get_target(out2, target_key)
        targets.append(target)
        out_cls.append(out2_cls)
        keep2 = (out2_cls.F > 0).squeeze()

        if self.training:
            keep2 += target

        # Remove voxels 16
        out2 = self.pruning(out2, keep2)

        # Block 3
        out3 = self.block3(out2)
        out3_cls = self.block3_cls(out3)
        target = self.get_target(out3, target_key)
        targets.append(target)
        out_cls.append(out3_cls)
        keep3 = (out3_cls.F > 0).squeeze()

        if self.training:
            keep3 += target

        # Remove voxels 8
        out3 = self.pruning(out3, keep3)

        # Block 4
        out4 = self.block4(out3)
        out4_cls = self.block4_cls(out4)
        target = self.get_target(out4, target_key)
        targets.append(target)
        out_cls.append(out4_cls)
        keep4 = (out4_cls.F > 0).squeeze()

        if self.training:
            keep4 += target

        # Remove voxels 4
        out4 = self.pruning(out4, keep4)

        # Block 5
        out5 = self.block5(out4)
        out5_cls = self.block5_cls(out5)
        target = self.get_target(out5, target_key)
        targets.append(target)
        out_cls.append(out5_cls)
        keep5 = (out5_cls.F > 0).squeeze()

        if self.training:
            keep5 += target

        # Remove voxels 2
        out5 = self.pruning(out5, keep5)

        # Block 5
        out6 = self.block6(out5)
        out6_cls = self.block6_cls(out6)
        target = self.get_target(out6, target_key)
        targets.append(target)
        out_cls.append(out6_cls)
        keep6 = (out6_cls.F > 0).squeeze()

        # Last layer does not require keep
        # if self.training:
        #   keep6 += target

        # Remove voxels 1
        if keep6.sum() > 0:
            out6 = self.pruning(out6, keep6)

        return out_cls, targets, out6
  def _valid_epoch(self):
    # Change the network to evaluation mode
    self.model.eval()
    self.val_data_loader.dataset.reset_seed(0)
    num_data = 0
    hit_ratio_meter, feat_match_ratio, loss_meter, rte_meter, rre_meter = AverageMeter(
    ), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
    data_timer, feat_timer, matching_timer = Timer(), Timer(), Timer()
    tot_num_data = len(self.val_data_loader.dataset)
    if self.val_max_iter > 0:
      tot_num_data = min(self.val_max_iter, tot_num_data)
    data_loader_iter = self.val_data_loader.__iter__()

    for batch_idx in range(tot_num_data):
      data_timer.tic()
      input_dict = data_loader_iter.next()
      data_timer.toc()

      # pairs consist of (xyz1 index, xyz0 index)
      feat_timer.tic()
      sinput0 = ME.SparseTensor(
          input_dict['sinput0_F'], coords=input_dict['sinput0_C']).to(self.device)
      F0 = self.model(sinput0).F

      sinput1 = ME.SparseTensor(
          input_dict['sinput1_F'], coords=input_dict['sinput1_C']).to(self.device)
      F1 = self.model(sinput1).F
      feat_timer.toc()

      matching_timer.tic()
      xyz0, xyz1, T_gt = input_dict['pcd0'], input_dict['pcd1'], input_dict['T_gt']
      xyz0_corr, xyz1_corr = self.find_corr(xyz0, xyz1, F0, F1, subsample_size=5000)
      T_est = te.est_quad_linear_robust(xyz0_corr, xyz1_corr)

      loss = corr_dist(T_est, T_gt, xyz0, xyz1, weight=None)
      loss_meter.update(loss)

      rte = np.linalg.norm(T_est[:3, 3] - T_gt[:3, 3])
      rte_meter.update(rte)
      rre = np.arccos((np.trace(T_est[:3, :3].t() @ T_gt[:3, :3]) - 1) / 2)
      if not np.isnan(rre):
        rre_meter.update(rre)

      hit_ratio = self.evaluate_hit_ratio(
          xyz0_corr, xyz1_corr, T_gt, thresh=self.config.hit_ratio_thresh)
      hit_ratio_meter.update(hit_ratio)
      feat_match_ratio.update(hit_ratio > 0.05)
      matching_timer.toc()

      num_data += 1
      torch.cuda.empty_cache()

      if batch_idx % 100 == 0 and batch_idx > 0:
        logging.info(' '.join([
            f"Validation iter {num_data} / {tot_num_data} : Data Loading Time: {data_timer.avg:.3f},",
            f"Feature Extraction Time: {feat_timer.avg:.3f}, Matching Time: {matching_timer.avg:.3f},",
            f"Loss: {loss_meter.avg:.3f}, RTE: {rte_meter.avg:.3f}, RRE: {rre_meter.avg:.3f},",
            f"Hit Ratio: {hit_ratio_meter.avg:.3f}, Feat Match Ratio: {feat_match_ratio.avg:.3f}"
        ]))
        data_timer.reset()

    logging.info(' '.join([
        f"Final Loss: {loss_meter.avg:.3f}, RTE: {rte_meter.avg:.3f}, RRE: {rre_meter.avg:.3f},",
        f"Hit Ratio: {hit_ratio_meter.avg:.3f}, Feat Match Ratio: {feat_match_ratio.avg:.3f}"
    ]))
    return {
        "loss": loss_meter.avg,
        "rre": rre_meter.avg,
        "rte": rte_meter.avg,
        'feat_match_ratio': feat_match_ratio.avg,
        'hit_ratio': hit_ratio_meter.avg
    }
Beispiel #20
0
def train(net, dataloader, device, config):
    optimizer = optim.SGD(
        net.parameters(),
        lr=config.lr,
        momentum=config.momentum,
        weight_decay=config.weight_decay,
    )
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.95)

    crit = nn.BCEWithLogitsLoss()

    start_iter = 0
    if config.resume is not None:
        checkpoint = torch.load(config.resume)
        print("Resuming weights")
        net.load_state_dict(checkpoint["state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        scheduler.load_state_dict(checkpoint["scheduler"])
        start_iter = checkpoint["curr_iter"]

    net.train()
    train_iter = iter(dataloader)
    # val_iter = iter(val_dataloader)
    logging.info(f"LR: {scheduler.get_lr()}")
    for i in range(start_iter, config.max_iter):

        s = time()
        data_dict = train_iter.next()
        d = time() - s

        optimizer.zero_grad()
        sin = ME.SparseTensor(
            features=torch.ones(len(data_dict["coords"]), 1),
            coordinates=data_dict["coords"].int(),
            device=device,
        )

        # Generate target sparse tensor
        target_key = sin.coordinate_map_key

        out_cls, targets, sout, means, log_vars, zs = net(sin, target_key)
        num_layers, BCE = len(out_cls), 0
        losses = []
        for out_cl, target in zip(out_cls, targets):
            curr_loss = crit(out_cl.F.squeeze(), target.type(out_cl.F.dtype).to(device))
            losses.append(curr_loss.item())
            BCE += curr_loss / num_layers

        KLD = -0.5 * torch.mean(
            torch.mean(1 + log_vars.F - means.F.pow(2) - log_vars.F.exp(), 1)
        )
        loss = KLD + BCE

        loss.backward()
        optimizer.step()
        t = time() - s

        if i % config.stat_freq == 0:
            logging.info(
                f"Iter: {i}, Loss: {loss.item():.3e}, Depths: {len(out_cls)} Data Loading Time: {d:.3e}, Tot Time: {t:.3e}"
            )

        if i % config.val_freq == 0 and i > 0:
            torch.save(
                {
                    "state_dict": net.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "scheduler": scheduler.state_dict(),
                    "curr_iter": i,
                },
                config.weights,
            )

            scheduler.step()
            logging.info(f"LR: {scheduler.get_lr()}")

            net.train()
  def _train_epoch(self, epoch):
    config = self.config

    gc.collect()
    self.model.train()

    # Epoch starts from 1
    total_loss = 0
    total_num = 0.0
    data_loader = self.data_loader
    data_loader_iter = self.data_loader.__iter__()
    iter_size = self.iter_size
    data_meter, data_timer, total_timer = AverageMeter(), Timer(), Timer()
    pos_dist_meter, neg_dist_meter = AverageMeter(), AverageMeter()
    start_iter = (epoch - 1) * (len(data_loader) // iter_size)
    for curr_iter in range(len(data_loader) // iter_size):
      self.optimizer.zero_grad()
      batch_loss = 0
      data_time = 0
      total_timer.tic()
      for iter_idx in range(iter_size):
        data_timer.tic()
        input_dict = data_loader_iter.next()
        data_time += data_timer.toc(average=False)

        # pairs consist of (xyz1 index, xyz0 index)
        sinput0 = ME.SparseTensor(
            input_dict['sinput0_F'], coords=input_dict['sinput0_C']).to(self.device)
        F0 = self.model(sinput0).F

        sinput1 = ME.SparseTensor(
            input_dict['sinput1_F'], coords=input_dict['sinput1_C']).to(self.device)
        F1 = self.model(sinput1).F

        pos_pairs = input_dict['correspondences']
        loss, pos_dist, neg_dist = self.triplet_loss(
            F0,
            F1,
            pos_pairs,
            num_pos=config.triplet_num_pos * config.batch_size,
            num_hn_samples=config.triplet_num_hn * config.batch_size,
            num_rand_triplet=config.triplet_num_rand * config.batch_size)
        loss /= iter_size
        loss.backward()
        batch_loss += loss.item()
        pos_dist_meter.update(pos_dist)
        neg_dist_meter.update(neg_dist)

      self.optimizer.step()
      gc.collect()

      torch.cuda.empty_cache()

      total_loss += batch_loss
      total_num += 1.0
      total_timer.toc()
      data_meter.update(data_time)

      if curr_iter % self.config.stat_freq == 0:
        self.writer.add_scalar('train/loss', batch_loss, start_iter + curr_iter)
        logging.info(
            "Train Epoch: {} [{}/{}], Current Loss: {:.3e}, Pos dist: {:.3e}, Neg dist: {:.3e}"
            .format(epoch, curr_iter,
                    len(self.data_loader) //
                    iter_size, batch_loss, pos_dist_meter.avg, neg_dist_meter.avg) +
            "\tData time: {:.4f}, Train time: {:.4f}, Iter time: {:.4f}".format(
                data_meter.avg, total_timer.avg - data_meter.avg, total_timer.avg))
        pos_dist_meter.reset()
        neg_dist_meter.reset()
        data_meter.reset()
        total_timer.reset()
Beispiel #22
0
def do_train(cfg, model, data_loader, optimizer, scheduler,
             criterion, checkpointer, device, arguments,
             tblogger, data_loader_val, distributed):
    logger = logging.getLogger('eve.' + __name__)
    meters = MetricLogger(delimiter="  ")
    max_iter = len(data_loader)
    start_iter = arguments['iteration']
    model.train()
    start_training_time = time.time()
    end = time.time()
    logger.info("Start training")
    logger.info("Arguments: {}".format(arguments))

    for iteration, batch in enumerate(data_loader, start_iter):
        model.train()
        data_time = time.time() - end
        iteration = iteration + 1
        arguments['iteration'] = iteration

        # FIXME: for eve, modify dataloader
        locs, feats, targets, _ = batch
        inputs = ME.SparseTensor(feats, coords=locs).to(device)
        targets = targets.to(device, non_blocking=True).long()
        out = model(inputs, y=targets)

        if len(out) == 2:  # minkunet_eve
            outputs, match = out
        else:
            outputs = out
        loss = criterion(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if len(out) == 2:  # FIXME
            loss_dict = dict(loss=loss, match_acc=match[0], match_time=match[1])
        else:
            loss_dict = dict(loss=loss)
        loss_dict_reduced = reduce_dict(loss_dict)
        meters.update(**loss_dict_reduced)

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data_time=data_time)
        eta_seconds = meters.time.global_avg * (max_iter - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        if tblogger is not None:
            for name, meter in meters.meters.items():
                if 'time' in name:
                    tblogger.add_scalar(
                        'other/' + name, meter.median, iteration)
                else:
                    tblogger.add_scalar(
                        'train/' + name, meter.median, iteration)
            tblogger.add_scalar(
                'other/lr', optimizer.param_groups[0]['lr'], iteration)

        if iteration % cfg.SOLVER.LOG_PERIOD == 0 \
                or iteration == max_iter \
                or iteration == 0:
            logger.info(
                meters.delimiter.join(
                    [
                        "train eta: {eta}",
                        "iter: {iter}",
                        "{meters}",
                        "lr: {lr:.6f}",
                        "max mem: {memory:.0f}",
                    ]
                ).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer.param_groups[0]['lr'],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                )
            )

        scheduler.step()

        if iteration % cfg.SOLVER.CHECKPOINT_PERIOD == 0:
            checkpointer.save('model_{:06d}'.format(iteration), **arguments)

        if iteration % 100 == 0:
            checkpointer.save('model_last', **arguments)

        if iteration == max_iter:
            checkpointer.save('model_final', **arguments)

        if iteration % cfg.SOLVER.EVAL_PERIOD == 0 \
                or iteration == max_iter:
            metrics = val_in_train(
                model,
                criterion,
                cfg.DATASETS.VAL,
                data_loader_val,
                tblogger,
                iteration,
                checkpointer,
                distributed)

            if metrics is not None:
                if arguments['best_iou'] < metrics['iou']:
                    arguments['best_iou'] = metrics['iou']
                    logger.info('best_iou: {}'.format(arguments['best_iou']))
                    checkpointer.save('model_best', **arguments)
                else:
                    logger.info('best_iou: {}'.format(arguments['best_iou']))

            if tblogger is not None:
                tblogger.add_scalar(
                    'val/best_iou', arguments['best_iou'], iteration)

            model.train()

            end = time.time()

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info(
        "Total training time: {} ({:.4f} s / it)".format(
            total_time_str, total_training_time / (max_iter)
        )
    )
Beispiel #23
0
    criterion = nn.CrossEntropyLoss()
    net = ExampleNetwork(in_feat=3, out_feat=5, D=2)
    print(net)

    # a data loader must return a tuple of coords, features, and labels.
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    net = net.to(device)
    optimizer = SGD(net.parameters(), lr=1e-1)

    for i in range(10):
        optimizer.zero_grad()

        # Get new data
        coords, feat, label = data_loader()
        input = ME.SparseTensor(feat, coords=coords).to(device)
        label = label.to(device)

        # Forward
        output = net(input)

        # Loss
        loss = criterion(output.F, label)
        print('Iteration: ', i, ', Loss: ', loss.item())

        # Gradient
        loss.backward()
        optimizer.step()

    # Saving and loading a network
    torch.save(net.state_dict(), 'test.pth')
Beispiel #24
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias='auto',
                 conv_cfg=None,
                 norm_cfg=None,
                 activation='relu',
                 inplace=True,
                 order=('conv', 'norm', 'act')):
        super(ConvModule, self).__init__()
        assert conv_cfg is None or isinstance(conv_cfg, dict)
        assert norm_cfg is None or isinstance(norm_cfg, dict)
        self.is_mink = conv_cfg is not None and conv_cfg['type'] == 'MinkConv'
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
        self.activation = activation
        self.inplace = inplace
        self.order = order
        assert isinstance(self.order, tuple) and len(self.order) == 3
        assert set(order) == set(['conv', 'norm', 'act'])

        self.with_norm = norm_cfg is not None
        self.with_activatation = activation is not None
        # if the conv layer is before a norm layer, bias is unnecessary.
        if bias == 'auto':
            bias = False if self.with_norm else True
        self.with_bias = bias

        if self.with_norm and self.with_bias:
            warnings.warn('ConvModule has norm and bias at the same time')

        # build convolution layer
        self.conv = build_conv_layer(conv_cfg,
                                     in_channels,
                                     out_channels,
                                     kernel_size,
                                     stride=stride,
                                     padding=padding,
                                     dilation=dilation,
                                     groups=groups,
                                     bias=bias)
        # export the attributes of self.conv to a higher level for convenience
        self.in_channels = self.conv.in_channels
        self.out_channels = self.conv.out_channels
        self.kernel_size = self.conv.kernel_size
        self.stride = self.conv.stride
        if not self.is_mink:
            self.padding = self.conv.padding
            self.output_padding = self.conv.output_padding
            self.transposed = self.conv.transposed
            self.groups = self.conv.groups
        self.dilation = self.conv.dilation

        # build normalization layers
        if self.with_norm:
            # norm layer is after conv layer
            if order.index('norm') > order.index('conv'):
                norm_channels = out_channels
            else:
                norm_channels = in_channels
            self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels)
            self.add_module(self.norm_name, norm)

        # build activation layer
        if self.with_activatation:
            # TODO: introduce `act_cfg` and supports more activation layers
            if self.activation not in ['relu', 'MinkRelu']:
                raise ValueError('{} is currently not supported.'.format(
                    self.activation))
            if self.activation == 'relu':
                self.activate = nn.ReLU(inplace=inplace)
            if self.activation == 'MinkRelu':
                import MinkowskiEngine as ME
                ME.MinkowskiReLU(inplace=True)

        # Use msra init by default
        self.init_weights()
Beispiel #25
0
    # Define a model and load the weights
    model = MinkUNet34C(3, 20).to(device)
    model_dict = torch.load(config.weights)
    model.load_state_dict(model_dict)
    model.eval()

    coords, colors, pcd = load_file(config.file_name)
    # Measure time
    with torch.no_grad():
        voxel_size = 0.02
        # Feed-forward pass and get the prediction
        in_field = ME.TensorField(
            features=normalize_color(torch.from_numpy(colors)),
            coordinates=ME.utils.batched_coordinates([coords / voxel_size],
                                                     dtype=torch.float32),
            quantization_mode=ME.SparseTensorQuantizationMode.
            UNWEIGHTED_AVERAGE,
            minkowski_algorithm=ME.MinkowskiAlgorithm.SPEED_OPTIMIZED,
            device=device,
        )
        # Convert to a sparse tensor
        sinput = in_field.sparse()
        # Output sparse tensor
        soutput = model(sinput)
        # get the prediction on the input tensor field
        out_field = soutput.slice(in_field)
        logits = out_field.F

    _, pred = logits.max(1)
    pred = pred.cpu().numpy()
Beispiel #26
0
    def inference_one_batch(self, input_dict, phase):
        assert phase in ['train', 'val', 'test']
        ##################################
        # training
        if (phase == 'train'):
            self.model.train()
            ###############################################
            # forward pass
            sinput_src = ME.SparseTensor(input_dict['src_F'].to(self.device),
                                         coordinates=input_dict['src_C'].to(
                                             self.device))
            sinput_tgt = ME.SparseTensor(input_dict['tgt_F'].to(self.device),
                                         coordinates=input_dict['tgt_C'].to(
                                             self.device))

            src_feats, tgt_feats, scores_overlap, scores_saliency = self.model(
                sinput_src, sinput_tgt)
            src_pcd, tgt_pcd = input_dict['pcd_src'].to(
                self.device), input_dict['pcd_tgt'].to(self.device)
            c_rot = input_dict['rot'].to(self.device)
            c_trans = input_dict['trans'].to(self.device)
            correspondence = input_dict['correspondences'].long().to(
                self.device)

            ###################################################
            # get loss
            stats = self.desc_loss(src_pcd, tgt_pcd, src_feats, tgt_feats,
                                   correspondence, c_rot, c_trans,
                                   scores_overlap, scores_saliency,
                                   input_dict['scale'])

            c_loss = stats['circle_loss'] * self.w_circle_loss + stats[
                'overlap_loss'] * self.w_overlap_loss + stats[
                    'saliency_loss'] * self.w_saliency_loss

            c_loss.backward()

        else:
            self.model.eval()
            with torch.no_grad():
                ###############################################
                # forward pass
                sinput_src = ME.SparseTensor(
                    input_dict['src_F'].to(self.device),
                    coordinates=input_dict['src_C'].to(self.device))
                sinput_tgt = ME.SparseTensor(
                    input_dict['tgt_F'].to(self.device),
                    coordinates=input_dict['tgt_C'].to(self.device))

                src_feats, tgt_feats, scores_overlap, scores_saliency = self.model(
                    sinput_src, sinput_tgt)
                src_pcd, tgt_pcd = input_dict['pcd_src'].to(
                    self.device), input_dict['pcd_tgt'].to(self.device)
                c_rot = input_dict['rot'].to(self.device)
                c_trans = input_dict['trans'].to(self.device)
                correspondence = input_dict['correspondences'].long().to(
                    self.device)

                ###################################################
                # get loss
                stats = self.desc_loss(src_pcd, tgt_pcd, src_feats, tgt_feats,
                                       correspondence, c_rot, c_trans,
                                       scores_overlap, scores_saliency,
                                       input_dict['scale'])

        ##################################
        # detach the gradients for loss terms
        stats['circle_loss'] = float(stats['circle_loss'].detach())
        stats['overlap_loss'] = float(stats['overlap_loss'].detach())
        stats['saliency_loss'] = float(stats['saliency_loss'].detach())

        return stats
Beispiel #27
0
def train(net, device, config):
    optimizer = optim.SGD(net.parameters(),
                          lr=config.lr,
                          momentum=config.momentum,
                          weight_decay=config.weight_decay)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.95)

    crit = torch.nn.CrossEntropyLoss()

    train_dataloader = make_data_loader('train',
                                        augment_data=True,
                                        batch_size=config.batch_size,
                                        shuffle=True,
                                        num_workers=config.num_workers,
                                        repeat=True,
                                        config=config)
    val_dataloader = make_data_loader('val',
                                      augment_data=False,
                                      batch_size=config.batch_size,
                                      shuffle=True,
                                      num_workers=config.num_workers,
                                      repeat=True,
                                      config=config)

    curr_iter = 0
    if os.path.exists(config.weights):
        checkpoint = torch.load(config.weights)
        net.load_state_dict(checkpoint['state_dict'])
        if config.load_optimizer.lower() == 'true':
            curr_iter = checkpoint['curr_iter'] + 1
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])

    net.train()
    train_iter = iter(train_dataloader)
    val_iter = iter(val_dataloader)
    logging.info(f'LR: {scheduler.get_lr()}')
    for i in range(curr_iter, config.max_iter):

        s = time()
        data_dict = train_iter.next()
        d = time() - s

        optimizer.zero_grad()
        sin = ME.SparseTensor(
            data_dict['coords'][:, :3] * config.voxel_size,
            data_dict['coords'].int(),
            allow_duplicate_coords=True,  # for classification, it doesn't matter
        ).to(device)
        sout = net(sin)
        loss = crit(sout.F, data_dict['labels'].to(device))
        loss.backward()
        optimizer.step()
        t = time() - s

        if i % config.stat_freq == 0:
            logging.info(
                f'Iter: {i}, Loss: {loss.item():.3e}, Data Loading Time: {d:.3e}, Tot Time: {t:.3e}'
            )

        if i % config.val_freq == 0 and i > 0:
            torch.save(
                {
                    'state_dict': net.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'curr_iter': i,
                }, config.weights)

            # Validation
            logging.info('Validation')
            test(net, val_iter, config, 'val')

            scheduler.step()
            logging.info(f'LR: {scheduler.get_lr()}')

            net.train()
    def __init__(self,
                 config,
                 in_channel,
                 out_channel,
                 final_dim=96,
                 dimension=3):

        ME.MinkowskiNetwork.__init__(self, dimension)
        # The normal channel for Modelnet is 3, for scannet is 6, for scanobjnn is 0
        normal_channel = 3

        self.CONV_TYPE = ConvType.SPATIAL_HYPERCUBE

        # self.dims = np.array([32, 64, 128, 256])
        self.dims = np.array([32, 64, 128, 256])
        # self.neighbor_ks = np.array([4, 4, 4, 4])
        # self.neighbor_ks = np.array([12, 12, 12, 12])
        self.neighbor_ks = np.array([16, 16, 16, 16])
        # self.neighbor_ks = np.array([32, 32, 32, 32])
        # self.neighbor_ks = np.array([8, 8, 16, 16])
        # self.neighbor_ks = np.array([32, 32, 32, 32])
        # self.neighbor_ks = np.array([8, 12, 16, 24])
        # self.neighbor_ks = np.array([8, 8, 8, 8])

        self.final_dim = final_dim

        stem_dim = self.dims[0]

        if config.xyz_input:
            in_channel = normal_channel + in_channel
        else:
            in_channel = in_channel

        # pixel size 1
        self.stem1 = nn.Sequential(
            ME.MinkowskiConvolution(in_channel,
                                    stem_dim,
                                    kernel_size=config.ks,
                                    dimension=3),
            ME.MinkowskiBatchNorm(stem_dim),
            ME.MinkowskiReLU(),
        )

        # self.stem2 = TDLayer(input_dim=self.dims[0], out_dim=self.dims[0])
        self.stem2 = nn.Sequential(
            ME.MinkowskiConvolution(stem_dim,
                                    stem_dim,
                                    kernel_size=3,
                                    dimension=3,
                                    stride=2),
            ME.MinkowskiBatchNorm(stem_dim),
            ME.MinkowskiReLU(),
        )
        # window_beta = 5
        window_beta = None
        base_r = 5
        self.ALPHA_BLENDING_MID_TR = False
        self.ALPHA_BLENDING_FIRST_TR = True

        # self.PTBlock1 = PTBlock(in_dim=self.dims[0], hidden_dim = self.dims[0], n_sample=self.neighbor_ks[0], skip_attn=False, r=base_r, kernel_size=config.ks, window_beta=window_beta)

        self.PTBlock2 = PTBlock(in_dim=self.dims[1],
                                hidden_dim=self.dims[1],
                                n_sample=self.neighbor_ks[1],
                                skip_attn=False,
                                r=2 * base_r,
                                kernel_size=config.ks,
                                window_beta=window_beta)
        self.PTBlock3 = PTBlock(in_dim=self.dims[2],
                                hidden_dim=self.dims[2],
                                n_sample=self.neighbor_ks[2],
                                skip_attn=False,
                                r=2 * base_r,
                                kernel_size=config.ks,
                                window_beta=window_beta)
        self.PTBlock4 = PTBlock(in_dim=self.dims[3],
                                hidden_dim=self.dims[3],
                                n_sample=self.neighbor_ks[3],
                                skip_attn=False,
                                r=4 * base_r,
                                kernel_size=config.ks,
                                window_beta=window_beta)

        if self.ALPHA_BLENDING_MID_TR:
            self.PTBlock4_branch = PTBlock(in_dim=self.dims[3],
                                           hidden_dim=self.dims[3],
                                           n_sample=self.neighbor_ks[3],
                                           skip_attn=True,
                                           r=4 * base_r,
                                           kernel_size=config.ks)
            self.PTBlock5_branch = PTBlock(in_dim=128,
                                           hidden_dim=128,
                                           n_sample=self.neighbor_ks[3],
                                           skip_attn=True,
                                           r=2 * base_r,
                                           kernel_size=config.ks)  # out: 256

        if self.ALPHA_BLENDING_FIRST_TR:
            self.PTBlock1_branch = PTBlock(in_dim=self.dims[0],
                                           hidden_dim=self.dims[0],
                                           n_sample=self.neighbor_ks[0],
                                           skip_attn=False,
                                           r=base_r,
                                           kernel_size=config.ks,
                                           window_beta=window_beta)

        self.PTBlock5 = PTBlock(in_dim=128,
                                hidden_dim=128,
                                n_sample=self.neighbor_ks[3],
                                skip_attn=False,
                                r=2 * base_r,
                                kernel_size=config.ks,
                                window_beta=window_beta)  # out: 256
        self.PTBlock6 = PTBlock(in_dim=128,
                                hidden_dim=128,
                                n_sample=self.neighbor_ks[2],
                                skip_attn=False,
                                r=2 * base_r,
                                kernel_size=config.ks,
                                window_beta=window_beta)  # out: 128
        self.PTBlock7 = PTBlock(in_dim=96,
                                hidden_dim=96,
                                n_sample=self.neighbor_ks[1],
                                skip_attn=False,
                                r=base_r,
                                kernel_size=config.ks,
                                window_beta=window_beta)  # out: 64

        # self.PTBlock5 = StackedPTBlock(in_dim=128, hidden_dim=128,n_sample=self.neighbor_ks[3], skip_attn=False, r=2*base_r, kernel_size=config.ks) # out: 256
        # self.PTBlock6 = StackedPTBlock(in_dim=128, hidden_dim=128, n_sample=self.neighbor_ks[2], skip_attn=False, r=2*base_r, kernel_size=config.ks) # out: 128
        # self.PTBlock7 = StackedPTBlock(in_dim=96, hidden_dim=96, n_sample=self.neighbor_ks[1], skip_attn=False, r=base_r, kernel_size=config.ks) # out: 64

        # BLOCK_TYPE = SingleConv
        BLOCK_TYPE = BasicBlock

        self.PTBlock1 = self._make_layer(block=BLOCK_TYPE,
                                         inplanes=self.dims[0],
                                         planes=self.dims[0],
                                         num_blocks=1)
        # self.PTBlock2 = self._make_layer(block=BLOCK_TYPE, inplanes=self.dims[1], planes=self.dims[1], num_blocks=1)
        # self.PTBlock3 = self._make_layer(block=BLOCK_TYPE, inplanes=self.dims[2], planes=self.dims[2], num_blocks=1)
        # self.PTBlock4 = self._make_layer(block=BLOCK_TYPE, inplanes=self.dims[3], planes=self.dims[3], num_blocks=1)
        # self.PTBlock5 = self._make_layer(block=BLOCK_TYPE, inplanes=128, planes=128, num_blocks=1)
        # self.PTBlock6 = self._make_layer(block=BLOCK_TYPE, inplanes=128, planes=128, num_blocks=1)
        # self.PTBlock7 = self._make_layer(block=BLOCK_TYPE, inplanes=96,  planes=96, num_blocks=1)

        # pixel size 2
        self.TDLayer1 = TDLayer(input_dim=self.dims[0],
                                out_dim=self.dims[1])  # strided conv

        # pixel size 4
        self.TDLayer2 = TDLayer(input_dim=self.dims[1], out_dim=self.dims[2])

        # pixel size 8
        self.TDLayer3 = TDLayer(input_dim=self.dims[2], out_dim=self.dims[3])

        # pixel size 16: PTBlock4

        # pixel size 8
        # self.TULayer5 = TULayer(input_a_dim=self.dims[3], input_b_dim = self.dims[2], out_dim=self.dims[2]) # out: 256//2 + 128 = 256
        self.TULayer5 = ResNetLikeTU(input_a_dim=256,
                                     input_b_dim=128,
                                     out_dim=128)  # out: 256//2 + 128 = 256

        # pixel size 4
        # self.TULayer6 = TULayer(input_a_dim=self.dims[2], input_b_dim = self.dims[1], out_dim=self.dims[1]) # out: 256//2 + 64 = 192
        self.TULayer6 = ResNetLikeTU(input_a_dim=128,
                                     input_b_dim=64,
                                     out_dim=128)  # out: 256//2 + 64 = 192

        # pixel size 2
        # self.TULayer7 = TULayer(input_a_dim=self.dims[1], input_b_dim = self.dims[0], out_dim=self.dims[0]) # 128 // 2 + 32 = 96
        self.TULayer7 = ResNetLikeTU(input_a_dim=128,
                                     input_b_dim=32,
                                     out_dim=96)  # 128 // 2 + 32 = 96

        # pixel size 1
        # self.PTBlock8 = PTBlock(in_dim=self.dims[0], hidden_dim=self.dims[0], n_sample=self.neighbor_ks[1])  # 32

        # self.TULayer8 = TULayer(input_a_dim=self.dims[0], input_b_dim = self.dims[0], out_dim=self.dims[0]) # 64 // 2 + 32
        self.TULayer8 = ResNetLikeTU(input_a_dim=96,
                                     input_b_dim=32,
                                     out_dim=96)  # 64 // 2 + 32
        self.fc = ME.MinkowskiLinear(96, out_channel)
  def __init__(self,
               in_channels=3,
               out_channels=32,
               bn_momentum=0.1,
               normalize_feature=None,
               conv1_kernel_size=None,
               D=3):
    ME.MinkowskiNetwork.__init__(self, D)
    NORM_TYPE = self.NORM_TYPE
    BLOCK_NORM_TYPE = self.BLOCK_NORM_TYPE
    CHANNELS = self.CHANNELS
    TR_CHANNELS = self.TR_CHANNELS
    self.normalize_feature = normalize_feature
    self.conv1 = ME.MinkowskiConvolution(
        in_channels=in_channels,
        out_channels=CHANNELS[1],
        kernel_size=conv1_kernel_size,
        stride=1,
        dilation=1,
        has_bias=False,
        dimension=D)
    self.norm1 = get_norm(NORM_TYPE, CHANNELS[1], bn_momentum=bn_momentum, D=D)

    self.block1 = get_block(
        BLOCK_NORM_TYPE, CHANNELS[1], CHANNELS[1], bn_momentum=bn_momentum, D=D)

    self.conv2 = ME.MinkowskiConvolution(
        in_channels=CHANNELS[1],
        out_channels=CHANNELS[2],
        kernel_size=3,
        stride=2,
        dilation=1,
        has_bias=False,
        dimension=D)
    self.norm2 = get_norm(NORM_TYPE, CHANNELS[2], bn_momentum=bn_momentum, D=D)

    self.block2 = get_block(
        BLOCK_NORM_TYPE, CHANNELS[2], CHANNELS[2], bn_momentum=bn_momentum, D=D)

    self.conv3 = ME.MinkowskiConvolution(
        in_channels=CHANNELS[2],
        out_channels=CHANNELS[3],
        kernel_size=3,
        stride=2,
        dilation=1,
        has_bias=False,
        dimension=D)
    self.norm3 = get_norm(NORM_TYPE, CHANNELS[3], bn_momentum=bn_momentum, D=D)

    self.block3 = get_block(
        BLOCK_NORM_TYPE, CHANNELS[3], CHANNELS[3], bn_momentum=bn_momentum, D=D)

    self.conv4 = ME.MinkowskiConvolution(
        in_channels=CHANNELS[3],
        out_channels=CHANNELS[4],
        kernel_size=3,
        stride=2,
        dilation=1,
        has_bias=False,
        dimension=D)
    self.norm4 = get_norm(NORM_TYPE, CHANNELS[4], bn_momentum=bn_momentum, D=D)

    self.block4 = get_block(
        BLOCK_NORM_TYPE, CHANNELS[4], CHANNELS[4], bn_momentum=bn_momentum, D=D)

    self.conv4_tr = ME.MinkowskiConvolutionTranspose(
        in_channels=CHANNELS[4],
        out_channels=TR_CHANNELS[4],
        kernel_size=3,
        stride=2,
        dilation=1,
        has_bias=False,
        dimension=D)
    self.norm4_tr = get_norm(NORM_TYPE, TR_CHANNELS[4], bn_momentum=bn_momentum, D=D)

    self.block4_tr = get_block(
        BLOCK_NORM_TYPE, TR_CHANNELS[4], TR_CHANNELS[4], bn_momentum=bn_momentum, D=D)

    self.conv3_tr = ME.MinkowskiConvolutionTranspose(
        in_channels=CHANNELS[3] + TR_CHANNELS[4],
        out_channels=TR_CHANNELS[3],
        kernel_size=3,
        stride=2,
        dilation=1,
        has_bias=False,
        dimension=D)
    self.norm3_tr = get_norm(NORM_TYPE, TR_CHANNELS[3], bn_momentum=bn_momentum, D=D)

    self.block3_tr = get_block(
        BLOCK_NORM_TYPE, TR_CHANNELS[3], TR_CHANNELS[3], bn_momentum=bn_momentum, D=D)

    self.conv2_tr = ME.MinkowskiConvolutionTranspose(
        in_channels=CHANNELS[2] + TR_CHANNELS[3],
        out_channels=TR_CHANNELS[2],
        kernel_size=3,
        stride=2,
        dilation=1,
        has_bias=False,
        dimension=D)
    self.norm2_tr = get_norm(NORM_TYPE, TR_CHANNELS[2], bn_momentum=bn_momentum, D=D)

    self.block2_tr = get_block(
        BLOCK_NORM_TYPE, TR_CHANNELS[2], TR_CHANNELS[2], bn_momentum=bn_momentum, D=D)

    self.conv1_tr = ME.MinkowskiConvolution(
        in_channels=CHANNELS[1] + TR_CHANNELS[2],
        out_channels=TR_CHANNELS[1],
        kernel_size=1,
        stride=1,
        dilation=1,
        has_bias=False,
        dimension=D)

    # self.block1_tr = BasicBlockBN(TR_CHANNELS[1], TR_CHANNELS[1], bn_momentum=bn_momentum, D=D)

    self.final = ME.MinkowskiConvolution(
        in_channels=TR_CHANNELS[1],
        out_channels=out_channels,
        kernel_size=1,
        stride=1,
        dilation=1,
        has_bias=True,
        dimension=D)
Beispiel #30
0
def train(net, dataloader, device, config):
    optimizer = optim.SGD(net.parameters(),
                          lr=config.lr,
                          momentum=config.momentum,
                          weight_decay=config.weight_decay)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.95)

    crit = nn.BCEWithLogitsLoss()

    net.train()
    train_iter = iter(dataloader)
    # val_iter = iter(val_dataloader)
    logging.info(f'LR: {scheduler.get_lr()}')
    for i in range(config.max_iter):

        s = time()
        data_dict = train_iter.next()
        d = time() - s

        optimizer.zero_grad()

        in_feat = torch.ones((len(data_dict['coords']), 1))

        sin = ME.SparseTensor(
            feats=in_feat,
            coords=data_dict['coords'],
        ).to(device)

        # Generate target sparse tensor
        cm = sin.coords_man
        target_key = cm.create_coords_key(ME.utils.batched_coordinates(
            data_dict['xyzs']),
                                          force_creation=True,
                                          allow_duplicate_coords=True)

        # Generate from a dense tensor
        out_cls, targets, sout = net(sin, target_key)
        num_layers, loss = len(out_cls), 0
        losses = []
        for out_cl, target in zip(out_cls, targets):
            curr_loss = crit(out_cl.F.squeeze(),
                             target.type(out_cl.F.dtype).to(device))
            losses.append(curr_loss.item())
            loss += curr_loss / num_layers

        loss.backward()
        optimizer.step()
        t = time() - s

        if i % config.stat_freq == 0:
            logging.info(
                f'Iter: {i}, Loss: {loss.item():.3e}, Data Loading Time: {d:.3e}, Tot Time: {t:.3e}'
            )

        if i % config.val_freq == 0 and i > 0:
            torch.save(
                {
                    'state_dict': net.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'curr_iter': i,
                }, config.weights)

            scheduler.step()
            logging.info(f'LR: {scheduler.get_lr()}')

            net.train()