Example #1
0
def iterate(mode, args, loader, model, optimizer, logger, epoch):
    actual_epoch = epoch - args.start_epoch + args.start_epoch_bias

    block_average_meter = AverageMeter()
    block_average_meter.reset(False)
    average_meter = AverageMeter()
    meters = [block_average_meter, average_meter]

    # switch to appropriate mode
    assert mode in ["train", "val", "eval", "test_prediction", "test_completion"], \
        "unsupported mode: {}".format(mode)
    if mode == 'train':
        model.train()
        lr = helper.adjust_learning_rate(args.lr, optimizer, actual_epoch,
                                         args)
    else:
        model.eval()
        lr = 0

    torch.cuda.empty_cache()
    for i, batch_data in enumerate(loader):
        dstart = time.time()
        batch_data = {
            key: val.to(device)
            for key, val in batch_data.items() if val is not None
        }

        gt = batch_data[
            'gt'] if mode != 'test_prediction' and mode != 'test_completion' else None
        data_time = time.time() - dstart

        pred = None
        start = None
        gpu_time = 0

        #start = time.time()
        #pred = model(batch_data)
        #gpu_time = time.time() - start

        #'''
        if (args.network_model == 'e'):
            start = time.time()
            st1_pred, st2_pred, pred = model(batch_data)
        else:
            start = time.time()
            pred = model(batch_data)

        if (args.evaluate):
            gpu_time = time.time() - start
        #'''

        depth_loss, photometric_loss, smooth_loss, mask = 0, 0, 0, None

        # inter loss_param
        st1_loss, st2_loss, loss = 0, 0, 0
        w_st1, w_st2 = 0, 0
        round1, round2, round3 = 1, 3, None
        if (actual_epoch <= round1):
            w_st1, w_st2 = 0.2, 0.2
        elif (actual_epoch <= round2):
            w_st1, w_st2 = 0.05, 0.05
        else:
            w_st1, w_st2 = 0, 0

        if mode == 'train':
            # Loss 1: the direct depth supervision from ground truth label
            # mask=1 indicates that a pixel does not ground truth labels
            depth_loss = depth_criterion(pred, gt)

            if args.network_model == 'e':
                st1_loss = depth_criterion(st1_pred, gt)
                st2_loss = depth_criterion(st2_pred, gt)
                loss = (1 - w_st1 - w_st2
                        ) * depth_loss + w_st1 * st1_loss + w_st2 * st2_loss
            else:
                loss = depth_loss

            if i % multi_batch_size == 0:
                optimizer.zero_grad()
            loss.backward()

            if i % multi_batch_size == (multi_batch_size -
                                        1) or i == (len(loader) - 1):
                optimizer.step()
            print("loss:", loss, " epoch:", epoch, " ", i, "/", len(loader))

        if mode == "test_completion":
            str_i = str(i)
            path_i = str_i.zfill(10) + '.png'
            path = os.path.join(args.data_folder_save, path_i)
            vis_utils.save_depth_as_uint16png_upload(pred, path)

        if (not args.evaluate):
            gpu_time = time.time() - start
        # measure accuracy and record loss
        with torch.no_grad():
            mini_batch_size = next(iter(batch_data.values())).size(0)
            result = Result()
            if mode != 'test_prediction' and mode != 'test_completion':
                result.evaluate(pred.data, gt.data, photometric_loss)
                [
                    m.update(result, gpu_time, data_time, mini_batch_size)
                    for m in meters
                ]

                if mode != 'train':
                    logger.conditional_print(mode, i, epoch, lr, len(loader),
                                             block_average_meter,
                                             average_meter)
                logger.conditional_save_img_comparison(mode, i, batch_data,
                                                       pred, epoch)
                logger.conditional_save_pred(mode, i, pred, epoch)

    avg = logger.conditional_save_info(mode, average_meter, epoch)
    is_best = logger.rank_conditional_save_best(mode, avg, epoch)
    if is_best and not (mode == "train"):
        logger.save_img_comparison_as_best(mode, epoch)
    logger.conditional_summarize(mode, avg, is_best)

    return avg, is_best
Example #2
0
                writer.add_scalar('loss/val_loss', val_loss_meter.avg, i)
                writer.add_scalar('psnr/val_psnr', val_psnr_meter.avg, i)

                format_str = '===> Iter [{:d}/{:d}] Val_Loss: {:.6f}, Val_PSNR: {:.4f}'
                print(
                    format_str.format(i, config['training']['iterations'],
                                      val_loss_meter.avg, val_psnr_meter.avg))
                sys.stdout.flush()

                if val_psnr_meter.avg >= best_val_psnr:
                    best_val_psnr = val_psnr_meter.avg
                    ckpt = {
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict(),
                        'best_val_psnr': best_val_psnr,
                        'iter': i
                    }
                    path = '{}/{}/{}_{}.pth'.format(
                        config['training']['checkpoint_folder'],
                        os.path.basename(args.config)[:-5],
                        os.path.basename(args.config)[:-5], i)
                    torch.save(ckpt, path)

                val_loss_meter.reset()
                val_psnr_meter.reset()

            if i >= config['training']['iterations']:
                still_training = False
                break
Example #3
0
def train(cfg, writer, logger):
    init_random()

    device = torch.device("cuda:{}".format(cfg['model']['default_gpu'])
                          if torch.cuda.is_available() else 'cpu')

    # create dataSet
    data_sets = create_dataset(cfg, writer, logger)  # source_train\ target_train\ source_valid\ target_valid + _loader
    if cfg.get('valset') == 'gta5':
        val_loader = data_sets.source_valid_loader
    else:
        val_loader = data_sets.target_valid_loader
    logger.info('source train batchsize is {}'.format(data_sets.source_train_loader.args.get('batch_size')))
    print('source train batchsize is {}'.format(data_sets.source_train_loader.args.get('batch_size')))
    logger.info('target train batchsize is {}'.format(data_sets.target_train_loader.batch_size))
    print('target train batchsize is {}'.format(data_sets.target_train_loader.batch_size))
    logger.info('valset is {}'.format(cfg.get('valset')))
    print('val_set is {}'.format(cfg.get('valset')))
    logger.info('val batch_size is {}'.format(val_loader.batch_size))
    print('val batch_size is {}'.format(val_loader.batch_size))

    # create model
    model = CustomModel(cfg, writer, logger)

    # LOSS function
    loss_fn = get_loss_function(cfg)

    # load category anchors
    objective_vectors = torch.load('category_anchors')
    model.objective_vectors = objective_vectors['objective_vectors']
    model.objective_vectors_num = objective_vectors['objective_num']

    # Setup Metrics
    running_metrics_val = RunningScore(cfg['data']['target']['n_class'])
    source_running_metrics_val = RunningScore(cfg['data']['source']['n_class'])
    val_loss_meter, source_val_loss_meter = AverageMeter(), AverageMeter()
    time_meter = AverageMeter()

    # begin training
    model.iter = 0
    epochs = cfg['training']['epochs']
    for epoch in tqdm(range(epochs)):
        if model.iter > cfg['training']['train_iters']:
            break

        for (target_image, target_label, target_img_name) in tqdm(data_sets.target_train_loader):
            start_ts = time.time()
            model.iter += 1
            if model.iter > cfg['training']['train_iters']:
                break

            ############################
            # train on source & target #
            ############################
            # get data
            images, labels, source_img_name = data_sets.source_train_loader.next()
            images, labels = images.to(device), labels.to(device)
            target_image, target_label = target_image.to(device), target_label.to(device)

            # init model
            model.train(logger=logger)
            if cfg['training'].get('freeze_bn'):
                model.freeze_bn_apply()
            model.optimizer_zero_grad()

            # train for one batch
            loss, loss_cls_L2, loss_pseudo = model.step(images, labels, target_image, target_label)
            model.scheduler_step()

            if loss_cls_L2 > 10:
                logger.info('loss_cls_l2 abnormal!!')

            # print
            time_meter.update(time.time() - start_ts)
            if (model.iter + 1) % cfg['training']['print_interval'] == 0:
                unchanged_cls_num = 0
                fmt_str = "Epoches [{:d}/{:d}] Iter [{:d}/{:d}]  Loss: {:.4f} " \
                          "Loss_cls_L2: {:.4f}  Loss_pseudo: {:.4f}  Time/Image: {:.4f} "
                print_str = fmt_str.format(epoch + 1, epochs, model.iter + 1, cfg['training']['train_iters'],
                                           loss.item(), loss_cls_L2, loss_pseudo,
                                           time_meter.avg / cfg['data']['source']['batch_size'])

                print(print_str)
                logger.info(print_str)
                logger.info('unchanged number of objective class vector: {}'.format(unchanged_cls_num))
                writer.add_scalar('loss/train_loss', loss.item(), model.iter + 1)
                writer.add_scalar('loss/train_cls_L2Loss', loss_cls_L2, model.iter + 1)
                writer.add_scalar('loss/train_pseudoLoss', loss_pseudo, model.iter + 1)
                time_meter.reset()

                score_cl, _ = model.metrics.running_metrics_val_clusters.get_scores()
                logger.info('clus_IoU: {}'.format(score_cl["Mean IoU : \t"]))
                logger.info('clus_Recall: {}'.format(model.metrics.calc_mean_Clu_recall()))
                logger.info('clus_Acc: {}'.format(
                    np.mean(model.metrics.classes_recall_clu[:, 0] / model.metrics.classes_recall_clu[:, 2])))

                score_cl, _ = model.metrics.running_metrics_val_threshold.get_scores()
                logger.info('thr_IoU: {}'.format(score_cl["Mean IoU : \t"]))
                logger.info('thr_Recall: {}'.format(model.metrics.calc_mean_Thr_recall()))
                logger.info('thr_Acc: {}'.format(
                    np.mean(model.metrics.classes_recall_thr[:, 0] / model.metrics.classes_recall_thr[:, 2])))

            # evaluation
            if (model.iter + 1) % cfg['training']['val_interval'] == 0 or \
                    (model.iter + 1) == cfg['training']['train_iters']:
                validation(model, logger, writer, data_sets, device, running_metrics_val, val_loss_meter, loss_fn,
                           source_val_loss_meter, source_running_metrics_val, iters=model.iter)

                torch.cuda.empty_cache()
                logger.info('Best iou until now is {}'.format(model.best_iou))

            # monitoring the accuracy and recall of CAG-based PLA and probability-based PLA
            monitor(model)

            model.metrics.reset()
Example #4
0
def train(model,train_loader,valid_loader1,valid_loader2,optimizer,scheduler,num_epochs,eval_every,margin,device,name):
    IOU_list = []
    best_IOU = 1
    global_step = 0
    train_loss = AverageMeter()
    local_train_loss = AverageMeter()
    best_train_loss = float("Inf")
    best_val_loss = float("Inf")
    loss_list = []
    total_step = len(train_loader)*num_epochs
    print(f'total steps: {total_step}')
    for epoch in range(num_epochs):
        print(f'epoch {epoch+1}')
        #losses = []
        for _, data in enumerate(tqdm(train_loader)):
            model.train()
            inputs = data['image'].to(device) # inputs
            target = data['target'].to(device) # targets
            embeddings = model(inputs)
            loss= batch_hard_triplet_loss(target, embeddings, margin=margin)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss.update(loss.item())
            local_train_loss.update(loss.item())
            global_step += 1
            current_lr = optimizer.param_groups[0]['lr']
            ### print
            if global_step % eval_every == 0:
                print('Epoch [{}/{}], Step [{}/{}], Train Loss: {:.4f} ({:.4f}), lr: {:.4f}'
                          .format(epoch+1, num_epochs, global_step, total_step, local_train_loss.avg, train_loss.avg, current_lr))
                if local_train_loss.avg < best_train_loss:
                    best_train_loss = local_train_loss.avg
                    print('Best trian loss:',local_train_loss.avg)
                loss_list.append(local_train_loss.avg)
                local_train_loss.reset()
        # valid
        with torch.no_grad():
            model.eval()
            val_loss = AverageMeter()
            for _, valid_data in enumerate(valid_loader1):
                inputs = valid_data['image'].to(device) # inputs
                target = valid_data['target'].to(device) # targets
                embeddings = model(inputs)
                valid_loss= batch_hard_triplet_loss(target, embeddings, margin=margin)
                val_loss.update(valid_loss.item())
        dist1 = result(model,valid_loader1,device, loss_fn='triplet')
        dist2 = result(model,valid_loader2,device, loss_fn='triplet')
        try:
            same_hist = plt.hist(dist1, 100, range=[np.floor(np.min([dist1.min(), dist2.min()])),np.ceil(np.max([dist1.max(), dist2.max()]))], alpha=0.5, label='same')
            diff_hist = plt.hist(dist2, 100, range=[np.floor(np.min([dist1.min(), dist2.min()])),np.ceil(np.max([dist1.max(), dist2.max()]))], alpha=0.5, label='diff')
            plt.legend(loc='upper right')
            plt.savefig('result/distribution_epoch'+str(epoch+1)+'.png')
            difference = same_hist[0] - diff_hist[0]
            difference[:same_hist[0].argmax()] = np.Inf
            difference[diff_hist[0].argmax():] = np.Inf
            dist_threshold = (same_hist[1][np.where(difference <= 0)[0].min()] + same_hist[1][np.where(difference <= 0)[0].min() - 1])/2
            overlap = np.sum(dist1>=dist_threshold) + np.sum(dist2<=dist_threshold)
            IOU = overlap / (dist1.shape[0] * 2 - overlap)
        except:
            print("Model results in collapse") # if the collapse to 0 then, the result cannot be printed

        print('dist_threshold:',dist_threshold,'overlap:',overlap,'IOU:',IOU)
        plt.clf()
        IOU_list.append(IOU)
        if IOU < best_IOU:
            best_IOU = IOU
            save(name,model,optimizer,scheduler)

        print('Valid loss:',val_loss.avg)
        if val_loss.avg < best_val_loss:
            best_val_loss = val_loss.avg
            print(best_val_loss)
        val_loss.reset()
        # count the step of the scheduler in each epoch
        scheduler.step()
    # loss graph
    steps = range(len(loss_list))
    plt.plot(steps, loss_list)
    plt.title('Train loss')
    plt.ylabel('Loss')
    plt.xlabel('Steps')
    plt.savefig('train_loss.png')
    plt.clf()
    print('Finished Training')
Example #5
0
def train2(model,train_loader,valid_loader1,valid_loader2,metric_crit,optimizer,scheduler,num_epochs,eval_every,num_class,device,name):
    IOU_list = []
    best_IOU = 1
    global_step = 0
    # loss from 1st epoch to nth epoch
    train_loss = AverageMeter()
    # loss to n step/eval_every
    local_train_loss = AverageMeter()
    # a list host the loss every eval_every
    loss_list = []
    best_train_loss = float("Inf")
    total_step = len(train_loader)*num_epochs
    print(f'total steps: {total_step}')
    for epoch in range(num_epochs):
        print(f'epoch {epoch+1}')
        for _, data in enumerate(tqdm(train_loader)):
            model.train()
            # original image
            inputs = data['image'].to(device)
            #targets = data['target'].to(device)
            outputs = model(inputs)
            loss = loss_fn(metric_crit, data, outputs, num_class, device)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss.update(loss.cpu().item(), inputs.size(0))
            local_train_loss.update(loss.cpu().item(), inputs.size(0))
            global_step += 1
            current_lr = optimizer.param_groups[0]['lr']
            if global_step % eval_every == 0:
                print('Epoch [{}/{}], Step [{}/{}], Train Loss: {:.4f} ({:.4f}), lr: {:.4f}'
                          .format(epoch+1, num_epochs, global_step, total_step, local_train_loss.avg, train_loss.avg, current_lr))
                if local_train_loss.avg < best_train_loss:
                    best_train_loss = local_train_loss.avg
                    print('Best trian loss:',local_train_loss.avg)
                loss_list.append(local_train_loss.avg)
                local_train_loss.reset()
        # val
        dist1 = result(model,valid_loader1,device, loss_fn='arcface')
        dist2 = result(model,valid_loader2,device, loss_fn='arcface')
        try:
            same_hist = plt.hist(dist1, 100, range=[np.floor(np.min([dist1.min(), dist2.min()])),np.ceil(np.max([dist1.max(), dist2.max()]))], alpha=0.5, label='same')
            diff_hist = plt.hist(dist2, 100, range=[np.floor(np.min([dist1.min(), dist2.min()])),np.ceil(np.max([dist1.max(), dist2.max()]))], alpha=0.5, label='diff')
            plt.legend(loc='upper right')
            plt.savefig('result/distribution_epoch'+str(epoch+1)+'.png')
            difference = same_hist[0] - diff_hist[0]
            difference[:same_hist[0].argmax()] = np.Inf
            difference[diff_hist[0].argmax():] = np.Inf
            dist_threshold = (same_hist[1][np.where(difference <= 0)[0].min()] + same_hist[1][np.where(difference <= 0)[0].min() - 1])/2
            overlap = np.sum(dist1>=dist_threshold) + np.sum(dist2<=dist_threshold)
            IOU = overlap / (dist1.shape[0] * 2 - overlap)
        except:
            print("Model results in collapse") # if the collapse to 0 then, the result cannot be printed
        print('dist_threshold:',dist_threshold,'overlap:',overlap,'IOU:',IOU)
        plt.clf()
        IOU_list.append(IOU)
        if IOU < best_IOU:
            best_IOU = IOU
            save(name,model,optimizer,scheduler)
        scheduler.step()
    # loss graph
    steps = range(len(loss_list))
    plt.plot(steps, loss_list)
    plt.title('Train loss')
    plt.ylabel('Loss')
    plt.xlabel('Steps')
    plt.savefig('train_loss.png')
    plt.clf()
    print('Finished Training')
Example #6
0
def train(model, train_loader, valid_loader, valid_loader1, valid_loader2,
          optimizer, scheduler, num_epochs, eval_every, margin, device, name):
    epoch_loss_list = {'train': [], 'valid': []}
    IOU_list = []
    global_step = 0
    train_loss = AverageMeter()
    valid_loss = AverageMeter()
    best_IOU = 1
    #train_average_margin = AverageMeter()
    total_step = len(train_loader) * num_epochs
    count = 0
    print(f'total steps: {total_step}')
    for epoch in range(num_epochs):
        print(f'epoch {epoch+1}')
        #losses = []
        for _, data in enumerate(tqdm(train_loader)):
            if count > 0:
                break
            count += 1
            model.train()
            inputs = data['image'].to(device)  # inputs
            target = data['target'].to(device)  # targets
            embeddings = model(inputs)
            loss, _ = batch_all_triplet_loss(target,
                                             embeddings,
                                             margin=margin,
                                             epoch=epoch)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss.update(loss.item())
            global_step += 1
            current_lr = optimizer.param_groups[0]['lr']

            ### print
            if global_step % eval_every == 0:
                model.eval()
                for _, data in enumerate(tqdm(valid_loader)):
                    inputs = data['image'].to(device)  # inputs
                    target = data['target'].to(device)  # targets
                    embeddings = model(inputs)
                    loss, _ = batch_all_triplet_loss(target,
                                                     embeddings,
                                                     margin=margin,
                                                     epoch=epoch)
                    valid_loss.update(loss.item())

                print(
                    'Epoch [{}/{}], Step [{}/{}], Train Loss: {:.4f}, Valid Loss: {:.4f}, lr: {:.4f}'
                    .format(epoch + 1, num_epochs, global_step, total_step,
                            train_loss.avg, valid_loss.avg, current_lr))

        # valid
        dist1 = result(model, valid_loader1, device)
        dist2 = result(model, valid_loader2, device)
        same_hist = plt.hist(dist1,
                             100,
                             range=[
                                 np.floor(np.min([dist1.min(),
                                                  dist2.min()])),
                                 np.ceil(np.max([dist1.max(),
                                                 dist2.max()]))
                             ],
                             alpha=0.5,
                             label='same')
        diff_hist = plt.hist(dist2,
                             100,
                             range=[
                                 np.floor(np.min([dist1.min(),
                                                  dist2.min()])),
                                 np.ceil(np.max([dist1.max(),
                                                 dist2.max()]))
                             ],
                             alpha=0.5,
                             label='diff')
        plt.legend(loc='upper right')
        plt.savefig('../Result/distribution_epoch' + str(epoch + 1) + '.png')
        difference = same_hist[0] - diff_hist[0]
        difference[:same_hist[0].argmax()] = np.Inf
        difference[diff_hist[0].argmax():] = np.Inf
        dist_threshold = (
            same_hist[1][np.where(difference <= 0)[0].min()] +
            same_hist[1][np.where(difference <= 0)[0].min() - 1]) / 2
        overlap = np.sum(dist1 >= dist_threshold) + np.sum(
            dist2 <= dist_threshold)
        IOU = overlap / (dist1.shape[0] * 2 - overlap)
        print('dist_threshold:', dist_threshold, 'overlap:', overlap, 'IOU:',
              IOU)
        plt.clf()

        epoch_loss_list['train'].append(train_loss.avg)
        epoch_loss_list['valid'].append(valid_loss.avg)
        IOU_list.append(IOU)

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 9))
        ax1.plot(range(len(epoch_loss_list['train'])),
                 epoch_loss_list['train'],
                 label=('train_loss'))
        ax1.plot(range(len(epoch_loss_list['valid'])),
                 epoch_loss_list['valid'],
                 label=('valid_loss'))
        ax2.plot(range(len(IOU_list)), IOU_list, label=('IOU'))
        ax1.legend(prop={'size': 15})
        ax2.legend(prop={'size': 15})
        plt.savefig('../Result/loss.png')
        plt.clf()

        if IOU < best_IOU:
            best_IOU = IOU
            save(name, model, optimizer, scheduler)

        train_loss.reset()
        valid_loss.reset()

        # count the step of the scheduler in each epoch
        scheduler.step()
    print('Finished Training')
Example #7
0
def main(args):
    def log_string(str):
        logger.info(str)
        print(str)

    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    '''CREATE DIR'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    experiment_dir = Path('./log/')
    experiment_dir.mkdir(exist_ok=True)
    experiment_dir = experiment_dir.joinpath('sem_seg')
    experiment_dir.mkdir(exist_ok=True)
    if args.log_dir is None:
        experiment_dir = experiment_dir.joinpath(timestr)
    else:
        experiment_dir = experiment_dir.joinpath(args.log_dir)
    experiment_dir.mkdir(exist_ok=True)
    checkpoints_dir = experiment_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = experiment_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)

    '''LOG'''
    args = parse_args()
    logger = logging.getLogger("Model")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    log_string('PARAMETER ...')
    log_string(args)

    root_dataset = args.root
    NUM_CLASSES = 2
    NUM_POINT = args.npoint
    BATCH_SIZE = args.batch_size

    print("start loading training data ...")

    # def __init__(self,
    #              root,
    #              is_train=True,
    #              is_validation=False,
    #              is_test=False,
    #              num_channel=5
    #              ):

    TRAIN_DATASET = BigredDataSet(
    root=root_dataset,
    is_train=True,
    is_validation=False,
    is_test=False,
    num_channel=args.num_channel,
    test_code = False
    )

    print("start loading test data ...")

    TEST_DATASET = BigredDataSet(
    root=root_dataset,
    is_train=False,
    is_validation=True,
    is_test=False,
    num_channel=args.num_channel,
    test_code = False
    )
    trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=BATCH_SIZE, shuffle=True, num_workers=args.num_worker, pin_memory=True, drop_last=True, worker_init_fn = lambda x: np.random.seed(x+int(time.time())))
    testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=BATCH_SIZE, shuffle=False, num_workers=args.num_worker, pin_memory=True, drop_last=True)
    weights = torch.Tensor(TRAIN_DATASET.labelweights).cuda()
    #pdb.set_trace()

    log_string("The number of training data is: %d" % len(TRAIN_DATASET))
    log_string("The number of test data is: %d" % len(TEST_DATASET))

    '''MODEL LOADING'''
    MODEL = importlib.import_module(args.model)
    shutil.copy('models/%s.py' % args.model, str(experiment_dir))
    shutil.copy('models/pointnet_util.py', str(experiment_dir))

    classifier = MODEL.get_model(NUM_CLASSES,num_channel = args.num_channel)
    gpu_list = list(range(int(max(args.gpu))+1))


    classifier = torch.nn.DataParallel(classifier, device_ids=gpu_list).cuda()
    criterion = MODEL.get_loss().cuda()

    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv2d') != -1:
            torch.nn.init.xavier_normal_(m.weight.data)
            torch.nn.init.constant_(m.bias.data, 0.0)
        elif classname.find('Linear') != -1:
            torch.nn.init.xavier_normal_(m.weight.data)
            torch.nn.init.constant_(m.bias.data, 0.0)

    try:
        checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth')
        start_epoch = checkpoint['epoch']
        classifier.load_state_dict(checkpoint['model_state_dict'])
        log_string('Use pretrain model')
    except:
        log_string('No existing model, starting training from scratch...')
        start_epoch = 0
        classifier = classifier.apply(weights_init)

    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(
            classifier.parameters(),
            lr=args.learning_rate,
            betas=(0.9, 0.999),
            eps=1e-08,
            weight_decay=args.decay_rate
        )
    else:
        optimizer = torch.optim.SGD(classifier.parameters(), lr=args.learning_rate, momentum=0.9)

    def bn_momentum_adjust(m, momentum):
        if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d):
            m.momentum = momentum

    LEARNING_RATE_CLIP = 1e-5
    MOMENTUM_ORIGINAL = 0.1
    MOMENTUM_DECCAY = 0.5
    MOMENTUM_DECCAY_STEP = args.step_size

    global_epoch = 0
    best_value = 0
    writer = SummaryWriter()
    counter_play= 0

    mean_miou = AverageMeter()
    mean_acc = AverageMeter()
    mean_loss = AverageMeter()

    print("len(trainDataLoader)",len(trainDataLoader))
    print("len(trainDataLoader)",len(testDataLoader))


    for epoch in range(start_epoch,args.epoch):
        num_batches = len(trainDataLoader)

        '''Train on chopped scenes'''
        log_string('**** Epoch %d (%d/%s) ****' % (global_epoch + 1, epoch + 1, args.epoch))
        lr = max(args.learning_rate * (args.lr_decay ** (epoch // args.step_size)), LEARNING_RATE_CLIP)
        log_string('Learning rate:%f' % lr)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        momentum = MOMENTUM_ORIGINAL * (MOMENTUM_DECCAY ** (epoch // MOMENTUM_DECCAY_STEP))
        if momentum < 0.01:
            momentum = 0.01
        print('BN momentum updated to: %f' % momentum)
        classifier = classifier.apply(lambda x: bn_momentum_adjust(x,momentum))
        classifier.train()
        mean_miou.reset()
        mean_acc.reset()
        mean_loss.reset()
        for i, data in tqdm(enumerate(trainDataLoader), total=len(trainDataLoader), smoothing=0.9):
            points, target = data
            points, target = points.float().cuda(),target.long().cuda()
            points = points.transpose(2, 1)
            optimizer.zero_grad()
            classifier = classifier.train()
            seg_pred, trans_feat = classifier(points)
            pred_val = seg_pred.contiguous().cpu().data.numpy()
            seg_pred = seg_pred.contiguous().view(-1, NUM_CLASSES)
            batch_label2 = target.cpu().data.numpy()

            pred_val = np.argmax(pred_val, 2)

            batch_label = target.view(-1, 1)[:, 0].cpu().data.numpy()
            target = target.view(-1, 1)[:, 0]
            loss = criterion(seg_pred, target, trans_feat, weights)
            loss.backward()
            optimizer.step()
            pred_choice = seg_pred.cpu().data.max(1)[1].numpy()
            correct = np.sum(pred_choice == batch_label)

            current_seen_class = [0 for _ in range(NUM_CLASSES)]
            current_correct_class = [0 for _ in range(NUM_CLASSES)]
            current_iou_deno_class = [0 for _ in range(NUM_CLASSES)]
            for l in range(NUM_CLASSES):
                current_seen_class[l] = np.sum((batch_label2 == l))
                current_correct_class[l] = np.sum((pred_val == l) & (batch_label2 == l))
                current_iou_deno_class[l] = np.sum(((pred_val == l) | (batch_label2 == l)))

            m_iou = np.mean(np.array(current_correct_class) / (np.array(current_iou_deno_class, dtype=np.float) + 1e-6))
            loss_num = loss.item()
            acc_num = correct / float(args.batch_size * args.npoint)
            writer.add_scalar('training_loss', loss_num, counter_play)
            writer.add_scalar('training_accuracy', acc_num, counter_play)
            writer.add_scalar('training_mIoU', m_iou, counter_play)
            counter_play = counter_play + 1
            mean_miou.update(m_iou)
            mean_acc.update(acc_num)
            mean_loss.update(loss_num)

        train_ave_miou = mean_miou.avg
        train_ave_loss = mean_loss.avg
        train_ave_acc = mean_acc.avg
        log_string('Training point avg class IoU: %f' % train_ave_miou)
        log_string('Training mean loss: %f' % train_ave_loss)
        log_string('Training accuracy: %f' % train_ave_acc)
        # logger.info('Save model...')
        # savepath = str(checkpoints_dir) + '/traningmiou_'+str(mIoU)+'.pth'
        # # savepath = str(checkpoints_dir) + '/model.pth'
        #
        # log_string('Saving at %s' % savepath)
        # state = {
        #     'epoch': epoch,
        #     'model_state_dict': classifier.state_dict(),
        #     'optimizer_state_dict': optimizer.state_dict(),
        # }
        # torch.save(state, savepath)
        # log_string('Saving model....')

        print("----------------------Validation----------------------")
        mean_miou.reset()
        mean_acc.reset()
        mean_loss.reset()
        classifier.eval()
        with torch.no_grad():
            num_batches = len(testDataLoader)
            labelweights = np.zeros(NUM_CLASSES)
            log_string('---- EPOCH %03d EVALUATION ----' % (global_epoch + 1))
            for i, data in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9):
                points, target = data
                points = points.transpose(2, 1)
                points, target = points.cuda(), target.cuda()
                classifier = classifier.eval()
                tic = time.perf_counter()
                pred, _ = classifier(points)
                toc = time.perf_counter()
                # print(f"Downloaded the tutorial in {toc - tic:0.4f} seconds")
                pred = pred.view(-1, NUM_CLASSES)
                target = target.view(-1, 1)[:, 0]
                pred_choice = pred.data.max(1)[1]
                correct = pred_choice.eq(target.data).cpu().sum()
                pred_np = pred_choice.cpu().data.numpy()
                target_np = target.cpu().data.numpy()
                m_iou = mIoU(pred_np, target_np)
                mean_miou.update(m_iou)
                mean_acc.update(correct.item() / float(args.batch_size * 20000))

        val_ave_miou = mean_miou.avg
        val_ave_acc = mean_acc.avg
        writer.add_scalar('Validation_ave_miou', val_ave_miou, epoch)
        writer.add_scalar('Validation_ave_acc', val_ave_acc, epoch)
        print('Epoch: %d' % epoch)
        print('Validation_ave_miou: %f' % val_ave_miou)
        print('Train_ave_val_ave_miou: %f' % train_ave_miou)
        # labelweights = labelweights.astype(np.float32) / np.sum(labelweights.astype(np.float32))
        # iou_per_class_str = '------- IoU --------\n'
        # for l in range(NUM_CLASSES):
        #     iou_per_class_str += 'class %s weight: %.3f, IoU: %.3f \n' % (
        #         seg_label_to_cat[l] + ' ' * (14 - len(seg_label_to_cat[l])), labelweights[l - 1],
        #         total_correct_class[l] / float(total_iou_deno_class[l]))
        package = dict()
        package['state_dict'] = classifier.state_dict()
        package['optimizer'] = optimizer.state_dict()
        package['Train_ave_val_ave_miou'] = train_ave_miou
        package['Train_ave_acc'] = train_ave_acc
        package['Train_ave_loss'] = train_ave_loss
        package['Validation_ave_miou'] = val_ave_miou
        package['Validation_ave_acc'] = val_ave_acc
        package['epoch'] = epoch
        package['global_epoch'] = global_epoch
        package['time'] = time.ctime()
        package['num_channel'] = args.num_channel
        package['num_gpu'] = args.num_gpu
        # torch.save(package, save_dir + '/val_miou_%f_val_acc_%f_%d.pth' % (val_ave_miou, val_ave_acc, epoch))


        savepath = str(checkpoints_dir)+'/val_miou_'+str(val_ave_miou)+'_val_acc_'+str(val_ave_acc)+'_'+str(epoch)+'.pth'
        torch.save(package, savepath)

        print('Is Best? ', best_value < val_ave_miou)
        if (best_value < val_ave_miou):
            best_value = val_ave_miou
            savepath = str(checkpoints_dir) + '/best_model_valmiou_' + str(val_ave_miou) +'_'+str(epoch)+'.pth'
            log_string('Saving at %s' % savepath)
            torch.save(package, savepath)


        # if val_ave_miou >= best_iou:
        #     best_iou = val_ave_miou
        #     logger.info('Save model...')
        #     savepath = str(checkpoints_dir) + '/best_model_testmiou_' + str(mIoU) + '.pth'
        #     log_string('Saving at %s' % savepath)
        #     state = {
        #         'epoch': epoch,
        #         'class_avg_iou': mIoU,
        #         'model_state_dict': classifier.state_dict(),
        #         'optimizer_state_dict': optimizer.state_dict(),
        #     }
        #     torch.save(state, savepath)
        #     log_string('Saving model....')
        log_string('Best mIoU: %f' % best_value)

    global_epoch += 1