def print_results_tables(records, selection_method, latex): """Given all records, print a results table for each dataset.""" grouped_records = get_grouped_records(records).map( lambda group: { **group, 'sweep_acc': selection_method.sweep_acc(group['records']) }).filter(lambda g: g['sweep_acc'] is not None) # read algorithm names and sort (predefined order) alg_names = Q(records).select('args.algorithm').unique() alg_names = ([n for n in algorithms.ALGORITHMS if n in alg_names] + [n for n in alg_names if n not in algorithms.ALGORITHMS]) # read dataset names and sort (lexicographic order) dataset_names = Q(records).select('args.dataset').unique().sorted() for dataset in dataset_names: test_envs = range(datasets.NUM_ENVIRONMENTS[dataset]) table = [[None for _ in test_envs] for _ in alg_names] for i, algorithm in enumerate(alg_names): for j, test_env in enumerate(test_envs): trial_accs = (grouped_records.filter_equals( 'dataset, algorithm, test_env', (dataset, algorithm, test_env)).select('sweep_acc')) table[i][j] = format_mean(trial_accs, latex) col_labels = [ 'Algorithm', *datasets.get_dataset_class(dataset).ENVIRONMENT_NAMES ] header_text = (f'Dataset: {dataset}, ' f'model selection method: {selection_method.name}') print_table(table, header_text, alg_names, list(col_labels), colwidth=20, latex=latex) # Print an 'averages' table table = [[None for _ in dataset_names] for _ in alg_names] for i, algorithm in enumerate(alg_names): for j, dataset in enumerate(dataset_names): trial_averages = (grouped_records.filter_equals( 'algorithm, dataset', (algorithm, dataset)).group('trial_seed').map( lambda trial_seed, group: group.select('sweep_acc').mean()) ) table[i][j] = format_mean(trial_averages, latex) col_labels = ['Algorithm', *dataset_names] header_text = f'Averages, model selection method: {selection_method.name}' print_table(table, header_text, alg_names, col_labels, colwidth=25, latex=latex)
def test_featurizer(self, dataset_name): """Test that Featurizer() returns a module which can take a correctly-sized input and return a correctly-sized output.""" batch_size = 8 hparams = hparams_registry.default_hparams('ERM', dataset_name) dataset = datasets.get_dataset_class(dataset_name)('', [], hparams) input_ = helpers.make_minibatches(dataset, batch_size)[0][0] input_shape = dataset.input_shape algorithm = networks.Featurizer(input_shape, hparams).cuda() output = algorithm(input_) self.assertEqual(list(output.shape), [batch_size, algorithm.n_outputs])
def test_init_update_predict(self, dataset_name, algorithm_name): """Test that a given algorithm inits, updates and predicts without raising errors.""" batch_size = 8 hparams = hparams_registry.default_hparams(algorithm_name, dataset_name) dataset = datasets.get_dataset_class(dataset_name)('', [], hparams) minibatches = helpers.make_minibatches(dataset, batch_size) algorithm_class = algorithms.get_algorithm_class(algorithm_name) algorithm = algorithm_class(dataset.input_shape, dataset.num_classes, len(dataset), hparams).cuda() for _ in range(3): self.assertIsNotNone(algorithm.update(minibatches)) algorithm.eval() self.assertEqual(list(algorithm.predict(minibatches[0][0]).shape), [batch_size, dataset.num_classes])
def test_dataset_erm(self, dataset_name): """ Test that ERM can complete one step on a given dataset without raising an error. Also test that NUM_ENVIRONMENTS[dataset] is set correctly. """ batch_size = 8 hparams = hparams_registry.default_hparams('ERM', dataset_name) dataset = datasets.get_dataset_class(dataset_name)( os.environ['DATA_DIR'], [], hparams) self.assertEqual(datasets.NUM_ENVIRONMENTS[dataset_name], len(dataset)) algorithm = algorithms.get_algorithm_class('ERM')(dataset.input_shape, dataset.num_classes, len(dataset), hparams).cuda() minibatches = helpers.make_minibatches(dataset, batch_size) algorithm.update(minibatches)
def todo_rename(records, selection_method, latex): grouped_records = reporting.get_grouped_records(records).map(lambda group: { **group, "sweep_acc": selection_method.sweep_acc(group["records"]) } ).filter(lambda g: g["sweep_acc"] is not None) # read algorithm names and sort (predefined order) alg_names = Q(records).select("args.algorithm").unique() alg_names = ([n for n in algorithms.ALGORITHMS if n in alg_names] + [n for n in alg_names if n not in algorithms.ALGORITHMS]) # read dataset names and sort (lexicographic order) dataset_names = Q(records).select("args.dataset").unique().sorted() dataset_names = [d for d in datasets.DATASETS if d in dataset_names] for dataset in dataset_names: if latex: print() print("\\subsubsection{{{}}}".format(dataset)) test_envs = range(datasets.num_environments(dataset)) table = [[None for _ in [*test_envs, "Avg"]] for _ in alg_names] for i, algorithm in enumerate(alg_names): means = [] for j, test_env in enumerate(test_envs): trial_accs = (grouped_records .filter_equals( "dataset, algorithm, test_env", (dataset, algorithm, test_env) ).select("sweep_acc")) mean, err, table[i][j] = format_mean(trial_accs, latex) means.append(mean) if None in means: table[i][-1] = "X" else: table[i][-1] = "{:.1f}".format(sum(means) / len(means)) col_labels = [ "Algorithm", *datasets.get_dataset_class(dataset).ENVIRONMENTS, "Avg" ] header_text = (f"Dataset: {dataset}, " f"model selection method: {selection_method.name}") print_table(table, header_text, alg_names, list(col_labels), colwidth=20, latex=latex) # Print an "averages" table if latex: print() print("\\subsubsection{Averages}") table = [[None for _ in [*dataset_names, "Avg"]] for _ in alg_names] for i, algorithm in enumerate(alg_names): means = [] for j, dataset in enumerate(dataset_names): trial_averages = (grouped_records .filter_equals("algorithm, dataset", (algorithm, dataset)) .group("trial_seed") .map(lambda trial_seed, group: group.select("sweep_acc").mean() ) ) mean, err, table[i][j] = format_mean(trial_averages, latex) means.append(mean) if None in means: table[i][-1] = "X" else: table[i][-1] = "{:.1f}".format(sum(means) / len(means)) col_labels = ["Algorithm", *dataset_names, "Avg"] header_text = f"Averages, model selection method: {selection_method.name}" print_table(table, header_text, alg_names, col_labels, colwidth=25, latex=latex)
args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) datasets_to_save = [ "OfficeHome", "TerraIncognita", "DomainNet", "RotatedMNIST", "ColoredMNIST", "SVIRO", ] for dataset_name in tqdm(datasets_to_save): hparams = hparams_registry.default_hparams("ERM", dataset_name) dataset = datasets.get_dataset_class(dataset_name)( args.data_dir, list(range(datasets.num_environments(dataset_name))), hparams) for env_idx, env in enumerate(tqdm(dataset)): for i in tqdm(range(50)): idx = random.choice(list(range(len(env)))) x, y = env[idx] while y > 10: idx = random.choice(list(range(len(env)))) x, y = env[idx] if x.shape[0] == 2: x = torch.cat([x, torch.zeros_like(x)], dim=0)[:3, :, :] if x.min() < 0: mean = torch.tensor([0.485, 0.456, 0.406])[:, None, None] std = torch.tensor([0.229, 0.224, 0.225])[:, None, None] x = (x * std) + mean assert x.min() >= 0
if __name__ == '__main__': parser = argparse.ArgumentParser(description='Domain generalization') parser.add_argument('--data_dir', type=str) parser.add_argument('--output_dir', type=str) args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) datasets_to_save = [ 'OfficeHome', 'TerraIncognita', 'DomainNet', 'RotatedMNIST', 'ColoredMNIST' ] for dataset_name in tqdm(datasets_to_save): hparams = hparams_registry.default_hparams('ERM', dataset_name) dataset = datasets.get_dataset_class(dataset_name)( args.data_dir, list(range(datasets.NUM_ENVIRONMENTS[dataset_name])), hparams) for env_idx, env in enumerate(tqdm(dataset)): for i in tqdm(range(50)): idx = random.choice(list(range(len(env)))) x, y = env[idx] while y > 10: idx = random.choice(list(range(len(env)))) x, y = env[idx] if x.shape[0] == 2: x = torch.cat([x, torch.zeros_like(x)], dim=0)[:3, :, :] if x.min() < 0: mean = torch.tensor([0.485, 0.456, 0.406])[:, None, None] std = torch.tensor([0.229, 0.224, 0.225])[:, None, None] x = (x * std) + mean assert (x.min() >= 0)
def print_results_tables(records, selection_method, latex): """Given all records, print a results table for each dataset.""" grouped_records = reporting.get_grouped_records(records).map(lambda group: { **group, "sweep_accs": selection_method.sweep_accs(group["records"]) } ) # read algorithm names and sort (predefined order) alg_names = Q(records).select("args.algorithm").unique() alg_names = ([n for n in algorithms.ALGORITHMS if n in alg_names] + [n for n in alg_names if n not in algorithms.ALGORITHMS]) # read dataset names and sort (lexicographic order) dataset_names = Q(records).select("args.dataset").unique().sorted() dataset_names = [d for d in datasets.DATASETS if d in dataset_names] for dataset in dataset_names: if latex: print() print("\\subsubsection{{{}}}".format(dataset)) test_envs = range(datasets.num_environments(dataset)) table = [[None for _ in [*test_envs, "Avg"]] for _ in alg_names] for i, algorithm in enumerate(alg_names): means = [] stdevs = [] for j, test_env in enumerate(test_envs): try: acc = grouped_records.filter_equals( "dataset, algorithm, test_env", (dataset, algorithm, test_env) )[0]['sweep_accs'][0] mean = acc['test_acc'] stdev = acc['test_acc_std'] except: mean = float('nan') stdev = float('nan') means.append(mean) stdevs.append(stdev) _, _, table[i][j] = format_mean(mean, stdev, latex) avg_mean = np.mean(means) avg_stdev = np.sqrt(np.sum(np.array(stdevs)**2)) / len(stdevs) _, _, table[i][-1] = format_mean(avg_mean, avg_stdev, latex) col_labels = [ "Algorithm", *datasets.get_dataset_class(dataset).ENVIRONMENTS, "Avg" ] header_text = (f"Dataset: {dataset}, " f"model selection method: {selection_method.name}") print_table(table, header_text, alg_names, list(col_labels), colwidth=20, latex=latex) # Print an "averages" table if latex: print() print("\\subsubsection{Averages}") table = [[None for _ in [*dataset_names, "Avg"]] for _ in alg_names] for i, algorithm in enumerate(alg_names): means = [] for j, dataset in enumerate(dataset_names): try: mean = (grouped_records .filter_equals("algorithm, dataset", (algorithm, dataset)) .select(lambda x: x['sweep_accs'][0]['test_acc']) .mean() ) except: mean = float('nan') mean *= 100. table[i][j] = "{:.1f}".format(mean) means.append(mean) table[i][-1] = "{:.1f}".format(sum(means) / len(means)) col_labels = ["Algorithm", *dataset_names, "Avg"] header_text = f"Averages, model selection method: {selection_method.name}" print_table(table, header_text, alg_names, col_labels, colwidth=25, latex=latex)