def electra_base_search_aggregator() -> dict:
    cfg = update_5best(electra_base_no_hpo())
    cfg['hpo_params']['num_trials'] = 10
    cfg['models']['MultimodalTextModel']['search_space']['model.network.agg_net.agg_type']\
        = space.Categorical('attention', 'mean', 'max', 'concat')
    cfg['models']['MultimodalTextModel']['search_space']['model.network.base_feature_units'] \
        = space.Categorical(-1, 128)
    return cfg
Beispiel #2
0
def electra_base_late_fusion_concate_e10_avg3():
    cfg = electra_base_no_hpo()
    cfg['models']['MultimodalTextModel']['search_space']['model.use_avg_nbest'] = True
    cfg['models']['MultimodalTextModel']['search_space']['optimization.nbest'] = 3
    cfg['models']['MultimodalTextModel']['search_space'][
        'model.network.agg_net.agg_type'] = 'concat'
    cfg['models']['MultimodalTextModel']['search_space'][
        'model.network.aggregate_categorical'] = True
    return cfg
def electra_base_grid_search() -> dict:
    cfg = update_5best(electra_base_no_hpo())
    cfg['hpo_params']['num_trials'] = 12
    cfg['hpo_params']['search_strategy'] = 'random'
    cfg['models']['MultimodalTextModel']['search_space'][
        'optimization.num_train_epochs'] = space.Categorical(5, 10)
    cfg['models']['MultimodalTextModel']['search_space'][
        'optimization.lr'] = space.Categorical(1E-4, 5E-5)
    cfg['models']['MultimodalTextModel']['search_space'][
        'optimization.layerwise_lr_decay'] = space.Categorical(0.8, 0.9, 1.0)
    cfg['models']['MultimodalTextModel']['search_space'][
        'optimization.wd'] = space.Categorical(0.01, 1E-4, 0.0)
    return cfg
def electra_models_with_fusion_strategies(model_type, fusion_strategy,
                                          num_epochs, average3):
    if model_type == 'small':
        cfg = update_5best(electra_small_no_hpo())
    elif model_type == 'base':
        cfg = update_5best(electra_base_no_hpo())
    elif model_type == 'large':
        cfg = update_5best(electra_large_no_hpo())
    else:
        raise NotImplementedError
    if average3:
        cfg['models']['MultimodalTextModel']['search_space'][
            'model.use_avg_nbest'] = True
        cfg['models']['MultimodalTextModel']['search_space'][
            'optimization.nbest'] = 3
    cfg['models']['MultimodalTextModel']['search_space'][
        'optimization.num_train_epochs'] = num_epochs
    if fusion_strategy == 'late_fusion_mean':
        cfg['models']['MultimodalTextModel']['search_space'][
            'model.network.agg_net.agg_type'] = 'mean'
        cfg['models']['MultimodalTextModel']['search_space'][
            'model.network.aggregate_categorical'] = True
    elif fusion_strategy == 'late_fusion_max':
        cfg['models']['MultimodalTextModel']['search_space'][
            'model.network.agg_net.agg_type'] = 'max'
        cfg['models']['MultimodalTextModel']['search_space'][
            'model.network.aggregate_categorical'] = True
    elif fusion_strategy == 'late_fusion_concat':
        cfg['models']['MultimodalTextModel']['search_space'][
            'model.network.agg_net.agg_type'] = 'concat'
        cfg['models']['MultimodalTextModel']['search_space'][
            'model.network.aggregate_categorical'] = True
    elif fusion_strategy == 'late_fusion_concat_gates':
        cfg['models']['MultimodalTextModel']['search_space'][
            'model.network.agg_net.agg_type'] = 'concat'
        cfg['models']['MultimodalTextModel']['search_space'][
            'model.network.aggregate_categorical'] = True
        cfg['models']['MultimodalTextModel']['search_space'][
            'model.network.numerical_net.gated_activation'] = True
        cfg['models']['MultimodalTextModel']['search_space'][
            'model.network.categorical_agg.gated_activation'] = True
    elif fusion_strategy == 'early_fusion':
        cfg['models']['MultimodalTextModel']['search_space'][
            'model.network.agg_net.agg_type'] = 'attention_token'
        cfg['models']['MultimodalTextModel']['search_space'][
            'model.network.aggregate_categorical'] = False
    elif fusion_strategy == 'early_fusion_layer3_units128':
        cfg['models']['MultimodalTextModel']['search_space'][
            'model.network.agg_net.agg_type'] = 'attention_token'
        cfg['models']['MultimodalTextModel']['search_space'][
            'model.network.agg_net.attention_net.num_layers'] = 3
        cfg['models']['MultimodalTextModel']['search_space'][
            'model.network.agg_net.attention_net.units'] = 128
        cfg['models']['MultimodalTextModel']['search_space'][
            'model.network.aggregate_categorical'] = False
    elif fusion_strategy == 'early_fusion_layer3':
        cfg['models']['MultimodalTextModel']['search_space'][
            'model.network.agg_net.agg_type'] = 'attention_token'
        cfg['models']['MultimodalTextModel']['search_space'][
            'model.network.agg_net.attention_net.num_layers'] = 3
        cfg['models']['MultimodalTextModel']['search_space'][
            'model.network.aggregate_categorical'] = False
    elif fusion_strategy == 'early_fusion_layer3_leaky':
        cfg['models']['MultimodalTextModel']['search_space'][
            'model.network.agg_net.agg_type'] = 'attention_token'
        cfg['models']['MultimodalTextModel']['search_space'][
            'model.network.agg_net.attention_net.num_layers'] = 3
        cfg['models']['MultimodalTextModel']['search_space'][
            'model.network.agg_net.attention_net.activation'] = 'leaky'
        cfg['models']['MultimodalTextModel']['search_space'][
            'model.network.aggregate_categorical'] = False
    elif fusion_strategy == 'early_fusion_layer3_leaky_units128':
        cfg['models']['MultimodalTextModel']['search_space'][
            'model.network.agg_net.agg_type'] = 'attention_token'
        cfg['models']['MultimodalTextModel']['search_space'][
            'model.network.agg_net.attention_net.num_layers'] = 3
        cfg['models']['MultimodalTextModel']['search_space'][
            'model.network.agg_net.attention_net.activation'] = 'leaky'
        cfg['models']['MultimodalTextModel']['search_space'][
            'model.network.agg_net.attention_net.units'] = 128
        cfg['models']['MultimodalTextModel']['search_space'][
            'model.network.aggregate_categorical'] = False
    elif fusion_strategy == 'all_text':
        cfg['models']['MultimodalTextModel']['search_space'][
            'model.network.agg_net.agg_type'] = 'concat'
        cfg['models']['MultimodalTextModel']['search_space'][
            'preprocessing.categorical.convert_to_text'] = True
        cfg['models']['MultimodalTextModel']['search_space'][
            'preprocessing.numerical.convert_to_text'] = True
    else:
        raise NotImplementedError
    return cfg