def get_config():
  """Config for training a patch-transformer on JFT."""
  config = ml_collections.ConfigDict()

  # Directory for the version de-dup'd from BiT downstream test-sets.
  config.dataset = 'imagenet21k'
  config.val_split = 'full[:102400]'
  config.train_split = 'full[102400:]'
  config.num_classes = 21843
  config.init_head_bias = -10.0

  config.trial = 0
  config.batch_size = 1024
  config.num_epochs = 90

  pp_common = '|value_range(-1, 1)'
  config.pp_train = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common
  config.pp_train += f'|onehot({config.num_classes}, on=0.9999, off=0.0001)'
  config.pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common
  config.pp_eval += f'|onehot({config.num_classes})'
  config.shuffle_buffer_size = 250_000  # Per host, so small-ish is ok.

  config.log_training_steps = 1000
  config.log_eval_steps = 10000
  # NOTE: eval is very fast O(seconds) so it's fine to run it often.
  config.checkpoint_steps = 17250
  config.checkpoint_timeout = 10

  # Model section
  config.model = ml_collections.ConfigDict()
  config.model.patches = ml_collections.ConfigDict()
  config.model.patches.size = [16, 16]
  config.model.hidden_size = 768
  config.model.transformer = ml_collections.ConfigDict()
  config.model.transformer.attention_dropout_rate = 0.
  config.model.transformer.dropout_rate = 0.1
  config.model.transformer.mlp_dim = 3072
  config.model.transformer.num_heads = 12
  config.model.transformer.num_layers = 12
  config.model.classifier = 'token'  # Or 'gap'
  config.model.representation_size = 768
  # Optimizer section
  config.optim_name = 'Adam'
  config.optim = ml_collections.ConfigDict()
  config.optim.weight_decay = 0.03

  # TODO(lbeyer): make a mini-language like preprocessings.
  config.lr = ml_collections.ConfigDict()
  config.lr.base = 0.001  # LR has to be lower for larger models!
  config.lr.warmup_steps = 10_000
  config.lr.decay_type = 'linear'
  config.lr.linear_end = 1e-5

  # Few-shot eval section
  config.fewshot = common_fewshot.get_fewshot()
  config.fewshot.log_steps = 10_000
  return config
示例#2
0
def get_config():
  """Config."""
  config = ml_collections.ConfigDict()

  config.dataset = 'jft/entity:1.0.0'
  config.val_split = 'test[:49511]'  # aka tiny_test/test[:5%] in task_adapt
  config.train_split = 'train'  # task_adapt used train+validation so +64167
  config.num_classes = 18291
  config.init_head_bias = -10.0    # ~= ln(1/18k) ~= ln(1/num_classes)

  config.trial = 0
  config.batch_size = 4096
  config.num_epochs = 7

  pp_common = '|value_range(-1, 1)'
  pp_common += f'|onehot({config.num_classes})'
  # To use ancestor 'smearing', use this line instead:
  # pp_common += f'|onehot({config.num_classes}, key='labels_extended', key_result='labels')  # pylint: disable=line-too-long
  pp_common += '|keep(["image", "labels"])'
  config.pp_train = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common
  config.pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common
  config.shuffle_buffer_size = 250_000  # Per host, so small-ish is ok.

  config.log_training_steps = 5000
  config.log_eval_steps = 10000
  config.checkpoint_steps = 15000
  config.checkpoint_timeout = 10

  # Model section
  config.model = ml_collections.ConfigDict()
  config.model.patches = ml_collections.ConfigDict()
  config.model.patches.size = [32, 32]
  config.model.hidden_size = 1024
  config.model.transformer = ml_collections.ConfigDict()
  config.model.transformer.attention_dropout_rate = 0.
  config.model.transformer.dropout_rate = 0.
  config.model.transformer.mlp_dim = 4096
  config.model.transformer.num_heads = 16
  config.model.transformer.num_layers = 24
  config.model.classifier = 'token'  # Or 'gap'
  config.model.representation_size = 1024

  # Heteroscedastic
  config.model.multiclass = False
  config.model.temperature = 0.2
  config.model.mc_samples = 100
  config.model.num_factors = 50
  config.model.param_efficient = True

  # BatchEnsemble
  config.model.transformer.be_layers = (21, 22, 23)
  config.model.transformer.ens_size = 3
  config.model.transformer.random_sign_init = -0.5

  # GP
  config.model.use_gp = False
  # Use momentum-based (i.e., non-exact) covariance update for pre-training.
  # This is because the exact covariance update can be unstable for pretraining,
  # since it involves inverting a precision matrix accumulated over 300M data.
  config.model.covmat_momentum = .999
  config.model.ridge_penalty = 1.
  # No need to use mean field adjustment for pretraining.
  config.model.mean_field_factor = -1.

  # Optimizer section
  config.optim_name = 'Adam'
  config.optim = ml_collections.ConfigDict()
  config.optim.weight_decay = 0.1
  config.optim.beta1 = 0.9
  config.optim.beta2 = 0.999
  config.weight_decay = None  # No explicit weight decay
  config.grad_clip_norm = 1.0  # setting from mark's rec
  config.lr = ml_collections.ConfigDict()
  config.lr.base = 6e-4  # LR has to be lower for larger models!
  config.lr.warmup_steps = 10_000
  config.lr.decay_type = 'linear'
  config.lr.linear_end = 1e-5

  # Few-shot eval section
  config.fewshot = common_fewshot.get_fewshot()
  config.fewshot.representation_layer = 'pre_ens_logits'
  config.fewshot.log_steps = 100_000
  return config
示例#3
0
def get_config():
    """Config."""
    config = ml_collections.ConfigDict()

    config.seed = 0

    config.dataset = 'imagenet21k'
    config.val_split = 'full[:102400]'
    config.train_split = 'full[102400:]'
    config.num_classes = 21843
    config.init_head_bias = -10.0

    config.trial = 0
    config.batch_size = 4096
    config.num_epochs = 90

    pp_common = '|value_range(-1, 1)'
    config.pp_train = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common
    config.pp_train += f'|onehot({config.num_classes}, on=0.9999, off=0.0001)'
    config.pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common
    config.pp_eval += f'|onehot({config.num_classes})'
    config.shuffle_buffer_size = 250_000  # Per host, so small-ish is ok.

    config.log_training_steps = 50
    config.log_eval_steps = 1000
    config.checkpoint_steps = 5000

    # Model section
    config.model = ml_collections.ConfigDict()
    config.model.patches = ml_collections.ConfigDict()
    config.model.patches.size = [32, 32]
    config.model.hidden_size = 1024
    config.model.transformer = ml_collections.ConfigDict()
    config.model.transformer.attention_dropout_rate = 0.
    config.model.transformer.dropout_rate = 0.1
    config.model.transformer.mlp_dim = 4096
    config.model.transformer.num_heads = 16
    config.model.transformer.num_layers = 24
    config.model.classifier = 'token'  # Or 'gap'
    config.model.representation_size = 1024

    # BatchEnsemble parameters.
    config.model.transformer.be_layers = (21, 22, 23)
    config.model.transformer.ens_size = 3
    config.model.transformer.random_sign_init = -0.5
    config.fast_weight_lr_multiplier = 1.0

    # Optimizer section
    config.optim_name = 'Adam'
    config.optim = ml_collections.ConfigDict()
    config.optim.beta1 = 0.9
    config.optim.beta2 = 0.999
    config.weight_decay = 0.1
    config.grad_clip_norm = 1.0

    config.lr = ml_collections.ConfigDict()
    config.lr.base = 0.001  # LR has to be lower for larger models!
    config.lr.warmup_steps = 10_000
    config.lr.decay_type = 'linear'
    config.lr.linear_end = 1e-5

    # Few-shot eval section
    config.fewshot = common_fewshot.get_fewshot()
    config.fewshot.log_steps = 25_000
    return config
示例#4
0
def get_config():
  """Config for training a patch-transformer on JFT."""
  config = ml_collections.ConfigDict()

  config.seed = 0

  # Directory for the version de-dup'd from BiT downstream test-sets.
  config.dataset = 'jft/entity:1.0.0'
  config.val_split = 'test[:49511]'  # aka tiny_test/test[:5%] in task_adapt
  config.train_split = 'train'  # task_adapt used train+validation so +64167
  config.num_classes = 18291
  config.init_head_bias = -10.0

  config.trial = 0
  config.batch_size = 4096
  config.num_epochs = 7

  pp_common = '|value_range(-1, 1)'
  pp_common += f'|onehot({config.num_classes})'
  # To use ancestor "smearing", use this line instead:
  # pp_common += f'|onehot({config.num_classes}, key="labels_extended", key_result="labels")  # pylint: disable=line-too-long
  pp_common += '|keep(["image", "labels"])'
  config.pp_train = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common
  config.pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common
  config.shuffle_buffer_size = 250_000  # Per host, so small-ish is ok.

  config.log_training_steps = 10000
  config.log_eval_steps = 73230  # ~= steps_per_epoch
  # NOTE: Save infrequently to prevent crowding the disk space.
  config.checkpoint_steps = 17250
  config.checkpoint_timeout = 10

  # Model section
  config.model = ml_collections.ConfigDict()
  config.model.patches = ml_collections.ConfigDict()
  config.model.patches.size = [32, 32]
  config.model.hidden_size = 1024
  config.model.transformer = ml_collections.ConfigDict()
  config.model.transformer.attention_dropout_rate = 0.
  config.model.transformer.dropout_rate = 0.
  config.model.transformer.mlp_dim = 4096
  config.model.transformer.num_heads = 16
  config.model.transformer.num_layers = 24
  config.model.classifier = 'token'  # Or 'gap'
  config.model.representation_size = 1024

  # Heteroscedastic
  config.het = ml_collections.ConfigDict()
  config.het.multiclass = False
  config.het.temperature = 1.5
  config.het.mc_samples = 1000
  config.het.num_factors = 50
  config.het.param_efficient = True

  # Gaussian process layer section
  config.gp_layer = ml_collections.ConfigDict()
  # Use momentum-based (i.e., non-exact) covariance update for pre-training.
  # This is because the exact covariance update can be unstable for pretraining,
  # since it involves inverting a precision matrix accumulated over 300M data.
  config.gp_layer.covmat_momentum = .999
  config.gp_layer.ridge_penalty = 1.
  # No need to use mean field adjustment for pretraining.
  config.gp_layer.mean_field_factor = -1.

  # Optimizer section
  config.optim_name = 'Adam'
  config.optim = ml_collections.ConfigDict()
  config.optim.weight_decay = 0.1
  config.optim.beta1 = 0.9
  config.optim.beta2 = 0.999
  config.weight_decay = None  # No explicit weight decay
  config.grad_clip_norm = 1.0

  # TODO(lbeyer): make a mini-language like preprocessings.
  config.lr = ml_collections.ConfigDict()
  config.lr.base = 6e-4  # LR has to be lower for larger models!
  config.lr.warmup_steps = 10_000
  config.lr.decay_type = 'linear'
  config.lr.linear_end = 1e-5

  # Few-shot eval section
  config.fewshot = common_fewshot.get_fewshot()
  config.fewshot.log_steps = 50_000
  return config
def get_config():
    """Config for training a patch-transformer on JFT."""
    config = ml_collections.ConfigDict()

    config.seed = 0

    # Directory for the version de-dup'd from BiT downstream test-sets.
    config.dataset = 'jft/entity:1.0.0'
    config.val_split = 'test[:49511]'  # aka tiny_test/test[:5%] in task_adapt
    config.train_split = 'train'  # task_adapt used train+validation so +64167
    config.num_classes = 18291
    config.init_head_bias = -10.0

    config.trial = 0
    config.batch_size = 4096
    config.num_epochs = 5

    pp_common = '|value_range(-1, 1)'
    pp_common += f'|onehot({config.num_classes})'
    # To use ancestor 'smearing', use this line instead:
    # pp_common += f'|onehot({config.num_classes}, key='labels_extended', key_result='labels')  # pylint: disable=line-too-long
    pp_common += '|keep(["image", "labels"])'
    config.pp_train = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common
    config.pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common
    config.shuffle_buffer_size = 250_000  # Per host, so small-ish is ok.

    config.log_training_steps = 50
    config.log_eval_steps = 1000
    # NOTE: eval is very fast O(seconds) so it's fine to run it often.
    config.checkpoint_steps = 1000

    # Model section
    config.model = ml_collections.ConfigDict()
    config.model.patches = ml_collections.ConfigDict()
    config.model.patches.size = [32, 32]
    config.model.hidden_size = 512
    config.model.transformer = ml_collections.ConfigDict()
    config.model.transformer.attention_dropout_rate = 0.
    config.model.transformer.dropout_rate = 0.
    config.model.transformer.mlp_dim = 2048
    config.model.transformer.num_heads = 8
    config.model.transformer.num_layers = 8
    config.model.classifier = 'token'  # Or 'gap'
    config.model.representation_size = 512

    # Optimizer section
    config.optim_name = 'Adam'
    config.optim = ml_collections.ConfigDict()
    config.optim.weight_decay = 0.1
    config.optim.beta1 = 0.9
    config.optim.beta2 = 0.999
    config.weight_decay = None  # No explicit weight decay

    # TODO(lbeyer): make a mini-language like preprocessings.
    config.lr = ml_collections.ConfigDict()
    config.lr.base = 0.001
    config.lr.warmup_steps = 10_000
    config.lr.decay_type = 'linear'
    config.lr.linear_end = 1e-5

    # Few-shot eval section
    config.fewshot = common_fewshot.get_fewshot()
    config.fewshot.log_steps = 25_000
    return config
示例#6
0
def get_config():
    """Config."""
    config = ml_collections.ConfigDict()

    config.seed = 0

    # JFT parameters.
    config.dataset = 'jft/entity:1.0.0'
    config.val_split = 'test[:49511]'  # aka tiny_test/test[:5%] in task_adapt
    config.train_split = 'train'  # task_adapt used train+validation so +64167
    config.num_classes = 18291
    config.init_head_bias = -10.0  # ~= ln(1/18k) ~= ln(1/num_classes)

    pp_common = '|value_range(-1, 1)'
    pp_common += f'|onehot({config.num_classes})'
    # To use ancestor 'smearing', use this line instead:
    # pp_common += f'|onehot({config.num_classes}, key='labels_extended', key_result='labels')  # pylint: disable=line-too-long
    pp_common += '|keep(["image", "labels"])'
    config.pp_train = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common
    config.pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common
    config.shuffle_buffer_size = 250_000  # Per host, so small-ish is ok.

    # Model parameters.
    config.model = ml_collections.ConfigDict()
    config.model.patches = ml_collections.ConfigDict()
    config.model.patches.size = [32, 32]
    config.model.hidden_size = 768
    config.model.transformer = ml_collections.ConfigDict()
    config.model.transformer.attention_dropout_rate = 0.
    config.model.transformer.dropout_rate = 0.
    config.model.transformer.mlp_dim = 3072
    config.model.transformer.num_heads = 12
    config.model.transformer.num_layers = 12
    config.model.classifier = 'token'  # Or 'gap'
    config.model.representation_size = 768

    # Optimizer section
    config.optim_name = 'Adam'
    config.optim = ml_collections.ConfigDict()
    config.optim.beta1 = 0.9
    config.optim.beta2 = 0.999
    config.weight_decay = 0.1

    config.lr = ml_collections.ConfigDict()
    config.lr.base = 8e-4  # LR likely has to be lower for larger models!
    config.lr.warmup_steps = 10_000
    config.lr.decay_type = 'linear'
    config.lr.linear_end = 1e-5
    config.disable_preemption_reproducibility = True

    config.batch_size = 4096  # Global batch size.
    config.num_epochs = 7

    config.log_training_steps = 50
    config.log_eval_steps = 1000

    config.checkpoint_steps = 5000
    config.checkpoint_timeout = 10

    config.prefetch_to_device = 2
    config.trial = 0

    # Few-shot eval section
    config.fewshot = common_fewshot.get_fewshot()
    config.fewshot.log_steps = 25_000
    return config
def get_config():
  """Config for training a patch-transformer on JFT."""
  config = ml_collections.ConfigDict()

  config.seed = 0

  config.dataset = 'imagenet21k'
  config.val_split = 'full[:102400]'
  config.train_split = 'full[102400:]'
  config.num_classes = 21843
  config.init_head_bias = -10.0

  config.trial = 0
  config.batch_size = 4096
  config.num_epochs = 90

  pp_common = '|value_range(-1, 1)'
  config.pp_train = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common
  config.pp_train += f'|onehot({config.num_classes}, on=0.9999, off=0.0001)'
  config.pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common
  config.pp_eval += f'|onehot({config.num_classes})'
  config.shuffle_buffer_size = 250_000  # Per host, so small-ish is ok.

  config.log_training_steps = 10000
  config.log_eval_steps = 3003  # ~= steps_per_epoch
  # NOTE: Save infrequently to prevent crowding the disk space.
  config.checkpoint_steps = 17250
  config.checkpoint_timeout = 10

  # Model section
  config.model = ml_collections.ConfigDict()
  config.model.patches = ml_collections.ConfigDict()
  config.model.patches.size = [32, 32]
  config.model.hidden_size = 1024
  config.model.transformer = ml_collections.ConfigDict()
  config.model.transformer.attention_dropout_rate = 0.
  config.model.transformer.dropout_rate = 0.1
  config.model.transformer.mlp_dim = 4096
  config.model.transformer.num_heads = 16
  config.model.transformer.num_layers = 24
  config.model.classifier = 'token'  # Or 'gap'
  config.model.representation_size = 1024

  # Heteroscedastic
  config.het = ml_collections.ConfigDict()
  config.het.multiclass = False
  config.het.temperature = 1.5
  config.het.mc_samples = 1000
  config.het.num_factors = 50
  config.het.param_efficient = True

  # Gaussian process layer section
  config.gp_layer = ml_collections.ConfigDict()
  # Use momentum-based (i.e., non-exact) covariance update for pre-training.
  # This is because the exact covariance update can be unstable for pretraining,
  # since it involves inverting a precision matrix accumulated over 300M data.
  config.gp_layer.covmat_momentum = .999
  config.gp_layer.ridge_penalty = 1.
  # No need to use mean field adjustment for pretraining.
  config.gp_layer.mean_field_factor = -1.

  # Optimizer section
  config.optim_name = 'Adam'
  config.optim = ml_collections.ConfigDict()
  config.optim.weight_decay = 0.03
  config.grad_clip_norm = 1.0
  config.optim.beta1 = 0.9
  config.optim.beta2 = 0.999

  # TODO(lbeyer): make a mini-language like preprocessings.
  config.lr = ml_collections.ConfigDict()
  # LR has to be lower for GP layer and on larger models.
  config.lr.base = 6e-4  # LR has to be lower for larger models!
  config.lr.warmup_steps = 10_000
  config.lr.decay_type = 'linear'
  config.lr.linear_end = 1e-5

  # Few-shot eval section
  config.fewshot = common_fewshot.get_fewshot()
  config.fewshot.log_steps = 50_000
  return config
def get_config():
    """Config."""
    config = ml_collections.ConfigDict()

    config.seed = 0

    config.dataset = 'jft/entity:1.0.0'
    config.val_split = 'test[:49511]'  # aka tiny_test/test[:5%] in task_adapt
    config.train_split = 'train'  # task_adapt used train+validation so +64167
    config.num_classes = 18291
    config.init_head_bias = -10.0

    config.resume = '/cns/tp-d/home/trandustin/baselines-jft-0211_032549/1/checkpoint.npz'
    config.trial = 0
    config.batch_size = 4096
    config.num_epochs = 14
    config.prefetch_to_device = 2
    # TODO(trandustin): To resume properly, I removed this setting. Not sure what
    # it's doing.
    # config.disable_preemption_reproducibility = True

    pp_common = '|value_range(-1, 1)'
    pp_common += f'|onehot({config.num_classes})'
    pp_common += '|keep(["image", "labels"])'
    config.pp_train = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common
    config.pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common
    config.shuffle_buffer_size = 250_000  # Per host, so small-ish is ok.

    config.log_training_steps = 50
    config.log_eval_steps = 1000
    config.checkpoint_steps = 5000
    config.checkpoint_timeout = 10

    # Model section
    config.model = ml_collections.ConfigDict()
    config.model.patches = ml_collections.ConfigDict()
    config.model.patches.size = [14, 14]
    config.model.hidden_size = 1280
    config.model.transformer = ml_collections.ConfigDict()
    config.model.transformer.attention_dropout_rate = 0.
    config.model.transformer.dropout_rate = 0.
    config.model.transformer.mlp_dim = 5120
    config.model.transformer.num_heads = 16
    config.model.transformer.num_layers = 32
    config.model.classifier = 'token'  # Or 'gap'
    config.model.representation_size = 1280

    # BatchEnsemble section
    # Using last n=5 layers was chosen somewhat arbitrarily, >3 from L/32.
    config.model.transformer.be_layers = (27, 28, 29, 30, 31)
    config.model.transformer.ens_size = 3
    config.model.transformer.random_sign_init = 0.5
    config.fast_weight_lr_multiplier = 1.0

    # Optimizer section
    # We use Adam HP to lower memory.
    config.optim_name = 'adam_hp'
    config.optim = ml_collections.ConfigDict()
    config.optim.beta1 = 0.9
    config.optim.beta2 = 0.999
    config.weight_decay = 0.1
    config.grad_clip_norm = 10.0

    config.lr = ml_collections.ConfigDict()
    # Note original ViT-H/14 uses 4e-4 and no grad clip norm until 130K steps,
    # then 3e-4 and grad_clip_norm=10.0 after.
    config.lr.base = 3e-4  # LR has to be lower for larger models!
    config.lr.warmup_steps = 10_000
    config.lr.decay_type = 'linear'
    config.lr.linear_end = 1e-5

    # Few-shot eval section
    config.fewshot = common_fewshot.get_fewshot()
    config.fewshot.log_steps = 25_000
    return config
示例#9
0
def get_config():
    """Config for training on JFT300M. Batch size 4096 fits on DF4x4."""
    config = ml_collections.ConfigDict()

    config.seed = 0

    # JFT parameters.
    config.dataset = 'jft/entity:1.0.0'
    config.val_split = 'test[:49511]'  # aka tiny_test/test[:5%] in task_adapt
    config.train_split = 'train'  # task_adapt used train+validation so +64167
    config.num_classes = 18291
    config.init_head_bias = -10.0  # ~= ln(1/18k) ~= ln(1/num_classes)

    pp_common = '|value_range(-1, 1)'
    pp_common += f'|onehot({config.num_classes})'
    # To use ancestor 'smearing', use this line instead:
    # pp_common += f'|onehot({config.num_classes}, key='labels_extended', key_result='labels')  # pylint: disable=line-too-long
    pp_common += '|keep(["image", "labels"])'
    config.pp_train = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common
    config.pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common
    config.shuffle_buffer_size = 250_000  # Per host, so small-ish is ok.

    # Model section
    config.model = ml_collections.ConfigDict()
    config.model.patches = ml_collections.ConfigDict()
    config.model.patches.size = [32, 32]
    config.model.hidden_size = 512
    config.model.representation_size = 512
    config.model.classifier = 'token'
    config.model.transformer = ml_collections.ConfigDict()
    config.model.transformer.num_layers = 8
    config.model.transformer.dropout_rate = 0.0
    config.model.transformer.mlp_dim = 2048
    config.model.transformer.num_heads = 8
    config.model.transformer.attention_dropout_rate = 0.0

    # BatchEnsemble parameters
    config.model.transformer.be_layers = (5, 6, 7)
    config.model.transformer.ens_size = 2
    config.model.transformer.random_sign_init = -0.5
    config.fast_weight_lr_multiplier = 1.0

    # Optimizer parameters.
    config.optim_name = 'Adam'
    config.optim = ml_collections.ConfigDict(dict(beta1=0.9, beta2=0.999))
    config.weight_decay = 0.1
    # TODO(trandustin): Potentially add weight decay only on slow weights, similar
    # to original BE and ghassen's BE ViT code.
    # config.weight_decay = [0.1]
    # config.weight_decay_pattern = [".*/kernel"]  # Does not decay fast-weights.
    config.grad_clip_norm = None

    config.lr = ml_collections.ConfigDict()
    config.lr.base = 1e-3  # LR likely has to be lower for larger models!
    config.lr.warmup_steps = 10_000
    config.lr.decay_type = 'linear'
    config.lr.linear_end = 1e-5

    config.batch_size = 4096  # Global batch size.
    config.num_epochs = 5

    config.log_training_steps = 50
    config.log_eval_steps = 1000

    config.checkpoint_steps = 5000
    config.checkpoint_timeout = 10

    config.prefetch_to_device = 2
    config.trial = 0

    # Few-shot eval section
    config.fewshot = common_fewshot.get_fewshot()
    config.fewshot.log_steps = 25_000
    return config