def prepare(batch): NUM_SENTS = 10 RETRIEVER = data.DocRetriever(WIKI_DOCS_PATH, tfidf_get_docs) # RETRIEVER = data.OracleDocRetriever("data/wiki.db") # oracle = data.OracleDocRetriever("data/wiki.db") em = Embedder() sel = Selector(em) SELECTOR = sel return data.collate( batch, NUM_SENTS, RETRIEVER, SELECTOR, oracle_doc_ret=isinstance(RETRIEVER, data.OracleDocRetriever), )
def sample_qxy_debug_data(self, data): xs, ys = collate(data) device = self.px.embed.weight.device xs, ys = xs.to(device), ys.to(device) with torch.no_grad(): logprob_qxy = self.qxy.logprob(ys, xs).cpu().numpy().tolist() logprob_px = self.px.logprob(xs).cpu().numpy().tolist() logprob_pyx = -self.pyx(xs, ys, per_instance=True) xs = xs.cpu().numpy().transpose().tolist() ys = ys.cpu().numpy().transpose().tolist() sxs = self.print_tokens(self.vocab_x, xs) sys = self.print_tokens(self.vocab_y, ys) plist = [] for (x, y, lqxy, lpx, lpyx) in zip(sxs, sys, logprob_qxy, logprob_px, logprob_pyx): plist.append( f"x: {x} \t y: {y} \t logpqxy: {lqxy} \t logpx: {lpx} \t logpyx {lpyx}" ) return plist
def main(): args = parser.parse_args() np.random.seed(args.seed) torch.random.manual_seed(args.seed) dictionary = Dictionary.load(args.vocab_path) dictionary.truncate(args.max_vocab_size) test_dataset = SummaryDataset(os.path.join(args.data_path, 'test'), dictionary=dictionary, max_article_size=args.max_source_positions, max_summary_size=args.max_target_positions, max_elements=10 if args.debug else None) test_sampler = SequentialSampler(test_dataset) test_dataloader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, sampler=test_sampler, \ num_workers=args.num_workers, collate_fn=lambda samples: collate(samples, dictionary.pad_index, dictionary.eos_index)) summarization_task = SummarizationTask(args, dictionary) if args.model == 'transformer': args.local_transformer = False # transformer.base_architecture(args) transformer.transformer_small(args) model = transformer.TransformerModel.build_model( args, summarization_task).to(args.device) elif args.model == 'lstm': lstm.base_architecture(args) args.criterion = None model = lstm.LSTMModel.build_model(args, summarization_task).to(args.device) elif args.model == 'lightconv': args.encoder_conv_type = 'lightweight' args.decoder_conv_type = 'lightweight' args.weight_softmax = True lightconv.lightconv_small(args) model = lightconv.LightConvModel.build_model( args, summarization_task).to(args.device) elif args.model == 'localtransformer': args.local_transformer = True # transformer.base_architecture(args) transformer.transformer_small(args) model = transformer.TransformerModel.build_model( args, summarization_task).to(args.device) elif args.model == 'transformer_conv': # args.local_transformer = True # transformer.base_architecture(args) transformer_conv.transformer_conv_small(args) model = transformer_conv.TransformerConvModel.build_model( args, summarization_task).to(args.device) elif args.model == 'transformer_mc': # args.local_transformer = True # transformer.base_architecture(args) transformer_mc.transformer_mc_small(args) model = transformer_mc.TransformerMCModel.build_model( args, summarization_task).to(args.device) if args.model_path: model.load_state_dict(torch.load(args.model_path)) generator = SequenceGenerator(dictionary, beam_size=args.beam_size, max_len_b=args.max_target_positions) avg_rouge_score = defaultdict(float) for batch_idx, batch in enumerate(test_dataloader): src_tokens = batch['net_input']['src_tokens'].to(args.device) src_lengths = batch['net_input']['src_lengths'].to(args.device) references = batch['target'] references = [ remove_special_tokens(ref, dictionary) for ref in references ] references = [dictionary.string(ref) for ref in references] # encoder_input = {'src_tokens': src_tokens, 'src_lengths': src_lengths} hypos = generator.generate([model], { 'net_input': { 'src_tokens': src_tokens, 'src_lengths': src_lengths } }) hypotheses = [hypo[0]['tokens'] for hypo in hypos] assert len(hypotheses) == src_tokens.size()[0] # = size of the batch hypotheses = [ remove_special_tokens(hypo, dictionary) for hypo in hypotheses ] hypotheses = [dictionary.string(hyp) for hyp in hypotheses] if args.verbose: print("\nComparison references/hypotheses:") for ref, hypo in zip(references, hypotheses): print(ref) print(hypo) print() avg_rouge_score_batch = compute_rouge.compute_score( references, hypotheses) print("rouge for this batch:", avg_rouge_score_batch) compute_rouge.update(avg_rouge_score, batch_idx * args.batch_size, avg_rouge_score_batch, len(hypotheses)) return avg_rouge_score
def main(): args = parser.parse_args() np.random.seed(args.seed) torch.random.manual_seed(args.seed) dictionary = Dictionary.load(args.vocab_path) summarization_task = SummarizationTask(args, dictionary) if args.model == 'transformer': args.local_transformer = False # transformer.base_architecture(args) transformer.transformer_small(args) model = transformer.TransformerModel.build_model( args, summarization_task).to(args.device) elif args.model == 'lstm': lstm.base_architecture(args) args.criterion = None model = lstm.LSTMModel.build_model(args, summarization_task).to(args.device) elif args.model == 'lightconv': args.encoder_conv_type = 'lightweight' args.decoder_conv_type = 'lightweight' args.weight_softmax = True lightconv.lightconv_small(args) model = lightconv.LightConvModel.build_model( args, summarization_task).to(args.device) elif args.model == 'localtransformer': args.local_transformer = True # transformer.base_architecture(args) transformer.transformer_small(args) model = transformer.TransformerModel.build_model( args, summarization_task).to(args.device) elif args.model == 'transformer_conv': # args.local_transformer = True # transformer.base_architecture(args) transformer_conv.transformer_conv_small(args) model = transformer_conv.TransformerConvModel.build_model( args, summarization_task).to(args.device) elif args.model == 'transformer_mc': # args.local_transformer = True # transformer.base_architecture(args) transformer_mc.transformer_mc_small(args) model = transformer_mc.TransformerMCModel.build_model( args, summarization_task).to(args.device) total_speeds = [] len_articles = [ a for a in range(args.min_len_article, args.max_len_article + args.len_step, args.len_step) ] for len_article in len_articles: dataset = DummySummaryDataset(args.total_sents, len_article, args.len_summaries, dictionary) dataloader = DataLoader( dataset=dataset, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=lambda samples: collate(samples, dictionary.pad_index, dictionary.eos_index)) print("MODEL {} ARTICLE LEN {} SUMMARY LEN {}".format( args.model, len_article, args.len_summaries)) total_sents, total_time = \ run_pass(model, dataloader, dictionary.pad_index, args.device, args.encoder_only) print("MODEL {} SPEED {} T/SENT {}".format(args.model, total_sents / total_time, total_time / total_sents)) print() total_speeds.append(total_sents / total_time) if not os.path.isdir(os.path.join(args.save_dir, args.model)): os.makedirs(os.path.join(args.save_dir, args.model)) filename = "full_model_" if not args.encoder_only else "encoder_only_" filename += str(args.min_len_article) + "_" + str( args.max_len_article) + "_" filename += 'total_speeds.npy' np.save(os.path.join(args.save_dir, args.model, filename), np.array(list(zip(len_articles, total_speeds))))
def main(): args = parser.parse_args() np.random.seed(args.seed) torch.random.manual_seed(args.seed) dictionary = Dictionary.load(args.vocab_path) dictionary.truncate(args.max_vocab_size) train_dataset = SummaryDataset( os.path.join(args.data_path, 'train'), dictionary=dictionary, max_article_size=args.max_source_positions - 2, max_summary_size=args.max_target_positions - 2, max_elements=20 if args.debug else None) val_dataset = SummaryDataset( os.path.join(args.data_path, 'val'), dictionary=dictionary, max_article_size=args.max_source_positions - 2, max_summary_size=args.max_target_positions - 2, max_elements=20 if args.debug else None) # TODO maybe change the sampler to group texts of similar lengths train_sampler = RandomSampler(train_dataset, replacement=False) train_dataloader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, sampler=train_sampler, \ num_workers=args.num_workers, collate_fn=lambda samples: collate(samples, dictionary.pad_index, dictionary.eos_index)) val_dataloader = DataLoader( dataset=val_dataset, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=lambda samples: collate(samples, dictionary.pad_index, dictionary.eos_index)) dataloaders = {'train': train_dataloader, 'val': val_dataloader} summarization_task = SummarizationTask(args, dictionary) if args.model == 'transformer': args.local_transformer = False # transformer.base_architecture(args) transformer.transformer_small(args) model = transformer.TransformerModel.build_model( args, summarization_task).to(args.device) elif args.model == 'lstm': lstm.base_architecture(args) args.criterion = None model = lstm.LSTMModel.build_model(args, summarization_task).to(args.device) elif args.model == 'lightconv': args.encoder_conv_type = 'lightweight' args.decoder_conv_type = 'lightweight' args.weight_softmax = True lightconv.lightconv_small(args) model = lightconv.LightConvModel.build_model( args, summarization_task).to(args.device) elif args.model == 'localtransformer': args.local_transformer = True # transformer.base_architecture(args) transformer.transformer_small(args) model = transformer.TransformerModel.build_model( args, summarization_task).to(args.device) elif args.model == 'transformer_conv': # args.local_transformer = True # transformer.base_architecture(args) transformer_conv.transformer_conv_small(args) model = transformer_conv.TransformerConvModel.build_model( args, summarization_task).to(args.device) elif args.model == 'transformer_mc': # args.local_transformer = True # transformer.base_architecture(args) transformer_mc.transformer_mc_small(args) model = transformer_mc.TransformerMCModel.build_model( args, summarization_task).to(args.device) criterion = nn.CrossEntropyLoss(reduction='mean') if args.optimizer == 'sgd': optimizer = torch.optim.SGD(params=model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) else: optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr, weight_decay=args.weight_decay) lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( optimizer, gamma=args.exponential_decay) if args.flag == "": args.flag = 'train_transformer_{date:%Y-%m-%d_%H:%M:%S}'.format( date=datetime.datetime.now()) if not os.path.isdir(os.path.join(args.save_dir, args.flag)): os.makedirs(os.path.join(args.save_dir, args.flag)) print("Launching training with: \noptimizer: {}\n lr: {}\n \ exponential_decay: {}\n momentum: {}\n weight_decay: {}\n batch_size: {}\n". format(args.optimizer, args.lr, args.exponential_decay, args.momentum, args.weight_decay, args.batch_size)) train(dataloaders, model, criterion, optimizer, lr_scheduler, args.device, dictionary.pad_index, save_dir=os.path.join(args.save_dir, args.flag), n_epochs=args.n_epochs, save=args.save, debug=args.debug, dictionary=dictionary)
def main(): np.random.seed(FLAGS.seed) torch.manual_seed(FLAGS.seed) dataset, val_dataset = data.get_dataset(ENV) model = Model(ENV, dataset) if FLAGS.gpu: model.cuda() trainer = Trainer(model) first_epoch = 0 skip_flat = False if FLAGS.resume_epoch is not None: first_epoch, skip_flat = _restore(model, trainer) loader = torch_data.DataLoader( dataset, batch_size=FLAGS.n_batch_examples, shuffle=True, num_workers=2, collate_fn=lambda items: data.collate(items, dataset)) vtrain_loader = torch_data.DataLoader( dataset, batch_size=FLAGS.n_batch_examples, shuffle=False, num_workers=1, sampler=list(range(FLAGS.n_val_examples)), collate_fn=lambda items: data.collate(items, dataset, val=True)) val_loader = torch_data.DataLoader( val_dataset, batch_size=FLAGS.n_batch_examples, shuffle=False, num_workers=1, collate_fn=lambda items: data.collate(items, val_dataset, val=True)) @hlog.fn('exec') def execute(): for d, l, n, f in [ (dataset, vtrain_loader, 'train', lambda m: m.act), (dataset, vtrain_loader, 'train_h', lambda m: m.act_hier), (val_dataset, val_loader, 'val', lambda m: m.act), (val_dataset, val_loader, 'val_h', lambda m: m.act_hier), ]: interact.execute(model, d, l, ENV, n, f, dump=True) @hlog.fn('train', timer=False) def train_step(batch): if FLAGS.train_flat_on_parse: batch_parses = [parses[task.task_id] for task in batch.tasks] else: batch_parses = None seq_batch = data.SeqBatch.of(batch, dataset, parses=batch_parses) step_batch = data.StepBatch.of(batch, dataset, parses=batch_parses) step_loss, step_stats = model.score_step(step_batch) seq_loss, seq_stats = model.score_seq(seq_batch) stats = Stats() + step_stats + seq_stats loss = step_loss + seq_loss trainer.step(train_loss=loss) return stats @hlog.fn('hier', timer=False) def hier_step(batch, parses): batch_parses = [parses[task.task_id] for task in batch.tasks] seq_batch = data.StepBatch.of( batch, dataset, hier=True, parses=batch_parses) hier_loss, hier_stats = model.score_hier(seq_batch) stats = Stats() + hier_stats trainer.step(hier_loss=hier_loss) return stats @hlog.fn('val', timer=False) def val_step(): stats = Stats() loss = 0 for i_batch, batch in enumerate(val_loader): seq_batch = data.SeqBatch.of(batch, dataset) step_batch = data.StepBatch.of(batch, dataset) step_loss, step_stats = model.score_step(step_batch) seq_loss, seq_stats = model.score_seq(seq_batch) stats += Stats() + step_stats + seq_stats loss += seq_loss + step_loss trainer.step(val_loss=loss) _log(stats) #execute() with hlog.task('learn'): i_iter = 0 parses = defaultdict(list) for i_epoch in hlog.loop( 'epoch_%03d', counter=range(first_epoch, FLAGS.n_epochs)): # flat step n_flat_passes = 0 if skip_flat else FLAGS.n_flat_passes skip_flat = False for i_pass in hlog.loop('flat_%03d', range(n_flat_passes)): stats = Stats() for i_batch, batch in hlog.loop( 'batch_%05d', enumerate(loader), timer=False): stats += train_step(batch) _log(stats) val_step() if (i_pass + 1) % FLAGS.n_exec == 0: execute() _save(model, trainer, i_epoch, FLAT_TAG) # parse if (i_epoch + 1) % FLAGS.n_parse == 0: parses = {} with hlog.task('parse'): for i_batch, batch in hlog.loop( 'batch_%05d', enumerate(loader), timer=False): seq_batch = data.SeqBatch.of(batch, dataset) batch_parses = model.parse(seq_batch) assert not any(k in parses for k in batch_parses) parses.update(batch_parses) # hier step for i_pass in hlog.loop('hier_%03d', range(FLAGS.n_hier_passes)): stats = Stats() for i_batch, batch in hlog.loop( 'batch_%05d', enumerate(loader), timer=False): stats += hier_step(batch, parses) _log(stats) if (i_pass + 1) % FLAGS.n_exec == 0: execute() _save(model, trainer, i_epoch, HIER_TAG)