def train_process(tower_mems, tower_new_mems, train_op, train_hooks, ckpt_dir):
    # Training loop
    per_core_bsz = FLAGS.batch_size // FLAGS.num_core_per_host

    tower_mems_np = [[
        np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model],
                 dtype=np.float32) for _ in range(FLAGS.n_layer)
    ] for _ in range(FLAGS.num_core_per_host)]

    logger.info("Start train transformer-xl lm for dataset:{}".format(
        FLAGS.dataset))

    with MonitoredSession(session_creator=ChiefSessionCreator(
            config=tf.ConfigProto(allow_soft_placement=True),
            checkpoint_dir=ckpt_dir),
                          hooks=train_hooks) as sess:

        fetches = [tower_new_mems, train_op]
        feed_dict = {}
        while not sess.should_stop():
            # Segment - Level Recurrence with State Reuse
            for i in range(FLAGS.num_core_per_host):
                for m, m_np in zip(tower_mems[i], tower_mems_np[i]):
                    feed_dict[m] = m_np

            fetched = sess.run(fetches, feed_dict=feed_dict)
            tower_mems_np, _ = fetched
Example #2
0
    def _init_sess(self):
        with self.graph.as_default():
            # Initialize variables
            init_op = tf.global_variables_initializer()
            tf.add_to_collection(tf.GraphKeys.INIT_OP, init_op)
            local_init_op = tf.local_variables_initializer()
            tf.add_to_collection(tf.GraphKeys.LOCAL_INIT_OP, local_init_op)
            self._uninit_vars = tf.report_uninitialized_variables()

            # Retrieve summary writer and create MonitoredSession
            # https://github.com/tensorflow/tensorflow/issues/11350
            # https://github.com/tensorflow/tensorflow/blob/
            #   a7e225350abeed719f634ef71cd9d908424877b2/tensorflow/python/
            #   training/basic_session_run_hooks.py#L337
            # TODO Use tf.contrib.summary
            self.summary_writer = tf.summary.FileWriter(self.cfg.save_path)
            tf_config = tf.ConfigProto(allow_soft_placement=True)
            sess_creator = ChiefSessionCreator(
                config=tf_config,
                checkpoint_dir=self.cfg.restore_path)
            self._hooks = self.get_hooks()
            sess_gen = MonitoredSession(session_creator=sess_creator,
                                        hooks=self._hooks)

            return sess_gen
Example #3
0
def main():
#    filelist = tf.train.match_filenames_once(["hdfs://tensorflow-on-yarn-test/user/yichuan.dingyc/hdfs_train_test/mini-data-processed/part-*"])
    filelist = tf.train.match_filenames_once(["part*"])
    #filelist = tf.train.match_filenames_once(["hdfs://tensorflow-on-yarn-test/user/xiangqin.oxq/nn_data/part-000*"])
    filename_queue = tf.train.string_input_producer(filelist)

    reader = user_ops.SmStandardKvReader("[dat]", "[common]", trim=True)
    file_name, record = reader.read(filename_queue)

    input_schema = "wd_input_schema.json"
    parse_schema = "wd_parse_schema.json"

    standard_kv_parser = lib_parser.StandardKvParser([record], input_schema, parse_schema)
    tensor_dict = standard_kv_parser.get_tensor_dict()

    i=0
    init_op = [tf.local_variables_initializer(), tf.global_variables_initializer()]
    with MonitoredSession() as sess:
        sess.run(init_op)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            begin_time = time.time()
            while True:
                if i %100 == 0:
                    print i
                sess.run(tensor_dict)
                #print(sess.run(record))
                i += 1
            coord.request_stop()
            coord.join(threads)
        except Exception as e:
            pass
Example #4
0
def main():
    #    filelist = tf.train.match_filenames_once(["hdfs://tensorflow-on-yarn-test/user/tianjin.gutj/mnist_train.tfrecord"])
    filelist = tf.train.match_filenames_once(["mnist_train.tfrecord"])
    reader = tf.TFRecordReader()
    filename_queue = tf.train.string_input_producer(filelist)
    image, label = reader.read(filename_queue)
    init_op = [
        tf.local_variables_initializer(),
        tf.global_variables_initializer()
    ]
    #    with tf.Session() as sess:
    with MonitoredSession() as sess:
        sess.run(init_op)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        i = 0
        while True:
            i += 1
            sess.run(label)
            #    print(i)

        coord.request_stop()
        coord.join(threads)
Example #5
0
def interpolate():
    properties = get_properties(FLAGS)
    logdir = setup_logdir(FLAGS, properties)
    noise = tf.placeholder(dtype=tf.float32,
                           shape=[FLAGS.batch_size, FLAGS.z_dim])
    model = get_model(FLAGS, properties, logdir, noise)
    generated_seqs = get_generated_seqs(model)
    session_creator = ChiefSessionCreator(
        master='',
        checkpoint_filename_with_path=tf.train.latest_checkpoint(logdir))
    seqs = []
    with MonitoredSession(session_creator=session_creator,
                          hooks=None) as session:
        noise1 = np.random.uniform(-1, 1, FLAGS.z_dim)
        noise2 = np.random.uniform(-1, 1, FLAGS.z_dim)
        n = np.stack([
            slerp(ratio, noise1, noise2)
            for ratio in np.linspace(0, 1, FLAGS.batch_size)
        ])
        results, d_scores = session.run(
            [generated_seqs, model.discriminator_fake], feed_dict={noise: n})
        for i in range(FLAGS.batch_size):
            seqs.append(Sequence(id=i, seq=results[i], d_score=d_scores[i]))
        print(
            sequences_to_fasta(seqs,
                               properties['class_mapping'],
                               escape=False,
                               strip_zeros=True))
Example #6
0
def get_discriminator_results():
    properties = get_properties(FLAGS)
    logdir = setup_logdir(FLAGS, properties)
    noise = tf.placeholder(dtype=tf.float32,
                           shape=[FLAGS.batch_size, FLAGS.z_dim])
    model = get_model(FLAGS, properties, logdir, noise)
    s1 = [FLAGS.batch_size, properties[SEQ_LENGTH]]
    input = tf.placeholder(dtype=tf.int32, shape=s1)
    data = tf.expand_dims(tf.transpose(tf.one_hot(input, FLAGS.n_seqs, axis=1),
                                       [0, 2, 1]),
                          axis=1)
    s2 = [FLAGS.batch_size]
    labels = tf.placeholder(dtype=tf.float32, shape=[FLAGS.batch_size])
    with tf.variable_scope('model', reuse=tf.AUTO_REUSE):
        d, d_h = model.get_discriminator_result(data, labels, reuse=True)

    fasta_seqs = fasta_to_numpy(FLAGS.fasta_path, properties[SEQ_LENGTH])
    session_creator = ChiefSessionCreator(
        master='',
        checkpoint_filename_with_path=tf.train.latest_checkpoint(logdir))
    seqs = []
    with MonitoredSession(session_creator=session_creator,
                          hooks=None) as session:
        for i in range(0, len(fasta_seqs), FLAGS.batch_size):
            print("Processing batch ", i)
            batch = fasta_seqs[i:i + FLAGS.batch_size]
            l = len(batch)
            if l < (FLAGS.batch_size):
                batch = np.vstack([
                    batch,
                    np.zeros([FLAGS.batch_size - l, properties[SEQ_LENGTH]])
                ])
            d_scores, step = session.run([d, tf.train.get_global_step()],
                                         feed_dict={
                                             input: batch,
                                             labels: np.zeros(s2)
                                         })
            for j in range(l):
                seqs.append(
                    Sequence(id=j + i,
                             seq=fasta_seqs[j + i],
                             d_score=d_scores[j]))
        fasta = sequences_to_fasta(seqs,
                                   properties['class_mapping'],
                                   escape=False,
                                   strip_zeros=True)
        time_stamp = time.strftime('%H_%M_%S', time.gmtime())
        original_name = os.path.splitext(os.path.basename(FLAGS.fasta_path))[0]
        path = os.path.join(
            logdir, '{}_scores_{}_{}.fasta'.format(original_name, step,
                                                   time_stamp))
        with open(path, 'w') as f:
            print(fasta, file=f)
            tf.logging.info('{} sequences stored in {}'.format(
                len(seqs), path))
Example #7
0
def generate_sequences():
    properties = get_properties(FLAGS)
    logdir = setup_logdir(FLAGS, properties)
    tf.logging.info('Noise will have standard deviation of {}'.format(
        FLAGS.stddev))
    noise = tf.random.truncated_normal([FLAGS.batch_size, FLAGS.z_dim],
                                       stddev=FLAGS.stddev,
                                       dtype=tf.float32)
    model = get_model(FLAGS, properties, logdir, noise)
    if FLAGS.one_hot:
        generated_seqs = tf.squeeze(tf.argmax(model.fake_x, axis=-1))
    else:
        generated_seqs = convert_to_acid_ids(model.fake_x)
    seqs = []
    session_creator = ChiefSessionCreator(
        master='',
        checkpoint_filename_with_path=tf.train.latest_checkpoint(logdir))
    with MonitoredSession(session_creator=session_creator,
                          hooks=None) as session:
        while True:
            results, step = session.run(
                [generated_seqs, tf.train.get_global_step()], None)
            id = len(seqs)
            for i in range(FLAGS.batch_size):
                seqs.append(Sequence(id=id + i, seq=results[i]))
            if len(seqs) >= FLAGS.n_seqs:
                break
    time_stamp = time.strftime('%H_%M_%S', time.gmtime())
    path = os.path.join(logdir,
                        'generated_{}_{}.fasta'.format(step, time_stamp))
    fasta = sequences_to_fasta(seqs,
                               properties['class_mapping'],
                               escape=False,
                               strip_zeros=True)
    if FLAGS.blast:
        db_path = os.path.join(
            FLAGS.data_dir, FLAGS.dataset,
            FLAGS.blast_db.replace("\\", os.sep) + "_" + FLAGS.running_mode)
        blast_results, err = get_local_blast_results(logdir, db_path, fasta)
        seqs, evalues, similarities, identity = update_sequences_with_blast_results(
            blast_results, seqs)
        print_stats([("Evalue", evalues), ("BLOMSUM45", similarities),
                     ("Identity", identity)], len(seqs))
        fasta = sequences_to_fasta(seqs,
                                   properties['class_mapping'],
                                   escape=False,
                                   strip_zeros=True)
    with open(path, 'w') as f:
        print(fasta, file=f)
        tf.logging.info('{} sequences stored in {}'.format(len(seqs), path))
    tf.logging.info('Finished evaluation at ' +
                    time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime()))
Example #8
0
def raw_results():
    properties = get_properties(FLAGS)
    logdir = setup_logdir(FLAGS, properties)
    noise = tf.random.truncated_normal([FLAGS.batch_size, FLAGS.z_dim],
                                       stddev=0.5,
                                       dtype=tf.float32)
    model = get_model(FLAGS, properties, logdir, noise)
    raw_generations = tf.squeeze(model.fake_x)
    session_creator = ChiefSessionCreator(
        master='',
        checkpoint_filename_with_path=tf.train.latest_checkpoint(logdir))
    with MonitoredSession(session_creator=session_creator,
                          hooks=None) as session:
        results, step = session.run(
            [raw_generations, tf.train.get_global_step()], None)
        time_stamp = time.strftime('%H_%M_%S', time.gmtime())
        path = os.path.join(logdir, 'raw_{}_{}.npz'.format(step, time_stamp))
        with open(path, 'wb') as f:
            np.savez(f, results)
if __name__ == "__main__":

    def generator_fn():
        yield ([[1, 2, 3], [1, 1, 0], [1, 1, 1]], [[1, 2, 1], [1, 1, 0],
                                                   [1, 1, 1]])
        yield ([[0, 1, 0], [1, 0, 2], [3, 0, 0]], [[0, 1, 0], [2, 1, 2],
                                                   [0, 3, 1]])

    ds = tf.data.Dataset.from_generator(generator_fn, (tf.int32, tf.int32),
                                        ([3, 3], [3, 3]))
    y_true, y_pred = ds.make_one_shot_iterator().get_next()

    # result, fake_labels, fake_predictions = correct_rate(y_true, y_pred)
    result = correct_rate(y_true, y_pred)

    with MonitoredSession(hooks=[TensorObserveHook()]) as sess:
        while not sess.should_stop():
            try:
                sess.run([result[1]])
            except OutOfRangeError as e:
                break

        result = sess.run(result[0])
        # Check final value
        assert np.allclose(result, 0.5)

    # with tf.Session() as sess:
    #     # Initialize and run the update op on each batch
    #     sess.run(tf.local_variables_initializer())
    #     while True:
    #         try:
Example #10
0
def MonitoredTrainingSession(master='',
                             is_chief=True,
                             checkpoint_dir=None,
                             scaffold=None,
                             hooks=None,
                             chief_only_hooks=None,
                             save_checkpoint_secs=USE_DEFAULT,
                             save_summaries_steps=USE_DEFAULT,
                             save_summaries_secs=USE_DEFAULT,
                             config=None,
                             stop_grace_period_secs=120,
                             log_step_count_steps=100,
                             save_checkpoint_steps=USE_DEFAULT,
                             summary_dir=None):
    if save_summaries_steps == USE_DEFAULT and save_summaries_secs == USE_DEFAULT:
        save_summaries_steps = 100
        save_summaries_secs = None
    elif save_summaries_secs == USE_DEFAULT:
        save_summaries_secs = None
    elif save_summaries_steps == USE_DEFAULT:
        save_summaries_steps = None

    if (save_checkpoint_steps == USE_DEFAULT
            and save_checkpoint_secs == USE_DEFAULT):
        save_checkpoint_steps = None
        save_checkpoint_secs = 600
    elif save_checkpoint_secs == USE_DEFAULT:
        save_checkpoint_secs = None
    elif save_checkpoint_steps == USE_DEFAULT:
        save_checkpoint_steps = None

    scaffold = scaffold or Scaffold()

    all_hooks = []
    if is_chief and chief_only_hooks:
        all_hooks.extend(chief_only_hooks)

    session_creator = ChiefSessionCreator(scaffold=scaffold,
                                          checkpoint_dir=checkpoint_dir,
                                          master=master,
                                          config=config)

    summary_dir = summary_dir or checkpoint_dir
    if summary_dir:
        if (save_summaries_steps
                and save_summaries_steps > 0) or (save_summaries_secs
                                                  and save_summaries_secs > 0):
            all_hooks.append(
                basic_session_run_hooks.SummarySaverHook(
                    scaffold=scaffold,
                    save_steps=save_summaries_steps,
                    save_secs=save_summaries_secs,
                    output_dir=summary_dir))

    if checkpoint_dir:
        if (save_checkpoint_secs and save_checkpoint_secs > 0) or (
                save_checkpoint_steps and save_checkpoint_steps > 0):
            all_hooks.append(
                basic_session_run_hooks.CheckpointSaverHook(
                    checkpoint_dir,
                    save_steps=save_checkpoint_steps,
                    save_secs=save_checkpoint_secs,
                    scaffold=scaffold))

    if hooks:
        all_hooks.extend(hooks)

    hvd_info_rank0('all hooks {}'.format(all_hooks))
    return MonitoredSession(session_creator=session_creator,
                            hooks=all_hooks,
                            stop_grace_period_secs=stop_grace_period_secs)