Beispiel #1
0
def create_embedding_model(hparams, model_creator, scope=None):
    """Create embedding graph, model, and iterator."""
    print("# Creating EmbeddingModel...")

    batch_size = hparams.get('batch_size')
    shuffle = hparams.get('shuffle')

    graph = tf.Graph()
    with graph.as_default(), tf.container(scope or "embedding"):
        features_placeholder = tf.placeholder(shape=(None, None, None),
                                              dtype=tf.float32)
        labels_placeholder = tf.placeholder(shape=(None, ), dtype=tf.int64)

        iterator = DataSetIterator(features=features_placeholder,
                                   labels=labels_placeholder,
                                   batch_size=batch_size,
                                   shuffle=shuffle)

        assert isinstance(hparams, tf_training.HParams)
        assert hparams.get('mode') == ModeKeys.EMBEDDING

        model_params = param_utils.get_model_params(hparams, iterator)
        model = model_creator(**model_params.values())

    return EmbeddingModel(graph=graph,
                          model=model,
                          iterator=iterator,
                          features_placeholder=features_placeholder,
                          labels_placeholder=labels_placeholder)
Beispiel #2
0
def create_train_model(hparams, model_creator, scope=None):
    """Create train graph, model, and iterator."""
    print("# Creating TrainModel...")

    batch_size = hparams.get('batch_size')
    shuffle = hparams.get('shuffle')

    graph = tf.Graph()
    with graph.as_default(), tf.container(scope or "train"):
        features_placeholder = tf.placeholder(shape=(None, None, None),
                                              dtype=tf.float32)
        labels_placeholder = tf.placeholder(shape=(None, ), dtype=tf.int64)
        skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64)

        iterator = DataSetIterator(features=features_placeholder,
                                   labels=labels_placeholder,
                                   skip_count=skip_count_placeholder,
                                   batch_size=batch_size,
                                   shuffle=shuffle)

        assert isinstance(hparams, tf_training.HParams)
        assert hparams.get('mode') in [
            ModeKeys.TRAIN, ModeKeys.TRAIN_CLASSIFIER, ModeKeys.FINE_TUNE
        ]

        model_params = param_utils.get_model_params(hparams, iterator)
        model = model_creator(**model_params.values())

    return TrainModel(graph=graph,
                      model=model,
                      iterator=iterator,
                      features_placeholder=features_placeholder,
                      labels_placeholder=labels_placeholder,
                      skip_count_placeholder=skip_count_placeholder)
def create_train_model(hparams, model_creator, scope=None):
    """Create train graph, model, and iterator."""
    print("# Creating TrainModel...")

    src_train_file = "%s/%s.%s" % (hparams.data_dir, hparams.train_prefix,
                                   hparams.src_suffix)
    tgt_train_file = "%s/%s.%s" % (hparams.data_dir, hparams.train_prefix,
                                   hparams.tgt_suffix)
    src_vocab_file = "%s/%s.%s" % (hparams.data_dir, hparams.vocab_prefix,
                                   hparams.src_suffix)
    tgt_vocab_file = "%s/%s.%s" % (hparams.data_dir, hparams.vocab_prefix,
                                   hparams.tgt_suffix)
    batch_size = hparams.batch_size
    num_buckets = hparams.num_buckets

    graph = tf.Graph()

    with graph.as_default(), tf.container(scope or "train"):
        skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64)

        vocabulary = Vocabulary(src_vocab_file=src_vocab_file,
                                tgt_vocab_file=tgt_vocab_file)

        iterator = TrainIterator(vocabulary=vocabulary,
                                 src_data_file=src_train_file,
                                 tgt_data_file=tgt_train_file,
                                 batch_size=batch_size,
                                 num_buckets=num_buckets,
                                 skip_count=skip_count_placeholder)

        assert isinstance(hparams, tf_training.HParams)

        model_params = get_model_params(hparams=hparams,
                                        vocabulary=vocabulary,
                                        iterator=iterator)
        model_params.add_hparam('mode', ModeKeys.TRAIN)

        model = model_creator(**model_params.values())

    return TrainModel(graph=graph,
                      model=model,
                      iterator=iterator,
                      skip_count_placeholder=skip_count_placeholder)
Beispiel #4
0
def run_train(features, labels, hparams):
    out_dir = hparams.out_dir
    utils.ensure_path_exist(out_dir)
    model_dir = os.path.join(out_dir, "ckpts")
    utils.ensure_path_exist(model_dir)

    model_params = param_utils.get_model_params(hparams, None)
    model = VggishModelEst(**model_params.values())

    input_fn = tf.estimator.inputs.numpy_input_fn(
        x={
            "features": features,
            "class_weights": hparams.class_weights
        },
        y=labels,
        batch_size=hparams.batch_size,
        num_epochs=hparams.num_epochs,
        shuffle=True)

    model.classifier.train(input_fn=input_fn)
def create_eval_model(hparams, model_creator, scope=None):
    """Create eval graph, model, src/tgt file holders, and iterator."""
    print("# Creating EvalModel...")

    src_vocab_file = "%s/%s.%s" % (hparams.data_dir, hparams.vocab_prefix,
                                   hparams.src_suffix)
    tgt_vocab_file = "%s/%s.%s" % (hparams.data_dir, hparams.vocab_prefix,
                                   hparams.tgt_suffix)
    batch_size = hparams.batch_size
    num_buckets = hparams.num_buckets

    graph = tf.Graph()

    with graph.as_default(), tf.container(scope or "eval"):
        src_eval_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
        tgt_eval_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)

        vocabulary = Vocabulary(src_vocab_file=src_vocab_file,
                                tgt_vocab_file=tgt_vocab_file)

        iterator = EvalIterator(vocabulary=vocabulary,
                                src_data_file=src_eval_file_placeholder,
                                tgt_data_file=tgt_eval_file_placeholder,
                                batch_size=batch_size,
                                num_buckets=num_buckets)

        assert isinstance(hparams, tf_training.HParams)

        model_params = get_model_params(hparams=hparams,
                                        vocabulary=vocabulary,
                                        iterator=iterator)
        model_params.add_hparam('mode', ModeKeys.EVAL)

        model = model_creator(**model_params.values())

    return EvalModel(graph=graph,
                     model=model,
                     src_file_placeholder=src_eval_file_placeholder,
                     tgt_file_placeholder=tgt_eval_file_placeholder,
                     iterator=iterator)
def create_infer_model(hparams, model_creator, scope=None):
    """Create inference model."""
    print("# Creating InferModel...")

    src_vocab_file = "%s/%s.%s" % (hparams.data_dir, hparams.vocab_prefix,
                                   hparams.src_suffix)
    tgt_vocab_file = "%s/%s.%s" % (hparams.data_dir, hparams.vocab_prefix,
                                   hparams.tgt_suffix)

    graph = tf.Graph()

    with graph.as_default(), tf.container(scope or "infer"):
        src_data_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
        batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64)

        vocabulary = Vocabulary(src_vocab_file=src_vocab_file,
                                tgt_vocab_file=tgt_vocab_file)

        iterator = InferIterator(vocabulary=vocabulary,
                                 src_data=src_data_placeholder,
                                 batch_size=batch_size_placeholder)

        assert isinstance(hparams, tf_training.HParams)

        model_params = get_model_params(hparams=hparams,
                                        vocabulary=vocabulary,
                                        iterator=iterator)
        model_params.add_hparam('mode', ModeKeys.INFER)

        model = model_creator(**model_params.values())

    return InferModel(graph=graph,
                      model=model,
                      src_data_placeholder=src_data_placeholder,
                      batch_size_placeholder=batch_size_placeholder,
                      iterator=iterator)