コード例 #1
0
 def test_all(self):
     var = xdl.Variable(name="w",
                        dtype=DataType.int32,
                        shape=[4],
                        initializer=xdl.Zeros())
     execute(xdl.variable_registers())
     execute(xdl.global_initializers())
     save_op = xdl.ps_save_op(ckpt_version=np.array(123, dtype=np.int8))
     execute(save_op)
     add_op = xdl.ps_assign_add_op(var_name="w",
                                   var_type="index",
                                   delta=np.array([1, 2, 3, 4],
                                                  dtype=np.int32))
     execute(add_op)
     ret = execute(var.value)
     self.assertTrue((ret == np.array([1, 2, 3, 4])).all())
     restore_op = xdl.ps_restore_op(
         ckpt_version=np.array(123, dtype=np.int8))
     execute(restore_op)
     ret = execute(var.value)
     self.assertTrue((ret == np.array([0, 0, 0, 0])).all())
コード例 #2
0
 def restore_op(self, version):
     return xdl.ps_restore_op(_string_to_int8(version))