Exemplo n.º 1
0
             experiment_name='VarNaming_subtoken_edge_ablation',
             experiment_run_log_id=experiment_run_log_id,
             model_name='VarNamingFixedVocabGGNN',
             model_label='all_edge',
             n_workers=8,
             n_batch=250 * 4,
             evaluation_metrics=('evaluate_full_name_accuracy',
                                 'evaluate_subtokenwise_accuracy',
                                 'evaluate_edit_distance'),
             model_params_to_load='best.params',
             skip_s3_sync=skip_s3_sync,
             test=test),
        dict(seed=5145,
             gpu_ids=(0, 1, 2, 3),
             dataset_name='18_popular_mavens',
             experiment_name='VarNaming_subtoken_edge_ablation',
             experiment_run_log_id=experiment_run_log_id,
             model_name='VarNamingNameGraphVocabGGNN',
             model_label='syntax_edge',
             n_workers=8,
             n_batch=250 * 4,
             evaluation_metrics=('evaluate_full_name_accuracy',
                                 'evaluate_subtokenwise_accuracy',
                                 'evaluate_edit_distance'),
             model_params_to_load='best.params',
             skip_s3_sync=skip_s3_sync,
             test=test),
    ])
    run_command_on_remote('local', evaluate_models_for_experiment,
                          list_of_kwargs)
              debug=False)),
        (aws_config['remote_ids']['box1'],
         dict(dataset_name='18_popular_mavens',
              experiment_name='FITB_vocab_comparison',
              experiment_run_log_id=experiment_run_log_id,
              seed=5145,
              gpu_ids=(0, 1, 2, 3),
              model_name='FITBGSCVocabGGNN',
              model_label='syntax_edge',
              model_kwargs=dict(hidden_size=64,
                                type_emb_size=30,
                                name_emb_size=31,
                                n_msg_pass_iters=8),
              init_fxn_name='Xavier',
              init_fxn_kwargs=dict(),
              loss_fxn_name='FITBLoss',
              loss_fxn_kwargs=dict(),
              optimizer_name='Adam',
              optimizer_kwargs={'learning_rate': .0002},
              val_fraction=0.15,
              n_workers=8,
              n_epochs=200,
              evaluation_metrics=('evaluate_FITB_accuracy', ),
              n_batch=250 * 4,
              skip_s3_sync=False,
              debug=False)),
    ]
    for instance_id, kwargs in instance_ids_train_kwargs:
        time.sleep(1)
        run_command_on_remote(instance_id, train_model_for_experiment, kwargs)
Exemplo n.º 3
0
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
from data.AugmentedAST import syntax_only_excluded_edge_types
from experiments import aws_config
from experiments.make_tasks_and_preprocess_for_experiment import make_tasks_and_preprocess
from experiments.run_command_on_remote import run_command_on_remote

if __name__ == '__main__':
    kwargs = dict(seed=515,
                  dataset_name='18_popular_mavens',
                  experiment_name='VarNaming_subtoken_edge_ablation',
                  task_names=['VarNamingTask'],
                  n_jobs=30,
                  model_names_labels_and_prepro_kwargs=[
                      ('VarNamingGSCVocabGGNN', 'all_edge_except_subtoken',
                       frozenset(),
                       dict(max_name_encoding_length=30,
                            add_edges=False), dict(max_nodes_per_graph=500)),
                      ('VarNamingGSCVocabGGNN', 'syntax_edge_except_subtoken',
                       syntax_only_excluded_edge_types,
                       dict(max_name_encoding_length=30,
                            add_edges=False), dict(max_nodes_per_graph=500)),
                  ],
                  skip_make_tasks=False)
    run_command_on_remote(aws_config['remote_ids']['box1'],
                          make_tasks_and_preprocess, kwargs)
Exemplo n.º 4
0
                 'evaluate_subtokenwise_accuracy',
                 'evaluate_edit_distance',
                 'evaluate_length_weighted_edit_distance',
                 'evaluate_top_5_full_name_accuracy',
             ),
             model_params_to_load='model_checkpoint_epoch_23.params',
             skip_s3_sync=skip_s3_sync,
             test=test),
        dict(seed=5145,
             gpu_ids=(0, 1, 2, 3),
             dataset_name='18_popular_mavens',
             experiment_name='VarNaming_subtoken_edge_ablation',
             experiment_run_log_id=experiment_run_log_id,
             model_name='VarNamingGSCVocabGGNN',
             model_label='syntax_edge_except_subtoken',
             n_workers=8,
             n_batch=250 * 4,
             evaluation_metrics=(
                 'evaluate_full_name_accuracy',
                 'evaluate_subtokenwise_accuracy',
                 'evaluate_edit_distance',
                 'evaluate_length_weighted_edit_distance',
                 'evaluate_top_5_full_name_accuracy',
             ),
             model_params_to_load='model_checkpoint_epoch_36.params',
             skip_s3_sync=skip_s3_sync,
             test=test),
    ])
    run_command_on_remote(aws_config['remote_ids']['box1'],
                          evaluate_models_for_experiment, list_of_kwargs)