Esempio n. 1
0
def main(args):
    """
        Calc loss and perplexity on training and validation set
    """
    logging.info('Commencing Validation!')
    torch.manual_seed(42)
    np.random.seed(42)

    utils.init_logging(args)

    # Load dictionaries [for each language]
    src_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.source_lang)))
    logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(
        args.source_lang, len(src_dict)))
    tgt_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.target_lang)))
    logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(
        args.target_lang, len(tgt_dict)))

    # Load datasets
    def load_data(split):
        return Seq2SeqDataset(
            src_file=os.path.join(args.data,
                                  '{:s}.{:s}'.format(split, args.source_lang)),
            tgt_file=os.path.join(args.data,
                                  '{:s}.{:s}'.format(split, args.target_lang)),
            src_dict=src_dict,
            tgt_dict=tgt_dict)

    train_dataset = load_data(
        split='train') if not args.train_on_tiny else load_data(
            split='tiny_train')
    valid_dataset = load_data(split='valid')

    # Build model and optimization criterion
    model = models.build_model(args, src_dict, tgt_dict)
    logging.info('Built a model with {:d} parameters'.format(
        sum(p.numel() for p in model.parameters())))
    criterion = nn.CrossEntropyLoss(ignore_index=src_dict.pad_idx,
                                    reduction='sum')

    if torch.cuda.is_available() and args.cuda:
        model = model.cuda()
    # Instantiate optimizer and learning rate scheduler
    optimizer = torch.optim.Adam(model.parameters(), args.lr)

    # Load last checkpoint if one exists
    state_dict = utils.load_checkpoint(args, model, optimizer)  # lr_scheduler


    train_loader = \
        torch.utils.data.DataLoader(train_dataset, num_workers = 1, collate_fn = train_dataset.collater,
                                    batch_sampler = BatchSampler(train_dataset, args.max_tokens, args.batch_size, 1,
                                                                 0, shuffle = True, seed = 42))

    # Calculate validation loss
    train_perplexity = validate(args, model, criterion, train_dataset, 0)

    valid_perplexity = validate(args, model, criterion, valid_dataset, 0)
Esempio n. 2
0
def main(args):
    """ Main training function. Trains the translation model over the course of several epochs, including dynamic
    learning rate adjustment and gradient clipping. """

    logging.info('Commencing training!')
    torch.manual_seed(42)

    utils.init_logging(args)

    # Load dictionaries
    src_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.source_lang)))
    logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(
        args.source_lang, len(src_dict)))
    tgt_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.target_lang)))
    logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(
        args.target_lang, len(tgt_dict)))

    # Load datasets
    def load_data(split):
        return Seq2SeqDataset(
            src_file=os.path.join(args.data,
                                  '{:s}.{:s}'.format(split, args.source_lang)),
            tgt_file=os.path.join(args.data,
                                  '{:s}.{:s}'.format(split, args.target_lang)),
            src_dict=src_dict,
            tgt_dict=tgt_dict)

    valid_dataset = load_data(split='valid')

    # Build model and optimization criterion
    model = models.build_model(args, src_dict, tgt_dict)
    logging.info('Built a model with {:d} parameters'.format(
        sum(p.numel() for p in model.parameters())))
    criterion = nn.CrossEntropyLoss(ignore_index=src_dict.pad_idx,
                                    reduction='sum')
    if args.cuda:
        model = model.cuda()
        criterion = criterion.cuda()

    # Instantiate optimizer and learning rate scheduler
    optimizer = torch.optim.Adam(model.parameters(), args.lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           'min',
                                                           factor=0.5,
                                                           patience=1)

    # Load last checkpoint if one exists
    state_dict = utils.load_checkpoint(args, model, optimizer,
                                       scheduler)  # lr_scheduler
    last_epoch = state_dict['last_epoch'] if state_dict is not None else -1

    # Track validation performance for early stopping
    bad_epochs = 0
    best_validate = float('inf')

    for epoch in range(last_epoch + 1, args.max_epoch):
        ## BPE Dropout
        # Set the seed to be equal to the epoch
        # (this way we guarantee same seeds over multiple training runs, but not for each training epoch)
        seed = epoch

        bpe_dropout_if_needed(seed, args.bpe_dropout)

        # Load the BPE (dropout-ed) training data
        train_dataset = load_data(
            split='train') if not args.train_on_tiny else load_data(
                split='tiny_train')
        train_loader = \
            torch.utils.data.DataLoader(train_dataset, num_workers=1, collate_fn=train_dataset.collater,
                                        batch_sampler=BatchSampler(train_dataset, args.max_tokens, args.batch_size, 1,
                                                                   0, shuffle=True, seed=42))
        model.train()
        stats = OrderedDict()
        stats['loss'] = 0
        stats['lr'] = 0
        stats['num_tokens'] = 0
        stats['batch_size'] = 0
        stats['grad_norm'] = 0
        stats['clip'] = 0
        # Display progress
        progress_bar = tqdm(train_loader,
                            desc='| Epoch {:03d}'.format(epoch),
                            leave=False,
                            disable=False)

        # Iterate over the training set
        for i, sample in enumerate(progress_bar):
            if args.cuda:
                sample = utils.move_to_cuda(sample)
            if len(sample) == 0:
                continue
            model.train()

            output, _ = model(sample['src_tokens'], sample['src_lengths'],
                              sample['tgt_inputs'])
            loss = \
                criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1)) / len(sample['src_lengths'])
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.clip_norm)
            optimizer.step()
            optimizer.zero_grad()

            # Update statistics for progress bar
            total_loss, num_tokens, batch_size = loss.item(
            ), sample['num_tokens'], len(sample['src_tokens'])
            stats['loss'] += total_loss * len(
                sample['src_lengths']) / sample['num_tokens']
            stats['lr'] += optimizer.param_groups[0]['lr']
            stats['num_tokens'] += num_tokens / len(sample['src_tokens'])
            stats['batch_size'] += batch_size
            stats['grad_norm'] += grad_norm
            stats['clip'] += 1 if grad_norm > args.clip_norm else 0
            progress_bar.set_postfix(
                {
                    key: '{:.4g}'.format(value / (i + 1))
                    for key, value in stats.items()
                },
                refresh=True)

        logging.info('Epoch {:03d}: {}'.format(
            epoch, ' | '.join(key + ' {:.4g}'.format(value / len(progress_bar))
                              for key, value in stats.items())))

        # Calculate validation loss
        valid_perplexity, valid_loss = validate(args, model, criterion,
                                                valid_dataset, epoch)
        model.train()

        # Scheduler step
        if args.adaptive_lr:
            scheduler.step(valid_loss)

        # Save checkpoints
        if epoch % args.save_interval == 0:
            utils.save_checkpoint(args, model, optimizer, scheduler, epoch,
                                  valid_perplexity)  # lr_scheduler

        # Check whether to terminate training
        if valid_perplexity < best_validate:
            best_validate = valid_perplexity
            bad_epochs = 0
        else:
            bad_epochs += 1
        if bad_epochs >= args.patience:
            logging.info(
                'No validation set improvements observed for {:d} epochs. Early stop!'
                .format(args.patience))
            break
Esempio n. 3
0
def main(args):
    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported.')
    torch.manual_seed(args.seed)
    torch.cuda.set_device(args.device_id)
    utils.init_logging(args)

    if args.distributed_world_size > 1:
        torch.distributed.init_process_group(
            backend=args.distributed_backend,
            init_method=args.distributed_init_method,
            world_size=args.distributed_world_size,
            rank=args.distributed_rank)

    # Load dictionaries
    src_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{}'.format(args.source_lang)))
    logging.info('Loaded a source dictionary ({}) with {} words'.format(
        args.source_lang, len(src_dict)))
    tgt_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{}'.format(args.target_lang)))
    logging.info('Loaded a target dictionary ({}) with {} words'.format(
        args.target_lang, len(tgt_dict)))

    # Load datasets
    def load_data(split):
        return Seq2SeqDataset(
            src_file=os.path.join(args.data,
                                  '{}.{}'.format(split, args.source_lang)),
            tgt_file=os.path.join(args.data,
                                  '{}.{}'.format(split, args.target_lang)),
            src_dict=src_dict,
            tgt_dict=tgt_dict)

    train_dataset = load_data(split='train')
    valid_dataset = load_data(split='valid')

    # Build model and criterion
    model = models.build_model(args, src_dict, tgt_dict).cuda()
    logging.info('Built a model with {} parameters'.format(
        sum(p.numel() for p in model.parameters())))
    criterion = nn.CrossEntropyLoss(ignore_index=src_dict.pad_idx,
                                    reduction='sum').cuda()

    # Build an optimizer and a learning rate schedule
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, patience=0, min_lr=args.min_lr, factor=args.lr_shrink)

    # Load last checkpoint if one exists
    state_dict = utils.load_checkpoint(args, model, optimizer, lr_scheduler)
    last_epoch = state_dict['last_epoch'] if state_dict is not None else -1

    for epoch in range(last_epoch + 1, args.max_epoch):
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            num_workers=args.num_workers,
            collate_fn=train_dataset.collater,
            batch_sampler=BatchSampler(train_dataset,
                                       args.max_tokens,
                                       args.batch_size,
                                       args.distributed_world_size,
                                       args.distributed_rank,
                                       shuffle=True,
                                       seed=args.seed))

        model.train()
        stats = {
            'loss': 0.,
            'lr': 0.,
            'num_tokens': 0.,
            'batch_size': 0.,
            'grad_norm': 0.,
            'clip': 0.
        }
        progress_bar = tqdm(train_loader,
                            desc='| Epoch {:03d}'.format(epoch),
                            leave=False,
                            disable=(args.distributed_rank != 0))

        for i, sample in enumerate(progress_bar):
            sample = utils.move_to_cuda(sample)
            if len(sample) == 0:
                continue

            # Forward and backward pass
            output, _ = model(sample['src_tokens'], sample['src_lengths'],
                              sample['tgt_inputs'])
            loss = criterion(output.view(-1, output.size(-1)),
                             sample['tgt_tokens'].view(-1))
            optimizer.zero_grad()
            loss.backward()

            # Reduce gradients across all GPUs
            if args.distributed_world_size > 1:
                utils.reduce_grads(model.parameters())
                total_loss, num_tokens, batch_size = list(
                    map(
                        sum,
                        zip(*utils.all_gather_list([
                            loss.item(), sample['num_tokens'],
                            len(sample['src_tokens'])
                        ]))))
            else:
                total_loss, num_tokens, batch_size = loss.item(
                ), sample['num_tokens'], len(sample['src_tokens'])

            # Normalize gradients by number of tokens and perform clipping
            for param in model.parameters():
                param.grad.data.div_(num_tokens)
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.clip_norm)
            optimizer.step()

            # Update statistics for progress bar
            stats['loss'] += total_loss / num_tokens / math.log(2)
            stats['lr'] += optimizer.param_groups[0]['lr']
            stats['num_tokens'] += num_tokens / len(sample['src_tokens'])
            stats['batch_size'] += batch_size
            stats['grad_norm'] += grad_norm
            stats['clip'] += 1 if grad_norm > args.clip_norm else 0
            progress_bar.set_postfix(
                {
                    key: '{:.4g}'.format(value / (i + 1))
                    for key, value in stats.items()
                },
                refresh=True)

        logging.info('Epoch {:03d}: {}'.format(
            epoch, ' | '.join(key + ' {:.4g}'.format(value / len(progress_bar))
                              for key, value in stats.items())))

        # Adjust learning rate based on validation loss
        valid_loss = validate(args, model, criterion, valid_dataset, epoch)
        lr_scheduler.step(valid_loss)

        # Save checkpoints
        if epoch % args.save_interval == 0:
            utils.save_checkpoint(args, model, optimizer, lr_scheduler, epoch,
                                  valid_loss)
        if optimizer.param_groups[0]['lr'] <= args.min_lr:
            logging.info('Done training!')
            break
Esempio n. 4
0
def main(args):
    """ Main training function. Trains the translation model over the course of several epochs, including dynamic
    learning rate adjustment and gradient clipping. """
    logging.info('Commencing training!')
    torch.manual_seed(42)
    np.random.seed(42)

    utils.init_logging(args)

    # Load dictionaries
    src_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.source_lang)))
    logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(args.source_lang, len(src_dict)))
    tgt_dict = Dictionary.load(os.path.join(args.data, 'dict.{:s}'.format(args.target_lang)))
    logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(args.target_lang, len(tgt_dict)))

    # Load datasets
    def load_data(split):
        return Seq2SeqDataset(
            src_file=os.path.join(args.data, '{:s}.{:s}'.format(split, args.source_lang)),
            tgt_file=os.path.join(args.data, '{:s}.{:s}'.format(split, args.target_lang)),
            src_dict=src_dict, tgt_dict=tgt_dict)

    train_dataset = load_data(split='train') if not args.train_on_tiny else load_data(split='tiny_train')
    valid_dataset = load_data(split='valid')

    # yichao: enable cuda
    use_cuda = torch.cuda.is_available() and args.device == 'cuda'
    device = torch.device("cuda" if use_cuda else "cpu")
    print("===> Using %s" % device)

    # Build model and optimization criterion
    # yichao: enable cuda, i.e. add .to(device)
    model = models.build_model(args, src_dict, tgt_dict).to(device)
    logging.info('Built a model with {:d} parameters'.format(sum(p.numel() for p in model.parameters())))
    criterion = nn.CrossEntropyLoss(ignore_index=src_dict.pad_idx, reduction='sum').to(device)

    # Instantiate optimizer and learning rate scheduler
    optimizer = torch.optim.Adam(model.parameters(), args.lr)

    # Load last checkpoint if one exists
    state_dict = utils.load_checkpoint(args, model, optimizer)  # lr_scheduler
    last_epoch = state_dict['last_epoch'] if state_dict is not None else -1

    # Track validation performance for early stopping
    bad_epochs = 0
    best_validate = float('inf')

    for epoch in range(last_epoch + 1, args.max_epoch):
        train_loader = \
            torch.utils.data.DataLoader(train_dataset, num_workers=1, collate_fn=train_dataset.collater,
                                        batch_sampler=BatchSampler(train_dataset, args.max_tokens, args.batch_size, 1,
                                                                   0, shuffle=True, seed=42))
        model.train()
        stats = OrderedDict()
        stats['loss'] = 0
        stats['lr'] = 0
        stats['num_tokens'] = 0
        stats['batch_size'] = 0
        stats['grad_norm'] = 0
        stats['clip'] = 0

        # Display progress
        progress_bar = tqdm(train_loader, desc='| Epoch {:03d}'.format(epoch), leave=False, disable=False)

        # Iterate over the training set
        for i, sample in enumerate(progress_bar):

            if len(sample) == 0:
                continue
            model.train()

            '''
            ___QUESTION-1-DESCRIBE-F-START___
            Describe what the following lines of code do.
            '''
            # yichao: enable cuda
            sample['src_tokens'], sample['src_lengths'], sample['tgt_inputs'], sample['tgt_tokens'] = \
                sample['src_tokens'].to(device), sample['src_lengths'].to(device), \
                sample['tgt_inputs'].to(device), sample['tgt_tokens'].to(device)

            output, _ = model(sample['src_tokens'], sample['src_lengths'], sample['tgt_inputs'])

            loss = \
                criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1)) / len(sample['src_lengths'])
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm)
            optimizer.step()
            optimizer.zero_grad()
            '''___QUESTION-1-DESCRIBE-F-END___'''

            # Update statistics for progress bar
            total_loss, num_tokens, batch_size = loss.item(), sample['num_tokens'], len(sample['src_tokens'])
            stats['loss'] += total_loss * len(sample['src_lengths']) / sample['num_tokens']
            stats['lr'] += optimizer.param_groups[0]['lr']
            stats['num_tokens'] += num_tokens / len(sample['src_tokens'])
            stats['batch_size'] += batch_size
            stats['grad_norm'] += grad_norm
            stats['clip'] += 1 if grad_norm > args.clip_norm else 0
            progress_bar.set_postfix({key: '{:.4g}'.format(value / (i + 1)) for key, value in stats.items()},
                                     refresh=True)

        logging.info('Epoch {:03d}: {}'.format(epoch, ' | '.join(key + ' {:.4g}'.format(
            value / len(progress_bar)) for key, value in stats.items())))

        # Calculate validation loss
        valid_perplexity = validate(args, model, criterion, valid_dataset, epoch)
        model.train()

        # Save checkpoints
        if epoch % args.save_interval == 0:
            utils.save_checkpoint(args, model, optimizer, epoch, valid_perplexity)  # lr_scheduler

        # Check whether to terminate training
        if valid_perplexity < best_validate:
            best_validate = valid_perplexity
            bad_epochs = 0
        else:
            bad_epochs += 1
        if bad_epochs >= args.patience:
            logging.info('No validation set improvements observed for {:d} epochs. Early stop!'.format(args.patience))
            break
Esempio n. 5
0
def main(args):
    """ Main training function. Trains the translation model over the course of several epochs, including dynamic
    learning rate adjustment and gradient clipping. """

    logging.info('Commencing training!')
    torch.manual_seed(42)

    utils.init_logging(args)

    # Load dictionaries
    src_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.source_lang)))
    logging.info('Loaded a source dictionary ({:s}) with {:d} words'.format(
        args.source_lang, len(src_dict)))
    tgt_dict = Dictionary.load(
        os.path.join(args.data, 'dict.{:s}'.format(args.target_lang)))
    logging.info('Loaded a target dictionary ({:s}) with {:d} words'.format(
        args.target_lang, len(tgt_dict)))

    # Load datasets
    def load_data(split):
        return Seq2SeqDataset(
            src_file=os.path.join(args.data,
                                  '{:s}.{:s}'.format(split, args.source_lang)),
            tgt_file=os.path.join(args.data,
                                  '{:s}.{:s}'.format(split, args.target_lang)),
            src_dict=src_dict,
            tgt_dict=tgt_dict)

    train_dataset = load_data(
        split='train') if not args.train_on_tiny else load_data(
            split='tiny_train')
    valid_dataset = load_data(split='valid')

    # Build model and optimization criterion
    model = models.build_model(args, src_dict, tgt_dict)
    model_rev = models.build_model(args, tgt_dict, src_dict)
    logging.info('Built a model with {:d} parameters'.format(
        sum(p.numel() for p in model.parameters())))
    criterion = nn.CrossEntropyLoss(ignore_index=src_dict.pad_idx,
                                    reduction='sum')
    criterion2 = nn.MSELoss(reduction='sum')
    if args.cuda:
        model = model.cuda()
        model_rev = model_rev.cuda()
        criterion = criterion.cuda()

    # Instantiate optimizer and learning rate scheduler
    optimizer = torch.optim.Adam(model.parameters(), args.lr)

    # Load last checkpoint if one exists
    state_dict = utils.load_checkpoint(args, model, optimizer)  # lr_scheduler
    utils.load_checkpoint_rev(args, model_rev, optimizer)  # lr_scheduler
    last_epoch = state_dict['last_epoch'] if state_dict is not None else -1

    # Track validation performance for early stopping
    bad_epochs = 0
    best_validate = float('inf')

    for epoch in range(last_epoch + 1, args.max_epoch):
        train_loader = \
            torch.utils.data.DataLoader(train_dataset, num_workers=1, collate_fn=train_dataset.collater,
                                        batch_sampler=BatchSampler(train_dataset, args.max_tokens, args.batch_size, 1,
                                                                   0, shuffle=True, seed=42))
        model.train()
        model_rev.train()
        stats = OrderedDict()
        stats['loss'] = 0
        stats['lr'] = 0
        stats['num_tokens'] = 0
        stats['batch_size'] = 0
        stats['grad_norm'] = 0
        stats['clip'] = 0
        # Display progress
        progress_bar = tqdm(train_loader,
                            desc='| Epoch {:03d}'.format(epoch),
                            leave=False,
                            disable=False)

        # Iterate over the training set
        for i, sample in enumerate(progress_bar):
            if args.cuda:
                sample = utils.move_to_cuda(sample)
            if len(sample) == 0:
                continue
            model.train()

            (output, att), src_out = model(sample['src_tokens'],
                                           sample['src_lengths'],
                                           sample['tgt_inputs'])
            # print(sample['src_lengths'])
            # print(sample['tgt_inputs'].size())
            # print(sample['src_tokens'].size())
            src_inputs = sample['src_tokens'].clone()
            src_inputs[0, 1:src_inputs.size(1)] = sample['src_tokens'][0, 0:(
                src_inputs.size(1) - 1)]
            src_inputs[0, 0] = sample['src_tokens'][0, src_inputs.size(1) - 1]
            tgt_lengths = sample['src_lengths'].clone(
            )  #torch.tensor([sample['tgt_tokens'].size(1)])
            tgt_lengths += sample['tgt_inputs'].size(
                1) - sample['src_tokens'].size(1)
            # print(tgt_lengths)
            # print(sample['num_tokens'])

            # if args.cuda:
            #     tgt_lengths = tgt_lengths.cuda()
            (output_rev,
             att_rev), src_out_rev = model_rev(sample['tgt_tokens'],
                                               tgt_lengths, src_inputs)

            # notice that those are without masks already
            # print(sample['tgt_tokens'].view(-1))
            d, d_rev = get_diff(att, src_out, att_rev, src_out_rev)

            # print(sample['src_tokens'].size())
            # print(sample['tgt_inputs'].size())
            # print(att.size())
            # print(src_out.size())
            # print(acontext.size())
            # print(src_out_rev.size())
            # # print(sample['tgt_inputs'].dtype)
            # # print(sample['src_lengths'])
            # # print(sample['src_tokens'])
            # # print('output %s' % str(output.size()))
            # # print(att)
            # # print(len(sample['src_lengths']))
            # print(d)
            # print(d_rev)
            # print(criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1)) / len(sample['src_lengths']))
            # print(att2)
            # output=output.cpu().detach().numpy()
            # output=torch.from_numpy(output).cuda()
            # output_rev=output_rev.cpu().detach().numpy()
            # output_rev=torch.from_numpy(output_rev).cuda()
            loss = \
                criterion(output.view(-1, output.size(-1)), sample['tgt_tokens'].view(-1)) / len(sample['src_lengths'])  + d +\
                criterion(output_rev.view(-1, output_rev.size(-1)), sample['src_tokens'].view(-1)) / len(tgt_lengths) +d_rev
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.clip_norm)
            # loss_rev = \
            #     criterion(output_rev.view(-1, output_rev.size(-1)), sample['src_tokens'].view(-1)) / len(tgt_lengths)
            # loss_rev.backward()
            # grad_norm_rev = torch.nn.utils.clip_grad_norm_(model_rev.parameters(), args.clip_norm)
            optimizer.step()
            optimizer.zero_grad()

            # Update statistics for progress bar
            total_loss, num_tokens, batch_size = (
                loss - d - d_rev).item(), sample['num_tokens'], len(
                    sample['src_tokens'])
            stats['loss'] += total_loss * len(
                sample['src_lengths']) / sample['num_tokens']
            # stats['loss_rev'] += loss_rev.item() * len(sample['src_lengths']) / sample['src_tokens'].size(0) / sample['src_tokens'].size(1)
            stats['lr'] += optimizer.param_groups[0]['lr']
            stats['num_tokens'] += num_tokens / len(sample['src_tokens'])
            stats['batch_size'] += batch_size
            stats['grad_norm'] += grad_norm
            stats['clip'] += 1 if grad_norm > args.clip_norm else 0
            progress_bar.set_postfix(
                {
                    key: '{:.4g}'.format(value / (i + 1))
                    for key, value in stats.items()
                },
                refresh=True)

        logging.info('Epoch {:03d}: {}'.format(
            epoch, ' | '.join(key + ' {:.4g}'.format(value / len(progress_bar))
                              for key, value in stats.items())))

        # Calculate validation loss
        valid_perplexity = validate(args, model, model_rev, criterion,
                                    valid_dataset, epoch)
        model.train()
        model_rev.train()

        # Save checkpoints
        if epoch % args.save_interval == 0:
            utils.save_checkpoint(args, model, model_rev, optimizer, epoch,
                                  valid_perplexity)  # lr_scheduler

        # Check whether to terminate training
        if valid_perplexity < best_validate:
            best_validate = valid_perplexity
            bad_epochs = 0
        else:
            bad_epochs += 1
        if bad_epochs >= args.patience:
            logging.info(
                'No validation set improvements observed for {:d} epochs. Early stop!'
                .format(args.patience))
            break