def test_featurizer(self, dataset_name): """Test that Featurizer() returns a module which can take a correctly-sized input and return a correctly-sized output.""" batch_size = 8 hparams = hparams_registry.default_hparams('ERM', dataset_name) dataset = datasets.get_dataset_class(dataset_name)('', [], hparams) input_ = helpers.make_minibatches(dataset, batch_size)[0][0] input_shape = dataset.input_shape algorithm = networks.Featurizer(input_shape, hparams).cuda() output = algorithm(input_) self.assertEqual(list(output.shape), [batch_size, algorithm.n_outputs])
def test_init_update_predict(self, dataset_name, algorithm_name): """Test that a given algorithm inits, updates and predicts without raising errors.""" batch_size = 8 hparams = hparams_registry.default_hparams(algorithm_name, dataset_name) dataset = datasets.get_dataset_class(dataset_name)('', [], hparams) minibatches = helpers.make_minibatches(dataset, batch_size) algorithm_class = algorithms.get_algorithm_class(algorithm_name) algorithm = algorithm_class(dataset.input_shape, dataset.num_classes, len(dataset), hparams).cuda() for _ in range(3): self.assertIsNotNone(algorithm.update(minibatches)) algorithm.eval() self.assertEqual(list(algorithm.predict(minibatches[0][0]).shape), [batch_size, dataset.num_classes])
def test_dataset_erm(self, dataset_name): """ Test that ERM can complete one step on a given dataset without raising an error. Also test that NUM_ENVIRONMENTS[dataset] is set correctly. """ batch_size = 8 hparams = hparams_registry.default_hparams('ERM', dataset_name) dataset = datasets.get_dataset_class(dataset_name)( os.environ['DATA_DIR'], [], hparams) self.assertEqual(datasets.NUM_ENVIRONMENTS[dataset_name], len(dataset)) algorithm = algorithms.get_algorithm_class('ERM')(dataset.input_shape, dataset.num_classes, len(dataset), hparams).cuda() minibatches = helpers.make_minibatches(dataset, batch_size) algorithm.update(minibatches)
print("Environment:") print("\tPython: {}".format(sys.version.split(" ")[0])) print("\tPyTorch: {}".format(torch.__version__)) print("\tTorchvision: {}".format(torchvision.__version__)) print("\tCUDA: {}".format(torch.version.cuda)) print("\tCUDNN: {}".format(torch.backends.cudnn.version())) print("\tNumPy: {}".format(np.__version__)) print("\tPIL: {}".format(PIL.__version__)) print('Args:') for k, v in sorted(vars(args).items()): print('\t{}: {}'.format(k, v)) if args.hparams_seed == 0: hparams = hparams_registry.default_hparams(args.algorithm, args.dataset) else: hparams = hparams_registry.random_hparams( args.algorithm, args.dataset, misc.seed_hash(args.hparams_seed, args.trial_seed)) if args.hparams: hparams.update(json.loads(args.hparams)) print('HParams:') for k, v in sorted(hparams.items()): print('\t{}: {}'.format(k, v)) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.backends.cudnn.deterministic = True
parser.add_argument("--data_dir", type=str) parser.add_argument("--output_dir", type=str) args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) datasets_to_save = [ "OfficeHome", "TerraIncognita", "DomainNet", "RotatedMNIST", "ColoredMNIST", "SVIRO", ] for dataset_name in tqdm(datasets_to_save): hparams = hparams_registry.default_hparams("ERM", dataset_name) dataset = datasets.get_dataset_class(dataset_name)( args.data_dir, list(range(datasets.num_environments(dataset_name))), hparams) for env_idx, env in enumerate(tqdm(dataset)): for i in tqdm(range(50)): idx = random.choice(list(range(len(env)))) x, y = env[idx] while y > 10: idx = random.choice(list(range(len(env)))) x, y = env[idx] if x.shape[0] == 2: x = torch.cat([x, torch.zeros_like(x)], dim=0)[:3, :, :] if x.min() < 0: mean = torch.tensor([0.485, 0.456, 0.406])[:, None, None] std = torch.tensor([0.229, 0.224, 0.225])[:, None, None]
from tqdm import tqdm if __name__ == '__main__': parser = argparse.ArgumentParser(description='Domain generalization') parser.add_argument('--data_dir', type=str) parser.add_argument('--output_dir', type=str) args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) datasets_to_save = [ 'OfficeHome', 'TerraIncognita', 'DomainNet', 'RotatedMNIST', 'ColoredMNIST' ] for dataset_name in tqdm(datasets_to_save): hparams = hparams_registry.default_hparams('ERM', dataset_name) dataset = datasets.get_dataset_class(dataset_name)( args.data_dir, list(range(datasets.NUM_ENVIRONMENTS[dataset_name])), hparams) for env_idx, env in enumerate(tqdm(dataset)): for i in tqdm(range(50)): idx = random.choice(list(range(len(env)))) x, y = env[idx] while y > 10: idx = random.choice(list(range(len(env)))) x, y = env[idx] if x.shape[0] == 2: x = torch.cat([x, torch.zeros_like(x)], dim=0)[:3, :, :] if x.min() < 0: mean = torch.tensor([0.485, 0.456, 0.406])[:, None, None] std = torch.tensor([0.229, 0.224, 0.225])[:, None, None]