Ejemplo n.º 1
0
def main():
    args = args_parser.parse_args()
    if args.sub_command == 'onnx2tnn':
        onnx_path = parse_path.parse_path(args.onnx_path)
        output_dir = parse_path.parse_path(args.output_dir)
        version = args.version
        optimize = args.optimize
        half = args.half
        onnx_path = parse_path.parse_path(onnx_path)
        output_dir = parse_path.parse_path(output_dir)
        onnx2tnn.convert(onnx_path, output_dir, version, optimize, half)
    elif args.sub_command == 'caffe2tnn':
        proto_path = parse_path.parse_path(args.proto_path)
        model_path = parse_path.parse_path(args.model_path)
        output_dir = parse_path.parse_path(args.output_dir)
        version = args.version
        optimize = args.optimize
        half = args.half
        caffe2tnn.convert(proto_path, model_path, output_dir, version, optimize, half)

    elif args.sub_command == 'tf2tnn':
        tf_path = parse_path.parse_path(args.tf_path)
        output_dir = parse_path.parse_path(args.output_dir)
        input_names = args.input_names
        output_names = args.output_names
        version = args.version
        optimize = args.optimize
        half = args.half
        tf2tnn.convert(tf_path, input_names, output_names, output_dir, version, optimize, half)
    else:
        print("Do not support convert!")
Ejemplo n.º 2
0
def main():
    args = args_parser.parse_args()
    if args.sub_command == 'onnx2tnn':
        onnx_path = parse_path.parse_path(args.onnx_path)
        output_dir = parse_path.parse_path(args.output_dir)
        input_names = args.input_names
        version = args.version
        optimize = args.optimize
        half = args.half
        align = args.align
        input_file = args.input_file_path
        ref_file = args.refer_file_path
        onnx_path = parse_path.parse_path(onnx_path)
        output_dir = parse_path.parse_path(output_dir)
        input_file = parse_path.parse_path(input_file)
        ref_file = parse_path.parse_path(ref_file)
        onnx2tnn.convert(onnx_path, output_dir, version, optimize, half, align,
                         input_file, ref_file, input_names)

    elif args.sub_command == 'caffe2tnn':
        proto_path = parse_path.parse_path(args.proto_path)
        model_path = parse_path.parse_path(args.model_path)
        output_dir = parse_path.parse_path(args.output_dir)
        version = args.version
        optimize = args.optimize
        half = args.half
        align = args.align
        input_file = args.input_file_path
        ref_file = args.refer_file_path
        input_file = parse_path.parse_path(input_file)
        ref_file = parse_path.parse_path(ref_file)
        caffe2tnn.convert(proto_path, model_path, output_dir, version,
                          optimize, half, align, input_file, ref_file)

    elif args.sub_command == 'tf2tnn':
        tf_path = parse_path.parse_path(args.tf_path)
        output_dir = parse_path.parse_path(args.output_dir)
        input_names = args.input_names
        output_names = args.output_names
        version = args.version
        optimize = args.optimize
        half = args.half
        align = args.align
        not_fold_const = args.not_fold_const
        input_file = args.input_file_path
        ref_file = args.refer_file_path
        input_file = parse_path.parse_path(input_file)
        ref_file = parse_path.parse_path(ref_file)
        tf2tnn.convert(tf_path, input_names, output_names, output_dir, version,
                       optimize, half, align, not_fold_const, input_file,
                       ref_file)
    else:
        print("Do not support convert!")
def main():
    parser = args_parser.parse_args()
    args = parser.parse_args()

    debug_mode: bool = args.debug
    if debug_mode is True:
        logging.basicConfig(level=logging.DEBUG, format='')
    else:
        logging.basicConfig(level=logging.INFO, format='')
    logging.info("\n{}  convert model, please wait a moment {}\n".format(
        "-" * 10, "-" * 10))

    if args.sub_command == 'onnx2tnn':
        onnx_path = parse_path.parse_path(args.onnx_path)
        output_dir = parse_path.parse_path(args.output_dir)
        version = args.version
        optimize = args.optimize
        half = args.half
        align = args.align
        input_file = args.input_file_path
        ref_file = args.refer_file_path
        onnx_path = parse_path.parse_path(onnx_path)
        output_dir = parse_path.parse_path(output_dir)
        input_file = parse_path.parse_path(input_file)
        ref_file = parse_path.parse_path(ref_file)
        input_names = None
        if args.input_names is not None:
            input_names = ""
            for item in args.input_names:
                input_names += (item + " ")
        try:
            onnx2tnn.convert(onnx_path,
                             output_dir,
                             version,
                             optimize,
                             half,
                             align,
                             input_file,
                             ref_file,
                             input_names,
                             debug_mode=debug_mode)
        except Exception as err:
            logging.error("Conversion to  tnn failed :(\n")
            logging.error(err)

    elif args.sub_command == 'caffe2tnn':
        proto_path = parse_path.parse_path(args.proto_path)
        model_path = parse_path.parse_path(args.model_path)
        output_dir = parse_path.parse_path(args.output_dir)
        version = args.version
        optimize = args.optimize
        half = args.half
        align = args.align
        input_file = args.input_file_path
        ref_file = args.refer_file_path
        input_file = parse_path.parse_path(input_file)
        ref_file = parse_path.parse_path(ref_file)
        try:
            caffe2tnn.convert(proto_path,
                              model_path,
                              output_dir,
                              version,
                              optimize,
                              half,
                              align,
                              input_file,
                              ref_file,
                              debug_mode=debug_mode)
        except Exception as err:
            logging.error("Conversion to  tnn failed :(\n")
            logging.error(err)

    elif args.sub_command == 'tf2tnn':
        tf_path = parse_path.parse_path(args.tf_path)
        output_dir = parse_path.parse_path(args.output_dir)
        input_names = args.input_names
        output_names = args.output_names
        version = args.version
        optimize = args.optimize
        half = args.half
        align = args.align
        not_fold_const = args.not_fold_const
        input_file = args.input_file_path
        ref_file = args.refer_file_path
        input_file = parse_path.parse_path(input_file)
        ref_file = parse_path.parse_path(ref_file)
        try:
            tf2tnn.convert(tf_path,
                           input_names,
                           output_names,
                           output_dir,
                           version,
                           optimize,
                           half,
                           align,
                           not_fold_const,
                           input_file,
                           ref_file,
                           debug_mode=debug_mode)
        except Exception as err:
            logging.error("\nConversion to  tnn failed :(\n")
            logging.error(err)
    elif args.sub_command == 'tflite2tnn':
        tf_path = parse_path.parse_path(args.tf_path)
        output_dir = parse_path.parse_path(args.output_dir)
        version = args.version
        align = args.align
        input_file = args.input_file_path
        ref_file = args.refer_file_path
        input_file = parse_path.parse_path(input_file)
        ref_file = parse_path.parse_path(ref_file)
        try:
            tflite2tnn.convert(tf_path,
                               output_dir,
                               version,
                               align,
                               input_file,
                               ref_file,
                               debug_mode=debug_mode)
        except Exception as err:
            logging.error("\n Conversion to  tnn failed :(\n")
            logging.error(err)
    elif args.sub_command is None:
        parser.print_help()
    else:
        logging.info("Do not support convert!")
Ejemplo n.º 4
0
def main():
    parser = args_parser.parse_args()
    args = parser.parse_args()

    logging.info("\n{}  convert model, please wait a moment {}\n".format(
        "-" * 10, "-" * 10))

    if args.sub_command == 'onnx2tnn':
        onnx_path = parse_path.parse_path(args.onnx_path)
        output_dir = parse_path.parse_path(args.output_dir)
        input_names = args.input_names
        version = args.version
        optimize = args.optimize
        half = args.half
        align = args.align
        input_file = args.input_file_path
        ref_file = args.refer_file_path
        onnx_path = parse_path.parse_path(onnx_path)
        output_dir = parse_path.parse_path(output_dir)
        input_file = parse_path.parse_path(input_file)
        ref_file = parse_path.parse_path(ref_file)

        try:
            onnx2tnn.convert(onnx_path, output_dir, version, optimize, half,
                             align, input_file, ref_file, input_names)
        except Exception as err:
            logging.error("Conversion to  tnn failed :(\n")

    elif args.sub_command == 'caffe2tnn':
        proto_path = parse_path.parse_path(args.proto_path)
        model_path = parse_path.parse_path(args.model_path)
        output_dir = parse_path.parse_path(args.output_dir)
        version = args.version
        optimize = args.optimize
        half = args.half
        align = args.align
        input_file = args.input_file_path
        ref_file = args.refer_file_path
        input_file = parse_path.parse_path(input_file)
        ref_file = parse_path.parse_path(ref_file)
        try:
            caffe2tnn.convert(proto_path, model_path, output_dir, version,
                              optimize, half, align, input_file, ref_file)
        except Exception as err:
            logging.error("Conversion to  tnn failed :(\n")

    elif args.sub_command == 'tf2tnn':
        tf_path = parse_path.parse_path(args.tf_path)
        output_dir = parse_path.parse_path(args.output_dir)
        input_names = args.input_names
        output_names = args.output_names
        version = args.version
        optimize = args.optimize
        half = args.half
        align = args.align
        not_fold_const = args.not_fold_const
        input_file = args.input_file_path
        ref_file = args.refer_file_path
        input_file = parse_path.parse_path(input_file)
        ref_file = parse_path.parse_path(ref_file)

        try:
            tf2tnn.convert(tf_path, input_names, output_names, output_dir,
                           version, optimize, half, align, not_fold_const,
                           input_file, ref_file)
        except Exception as err:
            logging.error("\nConversion to  tnn failed :(\n")
    elif args.sub_command is None:
        parser.print_help()
    else:
        logging.info("Do not support convert!")
Ejemplo n.º 5
0
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger

from models.etm import ETM
from utils.args_parser import parse_args, flatten_cfg, mkdir, save_yaml, newest
from utils.constants import Cte
import argparse

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('-f',
                    '--config_file',
                    default='params/trainer.yaml',
                    type=str)
args = parser.parse_args()

cfg = parse_args(args.config_file)

pl.seed_everything(cfg['seed'])

# %% Load dataset
data_module = None
if cfg['dataset']['name'] == Cte.NG:
    from datasets.news_group import NewsGroupDataModule
    data_module = NewsGroupDataModule(**cfg['dataset']['params'])

assert data_module is not None
cfg['model']['params']['vocab_size'] = data_module.vocab_size

# %% Load model
model = ETM(**cfg['model']['params'])