Ejemplo n.º 1
0
def main(cfg):
    # ===== 0.random seed and set logger ===== #
    logger = cfg._get_test_logger()
    logger.write('# === MAIN TEST === #')
    torch.manual_seed(cfg.random_seed)
    torch.cuda.manual_seed(cfg.random_seed)

    # ===== 1.load test data ===== #
    logger.write('==> Test Data loading...')
    test_dataset = cfg.dataset('test', cfg.data_param, cfg.transform_param,
                               logger)
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=10,
        shuffle=False,
        num_workers=cfg.training_setting_param.num_workers,
        pin_memory=False,
        drop_last=True)
    logger.write('==> Test Data loaded Successfully!')

    # ===== 2.load the network ===== #
    logger.write('==> Model loading...')
    net = get_model(cfg.model_param).to(cfg.device)
    assert osp.exists(cfg.test_setting_param.model_path)
    logger.write("==> loading pretrained model '{}'".format(
        cfg.test_setting_param.model_path))
    net = load_pretrained_model(net, cfg.test_setting_param.model_path)
    logger.write('==> Model loaded Successfully!')

    # ===== 4.main test process ===== #
    cfg.test(cfg, net, test_loader, logger)
Ejemplo n.º 2
0
def main(cfg):
    # ===== 0.random seed and set logger ===== #
    logger = cfg._get_train_logger()
    logger.write('# ===== MAIN TRAINING ===== #')
    torch.manual_seed(cfg.random_seed)
    torch.cuda.manual_seed(cfg.random_seed)

    # ===== 1.load train and val dataset ===== #
    logger.write('==> Data loading...')
    train_dataset = cfg.dataset('train', cfg.data_param, cfg.transform_param,
                                logger)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.training_setting_param.batch_size_per_gpu *
        cfg.devices_num,
        shuffle=True,
        num_workers=cfg.training_setting_param.num_workers,
        pin_memory=True,
        drop_last=True)
    val_dataset = cfg.dataset('val', cfg.data_param, cfg.transform_param,
                              logger)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=cfg.training_setting_param.batch_size_per_gpu *
        cfg.devices_num,
        shuffle=False,
        num_workers=cfg.training_setting_param.num_workers,
        pin_memory=False,
        drop_last=True)
    logger.write('==> Data loaded Successfully!')

    # ===== 2.load the network ===== #
    logger.write('==> Model loading...')
    time.sleep(2)
    net = get_model(cfg.model_param).to(cfg.device)
    input = torch.rand(cfg.transform_param.input_size)
    params = count_param(net)
    logger.write('%s totoal parameters: %.2fM (%d)' %
                 (cfg.model_param.model_name, params / 1e6, params))

    # ===== 3.define loss and optimizer ===== #
    #cfg.criterion = get_criterion(cfg.loss_param, cfg.task_name)
    cfg.optimizer = get_optim(net, cfg.training_setting_param)

    # ===== 4.optionally load a pretrained model or resume from one specific epoch ===== #
    if cfg.pretrained_model_path is not None:
        logger.write("==> loading pretrained model '{}'".format(
            cfg.pretrained_model_path))
        net = load_pretrained_model(net, cfg.pretrained_model_path)
    if cfg.resume_ckpt_path is not None:
        if os.path.isfile(cfg.resume_ckpt_path):
            logger.write("==> loading checkpoint '{}'".format(
                cfg.resume_ckpt_path))
            net, cfg.resume_epoch, resume_metric_best = resume(
                net, cfg.resume_ckpt_path)
            logger.write("==> loaded checkpoint '{}' (epoch {})".format(
                cfg.resume_ckpt_path, cfg.resume_epoch))
        else:
            logger.write("==> no checkpoint found at '{}'".format(
                cfg.resume_ckpt_path))
            os._exit(0)

    # ===== 5.main training process ===== #
    cudnn.benchmark = True
    metric_best = resume_metric_best if cfg.resume_ckpt_path is not None else -1000
    start_epoch = cfg.resume_epoch if cfg.resume_ckpt_path is not None else 0
    logger.set_names([
        'Epoch', 'Loss', 'Metric_Train', 'Metric_Val', 'Is_Best', 'Epoch_Time',
        'End_Time'
    ])

    for epoch in range(start_epoch, cfg.training_setting_param.epochs):
        epoch_start_time = time.time()
        print('==> Epoch: [{0}] Prepare Training... '.format(epoch))
        torch.cuda.empty_cache()
        # train for one epoch
        loss_train, metric_train = cfg.train(cfg, net, train_loader, epoch)
        # evaluate on validation set
        metric_val = cfg.val(cfg, net, val_loader)
        # save best checkpoints
        is_best = metric_val > metric_best
        metric_best = max(metric_val, metric_best)
        save_ckpt(cfg, epoch, net, metric_val, metric_best, is_best)
        # log epoch time and estimate the final end
        epoch_time = (time.time() - epoch_start_time) / 60
        final_end_time = estimate_final_end_time(
            cfg.training_setting_param.epochs, epoch, epoch_time)
        logger.append([
            epoch, loss_train, metric_train, metric_val,
            str(int(is_best)), epoch_time, final_end_time
        ])
def _train(args):
    _set_seed(args["seed"])

    factory.set_device(args)

    inc_dataset = factory.get_data(args)
    args["classes_order"] = inc_dataset.class_order
    print(inc_dataset.class_order)
    model = factory.get_model(args)

    results = results_utils.get_template_results(args)

    memory = None

    for _ in range(inc_dataset.n_tasks):
        task_info, train_loader, val_loader, test_loader = inc_dataset.new_task(
            memory)
        if task_info["task"] == args["max_task"]:
            break

        model.set_task_info(task=task_info["task"],
                            total_n_classes=task_info["max_class"],
                            increment=task_info["increment"],
                            n_train_data=task_info["n_train_data"],
                            n_test_data=task_info["n_test_data"],
                            n_tasks=task_info["max_task"])

        model.eval()
        model.before_task(train_loader, val_loader)
        print("Train on {}->{}.".format(task_info["min_class"],
                                        task_info["max_class"]))
        model.train()
        model.train_task(train_loader, val_loader)
        model.eval()
        model.after_task(inc_dataset)
        #model._save_model("../model_weight/cifar_(2000)_ours_005009dissEweightDist_v2_withoutbias_order{}_task{}.pkl".format(args["order"], task_info["task"]))

        print("Eval on {}->{}.".format(0, task_info["max_class"]))
        #ypred, ytrue = model.eval_task(test_loader)
        ypred, ytrue, ynpred, yntrue, y_top5, yn_top5 = model.eval_task(
            test_loader)

        acc_stats = utils.compute_accuracy(ynpred,
                                           yntrue,
                                           task_size=args["increment"])
        print(acc_stats)
        results["results"].append(acc_stats)

        ####################################################################################
        #classifier 100 classes
        acc_stats = utils.compute_accuracy(ypred,
                                           ytrue,
                                           task_size=args["increment"])
        print('classifier:     ', acc_stats)

        with open(
                '../results/results_txt/cifar_(1000)_learn50_ours_without_Norm_classifier.txt',
                "a") as accuracy:
            for i in acc_stats.values():
                accuracy.write(str(i) + " ")
            accuracy.write("\n")
        """     
        #top5
        acc_stats = utils.compute_accuracy(y_top5, ytrue, task_size=args["increment"])
        print('classifier top5:',acc_stats)
        
        with open('../results/results_txt/cifar_(2000)_ours_04dissEweightDist_nobias_classifier_top5.txt', "a") as accuracy:
            for i in acc_stats.values():
                accuracy.write(str(i) + " ")
            accuracy.write("\n")
        """
        #nearest 100 classes
        acc_stats = utils.compute_accuracy(ynpred,
                                           yntrue,
                                           task_size=args["increment"])
        print('nearest:        ', acc_stats)

        with open(
                '../results/results_txt/cifar_(1000)_learn50_ours_without_Norm_nearest.txt',
                "a") as accuracy:
            for i in acc_stats.values():
                accuracy.write(str(i) + " ")
            accuracy.write("\n")
        """     
        #top5
        acc_stats = utils.compute_accuracy(yn_top5, yntrue, task_size=args["increment"])
        print('nearest top5:   ',acc_stats)

        with open('../results/results_txt/cifar_(2000)_ours_04dissEweightDist_nobias_nearest_top5.txt', "a") as accuracy:
            for i in acc_stats.values():
                accuracy.write(str(i) + " ")
            accuracy.write("\n")
        """
        """ 
        ###############################Confusion matrix######################################
        if _ == (inc_dataset.n_tasks-1) :
            confusion = confusion_matrix(ytrue, ypred)
            confusion_plot = sn.heatmap(confusion, annot=False, cbar=False,
                     xticklabels =10,yticklabels =10, square = True)
            fig = confusion_plot.get_figure()
            fig.savefig("../confusion_matrix/shuffle/cifar_(2000)_ours_04dissEweightDist_nobias/classifier"+str(args["order"])+".png")
            
            confusion = confusion_matrix(yntrue, ynpred)
            confusion_plot = sn.heatmap(confusion, annot=False, cbar=False,
                     xticklabels =10,yticklabels =10, square = True)
            fig = confusion_plot.get_figure()
            fig.savefig("../confusion_matrix/shuffle/cifar_(2000)_ours_04dissEweightDist_nobias/nearest"+str(args["order"])+".png")
        #####################################################################################
        """
        memory = model.get_memory()

    print("Average Incremental Accuracy: {}.".format(
        results_utils.compute_avg_inc_acc(results["results"])))

    if args["name"]:
        results_utils.save_results(results, args["name"])

    del model
    del inc_dataset
    torch.cuda.empty_cache()
Ejemplo n.º 4
0
def _train(args):
    _set_seed(args["seed"])

    factory.set_device(args)

    inc_dataset = factory.get_data(args)
    args["classes_order"] = inc_dataset.class_order

    model = factory.get_model(args)

    results = results_utils.get_template_results(args)

    memory = None

    for _ in range(inc_dataset.n_tasks):
        task_info, train_loader, val_loader, test_loader = inc_dataset.new_task(
            memory)
        if task_info["task"] == args["max_task"]:
            break

        model.set_task_info(task=task_info["task"],
                            total_n_classes=task_info["max_class"],
                            increment=task_info["increment"],
                            n_train_data=task_info["n_train_data"],
                            n_test_data=task_info["n_test_data"],
                            n_tasks=task_info["max_task"])

        model.eval()
        model.before_task(train_loader, val_loader)
        print("Train on {}->{}.".format(task_info["min_class"],
                                        task_info["max_class"]))
        model.train()
        model.train_task(train_loader, val_loader)
        model.eval()
        model.after_task(inc_dataset)

        print("Eval on {}->{}.".format(0, task_info["max_class"]))
        ypred, ytrue = model.eval_task(test_loader)
        acc_stats = utils.compute_accuracy(ypred,
                                           ytrue,
                                           task_size=args["increment"])
        print(acc_stats)
        results["results"].append(acc_stats)

        memory = model.get_memory()

    print("Average Incremental Accuracy: {}.".format(
        results_utils.compute_avg_inc_acc(results["results"])))

    if args["name"]:
        results_utils.save_results(results, args["name"])

    ######################

    #with open('closs_L-_all_200.txt', 'w', newline='') as f:
    #    mywrite = csv.writer(f)
    #    mywrite.writerow(model._classification_loss)
    #with open('gloss_L-_all_200.txt', 'w', newline='') as f:
    #    mywrite = csv.writer(f)
    #    mywrite.writerow(model._graph_loss)

    del model
    del inc_dataset
    torch.cuda.empty_cache()
Ejemplo n.º 5
0
def _train(args):
    _set_seed(args["seed"])

    factory.set_device(args)

    inc_dataset = factory.get_data(args)
    
    # build ltl data loader (train:val = 0.5:0.5)
    ltl_inc_dataset = factory.get_ltl_data(args)
    # build ss data loader (batch size=1)
    ss_inc_dataset = factory.get_ss_data(args)
    gb_inc_dataset = factory.get_gb_data(args)
    
    args["classes_order"] = inc_dataset.class_order
    print(inc_dataset.class_order)
    model = factory.get_model(args)
    
    results = results_utils.get_template_results(args)

    memory = None
    
    ###load pretrain model (metric)######################
    #model._network.load_state_dict(torch.load('./net.weight'))
    model._metric_network.load_state_dict(torch.load('../../iCaRL_new/model_weight/cifar_pretrain_metric_withNorm_nearest.pkl'))
    #freeze
    for param in model._metric_network.parameters():
        param.requires_grad = False
    model._metric_network.eval()
    ######################################
    
    for _ in range(inc_dataset.n_tasks):
        # the data loader is for pretraining model 
        task_info, train_loader, val_loader, test_loader = inc_dataset.new_task(memory)
        
        # the data loader is for ss generator(nt yet implement)
        task_info, ss_train_loader, ss_val_loader, ss_test_loader = ss_inc_dataset.new_task(memory)        

        
        if task_info["task"] == args["max_task"]:
            break
        
        # prepare to learn to learn (1st pretrain, 2nd ltl)
        print("#################"+str(task_info["task"])+"#####################################")
        if task_info["task"] == 0:
            # the data loader is for training the gamma and beta (in cross domain method)
            ltl_task_info, ltl_train_loader, ltl_val_loader, ltl_test_loader = ltl_inc_dataset.new_task(memory)
            
            # get the data loader for train 100 classes
            # the data loader is for training 100 classes together (to see the scaling and shifting can work or not)
            gb_task_info, gb_train_loader, gb_val_loader, gb_test_loader = gb_inc_dataset.new_task(None)
            
            # 1st pretrain
            model.set_task_info(
                task=task_info["task"],
                total_n_classes=task_info["max_class"],
                increment=task_info["increment"],
                n_train_data=task_info["n_train_data"],
                n_test_data=task_info["n_test_data"],
                n_tasks=task_info["max_task"]
            )
            # adjust the controller of mtl weight, feature wise layer and need_grad
            #change_ft(model, ft=False)
            #change_mtl(model, mtl=False)
            #change_ss_flag(model, flag=False) 
            #change_weight_requires_grad(model, normal_grad_need=True, mtl_grad_need=False)
            #change_fw_requires_grad(model, fw_need=False)
            
            model.eval()
            model.before_task(train_loader, val_loader)
            print("Train on {}->{}.".format(task_info["min_class"], task_info["max_class"]))
            """
            model.train()
            model.train_task(train_loader, val_loader)
            """
            model._network.load_state_dict(torch.load('./pre_net.weight'))
            model.eval()
            # torch.save(model._network.state_dict(), './pre_net.weight')
            model.after_task(inc_dataset)
            
            print("Eval on {}->{}.".format(0, task_info["max_class"]))
            
            # classify 100 class
            ypred, ytrue, ynpred, yntrue, y_top5, yn_top5 = model.eval_task(train_loader)            

            ####################################################################################
            #classifier 100 classes
            acc_stats = utils.compute_accuracy(ypred, ytrue, task_size=args["increment"])
            print('train classifier:     ',acc_stats)

            with open('./results/results_txt/'+args["name"]+'_classifier.txt', "a") as accuracy:
                accuracy.write("train\n")
                for i in acc_stats.values():
                    accuracy.write(str(i) + " ")
                accuracy.write("\n")

            #nearest 100 classes
            acc_stats = utils.compute_accuracy(ynpred, yntrue, task_size=args["increment"])
            print('train nearest:        ',acc_stats)

            with open('./results/results_txt/'+args["name"]+'_nearest.txt', "a") as accuracy:
                accuracy.write("train\n")
                for i in acc_stats.values():
                    accuracy.write(str(i) + " ")
                accuracy.write("\n")
            ####################################################################################
            ypred, ytrue, ynpred, yntrue, y_top5, yn_top5 = model.eval_task(test_loader)            

            ####################################################################################
            #classifier 100 classes
            acc_stats = utils.compute_accuracy(ypred, ytrue, task_size=args["increment"])
            print('classifier:     ',acc_stats)

            with open('./results/results_txt/'+args["name"]+'_classifier.txt', "a") as accuracy:
                accuracy.write("test\n")
                for i in acc_stats.values():
                    accuracy.write(str(i) + " ")
                accuracy.write("\n")

            #nearest 100 classes
            acc_stats = utils.compute_accuracy(ynpred, yntrue, task_size=args["increment"])
            print('nearest:        ',acc_stats)

            with open('./results/results_txt/'+args["name"]+'_nearest.txt', "a") as accuracy:
                accuracy.write("test\n")
                for i in acc_stats.values():
                    accuracy.write(str(i) + " ")
                accuracy.write("\n")
            ####################################################################################
            
            # 2nd ltl
            model.set_task_info(
                task=ltl_task_info["task"],
                total_n_classes=ltl_task_info["max_class"],
                increment=ltl_task_info["increment"],
                n_train_data=ltl_task_info["n_train_data"],
                n_test_data=ltl_task_info["n_test_data"],
                n_tasks=ltl_task_info["max_task"]
            )
            # use feature wise during training
            #change_ft(model, ft=True)

            model.eval()
            model.before_task_to_2nd_ltl(ltl_train_loader, ltl_val_loader)
            print("Train on {}->{}.".format(task_info["min_class"], task_info["max_class"]))
            
            # model.train()
            # model.ltl_train_task(ltl_train_loader, ltl_val_loader)
            
            model._network.load_state_dict(torch.load('./ss_net.weight'))
            model.eval()
            # torch.save(model._network.state_dict(), './ss_net.weight')
            model.after_task(inc_dataset)
            
            print("Eval on {}->{}.".format(0, task_info["max_class"]))
            
            # classify 100 class
            ypred, ytrue, ynpred, yntrue, y_top5, yn_top5 = model.eval_task(train_loader)            

            ####################################################################################
            #classifier 100 classes
            acc_stats = utils.compute_accuracy(ypred, ytrue, task_size=args["increment"])
            print('train classifier:     ',acc_stats)

            with open('./results/results_txt/'+args["name"]+'_classifier.txt', "a") as accuracy:
                accuracy.write("train\n")
                for i in acc_stats.values():
                    accuracy.write(str(i) + " ")
                accuracy.write("\n")

            #nearest 100 classes
            acc_stats = utils.compute_accuracy(ynpred, yntrue, task_size=args["increment"])
            print('train nearest:        ',acc_stats)

            with open('./results/results_txt/'+args["name"]+'_nearest.txt', "a") as accuracy:
                accuracy.write("train\n")
                for i in acc_stats.values():
                    accuracy.write(str(i) + " ")
                accuracy.write("\n")
            ####################################################################################
            ypred, ytrue, ynpred, yntrue, y_top5, yn_top5 = model.eval_task(test_loader)            

            ####################################################################################
            #classifier 100 classes
            acc_stats = utils.compute_accuracy(ypred, ytrue, task_size=args["increment"])
            print('classifier:     ',acc_stats)

            with open('./results/results_txt/'+args["name"]+'_classifier.txt', "a") as accuracy:
                accuracy.write("test\n")
                for i in acc_stats.values():
                    accuracy.write(str(i) + " ")
                accuracy.write("\n")

            #nearest 100 classes
            acc_stats = utils.compute_accuracy(ynpred, yntrue, task_size=args["increment"])
            print('nearest:        ',acc_stats)

            with open('./results/results_txt/'+args["name"]+'_nearest.txt', "a") as accuracy:
                accuracy.write("test\n")
                for i in acc_stats.values():
                    accuracy.write(str(i) + " ")
                accuracy.write("\n")
            ####################################################################################
            
            # train the ss generator
            # close the ft and open the mtl to train the scaling and shifting
            #change_ft(model, ft=False)
            #change_mtl(model, mtl=True)
            #change_ss_flag(model, flag=True)
            
            
            # fixed the normal weight 
            #chage_extractor_requires_grad(model, ex_need_1=False, ex_need_2=False, ex_need_3=False, ex_need=False)            

            # for the scaling and shifting (gamma beta)
            #chage_extractor_requires_grad_mtl(model, ex_need_1=False, ex_need_2=False, ex_need_3=False, ex_need_1_bn_fw=False, ex_need_2_bn_fw=False, ex_need_3_bn_fw=False, ex_need=False) # <- not need now
            
            #chage_extractor_requires_grad_gb(model, ex_need_1_bn_fw=True, ex_need_2_bn_fw=True, ex_need_3_bn_fw=True, ex_need=True)
            
            # train the scaling and shifting
            # gb_train_loader: 100 classes 
            model.before_ss(train_loader, val_loader, ss=True)
            model.train()
            model.train_ss(train_loader, val_loader)
            model.eval()
            
            # use inc_dataset to store the exemplars
#             model.after_task(inc_dataset)
#             model.after_task(ss_inc_dataset)
#             model.after_task(inc_dataset)
#             model.ss_after_task(ss_inc_dataset)

            print('finish!!!!!!!!')
            # gb_after_task: dosen't store the exemplars
            model.gb_after_task(inc_dataset)
            
            print("Eval on {}->{}.".format(0, task_info["max_class"]))
            
            # classify 100 class
            ypred, ytrue, ynpred, yntrue, y_top5, yn_top5 = model.eval_task(train_loader, ss=1)            

            ####################################################################################
            #classifier 100 classes
            acc_stats = utils.compute_accuracy(ypred, ytrue, task_size=args["increment"])
            print('train classifier:     ',acc_stats)

            with open('./results/results_txt/'+args["name"]+'_classifier.txt', "a") as accuracy:
                accuracy.write("train\n")
                for i in acc_stats.values():
                    accuracy.write(str(i) + " ")
                accuracy.write("\n")

            #nearest 100 classes
            acc_stats = utils.compute_accuracy(ynpred, yntrue, task_size=args["increment"])
            print('train nearest:        ',acc_stats)

            with open('./results/results_txt/'+args["name"]+'_nearest.txt', "a") as accuracy:
                accuracy.write("train\n")
                for i in acc_stats.values():
                    accuracy.write(str(i) + " ")
                accuracy.write("\n")
            ####################################################################################
            ypred, ytrue, ynpred, yntrue, y_top5, yn_top5 = model.eval_task(test_loader, ss=1)            

            ####################################################################################
            #classifier 100 classes
            acc_stats = utils.compute_accuracy(ypred, ytrue, task_size=args["increment"])
            print('classifier:     ',acc_stats)

            with open('./results/results_txt/'+args["name"]+'_classifier.txt', "a") as accuracy:
                accuracy.write("test\n")
                for i in acc_stats.values():
                    accuracy.write(str(i) + " ")
                accuracy.write("\n")

            #nearest 100 classes
            acc_stats = utils.compute_accuracy(ynpred, yntrue, task_size=args["increment"])
            print('nearest:        ',acc_stats)

            with open('./results/results_txt/'+args["name"]+'_nearest.txt', "a") as accuracy:
                accuracy.write("test\n")
                for i in acc_stats.values():
                    accuracy.write(str(i) + " ")
                accuracy.write("\n")
            memory = model.get_memory()
            ####################################################################################
            
        #elif task_info["task"] == 1:    
        #    break
            
        else:  # below not yet modified, need to modify when use ss generator          
            #change_ft(model, ft=False)
            #change_mtl(model, mtl=True)
            #change_ss_flag(model, flag=True)
                        
            # close the updates
            #chage_extractor_requires_grad(model, ex_need_1=False, ex_need_2=False, ex_need_3=False, ex_need=False)
            
            # for meta-transfer learning        
            #chage_extractor_requires_grad_mtl(model, ex_need_1=True, ex_need_2=True, ex_need_3=True, ex_need_1_bn_fw=True, ex_need_2_bn_fw=True, ex_need_3_bn_fw=True, ex_need=True)    
    
            # only open the mtl_weight of the stage_1, stage_2 and stage_3
#             chage_extractor_requires_grad_mtl(model, ex_need_1=True, ex_need_2=True, ex_need_3=True, ex_need_1_bn_fw=False, ex_need_2_bn_fw=False, ex_need_3_bn_fw=False)

             # only open the mtl_weight of the stage_3
#             chage_extractor_requires_grad_mtl(model, ex_need_1=False, ex_need_2=False, ex_need_3=True, ex_need_1_bn_fw=False, ex_need_2_bn_fw=False, ex_need_3_bn_fw=False)

            
            model.set_task_info(
                task=task_info["task"],
                total_n_classes=task_info["max_class"],
                increment=task_info["increment"],
                n_train_data=task_info["n_train_data"],
                n_test_data=task_info["n_test_data"],
                n_tasks=task_info["max_task"]
            )

            model.eval()
            model.before_task(train_loader, val_loader)    #extend classifier
            #model.before_mtl(ss_train_loader, ss_val_loader)
            model.before_ss(train_loader, val_loader, ss=True, cls_layer=True)
            print("Train on {}->{}.".format(task_info["min_class"], task_info["max_class"]))
            model.train()
            #model.train_gb(train_loader, val_loader)
            model.train_ss(train_loader, val_loader)
            
            #################### train ssg after classifier #######################################
            # model.before_ss(train_loader, val_loader, ss=True)
            # print("Train on {}->{}.".format(task_info["min_class"], task_info["max_class"]))
            # model.train()
            # model.train_gb(train_loader, val_loader)
            # model.train_ss(train_loader, val_loader)
            #######################################################################################

            model.eval()
            model.gb_after_task(inc_dataset)
            

            print("Eval on {}->{}.".format(0, task_info["max_class"]))
            #ypred, ytrue = model.eval_task(test_loader)
            
            # batch size = 1        
            ypred, ytrue, ynpred, yntrue, y_top5, yn_top5 = model.eval_task(train_loader, ss=1)            

            ####################################################################################
            #classifier 100 classes
            acc_stats = utils.compute_accuracy(ypred, ytrue, task_size=args["increment"])
            print('train classifier:     ',acc_stats)

            with open('./results/results_txt/'+args["name"]+'_classifier.txt', "a") as accuracy:
                accuracy.write("train\n")
                for i in acc_stats.values():
                    accuracy.write(str(i) + " ")
                accuracy.write("\n")

            #nearest 100 classes
            acc_stats = utils.compute_accuracy(ynpred, yntrue, task_size=args["increment"])
            print('train nearest:        ',acc_stats)

            with open('./results/results_txt/'+args["name"]+'_nearest.txt', "a") as accuracy:
                accuracy.write("train\n")
                for i in acc_stats.values():
                    accuracy.write(str(i) + " ")
                accuracy.write("\n")
            ####################################################################################
            
            ypred, ytrue, ynpred, yntrue, y_top5, yn_top5 = model.eval_task(test_loader, ss=1)          
            
            acc_stats = utils.compute_accuracy(ynpred, yntrue, task_size=args["increment"])
            print(acc_stats)
            results["results"].append(acc_stats)
            memory = model.get_memory()
            ####################################################################################
            #classifier 100 classes
            acc_stats = utils.compute_accuracy(ypred, ytrue, task_size=args["increment"])
            print('classifier:     ',acc_stats)

            with open('./results/results_txt/'+args["name"]+'_classifier.txt', "a") as accuracy:
                accuracy.write("test\n")
                for i in acc_stats.values():
                    accuracy.write(str(i) + " ")
                accuracy.write("\n")

            #nearest 100 classes
            acc_stats = utils.compute_accuracy(ynpred, yntrue, task_size=args["increment"])
            print('nearest:        ',acc_stats)

            with open('./results/results_txt/'+args["name"]+'_nearest.txt', "a") as accuracy:
                accuracy.write("test\n")
                for i in acc_stats.values():
                    accuracy.write(str(i) + " ")
                accuracy.write("\n")

            memory = model.get_memory()
            ####################################################################################
    """        
    print(
        "Average Incremental Accuracy: {}.".format(
            results_utils.compute_avg_inc_acc(results["results"])
        )
    )
    """
    #if args["name"]:
    #    results_utils.save_results(results, args["name"])

    del model
    del inc_dataset
    del ss_inc_dataset
    del ltl_inc_dataset
    torch.cuda.empty_cache()