def preprocess_for_eval(self, raw_input_batch: spec.Tensor, raw_label_batch: spec.Tensor, train_mean: spec.Tensor, train_stddev: spec.Tensor) -> spec.Tensor: del train_mean del train_stddev N = raw_input_batch.size()[0] raw_input_batch = raw_input_batch.view(N, -1) return (raw_input_batch.to(DEVICE), raw_label_batch.to(DEVICE))
def __call__(self, x: spec.Tensor, train: bool): del train input_size = 28 * 28 num_hidden = 128 num_classes = 10 x = x.reshape((x.shape[0], input_size)) # Flatten. x = nn.Dense(features=num_hidden, use_bias=True)(x) x = nn.sigmoid(x) x = nn.Dense(features=num_classes, use_bias=True)(x) x = nn.log_softmax(x) return x
def forward(self, x: spec.Tensor): x = x.view(x.size()[0], -1) return self.net(x)
def _pytorch_to_jax(x: spec.Tensor): x = x.contiguous() # https://github.com/google/jax/issues/8082 return jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(x))
def preprocess_for_eval(self, raw_input_batch: spec.Tensor, raw_label_batch: spec.Tensor, train_mean: spec.Tensor, train_stddev: spec.Tensor) -> spec.Tensor: del train_mean del train_stddev return (raw_input_batch.float().to(DEVICE), raw_label_batch.to(DEVICE))