def get_predictions(model, serialization_dir, reader, device): """ Generates predictions from a trained model on a reader """ dev = reader.read('raw_data/drop/drop_dataset_dev.json') vocab = Vocabulary.from_files(join(serialization_dir, 'vocabulary')) iterator = BasicIterator(batch_size=1) iterator.index_with(vocab) dev_iter = iterator(dev, num_epochs=1) dev_batches = [batch for batch in dev_iter] dev_batches = move_to_device(dev_batches, device) predictions = {} with torch.no_grad(): for batch in tqdm(dev_batches): out = model(**batch) assert len(out['question_id']) == 1 assert len(out['answer']) == 1 query_id = out['question_id'][0] if 'value' in out['answer'][0]: prediction = out['answer'][0]['value'] elif 'count' in out['answer'][0]: prediction = out['answer'][0]['count'].item() else: raise ValueError() predictions[query_id] = prediction print(model.get_metrics()) torch.cuda.empty_cache() return predictions
def test_token_to_indices_batch_size_2(self): # Checks whether or not AllenNLP overwrites padding logic. # Test with batch size 2 with different lengths. batch_sentences = [ "I have a good dog called Killi .", "He fetches me stuff ." ] instances = [] for sentence in batch_sentences: tokens = [Token(token) for token in sentence.split()] field = TextField(tokens, {'test_chunky': self.indexer}) instance = Instance({"elmo_chunky": field}) instances.append(instance) vocab = Vocabulary() iterator = BasicIterator() iterator.index_with(vocab) for batch in iterator(instances, num_epochs=1, shuffle=False): break assert (batch['elmo_chunky']['mask'] > 0).sum(dim=1).tolist() == [8, 5] assert (batch['elmo_chunky']['seg_map'] > -1).sum(dim=1).tolist() == [ 8, 5 ] assert ((batch['elmo_chunky']['character_ids'] > 0).sum( dim=2) == 50).sum(dim=1).tolist() == [8, 5]
def main(config: str, model_th: str, dataset: str, hypo_file: str, ref_file: str, batch_size: int, no_gpu: bool): logger = logging.getLogger(__name__) logger.info("Loading configuration parameters") params = Params.from_file(config) vocab_params = params.pop("vocabulary") vocab = Vocabulary.from_params(vocab_params) reader_params = params.pop("dataset_reader") reader_name = reader_params.pop("type") # reader_params["lazy"] = True # make sure we do not load the entire dataset reader = DatasetReader.by_name(reader_name).from_params(reader_params) logger.info("Reading data from {}".format(dataset)) data = reader.read(dataset) iterator = BasicIterator(batch_size=batch_size) iterator.index_with(vocab) batches = iterator._create_batches(data, shuffle=False) logger.info("Loading model") model_params = params.pop("model") model_name = model_params.pop("type") model = Model.by_name(model_name).from_params(model_params, vocab=vocab) if not no_gpu: model.cuda(0) with open(model_th, 'rb') as f: if no_gpu: state_dict = torch.load(f, map_location=torch.device('cpu')) else: state_dict = torch.load(f) model.load_state_dict(state_dict) predictor = Seq2SeqPredictor(model, reader) model.eval() with open(hypo_file, 'w') as hf, open(ref_file, 'w') as rf: logger.info("Generating predictions") for sample in tqdm(batches): s = list(sample) pred = predictor.predict_batch_instance(s) for inst, p in zip(s, pred): print( " ".join(p["predicted_tokens"][0]), file=hf ) print( " ".join(t.text for t in inst["target_tokens"][1:-1]), file=rf )
def main(config: str, model_th: str, dataset: str, out_file): logger = logging.getLogger(__name__) logger.info("Loading model and data") params = Params.from_file(config) vocab_params = params.pop("vocabulary") vocab = Vocabulary.from_params(vocab_params) reader_params = Params({ "source_token_indexers": { "tokens": { "type": "single_id", "namespace": "tokens" } }, "target_namespace": "tokens" }) # reader_name = reader_params.pop("type") # reader_params["lazy"] = True # make sure we do not load the entire dataset reader = UnsupervisedBTReader.from_params(reader_params) logger.info("Reading data from {}".format(dataset)) data = reader.read(dataset) iterator = BasicIterator(batch_size=32) iterator.index_with(vocab) batches = iterator._create_batches(data, shuffle=False) logger.info("Loading model") model_params = params.pop("model") model_name = model_params.pop("type") model = Model.by_name(model_name).from_params(model_params, vocab=vocab) model.cuda(0) with open(model_th, 'rb') as f: model.load_state_dict(torch.load(f)) predictor = Seq2SeqPredictor(model, reader) model.eval() line_id = 0 writer = csv.writer(out_file, delimiter="\t") logger.info("Generating predictions") for sample in tqdm(batches): s = list(sample) pred = predictor.predict_batch_instance(s) for inst, p in zip(s, pred): writer.writerow((line_id, " ".join( (t.text for t in inst["source_tokens"][1:-1])), " ".join(p["predicted_tokens"][0]))) line_id += 1
def main(): args = parse_args() checkpoint_path = Path(args.checkpoint) checkpoint_dir = checkpoint_path.parent params_path = checkpoint_dir / 'params.json' vocab_dir = checkpoint_dir / 'vocab' params = Params.from_file(params_path) train_params, model_params = params.pop('train'), params.pop('model') tokenizer = WordTokenizer( start_tokens=['<s>'], end_tokens=['</s>'], ) token_indexer = SingleIdTokenIndexer(lowercase_tokens=True) dataset_reader = SnliReader(tokenizer=tokenizer, token_indexers={'tokens': token_indexer}) valid_dataset = dataset_reader.read(train_params.pop('valid_dataset_path')) if not args.test_dataset: test_dataset_path = train_params.pop('test_dataset_path') else: test_dataset_path = args.test_dataset test_dataset = dataset_reader.read(test_dataset_path) if args.only_label: test_dataset = [ d for d in test_dataset if d.fields['label'].label == args.only_label ] vocab = Vocabulary.from_files(vocab_dir) random.shuffle(valid_dataset) model_params['token_embedder']['pretrained_file'] = None model = SNLIModel(params=model_params, vocab=vocab) model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'), strict=False) model.to(args.cuda_device) model.eval() torch.set_grad_enabled(False) iterator = BasicIterator(batch_size=32) iterator.index_with(vocab) for dataset in (valid_dataset, test_dataset): generator = iterator(dataset, shuffle=False, num_epochs=1) model.get_metrics(reset=True) for batch in tqdm(generator): batch = move_to_device(batch, cuda_device=args.cuda_device) model(premise=batch['premise'], hypothesis=batch['hypothesis'], label=batch['label']) metrics = model.get_metrics() pprint(metrics)
def main(): args = parse_args() checkpoint_path = Path(args.checkpoint) checkpoint_dir = checkpoint_path.parent params_path = checkpoint_dir / 'params.json' vocab_dir = checkpoint_dir / 'vocab' params = Params.from_file(params_path) train_params, model_params = params.pop('train'), params.pop('model') tokenizer = WordTokenizer( start_tokens=['<s>'], end_tokens=['</s>'], ) token_indexer = SingleIdTokenIndexer(lowercase_tokens=True) dataset_reader = SnliReader(tokenizer=tokenizer, token_indexers={'tokens': token_indexer}) valid_dataset = dataset_reader.read(train_params.pop('valid_dataset_path')) vocab = Vocabulary.from_files(vocab_dir) random.shuffle(valid_dataset) model_params['token_embedder']['pretrained_file'] = None model = SNLIModel(params=model_params, vocab=vocab) model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'), strict=False) model.eval() iterator = BasicIterator(batch_size=1) iterator.index_with(vocab) generator = iterator(valid_dataset) for i in range(10): batch = next(generator) label_token_to_index = vocab.get_token_to_index_vocabulary('labels') print('----') print(' '.join( model.convert_to_readable_text(batch['premise']['tokens'])[0])) for label, label_index in label_token_to_index.items(): label_tensor = torch.tensor([label_index]) enc_embs = model.embed(batch['premise']['tokens']) enc_mask = get_text_field_mask(batch['premise']) enc_hidden = model.encode(inputs=enc_embs, mask=enc_mask, drop_start_token=True) code, kld = model.sample_code_and_compute_kld(enc_hidden) generated = model.generate(code=code, label=label_tensor, max_length=enc_mask.sum(1) * 2, beam_size=10, lp_alpha=args.lp_alpha) text = model.convert_to_readable_text(generated[:, 0])[0] print(label) print(' '.join(text))
def main(): args = parse_args() checkpoint_path = Path(args.checkpoint) checkpoint_dir = checkpoint_path.parent params_path = checkpoint_dir / 'params.json' vocab_dir = checkpoint_dir / 'vocab' params = Params.from_file(params_path) train_params, model_params = params.pop('train'), params.pop('model') tokenizer = WordTokenizer(start_tokens=['<s>'], end_tokens=['</s>'],) token_indexer = SingleIdTokenIndexer(lowercase_tokens=True) dataset_reader = SnliReader( tokenizer=tokenizer, token_indexers={'tokens': token_indexer}) valid_dataset = dataset_reader.read( train_params.pop('valid_dataset_path')) vocab = Vocabulary.from_files(vocab_dir) model_params['token_embedder']['pretrained_file'] = None model = SNLIModel(params=model_params, vocab=vocab) model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'), strict=False) model.to(args.device) model.eval() iterator = BasicIterator(batch_size=args.batch_size) iterator.index_with(vocab) generator = iterator(valid_dataset, num_epochs=1, shuffle=False) label_index_to_token = vocab.get_index_to_token_vocabulary('labels') out_file = open(args.out, 'w') for batch in tqdm(generator): premise_tokens = batch['premise']['tokens'] enc_embs = model.embed(premise_tokens.to(args.device)) enc_mask = get_text_field_mask(batch['premise']).to(args.device) enc_hidden = model.encode(inputs=enc_embs, mask=enc_mask, drop_start_token=True) code, kld = model.sample_code_and_compute_kld(enc_hidden) pre_text = model.convert_to_readable_text(premise_tokens[:, 1:]) label_tensor = batch['label'].to(args.device) generated = model.generate( code=code, label=label_tensor, max_length=25, beam_size=10, lp_alpha=args.lp_alpha) text = model.convert_to_readable_text(generated[:, 0]) for pre_text_b, text_b, label_index_b in zip(pre_text, text, label_tensor): obj = {'sentence1': ' '.join(pre_text_b), 'sentence2': ' '.join(text_b), 'gold_label': label_index_to_token[label_index_b.item()]} out_file.write(json.dumps(obj)) out_file.write('\n')
def main(config: str, model_th: str, dataset: str, seed: int): logger = logging.getLogger(__name__) logger.info("Loading model and data") params = Params.from_file(config) vocab_params = params.pop("vocabulary") vocab = Vocabulary.from_params(vocab_params) reader_params = params.pop("dataset_reader") reader_name = reader_params.pop("type") reader_params["lazy"] = True # make sure we do not load the entire dataset reader = DatasetReader.by_name(reader_name).from_params(reader_params) data = reader.read(dataset) iterator = BasicIterator(batch_size=10) iterator.index_with(vocab) batches = iterator._create_batches(data, shuffle=False) model_params = params.pop("model") model_name = model_params.pop("type") model = Model.by_name(model_name).from_params(model_params, vocab=vocab) # model.cuda(cuda_device) with open(model_th, 'rb') as f: model.load_state_dict(torch.load(f)) predictor = Seq2SeqPredictor(model, reader) model.eval() logger.info("Generating predictions") random.seed(seed) samples = [] for b in batches: samples.append(b) if random.random() > 0.6: break sample = list(random.choice(samples)) pred = predictor.predict_batch_instance(sample) for inst, p in zip(sample, pred): print() print("SOURCE:", " ".join([t.text for t in inst["source_tokens"]])) print("GOLD:", " ".join([t.text for t in inst["target_tokens"]])) print("GEN:", p["predicted_tokens"])
def measure_perplexity(self, file_name: str, batch_size: int = 50, is_including_unk: bool = True): unk_index = self.vocab.get_token_index(DEFAULT_OOV_TOKEN) ppl_state = PerplexityState(unk_index, is_including_unk) batch_number = 0 iterator = BasicIterator(batch_size=batch_size) iterator.index_with(self.vocab) dataset = self.reader.read(file_name) batches = iterator(dataset, num_epochs=1) for batch in batches: ppl_state = self._measure_perplexity_on_batch(batch, ppl_state) batch_number += 1 logger.info( "Measure_perplexity: {} sentences processed, {}".format( batch_number * batch_size, ppl_state)) return ppl_state
def get_predictions(abert, reader, device): """ Generates predictions from a trained model on a reader """ dev = reader.read('raw_data/drop/drop_dataset_dev.json') iterator = BasicIterator(batch_size=1) iterator.index_with(Vocabulary()) dev_iter = iterator(dev, num_epochs=1) dev_batches = [batch for batch in dev_iter] dev_batches = move_to_device(dev_batches, device) predictions = {} with torch.no_grad(): for batch in tqdm(dev_batches): out = abert(**batch) assert len(out['question_id']) == 1 assert len(out['answer']) == 1 query_id = out['question_id'][0] prediction = out['answer'][0]['value'] predictions[query_id] = prediction torch.cuda.empty_cache() return predictions
def main(): args = parse_args() params = Params.from_file(args.params) save_dir = Path(args.save) save_dir.mkdir(parents=True) params.to_file(save_dir / 'params.json') train_params, model_params = params.pop('train'), params.pop('model') random_seed = train_params.pop_int('random_seed', 2019) torch.manual_seed(random_seed) random.seed(random_seed) log_filename = save_dir / 'stdout.log' sys.stdout = TeeLogger(filename=log_filename, terminal=sys.stdout, file_friendly_terminal_output=False) sys.stderr = TeeLogger(filename=log_filename, terminal=sys.stderr, file_friendly_terminal_output=False) tokenizer = WordTokenizer( start_tokens=['<s>'], end_tokens=['</s>'], ) token_indexer = SingleIdTokenIndexer(lowercase_tokens=True) dataset_reader = SnliReader(tokenizer=tokenizer, token_indexers={'tokens': token_indexer}) train_labeled_dataset_path = train_params.pop('train_labeled_dataset_path') train_unlabeled_dataset_path = train_params.pop( 'train_unlabeled_dataset_path', None) train_labeled_dataset = dataset_reader.read(train_labeled_dataset_path) train_labeled_dataset = filter_dataset_by_length( dataset=train_labeled_dataset, max_length=30) if train_unlabeled_dataset_path is not None: train_unlabeled_dataset = dataset_reader.read( train_unlabeled_dataset_path) train_unlabeled_dataset = filter_dataset_by_length( dataset=train_unlabeled_dataset, max_length=30) else: train_unlabeled_dataset = [] valid_dataset = dataset_reader.read(train_params.pop('valid_dataset_path')) vocab = Vocabulary.from_instances( instances=train_labeled_dataset + train_unlabeled_dataset, max_vocab_size=train_params.pop_int('max_vocab_size', None)) vocab.save_to_files(save_dir / 'vocab') labeled_batch_size = train_params.pop_int('labeled_batch_size') unlabeled_batch_size = train_params.pop_int('unlabeled_batch_size') labeled_iterator = BasicIterator(batch_size=labeled_batch_size) unlabeled_iterator = BasicIterator(batch_size=unlabeled_batch_size) labeled_iterator.index_with(vocab) unlabeled_iterator.index_with(vocab) if not train_unlabeled_dataset: unlabeled_iterator = None model = SNLIModel(params=model_params, vocab=vocab) optimizer = optim.Adam(params=model.parameters(), lr=train_params.pop_float('lr', 1e-3)) summary_writer = SummaryWriter(log_dir=save_dir / 'log') kl_anneal_rate = train_params.pop_float('kl_anneal_rate', None) if kl_anneal_rate is None: kl_weight_scheduler = None else: kl_weight_scheduler = (lambda step: min(1.0, kl_anneal_rate * step)) model.kl_weight = 0.0 trainer = Trainer(model=model, optimizer=optimizer, labeled_iterator=labeled_iterator, unlabeled_iterator=unlabeled_iterator, train_labeled_dataset=train_labeled_dataset, train_unlabeled_dataset=train_unlabeled_dataset, validation_dataset=valid_dataset, summary_writer=summary_writer, serialization_dir=save_dir, num_epochs=train_params.pop('num_epochs', 50), iters_per_epoch=len(train_labeled_dataset) // labeled_batch_size, write_summary_every=100, validate_every=2000, patience=2, clip_grad_max_norm=5, kl_weight_scheduler=kl_weight_scheduler, cuda_device=train_params.pop_int('cuda_device', 0), early_stop=train_params.pop_bool('early_stop', True)) trainer.train()
def main(): args = parse_args() params = Params.from_file(args.params) save_dir = Path(args.save) save_dir.mkdir(parents=True) params.to_file(save_dir / 'params.json') train_params, model_params = params.pop('train'), params.pop('model') random_seed = train_params.pop_int('random_seed', 2019) torch.manual_seed(random_seed) random.seed(random_seed) log_filename = save_dir / 'stdout.log' sys.stdout = TeeLogger(filename=log_filename, terminal=sys.stdout, file_friendly_terminal_output=False) sys.stderr = TeeLogger(filename=log_filename, terminal=sys.stderr, file_friendly_terminal_output=False) tokenizer = WordTokenizer(word_splitter=JustSpacesWordSplitter(), start_tokens=['<s>'], end_tokens=['</s>']) token_indexer = SingleIdTokenIndexer(lowercase_tokens=True) dataset_reader = QuoraParaphraseDatasetReader( tokenizer=tokenizer, token_indexers={'tokens': token_indexer}) train_labeled_dataset_path = train_params.pop('train_labeled_dataset_path') train_unlabeled_dataset_path = train_params.pop( 'train_unlabeled_dataset_path', None) train_labeled_dataset = dataset_reader.read(train_labeled_dataset_path) train_labeled_dataset = filter_dataset_by_length( dataset=train_labeled_dataset, max_length=35) if train_unlabeled_dataset_path is not None: train_unlabeled_dataset = dataset_reader.read( train_unlabeled_dataset_path) train_unlabeled_dataset = filter_dataset_by_length( dataset=train_unlabeled_dataset, max_length=35) else: train_unlabeled_dataset = [] valid_dataset = dataset_reader.read(train_params.pop('valid_dataset_path')) vocab = Vocabulary.from_instances( instances=train_labeled_dataset + train_unlabeled_dataset, max_vocab_size=train_params.pop_int('max_vocab_size', None)) vocab.save_to_files(save_dir / 'vocab') labeled_batch_size = train_params.pop_int('labeled_batch_size') unlabeled_batch_size = train_params.pop_int('unlabeled_batch_size') labeled_iterator = BasicIterator(batch_size=labeled_batch_size) unlabeled_iterator = BasicIterator(batch_size=unlabeled_batch_size) labeled_iterator.index_with(vocab) unlabeled_iterator.index_with(vocab) if not train_unlabeled_dataset: unlabeled_iterator = None model = SeparatedQuoraModel(params=model_params, vocab=vocab) optimizer = optim.Adam(params=model.parameters()) summary_writer = SummaryWriter(log_dir=save_dir / 'log') trainer = SeparatedLVMTrainer( model=model, optimizer=optimizer, labeled_iterator=labeled_iterator, unlabeled_iterator=unlabeled_iterator, train_labeled_dataset=train_labeled_dataset, train_unlabeled_dataset=train_unlabeled_dataset, validation_dataset=valid_dataset, summary_writer=summary_writer, serialization_dir=save_dir, num_epochs=train_params.pop('num_epochs', 50), iters_per_epoch=len(train_labeled_dataset) // labeled_batch_size, write_summary_every=100, validate_every=2000, patience=train_params.pop('patience', 2), clip_grad_max_norm=5, cuda_device=train_params.pop_int('cuda_device', 0)) trainer.train()
vocab = Vocabulary.from_files('./.vocab/snli_vocab') glove = Embedding(vocab.get_vocab_size(), 300) ### Choose and load model here model = RocktaschelEtAlAttention(vocab, glove, word_by_word=False).to("cuda") with open( './.serialization_data/C.E. Attention_Adam_32_0.1_0.0003_5e-05_True/best.th', 'rb') as f: model.load_state_dict(torch.load(f)) model.to('cuda') predictor = NLIPredictor(model=model, dataset_reader=t) iterator = BasicIterator(batch_size=32) iterator.index_with(vocab) final = evaluate(model, train_dataset, iterator, cuda_device=0, batch_weight_key=None) print(final) final = evaluate(model, val_dataset, iterator, cuda_device=0, batch_weight_key=None) print(final) final = evaluate(model, test_dataset, iterator,
def main(): args = parse_args() params = Params.from_file(args.params) save_dir = Path(args.save) save_dir.mkdir(parents=True) params.to_file(save_dir / 'params.json') train_params, model_params = params.pop('train'), params.pop('model') random_seed = train_params.pop_int('random_seed', 2019) torch.manual_seed(random_seed) random.seed(random_seed) log_filename = save_dir / 'stdout.log' sys.stdout = TeeLogger(filename=log_filename, terminal=sys.stdout, file_friendly_terminal_output=False) sys.stderr = TeeLogger(filename=log_filename, terminal=sys.stderr, file_friendly_terminal_output=False) tokenizer = WordTokenizer(word_splitter=JustSpacesWordSplitter(), start_tokens=['<s>'], end_tokens=['</s>']) token_indexer = SingleIdTokenIndexer(lowercase_tokens=True) dataset_reader = QuoraParaphraseDatasetReader( tokenizer=tokenizer, token_indexers={'tokens': token_indexer}) train_labeled_dataset = dataset_reader.read( train_params.pop('train_labeled_dataset_path')) train_unlabeled_dataset = dataset_reader.read( train_params.pop('train_unlabeled_dataset_path')) valid_dataset = dataset_reader.read(train_params.pop('valid_dataset_path')) train_labeled_dataset = filter_dataset_by_length( dataset=train_labeled_dataset, max_length=35) train_unlabeled_dataset = filter_dataset_by_length( dataset=train_unlabeled_dataset, max_length=35) vocab = Vocabulary.from_instances( instances=train_labeled_dataset + train_unlabeled_dataset, max_vocab_size=train_params.pop_int('max_vocab_size', None)) vocab.save_to_files(save_dir / 'vocab') labeled_batch_size = train_params.pop_int('labeled_batch_size') unlabeled_batch_size = train_params.pop_int('unlabeled_batch_size') labeled_iterator = BasicIterator(batch_size=labeled_batch_size) unlabeled_iterator = BasicIterator(batch_size=unlabeled_batch_size) labeled_iterator.index_with(vocab) unlabeled_iterator.index_with(vocab) pretrained_checkpoint_path = train_params.pop('pretrained_checkpoint_path', None) model = QuoraModel(params=model_params, vocab=vocab) if pretrained_checkpoint_path: model.load_state_dict( torch.load(pretrained_checkpoint_path, map_location='cpu')) model.add_finetune_parameters( con_autoweight=train_params.pop_bool('con_autoweight', False), con_y_weight=train_params.pop_float('con_y_weight'), con_z_weight=train_params.pop_float('con_z_weight'), con_z2_weight=train_params.pop_float('con_z2_weight')) main_optimizer = optim.Adam(params=model.finetune_main_parameters( exclude_generator=train_params.pop_bool('exclude_generator')), lr=train_params.pop_float('lr', 1e-3)) aux_optimizer = optim.Adam(params=model.finetune_aux_parameters(), lr=train_params.pop_float('aux_lr', 1e-4)) summary_writer = SummaryWriter(log_dir=save_dir / 'log') kl_anneal_rate = train_params.pop_float('kl_anneal_rate', None) if kl_anneal_rate is None: kl_weight_scheduler = None else: kl_weight_scheduler = (lambda step: min(1.0, kl_anneal_rate * step)) model.kl_weight = 0.0 gumbel_anneal_rate = train_params.pop_float('gumbel_anneal_rate', None) if gumbel_anneal_rate is None: gumbel_temperature_scheduler = None else: gumbel_temperature_scheduler = ( lambda step: max(0.1, 1.0 - gumbel_anneal_rate * step)) model.gumbel_temperature = 1.0 trainer = FineTuningTrainer( model=model, main_optimizer=main_optimizer, aux_optimizer=aux_optimizer, labeled_iterator=labeled_iterator, unlabeled_iterator=unlabeled_iterator, train_labeled_dataset=train_labeled_dataset, train_unlabeled_dataset=train_unlabeled_dataset, validation_dataset=valid_dataset, summary_writer=summary_writer, serialization_dir=save_dir, num_epochs=train_params.pop_int('num_epochs', 50), iters_per_epoch=len(train_labeled_dataset) // labeled_batch_size, write_summary_every=100, validate_every=1000, patience=train_params.pop_int('patience', 5), clip_grad_max_norm=train_params.pop_float('grad_max_norm', 5.0), kl_weight_scheduler=kl_weight_scheduler, gumbel_temperature_scheduler=gumbel_temperature_scheduler, cuda_device=train_params.pop_int('cuda_device', 0)) trainer.train()