class ExpConfig(Config): # dataset configuration dataset = "machine-1-1" x_dim = get_data_dim(dataset) # model architecture configuration use_connected_z_q = True use_connected_z_p = True # model parameters z_dim = 3 rnn_cell = 'GRU' # 'GRU', 'LSTM' or 'Basic' rnn_num_hidden = 500 window_length = 100 dense_dim = 500 posterior_flow_type = 'nf' # 'nf' or None nf_layers = 20 # for nf max_epoch = 10 train_start = 0 max_train_size = None # `None` means full train set batch_size = 50 l2_reg = 0.0001 initial_lr = 0.001 lr_anneal_factor = 0.5 lr_anneal_epoch_freq = 40 lr_anneal_step_freq = 400 std_epsilon = 1e-4 # evaluation parameters test_n_z = 1024 test_batch_size = 50 test_start = 0 max_test_size = None # `None` means full test set valid_step_freq = 100 gradient_clip_norm = 10. early_stop = True # whether to apply early stop method # pot parameters # recommend values for `level`: # SMAP: 0.07 # MSL: 0.01 # SMD group 1: 0.0050 # SMD group 2: 0.0075 # SMD group 3: 0.0001 level = 0.01 # outputs config save_z = False # whether to save sampled z in hidden space get_score_on_dim = False # whether to get score on dim. If `True`, the score will be a 2-dim ndarray save_dir = 'model' restore_dir = None # If not None, restore variables from this dir result_dir = 'result' # Where to save the result file train_score_filename = 'train_score.pkl' test_score_filename = 'test_score.pkl'
var_dict = get_variables_as_dict(model_vs) saver = VariableSaver(var_dict, config.save_dir) saver.save() print('=' * 30 + 'result' + '=' * 30) pprint(best_valid_metrics) if __name__ == '__main__': # get config obj config = ExpConfig() # parse the arguments arg_parser = ArgumentParser() register_config_arguments(config, arg_parser) arg_parser.parse_args(sys.argv[1:]) config.x_dim = get_data_dim(config.dataset) print_with_title('Configurations', pformat(config.to_dict()), after='\n') # open the result object and prepare for result directories if specified results = MLResults(config.result_dir) results.save_config(config) # save experiment settings for review results.make_dirs(config.save_dir, exist_ok=True) with warnings.catch_warnings(): # suppress DeprecationWarning from NumPy caused by codes in TensorFlow-Probability warnings.filterwarnings("ignore", category=DeprecationWarning, module='numpy') main()