예제 #1
0
def evaluate_anil(params, features, head, test_tasks, device, features_path,
                  head_path):
    features.load_state_dict(torch.load(features_path))
    features.to(device)

    head.load_state_dict(torch.load(head_path))
    head = MAML(head, lr=params['inner_lr'])
    head.to(device)

    loss = torch.nn.CrossEntropyLoss(reduction='mean')

    return evaluate(params, test_tasks, head, loss, device, features=features)
예제 #2
0
def run_anil(params, test_tasks, device):
    # ANIL
    if 'omni' == params['dataset']:
        print('Loading Omniglot model')
        fc_neurons = 128
        features = ConvBase(output_size=64,
                            hidden=32,
                            channels=1,
                            max_pool=False)
    else:
        print('Loading Mini-ImageNet model')
        fc_neurons = 1600
        features = ConvBase(output_size=64, channels=3, max_pool=True)
    features = torch.nn.Sequential(features,
                                   Lambda(lambda x: x.view(-1, fc_neurons)))
    head = torch.nn.Linear(fc_neurons, params['ways'])
    head = MAML(head, lr=params['inner_lr'])

    # Evaluate the model at every checkpoint
    if eval_iters:
        ckpnt = base_path + "/model_checkpoints/"
        model_ckpnt_results = {}
        for model_ckpnt in os.scandir(ckpnt):
            if model_ckpnt.path.endswith(".pt"):

                if "features" in model_ckpnt.path:
                    features_path = model_ckpnt.path
                    head_path = str.replace(features_path, "features", "head")

                    print(f'Testing {model_ckpnt.path}')
                    res = evaluate_anil(params, features, head, test_tasks,
                                        device, features_path, head_path)
                    model_ckpnt_results[model_ckpnt.path] = res

        with open(base_path + '/ckpnt_results.json', 'w') as fp:
            json.dump(model_ckpnt_results, fp, sort_keys=True, indent=4)

    final_features = base_path + '/features.pt'
    final_head = base_path + '/head.pt'

    if meta_test:
        evaluate_anil(params, features, head, test_tasks, device,
                      final_features, final_head)

    if cl_exp:
        print("Running Continual Learning experiment...")
        features.load_state_dict(torch.load(final_features))
        features.to(device)

        head.load_state_dict(torch.load(final_head))
        head.to(device)

        loss = torch.nn.CrossEntropyLoss(reduction='mean')

        run_cl_exp(base_path,
                   head,
                   loss,
                   test_tasks,
                   device,
                   params['ways'],
                   params['shots'],
                   cl_params=cl_params,
                   features=features)

    if rep_exp:
        features.load_state_dict(torch.load(final_features))
        features.to(device)

        head.load_state_dict(torch.load(final_head))
        head.to(device)

        loss = torch.nn.CrossEntropyLoss(reduction='mean')

        # Only check head change
        rep_params['layers'] = [-1]

        print("Running Representation experiment...")
        run_rep_exp(base_path,
                    head,
                    loss,
                    test_tasks,
                    device,
                    params['ways'],
                    params['shots'],
                    rep_params=rep_params,
                    features=features)