Ejemplo n.º 1
0
    def test_add_decayed_weights_with_mask(self):
        """Test mask is not added for add_decayed_weights if specified in hps."""
        class Foo(nn.Module):
            """Dummy model."""

            train: bool
            filters: int

            @nn.compact
            def __call__(self, x):
                x = nn.Conv(self.filters, (1, 1),
                            use_bias=False,
                            dtype=jnp.float32)(x)
                x = nn.BatchNorm(use_running_average=not self.train,
                                 momentum=0.9,
                                 epsilon=1e-5,
                                 dtype=jnp.float32)(x)
                return x

        tx = from_hparams(
            ml_collections.ConfigDict({
                '0': {
                    'element': 'add_decayed_weights',
                    'hps': {
                        'weight_decay': 1e-4,
                        'mask': 'bias_bn'
                    }
                }
            }))
        key = jax.random.PRNGKey(0)
        x = jnp.ones((5, 4, 4, 3))
        y = jax.random.uniform(key, (5, 4, 4, 7))

        foo_vars = flax.core.unfreeze(Foo(filters=7, train=True).init(key, x))

        @self.variant
        def train_step(params, x, y):
            y1, new_batch_stats = Foo(filters=7, train=True).apply(
                params, x, mutable=['batch_stats'])

            return jnp.abs(y - y1).sum(), new_batch_stats

        state = self.variant(tx.init)(foo_vars['params'])
        grads, _ = jax.grad(train_step, has_aux=True)(foo_vars, x, y)
        updates, state = self.variant(tx.update)(dict(grads['params']), state,
                                                 foo_vars['params'])

        chex.assert_trees_all_close(updates['BatchNorm_0'],
                                    grads['params']['BatchNorm_0'])
Ejemplo n.º 2
0
def get_config():
    """Get the default hyperparameter configuration."""
    config = ml_collections.ConfigDict()

    # General
    config.debug = False
    config.experiment_name = ''  # Set in train.sh
    config.on_gcp = True
    config.local_output_dir = '/tmp'
    config.use_gpu = True
    config.random_seed = 42

    # Model params
    config.model_key = 'resnet18'
    config.cascaded = True
    config.tdl_mode = 'OSD'  # OSD, EWS, noise
    config.tdl_alpha = 0.0
    config.noise_var = 0.0
    config.bn_time_affine = False
    config.bn_time_stats = True

    # Loss params
    config.lambda_val = 0.0
    config.normalize_loss = False

    # Dataset params
    config.dataset_name = 'CIFAR10'  # CIFAR10, CIFAR100, TinyImageNet
    config.val_split = 0.1
    config.split_idxs_root = None
    config.augmentation_noise_type = 'occlusion'
    config.num_workers = 16
    config.drop_last = False

    # Train params
    config.batch_size = 128
    config.epochs = 100
    config.eval_freq = 10
    config.upload_freq = 5

    # Optimizer params
    config.learning_rate = 0.1
    config.momentum = 0.9
    config.weight_decay = 0.0005
    config.nesterov = True

    config.lr_milestones = [30, 60, 120, 150]
    config.lr_schedule_gamma = 0.2

    return config
Ejemplo n.º 3
0
    def test_embedding_layer(self):
        config = ml_collections.ConfigDict({
            "batch_size": 3,
            "vocab_size": 1000,
            "d_emb": 32,
            "max_seq_length": 64,
            "type_vocab_size": 2,
            "d_model": 4,
            "dropout_rate": 0.1,
            "dtype": jnp.float32
        })
        frozen_config = ml_collections.FrozenConfigDict(config)
        rng = jax.random.PRNGKey(100)

        embedding_layer = layers.EmbeddingLayer(config=frozen_config)
        init_batch = {
            "input_ids": jnp.ones((1, frozen_config.max_seq_length),
                                  jnp.int32),
            "type_ids": jnp.ones((1, frozen_config.max_seq_length), jnp.int32)
        }
        params = init_layer_variables(rng, embedding_layer,
                                      init_batch)["params"]

        expected_keys = {
            "word", "position", "type", "layer_norm", "hidden_mapping_in"
        }
        self.assertEqual(params.keys(), expected_keys)

        rng, init_rng = jax.random.split(rng)
        inputs = {
            "input_ids":
            jax.random.randint(
                init_rng,
                (frozen_config.batch_size, frozen_config.max_seq_length),
                minval=0,
                maxval=13),
            "type_ids":
            jax.random.randint(
                init_rng,
                (frozen_config.batch_size, frozen_config.max_seq_length),
                minval=0,
                maxval=2)
        }
        outputs = embedding_layer.apply({"params": params},
                                        rngs={"dropout": rng},
                                        **inputs)
        self.assertEqual(outputs.shape,
                         (frozen_config.batch_size,
                          frozen_config.max_seq_length, frozen_config.d_model))
Ejemplo n.º 4
0
def get_r50_l16_config():
    """Returns the Resnet50 + ViT-L/16 configuration. customized """
    config = get_l16_config()
    config.patches.grid = (16, 16)
    config.resnet = ml_collections.ConfigDict()
    config.resnet.num_layers = (3, 4, 9)
    config.resnet.width_factor = 1

    config.classifier = 'seg'
    config.resnet_pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz'
    config.decoder_channels = (256, 128, 64, 16)
    config.skip_channels = [512, 256, 64, 16]
    config.n_classes = 2
    config.activation = 'softmax'
    return config
def get_config():
  """Get the default hyperparameter configuration."""
  config = ml_collections.ConfigDict()
  config.batch_size = 512
  config.eval_frequency = TRAIN_EXAMPLES // config.batch_size
  config.num_train_steps = (TRAIN_EXAMPLES // config.batch_size) * NUM_EPOCHS
  config.num_eval_steps = VALID_EXAMPLES // config.batch_size
  config.weight_decay = 0.
  config.grad_clip_norm = None

  config.save_checkpoints = True
  config.restore_checkpoints = True
  config.checkpoint_freq = (TRAIN_EXAMPLES //
                            config.batch_size) * NUM_EPOCHS // 2
  config.random_seed = 0

  config.learning_rate = .001
  config.factors = 'constant * linear_warmup * cosine_decay'
  config.warmup = (TRAIN_EXAMPLES // config.batch_size) * 1
  config.steps_per_cycle = (TRAIN_EXAMPLES // config.batch_size) * NUM_EPOCHS

  # model params
  config.model = ml_collections.ConfigDict()
  config.model.num_layers = 1
  config.model.num_heads = 2
  config.model.emb_dim = 32
  config.model.dropout_rate = 0.1

  config.model.qkv_dim = config.model.emb_dim // 2
  config.model.mlp_dim = config.model.qkv_dim * 2
  config.model.attention_dropout_rate = 0.1
  config.model.classifier_pool = 'MEAN'
  config.model.learn_pos_emb = False

  config.trial = 0  # dummy for repeated runs.
  return config
def get_config(config_string):
    """Return an instance of ConfigDict depending on `config_string`."""
    possible_structures = {
        'linear':
        ml_collections.ConfigDict({
            'model_constructor':
            'snt.Linear',
            'model_config':
            ml_collections.ConfigDict({
                'output_size': 42,
            })
        }),
        'lstm':
        ml_collections.ConfigDict({
            'model_constructor':
            'snt.LSTM',
            'model_config':
            ml_collections.ConfigDict({
                'hidden_size': 108,
            })
        })
    }

    return possible_structures[config_string]
Ejemplo n.º 7
0
def get_fp_quant_hparams_config(quantized, quantized_reductions):
    """Create appropriate setting for 'QuantHParams' field used by softmax and LayerNormAQT."""
    if quantized_reductions and not quantized:
        raise ValueError("If `quantized` is False, `quantized_reductions` "
                         "must also be False.")
    if quantized:
        quant_hparams = ml_collections.ConfigDict({
            "prec":
            config_schema_utils.get_fp_config(),
            "reduction_prec":
            config_schema_utils.get_fp_config()
            if quantized_reductions else None
        })
    else:
        quant_hparams = None
    return quant_hparams
Ejemplo n.º 8
0
def get_r50_b16_config():
    """Returns the Resnet50 + ViT-B/16 configuration."""
    config = get_b16_config()
    config.patches.grid = (16, 16)
    config.resnet = ml_collections.ConfigDict()
    config.resnet.num_layers = (3, 4, 9)
    config.resnet.width_factor = 1

    config.classifier = "seg"
    config.decoder_channels = (256, 128, 64, 16)
    config.skip_channels = [512, 256, 64, 16]
    config.n_classes = 4
    config.n_skip = 3
    config.activation = "softmax"

    return config
Ejemplo n.º 9
0
def get_eval_config():
  """Configuration relation to model evaluation."""
  eval_config = ml_collections.ConfigDict()

  eval_config.eval_once = False
  eval_config.save_output = True
  # the size of chunks for evaluation inferences,
  # set to the value that fits your GPU/TPU memory.
  eval_config.chunk = 4096
  eval_config.inference = False

  eval_config.mvsn_style = False
  eval_config.return_coarse = False
  eval_config.checkpoint_step = -1

  return eval_config
Ejemplo n.º 10
0
def set_dataset_config(dataset, config):
  """Sets dataset-related configs."""
  config.dataset = mlc.ConfigDict()
  if dataset == "imagenet2012":
    config.dataset_name = dataset
    config.dataset.train_split = "train"
    config.dataset.val_split = "validation"
    config.dataset.num_classes = 1000
    config.dataset.input_shape = (224, 224, 3)
  else:
    assert dataset == "cifar10"
    config.dataset_name = dataset
    config.dataset.train_split = "train"
    config.dataset.val_split = "test"
    config.dataset.num_classes = 10
    config.dataset.input_shape = (32, 32, 3)
Ejemplo n.º 11
0
def get_config():
    """Get the default hyperparameter configuration."""
    config = ml_collections.ConfigDict()

    config.learning_rate = 0.1
    config.momentum = 0.9
    config.batch_size = 128
    config.num_epochs = 10

    config.cache = False
    config.half_precision = False

    # If num_train_steps==-1 then the number of training steps is calculated from
    # num_epochs using the entire dataset. Similarly for steps_per_eval.
    config.num_train_steps = -1
    config.steps_per_eval = -1
    return config
Ejemplo n.º 12
0
  def test_train_and_evaluate(self):
    workdir = tempfile.gettempdir()

    config = common.get_config()
    config.model = models.get_testing_config()
    config.dataset = 'cifar10'
    config.pp = ml_collections.ConfigDict(
        {'train': 'train[:98%]', 'test': 'test', 'resize': 448, 'crop': 384})
    config.batch = 64
    config.accum_steps = 2
    config.batch_eval = 8
    config.total_steps = 1
    config.pretrained_dir = workdir

    test_utils.create_checkpoint(config.model, f'{workdir}/testing.npz')
    opt_pmap = train.train_and_evaluate(config, workdir)
    self.assertTrue(os.path.exists(f'{workdir}/model.npz'))
Ejemplo n.º 13
0
def get_r50_b16_config():
    """Returns the Resnet50 + ViT-B/16 configuration."""
    config = get_b16_config()
    config.patches.grid = (16, 16)
    config.resnet = ml_collections.ConfigDict()
    config.resnet.num_layers = (3, 4, 9)
    config.resnet.width_factor = 1

    config.classifier = 'seg'
    config.pretrained_path = '/home/cjj/Documents/TransUNet/model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz'
    config.decoder_channels = (256, 128, 64, 16)
    config.skip_channels = [512, 256, 64, 16]
    config.n_classes = 19
    config.n_skip = 3
    config.activation = 'softmax'

    return config
Ejemplo n.º 14
0
def get_block_config(parent_config, block_kind):
    """Create a ConfigDict corresponding to wmt_mlperf.models.Encoder[Decoder]1DBlock.HParams."""
    config = ml_collections.ConfigDict()
    config_schema_utils.set_default_reference(config, parent_config,
                                              "mlp_block")
    if block_kind == BlockKind.encoder:
        config_schema_utils.set_default_reference(config, parent_config,
                                                  "attention")
    elif block_kind == BlockKind.decoder:
        config_schema_utils.set_default_reference(
            config,
            parent_config, ["self_attention", "enc_dec_attention"],
            parent_field="attention")
    else:
        raise ValueError(f"Unknown block_kind {block_kind}")
    config.lock()
    return config
Ejemplo n.º 15
0
def get_r50_b16_config():
    """Returns the Resnet50 + ViT-B/16 configuration."""
    config = get_b16_config()
    config.patches.grid = (16, 16)
    config.resnet = ml_collections.ConfigDict()
    config.resnet.num_layers = (3, 4, 9)
    config.resnet.width_factor = 1

    config.classifier = 'seg'
    config.pretrained_path = '/home/viplab/data/R50+ViT-B_16.npz'
    config.decoder_channels = (256, 128, 64, 16)
    config.skip_channels = [512, 256, 64, 16]
    config.n_classes = 2
    config.n_skip = 3
    config.activation = 'softmax'

    return config
Ejemplo n.º 16
0
def get_residual_config(parent_config):
    """Creates ConfigDict corresponding to imagenet.models.ResidualBlock.HParams."""
    config = ml_collections.ConfigDict()
    config_schema_utils.set_default_reference(
        config,
        parent_config, ["conv_proj", "conv_1", "conv_2", "conv_3"],
        parent_field="conv")
    # TODO(b/179063860): The input distribution is an intrinsic model
    # property and shouldn't be part of the model configuration. Update
    # the hparam dataclasses to eliminate the input_distribution field and
    # then delete this.
    config.conv_proj.quant_act.input_distribution = "positive"
    config.conv_2.quant_act.input_distribution = "positive"
    config.conv_3.quant_act.input_distribution = "positive"

    config.lock()
    return config
Ejemplo n.º 17
0
def get_config():
    """Return the default configuration."""
    config = ml_collections.ConfigDict()

    config.num_steps = 10000  # Number of training steps to perform.
    config.batch_size = 128  # Batch size.
    config.learning_rate = 0.01  # Learning rate

    # Number of samples to draw for prediction.
    config.num_prediction_samples = 500

    # Batch size to use for prediction. Ideally as big as possible, but may need
    # to be reduced for memory reasons depending on the value of
    # `num_prediction_samples`.
    config.prediction_batch_size = 500

    # Multiplier for the likelihood term in the loss
    config.likelihood_multiplier = 5.

    # Multiplier for the MMD constraint term in the loss
    config.constraint_multiplier = 0.

    # Scaling factor to use in KL term.
    config.beta = 1.0

    # The number of samples we draw from each latent variable distribution.
    config.mmd_sample_size = 100

    # Directory into which results should be placed. By default it is the empty
    # string, in which case no saving will occur. The directory specified will be
    # created if it does not exist.
    config.output_dir = ''

    # The index of the step at which to turn on the constraint multiplier. For
    # steps prior to this the multiplier will be zero.
    config.constraint_turn_on_step = 0

    # The random seed for tensorflow that is applied to the graph iff the value is
    # non-negative. By default the seed is not constrained.
    config.seed = -1

    # When doing fair inference, don't sample when given a sample for the baseline
    # gender.
    config.baseline_passthrough = False

    return config
Ejemplo n.º 18
0
    def load_splits(self):
        """Load dataset splits using tfds loader."""
        self.data_iters = ml_collections.ConfigDict()
        for key, split in self.splits.items():
            logging.info('Loading %s  split of the %s dataset.', split.name,
                         self.name)
            ds, num_examples = self.load_split_from_tfds(
                name=self.name,
                batch_size=split['batch_size'],
                train=split['train'],
                split=split.name,
                shuffle_seed=self.shuffle_seed)

            self.splits[key].num_examples = num_examples

            self.data_iters[key] = self.create_data_iter(
                ds, split['batch_size'])
Ejemplo n.º 19
0
def get_config():
    """Get the default hyperparameter configuration."""
    config = ml_collections.ConfigDict()
    config.seed = 1

    config.dataset = "imagenet-lt"
    config.model_name = "resnet50"
    config.sampling = "uniform"

    config.add_color_jitter = True
    config.loss = "ce"

    config.learning_rate = 0.1
    config.learning_rate_schedule = "cosine"
    config.warmup_epochs = 0
    config.sgd_momentum = 0.9

    if config.dataset == "imagenet-lt":
        config.weight_decay = 0.0005
        config.num_epochs = 90
    elif config.dataset == "inaturalist18":
        config.weight_decay = 0.0002
        config.num_epochs = 200
    else:
        raise ValueError(f"Dataset {config.dataset} not supported.")

    config.global_batch_size = 128
    # If num_train_steps==-1 then the number of training steps is calculated from
    # num_epochs.
    config.num_train_steps = -1
    config.num_eval_steps = -1

    config.log_loss_every_steps = 500
    config.eval_every_steps = 1000
    config.checkpoint_every_steps = 1000
    config.shuffle_buffer_size = 1000

    config.trial = 0  # dummy for repeated runs.

    # Distillation parameters
    config.proj_dim = -1
    config.distill_teacher = ""
    config.distill_alpha = 0.0
    config.distill_fd_beta = 0.0

    return config
Ejemplo n.º 20
0
def get_b16_none():
    """Returns the ViT-B/16 configuration."""
    config = ml_collections.ConfigDict()
    config.pretrained_filename = "ViT-B_16.npz"
    config.image_size = 224
    config.patch_size = 16
    config.n_layers = 12
    config.hidden_size = 768
    config.n_heads = 12
    config.name = "B16_None"
    config.mlp_dim = 3072
    config.dropout = 0.1
    config.filters = 9
    config.kernel_size = 1
    config.upsampling_factor = 16
    config.hybrid = False
    return config
Ejemplo n.º 21
0
    def set_splits(self):
        """Define splits of the dataset.

    For multi environment datasets, we have tree main splits:
    test, train and valid.
    Each of these splits is a dict mapping environment id to the information
    about that particular split of that dataset:

    self.splits
    |__Test
    |  |__Eval env 1
    |  |__Eval env 2
    |
    |__ Train
    |  |__ Train env 1
    |  |__ Train env 2
    |
    |__ Valied
       |__Eval env 1
       |__Eval env 2
    """

        train_split_configs = ml_collections.ConfigDict()
        valid_split_configs = ml_collections.ConfigDict()
        test_split_configs = ml_collections.ConfigDict()

        for env_name in self.train_environments:
            env_id = self.env2id(env_name)
            train_split_configs[str(env_id)] = ml_collections.ConfigDict(
                dict(name=self.get_env_split_name('train', env_name),
                     batch_size=self.batch_size,
                     train=True))

        for env_name in self.eval_environments:
            env_id = self.env2id(env_name)
            valid_split_configs[str(env_id)] = ml_collections.ConfigDict(
                dict(name=self.get_env_split_name('validation', env_name),
                     batch_size=self.eval_batch_size,
                     train=False))
            test_split_configs[str(env_id)] = ml_collections.ConfigDict(
                dict(name=self.get_env_split_name('test', env_name),
                     batch_size=self.eval_batch_size,
                     train=False))

        self.splits = ml_collections.ConfigDict(
            dict(test=test_split_configs,
                 validation=valid_split_configs,
                 train=train_split_configs))
Ejemplo n.º 22
0
def get_config():
    """Returns a training configuration."""
    config = ml_collections.ConfigDict()
    config.rng_seed = 0
    config.num_trajectories = 1
    config.single_step_predictions = True
    config.num_samples = 1000
    config.split_on = 'times'
    config.train_split_proportion = 80 / 1000
    config.time_delta = 1.
    config.train_time_jump_range = (1, 10)
    config.test_time_jumps = (1, 2, 5, 10, 20, 50)
    config.num_train_steps = 5000
    config.latent_size = 100
    config.activation = 'sigmoid'
    config.model = 'action-angle-network'
    config.encoder_decoder_type = 'flow'
    config.flow_type = 'shear'
    config.num_flow_layers = 10
    config.num_coordinates = 2
    if config.flow_type == 'masked_coupling':
        config.flow_spline_range_min = -3
        config.flow_spline_range_max = 3
        config.flow_spline_bins = 100
    config.polar_action_angles = True
    config.scaler = 'identity'
    config.learning_rate = 1e-3
    config.batch_size = 100
    config.eval_cadence = 50
    config.simulation = 'shm'
    config.regularizations = ml_collections.FrozenConfigDict({
        'actions':
        1.,
        'angular_velocities':
        0.,
        'encoded_decoded_differences':
        0.,
    })
    config.simulation_parameter_ranges = ml_collections.FrozenConfigDict({
        'phi': (0, 0),
        'A': (1, 10),
        'm': (1, 5),
        'w': (0.05, 0.1),
    })
    return config
    def test_reference_to_self(self):
        # Test adding a new field to a configdict which is a reference to an
        # existing field in the same configdict instance.
        config = ml_collections.ConfigDict({'parent': 1})
        config_schema_utils.set_default_reference(config,
                                                  config,
                                                  'child',
                                                  parent_field='parent')
        self.assertEqual(config.child, 1)
        self.assertEqual(config.parent, 1)

        config.parent = 5
        self.assertEqual(config.parent, 5)
        self.assertEqual(config.child, 5)

        config.child = 10
        self.assertEqual(config.parent, 5)
        self.assertEqual(config.child, 10)
Ejemplo n.º 24
0
  def setUp(self):
    super().setUp()
    self.config = ml_collections.ConfigDict(self.config)

    self.model_config = self.config.model_config
    encoder_config = self.model_config.encoder_config

    self.max_length = encoder_config.max_length
    self.max_sample_mentions = self.config.max_sample_mentions
    self.collater_fn = text_classifier.TextClassifier.make_collater_fn(
        self.config)
    self.postprocess_fn = text_classifier.TextClassifier.make_output_postprocess_fn(
        self.config)

    model = text_classifier.TextClassifier.build_model(self.model_config)
    dummy_input = text_classifier.TextClassifier.dummy_input(self.config)
    init_rng = jax.random.PRNGKey(0)
    self.init_parameters = model.init(init_rng, dummy_input, True)
Ejemplo n.º 25
0
def get_lit_l16ti_config():
  """Returns a LiT model with ViT-Large and tiny text towers."""
  config = ml_collections.ConfigDict()
  config.model_name = 'LiT-L16Ti'
  config.out_dim = (None, 1024)
  config.image = get_l16_config()
  config.text_model = 'text_transformer'
  config.text = {}
  config.text.width = 192
  config.text.num_layers = 12
  config.text.mlp_dim = 768
  config.text.num_heads = 3
  config.text.vocab_size = 16_000
  config.pp = {}
  config.pp.tokenizer_name = 'sentencepiece'
  config.pp.size = 224
  config.pp.max_len = 16
  return config
Ejemplo n.º 26
0
def main(_):
    cfg = ml_collections.ConfigDict()
    cfg.integer_field = 123

    # Locking prohibits the addition and deletion of new fields but allows
    # modification of existing values. Locking happens automatically during
    # loading through flags.
    cfg.lock()
    try:
        cfg.intagar_field = 124  # Raises AttributeError and suggests valid field.
    except AttributeError as e:
        print(e)
    cfg.integer_field = -123  # Works fine.

    with cfg.unlocked():
        cfg.intagar_field = 1555  # Works fine too.

    print(cfg)
Ejemplo n.º 27
0
def get_config():
    """Get the default hyperparameter configuration."""
    config = ml_collections.ConfigDict()

    # The initial learning rate.
    config.learning_rate = 0.001

    # Learning rate decay, applied each optimization step.
    config.lr_decay = 0.999995

    # Batch size to use for data-dependent initialization.
    config.init_batch_size = 16

    # Batch size for training.
    config.batch_size = 64

    # Number of training epochs.
    config.num_epochs = 200

    # Dropout rate.
    config.dropout_rate = 0.5

    # Number of resnet layers per block.
    config.n_resnet = 5

    # Number of features in each conv layer.
    config.n_feature = 160

    # Number of components in the output distribution.
    config.n_logistic_mix = 10

    # Exponential decay rate of the sum of previous model iterates during Polyak
    # averaging.
    config.polyak_decay = 0.9995

    # Batch size for sampling.
    config.sample_batch_size = 256
    # Random number generator seed for sampling.
    config.sample_rng_seed = 0

    # Integer for PRNG random seed.
    config.seed = 0

    return config
Ejemplo n.º 28
0
def get_config():
    default_config = dict(
        # General parameters
        function="",
        # Specifies the function to be approximated, e.g., SineShirp
        max=0.0,
        # Specifies the maximum value on which the function will be evaluated. e.g., -15.0.
        min=0.0,
        # Specifies the maximum value on which the function will be evaluated. e.g., 0.0.
        no_samples=0,
        # Specifies the number of samples that will be taken from the selected function,
        # between the min and max values.
        padding=0,
        # Specifies an amount of zero padding steps which will be concatenated at the end of
        # the sequence created by function(min, max, no_samples).
        optim="",
        # The optimizer to be used, e.g., Adam.
        lr=0.0,
        # The lr to be used, e.g., 0.001.
        no_iterations=0,
        # The number of training iterations to be executed, e.g., 20000.
        seed=0,
        # The seed of the run. e.g., 0.
        device="",
        # The device in which the model will be deployed, e.g., cuda.
        # Parameters of ConvKernel
        kernelnet_norm_type="",
        # If model == CKCNN, the normalization type to be used in the MLPs parameterizing the convolutional
        # kernels. If kernelnet_activation_function==Sine, no normalization will be used. e.g., LayerNorm.
        kernelnet_activation_function="",
        # If model == CKCNN, the activation function used in the MLPs parameterizing the convolutional
        # kernels. e.g., Sine.
        kernelnet_no_hidden=0,
        # If model == CKCNN, the number of hidden units used in the MLPs parameterizing the convolutional
        # kernels. e.g., 32.
        # Parameters of SIREN
        kernelnet_omega_0=0.0,
        # If model == CKCNN, kernelnet_activation_function==Sine, the value of the omega_0 parameter, e.g., 30.
        comment="",
        # An additional comment to be added to the config.path parameter specifying where
        # the network parameters will be saved / loaded from.
    )
    default_config = ml_collections.ConfigDict(default_config)
    return default_config
Ejemplo n.º 29
0
def get_config():
    """Get the default hyperparameter configuration."""
    config = ml_collections.ConfigDict()

    # base output directory for experiments (checkpoints, summaries), './output' if not valid
    config.output_base_dir = ''

    config.data_dir = '/data/'  # use --config.data_dir arg to set without modifying config file
    config.dataset = 'imagenet2012:5.0.0'
    config.num_classes = 1000  # FIXME not currently used

    config.model = 'tf_efficientnet_b0'
    config.image_size = 0  # set from model defaults if 0
    config.batch_size = 224
    config.eval_batch_size = 100  # set to config.bach_size if 0
    config.lr = 0.016
    config.label_smoothing = 0.1
    config.weight_decay = 1e-5  # l2 weight penalty added to loss
    config.ema_decay = .99997

    config.opt = 'rmsproptf'
    config.opt_eps = .001
    config.opt_beta1 = 0.9
    config.opt_beta2 = 0.9
    config.opt_weight_decay = 0.  # by default, weight decay not applied in opt, l2 penalty above is used

    config.lr_schedule = 'step'
    config.lr_decay_rate = 0.97
    config.lr_decay_epochs = 2.4
    config.lr_warmup_epochs = 5.
    config.lr_minimum = 1e-6
    config.num_epochs = 450

    config.cache = False
    config.half_precision = True

    config.drop_rate = 0.2
    config.drop_path_rate = 0.2

    # If num_train_steps==-1 then the number of training steps is calculated from
    # num_epochs using the entire dataset. Similarly for steps_per_eval.
    config.num_train_steps = -1
    config.steps_per_eval = -1
    return config
Ejemplo n.º 30
0
def get_cifar_config():
    config = mlc.ConfigDict()

    config.seed = 0

    ###################
    # Train Config

    config.train = mlc.ConfigDict()
    config.train.seed = 0
    config.train.epochs = 90
    config.train.device_batch_size = 64
    config.train.log_epochs = 1

    # Dataset section
    config.train.dataset_name = "cifar10"
    config.train.dataset = mlc.ConfigDict()
    config.train.dataset.train_split = "train"
    config.train.dataset.val_split = "test"
    config.train.dataset.num_classes = 10
    config.train.dataset.input_shape = (32, 32, 3)

    # Model section
    config.train.model_name = "resnet18"
    config.train.model = mlc.ConfigDict()

    # Optimizer section
    config.train.optim = mlc.ConfigDict()
    config.train.optim.optax_name = "trace"  # momentum
    config.train.optim.optax = mlc.ConfigDict()
    config.train.optim.optax.decay = 0.9
    config.train.optim.optax.nesterov = False

    config.train.optim.wd = 1e-4
    config.train.optim.wd_mults = [(".*", 1.0)]
    config.train.optim.grad_clip_norm = 1.0

    # Learning rate section
    config.train.optim.lr = 0.1
    config.train.optim.schedule = mlc.ConfigDict()
    config.train.optim.schedule.warmup_epochs = 5
    config.train.optim.schedule.decay_type = "cosine"
    # Base batch-size being 256.
    config.train.optim.schedule.scale_with_batchsize = True

    return config