Exemple #1
0
def chiron_train():
    training = tf.placeholder(tf.bool)
    global_step = tf.get_variable('global_step',
                                  trainable=False,
                                  shape=(),
                                  dtype=tf.int32,
                                  initializer=tf.zeros_initializer())
    x = tf.placeholder(tf.float32,
                       shape=[FLAGS.batch_size, FLAGS.sequence_len])
    seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size])
    y_indexs = tf.placeholder(tf.int64)
    y_values = tf.placeholder(tf.int32)
    y_shape = tf.placeholder(tf.int64)
    y = tf.SparseTensor(y_indexs, y_values, y_shape)
    default_config = os.path.join(FLAGS.log_dir, FLAGS.model_name,
                                  'model.json')

    if FLAGS.retrain:
        if os.path.isfile(default_config):
            config_file = default_config
        else:
            raise ValueError(
                "Model Json file has not been found in model log directory")
    else:
        config_file = FLAGS.configure

    config = model.read_config(config_file)

    logits, ratio = model.inference(x,
                                    seq_length,
                                    training,
                                    FLAGS.sequence_len,
                                    configure=config)
    ctc_loss = model.loss(logits, seq_length, y)
    opt = model.train_opt(FLAGS.step_rate,
                          FLAGS.max_steps,
                          global_step=global_step,
                          opt_name=config['opt_method'])
    step = opt.minimize(ctc_loss, global_step=global_step)
    error = model.prediction(logits, seq_length, y)
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    summary = tf.summary.merge_all()

    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))

    if FLAGS.retrain == False:
        sess.run(init)
        print("Model init finished, begin loading data. \n")
    else:
        saver.restore(
            sess, tf.train.latest_checkpoint(FLAGS.log_dir + FLAGS.model_name))
        print("Model loaded finished, begin loading data. \n")
    summary_writer = tf.summary.FileWriter(
        FLAGS.log_dir + FLAGS.model_name + '/summary/', sess.graph)
    model.save_model(default_config, config)

    train_ds, valid_ds = generate_train_valid_datasets()
    start = time.time()

    for i in range(FLAGS.max_steps):
        batch_x, seq_len, batch_y = train_ds.next_batch(FLAGS.batch_size)
        indxs, values, shape = batch_y
        feed_dict = {
            x: batch_x,
            seq_length: seq_len / ratio,
            y_indexs: indxs,
            y_values: values,
            y_shape: shape,
            training: True
        }
        loss_val, _ = sess.run([ctc_loss, step], feed_dict=feed_dict)
        if i % 10 == 0:
            global_step_val = tf.train.global_step(sess, global_step)
            valid_x, valid_len, valid_y = valid_ds.next_batch(FLAGS.batch_size)
            indxs, values, shape = valid_y
            feed_dict = {
                x: valid_x,
                seq_length: valid_len / ratio,
                y_indexs: indxs,
                y_values: values,
                y_shape: shape,
                training: True
            }
            error_val = sess.run(error, feed_dict=feed_dict)
            end = time.time()
            print(
            "Step %d/%d Epoch %d, batch number %d, train_loss: %5.3f validate_edit_distance: %5.3f Elapsed Time/step: %5.3f" \
            % (i, FLAGS.max_steps, train_ds.epochs_completed,
               train_ds.index_in_epoch, loss_val, error_val,
               (end - start) / (i + 1)))
            saver.save(sess,
                       FLAGS.log_dir + FLAGS.model_name + '/model.ckpt',
                       global_step=global_step_val)
            summary_str = sess.run(summary, feed_dict=feed_dict)
            summary_writer.add_summary(summary_str,
                                       global_step=global_step_val)
            summary_writer.flush()
    global_step_val = tf.train.global_step(sess, global_step)
    print("Model %s saved." % (FLAGS.log_dir + FLAGS.model_name))
    print("Reads number %d" % (train_ds.reads_n))
    saver.save(sess,
               FLAGS.log_dir + FLAGS.model_name + '/final.ckpt',
               global_step=global_step_val)
def train(hparams):
    """Main training function.
    This will train a Neural Network with the given dataset.

    Args:
        hparams: hyper parameter for training the neural network
            data-dir: String, the path of the data(binary batch files) directory.
            log-dir: String, the path to save the trained model.
            sequence-len: Int, length of input signal.
            batch-size: Int.
            step-rate: Float, step rate of the optimizer.
            max-steps: Int, max training steps.
            kmer: Int, size of the dna kmer.
            model-name: String, model will be saved at log-dir/model-name.
            retrain: Boolean, if True, the model will be reload from log-dir/model-name.

    """
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        training = tf.placeholder(tf.bool)
        global_step = tf.get_variable('global_step',
                                      trainable=False,
                                      shape=(),
                                      dtype=tf.int32,
                                      initializer=tf.zeros_initializer())

        opt = model.train_opt(hparams.step_rate,
                              hparams.max_steps,
                              global_step=global_step)
        x, seq_length, train_labels = inputs(hparams.data_dir,
                                             int(hparams.batch_size *
                                                 hparams.ngpus),
                                             for_valid=False)
        split_y = tf.split(train_labels, hparams.ngpus, axis=0)
        split_seq_length = tf.split(seq_length, hparams.ngpus, axis=0)
        split_x = tf.split(x, hparams.ngpus, axis=0)
        tower_grads = []
        default_config = os.path.join(hparams.log_dir, hparams.model_name,
                                      'model.json')
        if hparams.retrain:
            if os.path.isfile(default_config):
                config_file = default_config
            else:
                raise ValueError(
                    "Model Json file has not been found in model log directory"
                )
        else:
            config_file = hparams.configure
        config = model.read_config(config_file)
        with tf.variable_scope(tf.get_variable_scope()):
            for i in range(hparams.ngpus):
                with tf.device('/gpu:%d' % i):
                    with tf.name_scope('%s_%d' % ('gpu_tower', i)) as scope:
                        loss, error = tower_loss(
                            scope,
                            split_x[i],
                            split_seq_length[i],
                            split_y[i],
                            full_seq_len=hparams.sequence_len,
                            config=config)
                        tf.get_variable_scope().reuse_variables()
                        summaries = tf.get_collection(tf.GraphKeys.SUMMARIES,
                                                      scope)
                        grads = opt.compute_gradients(loss)
                        tower_grads.append(grads)
        grads = average_gradients(tower_grads)
        for grad, var in grads:
            if grad is not None:
                summaries.append(
                    tf.summary.histogram(var.op.name + '/gradients', grad))
        for var in tf.trainable_variables():
            summaries.append(tf.summary.histogram(var.op.name, var))
        apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
        var_averages = tf.train.ExponentialMovingAverage(
            decay=model.MOVING_AVERAGE_DECAY)
        var_averages_op = var_averages.apply(tf.trainable_variables())
        train_op = tf.group(apply_gradient_op, var_averages_op)
        init = tf.global_variables_initializer()
        saver = tf.train.Saver()
        summary = tf.summary.merge_all()

        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                                log_device_placement=False))
        model.save_model(default_config, config)
        if not hparams.retrain:
            sess.run(init)
            print("Model init finished, begin training. \n")
        else:
            saver.restore(
                sess,
                tf.train.latest_checkpoint(hparams.log_dir +
                                           hparams.model_name))
            print("Model loaded finished, begin training. \n")
        summary_writer = tf.summary.FileWriter(
            hparams.log_dir + hparams.model_name + '/summary/', sess.graph)
        _ = tf.train.start_queue_runners(sess=sess)

        start = time.time()
        for i in range(hparams.max_steps):
            feed_dict = {training: True}
            loss_val, _ = sess.run([loss, train_op], feed_dict=feed_dict)
            if i % 10 == 0:
                global_step_val = tf.train.global_step(sess, global_step)
                feed_dict = {training: True}
                error_val = sess.run(error, feed_dict=feed_dict)
                end = time.time()
                print(
                    "Step %d/%d ,  loss: %5.3f edit_distance: %5.3f Elapsed Time/batch: %5.3f" \
                    % (i, hparams.max_steps, loss_val, error_val,
                    (end - start) / (i + 1)))
                saver.save(sess,
                           hparams.log_dir + hparams.model_name +
                           '/model.ckpt',
                           global_step=global_step_val)
                summary_str = sess.run(summary, feed_dict=feed_dict)
                summary_writer.add_summary(summary_str,
                                           global_step=global_step_val)
                summary_writer.flush()
        global_step_val = tf.train.global_step(sess, global_step)
        print("Model %s saved." % (hparams.log_dir + hparams.model_name))
        saver.save(sess,
                   hparams.log_dir + hparams.model_name + '/final.ckpt',
                   global_step=global_step_val)
Exemple #3
0
def train():
    default_config = os.path.join(FLAGS.log_dir, FLAGS.model_name,
                                  'model.json')
    if FLAGS.retrain:
        if os.path.isfile(default_config):
            config_file = default_config
        else:
            raise ValueError(
                "Model Json file has not been found in model log directory")
    else:
        config_file = FLAGS.configure
    config = model.read_config(config_file)
    print("Begin training using following setting:")
    with open(os.path.join(FLAGS.log_dir, FLAGS.model_name, 'train_config'),
              'w+') as log_f:
        for pro in dir(FLAGS):
            if not pro.startswith('_'):
                print("%s:%s" % (pro, getattr(FLAGS, pro)))
                log_f.write("%s:%s\n" % (pro, getattr(FLAGS, pro)))
    net = compile_train_graph(config, FLAGS)
    sess = tf.Session(
        config=tf.ConfigProto(inter_op_parallelism_threads=FLAGS.threads,
                              intra_op_parallelism_threads=FLAGS.threads,
                              allow_soft_placement=True))
    if FLAGS.retrain == False:
        sess.run(net.init)
        print("Model init finished, begin loading data. \n")
    else:
        net.saver.restore(
            sess, tf.train.latest_checkpoint(FLAGS.log_dir + FLAGS.model_name))
        print("Model loaded finished, begin loading data. \n")
    summary_writer = tf.summary.FileWriter(
        FLAGS.log_dir + FLAGS.model_name + '/summary/', sess.graph)
    model.save_model(default_config, config)
    train_ds, valid_ds = generate_train_valid_datasets(
        initial_offset=DEFAULT_OFFSET)
    start = time.time()
    resample_n = 0
    for i in range(FLAGS.max_steps):
        if FLAGS.resample_after_epoch == 0:
            pass
        elif train_ds.epochs_completed >= FLAGS.resample_after_epoch:
            train_ds, valid_ds = generate_train_valid_datasets(
                initial_offset=resample_n * FLAGS.offset_increment +
                DEFAULT_OFFSET)
        batch_x, seq_len, batch_y = train_ds.next_batch(FLAGS.batch_size)
        indxs, values, shape = batch_y
        feed_dict = {
            net.x: batch_x,
            net.seq_length: seq_len / net.ratio,
            net.y_indexs: indxs,
            net.y_values: values,
            net.y_shape: shape,
            net.training: True
        }
        loss_val, _ = sess.run([net.ctc_loss, net.step], feed_dict=feed_dict)
        if i % 10 == 0:
            global_step_val = tf.train.global_step(sess, net.global_step)
            valid_x, valid_len, valid_y = valid_ds.next_batch(FLAGS.batch_size)
            indxs, values, shape = valid_y
            feed_dict = {
                net.x: valid_x,
                net.seq_length: valid_len / net.ratio,
                net.y_indexs: indxs,
                net.y_values: values,
                net.y_shape: shape,
                net.training: True
            }
            error_val = sess.run(net.error, feed_dict=feed_dict)
            #            x_val,errors_val,y_predict,y = sess.run([x,errors,y_,y],feed_dict = feed_dict)
            #            predict_seq,_ = sparse2dense([y_predict,0])
            #            true_seq,_ = sparse2dense([[y],0])
            end = time.time()
            print(
            "Step %d/%d Epoch %d, batch number %d, train_loss: %5.3f validate_edit_distance: %5.3f Elapsed Time/step: %5.3f" \
            % (i, FLAGS.max_steps, train_ds.epochs_completed,
               train_ds.index_in_epoch, loss_val, error_val,
               (end - start) / (i + 1)))
            net.saver.save(sess,
                           FLAGS.log_dir + FLAGS.model_name + '/model.ckpt',
                           global_step=global_step_val)
            summary_str = sess.run(net.summary, feed_dict=feed_dict)
            summary_writer.add_summary(summary_str,
                                       global_step=global_step_val)
            summary_writer.flush()
    global_step_val = tf.train.global_step(sess, net.global_step)
    print("Model %s saved." % (FLAGS.log_dir + FLAGS.model_name))
    print("Reads number %d" % (train_ds.reads_n))
    net.saver.save(sess,
                   FLAGS.log_dir + FLAGS.model_name + '/final.ckpt',
                   global_step=global_step_val)
Exemple #4
0
def train(hparam):
    """Main training function.
    This will train a Neural Network with the given dataset.

    Args:
        hparam: hyper parameter for training the neural network
            data_dir: String, the path of the data(binary batch files) directory.
            log-dir: String, the path to save the trained model.
            sequence-len: Int, length of input signal.
            batch-size: Int.
            step-rate: Float, step rate of the optimizer.
            max-steps: Int, max training steps.
            kmer: Int, size of the dna kmer.
            model-name: String, model will be saved at log-dir/model-name.
            retrain: Boolean, if True, the model will be reload from log-dir/model-name.

    """
    training = tf.placeholder(tf.bool)
    global_step = tf.get_variable('global_step',
                                  trainable=False,
                                  shape=(),
                                  dtype=tf.int32,
                                  initializer=tf.zeros_initializer())

    x, seq_length, train_labels = inputs(hparam.data_dir,
                                         hparam.batch_size,
                                         hparam.sequence_len,
                                         for_valid=False)
    y = dense2sparse(train_labels)
    default_config = os.path.join(hparam.log_dir, hparam.model_name,
                                  'model.json')
    if hparam.retrain:
        if os.path.isfile(default_config):
            config_file = default_config
        else:
            raise ValueError(
                "Model Json file has not been found in model log directory")
    else:
        config_file = hparam.configure
    config = model.read_config(config_file)
    logits, ratio = model.inference(x,
                                    seq_length,
                                    training,
                                    hparam.sequence_len,
                                    configure=config,
                                    apply_ratio=True)
    seq_length = tf.cast(tf.ceil(tf.cast(seq_length, tf.float32) / ratio),
                         tf.int32)
    ctc_loss = model.loss(logits, seq_length, y)
    opt = model.train_opt(hparam.step_rate,
                          hparam.max_steps,
                          global_step=global_step)
    step = opt.minimize(ctc_loss, global_step=global_step)
    error = model.prediction(logits, seq_length, y)
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    summary = tf.summary.merge_all()

    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    model.save_model(default_config, config)
    if not hparam.retrain:
        sess.run(init)
        print("Model init finished, begin training. \n")
    else:
        saver.restore(
            sess,
            tf.train.latest_checkpoint(hparam.log_dir + hparam.model_name))
        print("Model loaded finished, begin training. \n")
    summary_writer = tf.summary.FileWriter(
        hparam.log_dir + hparam.model_name + '/summary/', sess.graph)
    _ = tf.train.start_queue_runners(sess=sess)

    start = time.time()
    for i in range(hparam.max_steps):
        feed_dict = {training: True}
        loss_val, _ = sess.run([ctc_loss, step], feed_dict=feed_dict)
        if i % 10 == 0:
            global_step_val = tf.train.global_step(sess, global_step)
            feed_dict = {training: True}
            error_val = sess.run(error, feed_dict=feed_dict)
            end = time.time()
            print(
                "Step %d/%d ,  loss: %5.3f edit_distance: %5.3f Elapsed Time/batch: %5.3f" \
                % (i, hparam.max_steps, loss_val, error_val,
                   (end - start) / (i + 1)))
            saver.save(sess,
                       hparam.log_dir + hparam.model_name + '/model.ckpt',
                       global_step=global_step_val)
            summary_str = sess.run(summary, feed_dict=feed_dict)
            summary_writer.add_summary(summary_str,
                                       global_step=global_step_val)
            summary_writer.flush()
    global_step_val = tf.train.global_step(sess, global_step)
    print("Model %s saved." % (hparam.log_dir + hparam.model_name))
    saver.save(sess,
               hparam.log_dir + hparam.model_name + '/final.ckpt',
               global_step=global_step_val)