예제 #1
0
 def test_lifetime_cannot_update_with_none(self):
     datalake_file = DatalakeFile(None,
                                  uri=self.uri,
                                  type=type,
                                  lifetime='1day')
     with self.assertRaises(RuntimeError):
         datalake_file.lifetime = None
예제 #2
0
 def test_get_download_uri(self):
     mock_api = MagicMock()
     datalake_file = DatalakeFile(mock_api, uri=self.uri, type=type)
     datalake_file._api._connection.api_request.return_value = self.file_info
     download_uri = datalake_file._get_download_uri()
     self.assertEqual(download_uri, self.file_info['download_uri'])
     datalake_file._api._connection.api_request.assert_called_once()
예제 #3
0
    def test_commit_without_changes(self):
        mock_api = MagicMock()
        datalake_file = DatalakeFile(
            mock_api,
            uri=self.uri,
            type=type,
            metadata={'x-abeja-meta-filename': 'DcZzLGkV4AA8FQc.jpg'})
        file_info = {
            'url_expires_on': '2018-06-04T05:04:46+00:00',
            'uploaded_at': '2018-06-01T05:22:44+00:00',
            'metadata': {
                'x-abeja-meta-filename': 'DcZzLGkV4AA8FQc.jpg'
            },
            'file_id': self.file_id,
            'download_uri': 'https://example.com/dummy_download_uri',
            'content_type': 'image/jpeg'
        }
        datalake_file.get_file_info = MagicMock(return_value=file_info)

        self.assertTrue(datalake_file.commit())

        expected_metadata = {
            'x-abeja-meta-filename': 'DcZzLGkV4AA8FQc.jpg',
        }
        mock_api.put_channel_file_metadata.assert_called_once_with(
            self.channel_id, self.file_id, metadata=expected_metadata)
        mock_api.put_channel_file_lifetime.assert_not_called()
예제 #4
0
 def test_file_info(self):
     mock_api = MagicMock()
     datalake_file = DatalakeFile(mock_api, uri=self.uri, type=type)
     datalake_file._api._connection.api_request.return_value = self.file_info
     file_info = datalake_file.get_file_info()
     self.assertEqual(file_info, self.file_info)
     datalake_file._api._connection.api_request.assert_called_once()
예제 #5
0
 def test_get_json(self):
     datalake_file = DatalakeFile(None, uri=self.uri, type=type)
     mock_func = create_autospec(datalake_file._get_json_from_remote,
                                 return_value=self.json_data)
     datalake_file._get_json_from_remote = mock_func
     data = datalake_file.get_json()
     self.assertEqual(data, self.json_data)
     mock_func.assert_called_once_with()
예제 #6
0
 def test_get_content(self):
     datalake_file = DatalakeFile(None, uri=self.uri, type=type)
     mock_func = create_autospec(datalake_file._get_content_from_remote,
                                 return_value=self.binary_data)
     datalake_file._get_content_from_remote = mock_func
     content = datalake_file.get_content()
     self.assertEqual(content, self.binary_data)
     mock_func.assert_called_once_with()
예제 #7
0
 def test_to_source_data(self):
     datalake_file = DatalakeFile(None, uri=self.uri)
     assert datalake_file.to_source_data() == {'data_uri': self.uri}
     datalake_file.type = 'image/jpeg'
     assert datalake_file.to_source_data() == {
         'data_uri': self.uri,
         'data_type': 'image/jpeg'
     }
예제 #8
0
 def test_get_iter_lines(self):
     self.text_data = 'a\nb\nc'
     datalake_file = DatalakeFile(None, uri=self.uri, type=type)
     mock_func = create_autospec(datalake_file._get_iter_lines_from_remote,
                                 return_value=self._generate_iter_lines())
     datalake_file._get_iter_lines_from_remote = mock_func
     iter_lines = datalake_file.get_iter_lines()
     content = ('').join(list(iter_lines))
     self.assertEqual(content, self.text_data)
     mock_func.assert_called_once_with()
예제 #9
0
 def test_get_text_using_cache(self):
     cache_dir = '{}/{}'.format(TEST_MOUNT_DIR, self.channel_id)
     os.makedirs(cache_dir, exist_ok=True)
     with open('{}/{}'.format(cache_dir, self.file_id), 'w') as f:
         f.write(self.text_data)
     datalake_file = DatalakeFile(None, uri=self.uri, type=type)
     mock_func = MagicMock()
     datalake_file._get_text_from_remote = mock_func
     text = datalake_file.get_text()
     self.assertEqual(text, self.text_data)
     mock_func.assert_not_called()
예제 #10
0
 def test_get_iter_content(self):
     datalake_file = DatalakeFile(None, uri=self.uri, type=type)
     mock_func = create_autospec(
         datalake_file._get_iter_content_from_remote,
         return_value=self._generate_iter_content())
     datalake_file._get_iter_content_from_remote = mock_func
     iter_content = datalake_file.get_iter_content(chunk_size=128)
     content = b''
     for c in iter_content:
         content += c
     self.assertEqual(content, self.binary_data)
     mock_func.assert_called_once_with(128)
예제 #11
0
    def test_do_download(self):
        mock_api = MagicMock()
        datalake_file = DatalakeFile(mock_api, uri=self.uri, type=type)
        dummy_url = 'dummy url'
        datalake_file._get_download_uri = MagicMock(return_value=dummy_url)
        data = datalake_file._api._connection.request.return_value = self.text_data
        datalake_file._do_download()

        self.assertEqual(data, self.text_data)
        datalake_file._get_download_uri.assert_called_once()
        datalake_file._api._connection.request.assert_called_with('GET',
                                                                  dummy_url,
                                                                  stream=False)
예제 #12
0
    def test_do_download_error_handling(self):
        mock_api = MagicMock()
        datalake_file = DatalakeFile(mock_api, uri=self.uri, type=type)
        dummy_url = 'dummy url'
        datalake_file._get_download_uri = MagicMock(return_value=dummy_url)

        http_error = requests.exceptions.HTTPError()
        res = requests.models.Response()
        res.status_code = 400
        res._content = 'test error'.encode('utf-8')
        http_error.response = res
        datalake_file._api._connection.request.side_effect = http_error

        with self.assertRaises(BadRequest):
            datalake_file._do_download()
예제 #13
0
    def _get_label(self, i):
        item = self.dataset_list[i]
        annotation = item.attributes['segmentation']
        channel_id = annotation['channel_id']
        file_id = annotation['file_id']
        uri = 'datalake://{}/{}'.format(channel_id, file_id)
        ftype = 'image/png'

        source = DatalakeFile(api=self.client, channel_id=channel_id, file_id=file_id, uri=uri, type=ftype)
        file_content = source.get_content()
        file_like_object = io.BytesIO(file_content)

        f = Image.open(file_like_object)
        try:
            img = f.convert('P')
            label = np.asarray(img, dtype=np.int32)
        finally:
            f.close()

        # Label id max_id is for unlabeled pixels.
        label[label == self.max_id] = -1

        return label
예제 #14
0
    def test_commit_failed_in_updating_lifetime(self):
        mock_api = MagicMock()
        datalake_file = DatalakeFile(
            mock_api,
            uri=self.uri,
            type=type,
            metadata={'x-abeja-meta-filename': 'DcZzLGkV4AA8FQc.jpg'})
        file_info = {
            'url_expires_on': '2018-06-04T05:04:46+00:00',
            'uploaded_at': '2018-06-01T05:22:44+00:00',
            'metadata': {
                'x-abeja-meta-filename': 'DcZzLGkV4AA8FQc.jpg'
            },
            'file_id': self.file_id,
            'download_uri': 'https://example.com/dummy_download_uri',
            'content_type': 'image/jpeg'
        }
        datalake_file.get_file_info = MagicMock(return_value=file_info)
        mock_api.put_channel_file_lifetime.side_effect = HttpError(
            error='bad_request',
            error_description='bad request',
            status_code=400,
            url='dummy')

        datalake_file.metadata['label'] = 'cat'
        datalake_file.lifetime = '1week'
        with self.assertRaises(HttpError):
            datalake_file.commit()

        mock_api.put_channel_file_lifetime.assert_called_once_with(
            self.channel_id, self.file_id, lifetime='1week')
        self.assertEqual(mock_api.put_channel_file_metadata.call_count, 2)

        call_args = mock_api.put_channel_file_metadata.call_args_list

        expected_metadata = {
            'x-abeja-meta-filename': 'DcZzLGkV4AA8FQc.jpg',
            'x-abeja-meta-label': 'cat'
        }
        self.assertListEqual(list(call_args[0]),
                             [(self.channel_id, self.file_id), {
                                 'metadata': expected_metadata
                             }])

        expected_metadata = {'x-abeja-meta-filename': 'DcZzLGkV4AA8FQc.jpg'}
        self.assertListEqual(list(call_args[1]),
                             [(self.channel_id, self.file_id), {
                                 'metadata': expected_metadata
                             }])
예제 #15
0
def file_factory(client: APIClient, uri: str, type: str,
                 **kwargs) -> SourceData:
    """generate file for the given uri

    :param client:
    :param uri:
    :param type:
    :param kwargs:
    :return:
    :raises: UnsupportedURI if given uri is not supported
    """
    pr = urlparse(uri)
    if pr.scheme == 'datalake':
        return DatalakeFile(client, uri=uri, type=type, **kwargs)
    elif pr.scheme == "http" or pr.scheme == "https":
        return HTTPFile(client, uri=uri)
    raise UnsupportedURI('{} is not supported.'.format(uri))
예제 #16
0
    def __getitem__(self, index):
        item = self.dataset_list[index]

        # Image
        file_content = self.read_data(item.source_data[0])
        file_like_object = io.BytesIO(file_content)
        img = Image.open(file_like_object)

        # Label
        annotation = item.attributes['segmentation']
        channel_id = annotation['channel_id']
        file_id = annotation['file_id']
        uri = 'datalake://{}/{}'.format(channel_id, file_id)
        ftype = 'image/png'

        source = DatalakeFile(api=self.client,
                              channel_id=channel_id,
                              file_id=file_id,
                              uri=uri,
                              type=ftype)
        file_content = self.read_data(source)
        file_like_object = io.BytesIO(file_content)

        # Label id max_id is for unlabeled pixels.
        f = Image.open(file_like_object)
        try:
            limg = f.convert('P')
            label = np.asarray(limg, dtype=np.int32)
        finally:
            f.close()
        img = np.array(img)
        label[label == self.max_id] = -1

        if self.transform:
            img, label = self.transform(img, label)

        return torch.from_numpy(img.copy()).permute(2, 0, 1), torch.from_numpy(
            label.copy()).long()
예제 #17
0
 def test_get_content_from_remote(self):
     datalake_file = DatalakeFile(None, uri=self.uri, type=type)
     mock_response = self._build_content_response()
     datalake_file._do_download = MagicMock(return_value=mock_response)
     content = datalake_file._get_content_from_remote()
     self.assertEqual(content, self.binary_data)
예제 #18
0
 def test_lifetime_invalid(self):
     with self.assertRaises(RuntimeError):
         DatalakeFile(None, uri=self.uri, type=type, lifetime='invalid')
예제 #19
0
 def test_get_json_from_remote(self):
     datalake_file = DatalakeFile(None, uri=self.uri, type=type)
     mock_response = self._build_json_response()
     datalake_file._do_download = MagicMock(return_value=mock_response)
     data = datalake_file._get_json_from_remote()
     self.assertEqual(data, self.json_data)
예제 #20
0
 def test_lifetime_1month(self):
     datalake_file = DatalakeFile(None, uri=self.uri, type=type)
     datalake_file.lifetime = '1month'
예제 #21
0
 def test_convert_to_file_id(self):
     path = '/{}'.format(self.file_id)
     datalake_file = DatalakeFile(None, uri=self.uri, type=type)
     file_id = datalake_file._convert_to_file_id(path)
     self.assertEqual(file_id, self.file_id)
예제 #22
0
        'personal_access_token': ABEJA_PLATFORM_TOKEN
    }

    client = Client(organization_id=args.organization, credential=credential)
    datalake_client = DatalakeClient(credential=credential)
    dataset = client.get_dataset(args.dataset)

    dataset_list = dataset.dataset_items.list(prefetch=False)

    for d in dataset_list:
        break

    file_content = d.source_data[0].get_content()
    file_like_object = io.BytesIO(file_content)

    img = Image.open(file_like_object)
    img.show()

    uri = d.attributes['segmentation'][0]['uri']
    ftype = 'image/png'
    datalake_file = DatalakeFile(datalake_client,
                                 args.organization,
                                 uri=uri,
                                 type=ftype)

    file_content = datalake_file.get_content()
    file_like_object = io.BytesIO(file_content)

    img = Image.open(file_like_object)
    img.show()
예제 #23
0
 def test_init(self):
     datalake_file = DatalakeFile(None, uri=self.uri, type=type)
     self.assertEqual(datalake_file.uri, self.uri)
     self.assertEqual(datalake_file.type, type)
     self.assertEqual(datalake_file.channel_id, self.channel_id)
     self.assertEqual(datalake_file.file_id, self.file_id)