コード例 #1
0
def validate(args, epoch, trainer, criterion, dataset, subset, ngpus):
    """Evaluate the model on the validation set and return the average loss."""

    itr = dataset.dataloader(subset,
                             batch_size=None,
                             max_tokens=args.max_tokens,
                             max_positions=args.max_positions,
                             skip_invalid_size_inputs_valid_test=args.
                             skip_invalid_size_inputs_valid_test)
    loss_meter = AverageMeter()

    desc = '| epoch {:03d} | valid on \'{}\' subset'.format(epoch, subset)
    with progress_bar(itr, desc, leave=False) as t:
        for _, sample in data.skip_group_enumerator(t, ngpus):
            ntokens = sum(s['ntokens'] for s in sample)
            loss = trainer.valid_step(sample, criterion)
            loss_meter.update(loss, ntokens)
            t.set_postfix(loss='{:.2f}'.format(loss_meter.avg), refresh=False)

        val_loss = loss_meter.avg
        t.write(desc + ' | valid loss {:2.2f} | valid ppl {:3.2f}'.format(
            val_loss, math.pow(2, val_loss)))

    # update and return the learning rate
    return val_loss
コード例 #2
0
ファイル: trainer.py プロジェクト: srbutler/MaskGAN.pytorch
    def rollout_generator(self, num_rollouts, samples):
        masked, unmasked, lengths, mask = samples
        batch_size, seq_len = samples[0].size()
        meter = AverageMeter()
        ppl_meter = defaultdict(lambda: AverageMeter())
        self.opt.zero_grad()
        pbar = _tqdm(num_rollouts, 'generator-rollout')

        for rollout in pbar:
            loss, generated, ppl = self.model(masked,
                                              lengths,
                                              mask,
                                              unmasked,
                                              tag="g-step")
            loss = loss.sum() / batch_size
            loss.backward()
            meter.update(-1 * loss.item())
            # for key in ppl:
            #     ppl[key] = ppl[key].sum() / batch_size
            #     ppl_meter[key].update(ppl[key].item())
        self.opt.step()
        self.logger.log("generator/advantage", self.step, meter.avg)
        # for key in ppl_meter:
        #     self.logger.log("ppl/{}".format(key), ppl_meter[key].avg)

        self.debug('train', samples, generated)
コード例 #3
0
def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus):
    """Train the model for one epoch."""

    itr = dataset.dataloader(
        args.train_subset,
        num_workers=args.workers,
        max_tokens=args.max_tokens,
        seed=args.seed,
        epoch=epoch,
        max_positions=args.max_positions,
        sample_without_replacement=args.sample_without_replacement,
        skip_invalid_size_inputs_valid_test=args.
        skip_invalid_size_inputs_valid_test)
    loss_meter = AverageMeter()
    bsz_meter = AverageMeter()  # sentences per batch
    wpb_meter = AverageMeter()  # words per batch
    wps_meter = TimeMeter()  # words per second
    clip_meter = AverageMeter()  # % of updates clipped
    gnorm_meter = AverageMeter()  # gradient norm

    desc = '| epoch {:03d}'.format(epoch)
    lr = trainer.get_lr()
    with progress_bar(itr, desc, leave=False) as t:
        for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset):
            loss, grad_norm = trainer.train_step(sample, criterion)

            ntokens = sum(s['ntokens'] for s in sample)
            src_size = sum(s['src_tokens'].size(0) for s in sample)
            loss_meter.update(loss, ntokens)
            bsz_meter.update(src_size)
            wpb_meter.update(ntokens)
            wps_meter.update(ntokens)
            clip_meter.update(1 if grad_norm > args.clip_norm else 0)
            gnorm_meter.update(grad_norm)

            t.set_postfix(collections.OrderedDict([
                ('loss', '{:.2f} ({:.2f})'.format(loss, loss_meter.avg)),
                ('wps', '{:5d}'.format(round(wps_meter.avg))),
                ('wpb', '{:5d}'.format(round(wpb_meter.avg))),
                ('bsz', '{:5d}'.format(round(bsz_meter.avg))),
                ('lr', lr),
                ('clip', '{:3.0f}%'.format(clip_meter.avg * 100)),
                ('gnorm', '{:.4f}'.format(gnorm_meter.avg)),
            ]),
                          refresh=False)

            if i == 0:
                # ignore the first mini-batch in words-per-second calculation
                wps_meter.reset()
            if args.save_interval > 0 and (i + 1) % args.save_interval == 0:
                trainer.save_checkpoint(args, epoch, i + 1)

        fmt = desc + ' | train loss {:2.2f} | train ppl {:3.2f}'
        fmt += ' | s/checkpoint {:7d} | words/s {:6d} | words/batch {:6d}'
        fmt += ' | bsz {:5d} | lr {:0.6f} | clip {:3.0f}% | gnorm {:.4f}'
        t.write(
            fmt.format(loss_meter.avg, math.pow(2, loss_meter.avg),
                       round(wps_meter.elapsed_time), round(wps_meter.avg),
                       round(wpb_meter.avg), round(bsz_meter.avg), lr,
                       clip_meter.avg * 100, gnorm_meter.avg))
コード例 #4
0
ファイル: trainer.py プロジェクト: srbutler/MaskGAN.pytorch
    def rollout_critic(self, num_rollouts, samples):
        masked, unmasked, lengths, mask = samples
        batch_size, seq_len = samples[0].size()
        meter = AverageMeter()
        self.opt.zero_grad()
        pbar = _tqdm(num_rollouts, 'critic-rollout')
        for rollout in pbar:
            loss = self.model(masked, lengths, mask, unmasked, tag="c-step")
            loss = loss.sum() / batch_size
            loss.backward()
            meter.update(loss.item())

        self.opt.step()
        self.logger.log("critic/loss", self.step, meter.avg)
コード例 #5
0
def validate(args, epoch, trainer, dataset, max_positions, subset):
    """Evaluate the model on the validation set and return the average loss."""

    itr = dataset.eval_dataloader(
        subset, max_tokens=args.max_tokens, max_sentences=args.max_sentences_valid,
        max_positions=max_positions,
        skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
        descending=True,  # largest batch first to warm the caching allocator
    )
    loss_meter = AverageMeter()
    nll_loss_meter = AverageMeter()
    extra_meters = collections.defaultdict(lambda: AverageMeter())

    prefix = 'valid on \'{}\' subset'.format(subset)
    with utils.build_progress_bar(args, itr, epoch, prefix) as t:
        for _, sample in data.skip_group_enumerator(t, args.num_gpus):
            loss_dict = trainer.valid_step(sample)
            ntokens = sum(s['ntokens'] for s in sample)
            loss = loss_dict['loss']
            del loss_dict['loss']  # don't include in extra_meters or extra_postfix

            if 'nll_loss' in loss_dict:
                nll_loss = loss_dict['nll_loss']
                nll_loss_meter.update(nll_loss, ntokens)

            loss_meter.update(loss, ntokens)

            extra_postfix = []
            for k, v in loss_dict.items():
                extra_meters[k].update(v)
                extra_postfix.append((k, extra_meters[k].avg))

            t.log(collections.OrderedDict([
                ('valid loss', round(loss_meter.avg, 2)),
            ] + extra_postfix))

        t.print(collections.OrderedDict([
            ('valid loss', round(loss_meter.avg, 2)),
            ('valid ppl', get_perplexity(nll_loss_meter.avg
                                         if nll_loss_meter.count > 0
                                         else loss_meter.avg)),
        ] + [
            (k, meter.avg)
            for k, meter in extra_meters.items()
        ]))

    # update and return the learning rate
    return loss_meter.avg
コード例 #6
0
ファイル: trainer.py プロジェクト: srbutler/MaskGAN.pytorch
    def rollout_discriminator(self, num_rollouts, samples):
        masked, unmasked, lengths, mask = samples
        real, fake = AverageMeter(), AverageMeter()
        batch_size, seq_len = samples[0].size()

        self.opt.zero_grad()
        pbar = _tqdm(num_rollouts, 'discriminator-rollout')

        for rollout in pbar:
            real_loss = self.model(masked,
                                   lengths,
                                   mask,
                                   unmasked,
                                   tag="d-step",
                                   real=True)

            real_loss = real_loss.sum() / batch_size

            with torch.no_grad():
                net_output = self.model(masked,
                                        lengths,
                                        mask,
                                        unmasked,
                                        tag="g-step")
                generated = net_output[1]

            fake_loss = self.model(masked,
                                   lengths,
                                   mask,
                                   generated,
                                   tag="d-step",
                                   real=False)

            fake_loss = fake_loss.sum() / batch_size

            loss = (real_loss + fake_loss) / 2
            loss.backward()

            real.update(real_loss.item())
            fake.update(fake_loss.item())

        self.opt.step()
        self.logger.log("discriminator/real", self.step, real.avg)
        self.logger.log("discriminator/fake", self.step, fake.avg)
        self.logger.log("discriminator", self.step, real.avg + fake.avg)
コード例 #7
0
def validate(args, epoch, trainer, dataset, subset, ngpus):
    """Evaluate the model on the validation set and return the average loss."""

    itr = dataset.dataloader(subset,
                             batch_size=None,
                             max_tokens=args.max_tokens,
                             max_positions=args.max_positions,
                             skip_invalid_size_inputs_valid_test=args.
                             skip_invalid_size_inputs_valid_test)
    loss_meter = AverageMeter()
    extra_meters = collections.defaultdict(lambda: AverageMeter())

    desc = '| epoch {:03d} | valid on \'{}\' subset'.format(epoch, subset)
    with progress_bar(itr, desc, leave=False) as t:
        for _, sample in data.skip_group_enumerator(t, ngpus):
            loss_dict = trainer.valid_step(sample)
            loss = loss_dict['loss']
            del loss_dict[
                'loss']  # don't include in extra_meters or extra_postfix

            ntokens = sum(s['ntokens'] for s in sample)
            loss_meter.update(loss, ntokens)

            extra_postfix = []
            for k, v in loss_dict.items():
                extra_meters[k].update(v)
                extra_postfix.append((k, '{:.4f}'.format(extra_meters[k].avg)))

            t.set_postfix(collections.OrderedDict([
                ('loss', '{:.2f}'.format(loss_meter.avg)),
            ] + extra_postfix),
                          refresh=False)

        val_loss = loss_meter.avg
        fmt = desc + ' | valid loss {:2.2f} | valid ppl {:3.2f}'.format(
            val_loss, get_perplexity(val_loss))
        fmt += ''.join(' | {} {:.4f}'.format(k, meter.avg)
                       for k, meter in extra_meters.items())
        t.write(fmt)

    # update and return the learning rate
    return val_loss
コード例 #8
0
ファイル: train.py プロジェクト: ahiroto/ParlAI
def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
    """Evaluate the model on the validation set and return the average loss."""

    itr = dataset.eval_dataloader(
        subset, max_tokens=args.max_tokens, max_sentences=args.max_sentences,
        max_positions=max_positions,
        skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
        descending=True,  # largest batch first to warm the caching allocator
    )
    loss_meter = AverageMeter()
    extra_meters = collections.defaultdict(lambda: AverageMeter())

    prefix = 'valid on \'{}\' subset'.format(subset)
    with utils.build_progress_bar(args, itr, epoch, prefix) as t:
        for _, sample in data.skip_group_enumerator(t, ngpus):
            loss_dict = trainer.valid_step(sample)
            loss = loss_dict['loss']
            del loss_dict['loss']  # don't include in extra_meters or extra_postfix

            ntokens = sum(s['ntokens'] for s in sample)
            loss_meter.update(loss, ntokens)

            extra_postfix = []
            for k, v in loss_dict.items():
                extra_meters[k].update(v)
                extra_postfix.append((k, extra_meters[k].avg))

            t.log(collections.OrderedDict([
                ('valid loss', round(loss_meter.avg, 2)),
            ] + extra_postfix))

        t.print(collections.OrderedDict([
            ('valid loss', round(loss_meter.avg, 2)),
            ('valid ppl', get_perplexity(loss_meter.avg)),
        ] + [
            (k, meter.avg)
            for k, meter in extra_meters.items()
        ]))

    # update and return the learning rate
    return loss_meter.avg
コード例 #9
0
ファイル: train.py プロジェクト: Novemser/sum
def validate(args, epoch, trainer, criterion, dataset, subset, ngpus):
    """Evaluate the model on the validation set and return the average loss."""

    itr = dataset.dataloader(subset,
                             batch_size=None,
                             max_tokens=args.max_tokens,
                             max_positions=args.max_positions)
    loss_meter = AverageMeter()
    rouge_greedy_meter = AverageMeter()
    rouge_sampled_meter = AverageMeter()

    desc = '| epoch {:03d} | valid on \'{}\' subset'.format(epoch, subset)
    with progress_bar(itr, desc, leave=False) as t:
        for _, sample in data.skip_group_enumerator(t, ngpus):
            ntokens = sum(s['ntokens'] for s in sample)
            loss, mean_rouge_greedy, mean_rouge_sampled = trainer.valid_step(
                sample, criterion)
            loss_meter.update(loss, ntokens)
            rouge_greedy_meter.update(mean_rouge_greedy, 1)
            rouge_sampled_meter.update(mean_rouge_sampled, 1)
            t.set_postfix(
                collections.OrderedDict([
                    ('loss', '{:.2f}'.format(loss_meter.avg)),
                    ('ROUGE-L/f (greedy)',
                     '{:.4f}'.format(rouge_greedy_meter.avg)),
                    ('ROUGE-L/f (sampled)',
                     '{:.4f}'.format(rouge_sampled_meter.avg))
                ]))

        val_loss = loss_meter.avg
        t.write(
            desc +
            ' | valid loss {:2.2f} | valid ppl {:3.2f} | ROUGE-L (greedy): {:.4f} | ROUGE-L (sampled): {:.4f}'
            .format(val_loss, math.pow(2, val_loss), rouge_greedy_meter.avg,
                    rouge_sampled_meter.avg))

    # update and return the learning rate
    return val_loss
コード例 #10
0
def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
    """Train the model for one epoch."""

    itr = dataset.dataloader(
        args.train_subset,
        num_workers=args.workers,
        max_tokens=args.max_tokens,
        seed=args.seed,
        epoch=epoch,
        max_positions=args.max_positions,
        sample_without_replacement=args.sample_without_replacement,
        skip_invalid_size_inputs_valid_test=args.
        skip_invalid_size_inputs_valid_test)
    loss_meter = AverageMeter()
    bsz_meter = AverageMeter()  # sentences per batch
    wpb_meter = AverageMeter()  # words per batch
    wps_meter = TimeMeter()  # words per second
    clip_meter = AverageMeter()  # % of updates clipped
    extra_meters = collections.defaultdict(lambda: AverageMeter())

    desc = '| epoch {:03d}'.format(epoch)
    trainer.set_seed(args.seed + epoch)
    lr = trainer.get_lr()
    with progress_bar(itr, desc, leave=False) as t:
        for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset):
            loss_dict = trainer.train_step(sample)
            loss = loss_dict['loss']
            del loss_dict[
                'loss']  # don't include in extra_meters or extra_postfix

            ntokens = sum(s['ntokens'] for s in sample)
            src_size = sum(s['src_tokens'].size(0) for s in sample)
            loss_meter.update(loss, ntokens)
            bsz_meter.update(src_size)
            wpb_meter.update(ntokens)
            wps_meter.update(ntokens)
            clip_meter.update(1 if loss_dict['gnorm'] > args.clip_norm else 0)

            extra_postfix = []
            for k, v in loss_dict.items():
                extra_meters[k].update(v)
                extra_postfix.append((k, '{:.4f}'.format(extra_meters[k].avg)))

            t.set_postfix(collections.OrderedDict([
                ('loss', '{:.2f} ({:.2f})'.format(loss, loss_meter.avg)),
                ('wps', '{:5d}'.format(round(wps_meter.avg))),
                ('wpb', '{:5d}'.format(round(wpb_meter.avg))),
                ('bsz', '{:5d}'.format(round(bsz_meter.avg))),
                ('lr', lr),
                ('clip', '{:3.0f}%'.format(clip_meter.avg * 100)),
            ] + extra_postfix),
                          refresh=False)

            if i == 0:
                # ignore the first mini-batch in words-per-second calculation
                wps_meter.reset()
            if args.save_interval > 0 and (i + 1) % args.save_interval == 0:
                save_checkpoint(trainer, args, epoch, i + 1)

        fmt = desc + ' | train loss {:2.2f} | train ppl {:3.2f}'.format(
            loss_meter.avg, get_perplexity(loss_meter.avg))
        fmt += ' | s/checkpoint {:7d} | words/s {:6d} | words/batch {:6d}'.format(
            round(wps_meter.elapsed_time), round(wps_meter.avg),
            round(wpb_meter.avg))
        fmt += ' | bsz {:5d} | lr {:0.6f} | clip {:3.0f}%'.format(
            round(bsz_meter.avg), lr, clip_meter.avg * 100)
        fmt += ''.join(' | {} {:.4f}'.format(k, meter.avg)
                       for k, meter in extra_meters.items())
        t.write(fmt)
コード例 #11
0
def train(args, epoch, batch_offset, trainer, dataset, max_positions,
          num_gpus):
    """Train the model for one epoch."""

    seed = args.seed + epoch
    torch.manual_seed(seed)
    trainer.set_seed(seed)

    itr = dataset.train_dataloader(
        args.train_subset,
        num_workers=args.workers,
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
        seed=seed,
        epoch=epoch,
        sample_without_replacement=args.sample_without_replacement,
        sort_by_source_size=(epoch <= args.curriculum))
    loss_meter = AverageMeter()
    bsz_meter = AverageMeter()  # sentences per batch
    wpb_meter = AverageMeter()  # words per batch
    wps_meter = TimeMeter()  # words per second
    clip_meter = AverageMeter()  # % of updates clipped
    extra_meters = collections.defaultdict(lambda: AverageMeter())

    lr = trainer.get_lr()
    with utils.build_progress_bar(args, itr, epoch) as t:
        for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset):
            loss_dict = trainer.train_step(sample)
            loss = loss_dict['loss']
            del loss_dict[
                'loss']  # don't include in extra_meters or extra_postfix

            ntokens = sum(s['ntokens'] for s in sample)
            nsentences = sum(s['src_tokens'].size(0) for s in sample)
            loss_meter.update(loss,
                              nsentences if args.sentence_avg else ntokens)
            bsz_meter.update(nsentences)
            wpb_meter.update(ntokens)
            wps_meter.update(ntokens)
            clip_meter.update(1 if loss_dict['gnorm'] > args.clip_norm else 0)

            extra_postfix = []
            for k, v in loss_dict.items():
                extra_meters[k].update(v)
                extra_postfix.append((k, extra_meters[k].avg))

            t.log(
                collections.OrderedDict([
                    ('loss', loss_meter),
                    ('wps', round(wps_meter.avg)),
                    ('wpb', round(wpb_meter.avg)),
                    ('bsz', round(bsz_meter.avg)),
                    ('lr', lr),
                    ('clip', '{:.0%}'.format(clip_meter.avg)),
                ] + extra_postfix))

            if i == 0:
                # ignore the first mini-batch in words-per-second calculation
                wps_meter.reset()
            if args.save_interval > 0 and (i + 1) % args.save_interval == 0:
                save_checkpoint(trainer, args, epoch, i + 1)

        t.print(
            collections.OrderedDict([
                ('train loss', round(loss_meter.avg, 2)),
                ('train ppl', get_perplexity(loss_meter.avg)),
                ('s/checkpoint', round(wps_meter.elapsed_time)),
                ('words/s', round(wps_meter.avg)),
                ('words/batch', round(wpb_meter.avg)),
                ('bsz', round(bsz_meter.avg)),
                ('lr', lr),
                ('clip', '{:3.0f}%'.format(clip_meter.avg * 100)),
            ] + [(k, meter.avg) for k, meter in extra_meters.items()]))
コード例 #12
0
def validate(val_loader, r, epoch):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    n = r.num_iterations(loader_size=len(val_loader))
    if args.num_minibatches is not None:
        n = min(n, args.num_minibatches)
    r.eval(n)
    if not is_first_stage(): val_loader = None
    r.set_loader(val_loader)

    end = time.time()
    epoch_start_time = time.time()

    if args.no_input_pipelining:
        num_warmup_minibatches = 0
    else:
        num_warmup_minibatches = r.num_warmup_minibatches

    if args.verbose_frequency > 0:
        print("Letting in %d warm-up minibatches" % num_warmup_minibatches)
        print("Running validation for %d minibatches" % n)

    with torch.no_grad():
        for i in range(num_warmup_minibatches):
            r.run_forward()

        for i in range(n - num_warmup_minibatches):
            # perform forward pass
            r.run_forward()
            r.run_ack()

            if is_last_stage():
                output, target, loss, num_tokens = r.output, r.target, r.loss.item(
                ), r.num_tokens()

                # measure accuracy and record loss
                # prec1, prec5 = accuracy(output, target, topk=(1, 5))
                losses.update(loss, output.size(0))
                # top1.update(prec1[0], output.size(0))
                # top5.update(prec5[0], output.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                if i % args.print_freq == 0:
                    print(
                        'Test: [{0}][{1}/{2}]\t'
                        'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                        'Memory: {memory:.3f} ({cached_memory:.3f})\t'
                        'Loss: {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                            epoch,
                            i,
                            n,
                            batch_time=batch_time,
                            loss=losses,
                            memory=(float(torch.cuda.memory_allocated()) /
                                    10**9),
                            cached_memory=(float(torch.cuda.memory_cached()) /
                                           10**9)))
                    import sys
                    sys.stdout.flush()

        if is_last_stage():
            print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'.format(
                top1=top1, top5=top5))

        for i in range(num_warmup_minibatches):
            r.run_ack()

        # wait for all helper threads to complete
        r.wait()

        print('Epoch %d: %.3f seconds' %
              (epoch, time.time() - epoch_start_time))
        print("Epoch start time: %.3f, epoch end time: %.3f" %
              (epoch_start_time, time.time()))

    return top1.avg
コード例 #13
0
def train(train_loader, r, optimizer, epoch, lr_scheduler):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    n = 10000
    # n = r.num_iterations(loader_size=len(train_loader))
    if args.num_minibatches is not None:
        n = min(n, args.num_minibatches)
    # accumulation = 32
    accumulation = n
    n -= (n % accumulation)
    assert n % accumulation == 0
    r.train(n)
    if not is_first_stage(): train_loader = None
    r.set_loader(train_loader)

    end = time.time()
    epoch_start_time = time.time()

    if args.no_input_pipelining:
        num_warmup_minibatches = 0
    else:
        num_warmup_minibatches = r.num_warmup_minibatches

    if args.verbose_frequency > 0:
        print("Letting in %d warm-up minibatches" % num_warmup_minibatches)
        print("Running training for %d minibatches" % n)

    r.set_loss_scale(4 / accumulation)
    total_updates = n // accumulation
    for t in range(total_updates):
        # start num_warmup_minibatches forward passes
        for i in range(num_warmup_minibatches):
            r.run_forward()

        for i in range(accumulation - num_warmup_minibatches):
            end = time.time()
            # perform forward pass
            r.run_forward()

            if is_last_stage():
                # measure accuracy and record loss
                output, target, loss, num_tokens = r.output, r.target, r.loss.item(
                ), r.num_tokens()
                # print(loss, num_tokens)
                losses.update(loss / num_tokens / math.log(2), num_tokens)

            # perform backward pass
            r.run_backward()

            if is_last_stage():
                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()
                epoch_time = (end - epoch_start_time) / 3600.0
                full_epoch_time = (epoch_time /
                                   float(accumulation * t + i + 1)) * float(n)

                if (t * accumulation + i +
                        num_warmup_minibatches) % args.print_freq == 0:
                    print(
                        'Stage: [{0}] Epoch: [{1}][{2}/{3}]\t'
                        'Time({timestamp}): {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                        'Epoch time [hr]: {epoch_time:.3f} ({full_epoch_time:.3f})\t'
                        'Memory: {memory:.3f} ({cached_memory:.3f})\t'
                        'Loss: {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                            args.stage,
                            epoch,
                            accumulation * t + i + 1,
                            n,
                            timestamp=time.time(),
                            batch_time=batch_time,
                            epoch_time=epoch_time,
                            full_epoch_time=full_epoch_time,
                            loss=losses,  # top1=top1, top5=top5,
                            memory=(float(torch.cuda.memory_allocated()) /
                                    10**9),
                            cached_memory=(float(torch.cuda.memory_cached()) /
                                           10**9)))
                    import sys
                    sys.stdout.flush()
            else:
                if i + num_warmup_minibatches == accumulation - 1 and (
                        t * accumulation + i + num_warmup_minibatches
                ) % args.print_freq < accumulation:
                    print(
                        'Stage: [{0}] Epoch: [{1}][{2}/{3}]\tMemory: {memory:.3f} ({cached_memory:.3f})'
                        .format(
                            args.stage,
                            epoch,
                            accumulation * t + i + 1,
                            n,
                            memory=(float(torch.cuda.memory_allocated()) /
                                    10**9),
                            cached_memory=(float(torch.cuda.memory_cached()) /
                                           10**9)))
                    import sys
                    sys.stdout.flush()

            # if i == 500 and args.local_rank == 0:
            #     subprocess.Popen(['python', 'usage.py', 'gpu.log'])

        # finish remaining backward passes
        for i in range(num_warmup_minibatches):
            r.run_backward()

        # optimizer.step()
        if args.fp16:
            r.zero_grad()
        else:
            optimizer.zero_grad()
        num_updates = epoch * total_updates + t + 1
        lr_scheduler.step_update(num_updates)

    # wait for all helper threads to complete
    r.wait()

    print("Epoch %d: %.3f seconds" % (epoch, time.time() - epoch_start_time))
    print("Epoch start time: %.3f, epoch end time: %.3f" %
          (epoch_start_time, time.time()))
コード例 #14
0
class DDPTrainer():
    """Main class for data parallel training.

    This class supports data parallel training, where multiple workers each
    have a full model replica and gradients are accumulated synchronously via
    torch.distributed.all_reduce.
    """
    def __init__(self, args, model):

        if not torch.cuda.is_available():
            raise NotImplementedError('Training on CPU is not supported')

        self.args = args

        self.model = model.cuda()
        self.criterion = CRITERION_REGISTRY[args.criterion](args).cuda()
        self.optimizer = optim.build_optimizer(self.args,
                                               self.model.parameters())
        self.lr_scheduler = lr_scheduler.build_lr_scheduler(
            self.args, self.optimizer)
        self.scaler = amp.GradScaler(enabled=self.args.amp, init_scale=2**15)

        if self.args.distributed_world_size > 1:
            self.model = DDP(model)

        self._buffered_stats = defaultdict(lambda: [])
        self._num_updates = 0
        self._optim_history = None
        self.throughput_meter = TimeMeter()
        self.avg_loss_meter = AverageMeter()

    def save_checkpoint(self, filename, extra_state):
        """Save all training state in a checkpoint file."""
        if distributed_utils.is_master(self.args):  # only save one checkpoint
            utils.save_state(
                filename,
                self.args,
                self.get_model(),
                self.criterion,
                self.optimizer,
                self.lr_scheduler,
                self._num_updates,
                self._optim_history,
                extra_state,
            )

    def load_checkpoint(self, filename, load_optim=True):
        """Load all training state from a checkpoint file."""
        extra_state, optim_history, last_optim_state = \
            utils.load_model_state(filename, self.get_model())

        if last_optim_state is not None:
            # rebuild optimizer after loading model, since params may have changed
            #self.optimizer = optim.build_optimizer(self.args, self.model.parameters())
            self.lr_scheduler = lr_scheduler.build_lr_scheduler(
                self.args, self.optimizer)

            if load_optim:
                self._optim_history = optim_history
                # only reload optimizer and lr_scheduler if they match
                last_optim = self._optim_history[-1]
                if last_optim[
                        'criterion_name'] == self.criterion.__class__.__name__:
                    self.lr_scheduler.load_state_dict(
                        last_optim['lr_scheduler_state'])
                    if last_optim[
                            'optimizer_name'] == self.optimizer.__class__.__name__:
                        self.optimizer.load_state_dict(last_optim_state)

                self._num_updates = last_optim['num_updates']

        return extra_state

    def train_step(self, sample, update_params=True, last_step=False):
        """Do forward, backward and parameter update."""
        # Set seed based on args.seed and the update number so that we get
        # reproducible results when resuming from checkpoints
        seed = self.args.seed + self.get_num_updates()
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

        self.model.train()
        if isinstance(self.model, DDP):
            if last_step:
                self.model.disable_allreduce()
            else:
                self.model.enable_allreduce()

        # forward and backward pass
        sample = self._prepare_sample(sample)
        loss, oom_fwd = self._forward(sample)

        # If this is a last batch forward pass is skipped on some workers
        # Batch with sample_size 0 is not accounted for in weighted loss
        logging_output = {
            'ntokens': sample['ntokens'] if sample is not None else 0,
            'nsentences':
            sample['target'].size(0) if sample is not None else 0,
            'loss': utils.item(loss.data) if loss is not None else 0,
        }
        sample_size = sample['ntokens'] if sample is not None else 0
        oom_bwd = self._backward(loss)

        # buffer stats and logging outputs
        self._buffered_stats['sample_sizes'].append(sample_size)
        self._buffered_stats['logging_outputs'].append(logging_output)
        self._buffered_stats['ooms_fwd'].append(oom_fwd)
        self._buffered_stats['ooms_bwd'].append(oom_bwd)

        # update parameters
        if update_params and not last_step:
            # gather logging outputs from all replicas
            sample_sizes = self._buffered_stats['sample_sizes']
            logging_outputs = self._buffered_stats['logging_outputs']
            ooms_fwd = self._buffered_stats['ooms_fwd']
            ooms_bwd = self._buffered_stats['ooms_bwd']
            if self.args.distributed_world_size > 1:
                sample_sizes, logging_outputs, ooms_fwd, ooms_bwd = map(
                    lambda l: list(chain.from_iterable(l)),
                    zip(*distributed_utils.all_gather_list((sample_sizes,
                                                            logging_outputs,
                                                            ooms_fwd,
                                                            ooms_bwd))))
            ooms_fwd = sum(ooms_fwd)
            ooms_bwd = sum(ooms_bwd)
            ooms = ooms_fwd + ooms_bwd  # this is always <= distributed_world_size

            if ooms == self.args.distributed_world_size:
                print('| WARNING: OOM in all workers, skipping batch')
                self.zero_grad()
                return

            # aggregate stats and logging outputs
            grad_denom = sum(sample_sizes)
            for p in self.model.parameters():
                if p.requires_grad and p.grad is not None:
                    p.grad /= grad_denom

            self._opt()

            # Handle logging
            ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
            self.throughput_meter.update(ntokens)
            info_log_data = {
                'tokens/s':
                self.throughput_meter.avg,
                'tokens':
                ntokens,
                'loss':
                sum(log.get('loss', 0)
                    for log in logging_outputs) / ntokens / math.log(2)
            }
            self.avg_loss_meter.update(info_log_data['loss'])
            debug_log_data = {
                'batch_size':
                sum(log.get('nsentences', 0) for log in logging_outputs),
                'lr':
                self.get_lr(),
                'grad_denom':
                grad_denom,
                'updates':
                1
            }

            DLLogger.log(step=self._num_updates,
                         data=info_log_data,
                         verbosity=0)
            DLLogger.log(step=self._num_updates,
                         data=debug_log_data,
                         verbosity=1)

            self.clear_buffered_stats()

    def _forward(self, sample):
        loss = None
        oom = 0
        try:
            if sample is not None:
                with amp.autocast(enabled=self.args.amp):
                    # calculate loss and sample size
                    logits, _ = self.model(**sample['net_input'])
                    target = sample['target']
                    probs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
                    loss = self.criterion(probs, target)
        except RuntimeError as e:
            if 'out of memory' in str(e):
                print(
                    '| WARNING: ran out of memory in worker {}, skipping batch'
                    .format(self.args.distributed_rank),
                    force=True)
                oom = 1
                loss = None
            else:
                raise e
        return loss, oom

    def _backward(self, loss):
        oom = 0
        if loss is not None:
            try:
                self.scaler.scale(loss).backward()
            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print(
                        '| WARNING: ran out of memory in worker {}, skipping batch'
                        .format(self.args.distributed_rank),
                        force=True)
                    oom = 1
                    self.zero_grad()
                else:
                    raise e
        return oom

    def _opt(self):
        # take an optimization step
        self.scaler.step(self.optimizer.optimizer)
        self.scaler.update()
        self.zero_grad()
        self._num_updates += 1

        # update learning rate
        self.lr_scheduler.step_update(self._num_updates)

    def valid_step(self, sample):
        """Do forward pass in evaluation mode."""
        self.model.eval()
        # forward pass
        sample = self._prepare_sample(sample)
        with torch.no_grad():
            loss, oom_fwd = self._forward(sample)
        logging_output = {
            'ntokens': sample['ntokens'] if sample is not None else 0,
            'nsentences':
            sample['target'].size(0) if sample is not None else 0,
        }
        loss = loss.item() if loss is not None else 0
        assert not oom_fwd, 'Ran out of memory during validation'

        # gather logging outputs from all GPUs
        if self.args.distributed_world_size > 1:
            losses, logging_outputs = zip(
                *distributed_utils.all_gather_list((loss, logging_output)))
        else:
            losses = [loss]
            logging_outputs = [logging_output]

        weight = sum(log.get('ntokens', 0) for log in logging_outputs)
        scaled_loss = sum(losses) / weight / math.log(2)

        return scaled_loss

    def dummy_train_step(self, dummy_batch):
        """Dummy training step for warming caching allocator."""
        self.train_step(dummy_batch, update_params=False)
        self.zero_grad()
        self.clear_buffered_stats()

    def zero_grad(self):
        self.optimizer.zero_grad()

    def clear_buffered_stats(self):
        self._buffered_stats.clear()

    def lr_step(self, epoch, val_loss=None):
        """Adjust the learning rate based on the validation loss."""
        return self.lr_scheduler.step(epoch, val_loss)

    def lr_step_update(self, num_updates):
        """Update the learning rate after each update."""
        return self.lr_scheduler.step_update(num_updates)

    def get_lr(self):
        """Get the current learning rate."""
        return self.optimizer.get_lr()

    def get_throughput_meter(self):
        """Get the throughput meter"""
        return self.throughput_meter

    def get_model(self):
        """Get the model replica."""
        return self.model.module if isinstance(self.model, DDP) else self.model

    def get_num_updates(self):
        """Get the number of parameters updates."""
        return self._num_updates

    def _prepare_sample(self, sample):
        if not sample:
            return None
        return utils.move_to_cuda(sample)
コード例 #15
0
ファイル: main.py プロジェクト: sosp21paper326/naspipe
def main():

    args = parser.parse_args()

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)
    model = TransformerModel.build_model(args, task).cuda()
    criterion = task.build_criterion(args).cuda()
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 betas=eval(args.adam_betas),
                                 eps=args.adam_eps,
                                 weight_decay=args.weight_decay)

    # Load dataset splits
    load_dataset_splits(task, ['train', 'valid'])

    epoch_itr = data.EpochBatchIterator(
        dataset=task.dataset(args.train_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences_valid,
        max_positions=(args.max_source_positions, args.max_target_positions),
        ignore_invalid_inputs=True,
        required_batch_size_multiple=8,
        seed=1,
        num_shards=1,
        shard_id=0,
    )

    losses = AverageMeter()

    encoder_layer_forward = [
        AverageMeter() for _ in range(len(model.encoder.layers[0].layer))
    ]
    decoder_layer_forward = [
        AverageMeter() for _ in range(len(model.decoder.layers[0].layer))
    ]
    encoder_layer_backward = [
        AverageMeter() for _ in range(len(model.encoder.layers[0].layer))
    ]
    decoder_layer_backward = [
        AverageMeter() for _ in range(len(model.decoder.layers[0].layer))
    ]

    def measure_hook(forward, backward):
        def hook(module, input, output):
            for i, layer in enumerate(module.layer):

                if len(input) == 2:
                    x, _ = input
                else:
                    x, = input
                x = x.detach().clone().requires_grad_()

                # warm-up
                for _ in range(5):
                    if isinstance(layer, nn.MultiheadAttention):
                        out, _ = layer(x, x, x)
                    else:
                        out = layer(x)
                    torch.autograd.backward(out, out)

                starter, ender = torch.cuda.Event(
                    enable_timing=True), torch.cuda.Event(enable_timing=True)
                for _ in range(50):
                    starter.record()
                    if isinstance(layer, nn.MultiheadAttention):
                        out, _ = layer(x, x, x)
                    else:
                        out = layer(x)
                    ender.record()
                    torch.cuda.synchronize()
                    forward[i].update(starter.elapsed_time(ender))

                    starter.record()
                    torch.autograd.backward(out, out)
                    ender.record()
                    torch.cuda.synchronize()
                    backward[i].update(starter.elapsed_time(ender))

        return hook

    for layer in model.encoder.layers:
        layer.register_forward_hook(
            measure_hook(encoder_layer_forward, encoder_layer_backward))

    for layer in model.decoder.layers:
        layer.register_forward_hook(
            measure_hook(decoder_layer_forward, decoder_layer_backward))

    embed_forward = AverageMeter()
    embed_backward = AverageMeter()

    def embed_hook(module, input, output):
        tokens, _ = input

        # warm-up
        for _ in range(5):
            x = module.embed_scale * module.embed_tokens(tokens)
            x += module.embed_positions(tokens)
            torch.autograd.backward(x, x)

        starter, ender = torch.cuda.Event(
            enable_timing=True), torch.cuda.Event(enable_timing=True)
        for _ in range(50):
            starter.record()
            x = module.embed_scale * module.embed_tokens(tokens)
            x += module.embed_positions(tokens)
            ender.record()
            torch.cuda.synchronize()
            embed_forward.update(starter.elapsed_time(ender))

            starter.record()
            torch.autograd.backward(x, x)
            ender.record()
            torch.cuda.synchronize()
            embed_backward.update(starter.elapsed_time(ender))

    model.encoder.register_forward_hook(embed_hook)

    linear_forward = AverageMeter()
    linear_backward = AverageMeter()

    def linear_hook(module, input, output):
        _, encode_out = input
        encode_out = encode_out.detach().clone().requires_grad_()

        # warm-up
        for _ in range(5):
            x = encode_out.transpose(0, 1)
            out = F.linear(x, module.embed_out)
            torch.autograd.backward(out, out)

        starter, ender = torch.cuda.Event(
            enable_timing=True), torch.cuda.Event(enable_timing=True)
        for _ in range(50):
            starter.record()
            x = encode_out.transpose(0, 1)
            out = F.linear(x, module.embed_out)
            ender.record()
            torch.cuda.synchronize()
            linear_forward.update(starter.elapsed_time(ender))

            starter.record()
            torch.autograd.backward(out, out)
            ender.record()
            torch.cuda.synchronize()
            linear_backward.update(starter.elapsed_time(ender))

    model.decoder.register_forward_hook(linear_hook)

    itr = epoch_itr.next_epoch_itr()
    max_positions = (args.max_source_positions, args.max_target_positions)
    for i, sample in enumerate(itr):
        sample = task.dataset('train').get_dummy_batch(args.max_tokens,
                                                       max_positions)
        sample = utils.move_to_cuda(sample)
        loss, _, logging_output = criterion(model, sample)
        num_tokens = logging_output['ntokens']
        losses.update(loss.item() / num_tokens / math.log(2), num_tokens)
        if i % 100 == 0:
            print('Loss: {loss.val:.4f} ({loss.avg:.4f})'.format(loss=losses))
            print(
                'Time: {forward_time.avg:.3f} ({backward_time.avg:.3f})'
                '{forward_time_decoder.avg:.3f} ({backward_time_decoder.avg:.3f})'
                .format(forward_time=encoder_layer_forward[0],
                        backward_time=encoder_layer_backward[0],
                        forward_time_decoder=decoder_layer_forward[-1],
                        backward_time_decoder=decoder_layer_backward[-1]))
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        break

    stat = {i: {} for i in range(len(decoder_layer_forward))}
    for i, (f,
            b) in enumerate(zip(encoder_layer_forward,
                                encoder_layer_backward)):
        stat[i]['encoder'] = {}
        stat[i]['encoder']['forward'] = f.avg
        stat[i]['encoder']['backward'] = b.avg

    for i, (f,
            b) in enumerate(zip(decoder_layer_forward,
                                decoder_layer_backward)):
        stat[i]['decoder'] = {}
        stat[i]['decoder']['forward'] = f.avg
        stat[i]['decoder']['backward'] = b.avg

    stat['embed'] = {}
    stat['embed']['forward'] = embed_forward.avg
    stat['embed']['backward'] = embed_backward.avg

    stat['linear'] = {}
    stat['linear']['forward'] = linear_forward.avg
    stat['linear']['backward'] = linear_backward.avg

    with open('time.json', 'w') as file:
        json.dump(stat, file, indent=4)
コード例 #16
0
ファイル: train.py プロジェクト: ahiroto/ParlAI
def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
    """Train the model for one epoch."""

    seed = args.seed + epoch
    torch.manual_seed(seed)
    trainer.set_seed(seed)

    itr = dataset.train_dataloader(
        args.train_subset, num_workers=args.workers,
        max_tokens=args.max_tokens, max_sentences=args.max_sentences,
        max_positions=max_positions, seed=seed, epoch=epoch,
        sample_without_replacement=args.sample_without_replacement,
        sort_by_source_size=(epoch <= args.curriculum))
    loss_meter = AverageMeter()
    bsz_meter = AverageMeter()    # sentences per batch
    wpb_meter = AverageMeter()    # words per batch
    wps_meter = TimeMeter()       # words per second
    clip_meter = AverageMeter()   # % of updates clipped
    extra_meters = collections.defaultdict(lambda: AverageMeter())

    lr = trainer.get_lr()
    with utils.build_progress_bar(args, itr, epoch) as t:
        for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset):
            loss_dict = trainer.train_step(sample)
            loss = loss_dict['loss']
            del loss_dict['loss']  # don't include in extra_meters or extra_postfix

            ntokens = sum(s['ntokens'] for s in sample)
            nsentences = sum(s['src_tokens'].size(0) for s in sample)
            loss_meter.update(loss, nsentences if args.sentence_avg else ntokens)
            bsz_meter.update(nsentences)
            wpb_meter.update(ntokens)
            wps_meter.update(ntokens)
            clip_meter.update(1 if loss_dict['gnorm'] > args.clip_norm else 0)

            extra_postfix = []
            for k, v in loss_dict.items():
                extra_meters[k].update(v)
                extra_postfix.append((k, extra_meters[k].avg))

            t.log(collections.OrderedDict([
                ('loss', loss_meter),
                ('wps', round(wps_meter.avg)),
                ('wpb', round(wpb_meter.avg)),
                ('bsz', round(bsz_meter.avg)),
                ('lr', lr),
                ('clip', '{:.0%}'.format(clip_meter.avg)),
            ] + extra_postfix))

            if i == 0:
                # ignore the first mini-batch in words-per-second calculation
                wps_meter.reset()
            if args.save_interval > 0 and (i + 1) % args.save_interval == 0:
                save_checkpoint(trainer, args, epoch, i + 1)

        t.print(collections.OrderedDict([
            ('train loss', round(loss_meter.avg, 2)),
            ('train ppl', get_perplexity(loss_meter.avg)),
            ('s/checkpoint', round(wps_meter.elapsed_time)),
            ('words/s', round(wps_meter.avg)),
            ('words/batch', round(wpb_meter.avg)),
            ('bsz', round(bsz_meter.avg)),
            ('lr', lr),
            ('clip', '{:3.0f}%'.format(clip_meter.avg * 100)),
        ] + [
            (k, meter.avg)
            for k, meter in extra_meters.items()
        ]))
コード例 #17
0
def _main(args, output_file):
    logging.basicConfig(
        format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        level=logging.INFO,
        stream=output_file,
    )
    logger = logging.getLogger('fairseq_cli.generate')

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)

    # Set dictionaries
    try:
        src_dict = getattr(task, 'source_dictionary', None)
    except NotImplementedError:
        src_dict = None
    tgt_dict = task.target_dictionary

    # Load ensemble
    logger.info('loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(os.pathsep),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    from fairseq.sequence_scorer import SequenceScorer
    generator = task.build_generator(args)

    # Generate and compute BLEU score
    if args.sacrebleu:
        scorer = bleu.SacrebleuScorer()
    else:
        scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    avg_ranks = AverageMeter()
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        all_ents = []
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            gen_timer.start()
            hypos = task.inference_step(generator, models, sample,
                                        prefix_tokens)
            if 'ents' in sample:
                all_ents.extend(sample['ents'])
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None

                # Remove padding
                src_tokens = utils.strip_pad(
                    sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
                target_tokens = None
                if has_target:
                    target_tokens = utils.strip_pad(
                        sample['target'][i, :], tgt_dict.pad()).int().cpu()

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(
                        args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(
                        args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.remove_bpe)
                    else:
                        src_str = ""
                    if has_target:
                        target_str = tgt_dict.string(target_tokens,
                                                     args.remove_bpe,
                                                     escape_unk=True)

                if not args.quiet:
                    if src_dict is not None:
                        print('S-{}\t{}'.format(sample_id, src_str),
                              file=output_file)
                    if has_target:
                        print('T-{}\t{}'.format(sample_id, target_str),
                              file=output_file)

                # Process top predictions
                for j, hypo in enumerate(hypos[i][:args.nbest]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
                        alignment=hypo['alignment'],
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                    if not args.quiet:
                        score = hypo['score'] / math.log(
                            2)  # convert to base 2
                        print('H-{}\t{}\t{}'.format(sample_id, score,
                                                    hypo_str),
                              file=output_file)
                        print(
                            'P-{}\t{}'.format(
                                sample_id,
                                ' '.join(
                                    map(
                                        lambda x: '{:.4f}'.format(x),
                                        # convert from base e to base 2
                                        hypo['positional_scores'].div_(
                                            math.log(2)).tolist(),
                                    ))),
                            file=output_file)

                        if args.print_alignment:
                            print('A-{}\t{}'.format(
                                sample_id, ' '.join([
                                    '{}-{}'.format(src_idx, tgt_idx)
                                    for src_idx, tgt_idx in alignment
                                ])),
                                  file=output_file)

                        if args.print_step:
                            print('I-{}\t{}'.format(sample_id, hypo['steps']),
                                  file=output_file)

                        if getattr(args, 'retain_iter_history', False):
                            for step, h in enumerate(hypo['history']):
                                _, h_str, _ = utils.post_process_prediction(
                                    hypo_tokens=h['tokens'].int().cpu(),
                                    src_str=src_str,
                                    alignment=None,
                                    align_dict=None,
                                    tgt_dict=tgt_dict,
                                    remove_bpe=None,
                                )
                                print('E-{}_{}\t{}'.format(
                                    sample_id, step, h_str),
                                      file=output_file)

                        if getattr(args, 'score_reference', False):
                            print('R-{}\t{}'.format(
                                sample_id, '{:.4f}'.format(hypo['avg_ranks'])),
                                  file=output_file)

                    # Score only the top hypothesis
                    if getattr(args, 'score_reference', False):
                        avg_ranks.update(hypo['avg_ranks'])
                    if has_target and j == 0:
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tgt_dict.encode_line(
                                target_str, add_if_not_exist=True)
                        if hasattr(scorer, 'add_string'):
                            scorer.add_string(target_str, hypo_str)
                        else:
                            scorer.add(target_tokens, hypo_tokens)
            wps_meter.update(num_generated_tokens)
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += sample['nsentences']

    logger.info('NOTE: hypothesis and token scores are output in base 2')
    logger.info(
        'Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'
        .format(num_sentences, gen_timer.n, gen_timer.sum,
                num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        logger.info('Generate {} with beam={}: {}'.format(
            args.gen_subset, args.beam, scorer.result_string()))
    if getattr(args, 'score_reference', False):
        logger.info('Average rank of reference={:.4f}, Entropy={:.4f}'.format(
            avg_ranks.avg,
            torch.cat(all_ents, dim=0).mean()))

    return scorer
コード例 #18
0
ファイル: train.py プロジェクト: Novemser/sum
def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus):
    """Train the model for one epoch."""

    itr = dataset.dataloader(
        args.train_subset,
        num_workers=args.workers,
        max_tokens=args.max_tokens,
        seed=args.seed,
        epoch=epoch,
        max_positions=args.max_positions,
        sample_without_replacement=args.sample_without_replacement)
    ###print("itr:"+str(itr))
    loss_meter = AverageMeter()
    bsz_meter = AverageMeter()  # sentences per batch
    wpb_meter = AverageMeter()  # words per batch
    wps_meter = TimeMeter()  # words per second
    clip_meter = AverageMeter()  # % of updates clipped
    gnorm_meter = AverageMeter()  # gradient norm

    desc = '| epoch {:03d}'.format(epoch)
    lr = trainer.get_lr()
    with progress_bar(itr, desc, leave=False) as t:
        for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset):
            ###print("i:"+str(i)+" sample:"+str(sample)) ###id,src_tokens,input_tokens,input_positions,target,src_positions,ntokens
            ###print("i:"+str(i)+" sample len:"+str(len(sample))+" sample id:"+str(sample[0]['id'])+" sample src_tokens:"+str(sample[0]['src_tokens'][0]))
            aggregate_res = trainer.train_step(sample, criterion)
            mixed_loss = aggregate_res.loss
            ml_loss = aggregate_res.ml_loss
            grad_norm = aggregate_res.grad_norm
            mixed_loss = aggregate_res.loss
            rl_loss = aggregate_res.rl_loss
            mean_rouge_greedy = aggregate_res.mean_rouge_greedy
            mean_rouge_sampled = aggregate_res.mean_rouge_sampled
            mean_sum_log_prob = aggregate_res.mean_sum_log_prob

            ntokens = sum(s['ntokens'] for s in sample)
            src_size = sum(s['src_tokens'].size(0) for s in sample)
            loss_meter.update(ml_loss, ntokens)
            bsz_meter.update(src_size)
            wpb_meter.update(ntokens)
            wps_meter.update(ntokens)
            clip_meter.update(1 if grad_norm > args.clip_norm else 0)
            gnorm_meter.update(grad_norm)

            t.set_postfix(
                collections.OrderedDict([
                    ('loss', '{:.2f} ({:.2f})'.format(ml_loss,
                                                      loss_meter.avg)),
                    ('wps', '{:5d}'.format(round(wps_meter.avg))),
                    ('wpb', '{:5d}'.format(round(wpb_meter.avg))),
                    ('bsz', '{:5d}'.format(round(bsz_meter.avg))),
                    ('lr', lr),
                    ('clip', '{:3.0f}%'.format(clip_meter.avg * 100)),
                    ('gnorm', '{:.4f}'.format(gnorm_meter.avg)),
                ]))

            if args.enable_rl:
                fmt_other = 'mixed_loss: {:^10.4f} | ml_loss: {:^10.4f}'
                fmt_other += '| rl_loss: {:^10.4f} | mean_rouge_greedy: {:^10.4f}'
                fmt_other += '| mean_rouge_sampled: {:^10.4f} | mean_sum_log_prob: {:^10.4f}'
                print(
                    fmt_other.format(mixed_loss, ml_loss, rl_loss,
                                     mean_rouge_greedy, mean_rouge_sampled,
                                     mean_sum_log_prob))

            if i == 0:
                # ignore the first mini-batch in words-per-second calculation
                wps_meter.reset()
            if args.save_interval > 0 and (i + 1) % args.save_interval == 0:
                trainer.save_checkpoint(args, epoch, i + 1)

        fmt = desc + ' | train loss {:2.2f} | train ppl {:3.2f}'
        fmt += ' | s/checkpoint {:7d} | words/s {:6d} | words/batch {:6d}'
        fmt += ' | bsz {:5d} | lr {:0.6f} | clip {:3.0f}% | gnorm {:.4f}'
        t.write(
            fmt.format(loss_meter.avg, math.pow(2, loss_meter.avg),
                       round(wps_meter.elapsed_time), round(wps_meter.avg),
                       round(wpb_meter.avg), round(bsz_meter.avg), lr,
                       clip_meter.avg * 100, gnorm_meter.avg))