示例#1
0
 def dense_update(self, var, grad):
     return xdl.ps_dense_apply_momentum_op(learning_rate=self._lr,
                                           momentum=self._momentum,
                                           grad=grad,
                                           var_name=var.name,
                                           var_type=var.vtype,
                                           use_nesterov=self._use_nesterov)
 def test_all(self):
     var = xdl.Variable(name="w",
                        dtype=DataType.float,
                        shape=[4],
                        initializer=xdl.Ones())
     execute(xdl.variable_registers())
     execute(xdl.global_initializers())
     op = xdl.ps_dense_apply_momentum_op(learning_rate=np.array(
         0.5, dtype=np.float),
                                         momentum=np.array(0.9,
                                                           dtype=np.float),
                                         grad=np.array([1, 2, 3, 4],
                                                       dtype=np.float32),
                                         var_name="w",
                                         var_type="index",
                                         use_nesterov=False)
     execute(op)
     ret = execute(var.value)
     self.assertTrue((ret == np.array([0.5, 0, -0.5, -1],
                                      dtype=np.float32)).all())
     execute(op)
     ret = execute(var.value)
     self.assertTrue((ret == np.array([-0.45, -1.9, -3.35, -4.8],
                                      dtype=np.float32)).all())