def lattice_layer(input_tensor, lattice_sizes, is_monotone=None, output_dim=1, interpolation_type='hypercube', lattice_initializer=None, l1_reg=None, l2_reg=None, l1_torsion_reg=None, l2_torsion_reg=None, l1_laplacian_reg=None, l2_laplacian_reg=None): """Creates a lattice layer. Returns an output of lattice, lattice parameters, and projection ops. Args: input_tensor: [batch_size, input_dim] tensor. lattice_sizes: A list of lattice sizes of each dimension. is_monotone: A list of input_dim booleans, boolean or None. If None or False, lattice will not have monotonicity constraints. If is_monotone[k] == True, then the lattice output has the non-decreasing monotonicity with respect to input_tensor[?, k] (the kth coordinate). If True, all the input coordinate will have the non-decreasing monotonicity. output_dim: Number of outputs. interpolation_type: 'hypercube' or 'simplex'. lattice_initializer: (Optional) Initializer for lattice parameter vectors, a 2D tensor [output_dim, parameter_dim] (where parameter_dim == lattice_sizes[0] * ... * lattice_sizes[input_dim - 1]). If None, lattice_param_as_linear initializer will be used with linear_weights=[1 if monotone else 0 for monotone in is_monotone]. l1_reg: (float) l1 regularization amount. l2_reg: (float) l2 regularization amount. l1_torsion_reg: (float) l1 torsion regularization amount. l2_torsion_reg: (float) l2 torsion regularization amount. l1_laplacian_reg: (list of floats or float) list of L1 Laplacian regularization amount per each dimension. If a single float value is provided, then all diemnsion will get the same value. l2_laplacian_reg: (list of floats or float) list of L2 Laplacian regularization amount per each dimension. If a single float value is provided, then all diemnsion will get the same value. Returns: A tuple of: * output tensor of shape [batch_size, output_dim] * parameter tensor of shape [output_dim, parameter_dim] * None or projection ops, that must be applied at each step (or every so many steps) to project the model to a feasible space: used for bounding the outputs or for imposing monotonicity. * None or a regularization loss, if regularization is configured. Raises: ValueError: for invalid parameters. """ if interpolation_type not in _VALID_INTERPOLATION_TYPES: raise ValueError('interpolation_type should be one of {}'.format( _VALID_INTERPOLATION_TYPES)) if lattice_initializer is None: if is_monotone: is_monotone = tools.cast_to_list(is_monotone, len(lattice_sizes), 'is_monotone') linear_weights = [ 1.0 if monotonic else 0.0 for monotonic in is_monotone ] else: linear_weights = [0.0] * len(lattice_sizes) lattice_initializer = lattice_param_as_linear( lattice_sizes, output_dim, linear_weights=linear_weights) parameter_tensor = variable_scope.get_variable( interpolation_type + '_lattice_parameters', initializer=lattice_initializer) output_tensor = lattice_ops.lattice(input_tensor, parameter_tensor, lattice_sizes, interpolation_type=interpolation_type) with ops.name_scope('lattice_monotonic_projection'): if is_monotone: is_monotone = tools.cast_to_list(is_monotone, len(lattice_sizes), 'is_monotone') projected_parameter_tensor = monotone_lattice( parameter_tensor, lattice_sizes=lattice_sizes, is_monotone=is_monotone) delta = projected_parameter_tensor - parameter_tensor projection_ops = [parameter_tensor.assign_add(delta)] else: projection_ops = None with ops.name_scope('lattice_regularization'): reg = regularizers.lattice_regularization( parameter_tensor, lattice_sizes, l1_reg=l1_reg, l2_reg=l2_reg, l1_torsion_reg=l1_torsion_reg, l2_torsion_reg=l2_torsion_reg, l1_laplacian_reg=l1_laplacian_reg, l2_laplacian_reg=l2_laplacian_reg) return (output_tensor, parameter_tensor, projection_ops, reg)
def lattice_layer(input_tensor, lattice_sizes, is_monotone=None, output_min=None, output_max=None, output_dim=1, interpolation_type='hypercube', lattice_initializer=None, **regularizer_amounts): """Creates a lattice layer. Returns an output of lattice, lattice parameters, and projection ops. Args: input_tensor: [batch_size, input_dim] tensor. lattice_sizes: A list of lattice sizes of each dimension. is_monotone: A list of input_dim booleans, boolean or None. If None or False, lattice will not have monotonicity constraints. If is_monotone[k] == True, then the lattice output has the non-decreasing monotonicity with respect to input_tensor[?, k] (the kth coordinate). If True, all the input coordinate will have the non-decreasing monotonicity. output_min: Optional output lower bound. output_max: Optional output upper bound. output_dim: Number of outputs. interpolation_type: 'hypercube' or 'simplex'. lattice_initializer: (Optional) Initializer for lattice parameter vectors, a 2D tensor [output_dim, parameter_dim] (where parameter_dim == lattice_sizes[0] * ... * lattice_sizes[input_dim - 1]). If None, lattice_param_as_linear initializer will be used with linear_weights=[1] * len(lattice_sizes). **regularizer_amounts: Keyword args of regularization amounts passed to regularizers.lattice_regularization(). Keyword names should be among regularizers.LATTICE_ONE_DIMENSIONAL_REGULARIZERS or regularizers.LATTICE_MULTI_DIMENSIONAL_REGULARIZERS. For multi-dimensional regularizers the value should be float. For one-dimensional regularizers the values should be float or list of floats. If a single float value is provided, then all dimensions will get the same value. Returns: A tuple of: * output tensor of shape [batch_size, output_dim] * parameter tensor of shape [output_dim, parameter_dim] * None or projection ops, that must be applied at each step (or every so many steps) to project the model to a feasible space: used for bounding the outputs or for imposing monotonicity. * None or a regularization loss, if regularization is configured. Raises: ValueError: for invalid parameters. """ if interpolation_type not in _VALID_INTERPOLATION_TYPES: raise ValueError('interpolation_type should be one of {}'.format( _VALID_INTERPOLATION_TYPES)) if lattice_initializer is None: linear_weights = [1.0] * len(lattice_sizes) lattice_initializer = lattice_param_as_linear( lattice_sizes, output_dim, linear_weights=linear_weights) parameter_tensor = tf.compat.v1.get_variable( interpolation_type + '_lattice_parameters', initializer=lattice_initializer) output_tensor = lattice_ops.lattice(input_tensor, parameter_tensor, lattice_sizes, interpolation_type=interpolation_type) with tf.name_scope('lattice_monotonic_projection'): if is_monotone or output_min is not None or output_max is not None: projected_parameter_tensor = parameter_tensor if is_monotone: is_monotone = tools.cast_to_list(is_monotone, len(lattice_sizes), 'is_monotone') projected_parameter_tensor = monotone_lattice( projected_parameter_tensor, lattice_sizes=lattice_sizes, is_monotone=is_monotone) if output_min is not None: projected_parameter_tensor = tf.maximum( projected_parameter_tensor, output_min) if output_max is not None: projected_parameter_tensor = tf.minimum( projected_parameter_tensor, output_max) delta = projected_parameter_tensor - parameter_tensor projection_ops = [parameter_tensor.assign_add(delta)] else: projection_ops = None with tf.name_scope('lattice_regularization'): reg = regularizers.lattice_regularization(parameter_tensor, lattice_sizes, **regularizer_amounts) return (output_tensor, parameter_tensor, projection_ops, reg)