예제 #1
0
    for _ in range(dims.size(0)):
        data = data.sum(1)
    mask = torch.ones([data.size(0)], dtype=torch.uint8)
    dims = dims[:0]  # empty tensor
    return data, mask, dims


@torch.jit.script
def batch_from_scalar_tensor(data):
    data = data.unsqueeze(0)
    mask = torch.ones([1], dtype=torch.uint8)
    dims = torch.zeros([0], dtype=torch.uint8)
    return data, mask, dims


torch.register_batch_operator("tanh", batch_tanh.graph)
torch.register_batch_operator("sigmoid", batch_sigmoid.graph)
torch.register_batch_operator("relu", batch_relu.graph)
torch.register_batch_operator("neg", batch_neg.graph)
torch.register_batch_operator("neg", batch_neg_scalar.graph)
torch.register_batch_operator("add", batch_add.graph)
torch.register_batch_operator("add", batch_add_scalar.graph)
torch.register_batch_operator("sub", batch_sub.graph)
torch.register_batch_operator("sub", batch_sub_scalar.graph)
torch.register_batch_operator("mul", batch_mul.graph)
torch.register_batch_operator("mul", batch_mul_scalar.graph)
torch.register_batch_operator("div", batch_div.graph)
torch.register_batch_operator("matmul", batch_matmul.graph)
torch.register_batch_operator("mm", batch_mm.graph)
torch.register_batch_operator("fmod", batch_fmod.graph)
torch.register_batch_operator("zeros_like", batch_zeros_like.graph)
예제 #2
0
@torch.jit.script
def batch_select(data, mask, dims, dim, index):
    # if dim == 0:
    #     raise ValueError("Cannot select 0 dim in BatchTensor")
    data = data.select(dim, index)
    if dims[dim - 1]:
        mask = mask.select(dim, 0)
    else:
        mask = mask.select(dim, index)
    dims = torch.cat((dims[:dim - 1], dims[dim:dims.size(0)]))
    return data, mask, dims


# assume data, data1, data2 have same size
@torch.jit.script
def batch_where(data, mask, dims, data1, mask1, dims1, data2, mask2, dims2):
    res_data = torch.where(data, data1, data2)
    res_mask = torch.where(data, mask1, mask2)
    res_dims = dims1 or dims2
    return res_data, res_mask, res_dims


torch.register_batch_operator("tanh", batch_tanh.graph)
torch.register_batch_operator("sigmoid", batch_sigmoid.graph)
torch.register_batch_operator("add", batch_add.graph)
torch.register_batch_operator("mul", batch_mul.graph)
torch.register_batch_operator("matmul", batch_matmul.graph)
torch.register_batch_operator("mm", batch_mm.graph)
torch.register_batch_operator("select", batch_select.graph)
torch.register_batch_operator("where", batch_where.graph)