class LossAllStages: """Hyperparameters relating to the loss. This only includes those parameters that don't vary based on training stage. """ contrastive = hparam.nest(ContrastiveLoss) cross_entropy = hparam.nest(CrossEntropyLoss) include_bias_in_weight_decay = hparam.field(default=True, abbrev='bd')
class Stage: """Hyperparameters relating to a single stage of training.""" # Training is structured as having multiple stages, currently 2. The stages # are run sequentially, all using the same Graph. This enables using different # loss settings, learning rate schedules, and optimizers in each stage. The # primary use case for this is standard contrastive training, where stage 1 # trains with just the contrastive loss applied to the projection head and the # encoder, and stage 2 trains with just the cross-entropy loss applied to just # the classification head. It's possible to train with just a single stage by # setting the `train_epochs` for an unused stage to 0. A single stage can use # multiple losses simultaneously by setting the appropriate weights in the # corresponding LossStage parameters for each stage. training = hparam.nest(TrainingStage) loss = hparam.nest(LossStage)
class InputData: """Hyperparameters relating to input data.""" input_fn = hparam.field(default='imagenet', abbrev='ds') preprocessing = hparam.nest(ImagePreprocessing) max_samples = hparam.field(default=-1, abbrev='ms') label_noise_prob = hparam.field(default=0., abbrev='ln') # If True, during training the dataset is sharded per TPU host, rather than # each host independently iterating over the full dataset. This guarantees # that the same sample won't appear in the same global batch on different TPU # cores, and also saves memory when caching the dataset. shard_per_host = hparam.field(default=True, abbrev='hs')
class HParams: """Hyperparameters.""" # This unfortunately needs to be in the root and be named the same as its # abbreviation so that Bootstrap's TPUEstimator wrapper can find it. bs = hparam.field(default=2048, abbrev='bs') architecture = hparam.nest(Architecture) loss_all_stages = hparam.nest(LossAllStages) stage_1 = hparam.nest(Stage, prefix='s1') stage_2 = hparam.nest(Stage, prefix='s2') eval = hparam.nest(Eval) input_data = hparam.nest(InputData) warm_start = hparam.nest(WarmStart)
class InvalidDuplicateParams(): bool_params = hparam.nest(ValidBoolParams) duplicate_param = hparam.field('tbp', default='Im a duplicate')
class ValidHParams(): nested_params = hparam.nest(ValidNestedParams1) list_params = hparam.nest(ValidListParams)
class ValidNestedParams1: nested_params = hparam.nest(ValidNestedParams2) bool_params = hparam.nest(ValidBoolParams) enum_params = hparam.nest(ValidEnumParams)
class ValidNestedParams2: numeric_params = hparam.nest(ValidNumericParams) string_params1 = hparam.nest(ValidStringParams, prefix='s1') string_params2 = hparam.nest(ValidStringParams, prefix='s2')
def test_invalid_nest_instance(self): nested = ValidNestedParams1() with self.assertRaisesRegex( TypeError, r'nest\(\) must be passed a class, not an instance.'): hparam.nest(nested)
def test_invalid_nest_non_hparam(self): with self.assertRaisesRegex( TypeError, 'Nested hparams classes must use the @hparam.s decorator'): hparam.nest(NonHParams)