def main(_): tf.logging.set_verbosity(tf.logging.INFO) random.seed(FLAGS.random_seed) params = contrib_training.HParams( num_steps=FLAGS.num_steps, val_freq=FLAGS.val_freq, seq_len=FLAGS.seq_len, batch_size=FLAGS.batch_size, emb_variable=FLAGS.emb_variable, emb_size=FLAGS.emb_size, vocab_size=4, hidden_lstm_size=FLAGS.hidden_lstm_size, norm_lstm=FLAGS.norm_lstm, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, reg_type=FLAGS.reg_type, reg_weight=FLAGS.reg_weight, out_dir=FLAGS.out_dir, in_tr_data_dir=FLAGS.in_tr_data_dir, in_val_data_dir=FLAGS.in_val_data_dir, ood_val_data_dir=FLAGS.ood_val_data_dir, master=FLAGS.master, save_meta=FLAGS.save_meta, filter_label=FLAGS.filter_label, mutation_rate=FLAGS.mutation_rate, ) # setup output directory create_out_dir(params) # load datasets params.add_hparam('in_tr_file_pattern', 'in_tr') params.add_hparam('in_val_file_pattern', 'in_val') params.add_hparam('ood_val_file_pattern', 'ood_val') (in_tr_dataset, in_val_dataset, ood_val_dataset) = load_datasets(params) # print parameter settings tf.logging.info(params) with tf.gfile.GFile(os.path.join(params.model_dir, 'params.json'), mode='w') as f: f.write(json.dumps(params.to_json(), sort_keys=True)) # construct model model = SeqModel(params) model.reset() ## if previous model ckpt exists, restore the model from there tf.logging.info('model dir=%s', os.path.join(params.out_dir, '*.ckpt.index')) prev_steps, ckpt_file = utils.get_latest_ckpt(params.model_dir) if ckpt_file: tf.logging.info('previous ckpt exist, prev_steps=%s', prev_steps) model.restore_from_ckpt(ckpt_file) # training model.train(in_tr_dataset, in_val_dataset, ood_val_dataset, prev_steps)
def main(_): tf.logging.set_verbosity(tf.logging.INFO) random.seed(FLAGS.random_seed) params = contrib_training.HParams( embedding=FLAGS.embedding, num_steps=FLAGS.num_steps, val_freq=FLAGS.val_freq, seq_len=FLAGS.seq_len, batch_size=FLAGS.batch_size, emb_size=FLAGS.emb_size, vocab_size=4, hidden_lstm_size=FLAGS.hidden_lstm_size, hidden_dense_size=FLAGS.hidden_dense_size, dropout_rate=FLAGS.dropout_rate, learning_rate=FLAGS.learning_rate, num_motifs=FLAGS.num_motifs, len_motifs=FLAGS.len_motifs, temperature=FLAGS.temperature, reweight_sample=FLAGS.reweight_sample, l2_reg=FLAGS.l2_reg, out_dir=FLAGS.out_dir, in_tr_data_dir=FLAGS.in_tr_data_dir, in_val_data_dir=FLAGS.in_val_data_dir, ood_val_data_dir=FLAGS.ood_val_data_dir, master=FLAGS.master, save_meta=FLAGS.save_meta, label_dict_file=FLAGS.label_dict_file, mutation_rate=FLAGS.mutation_rate, epsilon=FLAGS.epsilon, ) # create output directories create_out_dir(params) # load datasets and labels for training params.add_hparam('in_tr_file_pattern', 'in_tr') params.add_hparam('in_val_file_pattern', 'in_val') params.add_hparam('ood_val_file_pattern', 'ood_val') label_sample_size, in_tr_dataset, in_val_dataset, ood_val_dataset = load_datasets_and_labels( params) params.add_hparam('n_class', len(label_sample_size)) tf.logging.info('label_sample_size=%s', label_sample_size) # compute weights for labels # load the dictionary for class labels. # Key: class name (string), values: encoded class label (int) with tf.gfile.GFile(os.path.join(params.label_dict_file), 'rb') as f_label_code: # label_dict_after_2016_new_species0 = json.load(f) params.add_hparam('label_dict', yaml.safe_load(f_label_code)) tf.logging.info('# of label_dict=%s', len(params.label_dict)) label_weights = utils.compute_label_weights_using_sample_size( params.label_dict, label_sample_size) params.add_hparam('label_weights', label_weights) # print parameter settings tf.logging.info(params) with tf.gfile.GFile(os.path.join(params.model_dir, 'params.json'), mode='w') as f: f.write(json.dumps(params.to_json(), sort_keys=True)) # construct model tf.logging.info('create model') model = SeqPredModel(params) model.reset() ## if previous model ckpt exists, restore the model from there tf.logging.info('model dir=%s', os.path.join(params.model_dir, '*.ckpt.index')) prev_steps, ckpt_file = utils.get_latest_ckpt(params.model_dir) if ckpt_file: tf.logging.info('previous ckpt exist, prev_steps=%s', prev_steps) model.restore_from_ckpt(ckpt_file) # training tf.logging.info('strart training') model.train(in_tr_dataset, in_val_dataset, ood_val_dataset, prev_steps)