for _ in range(dims.size(0)): data = data.sum(1) mask = torch.ones([data.size(0)], dtype=torch.uint8) dims = dims[:0] # empty tensor return data, mask, dims @torch.jit.script def batch_from_scalar_tensor(data): data = data.unsqueeze(0) mask = torch.ones([1], dtype=torch.uint8) dims = torch.zeros([0], dtype=torch.uint8) return data, mask, dims torch.register_batch_operator("tanh", batch_tanh.graph) torch.register_batch_operator("sigmoid", batch_sigmoid.graph) torch.register_batch_operator("relu", batch_relu.graph) torch.register_batch_operator("neg", batch_neg.graph) torch.register_batch_operator("neg", batch_neg_scalar.graph) torch.register_batch_operator("add", batch_add.graph) torch.register_batch_operator("add", batch_add_scalar.graph) torch.register_batch_operator("sub", batch_sub.graph) torch.register_batch_operator("sub", batch_sub_scalar.graph) torch.register_batch_operator("mul", batch_mul.graph) torch.register_batch_operator("mul", batch_mul_scalar.graph) torch.register_batch_operator("div", batch_div.graph) torch.register_batch_operator("matmul", batch_matmul.graph) torch.register_batch_operator("mm", batch_mm.graph) torch.register_batch_operator("fmod", batch_fmod.graph) torch.register_batch_operator("zeros_like", batch_zeros_like.graph)
@torch.jit.script def batch_select(data, mask, dims, dim, index): # if dim == 0: # raise ValueError("Cannot select 0 dim in BatchTensor") data = data.select(dim, index) if dims[dim - 1]: mask = mask.select(dim, 0) else: mask = mask.select(dim, index) dims = torch.cat((dims[:dim - 1], dims[dim:dims.size(0)])) return data, mask, dims # assume data, data1, data2 have same size @torch.jit.script def batch_where(data, mask, dims, data1, mask1, dims1, data2, mask2, dims2): res_data = torch.where(data, data1, data2) res_mask = torch.where(data, mask1, mask2) res_dims = dims1 or dims2 return res_data, res_mask, res_dims torch.register_batch_operator("tanh", batch_tanh.graph) torch.register_batch_operator("sigmoid", batch_sigmoid.graph) torch.register_batch_operator("add", batch_add.graph) torch.register_batch_operator("mul", batch_mul.graph) torch.register_batch_operator("matmul", batch_matmul.graph) torch.register_batch_operator("mm", batch_mm.graph) torch.register_batch_operator("select", batch_select.graph) torch.register_batch_operator("where", batch_where.graph)