def get_config(): config = base_config.get_base_config() config.random_seed = 0 images_per_epoch = 1281167 train_batch_size = 2048 num_epochs = 300 steps_per_epoch = images_per_epoch / train_batch_size config.training_steps = ((images_per_epoch * num_epochs) // train_batch_size) config.experiment_kwargs = config_dict.ConfigDict( dict(config=dict( lr=1e-3, num_epochs=num_epochs, image_size=224, num_classes=1000, which_dataset='imagenet', loss='softmax_cross_entropy', transpose=True, dtype=jnp.bfloat16, lr_schedule=dict(name='cosine_decay_schedule', kwargs=dict(init_value=1e-3, decay_steps=config.training_steps)), optimizer_weights=dict( name='adamw', kwargs=dict(b1=0.9, b2=0.999, weight_decay=0.05)), optimizer_biases=dict(name='adam', kwargs=dict(b1=0.9, b2=0.999)), model=dict(name='BoTNet', config_kwargs=dict(stage_sizes=[3, 4, 6, 6], dtype=jnp.bfloat16)), augment_name='cutmix_mixup_randaugment_405')))
def test_best_checkpoint_saves_only_at_improved_best_metrics(self): config = base_config.get_base_config() config.best_model_eval_metric = _FITNESS_METRIC_KEY config.training_steps = 100 ckpt = DummyCheckpoint() writer = mock.Mock() train.evaluate(DummyExperiment, config, ckpt, writer, jaxline_mode="eval") # The first step will always checkpoint. self.assertLen(ckpt._state_list, len(_IMPROVEMENT_STEPS) + 1) checkpointed_states = [s.global_step for s in ckpt._state_list] self.assertEqual(checkpointed_states, [0] + _IMPROVEMENT_STEPS)
def get_config(debug: bool = False) -> config_dict.ConfigDict: """Get Jaxline experiment config.""" config = base_config.get_base_config() # E.g. '/data/pretrained_models/k0_seed100' (and set k_fold_split_id=0, below) config.restore_path = config_dict.placeholder(str) training_batch_size = 64 eval_batch_size = 64 ## Experiment config. loss_config_name = 'RegressionLossConfig' loss_kwargs = dict( exponent=1., # 2 for l2 loss, 1 for l1 loss, etc... ) dataset_config = dict( data_root=config_dict.placeholder(str), augment_with_random_mirror_symmetry=True, k_fold_split_id=config_dict.placeholder(int), num_k_fold_splits=config_dict.placeholder(int), # Options: "in" or "out". # Filter=in would keep the samples with nans in the conformer features. # Filter=out would keep the samples with no NaNs anywhere in the conformer # features. filter_in_or_out_samples_with_nans_in_conformers=( config_dict.placeholder(str)), cached_conformers_file=config_dict.placeholder(str)) model_config = dict( mlp_hidden_size=512, mlp_layers=2, latent_size=512, use_layer_norm=False, num_message_passing_steps=32, shared_message_passing_weights=False, mask_padding_graph_at_every_step=True, loss_config_name=loss_config_name, loss_kwargs=loss_kwargs, processor_mode='resnet', global_reducer='sum', node_reducer='sum', dropedge_rate=0.1, dropnode_rate=0.1, aux_multiplier=0.1, add_relative_distance=True, add_relative_displacement=True, add_absolute_positions=False, position_normalization=2., relative_displacement_normalization=1., ignore_globals=False, ignore_globals_from_final_layer_for_predictions=True, ) if debug: # Make network smaller. model_config.update(dict( mlp_hidden_size=32, mlp_layers=1, latent_size=32, num_message_passing_steps=1)) config.experiment_kwargs = config_dict.ConfigDict( dict( config=dict( debug=debug, predictions_dir=config_dict.placeholder(str), ema=True, ema_decay=0.9999, sample_random=0.05, optimizer=dict( name='adam', optimizer_kwargs=dict(b1=.9, b2=.95), lr_schedule=dict( warmup_steps=int(5e4), decay_steps=int(5e5), init_value=1e-5, peak_value=1e-4, end_value=0., ), ), model=model_config, dataset_config=dataset_config, # As a rule of thumb, use the following statistics: # Avg. # nodes in graph: 16. # Avg. # edges in graph: 40. training=dict( dynamic_batch_size={ 'n_node': 256 if debug else 16 * training_batch_size, 'n_edge': 512 if debug else 40 * training_batch_size, 'n_graph': 2 if debug else training_batch_size, },), evaluation=dict( split='valid', dynamic_batch_size=dict( n_node=256 if debug else 16 * eval_batch_size, n_edge=512 if debug else 40 * eval_batch_size, n_graph=2 if debug else eval_batch_size, ))))) ## Training loop config. config.training_steps = int(5e6) config.checkpoint_dir = '/tmp/checkpoint/pcq/' config.train_checkpoint_all_hosts = False config.save_checkpoint_interval = 300 config.log_train_data_interval = 60 config.log_tensors_interval = 60 config.best_model_eval_metric = 'mae' config.best_model_eval_metric_higher_is_better = False return config
def get_config(arg_string): """Return config object for training.""" args = arg_string.split(",") if len(args) != 3: raise ValueError( "You must provide exactly three arguments separated by a " "comma - model_config_name,sweep_index,dataset_name.") model_config_name, sweep_index, dataset_name = args sweep_index = int(sweep_index) config = base_config.get_base_config() config.random_seed = 123109801 config.eval_modes = ("eval", "eval_metric") # Get the model config and the sweeps if model_config_name not in globals(): raise ValueError( f"The config name {model_config_name} does not exist in " f"jaxline_configs.py") config_and_sweep_fn = globals()[model_config_name] model_config, sweeps = config_and_sweep_fn() if not os.environ.get(_DATASETS_PATH_VAR_NAME, None): raise ValueError(f"You need to set the {_DATASETS_PATH_VAR_NAME}") dm_hamiltonian_suite_path = os.environ[_DATASETS_PATH_VAR_NAME] dataset_folder = os.path.join(dm_hamiltonian_suite_path, dataset_name) # Experiment config. Note that batch_size is per device. # In the experiments we run on 4 GPUs, so the effective batch size was 128. config.experiment_kwargs = collections.ConfigDict( dict(config=dict( dataset_folder=dataset_folder, model_kwargs=model_config, num_extrapolation_steps=60, drop_stats_containing=("neg_log_p_x", "l2_over_time", "neg_elbo"), optimizer=dict(name="adam", kwargs=dict( learning_rate=1.5e-4, b1=0.9, b2=0.999, )), training=dict(batch_size=32, burnin_steps=5, num_epochs=None, lagging_vae=False), evaluation=dict(batch_size=64, ), evaluation_metric=dict( batch_size=5, batch_n=20, num_eval_metric_steps=60, max_poly_order=5, max_jacobian_score=1000, rsq_threshold=0.9, sym_threshold=0.05, evaluation_point_n=10, weight_tolerance=1e-03, max_iter=1000, cv=2, alpha_min_logspace=-4, alpha_max_logspace=-0.5, alpha_step_n=10, calculate_fully_after_steps=40000, ), evaluation_metric_mlp=dict( batch_size=64, batch_n=10000, datapoint_param_multiplier=1000, num_eval_metric_steps=60, evaluation_point_n=10, evaluation_trajectory_n=50, rsq_threshold=0.9, sym_threshold=0.05, ridge_lambda=0.01, model=dict( num_units=4, num_layers=4, activation="tanh", ), optimizer=dict(name="adam", kwargs=dict(learning_rate=1.5e-3, )), ), evaluation_vpt=dict( batch_size=5, batch_n=2, vpt_threshold=0.025, )))) # Training loop config. config.training_steps = int(500000) config.interval_type = "steps" config.log_tensors_interval = 50 config.log_train_data_interval = 50 config.log_all_train_data = False config.save_checkpoint_interval = 100 config.checkpoint_dir = "/tmp/physics_inspired_models/" config.train_checkpoint_all_hosts = False config.eval_specific_checkpoint_dir = "" config.update_from_flattened_dict(sweeps[sweep_index]) return config
def get_config(): """Return config object for training.""" config = base_config.get_base_config() # Experiment config. train_batch_size = 1024 # Global batch size. images_per_epoch = 1281167 num_epochs = 90 steps_per_epoch = images_per_epoch / train_batch_size config.training_steps = ((images_per_epoch * num_epochs) // train_batch_size) config.random_seed = 0 config.experiment_kwargs = config_dict.ConfigDict( dict( config=dict( lr=0.1, num_epochs=num_epochs, label_smoothing=0.1, model='ResNet', image_size=224, use_ema=False, ema_decay=0.9999, # Quatros nuevos amigos ema_start=0, which_ema='tf1_ema', augment_name=None, # 'mixup_cutmix', augment_before_mix=True, eval_preproc='crop_resize', train_batch_size=train_batch_size, eval_batch_size=50, eval_subset='test', num_classes=1000, which_dataset='imagenet', fake_data=False, which_loss='softmax_cross_entropy', # For now, must be softmax transpose=True, # Use the double-transpose trick? bfloat16=False, lr_schedule=dict( name='WarmupCosineDecay', kwargs=dict(num_steps=config.training_steps, start_val=0, min_val=0, warmup_steps=5 * steps_per_epoch), ), lr_scale_by_bs=True, optimizer=dict( name='SGD', kwargs={ 'momentum': 0.9, 'nesterov': True, 'weight_decay': 1e-4, }, ), model_kwargs=dict( width=4, which_norm='BatchNorm', norm_kwargs=dict( create_scale=True, create_offset=True, decay_rate=0.9, ), # cross_replica_axis='i'), variant='ResNet50', activation='relu', drop_rate=0.0, ), ), )) # Training loop config: log and checkpoint every minute config.log_train_data_interval = 60 config.log_tensors_interval = 60 config.save_checkpoint_interval = 60 config.eval_specific_checkpoint_dir = '' return config
def get_config(debug: bool = False) -> config_dict.ConfigDict: """Get Jaxline experiment config.""" config = base_config.get_base_config() config.random_seed = 42 # E.g. '/data/pretrained_models/k0_seed100' (and set k_fold_split_id=0, below) config.restore_path = config_dict.placeholder(str) config.experiment_kwargs = config_dict.ConfigDict( dict(config=dict( debug=debug, predictions_dir=config_dict.placeholder(str), # 5 for model selection and early stopping, 50 for final eval. num_eval_iterations_to_ensemble=5, dataset_kwargs=dict( data_root='/data/', online_subsampling_kwargs=dict( max_nb_neighbours_per_type=[ [[40, 20, 0, 40], [0, 0, 0, 0], [0, 0, 0, 0]], [[40, 20, 0, 40], [40, 0, 10, 0], [0, 0, 0, 0]], ], remove_future_nodes=True, deduplicate_nodes=True, ), ratio_unlabeled_data_to_labeled_data=10.0, k_fold_split_id=config_dict.placeholder(int), use_all_labels_when_not_training=False, use_dummy_adjacencies=debug, ), optimizer=dict( name='adamw', kwargs=dict(weight_decay=1e-5, b1=0.9, b2=0.999), learning_rate_schedule=dict( use_schedule=True, base_learning_rate=1e-2, warmup_steps=50000, total_steps=config.get_ref('training_steps'), ), ), model_config=dict( mlp_hidden_sizes=[32] if debug else [512], latent_size=32 if debug else 256, num_message_passing_steps=2 if debug else 4, activation='relu', dropout_rate=0.3, dropedge_rate=0.25, disable_edge_updates=True, use_sent_edges=True, normalization_type='layer_norm', aggregation_function='sum', ), training=dict( loss_config=dict(bgrl_loss_config=dict( stop_gradient_for_supervised_loss=False, bgrl_loss_scale=1.0, symmetrize=True, first_graph_corruption_config=dict( feature_drop_prob=0.4, edge_drop_prob=0.2, ), second_graph_corruption_config=dict( feature_drop_prob=0.4, edge_drop_prob=0.2, ), ), ), # GPU memory may require reducing the `256`s below to `48`. dynamic_batch_size_config=dict( n_node=256 if debug else 340 * 256, n_edge=512 if debug else 720 * 256, n_graph=4 if debug else 256, ), ), eval=dict( split='valid', ema_annealing_schedule=dict(use_schedule=True, base_rate=0.999, total_steps=config.get_ref( 'training_steps')), dynamic_batch_size_config=dict( n_node=256 if debug else 340 * 128, n_edge=512 if debug else 720 * 128, n_graph=4 if debug else 128, ), )))) ## Training loop config. config.training_steps = 500000 config.checkpoint_dir = '/tmp/checkpoint/mag/' config.train_checkpoint_all_hosts = False config.log_train_data_interval = 10 config.log_tensors_interval = 10 config.save_checkpoint_interval = 30 config.best_model_eval_metric = 'accuracy' config.best_model_eval_metric_higher_is_better = True return config
def get_config(): """Return config object for training.""" use_debug_settings = IS_LOCAL config = base_config.get_base_config() # Experiment config. local_batch_size = 2 # Modify this to adapt to your custom distributed learning setup num_devices = 1 config.train_batch_size = local_batch_size * num_devices config.n_epochs = 110 def _default_or_debug(default_value, debug_value): return debug_value if use_debug_settings else default_value n_train_examples = N_TRAIN_EXAMPLES num_classes = N_CLASSES config.experiment_kwargs = config_dict.ConfigDict( dict(config=dict( optimizer=dict( base_lr=5e-4, max_norm=10.0, # < 0 to turn off. schedule_type='constant_cosine', weight_decay=1e-1, decay_pos_embs=True, scale_by_batch=True, cosine_decay_kwargs=dict( init_value=0.0, warmup_epochs=0, end_value=0.0, ), step_decay_kwargs=dict( decay_boundaries=[0.5, 0.8, 0.95], decay_rate=0.1, ), constant_cosine_decay_kwargs=dict( constant_fraction=0.5, end_value=0.0, ), optimizer='lamb', # Optimizer-specific kwargs: adam_kwargs=dict( b1=0.9, b2=0.999, eps=1e-8, ), lamb_kwargs=dict( b1=0.9, b2=0.999, eps=1e-6, ), ), # Don't specify output_channels - it's not used for # classifiers. model=dict( perceiver_kwargs=dict( input_preprocessor=dict( prep_type='pixels', # Channels for conv/conv1x1 preprocessing: num_channels=64, # ------------------------- # Position encoding arguments: # ------------------------- position_encoding_type='fourier', concat_or_add_pos='concat', spatial_downsample=1, # If >0, project position to this size: project_pos_dim=-1, trainable_position_encoding_kwargs=dict( num_channels=258, # Match default # for Fourier. init_scale=0.02, ), fourier_position_encoding_kwargs=dict( num_bands=64, max_resolution=(224, 224), sine_only=False, concat_pos=True, ), ), encoder=dict( num_self_attends_per_block=_default_or_debug(6, 2), # Weights won't be shared if num_blocks is set to 1. num_blocks=_default_or_debug(8, 2), z_index_dim=512, num_z_channels=1024, num_cross_attend_heads=1, num_self_attend_heads=8, cross_attend_widening_factor=1, self_attend_widening_factor=1, dropout_prob=0.0, # Position encoding for the latent array. z_pos_enc_init_scale=0.02, cross_attention_shape_for_attn='kv', use_query_residual=True, ), decoder=dict( num_z_channels=1024, use_query_residual=True, # Position encoding for the output logits. position_encoding_type='trainable', trainable_position_encoding_kwargs=dict( num_channels=1024, init_scale=0.02, ), ), ), ), training=dict(images_per_epoch=n_train_examples, label_smoothing=0.1, n_epochs=config.get_oneway_ref('n_epochs'), batch_size=config.get_oneway_ref( 'train_batch_size')), data=dict( num_classes=num_classes, # Run on smaller images to debug. im_dim=_default_or_debug(224, 32), augmentation=dict( # Typical randaug params: # num_layers in [1, 3] # magnitude in [5, 30] # Set randaugment to None to disable. randaugment=dict(num_layers=4, magnitude=5), cutmix=True, # Mixup alpha should be in [0, 1]. # Set to None to disable. mixup_alpha=0.2, ), ), evaluation=dict( subset='test', batch_size=2, ), ))) # Training loop config. config.training_steps = get_training_steps( config.get_oneway_ref('train_batch_size'), config.get_oneway_ref('n_epochs')) config.log_train_data_interval = 60 config.log_tensors_interval = 60 config.save_checkpoint_interval = 300 config.eval_specific_checkpoint_dir = '' config.best_model_eval_metric = 'eval_top_1_acc' config.checkpoint_dir = '/tmp/perceiver_imagnet_checkpoints' config.train_checkpoint_all_hosts = False # Prevents accidentally setting keys that aren't recognized (e.g. in tests). config.lock() return config
def get_config(): """Return config object for training.""" config = base_config.get_base_config() # Batch size, training steps and data. num_classes = 10 num_epochs = 400 # Gowal et al. (2020) and Rebuffi et al. (2021) use 1024 as batch size. # Reducing this batch size may require further adjustments to the batch # normalization decay or the learning rate. If you have to use a batch size # of 256, reduce the number of emulated workers to 1 (it should match the # results of using a batch size of 1024 with 4 workers). train_batch_size = 1024 def steps_from_epochs(n): return max(int(n * 50_000 / train_batch_size), 1) num_steps = steps_from_epochs(num_epochs) test_batch_size = train_batch_size # Specify the path to the downloaded data. You can download data from # https://github.com/deepmind/deepmind-research/tree/master/adversarial_robustness. # If the path is set to "cifar10_ddpm.npz" and is not found in the current # directory, the corresponding data will be downloaded. extra_npz = 'cifar10_ddpm.npz' # Learning rate. learning_rate = .1 * max(train_batch_size / 256, 1.) learning_rate_warmup = steps_from_epochs(10) learning_rate_fn = utils.get_cosine_schedule(learning_rate, num_steps, learning_rate_warmup) # Model definition. model_ctor = model_zoo.WideResNet model_kwargs = dict(num_classes=num_classes, depth=28, width=10, activation='swish') # Attack used during training (can be None). epsilon = 8 / 255 train_attack = attacks.UntargetedAttack( attacks.PGD(attacks.Adam( optax.piecewise_constant_schedule(init_value=.1, boundaries_and_scales={5: .1})), num_steps=10, initialize_fn=attacks.linf_initialize_fn(epsilon), project_fn=attacks.linf_project_fn(epsilon, bounds=(0., 1.))), loss_fn=attacks.untargeted_kl_divergence) # Attack used during evaluation (can be None). eval_attack = attacks.UntargetedAttack(attacks.PGD( attacks.Adam(learning_rate_fn=optax.piecewise_constant_schedule( init_value=.1, boundaries_and_scales={ 20: .1, 30: .01 })), num_steps=40, initialize_fn=attacks.linf_initialize_fn(epsilon), project_fn=attacks.linf_project_fn(epsilon, bounds=(0., 1.))), loss_fn=attacks.untargeted_margin) config.experiment_kwargs = config_dict.ConfigDict( dict(config=dict( epsilon=epsilon, num_classes=num_classes, # Results from various publications use 4 worker machines, which results # in slight differences when using less worker machines. To compensate for # such discrepancies, we emulate these additional workers. Set to zero, # when using more than 4 workers. emulated_workers=4, dry_run=False, save_final_checkpoint_as_npy=True, model=dict(constructor=model_ctor, kwargs=model_kwargs), training=dict(batch_size=train_batch_size, learning_rate=learning_rate_fn, weight_decay=5e-4, swa_decay=.995, use_cutmix=False, supervised_batch_ratio=.3, extra_data_path=extra_npz, extra_label_smoothing=.1, attack=train_attack), evaluation=dict( # If `interval` is positive, synchronously evaluate at regular # intervals. Setting it to zero will not evaluate while training, # unless `--jaxline_mode` is set to `train_eval_multithreaded`, which # asynchronously evaluates checkpoints. interval=steps_from_epochs(40), batch_size=test_batch_size, attack=eval_attack), ))) config.checkpoint_dir = '/tmp/jaxline/robust' config.train_checkpoint_all_hosts = False config.training_steps = num_steps config.interval_type = 'steps' config.log_train_data_interval = steps_from_epochs(.5) config.log_tensors_interval = steps_from_epochs(.5) config.save_checkpoint_interval = steps_from_epochs(40) config.eval_specific_checkpoint_dir = '' return config