コード例 #1
0
    def test_values_as_expected(self):
        """Test whether processed dictionaries match expected values."""

        metric_dict = {
            'cat1': {
                'key': 2.0,
                'denominator': 1.0
            },
            'cat2': {
                'key': 2.0,
                'denominator': 2.0
            },
        }

        processed_metrics = metric_utils.process_metrics(metric_dict)
        expected_result = {
            'cat1_key': 2.0,
            'cat1_denom': 1.0,
            'cat2_key': 1.0,
            'cat2_denom': 2.0,
        }
        self.assertEqual(processed_metrics, expected_result)

        metric_dict = {
            'cat1': {
                'key': 2.0,
                'denominator': 1.0
            },
            'cat2': {
                'key': 2.0,
                'denominator': 2.0
            },
        }

        processed_metrics = metric_utils.process_metrics(metric_dict,
                                                         prefix='pref')
        expected_result = {
            'pref/cat1_key': 2.0,
            'pref/cat1_denom': 1.0,
            'pref/cat2_key': 1.0,
            'pref/cat2_denom': 2.0,
        }
        self.assertEqual(processed_metrics, expected_result)
コード例 #2
0
    def test_reproduce_paper_evals(self, num_chunks):
        """Reproduce results from https://www.aclweb.org/anthology/P18-1009.pdf."""
        num_samples = self.labels.shape[0]
        chunk_size = num_samples // num_chunks
        metrics = []
        for chunk_start in range(0, num_samples, chunk_size):
            chunk_end = min(chunk_start + chunk_size, num_samples)
            labels = self.labels[chunk_start:chunk_end]
            predictions = self.predictions[chunk_start:chunk_end]
            current_metrics = ultra_fine_entity_typing_task.get_prediction_recall_metrics(
                labels, predictions)
            current_metrics = jax.tree_map(lambda x: jnp.expand_dims(x, 0),
                                           current_metrics)
            metrics.append(current_metrics)

        metrics = common_utils.get_metrics(metrics)
        metrics_sum = jax.tree_map(jnp.sum, metrics)
        processed_metrics = metric_utils.process_metrics(metrics_sum)
        self.assertAlmostEqual(processed_metrics['total_precision_value'],
                               0.481,
                               places=3)
        self.assertAlmostEqual(processed_metrics['total_recall_value'],
                               0.232,
                               places=3)
        self.assertAlmostEqual(
            processed_metrics['coarse_grained_precision_value'],
            0.603,
            places=3)
        self.assertAlmostEqual(
            processed_metrics['coarse_grained_recall_value'], 0.616, places=3)
        self.assertAlmostEqual(
            processed_metrics['fine_grained_precision_value'], 0.404, places=3)
        self.assertAlmostEqual(processed_metrics['fine_grained_recall_value'],
                               0.384,
                               places=3)
        self.assertAlmostEqual(
            processed_metrics['ultra_fine_grained_precision_value'],
            0.428,
            places=3)
        self.assertAlmostEqual(
            processed_metrics['ultra_fine_grained_recall_value'],
            0.088,
            places=3)
コード例 #3
0
ファイル: trainer.py プロジェクト: greck2908/language
def evaluate(
    eval_step_fn,
    train_state: ts.TrainState,
    model_vars: Dict[str, Any],
    eval_data: Sequence[Dict[str, Any]],
) -> Tuple[Dict[str, Any], Sequence[Tuple[Dict[str, Any], Optional[Dict[
    str, Any]]]]]:
  """Evaluate current parameters and return a dictionary with metrics.

  Args:
    eval_step_fn: partial eval step that takes in model params and inputs only
    train_state: contains model params, loss fn, grad update fn.
    model_vars: model variables that are not optimized.
    eval_data: sequence of evaluation data.

  Returns:
    Dictionary of metrics aggregated over all evaluation steps and the info
    for the very first batch (batch itself and corresponding auxiliary output).
  """

  logging.info('Performing evaluation.')
  eval_metrics = []
  eval_auxiliary = []
  for batch in eval_data:
    batch = jax.tree_map(jnp.asarray, batch)
    metrics, auxiliary_output = eval_step_fn(
        train_state,
        model_vars,
        batch,
    )
    eval_metrics.append(metrics)
    batch_auxiliary = (jax.device_get(batch), jax.device_get(auxiliary_output))
    eval_auxiliary.append(batch_auxiliary)
  eval_metrics = common_utils.get_metrics(eval_metrics)
  eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
  eval_summary = metric_utils.process_metrics(eval_metrics_sums, prefix='eval')
  return eval_summary, eval_auxiliary
コード例 #4
0
ファイル: trainer.py プロジェクト: greck2908/language
def train(config: ml_collections.ConfigDict):
  """Run training."""

  # Establish host information
  local_device_count = jax.local_device_count()
  host_count = jax.process_count()
  host_id = jax.process_index()

  task = task_registry.get_registered_task(config.task_name)

  start_step = 0
  rng = jax.random.PRNGKey(config.seed)

  model_config = ml_collections.FrozenConfigDict(config.model_config)
  model = task.build_model(model_config)

  # Initialization needs to be pmapped because models use collective ops.
  # Create dummy input
  dummy_input = {
      key: jnp.tile(value, (local_device_count,) + (1,) * value.ndim)
      for key, value in task.dummy_input(config).items()
  }

  rng, init_rng = jax.random.split(rng)
  init_rng = jax.random.split(init_rng, local_device_count)

  logging.info('Initializing model.')
  initial_variables = jax.pmap(
      model.init, 'batch', static_broadcasted_argnums=2)(init_rng, dummy_input,
                                                         True)
  logging.info('Finished initializing model.')
  initial_variables = initial_variables.unfreeze()

  if config.load_weights is not None:
    logging.info('Loading model weights from file')
    loaded_variables = task.load_weights(config)
    unexpected, missing = checkpoint_utils.merge_nested_dicts(
        initial_variables, loaded_variables)
    logging.info('*** Unexpected features: ***')
    for feature_name in unexpected:
      logging.info('\t%s', feature_name)
    logging.info('*** Missing features: ***')
    for feature_name in missing:
      logging.info('\t%s', feature_name)

  model_vars = {
      key: value for key, value in initial_variables.items() if key != 'params'
  }

  learning_rate_fn = optim_utils.create_learning_rate_scheduler(
      learning_rate=config.learning_rate,
      warmup=config.warmup,
      warmup_steps=config.get('warmup_steps', None),
      linear_decay=config.linear_decay,
      max_steps=config.num_train_steps,
      decay_minimum_factor=config.get('decay_minimum_factor', None),
  )

  if config.weight_decay_exclude is not None:
    decay_mask = optim_utils.create_dict_mask(initial_variables['params'],
                                              config.weight_decay_exclude)
  else:
    decay_mask = None
  tx = optax.adamw(
      learning_rate=learning_rate_fn,
      weight_decay=config.weight_decay,
      b1=0.9,
      b2=0.999,
      eps=1e-6,
      mask=decay_mask)
  if config.grad_clip is not None:
    tx = optax.chain(tx, optax.clip_by_global_norm(config.grad_clip))

  ignore_k_nans = config.get('ignore_k_nans')
  if ignore_k_nans is not None:
    tx = optax.apply_if_finite(tx, max_consecutive_errors=ignore_k_nans)

  loss_fn = task.make_loss_fn(config)
  train_state = ts.TrainState.create(
      apply_fn=loss_fn,
      params=jax_utils.unreplicate(initial_variables['params']),
      tx=tx,
  )

  # We access model params only from train state.
  del initial_variables

  # Restore unreplicated train state from last checkpoint
  train_state = checkpoints.restore_checkpoint(config.model_dir, train_state)
  # Grab last step.
  start_step = int(train_state.step)

  writer = metric_writers.create_default_writer(
      config.model_dir, just_logging=jax.process_index() > 0)
  if start_step == 0:
    writer.write_hparams(config.to_dict())

  dropout_rngs = jax.random.split(rng, local_device_count)

  del rng

  # Load datasets
  logging.info('Loading dataset.')

  # Make sure we don't re-use same data if we load weights or checkpoint
  seed = config.seed + start_step
  if config.load_weights:
    seed = seed + hash(config.load_weights)

  name_to_features = task.get_name_to_features(config)
  preprocess_fn = task.make_preprocess_fn(config)
  collater_fn = task.make_collater_fn(config)

  train_data = data_utils.load_multi_dataset(
      datasets_config=config.train_data,
      name_to_features=name_to_features,
      preprocess_fn=preprocess_fn,
      collater_fn=collater_fn,
      is_training=True,
      per_device_batch_size=config.per_device_batch_size,
      local_device_count=local_device_count,
      host_count=host_count,
      host_id=host_id,
      seed=config.seed,
  )
  train_iter = iter(train_data)

  pad_eval = config.get('pad_eval', False)
  if pad_eval:
    logging.info('Eval data is padded such that none of samples are dropped.')
  else:
    logging.warn('Eval data is NOT padded -- some samples might be dropped.')

  eval_data = data_utils.load_multi_dataset(
      datasets_config=config.eval_data,
      name_to_features=name_to_features,
      preprocess_fn=preprocess_fn,
      collater_fn=collater_fn,
      is_training=False,
      per_device_batch_size=config.per_device_batch_size,
      local_device_count=local_device_count,
      host_count=host_count,
      host_id=host_id,
      seed=config.seed,
      pad_eval=pad_eval,
  )
  eval_data = list(eval_data)
  logging.info('Loaded %d samples for evaluation.', len(eval_data))

  # Setup postprocessing_fn for saving samples occasionally.
  if config.get('save_samples_every_steps') is not None:
    if config.get('save_samples_every_steps') % config.eval_every_steps != 0:
      raise ValueError(
          '`eval_every_steps` must divide `save_samples_every_steps`.')
    postprocessing_fn = task.make_output_postprocess_fn(config)

  # Training loop
  logging.info('Starting training.')

  # Replicate train state.
  train_state = jax_utils.replicate(train_state)

  # compile multidevice versions of train/eval/predict step
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          model_config=model_config,
      ),
      axis_name='batch',
      donate_argnums=(0,),
  )  # pytype: disable=wrong-arg-types
  p_eval_step = jax.pmap(
      functools.partial(
          eval_step,
          model_config=model_config,
      ),
      axis_name='batch')

  hooks = []
  report_progress = periodic_actions.ReportProgress(
      num_train_steps=config.num_train_steps, writer=writer)

  if jax.process_index() == 0:
    hooks += [
        report_progress,
        periodic_actions.Profile(logdir=config.model_dir, num_profile_steps=5)
    ]
  train_metrics = []
  with metric_writers.ensure_flushes(writer):
    for step in range(start_step, config.num_train_steps):
      is_last_step = step == config.num_train_steps - 1

      # Shard data to devices and perform a training step.
      with jax.profiler.StepTraceAnnotation('train', step_num=step):
        batch = jax.tree_map(jnp.asarray, train_iter.get_next())
        train_state, metrics = p_train_step(
            train_state,
            model_vars,
            batch,
            dropout_rngs,
        )
        train_metrics.append(metrics)

      # Quick indication that training is happening.
      logging.log_first_n(logging.INFO, 'Finished training step %d', 5, step)
      for h in hooks:
        h(step)

        # Periodic metric handling.
      if step % config.eval_every_steps == 0 or is_last_step:
        with report_progress.timed('training_metrics'):
          logging.info('Gathering training metrics.')
          train_metrics = common_utils.get_metrics(train_metrics)
          metrics_sums = jax.tree_map(jnp.sum, train_metrics)
          summary = metric_utils.process_metrics(metrics_sums, prefix='train')
          summary['learning_rate'] = learning_rate_fn(step)

          writer.write_scalars(step, summary)
          train_metrics = []

          with report_progress.timed('eval'):
            eval_results, eval_auxiliary = evaluate(
                eval_step_fn=p_eval_step,
                train_state=train_state,
                model_vars=model_vars,
                eval_data=eval_data,
            )
            writer.write_scalars(step, eval_results)

            if config.get('save_samples_every_steps') is not None:
              with report_progress.timed('save_samples'):
                if config.get('save_first_batch_only', 'True'):
                  postprocessing_input = [eval_auxiliary[0]]
                eval_processed = [
                    postprocessing_fn(batch, auxiliary_output)
                    for batch, auxiliary_output in eval_auxiliary
                ]
                data_utils.save_samples_to_json(eval_processed, config, step)

      # Save a checkpoint on one host after every checkpoint_freq steps.
      save_checkpoint = (
          step % config.checkpoint_every_steps == 0 or is_last_step)
      if (config.save_checkpoints and save_checkpoint and
          jax.process_index() == 0):
        with report_progress.timed('checkpoint'):
          logging.info('Saving checkpoint at step %s', step)
          checkpoints.save_checkpoint(
              config.model_dir,
              jax_utils.unreplicate(train_state),
              step,
              keep=config.get('keep_checkpoints', 1),
              keep_every_n_steps=config.get('keep_checkpoint_every_steps'),
          )

      save_model = (
          config.save_every_steps and
          (step % config.save_every_steps == 0 or is_last_step) and step != 0)
      if (save_model and jax.process_index() == 0):
        with report_progress.timed('checkpoint'):
          logging.info('Saving weights at step %s', step)
          save_path = os.path.join(config.model_dir, 'weights',
                                   'step' + str(step))
          # By default, save only encoder weights
          weights = jax_utils.unreplicate(train_state).params['encoder']
          checkpoint_utils.save_weights(save_path, weights)