Пример #1
0
    def __init__(self, bucket_prefix: str, cache_path: str = None) -> None:
        '''
        Abstraction of storage of files on GCS object storage. Do not call
        this constructor but call the GcpFileStorage.setup() factory method

        :param bucket_prefix: prefix of the GCS bucket, to which '-private' and
        '-public' will be appended
        :param cache_path: path to the cache on the local file system. If no
        cache_path is specified, a local cache will not be used. This is the
        configuration to use when running multiple pods in parallel
        '''

        super().__init__(cache_path, cloud_type=CloudType.GCP)

        self._client = storage.Client()

        self.domain = 'storage.cloud.google.com'
        self.buckets: Dict[str:str] = {
            StorageType.PRIVATE.value:
            f'{bucket_prefix}-{StorageType.PRIVATE.value}',
            StorageType.PUBLIC.value:
            f'{bucket_prefix}-{StorageType.PUBLIC.value}'
        }
        # We keep a cache of Buckets. We call them 'clients' to
        # remain consistent with the implementations for AWS and Azure
        self.clients: Dict[StorageType, Dict[str, Bucket]] = {
            StorageType.PRIVATE.value:
            Bucket(self._client, self.buckets[StorageType.PRIVATE.value]),
            StorageType.PUBLIC.value:
            Bucket(self._client, self.buckets[StorageType.PUBLIC.value])
        }

        _LOGGER.debug('Initialized GCP SDK for buckets '
                      f'{self.buckets[StorageType.PRIVATE.value]} and '
                      f'{self.buckets[StorageType.PUBLIC.value]}')
Пример #2
0
    def create_bucket(self, bucket_name, requester_pays=None, project=None):
        """Create a new bucket.

        For example:

        .. literalinclude:: snippets.py
            :start-after: [START create_bucket]
            :end-before: [END create_bucket]

        This implements "storage.buckets.insert".

        If the bucket already exists, will raise
        :class:`google.cloud.exceptions.Conflict`.

        :type bucket_name: str
        :param bucket_name: The bucket name to create.

        :type requester_pays: bool
        :param requester_pays:
            (Optional) Whether requester pays for API requests for this
            bucket and its blobs.

        :type project: str
        :param project: (Optional) the project under which the  bucket is to
                        be created.  If not passed, uses the project set on
                        the client.

        :rtype: :class:`google.cloud.storage.bucket.Bucket`
        :returns: The newly created bucket.
        """
        bucket = Bucket(self, name=bucket_name)
        if requester_pays is not None:
            bucket.requester_pays = requester_pays
        bucket.create(client=self, project=project)
        return bucket
Пример #3
0
    def test_get_bucket_with_object_miss(self):
        from google.cloud.exceptions import NotFound
        from google.cloud.storage.bucket import Bucket

        project = "PROJECT"
        credentials = _make_credentials()
        client = self._make_one(project=project, credentials=credentials)

        nonesuch = "nonesuch"
        bucket_obj = Bucket(client, nonesuch)
        URI = "/".join(
            [
                client._connection.API_BASE_URL,
                "storage",
                client._connection.API_VERSION,
                "b",
                "nonesuch?projection=noAcl",
            ]
        )
        http = _make_requests_session(
            [_make_json_response({}, status=http_client.NOT_FOUND)]
        )
        client._http_internal = http

        with self.assertRaises(NotFound):
            client.get_bucket(bucket_obj)

        http.request.assert_called_once_with(
            method="GET",
            url=URI,
            data=mock.ANY,
            headers=mock.ANY,
            timeout=self._get_default_timeout(),
        )
Пример #4
0
    def get_bucket(self, bucket_name):
        """Get a bucket by name.

        If the bucket isn't found, this will raise a
        :class:`google.cloud.storage.exceptions.NotFound`.

        For example::

          >>> try:
          >>>   bucket = client.get_bucket('my-bucket')
          >>> except google.cloud.exceptions.NotFound:
          >>>   print('Sorry, that bucket does not exist!')

        This implements "storage.buckets.get".

        :type bucket_name: string
        :param bucket_name: The name of the bucket to get.

        :rtype: :class:`google.cloud.storage.bucket.Bucket`
        :returns: The bucket matching the name provided.
        :raises: :class:`google.cloud.exceptions.NotFound`
        """
        bucket = Bucket(self, name=bucket_name)
        bucket.reload(client=self)
        return bucket
Пример #5
0
    def download_blob_to_file(self,
                              blob_or_uri,
                              file_obj,
                              start=None,
                              end=None):
        """Download the contents of a blob object or blob URI into a file-like object.

        Args:
            blob_or_uri (Union[ \
            :class:`~google.cloud.storage.blob.Blob`, \
             str, \
            ]):
                The blob resource to pass or URI to download.
            file_obj (file):
                A file handle to which to write the blob's data.
            start (int):
                Optional. The first byte in a range to be downloaded.
            end (int):
                Optional. The last byte in a range to be downloaded.

        Examples:
            Download a blob using using a blob resource.

            >>> from google.cloud import storage
            >>> client = storage.Client()

            >>> bucket = client.get_bucket('my-bucket-name')
            >>> blob = storage.Blob('path/to/blob', bucket)

            >>> with open('file-to-download-to') as file_obj:
            >>>     client.download_blob_to_file(blob, file)  # API request.


            Download a blob using a URI.

            >>> from google.cloud import storage
            >>> client = storage.Client()

            >>> with open('file-to-download-to') as file_obj:
            >>>     client.download_blob_to_file(
            >>>         'gs://bucket_name/path/to/blob', file)


        """
        try:
            blob_or_uri.download_to_file(file_obj,
                                         client=self,
                                         start=start,
                                         end=end)
        except AttributeError:
            scheme, netloc, path, query, frag = urlsplit(blob_or_uri)
            if scheme != "gs":
                raise ValueError("URI scheme must be gs")
            bucket = Bucket(self, name=netloc)
            blob_or_uri = Blob(path, bucket)

            blob_or_uri.download_to_file(file_obj,
                                         client=self,
                                         start=start,
                                         end=end)
Пример #6
0
    def test_create_bucket_with_object_conflict(self):
        from google.cloud.exceptions import Conflict
        from google.cloud.storage.bucket import Bucket

        project = "PROJECT"
        other_project = "OTHER_PROJECT"
        credentials = _make_credentials()
        client = self._make_one(project=project, credentials=credentials)

        bucket_name = "bucket-name"
        bucket_obj = Bucket(client, bucket_name)
        URI = "/".join([
            client._connection.API_BASE_URL,
            "storage",
            client._connection.API_VERSION,
            "b?project=%s" % (other_project, ),
        ])
        data = {"error": {"message": "Conflict"}}
        http = _make_requests_session(
            [_make_json_response(data, status=http_client.CONFLICT)])
        client._http_internal = http

        with self.assertRaises(Conflict):
            client.create_bucket(bucket_obj, project=other_project)

        http.request.assert_called_once_with(method="POST",
                                             url=URI,
                                             data=mock.ANY,
                                             headers=mock.ANY)
        json_expected = {"name": bucket_name}
        json_sent = http.request.call_args_list[0][1]["data"]
        self.assertEqual(json_expected, json.loads(json_sent))
Пример #7
0
    def test_get_bucket_with_object_hit(self):
        from google.cloud.storage.bucket import Bucket

        project = "PROJECT"
        credentials = _make_credentials()
        client = self._make_one(project=project, credentials=credentials)

        bucket_name = "bucket-name"
        bucket_obj = Bucket(client, bucket_name)
        URI = "/".join(
            [
                client._connection.API_BASE_URL,
                "storage",
                client._connection.API_VERSION,
                "b",
                "%s?projection=noAcl" % (bucket_name,),
            ]
        )

        data = {"name": bucket_name}
        http = _make_requests_session([_make_json_response(data)])
        client._http_internal = http

        bucket = client.get_bucket(bucket_obj)

        self.assertIsInstance(bucket, Bucket)
        self.assertEqual(bucket.name, bucket_name)
        http.request.assert_called_once_with(
            method="GET",
            url=URI,
            data=mock.ANY,
            headers=mock.ANY,
            timeout=self._get_default_timeout(),
        )
Пример #8
0
    def test_list_blobs(self):
        from google.cloud.storage.bucket import Bucket

        BUCKET_NAME = "bucket-name"

        credentials = _make_credentials()
        client = self._make_one(project="PROJECT", credentials=credentials)
        connection = _make_connection({"items": []})

        with mock.patch(
            "google.cloud.storage.client.Client._connection",
            new_callable=mock.PropertyMock,
        ) as client_mock:
            client_mock.return_value = connection

            bucket_obj = Bucket(client, BUCKET_NAME)
            iterator = client.list_blobs(bucket_obj)
            blobs = list(iterator)

            self.assertEqual(blobs, [])
            connection.api_request.assert_called_once_with(
                method="GET",
                path="/b/%s/o" % BUCKET_NAME,
                query_params={"projection": "noAcl"},
                timeout=self._get_default_timeout(),
            )
 def __init__(self, authenticate_file, bucket_name):
     """
         authenticate_file: GOOGLE AUTH JSON file path.
         bucket_name: Name of bucket file assocaited with.
     """
     client = Client.from_service_account_json(authenticate_file)
     self.__bucket = Bucket(client, bucket_name)
Пример #10
0
    def get_bucket(self, bucket_name):
        """Get a bucket by name.

        If the bucket isn't found, this will raise a
        :class:`google.cloud.exceptions.NotFound`.

        For example:

        .. literalinclude:: snippets.py
            :start-after: [START get_bucket]
            :end-before: [END get_bucket]

        This implements "storage.buckets.get".

        :type bucket_name: str
        :param bucket_name: The name of the bucket to get.

        :rtype: :class:`google.cloud.storage.bucket.Bucket`
        :returns: The bucket matching the name provided.
        :raises: :class:`google.cloud.exceptions.NotFound`
        """
        bucket = Bucket(self, name=bucket_name)

        bucket.reload(client=self)
        return bucket
Пример #11
0
    def create_bucket(self, bucket_name):
        """Create a new bucket.

        For example:

        .. code-block:: python

          >>> bucket = client.create_bucket('my-bucket')
          >>> print(bucket)
          <Bucket: my-bucket>

        This implements "storage.buckets.insert".

        If the bucket already exists, will raise
        :class:`google.cloud.exceptions.Conflict`.

        :type bucket_name: str
        :param bucket_name: The bucket name to create.

        :rtype: :class:`google.cloud.storage.bucket.Bucket`
        :returns: The newly created bucket.
        """
        bucket = Bucket(self, name=bucket_name)
        bucket.create(client=self)
        return bucket
Пример #12
0
    def create_bucket(self, bucket_or_name, requester_pays=None, project=None):
        """API call: create a new bucket via a POST request.

        See
        https://cloud.google.com/storage/docs/json_api/v1/buckets/insert

        Args:
            bucket_or_name (Union[ \
                :class:`~google.cloud.storage.bucket.Bucket`, \
                 str, \
            ]):
                The bucket resource to pass or name to create.
            requester_pays (bool):
                Optional. Whether requester pays for API requests for this
                bucket and its blobs.
            project (str):
                Optional. the project under which the  bucket is to be created.
                If not passed, uses the project set on the client.

        Returns:
            google.cloud.storage.bucket.Bucket
                The newly created bucket.

        Raises:
            google.cloud.exceptions.Conflict
                If the bucket already exists.

        Examples:
            Create a bucket using a string.

            .. literalinclude:: snippets.py
                :start-after: [START create_bucket]
                :end-before: [END create_bucket]

            Create a bucket using a resource.

            >>> from google.cloud import storage
            >>> client = storage.Client()

            >>> # Set properties on a plain resource object.
            >>> bucket = storage.Bucket("my-bucket-name")
            >>> bucket.location = "europe-west6"
            >>> bucket.storage_class = "COLDLINE"

            >>> # Pass that resource object to the client.
            >>> bucket = client.create_bucket(bucket)  # API request.

        """

        bucket = None
        if isinstance(bucket_or_name, Bucket):
            bucket = bucket_or_name
        else:
            bucket = Bucket(self, name=bucket_or_name)

        if requester_pays is not None:
            bucket.requester_pays = requester_pays
        bucket.create(client=self, project=project)
        return bucket
Пример #13
0
 def _makeOne(self, client=None, name=None, properties=None):
     from google.cloud.storage.bucket import Bucket
     if client is None:
         connection = _Connection()
         client = _Client(connection)
     bucket = Bucket(client, name=name)
     bucket._properties = properties or {}
     return bucket
Пример #14
0
    def get_items_from_response(self, response):
        """Factory method which yields :class:`.Bucket` items from a response.

        :type response: dict
        :param response: The JSON API response for a page of buckets.
        """
        for item in response.get('items', []):
            name = item.get('name')
            bucket = Bucket(self.client, name)
            bucket._set_properties(item)
            yield bucket
Пример #15
0
    def test_list_blobs_w_all_arguments_and_user_project(self):
        from google.cloud.storage.bucket import Bucket

        BUCKET_NAME = "name"
        USER_PROJECT = "user-project-123"
        MAX_RESULTS = 10
        PAGE_TOKEN = "ABCD"
        PREFIX = "subfolder"
        DELIMITER = "/"
        VERSIONS = True
        PROJECTION = "full"
        FIELDS = "items/contentLanguage,nextPageToken"
        EXPECTED = {
            "maxResults": 10,
            "pageToken": PAGE_TOKEN,
            "prefix": PREFIX,
            "delimiter": DELIMITER,
            "versions": VERSIONS,
            "projection": PROJECTION,
            "fields": FIELDS,
            "userProject": USER_PROJECT,
        }

        credentials = _make_credentials()
        client = self._make_one(project=USER_PROJECT, credentials=credentials)
        connection = _make_connection({"items": []})

        with mock.patch(
            "google.cloud.storage.client.Client._connection",
            new_callable=mock.PropertyMock,
        ) as client_mock:
            client_mock.return_value = connection

            bucket = Bucket(client, BUCKET_NAME, user_project=USER_PROJECT)
            iterator = client.list_blobs(
                bucket_or_name=bucket,
                max_results=MAX_RESULTS,
                page_token=PAGE_TOKEN,
                prefix=PREFIX,
                delimiter=DELIMITER,
                versions=VERSIONS,
                projection=PROJECTION,
                fields=FIELDS,
                timeout=42,
            )
            blobs = list(iterator)

            self.assertEqual(blobs, [])
            connection.api_request.assert_called_once_with(
                method="GET",
                path="/b/%s/o" % BUCKET_NAME,
                query_params=EXPECTED,
                timeout=42,
            )
Пример #16
0
    def _item_to_value(self, item):
        """Convert a JSON bucket to the native object.

        :type item: dict
        :param item: An item to be converted to a bucket.

        :rtype: :class:`.Bucket`
        :returns: The next bucket in the page.
        """
        name = item.get('name')
        bucket = Bucket(self._parent.client, name)
        bucket._set_properties(item)
        return bucket
Пример #17
0
    def bucket(self):
        if not self._bucket:
            try:
                self._bucket = self.client.get_bucket(self._bucket_name)
            except NotFound:
                self._bucket = Bucket(client=self.client,
                                      name=self._bucket_name)
                self._bucket.create(client=self.client,
                                    location=self.project.location)
                logging.info('Bucket {} not found and was created.'.format(
                    self._bucket.name))

        return self._bucket
Пример #18
0
    def bucket(self, bucket_name):
        """Factory constructor for bucket object.

        .. note::
          This will not make an HTTP request; it simply instantiates
          a bucket object owned by this client.

        :type bucket_name: str
        :param bucket_name: The name of the bucket to be instantiated.

        :rtype: :class:`google.cloud.storage.bucket.Bucket`
        :returns: The bucket object created.
        """
        return Bucket(client=self, name=bucket_name)
Пример #19
0
    def test_create_w_extra_properties(self):
        from google.cloud.storage.client import Client
        from google.cloud.storage.bucket import Bucket

        BUCKET_NAME = "bucket-name"
        PROJECT = "PROJECT"
        CORS = [
            {
                "maxAgeSeconds": 60,
                "methods": ["*"],
                "origin": ["https://example.com/frontend"],
                "responseHeader": ["X-Custom-Header"],
            }
        ]
        LIFECYCLE_RULES = [{"action": {"type": "Delete"}, "condition": {"age": 365}}]
        LOCATION = "eu"
        LABELS = {"color": "red", "flavor": "cherry"}
        STORAGE_CLASS = "NEARLINE"
        DATA = {
            "name": BUCKET_NAME,
            "cors": CORS,
            "lifecycle": {"rule": LIFECYCLE_RULES},
            "location": LOCATION,
            "storageClass": STORAGE_CLASS,
            "versioning": {"enabled": True},
            "billing": {"requesterPays": True},
            "labels": LABELS,
        }

        connection = _make_connection(DATA)
        client = Client(project=PROJECT)
        client._base_connection = connection

        bucket = Bucket(client=client, name=BUCKET_NAME)
        bucket.cors = CORS
        bucket.lifecycle_rules = LIFECYCLE_RULES
        bucket.storage_class = STORAGE_CLASS
        bucket.versioning_enabled = True
        bucket.requester_pays = True
        bucket.labels = LABELS
        client.create_bucket(bucket, location=LOCATION)

        connection.api_request.assert_called_once_with(
            method="POST",
            path="/b",
            query_params={"project": PROJECT},
            data=DATA,
            _target_object=bucket,
            timeout=self._get_default_timeout(),
        )
Пример #20
0
    def get_bucket(self, bucket_or_name):
        """API call: retrieve a bucket via a GET request.

        See
        https://cloud.google.com/storage/docs/json_api/v1/buckets/get

        Args:
            bucket_or_name (Union[ \
                :class:`~google.cloud.storage.bucket.Bucket`, \
                 str, \
            ]):
                The bucket resource to pass or name to create.

        Returns:
            google.cloud.storage.bucket.Bucket
                The bucket matching the name provided.

        Raises:
            google.cloud.exceptions.NotFound
                If the bucket is not found.

        Examples:
            Retrieve a bucket using a string.

            .. literalinclude:: snippets.py
                :start-after: [START get_bucket]
                :end-before: [END get_bucket]

            Get a bucket using a resource.

            >>> from google.cloud import storage
            >>> client = storage.Client()

            >>> # Set properties on a plain resource object.
            >>> bucket = client.get_bucket("my-bucket-name")

            >>> # Time passes. Another program may have modified the bucket
            ... # in the meantime, so you want to get the latest state.
            >>> bucket = client.get_bucket(bucket)  # API request.

        """

        bucket = None
        if isinstance(bucket_or_name, Bucket):
            bucket = bucket_or_name
        else:
            bucket = Bucket(self, name=bucket_or_name)

        bucket.reload(client=self)
        return bucket
Пример #21
0
def _item_to_bucket(iterator, item):
    """Convert a JSON bucket to the native object.

    :type iterator: :class:`~google.api_core.page_iterator.Iterator`
    :param iterator: The iterator that has retrieved the item.

    :type item: dict
    :param item: An item to be converted to a bucket.

    :rtype: :class:`.Bucket`
    :returns: The next bucket in the page.
    """
    name = item.get("name")
    bucket = Bucket(iterator.client, name)
    bucket._set_properties(item)
    return bucket
Пример #22
0
    def bucket(self, bucket_name, user_project=None):
        """Factory constructor for bucket object.

        .. note::
          This will not make an HTTP request; it simply instantiates
          a bucket object owned by this client.

        :type bucket_name: str
        :param bucket_name: The name of the bucket to be instantiated.

        :type user_project: str
        :param user_project: (Optional) the project ID to be billed for API
                             requests made via the bucket.

        :rtype: :class:`google.cloud.storage.bucket.Bucket`
        :returns: The bucket object created.
        """
        return Bucket(client=self, name=bucket_name, user_project=user_project)
Пример #23
0
    def _bucket_arg_to_bucket(self, bucket_or_name):
        """Helper to return given bucket or create new by name.

        Args:
            bucket_or_name (Union[ \
                :class:`~google.cloud.storage.bucket.Bucket`, \
                 str, \
            ]):
                The bucket resource to pass or name to create.

        Returns:
            google.cloud.storage.bucket.Bucket
                The newly created bucket or the given one.
        """
        if isinstance(bucket_or_name, Bucket):
            bucket = bucket_or_name
        else:
            bucket = Bucket(self, name=bucket_or_name)
        return bucket
Пример #24
0
    def test_create_bucket_w_object_success(self):
        from google.cloud.storage.bucket import Bucket

        project = "PROJECT"
        credentials = _make_credentials()
        client = self._make_one(project=project, credentials=credentials)

        bucket_name = "bucket-name"
        bucket_obj = Bucket(client, bucket_name)
        bucket_obj.storage_class = "COLDLINE"
        bucket_obj.requester_pays = True

        URI = "/".join([
            client._connection.API_BASE_URL,
            "storage",
            client._connection.API_VERSION,
            "b?project=%s" % (project, ),
        ])
        json_expected = {
            "name": bucket_name,
            "billing": {
                "requesterPays": True
            },
            "storageClass": "COLDLINE",
        }
        data = json_expected
        http = _make_requests_session([_make_json_response(data)])
        client._http_internal = http

        bucket = client.create_bucket(bucket_obj)

        self.assertIsInstance(bucket, Bucket)
        self.assertEqual(bucket.name, bucket_name)
        self.assertTrue(bucket.requester_pays)
        http.request.assert_called_once_with(method="POST",
                                             url=URI,
                                             data=mock.ANY,
                                             headers=mock.ANY,
                                             timeout=mock.ANY)
        json_sent = http.request.call_args_list[0][1]["data"]
        self.assertEqual(json_expected, json.loads(json_sent))
Пример #25
0
    def create_bucket(self, bucket_name):
        """Create a new bucket.

        For example:

        .. literalinclude:: storage_snippets.py
            :start-after: [START create_bucket]
            :end-before: [END create_bucket]

        This implements "storage.buckets.insert".

        If the bucket already exists, will raise
        :class:`google.cloud.exceptions.Conflict`.

        :type bucket_name: str
        :param bucket_name: The bucket name to create.

        :rtype: :class:`google.cloud.storage.bucket.Bucket`
        :returns: The newly created bucket.
        """
        bucket = Bucket(self, name=bucket_name)
        bucket.create(client=self)
        return bucket
Пример #26
0
def sign(duration: str, key_file: click.File, resource: str) -> None:
    """
    Generate a signed URL that embeds authentication data
    so the URL can be used by someone who does not have a Google account.

    This tool exists to overcome a shortcoming of gsutil signurl that limits
    expiration to 7 days only.

    KEY_FILE should be a path to a JSON file containing service account private key.
    See gsutil signurl --help for details

    RESOURCE is a GCS location in the form <bucket>/<path>
    (don't add neither "gs://" nor "http://...")

    Example: gcs-signurl /tmp/creds.json /foo-bucket/bar-file.txt
    """
    bucket_name, _, path = resource.lstrip("/").partition("/")
    creds = service_account.Credentials.from_service_account_file(
        key_file.name)
    till = datetime.now() + _DurationToTimeDelta(duration)

    # Ignoring potential warning about end user credentials.
    # We don't actually do any operations on the client, but
    # unfortunately the only public API in google-cloud-storage package
    # requires building client->bucket->blob
    message = "Your application has authenticated using end user credentials from Google Cloud SDK"
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", message=message)
        client = Client()
    bucket = Bucket(client, bucket_name)
    blob = Blob(path, bucket)

    # Not passing version argument - to support compatibility with
    # google-cloud-storage<=1.14.0. They default to version 2 and hopefully
    # will not change it anytime soon.
    signed_url = blob.generate_signed_url(expiration=till, credentials=creds)
    click.echo(signed_url)
Пример #27
0
def _get_client_bucket(name: str) -> Tuple[Client, Bucket]:
    client = Client()
    bucket = Bucket(client, name)
    return (client, bucket)
Пример #28
0
def main(checkpoint, **args):
    task_id = setup_logging(
        'gen', logging.NOTSET if args.get('debug', False) else logging.INFO)

    params = dict(
        {
            'n_rnn': 3,
            'dim': 1024,
            'learn_h0': False,
            'q_levels': 256,
            'weight_norm': True,
            'frame_sizes': [16, 16, 4],
            'sample_rate': 16000,
            'n_samples': 1,
            'sample_length': 16000 * 60 * 4,
            'sampling_temperature': 1,
            'q_method': QMethod.LINEAR,
        },
        exp=checkpoint,
        **args)
    logging.info(str(params))
    logging.info('booting')

    # dataset = storage_client.list_blobs(bucket, prefix=path)
    # for blob in dataset:
    #   blob.download_to_filename(blob.name)
    bucket = None

    if args['bucket']:
        logging.debug('setup google storage bucket {}'.format(args['bucket']))
        storage_client = storage.Client()
        bucket = Bucket(storage_client, args['bucket'])

        preload_checkpoint(checkpoint, storage_client, bucket)

    results_path = os.path.abspath(
        os.path.join(checkpoint, os.pardir, os.pardir, task_id))
    ensure_dir_exists(results_path)

    checkpoint = os.path.abspath(checkpoint)

    tmp_pretrained_state = torch.load(
        checkpoint,
        map_location=lambda storage, loc: storage.cuda(0)
        if args['cuda'] else storage)

    # Load all tensors onto GPU 1
    # torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))

    pretrained_state = OrderedDict()

    for k, v in tmp_pretrained_state.items():
        # Delete "model." from key names since loading the checkpoint automatically attaches it
        layer_name = k.replace("model.", "")
        pretrained_state[layer_name] = v
        # print("k: {}, layer_name: {}, v: {}".format(k, layer_name, np.shape(v)))

    # Create model with same parameters as used in training
    model = SampleRNN(frame_sizes=params['frame_sizes'],
                      n_rnn=params['n_rnn'],
                      dim=params['dim'],
                      learn_h0=params['learn_h0'],
                      q_levels=params['q_levels'],
                      weight_norm=params['weight_norm'])
    if params['cuda']:
        model = model.cuda()

    # Load pretrained model
    model.load_state_dict(pretrained_state)

    def upload(file_path):
        if bucket is None:
            return

        # remove prefix /app
        name = file_path.replace(os.path.abspath(os.curdir) + '/', '')
        blob = Blob(name, bucket)
        logging.info('uploading {}'.format(name))
        blob.upload_from_filename(file_path)

    (_, dequantize) = quantizer(params['q_method'])
    gen = Gen(Runner(model), params['cuda'])
    gen.register_plugin(
        GeneratorPlugin(results_path, params['n_samples'],
                        params['sample_length'], params['sample_rate'],
                        params['q_levels'], dequantize,
                        params['sampling_temperature'], upload))

    gen.run()
Пример #29
0
def main(exp, dataset, **params):
    params = dict(default_params, exp=exp, dataset=dataset, **params)
    print(params)
    storage_client = None
    bucket = None

    path = os.path.join(params['datasets_path'], params['dataset'])

    if params['bucket']:
        storage_client = storage.Client()
        bucket = Bucket(storage_client, params['bucket'])
        preload_dataset(path, storage_client, bucket)

    results_path = setup_results_dir(params)
    tee_stdout(os.path.join(results_path, 'log'))

    (quantize, dequantize) = quantizer(params['q_method'])
    model = SampleRNN(frame_sizes=params['frame_sizes'],
                      n_rnn=params['n_rnn'],
                      dim=params['dim'],
                      learn_h0=params['learn_h0'],
                      q_levels=params['q_levels'],
                      weight_norm=params['weight_norm'])
    predictor = Predictor(model, dequantize)
    if params['cuda'] is not False:
        print(params['cuda'])
        model = model.cuda()
        predictor = predictor.cuda()

    optimizer = gradient_clipping(
        torch.optim.Adam(predictor.parameters(), lr=params['learning_rate']))

    data_loader = make_data_loader(path, model.lookback, quantize, params)
    test_split = 1 - params['test_frac']
    val_split = test_split - params['val_frac']

    trainer = Trainer(predictor,
                      sequence_nll_loss_bits,
                      optimizer,
                      data_loader(0, val_split, eval=False),
                      cuda=params['cuda'])

    checkpoints_path = os.path.join(results_path, 'checkpoints')
    checkpoint_data = load_last_checkpoint(checkpoints_path, storage_client,
                                           bucket)
    if checkpoint_data is not None:
        (state_dict, epoch, iteration) = checkpoint_data
        trainer.epochs = epoch
        trainer.iterations = iteration
        predictor.load_state_dict(state_dict)

    trainer.register_plugin(
        TrainingLossMonitor(smoothing=params['loss_smoothing']))
    trainer.register_plugin(
        ValidationPlugin(data_loader(val_split, test_split, eval=True),
                         data_loader(test_split, 1, eval=True)))
    trainer.register_plugin(SchedulerPlugin(params['lr_scheduler_step']))

    def upload(file_path):
        if bucket is None:
            return

        name = file_path.replace(os.path.abspath(os.curdir) + '/', '')
        blob = Blob(name, bucket)
        try:
            blob.upload_from_filename(file_path, timeout=300)
        except Exception as e:
            print(str(e))

    trainer.register_plugin(AbsoluteTimeMonitor())

    samples_path = os.path.join(results_path, 'samples')
    trainer.register_plugin(
        SaverPlugin(checkpoints_path, params['keep_old_checkpoints'], upload))
    trainer.register_plugin(
        GeneratorPlugin(samples_path,
                        params['n_samples'],
                        params['sample_length'],
                        params['sample_rate'],
                        params['q_levels'],
                        dequantize,
                        params['sampling_temperature'],
                        upload=upload))
    trainer.register_plugin(
        Logger(['training_loss', 'validation_loss', 'test_loss', 'time']))
    trainer.register_plugin(
        StatsPlugin(
            results_path,
            iteration_fields=[
                'training_loss',
                #('training_loss', 'running_avg'),
                'time'
            ],
            epoch_fields=[
                'training_loss', ('training_loss', 'running_avg'),
                'validation_loss', 'test_loss', 'time'
            ],
            plots={
                'loss': {
                    'x':
                    'iteration',
                    'ys': [
                        'training_loss',
                        # ('training_loss', 'running_avg'),
                        'validation_loss',
                        'test_loss'
                    ],
                    'log_y':
                    True
                }
            }))

    init_comet(params, trainer, samples_path, params['n_samples'],
               params['sample_rate'])

    trainer.run(params['epoch_limit'])
Пример #30
0
 def bucket(self):
     if self._bucket is None:
         self._bucket = Bucket(self.client, name=self.bucket_name)
     return self._bucket