Пример #1
0
def get_config():
  """Get the default hyperparameter configuration."""
  config = base_tc_config.get_config()
  config.random_seed = 0
  config.model_type = "transformer"
  config.attention_fn = favor.make_fast_generalized_attention(
    qkv_dim=config.qkv_dim // config.num_heads,
    features_type='deterministic',
    kernel_fn=jax.nn.relu,
    lax_scan_unroll=16)
  config.model_kwargs = dict(
    add_pos_emb=False,
    qk_transform_fn_factory=functools.partial(
      make_spe_transform_fn,
      spe_cls=spe.SineSPE,
      spe_kwargs=dict(
        num_realizations=64,
        num_sines=10
      ),
    )
  )
  config.batch_size = 8
  config.learning_rate = config.learning_rate / 32 * 8
  config.num_train_steps = 30000
  return config
Пример #2
0
def get_config():
    """Get the default hyperparameter configuration."""
    config = base_tc_config.get_config()
    config.model_type = "transformer"
    config.attention_fn = favor.make_fast_softmax_attention(
        qkv_dim=config.qkv_dim // config.num_heads, lax_scan_unroll=16)
    return config
def get_config():
    """Get the default hyperparameter configuration."""
    config = base_tc_config.get_config()
    config.model_type = "synthesizer"
    config.ignore_dot_product = True
    config.synthesizer_mode = "factorized_random"
    config.k = 32
    return config
Пример #4
0
def get_config():
    """Get the default hyperparameter configuration."""
    config = base_tc_config.get_config()
    config.model_type = "transformer"
    config.attention_fn = favor.make_fast_generalized_attention(
        qkv_dim=config.qkv_dim // config.num_heads,
        features_type='deterministic',
        kernel_fn=lambda x: jax.nn.elu(x) + 1,
        lax_scan_unroll=16)
    return config
Пример #5
0
def get_config():
    """Get the default hyperparameter configuration."""
    config = base_tc_config.get_config()
    config.random_seed = 2
    config.model_type = "transformer"
    config.attention_fn = favor.make_fast_softmax_attention(
        qkv_dim=config.qkv_dim // config.num_heads, lax_scan_unroll=16)
    config.batch_size = config.batch_size // 2
    config.learning_rate = config.learning_rate / 2
    config.num_train_steps = 30000
    return config
Пример #6
0
def get_config():
    """Get the default hyperparameter configuration."""
    config = base_tc_config.get_config()
    config.random_seed = 0
    config.model_type = "transformer"
    config.attention_fn = favor.make_fast_generalized_attention(
        qkv_dim=config.qkv_dim // config.num_heads,
        features_type='deterministic',
        kernel_fn=jax.nn.relu,
        lax_scan_unroll=16)
    config.batch_size = 8
    config.learning_rate = config.learning_rate / 32 * 8
    config.num_train_steps = 30000
    return config
Пример #7
0
def get_config():
    """Get the default hyperparameter configuration."""
    config = base_tc_config.get_config()
    config.random_seed = 0
    config.model_type = "transformer"
    config.attention_fn = favor.make_fast_softmax_attention(
        qkv_dim=config.qkv_dim // config.num_heads, lax_scan_unroll=16)
    config.model_kwargs = dict(add_pos_emb=False,
                               qk_transform_fn_factory=functools.partial(
                                   make_spe_transform_fn,
                                   spe_cls=spe.ConvSPE,
                                   spe_kwargs=dict(num_realizations=64,
                                                   kernel_size=128),
                                   shared=True))
    config.batch_size = 8
    config.learning_rate = config.learning_rate / 32 * 8
    config.num_train_steps = 30000
    return config
Пример #8
0
def get_config():
  """Get the default hyperparameter configuration."""
  config = base_tc_config.get_config()
  config.model_type = "transformer"
  return config
def get_config():
    """Get the default hyperparameter configuration."""
    config = base_tc_config.get_config()
    config.model_type = "sinkhorn"
    return config
def get_config():
    """Get the default hyperparameter configuration."""
    config = base_tc_config.get_config()
    config.model_type = "linformer"
    config.low_rank_features = 32
    return config