Ejemplo n.º 1
0
def convert_to_onnx(pth, length):
    cd = path.dirname(path.abspath(__file__))
    config_file = path.join(cd, 'config.json')
    config = ConfigParser.from_json(config_file)
    logger = config.get_logger('convert', 1)
    model = config.init_obj('arch', module_arch)
    state = torch.load(pth, map_location=torch.device('cpu'))
    state_dict = state['state_dict']
    model.load_state_dict(state_dict)
    model.eval()
    vocabulary = list('ACGT')
    encodes = np.eye(len(vocabulary))
    x = encodes[np.random.choice(encodes.shape[0], size=length)]
    input_tensor = torch.FloatTensor(np.expand_dims(x, axis=0))

    exported_onnx_file = pth.rsplit('.', 1)[0] + '.onnx'

    logger.info('Converting to ONNX model: {}'.format(exported_onnx_file))

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        torch.onnx.export(model,
                          input_tensor,
                          exported_onnx_file,
                          export_params=True,
                          opset_version=10,
                          do_constant_folding=True,
                          input_names=["input"],
                          output_names=["output"],
                          dynamic_axes={
                              "input": {
                                  0: "batch_size",
                                  1: "sequence"
                              },
                              "output": {
                                  0: "batch_size",
                                  1: "sequence"
                              }
                          })
Ejemplo n.º 2
0
    args.add_argument(
        '--chunk_size',
        default=1024,
        type=int,
        help=
        'chunk_size * threads reads to process per thread.(default: 1024) \n{}.'
        .format(
            'When chunk_size=1024 and threads=20, each process will load 1024 reads, in total consumming ~20G memory'
        ))

    args.add_argument('-v',
                      '--version',
                      action='version',
                      version='%(prog)s {version}'.format(version=__version__))

    if not isinstance(args, tuple):
        args = args.parse_args()
    if args.config is None:
        config_file = os.path.join(cd, 'config.json')
    else:
        config_file = args.config
    config = ConfigParser.from_json(config_file)

    os.environ['OMP_NUM_THREADS'] = '1'
    # os.environ['MKL_NUM_THREADS'] = '1'

    seq_pred = Predictor(config, args)
    seq_pred.load_model()
    seq_pred.run()