Ejemplo n.º 1
0
def test_experiment():
    """Tests the main experiment."""
    config = experiment.get_config()
    exp_config = config.experiment_kwargs.config
    exp_config.train_batch_size = 2
    exp_config.eval_batch_size = 2
    exp_config.lr = 0.1
    exp_config.fake_data = True
    exp_config.model_kwargs.width = 2
    print(exp_config.model_kwargs)

    xp = experiment.Experiment('train', exp_config, jax.random.PRNGKey(0))
    bcast = jax.pmap(lambda x: x)
    global_step = bcast(jnp.zeros(jax.local_device_count()))
    rng = bcast(jnp.stack([jax.random.PRNGKey(0)] * jax.local_device_count()))
    print('Taking a single experiment step for test purposes!')
    result = xp.step(global_step, rng)
    print(f'Step successfully taken, resulting metrics are {result}')
def get_config():
  """Return config object for training."""
  config = experiment.get_config()

  # Experiment config.
  train_batch_size = 4096  # Global batch size.
  images_per_epoch = 1281167
  num_epochs = 360
  steps_per_epoch = images_per_epoch / train_batch_size
  config.training_steps = ((images_per_epoch * num_epochs) // train_batch_size)
  config.random_seed = 0

  config.experiment_kwargs = config_dict.ConfigDict(
      dict(
          config=dict(
              lr=0.1,
              num_epochs=num_epochs,
              label_smoothing=0.1,
              model='NFNet',
              image_size=224,
              use_ema=True,
              ema_decay=0.99999,
              ema_start=0,
              augment_name=None,
              augment_before_mix=False,
              eval_preproc='resize_crop_32',
              train_batch_size=train_batch_size,
              eval_batch_size=50,
              eval_subset='test',
              num_classes=1000,
              which_dataset='imagenet',
              which_loss='softmax_cross_entropy',  # One of softmax or sigmoid
              bfloat16=True,
              lr_schedule=dict(
                  name='WarmupCosineDecay',
                  kwargs=dict(num_steps=config.training_steps,
                              start_val=0,
                              min_val=0.0,
                              warmup_steps=5*steps_per_epoch),
                  ),
              lr_scale_by_bs=True,
              optimizer=dict(
                  name='SGD_AGC',
                  kwargs={'momentum': 0.9, 'nesterov': True,
                          'weight_decay': 2e-5,
                          'clipping': 0.01, 'eps': 1e-3},
              ),
              model_kwargs=dict(
                  variant='F0',
                  width=1.0,
                  se_ratio=0.5,
                  alpha=0.2,
                  stochdepth_rate=0.25,
                  drop_rate=None,  # Use native drop-rate
                  activation='gelu',
                  final_conv_mult=2,
                  final_conv_ch=None,
                  use_two_convs=True,
                  ),
              )))

  # Unlike NF-RegNets, use the same weight decay for all, but vary RA levels
  variant = config.experiment_kwargs.config.model_kwargs.variant
  # RandAugment levels (e.g. 405 = 4 layers, magnitude 5, 205 = 2 layers, mag 5)
  augment = {'F0': '405', 'F1': '410', 'F2': '410', 'F3': '415',
             'F4': '415', 'F5': '415', 'F6': '415', 'F7': '415'}[variant]
  aug_base_name = 'cutmix_mixup_randaugment'
  config.experiment_kwargs.config.augment_name = f'{aug_base_name}_{augment}'

  return config
def get_config():
    """Return config object for training."""
    config = experiment.get_config()

    # Experiment config.
    train_batch_size = 1024  # Global batch size.
    images_per_epoch = 1281167
    num_epochs = 360
    steps_per_epoch = images_per_epoch / train_batch_size
    config.training_steps = ((images_per_epoch * num_epochs) //
                             train_batch_size)
    config.random_seed = 0

    config.experiment_kwargs = config_dict.ConfigDict(
        dict(config=dict(
            lr=0.4,
            num_epochs=num_epochs,
            label_smoothing=0.1,
            model='NF_RegNet',
            image_size=224,
            use_ema=True,
            ema_decay=0.99999,  # Cinco nueves amigos
            ema_start=0,
            augment_name='mixup_cutmix',
            train_batch_size=train_batch_size,
            eval_batch_size=50,
            eval_subset='test',
            num_classes=1000,
            which_dataset='imagenet',
            which_loss='softmax_cross_entropy',  # One of softmax or sigmoid
            bfloat16=False,
            lr_schedule=dict(
                name='WarmupCosineDecay',
                kwargs=dict(num_steps=config.training_steps,
                            start_val=0,
                            min_val=0.001,
                            warmup_steps=5 * steps_per_epoch),
            ),
            lr_scale_by_bs=False,
            optimizer=dict(
                name='SGD',
                kwargs={
                    'momentum': 0.9,
                    'nesterov': True,
                    'weight_decay': 5e-5,
                },
            ),
            model_kwargs=dict(
                variant='B0',
                width=0.75,
                expansion=2.25,
                se_ratio=0.5,
                alpha=0.2,
                stochdepth_rate=0.1,
                drop_rate=None,
                activation='silu',
            ),
        )))

    # Set weight decay based on variant (scaled as 5e-5 + 1e-5 * level)
    variant = config.experiment_kwargs.config.model_kwargs.variant
    weight_decay = {
        'B0': 5e-5,
        'B1': 6e-5,
        'B2': 7e-5,
        'B3': 8e-5,
        'B4': 9e-5,
        'B5': 1e-4
    }[variant]
    config.experiment_kwargs.config.optimizer.kwargs.weight_decay = weight_decay

    return config
Ejemplo n.º 4
0
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Quick script to test that experiment can import and run."""
import jax
import jax.numpy as jnp
from nfnets import experiment

config = experiment.get_config()
exp_config = config.experiment_kwargs.config
exp_config.train_batch_size = 2
exp_config.eval_batch_size = 2
exp_config.lr = 0.1
exp_config.fake_data = True
exp_config.model_kwargs.width = 2
print(exp_config.model_kwargs)

xp = experiment.Experiment('train', exp_config, jax.random.PRNGKey(0))
bcast = jax.pmap(lambda x: x)
global_step = bcast(jnp.zeros(jax.local_device_count()))
rng = bcast(jnp.stack([jax.random.PRNGKey(0)] * jax.local_device_count()))
print('Taking a single experiment step for test purposes!')
result = xp.step(global_step, rng)
print(f'Step successfully taken, resulting metrics are {result}')