Beispiel #1
0
def main(_):

    tf.logging.set_verbosity(tf.logging.INFO)
    random.seed(FLAGS.random_seed)

    params = contrib_training.HParams(
        num_steps=FLAGS.num_steps,
        val_freq=FLAGS.val_freq,
        seq_len=FLAGS.seq_len,
        batch_size=FLAGS.batch_size,
        emb_variable=FLAGS.emb_variable,
        emb_size=FLAGS.emb_size,
        vocab_size=4,
        hidden_lstm_size=FLAGS.hidden_lstm_size,
        norm_lstm=FLAGS.norm_lstm,
        dropout_rate=FLAGS.dropout_rate,
        learning_rate=FLAGS.learning_rate,
        reg_type=FLAGS.reg_type,
        reg_weight=FLAGS.reg_weight,
        out_dir=FLAGS.out_dir,
        in_tr_data_dir=FLAGS.in_tr_data_dir,
        in_val_data_dir=FLAGS.in_val_data_dir,
        ood_val_data_dir=FLAGS.ood_val_data_dir,
        master=FLAGS.master,
        save_meta=FLAGS.save_meta,
        filter_label=FLAGS.filter_label,
        mutation_rate=FLAGS.mutation_rate,
    )

    # setup output directory
    create_out_dir(params)

    # load datasets
    params.add_hparam('in_tr_file_pattern', 'in_tr')
    params.add_hparam('in_val_file_pattern', 'in_val')
    params.add_hparam('ood_val_file_pattern', 'ood_val')
    (in_tr_dataset, in_val_dataset, ood_val_dataset) = load_datasets(params)

    # print parameter settings
    tf.logging.info(params)
    with tf.gfile.GFile(os.path.join(params.model_dir, 'params.json'),
                        mode='w') as f:
        f.write(json.dumps(params.to_json(), sort_keys=True))

    # construct model
    model = SeqModel(params)
    model.reset()

    ## if previous model ckpt exists, restore the model from there
    tf.logging.info('model dir=%s', os.path.join(params.out_dir,
                                                 '*.ckpt.index'))
    prev_steps, ckpt_file = utils.get_latest_ckpt(params.model_dir)
    if ckpt_file:
        tf.logging.info('previous ckpt exist, prev_steps=%s', prev_steps)
        model.restore_from_ckpt(ckpt_file)

    # training
    model.train(in_tr_dataset, in_val_dataset, ood_val_dataset, prev_steps)
Beispiel #2
0
def main(_):

    tf.logging.set_verbosity(tf.logging.INFO)
    random.seed(FLAGS.random_seed)

    params = contrib_training.HParams(
        embedding=FLAGS.embedding,
        num_steps=FLAGS.num_steps,
        val_freq=FLAGS.val_freq,
        seq_len=FLAGS.seq_len,
        batch_size=FLAGS.batch_size,
        emb_size=FLAGS.emb_size,
        vocab_size=4,
        hidden_lstm_size=FLAGS.hidden_lstm_size,
        hidden_dense_size=FLAGS.hidden_dense_size,
        dropout_rate=FLAGS.dropout_rate,
        learning_rate=FLAGS.learning_rate,
        num_motifs=FLAGS.num_motifs,
        len_motifs=FLAGS.len_motifs,
        temperature=FLAGS.temperature,
        reweight_sample=FLAGS.reweight_sample,
        l2_reg=FLAGS.l2_reg,
        out_dir=FLAGS.out_dir,
        in_tr_data_dir=FLAGS.in_tr_data_dir,
        in_val_data_dir=FLAGS.in_val_data_dir,
        ood_val_data_dir=FLAGS.ood_val_data_dir,
        master=FLAGS.master,
        save_meta=FLAGS.save_meta,
        label_dict_file=FLAGS.label_dict_file,
        mutation_rate=FLAGS.mutation_rate,
        epsilon=FLAGS.epsilon,
    )

    # create output directories
    create_out_dir(params)

    # load datasets and labels for training
    params.add_hparam('in_tr_file_pattern', 'in_tr')
    params.add_hparam('in_val_file_pattern', 'in_val')
    params.add_hparam('ood_val_file_pattern', 'ood_val')
    label_sample_size, in_tr_dataset, in_val_dataset, ood_val_dataset = load_datasets_and_labels(
        params)
    params.add_hparam('n_class', len(label_sample_size))
    tf.logging.info('label_sample_size=%s', label_sample_size)

    # compute weights for labels
    # load the dictionary for class labels.
    # Key: class name (string), values: encoded class label (int)
    with tf.gfile.GFile(os.path.join(params.label_dict_file),
                        'rb') as f_label_code:
        # label_dict_after_2016_new_species0 = json.load(f)
        params.add_hparam('label_dict', yaml.safe_load(f_label_code))
        tf.logging.info('# of label_dict=%s', len(params.label_dict))

    label_weights = utils.compute_label_weights_using_sample_size(
        params.label_dict, label_sample_size)
    params.add_hparam('label_weights', label_weights)

    # print parameter settings
    tf.logging.info(params)
    with tf.gfile.GFile(os.path.join(params.model_dir, 'params.json'),
                        mode='w') as f:
        f.write(json.dumps(params.to_json(), sort_keys=True))

    # construct model
    tf.logging.info('create model')
    model = SeqPredModel(params)
    model.reset()

    ## if previous model ckpt exists, restore the model from there
    tf.logging.info('model dir=%s',
                    os.path.join(params.model_dir, '*.ckpt.index'))
    prev_steps, ckpt_file = utils.get_latest_ckpt(params.model_dir)
    if ckpt_file:
        tf.logging.info('previous ckpt exist, prev_steps=%s', prev_steps)
        model.restore_from_ckpt(ckpt_file)

    # training
    tf.logging.info('strart training')
    model.train(in_tr_dataset, in_val_dataset, ood_val_dataset, prev_steps)