def test_all(self): var = xdl.Variable(name="w", dtype=DataType.int32, shape=[4], initializer=xdl.Zeros()) execute(xdl.variable_registers()) execute(xdl.global_initializers()) save_op = xdl.ps_save_op(ckpt_version=np.array(123, dtype=np.int8)) execute(save_op) add_op = xdl.ps_assign_add_op(var_name="w", var_type="index", delta=np.array([1, 2, 3, 4], dtype=np.int32)) execute(add_op) ret = execute(var.value) self.assertTrue((ret == np.array([1, 2, 3, 4])).all()) restore_op = xdl.ps_restore_op( ckpt_version=np.array(123, dtype=np.int8)) execute(restore_op) ret = execute(var.value) self.assertTrue((ret == np.array([0, 0, 0, 0])).all())
def save_op(self, version): return xdl.ps_save_op(_string_to_int8(version))