def test_trained_model(args, final_model_id):
    args.seed = 0

    print('Instantiate data generators and model...')
    dataloader = datagenerator.DatasetGen(args)
    args.taskcla, args.inputsize = dataloader.taskcla, dataloader.inputsize
    if args.experiment == 'multidatasets': args.lrs = dataloader.lrs

    def get_model(final_model_id, test_data_id):
        # Load the test model
        test_net = network.Net(args)
        checkpoint_test = torch.load(
            os.path.join(args.checkpoint,
                         'model_{}.pth.tar'.format(test_data_id)))
        test_net.load_state_dict(checkpoint_test['model_state_dict'])

        # Load your final trained model
        net = network.Net(args)
        checkpoint = torch.load(
            os.path.join(args.checkpoint,
                         'model_{}.pth.tar'.format(final_model_id)))
        net.load_state_dict(checkpoint['model_state_dict'])

        # # Change the shared module with the final model's shared module
        final_shared = deepcopy(net.shared.state_dict())
        test_net.shared.load_state_dict(final_shared)
        test_net = test_net.to(args.device)

        return test_net

    for t, ncla in args.taskcla:
        print('*' * 250)
        dataset = dataloader.get(t)
        print(' ' * 105, 'Dataset {:2d} ({:s})'.format(t + 1,
                                                       dataset[t]['name']))
        print('*' * 250)

        # Model
        test_model = get_model(final_model_id, test_data_id=t)

        # Approach
        appr = approach(test_model, args, network=network)

        # Test
        test_res = appr.inference(dataset[t]['test'], t, model=test_model)

        print('>>> Test on task {:2d} - {:15s}: loss={:.3f}, acc={:5.4f}% <<<'.
              format(t, dataset[t]['name'], test_res['loss_t'],
                     test_res['acc_t']))
Example #2
0
def run(args, run_id):

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)

        # Faster run but not deterministic:
        # torch.backends.cudnn.benchmark = True

        # To get deterministic results that match with paper at cost of lower speed:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # Data loader
    print('Instantiate data generators and model...')
    dataloader = datagenerator.DatasetGen(args)
    args.taskcla, args.inputsize = dataloader.taskcla, dataloader.inputsize
    if args.experiment == 'multidatasets': args.lrs = dataloader.lrs

    # Model
    net = network.Net(args)
    net = net.to(args.device)

    net.print_model_size()
    # print (net)

    # Approach
    appr=approach(net,args,network=network)

    # Loop tasks
    acc=np.zeros((len(args.taskcla),len(args.taskcla)),dtype=np.float32)
    lss=np.zeros((len(args.taskcla),len(args.taskcla)),dtype=np.float32)

    for t,ncla in args.taskcla:

        print('*'*250)
        dataset = dataloader.get(t)
        print(' '*105, 'Dataset {:2d} ({:s})'.format(t+1,dataset[t]['name']))
        print('*'*250)

        # Train
        appr.train(t,dataset[t])
        print('-'*250)
        print()

        for u in range(t+1):
            # Load previous model and replace the shared module with the current one
            test_model = appr.load_model(u)
            test_res = appr.test(dataset[u]['test'], u, model=test_model)

            print('>>> Test on task {:2d} - {:15s}: loss={:.3f}, acc={:5.1f}% <<<'.format(u, dataset[u]['name'],
                                                                                          test_res['loss_t'],
                                                                                          test_res['acc_t']))


            acc[t, u] = test_res['acc_t']
            lss[t, u] = test_res['loss_t']


        # Save
        print()
        print('Saved accuracies at '+os.path.join(args.checkpoint,args.output))
        np.savetxt(os.path.join(args.checkpoint,args.output),acc,'%.6f')

    # Extract embeddings to plot in tensorboard for miniimagenet
    if args.tsne == 'yes' and args.experiment == 'miniimagenet':
        appr.get_tsne_embeddings_first_ten_tasks(dataset, model=appr.load_model(t))
        appr.get_tsne_embeddings_last_three_tasks(dataset, model=appr.load_model(t))

    avg_acc, gem_bwt = utils.print_log_acc_bwt(args.taskcla, acc, lss, output_path=args.checkpoint, run_id=run_id)

    return avg_acc, gem_bwt