Ejemplo n.º 1
0
    def test_python_command_execution(self):
        """Test command line execution."""
        Shell.call('rm -rf tmp')
        Shell.call('mkdir tmp')
        filename = '{}.java'.format(self.tmp_fn)
        cp_src = os.path.join('tmp', filename)
        with open(cp_src, 'w') as f:
            porter = Porter(self.estimator)
            out = porter.export(method_name='predict', class_name=self.tmp_fn)
            f.write(out)
        # $ javac tmp/Tmp.java
        cmd = ' '.join(['javac', cp_src])
        Shell.call(cmd)

        # Rename estimator for comparison:
        filename = '{}_2.java'.format(self.tmp_fn)
        cp_dest = os.path.join('tmp', filename)
        # $ mv tmp/Brain.java tmp/Brain_2.java
        cmd = ' '.join(['mv', cp_src, cp_dest])
        Shell.call(cmd)

        # Dump estimator:
        filename = '{}.pkl'.format(self.tmp_fn)
        pkl_path = os.path.join('tmp', filename)
        joblib.dump(self.estimator, pkl_path)

        # Port estimator:
        cmd = 'python -m sklearn_porter.cli.__main__ -i {}' \
              ' --class_name Brain'.format(pkl_path)
        Shell.call(cmd)
        # Compare file contents:
        equal = filecmp.cmp(cp_src, cp_dest)

        self.assertEqual(equal, True)
Ejemplo n.º 2
0
    def test_python_command_execution(self):
        """Test command line execution."""
        Shell.call('rm -rf tmp')
        Shell.call('mkdir tmp')
        filename = '{}.java'.format(self.tmp_fn)
        cp_src = os.path.join('tmp', filename)
        with open(cp_src, 'w') as f:
            porter = Porter(self.estimator)
            out = porter.export(method_name='predict', class_name=self.tmp_fn)
            f.write(out)
        # $ javac tmp/Tmp.java
        cmd = ' '.join(['javac', cp_src])
        Shell.call(cmd)

        # Rename estimator for comparison:
        filename = '{}_2.java'.format(self.tmp_fn)
        cp_dest = os.path.join('tmp', filename)
        # $ mv tmp/Brain.java tmp/Brain_2.java
        cmd = ' '.join(['mv', cp_src, cp_dest])
        Shell.call(cmd)

        # Dump estimator:
        filename = '{}.pkl'.format(self.tmp_fn)
        pkl_path = os.path.join('tmp', filename)
        joblib.dump(self.estimator, pkl_path)

        # Port estimator:
        cmd = 'python -m sklearn_porter.cli.__main__ {}' \
              ' --class_name Brain'.format(pkl_path)
        Shell.call(cmd)
        # Compare file contents:
        equal = filecmp.cmp(cp_src, cp_dest)

        self.assertEqual(equal, True)
Ejemplo n.º 3
0
def main():
    args = parse_args(sys.argv[1:])

    # Check input data:
    input_path = str(args.get('input'))
    if not (input_path.endswith('.pkl') or input_path.endswith('.mdl')) or not os.path.isfile(input_path):
        error = 'No valid estimator in pickle format was found.'
        sys.exit('Error: {}'.format(error))

    # Load data:
    from sklearn.externals import joblib
    estimator = joblib.load(input_path)

    # Determine the target programming language:
    language = str(args.get('language'))  # with default language
    languages = ['c', 'java', 'js', 'go', 'php', 'ruby']
    for key in languages:
        if args.get(key):  # found explicit assignment
            language = key
            break

    # Define destination path:
    dest_dir = str(args.get('output'))
    if dest_dir == '' or not os.path.isdir(dest_dir):
        dest_dir = input_path.split(os.sep)
        del dest_dir[-1]
        dest_dir = os.sep.join(dest_dir)

    # Port estimator:
    try:
        class_name = args.get('class_name')
        method_name = args.get('method_name')
        with_export = bool(args.get('export'))
        with_checksum = bool(args.get('checksum'))
        porter = Porter(estimator, language=language)
        output = porter.export(class_name=class_name,
                               method_name=method_name,
                               export_dir=dest_dir,
                               export_data=with_export,
                               export_append_checksum=with_checksum,
                               num_format = float_formatter,
                               details=True)
    except Exception as exception:
        # Catch any exception and exit the process:
        sys.exit('Error: {}'.format(str(exception)))
    else:
        # Print transpiled estimator to the console:
        if bool(args.get('pipe', False)):
            print(output.get('estimator'))
            sys.exit(0)

        only_data = bool(args.get('data'))
        if not only_data:
            filename = output.get('filename')
            dest_path = dest_dir + os.sep + filename
            # Save transpiled estimator:
            with open(dest_path, 'w') as file_:
                file_.write(output.get('estimator'))
Ejemplo n.º 4
0
 def test_commands_generation_for_c(self):
     comp_cmd, exec_cmd = Porter._get_commands('mdl.c', 'mdl', 'c')
     self.assertEqual(comp_cmd, 'gcc mdl.c -lm -o mdl')
     self.assertEqual(exec_cmd, './mdl')
Ejemplo n.º 5
0
 def test_commands_generation_for_php(self):
     comp_cmd, exec_cmd = Porter._get_commands('Mdl.php', 'Mdl', 'php')
     self.assertEqual(comp_cmd, None)
     self.assertEqual(exec_cmd, 'php -f Mdl.php')
Ejemplo n.º 6
0
 def test_commands_generation_for_java(self):
     comp_cmd, exec_cmd = Porter._get_commands('Mdl.java', 'Mdl', 'java')
     self.assertEqual(comp_cmd, 'javac Mdl.java')
     self.assertEqual(exec_cmd, 'java -classpath . Mdl')
Ejemplo n.º 7
0
 def test_commands_generation_for_java(self):
     comp_cmd, exec_cmd = Porter._get_commands('Mdl.java', 'Mdl', 'java')
     self.assertEqual(comp_cmd, 'javac Mdl.java')
     self.assertEqual(exec_cmd, 'java -classpath . Mdl')
Ejemplo n.º 8
0
 def test_filename_generation_for_php(self):
     language = 'php'
     self.assertEqual(Porter._get_filename('mdl', language), 'Mdl.php')
     self.assertEqual(Porter._get_filename(' mdl ', language), 'Mdl.php')
     self.assertEqual(Porter._get_filename('MDL', language), 'MDL.php')
Ejemplo n.º 9
0
 def test_porter_args_method(self):
     """Test invalid method name."""
     self.assertRaises(AttributeError,
                       lambda: Porter(self.estimator, method='invalid'))
Ejemplo n.º 10
0
 def test_filename_generation_for_java(self):
     language = 'java'
     self.assertEqual(Porter._get_filename('mdl', language), 'Mdl.java')
     self.assertEqual(Porter._get_filename(' mdl ', language), 'Mdl.java')
     self.assertEqual(Porter._get_filename('MDL', language), 'MDL.java')
Ejemplo n.º 11
0
 def test_commands_generation_for_go(self):
     comp_cmd, exec_cmd = Porter._get_commands('mdl.go', 'mdl', 'go')
     self.assertEqual(comp_cmd, 'go build -o mdl mdl.go')
     self.assertEqual(exec_cmd, './mdl')
Ejemplo n.º 12
0
 def test_commands_generation_for_js(self):
     comp_cmd, exec_cmd = Porter._get_commands('mdl.js', 'mdl', 'js')
     self.assertEqual(comp_cmd, None)
     self.assertEqual(exec_cmd, 'node mdl.js')
Ejemplo n.º 13
0
 def test_commands_generation_for_c(self):
     comp_cmd, exec_cmd = Porter._get_commands('mdl.c', 'mdl', 'c')
     self.assertEqual(comp_cmd, 'gcc mdl.c -lm -o mdl')
     self.assertEqual(exec_cmd, './mdl')
Ejemplo n.º 14
0
 def test_commands_generation_for_php(self):
     comp_cmd, exec_cmd = Porter._get_commands('Mdl.php', 'Mdl', 'php')
     self.assertEqual(comp_cmd, None)
     self.assertEqual(exec_cmd, 'php -f Mdl.php')
Ejemplo n.º 15
0
 def test_commands_generation_for_js(self):
     comp_cmd, exec_cmd = Porter._get_commands('mdl.js', 'mdl', 'js')
     self.assertEqual(comp_cmd, None)
     self.assertEqual(exec_cmd, 'node mdl.js')
Ejemplo n.º 16
0
 def test_commands_generation_for_go(self):
     comp_cmd, exec_cmd = Porter._get_commands('mdl.go', 'mdl', 'go')
     self.assertEqual(comp_cmd, 'go build -o mdl mdl.go')
     self.assertEqual(exec_cmd, './mdl')
Ejemplo n.º 17
0
 def test_filename_generation_for_php(self):
     language = 'php'
     self.assertEqual(Porter._get_filename('mdl', language), 'Mdl.php')
     self.assertEqual(Porter._get_filename(' mdl ', language), 'Mdl.php')
     self.assertEqual(Porter._get_filename('MDL', language), 'MDL.php')
Ejemplo n.º 18
0
 def test_porter_args_language(self):
     """Test invalid programming language."""
     self.assertRaises(AttributeError,
                       lambda: Porter(self.estimator, language='invalid'))
Ejemplo n.º 19
0
 def test_filename_generation_for_c(self):
     language = 'c'
     self.assertEqual(Porter._get_filename('mdl', language), 'mdl.c')
     self.assertEqual(Porter._get_filename(' mdl ', language), 'mdl.c')
     self.assertEqual(Porter._get_filename('MDL', language), 'MDL.c')
Ejemplo n.º 20
0
 def test_filename_generation_for_java(self):
     language = 'java'
     self.assertEqual(Porter._get_filename('mdl', language), 'Mdl.java')
     self.assertEqual(Porter._get_filename(' mdl ', language), 'Mdl.java')
     self.assertEqual(Porter._get_filename('MDL', language), 'MDL.java')
Ejemplo n.º 21
0
 def test_filename_generation_for_ruby(self):
     language = 'ruby'
     self.assertEqual(Porter._get_filename('mdl', language), 'mdl.rb')
     self.assertEqual(Porter._get_filename(' mdl ', language), 'mdl.rb')
     self.assertEqual(Porter._get_filename('MDL', language), 'MDL.rb')
Ejemplo n.º 22
0
 def test_filename_generation_for_c(self):
     language = 'c'
     self.assertEqual(Porter._get_filename('mdl', language), 'mdl.c')
     self.assertEqual(Porter._get_filename(' mdl ', language), 'mdl.c')
     self.assertEqual(Porter._get_filename('MDL', language), 'MDL.c')
Ejemplo n.º 23
0
 def test_filename_generation_for_ruby(self):
     language = 'ruby'
     self.assertEqual(Porter._get_filename('mdl', language), 'mdl.rb')
     self.assertEqual(Porter._get_filename(' mdl ', language), 'mdl.rb')
     self.assertEqual(Porter._get_filename('MDL', language), 'MDL.rb')