# save the variables var_dict = get_variables_as_dict(model_vs) saver = VariableSaver(var_dict, config.save_dir) saver.save() print('=' * 30 + 'result' + '=' * 30) pprint(best_valid_metrics) if __name__ == '__main__': # get config obj config = ExpConfig() # parse the arguments arg_parser = ArgumentParser() register_config_arguments(config, arg_parser) arg_parser.parse_args(sys.argv[1:]) config.x_dim = get_data_dim(config.dataset) print_with_title('Configurations', pformat(config.to_dict()), after='\n') # open the result object and prepare for result directories if specified results = MLResults(config.result_dir) results.save_config(config) # save experiment settings for review results.make_dirs(config.save_dir, exist_ok=True) with warnings.catch_warnings(): # suppress DeprecationWarning from NumPy caused by codes in TensorFlow-Probability warnings.filterwarnings("ignore", category=DeprecationWarning, module='numpy') main()
def main(): config = ExpConfig_transfer() # parse the arguments arg_parser = ArgumentParser() register_config_arguments(config, arg_parser) arg_parser.parse_args(sys.argv[1:]) executable_action_list = (config.executable_action.replace(" ", "")).split(',') dataset = [str(i) for i in range(533)] if '1' in executable_action_list: train_part1( dataset, start_time=datetime_to_timestamp(config.train_start_time), last_time=datetime_to_timestamp(config.train_last_time), action1_model_dir=config.action1_model_dir, action1_result_dir=config.action1_result_dir, action1_GPU_device_number=config.action1_GPU_device_number, action1_sample_ratio=config.action1_sample_ratio, action1_machine_sample_ratio=config.action1_machine_sample_ratio) if '2' in executable_action_list: get_all_machines_z( dataset, start_time=datetime_to_timestamp(config.train_z_start_time), last_time=datetime_to_timestamp(config.train_z_last_time), action1_model_dir=config.action1_model_dir, action1_result_dir=config.action1_result_dir, action2_run_parallel_number=config.action2_run_parallel_number) if '3' in executable_action_list: train_part2( dataset, action1_result_dir=config.action1_result_dir, action3_z_file_dir=config.action3_z_file_dir, action3_z_distance_matrix_name=config. action3_z_distance_matrix_name, action3_cluster_number=config.action3_cluster_number, action3_cluster_png_filename=config.action3_cluster_png_filename, action3_cluster_result_filename=config. action3_cluster_result_filename, action3_machine_file_name=config.action3_machine_file_name) if '4' in executable_action_list: train_part3( dataset, start_time=datetime_to_timestamp(config.train_start_time), last_time=datetime_to_timestamp(config.train_last_time), action3_z_file_dir=config.action3_z_file_dir, action3_cluster_result_filename=config. action3_cluster_result_filename, action3_machine_file_name=config.action3_machine_file_name, action4_run_parallel_number=config.action4_run_parallel_number, action4_model_dir_prefix=config.action4_model_dir_prefix, action4_result_dir_prefix=config.action4_result_dir_prefix, action1_model_dir=config.action1_model_dir, action4_cluster_max_machine=config.action4_cluster_max_machine) if '5' in executable_action_list: if config.test_start_timestamp is None: _test_start_timestamp = datetime_to_timestamp( config.test_start_time) _test_last_timestamp = datetime_to_timestamp(config.test_last_time) else: _test_start_timestamp = config.test_start_timestamp _test_last_timestamp = config.test_last_timestamp train_part4( dataset, historical_start_time=datetime_to_timestamp( config.train_z_start_time), historical_last_time=datetime_to_timestamp( config.train_z_last_time), start_time=_test_start_timestamp, last_time=_test_last_timestamp, action3_z_file_dir=config.action3_z_file_dir, action3_cluster_result_filename=config. action3_cluster_result_filename, action3_machine_file_name=config.action3_machine_file_name, action5_run_parallel_number=config.action5_run_parallel_number, action4_model_dir_prefix=config.action4_model_dir_prefix, action4_save_path=config.action4_save_path, action4_result_dir_prefix=config.action4_result_dir_prefix, get_historical_data_info_flag=config. action5_get_historical_data_info_flag, get_threshold_flag=config.action5_get_threshold_flag)