def __init__(self, config_filename, output_model_name):
        ''' Load/validate model information from a config file '''
        config = ConfigParser(allow_no_value=True)
        config.read(config_filename)
        my_config = config['Settings']
        self.spark = pyspark.sql.SparkSession.builder.appName('vienna') \
         .getOrCreate()

        # Load storage account info
        self.storage_account_name = ensure_str(
            my_config['storage_account_name'])
        self.storage_account_key = ensure_str(my_config['storage_account_key'])
        self.container_pretrained_models = ensure_str(
            my_config['container_pretrained_models'])
        self.container_trained_models = ensure_str(
            my_config['container_trained_models'])
        self.container_data_o16n = ensure_str(my_config['container_data_o16n'])
        self.container_prediction_results = ensure_str(
            my_config['container_prediction_results'])
        self.predictions_filename = '{}_predictions_o16n.csv'.format(
            output_model_name)

        # Load blob service and ensure containers are available
        blob_service = BlockBlobService(self.storage_account_name,
                                        self.storage_account_key)
        container_list = [i.name for i in blob_service.list_containers()]
        for container in [
                self.container_pretrained_models,
                self.container_trained_models, self.container_data_o16n,
                self.container_prediction_results
        ]:
            assert container in container_list, \
             'Could not find container {} in storage '.format(container) + \
             'account {}'.format(self.storage_account_name)

        # Load information on the named model
        self.output_model_name = output_model_name
        description = blob_service.get_blob_to_text(
            container_name=self.container_trained_models,
            blob_name='{}/model.info'.format(self.output_model_name))
        description_dict = {}
        for line in description.content.split('\n'):
            if len(line) == 0:
                continue
            key, val = line.strip().split(',')
            description_dict[key] = val
        self.model_source = description_dict['model_source']
        self.pretrained_model_type = description_dict['pretrained_model_type']

        # Create pipeline components common to both model types
        self.extract_path_udf = udf(lambda row: os.path.basename(row.path),
                                    StringType())
        self.unroller = mmlspark.UnrollImage().setInputCol('image') \
         .setOutputCol('unrolled')
        return
    def __init__(self, config_filename, pretrained_model_type,
                 mmlspark_model_type, output_model_name):
        ''' Load static info for cluster/job creation from a config file '''
        config = ConfigParser(allow_no_value=True)
        config.read(config_filename)
        my_config = config['Settings']
        self.spark = pyspark.sql.SparkSession.builder.appName('vienna') \
         .getOrCreate()

        self.pretrained_model_type = pretrained_model_type
        self.mmlspark_model_type = mmlspark_model_type
        self.output_model_name = output_model_name

        # Storage account where results will be written
        self.storage_account_name = ensure_str(
            my_config['storage_account_name'])
        self.storage_account_key = ensure_str(my_config['storage_account_key'])
        self.container_pretrained_models = ensure_str(
            my_config['container_pretrained_models'])
        self.container_trained_models = ensure_str(
            my_config['container_trained_models'])
        self.container_data_training = ensure_str(
            my_config['container_data_training'])
        self.container_data_testing = ensure_str(
            my_config['container_data_testing'])
        self.container_prediction_results = ensure_str(
            my_config['container_prediction_results'])

        # URIs where data will be loaded or saved
        self.train_uri = 'wasb://{}@{}.blob.core.windows.net/*/*.png'.format(
            self.container_data_training, self.storage_account_name)
        self.test_uri = 'wasb://{}@{}.blob.core.windows.net/*/*.png'.format(
            self.container_data_testing, self.storage_account_name)
        self.model_uri = 'wasb://{}@{}.blob.core.windows.net/{}'.format(
         self.container_pretrained_models, self.storage_account_name,
         'ResNet_18.model' if pretrained_model_type == 'resnet18' \
         else 'AlexNet.model')
        self.output_uri = 'wasb://{}@{}.blob.core.windows.net/{}/model'.format(
            self.container_trained_models, self.storage_account_name,
            output_model_name)
        self.predictions_filename = '{}_predictions_test_set.csv'.format(
            output_model_name)

        # Load the pretrained model
        self.last_layer_name = 'z.x' if (pretrained_model_type == 'resnet18') \
         else 'h2_d'
        self.cntk_model = mmlspark.CNTKModel().setInputCol('unrolled') \
         .setOutputCol('features') \
         .setModelLocation(self.spark, self.model_uri) \
         .setOutputNodeName(self.last_layer_name)

        # Initialize other Spark pipeline components
        self.extract_label_udf = udf(
            lambda row: os.path.basename(os.path.dirname(row.path)),
            StringType())
        self.extract_path_udf = udf(lambda row: row.path, StringType())
        if mmlspark_model_type == 'randomforest':
            self.mmlspark_model_type = RandomForestClassifier(numTrees=20,
                                                              maxDepth=5)
        elif mmlspark_model_type == 'logisticregression':
            self.mmlspark_model_type = LogisticRegression(regParam=0.01,
                                                          maxIter=10)
        self.unroller = mmlspark.UnrollImage().setInputCol('image') \
         .setOutputCol('unrolled')

        return