def make_custom_export_strategy(name, convert_fn, feature_columns, export_input_fn, use_core_columns=False): """Makes custom exporter of GTFlow tree format. Args: name: A string, for the name of the export strategy. convert_fn: A function that converts the tree proto to desired format and saves it to the desired location. Can be None to skip conversion. feature_columns: A list of feature columns. export_input_fn: A function that takes no arguments and returns an `InputFnOps`. Returns: An `ExportStrategy`. """ base_strategy = saved_model_export_utils.make_export_strategy( serving_input_fn=export_input_fn, strip_default_attrs=True) input_fn = export_input_fn() (sorted_feature_names, dense_floats, sparse_float_indices, _, _, sparse_int_indices, _, _) = gbdt_batch.extract_features(input_fn.features, feature_columns, use_core_columns) def export_fn(estimator, export_dir, checkpoint_path=None, eval_result=None): """A wrapper to export to SavedModel, and convert it to other formats.""" result_dir = base_strategy.export(estimator, export_dir, checkpoint_path, eval_result) with ops.Graph().as_default() as graph: with tf_session.Session(graph=graph) as sess: saved_model_loader.load(sess, [tag_constants.SERVING], result_dir) # Note: This is GTFlow internal API and might change. ensemble_model = graph.get_operation_by_name( "ensemble_model/TreeEnsembleSerialize") _, dfec_str = sess.run(ensemble_model.outputs) dtec = tree_config_pb2.DecisionTreeEnsembleConfig() dtec.ParseFromString(dfec_str) # Export the result in the same folder as the saved model. if convert_fn: convert_fn(dtec, sorted_feature_names, len(dense_floats), len(sparse_float_indices), len(sparse_int_indices), result_dir, eval_result) feature_importances = _get_feature_importances( dtec, sorted_feature_names, len(dense_floats), len(sparse_float_indices), len(sparse_int_indices)) sorted_by_importance = sorted(feature_importances.items(), key=lambda x: -x[1]) assets_dir = os.path.join(compat.as_bytes(result_dir), compat.as_bytes("assets.extra")) gfile.MakeDirs(assets_dir) with gfile.GFile( os.path.join(compat.as_bytes(assets_dir), compat.as_bytes("feature_importances")), "w") as f: f.write("\n".join("%s, %f" % (k, v) for k, v in sorted_by_importance)) return result_dir return export_strategy.ExportStrategy(name, export_fn, strip_default_attrs=True)
def make_export_strategy(serving_input_fn, default_output_alternative_key=None, assets_extra=None, as_text=False, exports_to_keep=5, strip_default_attrs=None): """Create an ExportStrategy for use with Experiment. Args: serving_input_fn: A function that takes no arguments and returns an `InputFnOps`. default_output_alternative_key: the name of the head to serve when an incoming serving request does not explicitly request a specific head. Must be `None` if the estimator inherits from @{tf.estimator.Estimator} or for single-headed models. assets_extra: A dict specifying how to populate the assets.extra directory within the exported SavedModel. Each key should give the destination path (including the filename) relative to the assets.extra directory. The corresponding value gives the full path of the source file to be copied. For example, the simple case of copying a single file without renaming it is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`. as_text: whether to write the SavedModel proto in text format. exports_to_keep: Number of exports to keep. Older exports will be garbage-collected. Defaults to 5. Set to None to disable garbage collection. strip_default_attrs: Boolean. If True, default attrs in the `GraphDef` will be stripped on write. This is recommended for better forward compatibility of the resulting `SavedModel`. Returns: An ExportStrategy that can be passed to the Experiment constructor. """ def export_fn(estimator, export_dir_base, checkpoint_path=None, strip_default_attrs=False): """Exports the given Estimator as a SavedModel. Args: estimator: the Estimator to export. export_dir_base: A string containing a directory to write the exported graph and checkpoints. checkpoint_path: The checkpoint path to export. If None (the default), the most recent checkpoint found within the model directory is chosen. strip_default_attrs: Boolean. If `True`, default-valued attributes will be removed from the NodeDefs. Returns: The string path to the exported directory. Raises: ValueError: If `estimator` is a @{tf.estimator.Estimator} instance and `default_output_alternative_key` was specified. """ if isinstance(estimator, core_estimator.Estimator): if default_output_alternative_key is not None: raise ValueError( 'default_output_alternative_key is not supported in core ' 'Estimator. Given: {}'.format( default_output_alternative_key)) export_result = estimator.export_savedmodel( export_dir_base, serving_input_fn, assets_extra=assets_extra, as_text=as_text, checkpoint_path=checkpoint_path, strip_default_attrs=strip_default_attrs) else: export_result = estimator.export_savedmodel( export_dir_base, serving_input_fn, default_output_alternative_key=default_output_alternative_key, assets_extra=assets_extra, as_text=as_text, checkpoint_path=checkpoint_path, strip_default_attrs=strip_default_attrs) garbage_collect_exports(export_dir_base, exports_to_keep) return export_result return export_strategy.ExportStrategy('Servo', export_fn, strip_default_attrs)
def extend_export_strategy(base_export_strategy, post_export_fn, post_export_name=None): """Extend ExportStrategy, calling post_export_fn after export. Args: base_export_strategy: An ExportStrategy that can be passed to the Experiment constructor. post_export_fn: A user-specified function to call after exporting the SavedModel. Takes two arguments - the path to the SavedModel exported by base_export_strategy and the directory where to export the SavedModel modified by the post_export_fn. Returns the path to the exported SavedModel. post_export_name: The directory name under the export base directory where SavedModels generated by the post_export_fn will be written. If None, the directory name of base_export_strategy is used. Returns: An ExportStrategy that can be passed to the Experiment constructor. """ def export_fn(estimator, export_dir_base, checkpoint_path=None): """Exports the given Estimator as a SavedModel and invokes post_export_fn. Args: estimator: the Estimator to export. export_dir_base: A string containing a directory to write the exported graphs and checkpoint. checkpoint_path: The checkpoint path to export. If None (the default), the most recent checkpoint found within the model directory is chosen. Returns: The string path to the SavedModel indicated by post_export_fn. Raises: ValueError: If `estimator` is a `tf.estimator.Estimator` instance and `default_output_alternative_key` was specified or if post_export_fn does not return a valid directory. RuntimeError: If unable to create temporary or final export directory. """ tmp_base_export_folder = 'temp-base-export-' + str(int(time.time())) tmp_base_export_dir = os.path.join(export_dir_base, tmp_base_export_folder) if gfile.Exists(tmp_base_export_dir): raise RuntimeError('Failed to obtain base export directory') gfile.MakeDirs(tmp_base_export_dir) tmp_base_export = base_export_strategy.export(estimator, tmp_base_export_dir, checkpoint_path) tmp_post_export_folder = 'temp-post-export-' + str(int(time.time())) tmp_post_export_dir = os.path.join(export_dir_base, tmp_post_export_folder) if gfile.Exists(tmp_post_export_dir): raise RuntimeError('Failed to obtain temp export directory') gfile.MakeDirs(tmp_post_export_dir) tmp_post_export = post_export_fn(tmp_base_export, tmp_post_export_dir) if not tmp_post_export.startswith(tmp_post_export_dir): raise ValueError( 'post_export_fn must return a sub-directory of {}'.format( tmp_post_export_dir)) post_export_relpath = os.path.relpath(tmp_post_export, tmp_post_export_dir) post_export = os.path.join(export_dir_base, post_export_relpath) if gfile.Exists(post_export): raise RuntimeError('Failed to obtain final export directory') gfile.Rename(tmp_post_export, post_export) gfile.DeleteRecursively(tmp_base_export_dir) gfile.DeleteRecursively(tmp_post_export_dir) return post_export name = post_export_name if post_export_name else base_export_strategy.name return export_strategy.ExportStrategy(name, export_fn)
def make_export_strategy(train_config, args, keep_target, assets_extra=None): def export_fn(estimator, export_dir_base, checkpoint_path=None, eval_result=None): with ops.Graph().as_default() as g: contrib_variables.create_global_step(g) input_ops = serving_from_csv_input(train_config, args, keep_target) model_fn_ops = estimator._call_model_fn( input_ops.features, None, model_fn_lib.ModeKeys.INFER) output_fetch_tensors = make_output_tensors( train_config=train_config, args=args, input_ops=input_ops, model_fn_ops=model_fn_ops, keep_target=keep_target) signature_def_map = { 'serving_default': signature_def_utils.predict_signature_def( input_ops.default_inputs, output_fetch_tensors) } if not checkpoint_path: # Locate the latest checkpoint checkpoint_path = saver.latest_checkpoint(estimator._model_dir) if not checkpoint_path: raise NotFittedError("Couldn't find trained model at %s." % estimator._model_dir) export_dir = saved_model_export_utils.get_timestamped_export_dir( export_dir_base) with tf_session.Session('') as session: #variables.initialize_local_variables() variables.local_variables_initializer() data_flow_ops.tables_initializer() saver_for_restore = saver.Saver(variables.global_variables(), sharded=True) saver_for_restore.restore(session, checkpoint_path) init_op = control_flow_ops.group( variables.local_variables_initializer(), data_flow_ops.tables_initializer()) # Perform the export builder = saved_model_builder.SavedModelBuilder(export_dir) builder.add_meta_graph_and_variables( session, [tag_constants.SERVING], signature_def_map=signature_def_map, assets_collection=ops.get_collection( ops.GraphKeys.ASSET_FILEPATHS), legacy_init_op=init_op) builder.save(False) # Add the extra assets if assets_extra: assets_extra_path = os.path.join( compat.as_bytes(export_dir), compat.as_bytes('assets.extra')) for dest_relative, source in assets_extra.items(): dest_absolute = os.path.join( compat.as_bytes(assets_extra_path), compat.as_bytes(dest_relative)) dest_path = os.path.dirname(dest_absolute) gfile.MakeDirs(dest_path) gfile.Copy(source, dest_absolute) # only keep the last 3 models saved_model_export_utils.garbage_collect_exports(export_dir_base, exports_to_keep=3) # save the last model to the model folder. # export_dir_base = A/B/intermediate_models/ if keep_target: final_dir = os.path.join(args.job_dir, 'evaluation_model') else: final_dir = os.path.join(args.job_dir, 'model') if file_io.is_directory(final_dir): file_io.delete_recursively(final_dir) file_io.recursive_create_dir(final_dir) _recursive_copy(export_dir, final_dir) return export_dir if keep_target: intermediate_dir = 'intermediate_evaluation_models' else: intermediate_dir = 'intermediate_prediction_models' return export_strategy.ExportStrategy(intermediate_dir, export_fn)
def make_best_model_export_strategy(serving_input_fn, exports_to_keep=1, model_dir=None, event_file_pattern=None, compare_fn=None, default_output_alternative_key=None, strip_default_attrs=None): """Creates an custom ExportStrategy for use with tf.contrib.learn.Experiment. Args: serving_input_fn: a function that takes no arguments and returns an `InputFnOps`. exports_to_keep: an integer indicating how many historical best models need to be preserved. model_dir: Directory where model parameters, graph etc. are saved. This will be used to load eval metrics from the directory when the export strategy is created. So the best metrics would not be lost even if the export strategy got preempted, which guarantees that only the best model would be exported regardless of preemption. If None, however, the export strategy would not be preemption-safe. To be preemption-safe, both model_dir and event_file_pattern would be needed. event_file_pattern: event file name pattern relative to model_dir, e.g. "eval_continuous/*.tfevents.*". If None, however, the export strategy would not be preemption-safe. To be preemption-safe, both model_dir and event_file_pattern would be needed. compare_fn: a function that select the 'best' candidate from a dictionary of evaluation result keyed by corresponding checkpoint path. default_output_alternative_key: the key for default serving signature for multi-headed inference graphs. strip_default_attrs: Boolean. If True, default attrs in the `GraphDef` will be stripped on write. This is recommended for better forward compatibility of the resulting `SavedModel`. Returns: An ExportStrategy that can be passed to the Experiment constructor. """ best_model_export_strategy = make_export_strategy( serving_input_fn, exports_to_keep=exports_to_keep, default_output_alternative_key=default_output_alternative_key, strip_default_attrs=strip_default_attrs) full_event_file_pattern = os.path.join( model_dir, event_file_pattern) if model_dir and event_file_pattern else None best_model_selector = BestModelSelector(full_event_file_pattern, compare_fn) def export_fn(estimator, export_dir_base, checkpoint_path, eval_result=None): """Exports the given Estimator as a SavedModel. Args: estimator: the Estimator to export. export_dir_base: A string containing a directory to write the exported graph and checkpoints. checkpoint_path: The checkpoint path to export. If None (the default), the most recent checkpoint found within the model directory is chosen. eval_result: placehold args matching the call signature of ExportStrategy. Returns: The string path to the exported directory. """ if not checkpoint_path: # TODO(b/67425018): switch to # checkpoint_path = estimator.latest_checkpoint() # as soon as contrib is cleaned up and we can thus be sure that # estimator is a tf.estimator.Estimator and not a # tf.contrib.learn.Estimator checkpoint_path = checkpoint_management.latest_checkpoint( estimator.model_dir) export_checkpoint_path, export_eval_result = best_model_selector.update( checkpoint_path, eval_result) if export_checkpoint_path and export_eval_result is not None: checkpoint_base = os.path.basename(export_checkpoint_path) export_dir = os.path.join(export_dir_base, checkpoint_base) return best_model_export_strategy.export(estimator, export_dir, export_checkpoint_path, export_eval_result) else: return '' return export_strategy.ExportStrategy('best_model', export_fn)
def make_best_model_export_strategy(serving_input_fn, exports_to_keep=1, compare_fn=None, default_output_alternative_key=None): """Creates an custom ExportStrategy for use with tf.contrib.learn.Experiment. Args: serving_input_fn: a function that takes no arguments and returns an `InputFnOps`. exports_to_keep: an integer indicating how many historical best models need to be preserved. compare_fn: a function that select the 'best' candidate from a dictionary of evaluation result keyed by corresponding checkpoint path. default_output_alternative_key: the key for default serving signature for multi-headed inference graphs. Returns: An ExportStrategy that can be passed to the Experiment constructor. """ best_model_export_strategy = make_export_strategy( serving_input_fn, exports_to_keep=exports_to_keep, default_output_alternative_key=default_output_alternative_key) best_model_selector = BestModelSelector(compare_fn) def export_fn(estimator, export_dir_base, checkpoint_path, eval_result=None): """Exports the given Estimator as a SavedModel. Args: estimator: the Estimator to export. export_dir_base: A string containing a directory to write the exported graph and checkpoints. checkpoint_path: The checkpoint path to export. If None (the default), the most recent checkpoint found within the model directory is chosen. eval_result: placehold args matching the call signature of ExportStrategy. Returns: The string path to the exported directory. """ if not checkpoint_path: # TODO(b/67425018): switch to # checkpoint_path = estimator.latest_checkpoint() # as soon as contrib is cleaned up and we can thus be sure that # estimator is a tf.estimator.Estimator and not a # tf.contrib.learn.Estimator checkpoint_path = saver.latest_checkpoint(estimator.model_dir) export_checkpoint_path, export_eval_result = best_model_selector.update( checkpoint_path, eval_result) if export_checkpoint_path and export_eval_result is not None: checkpoint_base = os.path.basename(export_checkpoint_path) export_dir = os.path.join(export_dir_base, checkpoint_base) return best_model_export_strategy.export(estimator, export_dir, export_checkpoint_path, export_eval_result) else: return '' return export_strategy.ExportStrategy('best_model', export_fn)
def make_export_strategy(args, keep_target, assets_extra, features, schema, stats): """Makes prediction graph that takes json input. Args: args: command line args keep_target: If ture, target column is returned in prediction graph. Target column must also exist in input data assets_extra: other fiels to copy to the output folder job_dir: root job folder features: features dict schema: schema list stats: stats dict """ target_name = feature_transforms.get_target_name(features) csv_header = [col['name'] for col in schema] if not keep_target: csv_header.remove(target_name) def export_fn(estimator, export_dir_base, checkpoint_path=None, eval_result=None): with ops.Graph().as_default() as g: contrib_variables.create_global_step(g) input_ops = feature_transforms.build_csv_serving_tensors_for_training_step( args.analysis, features, schema, stats, keep_target) model_fn_ops = estimator._call_model_fn( input_ops.features, None, model_fn_lib.ModeKeys.INFER) output_fetch_tensors = make_prediction_output_tensors( args=args, features=features, input_ops=input_ops, model_fn_ops=model_fn_ops, keep_target=keep_target) # Don't use signature_def_utils.predict_signature_def as that renames # tensor names if there is only 1 input/output tensor! signature_inputs = { key: tf.saved_model.utils.build_tensor_info(tensor) for key, tensor in six.iteritems(input_ops.default_inputs) } signature_outputs = { key: tf.saved_model.utils.build_tensor_info(tensor) for key, tensor in six.iteritems(output_fetch_tensors) } signature_def_map = { 'serving_default': signature_def_utils.build_signature_def( signature_inputs, signature_outputs, tf.saved_model.signature_constants.PREDICT_METHOD_NAME) } if not checkpoint_path: # Locate the latest checkpoint checkpoint_path = saver.latest_checkpoint(estimator._model_dir) if not checkpoint_path: raise ValueError("Couldn't find trained model at %s." % estimator._model_dir) export_dir = saved_model_export_utils.get_timestamped_export_dir( export_dir_base) if (model_fn_ops.scaffold is not None and model_fn_ops.scaffold.saver is not None): saver_for_restore = model_fn_ops.scaffold.saver else: saver_for_restore = saver.Saver(sharded=True) with tf_session.Session('') as session: saver_for_restore.restore(session, checkpoint_path) init_op = control_flow_ops.group( variables.local_variables_initializer(), resources.initialize_resources( resources.shared_resources()), tf.tables_initializer()) # Perform the export builder = saved_model_builder.SavedModelBuilder(export_dir) builder.add_meta_graph_and_variables( session, [tag_constants.SERVING], signature_def_map=signature_def_map, assets_collection=ops.get_collection( ops.GraphKeys.ASSET_FILEPATHS), legacy_init_op=init_op) builder.save(False) # Add the extra assets if assets_extra: assets_extra_path = os.path.join( compat.as_bytes(export_dir), compat.as_bytes('assets.extra')) for dest_relative, source in assets_extra.items(): dest_absolute = os.path.join( compat.as_bytes(assets_extra_path), compat.as_bytes(dest_relative)) dest_path = os.path.dirname(dest_absolute) file_io.recursive_create_dir(dest_path) file_io.copy(source, dest_absolute) # only keep the last 3 models saved_model_export_utils.garbage_collect_exports(export_dir_base, exports_to_keep=3) # save the last model to the model folder. # export_dir_base = A/B/intermediate_models/ if keep_target: final_dir = os.path.join(args.job_dir, 'evaluation_model') else: final_dir = os.path.join(args.job_dir, 'model') if file_io.is_directory(final_dir): file_io.delete_recursively(final_dir) file_io.recursive_create_dir(final_dir) recursive_copy(export_dir, final_dir) return export_dir if keep_target: intermediate_dir = 'intermediate_evaluation_models' else: intermediate_dir = 'intermediate_prediction_models' return export_strategy.ExportStrategy(intermediate_dir, export_fn)
def make_best_model_export_strategy(serving_input_fn, exports_to_keep=1, compare_fn=None, default_output_alternative_key=None, strip_default_attrs=False): # pylint: disable=line-too-long """Creates an custom ExportStrategy for use with tf.contrib.learn.Experiment. Args: serving_input_fn: a function that takes no arguments and returns an `InputFnOps`. exports_to_keep: an integer indicating how many historical best models need to be preserved. compare_fn: a function that select the 'best' candidate from a dictionary of evaluation result keyed by corresponding checkpoint path. default_output_alternative_key: the key for default serving signature for multi-headed inference graphs. strip_default_attrs: Boolean. If `True`, default-valued attributes will be removed from the NodeDefs. For a detailed guide, see [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). Returns: An ExportStrategy that can be passed to the Experiment constructor. """ # pylint: enable=line-too-long best_model_export_strategy = make_export_strategy( serving_input_fn, exports_to_keep=exports_to_keep, default_output_alternative_key=default_output_alternative_key, strip_default_attrs=strip_default_attrs) best_model_selector = BestModelSelector(compare_fn) def export_fn(estimator, export_dir_base, checkpoint_path, eval_result=None): """Exports the given Estimator as a SavedModel. Args: estimator: the Estimator to export. export_dir_base: A string containing a directory to write the exported graph and checkpoints. checkpoint_path: The checkpoint path to export. If None (the default), the most recent checkpoint found within the model directory is chosen. eval_result: placehold args matching the call signature of ExportStrategy. Returns: The string path to the exported directory. """ if not checkpoint_path: # TODO (b/67425018): switch to id:804 gh:805 # checkpoint_path = estimator.latest_checkpoint() # as soon as contrib is cleaned up and we can thus be sure that # estimator is a tf.estimator.Estimator and not a # tf.contrib.learn.Estimator checkpoint_path = saver.latest_checkpoint(estimator.model_dir) export_checkpoint_path, export_eval_result = best_model_selector.update( checkpoint_path, eval_result) if export_checkpoint_path and export_eval_result is not None: checkpoint_base = os.path.basename(export_checkpoint_path) export_dir = os.path.join(export_dir_base, checkpoint_base) return best_model_export_strategy.export(estimator, export_dir, export_checkpoint_path, export_eval_result) else: return '' return export_strategy.ExportStrategy('best_model', export_fn)
def make_export_strategy(args, keep_target, assets_extra, features, schema): """Makes prediction graph that takes json input. Args: args: command line args keep_target: If ture, target column is returned in prediction graph. Target column must also exist in input data assets_extra: other fiels to copy to the output folder job_dir: root job folder features: features dict schema: schema list """ target_name = get_target_name(features) raw_metadata = metadata_io.read_metadata( os.path.join(args.analysis_output_dir, RAW_METADATA_DIR)) csv_header = [col['name'] for col in schema] if not keep_target: csv_header.remove(target_name) def export_fn(estimator, export_dir_base, checkpoint_path=None, eval_result=None): with ops.Graph().as_default() as g: contrib_variables.create_global_step(g) input_ops = input_fn_maker.build_default_transforming_serving_input_fn( raw_metadata=raw_metadata, transform_savedmodel_dir=os.path.join(args.analysis_output_dir, TRANSFORM_FN_DIR), raw_label_keys=[target_name], raw_feature_keys=csv_header, convert_scalars_to_vectors=True)() model_fn_ops = estimator._call_model_fn( input_ops.features, None, model_fn_lib.ModeKeys.INFER) output_fetch_tensors = make_prediction_output_tensors( args=args, features=features, input_ops=input_ops, model_fn_ops=model_fn_ops, keep_target=keep_target) signature_def_map = { 'serving_default': signature_def_utils.predict_signature_def( input_ops.default_inputs, output_fetch_tensors) } if not checkpoint_path: # Locate the latest checkpoint checkpoint_path = saver.latest_checkpoint(estimator._model_dir) if not checkpoint_path: raise ValueError("Couldn't find trained model at %s." % estimator._model_dir) export_dir = saved_model_export_utils.get_timestamped_export_dir( export_dir_base) with tf_session.Session('') as session: variables.local_variables_initializer() data_flow_ops.tables_initializer() saver_for_restore = saver.Saver(variables.global_variables(), sharded=True) saver_for_restore.restore(session, checkpoint_path) init_op = control_flow_ops.group( variables.local_variables_initializer(), data_flow_ops.tables_initializer()) # Perform the export builder = saved_model_builder.SavedModelBuilder(export_dir) builder.add_meta_graph_and_variables( session, [tag_constants.SERVING], signature_def_map=signature_def_map, assets_collection=ops.get_collection( ops.GraphKeys.ASSET_FILEPATHS), legacy_init_op=init_op) builder.save(False) # Add the extra assets if assets_extra: assets_extra_path = os.path.join( compat.as_bytes(export_dir), compat.as_bytes('assets.extra')) for dest_relative, source in assets_extra.items(): dest_absolute = os.path.join( compat.as_bytes(assets_extra_path), compat.as_bytes(dest_relative)) dest_path = os.path.dirname(dest_absolute) file_io.recursive_create_dir(dest_path) file_io.copy(source, dest_absolute) # only keep the last 3 models saved_model_export_utils.garbage_collect_exports(export_dir_base, exports_to_keep=3) # save the last model to the model folder. # export_dir_base = A/B/intermediate_models/ if keep_target: final_dir = os.path.join(args.job_dir, 'evaluation_model') else: final_dir = os.path.join(args.job_dir, 'model') if file_io.is_directory(final_dir): file_io.delete_recursively(final_dir) file_io.recursive_create_dir(final_dir) recursive_copy(export_dir, final_dir) return export_dir if keep_target: intermediate_dir = 'intermediate_evaluation_models' else: intermediate_dir = 'intermediate_prediction_models' return export_strategy.ExportStrategy(intermediate_dir, export_fn)