Beispiel #1
0
def test_zeros():
    tensor = zeros((2, 3, 4))
    assert len(list(tensor)) == 2
    assert tensor.op.gpu is False

    tensor2 = zeros((2, 3, 4), chunk_size=1)
    # tensor's op key must be equal to tensor2
    assert tensor.op.key == tensor2.op.key
    assert tensor.key != tensor2.key

    tensor3 = zeros((2, 3, 3))
    assert tensor.op.key != tensor3.op.key
    assert tensor.key != tensor3.key

    # test create chunk op of zeros manually
    chunk_op1 = TensorZeros(dtype=tensor.dtype)
    chunk1 = chunk_op1.new_chunk(None, shape=(3, 3), index=(0, 0))
    chunk_op2 = TensorZeros(dtype=tensor.dtype)
    chunk2 = chunk_op2.new_chunk(None, shape=(3, 4), index=(0, 1))
    assert chunk1.op.key != chunk2.op.key
    assert chunk1.key != chunk2.key

    tensor = zeros((100, 100), chunk_size=50)
    tensor = tile(tensor)
    assert len({c.op.key for c in tensor.chunks}) == 1
    assert len({c.key for c in tensor.chunks}) == 1
Beispiel #2
0
    def testZeros(self):
        tensor = zeros((2, 3, 4))
        self.assertEqual(len(list(tensor)), 2)
        self.assertFalse(tensor.op.gpu)

        tensor2 = zeros((2, 3, 4), chunk_size=1)
        # tensor's op key must be equal to tensor2
        self.assertEqual(tensor.op.key, tensor2.op.key)
        self.assertNotEqual(tensor.key, tensor2.key)

        tensor3 = zeros((2, 3, 3))
        self.assertNotEqual(tensor.op.key, tensor3.op.key)
        self.assertNotEqual(tensor.key, tensor3.key)

        # test create chunk op of zeros manually
        chunk_op1 = TensorZeros(dtype=tensor.dtype)
        chunk1 = chunk_op1.new_chunk(None, shape=(3, 3), index=(0, 0))
        chunk_op2 = TensorZeros(dtype=tensor.dtype)
        chunk2 = chunk_op2.new_chunk(None, shape=(3, 4), index=(0, 1))
        self.assertNotEqual(chunk1.op.key, chunk2.op.key)
        self.assertNotEqual(chunk1.key, chunk2.key)

        tensor = zeros((100, 100), chunk_size=50)
        tensor = tensor.tiles()
        self.assertEqual(len({c.op.key for c in tensor.chunks}), 1)
        self.assertEqual(len({c.key for c in tensor.chunks}), 1)