예제 #1
0
    def test_train_test_split(self):
        client = Client(api_token="57c60ade109be36ef1a1c89f56247109fa448741")
        client.checkout_project(
            project_token="4b003477-3b31-4f74-8952-8a9dc879b0ec")
        client.create_network(network_name="test_creation_network_0")
        client.dl_annotations()
        client.train_test_split(prop=0.7)

        self.assertEqual(len(client.index_url),
                         len(client.dict_annotations["images"]))
        self.assertEqual(len(client.train_list), len(client.train_list_id))
        self.assertEqual(len(client.eval_list), len(client.eval_list_id))
예제 #2
0
 def test_dl_annotations(self):
     client = Client(api_token="57c60ade109be36ef1a1c89f56247109fa448741")
     client.checkout_project(
         project_token="4b003477-3b31-4f74-8952-8a9dc879b0ec")
     client.dl_annotations()
     self.assertTrue(len(client.dict_annotations.keys()) != 0)
예제 #3
0
    def test_create_dataset(self):
        client = Client(api_token="57c60ade109be36ef1a1c89f56247109fa448741")
        ds_id = client.create_dataset(dataset_name="test_dataset_0")
        self.assertTrue(isinstance(ds_id, str))

        with self.assertRaises(ValueError):
            client = Client(
                api_token="57c60ade109be36ef1a1c89f56247109fa448741")
            ds_id = client.create_dataset(dataset_name="test_dataset_0")

    def test_upload_and_create_dataset(self):
        client = Client(api_token="57c60ade109be36ef1a1c89f56247109fa448741")
        client = client.create_and_upload_dataset(
            dataset_name="test_dataset_1", path_to_images="test_images/")


if __name__ == '__main__':
    #unittest.main()
    client = Client(api_token="57c60ade109be36ef1a1c89f56247109fa448741")
    client.checkout_project(
        project_token="d8e65668-9e18-421e-966c-7daa8d7c7497")
    model_name = "ssd_base"
    client.checkout_network(model_name)
    client.dl_annotations()
    client.dl_pictures()
    client.train_test_split()
    client.generate_labelmap()
    a = client.tf_vars_generator(client.label_map, annotation_type="rectangle")
    x = next(a)
    print(x[:4])