def train(conf): words, vectors = load_vectors(conf.embeddings) vectors = np.array(vectors) char_mapping = load_char_mapping(conf.chars_file) input_mapping = InputMapping(char_mapping, words, conf.word_length) model = BiLSTM(conf, characters=n_chars(char_mapping), pretrained=vectors) train, validation, pos_weight = input_mapping.load_dataset( conf.input_directory, conf.validation_split, conf.batch_size, conf.sequence_length) training = Training(model, conf, train, validation, pos_weight) training.run()
def create_processor(conf): config = load_config() if conf.embeddings is None: conf.embeddings = Path(config['sentences.wordEmbeddings']) if conf.chars_file is None: conf.chars_file = Path(config['sentences.charsFile']) if conf.hparams_file is None: conf.hparams_file = Path(config['sentences.hparamsFile']) if conf.model_file is None: conf.model_file = Path(config['sentences.modelFile']) logger.info('Loading hparams from: {}'.format(conf.hparams_file)) with conf.hparams_file.open('r') as f: d = yaml.load(f, Loader) class Hparams: pass hparams = Hparams() hparams.__dict__.update(d) logger.info('Loading word embeddings from: "{}"'.format(conf.embeddings)) words, vectors = load_vectors(conf.embeddings) vectors = np.array(vectors) logger.info('Loading characters from: {}'.format(conf.chars_file)) char_mapping = load_char_mapping(conf.chars_file) input_mapping = InputMapping(char_mapping, words, hparams.word_length) model = BiLSTM(hparams, n_chars(char_mapping), vectors) model = torch.jit.script(model) model.eval() logger.info('Loading model weights from: {}'.format(conf.model_file)) with conf.model_file.open('rb') as f: state_dict = torch.load(f) model.load_state_dict(state_dict) proc = SentenceProcessor(input_mapping, model) return proc
def create_processor(conf): torch.set_num_threads(1) torch.set_num_interop_threads(1) logging.basicConfig(level=logging.INFO) check_data(conf.download_data) config = load_config() if conf.embeddings is None: conf.embeddings = Path(config['sentences.wordEmbeddings']) if conf.chars_file is None: conf.chars_file = Path(config['sentences.charsFile']) if conf.hparams_file is None: conf.hparams_file = Path(config['sentences.hparamsFile']) if conf.model_file is None: conf.model_file = Path(config['sentences.modelFile']) if conf.torch_device is not None: device = conf.torch_device else: device = "cpu" if conf.force_cpu or not torch.cuda.is_available( ) else "cuda" device = torch.device(device) logger.info('Using torch device: "{}"'.format(repr(device))) logger.info('Loading hparams from: {}'.format(conf.hparams_file)) with conf.hparams_file.open('r') as f: d = yaml.load(f, Loader) class Hparams: pass hparams = Hparams() hparams.__dict__.update(d) logger.info('Loading word embeddings from: "{}"'.format(conf.embeddings)) words, vectors = load_vectors(conf.embeddings) vectors = np.array(vectors) logger.info('Loading characters from: {}'.format(conf.chars_file)) char_mapping = load_char_mapping(conf.chars_file) input_mapping = InputMapping(char_mapping, words, hparams.word_length) model = BiLSTM(hparams, n_chars(char_mapping), vectors) model.eval() model.to(device=device) model.share_memory() logger.info('Loading model weights from: {}'.format(conf.model_file)) with conf.model_file.open('rb') as f: state_dict = torch.load(f) model.load_state_dict(state_dict) torch.multiprocessing.set_start_method('fork') processor = SentenceProcessor(input_mapping, model, device) return processor