예제 #1
0
def main(args):
    # print args recap
    print(args, end="\n\n")
    loss = torch.nn.CrossEntropyLoss()

    method = methods[args.method](args, loss, args.use_cuda)
    valid_acc, elapsed, ram_usage, ext_mem_sz, preds = method.train_model(
        tune=False)

    # directory with the code snapshot to generate the results
    sub_dir = 'submissions/' + args.sub_dir
    if not os.path.exists(sub_dir):
        os.makedirs(sub_dir)
    # copy code
    create_code_snapshot(".", sub_dir + "/code_snapshot")

    with open(sub_dir + "/metadata.txt", "w") as wf:
        for obj in [
                np.average(valid_acc), elapsed,
                np.average(ram_usage),
                np.max(ram_usage),
                np.average(ext_mem_sz),
                np.max(ext_mem_sz)
        ]:
            wf.write(str(obj) + "\n")

    with open(sub_dir + "/test_preds.txt", "w") as wf:
        for pred in preds:
            wf.write(str(pred) + "\n")
예제 #2
0
def main(args):

    # print args recap
    print(args, end="\n\n")

    # do not remove this line
    start = time.time()

    # Create the dataset object for example with the "ni, multi-task-nc, or nic
    # tracks" and assuming the core50 location in ./core50/data/
    dataset = CORE50(root='core50/data/',
                     scenario=args.scenario,
                     preload=args.preload_data)

    # Get the validation set
    print("Recovering validation set...")
    full_valdidset = dataset.get_full_valid_set()

    # model
    if args.classifier == 'ResNet18':
        classifier = models.resnet18(pretrained=True)
        classifier.fc = torch.nn.Linear(512, args.n_classes)

    opt = torch.optim.SGD(classifier.parameters(), lr=args.lr)
    criterion = torch.nn.CrossEntropyLoss()

    # vars to update over time
    valid_acc = []
    ext_mem_sz = []
    ram_usage = []
    heads = []
    ext_mem = None

    # loop over the training incremental batches (x, y, t)
    for i, train_batch in enumerate(dataset):
        train_x, train_y, t = train_batch

        # adding eventual replay patterns to the current batch
        idxs_cur = np.random.choice(train_x.shape[0],
                                    args.replay_examples,
                                    replace=False)

        if i == 0:
            ext_mem = [train_x[idxs_cur], train_y[idxs_cur]]
        else:
            ext_mem = [
                np.concatenate((train_x[idxs_cur], ext_mem[0])),
                np.concatenate((train_y[idxs_cur], ext_mem[1]))
            ]

        train_x = np.concatenate((train_x, ext_mem[0]))
        train_y = np.concatenate((train_y, ext_mem[1]))

        print("----------- batch {0} -------------".format(i))
        print("x shape: {0}, y shape: {1}".format(train_x.shape,
                                                  train_y.shape))
        print("Task Label: ", t)

        # train the classifier on the current batch/task
        _, _, stats = train_net(opt,
                                classifier,
                                criterion,
                                args.batch_size,
                                train_x,
                                train_y,
                                t,
                                args.epochs,
                                preproc=preprocess_imgs)
        if args.scenario == "multi-task-nc":
            heads.append(copy.deepcopy(classifier.fc))

        # collect statistics
        ext_mem_sz += stats['disk']
        ram_usage += stats['ram']

        # test on the validation set
        stats, _ = test_multitask(classifier,
                                  full_valdidset,
                                  args.batch_size,
                                  preproc=preprocess_imgs,
                                  multi_heads=heads,
                                  last_layer_name="fc",
                                  verbose=False)

        valid_acc += stats['acc']
        print("------------------------------------------")
        print("Avg. acc: {}".format(stats['acc']))
        print("------------------------------------------")

    # Generate submission.zip
    # directory with the code snapshot to generate the results
    sub_dir = 'submissions/' + args.sub_dir
    if not os.path.exists(sub_dir):
        os.makedirs(sub_dir)

    # copy code
    create_code_snapshot(".", sub_dir + "/code_snapshot")

    # generating metadata.txt: with all the data used for the CLScore
    elapsed = (time.time() - start) / 60
    print("Training Time: {}m".format(elapsed))
    with open(sub_dir + "/metadata.txt", "w") as wf:
        for obj in [
                np.average(valid_acc), elapsed,
                np.average(ram_usage),
                np.max(ram_usage),
                np.average(ext_mem_sz),
                np.max(ext_mem_sz)
        ]:
            wf.write(str(obj) + "\n")

    # test_preds.txt: with a list of labels separated by "\n"
    print("Final inference on test set...")
    full_testset = dataset.get_full_test_set()
    stats, preds = test_multitask(classifier,
                                  full_testset,
                                  args.batch_size,
                                  preproc=preprocess_imgs,
                                  multi_heads=heads,
                                  last_layer_name="fc",
                                  verbose=False)

    with open(sub_dir + "/test_preds.txt", "w") as wf:
        for pred in preds:
            wf.write(str(pred) + "\n")

    print("Experiment completed.")
예제 #3
0
def main(args):

    # print args recap
    print(args, end="\n\n")

    # do not remove this line
    start = time.time()

    # Create the dataset object for example with the "ni, multi-task-nc, or nic tracks"
    # and assuming the core50 location in ./core50/data/
    # ???review CORE50 to see if there is a way to shorten dataset and runtime for testing
    # Original call to dataset. Takes a long time.
    # dataset = CORE50(root='core50/data/', scenario=args.scenario,
    #                  preload=args.preload_data)
    #
    # custom call to create CORE50 custom object
    # using train=True uses training set and allows more control over batches and other stuff.
    dataset = CORE50(root='core50/data/',
                     scenario=args.scenario,
                     preload=args.preload_data)

    # Get the validation set
    print("Recovering validation set...")
    # default full validation set
    # full_valdidset = dataset.get_full_valid_set()
    # reduced validation set
    full_valdidset = dataset.get_full_valid_set(reduced=True)

    # model
    if args.classifier == 'ResNet18':
        classifier = models.resnet18(
            pretrained=True)  # classifier is a pretrained model
        classifier.fc = torch.nn.Linear(
            512, args.n_classes
        )  # in features: 512 # out features: set below -> args.n_classes = 50  #  Applies a linear transformation to the incoming data

    opt = torch.optim.SGD(classifier.parameters(),
                          lr=args.lr)  # Implements stochastic gradient descent
    criterion = torch.nn.CrossEntropyLoss(
    )  # This criterion combines nn.LogSoftmax() and nn.NLLLoss() in one single class.

    # vars to update over time
    valid_acc = []
    ext_mem_sz = []
    ram_usage = []
    heads = []

    # Start Modification
    ewc_lambda = 4  # should this be higher? closer to 0.01 or 0.4 or 0.8? that is what was used in other examples. What does a higher penalty do? what does a lower penatly do?
    # fisher_max = 0.0001
    # variable dictionary to hold fisher values
    fisher_dict = {}
    # variable dictionary to hold previous optimized weight values
    optpar_dict = {}
    # End Modification

    # loop over the training incremental batches (x, y, t)
    for i, train_batch in enumerate(dataset):
        train_x, train_y, t = train_batch

        # Start Modifiction

        # Make train_x and train_y smaller for testing here
        limit_size = False  # make true to limit training size # make false to allow full training set
        if limit_size:
            train_size = 3200
            # train_size = 11900
            train_x = train_x[0:train_size]
            train_y = train_y[0:train_size]

        # End Modification

        # Print current batch number
        print("----------- batch {0} -------------".format(i))
        # Print current batch shape
        print("x shape: {0}, y shape: {1}".format(train_x.shape,
                                                  train_y.shape))
        # print task label type
        print("Task Label: ", t)

        # utils.train_net: a custom function to train neural network. returns stats.
        _, _, stats = train_net_ewc(opt,
                                    classifier,
                                    criterion,
                                    args.batch_size,
                                    train_x,
                                    train_y,
                                    t,
                                    fisher_dict,
                                    optpar_dict,
                                    ewc_lambda,
                                    args.epochs,
                                    preproc=preprocess_imgs)

        # if multi-task-nc: make deep copy in list heads (aka nn brains)
        if args.scenario == "multi-task-nc":
            heads.append(copy.deepcopy(classifier.fc))
        ext_mem_sz += stats['disk']
        ram_usage += stats['ram']

        # Start Modifiction
        # Calculate the Fisher matrix values given new completed task
        on_task_update(t,
                       train_x,
                       train_y,
                       fisher_dict,
                       optpar_dict,
                       classifier,
                       opt,
                       criterion,
                       args.batch_size,
                       preproc=preprocess_imgs
                       )  # training complete # compute fisher matrix values
        # End Modification

        # test all nn models in list heads for performance. return stats for each.
        stats, _ = test_multitask(classifier,
                                  full_valdidset,
                                  args.batch_size,
                                  preproc=preprocess_imgs,
                                  multi_heads=heads)

        # print new stats on performance
        valid_acc += stats['acc']
        print("------------------------------------------")
        print("Avg. acc: {}".format(stats['acc']))
        print("------------------------------------------")

    # Generate submission.zip
    # directory with the code snapshot to generate the results
    sub_dir = 'submissions/' + args.sub_dir
    if not os.path.exists(sub_dir):
        os.makedirs(sub_dir)

    # copy code
    # custom function in utils folder to deal with possible file path issues
    create_code_snapshot(".", sub_dir + "/code_snapshot")

    # generating metadata.txt: with all the data used for the CLScore
    elapsed = (time.time() - start) / 60
    print("Training Time: {}m".format(elapsed))
    with open(sub_dir + "/metadata.txt", "w") as wf:
        for obj in [
                np.average(valid_acc), elapsed,
                np.average(ram_usage),
                np.max(ram_usage),
                np.average(ext_mem_sz),
                np.max(ext_mem_sz)
        ]:
            wf.write(str(obj) + "\n")

    # run final full test
    # test_preds.txt: with a list of labels separated by "\n"
    print("Final inference on test set...")
    full_testset = dataset.get_full_test_set()
    stats, preds = test_multitask(classifier,
                                  full_testset,
                                  args.batch_size,
                                  preproc=preprocess_imgs)

    with open(sub_dir + "/test_preds.txt", "w") as wf:
        for pred in preds:
            wf.write(str(pred) + "\n")

    print("Experiment completed.")
예제 #4
0
def main(args):

    # print args recap
    print(args, end="\n\n")

    # do not remove this line
    start = time.time()

    # Create the dataset object
    dataset = CORE50(
        root='/home/jbonato/Documents/cvpr_clvision_challenge/core50/data/',
        scenario=args.scenario,
        preload=True)

    #################################  Get the validation set
    print("Recovering validation set...")
    full_valdidset = dataset.get_full_valid_set()
    device0 = torch.device('cuda:0')

    ################################ # code for training
    if args.scenario == 'ni':
        NI = NI_wrap(dataset,
                     full_valdidset,
                     device=device0,
                     path='/home/jbonato/Documents/cvpr_clvision_challenge/',
                     load=args.load)
    elif args.scenario == 'multi-task-nc':
        NI = NC_wrap(dataset,
                     full_valdidset,
                     device=device0,
                     path='/home/jbonato/Documents/cvpr_clvision_challenge/',
                     load=args.load)
    elif args.scenario == 'nic':
        NI = NIC_wrap(dataset,
                      full_valdidset,
                      device=device0,
                      path='/home/jbonato/Documents/cvpr_clvision_challenge/',
                      load=args.load)

    stats, valid_acc = NI.train()
    ram_usage = np.asarray(stats['ram'])
    ext_mem_sz = np.asarray(stats['disk'])

    #################################  Generate submission.zip
    #################################  directory with the code snapshot to generate the results
    sub_dir = 'submissions/' + args.sub_dir
    if not os.path.exists(sub_dir):
        os.makedirs(sub_dir)

    #################################  copy code
    create_code_snapshot(".", sub_dir + "/code_snapshot")

    ################################## generating metadata.txt: with all the data used for the CLScore
    elapsed = (time.time() - start) / 60
    print("Training Time: {}m".format(elapsed))
    with open(sub_dir + "/metadata.txt", "w") as wf:
        for obj in [
                np.average(valid_acc), elapsed,
                np.average(ram_usage),
                np.max(ram_usage),
                np.average(ext_mem_sz),
                np.max(ext_mem_sz)
        ]:
            wf.write(str(obj) + "\n")
    with open(sub_dir + "/valid_hist.txt", "w") as wf:
        for obj in [valid_acc]:
            wf.write(str(obj) + "\n")

    # test_preds.txt: with a list of labels separated by "\n"
    print("Final inference on test set...")
    full_testset = dataset.get_full_test_set()

    pred = NI.test(full_testset, standalone=False)

    with open(sub_dir + "/test_preds.txt", "w") as wf:
        for jj in range(pred.shape[0]):
            wf.write(str(pred[jj]) + "\n")

    print("Experiment completed.")
예제 #5
0
def main(args):

    # print args recap
    print(args, end="\n\n")

    # do not remove this line
    start = time.time()

    # Create the dataset object for example with the "ni, multi-task-nc, or nic tracks"
    # and assuming the core50 location in ./core50/data/
    # ???review CORE50 to see if there is a way to shorten dataset and runtime for testing
    # Original call to dataset. Takes a long time.
    # dataset = CORE50(root='core50/data/', scenario=args.scenario,
    #                  preload=args.preload_data)
    #
    # custom call to create CORE50 custom object
    # using train=True uses training set and allows more control over batches and other stuff.
    dataset = CORE50(root='core50/data/',
                     scenario=args.scenario,
                     preload=args.preload_data)

    # Get the validation set
    print("Recovering validation set...")
    # default full validation set
    # full_valdidset = dataset.get_full_valid_set()
    # reduced validation set
    full_valdidset = dataset.get_full_valid_set(reduced=True)

    # model
    if args.classifier == 'ResNet18':
        classifier = models.resnet18(
            pretrained=True)  # classifier is a pretrained model
        classifier.fc = torch.nn.Linear(
            512, args.n_classes
        )  # in features: 512 # out features: set below -> args.n_classes = 50  #  Applies a linear transformation to the incoming data

    opt = torch.optim.SGD(classifier.parameters(),
                          lr=args.lr)  # Implements stochastic gradient descent
    criterion = torch.nn.CrossEntropyLoss(
    )  # This criterion combines nn.LogSoftmax() and nn.NLLLoss() in one single class.

    # vars to update over time
    valid_acc = []
    ext_mem_sz = []
    ram_usage = []
    heads = []

    # enumerate(dataset) provides itirator over all training sets and test sets
    # loop over the training incremental batches (x, y, t)
    for i, train_batch in enumerate(dataset):
        train_x, train_y, t = train_batch

        # Start modification

        # run batch 0 and 1. Then break.
        # if i == 2: break

        # shuffle new data
        train_x, train_y = shuffle_in_unison((train_x, train_y), seed=0)

        if i == 0:
            # this is the first round
            # store data for later
            all_x = train_x[0:train_x.shape[0] // 2]
            all_y = train_y[0:train_y.shape[0] // 2]
        else:
            # this is not the first round
            # create hybrid training set old and new data
            # shuffle old data
            all_x, all_y = shuffle_in_unison((all_x, all_y), seed=0)

            # create temp holder
            temp_x = train_x
            temp_y = train_y

            # set current variables to be used for training
            train_x = np.append(all_x, train_x, axis=0)
            train_y = np.append(all_y, train_y)
            train_x, train_y = shuffle_in_unison((train_x, train_y), seed=0)

            # append half of old and all of new data
            temp_x, temp_y = shuffle_in_unison((temp_x, temp_y), seed=0)
            keep_old = (all_x.shape[0] // (i + 1)) * i
            keep_new = temp_x.shape[0] // (i + 1)
            all_x = np.append(all_x[0:keep_old], temp_x[0:keep_new], axis=0)
            all_y = np.append(all_y[0:keep_old], temp_y[0:keep_new])
            del temp_x
            del temp_y

        # rest of code after this should be the same
        # End modification

        # Print current batch number
        print("----------- batch {0} -------------".format(i))
        # Print current batch shape
        print("x shape: {0}, y shape: {1}".format(train_x.shape,
                                                  train_y.shape))
        # print task label type
        print("Task Label: ", t)

        # utils.train_net: a custom function to train neural network. returns stats.
        _, _, stats = train_net(opt,
                                classifier,
                                criterion,
                                args.batch_size,
                                train_x,
                                train_y,
                                t,
                                args.epochs,
                                preproc=preprocess_imgs)

        # if multi-task-nc: make deep copy in list heads (aka nn brains)
        if args.scenario == "multi-task-nc":
            heads.append(copy.deepcopy(classifier.fc))
        ext_mem_sz += stats['disk']
        ram_usage += stats['ram']

        # test all nn models in list heads for performance. return stats for each.
        stats, _ = test_multitask(classifier,
                                  full_valdidset,
                                  args.batch_size,
                                  preproc=preprocess_imgs,
                                  multi_heads=heads)

        # print new stats on performance
        valid_acc += stats['acc']
        print("------------------------------------------")
        print("Avg. acc: {}".format(stats['acc']))
        print("------------------------------------------")

    # Generate submission.zip
    # directory with the code snapshot to generate the results
    sub_dir = 'submissions/' + args.sub_dir
    if not os.path.exists(sub_dir):
        os.makedirs(sub_dir)

    # copy code
    # custom function in utils folder to deal with possible file path issues
    create_code_snapshot(".", sub_dir + "/code_snapshot")

    # generating metadata.txt: with all the data used for the CLScore
    elapsed = (time.time() - start) / 60
    print("Training Time: {}m".format(elapsed))
    with open(sub_dir + "/metadata.txt", "w") as wf:
        for obj in [
                np.average(valid_acc), elapsed,
                np.average(ram_usage),
                np.max(ram_usage),
                np.average(ext_mem_sz),
                np.max(ext_mem_sz)
        ]:
            wf.write(str(obj) + "\n")

    # run final full test
    # test_preds.txt: with a list of labels separated by "\n"
    print("Final inference on test set...")
    full_testset = dataset.get_full_test_set()
    stats, preds = test_multitask(classifier,
                                  full_testset,
                                  args.batch_size,
                                  preproc=preprocess_imgs)

    with open(sub_dir + "/test_preds.txt", "w") as wf:
        for pred in preds:
            wf.write(str(pred) + "\n")

    print("Experiment completed.")
예제 #6
0
def main(args):

    # print args recap
    print(args, end="\n\n")

    # do not remove this line
    start = time.time()

    # Create the dataset object for example with the "ni, multi-task-nc, or nic
    # tracks" and assuming the core50 location in ./core50/data/
    dataset = CORE50(root='core50/data/',
                     scenario=args.scenario,
                     preload=args.preload_data)

    # Get the validation set
    print("Recovering validation set...")
    full_valdidset = dataset.get_full_valid_set()

    # model
    if args.classifier == 'ResNet18':
        # classifier = models.resnet18(pretrained=True)
        classifier = resnet18(pretrained=True)
        classifier.fc = torch.nn.Linear(512, args.n_classes)

    if args.classifier == 'ResNet34':
        # classifier = models.resnet34(pretrained=True)
        classifier = resnet34(pretrained=True)
        classifier.fc = torch.nn.Linear(512, args.n_classes)

    if args.classifier == 'ResNet50':
        classifier = resnet50(pretrained=True)
        classifier.fc = torch.nn.Linear(2048, args.n_classes)

    if args.classifier == 'MobileNetV2':
        classifier = mobilenet_v2(pretrained=True)
        classifier.classifier = torch.nn.Sequential(
            torch.nn.Dropout(0.2),
            torch.nn.Linear(1280, args.n_classes),
        )

    if args.classifier == 'mnasnet':
        classifier = mnasnet1_0(pretrained=True)
        classifier.classifier = torch.nn.Sequential(
            torch.nn.Dropout(p=0.2, inplace=True),
            torch.nn.Linear(1280, args.n_classes))

    if args.classifier == 'siamese':
        classifier = SiameseNetwork(n_class=args.n_classes)

    if args.optimizer == 'sgd':
        opt = torch.optim.SGD(classifier.parameters(), lr=args.lr)
    if args.optimizer == 'adam':
        opt = torch.optim.Adam(classifier.parameters(), lr=args.lr)
    if args.optimizer == 'radam':
        opt = RAdam(classifier.parameters(), lr=args.lr)
    # for param in classifier.conv1.parameters() or param in classifier.layer1.parameters() or param in classifier.relu.parameters():
    #     param.requires_grad = False
    # opt = torch.optim.SGD(filter(lambda p: p.requires_grad, classifier.parameters()), lr=args.lr)
    regularization_terms = {}
    task_count = 0
    criterion = torch.nn.CrossEntropyLoss()
    if args.classifier == 'siamese':
        # criterion = ContrastiveLoss(margin=2.)
        criterion = ContrastiveAndCELoss(margin=10.)

    # vars to update over time
    valid_acc = []
    ext_mem_sz = []
    ram_usage = []
    heads = []
    ext_mem = None

    # loop over the training incremental batches (x, y, t)
    for i, train_batch in enumerate(dataset):
        train_x, train_y, t = train_batch

        # adding eventual replay patterns to the current batch
        idxs_cur = np.random.choice(train_x.shape[0],
                                    args.replay_examples,
                                    replace=False)

        if i == 0:
            ext_mem = [train_x[idxs_cur], train_y[idxs_cur]]
        else:
            ext_mem = [
                np.concatenate((train_x[idxs_cur], ext_mem[0])),
                np.concatenate((train_y[idxs_cur], ext_mem[1]))
            ]

        train_x = np.concatenate((train_x, ext_mem[0]))
        train_y = np.concatenate((train_y, ext_mem[1]))

        print("----------- batch {0} -------------".format(i))
        print(
            "x shape: {0}, y shape: {1}, ext_mem_x shape: {2}, ext_mem_y shape: {3}"
            .format(train_x.shape, train_y.shape, ext_mem[0].shape,
                    ext_mem[1].shape))
        print("Task Label: ", t)

        # train the classifier on the current batch/task
        # _, _, stats, regularization_terms, classifier = train_net_mrcl(
        #     opt, classifier, criterion, args.batch_size, train_x, train_y, i,
        #     regularization_terms, task_count, args.regularize_mode, args.icarl,
        #     args.epochs, preproc=preprocess_imgs
        # )
        i_es = 0
        lr = args.lr
        valid_acc2 = []
        for iepoch in range(args.epochs):
            if args.classifier != 'siamese':
                _, _, stats, regularization_terms = train_net(
                    opt,
                    classifier,
                    criterion,
                    args.batch_size,
                    train_x,
                    train_y,
                    i,
                    regularization_terms,
                    task_count,
                    args.regularize_mode,
                    args.icarl,
                    args.aug,
                    args.resize_shape,
                    train_ep=1,
                    preproc=preprocess_imgs)
            else:
                _, _, stats, regularization_terms = train_net_siamese(
                    opt,
                    classifier,
                    criterion,
                    args.batch_size,
                    train_x,
                    train_y,
                    i,
                    regularization_terms,
                    task_count,
                    args.regularize_mode,
                    args.icarl,
                    args.aug,
                    args.resize_shape,
                    train_ep=1,
                    preproc=preprocess_imgs)
            # stats2, _ = test_multitask(
            #     classifier, full_valdidset, args.batch_size, args.aug, args.resize_shape,
            #     preproc=preprocess_imgs, multi_heads=heads, verbose=False
            # )
            # valid_acc2 += stats2['acc']
            # print("{}th epoch, Avg. acc: {}".format(iepoch+1, stats2['acc']))
            # if stats2['acc']==max(valid_acc2):
            #     classifier_max = copy.deepcopy(classifier)
            #     print('update classifier using {}th'.format(iepoch+1))
            # else:
            #     i_es +=1  # counter for early stop
            #     if i_es>2:
            #         if args.optimizer =='sgd' and lr>=0.0001:
            #             lr = 0.1*lr
            #             opt = torch.optim.SGD(classifier.parameters(), lr=lr)
            #             i_es = 0
            #             print('learning rate reduced to {}'.format(lr))
            #         else:
            #             print('early stop at epoch {}'.format(iepoch+1))
            #             classifier = copy.deepcopy(classifier_max)
            #             break
        # classifier = copy.deepcopy(classifier_max)

        if args.scenario == "multi-task-nc":
            if args.classifier == 'mnasnet':
                heads.append(copy.deepcopy(classifier.classifier[1]))
            elif args.classifier == 'siamese':
                heads.append(copy.deepcopy(classifier.cnn1.classifier[1]))
            else:
                heads.append(copy.deepcopy(classifier.fc))

        ### not using the nearest neighbour classifier in icarl
        # if args.icarl:
        #     exem_class = []
        #     classifier = maybe_cuda(classifier)
        #     classifier.eval()
        #     # update exemplar features
        #     if args.classifier=='mnasnet':
        #         for i_class in range(args.n_classes):
        #             exemplar_features = []
        #             # nb_iters = (i_class==ext_mem[1]).sum()//args.batch_size + 1
        #             # x_i = torch.from_numpy(preprocess_imgs(ext_mem[0][i_class==ext_mem[1]])).type(torch.FloatTensor)

        #             nb_iters = (i_class==train_y).sum()//args.batch_size + 1
        #             x_i = torch.from_numpy(preprocess_imgs(train_x[i_class==train_y])).type(torch.FloatTensor)
        #             if len(x_i)==0:
        #                 print('no exemplars to be updated for class {}...'.format(i_class))
        #                 break
        #             with torch.no_grad():
        #                 for it in range(nb_iters):
        #                     start = it * args.batch_size
        #                     end = (it + 1) * args.batch_size
        #                     x_i_mb = maybe_cuda(x_i[start:end])
        #                     if x_i_mb.shape[0]==0:
        #                         break
        #                     feat_exem = classifier.layers(x_i_mb).mean([2, 3])  # mnasnet
        #                     exemplar_features.extend(np.array(feat_exem.cpu()))
        #                 mean_exem_feats = np.mean(exemplar_features, axis=0)
        #                 # mean_exem_feats = mean_exem_feats / np.linalg.norm(mean_exem_feats) # Normalize
        #                 exem_class.append(torch.from_numpy(mean_exem_feats).type(torch.FloatTensor))
        #         # classify with nearest
        #         stats_icarl, _ = test_multitask_icarl(
        #             classifier, full_valdidset, exem_class, args.batch_size,
        #             preproc=preprocess_imgs, multi_heads=heads, verbose=False
        #         )
        #         print("icarl avg. acc: {}".format(stats_icarl['acc']))
        #     else:
        #         for i_class in range(args.n_classes):
        #             nb_iters = (i_class==ext_mem[1]).sum()//args.batch_size if (i_class==ext_mem[1]).sum()//args.batch_size >0 else (i_class==ext_mem[1]).sum()
        #             x_i = torch.from_numpy(preprocess_imgs(ext_mem[0][i_class==ext_mem[1]])).type(torch.FloatTensor)
        #             if len(x_i)==0:
        #                 break
        #             with torch.no_grad():
        #                 for it in range(nb_iters):
        #                     start = it * args.batch_size
        #                     end = (it + 1) * args.batch_size
        #                     x_i_mb = maybe_cuda(x_i[start:end])
        #                     feat_exem = classifier.conv1(x_i_mb)
        #                     feat_exem = classifier.bn1(feat_exem)
        #                     feat_exem = classifier.relu(feat_exem)
        #                     print(feat_exem.shape)
        #                     feat_exem = classifier.maxpool(feat_exem)
        #                     feat_exem = classifier.layer1(feat_exem)
        #                     feat_exem = classifier.layer2(feat_exem)
        #                     feat_exem = classifier.layer3(feat_exem)
        #                     feat_exem = classifier.layer4(feat_exem)
        #                     feat_exem = classifier.avgpool(feat_exem)
        #                     feat_exem = feat_exem.view(feat_exem.size(0), -1)
        #                     exemplar_features.append(np.array(feat_exem.cpu()))
        #                 mean_exem_feats = np.mean(np.mean(exemplar_features, axis=0), axis=0)
        #                 mean_exem_feats = mean_exem_feats / np.linalg.norm(mean_exem_feats) # Normalize
        #                 exem_class.append(torch.from_numpy(mean_exem_feats).type(torch.FloatTensor))
        #         # classify with nearest
        #         stats_icarl, _ = test_multitask_icarl(
        #             classifier, full_valdidset, exem_class, args.batch_size,
        #             preproc=preprocess_imgs, multi_heads=heads, verbose=False
        #         )
        #         print("icarl avg. acc: {}".format(stats_icarl['acc']))
        ###

        task_count += 1

        # collect statistics
        ext_mem_sz += stats['disk']
        ram_usage += stats['ram']

        # test on the validation set
        if args.classifier != 'siamese':
            stats, _ = test_multitask(classifier,
                                      full_valdidset,
                                      args.batch_size,
                                      args.aug,
                                      args.resize_shape,
                                      preproc=preprocess_imgs,
                                      multi_heads=heads,
                                      verbose=False)
        else:
            exem_class = []
            classifier = maybe_cuda(classifier)
            classifier.eval()
            # update exemplar features
            for i_class in range(args.n_classes):
                exemplar_features = []
                nb_iters = (i_class == train_y).sum() // args.batch_size + 1
                x_i = preprocess_imgs(train_x[i_class == train_y],
                                      aug=False,
                                      resize_shape=args.resize_shape).type(
                                          torch.FloatTensor)
                if len(x_i) == 0:
                    print('no exemplars to be updated for class {}...'.format(
                        i_class))
                    break
                with torch.no_grad():
                    for it in range(nb_iters):
                        start = it * args.batch_size
                        end = (it + 1) * args.batch_size
                        x_i_mb = maybe_cuda(x_i[start:end])
                        if x_i_mb.shape[0] == 0:
                            break
                        feat_exem = classifier.forward_once(x_i_mb)
                        i = random.choice(
                            range(x_i_mb.shape[0])
                        )  # sample one exemplar for each iteration per class
                        # exemplar_features.extend(np.array(feat_exem.cpu()))
                        exemplar_features.append(np.array(feat_exem[i].cpu()))
                    # mean_exem_feats = np.mean(exemplar_features, axis=0)
                    # # mean_exem_feats = mean_exem_feats / np.linalg.norm(mean_exem_feats) # Normalize
                    # exem_class.append(torch.from_numpy(mean_exem_feats).type(torch.FloatTensor))
                    exem_class.append(
                        torch.from_numpy(np.array(exemplar_features)).type(
                            torch.FloatTensor))
            # classify with nearest
            stats, _ = test_multitask_siamese(classifier,
                                              full_valdidset,
                                              exem_class,
                                              args.batch_size,
                                              args.aug,
                                              args.resize_shape,
                                              preproc=preprocess_imgs,
                                              multi_heads=heads,
                                              verbose=False)

        valid_acc += stats['acc']
        print("------------------------------------------")
        print("Avg. acc: {}".format(stats['acc']))
        print("------------------------------------------")

    # Generate submission.zip
    # directory with the code snapshot to generate the results
    sub_dir = 'submissions/' + args.sub_dir
    if not os.path.exists(sub_dir):
        os.makedirs(sub_dir)

    # copy code
    create_code_snapshot(".", sub_dir + "/code_snapshot")

    # generating metadata.txt: with all the data used for the CLScore
    elapsed = (time.time() - start) / 60
    print("Training Time: {}m".format(elapsed))
    print("Final average valid acc: {}".format(np.average(valid_acc)))
    with open(sub_dir + "/metadata.txt", "w") as wf:
        for obj in [
                np.average(valid_acc), elapsed,
                np.average(ram_usage),
                np.max(ram_usage),
                np.average(ext_mem_sz),
                np.max(ext_mem_sz)
        ]:
            wf.write(str(obj) + "\n")

    # test_preds.txt: with a list of labels separated by "\n"
    print("Final inference on test set...")
    full_testset = dataset.get_full_test_set()
    if args.classifier != 'siamese':
        stats, preds = test_multitask(classifier,
                                      full_testset,
                                      args.batch_size,
                                      args.aug,
                                      args.resize_shape,
                                      preproc=preprocess_imgs,
                                      multi_heads=heads,
                                      verbose=False)
    else:
        stats, _ = test_multitask_siamese(classifier,
                                          full_testset,
                                          exem_class,
                                          args.batch_size,
                                          args.aug,
                                          args.resize_shape,
                                          preproc=preprocess_imgs,
                                          multi_heads=heads,
                                          verbose=False)

    with open(sub_dir + "/test_preds.txt", "w") as wf:
        for pred in preds:
            wf.write(str(pred) + "\n")

    print("Experiment completed.")