def main(args): train_df = pd.read_pickle(args.train_data) valid_df = pd.read_pickle(args.valid_data) tokenizer = Tokenizer() tokenizer.fit_word(train_df.repl_words.tolist()) train_sentences_idx = sentence_preprocessing(train_df, tokenizer) valid_sentences_idx = sentence_preprocessing(valid_df, tokenizer) bi_lm_model = BiLM(args.word_emb_size, args.lstm_unit_size, len(tokenizer.vocab_word)) if torch.cuda.device_count() > 1: print("Use", torch.cuda.device_count(), "GPUs.") bi_lm_model = torch.nn.DataParallel(bi_lm_model) elif torch.cuda.device_count() == 1: print("Use single GPU.") else: print("Use CPU.") bi_lm_model.to(device) bi_lm_model = train(bi_lm_model, train_sentences_idx, valid_sentences_idx, args.epochs, args.batch_size, args.early_stopping) torch.save(bi_lm_model.state_dict(), args.output)
def get_tokenizer(is_transfer, sentences=None): tokenizer = Tokenizer() tokenizer.vocab_tag = { '<PAD>': 0, 'B': 1, 'I': 2, 'O': 3, '<START>': 4, '<STOP>': 5 } if is_transfer: with open("../data/all_word_vocab.json", 'r') as f: tokenizer.vocab_word = json.load(f) else: tokenizer.fit_word(sentences) return tokenizer