Example #1
0
def main(args):
    """HYPER PARAMETER"""
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    datapath = "../../Dataset/fusion"
    """CREATE DIR"""
    experiment_dir = Path('./experiment/')
    experiment_dir.mkdir(exist_ok=True)
    file_dir = Path(
        str(experiment_dir) + '/%sSSGPD-' % args.model_name +
        str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')))
    file_dir.mkdir(exist_ok=True)
    checkpoints_dir = file_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = file_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    """LOG"""
    args = parse_args()
    logger = logging.getLogger(args.model_name)
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler(
        str(log_dir) + 'train_%s_cls.txt' % args.model_name)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    logger.info(
        '---------------------------------------------------TRANING---------------------------------------------------'
    )
    logger.info('PARAMETER ...')
    logger.info(args)
    """DATA LOADING"""
    trainDataset = OneViewDatasetLoader(grasp_points_num=1024,
                                        dataset_path=datapath,
                                        tag='train')
    testDataset = OneViewDatasetLoader(grasp_points_num=1024,
                                       dataset_path=datapath,
                                       tag='test')
    trainDataLoader = torch.utils.data.DataLoader(trainDataset,
                                                  batch_size=args.batchsize,
                                                  shuffle=True)
    testDataLoader = torch.utils.data.DataLoader(testDataset,
                                                 batch_size=args.batchsize,
                                                 shuffle=False)
    """MODEL LOADING"""
    num_class = 2  # Mark
    classifier = PointConvClsSsg(num_class).cuda()
    if args.pretrain is not None:
        print('Use pretrain model...')
        logger.info('Use pretrain model')
        checkpoint = torch.load(args.pretrain)
        start_epoch = checkpoint['epoch']
        classifier.load_state_dict(checkpoint['model_state_dict'])
    else:
        print('No existing model, starting training from scratch...')
        start_epoch = 0

    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(classifier.parameters(),
                                    lr=0.01,
                                    momentum=0.9)
    elif args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(classifier.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=args.decay_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=30,
                                                gamma=0.5)
    global_epoch = 0
    global_step = 0
    best_tst_accuracy = 0.0
    blue = lambda x: '\033[94m' + x + '\033[0m'
    """TRANING"""
    logger.info('Start training...')
    for epoch in range(start_epoch, args.epoch):
        print('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch))
        logger.info('Epoch %d (%d/%s):', global_epoch + 1, epoch + 1,
                    args.epoch)

        scheduler.step()
        for batch_id, data in tqdm(enumerate(trainDataLoader, 0),
                                   total=len(trainDataLoader),
                                   smoothing=0.9):
            points, target = data
            # target = target[:, 0]
            # points = points.transpose(2, 1)
            points, target = points.float().cuda(), target.long().squeeze(
            ).cuda()  # Mark
            optimizer.zero_grad()
            classifier = classifier.train()

            # start = time.perf_counter()
            # print(points.shape)
            pred = classifier(points)
            # print(time.perf_counter()-start)

            loss = F.nll_loss(pred, target.long())

            loss.backward()
            optimizer.step()
            global_step += 1

        train_acc = test(classifier.eval(),
                         trainDataLoader) if args.train_metric else None
        acc = test(classifier, testDataLoader)

        print('\r Loss: %f' % loss.data)
        logger.info('Loss: %.2f', loss.data)
        if args.train_metric:
            print('Train Accuracy: %f' % train_acc)
            logger.info('Train Accuracy: %f', (train_acc))
        print(
            '\r Test %s: %f   ***  %s: %f' %
            (blue('Accuracy'), acc, blue('Best Accuracy'), best_tst_accuracy))
        logger.info('Test Accuracy: %f  *** Best Test Accuracy: %f', acc,
                    best_tst_accuracy)

        if (acc >= best_tst_accuracy) and epoch > 0:
            best_tst_accuracy = acc
            logger.info('Save model...')
            save_checkpoint(global_epoch + 1,
                            train_acc if args.train_metric else 0.0, acc,
                            classifier, optimizer, str(checkpoints_dir),
                            args.model_name)
            print('Saving model....')
        global_epoch += 1
    print('Best Accuracy: %f' % best_tst_accuracy)

    logger.info('End of training...')
Example #2
0
def main(args):
    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    datapath = './data/ModelNet/'
    '''CREATE DIR'''
    experiment_dir = Path('./eval_experiment/')
    experiment_dir.mkdir(exist_ok=True)
    file_dir = Path(
        str(experiment_dir) + '/%sModelNet40-' % args.model_name +
        str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')))
    file_dir.mkdir(exist_ok=True)
    checkpoints_dir = file_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    os.system('cp %s %s' % (args.kb1checkpoint, checkpoints_dir))
    log_dir = file_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG'''
    args = parse_args()
    logger = logging.getLogger(args.model_name)
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler(
        str(log_dir) + 'eval_%s_cls.txt' % args.model_name)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    logger.info(
        '---------------------------------------------------EVAL---------------------------------------------------'
    )
    logger.info('PARAMETER ...')
    logger.info(args)
    '''DATA LOADING'''
    logger.info('Load dataset ...')
    train_data, train_label, test_data, test_label = load_data(
        datapath, classification=True)
    logger.info("The number of training data is: %d", train_data.shape[0])
    logger.info("The number of test data is: %d", test_data.shape[0])
    testDataset = ModelNetDataLoader(test_data, test_label)
    testDataLoader = torch.utils.data.DataLoader(testDataset,
                                                 batch_size=args.batchsize,
                                                 shuffle=False)
    '''MODEL LOADING'''
    num_class = 39
    kb1classifier = PointConvClsSsg(num_class).cuda()
    if args.kb1checkpoint is not None:
        print('Load k but 1 CheckPoint...')
        logger.info('Load k but 1 CheckPoint')
        kb1checkpoint = torch.load(args.kb1checkpoint)
        start_epoch = kb1checkpoint['epoch']
        kb1classifier.load_state_dict(kb1checkpoint['model_state_dict'])
    else:
        print('Please load k but 1 Checkpoint to eval...')
        sys.exit(0)
        start_epoch = 0

    num_class1 = 2
    binaryclassifier = PointConvClsSsg(num_class1).cuda()
    if args.binarycheckpoint is not None:
        print('Load binary CheckPoint...')
        logger.info('Load binary CheckPoint')
        binarycheckpoint = torch.load(args.binarycheckpoint)
        start_epoch = binarycheckpoint['epoch']
        binaryclassifier.load_state_dict(binarycheckpoint['model_state_dict'])
    else:
        print('Please load binary Checkpoint to eval...')
        sys.exit(0)
        start_epoch2 = 0

    blue = lambda x: '\033[94m' + x + '\033[0m'
    '''EVAL'''
    logger.info('Start evaluating...')
    print('Start evaluating...')

    total_correct = 0
    total_seen = 0
    preds = []
    for batch_id, data in tqdm(enumerate(testDataLoader, 0),
                               total=len(testDataLoader),
                               smoothing=0.9):
        pointcloud, target = data
        target = target[:, 0]
        #import ipdb; ipdb.set_trace()
        pred_view = torch.zeros(pointcloud.shape[0], num_class).cuda()
        binary_view = torch.zeros(pointcloud.shape[0], num_class1).cuda()

        for _ in range(args.num_view):
            pointcloud = generate_new_view(pointcloud)
            #import ipdb; ipdb.set_trace()
            #points = torch.from_numpy(pointcloud).permute(0, 2, 1)
            points = pointcloud.permute(0, 2, 1)
            points, target = points.cuda(), target.cuda()
            kb1classifier = kb1classifier.eval()
            binaryclassifier = binaryclassifier.eval()
            with torch.no_grad():
                pred = kb1classifier(points)
                pred_binary = binaryclassifier(points)
            pred_view += pred
            binary_view += pred_binary

        kb1_logprob = pred_view.data
        binary_logprob = binary_view.data
        ## since we assigned the composite class the largest label, we will split the log-probability for the last label to two part, one for binary 0 and one for binary 1.
        binary_pred_logprob = kb1_logprob[:, -1].reshape(
            1, len(kb1_logprob[:, -1])).transpose(0, 1).repeat(1, 2).view(
                -1, 2) + binary_logprob
        ## concatenate to get log-probability for all (40) classes
        pred_logprob = torch.from_numpy(
            np.c_[kb1_logprob[:, 0:-1].cpu().detach().numpy(),
                  binary_pred_logprob.cpu().detach().numpy()]).to('cuda')
        pred_choices = pred_logprob.max(1)[1]

        ## reset labels
        mapper_dict = {
            **{key: key + 1
               for key in range(12, 32)},
            **{key: key + 2
               for key in range(32, 38)},
            **{
                38: 33,
                39: 12
            }
        }

        def mp(entry):
            return mapper_dict[entry] if entry in mapper_dict else entry

        mp = np.vectorize(mp)

        pred_choice = torch.from_numpy(
            np.array(mp(pred_choices.cpu().detach().numpy()))).to('cuda')
        preds.append(pred_choice.cpu().detach().numpy())
        correct = pred_choice.eq(
            target.long().data).cpu().detach().numpy().sum()
        total_correct += correct.item()
        total_seen += float(points.size()[0])

    accuracy = total_correct / total_seen
    ## confusion matrix
    cm = confusion_matrix(test_label.ravel(), np.concatenate(preds).ravel())
    cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    t = pd.read_table('data/ModelNet/shape_names.txt', names=['label'])
    d = {key: val for key, val in zip(t.label, cm.diagonal())}
    print('Total Accuracy: %f' % accuracy)
    print('Accuracy per class:', d)

    logger.info('Total Accuracy: %f' % accuracy)
    logger.info('End of evaluation...')
def main(args):
    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    '''CREATE DIR'''
    experiment_dir = Path('./experiment/')
    experiment_dir.mkdir(exist_ok=True)
    file_dir = Path(str(experiment_dir) + '/%s_ModelNet40-' % args.model_name + str(
        datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')))
    file_dir.mkdir(exist_ok=True)
    checkpoints_dir = file_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = file_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)

    '''LOG'''
    args = parse_args()
    logger = logging.getLogger(args.model_name)
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler(str(log_dir) + 'train_%s_cls.txt' % args.model_name)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    logger.info(
        '---------------------------------------------------TRANING---------------------------------------------------')
    logger.info('PARAMETER ...')
    logger.info(args)

    '''DATA LOADING'''
    logger.info('Load dataset ...')
    DATA_PATH = './data/modelnet40_normal_resampled/'

    TRAIN_DATASET = ModelNetDataLoader(root=DATA_PATH, npoint=args.num_point, split='train', normal_channel=args.normal)
    TEST_DATASET = ModelNetDataLoader(root=DATA_PATH, npoint=args.num_point, split='test', normal_channel=args.normal)
    trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batchsize, shuffle=True,
                                                  num_workers=args.num_workers)
    testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batchsize, shuffle=False,
                                                 num_workers=args.num_workers)

    logger.info("The number of training data is: %d", len(TRAIN_DATASET))
    logger.info("The number of test data is: %d", len(TEST_DATASET))

    seed = 3
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    '''MODEL LOADING'''
    num_class = 40
    classifier = PointConvClsSsg(num_class).cuda()
    if args.pretrain is not None:
        print('Use pretrain model...')
        logger.info('Use pretrain model')
        checkpoint = torch.load(args.pretrain)
        start_epoch = checkpoint['epoch']
        classifier.load_state_dict(checkpoint['model_state_dict'])
    else:
        print('No existing model, starting training from scratch...')
        start_epoch = 0

    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9)
    elif args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(
            classifier.parameters(),
            lr=args.learning_rate,
            betas=(0.9, 0.999),
            eps=1e-08,
            weight_decay=args.decay_rate
        )
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.7)
    global_epoch = 0
    global_step = 0
    best_tst_accuracy = 0.0
    blue = lambda x: '\033[94m' + x + '\033[0m'

    '''TRANING'''
    logger.info('Start training...')
    for epoch in range(start_epoch, args.epoch):
        print('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch))
        logger.info('Epoch %d (%d/%s):', global_epoch + 1, epoch + 1, args.epoch)
        mean_correct = []

        scheduler.step()
        for batch_id, data in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9):
        # for batch_id, data in enumerate(trainDataLoader, 0):
            points, target = data
            points = points.data.numpy()
            # 增强数据: 随机放大和平移点云,随机移除一些点
            jittered_data = provider.random_scale_point_cloud(points[:, :, 0:3], scale_low=2.0 / 3, scale_high=3 / 2.0)
            jittered_data = provider.shift_point_cloud(jittered_data, shift_range=0.2)
            points[:, :, 0:3] = jittered_data
            points = provider.random_point_dropout_v2(points)
            provider.shuffle_points(points)
            points = torch.Tensor(points)
            target = target[:, 0]

            points = points.transpose(2, 1)
            points, target = points.cuda(), target.cuda()
            optimizer.zero_grad()

            classifier = classifier.train()
            # pred = classifier(points[:, :3, :], points[:, 3:, :])
            pred = classifier(points[:, :3, :], None)
            loss = F.nll_loss(pred, target.long())
            pred_choice = pred.data.max(1)[1]
            correct = pred_choice.eq(target.long().data).cpu().sum()
            mean_correct.append(correct.item() / float(points.size()[0]))
            loss.backward()
            optimizer.step()
            global_step += 1

        train_acc = np.mean(mean_correct)
        print('Train Accuracy: %f' % train_acc)
        logger.info('Train Accuracy: %f' % train_acc)

        acc = test(classifier, testDataLoader)

        if (acc >= best_tst_accuracy) and epoch > 5:
            best_tst_accuracy = acc
            logger.info('Save model...')
            save_checkpoint(
                global_epoch + 1,
                train_acc,
                acc,
                classifier,
                optimizer,
                str(checkpoints_dir),
                args.model_name)
            print('Saving model....')

        print('\r Loss: %f' % loss.data)
        logger.info('Loss: %.2f', loss.data)
        print('\r Test %s: %f   ***  %s: %f' % (blue('Accuracy'), acc, blue('Best Accuracy'), best_tst_accuracy))
        logger.info('Test Accuracy: %f  *** Best Test Accuracy: %f', acc, best_tst_accuracy)

        global_epoch += 1
    print('Best Accuracy: %f' % best_tst_accuracy)

    logger.info('End of training...')
def main(args):
    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    datapath = './data/ModelNet/'
    '''CREATE DIR'''
    experiment_dir = Path('./eval_experiment/')
    experiment_dir.mkdir(exist_ok=True)
    file_dir = Path(
        str(experiment_dir) + '/%sModelNet40-' % args.model_name +
        str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')))
    file_dir.mkdir(exist_ok=True)
    checkpoints_dir = file_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    os.system('cp %s %s' % (args.checkpoint, checkpoints_dir))
    log_dir = file_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG'''
    args = parse_args()
    logger = logging.getLogger(args.model_name)
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler(
        str(log_dir) + 'eval_%s_cls.txt' % args.model_name)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    logger.info(
        '---------------------------------------------------EVAL---------------------------------------------------'
    )
    logger.info('PARAMETER ...')
    logger.info(args)
    '''DATA LOADING'''
    logger.info('Load dataset ...')
    train_data, train_label, test_data, test_label = load_data(
        datapath, classification=True)
    logger.info("The number of training data is: %d", train_data.shape[0])
    logger.info("The number of test data is: %d", test_data.shape[0])
    testDataset = ModelNetDataLoader(test_data, test_label)
    testDataLoader = torch.utils.data.DataLoader(testDataset,
                                                 batch_size=args.batchsize,
                                                 shuffle=False)
    '''MODEL LOADING'''
    num_class = 39
    classifier = PointConvClsSsg(num_class).cuda()
    if args.checkpoint is not None:
        print('Load CheckPoint...')
        logger.info('Load CheckPoint')
        checkpoint = torch.load(args.checkpoint)
        start_epoch = checkpoint['epoch']
        classifier.load_state_dict(checkpoint['model_state_dict'])
    else:
        print('Please load Checkpoint to eval...')
        sys.exit(0)
        start_epoch = 0

    blue = lambda x: '\033[94m' + x + '\033[0m'
    '''EVAL'''
    logger.info('Start evaluating...')
    print('Start evaluating...')

    total_correct = 0
    total_seen = 0
    for batch_id, data in tqdm(enumerate(testDataLoader, 0),
                               total=len(testDataLoader),
                               smoothing=0.9):
        pointcloud, target = data
        target = target[:, 0]
        #import ipdb; ipdb.set_trace()
        pred_view = torch.zeros(pointcloud.shape[0], num_class).cuda()

        for _ in range(args.num_view):
            pointcloud = generate_new_view(pointcloud)
            #import ipdb; ipdb.set_trace()
            #points = torch.from_numpy(pointcloud).permute(0, 2, 1)
            points = pointcloud.permute(0, 2, 1)
            points, target = points.cuda(), target.cuda()
            classifier = classifier.eval()
            with torch.no_grad():
                pred = classifier(points)
            pred_view += pred
        pred_choice = pred_view.data.max(1)[1]
        correct = pred_choice.eq(target.long().data).cpu().sum()
        total_correct += correct.item()
        total_seen += float(points.size()[0])

    accuracy = total_correct / total_seen
    print('Total Accuracy: %f' % accuracy)

    logger.info('Total Accuracy: %f' % accuracy)
    logger.info('End of evaluation...')
Example #5
0
def main(args):
    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    '''CREATE DIR'''
    experiment_dir = Path('./eval_experiment/')
    experiment_dir.mkdir(exist_ok=True)
    file_dir = Path(str(experiment_dir) + '/%s_ModelNet40-'%args.model_name + str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')))
    file_dir.mkdir(exist_ok=True)
    checkpoints_dir = file_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    os.system('cp %s %s' % (args.checkpoint, checkpoints_dir))
    log_dir = file_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)

    '''LOG'''
    args = parse_args()
    logger = logging.getLogger(args.model_name)
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler(str(log_dir) + 'eval_%s_cls.txt'%args.model_name)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    logger.info('---------------------------------------------------EVAL---------------------------------------------------')
    logger.info('PARAMETER ...')
    logger.info(args)

    '''DATA LOADING'''
    logger.info('Load dataset ...')
    DATA_PATH = './data/modelnet40_normal_resampled/'

    TEST_DATASET = ModelNetDataLoader(root=DATA_PATH, npoint=args.num_point, split='test', normal_channel=args.normal)
    testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batchsize, shuffle=False, num_workers=args.num_workers)
    logger.info("The number of test data is: %d", len(TEST_DATASET))

    seed = 3
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    '''MODEL LOADING'''
    num_class = 40
    classifier = PointConvClsSsg(num_class).cuda()
    if args.checkpoint is not None:
        print('Load CheckPoint...')
        logger.info('Load CheckPoint')
        checkpoint = torch.load(args.checkpoint)
        start_epoch = checkpoint['epoch']
        classifier.load_state_dict(checkpoint['model_state_dict'])
    else:
        print('Please load Checkpoint to eval...')
        sys.exit(0)
        start_epoch = 0

    blue = lambda x: '\033[94m' + x + '\033[0m'

    '''EVAL'''
    logger.info('Start evaluating...')
    print('Start evaluating...')

    classifier = classifier.eval()
    mean_correct = []
    for batch_id, data in tqdm(enumerate(testDataLoader, 0), total=len(testDataLoader), smoothing=0.9):
        pointcloud, target = data
        target = target[:, 0]

        points = pointcloud.permute(0, 2, 1)
        points, target = points.cuda(), target.cuda()
        with torch.no_grad():
            pred = classifier(points[:, :3, :], points[:, 3:, :])
        pred_choice = pred.data.max(1)[1]
        correct = pred_choice.eq(target.long().data).cpu().sum()

        mean_correct.append(correct.item()/float(points.size()[0]))

    accuracy = np.mean(mean_correct)
    print('Total Accuracy: %f'%accuracy)

    logger.info('Total Accuracy: %f'%accuracy)
    logger.info('End of evaluation...')
Example #6
0
def main(args):
    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    datapath = './data/ModelNet/'
    '''CREATE DIR'''
    experiment_dir = Path('./experiment/')
    experiment_dir.mkdir(exist_ok=True)
    file_dir = Path(
        str(experiment_dir) + '/%sModelNet40-' % args.model_name +
        str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')))
    file_dir.mkdir(exist_ok=True)
    checkpoints_dir = file_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = file_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)
    '''LOG'''
    args = parse_args()
    logger = logging.getLogger(args.model_name)
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler(
        str(log_dir) + 'train_%s_cls.txt' % args.model_name)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    logger.info(
        '---------------------------------------------------TRANING---------------------------------------------------'
    )
    logger.info('PARAMETER ...')
    logger.info(args)
    '''DATA LOADING'''
    logger.info('Load dataset ...')
    train_data, train_label, test_data, test_label = load_data(
        datapath, classification=True)
    logger.info("The number of training data is: %d", train_data.shape[0])
    logger.info("The number of test data is: %d", test_data.shape[0])
    trainDataset = ModelNetDataLoader(train_data, train_label)
    testDataset = ModelNetDataLoader(test_data, test_label)
    trainDataLoader = torch.utils.data.DataLoader(trainDataset,
                                                  batch_size=args.batchsize,
                                                  shuffle=True)
    testDataLoader = torch.utils.data.DataLoader(testDataset,
                                                 batch_size=args.batchsize,
                                                 shuffle=False)
    '''MODEL LOADING'''
    num_class = 40
    classifier = PointConvClsSsg(num_class).cuda()
    if args.pretrain is not None:
        print('Use pretrain model...')
        logger.info('Use pretrain model')
        checkpoint = torch.load(args.pretrain)
        start_epoch = checkpoint['epoch']
        classifier.load_state_dict(checkpoint['model_state_dict'])
    else:
        print('No existing model, starting training from scratch...')
        start_epoch = 0

    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(classifier.parameters(),
                                    lr=0.01,
                                    momentum=0.9)
    elif args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(classifier.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=args.decay_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=30,
                                                gamma=0.5)
    global_epoch = 0
    global_step = 0
    best_tst_accuracy = 0.0
    blue = lambda x: '\033[94m' + x + '\033[0m'
    '''TRANING'''
    logger.info('Start training...')
    first_time = True
    for epoch in range(start_epoch, args.epoch):
        print('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch))
        logger.info('Epoch %d (%d/%s):', global_epoch + 1, epoch + 1,
                    args.epoch)

        scheduler.step()
        for batch_id, data in tqdm(enumerate(trainDataLoader, 0),
                                   total=len(trainDataLoader),
                                   smoothing=0.9):
            points, target = data
            target = target[:, 0]
            points = points.transpose(2, 1)
            points, target = points.cuda(), target.cuda()
            #construct_planes(points[0])
            optimizer.zero_grad()
            classifier = classifier.train()
            pred = classifier(points)
            loss = F.nll_loss(pred, target.long())

            loss.backward()
            optimizer.step()
            global_step += 1

        train_acc = test(classifier.eval(), trainDataLoader,
                         False) if args.train_metric else None
        acc = test(classifier, testDataLoader, False)

        print('\r Loss: %f' % loss.data)
        logger.info('Loss: %.2f', loss.data)
        if args.train_metric:
            print('Train Accuracy: %f' % train_acc)
            logger.info('Train Accuracy: %f', (train_acc))
        print(
            '\r Test %s: %f   ***  %s: %f' %
            (blue('Accuracy'), acc, blue('Best Accuracy'), best_tst_accuracy))
        logger.info('Test Accuracy: %f  *** Best Test Accuracy: %f', acc,
                    best_tst_accuracy)

        if (acc >= best_tst_accuracy) and epoch > 5:
            best_tst_accuracy = acc
            logger.info('Save model...')
            save_checkpoint(global_epoch + 1,
                            train_acc if args.train_metric else 0.0, acc,
                            classifier, optimizer, str(checkpoints_dir),
                            args.model_name)
            print('Saving model....')
        global_epoch += 1
    print('Best Accuracy: %f' % best_tst_accuracy)

    logger.info('End of training...')