def load_mmlspark_model_components(config):
    ''' Loads all components needed to apply a trained MMLSpark model '''
    # Load the pretrained featurization model
    if config.pretrained_model_type == 'resnet18':
        model_filename = 'ResNet_18.model'
        last_layer_name = 'z.x'
    elif config.pretrained_model_type == 'alexnet':
        model_filename = 'AlexNet.model'
        last_layer_name = 'h2_d'
    model_uri = 'wasb://{}@'.format(config.container_pretrained_models) + \
       '{}.blob.core.windows'.format(config.storage_account_name) + \
       '.net/{}'.format(model_filename)
    config.cntk_model = mmlspark.CNTKModel().setInputCol('unrolled') \
     .setOutputCol('features').setModelLocation(config.spark, model_uri) \
     .setOutputNodeName(last_layer_name)

    # Load the MMLSpark-trained model
    mmlspark_uri = 'wasb://{}@'.format(config.container_trained_models) + \
          '{}.blob.core.'.format(config.storage_account_name) + \
          'windows.net/{}/model'.format(config.output_model_name)
    config.mmlspark_model = TrainedClassifierModel.load(mmlspark_uri)

    # Load the transform that will convert model output from indices to strings
    config.tf = mmlspark.IndexToValue().setInputCol('scored_labels') \
     .setOutputCol('pred_label')

    return (config)
def main(pretrained_model_type, mmlspark_model_type, config_filename,
		 output_model_name, sample_frac):
	# Load the configuration file
	config = ConfigFile(config_filename, pretrained_model_type,
		mmlspark_model_type, output_model_name)
	write_model_summary_to_blob(config, mmlspark_model_type)

	# Log the parameters of the run
	run_logger = get_azureml_logger()
	run_logger.log('amlrealworld.aerial_image_classification.run_mmlspark','true')
	run_logger.log('pretrained_model_type', pretrained_model_type)
	run_logger.log('mmlspark_model_type', mmlspark_model_type)
	run_logger.log('config_filename', config_filename)
	run_logger.log('output_model_name', output_model_name)
	run_logger.log('sample_frac', sample_frac)

	# Train and save the MMLSpark model
	train_df = load_data(config.train_uri, config, sample_frac)
	mmlspark_model = mmlspark.TrainClassifier(
		model=config.mmlspark_model_type, labelCol='label').fit(train_df)
	mmlspark_model.write().overwrite().save(config.output_uri)

	# Apply the MMLSpark model to the test set and save the accuracy metric
	test_df = load_data(config.test_uri, config, sample_frac)
	predictions = mmlspark_model.transform(test_df)
	metrics = mmlspark.ComputeModelStatistics(evaluationMetric='accuracy') \
		.transform(predictions)
	metrics.show()
	run_logger.log('accuracy_on_test_set', metrics.first()['accuracy'])
	
	# Save the predictions
	tf = mmlspark.IndexToValue().setInputCol('scored_labels') \
		.setOutputCol('pred_label')
	predictions = tf.transform(predictions).select(
		'filepath', 'label', 'pred_label')
	output_str = predictions.toPandas().to_csv(index=False)
	blob_service = BlockBlobService(config.storage_account_name,
									config.storage_account_key)
	blob_service.create_container(config.container_prediction_results)
	blob_service.create_blob_from_text(
			config.container_prediction_results,
			config.predictions_filename,
			output_str)

	return