def get_model_variables(getter,
                        name,
                        rename=None,
                        shape=None,
                        dtype=None,
                        initializer=None,
                        regularizer=None,
                        trainable=True,
                        collections=None,
                        caching_device=None,
                        partitioner=None,
                        use_resource=None,
                        **_):
    """This ensures variables are retrieved in a consistent way for core layers."""
    name_components = name.split('/')
    short_name = name_components[-1]

    # rename is an optional dict of strings defining alteration of tensor names
    if rename and short_name in rename:
        name_components[-1] = rename[short_name]
        name = '/'.join(name_components)
    return tf_variables.model_variable(name,
                                       shape=shape,
                                       dtype=dtype,
                                       initializer=initializer,
                                       regularizer=regularizer,
                                       collections=collections,
                                       trainable=trainable,
                                       caching_device=caching_device,
                                       partitioner=partitioner,
                                       custom_getter=getter,
                                       use_resource=use_resource)
Esempio n. 2
0
def _model_variable_getter(getter,
                           name,
                           shape=None,
                           dtype=None,
                           initializer=None,
                           regularizer=None,
                           trainable=True,
                           collections=None,
                           caching_device=None,
                           partitioner=None,
                           rename=None,
                           use_resource=None,
                           **_):
    """Getter that uses model_variable for compatibility with core layers."""
    name_components = name.split('/')
    short_name = name_components[-1]
    if rename and short_name in rename:
        name_components[-1] = rename[short_name]
        name = '/'.join(name_components)
    return tf_variables.model_variable(name,
                                       shape=shape,
                                       dtype=dtype,
                                       initializer=initializer,
                                       regularizer=regularizer,
                                       collections=collections,
                                       trainable=trainable,
                                       caching_device=caching_device,
                                       partitioner=partitioner,
                                       custom_getter=getter,
                                       use_resource=use_resource)