def __init__(self, hidden_layers, input_size=784, num_classes=10): """Initializes the neural network. Args: hidden_layers: List of ints specifying the sizes of hidden layers. Could be empty. input_size: Length of the input array. The network receives the input image as a flattened 1-d array. Defaults to 784(28*28), the default image size for MNIST. num_classes: The number of output classes. Defaults to 10. """ hidden_layers = [input_size] + hidden_layers + [num_classes] self.weights = [] self.biases = [] for i in range(len(hidden_layers) - 1): # TODO(srbs): This is manually cast to float32 to avoid the cast in # np.dot since backprop fails for tf.cast op. self.weights.append( np.array(np.random.randn(hidden_layers[i + 1], hidden_layers[i]), copy=False, dtype=float32)) self.biases.append( np.array(np.random.randn(hidden_layers[i + 1]), copy=False, dtype=float32))
def test_vmap_out_axes_leaf_types(self): with self.assertRaisesRegex( TypeError, r'vmap out_axes must be an int, None, or .*'): extensions.vmap(lambda x: x, out_axes=(tf_np.array([1., 2. ]), ))(tf_np.array([1., 2.]))
def train(self, x, y, learning_rate=0.01): """Runs a single training pass. Args: x: 2-d array of size batch_size x image_size. y: 2-d array of size batch_size x num_classes in one-hot notation. learning_rate: The learning rate. """ x = np.array(x, copy=False) y = np.array(y, copy=False) def mean_squared_error(x, y): diff = x - y return np.sum(diff * diff) / len(x) wb_tensors = [p.data for p in self.weights + self.biases] with tf.GradientTape() as g: g.watch(wb_tensors) loss = mean_squared_error(self.forward(x), y) gradients = g.gradient(loss.data, wb_tensors) new_weights_and_biases = [] for v, dv in zip(self.weights + self.biases, gradients): new_weights_and_biases.append(v - learning_rate * dv) total_len = len(new_weights_and_biases) self.weights = new_weights_and_biases[:total_len // 2] self.biases = new_weights_and_biases[total_len // 2:]
def testIndexDtypeError(self): # https://github.com/google/jax/issues/2795 jnp.array(1) # get rid of startup warning with warnings.catch_warnings(record=True) as w: warnings.simplefilter("error") jnp.zeros(5).at[::2].set(1) self.assertLen(w, 0)
def testBooleanIndexingWithEmptyResult(self): # based on a TensorFlow Probability test that started failing after #1623 x = jnp.array([-1]) mask = jnp.array([False]) ans = x[mask] # doesn't crash expected = onp.array([-1])[onp.array([False])] self.assertAllClose(ans, expected, check_dtypes=False)
def test_vmap_mismatched_axis_sizes_error_message_issue_705(self): # https://github.com/google/jax/issues/705 with self.assertRaisesRegex( ValueError, 'vmap must have at least one non-None value in in_axes'): # If the output is mapped, there must be a non-None in_axes extensions.vmap(lambda x: x, in_axes=None)(tf_np.array([1., 2.])) # Error is: TypeError: only integer scalar arrays can be converted to a # scalar index with self.assertRaisesRegex( ValueError, 'vmap out_axes specification must be a tree prefix of the ' 'corresponding value.*'): extensions.vmap( lambda x: x, in_axes=0, out_axes=(2, 3))( tf_np.array([1., 2.]))
def sugar_fn(op, indexer, x, y): x = jnp.array(x) return { UpdateOps.UPDATE: x.at[indexer].set, UpdateOps.ADD: x.at[indexer].add, UpdateOps.MUL: x.at[indexer].mul, UpdateOps.MIN: x.at[indexer].min, UpdateOps.MAX: x.at[indexer].max, }[op](y)
def evaluate(self, x, y): """Returns the number of correct predictions. Args: x: 2-d array of size batch_size x image_size. y: 2-d array of size batch_size x num_classes. Returns: A scalar, the number of correct predictions. """ y_actual = np.argmax(y, axis=1) y_predicted = np.argmax(self.forward(x), axis=1) return int( np.sum(np.array(y_actual == y_predicted, copy=False, dtype=int32)))
def f(x): if isinstance(x, (tf.Tensor, tf.IndexedSlices)): return tf_np.array(x, copy=False) else: return x