コード例 #1
0
def eval_nn(inputp, config, output, normalization, network_type, model_file):     \
            # pylint: disable=too-many-arguments
    """ evaluate a dataset  with a neural network stored on disk

    arguments:
    inputp -- path to the ROOT dataset
    model -- path to the yaml keras model config file
    config -- path to the branches config file
    output -- output name for the sqlite database (overwrites the 'test' table)
    normalization -- path to the txt file with normalization constants
    """

    model = keras.models.load_model(
        #'Outputs/Try_3p/MDN_Three1Gauss_MC16_v3-150ep.h5', #MDN_Two1Gauss_MC16_v6-250ep.h5,  'Outputs/Try_2DGauss/pos1_MC16.h5'
        #'/home/elham/NN_optimise/pixel-MDN-training/Outputs/Try_2DGauss/pos1_MC16.h5',
        model_file,
        custom_objects={'loss': mixture_density_loss(nb_components=1)})

    _eval_dataset(model=model,
                  network_type=network_type,
                  path=inputp,
                  tree='NNinput',
                  branches=utils.get_data_config_names(config, meta=True),
                  norm=utils.load_normalization(normalization),
                  dbpath=output)
コード例 #2
0
ファイル: train.py プロジェクト: madved/phones-las
def input_fn(dataset_filename,
             vocab_filename,
             norm_filename=None,
             num_channels=39,
             batch_size=8,
             num_epochs=1,
             binf2phone=None,
             num_parallel_calls=32,
             max_frames=-1,
             max_symbols=-1):
    binary_targets = binf2phone is not None
    labels_shape = [] if not binary_targets else len(binf2phone.index)
    labels_dtype = tf.string if not binary_targets else tf.float32
    dataset = utils.read_dataset(dataset_filename,
                                 num_channels,
                                 labels_shape=labels_shape,
                                 labels_dtype=labels_dtype)
    vocab_table = utils.create_vocab_table(vocab_filename)

    if norm_filename is not None:
        means, stds = utils.load_normalization(args.norm)
    else:
        means = stds = None

    sos = binf2phone[utils.SOS].values if binary_targets else utils.SOS
    eos = binf2phone[utils.EOS].values if binary_targets else utils.EOS

    dataset = utils.process_dataset(dataset,
                                    vocab_table,
                                    sos,
                                    eos,
                                    means,
                                    stds,
                                    batch_size,
                                    num_epochs,
                                    binary_targets=binary_targets,
                                    labels_shape=labels_shape,
                                    num_parallel_calls=num_parallel_calls,
                                    max_frames=max_frames,
                                    max_symbols=max_symbols)

    return dataset
コード例 #3
0
ファイル: infer.py プロジェクト: madved/phones-las
def input_fn(dataset_filename,
             vocab_filename,
             norm_filename=None,
             num_channels=39,
             batch_size=8,
             take=0,
             binf2phone=None):
    binary_targets = binf2phone is not None
    labels_shape = [] if not binary_targets else len(binf2phone.index)
    labels_dtype = tf.string if not binary_targets else tf.float32
    dataset = utils.read_dataset(dataset_filename,
                                 num_channels,
                                 labels_shape=labels_shape,
                                 labels_dtype=labels_dtype)
    vocab_table = utils.create_vocab_table(vocab_filename)

    if norm_filename is not None:
        means, stds = utils.load_normalization(args.norm)
    else:
        means = stds = None

    sos = binf2phone[utils.SOS].values if binary_targets else utils.SOS
    eos = binf2phone[utils.EOS].values if binary_targets else utils.EOS

    dataset = utils.process_dataset(dataset,
                                    vocab_table,
                                    sos,
                                    eos,
                                    means,
                                    stds,
                                    batch_size,
                                    1,
                                    binary_targets=binary_targets,
                                    labels_shape=labels_shape,
                                    is_infer=True)

    if args.take > 0:
        dataset = dataset.take(take)
    return dataset
コード例 #4
0
def input_fn(dataset_filename,
             vocab_filename,
             norm_filename=None,
             num_channels=39,
             batch_size=8,
             num_epochs=1,
             num_parallel_calls=32,
             max_frames=-1,
             max_symbols=-1,
             take=0,
             is_infer=False):
    dataset = read_dataset(dataset_filename, num_channels)
    vocab_table = utils.create_vocab_table(vocab_filename)

    if norm_filename is not None and os.path.exists(norm_filename):
        means, stds = utils.load_normalization(norm_filename)
    else:
        means = stds = None

    sos = utils.SOS
    eos = utils.EOS

    dataset = process_dataset(dataset,
                              vocab_table,
                              sos,
                              eos,
                              means,
                              stds,
                              batch_size,
                              num_epochs,
                              num_parallel_calls=num_parallel_calls,
                              max_frames=max_frames,
                              max_symbols=max_symbols)

    if take > 0:
        dataset = dataset.take(take)

    return dataset
コード例 #5
0
def input_fn(features,
             vocab_filename,
             norm_filename=None,
             num_channels=39,
             batch_size=8,
             ground_truth=None):
    def gen():
        if ground_truth is not None:
            iterable = zip(features, ground_truth)
        else:
            iterable = features
        for item in iterable:
            yield item

    output_types = (tf.float32,
                    tf.string) if ground_truth is not None else tf.float32
    output_shapes = tf.TensorShape([None, features[0].shape[-1]])
    if ground_truth is not None:
        output_shapes = (output_shapes,
                         tf.TensorShape([None, ground_truth[0].shape[-1]]))
    dataset = tf.data.Dataset.from_generator(gen, output_types, output_shapes)
    vocab_table = utils.create_vocab_table(vocab_filename)

    if norm_filename is not None:
        means, stds = utils.load_normalization(norm_filename)
    else:
        means = stds = None

    dataset = utils.process_dataset(dataset,
                                    vocab_table,
                                    utils.SOS,
                                    utils.EOS,
                                    means,
                                    stds,
                                    min(features[0].shape[0], batch_size),
                                    1,
                                    is_infer=True)
    return dataset
コード例 #6
0
def input_fn(features, vocab_filename, norm_filename=None):
    def gen():
        for item in features:
            yield item

    output_shapes = tf.TensorShape([None, features[0].shape[-1]])
    dataset = tf.data.Dataset.from_generator(gen, tf.float32, output_shapes)
    vocab_table = utils.create_vocab_table(vocab_filename)
    if norm_filename is not None:
        means, stds = utils.load_normalization(norm_filename)
    else:
        means = stds = None

    dataset = utils.process_dataset(dataset,
                                    vocab_table,
                                    utils.SOS,
                                    utils.EOS,
                                    means,
                                    stds,
                                    1,
                                    1,
                                    is_infer=True)
    return dataset