def experiment(): """Baseline exp""" args = get_args(None) logger = Experiment(comet_ml_key, project_name="active-label-shift-adaptation") logger.set_name("Baseline lr {} g {} bs {} {} {}".format( args.lr, args.gamma, args.batch_size, "simple " if args.simple_model else "", args.dataset)) logger.log_parameters(vars(args)) # Seed the experiment seed = args.seed torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) print("Running seed ", seed) torch.cuda.set_device(args.device) assert torch.cuda.is_available() # Shuffle dataset dataset = get_datasets(args) # Train h0 net_cls = get_net_cls(args) network = net_cls(args.num_cls).to(args.device) def log_fn(epoch): network.eval() accuracy = evaluate( network, dataset.iterate(args.infer_batch_size, False, split="test"), args.device, args) logger.log_metrics(accuracy, step=epoch) train(network, dataset=dataset, epochs=args.initial_epochs, args=args, log_fn=log_fn)
def iwal_bootstrap(network, net_cls, dataset, args=None): """Query-by-committee (pool): disagreement between random inits of net.""" # Batch-mode settings batch_size = int(dataset.online_len() / args.num_batches) label_batch_size = int(args.sample_prop * batch_size) labeled_ptrs = np.array([], dtype=np.int32) # Initialize committee committee = [network] last_ptr = 0 # Begin batch-mode sampling for batch_i in range(1, args.num_batches + 1): # Process IWAL probabilities all_probs = np.zeros( (args.bald_size, args.num_cls, dataset.online_len())) model = network model.train() with torch.no_grad(): # Get committee probability predictions for model_i in range(args.bald_size): probs = [list() for i in range(args.num_cls)] for (data, _, _) in dataset.iterate(batch_size=args.infer_batch_size, shuffle=False, split="online"): data = data.to(args.device) logits = model(data) output = torch.exp(logits) # p(y | x) output = output.cpu().data.numpy() if not args.train_iw and not args.only_rlls_infer: output = output * dataset.label_weights output = output / np.sum(output, axis=1)[:, None] for i in range(args.num_cls): probs[i].append(output[:, i]) for i, x in enumerate(probs): all_probs[model_i, i] = np.concatenate(x) # Get ys for diversification model.eval() if args.diversify in ["guess", "overguess", "subguess"]: # Produce new ys ys = [] network.eval() for image, _, _ in dataset.iterate( batch_size=args.infer_batch_size, shuffle=False, split="online"): image = image.to(args.device) output = torch.exp(network(image)) p = output.cpu().data.numpy() if not args.train_iw and not args.only_rlls_infer: p = p * dataset.label_weights p = p / np.sum(p, axis=1)[:, None] ys.append(np.argmax(p, axis=1)) ys = np.concatenate(ys) all_probs = np.transpose(all_probs, [2, 0, 1]) # For some datapoint and some label, this is largest disagreement in prob: probs_disagreement = np.max(all_probs, axis=1) - np.min(all_probs, axis=1) assert probs_disagreement.shape == (dataset.online_len(), args.num_cls) # For some datapoint, this is largest disagreement in prob: probs_disagreement = np.max(probs_disagreement, axis=1) sample_probs = args.iwal_normalizer + \ (1 - args.iwal_normalizer) * probs_disagreement print("Original sample_probs:", sample_probs, len(sample_probs), np.mean(sample_probs), np.var(sample_probs)) #iw = p_t for all training datapoints. if args.diversify == "none": sample_probs = sample_probs elif args.diversify == "guess": yunique, ycounts = np.unique(ys, return_counts=True) ycounts = np.array(ycounts, dtype=np.float32) / np.sum(ycounts) ycounts += 1e-4 ycounts = 1 / ycounts ycounts = ycounts / np.sum(ycounts) dss = {} for y, c in zip(yunique, ycounts): dss[y] = c for i, y in enumerate(ys): all_probs[i] *= dss[y] elif args.diversify == "subguess": yunique = np.unique(ys) ycounts = np.sqrt(dataset.first_weight) ycounts += 1e-4 ycounts = ycounts / np.sum(ycounts) dss = {} for y, c in zip(yunique, ycounts): dss[y] = c for i, y in enumerate(ys): all_probs[i] *= dss[y] # Labeled data new_ptrs = [] for i in range(last_ptr, dataset.online_len()): if len(new_ptrs) > label_batch_size: break sample_probs[i] = max(sample_probs[i], 0) sample_probs[i] = min(sample_probs[i], 1) if random.random() < sample_probs[i]: last_ptr = i new_ptrs.append(i) labeled_ptrs = np.concatenate( [labeled_ptrs, np.array(new_ptrs, dtype=np.int32)]) dataset.label_ptrs(labeled_ptrs) # Note sample proportion print("Sample proportion: ", len(labeled_ptrs) / dataset.online_len()) # Train networks on current batch status for this_network in committee: this_network.train() train(this_network, dataset, epochs=args.partial_epochs, lr=args.finetune_lr, args=args) this_network.eval() # Handle reweighting procedure if args.iterative_iw: label_shift(committee[0], dataset, args) yield committee[0]
def general_sampling(network, net_cls, dataset, args=None): # pylint: disable=R0914,R0912,R0915 """Query-by-committee (pool): disagreement between random inits of net.""" # Batch-mode settings batch_size = int(dataset.online_len() / args.num_batches) label_batch_size = int(args.sample_prop * batch_size) labeled_ptrs = np.array([], dtype=np.int32) # Initialize committee committee = [network] if args.sampling_strategy in ["qbc"]: for i in range(args.vs_size - 1): this_network = net_cls(args.num_cls).to(args.device) this_network.train() train(this_network, dataset, epochs=args.initial_epochs, args=args, milestones=LONG_MILESTONES) this_network.eval() committee.append(this_network) # Map pointers to label if args.diversify == "cheat": ys = [] for i in dataset.indices(split="online"): y, _ = dataset.train_labels[i] ys.append(y) ys = np.array(ys, dtype=np.int32) # Begin batch-mode sampling for batch_i in range(1, args.num_batches + 1): stats = [] # Smaller stats means higher priority sep_stats = [] # Bigger value means more important with torch.no_grad(): for image, y, _ in dataset.iterate( batch_size=args.infer_batch_size, shuffle=False, split="online"): image = image.to(args.device) # Aggregate domain sep if args.domainsep: this_network = committee[0] this_network.eval() output = torch.exp(this_network(image)) p = output.cpu().data.numpy() sep_stats.append(np.sum(-p * np.log(p + 1e-9), axis=1)) # Produce stats depending on algorithm if args.sampling_strategy == "qbc": predictions = [] for this_network in committee: this_network.eval() output = torch.exp(this_network(image)) p = output.cpu().data.numpy() if not args.train_iw and not args.only_rlls_infer: p = p * dataset.label_weights p = p / np.sum(p, axis=1)[:, None] predictions.append(np.argmax(p, axis=1)) predictions = np.stack(predictions) stats.append(mode(predictions, axis=0).count[0]) if args.sampling_strategy == "bald": predictions = [] this_network = committee[0] this_network.train() with torch.no_grad(): for _ in range(args.bald_size): output = torch.exp(this_network(image)) p = output.cpu().data.numpy() if not args.train_iw and not args.only_rlls_infer: p = p * dataset.label_weights p = p / np.sum(p, axis=1)[:, None] predictions.append(np.argmax(p, axis=1)) predictions = np.stack(predictions) stats.append(mode(predictions, axis=0).count[0]) if args.sampling_strategy == "cheat": this_network = committee[0] this_network.eval() output = torch.exp(this_network(image)) p = output.cpu().data.numpy() if not args.train_iw and not args.only_rlls_infer: p = p * dataset.label_weights p = p / np.sum(p, axis=1)[:, None] stats.append(np.equal(np.argmax(p, axis=1), y)) if args.sampling_strategy == "margin": margin = [] this_network = committee[0] this_network.eval() output = torch.exp(this_network(image)) p = output.cpu().data.numpy() if not args.train_iw and not args.only_rlls_infer: p = p * dataset.label_weights p = p / np.sum(p, axis=1)[:, None] sorted_p = np.sort(p) margin = sorted_p[:, -1] - sorted_p[:, -2] stats.append(margin) if args.sampling_strategy == "maxent": this_network = committee[0] this_network.eval() output = torch.exp(this_network(image)) p = output.cpu().data.numpy() if not args.train_iw and not args.only_rlls_infer: p = p * dataset.label_weights p = p / np.sum(p, axis=1)[:, None] stats.append(-np.sum(-p * np.log(p + 1e-9), axis=1)) if args.sampling_strategy == "random": stats.append(np.random.uniform(size=(len(image), ))) if args.diversify in ["guess", "overguess"]: # Produce new ys ys = [] network.eval() for image, _, _ in dataset.iterate( batch_size=args.infer_batch_size, shuffle=False, split="online"): image = image.to(args.device) output = torch.exp(network(image)) p = output.cpu().data.numpy() if not args.train_iw and not args.only_rlls_infer: p = p * dataset.label_weights p = p / np.sum(p, axis=1)[:, None] ys.append(np.argmax(p, axis=1)) ys = np.concatenate(ys) # Concatenate stats stats = np.concatenate(stats) if args.domainsep: sep_stats = np.concatenate(sep_stats) sep_stats[sep_stats < 0.5] = 0 sep_odds = np.uniform(size=sep_stats.shape) selection = np.greater(sep_odds, np.uniform(size=sep_stats.shape)) stats[~selection] = np.infty # Stack stats new_ptrs = np.setdiff1d(np.arange(len(stats)), labeled_ptrs) sorted_ptrs = new_ptrs[np.argsort(stats[new_ptrs])] if args.diversify == "none": labeled_ptrs = np.concatenate( [labeled_ptrs, sorted_ptrs[:label_batch_size]]) elif args.diversify == "guess": # Take top examples from each label sorted_ptrs_by_label = {y: [] for y in range(args.num_cls)} for ptr in sorted_ptrs: sorted_ptrs_by_label[ys[ptr]].append(ptr) # Of remaining ptrs per label, find most equal allocation label_lens = sorted( [len(x) for x in sorted_ptrs_by_label.values()]) for i, l in enumerate(label_lens): size = math.ceil((label_batch_size - sum(label_lens[:i])) / len(label_lens[i:])) if size <= l: break size = -1 if size == -1: raise ValueError() # Label pts per each for k, ptrs in sorted_ptrs_by_label.items(): labeled_ptrs = np.concatenate([labeled_ptrs, ptrs[:size]]) assert len(np.unique(labeled_ptrs)) == len(labeled_ptrs) elif args.diversify == "overguess": # Take top examples from each label sorted_ptrs_by_label = {y: [] for y in range(args.num_cls)} for ptr in sorted_ptrs: sorted_ptrs_by_label[ys[ptr]].append(ptr) # Label pts per each for k, ptrs in sorted_ptrs_by_label.items(): size = math.ceil(dataset.label_weights[k] / sum(dataset.label_weights) * label_batch_size) labeled_ptrs = np.concatenate([labeled_ptrs, ptrs[:size]]) assert len(np.unique(labeled_ptrs)) == len(labeled_ptrs) dataset.label_ptrs(labeled_ptrs) # Note sample proportion print("Sample proportion: ", len(labeled_ptrs) / dataset.online_len()) # Train networks on current batch status if args.domainsep: committee[0].train() sep_train(committee[0], dataset, epochs=args.partial_epochs, lr=args.finetune_lr, args=args) committee[0].eval() for this_network in committee: this_network.train() train(this_network, dataset, epochs=args.partial_epochs, lr=args.finetune_lr, args=args) this_network.eval() # Handle reweighting procedure if args.iterative_iw: label_shift(committee[0], dataset, args) yield committee[0]
def experiment(): """Run finetuning label shift exp""" args = get_args(None) name = "Finetune {}:{}:{} {}:{} {}:{}:{}:{} {}{}v{}".format( args.dataset, args.dataset_cap, args.warmstart_ratio, args.shift_strategy, args.dirichlet_alpha, args.shift_correction, args.rlls_reg, args.rlls_lambda, args.lr, "IW " if args.train_iw else "NOIW ", "ITIW " if args.iterative_iw else "NOITIW ", args.version, ) # Initialize comet.ml if args.log: comet_api = api.API(api_key=comet_ml_key) exps = comet_api.get_experiments( "ericzhao28", project_name="active-label-shift-adaptation", pattern=name) for exp in exps: if exp.get_name() == name: raise ValueError("EXP EXISTS!") logger = Experiment(comet_ml_key, project_name="active-label-shift-adaptation") logger.set_name(name) logger.log_parameters(vars(args)) # Seed the experiment seed = args.seed torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) print("Running seed ", seed) torch.cuda.set_device(args.device) assert torch.cuda.is_available() # Shuffle dataset dataset = get_datasets(args) dataset.label_ptrs(np.arange(dataset.online_len())) # Train h0 net_cls = get_net_cls(args) network = net_cls(args.num_cls).to(args.device) def log_fn(epoch): network.eval() accuracy = evaluate(network, dataset.iterate(args.infer_batch_size, False, split="test"), args.device, args, label_weights=dataset.label_weights) print(accuracy) logger.log_metrics(accuracy, prefix="initial", step=epoch) train(network, dataset=dataset, epochs=args.initial_epochs, args=args, log_fn=log_fn) # Get source shift corrections lsmse = label_shift(network, dataset, args) def log_fn_shifted(epoch): network.eval() if args.iterative_iw: label_shift(network, dataset, args) accuracy = evaluate(network, dataset.iterate(args.infer_batch_size, False, split="test"), args.device, args, label_weights=dataset.label_weights) print(accuracy) logger.log_metrics(accuracy, prefix="shifted", step=epoch) train(network, dataset=dataset, epochs=args.initial_epochs, args=args, log_fn=log_fn_shifted) if args.iterative_iw: lsmse = label_shift(network, dataset, args) logger.log_metrics({"IW MSE": lsmse}, prefix="initial")
def experiment(args, logger, name, seed=None): """Run LS experiment""" # Seed the experiment if seed is None: seed = args.seed torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) print("Running seed ", seed) # torch.cuda.set_device(args.device) # assert torch.cuda.is_available() if args.reweight: assert args.iterative_iw assert args.shift_correction in ["rlls", "bbse", "cheat", "none"] # Shuffle dataset dataset = get_datasets(args) # Load net_cls = get_net_cls(args) network = net_cls(args.num_cls).to(args.device) # Train h0 # milestones = NABIRDS_MILESTONES if args.dataset == "nabirds" else LONG_MILESTONES milestones = LONG_MILESTONES # Check for h0 if args.dataset == "nabirds": fname = "content/data/%s_%s_%s_%s_%s.h5" % ( args.seed, args.dataset_cap, args.dataset, args.nabirdstype, args.version) else: fname = "content/data/%s_%s_%s_%s_%s_%s.h5" % ( args.seed, args.dataset_cap, args.dataset, args.shift_strategy, args.warmstart_ratio, args.version) if os.path.exists(fname): checkpoint = torch.load(fname, map_location=torch.device(args.device)) network.load_state_dict(checkpoint) network = network.to(args.device) else: train(network, dataset=dataset, epochs=args.initial_epochs, args=args, milestones=milestones) if args.domainsep: sep_train(network, dataset=dataset, epochs=args.initial_epochs, args=args) torch.save(network.state_dict(), home_dir + fname) label_shift(network, dataset, args) # Initialize sampling strategy if args.sampling_strategy == "iwal": iterator = iwal_bootstrap(network, net_cls, dataset, args) else: iterator = general_sampling(network, net_cls, dataset, args) # Initialize results _, initial_labeled_shift, initial_uniform_labeled_shift = measure_composition( dataset) initial_accuracy = evaluate(network, dataset.iterate(args.infer_batch_size, False, split="test"), args.device, args, label_weights=dataset.label_weights) num_labeled = [0] labeled_shifts = [initial_labeled_shift] uniform_labeled_shifts = [initial_uniform_labeled_shift] accuracies = {k: [v] for k, v in initial_accuracy.items()} metrics = { "Number of labels": num_labeled, "Source shift": labeled_shifts, "Uniform source shift": uniform_labeled_shifts, } metrics.update(accuracies) for k, v in metrics.items(): logger.log_metric(k, v[-1], step=metrics["Number of labels"][-1]) print(metrics) save_to_csv(name, metrics, args, logger) # Begin sampling for network in iterator: # Evaluate current network network.eval() accuracy = evaluate(network, dataset.iterate(args.infer_batch_size, False, split="test"), args.device, args, label_weights=dataset.label_weights) for k, v in accuracy.items(): accuracies[k].append(v) num_labeled.append(dataset.online_labeled_len()) new_optimal_weights, labeled_shift, uniform_labeled_shift = measure_composition( dataset) labeled_shifts.append(labeled_shift) uniform_labeled_shifts.append(uniform_labeled_shift) print("Optimal weights for new source", new_optimal_weights) # Record metrics metrics = { "Number of labels": num_labeled, "Source shift": labeled_shifts, "Uniform source shift": uniform_labeled_shifts, } metrics.update(accuracies) for k, v in metrics.items(): logger.log_metric(k, v[-1], step=metrics["Number of labels"][-1]) print(metrics) save_to_csv(name, metrics, args, logger) logger.log_metric("Done", True) return metrics