def main(params): logging.info("Loading the datasets...") train_iter, dev_iter, test_iterator, DE, EN = load_dataset( params.data_path, params.train_batch_size, params.dev_batch_size) de_size, en_size = len(DE.vocab), len(EN.vocab) logging.info("[DE Vocab Size]: {}, [EN Vocab Size]: {}".format( de_size, en_size)) logging.info("- done.") params.src_vocab_size = de_size params.tgt_vocab_size = en_size params.sos_index = EN.vocab.stoi["<s>"] params.pad_token = EN.vocab.stoi["<pad>"] params.eos_index = EN.vocab.stoi["</s>"] params.itos = EN.vocab.itos params.SRC = DE params.TRG = EN # make the Seq2Seq model model = make_seq2seq_model(params) # default optimizer optimizer = optim.Adam(model.parameters(), lr=params.lr) if params.model_type == "Transformer": criterion = LabelSmoothingLoss(params.label_smoothing, params.tgt_vocab_size, params.pad_token).to(params.device) optimizer = ScheduledOptimizer(optimizer=optimizer, d_model=params.hidden_size, factor=2, n_warmup_steps=params.n_warmup_steps) scheduler = None else: criterion = nn.NLLLoss(reduction="sum", ignore_index=params.pad_token) scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, patience=params.patience, factor=.1, verbose=True) # intialize the Trainer trainer = Trainer(model, optimizer, scheduler, criterion, train_iter, dev_iter, params) if params.restore_file: restore_path = os.path.join(params.model_dir + "/checkpoints/", params.restore_file) logging.info("Restoring parameters from {}".format(restore_path)) Trainer.load_checkpoint(model, restore_path, optimizer) # train the model trainer.train()
def main(): args = parse_train_arg() task = task_dict[args.task] init_distributed_mode(args) logger = init_logger(args) if hasattr(args, 'base_model_name'): logger.warning('Argument base_model_name is deprecated! Use `--table-bert-extra-config` instead!') init_signal_handler() train_data_dir = args.data_dir / 'train' dev_data_dir = args.data_dir / 'dev' table_bert_config = task['config'].from_file( args.data_dir / 'config.json', **args.table_bert_extra_config) if args.is_master: args.output_dir.mkdir(exist_ok=True, parents=True) with (args.output_dir / 'train_config.json').open('w') as f: json.dump(vars(args), f, indent=2, sort_keys=True, default=str) logger.info(f'Table Bert Config: {table_bert_config.to_log_string()}') # copy the table bert config file to the working directory # shutil.copy(args.data_dir / 'config.json', args.output_dir / 'tb_config.json') # save table BERT config table_bert_config.save(args.output_dir / 'tb_config.json') assert args.data_dir.is_dir(), \ "--data_dir should point to the folder of files made by pregenerate_training_data.py!" if args.cpu: device = torch.device('cpu') else: device = torch.device(f'cuda:{torch.cuda.current_device()}') logger.info("device: {} gpu_id: {}, distributed training: {}, 16-bits training: {}".format( device, args.local_rank, bool(args.multi_gpu), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( args.gradient_accumulation_steps)) real_batch_size = args.train_batch_size # // args.gradient_accumulation_steps random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if not args.cpu: torch.cuda.manual_seed_all(args.seed) if args.output_dir.is_dir() and list(args.output_dir.iterdir()): logger.warning(f"Output directory ({args.output_dir}) already exists and is not empty!") args.output_dir.mkdir(parents=True, exist_ok=True) # Prepare model if args.multi_gpu and args.global_rank != 0: torch.distributed.barrier() if args.no_init: raise NotImplementedError else: model = task['model'](table_bert_config) if args.multi_gpu and args.global_rank == 0: torch.distributed.barrier() if args.fp16: model = model.half() model = model.to(device) if args.multi_gpu: if args.ddp_backend == 'pytorch': model = nn.parallel.DistributedDataParallel( model, find_unused_parameters=True, device_ids=[args.local_rank], output_device=args.local_rank, broadcast_buffers=False ) else: import apex model = apex.parallel.DistributedDataParallel(model, delay_allreduce=True) model_ptr = model.module else: model_ptr = model # set up update parameters for LR scheduler dataset_cls = task['dataset'] train_set_info = dataset_cls.get_dataset_info(train_data_dir, args.max_epoch) total_num_updates = train_set_info['total_size'] // args.train_batch_size // args.world_size // args.gradient_accumulation_steps args.max_epoch = train_set_info['max_epoch'] logger.info(f'Train data size: {train_set_info["total_size"]} for {args.max_epoch} epochs, total num. updates: {total_num_updates}') args.total_num_update = total_num_updates args.warmup_updates = int(total_num_updates * 0.1) trainer = Trainer(model, args) checkpoint_file = args.output_dir / 'model.ckpt.bin' is_resumed = False # trainer.save_checkpoint(checkpoint_file) if checkpoint_file.exists(): logger.info(f'Logging checkpoint file {checkpoint_file}') is_resumed = True trainer.load_checkpoint(checkpoint_file) model.train() # we also partitation the dev set for every local process logger.info('Loading dev set...') sys.stdout.flush() dev_set = dataset_cls(epoch=0, training_path=dev_data_dir, tokenizer=model_ptr.tokenizer, config=table_bert_config, multi_gpu=args.multi_gpu, debug=args.debug_dataset) logger.info("***** Running training *****") logger.info(f" Current config: {args}") if trainer.num_updates > 0: logger.info(f'Resume training at epoch {trainer.epoch}, ' f'epoch step {trainer.in_epoch_step}, ' f'global step {trainer.num_updates}') start_epoch = trainer.epoch for epoch in range(start_epoch, args.max_epoch): # inclusive model.train() with torch.random.fork_rng(devices=None if args.cpu else [device.index]): torch.random.manual_seed(131 + epoch) epoch_dataset = dataset_cls(epoch=trainer.epoch, training_path=train_data_dir, config=table_bert_config, tokenizer=model_ptr.tokenizer, multi_gpu=args.multi_gpu, debug=args.debug_dataset) train_sampler = RandomSampler(epoch_dataset) train_dataloader = DataLoader(epoch_dataset, sampler=train_sampler, batch_size=real_batch_size, num_workers=0, collate_fn=epoch_dataset.collate) samples_iter = GroupedIterator(iter(train_dataloader), args.gradient_accumulation_steps) trainer.resume_batch_loader(samples_iter) with tqdm(total=len(samples_iter), initial=trainer.in_epoch_step, desc=f"Epoch {epoch}", file=sys.stdout, disable=not args.is_master, miniters=100) as pbar: for samples in samples_iter: logging_output = trainer.train_step(samples) pbar.update(1) pbar.set_postfix_str(', '.join(f"{k}: {v:.4f}" for k, v in logging_output.items())) if ( 0 < trainer.num_updates and trainer.num_updates % args.save_checkpoint_every_niter == 0 and args.is_master ): # Save model checkpoint logger.info("** ** * Saving checkpoint file ** ** * ") trainer.save_checkpoint(checkpoint_file) logger.info(f'Epoch {epoch} finished.') if args.is_master: # Save a trained table_bert logger.info("** ** * Saving fine-tuned table_bert ** ** * ") model_to_save = model_ptr # Only save the table_bert it-self output_model_file = args.output_dir / f"pytorch_model_epoch{epoch:02d}.bin" torch.save(model_to_save.state_dict(), str(output_model_file)) # perform validation logger.info("** ** * Perform validation ** ** * ") dev_results = trainer.validate(dev_set) if args.is_master: logger.info('** ** * Validation Results ** ** * ') logger.info(f'Epoch {epoch} Validation Results: {dev_results}') # flush logging information to disk sys.stderr.flush() trainer.next_epoch()
def main(params, greedy, beam_size, test): """ The main function for decoding a trained MT model Arguments: params: parameters related to the `model` that is being decoded greedy: whether or not to do greedy decoding beam_size: size of beam if doing beam search """ print("Loading dataset...") _, dev_iter, test_iterator, DE, EN = load_dataset(params.data_path, params.train_batch_size, params.dev_batch_size) de_size, en_size = len(DE.vocab), len(EN.vocab) print("[DE Vocab Size: ]: {}, [EN Vocab Size]: {}".format( de_size, en_size)) params.src_vocab_size = de_size params.tgt_vocab_size = en_size params.sos_index = EN.vocab.stoi["<s>"] params.pad_token = EN.vocab.stoi["<pad>"] params.eos_index = EN.vocab.stoi["</s>"] params.itos = EN.vocab.itos device = torch.device('cuda' if params.cuda else 'cpu') params.device = device # make the Seq2Seq model model = make_seq2seq_model(params) # load the saved model for evaluation if params.average > 1: print("Averaging the last {} checkpoints".format(params.average)) checkpoint = {} checkpoint["state_dict"] = average_checkpoints(params.model_dir, params.average) model = Trainer.load_checkpoint(model, checkpoint) else: model_path = os.path.join(params.model_dir + "checkpoints/", params.model_file) print("Restoring parameters from {}".format(model_path)) model = Trainer.load_checkpoint(model, model_path) # evaluate on the test set if test: print("Doing Beam Search on the Test Set") test_decoder = Translator(model, test_iterator, params, device) test_beam_search_outputs = test_decoder.beam_decode( beam_width=beam_size) test_decoder.output_decoded_translations( test_beam_search_outputs, "beam_search_outputs_size_test={}.en".format(beam_size)) return # instantiate a Translator object to translate SRC langauge to TRG language using Greedy/Beam Decoding decoder = Translator(model, dev_iter, params, device) if greedy: print("Doing Greedy Decoding...") greedy_outputs = decoder.greedy_decode(max_len=100) decoder.output_decoded_translations(greedy_outputs, "greedy_outputs.en") print("Evaluating BLEU Score on Greedy Tranlsation...") subprocess.call([ './utils/eval.sh', params.model_dir + "outputs/greedy_outputs.en" ]) if beam_size: print("Doing Beam Search...") beam_search_outputs = decoder.beam_decode(beam_width=beam_size) decoder.output_decoded_translations( beam_search_outputs, "beam_search_outputs_size={}.en".format(beam_size)) print("Evaluating BLEU Score on Beam Search Translation") subprocess.call([ './utils/eval.sh', params.model_dir + "outputs/beam_search_outputs_size={}.en".format(beam_size) ])