Ejemplo n.º 1
0
def main(params):
    logging.info("Loading the datasets...")
    train_iter, dev_iter, test_iterator, DE, EN = load_dataset(
        params.data_path, params.train_batch_size, params.dev_batch_size)
    de_size, en_size = len(DE.vocab), len(EN.vocab)
    logging.info("[DE Vocab Size]: {}, [EN Vocab Size]: {}".format(
        de_size, en_size))
    logging.info("- done.")

    params.src_vocab_size = de_size
    params.tgt_vocab_size = en_size
    params.sos_index = EN.vocab.stoi["<s>"]
    params.pad_token = EN.vocab.stoi["<pad>"]
    params.eos_index = EN.vocab.stoi["</s>"]
    params.itos = EN.vocab.itos
    params.SRC = DE
    params.TRG = EN

    # make the Seq2Seq model
    model = make_seq2seq_model(params)

    # default optimizer
    optimizer = optim.Adam(model.parameters(), lr=params.lr)

    if params.model_type == "Transformer":
        criterion = LabelSmoothingLoss(params.label_smoothing,
                                       params.tgt_vocab_size,
                                       params.pad_token).to(params.device)
        optimizer = ScheduledOptimizer(optimizer=optimizer,
                                       d_model=params.hidden_size,
                                       factor=2,
                                       n_warmup_steps=params.n_warmup_steps)
        scheduler = None
    else:
        criterion = nn.NLLLoss(reduction="sum", ignore_index=params.pad_token)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, patience=params.patience, factor=.1, verbose=True)

    # intialize the Trainer
    trainer = Trainer(model, optimizer, scheduler, criterion, train_iter,
                      dev_iter, params)

    if params.restore_file:
        restore_path = os.path.join(params.model_dir + "/checkpoints/",
                                    params.restore_file)
        logging.info("Restoring parameters from {}".format(restore_path))
        Trainer.load_checkpoint(model, restore_path, optimizer)

    # train the model
    trainer.train()
Ejemplo n.º 2
0
def main():
    args = parse_train_arg()
    task = task_dict[args.task]

    init_distributed_mode(args)
    logger = init_logger(args)

    if hasattr(args, 'base_model_name'):
        logger.warning('Argument base_model_name is deprecated! Use `--table-bert-extra-config` instead!')

    init_signal_handler()

    train_data_dir = args.data_dir / 'train'
    dev_data_dir = args.data_dir / 'dev'
    table_bert_config = task['config'].from_file(
        args.data_dir / 'config.json', **args.table_bert_extra_config)

    if args.is_master:
        args.output_dir.mkdir(exist_ok=True, parents=True)
        with (args.output_dir / 'train_config.json').open('w') as f:
            json.dump(vars(args), f, indent=2, sort_keys=True, default=str)

        logger.info(f'Table Bert Config: {table_bert_config.to_log_string()}')

        # copy the table bert config file to the working directory
        # shutil.copy(args.data_dir / 'config.json', args.output_dir / 'tb_config.json')
        # save table BERT config
        table_bert_config.save(args.output_dir / 'tb_config.json')

    assert args.data_dir.is_dir(), \
        "--data_dir should point to the folder of files made by pregenerate_training_data.py!"

    if args.cpu:
        device = torch.device('cpu')
    else:
        device = torch.device(f'cuda:{torch.cuda.current_device()}')

    logger.info("device: {} gpu_id: {}, distributed training: {}, 16-bits training: {}".format(
        device, args.local_rank, bool(args.multi_gpu), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
            args.gradient_accumulation_steps))

    real_batch_size = args.train_batch_size  # // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if not args.cpu:
        torch.cuda.manual_seed_all(args.seed)

    if args.output_dir.is_dir() and list(args.output_dir.iterdir()):
        logger.warning(f"Output directory ({args.output_dir}) already exists and is not empty!")
    args.output_dir.mkdir(parents=True, exist_ok=True)

    # Prepare model
    if args.multi_gpu and args.global_rank != 0:
        torch.distributed.barrier()

    if args.no_init:
        raise NotImplementedError
    else:
        model = task['model'](table_bert_config)

    if args.multi_gpu and args.global_rank == 0:
        torch.distributed.barrier()

    if args.fp16:
        model = model.half()

    model = model.to(device)
    if args.multi_gpu:
        if args.ddp_backend == 'pytorch':
            model = nn.parallel.DistributedDataParallel(
                model,
                find_unused_parameters=True,
                device_ids=[args.local_rank], output_device=args.local_rank,
                broadcast_buffers=False
            )
        else:
            import apex
            model = apex.parallel.DistributedDataParallel(model, delay_allreduce=True)

        model_ptr = model.module
    else:
        model_ptr = model

    # set up update parameters for LR scheduler
    dataset_cls = task['dataset']

    train_set_info = dataset_cls.get_dataset_info(train_data_dir, args.max_epoch)
    total_num_updates = train_set_info['total_size'] // args.train_batch_size // args.world_size // args.gradient_accumulation_steps
    args.max_epoch = train_set_info['max_epoch']
    logger.info(f'Train data size: {train_set_info["total_size"]} for {args.max_epoch} epochs, total num. updates: {total_num_updates}')

    args.total_num_update = total_num_updates
    args.warmup_updates = int(total_num_updates * 0.1)

    trainer = Trainer(model, args)

    checkpoint_file = args.output_dir / 'model.ckpt.bin'
    is_resumed = False
    # trainer.save_checkpoint(checkpoint_file)
    if checkpoint_file.exists():
        logger.info(f'Logging checkpoint file {checkpoint_file}')
        is_resumed = True
        trainer.load_checkpoint(checkpoint_file)

    model.train()

    # we also partitation the dev set for every local process
    logger.info('Loading dev set...')
    sys.stdout.flush()
    dev_set = dataset_cls(epoch=0, training_path=dev_data_dir, tokenizer=model_ptr.tokenizer, config=table_bert_config,
                          multi_gpu=args.multi_gpu, debug=args.debug_dataset)

    logger.info("***** Running training *****")
    logger.info(f"  Current config: {args}")

    if trainer.num_updates > 0:
        logger.info(f'Resume training at epoch {trainer.epoch}, '
                    f'epoch step {trainer.in_epoch_step}, '
                    f'global step {trainer.num_updates}')

    start_epoch = trainer.epoch
    for epoch in range(start_epoch, args.max_epoch):  # inclusive
        model.train()

        with torch.random.fork_rng(devices=None if args.cpu else [device.index]):
            torch.random.manual_seed(131 + epoch)

            epoch_dataset = dataset_cls(epoch=trainer.epoch, training_path=train_data_dir, config=table_bert_config,
                                        tokenizer=model_ptr.tokenizer, multi_gpu=args.multi_gpu, debug=args.debug_dataset)
            train_sampler = RandomSampler(epoch_dataset)
            train_dataloader = DataLoader(epoch_dataset, sampler=train_sampler, batch_size=real_batch_size,
                                          num_workers=0,
                                          collate_fn=epoch_dataset.collate)

        samples_iter = GroupedIterator(iter(train_dataloader), args.gradient_accumulation_steps)
        trainer.resume_batch_loader(samples_iter)

        with tqdm(total=len(samples_iter), initial=trainer.in_epoch_step,
                  desc=f"Epoch {epoch}", file=sys.stdout, disable=not args.is_master, miniters=100) as pbar:

            for samples in samples_iter:
                logging_output = trainer.train_step(samples)

                pbar.update(1)
                pbar.set_postfix_str(', '.join(f"{k}: {v:.4f}" for k, v in logging_output.items()))

                if (
                    0 < trainer.num_updates and
                    trainer.num_updates % args.save_checkpoint_every_niter == 0 and
                    args.is_master
                ):
                    # Save model checkpoint
                    logger.info("** ** * Saving checkpoint file ** ** * ")
                    trainer.save_checkpoint(checkpoint_file)

            logger.info(f'Epoch {epoch} finished.')

            if args.is_master:
                # Save a trained table_bert
                logger.info("** ** * Saving fine-tuned table_bert ** ** * ")
                model_to_save = model_ptr  # Only save the table_bert it-self
                output_model_file = args.output_dir / f"pytorch_model_epoch{epoch:02d}.bin"
                torch.save(model_to_save.state_dict(), str(output_model_file))

            # perform validation
            logger.info("** ** * Perform validation ** ** * ")
            dev_results = trainer.validate(dev_set)

            if args.is_master:
                logger.info('** ** * Validation Results ** ** * ')
                logger.info(f'Epoch {epoch} Validation Results: {dev_results}')

            # flush logging information to disk
            sys.stderr.flush()

        trainer.next_epoch()
Ejemplo n.º 3
0
def main(params, greedy, beam_size, test):
    """
    The main function for decoding a trained MT model
    Arguments:
        params: parameters related to the `model` that is being decoded
        greedy: whether or not to do greedy decoding
        beam_size: size of beam if doing beam search
    """
    print("Loading dataset...")
    _, dev_iter, test_iterator, DE, EN = load_dataset(params.data_path,
                                                      params.train_batch_size,
                                                      params.dev_batch_size)
    de_size, en_size = len(DE.vocab), len(EN.vocab)
    print("[DE Vocab Size: ]: {}, [EN Vocab Size]: {}".format(
        de_size, en_size))

    params.src_vocab_size = de_size
    params.tgt_vocab_size = en_size
    params.sos_index = EN.vocab.stoi["<s>"]
    params.pad_token = EN.vocab.stoi["<pad>"]
    params.eos_index = EN.vocab.stoi["</s>"]
    params.itos = EN.vocab.itos

    device = torch.device('cuda' if params.cuda else 'cpu')
    params.device = device

    # make the Seq2Seq model
    model = make_seq2seq_model(params)

    # load the saved model for evaluation
    if params.average > 1:
        print("Averaging the last {} checkpoints".format(params.average))
        checkpoint = {}
        checkpoint["state_dict"] = average_checkpoints(params.model_dir,
                                                       params.average)
        model = Trainer.load_checkpoint(model, checkpoint)
    else:
        model_path = os.path.join(params.model_dir + "checkpoints/",
                                  params.model_file)
        print("Restoring parameters from {}".format(model_path))
        model = Trainer.load_checkpoint(model, model_path)

    # evaluate on the test set
    if test:
        print("Doing Beam Search on the Test Set")
        test_decoder = Translator(model, test_iterator, params, device)
        test_beam_search_outputs = test_decoder.beam_decode(
            beam_width=beam_size)
        test_decoder.output_decoded_translations(
            test_beam_search_outputs,
            "beam_search_outputs_size_test={}.en".format(beam_size))
        return

    # instantiate a Translator object to translate SRC langauge to TRG language using Greedy/Beam Decoding
    decoder = Translator(model, dev_iter, params, device)

    if greedy:
        print("Doing Greedy Decoding...")
        greedy_outputs = decoder.greedy_decode(max_len=100)
        decoder.output_decoded_translations(greedy_outputs,
                                            "greedy_outputs.en")

        print("Evaluating BLEU Score on Greedy Tranlsation...")
        subprocess.call([
            './utils/eval.sh', params.model_dir + "outputs/greedy_outputs.en"
        ])

    if beam_size:
        print("Doing Beam Search...")
        beam_search_outputs = decoder.beam_decode(beam_width=beam_size)
        decoder.output_decoded_translations(
            beam_search_outputs,
            "beam_search_outputs_size={}.en".format(beam_size))

        print("Evaluating BLEU Score on Beam Search Translation")
        subprocess.call([
            './utils/eval.sh', params.model_dir +
            "outputs/beam_search_outputs_size={}.en".format(beam_size)
        ])