def test_train_and_evaluate(self, dataset_source): config = common.get_config() config.model = models.get_testing_config() config.batch = 64 config.accum_steps = 2 config.batch_eval = 8 config.total_steps = 1 with tempfile.TemporaryDirectory() as workdir: if dataset_source == 'tfds': config.dataset = 'cifar10' config.pp = ml_collections.ConfigDict({ 'train': 'train[:98%]', 'test': 'test', 'crop': 224 }) elif dataset_source == 'directory': config.dataset = os.path.join(workdir, 'dataset') config.pp = ml_collections.ConfigDict({'crop': 224}) for mode in ('train', 'test'): for class_name in ('test1', 'test2'): for i in range(8): path = os.path.join(config.dataset, mode, class_name, f'{i}.jpg') os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, 'wb') as f: f.write(JPG_BLACK_1PX) else: raise ValueError(f'Unknown dataset_source: "{dataset_source}"') config.pretrained_dir = workdir test_utils.create_checkpoint(config.model, f'{workdir}/testing.npz') _ = train.train_and_evaluate(config, workdir) self.assertTrue(os.path.exists(f'{workdir}/checkpoint_1'))
def test_load_pretrained(self): tempdir = tempfile.gettempdir() model_config = config_lib.get_testing_config() test_utils.create_checkpoint(model_config, f'{tempdir}/testing.npz') model = models.VisionTransformer(num_classes=2, **model_config) variables = model.init( jax.random.PRNGKey(0), inputs=jnp.ones([1, 32, 32, 3], jnp.float32), train=False, ) checkpoint.load_pretrained(pretrained_path=f'{tempdir}/testing.npz', init_params=variables['params'], model_config=model_config)
def test_main(self): config = config_lib.get_config() config.num_classes = 10 config.image_size = 224 config.batch = 8 config.model_name = 'testing' model_config = models.get_testing_config() workdir = tempfile.gettempdir() config.pretrained_dir = workdir test_utils.create_checkpoint(model_config, f'{workdir}/testing.npz') inference_time.inference_time(config, workdir) self.assertNotEmpty(glob.glob(f'{workdir}/events.out.tfevents.*'))
def test_train_and_evaluate(self): workdir = tempfile.gettempdir() config = common.get_config() config.model = models.get_testing_config() config.dataset = 'cifar10' config.pp = ml_collections.ConfigDict( {'train': 'train[:98%]', 'test': 'test', 'resize': 448, 'crop': 384}) config.batch = 64 config.accum_steps = 2 config.batch_eval = 8 config.total_steps = 1 config.pretrained_dir = workdir test_utils.create_checkpoint(config.model, f'{workdir}/testing.npz') opt_pmap = train.train_and_evaluate(config, workdir) self.assertTrue(os.path.exists(f'{workdir}/model.npz'))