예제 #1
0
  def test_cpu(self):
    """Run full training for MNIST CPU training."""
    model_dir = self.get_tmp_model_dir()
    start_time = time.time()
    mnist_lib.train_and_evaluate(
      config=config_lib.get_config(), model_dir=model_dir)
    benchmark_time = time.time() - start_time
    summaries = self.read_summaries(model_dir)

    # Summaries contain all the information necessary for the regression
    # metrics.
    wall_time, _, eval_accuracy = zip(*summaries['eval_accuracy'])
    wall_time = np.array(wall_time)
    sec_per_epoch = np.mean(wall_time[1:] - wall_time[:-1])
    end_eval_accuracy = eval_accuracy[-1]

    # Assertions are deferred until the test finishes, so the metrics are
    # always reported and benchmark success is determined based on *all*
    # assertions.
    self.assertBetween(end_eval_accuracy, 0.98, 1.0)

    # Use the reporting API to report single or multiple metrics/extras.
    self.report_wall_time(benchmark_time)
    self.report_metrics({
        'sec_per_epoch': sec_per_epoch,
        'accuracy': end_eval_accuracy,
    })
    self.report_extras({
        'model_name': 'MNIST',
        'description': 'CPU test for MNIST.'
    })
예제 #2
0
def parse_option():
    parser = argparse.ArgumentParser(
        description='Train image-based re-id model')
    parser.add_argument('--cfg',
                        type=str,
                        required=True,
                        metavar="FILE",
                        help='path to config file')
    # Datasets
    parser.add_argument('--root',
                        type=str,
                        help="your root path to data directory")
    parser.add_argument('--dataset',
                        type=str,
                        help="market1501, cuhk03, dukemtmcreid, msmt17")
    # Miscs
    parser.add_argument('--output',
                        type=str,
                        help="your output path to save model and logs")
    parser.add_argument('--resume', type=str, metavar='PATH')
    parser.add_argument('--eval', action='store_true', help="evaluation only")
    parser.add_argument('--tag', type=str, help='tag for log file')
    parser.add_argument('--gpu',
                        default='0',
                        type=str,
                        help='gpu device ids for CUDA_VISIBLE_DEVICES')

    args, unparsed = parser.parse_known_args()
    config = get_config(args)

    return config
예제 #3
0
    def test_train_and_evaluate(self):
        config = default.get_config()
        config.max_corpus_chars = 1000
        config.vocab_size = 32
        config.batch_size = 8
        config.num_train_steps = 1
        config.num_eval_steps = 1
        config.num_predict_steps = 1

        config.num_layers = 1
        config.qkv_dim = 128
        config.emb_dim = 128
        config.mlp_dim = 512
        config.num_heads = 2

        config.max_target_length = 32
        config.max_eval_target_length = 32
        config.max_predict_length = 32

        workdir = tempfile.mkdtemp()

        # Go two directories up to the root of the flax directory.
        flax_root_dir = pathlib.Path(__file__).parents[2]
        data_dir = str(flax_root_dir) + '/.tfds/metadata'  # pylint: disable=unused-variable

        with tfds.testing.mock_data(num_examples=128, data_dir=data_dir):
            train.train_and_evaluate(config, workdir)
        logging.info('workdir content: %s', tf.io.gfile.listdir(workdir))
예제 #4
0
def get_test_config():
    config = default.get_config()
    config.init_batch_size = 8
    config.batch_size = 8
    config.num_epochs = 1
    config.n_resent = 1
    config.n_feature = 8
    return config
예제 #5
0
def get_config():
  """Get the hyperparameter configuration to train on 8 x Nvidia V100 GPUs."""
  # Override default configuration to avoid duplication of field definition.
  config = default_lib.get_config()

  config.batch_size = 512
  config.cache = True

  return config
예제 #6
0
def get_config():
    """Get the hyperparameter configuration for Fake data benchmark."""
    # Override default configuration to avoid duplication of field definition.
    config = default_lib.get_config()
    config.batch_size = 256 * jax.device_count()
    config.half_precision = True
    config.num_epochs = 5

    # Previously the input pipeline computed:
    # `steps_per_epoch` as input_pipeline.TRAIN_IMAGES // batch_size
    config.num_train_steps = 1024 // config.batch_size
    # and `steps_per_eval` as input_pipeline.EVAL_IMAGES // batch_size
    config.steps_per_eval = 512 // config.batch_size

    return config
예제 #7
0
def get_config():
    """Get the default hyperparameter configuration."""
    config = default_lib.get_config()

    config.dataset_type = "deepvoxels"
    config.shape = "greek"

    config.batching = False
    config.num_importance = 64
    config.num_rand = 1024
    config.num_samples = 64
    config.use_viewdirs = True
    config.white_bkgd = True

    return config
예제 #8
0
  def test_train_and_evaluate(self):
    """Tests training and evaluation code by running a single step."""
    # Create a temporary directory where tensorboard metrics are written.
    workdir = tempfile.mkdtemp()

    # Go two directories up to the root of the flax directory.
    flax_root_dir = pathlib.Path(__file__).parents[2]
    data_dir = str(flax_root_dir) + "/.tfds/metadata"  # pylint: disable=unused-variable

    # Define training configuration.
    config = default.get_config()
    config.num_epochs = 1
    config.batch_size = 8

    with tfds.testing.mock_data(num_examples=8, data_dir=data_dir):
      train.train_and_evaluate(config=config, workdir=workdir)
예제 #9
0
def get_config():
    """Get the default hyperparameter configuration."""
    config = default_lib.get_config()

    config.dataset_type = "llff"

    config.llff.factor = 8
    config.llff.hold = 8

    config.num_importance = 64
    config.num_rand = 1024
    config.num_samples = 64
    config.raw_noise_std = 1.0
    config.use_viewdirs = True

    return config
예제 #10
0
def get_config():
    """Get the default hyperparameter configuration."""
    config = default_lib.get_config()

    config.dataset_type = "blender"

    config.batching = False
    config.num_importance = 128
    config.num_rand = 1024
    config.num_samples = 64
    config.num_steps = 500000
    config.lr_decay = 500
    config.use_viewdirs = True
    config.white_bkgd = True

    return config
예제 #11
0
    def test_train_and_evaluate(self):
        """Runs a single train/eval step with mocked data."""
        # Create a temporary directory where tensorboard metrics are written.
        model_dir = tempfile.mkdtemp()

        # Go two directories up to the root of the flax directory.
        flax_root_dir = pathlib.Path(__file__).parents[2]
        data_dir = str(flax_root_dir) + '/.tfds/metadata'

        # Define training configuration.
        config = config_lib.get_config()
        config.num_epochs = 1
        config.batch_size = 8

        with tfds.testing.mock_data(num_examples=8, data_dir=data_dir):
            mnist_lib.train_and_evaluate(config=config, model_dir=model_dir)
예제 #12
0
def get_config():
    """Get the default hyperparameter configuration."""
    config = default_lib.get_config()

    config.dataset_type = "deepvoxels"
    config.shape = "cube"

    config.batching = False
    config.num_importance = 128
    config.num_rand = 4096
    config.num_samples = 64
    config.num_steps = 200000
    config.lr_decay = 200
    config.use_viewdirs = True
    config.white_bkgd = True

    return config
예제 #13
0
  def test_train_and_evaluate(self):
    """Tests training and evaluation loop using mocked data."""
    # Create a temporary directory where tensorboard metrics are written.
    model_dir = tempfile.mkdtemp()

    # Go two directories up to the root of the flax directory.
    flax_root_dir = pathlib.Path(__file__).parents[2]
    data_dir = str(flax_root_dir) + '/.tfds/metadata'

    # Define training configuration
    config = default_lib.get_config()
    config.batch_size = 1
    config.num_epochs = 1
    config.num_train_steps = 1
    config.steps_per_eval = 1

    with tfds.testing.mock_data(num_examples=1, data_dir=data_dir):
      imagenet_lib.train_and_evaluate(model_dir=model_dir, config=config)
예제 #14
0
  def _get_datasets(self):
    config = default.get_config()
    config.per_device_batch_size = 1
    config.vocab_size = 32
    config.max_corpus_chars = 1000
    config.max_target_length = _TARGET_LENGTH
    config.max_eval_target_length = _EVAL_TARGET_LENGTH
    config.max_predict_length = _PREDICT_TARGET_LENGTH

    vocab_path = os.path.join(tempfile.mkdtemp(), 'sentencepiece_model')

    # Go two directories up to the root of the flax directory.
    flax_root_dir = pathlib.Path(__file__).parents[2]
    data_dir = str(flax_root_dir) + '/.tfds/metadata'  # pylint: disable=unused-variable

    with tfds.testing.mock_data(num_examples=128, data_dir=data_dir):
      train_ds, eval_ds, predict_ds, _ = input_pipeline.get_wmt_datasets(
          n_devices=2, config=config, vocab_path=vocab_path)
    return train_ds, eval_ds, predict_ds
예제 #15
0
def get_config():
    """Get the default hyperparameter configuration."""
    config = default_lib.get_config()

    config.dataset_type = "llff"
    config.shape = "fern"
    config.llff.hold = 8
    config.down_factor = 4

    config.num_importance = 128
    config.num_rand = 4096
    config.num_samples = 64
    config.num_steps = 200000
    config.lr_decay = 250
    config.raw_noise_std = 1.0
    config.use_viewdirs = True
    config.num_poses = 120

    return config
예제 #16
0
def get_config():
    """Get the default hyperparameter configuration."""
    config = default_lib.get_config()

    config.dataset_type = "blender"
    config.half_res = True

    config.batching = False
    config.num_importance = 64
    config.num_rand = 1024
    config.num_samples = 64
    config.num_steps = 200000
    config.use_viewdirs = True
    config.white_bkgd = True

    config.i_print = 500
    config.i_img = 5000
    config.render_factor = 2

    return config
예제 #17
0
파일: train_test.py 프로젝트: voicedm/flax
    def test_train_step_updates_parameters(self):
        """Tests if the train step updates the parameters in train state."""
        # Create model and a state that contains the parameters.
        config = default_config.get_config()
        config.vocab_size = 13
        rng = jax.random.PRNGKey(config.seed)
        model = train.model_from_config(config)
        state = train.create_train_state(rng, config, model)

        token_ids = np.array([[2, 4, 3], [2, 6, 3]], dtype=np.int32)
        lengths = np.array([2, 3], dtype=np.int32)
        labels = np.zeros_like(lengths)
        batch = {'token_ids': token_ids, 'length': lengths, 'label': labels}
        rngs = {'dropout': rng}
        train_step_fn = jax.jit(train.train_step)
        new_state, metrics = train_step_fn(state, batch, rngs)
        self.assertIsInstance(new_state, train.TrainState)
        self.assertIsInstance(metrics, train.Metrics)
        old_param_values = jax.tree_leaves(state.params)
        new_param_values = jax.tree_leaves(new_state.params)
        for old_array, new_array in zip(old_param_values, new_param_values):
            # Make sure parameters were updated.
            self.assertFalse(np.allclose(old_array, new_array))