def test_train_and_evaluate(self, dataset_source):
    config = common.get_config()
    config.model = models.get_testing_config()
    config.batch = 64
    config.accum_steps = 2
    config.batch_eval = 8
    config.total_steps = 1

    with tempfile.TemporaryDirectory() as workdir:
      if dataset_source == 'tfds':
        config.dataset = 'cifar10'
        config.pp = ml_collections.ConfigDict({
            'train': 'train[:98%]',
            'test': 'test',
            'crop': 224
        })
      elif dataset_source == 'directory':
        config.dataset = os.path.join(workdir, 'dataset')
        config.pp = ml_collections.ConfigDict({'crop': 224})
        for mode in ('train', 'test'):
          for class_name in ('test1', 'test2'):
            for i in range(8):
              path = os.path.join(config.dataset, mode, class_name, f'{i}.jpg')
              os.makedirs(os.path.dirname(path), exist_ok=True)
              with open(path, 'wb') as f:
                f.write(JPG_BLACK_1PX)
      else:
        raise ValueError(f'Unknown dataset_source: "{dataset_source}"')

      config.pretrained_dir = workdir
      test_utils.create_checkpoint(config.model, f'{workdir}/testing.npz')

      _ = train.train_and_evaluate(config, workdir)
      self.assertTrue(os.path.exists(f'{workdir}/checkpoint_1'))
 def test_load_pretrained(self):
     tempdir = tempfile.gettempdir()
     model_config = config_lib.get_testing_config()
     test_utils.create_checkpoint(model_config, f'{tempdir}/testing.npz')
     model = models.VisionTransformer(num_classes=2, **model_config)
     variables = model.init(
         jax.random.PRNGKey(0),
         inputs=jnp.ones([1, 32, 32, 3], jnp.float32),
         train=False,
     )
     checkpoint.load_pretrained(pretrained_path=f'{tempdir}/testing.npz',
                                init_params=variables['params'],
                                model_config=model_config)
예제 #3
0
  def test_main(self):
    config = config_lib.get_config()
    config.num_classes = 10
    config.image_size = 224
    config.batch = 8
    config.model_name = 'testing'
    model_config = models.get_testing_config()

    workdir = tempfile.gettempdir()
    config.pretrained_dir = workdir
    test_utils.create_checkpoint(model_config, f'{workdir}/testing.npz')
    inference_time.inference_time(config, workdir)
    self.assertNotEmpty(glob.glob(f'{workdir}/events.out.tfevents.*'))
예제 #4
0
  def test_train_and_evaluate(self):
    workdir = tempfile.gettempdir()

    config = common.get_config()
    config.model = models.get_testing_config()
    config.dataset = 'cifar10'
    config.pp = ml_collections.ConfigDict(
        {'train': 'train[:98%]', 'test': 'test', 'resize': 448, 'crop': 384})
    config.batch = 64
    config.accum_steps = 2
    config.batch_eval = 8
    config.total_steps = 1
    config.pretrained_dir = workdir

    test_utils.create_checkpoint(config.model, f'{workdir}/testing.npz')
    opt_pmap = train.train_and_evaluate(config, workdir)
    self.assertTrue(os.path.exists(f'{workdir}/model.npz'))