Ejemplo n.º 1
0
def skyline_iteration_provider(model):
    args = get_args()
    opt_config = {
        'optimizer': args.optimizer,
        'lr': args.lr,
    }
    scheduler_config = {
        'warmup_steps': args.warmup_steps,
        'remain_steps': args.remain_steps,
        'decay_interval': args.decay_interval,
        'decay_steps': args.decay_steps,
        'decay_factor': args.decay_factor,
    }

    train_loader_len = 437268
    total_train_iters = train_loader_len // args.train_iter_size * args.epochs
    opt_name = opt_config.pop('optimizer')
    optimizer = torch.optim.__dict__[opt_name](model.parameters(),
                                               **opt_config)
    scheduler = WarmupMultiStepLR(optimizer, total_train_iters,
                                  **scheduler_config)
    fp_optimizer = Fp32Optimizer(model, args.grad_clip)

    def iteration(src, src_len, tgt, tgt_len):
        loss = model(src, src_len, tgt, tgt_len)
        loss.backward()
        fp_optimizer.step(optimizer, scheduler, update=True)

    return iteration
Ejemplo n.º 2
0
    def __init__(self,
                 model,
                 criterion,
                 opt_config,
                 scheduler_config,
                 print_freq=10,
                 save_freq=1000,
                 grad_clip=float('inf'),
                 batch_first=False,
                 save_info={},
                 save_path='.',
                 train_iterations=0,
                 checkpoint_filename='checkpoint%s.pth',
                 keep_checkpoints=5,
                 math='fp32',
                 cuda=True,
                 distributed=False,
                 intra_epoch_eval=0,
                 iter_size=1,
                 translator=None,
                 verbose=False):
        """
        Constructor for the Seq2SeqTrainer.

        :param model: model to train
        :param criterion: criterion (loss function)
        :param opt_config: dictionary with options for the optimizer
        :param scheduler_config: dictionary with options for the learning rate
            scheduler
        :param print_freq: prints short summary every 'print_freq' iterations
        :param save_freq: saves checkpoint every 'save_freq' iterations
        :param grad_clip: coefficient for gradient clipping
        :param batch_first: if True the model uses (batch,seq,feature) tensors,
            if false the model uses (seq, batch, feature)
        :param save_info: dict with additional state stored in each checkpoint
        :param save_path: path to the directiory for checkpoints
        :param train_iterations: total number of training iterations to execute
        :param checkpoint_filename: name of files with checkpoints
        :param keep_checkpoints: max number of checkpoints to keep
        :param math: arithmetic type
        :param cuda: if True use cuda, if False train on cpu
        :param distributed: if True run distributed training
        :param intra_epoch_eval: number of additional eval runs within each
            training epoch
        :param iter_size: number of iterations between weight updates
        :param translator: instance of Translator, runs inference on test set
        :param verbose: enables verbose logging
        """
        super(Seq2SeqTrainer, self).__init__()
        self.model = model
        self.criterion = criterion
        self.epoch = 0
        self.save_info = save_info
        self.save_path = save_path
        self.save_freq = save_freq
        self.save_counter = 0
        self.checkpoint_filename = checkpoint_filename
        self.checkpoint_counter = cycle(range(keep_checkpoints))
        self.opt_config = opt_config
        self.cuda = cuda
        self.distributed = distributed
        self.print_freq = print_freq
        self.batch_first = batch_first
        self.verbose = verbose
        self.loss = None
        self.translator = translator
        self.intra_epoch_eval = intra_epoch_eval
        self.iter_size = iter_size

        if cuda:
            self.model = self.model.cuda()
            self.criterion = self.criterion.cuda()

        if math == 'fp16':
            self.model = self.model.half()

        if distributed:
            self.model = DDP(self.model)

        if math == 'fp16':
            self.fp_optimizer = Fp16Optimizer(self.model, grad_clip)
            params = self.fp_optimizer.fp32_params
        elif math == 'fp32':
            self.fp_optimizer = Fp32Optimizer(self.model, grad_clip)
            params = self.model.parameters()

        opt_name = opt_config.pop('optimizer')
        self.optimizer = torch.optim.__dict__[opt_name](params, **opt_config)
        logging.info(f'Using optimizer: {self.optimizer}')

        self.scheduler = WarmupMultiStepLR(self.optimizer, train_iterations,
                                           **scheduler_config)
Ejemplo n.º 3
0
class Seq2SeqTrainer:
    """
    Seq2SeqTrainer
    """
    def __init__(self,
                 model,
                 criterion,
                 opt_config,
                 scheduler_config,
                 print_freq=10,
                 save_freq=1000,
                 grad_clip=float('inf'),
                 batch_first=False,
                 save_info={},
                 save_path='.',
                 train_iterations=0,
                 checkpoint_filename='checkpoint%s.pth',
                 keep_checkpoints=5,
                 math='fp32',
                 cuda=True,
                 distributed=False,
                 intra_epoch_eval=0,
                 iter_size=1,
                 translator=None,
                 verbose=False):
        """
        Constructor for the Seq2SeqTrainer.

        :param model: model to train
        :param criterion: criterion (loss function)
        :param opt_config: dictionary with options for the optimizer
        :param scheduler_config: dictionary with options for the learning rate
            scheduler
        :param print_freq: prints short summary every 'print_freq' iterations
        :param save_freq: saves checkpoint every 'save_freq' iterations
        :param grad_clip: coefficient for gradient clipping
        :param batch_first: if True the model uses (batch,seq,feature) tensors,
            if false the model uses (seq, batch, feature)
        :param save_info: dict with additional state stored in each checkpoint
        :param save_path: path to the directiory for checkpoints
        :param train_iterations: total number of training iterations to execute
        :param checkpoint_filename: name of files with checkpoints
        :param keep_checkpoints: max number of checkpoints to keep
        :param math: arithmetic type
        :param cuda: if True use cuda, if False train on cpu
        :param distributed: if True run distributed training
        :param intra_epoch_eval: number of additional eval runs within each
            training epoch
        :param iter_size: number of iterations between weight updates
        :param translator: instance of Translator, runs inference on test set
        :param verbose: enables verbose logging
        """
        super(Seq2SeqTrainer, self).__init__()
        self.model = model
        self.criterion = criterion
        self.epoch = 0
        self.save_info = save_info
        self.save_path = save_path
        self.save_freq = save_freq
        self.save_counter = 0
        self.checkpoint_filename = checkpoint_filename
        self.checkpoint_counter = cycle(range(keep_checkpoints))
        self.opt_config = opt_config
        self.cuda = cuda
        self.distributed = distributed
        self.print_freq = print_freq
        self.batch_first = batch_first
        self.verbose = verbose
        self.loss = None
        self.translator = translator
        self.intra_epoch_eval = intra_epoch_eval
        self.iter_size = iter_size

        if cuda:
            self.model = self.model.cuda()
            self.criterion = self.criterion.cuda()

        if math == 'fp16':
            self.model = self.model.half()

        if distributed:
            self.model = DDP(self.model)

        if math == 'fp16':
            self.fp_optimizer = Fp16Optimizer(self.model, grad_clip)
            params = self.fp_optimizer.fp32_params
        elif math == 'fp32':
            self.fp_optimizer = Fp32Optimizer(self.model, grad_clip)
            params = self.model.parameters()

        opt_name = opt_config.pop('optimizer')
        self.optimizer = torch.optim.__dict__[opt_name](params, **opt_config)
        logging.info(f'Using optimizer: {self.optimizer}')

        self.scheduler = WarmupMultiStepLR(self.optimizer, train_iterations,
                                           **scheduler_config)

    def iterate(self, src, tgt, update=True, training=True):
        """
        Performs one iteration of the training/validation.

        :param src: batch of examples from the source language
        :param tgt: batch of examples from the target language
        :param update: if True: optimizer does update of the weights
        :param training: if True: executes optimizer
        """
        src, src_length = src
        tgt, tgt_length = tgt
        src_length = torch.LongTensor(src_length)
        tgt_length = torch.LongTensor(tgt_length)

        num_toks = {}
        num_toks['tgt'] = int(sum(tgt_length - 1))
        num_toks['src'] = int(sum(src_length))

        if self.cuda:
            src = src.cuda()
            src_length = src_length.cuda()
            tgt = tgt.cuda()

        if self.batch_first:
            output = self.model(src, src_length, tgt[:, :-1])
            tgt_labels = tgt[:, 1:]
            T, B = output.size(1), output.size(0)
        else:
            output = self.model(src, src_length, tgt[:-1])
            tgt_labels = tgt[1:]
            T, B = output.size(0), output.size(1)

        loss = self.criterion(output.view(T * B, -1),
                              tgt_labels.contiguous().view(-1))

        loss_per_batch = loss.item()
        loss /= (B * self.iter_size)

        if training:
            self.fp_optimizer.step(loss, self.optimizer, self.scheduler,
                                   update)

        loss_per_token = loss_per_batch / num_toks['tgt']
        loss_per_sentence = loss_per_batch / B

        return loss_per_token, loss_per_sentence, num_toks

    def feed_data(self, data_loader, training=True):
        """
        Runs training or validation on batches from data_loader.

        :param data_loader: data loader
        :param training: if True runs training else runs validation
        """
        if training:
            assert self.optimizer is not None
            eval_fractions = np.linspace(0, 1, self.intra_epoch_eval + 2)[1:-1]
            iters_with_update = len(data_loader) // self.iter_size
            eval_iters = (eval_fractions * iters_with_update).astype(int)
            eval_iters = eval_iters * self.iter_size
            eval_iters = set(eval_iters)

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses_per_token = AverageMeter(skip_first=False)
        losses_per_sentence = AverageMeter(skip_first=False)

        tot_tok_time = AverageMeter()
        src_tok_time = AverageMeter()
        tgt_tok_time = AverageMeter()

        batch_size = data_loader.batch_size

        end = time.time()
        for i, (src, tgt) in enumerate(data_loader):
            self.save_counter += 1
            # measure data loading time
            data_time.update(time.time() - end)

            update = False
            if i % self.iter_size == self.iter_size - 1:
                update = True

            # do a train/evaluate iteration
            stats = self.iterate(src, tgt, update, training=training)
            loss_per_token, loss_per_sentence, num_toks = stats

            # measure accuracy and record loss
            losses_per_token.update(loss_per_token, num_toks['tgt'])
            losses_per_sentence.update(loss_per_sentence, batch_size)

            # measure elapsed time
            elapsed = time.time() - end
            batch_time.update(elapsed)
            src_tok_time.update(num_toks['src'] / elapsed)
            tgt_tok_time.update(num_toks['tgt'] / elapsed)
            tot_num_toks = num_toks['tgt'] + num_toks['src']
            tot_tok_time.update(tot_num_toks / elapsed)
            self.loss = losses_per_token.avg

            if training and i in eval_iters:
                test_bleu, _ = self.translator.run(calc_bleu=True,
                                                   epoch=self.epoch,
                                                   iteration=i)

                log = []
                log += [f'TRAIN [{self.epoch}][{i}/{len(data_loader)}]']
                log += [f'BLEU: {test_bleu:.2f}']
                log = '\t'.join(log)
                logging.info(log)

                self.model.train()
                self.preallocate(data_loader, training=True)

            if i % self.print_freq == 0:
                phase = 'TRAIN' if training else 'VALIDATION'
                log = []
                log += [f'{phase} [{self.epoch}][{i}/{len(data_loader)}]']
                log += [f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})']
                log += [f'Data {data_time.val:.2e} ({data_time.avg:.2e})']
                log += [
                    f'Tok/s {tot_tok_time.val:.0f} ({tot_tok_time.avg:.0f})'
                ]
                if self.verbose:
                    log += [
                        f'Src tok/s {src_tok_time.val:.0f} ({src_tok_time.avg:.0f})'
                    ]
                    log += [
                        f'Tgt tok/s {tgt_tok_time.val:.0f} ({tgt_tok_time.avg:.0f})'
                    ]
                    log += [
                        f'Loss/sentence {losses_per_sentence.val:.1f} ({losses_per_sentence.avg:.1f})'
                    ]
                log += [
                    f'Loss/tok {losses_per_token.val:.4f} ({losses_per_token.avg:.4f})'
                ]
                if training:
                    lr = self.optimizer.param_groups[0]['lr']
                    log += [f'LR {lr:.3e}']
                log = '\t'.join(log)
                logging.info(log)

            save_chkpt = (self.save_counter %
                          self.save_freq) == (self.save_freq - 1)
            if training and save_chkpt:
                self.save_counter = 0
                self.save_info['iteration'] = i
                identifier = next(self.checkpoint_counter, -1)
                if identifier != -1:
                    with sync_workers() as rank:
                        if rank == 0:
                            self.save(identifier=identifier)

            end = time.time()

        tot_tok_time.reduce('sum')
        losses_per_token.reduce('mean')

        return losses_per_token.avg, tot_tok_time.avg

    def preallocate(self, data_loader, training):
        """
        Generates maximum sequence length batch and runs forward and backward
        pass without updating model parameters.

        :param data_loader: data loader
        :param training: if True preallocates memory for backward pass
        """
        batch_size = data_loader.batch_size
        max_len = data_loader.dataset.max_len

        src_length = [max_len] * batch_size
        tgt_length = [max_len] * batch_size

        if self.batch_first:
            shape = (batch_size, max_len)
        else:
            shape = (max_len, batch_size)

        src = torch.full(shape, 4, dtype=torch.int64)
        tgt = torch.full(shape, 4, dtype=torch.int64)
        src = src, src_length
        tgt = tgt, tgt_length
        self.iterate(src, tgt, update=False, training=training)
        self.model.zero_grad()

    def optimize(self, data_loader):
        """
        Sets model in training mode, preallocates memory and runs training on
        data provided by data_loader.

        :param data_loader: data loader
        """
        torch.set_grad_enabled(True)
        self.model.train()
        torch.cuda.empty_cache()
        self.preallocate(data_loader, training=True)
        output = self.feed_data(data_loader, training=True)
        self.model.zero_grad()
        torch.cuda.empty_cache()
        return output

    def evaluate(self, data_loader):
        """
        Sets model in eval mode, disables gradients, preallocates memory and
        runs validation on data provided by data_loader.

        :param data_loader: data loader
        """
        torch.set_grad_enabled(False)
        self.model.eval()
        torch.cuda.empty_cache()
        self.preallocate(data_loader, training=False)
        output = self.feed_data(data_loader, training=False)
        self.model.zero_grad()
        torch.cuda.empty_cache()
        return output

    def load(self, filename):
        """
        Loads checkpoint from filename.

        :param filename: path to the checkpoint file
        """
        if os.path.isfile(filename):
            checkpoint = torch.load(filename, map_location={'cuda:0': 'cpu'})
            if self.distributed:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            self.fp_optimizer.initialize_model(self.model)
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.scheduler.load_state_dict(checkpoint['scheduler'])
            self.epoch = checkpoint['epoch']
            self.loss = checkpoint['loss']
            logging.info(f'Loaded checkpoint {filename} (epoch {self.epoch})')
        else:
            logging.error(f'Invalid checkpoint: {filename}')

    def save(self, identifier=None, is_best=False, save_all=False):
        """
        Stores checkpoint to a file.

        :param identifier: identifier for periodic checkpoint
        :param is_best: if True stores checkpoint to 'model_best.pth'
        :param save_all: if True stores checkpoint after completed training
            epoch
        """
        def write_checkpoint(state, filename):
            filename = os.path.join(self.save_path, filename)
            logging.info(f'Saving model to {filename}')
            torch.save(state, filename)

        if self.distributed:
            model_state = self.model.module.state_dict()
        else:
            model_state = self.model.state_dict()

        state = {
            'epoch': self.epoch,
            'state_dict': model_state,
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict(),
            'loss': getattr(self, 'loss', None),
        }
        state = dict(list(state.items()) + list(self.save_info.items()))

        if identifier is not None:
            filename = self.checkpoint_filename % identifier
            write_checkpoint(state, filename)

        if is_best:
            filename = 'model_best.pth'
            write_checkpoint(state, filename)

        if save_all:
            filename = f'checkpoint_epoch_{self.epoch:03d}.pth'
            write_checkpoint(state, filename)
Ejemplo n.º 4
0
    def __init__(self,
                 model,
                 criterion,
                 opt_config,
                 scheduler_config,
                 print_freq=10,
                 save_freq=1000,
                 grad_clip=float('inf'),
                 save_info={},
                 save_dir='.',
                 train_iterations=0,
                 checkpoint_filename='checkpoint%s.pth',
                 keep_checkpoints=5,
                 math='fp32',
                 loss_scaling={},
                 intra_epoch_eval=0,
                 prealloc_mode='always',
                 iter_size=1,
                 translator=None,
                 verbose=False):
        """
        Constructor for the Seq2SeqTrainer.

        :param model: model to train
        :param criterion: criterion (loss function)
        :param opt_config: dictionary with options for the optimizer
        :param scheduler_config: dictionary with options for the learning rate
            scheduler
        :param print_freq: prints short summary every 'print_freq' iterations
        :param save_freq: saves checkpoint every 'save_freq' iterations
        :param grad_clip: coefficient for gradient clipping
        :param save_info: dict with additional state stored in each checkpoint
        :param save_dir: path to the directiory for checkpoints
        :param train_iterations: total number of training iterations to execute
        :param checkpoint_filename: name of files with checkpoints
        :param keep_checkpoints: max number of checkpoints to keep
        :param math: arithmetic type
        :param loss_scaling: options for dynamic loss scaling
        :param intra_epoch_eval: number of additional eval runs within each
            training epoch
        :param prealloc_mode: controls preallocation,
            choices=['off', 'once', 'always']
        :param iter_size: number of iterations between weight updates
        :param translator: instance of Translator, runs inference on test set
        :param verbose: enables verbose logging
        """
        super(Seq2SeqTrainer, self).__init__()
        self.model = model
        self.criterion = criterion
        self.epoch = 0
        self.save_info = save_info
        self.save_dir = save_dir
        self.save_freq = save_freq
        self.save_counter = 0
        self.checkpoint_filename = checkpoint_filename
        self.checkpoint_counter = cycle(range(keep_checkpoints))
        self.opt_config = opt_config
        self.device = next(model.parameters()).device
        self.print_freq = print_freq
        self.verbose = verbose
        self.loss = None
        self.translator = translator
        self.intra_epoch_eval = intra_epoch_eval
        self.iter_size = iter_size
        self.prealloc_mode = prealloc_mode
        self.preallocated = False

        self.distributed = torch.distributed.is_initialized()
        self.batch_first = model.batch_first

        params = self.model.parameters()

        if math == 'manual_fp16':
            self.fp_optimizer = FP16Optimizer(
                self.model,
                grad_clip,
                loss_scale=loss_scaling['init_scale'],
                dls_upscale_interval=loss_scaling['upscale_interval'])
            params = self.fp_optimizer.fp32_params
        elif math == 'fp32':
            self.fp_optimizer = FP32Optimizer(self.model, grad_clip)

        opt_name = opt_config.pop('optimizer')
        self.optimizer = torch.optim.__dict__[opt_name](params, **opt_config)
        logging.info(f'Using optimizer: {self.optimizer}')

        self.scheduler = WarmupMultiStepLR(self.optimizer, train_iterations,
                                           **scheduler_config)

        if math == 'fp16':
            self.model, self.optimizer = amp.initialize(
                self.model,
                self.optimizer,
                cast_model_outputs=torch.float16,
                keep_batchnorm_fp32=False,
                opt_level='O2')

            self.fp_optimizer = AMPOptimizer(
                self.model,
                grad_clip,
                loss_scale=loss_scaling['init_scale'],
                dls_upscale_interval=loss_scaling['upscale_interval'])

        if self.distributed:
            self.model = DistributedDataParallel(self.model)
Ejemplo n.º 5
0
    def __init__(self,
                 model,
                 criterion,
                 opt_config,
                 scheduler_config,
                 print_freq=10,
                 save_freq=1000,
                 grad_clip=float('inf'),
                 batch_first=False,
                 save_info={},
                 save_path='.',
                 train_iterations=0,
                 checkpoint_filename='checkpoint%s.pth',
                 keep_checkpoints=5,
                 math='fp32',
                 cuda=True,
                 distributed=False,
                 intra_epoch_eval=0,
                 iter_size=1,
                 translator=None,
                 verbose=False):
        super(Seq2SeqTrainer, self).__init__()
        self.model = model
        self.criterion = criterion
        self.epoch = 0
        self.save_info = save_info
        self.save_path = save_path
        self.save_freq = save_freq
        self.save_counter = 0
        self.checkpoint_filename = checkpoint_filename
        self.checkpoint_counter = cycle(range(keep_checkpoints))
        self.opt_config = opt_config
        self.cuda = cuda
        self.distributed = distributed
        self.print_freq = print_freq
        self.batch_first = batch_first
        self.verbose = verbose
        self.loss = None
        self.translator = translator
        self.intra_epoch_eval = intra_epoch_eval
        self.iter_size = iter_size

        if cuda:
            self.model = self.model.cuda()
            self.criterion = self.criterion.cuda()

        if math == 'fp16':
            self.model = self.model.half()

        if distributed:
            self.model = DDP(self.model)

        if math == 'fp16':
            self.fp_optimizer = Fp16Optimizer(self.model, grad_clip)
            params = self.fp_optimizer.fp32_params
        elif math == 'fp32':
            self.fp_optimizer = Fp32Optimizer(self.model, grad_clip)
            params = self.model.parameters()

        opt_name = opt_config.pop('optimizer')
        self.optimizer = torch.optim.__dict__[opt_name](params, **opt_config)
        logging.info(f'Using optimizer: {self.optimizer}')

        self.scheduler = WarmupMultiStepLR(self.optimizer, train_iterations,
                                           **scheduler_config)
Ejemplo n.º 6
0
    def __init__(self,
                 model,
                 criterion,
                 opt_config,
                 scheduler_config,
                 print_freq=10,
                 save_freq=1000,
                 grad_clip=float('inf'),
                 batch_first=False,
                 save_info={},
                 save_path='.',
                 checkpoint_filename='checkpoint%s.pth',
                 keep_checkpoints=5,
                 math='fp32',
                 cuda=True,
                 distributed=False,
                 distributed_overlap_allreduce=False,
                 distributed_overlap_allreduce_messagesize=1e7,
                 intra_epoch_eval=0,
                 translator=None,
                 verbose=False,
                 arch="gnmt"):
        super(Seq2SeqTrainer, self).__init__()
        self.model = model
        self.criterion = criterion
        self.epoch = 0
        self.save_info = save_info
        self.save_path = save_path
        self.save_freq = save_freq
        self.save_counter = 0
        self.checkpoint_filename = checkpoint_filename
        self.checkpoint_counter = cycle(range(keep_checkpoints))
        self.opt_config = opt_config
        self.cuda = cuda
        self.distributed = distributed
        self.print_freq = print_freq
        self.batch_first = batch_first
        self.verbose = verbose
        self.loss = None
        self.translator = translator
        self.intra_epoch_eval = intra_epoch_eval
        self.arch = arch

        self.retain_allreduce_buffers = True
        self.gradient_average = False

        if cuda:
            self.model = self.model.cuda()
            self.criterion = self.criterion.cuda()

        if math == 'fp16':
            self.model = self.model.half()
            if distributed:
                # self.model = apex.parallel.DistributedDataParallel(self.model, message_size=10000000, delay_allreduce=True)
                self.model = apex.parallel.DistributedDataParallel(
                    self.model,
                    message_size=distributed_overlap_allreduce_messagesize,
                    delay_allreduce=(not distributed_overlap_allreduce),
                    retain_allreduce_buffers=self.retain_allreduce_buffers,
                    gradient_average=self.gradient_average)
            self.fp_optimizer = Fp16Optimizer(self.model, grad_clip)
            params = [self.fp_optimizer.fp32_params]
        elif math == 'fp32':
            if distributed:
                # self.model = apex.parallel.DistributedDataParallel(self.model, message_size=10000000, delay_allreduce=True)
                self.model = apex.parallel.DistributedDataParallel(
                    self.model,
                    message_size=distributed_overlap_allreduce_messagesize,
                    delay_allreduce=(not distributed_overlap_allreduce))
            self.fp_optimizer = Fp32Optimizer(self.model, grad_clip)
            params = self.model.parameters()

        opt_name = opt_config.pop('optimizer')
        if opt_name == 'FusedAdam':
            self.optimizer = apex.optimizers.FusedAdam(params, **opt_config)
        else:
            self.optimizer = torch.optim.__dict__[opt_name](params,
                                                            **opt_config)

        gnmt_print(key=mlperf_log.OPT_NAME, value=mlperf_log.ADAM)
        gnmt_print(key=mlperf_log.OPT_LR, value=opt_config['lr'])
        gnmt_print(key=mlperf_log.OPT_HP_ADAM_BETA1,
                   value=self.optimizer.defaults['betas'][0])
        gnmt_print(key=mlperf_log.OPT_HP_ADAM_BETA2,
                   value=self.optimizer.defaults['betas'][1])
        gnmt_print(key=mlperf_log.OPT_HP_ADAM_EPSILON,
                   value=self.optimizer.defaults['eps'])

        self.scheduler = WarmupMultiStepLR(
            self.optimizer,
            lr_method=scheduler_config["lr_method"],
            warmup_iters=scheduler_config["warmup_iters"],
            remain_steps=scheduler_config["remain_steps"],
            decay_steps=scheduler_config["decay_steps"])

        logging.info(f'Using optimizer: {self.optimizer}')
Ejemplo n.º 7
0
class Seq2SeqTrainer:
    """
    Seq2SeqTrainer
    """
    def __init__(self,
                 model,
                 criterion,
                 opt_config,
                 scheduler_config,
                 print_freq=10,
                 save_freq=1000,
                 grad_clip=float('inf'),
                 batch_first=False,
                 save_info={},
                 save_path='.',
                 checkpoint_filename='checkpoint%s.pth',
                 keep_checkpoints=5,
                 math='fp32',
                 cuda=True,
                 distributed=False,
                 distributed_overlap_allreduce=False,
                 distributed_overlap_allreduce_messagesize=1e7,
                 intra_epoch_eval=0,
                 translator=None,
                 verbose=False,
                 arch="gnmt"):
        super(Seq2SeqTrainer, self).__init__()
        self.model = model
        self.criterion = criterion
        self.epoch = 0
        self.save_info = save_info
        self.save_path = save_path
        self.save_freq = save_freq
        self.save_counter = 0
        self.checkpoint_filename = checkpoint_filename
        self.checkpoint_counter = cycle(range(keep_checkpoints))
        self.opt_config = opt_config
        self.cuda = cuda
        self.distributed = distributed
        self.print_freq = print_freq
        self.batch_first = batch_first
        self.verbose = verbose
        self.loss = None
        self.translator = translator
        self.intra_epoch_eval = intra_epoch_eval
        self.arch = arch

        self.retain_allreduce_buffers = True
        self.gradient_average = False

        if cuda:
            self.model = self.model.cuda()
            self.criterion = self.criterion.cuda()

        if math == 'fp16':
            self.model = self.model.half()
            if distributed:
                # self.model = apex.parallel.DistributedDataParallel(self.model, message_size=10000000, delay_allreduce=True)
                self.model = apex.parallel.DistributedDataParallel(
                    self.model,
                    message_size=distributed_overlap_allreduce_messagesize,
                    delay_allreduce=(not distributed_overlap_allreduce),
                    retain_allreduce_buffers=self.retain_allreduce_buffers,
                    gradient_average=self.gradient_average)
            self.fp_optimizer = Fp16Optimizer(self.model, grad_clip)
            params = [self.fp_optimizer.fp32_params]
        elif math == 'fp32':
            if distributed:
                # self.model = apex.parallel.DistributedDataParallel(self.model, message_size=10000000, delay_allreduce=True)
                self.model = apex.parallel.DistributedDataParallel(
                    self.model,
                    message_size=distributed_overlap_allreduce_messagesize,
                    delay_allreduce=(not distributed_overlap_allreduce))
            self.fp_optimizer = Fp32Optimizer(self.model, grad_clip)
            params = self.model.parameters()

        opt_name = opt_config.pop('optimizer')
        if opt_name == 'FusedAdam':
            self.optimizer = apex.optimizers.FusedAdam(params, **opt_config)
        else:
            self.optimizer = torch.optim.__dict__[opt_name](params,
                                                            **opt_config)

        gnmt_print(key=mlperf_log.OPT_NAME, value=mlperf_log.ADAM)
        gnmt_print(key=mlperf_log.OPT_LR, value=opt_config['lr'])
        gnmt_print(key=mlperf_log.OPT_HP_ADAM_BETA1,
                   value=self.optimizer.defaults['betas'][0])
        gnmt_print(key=mlperf_log.OPT_HP_ADAM_BETA2,
                   value=self.optimizer.defaults['betas'][1])
        gnmt_print(key=mlperf_log.OPT_HP_ADAM_EPSILON,
                   value=self.optimizer.defaults['eps'])

        self.scheduler = WarmupMultiStepLR(
            self.optimizer,
            lr_method=scheduler_config["lr_method"],
            warmup_iters=scheduler_config["warmup_iters"],
            remain_steps=scheduler_config["remain_steps"],
            decay_steps=scheduler_config["decay_steps"])

        logging.info(f'Using optimizer: {self.optimizer}')

    def iterate(self, src, tgt, update=True, training=True):
        src, src_length = src
        tgt, tgt_length = tgt
        src_length = torch.LongTensor(src_length)
        tgt_length = torch.LongTensor(tgt_length)

        num_toks = {}
        num_toks['tgt'] = int(sum(tgt_length - 1))
        num_toks['src'] = int(sum(src_length))

        if self.cuda:
            src = src.cuda()
            src_length = src_length.cuda()
            tgt = tgt.cuda()

        if self.batch_first:
            output = self.model(src, src_length, tgt[:, :-1])
            tgt_labels = tgt[:, 1:]
            T, B = output.size(1), output.size(0)
        else:
            output = self.model(src, src_length, tgt[:-1])
            tgt_labels = tgt[1:]
            T, B = output.size(0), output.size(1)

        loss = self.criterion(output.view(T * B, -1),
                              tgt_labels.contiguous().view(-1))

        loss_per_batch = loss.item()
        loss /= B

        if training:
            self.fp_optimizer.step(loss, self.optimizer, self.scheduler,
                                   update)

        loss_per_token = loss_per_batch / num_toks['tgt']
        loss_per_sentence = loss_per_batch / B

        return loss_per_token, loss_per_sentence, num_toks

    def feed_data(self, data_loader, training=True):
        """
        Runs training or validation on batches from data_loader.

        :param data_loader: data loader
        :param training: if True runs training else runs validation
        """
        if training:
            assert self.optimizer is not None
            eval_fractions = np.linspace(0, 1, self.intra_epoch_eval + 2)[1:-1]
            eval_iters = (eval_fractions * len(data_loader)).astype(int)
            eval_iters = set(eval_iters)

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses_per_token = AverageMeter()
        losses_per_sentence = AverageMeter()

        tot_tok_time = AverageMeter()
        src_tok_time = AverageMeter()
        tgt_tok_time = AverageMeter()

        batch_size = data_loader.batch_size
        layer_timestamps = []
        verbose = True

        module_whitelist = ["EmuBidirLSTM", "RecurrentAttention", "Classifier"]

        for i, (src, tgt) in enumerate(data_loader):
            break
        (src, src_length) = src
        (tgt, tgt_length) = tgt
        src_length = torch.LongTensor(src_length).cuda()
        src = src.cuda()
        tgt = tgt.cuda()
        model_input = (src, src_length, tgt[:-1])
        summary = torchsummary.summary(model=self.model,
                                       module_whitelist=module_whitelist,
                                       model_input=model_input,
                                       verbose=True)

        end = time.time()
        NUM_STEPS_TO_PROFILE = 100  # profile 100 steps
        for i, (src, tgt) in enumerate(data_loader):
            self.save_counter += 1
            # measure data loading time
            data_time.update(time.time() - end)

            with torchprofiler.Profiling(self.model, module_whitelist) as p:
                # do a train/evaluate iteration
                stats = self.iterate(src, tgt, training=training)
                loss_per_token, loss_per_sentence, num_toks = stats
            print(str(p))
            layer_timestamps.append(p.processed_times())

            # measure accuracy and record loss
            losses_per_token.update(loss_per_token, num_toks['tgt'])
            losses_per_sentence.update(loss_per_sentence, batch_size)

            # measure elapsed time
            elapsed = time.time() - end
            batch_time.update(elapsed)
            src_tok_time.update(num_toks['src'] / elapsed)
            tgt_tok_time.update(num_toks['tgt'] / elapsed)
            tot_num_toks = num_toks['tgt'] + num_toks['src']
            tot_tok_time.update(tot_num_toks / elapsed)
            self.loss = losses_per_token.avg

            if training and i in eval_iters:
                test_bleu, _ = self.translator.run(calc_bleu=True,
                                                   epoch=self.epoch,
                                                   iteration=i)

                log = []
                log += [f'TRAIN [{self.epoch}][{i}/{len(data_loader)}]']
                log += [f'BLEU: {test_bleu:.2f}']
                log = '\t'.join(log)
                logging.info(log)

                self.model.train()
                self.preallocate(data_loader, training=True)

            if i % self.print_freq == 0:
                phase = 'TRAIN' if training else 'VALIDATION'
                log = []
                log += [f'{phase} [{self.epoch}][{i}/{len(data_loader)}]']
                log += [f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})']
                log += [f'Data {data_time.val:.5f} ({data_time.avg:.5f})']
                log += [
                    f'Tok/s {tot_tok_time.val:.0f} ({tot_tok_time.avg:.0f})'
                ]
                if self.verbose:
                    log += [
                        f'Src tok/s {src_tok_time.val:.0f} ({src_tok_time.avg:.0f})'
                    ]
                    log += [
                        f'Tgt tok/s {tgt_tok_time.val:.0f} ({tgt_tok_time.avg:.0f})'
                    ]
                    log += [
                        f'Loss/sentence {losses_per_sentence.val:.1f} ({losses_per_sentence.avg:.1f})'
                    ]
                log += [
                    f'Loss/tok {losses_per_token.val:.4f} ({losses_per_token.avg:.4f})'
                ]
                lr = [
                    param_group['lr']
                    for param_group in self.optimizer.param_groups
                ]
                log += [f'Learning Rate {lr}']
                log = '\t'.join(log)
                logging.info(log)

            if i >= NUM_STEPS_TO_PROFILE:
                break

            save_chkpt = (self.save_counter %
                          self.save_freq) == (self.save_freq - 1)
            if training and save_chkpt:
                self.save_counter = 0
                self.save_info['iteration'] = i
                identifier = next(self.checkpoint_counter, -1)
                if identifier != -1:
                    with sync_workers() as rank:
                        if rank == 0:
                            self.save(identifier=identifier)

            end = time.time()

        if verbose:
            print(
                "\n==========================================================")
            print("Layer Type    Forward Time (ms)    Backward Time (ms)")
            print("==========================================================")

        tot_accounted_time = 0.0
        per_layer_times = []
        for i in range(len(layer_timestamps[0])):
            layer_type = str(layer_timestamps[0][i][0])
            layer_forward_time_sum = 0.0
            layer_backward_time_sum = 0.0
            for j in range(len(layer_timestamps)):
                layer_forward_time_sum += (layer_timestamps[j][i][2] / 1000)
                layer_backward_time_sum += (layer_timestamps[j][i][5] / 1000)
            per_layer_times.append(
                (layer_type, layer_forward_time_sum / len(layer_timestamps),
                 layer_backward_time_sum / len(layer_timestamps)))
            if verbose:
                print(per_layer_times[-1][0], per_layer_times[-1][1],
                      per_layer_times[-1][2])
            tot_accounted_time += (per_layer_times[-1][1] +
                                   per_layer_times[-1][2])

        print("Total accounted time: %.3f ms" % tot_accounted_time)

        summary_i = 0
        per_layer_times_i = 0
        last_summary_i = -1
        last_per_layer_times_i = -1
        while len(per_layer_times) > 0:
            per_layer_time = per_layer_times.pop(0)
            for summary_i in range(len(summary)):
                summary_elem = summary[summary_i]
                if str(summary_elem['layer_name']) != str(per_layer_time[0]):
                    continue
                if 'forward_time' in summary_elem and 'backward_time' in summary_elem:
                    continue
                summary_elem['forward_time'] = per_layer_time[1]
                summary_elem['backward_time'] = per_layer_time[2]
                break

        if training:
            create_graph(self.model, module_whitelist, (src, tgt), summary,
                         os.path.join("profiles", self.arch))

        tot_tok_time.reduce('sum')
        losses_per_token.reduce('mean')

        return losses_per_token.avg, tot_tok_time.avg

    def preallocate(self, data_loader, training):
        """
        Generates maximum sequence length batch and runs forward and backward
        pass without updating model parameters.

        :param data_loader: data loader
        :param training: if True preallocates memory for backward pass
        """
        batch_size = data_loader.batch_size
        max_len = data_loader.dataset.max_len

        src_length = [max_len] * batch_size
        tgt_length = [max_len] * batch_size

        if self.batch_first:
            shape = (batch_size, max_len)
        else:
            shape = (max_len, batch_size)

        src = torch.full(shape, 4, dtype=torch.int64)
        tgt = torch.full(shape, 4, dtype=torch.int64)
        src = src, src_length
        tgt = tgt, tgt_length
        self.iterate(src, tgt, update=False, training=training)

    def optimize(self, data_loader):
        """
        Sets model in training mode, preallocates memory and runs training on
        data provided by data_loader.

        :param data_loader: data loader
        """
        torch.set_grad_enabled(True)
        self.model.train()
        torch.cuda.empty_cache()
        self.preallocate(data_loader, training=True)
        output = self.feed_data(data_loader, training=True)
        torch.cuda.empty_cache()
        return output

    def evaluate(self, data_loader):
        """
        Sets model in eval mode, disables gradients, preallocates memory and
        runs validation on data provided by data_loader.

        :param data_loader: data loader
        """
        torch.set_grad_enabled(False)
        self.model.eval()
        torch.cuda.empty_cache()
        self.preallocate(data_loader, training=False)
        output = self.feed_data(data_loader, training=False)
        torch.cuda.empty_cache()
        return output

    def load(self, filename):
        """
        Loads checkpoint from filename.

        :param filename: path to the checkpoint file
        """
        if os.path.isfile(filename):
            checkpoint = torch.load(filename, map_location={'cuda:0': 'cpu'})
            self.model.load_state_dict(checkpoint['state_dict'])
            self.fp_optimizer.initialize_model(self.model)
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.scheduler.load_state_dict(checkpoint['scheduler'])
            self.epoch = checkpoint['epoch']
            self.loss = checkpoint['loss']
            logging.info(f'Loaded checkpoint {filename} (epoch {self.epoch})')
        else:
            logging.error(f'Invalid checkpoint: {filename}')

    def save(self, identifier=None, is_best=False, save_all=False):
        """
        Stores checkpoint to a file.

        :param identifier: identifier for periodic checkpoint
        :param is_best: if True stores checkpoint to 'model_best.pth'
        :param save_all: if True stores checkpoint after completed training
            epoch
        """
        def write_checkpoint(state, filename):
            filename = os.path.join(self.save_path, filename)
            logging.info(f'Saving model to {filename}')
            torch.save(state, filename)

        state = {
            'epoch': self.epoch,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict(),
            'loss': getattr(self, 'loss', None),
        }
        state = dict(list(state.items()) + list(self.save_info.items()))

        if identifier is not None:
            filename = self.checkpoint_filename % identifier
            write_checkpoint(state, filename)

        if is_best:
            filename = 'model_best.pth'
            write_checkpoint(state, filename)

        if save_all:
            filename = f'checkpoint_epoch_{self.epoch:03d}.pth'
            write_checkpoint(state, filename)
Ejemplo n.º 8
0
def main():
    """
    Launches data-parallel multi-gpu training.
    """

    mlperf_compliance.mlperf_log.LOGGER.propagate = False

    mlperf_compliance.mlperf_log.setdefault(
        root_dir=os.path.dirname(os.path.abspath(__file__)),
        benchmark=mlperf_compliance.constants.GNMT,
        stack_offset=1,
        extra_print=False
        )

    mlperf_print(key=mlperf_compliance.constants.INIT_START,
                 log_all_ranks=True)

    args = parse_args()
    device = utils.set_device(args.cuda, args.local_rank)
    distributed = utils.init_distributed(args.cuda)

    # preinit and warmup streams/ groups for apex DDP communicators
    allreduce_communicators=None
    if distributed and args.apex_num_allreduce_streams > 1:
        bucket_pgs = [torch.distributed.new_group() for _ in range(args.apex_num_allreduce_streams)]
        bucket_streams = [torch.cuda.Stream() for _ in range(args.apex_num_allreduce_streams)]
        for pg, stream in zip(bucket_pgs,bucket_streams):
            with torch.cuda.stream(stream):
                torch.distributed.all_reduce(torch.cuda.FloatTensor(1), group=pg)
        allreduce_communicators=(bucket_pgs,bucket_streams)

    args.rank = utils.get_rank()

    if not args.cudnn:
        torch.backends.cudnn.enabled = False

    # create directory for results
    save_path = os.path.join(args.results_dir, args.save)
    args.save_path = save_path
    os.makedirs(save_path, exist_ok=True)

    # setup logging
    log_filename = f'log_rank_{utils.get_rank()}.log'
    utils.setup_logging(args.log_all_ranks,
                        os.path.join(save_path, log_filename))

    if args.env:
        utils.log_env_info()

    logging.info(f'Saving results to: {save_path}')
    logging.info(f'Run arguments: {args}')

    # automatically set train_iter_size based on train_global_batch_size,
    # world_size and per-worker train_batch_size
    if args.train_global_batch_size is not None:
        global_bs = args.train_global_batch_size
        bs = args.train_batch_size
        world_size = utils.get_world_size()
        assert global_bs % (bs * world_size) == 0
        args.train_iter_size = global_bs // (bs * world_size)
        logging.info(f'Global batch size was set in the config, '
                     f'Setting train_iter_size to {args.train_iter_size}')
    # setup L2 promotion
    if args.cuda:
        utils.l2_promote()

    worker_seeds, shuffling_seeds = utils.setup_seeds(args.seed, args.epochs,
                                                      device)
    worker_seed = worker_seeds[args.rank]
    logging.info(f'Worker {args.rank} is using worker seed: {worker_seed}')
    torch.manual_seed(worker_seed)

    # build tokenizer
    # https://github.com/mlperf/policies/issues/201
    pad_vocab = utils.pad_vocabulary(args.math)
    tokenizer = Tokenizer(os.path.join(args.dataset_dir, config.VOCAB_FNAME),
                          pad_vocab)

    vocab_size = tokenizer.vocab_size

    # build GNMT model
    model_config = {'hidden_size': args.hidden_size,
                    'num_layers': args.num_layers,
                    'dropout': args.dropout, 'batch_first': False,
                    'share_embedding': args.share_embedding,
                    'fusion': args.fused_attention}
    model = GNMT(vocab_size=vocab_size, **model_config)
    logging.info(model)

    # define loss function (criterion) and optimizer
    criterion = build_criterion(vocab_size, config.PAD, args.smoothing,
                                args.fused_xentropy)

    opt_config = {'optimizer': args.optimizer, 'lr': args.lr}
    opt_config.update(literal_eval(args.optimizer_extra))
    logging.info(f'Training optimizer config: {opt_config}')

    num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info(f'Number of parameters: {num_parameters}')

    # create trainer
    save_info = {'model_config': model_config, 'config': args, 'tokenizer':
                 tokenizer.get_state()}
    loss_scaling = {'init_scale': args.init_scale, 'upscale_interval':
                    args.upscale_interval}
    trainer_options = dict(
        criterion=criterion,
        grad_clip=args.grad_clip,
        iter_size=args.train_iter_size,
        save_path=save_path,
        save_freq=args.save_freq,
        save_info=save_info,
        opt_config=opt_config,
        batch_first=model.batch_first,
        keep_checkpoints=args.keep_checkpoints,
        math=args.math,
        loss_scaling=loss_scaling,
        print_freq=args.print_freq,
        cuda=args.cuda,
        distributed=distributed,
        distributed_overlap_allreduce=args.enable_apex_allreduce_overlap,
        distributed_overlap_num_allreduce_streams=args.apex_num_allreduce_streams,
        distributed_overlap_allreduce_messagesize=args.apex_message_size,
        distributed_overlap_allreduce_communicators=allreduce_communicators,
        intra_epoch_eval=args.intra_epoch_eval,
        prealloc_mode=args.prealloc_mode)

    trainer_options['model'] = model
    trainer = trainers.Seq2SeqTrainer(**trainer_options)

    trainer.preallocate(args.train_batch_size, args.max_length_train,
                        training=True)

    mlperf_print(key=mlperf_compliance.constants.INIT_STOP,
                 sync=True)
    mlperf_print(key=mlperf_compliance.constants.RUN_START,
                 sync=True)
    utils.barrier()

    mlperf_print(key=mlperf_compliance.constants.MAX_SEQUENCE_LENGTH,
                 value=args.max_length_train,
                 metadata={'method': 'discard'})

    if args.use_preproc_data:
        train_data = PreprocessedDataset(
            min_len=args.min_length_train,
            max_len=args.max_length_train,
            vocab_size=tokenizer.vocab_size,
            )
        train_data.read_data(
            os.path.join(args.preproc_data_dir, 'training.bin'),
            tokenizer.vocab_size,
            )
        train_data.prepare()
    else:
        train_data = LazyParallelDataset(
            src_fname=os.path.join(args.dataset_dir, config.SRC_TRAIN_FNAME),
            tgt_fname=os.path.join(args.dataset_dir, config.TGT_TRAIN_FNAME),
            tokenizer=tokenizer,
            min_len=args.min_length_train,
            max_len=args.max_length_train,
            sort=False,
            max_size=args.max_size,
            )

    test_data = TextDataset(
        src_fname=os.path.join(args.dataset_dir, config.SRC_TEST_FNAME),
        tokenizer=tokenizer,
        min_len=args.min_length_test,
        max_len=args.max_length_test,
        sort=True)

    batching_opt = {'shard_size': args.shard_size,
                    'num_buckets': args.num_buckets}

    # get data loaders
    train_loader = train_data.get_loader(batch_size=args.train_batch_size,
                                         seeds=shuffling_seeds,
                                         batch_first=model.batch_first,
                                         shuffle=True,
                                         batching=args.batching,
                                         batching_opt=batching_opt,
                                         num_workers=args.train_loader_workers)

    mlperf_print(key=mlperf_compliance.constants.GLOBAL_BATCH_SIZE,
                 value=args.train_batch_size * utils.get_world_size(),
                 sync=False)

    test_loader = test_data.get_loader(batch_size=args.test_batch_size,
                                       batch_first=model.batch_first,
                                       shuffle=False,
                                       num_workers=args.test_loader_workers)

    translator = Translator(model=model,
                            tokenizer=tokenizer,
                            loader=test_loader,
                            beam_size=args.beam_size,
                            max_seq_len=args.max_length_test,
                            len_norm_factor=args.len_norm_factor,
                            len_norm_const=args.len_norm_const,
                            cov_penalty_factor=args.cov_penalty_factor,
                            cuda=args.cuda,
                            print_freq=args.print_freq,
                            dataset_dir=args.dataset_dir,
                            target_bleu=args.target_bleu,
                            save_path=args.save_path)

    total_train_iters = len(train_loader) // args.train_iter_size * args.epochs

    scheduler_config = {'warmup_steps': args.warmup_steps,
                        'remain_steps': args.remain_steps,
                        'decay_interval': args.decay_interval,
                        'decay_steps': args.decay_steps,
                        'decay_factor': args.decay_factor}

    logging.info(f'Training LR schedule config: {scheduler_config}')
    scheduler = WarmupMultiStepLR(trainer.optimizer, total_train_iters,
                                  **scheduler_config)
    trainer.scheduler = scheduler
    trainer.translator = translator

    # optionally resume from a checkpoint
    if args.resume:
        checkpoint_file = args.resume
        if os.path.isdir(checkpoint_file):
            checkpoint_file = os.path.join(
                checkpoint_file, 'model_best.pth')
        if os.path.isfile(checkpoint_file):
            trainer.load(checkpoint_file)
        else:
            logging.error(f'No checkpoint found at {args.resume}')

    # training loop
    break_training = False
    test_bleu = None
    for epoch in range(args.start_epoch, args.epochs):
        mlperf_print(key=mlperf_compliance.constants.BLOCK_START,
                     metadata={'first_epoch_num': epoch + 1,
                               'epoch_count': 1},
                     sync=True)
        mlperf_print(key=mlperf_compliance.constants.EPOCH_START,
                     metadata={'epoch_num': epoch + 1},
                     sync=True)

        logging.info(f'Starting epoch {epoch}')
        train_loader.sampler.set_epoch(epoch)

        trainer.epoch = epoch
        train_loss, train_perf = trainer.optimize(train_loader)

        mlperf_print(key=mlperf_compliance.constants.EPOCH_STOP,
                     metadata={'epoch_num': epoch + 1},
                     sync=True)

        if args.eval:
            mlperf_print(key=mlperf_compliance.constants.EVAL_START,
                         metadata={'epoch_num': epoch + 1},
                         sync=True)
            test_bleu, break_training = translator.run(calc_bleu=True,
                                                       epoch=epoch)
            mlperf_print(key=mlperf_compliance.constants.EVAL_ACCURACY,
                         value=test_bleu,
                         metadata={'epoch_num': epoch + 1},
                         sync=False)
            mlperf_print(key=mlperf_compliance.constants.EVAL_STOP,
                         metadata={'epoch_num': epoch + 1},
                         sync=True)

        acc_log = []
        acc_log += [f'Summary: Epoch: {epoch}']
        acc_log += [f'Training Loss: {train_loss:.4f}']
        if args.eval:
            acc_log += [f'Test BLEU: {test_bleu:.2f}']

        perf_log = []
        perf_log += [f'Performance: Epoch: {epoch}']
        perf_log += [f'Training: {train_perf:.0f} Tok/s']

        if args.rank == 0:
            logging.info('\t'.join(acc_log))
            logging.info('\t'.join(perf_log))

        logging.info(f'Finished epoch {epoch}')
        mlperf_print(key=mlperf_compliance.constants.BLOCK_STOP,
                     metadata={'first_epoch_num': epoch + 1},
                     sync=True)

        if break_training:
            break

    if args.use_preproc_data:
        train_data.finalize()

    status = 'success' if break_training else 'aborted'
    mlperf_print(key=mlperf_compliance.constants.RUN_STOP,
                 metadata={'status': status},
                 sync=True)