Exemple #1
0
def get_config():
    """Get the hyperparameter configuration."""
    config = base_cifar10_config.get_config()
    config.random_seed = 1
    config.model_type = "transformer"
    config.learning_rate = .00025
    config.batch_size = 96
    config.eval_frequency = 4 * TRAIN_EXAMPLES // config.batch_size
    config.num_train_steps = (TRAIN_EXAMPLES // config.batch_size) * NUM_EPOCHS
    config.num_eval_steps = VALID_EXAMPLES // config.batch_size
    config.factors = 'constant * linear_warmup * cosine_decay'
    config.warmup = (TRAIN_EXAMPLES // config.batch_size) * 1

    config.model.dropout_rate = 0.3
    config.model.attention_dropout_rate = 0.2
    config.model.learn_pos_emb = True
    config.model.num_layers = 1
    config.model.emb_dim = 128
    config.model.qkv_dim = 64
    config.model.mlp_dim = 128
    config.model.num_heads = 8
    config.model.classifier_pool = "CLS"
    config.model.add_pos_emb = False
    num_realizations = 32
    config.model.qk_transform_fn_factory = functools.partial(
        make_spe_transform_fn,
        spe_cls=spe.SineSPE,
        spe_kwargs=dict(num_realizations=num_realizations, num_sines=10))
    config.attention_fn = favor.make_fast_softmax_attention(
        qkv_dim=num_realizations, lax_scan_unroll=16)
    return config
Exemple #2
0
def get_config():
    """Get the hyperparameter configuration."""
    config = base_cifar10_config.get_config()
    config.random_seed = 0
    config.model_type = "transformer"
    config.learning_rate = .00025
    config.batch_size = 96
    config.eval_frequency = TRAIN_EXAMPLES // config.batch_size
    config.num_train_steps = (TRAIN_EXAMPLES // config.batch_size) * NUM_EPOCHS
    config.num_eval_steps = VALID_EXAMPLES // config.batch_size
    config.factors = 'constant * linear_warmup * cosine_decay'
    config.warmup = (TRAIN_EXAMPLES // config.batch_size) * 1

    config.model.dropout_rate = 0.3
    config.model.attention_dropout_rate = 0.2
    config.model.learn_pos_emb = True
    config.model.num_layers = 1
    config.model.emb_dim = 128
    config.model.qkv_dim = 64
    config.model.mlp_dim = 128
    config.model.num_heads = 8
    config.model.classifier_pool = "CLS"

    config.attention_fn = favor.make_fast_generalized_attention(
        qkv_dim=config.model.qkv_dim // config.model.num_heads,
        features_type='deterministic',
        kernel_fn=jax.nn.relu,
        lax_scan_unroll=16)
    return config
def get_config():
    """Get the hyperparameter configuration."""
    config = base_cifar10_config.get_config()
    config.model_type = "sinkhorn"
    config.model.block_size = 64
    config.model.num_layers = 1
    config.weight_decay = 0.7
    config.model.dropout_rate = 0.1
    return config
def get_config():
    """Get the hyperparameter configuration."""
    config = base_cifar10_config.get_config()
    config.model_type = "longformer"
    config.model.num_layers = 4
    config.model.emb_dim = 128
    config.model.qkv_dim = 64
    config.model.mlp_dim = 128
    config.model.num_heads = 4
    config.model.classifier_pool = "MEAN"
    return config
def get_config():
    """Get the hyperparameter configuration."""
    config = base_cifar10_config.get_config()
    config.model_type = "bigbird"
    config.model.block_size = 64
    config.model.num_layers = 1
    config.model.emb_dim = 128
    config.model.qkv_dim = 64
    config.model.mlp_dim = 128
    config.model.num_head = 4
    config.model.classifier_pool = "CLS"
    return config
def get_config():
    """Get the hyperparameter configuration."""
    config = base_cifar10_config.get_config()
    config.model_type = "synthesizer"
    config.model.synthesizer_mode = "random"
    config.model.num_layers = 1
    config.model.emb_dim = 128
    config.model.qkv_dim = 64
    config.model.mlp_dim = 128
    config.model.num_heads = 8
    config.model.classifier_pool = "CLS"

    return config
Exemple #7
0
def get_config():
  """Get the hyperparameter configuration."""
  config = base_cifar10_config.get_config()
  config.random_seed = 0
  config.model_type = "transformer"
  config.learning_rate = .00019
  config.batch_size = 96
  config.factors = 'constant * linear_warmup * cosine_decay'
  config.warmup = (TRAIN_EXAMPLES // config.batch_size) * 1
  config.model.dropout_rate = 0.3
  config.model.attention_dropout_rate = 0.2
  config.model.learn_pos_emb = True
  config.model.num_layers = 1
  config.model.emb_dim = 128
  config.model.qkv_dim = 64
  config.model.mlp_dim = 128
  config.model.num_heads = 8
  config.model.classifier_pool = "CLS"
  config.attention_fn = favor.make_fast_softmax_attention(
    qkv_dim=config.model.qkv_dim // config.model.num_heads,
    lax_scan_unroll=16)
  return config
def get_config():
  """Get the hyperparameter configuration."""
  config = base_cifar10_config.get_config()
  config.model_type = "transformer"
  return config