예제 #1
0
    def test_09_add_command_with_metric_for_wrong_entity(self):
        repo_type = DATASETS
        self.set_up_add()

        create_spec(self, repo_type, self.tmp_dir)
        workspace = os.path.join(self.tmp_dir, repo_type, DATASET_NAME)

        os.makedirs(os.path.join(workspace, 'data'))

        create_file(workspace, 'file1', '0')

        metrics_options = '--metric Accuracy 1 --metric Recall 2'

        self.assertIn(
            output_messages['INFO_ADDING_PATH'] % repo_type,
            check_output(MLGIT_ADD %
                         (repo_type, DATASET_NAME, metrics_options)))
        index = os.path.join(ML_GIT_DIR, repo_type, 'index', 'metadata',
                             DATASET_NAME, 'INDEX.yaml')
        self._check_index(index, ['data/file1'], [])

        with open(os.path.join(workspace, DATASET_NAME + '.spec')) as spec:
            spec_file = yaml_processor.load(spec)
            spec_key = get_spec_key(repo_type)
            metrics = spec_file[spec_key].get('metrics', {})
            self.assertTrue(metrics == {})
예제 #2
0
 def _add_associate_entity_metadata(self, metadata, specs):
     dataset = EntityType.DATASETS.value
     labels = EntityType.LABELS.value
     model = EntityType.MODELS.value
     entity_spec_key = get_spec_key(self.__repo_type)
     if dataset in specs and self.__repo_type in [labels, model]:
         d_spec = specs[dataset]
         refs_path = get_refs_path(self.__config, dataset)
         r = Refs(refs_path, d_spec, dataset)
         tag, sha = r.head()
         if tag is not None:
             log.info(output_messages['INFO_ASSOCIATE_DATASETS'] %
                      (d_spec, tag, self.__repo_type),
                      class_name=LOCAL_REPOSITORY_CLASS_NAME)
             metadata[entity_spec_key][DATASET_SPEC_KEY] = {}
             metadata[entity_spec_key][DATASET_SPEC_KEY]['tag'] = tag
             metadata[entity_spec_key][DATASET_SPEC_KEY]['sha'] = sha
     if labels in specs and self.__repo_type in [model]:
         l_spec = specs[labels]
         refs_path = get_refs_path(self.__config, labels)
         r = Refs(refs_path, l_spec, labels)
         tag, sha = r.head()
         if tag is not None:
             log.info('Associate labels [%s]-[%s] to the %s.' %
                      (l_spec, tag, self.__repo_type),
                      class_name=LOCAL_REPOSITORY_CLASS_NAME)
             metadata[entity_spec_key][LABELS_SPEC_KEY] = {}
             metadata[entity_spec_key][LABELS_SPEC_KEY]['tag'] = tag
             metadata[entity_spec_key][LABELS_SPEC_KEY]['sha'] = sha
예제 #3
0
    def test_33_create_entity_and_gdriveh_storage_with_wizard(self):
        entity_type = DATASETS
        self.assertIn(
            output_messages['INFO_INITIALIZED_PROJECT_IN'] % self.tmp_dir,
            check_output(MLGIT_INIT))
        self.assertIn(
            output_messages['INFO_ADD_REMOTE'] %
            (os.path.join(self.tmp_dir, GIT_PATH), entity_type),
            check_output(MLGIT_REMOTE_ADD %
                         (entity_type,
                          (os.path.join(self.tmp_dir, GIT_PATH)))))
        self.assertNotIn(ERROR_MESSAGE,
                         check_output(MLGIT_ENTITY_INIT % entity_type))
        bucket_name = 'test-wizard'
        storage_type = StorageType.GDRIVEH.value
        runner = CliRunner()
        runner.invoke(entity.datasets,
                      ['create', entity_type + '-ex', '--wizard'],
                      input='\n'.join([
                          'category', 'strict', 'X', GDRIVEH, bucket_name, ''
                      ]))

        with open(os.path.join(self.tmp_dir, ML_GIT_DIR, 'config.yaml'),
                  'r') as c:
            config = yaml_processor.load(c)
            self.assertTrue(bucket_name in config[STORAGE_CONFIG_KEY][GDRIVEH])
        spec = os.path.join(self.tmp_dir, entity_type, entity_type + '-ex',
                            entity_type + '-ex.spec')
        with open(spec, 'r') as s:
            spec_file = yaml_processor.load(s)
            self.assertEqual(
                spec_file[get_spec_key(entity_type)]['manifest']
                [STORAGE_SPEC_KEY], storage_type + '://' + bucket_name)
예제 #4
0
    def __commit_metadata(self, full_metadata_path, index_path, metadata,
                          specs, ws_path):
        idx_path = os.path.join(index_path, 'metadata', self._spec)
        log.debug(output_messages['DEBUG_COMMIT_SPEC'] % self._spec,
                  class_name=METADATA_CLASS_NAME)
        # saves README.md if any
        readme = 'README.md'
        src_readme = os.path.join(idx_path, readme)
        if os.path.exists(src_readme):
            dst_readme = os.path.join(full_metadata_path, readme)
            try:
                shutil.copy2(src_readme, dst_readme)
            except Exception as e:
                log.error(output_messages['ERROR_COULD_NOT_FIND_README'],
                          class_name=METADATA_CLASS_NAME)
                raise e
        amount, workspace_size = self._get_amount_and_size_of_workspace_files(
            full_metadata_path, ws_path)
        # saves metadata and commit

        entity_spec_key = get_spec_key(self.__repo_type)
        metadata[entity_spec_key]['manifest']['files'] = MANIFEST_FILE
        metadata[entity_spec_key]['manifest']['size'] = humanize.naturalsize(
            workspace_size)
        metadata[entity_spec_key]['manifest']['amount'] = amount
        storage = metadata[entity_spec_key]['manifest'][STORAGE_SPEC_KEY]

        manifest = metadata[entity_spec_key]['manifest']
        PluginCaller(manifest).call(ADD_METADATA, ws_path, manifest)

        # Add metadata specific to labels ML entity type
        self._add_associate_entity_metadata(metadata, specs)
        self.__commit_spec(full_metadata_path, metadata)

        return storage
예제 #5
0
    def check_created_folders(self,
                              entity_type,
                              storage_type=S3H,
                              version=1,
                              bucket_name='fake_storage'):
        folder_data = os.path.join(self.tmp_dir, entity_type,
                                   entity_type + '-ex', 'data')
        spec = os.path.join(self.tmp_dir, entity_type, entity_type + '-ex',
                            entity_type + '-ex.spec')
        readme = os.path.join(self.tmp_dir, entity_type, entity_type + '-ex',
                              'README.md')
        entity_spec_key = get_spec_key(entity_type)
        with open(spec, 'r') as s:
            spec_file = yaml_processor.load(s)
            self.assertEqual(
                spec_file[entity_spec_key]['manifest'][STORAGE_SPEC_KEY],
                storage_type + '://' + bucket_name)
            self.assertEqual(spec_file[entity_spec_key]['name'],
                             entity_type + '-ex')
            self.assertEqual(spec_file[entity_spec_key]['version'], version)
        with open(os.path.join(self.tmp_dir, ML_GIT_DIR, 'config.yaml'),
                  'r') as y:
            config = yaml_processor.load(y)
            self.assertIn(entity_type, config)

        self.assertTrue(os.path.exists(folder_data))
        self.assertTrue(os.path.exists(spec))
        self.assertTrue(os.path.exists(readme))
예제 #6
0
파일: helper.py 프로젝트: tspthomas/ml-git
def create_spec(self,
                model,
                tmpdir,
                version=1,
                mutability=STRICT,
                storage_type=STORAGE_TYPE,
                artifact_name=None):
    if not artifact_name:
        artifact_name = f'{model}-ex'
    spec_key = get_spec_key(model)
    spec = {
        spec_key: {
            'categories': ['computer-vision', 'images'],
            'mutability': mutability,
            'manifest': {
                'files': 'MANIFEST.yaml',
                STORAGE_SPEC_KEY: '%s://mlgit' % storage_type
            },
            'name': artifact_name,
            'version': version
        }
    }
    with open(
            os.path.join(tmpdir, model, artifact_name,
                         f'{artifact_name}.spec'), 'w') as y:
        yaml_processor.dump(spec, y)
    spec_file = os.path.join(tmpdir, model, artifact_name,
                             f'{artifact_name}.spec')
    self.assertTrue(os.path.exists(spec_file))
예제 #7
0
    def __commit_metadata(self, full_metadata_path, index_path, metadata,
                          specs, ws_path):
        idx_path = os.path.join(index_path, 'metadata', self._spec)
        log.debug(output_messages['DEBUG_COMMIT_SPEC'] % self._spec,
                  class_name=METADATA_CLASS_NAME)
        # saves README.md if any
        readme = 'README.md'
        src_readme = os.path.join(idx_path, readme)
        self._copy_to_metadata_path(src_readme, full_metadata_path,
                                    'README.md')
        src_ignore_path = os.path.join(idx_path, MLGIT_IGNORE_FILE_NAME)
        self._copy_to_metadata_path(src_ignore_path, full_metadata_path,
                                    MLGIT_IGNORE_FILE_NAME)
        amount, workspace_size = self._get_amount_and_size_of_workspace_files(
            full_metadata_path, ws_path)
        # saves metadata and commit

        entity_spec_key = get_spec_key(self.__repo_type)
        metadata[entity_spec_key]['manifest']['files'] = MANIFEST_FILE
        metadata[entity_spec_key]['manifest']['size'] = humanize.naturalsize(
            workspace_size)
        metadata[entity_spec_key]['manifest']['amount'] = amount
        storage = metadata[entity_spec_key]['manifest'][STORAGE_SPEC_KEY]

        manifest = metadata[entity_spec_key]['manifest']
        PluginCaller(manifest).call(ADD_METADATA, ws_path, manifest)

        # Add metadata specific to labels ML entity type
        self._add_associate_entity_metadata(metadata, specs)
        self.__commit_spec(full_metadata_path, metadata)

        return storage
예제 #8
0
    def test_10_add_command_with_metric_file(self):
        repo_type = MODELS
        entity_name = '{}-ex'.format(repo_type)
        self.set_up_add(repo_type)

        create_spec(self, repo_type, self.tmp_dir)
        workspace = os.path.join(self.tmp_dir, repo_type, entity_name)

        os.makedirs(os.path.join(workspace, 'data'))

        create_file(workspace, 'file1', '0')

        csv_file = os.path.join(self.tmp_dir, 'metrics.csv')

        self.create_csv_file(csv_file, {'Accuracy': 1, 'Recall': 2})

        metrics_options = '--metrics-file="{}"'.format(csv_file)

        self.assertIn(
            output_messages['INFO_ADDING_PATH'] % repo_type,
            check_output(MLGIT_ADD %
                         (repo_type, entity_name, metrics_options)))
        index = os.path.join(ML_GIT_DIR, repo_type, 'index', 'metadata',
                             entity_name, 'INDEX.yaml')
        self._check_index(index, ['data/file1'], [])

        with open(os.path.join(workspace, entity_name + '.spec')) as spec:
            spec_file = yaml_processor.load(spec)
            spec_key = get_spec_key(repo_type)
            metrics = spec_file[spec_key].get('metrics', {})
            self.assertFalse(metrics == {})
            self.assertTrue(metrics['Accuracy'] == 1)
            self.assertTrue(metrics['Recall'] == 2)
예제 #9
0
 def _check_spec_version(self, repo_type, expected_version):
     entity_name = '{}-ex'.format(repo_type)
     workspace = os.path.join(self.tmp_dir, DATASETS, entity_name)
     with open(os.path.join(workspace, entity_name + '.spec')) as spec:
         spec_file = yaml_processor.load(spec)
         spec_key = get_spec_key(repo_type)
         version = spec_file[spec_key].get('version', 0)
         self.assertEquals(version, expected_version)
예제 #10
0
    def test_30_create_entity_and_s3h_storage_with_wizard(self):
        entity_type = DATASETS
        self.assertIn(
            output_messages['INFO_INITIALIZED_PROJECT_IN'] % self.tmp_dir,
            check_output(MLGIT_INIT))
        self.assertIn(
            output_messages['INFO_ADD_REMOTE'] %
            (os.path.join(self.tmp_dir, GIT_PATH), entity_type),
            check_output(MLGIT_REMOTE_ADD %
                         (entity_type,
                          (os.path.join(self.tmp_dir, GIT_PATH)))))
        self.assertNotIn(ERROR_MESSAGE,
                         check_output(MLGIT_ENTITY_INIT % entity_type))

        bucket_name = 'test-wizard'
        endpoint_url = 'www.url.com'
        region = 'us-east-1'
        storage_type = StorageType.S3H.value
        runner = CliRunner()
        runner.invoke(entity.datasets,
                      ['create', entity_type + '-ex', '--wizard'],
                      input='\n'.join([
                          'category', 'strict', 'X', storage_type, bucket_name,
                          PROFILE, endpoint_url, region, ''
                      ]))

        with open(os.path.join(self.tmp_dir, ML_GIT_DIR, 'config.yaml'),
                  'r') as c:
            config = yaml_processor.load(c)
            self.assertTrue(bucket_name in config[STORAGE_CONFIG_KEY][S3H])
            self.assertEqual(
                PROFILE, config[STORAGE_CONFIG_KEY][S3H][bucket_name]
                ['aws-credentials']['profile'])
            self.assertEqual(
                endpoint_url,
                config[STORAGE_CONFIG_KEY][S3H][bucket_name]['endpoint-url'])
            self.assertEqual(
                region, config[STORAGE_CONFIG_KEY][S3H][bucket_name]['region'])

        folder_data = os.path.join(self.tmp_dir, entity_type,
                                   entity_type + '-ex', 'data')
        spec = os.path.join(self.tmp_dir, entity_type, entity_type + '-ex',
                            entity_type + '-ex.spec')
        readme = os.path.join(self.tmp_dir, entity_type, entity_type + '-ex',
                              'README.md')
        entity_spec_key = get_spec_key(entity_type)
        with open(spec, 'r') as s:
            spec_file = yaml_processor.load(s)
            self.assertEqual(
                spec_file[entity_spec_key]['manifest'][STORAGE_SPEC_KEY],
                storage_type + '://' + bucket_name)
        self.assertTrue(os.path.exists(folder_data))
        self.assertTrue(os.path.exists(spec))
        self.assertTrue(os.path.exists(readme))
예제 #11
0
    def metadata_tag(self, metadata):
        repo_type = self.__repo_type
        entity_spec_key = get_spec_key(repo_type)
        sep = '__'
        tag = self.__metadata_spec(metadata, sep)

        tag = sep.join([tag, str(metadata[entity_spec_key]['version'])])

        log.debug(output_messages['DEBUG_NEW_TAG_CREATED'] % tag,
                  class_name=METADATA_CLASS_NAME)
        return tag
예제 #12
0
파일: config.py 프로젝트: HPInc/ml-git
def create_workspace_tree_structure(repo_type,
                                    artifact_name,
                                    categories,
                                    storage_type,
                                    bucket_name,
                                    version,
                                    imported_dir,
                                    mutability,
                                    entity_dir=''):
    # get root path to create directories and files
    repo_type_dir = os.path.join(get_root_path(), repo_type)
    artifact_path = os.path.join(repo_type_dir, entity_dir, artifact_name)
    if not path_is_parent(repo_type_dir, artifact_path):
        raise Exception(
            output_messages['ERROR_INVALID_ENTITY_DIR'].format(entity_dir))
    if os.path.exists(artifact_path):
        raise PermissionError(output_messages['INFO_ENTITY_NAME_EXISTS'])
    data_path = os.path.join(artifact_path, 'data')
    # import files from  the directory passed
    if imported_dir is not None:
        import_dir(imported_dir, data_path)
    else:
        os.makedirs(data_path)

    spec_path = os.path.join(artifact_path, artifact_name + SPEC_EXTENSION)
    readme_path = os.path.join(artifact_path, 'README.md')
    file_exists = os.path.isfile(spec_path)

    storage = '%s://%s' % (storage_type, FAKE_STORAGE
                           if bucket_name is None else bucket_name)
    entity_spec_key = get_spec_key(repo_type)
    spec_structure = {
        entity_spec_key: {
            'categories': categories,
            'manifest': {
                STORAGE_SPEC_KEY: storage
            },
            'name': artifact_name,
            'mutability': mutability,
            'version': version
        }
    }

    # write in spec  file
    if not file_exists:
        yaml_save(spec_structure, spec_path)
        with open(readme_path, 'w'):
            pass
        return True
    else:
        return False
예제 #13
0
    def _get_metrics(self, spec, sha):
        spec_file = self._get_spec_content(spec, sha)
        entity_spec_key = get_spec_key(self.__repo_type)
        metrics = spec_file[entity_spec_key].get(PERFORMANCE_KEY, {})
        metrics_table = PrettyTable()
        if not metrics:
            return ''

        metrics_table.field_names = ['Name', 'Value']
        metrics_table.align['Name'] = 'l'
        metrics_table.align['Value'] = 'l'
        for key, value in metrics.items():
            metrics_table.add_row([key, value])
        return '\n{}:\n{}'.format(PERFORMANCE_KEY, metrics_table.get_string())
예제 #14
0
    def __metadata_spec(self, metadata, sep):
        repo_type = self.__repo_type
        entity_sepc_key = get_spec_key(repo_type)
        cats = metadata[entity_sepc_key]['categories']
        if cats is None:
            log.error(output_messages['ERROR_ENTITY_NEEDS_CATATEGORY'])
            return
        elif type(cats) is list:
            categories = sep.join(cats)
        else:
            categories = cats

        # Generate Spec from Dataset Name & Categories
        try:
            return sep.join([categories, metadata[entity_sepc_key]['name']])
        except Exception:
            log.error(output_messages['ERROR_INVALID_DATASET_SPEC'] %
                      (get_sample_spec_doc('somebucket', entity_sepc_key)))
            return None
예제 #15
0
    def metadata_print(metadata_file, spec_name):
        md = yaml_load(metadata_file)

        sections = EntityType.to_list()
        for section in sections:
            spec_key = get_spec_key(section)
            if section in EntityType.to_list():
                try:
                    md[spec_key]  # 'hack' to ensure we don't print something useless
                    # 'dataset' not present in 'model' and vice versa
                    print('-- %s : %s --' % (section, spec_name))
                except Exception:
                    continue
            elif section not in EntityType.to_list():
                print('-- %s --' % (section))
            try:
                print(get_yaml_str(md[spec_key]))
            except Exception:
                continue
예제 #16
0
    def test_01_mutability_strict_push(self):
        entity_type = DATASETS
        self._create_entity_with_mutability(entity_type, STRICT)
        self._checkout_entity(entity_type)

        spec_with_categories = os.path.join(self.tmp_dir, entity_type,
                                            entity_type + '-ex',
                                            entity_type + '-ex.spec')

        entity_spec_key = get_spec_key(entity_type)
        ws_spec = self._verify_mutability(entity_spec_key, STRICT,
                                          spec_with_categories)
        self._change_mutability(entity_spec_key, FLEXIBLE,
                                spec_with_categories, ws_spec)

        create_file(os.path.join(entity_type, entity_type + '-ex'), 'file2',
                    '012')

        self.assertIn(
            output_messages['ERROR_MUTABILITY_CANNOT_CHANGE'],
            check_output(MLGIT_ADD % (entity_type, entity_type + '-ex', '')))
예제 #17
0
 def _get_tag_info(self, spec, tag):
     entity_spec_key = get_spec_key(self.__repo_type)
     spec_file = self._get_spec_content(spec, tag.commit)[entity_spec_key]
     related_dataset_tag, related_dataset_info = self._get_related_entity_info(
         spec_file, DATASET_SPEC_KEY)
     related_labels_tag, related_labels_info = self._get_related_entity_info(
         spec_file, LABELS_SPEC_KEY)
     tag_info = {
         DATE:
         time.strftime('%Y-%m-%d %H:%M:%S',
                       time.localtime(tag.commit.authored_date)),
         TAG:
         tag.name,
         RELATED_DATASET_TABLE_INFO:
         related_dataset_info,
         RELATED_LABELS_TABLE_INFO:
         related_labels_info
     }
     metrics = spec_file.get(PERFORMANCE_KEY, {})
     tag_info[PERFORMANCE_KEY] = metrics
     tag_table = self._create_tag_info_table(tag_info, metrics)
     return tag_info, tag_table
예제 #18
0
    def test_02_mutability_flexible_push(self):
        entity_type = MODELS
        self._create_entity_with_mutability(entity_type, FLEXIBLE)
        self._checkout_entity(entity_type,
                              'computer-vision__images__models-ex__1')

        spec_with_categories = os.path.join(self.tmp_dir, entity_type,
                                            entity_type + '-ex',
                                            entity_type + '-ex.spec')

        entity_spec_key = get_spec_key(entity_type)
        ws_spec = self._verify_mutability(entity_spec_key, FLEXIBLE,
                                          spec_with_categories)
        self._change_mutability(entity_spec_key, STRICT, spec_with_categories,
                                ws_spec)

        create_file(
            os.path.join(self.tmp_dir, entity_type, entity_type + '-ex'),
            'file2', '012')

        self.assertIn(
            output_messages['ERROR_MUTABILITY_CANNOT_CHANGE'],
            check_output(MLGIT_ADD % (entity_type, entity_type + '-ex', '')))
예제 #19
0
    def tag_exists(self, index_path, version=None):
        spec_file = os.path.join(index_path, 'metadata', self._spec,
                                 self._spec + SPEC_EXTENSION)
        full_metadata_path, entity_sub_path, metadata = self._full_metadata_path(
            spec_file)
        if metadata is None:
            return full_metadata_path, entity_sub_path, metadata

        if version:
            metadata[get_spec_key(self.__repo_type)]['version'] = version

        # generates a tag to associate to the commit
        tag = self.metadata_tag(metadata)

        # check if tag already exists in the ml-git repository
        tags = self._tag_exists(tag)
        if len(tags) > 0:
            log.error(output_messages[
                'ERROR_TAG_ALREADY_EXISTS_CONSIDER_USER_VERSION'] %
                      (tag, self.__repo_type),
                      class_name=METADATA_CLASS_NAME)
            return None, None, None
        return full_metadata_path, entity_sub_path, metadata
예제 #20
0
파일: test_api.py 프로젝트: HPInc/ml-git
def get_sample_spec(type):
    spec_test = {
        get_spec_key(type): {
            'categories': ['test'],
            'manifest': {
                'amount': 2,
                'files': 'MANIFEST.yaml',
                'size': '7.2 kB',
                'storage': 's3h://mlgit-bucket'
            },
            'mutability': 'mutable',
            'name': '{}-ex'.format(type),
            'version': 1
        }
    }

    if type == EntityType.MODELS.value:
        spec_test[get_spec_key(type)][get_spec_key(EntityType.DATASETS.value)] = {'tag': 'test__datasets-ex__1'}
        spec_test[get_spec_key(type)][get_spec_key(EntityType.LABELS.value)] = {'tag': 'test__labels-ex__1'}
    elif type == EntityType.LABELS.value:
        spec_test[get_spec_key(type)][get_spec_key(EntityType.DATASETS.value)] = {'tag': 'test__datasets-ex__1'}
    return spec_test
예제 #21
0
파일: helper.py 프로젝트: tspthomas/ml-git
def init_repository(entity,
                    self,
                    version=1,
                    storage_type=S3H,
                    profile=PROFILE,
                    artifact_name=None,
                    category='images'):
    if not artifact_name:
        artifact_name = f'{entity}-ex'
    if os.path.exists(os.path.join(self.tmp_dir, ML_GIT_DIR)):
        self.assertIn(output_messages['INFO_ALREADY_IN_RESPOSITORY'],
                      check_output(MLGIT_INIT))
    else:
        self.assertIn(
            output_messages['INFO_INITIALIZED_PROJECT_IN'] % self.tmp_dir,
            check_output(MLGIT_INIT))

    self.assertIn(
        output_messages['INFO_ADD_REMOTE'] %
        (os.path.join(self.tmp_dir, GIT_PATH), entity),
        check_output(MLGIT_REMOTE_ADD %
                     (entity, os.path.join(self.tmp_dir, GIT_PATH))))

    if storage_type == GDRIVEH:
        self.assertIn(
            output_messages['INFO_ADD_STORAGE_WITHOUT_PROFILE'] %
            (storage_type, BUCKET_NAME),
            check_output(MLGIT_STORAGE_ADD_WITH_TYPE %
                         (BUCKET_NAME, profile, storage_type)))
    elif profile is not None:
        self.assertIn(
            output_messages['INFO_ADD_STORAGE'] %
            (storage_type, BUCKET_NAME, profile),
            check_output(MLGIT_STORAGE_ADD_WITH_TYPE %
                         (BUCKET_NAME, profile, storage_type)))
    else:
        self.assertIn(
            output_messages['INFO_ADD_STORAGE_WITHOUT_PROFILE'] %
            (storage_type, BUCKET_NAME),
            check_output(MLGIT_STORAGE_ADD_WITHOUT_CREDENTIALS % BUCKET_NAME))

    self.assertIn(
        output_messages['INFO_METADATA_INIT'] %
        (os.path.join(self.tmp_dir, GIT_PATH),
         os.path.join(self.tmp_dir, ML_GIT_DIR, entity, 'metadata')),
        check_output(MLGIT_ENTITY_INIT % entity))

    edit_config_yaml(os.path.join(self.tmp_dir, ML_GIT_DIR), storage_type)
    workspace = os.path.join(self.tmp_dir, entity, artifact_name)
    os.makedirs(workspace)
    spec_key = get_spec_key(entity)
    spec = {
        spec_key: {
            'categories': ['computer-vision', category],
            'manifest': {
                'files': 'MANIFEST.yaml',
                STORAGE_SPEC_KEY: '%s://mlgit' % storage_type
            },
            'mutability': STRICT,
            'name': artifact_name,
            'version': version
        }
    }
    with open(
            os.path.join(self.tmp_dir, entity, artifact_name,
                         f'{artifact_name}.spec'), 'w') as y:
        yaml_processor.dump(spec, y)
    spec_file = os.path.join(self.tmp_dir, entity, artifact_name,
                             f'{artifact_name}.spec')
    self.assertTrue(os.path.exists(spec_file))