def load_model(filename): dump = torch.load(filename) algorithm_class = algorithms.get_algorithm_class( dump["args"]["algorithm"]) algorithm = algorithm_class(dump["model_input_shape"], dump["model_num_classes"], dump["model_num_domains"], dump["model_hparams"]) algorithm.load_state_dict(dump["model_dict"]) return algorithm
def test_init_update_predict(self, dataset_name, algorithm_name): """Test that a given algorithm inits, updates and predicts without raising errors.""" batch_size = 8 hparams = hparams_registry.default_hparams(algorithm_name, dataset_name) dataset = datasets.get_dataset_class(dataset_name)('', [], hparams) minibatches = helpers.make_minibatches(dataset, batch_size) algorithm_class = algorithms.get_algorithm_class(algorithm_name) algorithm = algorithm_class(dataset.input_shape, dataset.num_classes, len(dataset), hparams).cuda() for _ in range(3): self.assertIsNotNone(algorithm.update(minibatches)) algorithm.eval() self.assertEqual(list(algorithm.predict(minibatches[0][0]).shape), [batch_size, dataset.num_classes])
def test_dataset_erm(self, dataset_name): """ Test that ERM can complete one step on a given dataset without raising an error. Also test that NUM_ENVIRONMENTS[dataset] is set correctly. """ batch_size = 8 hparams = hparams_registry.default_hparams('ERM', dataset_name) dataset = datasets.get_dataset_class(dataset_name)( os.environ['DATA_DIR'], [], hparams) self.assertEqual(datasets.NUM_ENVIRONMENTS[dataset_name], len(dataset)) algorithm = algorithms.get_algorithm_class('ERM')(dataset.input_shape, dataset.num_classes, len(dataset), hparams).cuda() minibatches = helpers.make_minibatches(dataset, batch_size) algorithm.update(minibatches)
batch_size=64, num_workers=dataset.N_WORKERS) for env, _ in (in_splits + out_splits + uda_splits) ] eval_weights = [ None for _, weights in (in_splits + out_splits + uda_splits) ] eval_loader_names = ['env{}_in'.format(i) for i in range(len(in_splits))] eval_loader_names += [ 'env{}_out'.format(i) for i in range(len(out_splits)) ] eval_loader_names += [ 'env{}_uda'.format(i) for i in range(len(uda_splits)) ] algorithm_class = algorithms.get_algorithm_class(args.algorithm) algorithm = algorithm_class(dataset.input_shape, dataset.num_classes, len(dataset) - len(args.test_envs), hparams) if algorithm_dict is not None: algorithm.load_state_dict(algorithm_dict) algorithm.to(device) if hasattr(algorithm, 'network'): algorithm.network = DataParallelPassthrough(algorithm.network) else: for m in algorithm.children(): m = DataParallelPassthrough(m) train_minibatches_iterator = zip(*train_loaders) uda_minibatches_iterator = zip(*uda_loaders)