Ejemplo n.º 1
0
def train(train_dataloader, test_dataloader, model, criterion, optimizer,
          lr_scheduler, bnm_scheduler, args, num_batch):
    PointcloudScaleAndTranslate = d_utils.PointcloudScaleAndTranslate(
    )  # initialize augmentation
    global g_acc
    g_acc = 0.91  # only save the model whose acc > 0.91
    batch_count = 0
    model.train()
    for epoch in range(args.epochs):
        for i, data in enumerate(train_dataloader, 0):
            if lr_scheduler is not None:
                lr_scheduler.step(epoch)
            if bnm_scheduler is not None:
                bnm_scheduler.step(epoch - 1)
            points, target = data
            points, target = points.cuda(), target.cuda()
            points, target = Variable(points), Variable(target)

            # farthest point sampling
            # fps_idx = pointnet2_utils.furthest_point_sample(points, 1200)  # (B, npoint)

            # random sampling
            fps_idx = np.random.randint(0,
                                        points.shape[1] - 1,
                                        size=[points.shape[0], 1200])
            fps_idx = torch.from_numpy(fps_idx).type(torch.IntTensor).cuda()

            fps_idx = fps_idx[:,
                              np.random.choice(1200, args.num_points, False)]
            points = pointnet2_utils.gather_operation(
                points.transpose(1, 2).contiguous(),
                fps_idx).transpose(1, 2).contiguous()  # (B, N, 3)

            # augmentation
            points.data = PointcloudScaleAndTranslate(points.data)

            optimizer.zero_grad()

            pred = model(points)
            target = target.view(-1)
            loss = criterion(pred, target)
            loss.backward()
            optimizer.step()
            if i % args.print_freq_iter == 0:
                print(
                    '[epoch %3d: %3d/%3d] \t train loss: %0.6f \t lr: %0.5f' %
                    (epoch + 1, i, num_batch, loss.data.clone(),
                     lr_scheduler.get_lr()[0]))
            batch_count += 1

            # validation in between an epoch
            if args.evaluate and batch_count % int(
                    args.val_freq_epoch * num_batch) == 0:
                validate(test_dataloader, model, criterion, args, batch_count)
Ejemplo n.º 2
0
def train(train_dataloader, test_dataloader, model, criterion, optimizer,
          lr_scheduler, bnm_scheduler, args, num_batch):
    PointcloudScaleAndTranslate = d_utils.PointcloudScaleAndTranslate(
    )  # initialize augmentation
    global Class_mIoU, Inst_mIoU
    Class_mIoU, Inst_mIoU = 0.83, 0.85
    batch_count = 0
    #model.train()
    for epoch in range(args.epochs):
        model.train()
        losses = []
        start_time = time.time()
        for i, data in enumerate(train_dataloader):

            points, target, cls = data
            target = target.cuda()
            points = points.cuda()
            one_hot_target = to_categorical(target, 50)
            print('target', target.shape)
            # augmentation
            points = PointcloudScaleAndTranslate(points)
            optimizer.zero_grad()

            batch_one_hot_cls = np.zeros((len(cls), 16))  # 16 object classes
            for b in range(len(cls)):
                batch_one_hot_cls[b, int(cls[b])] = 1
            batch_one_hot_cls = torch.from_numpy(
                batch_one_hot_cls).float().cuda()
            # batch_one_hot_cls = Variable(batch_one_hot_cls.float().cuda())

            pred, context = model(points, batch_one_hot_cls)
            pred = pred.view(-1, args.num_classes)
            target = target.view(-1, 1)[:, 0]
            loss = criterion(pred, target, None, context, one_hot_target)

            loss.backward()

            optimizer.step()
            losses.append(loss.item())
            if lr_scheduler is not None:
                lr_scheduler.step(epoch)
            if bnm_scheduler is not None:
                bnm_scheduler.step(epoch - 1)
            #if i % args.print_freq_iter == 0:
            #    print('[epoch %3d: %3d/%3d] \t train loss: %0.6f \t lr: %0.5f' %(epoch+1, i, num_batch, loss.data.clone(), lr_scheduler.get_lr()[0]))
            batch_count += 1

            # validation in between an epoch
            #if (epoch < 3 or epoch > 40) and args.evaluate and batch_count % int(args.val_freq_epoch * num_batch) == 0:
        end_time = time.time()
        print('[epoch %3d time=%d s] \t train loss: %0.6f \t lr: %0.5f' %
              (epoch + 1, end_time - start_time, np.array(losses).mean(),
               lr_scheduler.get_lr()[0]))
        validate(test_dataloader, model, criterion, args, batch_count)
Ejemplo n.º 3
0
def train(train_dataloader, test_dataloader_z, test_dataloader_so3, model,
          criterion, optimizer, lr_scheduler, bnm_scheduler, args, num_batch):
    global Class_mIoU, Inst_mIoU
    Class_mIoU, Inst_mIoU = 0.75, 0.75
    batch_count = 0
    aug = d_utils.SO3Rotate()
    model.train()
    for epoch in range(args.epochs):
        for i, data in enumerate(train_dataloader, 0):
            if lr_scheduler is not None:
                lr_scheduler.step(epoch)
            if bnm_scheduler is not None:
                bnm_scheduler.step(epoch - 1)

            points, norm, target, cls = data
            points, norm, target = points.cuda(), norm.cuda(), target.cuda()

            # RS-CNN performs a translation to the input first
            if args.model == "rscnn_msn":
                points.data = d_utils.PointcloudScaleAndTranslate()(
                    points.data)

            points.data, norm.data = aug(points.data, norm.data)

            optimizer.zero_grad()

            batch_one_hot_cls = np.zeros((len(cls), 16))  # 16 object classes
            for b in range(len(cls)):
                batch_one_hot_cls[b, int(cls[b])] = 1
            batch_one_hot_cls = torch.from_numpy(batch_one_hot_cls)
            batch_one_hot_cls = batch_one_hot_cls.float().cuda()
            pred = model(points, norm, batch_one_hot_cls)
            pred = pred.view(-1, args.num_classes)
            target = target.view(-1, 1)[:, 0]
            loss = criterion(pred, target)
            loss.backward()
            optimizer.step()

            if i % args.print_freq_iter == 0:
                print(
                    '[epoch %3d: %3d/%3d] \t train loss: %0.6f \t lr: %0.5f' %
                    (epoch + 1, i, num_batch, loss.data.clone(),
                     lr_scheduler.get_lr()[0]))
            batch_count += 1

            if args.evaluate and batch_count % int(
                    args.val_freq_epoch * num_batch) == 0:
                # validate(test_dataloader_z, model, criterion, args, batch_count, 'z')
                validate(test_dataloader_so3, model, criterion, args,
                         batch_count, 'so3')
def train(train_dataloader, test_dataloader, model, criterion, optimizer,
          lr_scheduler, bnm_scheduler, args, num_batch):
    PointcloudScaleAndTranslate = d_utils.PointcloudScaleAndTranslate(
    )  # initialize augmentation
    global Class_mIoU, Inst_mIoU
    Class_mIoU, Inst_mIoU = 0.83, 0.85
    batch_count = 0
    model.train()
    for epoch in range(args.epochs):
        for i, data in enumerate(train_dataloader, 0):
            if lr_scheduler is not None:
                lr_scheduler.step(epoch)
            if bnm_scheduler is not None:
                bnm_scheduler.step(epoch - 1)
            points, target, cls = data
            points, target = points.cuda(), target.cuda()
            points, target = Variable(points), Variable(target)
            # augmentation
            points.data = PointcloudScaleAndTranslate(points.data)

            optimizer.zero_grad()

            batch_one_hot_cls = np.zeros((len(cls), 16))  # 16 object classes
            for b in range(len(cls)):
                batch_one_hot_cls[b, int(cls[b])] = 1
            batch_one_hot_cls = torch.from_numpy(batch_one_hot_cls)
            batch_one_hot_cls = Variable(batch_one_hot_cls.float().cuda())

            pred = model(points, batch_one_hot_cls)
            pred = pred.view(-1, args.num_classes)
            target = target.view(-1, 1)[:, 0]
            loss = criterion(pred, target)
            loss.backward()
            optimizer.step()

            if i % args.print_freq_iter == 0:
                print(
                    '[epoch %3d: %3d/%3d] \t train loss: %0.6f \t lr: %0.5f' %
                    (epoch + 1, i, num_batch, loss.data.clone(),
                     lr_scheduler.get_lr()[0]))
            batch_count += 1

            # validation in between an epoch
            if (epoch < 3
                    or epoch > 40) and args.evaluate and batch_count % int(
                        args.val_freq_epoch * num_batch) == 0:
                validate(test_dataloader, model, criterion, args, batch_count)
Ejemplo n.º 5
0
def train(train_dataloader, test_dataloader_z, test_dataloader_so3, model,
          criterion, optimizer, lr_scheduler, bnm_scheduler, args, num_batch):
    aug = d_utils.ZRotate()
    global g_acc
    g_acc = 0.88  # only save the model whose acc > g_acc
    batch_count = 0
    model.train()
    for epoch in range(args.epochs):
        for i, data in enumerate(train_dataloader, 0):
            if lr_scheduler is not None:
                lr_scheduler.step(epoch)
            if bnm_scheduler is not None:
                bnm_scheduler.step(epoch - 1)
            points, normals, target = data
            points, normals, target = points.cuda(), normals.cuda(
            ), target.cuda()
            if args.model == "pointnet2":
                fps_idx = pointnet2_utils.furthest_point_sample(
                    points, 1024)  # (B, npoint)
                points = pointnet2_utils.gather_operation(
                    points.transpose(1, 2).contiguous(),
                    fps_idx).transpose(1, 2).contiguous()  # (B, N, 3)
                normals = pointnet2_utils.gather_operation(
                    normals.transpose(1, 2).contiguous(),
                    fps_idx).transpose(1, 2).contiguous()
            else:
                fps_idx = pointnet2_utils.furthest_point_sample(
                    points, 1200)  # (B, npoint)
                fps_idx = fps_idx[:,
                                  np.random.choice(1200, args.num_points, False
                                                   )]
                points = pointnet2_utils.gather_operation(
                    points.transpose(1, 2).contiguous(),
                    fps_idx).transpose(1, 2).contiguous()  # (B, N, 3)
                normals = pointnet2_utils.gather_operation(
                    normals.transpose(1, 2).contiguous(),
                    fps_idx).transpose(1, 2).contiguous()
                # # RS-CNN performs a translation to the input first
                points.data = d_utils.PointcloudScaleAndTranslate()(
                    points.data)

            points.data, normals.data = aug(points.data, normals.data)

            optimizer.zero_grad()
            pred = model(points, normals)
            target = target.view(-1)
            loss = criterion(pred, target)
            loss.backward()
            optimizer.step()
            if i % args.print_freq_iter == 0:
                print(
                    '[epoch %3d: %3d/%3d] \t train loss: %0.6f \t lr: %0.5f' %
                    (epoch + 1, i, num_batch, loss.data.clone(),
                     lr_scheduler.get_lr()[0]))
            batch_count += 1

            # validation in between an epoch
            if args.evaluate and batch_count % int(
                    args.val_freq_epoch * num_batch) == 0:
                # validate(test_dataloader_z, model, criterion, args, batch_count, 'z')
                validate(test_dataloader_so3, model, criterion, args,
                         batch_count, 'so3')
Ejemplo n.º 6
0
def main():
    args = parser.parse_args()
    with open(args.config) as f:
        config = yaml.load(f)
    print("\n**************************")
    for k, v in config['common'].items():
        setattr(args, k, v)
        print('\n[%s]:' % (k), v)
    print("\n**************************\n")

    try:
        os.makedirs(args.save_path)
    except OSError:
        pass

    train_transforms = transforms.Compose([
        d_utils.PointcloudToTensor(),
        d_utils.PointcloudScaleAndTranslate(),
        d_utils.PointcloudRandomInputDropout()
    ])
    test_transforms = transforms.Compose([
        d_utils.PointcloudToTensor(),
        #d_utils.PointcloudScaleAndTranslate()
    ])

    train_dataset = ModelNet40Cls(num_points=args.num_points,
                                  root=args.data_root,
                                  transforms=train_transforms)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=int(args.workers))

    test_dataset = ModelNet40Cls(num_points=args.num_points,
                                 root=args.data_root,
                                 transforms=test_transforms,
                                 train=False)
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=int(args.workers))

    model = RSCNN_SSN(num_classes=args.num_classes,
                      input_channels=args.input_channels,
                      relation_prior=args.relation_prior,
                      use_xyz=True)
    model.cuda()
    optimizer = optim.Adam(model.parameters(),
                           lr=args.base_lr,
                           weight_decay=args.weight_decay)

    lr_lbmd = lambda e: max(args.lr_decay**(e // args.decay_step), args.lr_clip
                            / args.base_lr)
    bnm_lmbd = lambda e: max(
        args.bn_momentum * args.bn_decay**
        (e // args.decay_step), args.bnm_clip)
    lr_scheduler = lr_sched.LambdaLR(optimizer, lr_lbmd)
    bnm_scheduler = pt_utils.BNMomentumScheduler(model, bnm_lmbd)

    if args.checkpoint is not '':
        model.load_state_dict(torch.load(args.checkpoint))
        print('Load model successfully: %s' % (args.checkpoint))

    criterion = nn.CrossEntropyLoss()
    num_batch = len(train_dataset) / args.batch_size

    # training
    train(train_dataloader, test_dataloader, model, criterion, optimizer,
          lr_scheduler, bnm_scheduler, args, num_batch)
Ejemplo n.º 7
0
def train(ss_dataloader, train_dataloader, test_dataloader, encoder, decoer,
          optimizer, lr_scheduler, bnm_scheduler, args, num_batch,
          begin_epoch):
    PointcloudScaleAndTranslate = d_utils.PointcloudScaleAndTranslate(
    )  # initialize augmentation
    PointcloudRotate = d_utils.PointcloudRotate()
    metric_criterion = MetricLoss()
    chamfer_criterion = ChamferLoss()
    global svm_best_acc40
    batch_count = 0
    encoder.train()
    decoer.train()

    for epoch in range(begin_epoch, args.epochs):
        np.random.seed()
        for i, data in enumerate(ss_dataloader, 0):
            if lr_scheduler is not None:
                lr_scheduler.step(epoch)
            if bnm_scheduler is not None:
                bnm_scheduler.step(epoch - 1)
            points = data
            points = Variable(points.cuda())

            # data augmentation
            sampled_points = 1200
            has_normal = (points.size(2) > 3)

            if has_normal:
                normals = points[:, :, 3:6].contiguous()
            points = points[:, :, 0:3].contiguous()

            fps_idx = pointnet2_utils.furthest_point_sample(
                points, sampled_points)  # (B, npoint)
            fps_idx = fps_idx[:,
                              np.random.choice(sampled_points, args.
                                               num_points, False)]
            points_gt = pointnet2_utils.gather_operation(
                points.transpose(1, 2).contiguous(),
                fps_idx).transpose(1, 2).contiguous()  # (B, N, 3)
            if has_normal:
                normals = pointnet2_utils.gather_operation(
                    normals.transpose(1, 2).contiguous(), fps_idx)
            points = PointcloudScaleAndTranslate(points_gt.data)

            # optimize
            optimizer.zero_grad()

            features1, fuse_global, normals_pred = encoder(points)
            global_feature1 = features1[2].squeeze(2)
            refs1 = features1[0:2]
            recon1 = decoer(fuse_global).transpose(1, 2)  # bs, np, 3

            loss_metric = metric_criterion(global_feature1, refs1)
            loss_recon = chamfer_criterion(recon1, points_gt)
            if has_normal:
                loss_normals = NormalLoss(normals_pred, normals)
            else:
                loss_normals = normals_pred.new(1).fill_(0)
            loss = loss_recon + loss_metric + loss_normals
            loss.backward()
            optimizer.step()
            if i % args.print_freq_iter == 0:
                print(
                    '[epoch %3d: %3d/%3d] \t metric/chamfer/normal loss: %0.6f/%0.6f/%0.6f \t lr: %0.5f'
                    % (epoch + 1, i, num_batch, loss_metric.item(),
                       loss_recon.item(), loss_normals.item(),
                       lr_scheduler.get_lr()[0]))
            batch_count += 1

            # validation
            if args.evaluate and batch_count % int(
                    args.val_freq_epoch * num_batch) == 0:
                svm_acc40 = validate(train_dataloader, test_dataloader,
                                     encoder, args)

                save_dict = {
                    'epoch': epoch +
                    1,  # after training one epoch, the start_epoch should be epoch+1
                    'optimizer_state_dict': optimizer.state_dict(),
                    'encoder_state_dict': encoder.state_dict(),
                    'decoder_state_dict': decoer.state_dict(),
                    'svm_best_acc40': svm_best_acc40,
                }
                checkpoint_name = './ckpts/' + args.name + '.pth'
                torch.save(save_dict, checkpoint_name)
                if svm_acc40 == svm_best_acc40:
                    checkpoint_name = './ckpts/' + args.name + '_best.pth'
                    torch.save(save_dict, checkpoint_name)
Ejemplo n.º 8
0
def validate(test_dataloader, model, criterion, args, iter):
    global Class_mIoU, Inst_mIoU, test_dataset
    model.eval()
    PointcloudScaleAndTranslate = d_utils.PointcloudScaleAndTranslate(
    )  # initialize augmentation
    seg_classes = test_dataset.seg_classes
    shape_ious = {cat: [] for cat in seg_classes.keys()}
    seg_label_to_cat = {}  # {0:Airplane, 1:Airplane, ...49:Table}
    for cat in seg_classes.keys():
        for label in seg_classes[cat]:
            seg_label_to_cat[label] = cat

    losses = 0.0
    lens = len(test_dataloader)
    with torch.no_grad():
        for _, data in enumerate(test_dataloader):
            points, target, cls = data
            target = target.cuda()
            points = points.cuda()
            one_hot_target = to_categorical(target, 50)
            # augmentation
            #points = PointcloudScaleAndTranslate(points)

            batch_one_hot_cls = np.zeros((len(cls), 16))  # 16 object classes
            for b in range(len(cls)):
                batch_one_hot_cls[b, int(cls[b])] = 1
            batch_one_hot_cls = torch.from_numpy(
                batch_one_hot_cls).float().cuda()
            # batch_one_hot_cls = Variable(batch_one_hot_cls.float().cuda())

            pred, context = model(points, batch_one_hot_cls)
            pred_t = pred.view(-1, args.num_classes)
            target_t = target.view(-1, 1)[:, 0]
            loss = criterion(pred_t, target_t, None, context, one_hot_target)
            losses += loss.item()
            """
            points, target, cls = data
            #points, target = Variable(points, volatile=True), Variable(target, volatile=True)
            points, target = points.cuda(), target.cuda()
            one_hot_target=to_categorical(target,50)
            batch_one_hot_cls = np.zeros((len(cls), 16))   # 16 object classes
            for b in range(len(cls)):
                batch_one_hot_cls[b, int(cls[b])] = 1
            batch_one_hot_cls = torch.from_numpy(batch_one_hot_cls).float().cuda()
            # batch_one_hot_cls = Variable(batch_one_hot_cls.float().cuda())
            pred,context= model(points, batch_one_hot_cls)
            loss = criterion(pred.view(-1, args.num_classes), target.view(-1,1)[:,0],None,context,one_hot_target)
            losses+=loss.item()
            pred = pred.cpu()
            target = target.cpu()
            """
            pred_val = torch.zeros(len(cls),
                                   args.num_points).type(torch.LongTensor)
            # pred to the groundtruth classes (selected by seg_classes[cat])
            for b in range(len(cls)):
                cat = seg_label_to_cat[int(target[b, 0].cpu().numpy())]
                logits = pred[b, :, :]  # (num_points, num_classes)
                pred_val[b, :] = logits[:, seg_classes[cat]].max(
                    1)[1] + seg_classes[cat][0]

            for b in range(len(cls)):
                segp = pred_val[b, :].cpu().numpy()
                segl = target[b, :].cpu().numpy()
                cat = seg_label_to_cat[int(segl[0])]
                part_ious = [0.0 for _ in range(len(seg_classes[cat]))]
                for l in seg_classes[cat]:
                    if np.sum((segl == l) | (segp == l)) == 0:
                        # part is not present in this shape
                        part_ious[l - seg_classes[cat][0]] = 1.0
                    else:
                        part_ious[l - seg_classes[cat][0]] = np.sum(
                            (segl == l) & (segp == l)) / float(
                                np.sum((segl == l) | (segp == l)))
                shape_ious[cat].append(np.mean(part_ious))

        instance_ious = []
        for cat in shape_ious.keys():
            for iou in shape_ious[cat]:
                instance_ious.append(iou)
            shape_ious[cat] = np.mean(shape_ious[cat])
        mean_class_ious = np.mean(list(shape_ious.values()))

        for cat in sorted(shape_ious.keys()):
            print('****** %s: %0.6f' % (cat, shape_ious[cat]))
        print('************ Test Loss: %0.6f' % (losses / lens))
        print('************ Class_mIoU: %0.6f' % (mean_class_ious))
        print('************ Instance_mIoU: %0.6f' % (np.mean(instance_ious)))

        if mean_class_ious > Class_mIoU or np.mean(instance_ious) > Inst_mIoU:
            if mean_class_ious > Class_mIoU:
                Class_mIoU = mean_class_ious
            if np.mean(instance_ious) > Inst_mIoU:
                Inst_mIoU = np.mean(instance_ious)
            torch.save(
                model.state_dict(),
                '%s/seg_msn_iter_%d_ins_%0.6f_cls_%0.6f.pth' %
                (args.save_path, iter, np.mean(instance_ious),
                 mean_class_ious))