experiment_name='VarNaming_subtoken_edge_ablation', experiment_run_log_id=experiment_run_log_id, model_name='VarNamingFixedVocabGGNN', model_label='all_edge', n_workers=8, n_batch=250 * 4, evaluation_metrics=('evaluate_full_name_accuracy', 'evaluate_subtokenwise_accuracy', 'evaluate_edit_distance'), model_params_to_load='best.params', skip_s3_sync=skip_s3_sync, test=test), dict(seed=5145, gpu_ids=(0, 1, 2, 3), dataset_name='18_popular_mavens', experiment_name='VarNaming_subtoken_edge_ablation', experiment_run_log_id=experiment_run_log_id, model_name='VarNamingNameGraphVocabGGNN', model_label='syntax_edge', n_workers=8, n_batch=250 * 4, evaluation_metrics=('evaluate_full_name_accuracy', 'evaluate_subtokenwise_accuracy', 'evaluate_edit_distance'), model_params_to_load='best.params', skip_s3_sync=skip_s3_sync, test=test), ]) run_command_on_remote('local', evaluate_models_for_experiment, list_of_kwargs)
debug=False)), (aws_config['remote_ids']['box1'], dict(dataset_name='18_popular_mavens', experiment_name='FITB_vocab_comparison', experiment_run_log_id=experiment_run_log_id, seed=5145, gpu_ids=(0, 1, 2, 3), model_name='FITBGSCVocabGGNN', model_label='syntax_edge', model_kwargs=dict(hidden_size=64, type_emb_size=30, name_emb_size=31, n_msg_pass_iters=8), init_fxn_name='Xavier', init_fxn_kwargs=dict(), loss_fxn_name='FITBLoss', loss_fxn_kwargs=dict(), optimizer_name='Adam', optimizer_kwargs={'learning_rate': .0002}, val_fraction=0.15, n_workers=8, n_epochs=200, evaluation_metrics=('evaluate_FITB_accuracy', ), n_batch=250 * 4, skip_s3_sync=False, debug=False)), ] for instance_id, kwargs in instance_ids_train_kwargs: time.sleep(1) run_command_on_remote(instance_id, train_model_for_experiment, kwargs)
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. from data.AugmentedAST import syntax_only_excluded_edge_types from experiments import aws_config from experiments.make_tasks_and_preprocess_for_experiment import make_tasks_and_preprocess from experiments.run_command_on_remote import run_command_on_remote if __name__ == '__main__': kwargs = dict(seed=515, dataset_name='18_popular_mavens', experiment_name='VarNaming_subtoken_edge_ablation', task_names=['VarNamingTask'], n_jobs=30, model_names_labels_and_prepro_kwargs=[ ('VarNamingGSCVocabGGNN', 'all_edge_except_subtoken', frozenset(), dict(max_name_encoding_length=30, add_edges=False), dict(max_nodes_per_graph=500)), ('VarNamingGSCVocabGGNN', 'syntax_edge_except_subtoken', syntax_only_excluded_edge_types, dict(max_name_encoding_length=30, add_edges=False), dict(max_nodes_per_graph=500)), ], skip_make_tasks=False) run_command_on_remote(aws_config['remote_ids']['box1'], make_tasks_and_preprocess, kwargs)
'evaluate_subtokenwise_accuracy', 'evaluate_edit_distance', 'evaluate_length_weighted_edit_distance', 'evaluate_top_5_full_name_accuracy', ), model_params_to_load='model_checkpoint_epoch_23.params', skip_s3_sync=skip_s3_sync, test=test), dict(seed=5145, gpu_ids=(0, 1, 2, 3), dataset_name='18_popular_mavens', experiment_name='VarNaming_subtoken_edge_ablation', experiment_run_log_id=experiment_run_log_id, model_name='VarNamingGSCVocabGGNN', model_label='syntax_edge_except_subtoken', n_workers=8, n_batch=250 * 4, evaluation_metrics=( 'evaluate_full_name_accuracy', 'evaluate_subtokenwise_accuracy', 'evaluate_edit_distance', 'evaluate_length_weighted_edit_distance', 'evaluate_top_5_full_name_accuracy', ), model_params_to_load='model_checkpoint_epoch_36.params', skip_s3_sync=skip_s3_sync, test=test), ]) run_command_on_remote(aws_config['remote_ids']['box1'], evaluate_models_for_experiment, list_of_kwargs)