Beispiel #1
0
def pa(context_features,
       context_labels,
       max_iter=40,
       ad_opt='linear',
       lr=0.1,
       distance='cos'):
    """
    PA method: learning a linear transformation per task to adapt the features to a discriminative space 
    on the support set during meta-testing
    """
    input_dim = context_features.size(1)
    output_dim = input_dim
    stdv = 1. / math.sqrt(input_dim)
    vartheta = []
    if ad_opt == 'linear':
        vartheta.append(
            torch.eye(output_dim, input_dim).unsqueeze(-1).unsqueeze(-1).to(
                device).requires_grad_(True))

    optimizer = torch.optim.Adadelta(vartheta, lr=lr)
    for i in range(max_iter):
        optimizer.zero_grad()
        selected_features = apply_selection(context_features, vartheta)
        loss, stat, _ = prototype_loss(selected_features,
                                       context_labels,
                                       selected_features,
                                       context_labels,
                                       distance=distance)

        loss.backward()
        optimizer.step()
    return vartheta
Beispiel #2
0
def main():
    LIMITER = 5
    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = True

    # Setting up datasets
    extractor_domains = TRAIN_METADATASET_NAMES
    all_test_datasets = ALL_METADATASET_NAMES
    loader = MetaDatasetEpisodeReader('test',
                                      train_set=extractor_domains,
                                      validation_set=extractor_domains,
                                      test_set=all_test_datasets)

    # define the embedding method
    dataset_models = DATASET_MODELS_DICT[args['model.backbone']]
    embed_many = get_domain_extractors(extractor_domains, dataset_models, args)

    accs_names = ['SUR']

    all_accs = dict()
    # Go over all test datasets
    for test_dataset in all_test_datasets:
        print(test_dataset)
        all_accs[test_dataset] = {name: [] for name in accs_names}

        with tf.compat.v1.Session(config=config) as session:
            for idx in tqdm(range(LIMITER)):
                # extract image features and labels
                sample = loader.get_test_task(session, test_dataset)
                context_features_dict = embed_many(sample['context_images'])
                target_features_dict = embed_many(sample['target_images'])
                context_labels = sample['context_labels'].to(device)
                target_labels = sample['target_labels'].to(device)

                # optimize selection parameters and perform feature selection
                selection_params = sur(context_features_dict, context_labels, max_iter=40)
                selected_context = apply_selection(context_features_dict, selection_params)
                selected_target = apply_selection(target_features_dict, selection_params)

                final_acc = prototype_loss(selected_context, context_labels,
                                           selected_target, target_labels)[1]['acc']
                all_accs[test_dataset]['SUR'].append(final_acc)

    # Make a nice accuracy table
    rows = []
    for dataset_name in all_test_datasets:
        row = [dataset_name]
        for model_name in accs_names:
            acc = np.array(all_accs[dataset_name][model_name]) * 100
            mean_acc = acc.mean()
            conf = (1.96 * acc.std()) / np.sqrt(len(acc))
            row.append(f"{mean_acc:0.2f} +- {conf:0.2f}")
        rows.append(row)

    table = tabulate(rows, headers=['model \\ data'] + accs_names, floatfmt=".2f")
    print(table)
    print("\n")
Beispiel #3
0
def main():
    TEST_SIZE = 600

    # Setting up datasets
    trainsets, valsets, testsets = args['data.train'], args['data.val'], args[
        'data.test']
    test_loader = MetaDatasetEpisodeReader('test', trainsets, valsets,
                                           testsets)
    model = get_model(None, args)
    checkpointer = CheckPointer(args, model, optimizer=None)
    checkpointer.restore_model(ckpt='best', strict=False)
    model.eval()

    accs_names = ['NCC']
    var_accs = dict()

    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.compat.v1.Session(config=config) as session:
        # go over each test domain
        for dataset in testsets:
            print(dataset)
            var_accs[dataset] = {name: [] for name in accs_names}

            for i in tqdm(range(TEST_SIZE)):
                with torch.no_grad():
                    sample = test_loader.get_test_task(session, dataset)
                    context_features = model.embed(sample['context_images'])
                    target_features = model.embed(sample['target_images'])
                    context_labels = sample['context_labels']
                    target_labels = sample['target_labels']
                    _, stats_dict, _ = prototype_loss(context_features,
                                                      context_labels,
                                                      target_features,
                                                      target_labels)
                var_accs[dataset]['NCC'].append(stats_dict['acc'])

    # Print nice results table
    rows = []
    for dataset_name in testsets:
        row = [dataset_name]
        for model_name in accs_names:
            acc = np.array(var_accs[dataset_name][model_name]) * 100
            mean_acc = acc.mean()
            conf = (1.96 * acc.std()) / np.sqrt(len(acc))
            row.append(f"{mean_acc:0.2f} +- {conf:0.2f}")
        rows.append(row)

    table = tabulate(rows,
                     headers=['model \\ data'] + accs_names,
                     floatfmt=".2f")
    print(table)
    print("\n")
Beispiel #4
0
def tsa(context_images,
        context_labels,
        model,
        max_iter=40,
        lr=0.1,
        lr_beta=1,
        distance='cos'):
    """
    Optimizing task-specific parameters attached to the ResNet backbone, 
    e.g. adapters (alpha) and/or pre-classifier alignment mapping (beta)
    """
    model.eval()
    tsa_opt = args['test.tsa_opt']
    alpha_params = [v for k, v in model.named_parameters() if 'alpha' in k]
    beta_params = [v for k, v in model.named_parameters() if 'beta' in k]
    params = []
    if 'alpha' in tsa_opt:
        params.append({'params': alpha_params})
    if 'beta' in tsa_opt:
        params.append({'params': beta_params, 'lr': lr_beta})

    optimizer = torch.optim.Adadelta(params, lr=lr)

    if 'alpha' not in tsa_opt:
        with torch.no_grad():
            context_features = model.embed(context_images)
    for i in range(max_iter):
        optimizer.zero_grad()
        model.zero_grad()

        if 'alpha' in tsa_opt:
            # adapt features by task-specific adapters
            context_features = model.embed(context_images)
        if 'beta' in tsa_opt:
            # adapt feature by PA (beta)
            aligned_features = model.beta(context_features)
        else:
            aligned_features = context_features
        loss, stat, _ = prototype_loss(aligned_features,
                                       context_labels,
                                       aligned_features,
                                       context_labels,
                                       distance=distance)

        loss.backward()
        optimizer.step()
    return
Beispiel #5
0
def sur(context_features_dict, context_labels, max_iter=40):
    """
    SUR method: optimizes selection parameters lambda
    """
    lambdas = torch.zeros([1, 1, len(context_features_dict)]).to(device)
    lambdas.requires_grad_(True)
    n_classes = len(np.unique(context_labels.cpu().numpy()))
    optimizer = torch.optim.Adadelta([lambdas], lr=(3e+3 / n_classes))

    for i in range(max_iter):
        optimizer.zero_grad()
        selected_features = apply_selection(context_features_dict, lambdas)
        loss, stat, _ = prototype_loss(selected_features, context_labels,
                                       selected_features, context_labels)

        loss.backward()
        optimizer.step()
    return lambdas
Beispiel #6
0
def main():
    TEST_SIZE = 600

    # Setting up datasets
    trainsets, valsets, testsets = args['data.train'], args['data.val'], args[
        'data.test']
    testsets = ALL_METADATASET_NAMES  # comment this line to test the model on args['data.test']
    trainsets = TRAIN_METADATASET_NAMES
    test_loader = MetaDatasetEpisodeReader('test',
                                           trainsets,
                                           trainsets,
                                           testsets,
                                           test_type=args['test.type'])
    model = get_model(None, args)
    checkpointer = CheckPointer(args, model, optimizer=None)
    checkpointer.restore_model(ckpt='best', strict=False)
    model.eval()

    accs_names = ['NCC']
    var_accs = dict()

    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = False
    with tf.compat.v1.Session(config=config) as session:
        # go over each test domain
        for dataset in testsets:
            if dataset in TRAIN_METADATASET_NAMES:
                lr = 0.1
            else:
                lr = 1
            print(dataset)
            var_accs[dataset] = {name: [] for name in accs_names}

            for i in tqdm(range(TEST_SIZE)):
                with torch.no_grad():
                    sample = test_loader.get_test_task(session, dataset)
                    context_features = model.embed(sample['context_images'])
                    target_features = model.embed(sample['target_images'])
                    context_labels = sample['context_labels']
                    target_labels = sample['target_labels']

                # optimize selection parameters and perform feature selection
                selection_params = pa(context_features,
                                      context_labels,
                                      max_iter=40,
                                      lr=lr,
                                      distance=args['test.distance'])
                selected_context = apply_selection(context_features,
                                                   selection_params)
                selected_target = apply_selection(target_features,
                                                  selection_params)
                _, stats_dict, _ = prototype_loss(
                    selected_context,
                    context_labels,
                    selected_target,
                    target_labels,
                    distance=args['test.distance'])

                var_accs[dataset]['NCC'].append(stats_dict['acc'])
            dataset_acc = np.array(var_accs[dataset]['NCC']) * 100
            print(f"{dataset}: test_acc {dataset_acc.mean():.2f}%")
    # Print nice results table
    print('results of {}'.format(args['model.name']))
    rows = []
    for dataset_name in testsets:
        row = [dataset_name]
        for model_name in accs_names:
            acc = np.array(var_accs[dataset_name][model_name]) * 100
            mean_acc = acc.mean()
            conf = (1.96 * acc.std()) / np.sqrt(len(acc))
            row.append(f"{mean_acc:0.2f} +- {conf:0.2f}")
        rows.append(row)
    out_path = os.path.join(args['out.dir'], 'weights')
    out_path = check_dir(out_path, True)
    out_path = os.path.join(
        out_path,
        '{}-{}-{}-{}-test-results.npy'.format(args['model.name'],
                                              args['test.type'], 'pa',
                                              args['test.distance']))
    np.save(out_path, {'rows': rows})

    table = tabulate(rows,
                     headers=['model \\ data'] + accs_names,
                     floatfmt=".2f")
    print(table)
    print("\n")
Beispiel #7
0
def train():
    # initialize datasets and loaders
    trainsets, valsets, testsets = args['data.train'], args['data.val'], args[
        'data.test']
    train_loader = MetaDatasetBatchReader('train',
                                          trainsets,
                                          valsets,
                                          testsets,
                                          batch_size=args['train.batch_size'])
    val_loader = MetaDatasetEpisodeReader('val', trainsets, valsets, testsets)

    # initialize model and optimizer
    num_train_classes = sum(list(train_loader.dataset_to_n_cats.values()))
    model = get_model(num_train_classes, args)
    optimizer = get_optimizer(model, args, params=model.get_parameters())

    # Restoring the last checkpoint
    checkpointer = CheckPointer(args, model, optimizer=optimizer)
    if os.path.isfile(checkpointer.last_ckpt) and args['train.resume']:
        start_iter, best_val_loss, best_val_acc =\
            checkpointer.restore_model(ckpt='last', strict=False)
    else:
        print('No checkpoint restoration')
        best_val_loss = 999999999
        best_val_acc = start_iter = 0

    # define learning rate policy
    if args['train.lr_policy'] == "step":
        lr_manager = UniformStepLR(optimizer, args, start_iter)
    elif "exp_decay" in args['train.lr_policy']:
        lr_manager = ExpDecayLR(optimizer, args, start_iter)
    elif "cosine" in args['train.lr_policy']:
        lr_manager = CosineAnnealRestartLR(optimizer, args, start_iter)

    # defining the summary writer
    writer = SummaryWriter(checkpointer.model_path)

    # Training loop
    max_iter = args['train.max_iter']
    epoch_loss = {name: [] for name in trainsets}
    epoch_acc = {name: [] for name in trainsets}
    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.compat.v1.Session(config=config) as session:
        for i in tqdm(range(max_iter)):
            if i < start_iter:
                continue

            optimizer.zero_grad()
            sample = train_loader.get_train_batch(session)
            batch_dataset = sample['dataset_name']
            dataset_id = sample['dataset_ids'][0].detach().cpu().item()
            logits = model.forward(sample['images'])
            labels = sample['labels']
            batch_loss, stats_dict, _ = cross_entropy_loss(logits, labels)
            epoch_loss[batch_dataset].append(stats_dict['loss'])
            epoch_acc[batch_dataset].append(stats_dict['acc'])

            batch_loss.backward()
            optimizer.step()
            lr_manager.step(i)

            if (i + 1) % 200 == 0:
                for dataset_name in trainsets:
                    writer.add_scalar(f"loss/{dataset_name}-train_acc",
                                      np.mean(epoch_loss[dataset_name]), i)
                    writer.add_scalar(f"accuracy/{dataset_name}-train_acc",
                                      np.mean(epoch_acc[dataset_name]), i)
                    epoch_loss[dataset_name], epoch_acc[dataset_name] = [], []

                writer.add_scalar('learning_rate',
                                  optimizer.param_groups[0]['lr'], i)

            # Evaluation inside the training loop
            if (i + 1) % args['train.eval_freq'] == 0:
                model.eval()
                dataset_accs, dataset_losses = [], []
                for valset in valsets:
                    dataset_id = train_loader.dataset_name_to_dataset_id[
                        valset]
                    val_losses, val_accs = [], []
                    for j in tqdm(range(args['train.eval_size'])):
                        with torch.no_grad():
                            sample = val_loader.get_validation_task(
                                session, valset)
                            context_features = model.embed(
                                sample['context_images'])
                            target_features = model.embed(
                                sample['target_images'])
                            context_labels = sample['context_labels']
                            target_labels = sample['target_labels']
                            _, stats_dict, _ = prototype_loss(
                                context_features, context_labels,
                                target_features, target_labels)
                        val_losses.append(stats_dict['loss'])
                        val_accs.append(stats_dict['acc'])

                    # write summaries per validation set
                    dataset_acc, dataset_loss = np.mean(
                        val_accs) * 100, np.mean(val_losses)
                    dataset_accs.append(dataset_acc)
                    dataset_losses.append(dataset_loss)
                    writer.add_scalar(f"loss/{valset}/val_loss", dataset_loss,
                                      i)
                    writer.add_scalar(f"accuracy/{valset}/val_acc",
                                      dataset_acc, i)
                    print(
                        f"{valset}: val_acc {dataset_acc:.2f}%, val_loss {dataset_loss:.3f}"
                    )

                # write summaries averaged over datasets
                avg_val_loss, avg_val_acc = np.mean(dataset_losses), np.mean(
                    dataset_accs)
                writer.add_scalar(f"loss/avg_val_loss", avg_val_loss, i)
                writer.add_scalar(f"accuracy/avg_val_acc", avg_val_acc, i)

                # saving checkpoints
                if avg_val_acc > best_val_acc:
                    best_val_loss, best_val_acc = avg_val_loss, avg_val_acc
                    is_best = True
                    print('Best model so far!')
                else:
                    is_best = False
                checkpointer.save_checkpoint(i,
                                             best_val_acc,
                                             best_val_loss,
                                             is_best,
                                             optimizer=optimizer,
                                             state_dict=model.get_state_dict())

                model.train()
                print(f"Trained and evaluated at {i}")

    writer.close()
    if start_iter < max_iter:
        print(
            f"""Done training with best_mean_val_loss: {best_val_loss:.3f}, best_avg_val_acc: {best_val_acc:.2f}%"""
        )
    else:
        print(
            f"""No training happened. Loaded checkpoint at {start_iter}, while max_iter was {max_iter}"""
        )
Beispiel #8
0
def train():
    # initialize datasets and loaders
    trainsets, valsets, testsets = args['data.train'], args['data.val'], args[
        'data.test']

    train_loaders = []
    num_train_classes = dict()
    kd_weight_annealing = dict()
    for t_indx, trainset in enumerate(trainsets):
        train_loaders.append(
            MetaDatasetBatchReader('train', [trainset],
                                   valsets,
                                   testsets,
                                   batch_size=BATCHSIZES[trainset]))
        num_train_classes[trainset] = train_loaders[t_indx].num_classes(
            'train')
        # setting up knowledge distillation losses weights annealing
        kd_weight_annealing[trainset] = WeightAnnealing(
            T=int(args['train.cosine_anneal_freq'] * KDANNEALING[trainset]))
    val_loader = MetaDatasetEpisodeReader('val', trainsets, valsets, testsets)

    # initialize model and optimizer
    model = get_model(list(num_train_classes.values()), args)
    model_name_temp = args['model.name']
    # KL-divergence loss
    criterion_div = DistillKL(T=4)
    # get a MTL model initialized by ImageNet pretrained model and deactivate the pretrained flag
    args['model.pretrained'] = False
    optimizer = get_optimizer(model, args, params=model.get_parameters())
    # adaptors for aligning features between MDL and SDL models
    adaptors = adaptor(num_datasets=len(trainsets),
                       dim_in=512,
                       opt=args['adaptor.opt']).to(device)
    optimizer_adaptor = torch.optim.Adam(adaptors.parameters(),
                                         lr=0.1,
                                         weight_decay=5e-4)

    # loading single domain learning networks
    extractor_domains = trainsets
    dataset_models = DATASET_MODELS_DICT[args['model.backbone']]
    embed_many = get_domain_extractors(extractor_domains, dataset_models, args,
                                       num_train_classes)

    # restoring the last checkpoint
    args['model.name'] = model_name_temp
    checkpointer = CheckPointer(args, model, optimizer=optimizer)
    if os.path.isfile(checkpointer.out_last_ckpt) and args['train.resume']:
        start_iter, best_val_loss, best_val_acc =\
            checkpointer.restore_out_model(ckpt='last')
    else:
        print('No checkpoint restoration')
        best_val_loss = 999999999
        best_val_acc = start_iter = 0

    # define learning rate policy
    if args['train.lr_policy'] == "step":
        lr_manager = UniformStepLR(optimizer, args, start_iter)
        lr_manager_ad = UniformStepLR(optimizer_adaptor, args, start_iter)
    elif "exp_decay" in args['train.lr_policy']:
        lr_manager = ExpDecayLR(optimizer, args, start_iter)
        lr_manager_ad = ExpDecayLR(optimizer_adaptor, args, start_iter)
    elif "cosine" in args['train.lr_policy']:
        lr_manager = CosineAnnealRestartLR(optimizer, args, start_iter)
        lr_manager_ad = CosineAnnealRestartLR(optimizer_adaptor, args,
                                              start_iter)

    # defining the summary writer
    writer = SummaryWriter(checkpointer.out_path)

    # Training loop
    max_iter = args['train.max_iter']
    epoch_loss = {name: [] for name in trainsets}
    epoch_kd_f_loss = {name: [] for name in trainsets}
    epoch_kd_p_loss = {name: [] for name in trainsets}
    epoch_acc = {name: [] for name in trainsets}
    epoch_val_loss = {name: [] for name in valsets}
    epoch_val_acc = {name: [] for name in valsets}
    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = False
    with tf.compat.v1.Session(config=config) as session:
        for i in tqdm(range(max_iter)):
            if i < start_iter:
                continue

            optimizer.zero_grad()
            optimizer_adaptor.zero_grad()

            samples = []
            images = dict()
            num_samples = []
            # loading images and labels
            for t_indx, (name, train_loader) in enumerate(
                    zip(trainsets, train_loaders)):
                sample = train_loader.get_train_batch(session)
                samples.append(sample)
                images[name] = sample['images']
                num_samples.append(sample['images'].size(0))

            logits, mtl_features = model.forward(torch.cat(list(
                images.values()),
                                                           dim=0),
                                                 num_samples,
                                                 kd=True)
            stl_features, stl_logits = embed_many(images,
                                                  return_type='list',
                                                  kd=True,
                                                  logits=True)
            mtl_features = adaptors(mtl_features)

            batch_losses, stats_dicts = [], []
            kd_f_losses = 0
            kd_p_losses = 0
            for t_indx, trainset in enumerate(trainsets):
                batch_loss, stats_dict, _ = cross_entropy_loss(
                    logits[t_indx], samples[t_indx]['labels'])
                batch_losses.append(batch_loss * LOSSWEIGHTS[trainset])
                stats_dicts.append(stats_dict)
                batch_dataset = samples[t_indx]['dataset_name']
                epoch_loss[batch_dataset].append(stats_dict['loss'])
                epoch_acc[batch_dataset].append(stats_dict['acc'])
                ft, fs = torch.nn.functional.normalize(
                    stl_features[t_indx], p=2, dim=1,
                    eps=1e-12), torch.nn.functional.normalize(
                        mtl_features[t_indx], p=2, dim=1, eps=1e-12)
                kd_f_losses_ = distillation_loss(fs,
                                                 ft.detach(),
                                                 opt='kernelcka')
                kd_p_losses_ = criterion_div(logits[t_indx],
                                             stl_logits[t_indx])
                kd_weight = kd_weight_annealing[trainset](
                    t=i, opt='linear') * KDFLOSSWEIGHTS[trainset]
                bam_weight = kd_weight_annealing[trainset](
                    t=i, opt='linear') * KDPLOSSWEIGHTS[trainset]
                if kd_weight > 0:
                    kd_f_losses = kd_f_losses + kd_f_losses_ * kd_weight
                if bam_weight > 0:
                    kd_p_losses = kd_p_losses + kd_p_losses_ * bam_weight
                epoch_kd_f_loss[batch_dataset].append(kd_f_losses_.item())
                epoch_kd_p_loss[batch_dataset].append(kd_p_losses_.item())

            batch_loss = torch.stack(batch_losses).sum()
            kd_f_loss = kd_f_losses * args['train.sigma']
            kd_p_loss = kd_p_losses * args['train.beta']
            batch_loss = batch_loss + kd_f_loss + kd_p_loss

            batch_loss.backward()
            optimizer.step()
            optimizer_adaptor.step()
            lr_manager.step(i)
            lr_manager_ad.step(i)

            if (i + 1) % 200 == 0:
                for dataset_name in trainsets:
                    writer.add_scalar(f"loss/{dataset_name}-train_loss",
                                      np.mean(epoch_loss[dataset_name]), i)
                    writer.add_scalar(f"accuracy/{dataset_name}-train_acc",
                                      np.mean(epoch_acc[dataset_name]), i)
                    writer.add_scalar(
                        f"kd_f_loss/{dataset_name}-train_kd_f_loss",
                        np.mean(epoch_kd_f_loss[dataset_name]), i)
                    writer.add_scalar(
                        f"kd_p_loss/{dataset_name}-train_kd_p_loss",
                        np.mean(epoch_kd_p_loss[dataset_name]), i)
                    epoch_loss[dataset_name], epoch_acc[
                        dataset_name], epoch_kd_f_loss[
                            dataset_name], epoch_kd_p_loss[
                                dataset_name] = [], [], [], []

                writer.add_scalar('learning_rate',
                                  optimizer.param_groups[0]['lr'], i)

            # Evaluation inside the training loop
            if (i + 1) % args['train.eval_freq'] == 0:
                model.eval()
                dataset_accs, dataset_losses = [], []
                for valset in valsets:
                    val_losses, val_accs = [], []
                    for j in tqdm(range(args['train.eval_size'])):
                        with torch.no_grad():
                            sample = val_loader.get_validation_task(
                                session, valset)
                            context_features = model.embed(
                                sample['context_images'])
                            target_features = model.embed(
                                sample['target_images'])
                            context_labels = sample['context_labels']
                            target_labels = sample['target_labels']
                            _, stats_dict, _ = prototype_loss(
                                context_features, context_labels,
                                target_features, target_labels)
                        val_losses.append(stats_dict['loss'])
                        val_accs.append(stats_dict['acc'])

                    # write summaries per validation set
                    dataset_acc, dataset_loss = np.mean(
                        val_accs) * 100, np.mean(val_losses)
                    dataset_accs.append(dataset_acc)
                    dataset_losses.append(dataset_loss)
                    epoch_val_loss[valset].append(dataset_loss)
                    epoch_val_acc[valset].append(dataset_acc)
                    writer.add_scalar(f"loss/{valset}/val_loss", dataset_loss,
                                      i)
                    writer.add_scalar(f"accuracy/{valset}/val_acc",
                                      dataset_acc, i)
                    print(
                        f"{valset}: val_acc {dataset_acc:.2f}%, val_loss {dataset_loss:.3f}"
                    )

                # write summaries averaged over datasets
                avg_val_loss, avg_val_acc = np.mean(dataset_losses), np.mean(
                    dataset_accs)
                writer.add_scalar(f"loss/avg_val_loss", avg_val_loss, i)
                writer.add_scalar(f"accuracy/avg_val_acc", avg_val_acc, i)

                # saving checkpoints
                if avg_val_acc > best_val_acc:
                    best_val_loss, best_val_acc = avg_val_loss, avg_val_acc
                    is_best = True
                    print('Best model so far!')
                else:
                    is_best = False
                extra_dict = {
                    'epoch_loss': epoch_loss,
                    'epoch_acc': epoch_acc,
                    'epoch_val_loss': epoch_val_loss,
                    'epoch_val_acc': epoch_val_acc,
                    'adaptors': adaptors.state_dict(),
                    'optimizer_adaptor': optimizer_adaptor.state_dict()
                }
                checkpointer.save_checkpoint(i,
                                             best_val_acc,
                                             best_val_loss,
                                             is_best,
                                             optimizer=optimizer,
                                             state_dict=model.get_state_dict(),
                                             extra=extra_dict)

                model.train()
                print(f"Trained and evaluated at {i}")

    writer.close()
    if start_iter < max_iter:
        print(
            f"""Done training with best_mean_val_loss: {best_val_loss:.3f}, best_avg_val_acc: {best_val_acc:.2f}%"""
        )
    else:
        print(
            f"""No training happened. Loaded checkpoint at {start_iter}, while max_iter was {max_iter}"""
        )
Beispiel #9
0
def main():
    LIMITER = 600

    # Setting up datasets
    extractor_domains = TRAIN_METADATASET_NAMES
    all_test_datasets = ALL_METADATASET_NAMES
    dump_name = args['dump.name'] if args['dump.name'] else 'test_dump'
    testset = LMDBDataset(extractor_domains, all_test_datasets,
                          args['model.backbone'], 'test', dump_name, LIMITER)

    # define the embedding method
    dataset_models = DATASET_MODELS_DICT[args['model.backbone']]
    embed_many = get_domain_extractors(extractor_domains, dataset_models, args)

    accs_names = ['SUR']
    all_accs = dict()
    # Go over all test datasets
    for test_dataset in all_test_datasets:
        print(test_dataset)
        testset.set_sampling_dataset(test_dataset)
        test_loader = DataLoader(testset,
                                 batch_size=None,
                                 batch_sampler=None,
                                 num_workers=16)
        all_accs[test_dataset] = {name: [] for name in accs_names}

        for sample in tqdm(test_loader):
            context_labels = sample['context_labels'].to(device)
            target_labels = sample['target_labels'].to(device)
            context_features_dict = {
                k: v.to(device)
                for k, v in sample['context_feature_dict'].items()
            }
            target_features_dict = {
                k: v.to(device)
                for k, v in sample['target_feature_dict'].items()
            }

            # optimize selection parameters and perform feature selection
            selection_params = sur(context_features_dict,
                                   context_labels,
                                   max_iter=40)
            selected_context = apply_selection(context_features_dict,
                                               selection_params)
            selected_target = apply_selection(target_features_dict,
                                              selection_params)

            final_acc = prototype_loss(selected_context, context_labels,
                                       selected_target,
                                       target_labels)[1]['acc']
            all_accs[test_dataset]['SUR'].append(final_acc)

    # Make a nice accuracy table
    rows = []
    for dataset_name in all_test_datasets:
        row = [dataset_name]
        for model_name in accs_names:
            acc = np.array(all_accs[dataset_name][model_name]) * 100
            mean_acc = acc.mean()
            conf = (1.96 * acc.std()) / np.sqrt(len(acc))
            row.append(f"{mean_acc:0.2f} +- {conf:0.2f}")
        rows.append(row)

    table = tabulate(rows,
                     headers=['model \\ data'] + accs_names,
                     floatfmt=".2f")
    print(table)
    print("\n")