def constrained_tensors(constraint_fn, shape, dtype=np.float32, elements=None): """Strategy for drawing a constrained Tensor. Args: constraint_fn: Function mapping the unconstrained space to the desired constrained space. shape: Shape of the desired Tensors as a Python list. dtype: Dtype for constrained Tensors. elements: Optional strategy for selecting array elements. Returns: tensors: A strategy for drawing constrained Tensors of the given shape. """ # TODO(bjp): Allow a wider range of floats. # float32s = hps.floats( # np.finfo(np.float32).min / 2, np.finfo(np.float32).max / 2, # allow_nan=False, allow_infinity=False) if elements is None: if dtype_util.is_floating(dtype): width = np.dtype(dtype_util.as_numpy_dtype(dtype)).itemsize * 8 elements = hps.floats(-200, 200, allow_nan=False, allow_infinity=False, width=width) elif dtype_util.is_bool(dtype): elements = hps.booleans() else: raise NotImplementedError(dtype) def mapper(x): x = constraint_fn(tf.convert_to_tensor(x, dtype_hint=dtype)) if dtype_util.is_floating(x.dtype) and tf.executing_eagerly(): # We'll skip this check in graph mode; too expensive. if not np.all(np.isfinite(np.array(x))): raise AssertionError( '{} generated non-finite param value: {}'.format( constraint_fn, np.array(x))) return x return hpnp.arrays(dtype=dtype, shape=shape, elements=elements).map(mapper)
def expand_right_dims(x, broadcast=False): """Expand x so it can bcast w/ tensors of output shape.""" expanded_shape_left = tf.broadcast_dynamic_shape( tf.shape(x)[:-1], tf.ones([tf.size(y_ref_shape_left)], dtype=tf.int32)) expanded_shape = tf.concat( (expanded_shape_left, tf.shape(x)[-1:], tf.ones([tf.size(y_ref_shape_right)], dtype=tf.int32)), axis=0) x_expanded = tf.reshape(x, expanded_shape) if broadcast: broadcast_shape_left = tf.broadcast_dynamic_shape( tf.shape(x)[:-1], y_ref_shape_left) broadcast_shape = tf.concat( (broadcast_shape_left, tf.shape(x)[-1:], y_ref_shape_right), axis=0) if dtype_util.is_bool(x.dtype): x_expanded = x_expanded | tf.cast(tf.zeros(broadcast_shape), tf.bool) else: x_expanded += tf.zeros(broadcast_shape, dtype=x.dtype) return x_expanded
def expand_ends(x, broadcast=False): """Expand x so it can bcast w/ tensors of output shape.""" # Assume out_shape = A + x.shape + B, and rank(A) = axis. # Expand with singletons with same rank as A, B. expanded_shape = tf.pad(tensor=tf.shape(x), paddings=[[axis, tf.size(y_ref_shape_right)]], constant_values=1) x_expanded = tf.reshape(x, expanded_shape) if broadcast: out_shape = tf.concat(( y_ref_shape_left, tf.shape(x), y_ref_shape_right, ), axis=0) if dtype_util.is_bool(x.dtype): x_expanded = x_expanded | tf.cast(tf.zeros(out_shape), tf.bool) else: x_expanded += tf.zeros(out_shape, dtype=x.dtype) return x_expanded