def get_config(): """Get the hyperparameter configuration.""" config = base_cifar10_config.get_config() config.random_seed = 1 config.model_type = "transformer" config.learning_rate = .00025 config.batch_size = 96 config.eval_frequency = 4 * TRAIN_EXAMPLES // config.batch_size config.num_train_steps = (TRAIN_EXAMPLES // config.batch_size) * NUM_EPOCHS config.num_eval_steps = VALID_EXAMPLES // config.batch_size config.factors = 'constant * linear_warmup * cosine_decay' config.warmup = (TRAIN_EXAMPLES // config.batch_size) * 1 config.model.dropout_rate = 0.3 config.model.attention_dropout_rate = 0.2 config.model.learn_pos_emb = True config.model.num_layers = 1 config.model.emb_dim = 128 config.model.qkv_dim = 64 config.model.mlp_dim = 128 config.model.num_heads = 8 config.model.classifier_pool = "CLS" config.model.add_pos_emb = False num_realizations = 32 config.model.qk_transform_fn_factory = functools.partial( make_spe_transform_fn, spe_cls=spe.SineSPE, spe_kwargs=dict(num_realizations=num_realizations, num_sines=10)) config.attention_fn = favor.make_fast_softmax_attention( qkv_dim=num_realizations, lax_scan_unroll=16) return config
def get_config(): """Get the hyperparameter configuration.""" config = base_cifar10_config.get_config() config.random_seed = 0 config.model_type = "transformer" config.learning_rate = .00025 config.batch_size = 96 config.eval_frequency = TRAIN_EXAMPLES // config.batch_size config.num_train_steps = (TRAIN_EXAMPLES // config.batch_size) * NUM_EPOCHS config.num_eval_steps = VALID_EXAMPLES // config.batch_size config.factors = 'constant * linear_warmup * cosine_decay' config.warmup = (TRAIN_EXAMPLES // config.batch_size) * 1 config.model.dropout_rate = 0.3 config.model.attention_dropout_rate = 0.2 config.model.learn_pos_emb = True config.model.num_layers = 1 config.model.emb_dim = 128 config.model.qkv_dim = 64 config.model.mlp_dim = 128 config.model.num_heads = 8 config.model.classifier_pool = "CLS" config.attention_fn = favor.make_fast_generalized_attention( qkv_dim=config.model.qkv_dim // config.model.num_heads, features_type='deterministic', kernel_fn=jax.nn.relu, lax_scan_unroll=16) return config
def get_config(): """Get the hyperparameter configuration.""" config = base_cifar10_config.get_config() config.model_type = "sinkhorn" config.model.block_size = 64 config.model.num_layers = 1 config.weight_decay = 0.7 config.model.dropout_rate = 0.1 return config
def get_config(): """Get the hyperparameter configuration.""" config = base_cifar10_config.get_config() config.model_type = "longformer" config.model.num_layers = 4 config.model.emb_dim = 128 config.model.qkv_dim = 64 config.model.mlp_dim = 128 config.model.num_heads = 4 config.model.classifier_pool = "MEAN" return config
def get_config(): """Get the hyperparameter configuration.""" config = base_cifar10_config.get_config() config.model_type = "bigbird" config.model.block_size = 64 config.model.num_layers = 1 config.model.emb_dim = 128 config.model.qkv_dim = 64 config.model.mlp_dim = 128 config.model.num_head = 4 config.model.classifier_pool = "CLS" return config
def get_config(): """Get the hyperparameter configuration.""" config = base_cifar10_config.get_config() config.model_type = "synthesizer" config.model.synthesizer_mode = "random" config.model.num_layers = 1 config.model.emb_dim = 128 config.model.qkv_dim = 64 config.model.mlp_dim = 128 config.model.num_heads = 8 config.model.classifier_pool = "CLS" return config
def get_config(): """Get the hyperparameter configuration.""" config = base_cifar10_config.get_config() config.random_seed = 0 config.model_type = "transformer" config.learning_rate = .00019 config.batch_size = 96 config.factors = 'constant * linear_warmup * cosine_decay' config.warmup = (TRAIN_EXAMPLES // config.batch_size) * 1 config.model.dropout_rate = 0.3 config.model.attention_dropout_rate = 0.2 config.model.learn_pos_emb = True config.model.num_layers = 1 config.model.emb_dim = 128 config.model.qkv_dim = 64 config.model.mlp_dim = 128 config.model.num_heads = 8 config.model.classifier_pool = "CLS" config.attention_fn = favor.make_fast_softmax_attention( qkv_dim=config.model.qkv_dim // config.model.num_heads, lax_scan_unroll=16) return config
def get_config(): """Get the hyperparameter configuration.""" config = base_cifar10_config.get_config() config.model_type = "transformer" return config