示例#1
0
    def test_plus_is_minus_variable_local(self):
        x = sy.Variable(torch.FloatTensor([5, 6]))
        y = sy.Variable(torch.FloatTensor([3, 4]))
        x = sy._PlusIsMinusTensor().on(x)
        y = sy._PlusIsMinusTensor().on(y)

        display = 'Variable > _PlusIsMinusTensor > _LocalTensor\n' \
                  ' - FloatTensor > _PlusIsMinusTensor > _LocalTensor\n' \
                  ' - - Variable > _PlusIsMinusTensor > _LocalTensor\n' \
                  '   - FloatTensor > _PlusIsMinusTensor > _LocalTensor'

        assert torch_utils.chain_print(x, display=False) == display

        z = x.add(y)

        assert torch_utils.chain_print(z,
                                 display=False) == 'Variable > _PlusIsMinusTensor > ' \
                                                   '_LocalTensor\n - FloatTensor >' \
                                                   ' _PlusIsMinusTensor > _LocalTensor'

        # cut chain for the equality check
        z.data.child = z.data.child.child
        assert torch.equal(z.data, torch.FloatTensor([2, 2]))

        z = torch.add(x, y)

        # cut chain for the equality check
        z.data.child = z.data.child.child
        assert torch.equal(z.data, torch.FloatTensor([2, 2]))
示例#2
0
    def test_plus_is_minus_tensor_local(self):
        x = torch.FloatTensor([5, 6])
        y = torch.FloatTensor([3, 4])
        x = sy._PlusIsMinusTensor().on(x)
        y = sy._PlusIsMinusTensor().on(y)

        assert torch_utils.chain_print(
            x,
            display=False) == 'FloatTensor > _PlusIsMinusTensor > _LocalTensor'

        z = x.add(y)

        assert torch_utils.chain_print(
            z,
            display=False) == 'FloatTensor > _PlusIsMinusTensor > _LocalTensor'

        # cut chain for the equality check
        z.child = z.child.child
        assert torch.equal(z, torch.FloatTensor([2, 2]))

        z = torch.add(x, y)

        # cut chain for the equality check
        z.child = z.child.child
        assert torch.equal(z, torch.FloatTensor([2, 2]))
示例#3
0
    def test_plus_is_minus_tensor_remote(self):
        x = torch.FloatTensor([5, 6])
        y = torch.FloatTensor([3, 4])
        x = sy._PlusIsMinusTensor().on(x)
        y = sy._PlusIsMinusTensor().on(y)

        id1 = random.randint(0, 10e10)
        id2 = random.randint(0, 10e10)
        x.send(bob, ptr_id=id1)
        y.send(bob, ptr_id=id2)

        z = x.add(y)
        assert torch_utils.chain_print(
            z, display=False) == 'FloatTensor > _PointerTensor'

        # Check chain on remote
        ptr_id = z.child.id_at_location
        assert torch_utils.chain_print(
            bob._objects[ptr_id].parent,
            display=False) == 'FloatTensor > _PlusIsMinusTensor > _LocalTensor'

        z.get()
        assert torch_utils.chain_print(
            z,
            display=False) == 'FloatTensor > _PlusIsMinusTensor > _LocalTensor'

        # cut chain for the equality check
        z.child = z.child.child
        assert torch.equal(z, torch.FloatTensor([2, 2]))
示例#4
0
    def test_plus_is_minus_variable_remote(self):
        x = sy.Variable(torch.FloatTensor([5, 6]))
        y = sy.Variable(torch.FloatTensor([3, 4]))
        x = sy._PlusIsMinusTensor().on(x)
        y = sy._PlusIsMinusTensor().on(y)

        id1 = random.randint(0, 10e10)
        id2 = random.randint(0, 10e10)
        id11 = random.randint(0, 10e10)
        id21 = random.randint(0, 10e10)
        x.send(bob, new_id=id1, new_data_id=id11)
        y.send(bob, new_id=id2, new_data_id=id21)

        z = x.add(y)
        assert torch_utils.chain_print(z, display=False) == 'Variable > _PointerTensor\n' \
                                                      ' - FloatTensor > _PointerTensor\n' \
                                                      ' - - Variable > _PointerTensor\n' \
                                                      '   - FloatTensor > _PointerTensor'

        assert bob._objects[z.id_at_location].owner.id == 'bob'
        assert bob._objects[z.data.id_at_location].owner.id == 'bob'

        # Check chain on remote
        ptr_id = x.child.id_at_location
        display = 'Variable > _PlusIsMinusTensor > _LocalTensor\n' \
                  ' - FloatTensor > _PlusIsMinusTensor > _LocalTensor\n' \
                  ' - - Variable > _PlusIsMinusTensor > _LocalTensor\n' \
                  '   - FloatTensor > _PlusIsMinusTensor > _LocalTensor'
        assert torch_utils.chain_print(bob._objects[ptr_id].parent,
                                       display=False) == display

        # Check chain on remote
        # TODO For now we don't reconstruct the grad chain one non-leaf variable (in our case a leaf
        # variable is a variable that we sent), because we don't care about their gradient. But if we do,
        # then this is a TODO!
        ptr_id = z.child.id_at_location
        display = 'Variable > _PlusIsMinusTensor > _LocalTensor\n' \
                  ' - FloatTensor > _PlusIsMinusTensor > _LocalTensor\n' \
                  ' - - Variable > _LocalTensor\n' \
                  '   - FloatTensor > _LocalTensor'
        assert torch_utils.chain_print(bob._objects[ptr_id].parent,
                                       display=False) == display

        z.get()
        display = 'Variable > _PlusIsMinusTensor > _LocalTensor\n' \
                  ' - FloatTensor > _PlusIsMinusTensor > _LocalTensor\n' \
                  ' - - Variable > _LocalTensor\n' \
                  '   - FloatTensor > _LocalTensor'
        assert torch_utils.chain_print(z, display=False) == display

        # cut chain for the equality check
        z.data.child = z.data.child.child
        assert torch.equal(z.data, torch.FloatTensor([2, 2]))