def get_config(): """Get the default hyperparameter configuration.""" config = base_listops_config.get_config() config.random_seed = 1 config.model_type = "transformer" config.attention_fn = favor.make_fast_generalized_attention( qkv_dim=config.qkv_dim // config.num_heads, features_type='deterministic', kernel_fn=jax.nn.relu, lax_scan_unroll=16) config.model_kwargs = dict( add_pos_emb=False, qk_transform_fn_factory=functools.partial( make_spe_transform_fn, spe_cls=spe.SineSPE, spe_kwargs=dict( num_realizations=64, num_sines=10 ), ) ) config.batch_size = 8 config.learning_rate = config.learning_rate / 32 * 8 config.num_train_steps = 10000 config.eval_frequency = config.eval_frequency * 4 return config
def get_config(): """Get the default hyperparameter configuration.""" config = base_listops_config.get_config() config.model_type = "performer" config.model_kwargs = config_dict.create( attention_fn_cls="softmax", attention_fn_kwargs=config_dict.create(ortho_scaling=0.0, nb_features=2)) return config
def get_config(): """Get the default hyperparameter configuration.""" config = base_listops_config.get_config() config.model_type = "transformer" config.attention_fn = favor.make_fast_generalized_attention( qkv_dim=config.qkv_dim // config.num_heads, features_type='ortho', kernel_fn=jax.nn.relu, lax_scan_unroll=16) return config
def get_config(): """Get the default hyperparameter configuration.""" config = base_listops_config.get_config() config.random_seed = 2 config.model_type = "transformer" config.attention_fn = favor.make_fast_softmax_attention( qkv_dim=config.qkv_dim // config.num_heads, lax_scan_unroll=16) config.batch_size = 8 config.learning_rate = config.learning_rate / 32 * 8 config.num_train_steps = 10000 return config
def get_config(): """Get the default hyperparameter configuration.""" config = base_listops_config.get_config() config.random_seed = 0 config.model_type = "transformer" config.attention_fn = favor.make_fast_generalized_attention( qkv_dim=config.qkv_dim // config.num_heads, features_type='deterministic', kernel_fn=jax.nn.relu, lax_scan_unroll=16) config.batch_size = 8 config.learning_rate = config.learning_rate / 32 * 8 config.num_train_steps = 10000 return config
def get_config(): """Get the default hyperparameter configuration.""" config = base_listops_config.get_config() config.random_seed = 2 config.model_type = "transformer" config.attention_fn = favor.make_fast_softmax_attention( qkv_dim=config.qkv_dim // config.num_heads, lax_scan_unroll=16) config.model_kwargs = dict(add_pos_emb=False, qk_transform_fn_factory=functools.partial( make_spe_transform_fn, spe_cls=spe.SineSPE, spe_kwargs=dict(num_realizations=64, num_sines=10), )) config.batch_size = 8 config.learning_rate = config.learning_rate / 32 * 8 config.num_train_steps = 10000 return config
def get_config(): """Get the default hyperparameter configuration.""" config = base_listops_config.get_config() config.model_type = "transformer" return config