def log_fn(epoch): network.eval() accuracy = evaluate( network, dataset.iterate(args.infer_batch_size, False, split="test"), args.device, args) logger.log_metrics(accuracy, step=epoch)
def log_fn(epoch): network.eval() accuracy = evaluate(network, dataset.iterate(args.infer_batch_size, False, split="test"), args.device, args, label_weights=dataset.label_weights) print(accuracy) logger.log_metrics(accuracy, prefix="initial", step=epoch)
def log_fn_shifted(epoch): network.eval() if args.iterative_iw: label_shift(network, dataset, args) accuracy = evaluate(network, dataset.iterate(args.infer_batch_size, False, split="test"), args.device, args, label_weights=dataset.label_weights) print(accuracy) logger.log_metrics(accuracy, prefix="shifted", step=epoch)
def experiment(args, logger, name, seed=None): """Run LS experiment""" # Seed the experiment if seed is None: seed = args.seed torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) print("Running seed ", seed) # torch.cuda.set_device(args.device) # assert torch.cuda.is_available() if args.reweight: assert args.iterative_iw assert args.shift_correction in ["rlls", "bbse", "cheat", "none"] # Shuffle dataset dataset = get_datasets(args) # Load net_cls = get_net_cls(args) network = net_cls(args.num_cls).to(args.device) # Train h0 # milestones = NABIRDS_MILESTONES if args.dataset == "nabirds" else LONG_MILESTONES milestones = LONG_MILESTONES # Check for h0 if args.dataset == "nabirds": fname = "content/data/%s_%s_%s_%s_%s.h5" % ( args.seed, args.dataset_cap, args.dataset, args.nabirdstype, args.version) else: fname = "content/data/%s_%s_%s_%s_%s_%s.h5" % ( args.seed, args.dataset_cap, args.dataset, args.shift_strategy, args.warmstart_ratio, args.version) if os.path.exists(fname): checkpoint = torch.load(fname, map_location=torch.device(args.device)) network.load_state_dict(checkpoint) network = network.to(args.device) else: train(network, dataset=dataset, epochs=args.initial_epochs, args=args, milestones=milestones) if args.domainsep: sep_train(network, dataset=dataset, epochs=args.initial_epochs, args=args) torch.save(network.state_dict(), home_dir + fname) label_shift(network, dataset, args) # Initialize sampling strategy if args.sampling_strategy == "iwal": iterator = iwal_bootstrap(network, net_cls, dataset, args) else: iterator = general_sampling(network, net_cls, dataset, args) # Initialize results _, initial_labeled_shift, initial_uniform_labeled_shift = measure_composition( dataset) initial_accuracy = evaluate(network, dataset.iterate(args.infer_batch_size, False, split="test"), args.device, args, label_weights=dataset.label_weights) num_labeled = [0] labeled_shifts = [initial_labeled_shift] uniform_labeled_shifts = [initial_uniform_labeled_shift] accuracies = {k: [v] for k, v in initial_accuracy.items()} metrics = { "Number of labels": num_labeled, "Source shift": labeled_shifts, "Uniform source shift": uniform_labeled_shifts, } metrics.update(accuracies) for k, v in metrics.items(): logger.log_metric(k, v[-1], step=metrics["Number of labels"][-1]) print(metrics) save_to_csv(name, metrics, args, logger) # Begin sampling for network in iterator: # Evaluate current network network.eval() accuracy = evaluate(network, dataset.iterate(args.infer_batch_size, False, split="test"), args.device, args, label_weights=dataset.label_weights) for k, v in accuracy.items(): accuracies[k].append(v) num_labeled.append(dataset.online_labeled_len()) new_optimal_weights, labeled_shift, uniform_labeled_shift = measure_composition( dataset) labeled_shifts.append(labeled_shift) uniform_labeled_shifts.append(uniform_labeled_shift) print("Optimal weights for new source", new_optimal_weights) # Record metrics metrics = { "Number of labels": num_labeled, "Source shift": labeled_shifts, "Uniform source shift": uniform_labeled_shifts, } metrics.update(accuracies) for k, v in metrics.items(): logger.log_metric(k, v[-1], step=metrics["Number of labels"][-1]) print(metrics) save_to_csv(name, metrics, args, logger) logger.log_metric("Done", True) return metrics