def test_scan_computes_correct_size_for_multiple_containers(self, tmpdir):
        c1 = containers.Container(os.path.join(tmpdir.strpath, 'c1.h5'))
        c2 = containers.Container(os.path.join(tmpdir.strpath, 'c2.h5'))
        c3 = containers.Container(os.path.join(tmpdir.strpath, 'c3.h5'))
        c1.open()
        c1.set('utt-1', np.random.random((6, 6)).astype(np.float32))
        c1.set('utt-2', np.random.random((2, 6)).astype(np.float32))
        c1.set('utt-3', np.random.random((9, 6)).astype(np.float32))
        c2.open()
        c2.set('utt-1', np.random.random((2, 6)).astype(np.float32))
        c2.set('utt-2', np.random.random((1, 6)).astype(np.float32))
        c2.set('utt-3', np.random.random((4, 6)).astype(np.float32))
        c3.open()
        c3.set('utt-1', np.random.random((1, 6)).astype(np.float32))
        c3.set('utt-2', np.random.random((3, 6)).astype(np.float32))
        c3.set('utt-3', np.random.random((8, 6)).astype(np.float32))

        loader = partitioning.PartitioningContainerLoader(
            ['utt-1', 'utt-2', 'utt-3'], [c1, c2, c3],
            '1000',
            shuffle=True,
            seed=88)

        sizes = loader._scan()

        assert sizes == {
            'utt-1': (6 + 2 + 1) * 6 * np.dtype(np.float32).itemsize,
            'utt-2': (2 + 1 + 3) * 6 * np.dtype(np.float32).itemsize,
            'utt-3': (9 + 4 + 8) * 6 * np.dtype(np.float32).itemsize
        }
    def test_get_lengths_returns_correct_lengths_for_multiple_containers(
            self, tmpdir):
        c1 = containers.Container(os.path.join(tmpdir.strpath, 'c1.h5'))
        c2 = containers.Container(os.path.join(tmpdir.strpath, 'c2.h5'))
        c3 = containers.Container(os.path.join(tmpdir.strpath, 'c3.h5'))
        c1.open()
        c1.set('utt-1', np.random.random((6, 6)).astype(np.float32))
        c1.set('utt-2', np.random.random((2, 6)).astype(np.float32))
        c1.set('utt-3', np.random.random((9, 6)).astype(np.float32))
        c2.open()
        c2.set('utt-1', np.random.random((2, 6)).astype(np.float32))
        c2.set('utt-2', np.random.random((1, 6)).astype(np.float32))
        c2.set('utt-3', np.random.random((4, 6)).astype(np.float32))
        c3.open()
        c3.set('utt-1', np.random.random((1, 6)).astype(np.float32))
        c3.set('utt-2', np.random.random((3, 6)).astype(np.float32))
        c3.set('utt-3', np.random.random((8, 6)).astype(np.float32))

        loader = partitioning.PartitioningContainerLoader(
            ['utt-1', 'utt-2', 'utt-3'], [c1, c2, c3],
            '1000',
            shuffle=True,
            seed=88)

        lengths = loader._get_all_lengths()

        assert len(lengths) == 3
        assert lengths['utt-1'] == (6, 2, 1)
        assert lengths['utt-2'] == (2, 1, 3)
        assert lengths['utt-3'] == (9, 4, 8)
    def test_load_partition_data(self, tmpdir):
        c1 = containers.Container(os.path.join(tmpdir.strpath, 'c1.h5'))
        c1.open()
        utt_1_data = np.random.random((6, 6)).astype(np.float32)
        utt_2_data = np.random.random((2, 6)).astype(np.float32)
        utt_3_data = np.random.random((9, 6)).astype(np.float32)
        utt_4_data = np.random.random((2, 6)).astype(np.float32)
        utt_5_data = np.random.random((5, 6)).astype(np.float32)
        c1.set('utt-1', utt_1_data)
        c1.set('utt-2', utt_2_data)
        c1.set('utt-3', utt_3_data)
        c1.set('utt-4', utt_4_data)
        c1.set('utt-5', utt_5_data)

        loader = partitioning.PartitioningContainerLoader(
            ['utt-1', 'utt-2', 'utt-3', 'utt-4', 'utt-5'],
            c1,
            '250',
            shuffle=False)

        part_1 = loader.load_partition_data(0)
        assert part_1.info.utt_ids == ['utt-1', 'utt-2']
        assert np.allclose(part_1.utt_data[0], utt_1_data)
        assert np.allclose(part_1.utt_data[1], utt_2_data)

        part_2 = loader.load_partition_data(1)
        assert part_2.info.utt_ids == ['utt-3']
        assert np.allclose(part_2.utt_data[0], utt_3_data)

        part_3 = loader.load_partition_data(2)
        assert part_3.info.utt_ids == ['utt-4', 'utt-5']
        assert np.allclose(part_3.utt_data[0], utt_4_data)
        assert np.allclose(part_3.utt_data[1], utt_5_data)
    def test_reload_creates_different_partitions_on_second_run(self, tmpdir):
        c1 = containers.Container(os.path.join(tmpdir.strpath, 'c1.h5'))
        c1.open()
        c1.set('utt-1', np.random.random((6, 6)).astype(np.float32))
        c1.set('utt-2', np.random.random((2, 6)).astype(np.float32))
        c1.set('utt-3', np.random.random((9, 6)).astype(np.float32))
        c1.set('utt-4', np.random.random((2, 6)).astype(np.float32))
        c1.set('utt-5', np.random.random((5, 6)).astype(np.float32))

        loader = partitioning.PartitioningContainerLoader(
            ['utt-1', 'utt-2', 'utt-3', 'utt-4', 'utt-5'],
            c1,
            '250',
            shuffle=True,
            seed=100)

        partitions_one = loader.partitions
        loader.reload()
        partitions_two = loader.partitions

        len_changed = len(partitions_one) == len(partitions_two)

        if len_changed:
            assert True
        else:
            utt_ids_changed = False

            for x, y in zip(partitions_one, partitions_two):
                if x.utt_ids != y.utt_ids:
                    utt_ids_changed = True

            assert utt_ids_changed
    def test_reload_creates_correct_partitions(self, tmpdir):
        c1 = containers.Container(os.path.join(tmpdir.strpath, 'c1.h5'))
        c1.open()
        c1.set('utt-1', np.random.random((6, 6)).astype(np.float32))
        c1.set('utt-2', np.random.random((2, 6)).astype(np.float32))
        c1.set('utt-3', np.random.random((9, 6)).astype(np.float32))
        c1.set('utt-4', np.random.random((2, 6)).astype(np.float32))
        c1.set('utt-5', np.random.random((5, 6)).astype(np.float32))

        loader = partitioning.PartitioningContainerLoader(
            ['utt-1', 'utt-2', 'utt-3', 'utt-4', 'utt-5'],
            c1,
            '250',
            shuffle=False)

        assert len(loader.partitions) == 3
        assert loader.partitions[0].utt_ids == ['utt-1', 'utt-2']
        assert loader.partitions[0].utt_lengths == [(6, ), (2, )]
        assert loader.partitions[0].size == 192
        assert loader.partitions[1].utt_ids == ['utt-3']
        assert loader.partitions[1].utt_lengths == [(9, )]
        assert loader.partitions[1].size == 216
        assert loader.partitions[2].utt_ids == ['utt-4', 'utt-5']
        assert loader.partitions[2].utt_lengths == [(2, ), (5, )]
        assert loader.partitions[2].size == 168
    def test_raises_error_if_utt_is_missing_in_container(self, tmpdir):
        c1 = containers.Container(os.path.join(tmpdir.strpath, 'c1.h5'))
        c1.open()
        c1.set('utt-1', np.random.random((6, 6)).astype(np.float32))
        c1.set('utt-3', np.random.random((9, 6)).astype(np.float32))

        with pytest.raises(ValueError):
            partitioning.PartitioningContainerLoader(
                ['utt-1', 'utt-2', 'utt-3'], c1, '250', shuffle=True, seed=88)
    def test_reload_creates_no_partition_with_no_utterances(self, tmpdir):
        c1 = containers.Container(os.path.join(tmpdir.strpath, 'c1.h5'))
        c1.open()

        loader = partitioning.PartitioningContainerLoader([],
                                                          c1,
                                                          '250',
                                                          shuffle=False)

        assert len(loader.partitions) == 0
    def test_scan_computes_correct_size_for_one_container(self, tmpdir):
        c1 = containers.Container(os.path.join(tmpdir.strpath, 'c1.h5'))
        c1.open()
        c1.set('utt-1', np.random.random((6, 6)).astype(np.float32))
        c1.set('utt-2', np.random.random((2, 6)).astype(np.float32))
        c1.set('utt-3', np.random.random((9, 6)).astype(np.float32))

        loader = partitioning.PartitioningContainerLoader(
            ['utt-1', 'utt-2', 'utt-3'], c1, '250', shuffle=True, seed=88)

        sizes = loader._scan()

        assert sizes == {
            'utt-1': 6 * 6 * np.dtype(np.float32).itemsize,
            'utt-2': 2 * 6 * np.dtype(np.float32).itemsize,
            'utt-3': 9 * 6 * np.dtype(np.float32).itemsize
        }