示例#1
0
 def construct(self, tensor_a, tensor_b):
     a = F.mixed_precision_cast(mstype.float16, tensor_a)
     b = F.mixed_precision_cast(mstype.float16, tensor_b)
     c = self.sub(a, b)
     dictionary = {"key": a}
     result = self.net(c, key1=a, key2=dictionary)
     return result
示例#2
0
 def construct(self,
               image,
               target,
               weight,
               scale=None,
               center=None,
               score=None,
               idx=None):
     out = self._backbone(image)
     output = F.mixed_precision_cast(mstype.float32, out)
     target = F.mixed_precision_cast(mstype.float32, target)
     weight = F.mixed_precision_cast(mstype.float32, weight)
     return self._loss_fn(output, target, weight)
示例#3
0
    def construct(self, positions, forces, energy):
        outputs = self._network(positions)
        foutputs = -1 * self.grad_op(self._network)(positions)
        if self.add_cast_fp32:
            forces = F.mixed_precision_cast(ms.float32, forces)
            energy = F.mixed_precision_cast(ms.float32, energy)
            outputs = F.cast(outputs, ms.float32)

        if self._energy_fn is None:
            eloss = 0
        else:
            eloss = self._energy_fn(outputs, energy)

        if self._force_fn is None:
            floss = 0
        else:
            floss = self._force_fn(foutputs, forces)
        
        return eloss, floss, outputs, energy, foutputs, forces
示例#4
0
def test_mixed_precision_cast():
    x = Tensor(np.ones([2, 3], dtype=np.float32))
    z = F.mixed_precision_cast(mstype.float16, x)
    assert z.dtype == mstype.float16
示例#5
0
 def construct(self, tensor_c, **kwargs):
     d = F.mixed_precision_cast(mstype.float16, tensor_c)
     dict_cast = F.mixed_precision_cast(mstype.float16, kwargs)
     e = self.add(d, dict_cast["key1"])
     f = self.add(e, dict_cast["key2"]["key"])
     return f
示例#6
0
 def construct(self, data, label):
     out = self._backbone(data)
     label = F.mixed_precision_cast(mstype.float32, label)
     return self._loss_fn(F.mixed_precision_cast(mstype.float32, out),
                          label)