Exemple #1
0
class LossAllStages:
    """Hyperparameters relating to the loss.

  This only includes those parameters that don't vary based on training stage.
  """
    contrastive = hparam.nest(ContrastiveLoss)
    cross_entropy = hparam.nest(CrossEntropyLoss)
    include_bias_in_weight_decay = hparam.field(default=True, abbrev='bd')
Exemple #2
0
class Stage:
    """Hyperparameters relating to a single stage of training."""
    # Training is structured as having multiple stages, currently 2. The stages
    # are run sequentially, all using the same Graph. This enables using different
    # loss settings, learning rate schedules, and optimizers in each stage. The
    # primary use case for this is standard contrastive training, where stage 1
    # trains with just the contrastive loss applied to the projection head and the
    # encoder, and stage 2 trains with just the cross-entropy loss applied to just
    # the classification head. It's possible to train with just a single stage by
    # setting the `train_epochs` for an unused stage to 0. A single stage can use
    # multiple losses simultaneously by setting the appropriate weights in the
    # corresponding LossStage parameters for each stage.
    training = hparam.nest(TrainingStage)
    loss = hparam.nest(LossStage)
Exemple #3
0
class InputData:
    """Hyperparameters relating to input data."""
    input_fn = hparam.field(default='imagenet', abbrev='ds')
    preprocessing = hparam.nest(ImagePreprocessing)
    max_samples = hparam.field(default=-1, abbrev='ms')
    label_noise_prob = hparam.field(default=0., abbrev='ln')
    # If True, during training the dataset is sharded per TPU host, rather than
    # each host independently iterating over the full dataset. This guarantees
    # that the same sample won't appear in the same global batch on different TPU
    # cores, and also saves memory when caching the dataset.
    shard_per_host = hparam.field(default=True, abbrev='hs')
Exemple #4
0
class HParams:
    """Hyperparameters."""
    # This unfortunately needs to be in the root and be named the same as its
    # abbreviation so that Bootstrap's TPUEstimator wrapper can find it.
    bs = hparam.field(default=2048, abbrev='bs')
    architecture = hparam.nest(Architecture)
    loss_all_stages = hparam.nest(LossAllStages)
    stage_1 = hparam.nest(Stage, prefix='s1')
    stage_2 = hparam.nest(Stage, prefix='s2')
    eval = hparam.nest(Eval)
    input_data = hparam.nest(InputData)
    warm_start = hparam.nest(WarmStart)
Exemple #5
0
class InvalidDuplicateParams():
    bool_params = hparam.nest(ValidBoolParams)
    duplicate_param = hparam.field('tbp', default='Im a duplicate')
Exemple #6
0
class ValidHParams():
    nested_params = hparam.nest(ValidNestedParams1)
    list_params = hparam.nest(ValidListParams)
Exemple #7
0
class ValidNestedParams1:
    nested_params = hparam.nest(ValidNestedParams2)
    bool_params = hparam.nest(ValidBoolParams)
    enum_params = hparam.nest(ValidEnumParams)
Exemple #8
0
class ValidNestedParams2:
    numeric_params = hparam.nest(ValidNumericParams)
    string_params1 = hparam.nest(ValidStringParams, prefix='s1')
    string_params2 = hparam.nest(ValidStringParams, prefix='s2')
Exemple #9
0
 def test_invalid_nest_instance(self):
     nested = ValidNestedParams1()
     with self.assertRaisesRegex(
             TypeError,
             r'nest\(\) must be passed a class, not an instance.'):
         hparam.nest(nested)
Exemple #10
0
 def test_invalid_nest_non_hparam(self):
     with self.assertRaisesRegex(
             TypeError,
             'Nested hparams classes must use the @hparam.s decorator'):
         hparam.nest(NonHParams)