def test_add_decayed_weights_with_mask(self): """Test mask is not added for add_decayed_weights if specified in hps.""" class Foo(nn.Module): """Dummy model.""" train: bool filters: int @nn.compact def __call__(self, x): x = nn.Conv(self.filters, (1, 1), use_bias=False, dtype=jnp.float32)(x) x = nn.BatchNorm(use_running_average=not self.train, momentum=0.9, epsilon=1e-5, dtype=jnp.float32)(x) return x tx = from_hparams( ml_collections.ConfigDict({ '0': { 'element': 'add_decayed_weights', 'hps': { 'weight_decay': 1e-4, 'mask': 'bias_bn' } } })) key = jax.random.PRNGKey(0) x = jnp.ones((5, 4, 4, 3)) y = jax.random.uniform(key, (5, 4, 4, 7)) foo_vars = flax.core.unfreeze(Foo(filters=7, train=True).init(key, x)) @self.variant def train_step(params, x, y): y1, new_batch_stats = Foo(filters=7, train=True).apply( params, x, mutable=['batch_stats']) return jnp.abs(y - y1).sum(), new_batch_stats state = self.variant(tx.init)(foo_vars['params']) grads, _ = jax.grad(train_step, has_aux=True)(foo_vars, x, y) updates, state = self.variant(tx.update)(dict(grads['params']), state, foo_vars['params']) chex.assert_trees_all_close(updates['BatchNorm_0'], grads['params']['BatchNorm_0'])
def get_config(): """Get the default hyperparameter configuration.""" config = ml_collections.ConfigDict() # General config.debug = False config.experiment_name = '' # Set in train.sh config.on_gcp = True config.local_output_dir = '/tmp' config.use_gpu = True config.random_seed = 42 # Model params config.model_key = 'resnet18' config.cascaded = True config.tdl_mode = 'OSD' # OSD, EWS, noise config.tdl_alpha = 0.0 config.noise_var = 0.0 config.bn_time_affine = False config.bn_time_stats = True # Loss params config.lambda_val = 0.0 config.normalize_loss = False # Dataset params config.dataset_name = 'CIFAR10' # CIFAR10, CIFAR100, TinyImageNet config.val_split = 0.1 config.split_idxs_root = None config.augmentation_noise_type = 'occlusion' config.num_workers = 16 config.drop_last = False # Train params config.batch_size = 128 config.epochs = 100 config.eval_freq = 10 config.upload_freq = 5 # Optimizer params config.learning_rate = 0.1 config.momentum = 0.9 config.weight_decay = 0.0005 config.nesterov = True config.lr_milestones = [30, 60, 120, 150] config.lr_schedule_gamma = 0.2 return config
def test_embedding_layer(self): config = ml_collections.ConfigDict({ "batch_size": 3, "vocab_size": 1000, "d_emb": 32, "max_seq_length": 64, "type_vocab_size": 2, "d_model": 4, "dropout_rate": 0.1, "dtype": jnp.float32 }) frozen_config = ml_collections.FrozenConfigDict(config) rng = jax.random.PRNGKey(100) embedding_layer = layers.EmbeddingLayer(config=frozen_config) init_batch = { "input_ids": jnp.ones((1, frozen_config.max_seq_length), jnp.int32), "type_ids": jnp.ones((1, frozen_config.max_seq_length), jnp.int32) } params = init_layer_variables(rng, embedding_layer, init_batch)["params"] expected_keys = { "word", "position", "type", "layer_norm", "hidden_mapping_in" } self.assertEqual(params.keys(), expected_keys) rng, init_rng = jax.random.split(rng) inputs = { "input_ids": jax.random.randint( init_rng, (frozen_config.batch_size, frozen_config.max_seq_length), minval=0, maxval=13), "type_ids": jax.random.randint( init_rng, (frozen_config.batch_size, frozen_config.max_seq_length), minval=0, maxval=2) } outputs = embedding_layer.apply({"params": params}, rngs={"dropout": rng}, **inputs) self.assertEqual(outputs.shape, (frozen_config.batch_size, frozen_config.max_seq_length, frozen_config.d_model))
def get_r50_l16_config(): """Returns the Resnet50 + ViT-L/16 configuration. customized """ config = get_l16_config() config.patches.grid = (16, 16) config.resnet = ml_collections.ConfigDict() config.resnet.num_layers = (3, 4, 9) config.resnet.width_factor = 1 config.classifier = 'seg' config.resnet_pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz' config.decoder_channels = (256, 128, 64, 16) config.skip_channels = [512, 256, 64, 16] config.n_classes = 2 config.activation = 'softmax' return config
def get_config(): """Get the default hyperparameter configuration.""" config = ml_collections.ConfigDict() config.batch_size = 512 config.eval_frequency = TRAIN_EXAMPLES // config.batch_size config.num_train_steps = (TRAIN_EXAMPLES // config.batch_size) * NUM_EPOCHS config.num_eval_steps = VALID_EXAMPLES // config.batch_size config.weight_decay = 0. config.grad_clip_norm = None config.save_checkpoints = True config.restore_checkpoints = True config.checkpoint_freq = (TRAIN_EXAMPLES // config.batch_size) * NUM_EPOCHS // 2 config.random_seed = 0 config.learning_rate = .001 config.factors = 'constant * linear_warmup * cosine_decay' config.warmup = (TRAIN_EXAMPLES // config.batch_size) * 1 config.steps_per_cycle = (TRAIN_EXAMPLES // config.batch_size) * NUM_EPOCHS # model params config.model = ml_collections.ConfigDict() config.model.num_layers = 1 config.model.num_heads = 2 config.model.emb_dim = 32 config.model.dropout_rate = 0.1 config.model.qkv_dim = config.model.emb_dim // 2 config.model.mlp_dim = config.model.qkv_dim * 2 config.model.attention_dropout_rate = 0.1 config.model.classifier_pool = 'MEAN' config.model.learn_pos_emb = False config.trial = 0 # dummy for repeated runs. return config
def get_config(config_string): """Return an instance of ConfigDict depending on `config_string`.""" possible_structures = { 'linear': ml_collections.ConfigDict({ 'model_constructor': 'snt.Linear', 'model_config': ml_collections.ConfigDict({ 'output_size': 42, }) }), 'lstm': ml_collections.ConfigDict({ 'model_constructor': 'snt.LSTM', 'model_config': ml_collections.ConfigDict({ 'hidden_size': 108, }) }) } return possible_structures[config_string]
def get_fp_quant_hparams_config(quantized, quantized_reductions): """Create appropriate setting for 'QuantHParams' field used by softmax and LayerNormAQT.""" if quantized_reductions and not quantized: raise ValueError("If `quantized` is False, `quantized_reductions` " "must also be False.") if quantized: quant_hparams = ml_collections.ConfigDict({ "prec": config_schema_utils.get_fp_config(), "reduction_prec": config_schema_utils.get_fp_config() if quantized_reductions else None }) else: quant_hparams = None return quant_hparams
def get_r50_b16_config(): """Returns the Resnet50 + ViT-B/16 configuration.""" config = get_b16_config() config.patches.grid = (16, 16) config.resnet = ml_collections.ConfigDict() config.resnet.num_layers = (3, 4, 9) config.resnet.width_factor = 1 config.classifier = "seg" config.decoder_channels = (256, 128, 64, 16) config.skip_channels = [512, 256, 64, 16] config.n_classes = 4 config.n_skip = 3 config.activation = "softmax" return config
def get_eval_config(): """Configuration relation to model evaluation.""" eval_config = ml_collections.ConfigDict() eval_config.eval_once = False eval_config.save_output = True # the size of chunks for evaluation inferences, # set to the value that fits your GPU/TPU memory. eval_config.chunk = 4096 eval_config.inference = False eval_config.mvsn_style = False eval_config.return_coarse = False eval_config.checkpoint_step = -1 return eval_config
def set_dataset_config(dataset, config): """Sets dataset-related configs.""" config.dataset = mlc.ConfigDict() if dataset == "imagenet2012": config.dataset_name = dataset config.dataset.train_split = "train" config.dataset.val_split = "validation" config.dataset.num_classes = 1000 config.dataset.input_shape = (224, 224, 3) else: assert dataset == "cifar10" config.dataset_name = dataset config.dataset.train_split = "train" config.dataset.val_split = "test" config.dataset.num_classes = 10 config.dataset.input_shape = (32, 32, 3)
def get_config(): """Get the default hyperparameter configuration.""" config = ml_collections.ConfigDict() config.learning_rate = 0.1 config.momentum = 0.9 config.batch_size = 128 config.num_epochs = 10 config.cache = False config.half_precision = False # If num_train_steps==-1 then the number of training steps is calculated from # num_epochs using the entire dataset. Similarly for steps_per_eval. config.num_train_steps = -1 config.steps_per_eval = -1 return config
def test_train_and_evaluate(self): workdir = tempfile.gettempdir() config = common.get_config() config.model = models.get_testing_config() config.dataset = 'cifar10' config.pp = ml_collections.ConfigDict( {'train': 'train[:98%]', 'test': 'test', 'resize': 448, 'crop': 384}) config.batch = 64 config.accum_steps = 2 config.batch_eval = 8 config.total_steps = 1 config.pretrained_dir = workdir test_utils.create_checkpoint(config.model, f'{workdir}/testing.npz') opt_pmap = train.train_and_evaluate(config, workdir) self.assertTrue(os.path.exists(f'{workdir}/model.npz'))
def get_r50_b16_config(): """Returns the Resnet50 + ViT-B/16 configuration.""" config = get_b16_config() config.patches.grid = (16, 16) config.resnet = ml_collections.ConfigDict() config.resnet.num_layers = (3, 4, 9) config.resnet.width_factor = 1 config.classifier = 'seg' config.pretrained_path = '/home/cjj/Documents/TransUNet/model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz' config.decoder_channels = (256, 128, 64, 16) config.skip_channels = [512, 256, 64, 16] config.n_classes = 19 config.n_skip = 3 config.activation = 'softmax' return config
def get_block_config(parent_config, block_kind): """Create a ConfigDict corresponding to wmt_mlperf.models.Encoder[Decoder]1DBlock.HParams.""" config = ml_collections.ConfigDict() config_schema_utils.set_default_reference(config, parent_config, "mlp_block") if block_kind == BlockKind.encoder: config_schema_utils.set_default_reference(config, parent_config, "attention") elif block_kind == BlockKind.decoder: config_schema_utils.set_default_reference( config, parent_config, ["self_attention", "enc_dec_attention"], parent_field="attention") else: raise ValueError(f"Unknown block_kind {block_kind}") config.lock() return config
def get_r50_b16_config(): """Returns the Resnet50 + ViT-B/16 configuration.""" config = get_b16_config() config.patches.grid = (16, 16) config.resnet = ml_collections.ConfigDict() config.resnet.num_layers = (3, 4, 9) config.resnet.width_factor = 1 config.classifier = 'seg' config.pretrained_path = '/home/viplab/data/R50+ViT-B_16.npz' config.decoder_channels = (256, 128, 64, 16) config.skip_channels = [512, 256, 64, 16] config.n_classes = 2 config.n_skip = 3 config.activation = 'softmax' return config
def get_residual_config(parent_config): """Creates ConfigDict corresponding to imagenet.models.ResidualBlock.HParams.""" config = ml_collections.ConfigDict() config_schema_utils.set_default_reference( config, parent_config, ["conv_proj", "conv_1", "conv_2", "conv_3"], parent_field="conv") # TODO(b/179063860): The input distribution is an intrinsic model # property and shouldn't be part of the model configuration. Update # the hparam dataclasses to eliminate the input_distribution field and # then delete this. config.conv_proj.quant_act.input_distribution = "positive" config.conv_2.quant_act.input_distribution = "positive" config.conv_3.quant_act.input_distribution = "positive" config.lock() return config
def get_config(): """Return the default configuration.""" config = ml_collections.ConfigDict() config.num_steps = 10000 # Number of training steps to perform. config.batch_size = 128 # Batch size. config.learning_rate = 0.01 # Learning rate # Number of samples to draw for prediction. config.num_prediction_samples = 500 # Batch size to use for prediction. Ideally as big as possible, but may need # to be reduced for memory reasons depending on the value of # `num_prediction_samples`. config.prediction_batch_size = 500 # Multiplier for the likelihood term in the loss config.likelihood_multiplier = 5. # Multiplier for the MMD constraint term in the loss config.constraint_multiplier = 0. # Scaling factor to use in KL term. config.beta = 1.0 # The number of samples we draw from each latent variable distribution. config.mmd_sample_size = 100 # Directory into which results should be placed. By default it is the empty # string, in which case no saving will occur. The directory specified will be # created if it does not exist. config.output_dir = '' # The index of the step at which to turn on the constraint multiplier. For # steps prior to this the multiplier will be zero. config.constraint_turn_on_step = 0 # The random seed for tensorflow that is applied to the graph iff the value is # non-negative. By default the seed is not constrained. config.seed = -1 # When doing fair inference, don't sample when given a sample for the baseline # gender. config.baseline_passthrough = False return config
def load_splits(self): """Load dataset splits using tfds loader.""" self.data_iters = ml_collections.ConfigDict() for key, split in self.splits.items(): logging.info('Loading %s split of the %s dataset.', split.name, self.name) ds, num_examples = self.load_split_from_tfds( name=self.name, batch_size=split['batch_size'], train=split['train'], split=split.name, shuffle_seed=self.shuffle_seed) self.splits[key].num_examples = num_examples self.data_iters[key] = self.create_data_iter( ds, split['batch_size'])
def get_config(): """Get the default hyperparameter configuration.""" config = ml_collections.ConfigDict() config.seed = 1 config.dataset = "imagenet-lt" config.model_name = "resnet50" config.sampling = "uniform" config.add_color_jitter = True config.loss = "ce" config.learning_rate = 0.1 config.learning_rate_schedule = "cosine" config.warmup_epochs = 0 config.sgd_momentum = 0.9 if config.dataset == "imagenet-lt": config.weight_decay = 0.0005 config.num_epochs = 90 elif config.dataset == "inaturalist18": config.weight_decay = 0.0002 config.num_epochs = 200 else: raise ValueError(f"Dataset {config.dataset} not supported.") config.global_batch_size = 128 # If num_train_steps==-1 then the number of training steps is calculated from # num_epochs. config.num_train_steps = -1 config.num_eval_steps = -1 config.log_loss_every_steps = 500 config.eval_every_steps = 1000 config.checkpoint_every_steps = 1000 config.shuffle_buffer_size = 1000 config.trial = 0 # dummy for repeated runs. # Distillation parameters config.proj_dim = -1 config.distill_teacher = "" config.distill_alpha = 0.0 config.distill_fd_beta = 0.0 return config
def get_b16_none(): """Returns the ViT-B/16 configuration.""" config = ml_collections.ConfigDict() config.pretrained_filename = "ViT-B_16.npz" config.image_size = 224 config.patch_size = 16 config.n_layers = 12 config.hidden_size = 768 config.n_heads = 12 config.name = "B16_None" config.mlp_dim = 3072 config.dropout = 0.1 config.filters = 9 config.kernel_size = 1 config.upsampling_factor = 16 config.hybrid = False return config
def set_splits(self): """Define splits of the dataset. For multi environment datasets, we have tree main splits: test, train and valid. Each of these splits is a dict mapping environment id to the information about that particular split of that dataset: self.splits |__Test | |__Eval env 1 | |__Eval env 2 | |__ Train | |__ Train env 1 | |__ Train env 2 | |__ Valied |__Eval env 1 |__Eval env 2 """ train_split_configs = ml_collections.ConfigDict() valid_split_configs = ml_collections.ConfigDict() test_split_configs = ml_collections.ConfigDict() for env_name in self.train_environments: env_id = self.env2id(env_name) train_split_configs[str(env_id)] = ml_collections.ConfigDict( dict(name=self.get_env_split_name('train', env_name), batch_size=self.batch_size, train=True)) for env_name in self.eval_environments: env_id = self.env2id(env_name) valid_split_configs[str(env_id)] = ml_collections.ConfigDict( dict(name=self.get_env_split_name('validation', env_name), batch_size=self.eval_batch_size, train=False)) test_split_configs[str(env_id)] = ml_collections.ConfigDict( dict(name=self.get_env_split_name('test', env_name), batch_size=self.eval_batch_size, train=False)) self.splits = ml_collections.ConfigDict( dict(test=test_split_configs, validation=valid_split_configs, train=train_split_configs))
def get_config(): """Returns a training configuration.""" config = ml_collections.ConfigDict() config.rng_seed = 0 config.num_trajectories = 1 config.single_step_predictions = True config.num_samples = 1000 config.split_on = 'times' config.train_split_proportion = 80 / 1000 config.time_delta = 1. config.train_time_jump_range = (1, 10) config.test_time_jumps = (1, 2, 5, 10, 20, 50) config.num_train_steps = 5000 config.latent_size = 100 config.activation = 'sigmoid' config.model = 'action-angle-network' config.encoder_decoder_type = 'flow' config.flow_type = 'shear' config.num_flow_layers = 10 config.num_coordinates = 2 if config.flow_type == 'masked_coupling': config.flow_spline_range_min = -3 config.flow_spline_range_max = 3 config.flow_spline_bins = 100 config.polar_action_angles = True config.scaler = 'identity' config.learning_rate = 1e-3 config.batch_size = 100 config.eval_cadence = 50 config.simulation = 'shm' config.regularizations = ml_collections.FrozenConfigDict({ 'actions': 1., 'angular_velocities': 0., 'encoded_decoded_differences': 0., }) config.simulation_parameter_ranges = ml_collections.FrozenConfigDict({ 'phi': (0, 0), 'A': (1, 10), 'm': (1, 5), 'w': (0.05, 0.1), }) return config
def test_reference_to_self(self): # Test adding a new field to a configdict which is a reference to an # existing field in the same configdict instance. config = ml_collections.ConfigDict({'parent': 1}) config_schema_utils.set_default_reference(config, config, 'child', parent_field='parent') self.assertEqual(config.child, 1) self.assertEqual(config.parent, 1) config.parent = 5 self.assertEqual(config.parent, 5) self.assertEqual(config.child, 5) config.child = 10 self.assertEqual(config.parent, 5) self.assertEqual(config.child, 10)
def setUp(self): super().setUp() self.config = ml_collections.ConfigDict(self.config) self.model_config = self.config.model_config encoder_config = self.model_config.encoder_config self.max_length = encoder_config.max_length self.max_sample_mentions = self.config.max_sample_mentions self.collater_fn = text_classifier.TextClassifier.make_collater_fn( self.config) self.postprocess_fn = text_classifier.TextClassifier.make_output_postprocess_fn( self.config) model = text_classifier.TextClassifier.build_model(self.model_config) dummy_input = text_classifier.TextClassifier.dummy_input(self.config) init_rng = jax.random.PRNGKey(0) self.init_parameters = model.init(init_rng, dummy_input, True)
def get_lit_l16ti_config(): """Returns a LiT model with ViT-Large and tiny text towers.""" config = ml_collections.ConfigDict() config.model_name = 'LiT-L16Ti' config.out_dim = (None, 1024) config.image = get_l16_config() config.text_model = 'text_transformer' config.text = {} config.text.width = 192 config.text.num_layers = 12 config.text.mlp_dim = 768 config.text.num_heads = 3 config.text.vocab_size = 16_000 config.pp = {} config.pp.tokenizer_name = 'sentencepiece' config.pp.size = 224 config.pp.max_len = 16 return config
def main(_): cfg = ml_collections.ConfigDict() cfg.integer_field = 123 # Locking prohibits the addition and deletion of new fields but allows # modification of existing values. Locking happens automatically during # loading through flags. cfg.lock() try: cfg.intagar_field = 124 # Raises AttributeError and suggests valid field. except AttributeError as e: print(e) cfg.integer_field = -123 # Works fine. with cfg.unlocked(): cfg.intagar_field = 1555 # Works fine too. print(cfg)
def get_config(): """Get the default hyperparameter configuration.""" config = ml_collections.ConfigDict() # The initial learning rate. config.learning_rate = 0.001 # Learning rate decay, applied each optimization step. config.lr_decay = 0.999995 # Batch size to use for data-dependent initialization. config.init_batch_size = 16 # Batch size for training. config.batch_size = 64 # Number of training epochs. config.num_epochs = 200 # Dropout rate. config.dropout_rate = 0.5 # Number of resnet layers per block. config.n_resnet = 5 # Number of features in each conv layer. config.n_feature = 160 # Number of components in the output distribution. config.n_logistic_mix = 10 # Exponential decay rate of the sum of previous model iterates during Polyak # averaging. config.polyak_decay = 0.9995 # Batch size for sampling. config.sample_batch_size = 256 # Random number generator seed for sampling. config.sample_rng_seed = 0 # Integer for PRNG random seed. config.seed = 0 return config
def get_config(): default_config = dict( # General parameters function="", # Specifies the function to be approximated, e.g., SineShirp max=0.0, # Specifies the maximum value on which the function will be evaluated. e.g., -15.0. min=0.0, # Specifies the maximum value on which the function will be evaluated. e.g., 0.0. no_samples=0, # Specifies the number of samples that will be taken from the selected function, # between the min and max values. padding=0, # Specifies an amount of zero padding steps which will be concatenated at the end of # the sequence created by function(min, max, no_samples). optim="", # The optimizer to be used, e.g., Adam. lr=0.0, # The lr to be used, e.g., 0.001. no_iterations=0, # The number of training iterations to be executed, e.g., 20000. seed=0, # The seed of the run. e.g., 0. device="", # The device in which the model will be deployed, e.g., cuda. # Parameters of ConvKernel kernelnet_norm_type="", # If model == CKCNN, the normalization type to be used in the MLPs parameterizing the convolutional # kernels. If kernelnet_activation_function==Sine, no normalization will be used. e.g., LayerNorm. kernelnet_activation_function="", # If model == CKCNN, the activation function used in the MLPs parameterizing the convolutional # kernels. e.g., Sine. kernelnet_no_hidden=0, # If model == CKCNN, the number of hidden units used in the MLPs parameterizing the convolutional # kernels. e.g., 32. # Parameters of SIREN kernelnet_omega_0=0.0, # If model == CKCNN, kernelnet_activation_function==Sine, the value of the omega_0 parameter, e.g., 30. comment="", # An additional comment to be added to the config.path parameter specifying where # the network parameters will be saved / loaded from. ) default_config = ml_collections.ConfigDict(default_config) return default_config
def get_config(): """Get the default hyperparameter configuration.""" config = ml_collections.ConfigDict() # base output directory for experiments (checkpoints, summaries), './output' if not valid config.output_base_dir = '' config.data_dir = '/data/' # use --config.data_dir arg to set without modifying config file config.dataset = 'imagenet2012:5.0.0' config.num_classes = 1000 # FIXME not currently used config.model = 'tf_efficientnet_b0' config.image_size = 0 # set from model defaults if 0 config.batch_size = 224 config.eval_batch_size = 100 # set to config.bach_size if 0 config.lr = 0.016 config.label_smoothing = 0.1 config.weight_decay = 1e-5 # l2 weight penalty added to loss config.ema_decay = .99997 config.opt = 'rmsproptf' config.opt_eps = .001 config.opt_beta1 = 0.9 config.opt_beta2 = 0.9 config.opt_weight_decay = 0. # by default, weight decay not applied in opt, l2 penalty above is used config.lr_schedule = 'step' config.lr_decay_rate = 0.97 config.lr_decay_epochs = 2.4 config.lr_warmup_epochs = 5. config.lr_minimum = 1e-6 config.num_epochs = 450 config.cache = False config.half_precision = True config.drop_rate = 0.2 config.drop_path_rate = 0.2 # If num_train_steps==-1 then the number of training steps is calculated from # num_epochs using the entire dataset. Similarly for steps_per_eval. config.num_train_steps = -1 config.steps_per_eval = -1 return config
def get_cifar_config(): config = mlc.ConfigDict() config.seed = 0 ################### # Train Config config.train = mlc.ConfigDict() config.train.seed = 0 config.train.epochs = 90 config.train.device_batch_size = 64 config.train.log_epochs = 1 # Dataset section config.train.dataset_name = "cifar10" config.train.dataset = mlc.ConfigDict() config.train.dataset.train_split = "train" config.train.dataset.val_split = "test" config.train.dataset.num_classes = 10 config.train.dataset.input_shape = (32, 32, 3) # Model section config.train.model_name = "resnet18" config.train.model = mlc.ConfigDict() # Optimizer section config.train.optim = mlc.ConfigDict() config.train.optim.optax_name = "trace" # momentum config.train.optim.optax = mlc.ConfigDict() config.train.optim.optax.decay = 0.9 config.train.optim.optax.nesterov = False config.train.optim.wd = 1e-4 config.train.optim.wd_mults = [(".*", 1.0)] config.train.optim.grad_clip_norm = 1.0 # Learning rate section config.train.optim.lr = 0.1 config.train.optim.schedule = mlc.ConfigDict() config.train.optim.schedule.warmup_epochs = 5 config.train.optim.schedule.decay_type = "cosine" # Base batch-size being 256. config.train.optim.schedule.scale_with_batchsize = True return config