def main(config: str, model_th: str, dataset: str, hypo_file: str, ref_file: str, batch_size: int, no_gpu: bool): logger = logging.getLogger(__name__) logger.info("Loading configuration parameters") params = Params.from_file(config) vocab_params = params.pop("vocabulary") vocab = Vocabulary.from_params(vocab_params) reader_params = params.pop("dataset_reader") reader_name = reader_params.pop("type") # reader_params["lazy"] = True # make sure we do not load the entire dataset reader = DatasetReader.by_name(reader_name).from_params(reader_params) logger.info("Reading data from {}".format(dataset)) data = reader.read(dataset) iterator = BasicIterator(batch_size=batch_size) iterator.index_with(vocab) batches = iterator._create_batches(data, shuffle=False) logger.info("Loading model") model_params = params.pop("model") model_name = model_params.pop("type") model = Model.by_name(model_name).from_params(model_params, vocab=vocab) if not no_gpu: model.cuda(0) with open(model_th, 'rb') as f: if no_gpu: state_dict = torch.load(f, map_location=torch.device('cpu')) else: state_dict = torch.load(f) model.load_state_dict(state_dict) predictor = Seq2SeqPredictor(model, reader) model.eval() with open(hypo_file, 'w') as hf, open(ref_file, 'w') as rf: logger.info("Generating predictions") for sample in tqdm(batches): s = list(sample) pred = predictor.predict_batch_instance(s) for inst, p in zip(s, pred): print( " ".join(p["predicted_tokens"][0]), file=hf ) print( " ".join(t.text for t in inst["target_tokens"][1:-1]), file=rf )
def main(config: str, model_th: str, dataset: str, out_file): logger = logging.getLogger(__name__) logger.info("Loading model and data") params = Params.from_file(config) vocab_params = params.pop("vocabulary") vocab = Vocabulary.from_params(vocab_params) reader_params = Params({ "source_token_indexers": { "tokens": { "type": "single_id", "namespace": "tokens" } }, "target_namespace": "tokens" }) # reader_name = reader_params.pop("type") # reader_params["lazy"] = True # make sure we do not load the entire dataset reader = UnsupervisedBTReader.from_params(reader_params) logger.info("Reading data from {}".format(dataset)) data = reader.read(dataset) iterator = BasicIterator(batch_size=32) iterator.index_with(vocab) batches = iterator._create_batches(data, shuffle=False) logger.info("Loading model") model_params = params.pop("model") model_name = model_params.pop("type") model = Model.by_name(model_name).from_params(model_params, vocab=vocab) model.cuda(0) with open(model_th, 'rb') as f: model.load_state_dict(torch.load(f)) predictor = Seq2SeqPredictor(model, reader) model.eval() line_id = 0 writer = csv.writer(out_file, delimiter="\t") logger.info("Generating predictions") for sample in tqdm(batches): s = list(sample) pred = predictor.predict_batch_instance(s) for inst, p in zip(s, pred): writer.writerow((line_id, " ".join( (t.text for t in inst["source_tokens"][1:-1])), " ".join(p["predicted_tokens"][0]))) line_id += 1
def main(config: str, model_th: str, dataset: str, seed: int): logger = logging.getLogger(__name__) logger.info("Loading model and data") params = Params.from_file(config) vocab_params = params.pop("vocabulary") vocab = Vocabulary.from_params(vocab_params) reader_params = params.pop("dataset_reader") reader_name = reader_params.pop("type") reader_params["lazy"] = True # make sure we do not load the entire dataset reader = DatasetReader.by_name(reader_name).from_params(reader_params) data = reader.read(dataset) iterator = BasicIterator(batch_size=10) iterator.index_with(vocab) batches = iterator._create_batches(data, shuffle=False) model_params = params.pop("model") model_name = model_params.pop("type") model = Model.by_name(model_name).from_params(model_params, vocab=vocab) # model.cuda(cuda_device) with open(model_th, 'rb') as f: model.load_state_dict(torch.load(f)) predictor = Seq2SeqPredictor(model, reader) model.eval() logger.info("Generating predictions") random.seed(seed) samples = [] for b in batches: samples.append(b) if random.random() > 0.6: break sample = list(random.choice(samples)) pred = predictor.predict_batch_instance(sample) for inst, p in zip(sample, pred): print() print("SOURCE:", " ".join([t.text for t in inst["source_tokens"]])) print("GOLD:", " ".join([t.text for t in inst["target_tokens"]])) print("GEN:", p["predicted_tokens"])
def main(config: str): logger = logging.getLogger(__name__) logger.info("Loading model and data") params = Params.from_file(config) vocab_params = params.pop("vocabulary") vocab = Vocabulary.from_params(vocab_params) logger.info("Loading model") model_params = params.pop("model") model_name = model_params.pop("type") model = Model.by_name(model_name).from_params(model_params, vocab=vocab) print("Number of parameters:", count_parameters(model))