예제 #1
0
 def _runIterAndCheck(self, lattice_sizes, expected_vertices):
   # Running iterator, and check the returned vertices with expected_vertices.
   lattice_structure = tools.LatticeStructure(lattice_sizes)
   for (index, vertices) in tools.lattice_indices_generator(lattice_structure):
     self.assertItemsEqual(vertices, expected_vertices[index])
예제 #2
0
def lattice_param_as_linear(lattice_sizes, output_dim, linear_weights=1.0):
    """Returns lattice parameter that represents a normalized linear function.

  For simplicity, let's assume output_dim == 1 (when output_dim > 1 you get
  output_dim lattices one for each linear function). This function returns a
  lattice parameter so that

    lattice_param' * phi(x) = 1 / len(lattice_sizes) *
      (sum_k x[k] * linear_weights[k]/(lattice_sizes[k] - 1) + bias)

  where phi(x) is the lattice interpolation weight and
  bias = -sum_k linear_weights[k] / 2.

  The normalization in the weights and the bias term make the output lie in the
  range [-0.5, 0.5], when every member of linear_weights is 1.0.
  In addition, the bias term makes the expected value zero when x[k] is from the
  uniform distribution over [0, lattice_sizes[k] - 1].

  The returned lattice_param can be used to initialize a lattice layer as a
  linear function.

  Args:
    lattice_sizes: (list of ints) A list of lattice sizes of each dimension.
    output_dim: (int) number of outputs.
    linear_weights: (float, list of floats, list of list of floats) linear
      function's weight terms. linear_weights[k][n] == kth output's nth weight.
      If float, then all the weights uses one value as
      [[linear_weights] * len(lattice_sizes)] * output_dim.
      If list of floats, then the len(linear_weights) == len(lattice_sizes) is
      expected, and the weights are [linear_weights] * output_dim, i.e., all
      output_dimension will get same linear_weights.
  Returns:
    List of list of floats with size (output_dim, number_of_lattice_param).
  Raises:
    ValueError: * Any element in lattice_sizes is less than 2.
      * lattice_sizes is empty.
      * If linear_weights is not supported type, or shape of linear_weights are
        not the desired values .
  """
    if not lattice_sizes:
        raise ValueError('lattice_sizes should not be empty')
    for lattice_size in lattice_sizes:
        if lattice_size < 2:
            raise ValueError(
                'All elements in lattice_sizes are expected to greater '
                'than equal to 2, but got %s' % lattice_sizes)

    lattice_rank = len(lattice_sizes)
    linear_weight_matrix = None
    if isinstance(linear_weights, float):
        linear_weight_matrix = [[linear_weights] * lattice_rank] * output_dim
    elif isinstance(linear_weights, list):
        # Branching using the first element in linear_weights. linear_weights[0]
        # should exist, since lattice_sizes is not empty.
        if isinstance(linear_weights[0], float):
            if len(linear_weights) != lattice_rank:
                raise ValueError(
                    'A number of elements in linear_weights (%d) != lattice rank (%d)'
                    % (len(linear_weights), lattice_rank))
            # Repeating same weights for all output_dim.
            linear_weight_matrix = [linear_weights] * output_dim
        elif isinstance(linear_weights[0], list):
            # 2d matrix case.
            if len(linear_weights) != output_dim:
                raise ValueError(
                    'A number of lists in linear_weights (%d) != output_dim (%d)'
                    % (len(linear_weights), output_dim))
            for linear_weight in linear_weights:
                if len(linear_weight) != lattice_rank:
                    raise ValueError(
                        'linear_weights contain more than one list whose length != '
                        'lattice rank(%d)' % lattice_rank)
            linear_weight_matrix = linear_weights
        else:
            raise ValueError(
                'Only list of float or list of list of floats are supported')
    else:
        raise ValueError(
            'Only float or list of float or list of list of floats are supported.'
        )

    # Create lattice structure to enumerate (index, lattice_dim) pairs.
    lattice_structure = tools.LatticeStructure(lattice_sizes)

    # Normalize linear_weight_matrix.
    lattice_parameters = []
    for linear_weight_per_output in linear_weight_matrix:
        sum_of_weights = 0.0
        for weight in linear_weight_per_output:
            sum_of_weights += weight
        sum_of_weights /= (2.0 * lattice_rank)
        lattice_parameter = [-sum_of_weights] * lattice_structure.num_vertices
        for (idx,
             vertex) in tools.lattice_indices_generator(lattice_structure):
            for dim in range(lattice_rank):
                lattice_parameter[idx] += (linear_weight_per_output[dim] *
                                           float(vertex[dim]) /
                                           float(lattice_rank *
                                                 (lattice_sizes[dim] - 1)))
        lattice_parameters.append(lattice_parameter)

    return lattice_parameters