def main():
    ''' --- SELECT DEVICES --- '''
    # Select either gpu or cpu
    device = torch.device("cuda" if args.cuda else "cpu")
    # Select among available GPUs
    if args.cuda:
        os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
            str(x) for x in args.gpudevice)
    ''' --- INIT NETWORK MODEL --- '''
    projdir = sys.path[0]
    # Path for saving and loading the network.
    saveloadpath = os.path.join(projdir, 'experiment\\checkpoints',
                                args.exp_name + '.pth')
    Path(os.path.dirname(saveloadpath)).mkdir(exist_ok=True, parents=True)
    # Load selected network model and put it to right device
    if args.model_name == 'pointnet':
        classifier = PointNetCls(dim=args.pointCoordDim,
                                 num_class=len(args.categories),
                                 feature_transform=args.feature_transform)
    elif args.model_name == 'pointnet2':
        classifier = PointNet2ClsMsg(dim=args.pointCoordDim,
                                     num_class=len(args.categories))
    else:
        raise Exception(
            'Argument "model_name" does not match existent networks')
    classifier = classifier.to(device)
    ''' --- LOAD NETWORK IF EXISTS --- '''
    if os.path.exists(saveloadpath):
        print('Using pretrained model found...')
        checkpoint = torch.load(saveloadpath)
        start_epoch = checkpoint[
            'epoch'] + 1  # Just becase make sure counting starts from 1, 2, ..., rather than 0, 1, ..., when print the information of start_epoch
        iteration = checkpoint['iteration']
        best_test_acc = checkpoint['test_accuracy']
        classifier.load_state_dict(checkpoint['model_state_dict'])
    else:
        raise Exception('Model in: {} does not exists'.format(saveloadpath))

    # put classifier in evaluation mode
    classifier = classifier.eval()
    ''' --- INIT DATASETS AND DATALOADER (FOR SINGLE EPOCH) --- '''
    # Ideal for PointNet and pointLSTM - dataloader will return (B:batch, S:seq, C:features, N:points)
    dataTransformations = transforms.Compose([
        ToSeries(),
        DataAugmentation(),
        Resampling(maxPointsPerFrame=10),
        ToTensor()
    ])

    # init dataset
    nusc = NuScenes(version=args.nuscenes_eval_dir,
                    dataroot=args.nuscenes_dir,
                    verbose=True)
    ''' Iterate over samples '''
    idx = list(range(len(nusc.sample)))
    random.shuffle(idx)  # shuffle samples
    # for sample_rec in nusc.sample:
    for i in idx:
        sample_rec = nusc.sample[i]
        sample_token = sample_rec['token']

        # read all sensors data and merge them to a common reference frame
        sensorlist = []
        for sensor in args.sensors:
            sensorlist.append(
                MyRadarPointCloud.from_sample_token(nusc, sample_token,
                                                    sensor))

        global radar_pc
        radar_pc = MyRadarPointCloud.merge(sensorlist, 0)

        # get annotations for actual sample
        ann_tokens = nusc.get('sample', sample_token)['anns']

        # filter annotations that are objects of interest for our classifier
        def get_ann_properties(nusc, ann_token: str):
            ann_rec = nusc.get('sample_annotation', ann_token)
            category = ann_rec['category_name']
            attr_tokens = ann_rec['attribute_tokens']
            attr = [
                nusc.get('attribute', attr_token)['name']
                for attr_token in attr_tokens
            ]
            if len(attr) == 0: attr = ['']
            print('object category:{:30} attr:{:40} center:{:20}'.format(
                category, str(attr),
                str(ann_rec['translation'] - radar_pc.A_cs_2_gl[:3, 3])))
            return category, attr[
                0]  # for now lets return only the first attribute (we might have more per object)

        def isInterestingObject(ann_token):
            cat, att = get_ann_properties(nusc, ann_token)
            # category-attribute pair must match at least one of the desired category-attr pair
            for desired_cat, desired_att in zip(args.categories,
                                                args.attributes):
                if cat in desired_cat and att in desired_att:
                    return True
            return False

        print('\n\nGround-truth objects on this frame:')
        ann_tokens = list(filter(isInterestingObject, ann_tokens))

        def getObjLabel(ann_token):
            cat, att = get_ann_properties(nusc, ann_token)
            for idx, (desired_cat, desired_att) in enumerate(
                    zip(args.categories, args.attributes)):
                if cat in desired_cat and att in desired_att:
                    return idx
            return np.NaN

        labels = list(map(getObjLabel, ann_tokens))
        assert np.all(
            ~np.isnan(labels)
        ), 'Something strange happened... object was selected as interesting but we can not find its label'

        # create bounding boxes from annotations
        ann_boxes = []
        for ann_token, label in zip(ann_tokens, labels):
            box = radar_pc.box(ann_token)
            box.label = label
            box.name = args.objectNames[label]
            ann_boxes.append(box)
        ''' APPLY DBSCAN ON EACH SCENE'''
        from sklearn.cluster import DBSCAN
        points_scene = radar_pc.points  # points_gl = <5,N>
        # filter out objects, whose speed is less than 0.3 m/s
        idx_moving = np.linalg.norm(points_scene[2:4, :].T, axis=1) > 0.25
        points_scene = points_scene[:, idx_moving]
        # apply DBSCAN
        clustering = DBSCAN(eps=3, min_samples=1).fit(points_scene[:2, :].T)
        ''' For each cluster, run network and predict class '''
        pred_results = []
        for cluster_idx in range(max(clustering.labels_)):
            # select points from the current cluster
            points_idx = clustering.labels_ == cluster_idx
            points_obj = points_scene[:, points_idx]
            # apply necessary transformations
            features = dataTransformations(points_obj)
            # convert to torch.tensor
            features = torch.tensor(features).type(
                torch.FloatTensor).unsqueeze(0).to(device)
            # calculate network prediction
            pred = classifier(features).argmax().item()
            # store result
            pred_results.append((points_obj, pred))
        ''' RENDER '''
        OBJCOLORS = ['red', 'green', 'blue', 'grey', 'orange', 'white']
        fig = plt.figure(constrained_layout=True)
        gs = GridSpec(2, 6, figure=fig)
        gs.update(wspace=0, hspace=0)
        axs = [
            fig.add_subplot(gs[1, :3]),
            fig.add_subplot(gs[1, 3:]),
            fig.add_subplot(gs[0, 0]),
            fig.add_subplot(gs[0, 1]),
            fig.add_subplot(gs[0, 2]),
            fig.add_subplot(gs[0, 3]),
            fig.add_subplot(gs[0, 4]),
            fig.add_subplot(gs[0, 5])
        ]

        # render annotation boxes
        for box in ann_boxes:
            color = OBJCOLORS[box.label]
            box.render(axs[0], colors=(color, ) * 3)
            box.render(axs[1], colors=(color, ) * 3)
        # render cars
        radar_pc.car.render(axs[0], colors=('orange', 'k', 'k'))
        radar_pc.car.render(axs[1], colors=('orange', 'k', 'k'))
        # render pointcloud
        radar_pc._render_pc(axs[0], radar_pc.points, color_channel='k')
        for cluster_points, pred in pred_results:
            # radar_pc._render_pc( axs[1], cluster_points, color_channel=OBJCOLORS[pred])
            if pred != 5:
                axs[1].scatter(cluster_points[0, :],
                               cluster_points[1, :],
                               s=30,
                               c=OBJCOLORS[pred],
                               edgecolors='k',
                               linewidths=1,
                               zorder=100)
                # plot ellipses
                if cluster_points.shape[1] > 1:
                    confidence_ellipse(axs[1],
                                       cluster_points[0, :],
                                       cluster_points[1, :],
                                       edgecolor=OBJCOLORS[pred],
                                       facecolor=OBJCOLORS[pred],
                                       alpha=0.5)
        # render camera images
        cameras = [
            'CAM_BACK_LEFT', 'CAM_FRONT_LEFT', 'CAM_FRONT', 'CAM_FRONT_RIGHT',
            'CAM_BACK_RIGHT', 'CAM_BACK'
        ]
        for ax_idx, cam_name in enumerate(cameras):
            cam_token = sample_rec['data'][cam_name]
            cam_rec = nusc.get('sample_data', cam_token)
            img_filename = os.path.join(nusc.dataroot, cam_rec['filename'])
            img = mpimg.imread(img_filename)
            # axs[ax_idx].set_axis_off()
            axs[ax_idx + 2].get_xaxis().set_visible(False)
            axs[ax_idx + 2].get_yaxis().set_visible(False)
            axs[ax_idx + 2].imshow(img)
            axs[ax_idx + 2].set_title(cam_name)
        # format
        legend_elements = [
            Line2D([0], [0], color=color, lw=4, label=name)
            for color, name in zip(OBJCOLORS, args.objectNames)
        ]
        axs[0].legend(handles=legend_elements,
                      loc='upper right',
                      prop={'size': 8})
        axs[1].legend(handles=legend_elements,
                      loc='upper right',
                      prop={'size': 8})
        axs[0].axis('equal')
        axs[1].axis('equal')
        mng = plt.get_current_fig_manager()
        mng.window.showMaximized()
        plt.tight_layout()
        lims = (-60, 60)
        axs[0].set_xlim(lims)
        axs[0].set_ylim(lims)
        axs[0].set_title(
            'Radar detection points and ground truth bounding boxes of MOVING objects'
        )
        axs[1].set_xlim(lims)
        axs[1].set_ylim(lims)
        axs[1].set_title(
            'Segmented segmentation of detection points (DBSCAN + PointNet classifier)\n3-sigma ellipses for each cluster and ground-truth bbs'
        )
        plt.tight_layout()
        plt.show()

        # plot scene with the standard Nuscenes method
        # nusc.render_sample(sample_token)
        # plt.show()

    # sample_token = nusc.get('sample', sample_token)['next']

    tb_writer.close()
def main(args):
    ''' --- SELECT DEVICES --- '''
    # Select either gpu or cpu
    device = torch.device("cuda" if args.cuda else "cpu")
    # Select among available GPUs
    if args.cuda:
        os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
            str(x) for x in args.gpudevice)
    ''' --- CREATE EXPERIMENTS DIRECTORY AND LOGGERS IN TENSORBOARD --- '''
    projdir = sys.path[0]
    # Path for saving and loading the network.
    saveloadpath = os.path.join(projdir, 'experiment\\checkpoints',
                                args.exp_name + '.pth')
    Path(os.path.dirname(saveloadpath)).mkdir(exist_ok=True, parents=True)
    # timestamp = str(datetime.datetime.now().strftime('%Y-%m-%d-%H-%M'))
    tblogdir = os.path.join(projdir, 'experiment\\tensorboardX',
                            args.exp_name)  # + '_' + timestamp )
    Path(tblogdir).mkdir(exist_ok=True, parents=True)
    # Create tb_writer(the writer will be used to write the information on tb) by using SummaryWriter,
    # flush_secs defines how much seconds need to wait for writing information.
    tb_writer = SummaryWriter(logdir=tblogdir,
                              flush_secs=3,
                              write_to_disk=True)
    ''' --- INIT DATASETS AND DATALOADER (FOR SINGLE EPOCH) --- '''
    # Read data from file, and create training data and testing data which are both in multiple frames. Beware Ts is
    # recording for every frame, i.e. every 82ms the automotive radar records once to form single frame(We need this information for LSTM).
    train_dataset, test_dataset, class_names = read_dataset(
        args.datapath, Ts=0.082, train_test_split=0.8)

    # Prepare the traing and testing dataset. both trainDataset and testDataset are dataset have multiple frames data,
    # for each frame it contains the "unified number of detection points"(NMAX detection points per frame).

    # Init test dataset(Beware we should NOT use data augmentation for test dataset)
    test_dataTransformations = transforms.Compose(
        [NormalizeTime(), Resampling(maxPointsPerFrame=10)])
    testDataset = RadarClassDataset(dataset=test_dataset,
                                    transforms=test_dataTransformations,
                                    sequence_length=1)
    # Init train datasets
    train_dataTransformations = transforms.Compose([
        NormalizeTime(),
        DataAugmentation(),
        Resampling(maxPointsPerFrame=10)
    ])
    trainDataset = RadarClassDataset(dataset=train_dataset,
                                     transforms=train_dataTransformations,
                                     sequence_length=1)
    # Create dataloader for training by using batch_size frames' data in each batch
    trainDataLoader = DataLoader(trainDataset,
                                 batch_size=args.batchsize,
                                 shuffle=True,
                                 num_workers=args.num_workers)
    ''' --- INIT NETWORK MODEL --- '''
    # Load selected network model and put it to right device
    if args.model_name == 'pointnet':
        classifier = PointNetCls(dim=args.pointCoordDim,
                                 num_class=args.numclasses,
                                 feature_transform=args.feature_transform)
    elif args.model_name == 'pointnet2':
        classifier = PointNet2ClsMsg(
            dim=args.pointCoordDim,
            num_class=args.numclasses,
        )
    else:
        raise Exception(
            'Argument "model_name" does not match existent networks')
    classifier = classifier.to(device)
    ''' --- LOAD NETWORK IF EXISTS --- '''
    if os.path.exists(saveloadpath):
        print('Using pretrained model found...')
        checkpoint = torch.load(saveloadpath)
        start_epoch = checkpoint[
            'epoch'] + 1  # Just becase make sure counting starts from 1, 2, ..., rather than 0, 1, ..., when print the information of start_epoch
        iteration = checkpoint['iteration']
        best_test_acc = checkpoint['test_accuracy']
        classifier.load_state_dict(checkpoint['model_state_dict'])
    else:
        print('No existing model, starting training from scratch...')
        start_epoch = 1  # Just becase make sure counting starts from 1, 2, ..., rather than 0, 1, ..., when print the information of start_epoch
        iteration = 1  # Just becase make sure counting starts from 1, 2, ..., rather than 0, 1, ..., when print the information of iteration
        best_test_acc = 0
    ''' --- CREATE OPTIMIZER ---'''
    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(classifier.parameters(),
                                    lr=args.lr,
                                    momentum=0.9)
    elif args.optimizer == 'ADAM':
        optimizer = torch.optim.Adam(classifier.parameters(),
                                     lr=args.lr,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=args.decay_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, step_size=args.lr_epoch_half,
        gamma=0.5)  # half(0.5) the learning rate every 'step_size' epochs

    # log info
    printparams = 'Model parameters:' + json.dumps(
        vars(args), indent=4, sort_keys=True)
    print(printparams)
    tb_writer.add_text('hyper-parameters', printparams,
                       iteration)  # tb_writer.add_hparam(args)
    tb_writer.add_text(
        'dataset', 'dataset sample size: training: {}, test: {}'.format(
            train_dataset.shape[0], test_dataset.shape[0]), iteration)
    ''' --- START TRANING ---'''
    for epoch in range(start_epoch, args.epoch + 1):
        print('Epoch %d/%s:' % (epoch, args.epoch))

        # Add the "learning rate" into tensorboard scalar which will be shown in tensorboard
        tb_writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'],
                             iteration)

        for batch_id, data in tqdm(enumerate(trainDataLoader, 0),
                                   total=len(trainDataLoader),
                                   smoothing=0.9):
            points, target = data  # (B:batch x S:seq x C:features x N:points) , (B x S:seq)
            # Squeeze to drop Sequence dimension, which is equal to 1, convert all the data to float(otherwise there will be data type problems when running the model) and move to device
            points, target = points.squeeze(
                dim=1).float().to(device), target.float().to(
                    device)  # (B:batch x C:features x N:points) , (B)
            # points, target = points.float().to(device), target.float().to(device)
            # Reset gradients
            optimizer.zero_grad()
            # Sets the module in training mode
            classifier = classifier.train()
            # Forward propagation
            pred = classifier(points)
            # Calculate cross entropy loss (In the pointnet/pointnet2 network model, it outputs log_softmax result. Since
            # "log_softmax -> nll_loss" == CrossEntropyLoss, so that we just need to call F.nll_loss)
            loss = F.nll_loss(pred, target.long())
            if args.model_name == 'pointnet':
                loss += feature_transform_regularizer(classifier.trans) * 0.001
                if args.feature_transform:
                    loss += feature_transform_regularizer(
                        classifier.trans_feat) * 0.001
            # Back propagate
            loss.backward()
            # Update weights
            optimizer.step()
            # Log once for every 5 batches, add the "train_loss/cross_entropy" into tensorboard scalar which will be shown in tensorboard
            if not batch_id % 5:
                tb_writer.add_scalar('train_loss/cross_entropy', loss.item(),
                                     iteration)
            iteration += 1
            # if batch_id> 2: break

        scheduler.step()
        ''' --- TEST AND SAVE NETWORK --- '''
        if not epoch % 10:  # Doing the following things every epoch.
            # Perform predictions on the training data.
            train_targ, train_pred = test(classifier,
                                          trainDataset,
                                          device,
                                          num_workers=args.num_workers,
                                          batch_size=1800)
            # Perform predictions on the testing data.
            test_targ, test_pred = test(classifier,
                                        testDataset,
                                        device,
                                        num_workers=args.num_workers,
                                        batch_size=1800)

            # Calculate the accuracy rate for training data.
            train_acc = metrics_accuracy(train_targ, train_pred)
            # Calculate the accuracy rate for testing data.
            test_acc = metrics_accuracy(test_targ, test_pred)
            print('\r Training loss: {}'.format(loss.item()))
            print('Train Accuracy: {}\nTest Accuracy: {}'.format(
                train_acc, test_acc))
            # Add the "train_acc" "test_acc" into tensorboard scalars which will be shown in tensorboard.
            tb_writer.add_scalars('metrics/accuracy', {
                'train': train_acc,
                'test': test_acc
            }, iteration)

            # Calculate confusion matrix.
            confmatrix_test = metrics_confusion_matrix(test_targ, test_pred)
            print('Test confusion matrix: \n', confmatrix_test)
            # Log confusion matrix.
            fig, ax = plot_confusion_matrix(confmatrix_test,
                                            class_names,
                                            normalize=False,
                                            title='Test Confusion Matrix')
            # Log normalized confusion matrix.
            fig_n, ax_n = plot_confusion_matrix(
                confmatrix_test,
                class_names,
                normalize=True,
                title='Test Confusion Matrix - Normalized')
            # Add the "confusion matrix" "normalized confusion matrix" into tensorboard figure which will be shown in tensorboard.
            tb_writer.add_figure('test_confusion_matrix/abs',
                                 fig,
                                 global_step=iteration,
                                 close=True)
            tb_writer.add_figure('test_confusion_matrix/norm',
                                 fig_n,
                                 global_step=iteration,
                                 close=True)

            # Log precision recall curves.
            for idx, clsname in enumerate(class_names):
                # Convert log_softmax to softmax(which is actual probability) and select the desired class.
                test_pred_binary = torch.exp(test_pred[:, idx])
                test_targ_binary = test_targ.eq(idx)
                # Add the "precision recall curves" which will be shown in tensorboard.
                tb_writer.add_pr_curve(tag='pr_curves/' + clsname,
                                       labels=test_targ_binary,
                                       predictions=test_pred_binary,
                                       global_step=iteration)
            ''' --- SAVE NETWORK --- '''
            # if (test_acc >= best_test_acc): # For now lets save every time, since we are only testing in a subset of the test dataset
            best_test_acc = test_acc  # if test_acc > best_test_acc else best_test_acc
            state = {
                'epoch': epoch,
                'iteration': iteration,
                'train_accuracy': train_acc if args.train_metric else 0.0,
                'test_accuracy': best_test_acc,
                'model_state_dict': classifier.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }
            torch.save(state, saveloadpath)
            print('Model saved!!!')

    print('Best Accuracy: %f' % best_test_acc)

    tb_writer.close()
예제 #3
0
def main():    

    ''' --- SELECT DEVICES --- '''
    # Select either gpu or cpu
    device = torch.device("cuda" if args.cuda else "cpu")
    # Select among available GPUs
    if args.cuda: os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.gpudevice)
    

    ''' --- CREATE EXPERIMENTS DIRECTORY AND LOGGERS IN TENSORBOARD --- '''
    projdir = sys.path[0]
    # Path for saving and loading the network.
    saveloadpath = os.path.join( projdir, 'experiment\\checkpoints', args.exp_name+'.pth')
    Path(os.path.dirname(saveloadpath)).mkdir(exist_ok=True, parents=True)
    # timestamp = str(datetime.datetime.now().strftime('%Y-%m-%d-%H-%M'))
    tblogdir = os.path.join( projdir, 'experiment\\tensorboardX', args.exp_name ) # + '_' + timestamp )
    Path(tblogdir).mkdir(exist_ok=True, parents=True)
    # Create tb_writer(the writer will be used to write the information on tb) by using SummaryWriter, 
    # flush_secs defines how much seconds need to wait for writing information.
    tb_writer = SummaryWriter( logdir=tblogdir, flush_secs=3, write_to_disk=True)


    ''' --- INIT DATASETS AND DATALOADER (FOR SINGLE EPOCH) --- '''
    # Ideal for PointNet and pointLSTM - dataloader will return (B:batch, S:seq, C:features, N:points)
    dataTransformations = transforms.Compose([
        ToSeries(),
        DataAugmentation(),
        Resampling(maxPointsPerFrame=10),
        ToTensor()
    ])
    # Init nuScenes datasets
    nusc_train = NuScenes(version=args.nuscenes_train_dir, dataroot=args.nuscenes_dir, verbose=True)
    train_dataset = RadarClassDataset(nusc_train, categories=args.categories, sensors=args.sensors, transforms=dataTransformations, sequence_length=1)
    nusc_test = NuScenes(version=args.nuscenes_test_dir, dataroot=args.nuscenes_dir, verbose=True)
    test_dataset = RadarClassDataset(nusc_test, categories=args.categories, sensors=args.sensors, transforms=dataTransformations, sequence_length=1)
    # Init training data loader
    trainDataLoader = DataLoader(train_dataset, batch_size=args.batchsize, shuffle=True, num_workers=args.num_workers)

    ''' --- INIT NETWORK MODEL --- '''
    # Load selected network model and put it to right device
    if args.model_name == 'pointnet':
        classifier = PointNetCls(dim=args.pointCoordDim, num_class=len(args.categories), feature_transform=args.feature_transform)  
    elif args.model_name == 'pointnet2':
        classifier = PointNet2ClsMsg(dim=args.pointCoordDim, num_class=len(args.categories) )
    else:
        raise Exception('Argument "model_name" does not match existent networks')
    classifier = classifier.to(device)

    ''' --- INIT LOSS FUNCTION --- '''
    loss_fun = FocalLoss(gamma=args.focalLoss_gamma, num_classes=len(args.categories), alpha=args.weight_cat).to(device)

    ''' --- LOAD NETWORK IF EXISTS --- '''
    if os.path.exists(saveloadpath):
        print('Using pretrained model found...')
        checkpoint    = torch.load(saveloadpath)
        start_epoch   = checkpoint['epoch'] +1 # Just becase make sure counting starts from 1, 2, ..., rather than 0, 1, ..., when print the information of start_epoch
        iteration     = checkpoint['iteration']
        best_test_acc = checkpoint['test_accuracy']
        classifier.load_state_dict(checkpoint['model_state_dict'])
    else:
        print('No existing model, starting training from scratch...')
        start_epoch   = 1 # Just becase make sure counting starts from 1, 2, ..., rather than 0, 1, ..., when print the information of start_epoch
        iteration     = 1 # Just becase make sure counting starts from 1, 2, ..., rather than 0, 1, ..., when print the information of iteration
        best_test_acc = 0


    ''' --- CREATE OPTIMIZER ---'''
    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(
            classifier.parameters(), 
            lr=args.lr, 
            momentum=0.9)
    elif args.optimizer == 'ADAM':
        optimizer = torch.optim.Adam(
            classifier.parameters(),
            lr=args.lr,
            betas=(0.9, 0.999),
            eps=1e-08,
            weight_decay=args.decay_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_epoch_half, gamma=0.5) # half(0.5) the learning rate every 'step_size' epochs
    

    # Log info
    printparams = 'Model parameters:' + json.dumps(vars(args), indent=4, sort_keys=True)
    print(printparams)
    tb_writer.add_text('hyper-parameters',printparams,iteration) # tb_writer.add_hparam(args)
    tb_writer.add_text('dataset','dataset sample size: training: {}, test: {}'.format(len(train_dataset),len(test_dataset)),iteration)


    ''' --- START TRANING ---'''
    for epoch in range(start_epoch, args.epoch+1):
    # epoch = start_epoch
        print('Epoch %d/%s:' % (epoch, args.epoch))
        # Add the "learning rate" into tensorboard scalar which will be shown in tensorboard
        tb_writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], iteration)

        # Beware epochs_left = args.epoch - epoch
        for batch_id, data in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9):
            points, target = data   # (B:batch x S:seq x C:features x N:points) , (B x S:seq) 
            # Squeeze to drop Sequence dimension, which is equal to 1, convert all the data to float(otherwise there will be data type problems when running the model) and move to device
            points, target = points.squeeze(dim=1).float().to(device), target.float().to(device) # (B:batch x C:features x N:points) , (B)
            # points, target = points.float().to(device), target.float().to(device)
            # Reset gradients
            optimizer.zero_grad()
            # Sets the module in training mode
            classifier = classifier.train()           
            # Forward propagation
            pred = classifier(points)
            # MLE estimator = min (- log (softmax(x)) ) = min nll_loss(log_softmax(x))
            # loss = F.nll_loss(pred, target.long())
            loss = loss_fun(pred, target.long())
            if args.model_name == 'pointnet':
                loss += feature_transform_regularizer(classifier.trans) * 0.001
                if args.feature_transform:
                    loss +=  feature_transform_regularizer(classifier.trans_feat) * 0.001
            # Back propagate
            loss.backward()
            # Update weights
            optimizer.step()            
            # Log once for every 5 batches, add the "train_loss/cross_entropy" into tensorboard scalar which will be shown in tensorboard
            if not batch_id % 5: tb_writer.add_scalar('train_loss/cross_entropy', loss.item(), iteration)
            iteration += 1

            # Plot train confusion matrix every X steps
            if not iteration % 20:
                confmatrix_train = metrics_confusion_matrix(target, pred)
                print('\nTrain confusion matrix: \n',confmatrix_train)

            # We just finished one epoch
            # if not batch_id+1 % int(train_dataset.len__()/args.batchsize):

        ''' --- TEST NETWORK --- '''
        if not epoch % int(args.test_every_X_epochs): # Doing the following things every epoch.
            # Perform predictions on the training data.
            train_targ, train_pred = test(classifier, train_dataset, device, num_workers=0, batch_size=512)
            # Perform predictions on the testing data.
            test_targ,  test_pred  = test(classifier, test_dataset, device,  num_workers=0, batch_size=512)
            
            # Calculate the accuracy rate for training data.
            train_acc = metrics_accuracy(train_targ, train_pred)
            # Calculate the accuracy rate for testing data.
            test_acc  = metrics_accuracy(test_targ,  test_pred)
            print('\r Training loss: {}'.format(loss.item()))
            print('Train Accuracy: {}\nTest Accuracy: {}'.format(train_acc, test_acc) )
            # Add the "train_acc" "test_acc" into tensorboard scalars which will be shown in tensorboard.                       
            tb_writer.add_scalars('metrics/accuracy', {'train':train_acc, 'test':test_acc}, iteration)
            
            # Calculate confusion matrix.
            confmatrix_test = metrics_confusion_matrix(test_targ, test_pred)
            print('Test confusion matrix: \n',confmatrix_test)
            # Log confusion matrix.
            fig,   ax   = plot_confusion_matrix(confmatrix_test, args.categories, normalize=False, title='Test Confusion Matrix')
            # Log normalized confusion matrix.
            fig_n, ax_n = plot_confusion_matrix(confmatrix_test, args.categories, normalize=True,  title='Test Confusion Matrix - Normalized')
            # Add the "confusion matrix" "normalized confusion matrix" into tensorboard figure which will be shown in tensorboard.
            tb_writer.add_figure('test_confusion_matrix/abs',  fig,   global_step=iteration, close=True)
            tb_writer.add_figure('test_confusion_matrix/norm', fig_n, global_step=iteration, close=True)

            # Log precision recall curves.
            for idx, clsname in enumerate(args.categories):
                # Convert log_softmax to softmax(which is actual probability) and select the desired class.
                test_pred_binary = torch.exp(test_pred[:,idx])
                test_targ_binary = test_targ.eq(idx)
                # Add the "precision recall curves" which will be shown in tensorboard.
                tb_writer.add_pr_curve(tag='pr_curves/'+clsname, labels=test_targ_binary, predictions=test_pred_binary, global_step=iteration)

            # Store the best test accuracy
            if (test_acc >= best_test_acc):
                best_test_acc = max([best_test_acc, test_acc])
                # NOTE: we possibly want to save only when when the best test accuracy is surpassed. For now lets save every X epoch
        
        ''' --- SAVE NETWORK --- '''
        if not epoch % int(args.save_every_X_epochs):
            print('Best Accuracy: %f'%best_test_acc)
            state = {
                'epoch': epoch,
                'iteration': iteration,
                'test_accuracy': best_test_acc,
                'model_state_dict': classifier.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }
            torch.save(state, saveloadpath)
            print('Model saved!!!')
                
            # epoch += 1
            # print('Epoch %d/%s:' % (epoch, args.epoch))
        scheduler.step()
        
    
    tb_writer.close()