for epoch in range(opt.nepoch):
    scheduler.step()
    for i, data in enumerate(dataloader, 0):
        points, target = data
        points = points.transpose(2, 1)
        points, target = points.cuda(), target.cuda()
        optimizer.zero_grad()
        classifier = classifier.train()
        pred, trans, trans_feat = classifier(points)
        pred = pred.view(-1, num_classes)
        target = target.view(-1, 1)[:, 0] - 1
        #print(pred.size(), target.size())
        loss = F.nll_loss(pred, target)
        if opt.feature_transform:
            loss += feature_transform_regularizer(trans_feat) * 0.001
        loss.backward()
        optimizer.step()
        pred_choice = pred.data.max(1)[1]
        correct = pred_choice.eq(target.data).cpu().sum()
        print('[%d: %d/%d] train loss: %f accuracy: %f' %
              (epoch, i, num_batch, loss.item(),
               correct.item() / float(opt.batchSize * 2500)))

        if i % 10 == 0:
            j, data = next(enumerate(testdataloader, 0))
            points, target = data
            points = points.transpose(2, 1)
            points, target = points.cuda(), target.cuda()
            classifier = classifier.eval()
            pred, _, _ = classifier(points)
Ejemplo n.º 2
0
def train(lr=0.001):
    parser = argparse.ArgumentParser()
    opt = parser.parse_args()
    opt.nepoch = 1
    opt.batchsize = 18
    opt.workers = 0
    opt.outf = 'completion'
    opt.dataset = '/home/cdi0/data/shape_net_core_uniform_samples_2048_split/'
    opt.feature_transform = False
    opt.model = ''
    opt.device = 'cuda:1'
    opt.lr = lr

    opt.manualSeed = random.randint(1, 10000)  # fix seed
    print("Random Seed: ", opt.manualSeed)
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)

    dataset = ShapeNetDataset(dir=opt.dataset, )
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=opt.batchsize,
                                             shuffle=True,
                                             num_workers=int(opt.workers))

    test_dataset = ShapeNetDataset(
        dir=opt.dataset,
        train='test',
    )
    testdataloader = torch.utils.data.DataLoader(test_dataset,
                                                 batch_size=opt.batchsize,
                                                 shuffle=True,
                                                 num_workers=int(opt.workers))

    print(len(dataset), len(test_dataset))

    #try:
    #    os.makedirs(opt.outf)
    #except OSError:
    #    pass

    blue = lambda x: '\033[94m' + x + '\033[0m'
    device = opt.device

    netG = PointNetDenseCls(device=device,
                            feature_transform=opt.feature_transform)
    localD = LocalDiscriminator(k=2, device=device)
    globalD = GlobalDiscriminator(k=2, device=device)

    if opt.model != '':
        netG.load_state_dict(torch.load(opt.model))

    optimizerG = optim.Adam(netG.parameters(), lr=0.001, betas=(0.9, 0.999))
    optimizerD = optim.Adam(list(globalD.parameters()) +
                            list(localD.parameters()),
                            lr=0.001,
                            betas=(0.9, 0.999))

    schedulerG = optim.lr_scheduler.StepLR(optimizerG, step_size=20, gamma=0.5)
    schedulerD = optim.lr_scheduler.StepLR(optimizerD, step_size=20, gamma=0.5)

    netG.to(device)
    localD.to(device)
    globalD.to(device)

    criterion = distChamfer
    Dcriterion = nn.BCELoss()
    #Dcriterion = F.nll_loss

    real_label = 1
    fake_label = 0

    num_batch = len(dataset) / opt.batchsize
    writer = SummaryWriter()
    for epoch in range(opt.nepoch):
        for i, data in (enumerate(dataloader, 0)):
            #k = 614
            points, target, mask = data  # Nx4 or Nx3
            points = points.transpose(2, 1)  # 4xN
            points, target = points.to(device, dtype=torch.float), target.to(
                device, dtype=torch.float)
            b_size = points.shape[0]

            mask_ = mask.unsqueeze(2).repeat(1, 1, 3)
            #print(mask_.any(dim = 2).sum(dim=1))
            mask__ = ~mask_
            #print(mask__.any(dim = 2).sum(dim=1))
            mask__ = mask__.to(device, dtype=torch.float32)
            mask_ = mask_.to(device, dtype=torch.float32)

            optimizerD.zero_grad()

            localD = localD.train()
            globalD = globalD.train()

            ###### train D ######

            #label_real =  torch.stack((torch.zeros(b_size),torch.ones(b_size)), dim = 1).to(device, dtype = torch.long)
            #label_fake =  torch.stack((torch.ones(b_size),torch.zeros(b_size)), dim = 1).to(device, dtype = torch.long)

            label = torch.full((b_size, ), real_label, device=device)

            #print(mask__)
            #print(mask__[mask__.sum(dim=2) != 0].shape)
            target_mask = mask__ * target
            target_mask = target_mask[torch.abs(target_mask).sum(
                dim=2) != 0].view(b_size, -1, 3)

            target, target_mask = target.transpose(
                2, 1).contiguous(), target_mask.transpose(2, 1).contiguous()

            output_g = globalD(target)
            output_l = localD(target_mask)

            #rint(output_g.shape)
            #rint(output_l.shape)
            #rint(label.shape)

            errD_real_g = Dcriterion(output_g, label)
            errD_real_l = Dcriterion(output_l, label)

            errD_real = errD_real_g + errD_real_l
            errD_real.backward()

            target = target.transpose(2, 1).contiguous()

            pred = netG(points)

            #rint(pred.shape)
            ##int(target.shape)
            #rint(mask_.shape)
            #rint(mask__.shape)

            pred = (pred * mask__) + (target * mask_)

            pred_mask = pred * mask__
            pred_mask = pred_mask[torch.abs(pred_mask).sum(dim=2) != 0].view(
                b_size, -1, 3)

            pred, pred_mask = pred.transpose(
                2, 1).contiguous(), pred_mask.transpose(2, 1).contiguous()

            output_g = globalD(pred.detach())
            output_l = localD(pred_mask.detach())

            label.fill_(fake_label)

            errD_fake_g = Dcriterion(output_g, label)
            errD_fake_l = Dcriterion(output_l, label)

            errD_fake = errD_fake_g + errD_fake_l
            errD_fake.backward()

            errD = errD_real + errD_fake

            if errD.item() > 0.1:
                optimizerD.step()

            ###### train G ######

            optimizerG.zero_grad()
            optimizerD.zero_grad()

            netG = netG.train()

            output_g = globalD(pred)
            output_l = localD(pred_mask)

            label.fill_(real_label)

            errG_g = Dcriterion(output_g, label)
            errG_l = Dcriterion(output_l, label)

            errG = errG_g + errG_l

            pred = pred.transpose(2, 1).contiguous()

            #rint(pred.shape)
            #rint(target.shape)

            dist1, dist2 = criterion(pred, target)
            chamferloss = (torch.mean(dist1)) + (torch.mean(dist2))
            loss = chamferloss + errG

            loss.backward()

            if opt.feature_transform:
                loss += feature_transform_regularizer(trans_feat) * 0.001

            optimizerG.step()

            print('[%d: %d/%d] D_loss: %f, G_loss: %f, Chamfer_loss: %f ' %
                  (epoch, i, num_batch, errD.item(), errG.item(),
                   chamferloss.item()))

            if i % 10 == 0:
                j, data = next(enumerate(testdataloader, 0))
                points, target, mask = data
                points = points.transpose(2, 1)
                points, target = points.to(
                    device, dtype=torch.float), target.to(device,
                                                          dtype=torch.float)

                b_size = points.shape[0]

                localD = localD.eval()
                globalD = globalD.eval()

                ###### eval D ######
                label = torch.full((b_size, ), real_label, device=device)
                #label_real =  torch.stack((torch.zeros(b_size),torch.ones(b_size)), dim = 1).to(device)
                #label_fake =  torch.stack((torch.ones(b_size),torch.zeros(b_size)), dim = 1).to(device)

                mask_ = mask.unsqueeze(2).repeat(1, 1, 3)
                mask__ = ~mask_
                mask__ = mask__.to(device, dtype=torch.float32)
                mask_ = mask_.to(device, dtype=torch.float32)

                target_mask = mask__ * target
                target_mask = target_mask[torch.abs(target_mask).sum(
                    dim=2) != 0].view(b_size, -1, 3)

                target, target_mask = target.transpose(
                    2, 1).contiguous(), target_mask.transpose(2,
                                                              1).contiguous()

                output_g = globalD(target)
                output_l = localD(target_mask)

                errD_real_g_eval = Dcriterion(output_g, label)
                errD_real_l_eval = Dcriterion(output_l, label)

                errD_real_eval = errD_real_g_eval + errD_real_l_eval

                target = target.transpose(2, 1).contiguous()

                pred = netG(points)
                pred = (pred * mask__) + (target * mask_)

                pred_mask = pred * mask__
                pred_mask = pred_mask[torch.abs(pred_mask).sum(
                    dim=2) != 0].view(b_size, -1, 3)

                pred, pred_mask = pred.transpose(
                    2, 1).contiguous(), pred_mask.transpose(2, 1).contiguous()

                output_g_eval = globalD(pred.detach())
                output_l_eval = localD(pred_mask.detach())

                label.fill_(fake_label)

                errD_fake_g_eval = Dcriterion(output_g, label)
                errD_fake_l_eval = Dcriterion(output_l, label)

                errD_fake_eval = errD_fake_g_eval + errD_fake_l_eval

                errD_eval = errD_real_eval + errD_fake_eval

                ###### eval G ######

                netG = netG.eval()

                output_g = globalD(pred)
                output_l = localD(pred_mask)

                label.fill_(real_label)

                errG_g_eval = Dcriterion(output_g, label)
                errG_l_eval = Dcriterion(output_l, label)

                errG_eval = errG_g_eval + errG_l_eval

                pred = pred.transpose(2, 1).contiguous()

                dist1, dist2 = criterion(pred, target)
                chamferloss_eval = (torch.mean(dist1)) + (torch.mean(dist2))
                loss_eval = chamferloss_eval + errG_eval

                print('[%d: %d/%d] %s D_loss: %f, G_loss: %f ' %
                      (epoch, i, num_batch, blue('test'), errD_eval.item(),
                       loss.item()))

            if i % 100 == 0:
                n = int(i / 100)
                writer.add_scalar('errD_real', errD_real.item(),
                                  27 * epoch + n)
                writer.add_scalar('errD_fake', errD_fake.item(),
                                  27 * epoch + n)
                writer.add_scalar('errD_loss', errD.item(), 27 * epoch + n)

                writer.add_scalar('validation errD_real',
                                  errD_real_eval.item(), 27 * epoch + n)
                writer.add_scalar('validation errD_fake',
                                  errD_fake_eval.item(), 27 * epoch + n)
                writer.add_scalar('validation errD_loss', errD_eval.item(),
                                  27 * epoch + n)

                writer.add_scalar('errG_global', errG_g.item(), 27 * epoch + n)
                writer.add_scalar('errG_local', errG_l.item(), 27 * epoch + n)
                writer.add_scalar('chamfer_loss', chamferloss.item(),
                                  27 * epoch + n)
                writer.add_scalar('errG_loss', loss.item(), 27 * epoch + n)

                writer.add_scalar('validation errG_global', errG_g_eval.item(),
                                  27 * epoch + n)
                writer.add_scalar('validation errG_local', errG_l_eval.item(),
                                  27 * epoch + n)
                writer.add_scalar('validation chamfer_loss',
                                  chamferloss_eval.item(), 27 * epoch + n)
                writer.add_scalar('validation errG_loss', loss_eval.item(),
                                  27 * epoch + n)

                for name, param in globalD.named_parameters():
                    writer.add_histogram(name,
                                         param.clone().cpu().data.numpy(),
                                         27 * epoch + n)
                for name, param in localD.named_parameters():
                    writer.add_histogram(name,
                                         param.clone().cpu().data.numpy(),
                                         27 * epoch + n)
                for name, param in netG.named_parameters():
                    writer.add_histogram(name,
                                         param.clone().cpu().data.numpy(),
                                         27 * epoch + n)

        schedulerG.step()
        schedulerD.step()
        #torch.save(netG.state_dict(), '%s/com_model_G_%f_%d.pth' % (opt.outf, loss.item(), epoch))
        #torch.save(localD.state_dict(), '%s/com_model_localD_%f_%d.pth' % (opt.outf, errD.item(), epoch))
        #torch.save(globalD.state_dict(), '%s/com_model_globalD_%f_%d.pth' % (opt.outf, errD.item(), epoch))
    return errD.item(), errG, chamferloss
Ejemplo n.º 3
0
def our_main():
    from utils.show3d_balls import showpoints
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--batchSize', type=int, default=32, help='input batch size')
    parser.add_argument(
        '--num_points', type=int, default=2000, help='input batch size')
    parser.add_argument(
        '--workers', type=int, help='number of data loading workers', default=4)
    parser.add_argument(
        '--nepoch', type=int, default=250, help='number of epochs to train for')
    parser.add_argument('--outf', type=str, default='cls', help='output folder')
    parser.add_argument('--model', type=str, default='', help='model path')
    parser.add_argument('--dataset', type=str, required=True, help="dataset path")
    parser.add_argument('--dataset_type', type=str, default='shapenet', help="dataset type shapenet|modelnet40")
    parser.add_argument('--feature_transform', action='store_true', help="use feature transform")

    opt = parser.parse_args()
    print(opt)

    blue = lambda x: '\033[94m' + x + '\033[0m'

    opt.manualSeed = random.randint(1, 10000)  # fix seed
    print("Random Seed: ", opt.manualSeed)
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)

    if opt.dataset_type == 'shapenet':
        dataset = ShapeNetDataset(
            root=opt.dataset,
            classification=True,
            npoints=opt.num_points)

        test_dataset = ShapeNetDataset(
            root=opt.dataset,
            classification=True,
            split='test',
            npoints=opt.num_points,
            data_augmentation=False)
    elif opt.dataset_type == 'modelnet40':
        dataset = ModelNetDataset(
            root=opt.dataset,
            npoints=opt.num_points,
            split='trainval')

        test_dataset = ModelNetDataset(
            root=opt.dataset,
            split='test',
            npoints=opt.num_points,
            data_augmentation=False)
    else:
        exit('wrong dataset type')


    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=opt.batchSize,
        shuffle=True,
        num_workers=int(opt.workers))

    testdataloader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=opt.batchSize,
            shuffle=True,
            num_workers=int(opt.workers))

    print(len(dataset), len(test_dataset))
    num_classes = len(dataset.classes)
    print('classes', num_classes)

    try:
        os.makedirs(opt.outf)
    except OSError:
        pass

    classifier = PointNetCls(k=num_classes, feature_transform=opt.feature_transform)

    if opt.model != '':
        classifier.load_state_dict(torch.load(opt.model))


    optimizer = optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999))
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
    classifier.cuda()

    num_batch = len(dataset) / opt.batchSize

    ## python train_classification.py --dataset ../dataset --nepoch=4 --dataset_type  shapenet
    for epoch in range(opt.nepoch):
        scheduler.step()
        for i, data in enumerate(dataloader, 0):
            points, target = data
            target = target[:, 0]
            showpoints(points[0].numpy())
            points = points.transpose(2, 1)
            points, target = points.cuda(), target.cuda()
            optimizer.zero_grad()
            classifier = classifier.train()
            pred, trans, trans_feat = classifier(points)
            loss = F.nll_loss(pred, target)
            if opt.feature_transform:
                loss += feature_transform_regularizer(trans_feat) * 0.001
            loss.backward()
            optimizer.step()
            pred_choice = pred.data.max(1)[1]
            correct = pred_choice.eq(target.data).cpu().sum()
            print('[%d: %d/%d] train loss: %f accuracy: %f' % (epoch, i, num_batch, loss.item(), correct.item() / float(opt.batchSize)))

            if i % 10 == 0:
                j, data = next(enumerate(testdataloader, 0))
                points, target = data
                target = target[:, 0]
                points = points.transpose(2, 1)
                points, target = points.cuda(), target.cuda()
                classifier = classifier.eval()
                pred, _, _ = classifier(points)
                loss = F.nll_loss(pred, target)
                pred_choice = pred.data.max(1)[1]
                correct = pred_choice.eq(target.data).cpu().sum()
                print('[%d: %d/%d] %s loss: %f accuracy: %f' % (epoch, i, num_batch, blue('test'), loss.item(), correct.item()/float(opt.batchSize)))

        torch.save(classifier.state_dict(), '%s/cls_model_%d.pth' % (opt.outf, epoch))

    total_correct = 0
    total_testset = 0
    for i,data in tqdm(enumerate(testdataloader, 0)):
        points, target = data
        target = target[:, 0]
        points = points.transpose(2, 1)
        points, target = points.cuda(), target.cuda()
        classifier = classifier.eval()
        pred, _, _ = classifier(points)
        pred_choice = pred.data.max(1)[1]
        correct = pred_choice.eq(target.data).cpu().sum()
        total_correct += correct.item()
        total_testset += points.size()[0]

    print("final accuracy {}".format(total_correct / float(total_testset)))
                            comp_pre_label_dist[i, 0] = pred_labels_dist_cls[
                                i, majority_pre_id]
                            comp_pre_label_dist[i, 1] = sum_[
                                i] - pred_labels_dist_cls[i, majority_pre_id]
                        elif class_id < pred_labels_dist_cls.shape[1]:
                            comp_pre_label_dist[i, 1] = pred_labels_dist_cls[
                                i, majority_pre_id]
                            comp_pre_label_dist[i, 0] = sum_[
                                i] - pred_labels_dist_cls[i, majority_pre_id]
                        else:
                            comp_pre_label_dist[i, 0] = 0
                            comp_pre_label_dist[i, 1] = sum_[i]

                    seg_loss += loss_r(comp_pre_label_dist, comp_gt_label)

        loss = feature_transform_regularizer(
            trans) * 0.001 + seg_loss / 2500 + num_obj_loss
        loss.backward()
        optimizer.step()
        # correct = pred_labels.eq(gt_labels.data).cpu().sum()
        correct = 0
        # total = 0
        pred_labels_cpu = pred_labels.cpu()
        gt_labels_cpu = gt_labels.cpu()
        gt_numobj_cpu, _ = torch.max(gt_labels_cpu, dim=1)
        pred_numobj_cpu = pred_numobj.cpu()
        for b in range(batch_size):
            pred_numobj_ = pred_numobj_cpu[b]
            pred_labels_ = pred_labels_cpu[b]
            gt_labels_ = gt_labels_cpu[b].T
            gt_numobj_ = gt_numobj_cpu[b]
            # print("1:",  pred_labels_.min(), pred_labels_.max())
Ejemplo n.º 5
0
def PointNetSeg():
    classifier = point_net_seg(num_classes, feature_transform=opt.feature_transform)

    if opt.model != '':
        classifier.load_state_dict(torch.load(opt.model))  # if you have trained model params

    optimizer = optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999))
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
    classifier.cuda()  # model load to gpu

    num_batch = len(dataset) / opt.batch_size
    # save loss and acc:
    train_loss = {}
    test_loss = {}
    train_acc = {}
    test_acc = {}

    for epoch in range(opt.nepoch):
        for i, data in enumerate(dataloader, 0):
            points, target = data
            points = points.transpose(2, 1)
            points, target = points.cuda(), target.cuda()
            optimizer.zero_grad()
            classifier = classifier.train()

            try:
                pred, trans, trans_feat = classifier(points)
            except RuntimeError as exception:
                if "out of memory" in str(exception):
                    print("WARNING: out of memory")
                    if hasattr(torch.cuda, 'empty_cache'):
                        torch.cuda.empty_cache()
                else:
                    raise exception

            # pred, trans, trans_feat = classifier(points)
            pred = pred.view(-1,  num_classes)  # [B*N, k]
            target = target.view(-1, 1)[:, 0] - 1  # ShapeNet's label is from 1 to k
            loss = F.nll_loss(pred, target)  # -x[class]
            if opt.feature_transform:
                loss += feature_transform_regularizer(trans_feat) * 0.001
            loss.backward()
            optimizer.step()
            pred_seg = pred.data.max(1)[1]  # [B*N, k]->>[B*K, 1], max() return [values(probabilities), indices]
            correct = pred_seg.eq(target.data).cpu().sum()
            print('[%d: %d/%d] train loss: %f accuracy: %f' % (epoch, i, num_batch,
                                                               loss.item(),
                                                               correct.item()/float(opt.batch_size * 2500)))
            # add initial train loss and acc in first epoch:
            if epoch == 0 and i == 0:
                train_loss[epoch] = loss.item()
                train_acc[epoch] = correct.item() / float(opt.batch_size * 2500)

            if i%10 == 0:
                #  add train loss and acc in each epoch:
                if i+10 > num_batch:
                    train_loss[epoch+1] = loss.item()
                    train_acc[epoch+1] = correct.item() / float(opt.batch_size * 2500)

                j, data = next(enumerate(test_dataloader, 0))
                points, target = data
                points = points.transpose(2, 1)
                points, target = points.cuda(), target.cuda()
                classifier = classifier.eval()
                pred, _, _ = classifier(points)
                pred = pred.view(-1, num_classes)
                target = target.view(-1, 1)[:, 0] - 1
                loss = F.nll_loss(pred, target)
                pred_seg = pred.data.max(1)[1]
                correct = pred_seg.eq(target.data).cpu().sum()
                print('[%d: %d/%d] %s loss: %f accuracy: %f' % (epoch, i, num_batch,
                                                                blue('test'),
                                                                loss.item(),
                                                                correct.item() / float(opt.batch_size * 2500)))
                # add initial test loss and acc in first epoch:
                if epoch == 0 and i == 0:
                    test_loss[epoch] = loss.item()
                    test_acc[epoch] = correct.item() / float(opt.batch_size * 2500)
                # add test loss and acc in each epoch:
                if i+10 > num_batch:
                    test_loss[epoch+1] = loss.item()
                    test_acc[epoch+1] = correct.item() / float(opt.batch_size * 2500)

        scheduler.step()
        # save checkpoint every epoch:
        torch.save(classifier.state_dict(), '%s/seg_model_%s_%d.pth' % (opt.outf, opt.class_choice, epoch))

    # benchmark mIOU:
    shape_ious = []
    for i, data in tqdm(enumerate(test_dataloader, 0)):
        points, target = data
        points = points.transpose(2, 1)
        points, target = points.cuda(), target.cuda()
        classifier = classifier.eval()
        pred, _, _ = classifier(points)
        pred_seg = pred.data.max(2)[1]  # [B, N, k]->> [B, N, 1], [0]-max value, [1]-indices

        pred_np = pred_seg.cpu().data.numpy()  # [B, N, 1]
        target_np = target.cpu().data.numpy() - 1

        for shape_idx in range(target_np.shape[0]):
            parts = range(num_classes)
            part_ious = []
            for part in parts:
                I = np.sum(np.logical_and(pred_np[shape_idx] == part, target_np[shape_idx] == part))
                U = np.sum(np.logical_or(pred_np[shape_idx] == part, target_np[shape_idx] == part))
                if U == 0:
                    iou = 1  # #If the union of groundtruth and prediction points is empty, then count part IoU as 1
                else:
                    iou = I / U
                part_ious.append(iou)
            shape_ious.append(np.mean(part_ious))

    print('mIOU for class {}: {}'.format(opt.class_choice, np.mean(shape_ious)))
    return train_loss, test_loss, train_acc, test_acc
Ejemplo n.º 6
0
def PointNetCls():
    classifier = point_net_cls(k=num_classes, feature_transform=opt.feature_transform)  # default is False!

    if opt.model != '':
        classifier.load_state_dict(torch.load(opt.model))  # if you have trained model params

    optimizer = optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999))
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
    classifier.cuda()  # model load to gpu

    num_batch = len(dataset) / opt.batch_size
    # save loss and acc:
    train_loss = {}
    test_loss = {}
    train_acc = {}
    test_acc = {}

    for epoch in range(opt.nepoch):
        # scheduler.step()
        for i, data in enumerate(dataloader, 0):
            points, target = data
            target = target[:, 0]  # [B, 1]->>size([B])
            # print(target.shape)
            points = points.transpose(2, 1)  # [B=32, 3, N]
            points, target = points.cuda(), target.cuda()
            optimizer.zero_grad()
            classifier = classifier.train()  # for training mode
            # pred, trans, trans_feat = classifier(points)
            try:
                pred, trans, trans_feat = classifier(points)
            except RuntimeError as exception:
                if "out of memory" in str(exception):
                    print("WARNING: out of memory")
                    if hasattr(torch.cuda, 'empty_cache'):
                        torch.cuda.empty_cache()
                else:
                    raise exception

            loss = F.nll_loss(pred, target)
            if opt.feature_transform:
                loss += feature_transform_regularizer(trans_feat) * 0.01
            loss.backward()
            optimizer.step()
            pred_cls = pred.data.max(1)[1]  # [B, k]->>[B, 1]
            correct = pred_cls.eq(target.data).cpu().sum()  # num of correct predict in batch_i
            print('[%d: %d/%d] train loss: %f accuracy: %f' % (epoch, i, num_batch,
                                                               loss.item(),
                                                               correct.item() / float(opt.batch_size)))
            # add initial train loss and acc in first epoch:
            if epoch == 0 and i == 0:
                train_loss[epoch] = loss.item()
                train_acc[epoch] = correct.item() / float(opt.batch_size)

            # show acc in one batch test_data every 10 batch_size:
            if i % 10 == 0:
                #  add train loss and acc in each epoch:
                if i+10 > num_batch:
                    train_loss[epoch+1] = loss.item()
                    train_acc[epoch+1] = correct.item() / float(opt.batch_size)

                j, data = next(enumerate(test_dataloader, 0))
                points, target = data
                target = target[:, 0]
                points = points.transpose(2, 1)
                points, target = points.cuda(), target.cuda()
                classifier = classifier.eval()  # for evaluation mode
                pred, _, _ = classifier(points)
                loss = F.nll_loss(pred, target)

                pred_cls = pred.data.max(1)[1]
                correct = pred_cls.eq(target.data).cpu().sum()
                print(correct.item(), opt.batch_size)
                print('[%d: %d/%d] %s loss: %f accuracy: %f' % (epoch, i, num_batch, blue('test'),
                                                                loss.item(),
                                                                correct.item() / float(opt.batch_size)))
                # add initial test loss and acc in first epoch:
                if epoch == 0 and i == 0:
                    test_loss[epoch] = loss.item()
                    test_acc[epoch] = correct.item() / float(opt.batch_size)
                # add test loss and acc in each epoch:
                if i+10 > num_batch:
                    test_loss[epoch+1] = loss.item()
                    test_acc[epoch+1] = correct.item() / float(opt.batch_size)

        scheduler.step()
        # save checkpoint every epoch:
        torch.save(classifier.state_dict(), '%s/cls_model_%d.pth' % (opt.outf, epoch))  # default: 'cls/cls_model_0.pth'

    # calculate acc on whole test dataset:
    total_correct = 0
    total_testset = 0
    for i, data in tqdm(enumerate(test_dataloader, 0)):
        points, target = data
        target = target[:, 0]
        points = points.transpose(2, 1)
        points, target = points.cuda(), target.cuda()
        classifier = classifier.eval()
        pred, _, _ = classifier(points)
        pred_cls = pred.data.max(1)[1]
        correct = pred_cls.eq(target.data).cpu().sum()
        total_correct += correct.item()
        total_testset += points.size()[0]  # add batch_size

    print('final accuracy {}'.format(total_correct / float(total_testset)))
    return train_loss, test_loss, train_acc, test_acc