Ejemplo n.º 1
0
def get_config():
    config = base_config.get_base_config()

    config.random_seed = 0
    images_per_epoch = 1281167
    train_batch_size = 2048
    num_epochs = 300
    steps_per_epoch = images_per_epoch / train_batch_size
    config.training_steps = ((images_per_epoch * num_epochs) //
                             train_batch_size)
    config.experiment_kwargs = config_dict.ConfigDict(
        dict(config=dict(
            lr=1e-3,
            num_epochs=num_epochs,
            image_size=224,
            num_classes=1000,
            which_dataset='imagenet',
            loss='softmax_cross_entropy',
            transpose=True,
            dtype=jnp.bfloat16,
            lr_schedule=dict(name='cosine_decay_schedule',
                             kwargs=dict(init_value=1e-3,
                                         decay_steps=config.training_steps)),
            optimizer_weights=dict(
                name='adamw', kwargs=dict(b1=0.9, b2=0.999, weight_decay=0.05)),
            optimizer_biases=dict(name='adam', kwargs=dict(b1=0.9, b2=0.999)),
            model=dict(name='BoTNet',
                       config_kwargs=dict(stage_sizes=[3, 4, 6, 6],
                                          dtype=jnp.bfloat16)),
            augment_name='cutmix_mixup_randaugment_405')))
Ejemplo n.º 2
0
    def test_best_checkpoint_saves_only_at_improved_best_metrics(self):
        config = base_config.get_base_config()
        config.best_model_eval_metric = _FITNESS_METRIC_KEY
        config.training_steps = 100
        ckpt = DummyCheckpoint()
        writer = mock.Mock()
        train.evaluate(DummyExperiment,
                       config,
                       ckpt,
                       writer,
                       jaxline_mode="eval")

        # The first step will always checkpoint.
        self.assertLen(ckpt._state_list, len(_IMPROVEMENT_STEPS) + 1)
        checkpointed_states = [s.global_step for s in ckpt._state_list]
        self.assertEqual(checkpointed_states, [0] + _IMPROVEMENT_STEPS)
Ejemplo n.º 3
0
def get_config(debug: bool = False) -> config_dict.ConfigDict:
  """Get Jaxline experiment config."""
  config = base_config.get_base_config()
  # E.g. '/data/pretrained_models/k0_seed100' (and set k_fold_split_id=0, below)
  config.restore_path = config_dict.placeholder(str)

  training_batch_size = 64
  eval_batch_size = 64

  ## Experiment config.
  loss_config_name = 'RegressionLossConfig'
  loss_kwargs = dict(
      exponent=1.,  # 2 for l2 loss, 1 for l1 loss, etc...
  )

  dataset_config = dict(
      data_root=config_dict.placeholder(str),
      augment_with_random_mirror_symmetry=True,
      k_fold_split_id=config_dict.placeholder(int),
      num_k_fold_splits=config_dict.placeholder(int),
      # Options: "in" or "out".
      # Filter=in would keep the samples with nans in the conformer features.
      # Filter=out would keep the samples with no NaNs anywhere in the conformer
      # features.
      filter_in_or_out_samples_with_nans_in_conformers=(
          config_dict.placeholder(str)),
      cached_conformers_file=config_dict.placeholder(str))

  model_config = dict(
      mlp_hidden_size=512,
      mlp_layers=2,
      latent_size=512,
      use_layer_norm=False,
      num_message_passing_steps=32,
      shared_message_passing_weights=False,
      mask_padding_graph_at_every_step=True,
      loss_config_name=loss_config_name,
      loss_kwargs=loss_kwargs,
      processor_mode='resnet',
      global_reducer='sum',
      node_reducer='sum',
      dropedge_rate=0.1,
      dropnode_rate=0.1,
      aux_multiplier=0.1,
      add_relative_distance=True,
      add_relative_displacement=True,
      add_absolute_positions=False,
      position_normalization=2.,
      relative_displacement_normalization=1.,
      ignore_globals=False,
      ignore_globals_from_final_layer_for_predictions=True,
  )

  if debug:
    # Make network smaller.
    model_config.update(dict(
        mlp_hidden_size=32,
        mlp_layers=1,
        latent_size=32,
        num_message_passing_steps=1))

  config.experiment_kwargs = config_dict.ConfigDict(
      dict(
          config=dict(
              debug=debug,
              predictions_dir=config_dict.placeholder(str),
              ema=True,
              ema_decay=0.9999,
              sample_random=0.05,
              optimizer=dict(
                  name='adam',
                  optimizer_kwargs=dict(b1=.9, b2=.95),
                  lr_schedule=dict(
                      warmup_steps=int(5e4),
                      decay_steps=int(5e5),
                      init_value=1e-5,
                      peak_value=1e-4,
                      end_value=0.,
                  ),
              ),
              model=model_config,
              dataset_config=dataset_config,
              # As a rule of thumb, use the following statistics:
              # Avg. # nodes in graph: 16.
              # Avg. # edges in graph: 40.
              training=dict(
                  dynamic_batch_size={
                      'n_node': 256 if debug else 16 * training_batch_size,
                      'n_edge': 512 if debug else 40 * training_batch_size,
                      'n_graph': 2 if debug else training_batch_size,
                  },),
              evaluation=dict(
                  split='valid',
                  dynamic_batch_size=dict(
                      n_node=256 if debug else 16 * eval_batch_size,
                      n_edge=512 if debug else 40 * eval_batch_size,
                      n_graph=2 if debug else eval_batch_size,
                  )))))

  ## Training loop config.
  config.training_steps = int(5e6)
  config.checkpoint_dir = '/tmp/checkpoint/pcq/'
  config.train_checkpoint_all_hosts = False
  config.save_checkpoint_interval = 300
  config.log_train_data_interval = 60
  config.log_tensors_interval = 60
  config.best_model_eval_metric = 'mae'
  config.best_model_eval_metric_higher_is_better = False

  return config
Ejemplo n.º 4
0
def get_config(arg_string):
    """Return config object for training."""
    args = arg_string.split(",")
    if len(args) != 3:
        raise ValueError(
            "You must provide exactly three arguments separated by a "
            "comma - model_config_name,sweep_index,dataset_name.")
    model_config_name, sweep_index, dataset_name = args
    sweep_index = int(sweep_index)

    config = base_config.get_base_config()
    config.random_seed = 123109801
    config.eval_modes = ("eval", "eval_metric")

    # Get the model config and the sweeps
    if model_config_name not in globals():
        raise ValueError(
            f"The config name {model_config_name} does not exist in "
            f"jaxline_configs.py")
    config_and_sweep_fn = globals()[model_config_name]
    model_config, sweeps = config_and_sweep_fn()

    if not os.environ.get(_DATASETS_PATH_VAR_NAME, None):
        raise ValueError(f"You need to set the {_DATASETS_PATH_VAR_NAME}")
    dm_hamiltonian_suite_path = os.environ[_DATASETS_PATH_VAR_NAME]
    dataset_folder = os.path.join(dm_hamiltonian_suite_path, dataset_name)

    # Experiment config. Note that batch_size is per device.
    # In the experiments we run on 4 GPUs, so the effective batch size was 128.
    config.experiment_kwargs = collections.ConfigDict(
        dict(config=dict(
            dataset_folder=dataset_folder,
            model_kwargs=model_config,
            num_extrapolation_steps=60,
            drop_stats_containing=("neg_log_p_x", "l2_over_time", "neg_elbo"),
            optimizer=dict(name="adam",
                           kwargs=dict(
                               learning_rate=1.5e-4,
                               b1=0.9,
                               b2=0.999,
                           )),
            training=dict(batch_size=32,
                          burnin_steps=5,
                          num_epochs=None,
                          lagging_vae=False),
            evaluation=dict(batch_size=64, ),
            evaluation_metric=dict(
                batch_size=5,
                batch_n=20,
                num_eval_metric_steps=60,
                max_poly_order=5,
                max_jacobian_score=1000,
                rsq_threshold=0.9,
                sym_threshold=0.05,
                evaluation_point_n=10,
                weight_tolerance=1e-03,
                max_iter=1000,
                cv=2,
                alpha_min_logspace=-4,
                alpha_max_logspace=-0.5,
                alpha_step_n=10,
                calculate_fully_after_steps=40000,
            ),
            evaluation_metric_mlp=dict(
                batch_size=64,
                batch_n=10000,
                datapoint_param_multiplier=1000,
                num_eval_metric_steps=60,
                evaluation_point_n=10,
                evaluation_trajectory_n=50,
                rsq_threshold=0.9,
                sym_threshold=0.05,
                ridge_lambda=0.01,
                model=dict(
                    num_units=4,
                    num_layers=4,
                    activation="tanh",
                ),
                optimizer=dict(name="adam",
                               kwargs=dict(learning_rate=1.5e-3, )),
            ),
            evaluation_vpt=dict(
                batch_size=5,
                batch_n=2,
                vpt_threshold=0.025,
            ))))

    # Training loop config.
    config.training_steps = int(500000)
    config.interval_type = "steps"
    config.log_tensors_interval = 50
    config.log_train_data_interval = 50
    config.log_all_train_data = False

    config.save_checkpoint_interval = 100
    config.checkpoint_dir = "/tmp/physics_inspired_models/"
    config.train_checkpoint_all_hosts = False
    config.eval_specific_checkpoint_dir = ""

    config.update_from_flattened_dict(sweeps[sweep_index])
    return config
Ejemplo n.º 5
0
def get_config():
    """Return config object for training."""
    config = base_config.get_base_config()

    # Experiment config.
    train_batch_size = 1024  # Global batch size.
    images_per_epoch = 1281167
    num_epochs = 90
    steps_per_epoch = images_per_epoch / train_batch_size
    config.training_steps = ((images_per_epoch * num_epochs) //
                             train_batch_size)
    config.random_seed = 0
    config.experiment_kwargs = config_dict.ConfigDict(
        dict(
            config=dict(
                lr=0.1,
                num_epochs=num_epochs,
                label_smoothing=0.1,
                model='ResNet',
                image_size=224,
                use_ema=False,
                ema_decay=0.9999,  # Quatros nuevos amigos
                ema_start=0,
                which_ema='tf1_ema',
                augment_name=None,  # 'mixup_cutmix',
                augment_before_mix=True,
                eval_preproc='crop_resize',
                train_batch_size=train_batch_size,
                eval_batch_size=50,
                eval_subset='test',
                num_classes=1000,
                which_dataset='imagenet',
                fake_data=False,
                which_loss='softmax_cross_entropy',  # For now, must be softmax
                transpose=True,  # Use the double-transpose trick?
                bfloat16=False,
                lr_schedule=dict(
                    name='WarmupCosineDecay',
                    kwargs=dict(num_steps=config.training_steps,
                                start_val=0,
                                min_val=0,
                                warmup_steps=5 * steps_per_epoch),
                ),
                lr_scale_by_bs=True,
                optimizer=dict(
                    name='SGD',
                    kwargs={
                        'momentum': 0.9,
                        'nesterov': True,
                        'weight_decay': 1e-4,
                    },
                ),
                model_kwargs=dict(
                    width=4,
                    which_norm='BatchNorm',
                    norm_kwargs=dict(
                        create_scale=True,
                        create_offset=True,
                        decay_rate=0.9,
                    ),  # cross_replica_axis='i'),
                    variant='ResNet50',
                    activation='relu',
                    drop_rate=0.0,
                ),
            ), ))

    # Training loop config: log and checkpoint every minute
    config.log_train_data_interval = 60
    config.log_tensors_interval = 60
    config.save_checkpoint_interval = 60
    config.eval_specific_checkpoint_dir = ''

    return config
Ejemplo n.º 6
0
def get_config(debug: bool = False) -> config_dict.ConfigDict:
    """Get Jaxline experiment config."""
    config = base_config.get_base_config()
    config.random_seed = 42
    # E.g. '/data/pretrained_models/k0_seed100' (and set k_fold_split_id=0, below)
    config.restore_path = config_dict.placeholder(str)
    config.experiment_kwargs = config_dict.ConfigDict(
        dict(config=dict(
            debug=debug,
            predictions_dir=config_dict.placeholder(str),
            # 5 for model selection and early stopping, 50 for final eval.
            num_eval_iterations_to_ensemble=5,
            dataset_kwargs=dict(
                data_root='/data/',
                online_subsampling_kwargs=dict(
                    max_nb_neighbours_per_type=[
                        [[40, 20, 0, 40], [0, 0, 0, 0], [0, 0, 0, 0]],
                        [[40, 20, 0, 40], [40, 0, 10, 0], [0, 0, 0, 0]],
                    ],
                    remove_future_nodes=True,
                    deduplicate_nodes=True,
                ),
                ratio_unlabeled_data_to_labeled_data=10.0,
                k_fold_split_id=config_dict.placeholder(int),
                use_all_labels_when_not_training=False,
                use_dummy_adjacencies=debug,
            ),
            optimizer=dict(
                name='adamw',
                kwargs=dict(weight_decay=1e-5, b1=0.9, b2=0.999),
                learning_rate_schedule=dict(
                    use_schedule=True,
                    base_learning_rate=1e-2,
                    warmup_steps=50000,
                    total_steps=config.get_ref('training_steps'),
                ),
            ),
            model_config=dict(
                mlp_hidden_sizes=[32] if debug else [512],
                latent_size=32 if debug else 256,
                num_message_passing_steps=2 if debug else 4,
                activation='relu',
                dropout_rate=0.3,
                dropedge_rate=0.25,
                disable_edge_updates=True,
                use_sent_edges=True,
                normalization_type='layer_norm',
                aggregation_function='sum',
            ),
            training=dict(
                loss_config=dict(bgrl_loss_config=dict(
                    stop_gradient_for_supervised_loss=False,
                    bgrl_loss_scale=1.0,
                    symmetrize=True,
                    first_graph_corruption_config=dict(
                        feature_drop_prob=0.4,
                        edge_drop_prob=0.2,
                    ),
                    second_graph_corruption_config=dict(
                        feature_drop_prob=0.4,
                        edge_drop_prob=0.2,
                    ),
                ), ),
                # GPU memory may require reducing the `256`s below to `48`.
                dynamic_batch_size_config=dict(
                    n_node=256 if debug else 340 * 256,
                    n_edge=512 if debug else 720 * 256,
                    n_graph=4 if debug else 256,
                ),
            ),
            eval=dict(
                split='valid',
                ema_annealing_schedule=dict(use_schedule=True,
                                            base_rate=0.999,
                                            total_steps=config.get_ref(
                                                'training_steps')),
                dynamic_batch_size_config=dict(
                    n_node=256 if debug else 340 * 128,
                    n_edge=512 if debug else 720 * 128,
                    n_graph=4 if debug else 128,
                ),
            ))))

    ## Training loop config.
    config.training_steps = 500000
    config.checkpoint_dir = '/tmp/checkpoint/mag/'
    config.train_checkpoint_all_hosts = False
    config.log_train_data_interval = 10
    config.log_tensors_interval = 10
    config.save_checkpoint_interval = 30
    config.best_model_eval_metric = 'accuracy'
    config.best_model_eval_metric_higher_is_better = True

    return config
Ejemplo n.º 7
0
def get_config():
    """Return config object for training."""
    use_debug_settings = IS_LOCAL
    config = base_config.get_base_config()

    # Experiment config.
    local_batch_size = 2
    # Modify this to adapt to your custom distributed learning setup
    num_devices = 1
    config.train_batch_size = local_batch_size * num_devices
    config.n_epochs = 110

    def _default_or_debug(default_value, debug_value):
        return debug_value if use_debug_settings else default_value

    n_train_examples = N_TRAIN_EXAMPLES
    num_classes = N_CLASSES

    config.experiment_kwargs = config_dict.ConfigDict(
        dict(config=dict(
            optimizer=dict(
                base_lr=5e-4,
                max_norm=10.0,  # < 0 to turn off.
                schedule_type='constant_cosine',
                weight_decay=1e-1,
                decay_pos_embs=True,
                scale_by_batch=True,
                cosine_decay_kwargs=dict(
                    init_value=0.0,
                    warmup_epochs=0,
                    end_value=0.0,
                ),
                step_decay_kwargs=dict(
                    decay_boundaries=[0.5, 0.8, 0.95],
                    decay_rate=0.1,
                ),
                constant_cosine_decay_kwargs=dict(
                    constant_fraction=0.5,
                    end_value=0.0,
                ),
                optimizer='lamb',
                # Optimizer-specific kwargs:
                adam_kwargs=dict(
                    b1=0.9,
                    b2=0.999,
                    eps=1e-8,
                ),
                lamb_kwargs=dict(
                    b1=0.9,
                    b2=0.999,
                    eps=1e-6,
                ),
            ),
            # Don't specify output_channels - it's not used for
            # classifiers.
            model=dict(
                perceiver_kwargs=dict(
                    input_preprocessor=dict(
                        prep_type='pixels',
                        # Channels for conv/conv1x1 preprocessing:
                        num_channels=64,
                        # -------------------------
                        # Position encoding arguments:
                        # -------------------------
                        position_encoding_type='fourier',
                        concat_or_add_pos='concat',
                        spatial_downsample=1,
                        # If >0, project position to this size:
                        project_pos_dim=-1,
                        trainable_position_encoding_kwargs=dict(
                            num_channels=258,  # Match default # for Fourier.
                            init_scale=0.02,
                        ),
                        fourier_position_encoding_kwargs=dict(
                            num_bands=64,
                            max_resolution=(224, 224),
                            sine_only=False,
                            concat_pos=True,
                        ),
                    ),
                    encoder=dict(
                        num_self_attends_per_block=_default_or_debug(6, 2),
                        # Weights won't be shared if num_blocks is set to 1.
                        num_blocks=_default_or_debug(8, 2),
                        z_index_dim=512,
                        num_z_channels=1024,
                        num_cross_attend_heads=1,
                        num_self_attend_heads=8,
                        cross_attend_widening_factor=1,
                        self_attend_widening_factor=1,
                        dropout_prob=0.0,
                        # Position encoding for the latent array.
                        z_pos_enc_init_scale=0.02,
                        cross_attention_shape_for_attn='kv',
                        use_query_residual=True,
                    ),
                    decoder=dict(
                        num_z_channels=1024,
                        use_query_residual=True,
                        # Position encoding for the output logits.
                        position_encoding_type='trainable',
                        trainable_position_encoding_kwargs=dict(
                            num_channels=1024,
                            init_scale=0.02,
                        ),
                    ),
                ), ),
            training=dict(images_per_epoch=n_train_examples,
                          label_smoothing=0.1,
                          n_epochs=config.get_oneway_ref('n_epochs'),
                          batch_size=config.get_oneway_ref(
                              'train_batch_size')),
            data=dict(
                num_classes=num_classes,
                # Run on smaller images to debug.
                im_dim=_default_or_debug(224, 32),
                augmentation=dict(
                    # Typical randaug params:
                    # num_layers in [1, 3]
                    # magnitude in [5, 30]
                    # Set randaugment to None to disable.
                    randaugment=dict(num_layers=4, magnitude=5),
                    cutmix=True,
                    # Mixup alpha should be in [0, 1].
                    # Set to None to disable.
                    mixup_alpha=0.2,
                ),
            ),
            evaluation=dict(
                subset='test',
                batch_size=2,
            ),
        )))

    # Training loop config.
    config.training_steps = get_training_steps(
        config.get_oneway_ref('train_batch_size'),
        config.get_oneway_ref('n_epochs'))
    config.log_train_data_interval = 60
    config.log_tensors_interval = 60
    config.save_checkpoint_interval = 300
    config.eval_specific_checkpoint_dir = ''
    config.best_model_eval_metric = 'eval_top_1_acc'
    config.checkpoint_dir = '/tmp/perceiver_imagnet_checkpoints'
    config.train_checkpoint_all_hosts = False

    # Prevents accidentally setting keys that aren't recognized (e.g. in tests).
    config.lock()

    return config
Ejemplo n.º 8
0
def get_config():
    """Return config object for training."""
    config = base_config.get_base_config()

    # Batch size, training steps and data.
    num_classes = 10
    num_epochs = 400
    # Gowal et al. (2020) and Rebuffi et al. (2021) use 1024 as batch size.
    # Reducing this batch size may require further adjustments to the batch
    # normalization decay or the learning rate. If you have to use a batch size
    # of 256, reduce the number of emulated workers to 1 (it should match the
    # results of using a batch size of 1024 with 4 workers).
    train_batch_size = 1024

    def steps_from_epochs(n):
        return max(int(n * 50_000 / train_batch_size), 1)

    num_steps = steps_from_epochs(num_epochs)
    test_batch_size = train_batch_size
    # Specify the path to the downloaded data. You can download data from
    # https://github.com/deepmind/deepmind-research/tree/master/adversarial_robustness.
    # If the path is set to "cifar10_ddpm.npz" and is not found in the current
    # directory, the corresponding data will be downloaded.
    extra_npz = 'cifar10_ddpm.npz'

    # Learning rate.
    learning_rate = .1 * max(train_batch_size / 256, 1.)
    learning_rate_warmup = steps_from_epochs(10)
    learning_rate_fn = utils.get_cosine_schedule(learning_rate, num_steps,
                                                 learning_rate_warmup)

    # Model definition.
    model_ctor = model_zoo.WideResNet
    model_kwargs = dict(num_classes=num_classes,
                        depth=28,
                        width=10,
                        activation='swish')

    # Attack used during training (can be None).
    epsilon = 8 / 255
    train_attack = attacks.UntargetedAttack(
        attacks.PGD(attacks.Adam(
            optax.piecewise_constant_schedule(init_value=.1,
                                              boundaries_and_scales={5: .1})),
                    num_steps=10,
                    initialize_fn=attacks.linf_initialize_fn(epsilon),
                    project_fn=attacks.linf_project_fn(epsilon,
                                                       bounds=(0., 1.))),
        loss_fn=attacks.untargeted_kl_divergence)

    # Attack used during evaluation (can be None).
    eval_attack = attacks.UntargetedAttack(attacks.PGD(
        attacks.Adam(learning_rate_fn=optax.piecewise_constant_schedule(
            init_value=.1, boundaries_and_scales={
                20: .1,
                30: .01
            })),
        num_steps=40,
        initialize_fn=attacks.linf_initialize_fn(epsilon),
        project_fn=attacks.linf_project_fn(epsilon, bounds=(0., 1.))),
                                           loss_fn=attacks.untargeted_margin)

    config.experiment_kwargs = config_dict.ConfigDict(
        dict(config=dict(
            epsilon=epsilon,
            num_classes=num_classes,
            # Results from various publications use 4 worker machines, which results
            # in slight differences when using less worker machines. To compensate for
            # such discrepancies, we emulate these additional workers. Set to zero,
            # when using more than 4 workers.
            emulated_workers=4,
            dry_run=False,
            save_final_checkpoint_as_npy=True,
            model=dict(constructor=model_ctor, kwargs=model_kwargs),
            training=dict(batch_size=train_batch_size,
                          learning_rate=learning_rate_fn,
                          weight_decay=5e-4,
                          swa_decay=.995,
                          use_cutmix=False,
                          supervised_batch_ratio=.3,
                          extra_data_path=extra_npz,
                          extra_label_smoothing=.1,
                          attack=train_attack),
            evaluation=dict(
                # If `interval` is positive, synchronously evaluate at regular
                # intervals. Setting it to zero will not evaluate while training,
                # unless `--jaxline_mode` is set to `train_eval_multithreaded`, which
                # asynchronously evaluates checkpoints.
                interval=steps_from_epochs(40),
                batch_size=test_batch_size,
                attack=eval_attack),
        )))

    config.checkpoint_dir = '/tmp/jaxline/robust'
    config.train_checkpoint_all_hosts = False
    config.training_steps = num_steps
    config.interval_type = 'steps'
    config.log_train_data_interval = steps_from_epochs(.5)
    config.log_tensors_interval = steps_from_epochs(.5)
    config.save_checkpoint_interval = steps_from_epochs(40)
    config.eval_specific_checkpoint_dir = ''
    return config