Example #1
0
def construct_train_fn(config, operations=[]):
    """
    Function to construct the training function based on the config.

    Parameters
    ----------
    config: dict holding model configuration.

    Returns
    -------
    train_fn: callable which is passed to estimator.train function.
    This function prepares the dataset and returns it in a format which is suitable for the estimator API.
    """

    cfg_train_ds = cutil.safe_get('training', config)

    # Create decode operation
    decode_op = construct_decode_op(config['features'])

    # Create unzip operation
    unzip_op = construct_unzip_op()

    operations.insert(0, decode_op)
    if 'operations' in cfg_train_ds:
        for op in cfg_train_ds['operations']:
            operations.append(cutil.get_function(op['module'], op['name']))

    operations.append(unzip_op)
    preprocess = cutil.concatenate_functions(operations)

    def train_fn():
        """
        Function which is passed to .train(...) call of an estimator object.

        Returns
        -------
        dataset: tf.data.Dataset object with elements ({'f0': v0, ... 'fx': vx}, label).
        """
        #Load the dataset
        dataset = tf.data.TFRecordDataset(cfg_train_ds['filename'])

        # Apply possible preprocessing, batch and prefetch the dataset.
        dataset = dataset.map(preprocess, num_parallel_calls=os.cpu_count())

        sample = tf.data.experimental.get_single_element(dataset.take(1))
        element_size = get_deep_size(sample)

        # Shuffle the dataset
        buffer_size = tf.constant(
            int((virtual_memory().total / 2) / element_size), tf.int64)
        dataset = dataset.shuffle(config['shuffle_size'])

        dataset = dataset.batch(config['batch'])
        dataset = dataset.prefetch(buffer_size=1)
        return dataset.repeat()

    return train_fn
Example #2
0
def parse_component(inputs: dict, config: dict, outputs: dict):
    """
    Function to parse a dict holding the description for a component.
    A component is defined by an input and a number of layers.

    This function is supposed to be called in the model function of a tf.Estimator and eases model creation.

    The input description is used to build the feature_column and input layer.
    The input is then extended with batch dimension.

    Parameters
    ----------
    inputs: dict mapping from string to input tensor.

    config: dict holding keys 'input' for input speciication and 'layers', the list of layers after the input.

    outputs: dict to which to append this config output

    Returns
    -------
    layers: list(tf.layers.Layer), all layers added for this component.
            Layers not inheriting from tf.layers.Layer are passed as functions.

    variables: list(tf.Variable), list of all variables associated with the layers of this component.

    function: callable which performs a forward pass of features through the network.
    """

    layers = list()
    variables = list()
    funcs = list()

    # Get input shape for following layers
    shape = None
    if type(config['input']) != list:
        shape = inputs[config['input']].get_shape()
    else:
        shape = [inputs[key].get_shape() for key in config['input']]

    # Parse each layer specified in layers and append them to collections.
    for desc in config['layers']:
        layer, variable, function, shape = parse_layer(shape, desc)
        if layer is not None:
            layers.append(layer)
        if variable is not None:
            variables.append(variable)
        funcs.append(function)

    function = cutil.concatenate_functions(funcs)
    output_tensors = function(inputs[config['input']])

    if isinstance(config['output'], collections.Iterable) and isinstance(
            output_tensors, tuple):
        for key, value in zip(config['output'], output_tensors):
            if isinstance(value, tf.Tensor):
                outputs.update({key: tf.identity(value, name=key)})
            else:
                outputs.update({key: value})
    else:
        outputs.update({
            config['output']:
            tf.identity(output_tensors, name=config['output'])
        })
    return layers, variables, function
Example #3
0
def main(argv):
    incr_and_double = cutil.concatenate_functions([increment, double])
    print(incr_and_double(1))