示例#1
0
 def log_fn(epoch):
     network.eval()
     accuracy = evaluate(
         network, dataset.iterate(args.infer_batch_size,
                                  False,
                                  split="test"), args.device, args)
     logger.log_metrics(accuracy, step=epoch)
示例#2
0
 def log_fn(epoch):
     network.eval()
     accuracy = evaluate(network,
                         dataset.iterate(args.infer_batch_size,
                                         False,
                                         split="test"),
                         args.device,
                         args,
                         label_weights=dataset.label_weights)
     print(accuracy)
     logger.log_metrics(accuracy, prefix="initial", step=epoch)
示例#3
0
 def log_fn_shifted(epoch):
     network.eval()
     if args.iterative_iw:
         label_shift(network, dataset, args)
     accuracy = evaluate(network,
                         dataset.iterate(args.infer_batch_size,
                                         False,
                                         split="test"),
                         args.device,
                         args,
                         label_weights=dataset.label_weights)
     print(accuracy)
     logger.log_metrics(accuracy, prefix="shifted", step=epoch)
示例#4
0
def experiment(args, logger, name, seed=None):
    """Run LS experiment"""

    # Seed the experiment
    if seed is None:
        seed = args.seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    print("Running seed ", seed)
    # torch.cuda.set_device(args.device)
    # assert torch.cuda.is_available()
    if args.reweight:
        assert args.iterative_iw

    assert args.shift_correction in ["rlls", "bbse", "cheat", "none"]

    # Shuffle dataset
    dataset = get_datasets(args)

    # Load
    net_cls = get_net_cls(args)
    network = net_cls(args.num_cls).to(args.device)

    # Train h0
    # milestones = NABIRDS_MILESTONES if args.dataset == "nabirds" else LONG_MILESTONES
    milestones = LONG_MILESTONES

    # Check for h0
    if args.dataset == "nabirds":
        fname = "content/data/%s_%s_%s_%s_%s.h5" % (
            args.seed, args.dataset_cap, args.dataset, args.nabirdstype,
            args.version)
    else:
        fname = "content/data/%s_%s_%s_%s_%s_%s.h5" % (
            args.seed, args.dataset_cap, args.dataset, args.shift_strategy,
            args.warmstart_ratio, args.version)
    if os.path.exists(fname):
        checkpoint = torch.load(fname, map_location=torch.device(args.device))
        network.load_state_dict(checkpoint)
        network = network.to(args.device)
    else:
        train(network,
              dataset=dataset,
              epochs=args.initial_epochs,
              args=args,
              milestones=milestones)
        if args.domainsep:
            sep_train(network,
                      dataset=dataset,
                      epochs=args.initial_epochs,
                      args=args)
        torch.save(network.state_dict(), home_dir + fname)
    label_shift(network, dataset, args)

    # Initialize sampling strategy
    if args.sampling_strategy == "iwal":
        iterator = iwal_bootstrap(network, net_cls, dataset, args)
    else:
        iterator = general_sampling(network, net_cls, dataset, args)

    # Initialize results
    _, initial_labeled_shift, initial_uniform_labeled_shift = measure_composition(
        dataset)
    initial_accuracy = evaluate(network,
                                dataset.iterate(args.infer_batch_size,
                                                False,
                                                split="test"),
                                args.device,
                                args,
                                label_weights=dataset.label_weights)

    num_labeled = [0]
    labeled_shifts = [initial_labeled_shift]
    uniform_labeled_shifts = [initial_uniform_labeled_shift]
    accuracies = {k: [v] for k, v in initial_accuracy.items()}

    metrics = {
        "Number of labels": num_labeled,
        "Source shift": labeled_shifts,
        "Uniform source shift": uniform_labeled_shifts,
    }
    metrics.update(accuracies)
    for k, v in metrics.items():
        logger.log_metric(k, v[-1], step=metrics["Number of labels"][-1])
    print(metrics)
    save_to_csv(name, metrics, args, logger)

    # Begin sampling
    for network in iterator:
        # Evaluate current network
        network.eval()
        accuracy = evaluate(network,
                            dataset.iterate(args.infer_batch_size,
                                            False,
                                            split="test"),
                            args.device,
                            args,
                            label_weights=dataset.label_weights)
        for k, v in accuracy.items():
            accuracies[k].append(v)
        num_labeled.append(dataset.online_labeled_len())

        new_optimal_weights, labeled_shift, uniform_labeled_shift = measure_composition(
            dataset)
        labeled_shifts.append(labeled_shift)
        uniform_labeled_shifts.append(uniform_labeled_shift)
        print("Optimal weights for new source", new_optimal_weights)

        # Record metrics
        metrics = {
            "Number of labels": num_labeled,
            "Source shift": labeled_shifts,
            "Uniform source shift": uniform_labeled_shifts,
        }
        metrics.update(accuracies)
        for k, v in metrics.items():
            logger.log_metric(k, v[-1], step=metrics["Number of labels"][-1])
        print(metrics)
        save_to_csv(name, metrics, args, logger)

    logger.log_metric("Done", True)
    return metrics