def get_config(): """Config for training a patch-transformer on JFT.""" config = ml_collections.ConfigDict() # Directory for the version de-dup'd from BiT downstream test-sets. config.dataset = 'imagenet21k' config.val_split = 'full[:102400]' config.train_split = 'full[102400:]' config.num_classes = 21843 config.init_head_bias = -10.0 config.trial = 0 config.batch_size = 1024 config.num_epochs = 90 pp_common = '|value_range(-1, 1)' config.pp_train = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common config.pp_train += f'|onehot({config.num_classes}, on=0.9999, off=0.0001)' config.pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common config.pp_eval += f'|onehot({config.num_classes})' config.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok. config.log_training_steps = 1000 config.log_eval_steps = 10000 # NOTE: eval is very fast O(seconds) so it's fine to run it often. config.checkpoint_steps = 17250 config.checkpoint_timeout = 10 # Model section config.model = ml_collections.ConfigDict() config.model.patches = ml_collections.ConfigDict() config.model.patches.size = [16, 16] config.model.hidden_size = 768 config.model.transformer = ml_collections.ConfigDict() config.model.transformer.attention_dropout_rate = 0. config.model.transformer.dropout_rate = 0.1 config.model.transformer.mlp_dim = 3072 config.model.transformer.num_heads = 12 config.model.transformer.num_layers = 12 config.model.classifier = 'token' # Or 'gap' config.model.representation_size = 768 # Optimizer section config.optim_name = 'Adam' config.optim = ml_collections.ConfigDict() config.optim.weight_decay = 0.03 # TODO(lbeyer): make a mini-language like preprocessings. config.lr = ml_collections.ConfigDict() config.lr.base = 0.001 # LR has to be lower for larger models! config.lr.warmup_steps = 10_000 config.lr.decay_type = 'linear' config.lr.linear_end = 1e-5 # Few-shot eval section config.fewshot = common_fewshot.get_fewshot() config.fewshot.log_steps = 10_000 return config
def get_config(): """Config.""" config = ml_collections.ConfigDict() config.dataset = 'jft/entity:1.0.0' config.val_split = 'test[:49511]' # aka tiny_test/test[:5%] in task_adapt config.train_split = 'train' # task_adapt used train+validation so +64167 config.num_classes = 18291 config.init_head_bias = -10.0 # ~= ln(1/18k) ~= ln(1/num_classes) config.trial = 0 config.batch_size = 4096 config.num_epochs = 7 pp_common = '|value_range(-1, 1)' pp_common += f'|onehot({config.num_classes})' # To use ancestor 'smearing', use this line instead: # pp_common += f'|onehot({config.num_classes}, key='labels_extended', key_result='labels') # pylint: disable=line-too-long pp_common += '|keep(["image", "labels"])' config.pp_train = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common config.pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common config.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok. config.log_training_steps = 5000 config.log_eval_steps = 10000 config.checkpoint_steps = 15000 config.checkpoint_timeout = 10 # Model section config.model = ml_collections.ConfigDict() config.model.patches = ml_collections.ConfigDict() config.model.patches.size = [32, 32] config.model.hidden_size = 1024 config.model.transformer = ml_collections.ConfigDict() config.model.transformer.attention_dropout_rate = 0. config.model.transformer.dropout_rate = 0. config.model.transformer.mlp_dim = 4096 config.model.transformer.num_heads = 16 config.model.transformer.num_layers = 24 config.model.classifier = 'token' # Or 'gap' config.model.representation_size = 1024 # Heteroscedastic config.model.multiclass = False config.model.temperature = 0.2 config.model.mc_samples = 100 config.model.num_factors = 50 config.model.param_efficient = True # BatchEnsemble config.model.transformer.be_layers = (21, 22, 23) config.model.transformer.ens_size = 3 config.model.transformer.random_sign_init = -0.5 # GP config.model.use_gp = False # Use momentum-based (i.e., non-exact) covariance update for pre-training. # This is because the exact covariance update can be unstable for pretraining, # since it involves inverting a precision matrix accumulated over 300M data. config.model.covmat_momentum = .999 config.model.ridge_penalty = 1. # No need to use mean field adjustment for pretraining. config.model.mean_field_factor = -1. # Optimizer section config.optim_name = 'Adam' config.optim = ml_collections.ConfigDict() config.optim.weight_decay = 0.1 config.optim.beta1 = 0.9 config.optim.beta2 = 0.999 config.weight_decay = None # No explicit weight decay config.grad_clip_norm = 1.0 # setting from mark's rec config.lr = ml_collections.ConfigDict() config.lr.base = 6e-4 # LR has to be lower for larger models! config.lr.warmup_steps = 10_000 config.lr.decay_type = 'linear' config.lr.linear_end = 1e-5 # Few-shot eval section config.fewshot = common_fewshot.get_fewshot() config.fewshot.representation_layer = 'pre_ens_logits' config.fewshot.log_steps = 100_000 return config
def get_config(): """Config.""" config = ml_collections.ConfigDict() config.seed = 0 config.dataset = 'imagenet21k' config.val_split = 'full[:102400]' config.train_split = 'full[102400:]' config.num_classes = 21843 config.init_head_bias = -10.0 config.trial = 0 config.batch_size = 4096 config.num_epochs = 90 pp_common = '|value_range(-1, 1)' config.pp_train = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common config.pp_train += f'|onehot({config.num_classes}, on=0.9999, off=0.0001)' config.pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common config.pp_eval += f'|onehot({config.num_classes})' config.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok. config.log_training_steps = 50 config.log_eval_steps = 1000 config.checkpoint_steps = 5000 # Model section config.model = ml_collections.ConfigDict() config.model.patches = ml_collections.ConfigDict() config.model.patches.size = [32, 32] config.model.hidden_size = 1024 config.model.transformer = ml_collections.ConfigDict() config.model.transformer.attention_dropout_rate = 0. config.model.transformer.dropout_rate = 0.1 config.model.transformer.mlp_dim = 4096 config.model.transformer.num_heads = 16 config.model.transformer.num_layers = 24 config.model.classifier = 'token' # Or 'gap' config.model.representation_size = 1024 # BatchEnsemble parameters. config.model.transformer.be_layers = (21, 22, 23) config.model.transformer.ens_size = 3 config.model.transformer.random_sign_init = -0.5 config.fast_weight_lr_multiplier = 1.0 # Optimizer section config.optim_name = 'Adam' config.optim = ml_collections.ConfigDict() config.optim.beta1 = 0.9 config.optim.beta2 = 0.999 config.weight_decay = 0.1 config.grad_clip_norm = 1.0 config.lr = ml_collections.ConfigDict() config.lr.base = 0.001 # LR has to be lower for larger models! config.lr.warmup_steps = 10_000 config.lr.decay_type = 'linear' config.lr.linear_end = 1e-5 # Few-shot eval section config.fewshot = common_fewshot.get_fewshot() config.fewshot.log_steps = 25_000 return config
def get_config(): """Config for training a patch-transformer on JFT.""" config = ml_collections.ConfigDict() config.seed = 0 # Directory for the version de-dup'd from BiT downstream test-sets. config.dataset = 'jft/entity:1.0.0' config.val_split = 'test[:49511]' # aka tiny_test/test[:5%] in task_adapt config.train_split = 'train' # task_adapt used train+validation so +64167 config.num_classes = 18291 config.init_head_bias = -10.0 config.trial = 0 config.batch_size = 4096 config.num_epochs = 7 pp_common = '|value_range(-1, 1)' pp_common += f'|onehot({config.num_classes})' # To use ancestor "smearing", use this line instead: # pp_common += f'|onehot({config.num_classes}, key="labels_extended", key_result="labels") # pylint: disable=line-too-long pp_common += '|keep(["image", "labels"])' config.pp_train = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common config.pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common config.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok. config.log_training_steps = 10000 config.log_eval_steps = 73230 # ~= steps_per_epoch # NOTE: Save infrequently to prevent crowding the disk space. config.checkpoint_steps = 17250 config.checkpoint_timeout = 10 # Model section config.model = ml_collections.ConfigDict() config.model.patches = ml_collections.ConfigDict() config.model.patches.size = [32, 32] config.model.hidden_size = 1024 config.model.transformer = ml_collections.ConfigDict() config.model.transformer.attention_dropout_rate = 0. config.model.transformer.dropout_rate = 0. config.model.transformer.mlp_dim = 4096 config.model.transformer.num_heads = 16 config.model.transformer.num_layers = 24 config.model.classifier = 'token' # Or 'gap' config.model.representation_size = 1024 # Heteroscedastic config.het = ml_collections.ConfigDict() config.het.multiclass = False config.het.temperature = 1.5 config.het.mc_samples = 1000 config.het.num_factors = 50 config.het.param_efficient = True # Gaussian process layer section config.gp_layer = ml_collections.ConfigDict() # Use momentum-based (i.e., non-exact) covariance update for pre-training. # This is because the exact covariance update can be unstable for pretraining, # since it involves inverting a precision matrix accumulated over 300M data. config.gp_layer.covmat_momentum = .999 config.gp_layer.ridge_penalty = 1. # No need to use mean field adjustment for pretraining. config.gp_layer.mean_field_factor = -1. # Optimizer section config.optim_name = 'Adam' config.optim = ml_collections.ConfigDict() config.optim.weight_decay = 0.1 config.optim.beta1 = 0.9 config.optim.beta2 = 0.999 config.weight_decay = None # No explicit weight decay config.grad_clip_norm = 1.0 # TODO(lbeyer): make a mini-language like preprocessings. config.lr = ml_collections.ConfigDict() config.lr.base = 6e-4 # LR has to be lower for larger models! config.lr.warmup_steps = 10_000 config.lr.decay_type = 'linear' config.lr.linear_end = 1e-5 # Few-shot eval section config.fewshot = common_fewshot.get_fewshot() config.fewshot.log_steps = 50_000 return config
def get_config(): """Config for training a patch-transformer on JFT.""" config = ml_collections.ConfigDict() config.seed = 0 # Directory for the version de-dup'd from BiT downstream test-sets. config.dataset = 'jft/entity:1.0.0' config.val_split = 'test[:49511]' # aka tiny_test/test[:5%] in task_adapt config.train_split = 'train' # task_adapt used train+validation so +64167 config.num_classes = 18291 config.init_head_bias = -10.0 config.trial = 0 config.batch_size = 4096 config.num_epochs = 5 pp_common = '|value_range(-1, 1)' pp_common += f'|onehot({config.num_classes})' # To use ancestor 'smearing', use this line instead: # pp_common += f'|onehot({config.num_classes}, key='labels_extended', key_result='labels') # pylint: disable=line-too-long pp_common += '|keep(["image", "labels"])' config.pp_train = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common config.pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common config.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok. config.log_training_steps = 50 config.log_eval_steps = 1000 # NOTE: eval is very fast O(seconds) so it's fine to run it often. config.checkpoint_steps = 1000 # Model section config.model = ml_collections.ConfigDict() config.model.patches = ml_collections.ConfigDict() config.model.patches.size = [32, 32] config.model.hidden_size = 512 config.model.transformer = ml_collections.ConfigDict() config.model.transformer.attention_dropout_rate = 0. config.model.transformer.dropout_rate = 0. config.model.transformer.mlp_dim = 2048 config.model.transformer.num_heads = 8 config.model.transformer.num_layers = 8 config.model.classifier = 'token' # Or 'gap' config.model.representation_size = 512 # Optimizer section config.optim_name = 'Adam' config.optim = ml_collections.ConfigDict() config.optim.weight_decay = 0.1 config.optim.beta1 = 0.9 config.optim.beta2 = 0.999 config.weight_decay = None # No explicit weight decay # TODO(lbeyer): make a mini-language like preprocessings. config.lr = ml_collections.ConfigDict() config.lr.base = 0.001 config.lr.warmup_steps = 10_000 config.lr.decay_type = 'linear' config.lr.linear_end = 1e-5 # Few-shot eval section config.fewshot = common_fewshot.get_fewshot() config.fewshot.log_steps = 25_000 return config
def get_config(): """Config.""" config = ml_collections.ConfigDict() config.seed = 0 # JFT parameters. config.dataset = 'jft/entity:1.0.0' config.val_split = 'test[:49511]' # aka tiny_test/test[:5%] in task_adapt config.train_split = 'train' # task_adapt used train+validation so +64167 config.num_classes = 18291 config.init_head_bias = -10.0 # ~= ln(1/18k) ~= ln(1/num_classes) pp_common = '|value_range(-1, 1)' pp_common += f'|onehot({config.num_classes})' # To use ancestor 'smearing', use this line instead: # pp_common += f'|onehot({config.num_classes}, key='labels_extended', key_result='labels') # pylint: disable=line-too-long pp_common += '|keep(["image", "labels"])' config.pp_train = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common config.pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common config.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok. # Model parameters. config.model = ml_collections.ConfigDict() config.model.patches = ml_collections.ConfigDict() config.model.patches.size = [32, 32] config.model.hidden_size = 768 config.model.transformer = ml_collections.ConfigDict() config.model.transformer.attention_dropout_rate = 0. config.model.transformer.dropout_rate = 0. config.model.transformer.mlp_dim = 3072 config.model.transformer.num_heads = 12 config.model.transformer.num_layers = 12 config.model.classifier = 'token' # Or 'gap' config.model.representation_size = 768 # Optimizer section config.optim_name = 'Adam' config.optim = ml_collections.ConfigDict() config.optim.beta1 = 0.9 config.optim.beta2 = 0.999 config.weight_decay = 0.1 config.lr = ml_collections.ConfigDict() config.lr.base = 8e-4 # LR likely has to be lower for larger models! config.lr.warmup_steps = 10_000 config.lr.decay_type = 'linear' config.lr.linear_end = 1e-5 config.disable_preemption_reproducibility = True config.batch_size = 4096 # Global batch size. config.num_epochs = 7 config.log_training_steps = 50 config.log_eval_steps = 1000 config.checkpoint_steps = 5000 config.checkpoint_timeout = 10 config.prefetch_to_device = 2 config.trial = 0 # Few-shot eval section config.fewshot = common_fewshot.get_fewshot() config.fewshot.log_steps = 25_000 return config
def get_config(): """Config for training a patch-transformer on JFT.""" config = ml_collections.ConfigDict() config.seed = 0 config.dataset = 'imagenet21k' config.val_split = 'full[:102400]' config.train_split = 'full[102400:]' config.num_classes = 21843 config.init_head_bias = -10.0 config.trial = 0 config.batch_size = 4096 config.num_epochs = 90 pp_common = '|value_range(-1, 1)' config.pp_train = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common config.pp_train += f'|onehot({config.num_classes}, on=0.9999, off=0.0001)' config.pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common config.pp_eval += f'|onehot({config.num_classes})' config.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok. config.log_training_steps = 10000 config.log_eval_steps = 3003 # ~= steps_per_epoch # NOTE: Save infrequently to prevent crowding the disk space. config.checkpoint_steps = 17250 config.checkpoint_timeout = 10 # Model section config.model = ml_collections.ConfigDict() config.model.patches = ml_collections.ConfigDict() config.model.patches.size = [32, 32] config.model.hidden_size = 1024 config.model.transformer = ml_collections.ConfigDict() config.model.transformer.attention_dropout_rate = 0. config.model.transformer.dropout_rate = 0.1 config.model.transformer.mlp_dim = 4096 config.model.transformer.num_heads = 16 config.model.transformer.num_layers = 24 config.model.classifier = 'token' # Or 'gap' config.model.representation_size = 1024 # Heteroscedastic config.het = ml_collections.ConfigDict() config.het.multiclass = False config.het.temperature = 1.5 config.het.mc_samples = 1000 config.het.num_factors = 50 config.het.param_efficient = True # Gaussian process layer section config.gp_layer = ml_collections.ConfigDict() # Use momentum-based (i.e., non-exact) covariance update for pre-training. # This is because the exact covariance update can be unstable for pretraining, # since it involves inverting a precision matrix accumulated over 300M data. config.gp_layer.covmat_momentum = .999 config.gp_layer.ridge_penalty = 1. # No need to use mean field adjustment for pretraining. config.gp_layer.mean_field_factor = -1. # Optimizer section config.optim_name = 'Adam' config.optim = ml_collections.ConfigDict() config.optim.weight_decay = 0.03 config.grad_clip_norm = 1.0 config.optim.beta1 = 0.9 config.optim.beta2 = 0.999 # TODO(lbeyer): make a mini-language like preprocessings. config.lr = ml_collections.ConfigDict() # LR has to be lower for GP layer and on larger models. config.lr.base = 6e-4 # LR has to be lower for larger models! config.lr.warmup_steps = 10_000 config.lr.decay_type = 'linear' config.lr.linear_end = 1e-5 # Few-shot eval section config.fewshot = common_fewshot.get_fewshot() config.fewshot.log_steps = 50_000 return config
def get_config(): """Config.""" config = ml_collections.ConfigDict() config.seed = 0 config.dataset = 'jft/entity:1.0.0' config.val_split = 'test[:49511]' # aka tiny_test/test[:5%] in task_adapt config.train_split = 'train' # task_adapt used train+validation so +64167 config.num_classes = 18291 config.init_head_bias = -10.0 config.resume = '/cns/tp-d/home/trandustin/baselines-jft-0211_032549/1/checkpoint.npz' config.trial = 0 config.batch_size = 4096 config.num_epochs = 14 config.prefetch_to_device = 2 # TODO(trandustin): To resume properly, I removed this setting. Not sure what # it's doing. # config.disable_preemption_reproducibility = True pp_common = '|value_range(-1, 1)' pp_common += f'|onehot({config.num_classes})' pp_common += '|keep(["image", "labels"])' config.pp_train = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common config.pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common config.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok. config.log_training_steps = 50 config.log_eval_steps = 1000 config.checkpoint_steps = 5000 config.checkpoint_timeout = 10 # Model section config.model = ml_collections.ConfigDict() config.model.patches = ml_collections.ConfigDict() config.model.patches.size = [14, 14] config.model.hidden_size = 1280 config.model.transformer = ml_collections.ConfigDict() config.model.transformer.attention_dropout_rate = 0. config.model.transformer.dropout_rate = 0. config.model.transformer.mlp_dim = 5120 config.model.transformer.num_heads = 16 config.model.transformer.num_layers = 32 config.model.classifier = 'token' # Or 'gap' config.model.representation_size = 1280 # BatchEnsemble section # Using last n=5 layers was chosen somewhat arbitrarily, >3 from L/32. config.model.transformer.be_layers = (27, 28, 29, 30, 31) config.model.transformer.ens_size = 3 config.model.transformer.random_sign_init = 0.5 config.fast_weight_lr_multiplier = 1.0 # Optimizer section # We use Adam HP to lower memory. config.optim_name = 'adam_hp' config.optim = ml_collections.ConfigDict() config.optim.beta1 = 0.9 config.optim.beta2 = 0.999 config.weight_decay = 0.1 config.grad_clip_norm = 10.0 config.lr = ml_collections.ConfigDict() # Note original ViT-H/14 uses 4e-4 and no grad clip norm until 130K steps, # then 3e-4 and grad_clip_norm=10.0 after. config.lr.base = 3e-4 # LR has to be lower for larger models! config.lr.warmup_steps = 10_000 config.lr.decay_type = 'linear' config.lr.linear_end = 1e-5 # Few-shot eval section config.fewshot = common_fewshot.get_fewshot() config.fewshot.log_steps = 25_000 return config
def get_config(): """Config for training on JFT300M. Batch size 4096 fits on DF4x4.""" config = ml_collections.ConfigDict() config.seed = 0 # JFT parameters. config.dataset = 'jft/entity:1.0.0' config.val_split = 'test[:49511]' # aka tiny_test/test[:5%] in task_adapt config.train_split = 'train' # task_adapt used train+validation so +64167 config.num_classes = 18291 config.init_head_bias = -10.0 # ~= ln(1/18k) ~= ln(1/num_classes) pp_common = '|value_range(-1, 1)' pp_common += f'|onehot({config.num_classes})' # To use ancestor 'smearing', use this line instead: # pp_common += f'|onehot({config.num_classes}, key='labels_extended', key_result='labels') # pylint: disable=line-too-long pp_common += '|keep(["image", "labels"])' config.pp_train = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common config.pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common config.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok. # Model section config.model = ml_collections.ConfigDict() config.model.patches = ml_collections.ConfigDict() config.model.patches.size = [32, 32] config.model.hidden_size = 512 config.model.representation_size = 512 config.model.classifier = 'token' config.model.transformer = ml_collections.ConfigDict() config.model.transformer.num_layers = 8 config.model.transformer.dropout_rate = 0.0 config.model.transformer.mlp_dim = 2048 config.model.transformer.num_heads = 8 config.model.transformer.attention_dropout_rate = 0.0 # BatchEnsemble parameters config.model.transformer.be_layers = (5, 6, 7) config.model.transformer.ens_size = 2 config.model.transformer.random_sign_init = -0.5 config.fast_weight_lr_multiplier = 1.0 # Optimizer parameters. config.optim_name = 'Adam' config.optim = ml_collections.ConfigDict(dict(beta1=0.9, beta2=0.999)) config.weight_decay = 0.1 # TODO(trandustin): Potentially add weight decay only on slow weights, similar # to original BE and ghassen's BE ViT code. # config.weight_decay = [0.1] # config.weight_decay_pattern = [".*/kernel"] # Does not decay fast-weights. config.grad_clip_norm = None config.lr = ml_collections.ConfigDict() config.lr.base = 1e-3 # LR likely has to be lower for larger models! config.lr.warmup_steps = 10_000 config.lr.decay_type = 'linear' config.lr.linear_end = 1e-5 config.batch_size = 4096 # Global batch size. config.num_epochs = 5 config.log_training_steps = 50 config.log_eval_steps = 1000 config.checkpoint_steps = 5000 config.checkpoint_timeout = 10 config.prefetch_to_device = 2 config.trial = 0 # Few-shot eval section config.fewshot = common_fewshot.get_fewshot() config.fewshot.log_steps = 25_000 return config