Beispiel #1
0
 def preprocess_for_eval(self, raw_input_batch: spec.Tensor,
                         raw_label_batch: spec.Tensor, train_mean: spec.Tensor,
                         train_stddev: spec.Tensor) -> spec.Tensor:
   del train_mean
   del train_stddev
   N = raw_input_batch.size()[0]
   raw_input_batch = raw_input_batch.view(N, -1)
   return (raw_input_batch.to(DEVICE), raw_label_batch.to(DEVICE))
Beispiel #2
0
 def __call__(self, x: spec.Tensor, train: bool):
     del train
     input_size = 28 * 28
     num_hidden = 128
     num_classes = 10
     x = x.reshape((x.shape[0], input_size))  # Flatten.
     x = nn.Dense(features=num_hidden, use_bias=True)(x)
     x = nn.sigmoid(x)
     x = nn.Dense(features=num_classes, use_bias=True)(x)
     x = nn.log_softmax(x)
     return x
Beispiel #3
0
 def forward(self, x: spec.Tensor):
     x = x.view(x.size()[0], -1)
     return self.net(x)
Beispiel #4
0
def _pytorch_to_jax(x: spec.Tensor):
  x = x.contiguous()  # https://github.com/google/jax/issues/8082
  return jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(x))
 def preprocess_for_eval(self, raw_input_batch: spec.Tensor,
                         raw_label_batch: spec.Tensor, train_mean: spec.Tensor,
                         train_stddev: spec.Tensor) -> spec.Tensor:
   del train_mean
   del train_stddev
   return (raw_input_batch.float().to(DEVICE), raw_label_batch.to(DEVICE))