示例#1
0
文件: generate.py 项目: pluiez/thseq
 def infer(batch):
     batch = utils.move_cuda(batch)
     return model.infer(
         batch['x'], args.k,
         args.penalty,
         args.alpha,
         args.step_wise_penalty,
         args.min_len_a,
         args.min_len_b,
         args.max_len_a,
         args.max_len_b,
         topk=args.topk
     )
示例#2
0
文件: train.py 项目: pluiez/thseq
    def validate_logp(model, data_loader):
        if not utils.is_master():
            return None

        model.eval()
        num_tokens = 0
        total_loss = 0
        with model.no_sync() if isinstance(model, DDP) else contextlib.ExitStack():
            for i, batch in enumerate(data_loader):
                batch = utils.move_cuda(batch)
                logits = model(batch)
                nll_loss = F.cross_entropy(logits.transpose(2, 1).float(), batch['y'], reduction='sum')
                total_loss += nll_loss.float().item()
                num_tokens += batch['true_tokens_y']
        return -total_loss / num_tokens
示例#3
0
文件: trainer.py 项目: pluiez/thseq
    def fb(self, batch: dict, dummy=False):
        self.model.train()
        batch = utils.move_cuda(batch)
        with self._amp_autocast():
            loss, nll_loss = self.criterion(self.model(batch), batch['y'])
            if dummy:
                loss = loss * 0.0

        self.grad_scaler.scale(loss).backward()
        loss = loss.item()
        nll_loss = nll_loss.item()

        if self.args.profile:
            utils.profile_nan(self.model)

        return loss, nll_loss
示例#4
0
文件: generate.py 项目: pluiez/thseq
def main(args):
    logger.info('Loading checkpoints ...')
    model, vocabularies = load(args.checkpoints, verbose=args.verbose)
    s_vocab, t_vocab = vocabularies
    model = utils.move_cuda(model)

    def infer(batch):
        batch = utils.move_cuda(batch)
        return model.infer(
            batch['x'], args.k,
            args.penalty,
            args.alpha,
            args.step_wise_penalty,
            args.min_len_a,
            args.min_len_b,
            args.max_len_a,
            args.max_len_b,
            topk=args.topk
        )

    translator = Translator(infer, args.bpe, args.reverse, args.topk, s_vocab, args.verbose)

    meter = utils.SpeedMeter()

    logger.info('Building iterator ...')
    it = get_iterator(args, s_vocab)

    n_tok = 0
    n_snt = 0
    meter.start()

    logger.info('Start generation ...')
    for batch in it:
        translator.translate(batch)
        n_tok += batch['size_x']
        n_snt += len(batch['index'])
        meter.stop(batch['size_x'])

    sys.stderr.write(
        f'Sentences = {n_snt}, Tokens = {n_tok}, \n'
        f'Time = {datetime.timedelta(seconds=meter.duration)}, \n'
        f'Speed = {meter.avg:.2f} tok/s, {n_snt / meter.duration:.2f} snt/s\n'
    )
示例#5
0
文件: train.py 项目: pluiez/thseq
    def validate_bleu(model, data_loader, beam_size: int = 1):
        if not utils.is_master():
            return None

        model.eval()
        with model.no_sync() if isinstance(model, DDP) else contextlib.ExitStack():
            hyps = []
            refs = []

            infer = model.module.infer if isinstance(model, DDP) else model.infer

            for i, batch in enumerate(data_loader):
                batch = utils.move_cuda(batch)
                batch_hyps = infer(batch['x'], beam_size)
                batch_hyps = [k_best[0] for k_best in batch_hyps]

                hyps.extend([h['tokens'] for h in batch_hyps])
                refs.extend(batch['refs'])

            results = {'hyp': hyps, 'ref': refs}
            bleu = utils.bleu(results['ref'], results['hyp'])
            return bleu
示例#6
0
文件: train.py 项目: pluiez/thseq
def run(args):
    disable_numba_logging()
    if utils.is_master():
        with Path(args.checkpoint, 'config.json').open('w') as w:
            w.write(json.dumps(args.__dict__, indent=4, sort_keys=True))
    utils.seed(args.seed)
    sv, tv = load_vocab(args)

    logger.info('Building model')
    model = utils.move_cuda(models.build(args, [sv, tv]))
    criterion = utils.move_cuda(losses.CrossEntropyLoss(args.label_smoothing, ignore_index=tv.pad_id))
    optimizer = utils.move_cuda(optim.build(args, model.parameters()))
    lr_scheduler = lr_schedulers.build(args, args.lr, optimizer)

    state_dict = utils.checkpoint.load_latest(args.checkpoint)

    if not state_dict:
        logger.info(f'Model: \n{model}')
    elif args.finetune:
        raise ValueError(f'fine-tuning is not available while trying to resume training from an existing checkpoint.')

    logger.info(f'FP16: {args.fp16}')
    if args.fp16 and args.fp16 != 'none':
        if torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] < 7:
            logger.warning('Target device does not support acceleration with --fp16')
        if args.fp16 == 'half':
            model = model.half()
            criterion = criterion.half()

    train_itr = build_train_itr(args, sv, tv)
    trainer = Trainer(args, model, criterion, [sv, tv], optimizer, lr_scheduler, train_itr)

    logger.info(f'Max sentences = {args.max_sentences}, '
                f'max tokens = {args.max_tokens} ')

    state_dict = state_dict or try_prepare_finetune(args, model)

    stat_parameters(model)

    # Restore training process
    if state_dict:
        logger.info('Resuming from given checkpoint')
        trainer.load_state_dict(state_dict, no_load_model=bool(args.finetune))
        del state_dict

    if dist.is_initialized():
        model = DDP(
            model,
            device_ids=[torch.cuda.current_device()],
            output_device=torch.cuda.current_device(),
            find_unused_parameters=True
        )
        trainer.model = model

    check_devices(args.dist_world_size)

    def after_epoch_callback():
        logger.info(f'Finished epoch {trainer.epoch}. ')

    try:
        from torch.utils.tensorboard import SummaryWriter
    except ImportError:
        logger.warning('Tensorboard is not available.')
        SummaryWriter = utils.DummySummaryWriter

    if not utils.is_master():
        SummaryWriter = utils.DummySummaryWriter

    writer = SummaryWriter(log_dir=Path(args.checkpoint, 'tensorboard'),
                           purge_step=trainer.global_step or None)

    validate_callback = None
    if args.dev:
        validate_callback = get_validator(args, model, sv, tv)
    trainer.timer.start()

    with writer:
        with torch.autograd.profiler.record_function('train_loop'):
            trainer.train(validate_callback=validate_callback,
                          before_epoch_callback=None,
                          after_epoch_callback=after_epoch_callback,
                          summary_writer=writer)

    logger.info(f'Training finished @ {time.strftime("%b %d, %Y, %H:%M:%S", time.localtime())}, '
                f'took {datetime.timedelta(seconds=trainer.elapse // 1)}')

    logger.info(f'Best validation score: {trainer.best_score}, @ {trainer.best_at}')