def model_size(config): logger = logging.getLogger('') config.train.num_gpus = 1 model = eval(config.model)(config=config, num_gpus=config.train.num_gpus) model.build_train_model(test=config.train.eval_on_dev) sess_config = tf.ConfigProto() sess_config.gpu_options.allow_growth = True sess_config.allow_soft_placement = True with tf.Session(config=sess_config, graph=model.graph) as sess: # Initialize all variables. sess.run(tf.global_variables_initializer()) # Reload variables in disk. if tf.train.latest_checkpoint(config.model_dir): available_vars = available_variables(config.model_dir) if available_vars: saver = tf.train.Saver(var_list=available_vars) saver.restore(sess, tf.train.latest_checkpoint(config.model_dir)) for v in available_vars: logger.info('Reload {} from disk.'.format(v.name)) logger.info('=================================') import example.ctc.ctc_util as ctc_util logger.info(ctc_util.print_nnet_info()) else: logger.info('Nothing to be reload from disk.') else: logger.info('Nothing to be reload from disk.')
def train(config): logger = logging.getLogger('') """Train a model with a config file.""" data_reader = DataReader(config=config) model = eval(config.model)(config=config, num_gpus=config.train.num_gpus) model.build_train_model(test=config.train.eval_on_dev) sess_config = tf.ConfigProto() sess_config.gpu_options.allow_growth = True sess_config.allow_soft_placement = True summary_writer = tf.summary.FileWriter(config.model_dir, graph=model.graph) with tf.Session(config=sess_config, graph=model.graph) as sess: # Initialize all variables. sess.run(tf.global_variables_initializer()) # Reload variables in disk. if tf.train.latest_checkpoint(config.model_dir): available_vars = available_variables(config.model_dir) if available_vars: saver = tf.train.Saver(var_list=available_vars) saver.restore(sess, tf.train.latest_checkpoint(config.model_dir)) for v in available_vars: logger.info('Reload {} from disk.'.format(v.name)) else: logger.info('Nothing to be reload from disk.') else: logger.info('Nothing to be reload from disk.') evaluator = Evaluator() evaluator.init_from_existed(model, sess, data_reader) global dev_bleu, toleration dev_bleu = evaluator.evaluate( **config.dev) if config.train.eval_on_dev else 0 toleration = config.train.toleration def train_one_step(batch): feat_batch, target_batch, batch_size = batch feed_dict = expand_feed_dict({ model.src_pls: feat_batch, model.dst_pls: target_batch }) step, lr, loss, _ = sess.run([ model.global_step, model.learning_rate, model.loss, model.train_op ], feed_dict=feed_dict) if step % config.train.summary_freq == 0: summary = sess.run(model.summary_op, feed_dict=feed_dict) summary_writer.add_summary(summary, global_step=step) return step, lr, loss def maybe_save_model(): global dev_bleu, toleration new_dev_bleu = evaluator.evaluate( **config.dev) if config.train.eval_on_dev else dev_bleu + 1 if new_dev_bleu >= dev_bleu: mp = config.model_dir + '/model_step_{}'.format(step) model.saver.save(sess, mp) logger.info('Save model in %s.' % mp) toleration = config.train.toleration dev_bleu = new_dev_bleu else: toleration -= 1 step = 0 for epoch in range(1, config.train.num_epochs + 1): for batch in data_reader.get_training_batches_with_buckets(): # Train normal instances. start_time = time.time() step, lr, loss = train_one_step(batch) logger.info( 'epoch: {0}\tstep: {1}\tlr: {2:.6f}\tloss: {3:.4f}\ttime: {4:.4f}\tbatch_size: {5}' .format(epoch, step, lr, loss, time.time() - start_time, batch[2])) # Save model if config.train.save_freq > 0 and step % config.train.save_freq == 0: maybe_save_model() if config.train.num_steps and step >= config.train.num_steps: break # Save model per epoch if config.train.save_freq is less or equal than zero if config.train.save_freq <= 0: maybe_save_model() # Early stop if toleration <= 0: break logger.info("Finish training.")
def train(config, num_epoch, last_pretrain_model_dir, pretrain_model_dir, model_dir, block_idx): logger = logging.getLogger('') config.num_blocks = block_idx # if block_idx >= 2: # config.train.var_filter = 'encoder/block_' + str(block_idx - 1) + '|' + 'decoder/block_' + str( # block_idx - 1) + '|' + 'encoder/src_embedding' + '|' + 'decoder/dst_embedding' # if block_idx >= 2: # config.train.var_filter = 'encoder/block_' + str(block_idx - 1) + '|' + 'decoder/block_' + str( # block_idx - 1) logger.info("config.num_blocks=" + str(config.num_blocks) + ',config.train.var_filter=' + str(config.train.var_filter)) """Train a model with a config file.""" data_reader = DataReader(config=config) model = eval(config.model)(config=config, num_gpus=config.train.num_gpus) model.build_train_model(test=config.train.eval_on_dev) sess_config = tf.ConfigProto() sess_config.gpu_options.allow_growth = True sess_config.allow_soft_placement = True summary_writer = tf.summary.FileWriter(pretrain_model_dir, graph=model.graph) with tf.Session(config=sess_config, graph=model.graph) as sess: # Initialize all variables. sess.run(tf.global_variables_initializer()) # Reload variables in disk. if tf.train.latest_checkpoint(last_pretrain_model_dir): # available_vars = available_variables_without_global_step(last_pretrain_model_dir) available_vars = available_variables(last_pretrain_model_dir) if available_vars: saver = tf.train.Saver(var_list=available_vars) saver.restore( sess, tf.train.latest_checkpoint(last_pretrain_model_dir)) for v in available_vars: logger.info('Reload {} from disk.'.format(v.name)) else: logger.info('Nothing to be reload from disk.') else: logger.info('Nothing to be reload from disk.') evaluator = Evaluator() evaluator.init_from_existed(model, sess, data_reader) global dev_bleu, toleration dev_bleu = evaluator.evaluate( **config.dev) if config.train.eval_on_dev else 0 toleration = config.train.toleration def train_one_step(batch): feat_batch, target_batch = batch feed_dict = expand_feed_dict({ model.src_pls: feat_batch, model.dst_pls: target_batch }) step, lr, loss, _ = sess.run([ model.global_step, model.learning_rate, model.loss, model.train_op ], feed_dict=feed_dict) if step % config.train.summary_freq == 0: logger.info('pretrain summary_writer...') summary = sess.run(model.summary_op, feed_dict=feed_dict) summary_writer.add_summary(summary, global_step=step) summary_writer.flush() return step, lr, loss def maybe_save_model(model_dir, is_save_global_step=True): global dev_bleu, toleration new_dev_bleu = evaluator.evaluate( **config.dev) if config.train.eval_on_dev else dev_bleu + 1 if new_dev_bleu >= dev_bleu: mp = model_dir + '/pretrain_model_step_{}'.format(step) # model.saver.save(sess, mp) if is_save_global_step: model.saver.save(sess, mp) else: variables_without_global_step = global_variables_without_global_step( ) saver = tf.train.Saver( var_list=variables_without_global_step, max_to_keep=10) saver.save(sess, mp) logger.info('Save model in %s.' % mp) toleration = config.train.toleration dev_bleu = new_dev_bleu else: toleration -= 1 step = 0 for epoch in range(1, num_epoch + 1): for batch in data_reader.get_training_batches_with_buckets(): # Train normal instances. start_time = time.time() step, lr, loss = train_one_step(batch) logger.info( 'epoch: {0}\tstep: {1}\tlr: {2:.6f}\tloss: {3:.4f}\ttime: {4:.4f}' .format(epoch, step, lr, loss, time.time() - start_time)) if config.train.num_steps and step >= config.train.num_steps: break # Early stop if toleration <= 0: break maybe_save_model(pretrain_model_dir) if model_dir: maybe_save_model(model_dir, False) logger.info("Finish pretrain layer=" + str(block_idx))
def train(config): """Train a model with a config file.""" logger = logging.getLogger('') data_reader = DataReader(config=config) model = eval(config.model)(config=config, num_gpus=config.train.num_gpus) model.build_train_model(test=config.train.eval_on_dev) train_op, loss_op = model.get_train_op(name=None) global_saver = tf.train.Saver() sess_config = tf.ConfigProto() sess_config.gpu_options.allow_growth = True sess_config.allow_soft_placement = True summary_writer = tf.summary.FileWriter(config.model_dir) with tf.Session(config=sess_config) as sess: # Initialize all variables. sess.run(tf.global_variables_initializer()) # Reload variables from disk. if tf.train.latest_checkpoint(config.model_dir): available_vars = available_variables(config.model_dir) if available_vars: saver = tf.train.Saver(var_list=available_vars) saver.restore(sess, tf.train.latest_checkpoint(config.model_dir)) for v in available_vars: logger.info('Reload {} from disk.'.format(v.name)) else: logger.info('Nothing to be reload from disk.') else: logger.info('Nothing to be reload from disk.') evaluator = Evaluator() evaluator.init_from_existed(model, sess, data_reader) global dev_bleu, toleration dev_bleu = evaluator.evaluate(**config.dev) if config.train.eval_on_dev else 0 toleration = config.train.toleration def train_one_step(batch, loss_op, train_op): feed_dict = expand_feed_dict({model.src_pls: batch[0], model.dst_pls: batch[1]}) step, lr, loss, _ = sess.run( [model.global_step, model.learning_rate, loss_op, train_op], feed_dict=feed_dict) if step % config.train.summary_freq == 0: summary = sess.run(model.summary_op, feed_dict=feed_dict) summary_writer.add_summary(summary, global_step=step) return step, lr, loss def maybe_save_model(): global dev_bleu, toleration def save(): mp = config.model_dir + '/model_step_{}'.format(step) global_saver.save(sess, mp) logger.info('Save model in %s.' % mp) if config.train.eval_on_dev: new_dev_bleu = evaluator.evaluate(**config.dev) summary = tf.Summary(value=[tf.Summary.Value(tag="dev_bleu", simple_value=new_dev_bleu)]) summary_writer.add_summary(summary, step) if config.train.toleration is None: save() else: if new_dev_bleu >= dev_bleu: save() toleration = config.train.toleration dev_bleu = new_dev_bleu else: toleration -= 1 else: save() try: step = 0 for epoch in range(1, config.train.num_epochs+1): for batch in data_reader.get_training_batches(epoches=1): # Train normal instances. start_time = time.time() step, lr, loss = train_one_step(batch, loss_op, train_op) logger.info( 'epoch: {0}\tstep: {1}\tlr: {2:.6f}\tloss: {3:.4f}\ttime: {4:.4f}'. format(epoch, step, lr, loss, time.time() - start_time)) # Save model if config.train.save_freq > 0 \ and step > 0 \ and step % config.train.save_freq == 0: maybe_save_model() if config.train.num_steps is not None and step >= config.train.num_steps: raise BreakLoopException("BreakLoop") if toleration is not None and toleration <= 0: raise BreakLoopException("BreakLoop") # Save model per epoch if config.train.save_freq is less or equal than zero if config.train.save_freq <= 0: maybe_save_model() except BreakLoopException as e: logger.info(e) logger.info("Finish training.")
def train(config): logger = logging.getLogger('') """Train a model with a config file.""" train_graph = tf.Graph() print(config.train.tfrecord_pattern) data_files = tf.gfile.Glob(config.train.tfrecord_pattern) logging.info("Find {} tfrecords files".format(len(data_files))) with train_graph.as_default(): data_holder = tf.placeholder(tf.string, shape=[None]) dataset = tf.data.TFRecordDataset( data_holder, num_parallel_reads=config.train.read_threads) dataset = dataset.map(parse_function_var, num_parallel_calls=config.train.read_threads) shuffle_data = True if shuffle_data is True: dataset = dataset.shuffle(buffer_size=10000) dataset = dataset.repeat(config.train.num_epochs).batch( config.train.batchsize_read) iterator = dataset.make_initializable_iterator() feat_shape_tensor, feat_tensor, label_shape_tensor, label_tensor = iterator.get_next( ) feat_tensor = tf.sparse_tensor_to_dense(feat_tensor) label_tensor = tf.sparse_tensor_to_dense(label_tensor) label_tensor = tf.cast(label_tensor, tf.int32) feat_tensor_shapeop = tf.shape(feat_tensor) feat_tensor = tf.reshape( feat_tensor, [feat_tensor_shapeop[0], -1, config.train.input_dim]) model = eval(config.model)(config=config, num_gpus=config.train.num_gpus, X=feat_tensor, Y=label_tensor, tensor_graph=train_graph) model.build_train_model(test=config.train.eval_on_dev) sess_config = tf.ConfigProto() sess_config.gpu_options.allow_growth = True sess_config.allow_soft_placement = True summary_writer = tf.summary.FileWriter(config.model_dir, graph=model.graph) with tf.Session(config=sess_config, graph=model.graph) as sess: # Initialize all variables. sess.run(tf.global_variables_initializer()) sess.run(iterator.initializer, feed_dict={data_holder: data_files}) # Reload variables in disk. if tf.train.latest_checkpoint(config.model_dir): available_vars = available_variables(config.model_dir) if available_vars: saver = tf.train.Saver(var_list=available_vars) saver.restore(sess, tf.train.latest_checkpoint(config.model_dir)) for v in available_vars: logger.info('Reload {} from disk.'.format(v.name)) else: logger.info('Nothing to be reload from disk.') else: logger.info('Nothing to be reload from disk.') global dev_bleu, toleration dev_bleu = 0 toleration = config.train.toleration def train_one_step(batch): feat_batch, target_batch, batch_size = batch feed_dict = expand_feed_dict({ model.src_pls: feat_batch, model.dst_pls: target_batch }) step, lr, loss, _ = sess.run([ model.global_step, model.learning_rate, model.loss, model.train_op ], feed_dict=feed_dict) if step % config.train.summary_freq == 0: summary = sess.run(model.summary_op, feed_dict=feed_dict) summary_writer.add_summary(summary, global_step=step) return step, lr, loss def maybe_save_model(): global dev_bleu, toleration new_dev_bleu = dev_bleu + 1 if new_dev_bleu >= dev_bleu: mp = config.model_dir + '/model_step_{}'.format(step) model.saver.save(sess, mp) logger.info('Save model in %s.' % mp) toleration = config.train.toleration dev_bleu = new_dev_bleu else: toleration -= 1 step = 0 while True: try: pre_train_time = time.time() feat_shape, feat, label_shape, label = sess.run([ feat_shape_tensor, feat_tensor, label_shape_tensor, label_tensor ]) batch = (feat, label, feat.shape[0]) #logging.info("This batch has {} samples".format(feat.shape[0])) #logging.info("The feat shape is {}".format(feat.shape)) # Train normal instances. start_time = time.time() step, lr, loss = train_one_step(batch) logger.info( 'step: {0}\tlr: {1:.6f}\tloss: {2:.4f}\ttrain_time: {3:.4f}\tpre_train_time: {4:.5f}\tbatch_size: {5}' .format(step, lr, loss, time.time() - start_time, start_time - pre_train_time, batch[2])) # Save model pre_train_time = time.time() if config.train.save_freq > 0 and step % config.train.save_freq == 0: maybe_save_model() if config.train.num_steps and step >= config.train.num_steps: break # Save model per epoch if config.train.save_freq is less or equal than zero if config.train.save_freq <= 0: maybe_save_model() # Early stop if toleration <= 0: break except tf.errors.OutOfRangeError: logging.info("All data done!") logger.info("Finish training.")