def test_python_command_execution(self): """Test command line execution.""" Shell.call('rm -rf tmp') Shell.call('mkdir tmp') filename = '{}.java'.format(self.tmp_fn) cp_src = os.path.join('tmp', filename) with open(cp_src, 'w') as f: porter = Porter(self.estimator) out = porter.export(method_name='predict', class_name=self.tmp_fn) f.write(out) # $ javac tmp/Tmp.java cmd = ' '.join(['javac', cp_src]) Shell.call(cmd) # Rename estimator for comparison: filename = '{}_2.java'.format(self.tmp_fn) cp_dest = os.path.join('tmp', filename) # $ mv tmp/Brain.java tmp/Brain_2.java cmd = ' '.join(['mv', cp_src, cp_dest]) Shell.call(cmd) # Dump estimator: filename = '{}.pkl'.format(self.tmp_fn) pkl_path = os.path.join('tmp', filename) joblib.dump(self.estimator, pkl_path) # Port estimator: cmd = 'python -m sklearn_porter.cli.__main__ -i {}' \ ' --class_name Brain'.format(pkl_path) Shell.call(cmd) # Compare file contents: equal = filecmp.cmp(cp_src, cp_dest) self.assertEqual(equal, True)
def main(): args = parse_args(sys.argv[1:]) # Check input data: input_path = str(args.get('input')) if not (input_path.endswith('.pkl') or input_path.endswith('.mdl')) or not os.path.isfile(input_path): error = 'No valid estimator in pickle format was found.' sys.exit('Error: {}'.format(error)) # Load data: from sklearn.externals import joblib estimator = joblib.load(input_path) # Determine the target programming language: language = str(args.get('language')) # with default language languages = ['c', 'java', 'js', 'go', 'php', 'ruby'] for key in languages: if args.get(key): # found explicit assignment language = key break # Define destination path: dest_dir = str(args.get('output')) if dest_dir == '' or not os.path.isdir(dest_dir): dest_dir = input_path.split(os.sep) del dest_dir[-1] dest_dir = os.sep.join(dest_dir) # Port estimator: try: class_name = args.get('class_name') method_name = args.get('method_name') with_export = bool(args.get('export')) with_checksum = bool(args.get('checksum')) porter = Porter(estimator, language=language) output = porter.export(class_name=class_name, method_name=method_name, export_dir=dest_dir, export_data=with_export, export_append_checksum=with_checksum, num_format = float_formatter, details=True) except Exception as exception: # Catch any exception and exit the process: sys.exit('Error: {}'.format(str(exception))) else: # Print transpiled estimator to the console: if bool(args.get('pipe', False)): print(output.get('estimator')) sys.exit(0) only_data = bool(args.get('data')) if not only_data: filename = output.get('filename') dest_path = dest_dir + os.sep + filename # Save transpiled estimator: with open(dest_path, 'w') as file_: file_.write(output.get('estimator'))
def test_porter_args_language(self): """Test invalid programming language.""" self.assertRaises(AttributeError, lambda: Porter(self.estimator, language='invalid'))
def test_porter_args_method(self): """Test invalid method name.""" self.assertRaises(AttributeError, lambda: Porter(self.estimator, method='invalid'))