示例#1
0
def test(model,
         test_dataset,
         sampleprotos_mu,
         sampleprotos_logvar,
         args,
         cuda=False):

    model.train(mode=False)  #reset to test mode

    sampler = BalancedDatasetSampler(test_dataset, args.dataset_samples,
                                     args.test_size)

    if cuda:
        loader = torch.utils.data.DataLoader(test_dataset,
                                             shuffle=False,
                                             sampler=sampler,
                                             batch_size=args.dataset_classes *
                                             args.dataset_samples,
                                             num_workers=0,
                                             pin_memory=True)
    else:
        loader = torch.utils.data.DataLoader(test_dataset,
                                             shuffle=False,
                                             sampler=sampler,
                                             batch_size=args.dataset_classes *
                                             args.dataset_samples,
                                             num_workers=0)

    total_tested = 0
    total_correct = 0
    with torch.no_grad():
        data_stream = tqdm(enumerate(loader, 1))
        for batch_index, (imgs, labels, dummy) in data_stream:
            # break on test size.
            if total_tested >= args.test_size:
                break
            # test the model.
            # prepare the data.
            #print(imgs.size())
            imgs = imgs.view(args.dataset_classes, args.dataset_samples,
                             args.dataset_channels, args.dataset_width,
                             args.dataset_width)
            imgs = imgs[:, 0, :, :, :]  #take the first sample to test
            imgs = Variable(imgs).cuda() if cuda else Variable(imgs)

            loss, lossinfor = model.loss_proto(
                imgs, sampleprotos_mu, sampleprotos_logvar,
                args.dataset_nsupport, args.dataset_nquery,
                args.dataset_classes, args.dataset_channels,
                args.dataset_width, cuda, args.temperature)
            # update statistics.
            total_correct = total_correct + lossinfor['acc']
            total_tested = total_tested + 1

    model.train(mode=True)  #reset to training mode
    precision = total_correct / total_tested

    return precision
示例#2
0
def storeOriProto(model, train_dataset, args, cuda=False):

    sampler = BalancedDatasetSampler(train_dataset, args.dataset_total, 1)
    if cuda:
        loader = torch.utils.data.DataLoader(train_dataset,
                                             shuffle=False,
                                             sampler=sampler,
                                             batch_size=args.dataset_total *
                                             args.dataset_current_classes,
                                             num_workers=0,
                                             pin_memory=True)
    else:
        loader = torch.utils.data.DataLoader(train_dataset,
                                             shuffle=False,
                                             sampler=sampler,
                                             batch_size=args.dataset_total *
                                             args.dataset_current_classes,
                                             num_workers=0)

    total_proto = torch.Tensor()
    total_tested = 0
    data_stream = tqdm(enumerate(loader, 1))
    for batch_index, (imgs, labels, dummy) in data_stream:

        if total_tested >= args.oriproto_eachsize:
            break
        # test the model.
        # prepare the data.
        #print('img1')
        #print(imgs.size())
        imgs = imgs.view(args.dataset_current_classes, args.dataset_total,
                         args.dataset_channels, args.dataset_width,
                         args.dataset_width)
        #print('img2')
        #print(imgs.size())
        #current = imgs[:,0:args.proto_size-1,:]
        current = imgs

        if total_proto.nelement() == 0:
            total_proto = current.detach().cpu().clone()
        else:
            total_proto = torch.cat(
                (total_proto, current.detach().cpu().clone()), 1)

        total_tested = total_tested + 1
        #print('current')
        #print(current.size())

    total_proto = total_proto.view(args.dataset_current_classes,
                                   args.dataset_total, args.dataset_channels,
                                   args.dataset_width, args.dataset_width)
    return total_proto
def storeProto(model, train_dataset, args, cuda=False):
    
    model.train(mode = False) #reset to test mode
    
    sampler = BalancedDatasetSampler(train_dataset,args.dataset_samples,args.dataset_episodes) 
        
    if cuda:
        loader = torch.utils.data.DataLoader(train_dataset, shuffle=False, sampler=sampler, batch_size=args.batch_size, num_workers=0,pin_memory=True)
    else:
        loader = torch.utils.data.DataLoader(train_dataset, shuffle=False, sampler=sampler, batch_size=args.batch_size, num_workers=0)
            
            
    total_proto = torch.Tensor()
    total_tested = 0    
    
    with torch.no_grad():
        data_stream = tqdm(enumerate(loader, 1))
        for batch_index, (imgs, labels, dummy) in data_stream:
            # break on test size.
            if (total_tested >= args.dataset_episodes) or (total_tested >= args.oriproto_eachsize):
                break
            # test the model.
            # prepare the data.
            imgs = imgs.view(args.dataset_current_classes,args.dataset_samples,args.dataset_channels, args.dataset_width, args.dataset_width)
            imgs = Variable(imgs).cuda() if cuda else Variable(imgs)
                    
            current = model.getHiddenReps(imgs,args.dataset_nsupport, 
                                                 args.dataset_nquery, 
                                                 args.dataset_current_classes, 
                                                 args.dataset_channels, 
                                                 args.dataset_width) 
            
            # update statistics.
            if total_proto.nelement() == 0:
                total_proto = current.detach().cpu().reshape(args.dataset_current_classes,1,args.dataset_nquery+args.dataset_nsupport,-1)
            else:
                
                current = current.detach().cpu().clone().reshape(args.dataset_current_classes,1,args.dataset_nquery+args.dataset_nsupport,-1)
#                print('current')
#                print(current.size()) 
#                print('total_proto')
#                print(total_proto.size())
                total_proto = torch.cat((total_proto,current),1)      
                
            total_tested = total_tested + 1         
            
    model.train(mode=True) #reset to training mode
    #total_proto = total_proto/total_tested
    #print(total_proto)
    
    return total_proto
def train(model,
          train_datasets,
          test_datasets,
          task_output_space,
          args,
          cuda=False):

    #optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) #for current task
    #optimizer2 = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) #for previous task

    optimizer = optim.SGD(model.parameters(), lr=args.lr)  #for current task
    optimizer2 = optim.SGD(model.parameters(), lr=args.lr)  #for previous task

    model.train()

    precision_record = defaultdict(list)

    protos = {}
    oriprotos = {}

    lossinfor2 = {'loss': float('NaN'), 'acc': float('NaN')}

    args.dataset_samples = args.dataset_nquery + args.dataset_nsupport

    # Decide split ordering
    task_names = sorted(list(task_output_space.keys()), key=int)
    print('Task order:', task_names)
    test_dataset_all = []
    train_dataset_all = []
    for task, train_dataset in enumerate(train_datasets, 1):
        train_name = task_names[task - 1]
        train_dataset = train_datasets[train_name]
        test_dataset = test_datasets[train_name]
        train_dataset_all.append(train_dataset)
        test_dataset_all.append(test_dataset)

    for task, train_dataset in enumerate(train_datasets, 1):
        best_precision = 0  #the best precision for the current task

        if task == 1:
            args.dataset_classes = args.first_split_size
            args.batch_size = args.first_split_size * args.dataset_samples
            args.dataset_current_classes = args.first_split_size
        else:
            args.dataset_classes = args.dataset_classes + args.other_split_size
            args.batch_size = args.other_split_size * args.dataset_samples
            args.dataset_current_classes = args.other_split_size
            args.epochs_per_task = args.epochs_per_tasknext

        train_name = task_names[task - 1]
        train_dataset = train_datasets[train_name]

        if task > 1:
            args.dataset_episodes = args.dataset_nextepisodes

        for epoch in range(1, args.epochs_per_task + 1):
            iternum = 0
            sampler = BalancedDatasetSampler(train_dataset,
                                             args.dataset_samples,
                                             args.dataset_episodes)

            if cuda:
                loader = torch.utils.data.DataLoader(
                    train_dataset,
                    shuffle=False,
                    sampler=sampler,
                    batch_size=args.batch_size,
                    num_workers=0,
                    pin_memory=True)
            else:
                loader = torch.utils.data.DataLoader(
                    train_dataset,
                    shuffle=False,
                    sampler=sampler,
                    batch_size=args.batch_size,
                    num_workers=0)

            data_stream = tqdm(enumerate(loader, 1))
            for batch_index, (imgs, labels, dummy) in data_stream:

                # prepare the data.
                imgs = imgs.view(args.dataset_current_classes,
                                 args.dataset_samples, args.dataset_channels,
                                 args.dataset_width, args.dataset_width)
                if cuda: imgs = imgs.cuda()

                # combine with previous images
                oldimgs = torch.Tensor()
                for t in range(args.dataset_classes -
                               args.dataset_current_classes):

                    #generate a random index for selecting ori protos
                    rn = random.randint(0, args.oriproto_eachsize - 1)

                    samplesingleimgs = oriprotos[t][rn, :].view(
                        1, args.dataset_samples, args.dataset_channels,
                        args.dataset_width, args.dataset_width)
                    if cuda: samplesingleimgs = samplesingleimgs.cuda()

                    if oldimgs.nelement() == 0:
                        oldimgs = samplesingleimgs
                    else:
                        oldimgs = torch.cat((oldimgs, samplesingleimgs), 0)

                if cuda: oldimgs = oldimgs.cuda()
                imgs = torch.cat((oldimgs, imgs), 0)
                imgs = Variable(imgs).cuda() if cuda else Variable(imgs)

                # run the model and backpropagate the errors based on current task + previous protos
                optimizer.zero_grad()
                loss, lossinfor = model.loss_initial(
                    imgs, args.dataset_nsupport, args.dataset_nquery,
                    args.dataset_classes, args.dataset_channels,
                    args.dataset_width, args.temperature)  #,
                #protos, cuda)
                loss.backward()
                optimizer.step()

                #every 2 batches; backpropogate original protos
                if iternum % args.replay_freq == 0:
                    # run the model and backpropgate the errors based on previous oriprotos
                    if task > 1:
                        sampleimgs = torch.Tensor()
                        sampleprotos = torch.Tensor()
                        for t in range(args.dataset_classes -
                                       args.dataset_current_classes):

                            #generate a random index for selecting ori protos
                            rn = random.randint(0, args.oriproto_eachsize - 1)
                            samplesingleimgs = oriprotos[t][rn, :].view(
                                1, args.dataset_samples, args.dataset_channels,
                                args.dataset_width, args.dataset_width)
                            samplesingleprotos = protos[t][rn, :].view(
                                1, args.dataset_samples, -1)

                            if cuda: samplesingleimgs = samplesingleimgs.cuda()
                            if cuda:
                                samplesingleprotos = samplesingleprotos.cuda()

                            if sampleimgs.nelement() == 0:
                                sampleimgs = samplesingleimgs
                                sampleprotos = samplesingleprotos
                            else:
                                sampleimgs = torch.cat(
                                    (sampleimgs, samplesingleimgs), 0)
                                sampleprotos = torch.cat(
                                    (sampleprotos, samplesingleprotos), 0)

                        sampleprotos = Variable(sampleprotos).cuda(
                        ) if cuda else Variable(sampleprotos)
                        sampleimgs = Variable(sampleimgs).cuda(
                        ) if cuda else Variable(sampleimgs)

                        sampleprotos = sampleprotos.reshape(
                            (args.dataset_classes -
                             args.dataset_current_classes) *
                            (args.dataset_samples), -1)
                        #                            print('sampleimgs')
                        #                            print(sampleimgs.size())
                        #                            print('sampleprotos')
                        #                            print(sampleprotos.size())

                        optimizer2.zero_grad()
                        loss2, lossinfor2 = model.loss_proto(
                            sampleimgs, args.dataset_nsupport,
                            args.dataset_nquery, args.dataset_classes -
                            args.dataset_current_classes,
                            args.dataset_channels, args.dataset_width,
                            sampleprotos, cuda, args.temperature,
                            args.model_mode)
                        loss2.backward()
                        optimizer2.step()

                iternum = iternum + 1
                #print("loss: "+"{}".format(lossinfor['loss'])+"; accuracy: "+ "{}".format(lossinfor['acc']))
                #print(imgs.size());x1 = imgs[1][1][0].view(28,28,1); plt.imshow(x1[:,:,0]); break

                data_stream.set_description(
                    ('task: {task}/{tasks} | '
                     'epoch: {epoch}/{epochs} | '
                     'progress: [{trained}/{total}] ({progress:.0f}%) | '
                     'prec: {prec:.4} | '
                     'loss: {loss:.4} | '
                     'prev_prec: {prevprec:.4} | '
                     'prev_loss: {prevloss:.4} | ').format(
                         task=task,
                         tasks=len(train_datasets),
                         epoch=epoch,
                         epochs=args.epochs_per_task,
                         trained=batch_index * args.batch_size,
                         total=args.dataset_episodes * args.dataset_samples,
                         progress=(100. * batch_index * args.batch_size /
                                   len(loader.dataset)),
                         prec=lossinfor['acc'],
                         loss=lossinfor['loss'],
                         prevprec=lossinfor2['acc'],
                         prevloss=lossinfor2['loss'],
                     ))

            #end of epoch; test in validation set; save best model for this task
            val_dataset = torch.utils.data.ConcatDataset(
                train_dataset_all[:task])
            current_precision = validate(model, val_dataset, args, cuda)

            #if current_precision > best_precision:
            best_precision = current_precision
            save_checkpoint(model, args, epoch, current_precision, task)

        #end of task;  evaluate on current task and previous tasks
        load_checkpoint(model, args, task, cuda)
        #current_proto, current_oriproto = storeOriProto_Proto(model, train_dataset, args, cuda)

        args.oriproto_eachsize = int(
            math.floor(
                args.oriproto_size /
                args.dataset_classes))  #floor the number of oriprotos per task
        current_oriproto = storeOriProto(model, train_dataset, args, cuda)
        args.proto_eachsize = int(
            math.floor(args.proto_size / args.dataset_classes))
        current_proto = storeProto(model, train_dataset, args, cuda)
        #        print('current_proto')
        #        print(current_proto.size())

        for t in range(args.dataset_current_classes):
            oriprotos[args.dataset_classes - args.dataset_current_classes +
                      t] = current_oriproto[t]
            protos[args.dataset_classes - args.dataset_current_classes +
                   t] = current_proto[t]

        #print('protos')
        #print(protos.size())

        for i in range(len(train_datasets)):
            if i + 1 <= task:
                #reduce the storage of previous protos
                oriprotos[i] = oriprotos[i][:args.oriproto_eachsize, :]
                protos[i] = protos[i][:args.oriproto_eachsize, :]
#                print('task')
#                print(i)
#                print('reduced proto size')
#                print(protos[i].size())
#validate performance in all previous tasks
        val_dataset = torch.utils.data.ConcatDataset(test_dataset_all[:task])
        prec = test(model, val_dataset, args, cuda)
        precision_record[task - 1].append(prec)
        print(precision_record)
        #save precision_record as tensors for plotting later
        path = os.path.join(
            args.result_dir,
            '{firstname}_{secondname}-precision_record.pt'.format(
                firstname=args.model_name, secondname=args.n_repeats))
        torch.save(precision_record, path)
示例#5
0
def train(model,
          train_datasets,
          test_datasets,
          task_output_space,
          args,
          cuda=False):

    #optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) #for current task
    #optimizer2 = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) #for previous task

    optimizer = optim.SGD(model.parameters(), lr=args.lr)  #for current task
    optimizer2 = optim.SGD(model.parameters(), lr=args.lr)  #for previous task

    model.train()

    precision_record = defaultdict(list)

    protos_mu = {}
    protos_logvar = {}
    oriprotos = {}

    lossinfor2 = {'loss': float('NaN'), 'acc': float('NaN')}

    # Decide split ordering
    task_names = sorted(list(task_output_space.keys()), key=int)
    print('Task order:', task_names)
    test_dataset_all = []
    train_dataset_all = []
    numclass_all = []

    for task, train_dataset in enumerate(train_datasets, 1):
        train_name = task_names[task - 1]
        train_dataset = train_datasets[train_name]
        test_dataset = test_datasets[train_name]
        train_dataset_all.append(train_dataset)
        test_dataset_all.append(test_dataset)

    for task, train_dataset in enumerate(train_datasets, 1):
        best_precision = 0  #the best precision for the current task

        if task == 1:
            args.dataset_classes = args.first_split_size
            numclass_all.append(args.dataset_classes)
            args.batch_size = args.first_split_size * args.dataset_samples
            args.dataset_current_classes = args.first_split_size
            args.dataset_episodes = int(
                math.floor(args.dataset_total / args.dataset_samples))
        else:
            args.dataset_classes = args.dataset_classes + args.other_split_size
            numclass_all.append(args.dataset_classes)
            args.dataset_samples = args.oriproto_eachsize * 2
            args.dataset_episodes = int(
                math.floor(args.dataset_total / args.dataset_samples))
            args.batch_size = args.other_split_size * args.dataset_samples
            args.dataset_current_classes = args.other_split_size

        #print('nsample: ' + str(args.dataset_samples) + '; nepisodes: ' + str(args.dataset_episodes))
        train_name = task_names[task - 1]
        train_dataset = train_datasets[train_name]

        if task > 1:
            #args.dataset_episodes = args.dataset_nextepisodes
            args.epochs_per_task = args.epochs_per_tasknext
            args.epochs_per_tasknext = int(args.epochs_per_tasknext / 2)
            if args.epochs_per_tasknext < 10:
                args.epochs_per_tasknext = 10

        for epoch in range(1, args.epochs_per_task + 1):
            iternum = 0
            sampler = BalancedDatasetSampler(train_dataset,
                                             args.dataset_samples,
                                             args.dataset_episodes)

            if cuda:
                loader = torch.utils.data.DataLoader(
                    train_dataset,
                    shuffle=False,
                    sampler=sampler,
                    batch_size=args.batch_size,
                    num_workers=0,
                    pin_memory=True)
            else:
                loader = torch.utils.data.DataLoader(
                    train_dataset,
                    shuffle=False,
                    sampler=sampler,
                    batch_size=args.batch_size,
                    num_workers=0)

            data_stream = tqdm(enumerate(loader, 1))
            for batch_index, (imgs, labels, dummy) in data_stream:

                # prepare the data.
                imgs = imgs.view(args.dataset_current_classes,
                                 args.dataset_samples, args.dataset_channels,
                                 args.dataset_width, args.dataset_width)
                imgs = imgs[:,
                            np.random.permutation(
                                np.arange(args.dataset_samples)), :, :, :]
                #print(imgs.size())
                if cuda: imgs = imgs.cuda()

                # combine with previous images
                oldimgs = torch.Tensor()
                for t in range(args.dataset_classes -
                               args.dataset_current_classes):

                    #generate a random index for selecting ori protos for one image per class
                    #rn = random.randint(0,args.oriproto_eachsize-1)
                    #size: num_class, num_sample = 1, num_channle, num_width, num_width
                    samplesingleimgs = oriprotos[t][:, :].view(
                        1, args.oriproto_eachsize, args.dataset_channels,
                        args.dataset_width, args.dataset_width)
                    if cuda: samplesingleimgs = samplesingleimgs.cuda()

                    if oldimgs.nelement() == 0:
                        oldimgs = samplesingleimgs
                    else:
                        oldimgs = torch.cat((oldimgs, samplesingleimgs), 0)
                #print('oldimgs')
                #print(oldimgs.size())
                # combine with previous protos
                oldprotos_mu = torch.Tensor()
                oldprotos_logvar = torch.Tensor()
                for t in range(args.dataset_classes -
                               args.dataset_current_classes):
                    #always from the prev task
                    samplesprotos_mu = protos_mu[task - 2][t, :]
                    samplesprotos_mu = samplesprotos_mu.view(1, args.z_dim)
                    samplesprotos_logvar = protos_logvar[task - 2][t, :]
                    samplesprotos_logvar = samplesprotos_logvar.view(
                        1, args.z_dim)
                    if cuda: samplesprotos_mu = samplesprotos_mu.cuda()
                    if cuda: samplesprotos_logvar = samplesprotos_logvar.cuda()

                    if oldprotos_mu.nelement() == 0:
                        oldprotos_mu = samplesprotos_mu
                        oldprotos_logvar = samplesprotos_logvar
                    else:
                        oldprotos_mu = torch.cat(
                            (oldprotos_mu, samplesprotos_mu), 0)
                        oldprotos_logvar = torch.cat(
                            (oldprotos_logvar, samplesprotos_logvar), 0)

                if cuda: oldprotos_mu = Variable(oldprotos_mu).cuda()
                if cuda: oldprotos_logvar = Variable(oldprotos_logvar).cuda()
                #print('oldprotos_mu')
                #print(oldprotos_mu.size())
                #print('oldprotos_logvar')
                #print(oldprotos_logvar.size())
                #imgs = torch.cat((oldimgs,imgs),0)
                #imgs = Variable(imgs).cuda() if cuda else Variable(imgs)

                # run the model and backpropagate the errors based on current task + previous protos
                optimizer.zero_grad()
                #print('model.loss_initial')
                loss, lossinfor = model.loss_initial(
                    imgs, oldimgs, oldprotos_mu, oldprotos_logvar,
                    args.dataset_samples, args.dataset_current_classes,
                    args.dataset_classes, args.dataset_nsupport,
                    args.dataset_nquery, args.dataset_channels,
                    args.dataset_width, args.temperature)  #,
                #protos, cuda)
                loss.backward()
                optimizer.step()

                if iternum % args.replay_freq == 0:
                    # run the model and backpropgate the errors based on previous protos
                    for T in range(task - 1):
                        #if task > 1:

                        # combine with previous images
                        oldimgs = oldimgs.view(
                            args.dataset_classes -
                            args.dataset_current_classes,
                            args.oriproto_eachsize, args.dataset_channels,
                            args.dataset_width, args.dataset_width)
                        #print('oldimgs')
                        #print(oldimgs.size())
                        # combine with previous protos

                        for t in range(
                                task - 1
                        ):  #range(task-2,-1,-1): #range(task-1) in asceding order; range(task-2,-1,-1) in descending order; range(task-2,task-3,-1) in prev task only

                            optimizer2.zero_grad()
                            #always from the prev task
                            oldprotos_mu = protos_mu[t][:, :]
                            #samplesprotos_mu = samplesprotos_mu.view(1,args.z_dim)
                            oldprotos_logvar = protos_logvar[t][:, :]
                            #alwyas take only one example out of each class and do backprop
                            rn = random.randint(0, args.oriproto_eachsize - 1)
                            subsetoldimgs = oldimgs[:numclass_all[t],
                                                    rn, :].view(
                                                        numclass_all[t], 1,
                                                        args.dataset_channels,
                                                        args.dataset_width,
                                                        args.dataset_width)
                            #print('train recall func - oldprotos_mu')
                            #print(oldprotos_mu.size())
                            #print('train recall func - oldprotos_logvar')
                            #print(oldprotos_logvar.size())
                            #print('train recall func - subsetoldimgs')
                            #print(subsetoldimgs.size())
                            if cuda:
                                subsetoldimgs = Variable(subsetoldimgs).cuda()
                            if cuda:
                                oldprotos_mu = Variable(oldprotos_mu).cuda()
                            if cuda:
                                oldprotos_logvar = Variable(
                                    oldprotos_logvar).cuda()

                            loss2, lossinfor2 = model.loss_proto(
                                subsetoldimgs, oldprotos_mu, oldprotos_logvar,
                                args.dataset_nsupport, args.dataset_nquery,
                                numclass_all[t], args.dataset_channels,
                                args.dataset_width, cuda, args.temperature)
                            loss2.backward()
                            optimizer2.step()

                iternum = iternum + 1
                #print("loss: "+"{}".format(lossinfor['loss'])+"; accuracy: "+ "{}".format(lossinfor['acc']))
                #print(imgs.size());x1 = imgs[1][1][0].view(28,28,1); plt.imshow(x1[:,:,0]); break

                data_stream.set_description(
                    ('task: {task}/{tasks} | '
                     'epoch: {epoch}/{epochs} | '
                     'progress: [{trained}/{total}] ({progress:.0f}%) | '
                     'prec: {prec:.4} | '
                     'loss: {loss:.4} | '
                     'prev_prec: {prevprec:.4} | '
                     'prev_loss: {prevloss:.4} | ').format(
                         task=task,
                         tasks=len(train_datasets),
                         epoch=epoch,
                         epochs=args.epochs_per_task,
                         trained=batch_index * args.batch_size /
                         args.dataset_current_classes,
                         total=args.dataset_episodes * args.dataset_samples,
                         progress=(
                             100. * batch_index * args.batch_size /
                             (args.dataset_current_classes *
                              args.dataset_episodes * args.dataset_samples)),
                         prec=lossinfor['acc'],
                         loss=lossinfor['loss'],
                         prevprec=lossinfor2['acc'],
                         prevloss=lossinfor2['loss'],
                     ))

            #end of epoch; test in validation set; save best model for this task
            val_dataset = torch.utils.data.ConcatDataset(
                train_dataset_all[:task])
            #print(len(val_dataset))
            current_precision = validate(model, val_dataset, args, cuda)
            #print('validation and current_precision: ')
            #print(current_precision)
            #print(best_precision)
            #if current_precision > best_precision:
            #print('saving current model')
            best_precision = current_precision
            #print(best_precision)
            save_checkpoint(model, args, epoch, current_precision, task)
            #else:
            #load_checkpoint(model, args, task, cuda)

        #end of task;  evaluate on current task and previous tasks

        load_checkpoint(model, args, task, cuda)
        #current_proto, current_oriproto = storeOriProto_Proto(model, train_dataset, args, cuda)

        args.oriproto_eachsize = int(
            math.floor(
                args.oriproto_size /
                args.dataset_classes))  #floor the number of oriprotos per task
        current_oriproto = storeOriProto(model, train_dataset, args, cuda)
        #print('current_oriproto')
        #print(current_oriproto.size())
        for t in range(args.dataset_current_classes):
            oriprotos[args.dataset_classes - args.dataset_current_classes +
                      t] = current_oriproto[t]

        current_proto_mu, current_proto_logvar = storeProto(
            model, oriprotos, args, cuda)
        current_proto_mu = current_proto_mu.view(args.dataset_classes,
                                                 args.z_dim)
        current_proto_logvar = current_proto_logvar.view(
            args.dataset_classes, args.z_dim)
        protos_mu[task - 1] = current_proto_mu
        protos_logvar[task - 1] = current_proto_logvar
        #print('train func - protos_mu')
        #print(protos_mu[task-1])
        savefilename = args.result_dir + 'protos_' + str(task) + '.mat'
        scipy.io.savemat(
            savefilename, {
                'protosmu': current_proto_mu.detach().cpu().numpy(),
                'protos_logvar': current_proto_logvar.detach().cpu().numpy()
            })
        #print('train func - protos_logvar')
        #print(protos_logvar[task-1])

        #shrink the storage to make it constant
        for i in range(args.dataset_classes):
            #reduce the storage of previous protos
            oriprotos[i] = oriprotos[i][:args.oriproto_eachsize, :]
            #print('train func - oriprotos[i]:')
            #print(oriprotos[i].size())

        #validate performance in all previous tasks
        test_dataset_con = torch.utils.data.ConcatDataset(
            test_dataset_all[:task])
        #print(len(val_dataset))
        prec = test(model, test_dataset_con, protos_mu[task - 1],
                    protos_logvar[task - 1], args, cuda)
        precision_record[task - 1].append(prec)
        #print(precision_record)
        print('Testing results: (task = ' + str(task) + ')')
        #print(precision_record[task-1])
        print(precision_record)
        #save precision_record as tensors for plotting later
        path = os.path.join(
            args.result_dir,
            '{firstname}_{secondname}-precision_record.pt'.format(
                firstname=args.model_name, secondname=args.n_repeats))
        torch.save(precision_record, path)
def storeOriProto_Proto(model, train_dataset, args, cuda=False):
    
    model.train(mode = False) #reset to test mode
    
    sampler = BalancedDatasetSampler(train_dataset,args.dataset_samples,args.dataset_classes,args.dataset_episodes)  
            
    if cuda:
        loader = torch.utils.data.DataLoader(train_dataset, shuffle=False, sampler=sampler, batch_size=args.batch_size, num_workers=0,pin_memory=True)
    else:
        loader = torch.utils.data.DataLoader(train_dataset, shuffle=False, sampler=sampler, batch_size=args.batch_size, num_workers=0)
            
            
    total_proto = torch.Tensor()
    total_tested = 0    
       
    with torch.no_grad():
        data_stream = tqdm(enumerate(loader, 1))
        for batch_index, (imgs, labels) in data_stream:
            
            # prepare the data.
            imgs = imgs.view(args.dataset_classes,args.dataset_nquery+args.dataset_nsupport,args.dataset_channels, args.dataset_width, args.dataset_width)
            imgs = Variable(imgs).cuda() if cuda else Variable(imgs)
                       
            current = model.getHiddenReps(imgs,args.dataset_nsupport, 
                                                 args.dataset_nquery, 
                                                 args.dataset_classes, 
                                                 args.dataset_channels, 
                                                 args.dataset_width) 
            
            # update statistics.
            if total_proto.nelement() == 0:
                total_proto = current.detach().cpu()
            else:
                total_proto = total_proto + current.detach().cpu()   
                
            total_tested = total_tested + 1                
            
    model.train(mode=True) #reset to training mode
    total_proto = total_proto/total_tested
    #print(total_proto)
    
    ################# find best ori_proto #####################
    model.train(mode = False) #reset to test mode
    
    sampler = BalancedDatasetSampler(train_dataset,args.dataset_samples,args.dataset_classes,args.dataset_episodes)  
            
    if cuda:
        loader = torch.utils.data.DataLoader(train_dataset, shuffle=False, sampler=sampler, batch_size=args.batch_size, num_workers=0,pin_memory=True)
    else:
        loader = torch.utils.data.DataLoader(train_dataset, shuffle=False, sampler=sampler, batch_size=args.batch_size, num_workers=0)
            
            
    total_ratio = 0    
    with torch.no_grad():
        data_stream = tqdm(enumerate(loader, 1))
        for batch_index, (imgs, labels) in data_stream:
            
            # prepare the data.
            imgs = imgs.view(args.dataset_classes,args.dataset_nquery+args.dataset_nsupport,args.dataset_channels, args.dataset_width, args.dataset_width)
            imgs = Variable(imgs).cuda() if cuda else Variable(imgs)
            
            current_ratio = model.calculateDist_protos(imgs,args.dataset_nsupport, 
                                                 args.dataset_nquery, 
                                                 args.dataset_classes, 
                                                 args.dataset_channels, 
                                                 args.dataset_width,
                                                 total_proto,
                                                 args.temperature)
            if cuda:
                current = imgs.cpu().clone()
            
            if total_ratio < current_ratio:                
                total_ratio = current_ratio            
                best_oriproto = current.clone()
               
        #print('current')
        #print(current.size())
    model.train(mode=True) #reset to training mode
    
    return total_proto, best_oriproto