def main(args):
    ''' --- SELECT DEVICES --- '''
    # Select either gpu or cpu
    device = torch.device("cuda" if args.cuda else "cpu")
    # Select among available GPUs
    if args.cuda:
        os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
            str(x) for x in args.gpudevice)
    ''' --- CREATE EXPERIMENTS DIRECTORY AND LOGGERS IN TENSORBOARD --- '''
    projdir = sys.path[0]
    # Path for saving and loading the network.
    saveloadpath = os.path.join(projdir, 'experiment\\checkpoints',
                                args.exp_name + '.pth')
    Path(os.path.dirname(saveloadpath)).mkdir(exist_ok=True, parents=True)
    # timestamp = str(datetime.datetime.now().strftime('%Y-%m-%d-%H-%M'))
    tblogdir = os.path.join(projdir, 'experiment\\tensorboardX',
                            args.exp_name)  # + '_' + timestamp )
    Path(tblogdir).mkdir(exist_ok=True, parents=True)
    # Create tb_writer(the writer will be used to write the information on tb) by using SummaryWriter,
    # flush_secs defines how much seconds need to wait for writing information.
    tb_writer = SummaryWriter(logdir=tblogdir,
                              flush_secs=3,
                              write_to_disk=True)
    ''' --- INIT DATASETS AND DATALOADER (FOR SINGLE EPOCH) --- '''
    # Read data from file, and create training data and testing data which are both in multiple frames. Beware Ts is
    # recording for every frame, i.e. every 82ms the automotive radar records once to form single frame(We need this information for LSTM).
    train_dataset, test_dataset, class_names = read_dataset(
        args.datapath, Ts=0.082, train_test_split=0.8)

    # Prepare the traing and testing dataset. both trainDataset and testDataset are dataset have multiple frames data,
    # for each frame it contains the "unified number of detection points"(NMAX detection points per frame).

    # Init test dataset(Beware we should NOT use data augmentation for test dataset)
    test_dataTransformations = transforms.Compose(
        [NormalizeTime(), Resampling(maxPointsPerFrame=10)])
    testDataset = RadarClassDataset(dataset=test_dataset,
                                    transforms=test_dataTransformations,
                                    sequence_length=1)
    # Init train datasets
    train_dataTransformations = transforms.Compose([
        NormalizeTime(),
        DataAugmentation(),
        Resampling(maxPointsPerFrame=10)
    ])
    trainDataset = RadarClassDataset(dataset=train_dataset,
                                     transforms=train_dataTransformations,
                                     sequence_length=1)
    # Create dataloader for training by using batch_size frames' data in each batch
    trainDataLoader = DataLoader(trainDataset,
                                 batch_size=args.batchsize,
                                 shuffle=True,
                                 num_workers=args.num_workers)
    ''' --- INIT NETWORK MODEL --- '''
    # Load selected network model and put it to right device
    if args.model_name == 'pointnet':
        classifier = PointNetCls(dim=args.pointCoordDim,
                                 num_class=args.numclasses,
                                 feature_transform=args.feature_transform)
    elif args.model_name == 'pointnet2':
        classifier = PointNet2ClsMsg(
            dim=args.pointCoordDim,
            num_class=args.numclasses,
        )
    else:
        raise Exception(
            'Argument "model_name" does not match existent networks')
    classifier = classifier.to(device)
    ''' --- LOAD NETWORK IF EXISTS --- '''
    if os.path.exists(saveloadpath):
        print('Using pretrained model found...')
        checkpoint = torch.load(saveloadpath)
        start_epoch = checkpoint[
            'epoch'] + 1  # Just becase make sure counting starts from 1, 2, ..., rather than 0, 1, ..., when print the information of start_epoch
        iteration = checkpoint['iteration']
        best_test_acc = checkpoint['test_accuracy']
        classifier.load_state_dict(checkpoint['model_state_dict'])
    else:
        print('No existing model, starting training from scratch...')
        start_epoch = 1  # Just becase make sure counting starts from 1, 2, ..., rather than 0, 1, ..., when print the information of start_epoch
        iteration = 1  # Just becase make sure counting starts from 1, 2, ..., rather than 0, 1, ..., when print the information of iteration
        best_test_acc = 0
    ''' --- CREATE OPTIMIZER ---'''
    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(classifier.parameters(),
                                    lr=args.lr,
                                    momentum=0.9)
    elif args.optimizer == 'ADAM':
        optimizer = torch.optim.Adam(classifier.parameters(),
                                     lr=args.lr,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=args.decay_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, step_size=args.lr_epoch_half,
        gamma=0.5)  # half(0.5) the learning rate every 'step_size' epochs

    # log info
    printparams = 'Model parameters:' + json.dumps(
        vars(args), indent=4, sort_keys=True)
    print(printparams)
    tb_writer.add_text('hyper-parameters', printparams,
                       iteration)  # tb_writer.add_hparam(args)
    tb_writer.add_text(
        'dataset', 'dataset sample size: training: {}, test: {}'.format(
            train_dataset.shape[0], test_dataset.shape[0]), iteration)
    ''' --- START TRANING ---'''
    for epoch in range(start_epoch, args.epoch + 1):
        print('Epoch %d/%s:' % (epoch, args.epoch))

        # Add the "learning rate" into tensorboard scalar which will be shown in tensorboard
        tb_writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'],
                             iteration)

        for batch_id, data in tqdm(enumerate(trainDataLoader, 0),
                                   total=len(trainDataLoader),
                                   smoothing=0.9):
            points, target = data  # (B:batch x S:seq x C:features x N:points) , (B x S:seq)
            # Squeeze to drop Sequence dimension, which is equal to 1, convert all the data to float(otherwise there will be data type problems when running the model) and move to device
            points, target = points.squeeze(
                dim=1).float().to(device), target.float().to(
                    device)  # (B:batch x C:features x N:points) , (B)
            # points, target = points.float().to(device), target.float().to(device)
            # Reset gradients
            optimizer.zero_grad()
            # Sets the module in training mode
            classifier = classifier.train()
            # Forward propagation
            pred = classifier(points)
            # Calculate cross entropy loss (In the pointnet/pointnet2 network model, it outputs log_softmax result. Since
            # "log_softmax -> nll_loss" == CrossEntropyLoss, so that we just need to call F.nll_loss)
            loss = F.nll_loss(pred, target.long())
            if args.model_name == 'pointnet':
                loss += feature_transform_regularizer(classifier.trans) * 0.001
                if args.feature_transform:
                    loss += feature_transform_regularizer(
                        classifier.trans_feat) * 0.001
            # Back propagate
            loss.backward()
            # Update weights
            optimizer.step()
            # Log once for every 5 batches, add the "train_loss/cross_entropy" into tensorboard scalar which will be shown in tensorboard
            if not batch_id % 5:
                tb_writer.add_scalar('train_loss/cross_entropy', loss.item(),
                                     iteration)
            iteration += 1
            # if batch_id> 2: break

        scheduler.step()
        ''' --- TEST AND SAVE NETWORK --- '''
        if not epoch % 10:  # Doing the following things every epoch.
            # Perform predictions on the training data.
            train_targ, train_pred = test(classifier,
                                          trainDataset,
                                          device,
                                          num_workers=args.num_workers,
                                          batch_size=1800)
            # Perform predictions on the testing data.
            test_targ, test_pred = test(classifier,
                                        testDataset,
                                        device,
                                        num_workers=args.num_workers,
                                        batch_size=1800)

            # Calculate the accuracy rate for training data.
            train_acc = metrics_accuracy(train_targ, train_pred)
            # Calculate the accuracy rate for testing data.
            test_acc = metrics_accuracy(test_targ, test_pred)
            print('\r Training loss: {}'.format(loss.item()))
            print('Train Accuracy: {}\nTest Accuracy: {}'.format(
                train_acc, test_acc))
            # Add the "train_acc" "test_acc" into tensorboard scalars which will be shown in tensorboard.
            tb_writer.add_scalars('metrics/accuracy', {
                'train': train_acc,
                'test': test_acc
            }, iteration)

            # Calculate confusion matrix.
            confmatrix_test = metrics_confusion_matrix(test_targ, test_pred)
            print('Test confusion matrix: \n', confmatrix_test)
            # Log confusion matrix.
            fig, ax = plot_confusion_matrix(confmatrix_test,
                                            class_names,
                                            normalize=False,
                                            title='Test Confusion Matrix')
            # Log normalized confusion matrix.
            fig_n, ax_n = plot_confusion_matrix(
                confmatrix_test,
                class_names,
                normalize=True,
                title='Test Confusion Matrix - Normalized')
            # Add the "confusion matrix" "normalized confusion matrix" into tensorboard figure which will be shown in tensorboard.
            tb_writer.add_figure('test_confusion_matrix/abs',
                                 fig,
                                 global_step=iteration,
                                 close=True)
            tb_writer.add_figure('test_confusion_matrix/norm',
                                 fig_n,
                                 global_step=iteration,
                                 close=True)

            # Log precision recall curves.
            for idx, clsname in enumerate(class_names):
                # Convert log_softmax to softmax(which is actual probability) and select the desired class.
                test_pred_binary = torch.exp(test_pred[:, idx])
                test_targ_binary = test_targ.eq(idx)
                # Add the "precision recall curves" which will be shown in tensorboard.
                tb_writer.add_pr_curve(tag='pr_curves/' + clsname,
                                       labels=test_targ_binary,
                                       predictions=test_pred_binary,
                                       global_step=iteration)
            ''' --- SAVE NETWORK --- '''
            # if (test_acc >= best_test_acc): # For now lets save every time, since we are only testing in a subset of the test dataset
            best_test_acc = test_acc  # if test_acc > best_test_acc else best_test_acc
            state = {
                'epoch': epoch,
                'iteration': iteration,
                'train_accuracy': train_acc if args.train_metric else 0.0,
                'test_accuracy': best_test_acc,
                'model_state_dict': classifier.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }
            torch.save(state, saveloadpath)
            print('Model saved!!!')

    print('Best Accuracy: %f' % best_test_acc)

    tb_writer.close()
Example #2
0
def main(args):
    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    # datapath = './data/ModelNet/'  
    datapath = './data/objecnn20_data_hdf5_2048/'
    if args.rotation is not None:
        ROTATION = (int(args.rotation[0:2]),int(args.rotation[3:5]))
    else:
        ROTATION = None

    '''CREATE DIR'''
    experiment_dir = Path('./experiment/')
    experiment_dir.mkdir(exist_ok=True)
    file_dir = Path(str(experiment_dir) +'/%sObjectNNClf-'%args.model_name+ str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')))
    file_dir.mkdir(exist_ok=True)
    checkpoints_dir = file_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = file_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)

    '''LOG'''
    args = parse_args()
    logger = logging.getLogger(args.model_name)
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler(str(log_dir) + '/train_%s_ObjectNNClf.txt'%args.model_name)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    logger.info('---------------------------------------------------TRANING---------------------------------------------------')
    logger.info('PARAMETER ...')
    logger.info(args)

    '''DATA LOADING'''
    logger.info('Load dataset ...')
    train_data, train_label, test_data, test_label = load_data(datapath, classification=True)
    logger.info("The number of training data is: %d",train_data.shape[0])
    logger.info("The number of test data is: %d", test_data.shape[0])
    trainDataset = ObjectNNDataLoader(train_data, train_label, rotation=ROTATION)
    if ROTATION is not None:
        print('The range of training rotation is',ROTATION)
    testDataset = ObjectNNDataLoader(test_data, test_label, rotation=ROTATION)

    trainDataLoader = torch.utils.data.DataLoader(trainDataset, batch_size=args.batchsize, shuffle=True)
    testDataLoader = torch.utils.data.DataLoader(testDataset, batch_size=args.batchsize, shuffle=False)

    '''MODEL LOADING'''
    num_class = 20
    classifier = PointNetCls(num_class,args.feature_transform).cuda() if args.model_name == 'pointnet' else PointNet2ClsMsg().cuda()
    if args.pretrain is not None:
        print('Use pretrain model...')
        logger.info('Use pretrain model')
        checkpoint = torch.load(args.pretrain)
        start_epoch = checkpoint['epoch']
        # classifier.load_state_dict(checkpoint['model_state_dict'])
        # print(checkpoint['model_state_dict'])
        model_dict = classifier.state_dict()
        pretrained_dict = {k: v for k, v in checkpoint['model_state_dict'].items() if k in model_dict}
        model_dict.update(pretrained_dict)
        classifier.load_state_dict(model_dict)
    else:
        print('No existing model, starting training from scratch...')
        start_epoch = 0


    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9)
    elif args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(
            classifier.parameters(),
            lr=args.learning_rate,
            betas=(0.9, 0.999),
            eps=1e-08,
            weight_decay=args.decay_rate
        )
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5) # 调整学习率
    global_epoch = 0
    global_step = 0
    best_tst_accuracy = 0.0
    blue = lambda x: '\033[94m' + x + '\033[0m'

    '''TRANING'''
    logger.info('Start training...')
    for epoch in range(start_epoch,args.epoch):
        print('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch))
        logger.info('Epoch %d (%d/%s):' ,global_epoch + 1, epoch + 1, args.epoch)

        scheduler.step()
        for batch_id, data in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9):
            points, target = data
            target = target[:, 0]
            points = points.transpose(2, 1)
            points, target = points.cuda(), target.cuda()
            optimizer.zero_grad()
            classifier = classifier.train()
            pred, trans_feat, global_feature = classifier(points)
            loss = F.nll_loss(pred, target.long())
            if args.feature_transform and args.model_name == 'pointnet':
                loss += feature_transform_reguliarzer(trans_feat) * 0.001
            loss.backward()
            optimizer.step()
            global_step += 1

        train_acc = test(classifier.eval(), trainDataLoader) if args.train_metric else None
        acc = test(classifier, testDataLoader)


        print('\r Loss: %f' % loss.data)
        logger.info('Loss: %.2f', loss.data)
        if args.train_metric:
            print('Train Accuracy: %f' % train_acc)
            logger.info('Train Accuracy: %f', (train_acc))
        print('\r Test %s: %f' % (blue('Accuracy'),acc))
        logger.info('Test Accuracy: %f', acc)

        if (acc >= best_tst_accuracy) and epoch > 5:
            best_tst_accuracy = acc
            logger.info('Save model...')
            save_checkpoint(
                global_epoch + 1,
                train_acc if args.train_metric else 0.0,
                acc,
                classifier,
                optimizer,
                str(checkpoints_dir),
                args.model_name)
            print('Saving model....')
        global_epoch += 1
    print('Best Accuracy: %f'%best_tst_accuracy)

    logger.info('End of training...')
Example #3
0
        model = nn.DataParallel(model, device_ids=device_id).cuda()
optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=30, gamma=0.5)

def train(model, loader, epoch):
    print('we are training')
<<<<<<< Updated upstream
<<<<<<< HEAD
    optimizer.step()
=======
>>>>>>> a9fc64760b36adc06d7d772e52c04514aa96e708
=======
    optimizer.step()
>>>>>>> Stashed changes
    scheduler.step()
    model.train()
    torch.set_grad_enabled(True)
    correct = 0
    dataset_size = 0
    for batch_idx, (data, target) in enumerate(loader):
<<<<<<< Updated upstream
<<<<<<< HEAD
        #print("Now is the batch" + str(batch_idx))
=======
        print("Now is the batch" + str(batch_idx))
>>>>>>> a9fc64760b36adc06d7d772e52c04514aa96e708
=======
        #print("Now is the batch" + str(batch_idx))
>>>>>>> Stashed changes
        dataset_size += data.shape[0]
        data, target = data.float(), target.long().squeeze()
Example #4
0
def train(args):
    experiment_dir = mkdir('./experiment/')
    checkpoints_dir = mkdir('./experiment/clf/%s/' % (args.model_name))
    train_data, train_label, test_data, test_label = load_data(
        'experiment/data/modelnet40_ply_hdf5_2048/')

    trainDataset = ModelNetDataLoader(train_data,
                                      train_label,
                                      data_augmentation=args.augment)
    trainDataLoader = DataLoader(trainDataset,
                                 batch_size=args.batch_size,
                                 shuffle=True)

    testDataset = ModelNetDataLoader(test_data, test_label)
    testDataLoader = torch.utils.data.DataLoader(testDataset,
                                                 batch_size=args.batch_size,
                                                 shuffle=False)

    log.info('Building Model', args.model_name)
    if args.model_name == 'pointnet':
        num_class = 40
        model = PointNetCls(num_class, args.feature_transform).cuda()
    else:
        model = PointNet2ClsMsg().cuda()

    torch.backends.cudnn.benchmark = True
    model = torch.nn.DataParallel(model).cuda()
    log.debug('Using gpu:', args.gpu)

    if args.pretrain is not None:
        log.info('Use pretrain model...')
        state_dict = torch.load(args.pretrain)
        model.load_state_dict(state_dict)
        init_epoch = int(args.pretrain[:-4].split('-')[-1])
        log.info('start epoch from', init_epoch)
    else:
        log.info('Training from scratch')
        init_epoch = 0

    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    elif args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.learning_rate,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=args.decay_rate)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=20,
                                                gamma=0.5)
    LEARNING_RATE_CLIP = 1e-5

    global_epoch = 0
    global_step = 0
    best_tst_accuracy = 0.0

    log.info('Start training...')
    for epoch in range(init_epoch, args.epoch):
        scheduler.step()
        lr = max(optimizer.param_groups[0]['lr'], LEARNING_RATE_CLIP)

        log.debug(job='clf',
                  model=args.model_name,
                  gpu=args.gpu,
                  epoch='%d/%s' % (epoch, args.epoch),
                  lr=lr)

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        for batch_id, data in tqdm(enumerate(trainDataLoader, 0),
                                   total=len(trainDataLoader),
                                   smoothing=0.9):
            points, target = data
            target = target[:, 0]
            points = points.transpose(2, 1)
            points, target = points.cuda(), target.cuda()
            optimizer.zero_grad()
            model = model.train()
            pred, trans_feat = model(points)
            loss = F.nll_loss(pred, target.long())
            if args.feature_transform and args.model_name == 'pointnet':
                loss += feature_transform_reguliarzer(trans_feat) * 0.001
            loss.backward()
            optimizer.step()
            global_step += 1

        log.debug('clear cuda cache')
        torch.cuda.empty_cache()

        acc = test_clf(model, testDataLoader)
        log.info(loss='%.5f' % (loss.data))
        log.info(Test_Accuracy='%.5f' % acc)

        if acc >= best_tst_accuracy:
            best_tst_accuracy = acc
            fn_pth = 'clf-%s-%.5f-%04d.pth' % (args.model_name, acc, epoch)
            log.debug('Saving model....', fn_pth)
            torch.save(model.state_dict(),
                       os.path.join(checkpoints_dir, fn_pth))
        global_epoch += 1

    log.info(Best_Accuracy=best_tst_accuracy)
    log.info('End of training...')
Example #5
0
def main():    

    ''' --- SELECT DEVICES --- '''
    # Select either gpu or cpu
    device = torch.device("cuda" if args.cuda else "cpu")
    # Select among available GPUs
    if args.cuda: os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpudevice)
    

    ''' --- CREATE EXPERIMENTS DIRECTORY AND LOGGERS IN TENSORBOARD --- '''
    projdir = sys.path[0]
    # Path for saving and loading the network.
    saveloadpath = os.path.join( projdir, 'experiment\\checkpoints', args.exp_name+'.pth')
    Path(os.path.dirname(saveloadpath)).mkdir(exist_ok=True, parents=True)
    # timestamp = str(datetime.datetime.now().strftime('%Y-%m-%d-%H-%M'))
    tblogdir = os.path.join( projdir, 'experiment\\tensorboardX', args.exp_name ) # + '_' + timestamp )
    Path(tblogdir).mkdir(exist_ok=True, parents=True)
    # Create tb_writer(the writer will be used to write the information on tb) by using SummaryWriter, 
    # flush_secs defines how much seconds need to wait for writing information.
    tb_writer = SummaryWriter( logdir=tblogdir, flush_secs=3, write_to_disk=True)


    ''' --- INIT DATASETS AND DATALOADER (FOR SINGLE EPOCH) --- '''
    # Ideal for PointNet and pointLSTM - dataloader will return (B:batch, S:seq, C:features, N:points)
    dataTransformations = transforms.Compose([
        ToSeries(),
        DataAugmentation(),
        Resampling(maxPointsPerFrame=10),
        ToTensor()
    ])
    # Init nuScenes datasets
    nusc_train = NuScenes(version=args.nuscenes_train_dir, dataroot=args.nuscenes_dir, verbose=True)
    train_dataset = RadarClassDataset(nusc_train, categories=args.categories, sensors=args.sensors, transforms=dataTransformations, sequence_length=1)
    nusc_test = NuScenes(version=args.nuscenes_test_dir, dataroot=args.nuscenes_dir, verbose=True)
    test_dataset = RadarClassDataset(nusc_test, categories=args.categories, sensors=args.sensors, transforms=dataTransformations, sequence_length=1)
    # Init training data loader
    trainDataLoader = DataLoader(train_dataset, batch_size=args.batchsize, shuffle=True, num_workers=args.num_workers)

    ''' --- INIT NETWORK MODEL --- '''
    # Load selected network model and put it to right device
    if args.model_name == 'pointnet':
        classifier = PointNetCls(dim=args.pointCoordDim, num_class=len(args.categories), feature_transform=args.feature_transform)  
    elif args.model_name == 'pointnet2':
        classifier = PointNet2ClsMsg(dim=args.pointCoordDim, num_class=len(args.categories) )
    else:
        raise Exception('Argument "model_name" does not match existent networks')
    classifier = classifier.to(device)

    ''' --- INIT LOSS FUNCTION --- '''
    loss_fun = FocalLoss(gamma=args.focalLoss_gamma, num_classes=len(args.categories), alpha=args.weight_cat).to(device)

    ''' --- LOAD NETWORK IF EXISTS --- '''
    if os.path.exists(saveloadpath):
        print('Using pretrained model found...')
        checkpoint    = torch.load(saveloadpath)
        start_epoch   = checkpoint['epoch'] +1 # Just becase make sure counting starts from 1, 2, ..., rather than 0, 1, ..., when print the information of start_epoch
        iteration     = checkpoint['iteration']
        best_test_acc = checkpoint['test_accuracy']
        classifier.load_state_dict(checkpoint['model_state_dict'])
    else:
        print('No existing model, starting training from scratch...')
        start_epoch   = 1 # Just becase make sure counting starts from 1, 2, ..., rather than 0, 1, ..., when print the information of start_epoch
        iteration     = 1 # Just becase make sure counting starts from 1, 2, ..., rather than 0, 1, ..., when print the information of iteration
        best_test_acc = 0


    ''' --- CREATE OPTIMIZER ---'''
    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(
            classifier.parameters(), 
            lr=args.lr, 
            momentum=0.9)
    elif args.optimizer == 'ADAM':
        optimizer = torch.optim.Adam(
            classifier.parameters(),
            lr=args.lr,
            betas=(0.9, 0.999),
            eps=1e-08,
            weight_decay=args.decay_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_epoch_half, gamma=0.5) # half(0.5) the learning rate every 'step_size' epochs
    

    # Log info
    printparams = 'Model parameters:' + json.dumps(vars(args), indent=4, sort_keys=True)
    print(printparams)
    tb_writer.add_text('hyper-parameters',printparams,iteration) # tb_writer.add_hparam(args)
    tb_writer.add_text('dataset','dataset sample size: training: {}, test: {}'.format(len(train_dataset),len(test_dataset)),iteration)


    ''' --- START TRANING ---'''
    for epoch in range(start_epoch, args.epoch+1):
    # epoch = start_epoch
        print('Epoch %d/%s:' % (epoch, args.epoch))
        # Add the "learning rate" into tensorboard scalar which will be shown in tensorboard
        tb_writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], iteration)

        # Beware epochs_left = args.epoch - epoch
        for batch_id, data in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9):
            points, target = data   # (B:batch x S:seq x C:features x N:points) , (B x S:seq) 
            # Squeeze to drop Sequence dimension, which is equal to 1, convert all the data to float(otherwise there will be data type problems when running the model) and move to device
            points, target = points.squeeze(dim=1).float().to(device), target.float().to(device) # (B:batch x C:features x N:points) , (B)
            # points, target = points.float().to(device), target.float().to(device)
            # Reset gradients
            optimizer.zero_grad()
            # Sets the module in training mode
            classifier = classifier.train()           
            # Forward propagation
            pred = classifier(points)
            # MLE estimator = min (- log (softmax(x)) ) = min nll_loss(log_softmax(x))
            # loss = F.nll_loss(pred, target.long())
            loss = loss_fun(pred, target.long())
            if args.model_name == 'pointnet':
                loss += feature_transform_regularizer(classifier.trans) * 0.001
                if args.feature_transform:
                    loss +=  feature_transform_regularizer(classifier.trans_feat) * 0.001
            # Back propagate
            loss.backward()
            # Update weights
            optimizer.step()            
            # Log once for every 5 batches, add the "train_loss/cross_entropy" into tensorboard scalar which will be shown in tensorboard
            if not batch_id % 5: tb_writer.add_scalar('train_loss/cross_entropy', loss.item(), iteration)
            iteration += 1

            # Plot train confusion matrix every X steps
            if not iteration % 20:
                confmatrix_train = metrics_confusion_matrix(target, pred)
                print('\nTrain confusion matrix: \n',confmatrix_train)

            # We just finished one epoch
            # if not batch_id+1 % int(train_dataset.len__()/args.batchsize):

        ''' --- TEST NETWORK --- '''
        if not epoch % int(args.test_every_X_epochs): # Doing the following things every epoch.
            # Perform predictions on the training data.
            train_targ, train_pred = test(classifier, train_dataset, device, num_workers=0, batch_size=512)
            # Perform predictions on the testing data.
            test_targ,  test_pred  = test(classifier, test_dataset, device,  num_workers=0, batch_size=512)
            
            # Calculate the accuracy rate for training data.
            train_acc = metrics_accuracy(train_targ, train_pred)
            # Calculate the accuracy rate for testing data.
            test_acc  = metrics_accuracy(test_targ,  test_pred)
            print('\r Training loss: {}'.format(loss.item()))
            print('Train Accuracy: {}\nTest Accuracy: {}'.format(train_acc, test_acc) )
            # Add the "train_acc" "test_acc" into tensorboard scalars which will be shown in tensorboard.                       
            tb_writer.add_scalars('metrics/accuracy', {'train':train_acc, 'test':test_acc}, iteration)
            
            # Calculate confusion matrix.
            confmatrix_test = metrics_confusion_matrix(test_targ, test_pred)
            print('Test confusion matrix: \n',confmatrix_test)
            # Log confusion matrix.
            fig,   ax   = plot_confusion_matrix(confmatrix_test, args.categories, normalize=False, title='Test Confusion Matrix')
            # Log normalized confusion matrix.
            fig_n, ax_n = plot_confusion_matrix(confmatrix_test, args.categories, normalize=True,  title='Test Confusion Matrix - Normalized')
            # Add the "confusion matrix" "normalized confusion matrix" into tensorboard figure which will be shown in tensorboard.
            tb_writer.add_figure('test_confusion_matrix/abs',  fig,   global_step=iteration, close=True)
            tb_writer.add_figure('test_confusion_matrix/norm', fig_n, global_step=iteration, close=True)

            # Log precision recall curves.
            for idx, clsname in enumerate(args.categories):
                # Convert log_softmax to softmax(which is actual probability) and select the desired class.
                test_pred_binary = torch.exp(test_pred[:,idx])
                test_targ_binary = test_targ.eq(idx)
                # Add the "precision recall curves" which will be shown in tensorboard.
                tb_writer.add_pr_curve(tag='pr_curves/'+clsname, labels=test_targ_binary, predictions=test_pred_binary, global_step=iteration)

            # Store the best test accuracy
            if (test_acc >= best_test_acc):
                best_test_acc = max([best_test_acc, test_acc])
                # NOTE: we possibly want to save only when when the best test accuracy is surpassed. For now lets save every X epoch
        
        ''' --- SAVE NETWORK --- '''
        if not epoch % int(args.save_every_X_epochs):
            print('Best Accuracy: %f'%best_test_acc)
            state = {
                'epoch': epoch,
                'iteration': iteration,
                'test_accuracy': best_test_acc,
                'model_state_dict': classifier.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }
            torch.save(state, saveloadpath)
            print('Model saved!!!')
                
            # epoch += 1
            # print('Epoch %d/%s:' % (epoch, args.epoch))
        scheduler.step()
        
    
    tb_writer.close()