示例#1
0
    def call(self, inputs):
        x = tf.matmul(a=inputs, b=self.wee_kernel)
        x += tf.broadcast_to(self.wee_bias, shape=(1, ) + x.shape[1:])

        x = batch_dot_product(x,
                              tf.expand_dims(self.big_kernel, axis=0),
                              axis=1)
        x += self.big_bias

        return x
示例#2
0
    def call(
        self, inputs
    ):  # TODO test on non-random input, to see whether simpleoption beats vanilla
        last_axis = len(inputs.shape) - 1

        nan_mask = has_nan(inputs, axis=last_axis)
        nan_mask = tf.expand_dims(nan_mask, axis=last_axis)

        # Needed, because otherwise, the weights become NaN at some point (tf bug?)...
        inputs = tf.where(nan_mask, tf.zeros_like(inputs), inputs)

        # Apply kernel.
        x = tensor_dot(Along(inputs, [last_axis]), Along(self.kernel, [0]))

        x_shape = tf.shape(x)

        # Apply bias.
        x += tf.broadcast_to(self.bias, shape=x_shape)

        output_nan_mask = tf.broadcast_to(nan_mask, x_shape)
        broadcast_none_repr = tf.broadcast_to(self.none_repr, shape=x_shape)
        return tf.where(output_nan_mask, broadcast_none_repr, x)
示例#3
0
def broadcast_along(x: tf.Tensor, shape: List[int],
                    axes: List[int]) -> tf.Tensor:
    reshape_shape = shape
    j = 0
    k = 0
    for i in range(len(shape)):
        if i == axes[j]:
            reshape_shape[i] = 1
            j += 1
        else:
            assert reshape_shape[i] == x.shape[k]
            k += 1

    x = tf.reshape(x, reshape_shape)
    return tf.broadcast_to(x, shape)
示例#4
0
from src.commons.imports import tf
from src.commons.tensorflow.extention import with_noise
from src.commons.tensorflow.maker import NAN

ones = tf.ones(shape=(2, 3, 4))
twos = 2 * ones

z = tf.bitwise.bitwise_and(tf.cast(ones, dtype=tf.dtypes.int32),
                           tf.cast(twos, dtype=tf.dtypes.int32))
print("z:", z)

print(ones[0, 1:2, 2:])

mask = tf.constant(((True, False, False), (False, True, False)),
                   dtype=tf.dtypes.bool)
print(mask)
mask = tf.broadcast_to(tf.expand_dims(mask, axis=2), shape=(2, 3, 4))
print(mask)

mixed = tf.where(mask, ones, twos)

print(mixed)

noisy_ones = with_noise(noise=NAN, noise_proportion=0.2)(ones)
print(noisy_ones)
print(tf.reduce_min(noisy_ones, axis=2))
print(tf.reduce_max(noisy_ones, axis=2))
示例#5
0
def nans(shape: List[int], dtype: tf.dtypes.DType) -> tf.Tensor:
    return tf.broadcast_to(tf.cast(NAN, dtype), shape)
示例#6
0
 def closure(t: tf.Tensor) -> tf.Tensor:
     mask = bernoulli(t.shape, noise_proportion)
     broadcast_noise = tf.broadcast_to(noise, t.shape)
     return tf.where(mask, broadcast_noise, t)