def test_experiment(): """Tests the main experiment.""" config = experiment.get_config() exp_config = config.experiment_kwargs.config exp_config.train_batch_size = 2 exp_config.eval_batch_size = 2 exp_config.lr = 0.1 exp_config.fake_data = True exp_config.model_kwargs.width = 2 print(exp_config.model_kwargs) xp = experiment.Experiment('train', exp_config, jax.random.PRNGKey(0)) bcast = jax.pmap(lambda x: x) global_step = bcast(jnp.zeros(jax.local_device_count())) rng = bcast(jnp.stack([jax.random.PRNGKey(0)] * jax.local_device_count())) print('Taking a single experiment step for test purposes!') result = xp.step(global_step, rng) print(f'Step successfully taken, resulting metrics are {result}')
def get_config(): """Return config object for training.""" config = experiment.get_config() # Experiment config. train_batch_size = 4096 # Global batch size. images_per_epoch = 1281167 num_epochs = 360 steps_per_epoch = images_per_epoch / train_batch_size config.training_steps = ((images_per_epoch * num_epochs) // train_batch_size) config.random_seed = 0 config.experiment_kwargs = config_dict.ConfigDict( dict( config=dict( lr=0.1, num_epochs=num_epochs, label_smoothing=0.1, model='NFNet', image_size=224, use_ema=True, ema_decay=0.99999, ema_start=0, augment_name=None, augment_before_mix=False, eval_preproc='resize_crop_32', train_batch_size=train_batch_size, eval_batch_size=50, eval_subset='test', num_classes=1000, which_dataset='imagenet', which_loss='softmax_cross_entropy', # One of softmax or sigmoid bfloat16=True, lr_schedule=dict( name='WarmupCosineDecay', kwargs=dict(num_steps=config.training_steps, start_val=0, min_val=0.0, warmup_steps=5*steps_per_epoch), ), lr_scale_by_bs=True, optimizer=dict( name='SGD_AGC', kwargs={'momentum': 0.9, 'nesterov': True, 'weight_decay': 2e-5, 'clipping': 0.01, 'eps': 1e-3}, ), model_kwargs=dict( variant='F0', width=1.0, se_ratio=0.5, alpha=0.2, stochdepth_rate=0.25, drop_rate=None, # Use native drop-rate activation='gelu', final_conv_mult=2, final_conv_ch=None, use_two_convs=True, ), ))) # Unlike NF-RegNets, use the same weight decay for all, but vary RA levels variant = config.experiment_kwargs.config.model_kwargs.variant # RandAugment levels (e.g. 405 = 4 layers, magnitude 5, 205 = 2 layers, mag 5) augment = {'F0': '405', 'F1': '410', 'F2': '410', 'F3': '415', 'F4': '415', 'F5': '415', 'F6': '415', 'F7': '415'}[variant] aug_base_name = 'cutmix_mixup_randaugment' config.experiment_kwargs.config.augment_name = f'{aug_base_name}_{augment}' return config
def get_config(): """Return config object for training.""" config = experiment.get_config() # Experiment config. train_batch_size = 1024 # Global batch size. images_per_epoch = 1281167 num_epochs = 360 steps_per_epoch = images_per_epoch / train_batch_size config.training_steps = ((images_per_epoch * num_epochs) // train_batch_size) config.random_seed = 0 config.experiment_kwargs = config_dict.ConfigDict( dict(config=dict( lr=0.4, num_epochs=num_epochs, label_smoothing=0.1, model='NF_RegNet', image_size=224, use_ema=True, ema_decay=0.99999, # Cinco nueves amigos ema_start=0, augment_name='mixup_cutmix', train_batch_size=train_batch_size, eval_batch_size=50, eval_subset='test', num_classes=1000, which_dataset='imagenet', which_loss='softmax_cross_entropy', # One of softmax or sigmoid bfloat16=False, lr_schedule=dict( name='WarmupCosineDecay', kwargs=dict(num_steps=config.training_steps, start_val=0, min_val=0.001, warmup_steps=5 * steps_per_epoch), ), lr_scale_by_bs=False, optimizer=dict( name='SGD', kwargs={ 'momentum': 0.9, 'nesterov': True, 'weight_decay': 5e-5, }, ), model_kwargs=dict( variant='B0', width=0.75, expansion=2.25, se_ratio=0.5, alpha=0.2, stochdepth_rate=0.1, drop_rate=None, activation='silu', ), ))) # Set weight decay based on variant (scaled as 5e-5 + 1e-5 * level) variant = config.experiment_kwargs.config.model_kwargs.variant weight_decay = { 'B0': 5e-5, 'B1': 6e-5, 'B2': 7e-5, 'B3': 8e-5, 'B4': 9e-5, 'B5': 1e-4 }[variant] config.experiment_kwargs.config.optimizer.kwargs.weight_decay = weight_decay return config
# You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Quick script to test that experiment can import and run.""" import jax import jax.numpy as jnp from nfnets import experiment config = experiment.get_config() exp_config = config.experiment_kwargs.config exp_config.train_batch_size = 2 exp_config.eval_batch_size = 2 exp_config.lr = 0.1 exp_config.fake_data = True exp_config.model_kwargs.width = 2 print(exp_config.model_kwargs) xp = experiment.Experiment('train', exp_config, jax.random.PRNGKey(0)) bcast = jax.pmap(lambda x: x) global_step = bcast(jnp.zeros(jax.local_device_count())) rng = bcast(jnp.stack([jax.random.PRNGKey(0)] * jax.local_device_count())) print('Taking a single experiment step for test purposes!') result = xp.step(global_step, rng) print(f'Step successfully taken, resulting metrics are {result}')