def tower_loss(scope, x, seqlen, labels, full_seq_len):
    """Calculating the loss on a single GPU.
    
    Args:
        scope (String): prefix string describe the tower name, e.g. 'tower_0'
        x (Float): Tensor of shape [batch_size, max_time], batch of input signal.
        seqlen (Int): Tensor of shape [batch_size], length of sequence in batch.
        labels (Int): Sparse Tensor, true labels.

    Returns:
        Tensor of shape [batch_size] containing the loss for a batch of data.
    """
    logits, _ = model.inference(x,
                                seqlen,
                                training=True,
                                full_sequence_len=full_seq_len)
    sparse_labels = dense2sparse(labels)
    _ = model.loss(logits, seqlen, sparse_labels)
    error = model.prediction(logits, seqlen, sparse_labels)
    losses = tf.get_collection('losses', scope)
    total_loss = tf.add_n(losses, name='total_loss')
    for l in losses + [total_loss]:
        tf.summary.scalar(l.op.name, l)
        tf.summary.scalar(error.op.name, error)
    return total_loss, error
Beispiel #2
0
def compile_train_graph(config,hp):
    class net:
        pass
    net.training = tf.placeholder(tf.bool)
    net.global_step = tf.get_variable('global_step', trainable=False, shape=(),
                                  dtype=tf.int32,
                                  initializer=tf.zeros_initializer())
    net.x = tf.placeholder(tf.float32, shape=[hp.batch_size, hp.sequence_len])
    net.seq_length = tf.placeholder(tf.int32, shape=[hp.batch_size])
    net.y_indexs = tf.placeholder(tf.int64)
    net.y_values = tf.placeholder(tf.int32)
    net.y_shape = tf.placeholder(tf.int64)
    net.y = tf.SparseTensor(net.y_indexs, net.y_values, net.y_shape)
    net.logits, net.ratio = model.inference(net.x, net.seq_length, net.training,hp.sequence_len,configure = config)
    if 'fl_gamma' in config.keys():
        net.ctc_loss = model.loss(net.logits, net.seq_length, net.y, fl_gamma = config['fl_gamma'])
    else:
        net.ctc_loss = model.loss(net.logits, net.seq_length, net.y)
    net.opt = model.train_opt(hp.step_rate,
                          hp.max_steps, 
                          global_step=net.global_step,
                          opt_name = config['opt_method'])
    if hp.gradient_clip is None:
        net.step = net.opt.minimize(net.ctc_loss,global_step = net.global_step)
    else:
        net.gradients, net.variables = zip(*net.opt.compute_gradients(net.ctc_loss))
        net.gradients = [None if gradient is None else tf.clip_by_norm(gradient, hp.gradient_clip) for gradient in net.gradients]
        net.step = net.opt.apply_gradients(zip(net.gradients, net.variables),global_step = net.global_step)
    net.error,net.errors,net.y_ = model.prediction(net.logits, net.seq_length, net.y)
    net.init = tf.global_variables_initializer()
    net.variable_to_restore=set(variables._all_saveable_objects()+tf.moving_average_variables())
    net.saver = tf.train.Saver(var_list=net.variable_to_restore, 
                               save_relative_paths=True)
    net.summary = tf.summary.merge_all()
    return net
Beispiel #3
0
def chiron_train():
    training = tf.placeholder(tf.bool)
    global_step = tf.get_variable('global_step',
                                  trainable=False,
                                  shape=(),
                                  dtype=tf.int32,
                                  initializer=tf.zeros_initializer())
    x = tf.placeholder(tf.float32,
                       shape=[FLAGS.batch_size, FLAGS.sequence_len])
    seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size])
    y_indexs = tf.placeholder(tf.int64)
    y_values = tf.placeholder(tf.int32)
    y_shape = tf.placeholder(tf.int64)
    y = tf.SparseTensor(y_indexs, y_values, y_shape)
    default_config = os.path.join(FLAGS.log_dir, FLAGS.model_name,
                                  'model.json')

    if FLAGS.retrain:
        if os.path.isfile(default_config):
            config_file = default_config
        else:
            raise ValueError(
                "Model Json file has not been found in model log directory")
    else:
        config_file = FLAGS.configure

    config = model.read_config(config_file)

    logits, ratio = model.inference(x,
                                    seq_length,
                                    training,
                                    FLAGS.sequence_len,
                                    configure=config)
    ctc_loss = model.loss(logits, seq_length, y)
    opt = model.train_opt(FLAGS.step_rate,
                          FLAGS.max_steps,
                          global_step=global_step,
                          opt_name=config['opt_method'])
    step = opt.minimize(ctc_loss, global_step=global_step)
    error = model.prediction(logits, seq_length, y)
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    summary = tf.summary.merge_all()

    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))

    if FLAGS.retrain == False:
        sess.run(init)
        print("Model init finished, begin loading data. \n")
    else:
        saver.restore(
            sess, tf.train.latest_checkpoint(FLAGS.log_dir + FLAGS.model_name))
        print("Model loaded finished, begin loading data. \n")
    summary_writer = tf.summary.FileWriter(
        FLAGS.log_dir + FLAGS.model_name + '/summary/', sess.graph)
    model.save_model(default_config, config)

    train_ds, valid_ds = generate_train_valid_datasets()
    start = time.time()

    for i in range(FLAGS.max_steps):
        batch_x, seq_len, batch_y = train_ds.next_batch(FLAGS.batch_size)
        indxs, values, shape = batch_y
        feed_dict = {
            x: batch_x,
            seq_length: seq_len / ratio,
            y_indexs: indxs,
            y_values: values,
            y_shape: shape,
            training: True
        }
        loss_val, _ = sess.run([ctc_loss, step], feed_dict=feed_dict)
        if i % 10 == 0:
            global_step_val = tf.train.global_step(sess, global_step)
            valid_x, valid_len, valid_y = valid_ds.next_batch(FLAGS.batch_size)
            indxs, values, shape = valid_y
            feed_dict = {
                x: valid_x,
                seq_length: valid_len / ratio,
                y_indexs: indxs,
                y_values: values,
                y_shape: shape,
                training: True
            }
            error_val = sess.run(error, feed_dict=feed_dict)
            end = time.time()
            print(
            "Step %d/%d Epoch %d, batch number %d, train_loss: %5.3f validate_edit_distance: %5.3f Elapsed Time/step: %5.3f" \
            % (i, FLAGS.max_steps, train_ds.epochs_completed,
               train_ds.index_in_epoch, loss_val, error_val,
               (end - start) / (i + 1)))
            saver.save(sess,
                       FLAGS.log_dir + FLAGS.model_name + '/model.ckpt',
                       global_step=global_step_val)
            summary_str = sess.run(summary, feed_dict=feed_dict)
            summary_writer.add_summary(summary_str,
                                       global_step=global_step_val)
            summary_writer.flush()
    global_step_val = tf.train.global_step(sess, global_step)
    print("Model %s saved." % (FLAGS.log_dir + FLAGS.model_name))
    print("Reads number %d" % (train_ds.reads_n))
    saver.save(sess,
               FLAGS.log_dir + FLAGS.model_name + '/final.ckpt',
               global_step=global_step_val)
Beispiel #4
0
def train(hparam):
    """Main training function.
    This will train a Neural Network with the given dataset.

    Args:
        hparam: hyper parameter for training the neural network
            data_dir: String, the path of the data(binary batch files) directory.
            log-dir: String, the path to save the trained model.
            sequence-len: Int, length of input signal.
            batch-size: Int.
            step-rate: Float, step rate of the optimizer.
            max-steps: Int, max training steps.
            kmer: Int, size of the dna kmer.
            model-name: String, model will be saved at log-dir/model-name.
            retrain: Boolean, if True, the model will be reload from log-dir/model-name.

    """
    training = tf.placeholder(tf.bool)
    global_step = tf.get_variable('global_step',
                                  trainable=False,
                                  shape=(),
                                  dtype=tf.int32,
                                  initializer=tf.zeros_initializer())

    x, seq_length, train_labels = inputs(hparam.data_dir,
                                         hparam.batch_size,
                                         hparam.sequence_len,
                                         for_valid=False)
    y = dense2sparse(train_labels)
    default_config = os.path.join(hparam.log_dir, hparam.model_name,
                                  'model.json')
    if hparam.retrain:
        if os.path.isfile(default_config):
            config_file = default_config
        else:
            raise ValueError(
                "Model Json file has not been found in model log directory")
    else:
        config_file = hparam.configure
    config = model.read_config(config_file)
    logits, ratio = model.inference(x,
                                    seq_length,
                                    training,
                                    hparam.sequence_len,
                                    configure=config,
                                    apply_ratio=True)
    seq_length = tf.cast(tf.ceil(tf.cast(seq_length, tf.float32) / ratio),
                         tf.int32)
    ctc_loss = model.loss(logits, seq_length, y)
    opt = model.train_opt(hparam.step_rate,
                          hparam.max_steps,
                          global_step=global_step)
    step = opt.minimize(ctc_loss, global_step=global_step)
    error = model.prediction(logits, seq_length, y)
    init = tf.global_variables_initializer()
    saver = tf.train.Saver()
    summary = tf.summary.merge_all()

    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    model.save_model(default_config, config)
    if not hparam.retrain:
        sess.run(init)
        print("Model init finished, begin training. \n")
    else:
        saver.restore(
            sess,
            tf.train.latest_checkpoint(hparam.log_dir + hparam.model_name))
        print("Model loaded finished, begin training. \n")
    summary_writer = tf.summary.FileWriter(
        hparam.log_dir + hparam.model_name + '/summary/', sess.graph)
    _ = tf.train.start_queue_runners(sess=sess)

    start = time.time()
    for i in range(hparam.max_steps):
        feed_dict = {training: True}
        loss_val, _ = sess.run([ctc_loss, step], feed_dict=feed_dict)
        if i % 10 == 0:
            global_step_val = tf.train.global_step(sess, global_step)
            feed_dict = {training: True}
            error_val = sess.run(error, feed_dict=feed_dict)
            end = time.time()
            print(
                "Step %d/%d ,  loss: %5.3f edit_distance: %5.3f Elapsed Time/batch: %5.3f" \
                % (i, hparam.max_steps, loss_val, error_val,
                   (end - start) / (i + 1)))
            saver.save(sess,
                       hparam.log_dir + hparam.model_name + '/model.ckpt',
                       global_step=global_step_val)
            summary_str = sess.run(summary, feed_dict=feed_dict)
            summary_writer.add_summary(summary_str,
                                       global_step=global_step_val)
            summary_writer.flush()
    global_step_val = tf.train.global_step(sess, global_step)
    print("Model %s saved." % (hparam.log_dir + hparam.model_name))
    saver.save(sess,
               hparam.log_dir + hparam.model_name + '/final.ckpt',
               global_step=global_step_val)