コード例 #1
0
def test_on_fold(
        fold: int, data_of_fold: Dict[int, Tuple[List[Tensor], List[int]]],
        train_folds_of_fold: Dict[int, List[int]],
        model_factory: ModelFactory) -> Tuple[float, float, np.ndarray]:
    print(f"Testing fold {fold}")
    train_items, train_labels = train_data_for_fold(fold, data_of_fold,
                                                    train_folds_of_fold)
    test_items, test_labels = data_of_fold[fold]

    model = model_factory.get()
    model.fit(train_items, train_labels)

    train_prediction = model.predict(train_items)
    test_prediction = model.predict(test_items)

    train_accuracy = accuracy_score(train_labels, train_prediction)
    test_accuracy = accuracy_score(test_labels, test_prediction)
    test_confusion = confusion_matrix(test_labels, test_prediction)
    return train_accuracy, test_accuracy, test_confusion
コード例 #2
0
ファイル: test.py プロジェクト: plubon/Basecaller
def test(model_path, dataset_path):
    log_path = os.path.join(model_path, 'test_log')
    config = ConfigReader(os.path.join(model_path, 'config.json')).read()
    dataset_extractor = DatasetExtractor(dataset_path, config)
    dataset_test, test_size = dataset_extractor.extract()
    test_iterator = dataset_test.make_one_shot_iterator()
    dataset_handle = tf.placeholder(tf.string, shape=[])
    feedable_iterator = tf.data.Iterator.from_string_handle(dataset_handle, dataset_test.output_types,
                                                            dataset_test.output_shapes, dataset_test.output_classes)
    signal, label, signal_len, _ = feedable_iterator.get_next()
    label = tf.cast(label, dtype=tf.int32)
    model = ModelFactory.get(config.model_name, signal, config)
    optimizer = OptimizerFactory.get(config.optimizer, model.logits, label, signal_len)
    decoder = DecoderFactory.get(config.decoder, model.logits, signal_len)
    distance_op = tf.reduce_mean(tf.edit_distance(tf.cast(decoder.decoded, dtype=tf.int32), label))
    saver = tf.train.Saver()
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    saver.restore(sess, os.path.join(model_path, "model.ckpt"))
    tf.saved_model.simple_save(sess,
                               os.path.join(model_path, 'saved_model'),
                               inputs={"signal": signal,
                                       "lengths": signal_len},
                               outputs={"logits": model.logits})
コード例 #3
0
def train(config_path,
          train_dataset_path,
          val_dataset_path,
          output_path,
          early_stop=False):
    keep_training = True
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    copyfile(config_path, os.path.join(output_path, 'config.json'))
    log_path = os.path.join(output_path, 'log')
    config = ConfigReader(config_path).read()
    train_dataset_extractor = DatasetExtractor(train_dataset_path, config)
    val_dataset_extractor = DatasetExtractor(val_dataset_path, config)
    dataset_train, train_size = train_dataset_extractor.extract()
    train_iterator = dataset_train.make_one_shot_iterator()
    dataset_val, val_size = val_dataset_extractor.extract()
    dataset_test = dataset_val.take(300)
    dataset_val = dataset_val.take(75)
    val_iterator = dataset_val.make_initializable_iterator()
    test_iterator = dataset_test.make_one_shot_iterator()
    dataset_handle = tf.placeholder(tf.string, shape=[])
    feedable_iterator = tf.data.Iterator.from_string_handle(
        dataset_handle, dataset_train.output_types,
        dataset_train.output_shapes, dataset_train.output_classes)
    signal, label, signal_len, _ = feedable_iterator.get_next()
    label = tf.cast(label, dtype=tf.int32)
    model = ModelFactory.get(config.model_name, signal, config)
    optimizer = OptimizerFactory.get(config.optimizer, model.logits, label,
                                     signal_len)
    decoder = DecoderFactory.get(config.decoder, model.logits, signal_len)
    distance_op = tf.reduce_mean(
        tf.edit_distance(tf.cast(decoder.decoded, dtype=tf.int32), label))
    saver = tf.train.Saver()
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    if config.debug:
        sess = tf_debug.LocalCLIDebugWrapperSession(sess)
    sess.run(tf.global_variables_initializer())
    training_handle = sess.run(train_iterator.string_handle())
    validation_handle = sess.run(val_iterator.string_handle())
    test_handle = sess.run(test_iterator.string_handle())
    epoch = 0
    steps = 0
    previous_print_length = 0
    losses = []
    prev_val_distance = None
    prev_val_loss = None
    while keep_training:
        try:
            loss_value, _ = sess.run(
                [optimizer.loss, optimizer.optimizer],
                feed_dict={dataset_handle: training_handle})
            losses.append(loss_value)
            steps += config.batch_size
            if previous_print_length > 0:
                print('\b' * previous_print_length, end='', flush=True)
            message = f"Epoch: {epoch} Step: {steps} Step Loss:{loss_value} Epoch Loss: {np.mean(losses)}"
            log_to_file(log_path, message)
            previous_print_length = len(message)
            print(message, end='', flush=True)
            if steps >= train_size:
                saver.save(sess, os.path.join(output_path, f"model.ckpt"))
                if config.validate or early_stop:
                    distances = []
                    val_losses = []
                    sess.run(val_iterator.initializer)
                    while True:
                        try:
                            distance, val_loss = sess.run(
                                [distance_op, optimizer.loss],
                                feed_dict={dataset_handle: validation_handle})
                            distances.append(distance)
                            val_losses.append(val_loss)
                        except tf.errors.InvalidArgumentError as e:
                            log_to_file(log_path, e.message)
                            raise e
                        except tf.errors.OutOfRangeError:
                            break
                    mean_distance = np.mean(distances)
                    mean_val_loss = np.mean(val_losses)
                    if prev_val_distance is not None and prev_val_loss is not None:
                        if prev_val_distance < mean_distance and prev_val_loss < mean_val_loss:
                            keep_training = False
                    prev_val_loss = mean_val_loss
                    prev_val_distance = mean_distance
                    print(flush=True)
                    log_message = f"Epoch: {epoch} Validation Loss: {mean_val_loss} Edit Distance: {mean_distance}"
                    print(log_message, flush=True)
                    log_to_file(log_path, log_message)
                epoch += 1
                steps = 0
                previous_print_length = 0
                losses = []
        except tf.errors.OutOfRangeError:
            break  # End of dataset
    saver.save(sess, os.path.join(output_path, "model.ckpt"))
    test_distances = []
    test_losses = []
    while True:
        try:
            test_distance, test_loss = sess.run(
                [distance_op, optimizer.loss],
                feed_dict={dataset_handle: test_handle})
            test_distances.append(test_distance)
            test_losses.append(test_loss)
        except tf.errors.OutOfRangeError:
            break
    mean_test_distance = np.mean(test_distances)
    mean_test_loss = np.mean(test_losses)
    print(flush=True)
    log_message = f"Test Loss: {mean_test_loss} Edit Distance: {mean_test_distance}"
    print(log_message, flush=True)
    log_to_file(log_path, log_message)
コード例 #4
0
from model import ModelFactory
import tensorflow as tf
import sys

if __name__ == "__main__":
    placeholder = tf.placeholder(tf.float32, [None, 300, 1])
    model = ModelFactory.get(sys.argv[1], placeholder)