def train(hparams): """Main training function. This will train a Neural Network with the given dataset. Args: hparams: hyper parameter for training the neural network data-dir: String, the path of the data(binary batch files) directory. log-dir: String, the path to save the trained model. sequence-len: Int, length of input signal. batch-size: Int. step-rate: Float, step rate of the optimizer. max-steps: Int, max training steps. kmer: Int, size of the dna kmer. model-name: String, model will be saved at log-dir/model-name. retrain: Boolean, if True, the model will be reload from log-dir/model-name. """ with tf.Graph().as_default(), tf.device('/cpu:0'): training = tf.placeholder(tf.bool) global_step = tf.get_variable('global_step', trainable=False, shape=(), dtype=tf.int32, initializer=tf.zeros_initializer()) opt = model.train_opt(hparams.step_rate, hparams.max_steps, global_step=global_step) x, seq_length, train_labels = inputs(hparams.data_dir, int(hparams.batch_size * hparams.ngpus), for_valid=False) split_y = tf.split(train_labels, hparams.ngpus, axis=0) split_seq_length = tf.split(seq_length, hparams.ngpus, axis=0) split_x = tf.split(x, hparams.ngpus, axis=0) tower_grads = [] with tf.variable_scope(tf.get_variable_scope()): for i in range(hparams.ngpus): with tf.device('/gpu:%d' % i): with tf.name_scope('%s_%d' % ('gpu_tower', i)) as scope: loss, error = tower_loss( scope, split_x[i], split_seq_length[i], split_y[i], full_seq_len=hparams.sequence_len) tf.get_variable_scope().reuse_variables() summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope) grads = opt.compute_gradients(loss) tower_grads.append(grads) grads = average_gradients(tower_grads) for grad, var in grads: if grad is not None: summaries.append( tf.summary.histogram(var.op.name + '/gradients', grad)) for var in tf.trainable_variables(): summaries.append(tf.summary.histogram(var.op.name, var)) apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) var_averages = tf.train.ExponentialMovingAverage( decay=model.MOVING_AVERAGE_DECAY) var_averages_op = var_averages.apply(tf.trainable_variables()) train_op = tf.group(apply_gradient_op, var_averages_op) init = tf.global_variables_initializer() saver = tf.train.Saver() summary = tf.summary.merge_all() sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) save_model(hparams.log_dir, hparams.model_name) if not hparams.retrain: sess.run(init) print("Model init finished, begin training. \n") else: saver.restore( sess, tf.train.latest_checkpoint(hparams.log_dir + hparams.model_name)) print("Model loaded finished, begin training. \n") summary_writer = tf.summary.FileWriter( hparams.log_dir + hparams.model_name + '/summary/', sess.graph) _ = tf.train.start_queue_runners(sess=sess) start = time.time() for i in range(hparams.max_steps): feed_dict = {training: True} loss_val, _ = sess.run([loss, train_op], feed_dict=feed_dict) if i % 10 == 0: global_step_val = tf.train.global_step(sess, global_step) feed_dict = {training: True} error_val = sess.run(error, feed_dict=feed_dict) end = time.time() print( "Step %d/%d , loss: %5.3f edit_distance: %5.3f Elapsed Time/batch: %5.3f" \ % (i, hparams.max_steps, loss_val, error_val, (end - start) / (i + 1))) saver.save(sess, hparams.log_dir + hparams.model_name + '/model.ckpt', global_step=global_step_val) summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, global_step=global_step_val) summary_writer.flush() global_step_val = tf.train.global_step(sess, global_step) print("Model %s saved." % (hparams.log_dir + hparams.model_name)) saver.save(sess, hparams.log_dir + hparams.model_name + '/final.ckpt', global_step=global_step_val)
def train(hparam): """Main training function. This will train a Neural Network with the given dataset. Args: hparam: hyper parameter for training the neural network data_dir: String, the path of the data(binary batch files) directory. log-dir: String, the path to save the trained model. sequence-len: Int, length of input signal. batch-size: Int. step-rate: Float, step rate of the optimizer. max-steps: Int, max training steps. kmer: Int, size of the dna kmer. model-name: String, model will be saved at log-dir/model-name. retrain: Boolean, if True, the model will be reload from log-dir/model-name. """ training = tf.placeholder(tf.bool) global_step = tf.get_variable('global_step', trainable=False, shape=(), dtype=tf.int32, initializer=tf.zeros_initializer()) x, seq_length, train_labels = inputs(hparam.data_dir, hparam.batch_size, hparam.sequence_len, for_valid=False) y = dense2sparse(train_labels) default_config = os.path.join(hparam.log_dir, hparam.model_name, 'model.json') if hparam.retrain: if os.path.isfile(default_config): config_file = default_config else: raise ValueError( "Model Json file has not been found in model log directory") else: config_file = hparam.configure config = model.read_config(config_file) logits, ratio = model.inference(x, seq_length, training, hparam.sequence_len, configure=config, apply_ratio=True) seq_length = tf.cast(tf.ceil(tf.cast(seq_length, tf.float32) / ratio), tf.int32) ctc_loss = model.loss(logits, seq_length, y) opt = model.train_opt(hparam.step_rate, hparam.max_steps, global_step=global_step) step = opt.minimize(ctc_loss, global_step=global_step) error = model.prediction(logits, seq_length, y) init = tf.global_variables_initializer() saver = tf.train.Saver() summary = tf.summary.merge_all() sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) model.save_model(default_config, config) if not hparam.retrain: sess.run(init) print("Model init finished, begin training. \n") else: saver.restore( sess, tf.train.latest_checkpoint(hparam.log_dir + hparam.model_name)) print("Model loaded finished, begin training. \n") summary_writer = tf.summary.FileWriter( hparam.log_dir + hparam.model_name + '/summary/', sess.graph) _ = tf.train.start_queue_runners(sess=sess) start = time.time() for i in range(hparam.max_steps): feed_dict = {training: True} loss_val, _ = sess.run([ctc_loss, step], feed_dict=feed_dict) if i % 10 == 0: global_step_val = tf.train.global_step(sess, global_step) feed_dict = {training: True} error_val = sess.run(error, feed_dict=feed_dict) end = time.time() print( "Step %d/%d , loss: %5.3f edit_distance: %5.3f Elapsed Time/batch: %5.3f" \ % (i, hparam.max_steps, loss_val, error_val, (end - start) / (i + 1))) saver.save(sess, hparam.log_dir + hparam.model_name + '/model.ckpt', global_step=global_step_val) summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, global_step=global_step_val) summary_writer.flush() global_step_val = tf.train.global_step(sess, global_step) print("Model %s saved." % (hparam.log_dir + hparam.model_name)) saver.save(sess, hparam.log_dir + hparam.model_name + '/final.ckpt', global_step=global_step_val)