예제 #1
 def _runIterAndCheck(self, lattice_sizes, expected_vertices):
   # Running iterator, and check the returned vertices with expected_vertices.
   lattice_structure = tools.LatticeStructure(lattice_sizes)
   for (index, vertices) in tools.lattice_indices_generator(lattice_structure):
     self.assertItemsEqual(vertices, expected_vertices[index])
예제 #2
def lattice_param_as_linear(lattice_sizes, output_dim, linear_weights=1.0):
    """Returns lattice parameter that represents a normalized linear function.

  For simplicity, let's assume output_dim == 1 (when output_dim > 1 you get
  output_dim lattices one for each linear function). This function returns a
  lattice parameter so that

    lattice_param' * phi(x) = 1 / len(lattice_sizes) *
      (sum_k x[k] * linear_weights[k]/(lattice_sizes[k] - 1) + bias)

  where phi(x) is the lattice interpolation weight and
  bias = -sum_k linear_weights[k] / 2.

  The normalization in the weights and the bias term make the output lie in the
  range [-0.5, 0.5], when every member of linear_weights is 1.0.
  In addition, the bias term makes the expected value zero when x[k] is from the
  uniform distribution over [0, lattice_sizes[k] - 1].

  The returned lattice_param can be used to initialize a lattice layer as a
  linear function.

    lattice_sizes: (list of ints) A list of lattice sizes of each dimension.
    output_dim: (int) number of outputs.
    linear_weights: (float, list of floats, list of list of floats) linear
      function's weight terms. linear_weights[k][n] == kth output's nth weight.
      If float, then all the weights uses one value as
      [[linear_weights] * len(lattice_sizes)] * output_dim.
      If list of floats, then the len(linear_weights) == len(lattice_sizes) is
      expected, and the weights are [linear_weights] * output_dim, i.e., all
      output_dimension will get same linear_weights.
    List of list of floats with size (output_dim, number_of_lattice_param).
    ValueError: * Any element in lattice_sizes is less than 2.
      * lattice_sizes is empty.
      * If linear_weights is not supported type, or shape of linear_weights are
        not the desired values .
    if not lattice_sizes:
        raise ValueError('lattice_sizes should not be empty')
    for lattice_size in lattice_sizes:
        if lattice_size < 2:
            raise ValueError(
                'All elements in lattice_sizes are expected to greater '
                'than equal to 2, but got %s' % lattice_sizes)

    lattice_rank = len(lattice_sizes)
    linear_weight_matrix = None
    if isinstance(linear_weights, float):
        linear_weight_matrix = [[linear_weights] * lattice_rank] * output_dim
    elif isinstance(linear_weights, list):
        # Branching using the first element in linear_weights. linear_weights[0]
        # should exist, since lattice_sizes is not empty.
        if isinstance(linear_weights[0], float):
            if len(linear_weights) != lattice_rank:
                raise ValueError(
                    'A number of elements in linear_weights (%d) != lattice rank (%d)'
                    % (len(linear_weights), lattice_rank))
            # Repeating same weights for all output_dim.
            linear_weight_matrix = [linear_weights] * output_dim
        elif isinstance(linear_weights[0], list):
            # 2d matrix case.
            if len(linear_weights) != output_dim:
                raise ValueError(
                    'A number of lists in linear_weights (%d) != output_dim (%d)'
                    % (len(linear_weights), output_dim))
            for linear_weight in linear_weights:
                if len(linear_weight) != lattice_rank:
                    raise ValueError(
                        'linear_weights contain more than one list whose length != '
                        'lattice rank(%d)' % lattice_rank)
            linear_weight_matrix = linear_weights
            raise ValueError(
                'Only list of float or list of list of floats are supported')
        raise ValueError(
            'Only float or list of float or list of list of floats are supported.'

    # Create lattice structure to enumerate (index, lattice_dim) pairs.
    lattice_structure = tools.LatticeStructure(lattice_sizes)

    # Normalize linear_weight_matrix.
    lattice_parameters = []
    for linear_weight_per_output in linear_weight_matrix:
        sum_of_weights = 0.0
        for weight in linear_weight_per_output:
            sum_of_weights += weight
        sum_of_weights /= (2.0 * lattice_rank)
        lattice_parameter = [-sum_of_weights] * lattice_structure.num_vertices
        for (idx,
             vertex) in tools.lattice_indices_generator(lattice_structure):
            for dim in range(lattice_rank):
                lattice_parameter[idx] += (linear_weight_per_output[dim] *
                                           float(vertex[dim]) /
                                           float(lattice_rank *
                                                 (lattice_sizes[dim] - 1)))

    return lattice_parameters