def test_on_fold( fold: int, data_of_fold: Dict[int, Tuple[List[Tensor], List[int]]], train_folds_of_fold: Dict[int, List[int]], model_factory: ModelFactory) -> Tuple[float, float, np.ndarray]: print(f"Testing fold {fold}") train_items, train_labels = train_data_for_fold(fold, data_of_fold, train_folds_of_fold) test_items, test_labels = data_of_fold[fold] model = model_factory.get() model.fit(train_items, train_labels) train_prediction = model.predict(train_items) test_prediction = model.predict(test_items) train_accuracy = accuracy_score(train_labels, train_prediction) test_accuracy = accuracy_score(test_labels, test_prediction) test_confusion = confusion_matrix(test_labels, test_prediction) return train_accuracy, test_accuracy, test_confusion
def test(model_path, dataset_path): log_path = os.path.join(model_path, 'test_log') config = ConfigReader(os.path.join(model_path, 'config.json')).read() dataset_extractor = DatasetExtractor(dataset_path, config) dataset_test, test_size = dataset_extractor.extract() test_iterator = dataset_test.make_one_shot_iterator() dataset_handle = tf.placeholder(tf.string, shape=[]) feedable_iterator = tf.data.Iterator.from_string_handle(dataset_handle, dataset_test.output_types, dataset_test.output_shapes, dataset_test.output_classes) signal, label, signal_len, _ = feedable_iterator.get_next() label = tf.cast(label, dtype=tf.int32) model = ModelFactory.get(config.model_name, signal, config) optimizer = OptimizerFactory.get(config.optimizer, model.logits, label, signal_len) decoder = DecoderFactory.get(config.decoder, model.logits, signal_len) distance_op = tf.reduce_mean(tf.edit_distance(tf.cast(decoder.decoded, dtype=tf.int32), label)) saver = tf.train.Saver() sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) saver.restore(sess, os.path.join(model_path, "model.ckpt")) tf.saved_model.simple_save(sess, os.path.join(model_path, 'saved_model'), inputs={"signal": signal, "lengths": signal_len}, outputs={"logits": model.logits})
def train(config_path, train_dataset_path, val_dataset_path, output_path, early_stop=False): keep_training = True if not os.path.exists(output_path): os.makedirs(output_path) copyfile(config_path, os.path.join(output_path, 'config.json')) log_path = os.path.join(output_path, 'log') config = ConfigReader(config_path).read() train_dataset_extractor = DatasetExtractor(train_dataset_path, config) val_dataset_extractor = DatasetExtractor(val_dataset_path, config) dataset_train, train_size = train_dataset_extractor.extract() train_iterator = dataset_train.make_one_shot_iterator() dataset_val, val_size = val_dataset_extractor.extract() dataset_test = dataset_val.take(300) dataset_val = dataset_val.take(75) val_iterator = dataset_val.make_initializable_iterator() test_iterator = dataset_test.make_one_shot_iterator() dataset_handle = tf.placeholder(tf.string, shape=[]) feedable_iterator = tf.data.Iterator.from_string_handle( dataset_handle, dataset_train.output_types, dataset_train.output_shapes, dataset_train.output_classes) signal, label, signal_len, _ = feedable_iterator.get_next() label = tf.cast(label, dtype=tf.int32) model = ModelFactory.get(config.model_name, signal, config) optimizer = OptimizerFactory.get(config.optimizer, model.logits, label, signal_len) decoder = DecoderFactory.get(config.decoder, model.logits, signal_len) distance_op = tf.reduce_mean( tf.edit_distance(tf.cast(decoder.decoded, dtype=tf.int32), label)) saver = tf.train.Saver() sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) if config.debug: sess = tf_debug.LocalCLIDebugWrapperSession(sess) sess.run(tf.global_variables_initializer()) training_handle = sess.run(train_iterator.string_handle()) validation_handle = sess.run(val_iterator.string_handle()) test_handle = sess.run(test_iterator.string_handle()) epoch = 0 steps = 0 previous_print_length = 0 losses = [] prev_val_distance = None prev_val_loss = None while keep_training: try: loss_value, _ = sess.run( [optimizer.loss, optimizer.optimizer], feed_dict={dataset_handle: training_handle}) losses.append(loss_value) steps += config.batch_size if previous_print_length > 0: print('\b' * previous_print_length, end='', flush=True) message = f"Epoch: {epoch} Step: {steps} Step Loss:{loss_value} Epoch Loss: {np.mean(losses)}" log_to_file(log_path, message) previous_print_length = len(message) print(message, end='', flush=True) if steps >= train_size: saver.save(sess, os.path.join(output_path, f"model.ckpt")) if config.validate or early_stop: distances = [] val_losses = [] sess.run(val_iterator.initializer) while True: try: distance, val_loss = sess.run( [distance_op, optimizer.loss], feed_dict={dataset_handle: validation_handle}) distances.append(distance) val_losses.append(val_loss) except tf.errors.InvalidArgumentError as e: log_to_file(log_path, e.message) raise e except tf.errors.OutOfRangeError: break mean_distance = np.mean(distances) mean_val_loss = np.mean(val_losses) if prev_val_distance is not None and prev_val_loss is not None: if prev_val_distance < mean_distance and prev_val_loss < mean_val_loss: keep_training = False prev_val_loss = mean_val_loss prev_val_distance = mean_distance print(flush=True) log_message = f"Epoch: {epoch} Validation Loss: {mean_val_loss} Edit Distance: {mean_distance}" print(log_message, flush=True) log_to_file(log_path, log_message) epoch += 1 steps = 0 previous_print_length = 0 losses = [] except tf.errors.OutOfRangeError: break # End of dataset saver.save(sess, os.path.join(output_path, "model.ckpt")) test_distances = [] test_losses = [] while True: try: test_distance, test_loss = sess.run( [distance_op, optimizer.loss], feed_dict={dataset_handle: test_handle}) test_distances.append(test_distance) test_losses.append(test_loss) except tf.errors.OutOfRangeError: break mean_test_distance = np.mean(test_distances) mean_test_loss = np.mean(test_losses) print(flush=True) log_message = f"Test Loss: {mean_test_loss} Edit Distance: {mean_test_distance}" print(log_message, flush=True) log_to_file(log_path, log_message)
from model import ModelFactory import tensorflow as tf import sys if __name__ == "__main__": placeholder = tf.placeholder(tf.float32, [None, 300, 1]) model = ModelFactory.get(sys.argv[1], placeholder)