예제 #1
0
    def testConcatDataset(self):
        l1 = dataset.ListDataset(elem_list=[0, 1, 2, 3])
        l2 = dataset.ListDataset(elem_list=[10, 11, 13])
        concatdataset = dataset.ConcatDataset([l1, l2])

        self.assertEqual(len(concatdataset), 7)
        self.assertEqual(concatdataset[0], 0)
        self.assertEqual(concatdataset[3], 3)
        self.assertEqual(concatdataset[4], 10)
        self.assertEqual(concatdataset[6], 13)
예제 #2
0
    def testListDataset(self):
        h = [0, 1, 2]
        d = dataset.ListDataset(elem_list=h, load=lambda x: x)
        self.assertEqual(len(d), 3)
        self.assertEqual(d[0], 0)

        t = torch.LongTensor([0, 1, 2])
        d = dataset.ListDataset(elem_list=t, load=lambda x: x)
        self.assertEqual(len(d), 3)
        self.assertEqual(d[0], 0)

        a = np.asarray([0, 1, 2])
        d = dataset.ListDataset(elem_list=a, load=lambda x: x)
        self.assertEqual(len(d), 3)
        self.assertEqual(d[0], 0)
예제 #3
0
 def testBatchDataset(self):
     t = torch.range(0, 15).long()
     batchsize = 8
     d = dataset.ListDataset(t, lambda x: {'input': x})
     d = dataset.BatchDataset(d, batchsize)
     ex = d[0]['input']
     self.assertEqual(len(ex), batchsize)
     self.assertEqual(ex[-1], batchsize - 1)
예제 #4
0
    def testListDataset_path(self):
        def prefix(x):
            return 'bar/' + str(x)

        tbl = [0, 1, 2]
        d = dataset.ListDataset(tbl, prefix, 'foo')
        self.assertEqual(len(d), 3)
        self.assertEqual(d[2], 'bar/foo/2')
예제 #5
0
    def testListDataset(self):
        def identity(x):
            return x

        h = [0, 1, 2]
        d = dataset.ListDataset(elem_list=h, load=identity)
        self.assertEqual(len(d), 3)
        self.assertEqual(d[0], 0)

        t = torch.LongTensor([0, 1, 2])
        d = dataset.ListDataset(elem_list=t, load=identity)
        self.assertEqual(len(d), 3)
        self.assertEqual(d[0], 0)

        a = np.asarray([0, 1, 2])
        d = dataset.ListDataset(elem_list=a, load=identity)
        self.assertEqual(len(d), 3)
        self.assertEqual(d[0], 0)
예제 #6
0
    def testListDataset_file(self):
        _, filename = tempfile.mkstemp()
        with open(filename, 'w') as f:
            for i in range(0, 50):
                f.write(str(i) + '\n')

        d = dataset.ListDataset(filename, lambda x: x, 'foo')
        self.assertEqual(len(d), 50)
        self.assertEqual(d[15], 'foo/15')

        os.remove(filename)
예제 #7
0
 def testBatchDataset(self):
     if hasattr(torch, "arange"):
         t = torch.arange(0, 16).long()
     else:
         t = torch.range(0, 15).long()
     batchsize = 8
     d = dataset.ListDataset(t, lambda x: {"input": x})
     d = dataset.BatchDataset(d, batchsize)
     ex = d[0]["input"]
     self.assertEqual(len(ex), batchsize)
     self.assertEqual(ex[-1], batchsize - 1)
예제 #8
0
    def testListDataset_file(self):
        _, filename = tempfile.mkstemp()
        with open(filename, "w") as f:
            for i in range(0, 50):
                f.write(str(i) + "\n")

        d = dataset.ListDataset(filename, lambda x: x, "foo")
        self.assertEqual(len(d), 50)
        self.assertEqual(d[15], "foo/15")

        os.remove(filename)
예제 #9
0
    def testSplitDataset_fractions(self):
        h = [0, 1, 2, 3]
        listdataset = dataset.ListDataset(elem_list=h)
        splitdataset = dataset.SplitDataset(listdataset, {
            'train': 0.75,
            'val': 0.25
        })

        splitdataset.select('train')
        self.assertEqual(len(splitdataset), 3)
        self.assertEqual(splitdataset[2], 2)

        splitdataset.select('val')
        self.assertEqual(len(splitdataset), 1)
        self.assertEqual(splitdataset[0], 3)
예제 #10
0
    def testSplitDataset(self):
        h = [0, 1, 2, 3]
        listdataset = dataset.ListDataset(elem_list=h)
        splitdataset = dataset.SplitDataset(listdataset, {
            'train': 3,
            'val': 1
        })

        splitdataset.select('train')
        self.assertEqual(len(splitdataset), 3)
        self.assertEqual(splitdataset[2], 2)

        splitdataset.select('val')
        self.assertEqual(len(splitdataset), 1)
        self.assertEqual(splitdataset[0], 3)

        # test fluent api
        splitdataset = listdataset.split({'train': 3, 'val': 1})
        splitdataset.select('train')
        self.assertEqual(len(splitdataset), 3)
        self.assertEqual(splitdataset[2], 2)
예제 #11
0
    def testSplitDataset(self):
        h = [0, 1, 2, 3]
        listdataset = dataset.ListDataset(elem_list=h)
        splitdataset = dataset.SplitDataset(listdataset, {
            "train": 3,
            "val": 1
        })

        splitdataset.select("train")
        self.assertEqual(len(splitdataset), 3)
        self.assertEqual(splitdataset[2], 2)

        splitdataset.select("val")
        self.assertEqual(len(splitdataset), 1)
        self.assertEqual(splitdataset[0], 3)

        # test fluent api
        splitdataset = listdataset.split({"train": 3, "val": 1})
        splitdataset.select("train")
        self.assertEqual(len(splitdataset), 3)
        self.assertEqual(splitdataset[2], 2)
예제 #12
0
 def testListDataset_path(self):
     tbl = [0, 1, 2]
     d = dataset.ListDataset(tbl, 'bar/{}'.format, 'foo')
     self.assertEqual(len(d), 3)
     self.assertEqual(d[2], 'bar/foo/2')
예제 #13
0
 def testListDataset_path(self):
     tbl = [0, 1, 2]
     d = dataset.ListDataset(tbl, "bar/{}".format, "foo")
     self.assertEqual(len(d), 3)
     self.assertEqual(d[2], "bar/foo/2")