Пример #1
0
 def test_featurizer(self, dataset_name):
     """Test that Featurizer() returns a module which can take a
     correctly-sized input and return a correctly-sized output."""
     batch_size = 8
     hparams = hparams_registry.default_hparams('ERM', dataset_name)
     dataset = datasets.get_dataset_class(dataset_name)('', [], hparams)
     input_ = helpers.make_minibatches(dataset, batch_size)[0][0]
     input_shape = dataset.input_shape
     algorithm = networks.Featurizer(input_shape, hparams).cuda()
     output = algorithm(input_)
     self.assertEqual(list(output.shape), [batch_size, algorithm.n_outputs])
Пример #2
0
 def test_init_update_predict(self, dataset_name, algorithm_name):
     """Test that a given algorithm inits, updates and predicts without raising
     errors."""
     batch_size = 8
     hparams = hparams_registry.default_hparams(algorithm_name,
                                                dataset_name)
     dataset = datasets.get_dataset_class(dataset_name)('', [], hparams)
     minibatches = helpers.make_minibatches(dataset, batch_size)
     algorithm_class = algorithms.get_algorithm_class(algorithm_name)
     algorithm = algorithm_class(dataset.input_shape, dataset.num_classes,
                                 len(dataset), hparams).cuda()
     for _ in range(3):
         self.assertIsNotNone(algorithm.update(minibatches))
     algorithm.eval()
     self.assertEqual(list(algorithm.predict(minibatches[0][0]).shape),
                      [batch_size, dataset.num_classes])
Пример #3
0
 def test_dataset_erm(self, dataset_name):
     """
     Test that ERM can complete one step on a given dataset without raising
     an error.
     Also test that NUM_ENVIRONMENTS[dataset] is set correctly.
     """
     batch_size = 8
     hparams = hparams_registry.default_hparams('ERM', dataset_name)
     dataset = datasets.get_dataset_class(dataset_name)(
         os.environ['DATA_DIR'], [], hparams)
     self.assertEqual(datasets.NUM_ENVIRONMENTS[dataset_name], len(dataset))
     algorithm = algorithms.get_algorithm_class('ERM')(dataset.input_shape,
                                                       dataset.num_classes,
                                                       len(dataset),
                                                       hparams).cuda()
     minibatches = helpers.make_minibatches(dataset, batch_size)
     algorithm.update(minibatches)
Пример #4
0
    print("Environment:")
    print("\tPython: {}".format(sys.version.split(" ")[0]))
    print("\tPyTorch: {}".format(torch.__version__))
    print("\tTorchvision: {}".format(torchvision.__version__))
    print("\tCUDA: {}".format(torch.version.cuda))
    print("\tCUDNN: {}".format(torch.backends.cudnn.version()))
    print("\tNumPy: {}".format(np.__version__))
    print("\tPIL: {}".format(PIL.__version__))

    print('Args:')
    for k, v in sorted(vars(args).items()):
        print('\t{}: {}'.format(k, v))

    if args.hparams_seed == 0:
        hparams = hparams_registry.default_hparams(args.algorithm,
                                                   args.dataset)
    else:
        hparams = hparams_registry.random_hparams(
            args.algorithm, args.dataset,
            misc.seed_hash(args.hparams_seed, args.trial_seed))
    if args.hparams:
        hparams.update(json.loads(args.hparams))

    print('HParams:')
    for k, v in sorted(hparams.items()):
        print('\t{}: {}'.format(k, v))

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
Пример #5
0
    parser.add_argument("--data_dir", type=str)
    parser.add_argument("--output_dir", type=str)
    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)
    datasets_to_save = [
        "OfficeHome",
        "TerraIncognita",
        "DomainNet",
        "RotatedMNIST",
        "ColoredMNIST",
        "SVIRO",
    ]

    for dataset_name in tqdm(datasets_to_save):
        hparams = hparams_registry.default_hparams("ERM", dataset_name)
        dataset = datasets.get_dataset_class(dataset_name)(
            args.data_dir,
            list(range(datasets.num_environments(dataset_name))), hparams)
        for env_idx, env in enumerate(tqdm(dataset)):
            for i in tqdm(range(50)):
                idx = random.choice(list(range(len(env))))
                x, y = env[idx]
                while y > 10:
                    idx = random.choice(list(range(len(env))))
                    x, y = env[idx]
                if x.shape[0] == 2:
                    x = torch.cat([x, torch.zeros_like(x)], dim=0)[:3, :, :]
                if x.min() < 0:
                    mean = torch.tensor([0.485, 0.456, 0.406])[:, None, None]
                    std = torch.tensor([0.229, 0.224, 0.225])[:, None, None]
Пример #6
0
from tqdm import tqdm

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Domain generalization')
    parser.add_argument('--data_dir', type=str)
    parser.add_argument('--output_dir', type=str)
    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)
    datasets_to_save = [
        'OfficeHome', 'TerraIncognita', 'DomainNet', 'RotatedMNIST',
        'ColoredMNIST'
    ]

    for dataset_name in tqdm(datasets_to_save):
        hparams = hparams_registry.default_hparams('ERM', dataset_name)
        dataset = datasets.get_dataset_class(dataset_name)(
            args.data_dir,
            list(range(datasets.NUM_ENVIRONMENTS[dataset_name])), hparams)
        for env_idx, env in enumerate(tqdm(dataset)):
            for i in tqdm(range(50)):
                idx = random.choice(list(range(len(env))))
                x, y = env[idx]
                while y > 10:
                    idx = random.choice(list(range(len(env))))
                    x, y = env[idx]
                if x.shape[0] == 2:
                    x = torch.cat([x, torch.zeros_like(x)], dim=0)[:3, :, :]
                if x.min() < 0:
                    mean = torch.tensor([0.485, 0.456, 0.406])[:, None, None]
                    std = torch.tensor([0.229, 0.224, 0.225])[:, None, None]