def tower_loss(scope, x, seqlen, labels, full_seq_len): """Calculating the loss on a single GPU. Args: scope (String): prefix string describe the tower name, e.g. 'tower_0' x (Float): Tensor of shape [batch_size, max_time], batch of input signal. seqlen (Int): Tensor of shape [batch_size], length of sequence in batch. labels (Int): Sparse Tensor, true labels. Returns: Tensor of shape [batch_size] containing the loss for a batch of data. """ logits, _ = model.inference(x, seqlen, training=True, full_sequence_len=full_seq_len) sparse_labels = dense2sparse(labels) _ = model.loss(logits, seqlen, sparse_labels) error = model.prediction(logits, seqlen, sparse_labels) losses = tf.get_collection('losses', scope) total_loss = tf.add_n(losses, name='total_loss') for l in losses + [total_loss]: tf.summary.scalar(l.op.name, l) tf.summary.scalar(error.op.name, error) return total_loss, error
def compile_train_graph(config,hp): class net: pass net.training = tf.placeholder(tf.bool) net.global_step = tf.get_variable('global_step', trainable=False, shape=(), dtype=tf.int32, initializer=tf.zeros_initializer()) net.x = tf.placeholder(tf.float32, shape=[hp.batch_size, hp.sequence_len]) net.seq_length = tf.placeholder(tf.int32, shape=[hp.batch_size]) net.y_indexs = tf.placeholder(tf.int64) net.y_values = tf.placeholder(tf.int32) net.y_shape = tf.placeholder(tf.int64) net.y = tf.SparseTensor(net.y_indexs, net.y_values, net.y_shape) net.logits, net.ratio = model.inference(net.x, net.seq_length, net.training,hp.sequence_len,configure = config) if 'fl_gamma' in config.keys(): net.ctc_loss = model.loss(net.logits, net.seq_length, net.y, fl_gamma = config['fl_gamma']) else: net.ctc_loss = model.loss(net.logits, net.seq_length, net.y) net.opt = model.train_opt(hp.step_rate, hp.max_steps, global_step=net.global_step, opt_name = config['opt_method']) if hp.gradient_clip is None: net.step = net.opt.minimize(net.ctc_loss,global_step = net.global_step) else: net.gradients, net.variables = zip(*net.opt.compute_gradients(net.ctc_loss)) net.gradients = [None if gradient is None else tf.clip_by_norm(gradient, hp.gradient_clip) for gradient in net.gradients] net.step = net.opt.apply_gradients(zip(net.gradients, net.variables),global_step = net.global_step) net.error,net.errors,net.y_ = model.prediction(net.logits, net.seq_length, net.y) net.init = tf.global_variables_initializer() net.variable_to_restore=set(variables._all_saveable_objects()+tf.moving_average_variables()) net.saver = tf.train.Saver(var_list=net.variable_to_restore, save_relative_paths=True) net.summary = tf.summary.merge_all() return net
def chiron_train(): training = tf.placeholder(tf.bool) global_step = tf.get_variable('global_step', trainable=False, shape=(), dtype=tf.int32, initializer=tf.zeros_initializer()) x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.sequence_len]) seq_length = tf.placeholder(tf.int32, shape=[FLAGS.batch_size]) y_indexs = tf.placeholder(tf.int64) y_values = tf.placeholder(tf.int32) y_shape = tf.placeholder(tf.int64) y = tf.SparseTensor(y_indexs, y_values, y_shape) default_config = os.path.join(FLAGS.log_dir, FLAGS.model_name, 'model.json') if FLAGS.retrain: if os.path.isfile(default_config): config_file = default_config else: raise ValueError( "Model Json file has not been found in model log directory") else: config_file = FLAGS.configure config = model.read_config(config_file) logits, ratio = model.inference(x, seq_length, training, FLAGS.sequence_len, configure=config) ctc_loss = model.loss(logits, seq_length, y) opt = model.train_opt(FLAGS.step_rate, FLAGS.max_steps, global_step=global_step, opt_name=config['opt_method']) step = opt.minimize(ctc_loss, global_step=global_step) error = model.prediction(logits, seq_length, y) init = tf.global_variables_initializer() saver = tf.train.Saver() summary = tf.summary.merge_all() sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) if FLAGS.retrain == False: sess.run(init) print("Model init finished, begin loading data. \n") else: saver.restore( sess, tf.train.latest_checkpoint(FLAGS.log_dir + FLAGS.model_name)) print("Model loaded finished, begin loading data. \n") summary_writer = tf.summary.FileWriter( FLAGS.log_dir + FLAGS.model_name + '/summary/', sess.graph) model.save_model(default_config, config) train_ds, valid_ds = generate_train_valid_datasets() start = time.time() for i in range(FLAGS.max_steps): batch_x, seq_len, batch_y = train_ds.next_batch(FLAGS.batch_size) indxs, values, shape = batch_y feed_dict = { x: batch_x, seq_length: seq_len / ratio, y_indexs: indxs, y_values: values, y_shape: shape, training: True } loss_val, _ = sess.run([ctc_loss, step], feed_dict=feed_dict) if i % 10 == 0: global_step_val = tf.train.global_step(sess, global_step) valid_x, valid_len, valid_y = valid_ds.next_batch(FLAGS.batch_size) indxs, values, shape = valid_y feed_dict = { x: valid_x, seq_length: valid_len / ratio, y_indexs: indxs, y_values: values, y_shape: shape, training: True } error_val = sess.run(error, feed_dict=feed_dict) end = time.time() print( "Step %d/%d Epoch %d, batch number %d, train_loss: %5.3f validate_edit_distance: %5.3f Elapsed Time/step: %5.3f" \ % (i, FLAGS.max_steps, train_ds.epochs_completed, train_ds.index_in_epoch, loss_val, error_val, (end - start) / (i + 1))) saver.save(sess, FLAGS.log_dir + FLAGS.model_name + '/model.ckpt', global_step=global_step_val) summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, global_step=global_step_val) summary_writer.flush() global_step_val = tf.train.global_step(sess, global_step) print("Model %s saved." % (FLAGS.log_dir + FLAGS.model_name)) print("Reads number %d" % (train_ds.reads_n)) saver.save(sess, FLAGS.log_dir + FLAGS.model_name + '/final.ckpt', global_step=global_step_val)
def train(hparam): """Main training function. This will train a Neural Network with the given dataset. Args: hparam: hyper parameter for training the neural network data_dir: String, the path of the data(binary batch files) directory. log-dir: String, the path to save the trained model. sequence-len: Int, length of input signal. batch-size: Int. step-rate: Float, step rate of the optimizer. max-steps: Int, max training steps. kmer: Int, size of the dna kmer. model-name: String, model will be saved at log-dir/model-name. retrain: Boolean, if True, the model will be reload from log-dir/model-name. """ training = tf.placeholder(tf.bool) global_step = tf.get_variable('global_step', trainable=False, shape=(), dtype=tf.int32, initializer=tf.zeros_initializer()) x, seq_length, train_labels = inputs(hparam.data_dir, hparam.batch_size, hparam.sequence_len, for_valid=False) y = dense2sparse(train_labels) default_config = os.path.join(hparam.log_dir, hparam.model_name, 'model.json') if hparam.retrain: if os.path.isfile(default_config): config_file = default_config else: raise ValueError( "Model Json file has not been found in model log directory") else: config_file = hparam.configure config = model.read_config(config_file) logits, ratio = model.inference(x, seq_length, training, hparam.sequence_len, configure=config, apply_ratio=True) seq_length = tf.cast(tf.ceil(tf.cast(seq_length, tf.float32) / ratio), tf.int32) ctc_loss = model.loss(logits, seq_length, y) opt = model.train_opt(hparam.step_rate, hparam.max_steps, global_step=global_step) step = opt.minimize(ctc_loss, global_step=global_step) error = model.prediction(logits, seq_length, y) init = tf.global_variables_initializer() saver = tf.train.Saver() summary = tf.summary.merge_all() sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) model.save_model(default_config, config) if not hparam.retrain: sess.run(init) print("Model init finished, begin training. \n") else: saver.restore( sess, tf.train.latest_checkpoint(hparam.log_dir + hparam.model_name)) print("Model loaded finished, begin training. \n") summary_writer = tf.summary.FileWriter( hparam.log_dir + hparam.model_name + '/summary/', sess.graph) _ = tf.train.start_queue_runners(sess=sess) start = time.time() for i in range(hparam.max_steps): feed_dict = {training: True} loss_val, _ = sess.run([ctc_loss, step], feed_dict=feed_dict) if i % 10 == 0: global_step_val = tf.train.global_step(sess, global_step) feed_dict = {training: True} error_val = sess.run(error, feed_dict=feed_dict) end = time.time() print( "Step %d/%d , loss: %5.3f edit_distance: %5.3f Elapsed Time/batch: %5.3f" \ % (i, hparam.max_steps, loss_val, error_val, (end - start) / (i + 1))) saver.save(sess, hparam.log_dir + hparam.model_name + '/model.ckpt', global_step=global_step_val) summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, global_step=global_step_val) summary_writer.flush() global_step_val = tf.train.global_step(sess, global_step) print("Model %s saved." % (hparam.log_dir + hparam.model_name)) saver.save(sess, hparam.log_dir + hparam.model_name + '/final.ckpt', global_step=global_step_val)