Exemple #1
0
 def Params(cls) -> InstantiableParams:  # pylint:disable=invalid-name
     """Task parameters."""
     p = InstantiableParams(cls)
     p.Define('name', '',
              'Name of this task object, must be a valid identifier.')
     p.Define('model', None,
              'The underlying JAX model encapsulating all the layers.')
     p.Define('train', py_utils.Params(),
              'Params to control how this task should be trained.')
     p.Define('metrics', None, 'How metrics are computed.')
     tp = p.train
     tp.Define('learner', learners_lib.Learner.Params(),
               'One or a list of learners.')
     tp.Define('num_train_steps', 1e7,
               'Maximum number of training steps to run.')
     # TODO(bf-jax): Add an option to perform this wrt. a time duration.
     tp.Define(
         'save_interval_steps', 5000,
         'How frequently to save a model checkpoint in terms of the number of '
         'training steps.')
     tp.Define(
         'save_keep_interval_duration', '12h',
         'How frequently to keep a saved model checkpoint as a duration string '
         'such as `1h` for one hour or `90m` for 90 minutes. This is performed '
         'in addition to keeping the most recent `max_to_keep` checkpoint '
         'files.')
     tp.Define('save_max_to_keep', 10,
               'The maximum number of recent checkpoints to keep.')
     tp.Define(
         'summary_interval_steps', 100,
         'How frequently to generate summaries in terms of the number of '
         'training steps.')
     tp.Define(
         'norm_summary_interval_steps', 500,
         'How frequently to generate expensive summaries computing the norms '
         'of variables in terms of the number of training steps.')
     tp.Define(
         'eval_interval_steps', 100,
         'How frequently to evaluate the model on the evaluation splits in '
         'terms of the number of training steps.')
     tp.Define(
         'inputs_split_mapping', None, 'The PartitionSpec for inputs'
         'such as inputs, labels, targets, paddings, num words etc. This is only'
         'relevant for SPMD sharded models. By default it is None, which means'
         'all the inputs are replicated. For sharding inputs, this is a '
         '`NestedMap` with keys `map_1d`, `map_2d`, ..., etc.,'
         'which specifies how to shard the inputs of that dimension.')
     return p
Exemple #2
0
 def Params(cls) -> InstantiableParams:
   p = super().Params()
   p.Define('model', layers.TransformerEncoderDecoder.Params(),
            'Sequence model layer for this task.')
   p.Define(
       'return_predictions', False, 'Whether to return predictions during'
       'eval. Returning predictions is more expensive, but may be useful'
       'for debugging.')
   decoder_p = py_utils.Params()
   decoder_p.Define('seqlen', 0, 'Maximum output sequence length.')
   decoder_p.Define(
       'eos_id', 2,
       'The id of EOS token indicating the termination of decoding.')
   p.Define('decoder', decoder_p, 'Decoder params.')
   p.Define(
       'label_smoothing_prob', 0.0,
       'If > 0.0, smooth out one-hot prob by spreading this amount of'
       ' prob mass to all other tokens.')
   return p
Exemple #3
0
  def Params(cls) -> InstantiableParams:
    p = super().Params()
    p.Define('lm', layers.TransformerLm.Params(), 'LM layer.')
    p.Define(
        'return_predictions', False, 'Whether to return predictions during'
        'eval. Returning predictions is more expensive, but may be useful'
        'for debugging.')

    greedy_search_p = py_utils.Params()
    greedy_search_p.Define('seqlen', 0, 'Maximum output sequence length.')
    greedy_search_p.Define(
        'min_prefix_len', 5,
        'Minimum number of tokens picked to be used as decoding prefix.')
    greedy_search_p.Define(
        'eos_id', 2,
        'The id of EOS token indicating the termination of greedy search.')
    greedy_search_p.Define(
        'max_decode_steps', None,
        'If not None, the max decode steps for each example. If None, this '
        'is set to `seqlen`, which contains prefix.')
    p.Define('decoder', greedy_search_p, 'Decoder param.')
    return p