Пример #1
0
    def __init__(self,
                 trainable,
                 pool_size,
                 fraction,
                 stddev=1.0,
                 initializer=tf.random_normal_initializer,
                 use_kronecker_product=False,
                 index_store_type=virtual_variable.IndexStoreType.basic):
        """Creates an instance of `ProductVariablePool`.

    Args:
      trainable: boolean, indicate whether the created variables are trainable
        or not.
      pool_size: int, total number of virtual variables requried. The acutal
        number of virtual variables created can be larger than the number
        specified by this argument.
      fraction: float, the fraction of `pool_size` of variables to create.
      stddev: float, standard deviation for the variable pool. Default value is
        1.0.
      initializer: A tf.initializer e.g. 'truncated_normal_initializer' or
        'random_normal_initializer'. Default value is
        tf.random_normal_initializer.
      use_kronecker_product: Indicate product should be a kronecker product or a
        matrix prodcut.
      index_store_type: IndexStoreType, key of SUPPORTED_INDEX_STORES.
    """
        if fraction <= 0 or fraction > 1.0:
            raise ValueError('fraction %f must be >0 and <=1.0' % fraction)

        self._scope_name = 'ProductVariablePool'
        if use_kronecker_product:
            variable_generator = _create_kronecker_variables
        else:
            variable_generator = _create_matmul_variables

        with tf.variable_scope(self._scope_name):
            variables, size, pool = variable_generator(pool_size, fraction,
                                                       initializer, stddev,
                                                       trainable)

        self._core_variables = variables
        self._virtual_variables = tf.reshape(pool, [size], name='weight_pool')

        index_store_cls = virtual_variable.get_index_store(index_store_type)
        self._index_store = index_store_cls(size)
Пример #2
0
    def __init__(self,
                 trainable,
                 stddev,
                 pool_size,
                 fraction,
                 initializer,
                 seed=HASH_POOL_SEED,
                 index_store_type=virtual_variable.IndexStoreType.basic):
        """Creates an instance of `HashVariablePool`.

    Args:
      trainable: boolean, indicate whether the created variables are trainable
        or not.
      stddev: float, standard deviation for the variable pool.
      pool_size: int, total number of virtual variables requried. The acutal
        number of virtual variables created can be larger than the number
        specified by this argument.
      fraction: float, the fraction of `pool_size` of variables to create.
      initializer: A tf.initializer e.g. 'truncated_normal_initializer' or
        'random_normal'.
      seed: Integer, seed for the random hashing.
      index_store_type: String, key of SUPPORTED_INDEX_STORES. 'padding' is not
        supported by HashVariablePool yet.
    """
        del seed  # unused
        if fraction <= 0 or fraction > 1.0:
            raise ValueError('fraction %f must be >0 and <=1.0' % fraction)
        self._scope_name = 'HashVariablePool'
        self._hash_indices = None

        hash_size = int(np.floor(fraction * pool_size))
        if not hash_size:
            raise ValueError(
                'fraction %f too low, results in 0 size hash for pool size %d.'
                % (fraction, pool_size))

        index_store_cls = virtual_variable.get_index_store(index_store_type)
        self._index_store = index_store_cls(pool_size)
        if self._index_store.type == virtual_variable.IndexStoreType.padding:
            raise ValueError(
                'HashVariablePool does not support PaddingIndexStore '
                'yet.')
        replicas = int(np.ceil(float(pool_size + 1) / hash_size))

        # The following is for python2/3 compatibility. As range(k) does not return
        # a list in python 3.
        base_index_list = range(hash_size)
        if not isinstance(base_index_list, list):
            base_index_list = list(base_index_list)

        indices = np.array(base_index_list * replicas)

        # len(indices) = hash_size * replicas
        #             >= hash_size * (pool_size + 1) / hash_size
        #             ~ pool_size
        assert len(indices) >= pool_size
        indices = indices[:pool_size]

        # Preserving the state is done in order to not mess up with other elements
        # that might depend on numpy seed for some reason.
        # debuggin:
        # np_state = np.random.get_state()
        # np.random.seed(seed=seed)
        # random_indices = np.random.permutation(len(indices))
        # tf.logging.info('First 4 indices = %d %d', random_indices[:4], seed)
        # self._set_hash_indices(random_indices)
        # np.random.set_state(np_state)
        self._set_hash_indices(indices)

        with tf.variable_scope(self._scope_name):
            self._hash = tf.get_variable(
                'hash', [int(hash_size)],
                trainable=trainable,
                initializer=initializer(stddev=stddev))
            self._core_variables = [self._hash]