Exemplo n.º 1
0
 def forward(self, data: Union[Tensor, List[Tensor]], state: Dict[str, Any]) -> Union[Tensor, List[Tensor]]:
     training = state['mode'] == "train" and self.trainable
     if self.epoch_spec != state['epoch']:
         # Gather model input specs for the sake of TensorBoard and Traceability
         self.model.fe_input_spec = FeInputSpec(data, self.model)
         self.epoch_spec = state['epoch']
     data = feed_forward(self.model, data, training=training)
     return data
Exemplo n.º 2
0
 def update_gradient(gradient):
     with tf.GradientTape(persistent=True) as tape:
         pred = feed_forward(
             self.tf_model,
             tf.constant([[1.0, 1.0, 1.0], [1.0, -1.0, -0.5]]))
         output = gradient.forward(data=[pred], state={"tape": tape})
         self.assertTrue(
             is_equal(output[0][0].numpy(),
                      np.array([[2.0], [0.0], [0.5]], dtype="float32")))
Exemplo n.º 3
0
 def forward(self, data: Union[Tensor, List[Tensor]],
             state: Dict[str, Any]) -> Union[Tensor, List[Tensor]]:
     training = state['mode'] == "train" and self.trainable
     if isinstance(self.model,
                   torch.nn.Module) and self.epoch_spec != state['epoch']:
         # Gather model input specs for the sake of TensorBoard and Traceability
         self.model.fe_input_spec = FeInputSpec(data, self.model)
         self.epoch_spec = state['epoch']
     if self.multi_inputs:
         data = feed_forward(self.model, *data, training=training)
     else:
         data = feed_forward(self.model, data, training=training)
     intermediate_outputs = []
     for output in self.intermediate_outputs:
         intermediate_outputs.append(_unpack_output(output, self.device))
         output.clear(
         )  # This will only help with pytorch memory, tf tensors will remain until next forward
     if intermediate_outputs:
         data = to_list(data) + intermediate_outputs
     return data
Exemplo n.º 4
0
 def test_torch_model_input(self):
     gradient = GradientOp(finals="pred",
                           outputs="grad",
                           model=self.torch_model)
     with tf.GradientTape(persistent=True) as tape:
         pred = feed_forward(model=self.torch_model,
                             x=torch.tensor([[1.0, 1.0, 1.0],
                                             [1.0, -1.0, -0.5]]))
         output = gradient.forward(data=[pred], state={"tape": tape})
     self.assertTrue(
         is_equal(output[0][0].numpy(),
                  np.array([[2.0, 0.0, 0.5]], dtype="float32")))