예제 #1
0
def main():
    args = parser.parse_args()
    with open(args.config) as f:
        config = yaml.load(f)
    print("\n**************************")
    for k, v in config['common'].items():
        setattr(args, k, v)
        print('\n[%s]:'%(k), v)
    print("\n**************************\n")
    
    try:
        os.makedirs(args.save_path)
    except OSError:
        pass
    
    train_transforms = transforms.Compose([
        d_utils.PointcloudToTensor()
    ])
    test_transforms = transforms.Compose([
        d_utils.PointcloudToTensor()
    ])
    
    train_dataset = ShapeNetPart(root = args.data_root, num_points = args.num_points, split = 'trainval', normalize = True, transforms = train_transforms)
    train_dataloader = DataLoader(
        train_dataset, 
        batch_size=args.batch_size,
        shuffle=True, 
        num_workers=int(args.workers), 
        pin_memory=True
    )
    
    global test_dataset
    test_dataset = ShapeNetPart(root = args.data_root, num_points = args.num_points, split = 'test', normalize = True, transforms = test_transforms)
    test_dataloader = DataLoader(
        test_dataset, 
        batch_size=args.batch_size,
        shuffle=False, 
        num_workers=int(args.workers), 
        pin_memory=True
    )
    
    model = RSCNN_MSN(num_classes = args.num_classes, input_channels = args.input_channels, relation_prior = args.relation_prior, use_xyz = True)
    model.cuda()
    optimizer = optim.Adam(
        model.parameters(), lr=args.base_lr, weight_decay=args.weight_decay)

    lr_lbmd = lambda e: max(args.lr_decay**(e // args.decay_step), args.lr_clip / args.base_lr)
    bnm_lmbd = lambda e: max(args.bn_momentum * args.bn_decay**(e // args.decay_step), args.bnm_clip)
    lr_scheduler = lr_sched.LambdaLR(optimizer, lr_lbmd)
    bnm_scheduler = pt_utils.BNMomentumScheduler(model, bnm_lmbd)
    
    if args.checkpoint is not '':
        model.load_state_dict(torch.load(args.checkpoint))
        print('Load model successfully: %s' % (args.checkpoint))

    criterion = nn.CrossEntropyLoss()
    num_batch = len(train_dataset)/args.batch_size
    
    # training
    train(train_dataloader, test_dataloader, model, criterion, optimizer, lr_scheduler, bnm_scheduler, args, num_batch)
def main():
    args = parser.parse_args()
    with open(args.config) as f:
        config = yaml.load(f)
    for k, v in config['common'].items():
        setattr(args, k, v)
    
    test_transforms = transforms.Compose([
        d_utils.PointcloudToTensor()
    ])
    
    test_dataset = ShapeNetPart(root = args.data_root, num_points = args.num_points, split = 'test', normalize = True, transforms = test_transforms)
    test_dataloader = DataLoader(
        test_dataset, 
        batch_size=args.batch_size,
        shuffle=False, 
        num_workers=int(args.workers), 
        pin_memory=True
    )
    
    model = RSCNN_MSN(num_classes = args.num_classes, input_channels = args.input_channels, relation_prior = args.relation_prior, use_xyz = True)
    model.cuda()

    if args.checkpoint is not '':
        model.load_state_dict(torch.load(args.checkpoint))
        print('Load model successfully: %s' % (args.checkpoint))

    # evaluate
    PointcloudScale = d_utils.PointcloudScale(scale_low=0.87, scale_high=1.15)   # initialize random scaling
    model.eval()
    global_Class_mIoU, global_Inst_mIoU = 0, 0
    seg_classes = test_dataset.seg_classes
    seg_label_to_cat = {}           # {0:Airplane, 1:Airplane, ...49:Table}
    for cat in seg_classes.keys():
        for label in seg_classes[cat]:
            seg_label_to_cat[label] = cat
    
    for i in range(NUM_REPEAT):
        shape_ious = {cat:[] for cat in seg_classes.keys()}
        for _, data in enumerate(test_dataloader, 0):
            points, target, cls = data
            points, target = Variable(points, volatile=True), Variable(target, volatile=True)
            points, target = points.cuda(), target.cuda()

            batch_one_hot_cls = np.zeros((len(cls), 16))   # 16 object classes
            for b in range(len(cls)):
                batch_one_hot_cls[b, int(cls[b])] = 1
            batch_one_hot_cls = torch.from_numpy(batch_one_hot_cls)
            batch_one_hot_cls = Variable(batch_one_hot_cls.float().cuda())

            pred = 0
            new_points = Variable(torch.zeros(points.size()[0], points.size()[1], points.size()[2]).cuda(), volatile=True)
            for v in range(NUM_VOTE):
                if v > 0:
                    new_points.data = PointcloudScale(points.data)
                pred += F.softmax(model(new_points, batch_one_hot_cls), dim = 2)
            pred /= NUM_VOTE
            
            pred = pred.data.cpu()
            target = target.data.cpu()
            pred_val = torch.zeros(len(cls), args.num_points).type(torch.LongTensor)
            # pred to the groundtruth classes (selected by seg_classes[cat])
            for b in range(len(cls)):
                cat = seg_label_to_cat[target[b, 0]]
                logits = pred[b, :, :]   # (num_points, num_classes)
                pred_val[b, :] = logits[:, seg_classes[cat]].max(1)[1] + seg_classes[cat][0]
            
            for b in range(len(cls)):
                segp = pred_val[b, :]
                segl = target[b, :]
                cat = seg_label_to_cat[segl[0]]
                part_ious = [0.0 for _ in range(len(seg_classes[cat]))]
                for l in seg_classes[cat]:
                    if torch.sum((segl == l) | (segp == l)) == 0:
                        # part is not present in this shape
                        part_ious[l - seg_classes[cat][0]] = 1.0
                    else:
                        part_ious[l - seg_classes[cat][0]] = torch.sum((segl == l) & (segp == l)) / float(torch.sum((segl == l) | (segp == l)))
                shape_ious[cat].append(np.mean(part_ious))
        
        instance_ious = []
        for cat in shape_ious.keys():
            for iou in shape_ious[cat]:
                instance_ious.append(iou)
            shape_ious[cat] = np.mean(shape_ious[cat])
        mean_class_ious = np.mean(list(shape_ious.values()))
        
        print('\n------ Repeat %3d ------' % (i + 1))
        for cat in sorted(shape_ious.keys()):
            print('%s: %0.6f'%(cat, shape_ious[cat]))
        print('Class_mIoU: %0.6f' % (mean_class_ious))
        print('Instance_mIoU: %0.6f' % (np.mean(instance_ious)))

        if mean_class_ious > global_Class_mIoU:
            global_Class_mIoU = mean_class_ious
            global_Inst_mIoU = np.mean(instance_ious)
                
    print('\nBest voting Class_mIoU = %0.6f, Instance_mIoU = %0.6f' % (global_Class_mIoU, global_Inst_mIoU))
예제 #3
0
def train(args, io):
    train_dataset = ShapeNetPart(partition='trainval',
                                 num_points=args.num_points,
                                 class_choice=args.class_choice)
    if (len(train_dataset) < 100):
        drop_last = False
    else:
        drop_last = True
    train_loader = DataLoader(train_dataset,
                              num_workers=8,
                              batch_size=args.batch_size,
                              shuffle=True,
                              drop_last=drop_last)
    test_loader = DataLoader(ShapeNetPart(partition='test',
                                          num_points=args.num_points,
                                          class_choice=args.class_choice),
                             num_workers=8,
                             batch_size=args.test_batch_size,
                             shuffle=False,
                             drop_last=False)

    device = torch.device("cuda" if args.cuda else "cpu")
    io.cprint("Let's use" + str(torch.cuda.device_count()) + "GPUs!")

    seg_num_all = train_loader.dataset.seg_num_all
    seg_start_index = train_loader.dataset.seg_start_index

    # create model
    model = CurveNet().to(device)
    model = nn.DataParallel(model)

    if args.use_sgd:
        print("Use SGD")
        opt = optim.SGD(model.parameters(),
                        lr=args.lr * 100,
                        momentum=args.momentum,
                        weight_decay=1e-4)
    else:
        print("Use Adam")
        opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4)

    if args.scheduler == 'cos':
        scheduler = CosineAnnealingLR(opt, args.epochs, eta_min=1e-3)
    elif args.scheduler == 'step':
        scheduler = MultiStepLR(opt, [140, 180], gamma=0.1)
    criterion = cal_loss

    best_test_iou = 0
    for epoch in range(args.epochs):
        ####################
        # Train
        ####################
        train_loss = 0.0
        count = 0.0
        model.train()
        train_true_cls = []
        train_pred_cls = []
        train_true_seg = []
        train_pred_seg = []
        train_label_seg = []
        for data, label, seg in train_loader:
            seg = seg - seg_start_index
            label_one_hot = np.zeros((label.shape[0], 16))
            for idx in range(label.shape[0]):
                label_one_hot[idx, label[idx]] = 1
            label_one_hot = torch.from_numpy(label_one_hot.astype(np.float32))
            data, label_one_hot, seg = data.to(device), label_one_hot.to(
                device), seg.to(device)
            data = data.permute(0, 2, 1)
            batch_size = data.size()[0]
            opt.zero_grad()
            seg_pred = model(data, label_one_hot)
            seg_pred = seg_pred.permute(0, 2, 1).contiguous()
            loss = criterion(seg_pred.view(-1, seg_num_all),
                             seg.view(-1, 1).squeeze())
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            opt.step()
            pred = seg_pred.max(dim=2)[1]  # (batch_size, num_points)
            count += batch_size
            train_loss += loss.item() * batch_size
            seg_np = seg.cpu().numpy()  # (batch_size, num_points)
            pred_np = pred.detach().cpu().numpy()  # (batch_size, num_points)
            train_true_cls.append(
                seg_np.reshape(-1))  # (batch_size * num_points)
            train_pred_cls.append(
                pred_np.reshape(-1))  # (batch_size * num_points)
            train_true_seg.append(seg_np)
            train_pred_seg.append(pred_np)
            train_label_seg.append(label.reshape(-1))
        if args.scheduler == 'cos':
            scheduler.step()
        elif args.scheduler == 'step':
            if opt.param_groups[0]['lr'] > 1e-5:
                scheduler.step()
            if opt.param_groups[0]['lr'] < 1e-5:
                for param_group in opt.param_groups:
                    param_group['lr'] = 1e-5
        train_true_cls = np.concatenate(train_true_cls)
        train_pred_cls = np.concatenate(train_pred_cls)
        train_acc = metrics.accuracy_score(train_true_cls, train_pred_cls)
        avg_per_class_acc = metrics.balanced_accuracy_score(
            train_true_cls, train_pred_cls)
        train_true_seg = np.concatenate(train_true_seg, axis=0)
        train_pred_seg = np.concatenate(train_pred_seg, axis=0)
        train_label_seg = np.concatenate(train_label_seg)
        train_ious = calculate_shape_IoU(train_pred_seg, train_true_seg,
                                         train_label_seg, args.class_choice)
        outstr = 'Train %d, loss: %.6f, train acc: %.6f, train avg acc: %.6f, train iou: %.6f' % (
            epoch, train_loss * 1.0 / count, train_acc, avg_per_class_acc,
            np.mean(train_ious))
        io.cprint(outstr)

        ####################
        # Test
        ####################
        test_loss = 0.0
        count = 0.0
        model.eval()
        test_true_cls = []
        test_pred_cls = []
        test_true_seg = []
        test_pred_seg = []
        test_label_seg = []
        for data, label, seg in test_loader:
            seg = seg - seg_start_index
            label_one_hot = np.zeros((label.shape[0], 16))
            for idx in range(label.shape[0]):
                label_one_hot[idx, label[idx]] = 1
            label_one_hot = torch.from_numpy(label_one_hot.astype(np.float32))
            data, label_one_hot, seg = data.to(device), label_one_hot.to(
                device), seg.to(device)
            data = data.permute(0, 2, 1)
            batch_size = data.size()[0]
            seg_pred = model(data, label_one_hot)
            seg_pred = seg_pred.permute(0, 2, 1).contiguous()
            loss = criterion(seg_pred.view(-1, seg_num_all),
                             seg.view(-1, 1).squeeze())
            pred = seg_pred.max(dim=2)[1]
            count += batch_size
            test_loss += loss.item() * batch_size
            seg_np = seg.cpu().numpy()
            pred_np = pred.detach().cpu().numpy()
            test_true_cls.append(seg_np.reshape(-1))
            test_pred_cls.append(pred_np.reshape(-1))
            test_true_seg.append(seg_np)
            test_pred_seg.append(pred_np)
            test_label_seg.append(label.reshape(-1))
        test_true_cls = np.concatenate(test_true_cls)
        test_pred_cls = np.concatenate(test_pred_cls)
        test_acc = metrics.accuracy_score(test_true_cls, test_pred_cls)
        avg_per_class_acc = metrics.balanced_accuracy_score(
            test_true_cls, test_pred_cls)
        test_true_seg = np.concatenate(test_true_seg, axis=0)
        test_pred_seg = np.concatenate(test_pred_seg, axis=0)
        test_label_seg = np.concatenate(test_label_seg)
        test_ious = calculate_shape_IoU(test_pred_seg, test_true_seg,
                                        test_label_seg, args.class_choice)
        outstr = 'Test %d, loss: %.6f, test acc: %.6f, test avg acc: %.6f, test iou: %.6f, best iou %.6f' % (
            epoch, test_loss * 1.0 / count, test_acc, avg_per_class_acc,
            np.mean(test_ious), best_test_iou)
        io.cprint(outstr)
        if np.mean(test_ious) >= best_test_iou:
            best_test_iou = np.mean(test_ious)
            torch.save(model.state_dict(),
                       '../checkpoints/%s/models/model.t7' % args.exp_name)
예제 #4
0
def test(args, io):
    test_loader = DataLoader(ShapeNetPart(partition='test',
                                          num_points=args.num_points,
                                          class_choice=args.class_choice),
                             batch_size=args.test_batch_size,
                             shuffle=True,
                             drop_last=False)

    device = torch.device("cuda" if args.cuda else "cpu")

    #Try to load models
    seg_start_index = test_loader.dataset.seg_start_index
    model = CurveNet().to(device)
    model = nn.DataParallel(model)
    model.load_state_dict(torch.load(args.model_path))

    model = model.eval()
    test_acc = 0.0
    test_true_cls = []
    test_pred_cls = []
    test_true_seg = []
    test_pred_seg = []
    test_label_seg = []
    category = {}
    for data, label, seg in test_loader:
        seg = seg - seg_start_index
        label_one_hot = np.zeros((label.shape[0], 16))
        for idx in range(label.shape[0]):
            label_one_hot[idx, label[idx]] = 1
        label_one_hot = torch.from_numpy(label_one_hot.astype(np.float32))
        data, label_one_hot, seg = data.to(device), label_one_hot.to(
            device), seg.to(device)
        data = data.permute(0, 2, 1)
        seg_pred = model(data, label_one_hot)
        seg_pred = seg_pred.permute(0, 2, 1).contiguous()
        pred = seg_pred.max(dim=2)[1]
        seg_np = seg.cpu().numpy()
        pred_np = pred.detach().cpu().numpy()
        test_true_cls.append(seg_np.reshape(-1))
        test_pred_cls.append(pred_np.reshape(-1))
        test_true_seg.append(seg_np)
        test_pred_seg.append(pred_np)
        test_label_seg.append(label.reshape(-1))

    test_true_cls = np.concatenate(test_true_cls)
    test_pred_cls = np.concatenate(test_pred_cls)
    test_acc = metrics.accuracy_score(test_true_cls, test_pred_cls)
    avg_per_class_acc = metrics.balanced_accuracy_score(
        test_true_cls, test_pred_cls)
    test_true_seg = np.concatenate(test_true_seg, axis=0)
    test_pred_seg = np.concatenate(test_pred_seg, axis=0)
    test_label_seg = np.concatenate(test_label_seg)
    test_ious, category = calculate_shape_IoU(test_pred_seg,
                                              test_true_seg,
                                              test_label_seg,
                                              args.class_choice,
                                              eva=True)
    outstr = 'Test :: test acc: %.6f, test avg acc: %.6f, test iou: %.6f' % (
        test_acc, avg_per_class_acc, np.mean(test_ious))
    io.cprint(outstr)
    results = []
    for key in category.keys():
        results.append((int(key), np.mean(category[key]), len(category[key])))
    results.sort(key=lambda x: x[0])
    for re in results:
        io.cprint('idx: %d mIoU: %.3f num: %d' % (re[0], re[1], re[2]))
예제 #5
0
파일: train.py 프로젝트: cyysc1998/pt3d
def train(args, configpath):
    io = init(args, configpath)
    train_dataset = ShapeNetPart(partition='trainval',
                                 num_points=args.num_points)
    if (len(train_dataset) < 100):
        drop_last = False
    else:
        drop_last = True
    train_loader = DataLoader(train_dataset,
                              num_workers=8,
                              batch_size=args.batch_size,
                              shuffle=True,
                              drop_last=drop_last)
    test_loader = DataLoader(ShapeNetPart(partition='test',
                                          num_points=args.num_points),
                             num_workers=8,
                             batch_size=args.test_batch_size,
                             shuffle=True,
                             drop_last=False)

    seg_num_all = train_loader.dataset.seg_num_all
    seg_start_index = train_loader.dataset.seg_start_index

    device = torch.device("cuda" if args.cuda else "cpu")

    if args.model == 'consnet':
        model = ConsNet(args, seg_num_all).to(device)
    elif args.model == 'pretrain':
        model = ConsNet(args, seg_num_all).to(device)
        model.load_state_dict(torch.load(args.pretrain_path))
    else:
        raise Exception("Not implemented")

    if args.parallel == True:
        model = nn.DataParallel(model)

    print(str(model))

    if args.use_sgd:
        print("Use SGD")
        opt = optim.SGD(model.parameters(),
                        lr=args.lr * 100,
                        momentum=args.momentum,
                        weight_decay=1e-4)
        cur_lr = args.lr * 100
    else:
        print("Use Adam")
        opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4)
        cur_lr = args.lr

    if args.scheduler == 'cos':
        scheduler = CosineAnnealingLR(opt, args.epochs, eta_min=1e-3)
        print('Use COS')
    elif args.scheduler == 'step':
        scheduler = StepLR(opt, step_size=20, gamma=0.7)
        print('Use Step')

    if args.loss == 'l1loss':
        print('Use L1 Loss')
    elif args.loss == 'chamfer':
        print('Use Chamfer Distance')
    else:
        print('Not implemented')

    io.cprint('Experiment: %s' % args.exp_name)

    # Train
    min_loss = 100
    io.cprint('Begin to train...')
    for epoch in range(args.epochs):
        io.cprint(
            '=====================================Epoch %d========================================'
            % epoch)
        io.cprint('*****Train*****')
        # Train
        model.train()
        train_loss = 0
        for i, point in enumerate(train_loader):
            data, label, seg = point
            if epoch < 5:
                lr = 0.18 * cur_lr * epoch + 0.1 * cur_lr
                adjust_learning_rate(opt, lr)

            if args.task == '1obj_rotate':
                data1, data2, label1, label2 = obj_rotate_perm(
                    data, label)  # (B, N, 3)
            elif args.task == '2obj':
                data1, data2, label1, label2 = obj_2_perm(data,
                                                          label)  # (B, N, 3)
            elif args.task == 'alter':
                if epoch % 2 == 0:
                    data1, data2, label1, label2 = obj_rotate_perm(
                        data, label)  # (B, N, 3)
                else:
                    data1, data2, label1, label2 = obj_2_perm(
                        data, label)  # (B, N, 3)
            else:
                print('Task not implemented!')
                exit(0)

            if args.mixup == 'emd':
                mixup_data = emd_mixup(data1, data2)  # (B, N, 3)
            elif args.mixup == 'add':
                mixup_data = add_mixup(data1, data2)  # (B, N, 3)

            mixup_data = mixup_data.permute(0, 2, 1)  # (B, 3, N)
            batch_size = mixup_data.size()[0]

            seg = seg - seg_start_index

            if args.use_one_hot:
                label_one_hot1 = np.zeros((batch_size, 16))
                label_one_hot2 = np.zeros((batch_size, 16))
                for idx in range(batch_size):
                    label_one_hot1[idx, label1[idx]] = 1
                    label_one_hot2[idx, label2[idx]] = 1

                label_one_hot1 = torch.from_numpy(
                    label_one_hot1.astype(np.float32))
                label_one_hot2 = torch.from_numpy(
                    label_one_hot2.astype(np.float32))
            else:
                label_one_hot1 = torch.rand(batch_size, 16)
                label_one_hot2 = torch.rand(batch_size, 16)

            data, label_one_hot1, label_one_hot2, seg = data.to(
                device), label_one_hot1.to(device), label_one_hot2.to(
                    device), seg.to(device)

            # Project
            proj1 = rand_proj(data1)
            proj2 = rand_proj(data2)

            # Train
            opt.zero_grad()

            pred1 = model(mixup_data, proj1,
                          label_one_hot1).permute(0, 2, 1)  # (B, N, 3)
            pred2 = model(mixup_data, proj2,
                          label_one_hot2).permute(0, 2, 1)  # (B, N, 3)

            if args.loss == 'l1loss':
                loss = L1_loss(pred1, data1) + L1_loss(pred2, data2)
            elif args.loss == 'chamfer':
                loss1 = chamfer_distance(pred1, data1) + chamfer_distance(
                    data1, pred1)
                loss2 = chamfer_distance(pred2, data2) + chamfer_distance(
                    data2, pred2)
                loss = loss1 + loss2
            elif args.loss == 'emd':
                loss = emd_loss(pred1, data1) + emd_loss(pred2, data2)
            elif args.loss == 'emd2':
                loss = emd_loss_2(pred1, data1) + emd_loss_2(pred2, data2)
            else:
                raise NotImplementedError

            if args.l2loss:
                l2_loss = nn.MSELoss()(pred1, data1) + nn.MSELoss()(pred2,
                                                                    data2)
                loss += args.l2_param * l2_loss

            loss.backward()

            train_loss = train_loss + loss.item()
            opt.step()

            if (i + 1) % 100 == 0:
                io.cprint('iters %d, tarin loss: %.6f' % (i, train_loss / i))

        io.cprint('Learning rate: %.6f' % (opt.param_groups[0]['lr']))

        if args.scheduler == 'cos':
            scheduler.step()
        elif args.scheduler == 'step':
            if opt.param_groups[0]['lr'] > 1e-5:
                scheduler.step()
            if opt.param_groups[0]['lr'] < 1e-5:
                for param_group in opt.param_groups:
                    param_group['lr'] = 1e-5

        # Test
        if args.valid:
            io.cprint('*****Test*****')
            test_loss = 0
            model.eval()
            for data, label, seg in test_loader:
                with torch.no_grad():
                    if args.task == '1obj_rotate':
                        data1, data2, label1, label2 = obj_rotate_perm(
                            data, label)  # (B, N, 3)
                    elif args.task == '2obj':
                        data1, data2, label1, label2 = obj_2_perm(
                            data, label)  # (B, N, 3)
                    elif args.task == 'alter':
                        if epoch % 2 == 0:
                            data1, data2, label1, label2 = obj_rotate_perm(
                                data, label)  # (B, N, 3)
                        else:
                            data1, data2, label1, label2 = obj_2_perm(
                                data, label)  # (B, N, 3)
                    else:
                        print('Task not implemented!')
                        exit(0)

                    if args.mixup == 'emd':
                        mixup_data = emd_mixup(data1, data2)  # (B, N, 3)
                    elif args.mixup == 'add':
                        mixup_data = add_mixup(data1, data2)  # (B, N, 3)

                    mixup_data = mixup_data.permute(0, 2, 1)  # (B, 3, N)
                    batch_size = mixup_data.size()[0]

                    seg = seg - seg_start_index
                    label_one_hot1 = np.zeros((batch_size, 16))
                    label_one_hot2 = np.zeros((batch_size, 16))
                    for idx in range(batch_size):
                        label_one_hot1[idx, label1[idx]] = 1
                        label_one_hot2[idx, label2[idx]] = 1

                    label_one_hot1 = torch.from_numpy(
                        label_one_hot1.astype(np.float32))
                    label_one_hot2 = torch.from_numpy(
                        label_one_hot2.astype(np.float32))
                    data, label_one_hot1, label_one_hot2, seg = data.to(
                        device), label_one_hot1.to(device), label_one_hot2.to(
                            device), seg.to(device)

                    proj1 = rand_proj(data1)
                    proj2 = rand_proj(data2)

                    pred1 = model(mixup_data, proj1,
                                  label_one_hot1).permute(0, 2, 1)  # (B, N, 3)
                    pred2 = model(mixup_data, proj2,
                                  label_one_hot2).permute(0, 2, 1)  # (B, N, 3)

                    if args.loss == 'l1loss':
                        loss = L1_loss(pred1, data1) + L1_loss(pred2, data2)
                    elif args.loss == 'chamfer':
                        loss1 = chamfer_distance(
                            pred1, data1) + chamfer_distance(data1, pred1)
                        loss2 = chamfer_distance(
                            pred2, data2) + chamfer_distance(data2, pred2)
                        loss = loss1 + loss2
                    elif args.loss == 'emd':
                        loss = emd_loss(pred1, data1) + emd_loss(pred2, data2)
                    elif args.loss == 'emd2':
                        loss = emd_loss_2(pred1, data1) + emd_loss_2(
                            pred2, data2)
                    else:
                        raise NotImplementedError

                    test_loss = test_loss + loss.item()
            io.cprint(
                'Train loss: %.6f, Test loss: %.6f' %
                (train_loss / len(train_loader), test_loss / len(test_loader)))
            cur_loss = test_loss / len(test_loader)
            if cur_loss < min_loss:
                min_loss = cur_loss
                torch.save(
                    model.state_dict(), 'checkpoints/%s/best_%s.pkl' %
                    (args.exp_name, args.exp_name))
        if (epoch + 1) % 10 == 0:
            torch.save(
                model.state_dict(), 'checkpoints/%s/%s_epoch_%s.pkl' %
                (args.exp_name, args.exp_name, str(epoch)))
    torch.save(model.state_dict(),
               'checkpoints/%s/%s.pkl' % (args.exp_name, args.exp_name))
예제 #6
0
def main():
    args = parser.parse_args()
    with open(args.config) as f:
        config = yaml.load(f)
    print("\n**************************")
    for k, v in config['common'].items():
        setattr(args, k, v)
        print('\n[%s]:' % (k), v)
    print("\n**************************\n")

    try:
        os.makedirs(args.save_path)
    except OSError:
        pass

    train_dataset = ShapeNetPart(root=args.data_root,
                                 num_points=args.num_points,
                                 split='trainval',
                                 normalize=True)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=int(args.workers),
                                  pin_memory=True)

    global test_dataset_z
    test_dataset_z = ShapeNetPart(root=args.data_root,
                                  num_points=args.num_points,
                                  split='test',
                                  normalize=True)
    test_dataloader_z = DataLoader(test_dataset_z,
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   num_workers=int(args.workers),
                                   pin_memory=True)
    global test_dataset_so3
    test_dataset_so3 = ShapeNetPart(root=args.data_root,
                                    num_points=args.num_points,
                                    split='test',
                                    normalize=True)
    test_dataloader_so3 = DataLoader(test_dataset_so3,
                                     batch_size=args.batch_size,
                                     shuffle=False,
                                     num_workers=int(args.workers),
                                     pin_memory=True)

    if args.model == "pointnet2_ssn":
        model = PointNet2_SSN(num_classes=args.num_classes)
        model.cuda()
    elif args.model == "rscnn_msn":
        model = RSCNN_MSN(num_classes=args.num_classes)
        model.cuda()
        model = torch.nn.DataParallel(model)
    else:
        print("Doesn't support this model")
        return 0

    optimizer = optim.Adam(model.parameters(),
                           lr=args.base_lr,
                           weight_decay=args.weight_decay)
    lr_lbmd = lambda e: max(args.lr_decay**(e // args.decay_step), args.lr_clip
                            / args.base_lr)
    bnm_lmbd = lambda e: max(
        args.bn_momentum * args.bn_decay**
        (e // args.decay_step), args.bnm_clip)
    lr_scheduler = lr_sched.LambdaLR(optimizer, lr_lbmd)
    bnm_scheduler = pt_utils.BNMomentumScheduler(model, bnm_lmbd)

    if args.checkpoint is not '':
        model.load_state_dict(torch.load(args.checkpoint))
        print('Load model successfully: %s' % (args.checkpoint))

    criterion = nn.CrossEntropyLoss()
    num_batch = len(train_dataset) / args.batch_size

    # training
    # train(train_dataloader, test_dataloader_z, test_dataloader_so3, model, criterion, optimizer, lr_scheduler, bnm_scheduler, args, num_batch)
    validate(test_dataloader_so3, model, criterion, args, 1, 'so3')
예제 #7
0
def test(args, io):
    test_loader = DataLoader(ShapeNetPart(partition='test',
                                          num_points=args.num_points,
                                          class_choice=args.class_choice),
                             batch_size=args.test_batch_size,
                             shuffle=True,
                             drop_last=False)

    device = torch.device("cuda" if args.cuda else "cpu")

    #Try to load models
    seg_num_all = test_loader.dataset.seg_num_all
    seg_start_index = test_loader.dataset.seg_start_index
    if args.model == 'dgcnn':
        model = DGCNN_partseg(args, seg_num_all).to(device)
    else:
        raise Exception("Not implemented")

    model = nn.DataParallel(model)
    model.load_state_dict(torch.load(args.model_path))
    model = model.eval()
    test_acc = 0.0
    count = 0.0
    test_true_cls = []
    test_pred_cls = []
    test_true_seg = []
    test_pred_seg = []
    test_label_seg = []
    for data, label, seg in test_loader:
        seg = seg - seg_start_index
        label_one_hot = np.zeros((label.shape[0], 16))
        for idx in range(label.shape[0]):
            label_one_hot[idx, label[idx]] = 1
        label_one_hot = torch.from_numpy(label_one_hot.astype(np.float32))
        data, label_one_hot, seg = data.to(device), label_one_hot.to(
            device), seg.to(device)
        data = data.permute(0, 2, 1)
        batch_size = data.size()[0]
        seg_pred = model(data, label_one_hot)
        seg_pred = seg_pred.permute(0, 2, 1).contiguous()
        pred = seg_pred.max(dim=2)[1]
        seg_np = seg.cpu().numpy()
        pred_np = pred.detach().cpu().numpy()
        test_true_cls.append(seg_np.reshape(-1))
        test_pred_cls.append(pred_np.reshape(-1))
        test_true_seg.append(seg_np)
        test_pred_seg.append(pred_np)
        test_label_seg.append(label.reshape(-1))
    test_true_cls = np.concatenate(test_true_cls)
    test_pred_cls = np.concatenate(test_pred_cls)
    test_acc = metrics.accuracy_score(test_true_cls, test_pred_cls)
    avg_per_class_acc = metrics.balanced_accuracy_score(
        test_true_cls, test_pred_cls)
    test_true_seg = np.concatenate(test_true_seg, axis=0)
    test_pred_seg = np.concatenate(test_pred_seg, axis=0)
    test_label_seg = np.concatenate(test_label_seg)
    test_ious = calculate_shape_IoU(test_pred_seg, test_true_seg,
                                    test_label_seg, args.class_choice)
    outstr = 'Test :: test acc: %.6f, test avg acc: %.6f, test iou: %.6f' % (
        test_acc, avg_per_class_acc, np.mean(test_ious))
    io.cprint(outstr)