Exemple #1
0
def make_model(src_vocab,
               tgt_vocab,
               N=6,
               d_model=512,
               d_ff=2048,
               h=8,
               dropout=0.1):
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model).to(args.device)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout).to(args.device)
    position = PositionalEncoding(d_model, dropout).to(args.device)
    model = Transformer(
        Encoder(
            EncoderLayer(d_model, c(attn), c(ff), dropout).to(args.device),
            N).to(args.device),
        Decoder(
            DecoderLayer(d_model, c(attn), c(attn), c(ff),
                         dropout).to(args.device), N).to(args.device),
        nn.Sequential(
            Embeddings(d_model, src_vocab).to(args.device), c(position)),
        nn.Sequential(
            Embeddings(d_model, tgt_vocab).to(args.device), c(position)),
        Generator(d_model, tgt_vocab)).to(args.device)

    # This was important from their code.
    # Initialize parameters with Glorot / fan_avg.
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    return model.to(args.device)
Exemple #2
0
def training(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    #===================================#
    #==============Logging==============#
    #===================================#

    logger = logging.getLogger(__name__)
    logger.setLevel(logging.DEBUG)
    handler = TqdmLoggingHandler()
    handler.setFormatter(
        logging.Formatter(" %(asctime)s - %(message)s", "%Y-%m-%d %H:%M:%S"))
    logger.addHandler(handler)
    logger.propagate = False

    #===================================#
    #============Data Load==============#
    #===================================#

    # 1) Data open
    write_log(logger, "Load data...")
    gc.disable()
    with open(os.path.join(args.preprocess_path, 'processed.pkl'), 'rb') as f:
        data_ = pickle.load(f)
        train_src_indices = data_['train_src_indices']
        valid_src_indices = data_['valid_src_indices']
        train_trg_indices = data_['train_trg_indices']
        valid_trg_indices = data_['valid_trg_indices']
        src_word2id = data_['src_word2id']
        trg_word2id = data_['trg_word2id']
        src_vocab_num = len(src_word2id)
        trg_vocab_num = len(trg_word2id)
        del data_
    gc.enable()
    write_log(logger, "Finished loading data!")

    # 2) Dataloader setting
    dataset_dict = {
        'train':
        CustomDataset(train_src_indices,
                      train_trg_indices,
                      min_len=args.min_len,
                      src_max_len=args.src_max_len,
                      trg_max_len=args.trg_max_len),
        'valid':
        CustomDataset(valid_src_indices,
                      valid_trg_indices,
                      min_len=args.min_len,
                      src_max_len=args.src_max_len,
                      trg_max_len=args.trg_max_len),
    }
    dataloader_dict = {
        'train':
        DataLoader(dataset_dict['train'],
                   drop_last=True,
                   batch_size=args.batch_size,
                   shuffle=True,
                   pin_memory=True,
                   num_workers=args.num_workers),
        'valid':
        DataLoader(dataset_dict['valid'],
                   drop_last=False,
                   batch_size=args.batch_size,
                   shuffle=False,
                   pin_memory=True,
                   num_workers=args.num_workers)
    }
    write_log(
        logger,
        f"Total number of trainingsets  iterations - {len(dataset_dict['train'])}, {len(dataloader_dict['train'])}"
    )

    #===================================#
    #===========Train setting===========#
    #===================================#

    # 1) Model initiating
    write_log(logger, 'Instantiating model...')
    model = Transformer(
        src_vocab_num=src_vocab_num,
        trg_vocab_num=trg_vocab_num,
        pad_idx=args.pad_id,
        bos_idx=args.bos_id,
        eos_idx=args.eos_id,
        d_model=args.d_model,
        d_embedding=args.d_embedding,
        n_head=args.n_head,
        dim_feedforward=args.dim_feedforward,
        num_common_layer=args.num_common_layer,
        num_encoder_layer=args.num_encoder_layer,
        num_decoder_layer=args.num_decoder_layer,
        src_max_len=args.src_max_len,
        trg_max_len=args.trg_max_len,
        dropout=args.dropout,
        embedding_dropout=args.embedding_dropout,
        trg_emb_prj_weight_sharing=args.trg_emb_prj_weight_sharing,
        emb_src_trg_weight_sharing=args.emb_src_trg_weight_sharing,
        parallel=args.parallel)
    model.train()
    model = model.to(device)
    tgt_mask = model.generate_square_subsequent_mask(args.trg_max_len - 1,
                                                     device)

    # 2) Optimizer & Learning rate scheduler setting
    optimizer = optimizer_select(model, args)
    scheduler = shceduler_select(optimizer, dataloader_dict, args)
    scaler = GradScaler()

    # 3) Model resume
    start_epoch = 0
    if args.resume:
        write_log(logger, 'Resume model...')
        checkpoint = torch.load(
            os.path.join(args.save_path, 'checkpoint.pth.tar'))
        start_epoch = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        scaler.load_state_dict(checkpoint['scaler'])
        del checkpoint

    #===================================#
    #=========Model Train Start=========#
    #===================================#

    best_val_acc = 0

    write_log(logger, 'Traing start!')

    for epoch in range(start_epoch + 1, args.num_epochs + 1):
        start_time_e = time()
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()
            if phase == 'valid':
                write_log(logger, 'Validation start...')
                val_loss = 0
                val_acc = 0
                model.eval()
            for i, (src, trg) in enumerate(
                    tqdm(dataloader_dict[phase],
                         bar_format='{l_bar}{bar:30}{r_bar}{bar:-2b}')):

                # Optimizer setting
                optimizer.zero_grad(set_to_none=True)

                # Input, output setting
                src = src.to(device, non_blocking=True)
                trg = trg.to(device, non_blocking=True)

                trg_sequences_target = trg[:, 1:]
                non_pad = trg_sequences_target != args.pad_id
                trg_sequences_target = trg_sequences_target[
                    non_pad].contiguous().view(-1)

                # Train
                if phase == 'train':

                    # Loss calculate
                    with autocast():
                        predicted = model(src,
                                          trg[:, :-1],
                                          tgt_mask,
                                          non_pad_position=non_pad)
                        predicted = predicted.view(-1, predicted.size(-1))
                        loss = label_smoothing_loss(predicted,
                                                    trg_sequences_target,
                                                    args.pad_id)

                    scaler.scale(loss).backward()
                    scaler.unscale_(optimizer)
                    clip_grad_norm_(model.parameters(), args.clip_grad_norm)
                    scaler.step(optimizer)
                    scaler.update()

                    if args.scheduler in ['constant', 'warmup']:
                        scheduler.step()
                    if args.scheduler == 'reduce_train':
                        scheduler.step(loss)

                    # Print loss value only training
                    if i == 0 or freq == args.print_freq or i == len(
                            dataloader_dict['train']):
                        acc = (predicted.max(dim=1)[1] == trg_sequences_target
                               ).sum() / len(trg_sequences_target)
                        iter_log = "[Epoch:%03d][%03d/%03d] train_loss:%03.3f | train_acc:%03.2f%% | learning_rate:%1.6f | spend_time:%02.2fmin" % \
                            (epoch, i, len(dataloader_dict['train']),
                            loss.item(), acc*100, optimizer.param_groups[0]['lr'],
                            (time() - start_time_e) / 60)
                        write_log(logger, iter_log)
                        freq = 0
                    freq += 1

                # Validation
                if phase == 'valid':
                    with torch.no_grad():
                        predicted = model(src,
                                          trg[:, :-1],
                                          tgt_mask,
                                          non_pad_position=non_pad)
                        loss = F.cross_entropy(predicted, trg_sequences_target)
                    val_loss += loss.item()
                    val_acc += (predicted.max(dim=1)[1] == trg_sequences_target
                                ).sum() / len(trg_sequences_target)
                    if args.scheduler == 'reduce_valid':
                        scheduler.step(val_loss)
                    if args.scheduler == 'lambda':
                        scheduler.step()

            if phase == 'valid':
                val_loss /= len(dataloader_dict[phase])
                val_acc /= len(dataloader_dict[phase])
                write_log(logger, 'Validation Loss: %3.3f' % val_loss)
                write_log(logger,
                          'Validation Accuracy: %3.2f%%' % (val_acc * 100))
                if val_acc > best_val_acc:
                    write_log(logger, 'Checkpoint saving...')
                    torch.save(
                        {
                            'epoch': epoch,
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            'scheduler': scheduler.state_dict(),
                            'scaler': scaler.state_dict()
                        }, f'checkpoint_{args.parallel}.pth.tar')
                    best_val_acc = val_acc
                    best_epoch = epoch
                else:
                    else_log = f'Still {best_epoch} epoch accuracy({round(best_val_acc.item()*100, 2)})% is better...'
                    write_log(logger, else_log)

    # 3) Print results
    print(f'Best Epoch: {best_epoch}')
    print(f'Best Accuracy: {round(best_val_acc.item(), 2)}')
Exemple #3
0
def training(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    #===================================#
    #==============Logging==============#
    #===================================#

    logger = logging.getLogger(__name__)
    logger.setLevel(logging.DEBUG)
    handler = TqdmLoggingHandler()
    handler.setFormatter(
        logging.Formatter(" %(asctime)s - %(message)s", "%Y-%m-%d %H:%M:%S"))
    logger.addHandler(handler)
    logger.propagate = False

    #===================================#
    #============Data Load==============#
    #===================================#

    # 1) Dataloader setting
    write_log(logger, "Load data...")
    gc.disable()
    dataset_dict = {
        'train': CustomDataset(data_path=args.preprocessed_path,
                               phase='train'),
        'valid': CustomDataset(data_path=args.preprocessed_path,
                               phase='valid'),
        'test': CustomDataset(data_path=args.preprocessed_path, phase='test')
    }
    unique_menu_count = dataset_dict['train'].unique_count()
    dataloader_dict = {
        'train':
        DataLoader(dataset_dict['train'],
                   drop_last=True,
                   batch_size=args.batch_size,
                   shuffle=True,
                   pin_memory=True,
                   num_workers=args.num_workers,
                   collate_fn=PadCollate()),
        'valid':
        DataLoader(dataset_dict['valid'],
                   drop_last=False,
                   batch_size=args.batch_size,
                   shuffle=False,
                   pin_memory=True,
                   num_workers=args.num_workers,
                   collate_fn=PadCollate()),
        'test':
        DataLoader(dataset_dict['test'],
                   drop_last=False,
                   batch_size=args.batch_size,
                   shuffle=False,
                   pin_memory=True,
                   num_workers=args.num_workers,
                   collate_fn=PadCollate())
    }
    gc.enable()
    write_log(
        logger,
        f"Total number of trainingsets  iterations - {len(dataset_dict['train'])}, {len(dataloader_dict['train'])}"
    )

    #===================================#
    #===========Model setting===========#
    #===================================#

    # 1) Model initiating
    write_log(logger, "Instantiating models...")
    model = Transformer(model_type=args.model_type,
                        input_size=unique_menu_count,
                        d_model=args.d_model,
                        d_embedding=args.d_embedding,
                        n_head=args.n_head,
                        dim_feedforward=args.dim_feedforward,
                        num_encoder_layer=args.num_encoder_layer,
                        dropout=args.dropout)
    model = model.train()
    model = model.to(device)

    # 2) Optimizer setting
    optimizer = optimizer_select(model, args)
    scheduler = shceduler_select(optimizer, dataloader_dict, args)
    criterion = nn.MSELoss()
    scaler = GradScaler(enabled=True)

    model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

    # 2) Model resume
    start_epoch = 0
    if args.resume:
        checkpoint = torch.load(os.path.join(args.model_path,
                                             'checkpoint.pth.tar'),
                                map_location='cpu')
        start_epoch = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        model = model.train()
        model = model.to(device)
        del checkpoint

    #===================================#
    #=========Model Train Start=========#
    #===================================#

    best_val_rmse = 9999999

    write_log(logger, 'Train start!')

    for epoch in range(start_epoch, args.num_epochs):
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()
                train_start_time = time.time()
                freq = 0
            elif phase == 'valid':
                model.eval()
                val_loss = 0
                val_rmse = 0

            for i, (src_menu, label_lunch,
                    label_supper) in enumerate(dataloader_dict[phase]):

                # Optimizer setting
                optimizer.zero_grad()

                # Input, output setting
                src_menu = src_menu.to(device, non_blocking=True)
                label_lunch = label_lunch.float().to(device, non_blocking=True)
                label_supper = label_supper.float().to(device,
                                                       non_blocking=True)

                # Model
                with torch.set_grad_enabled(phase == 'train'):
                    with autocast(enabled=True):
                        if args.model_type == 'sep':
                            logit = model(src_menu)
                            logit_lunch = logit[:, 0]
                            logit_supper = logit[:, 0]
                        elif args.model_type == 'total':
                            logit = model(src_menu)
                            logit_lunch = logit[:, 0]
                            logit_supper = logit[:, 1]

                    # Loss calculate
                    loss_lunch = criterion(logit_lunch, label_lunch)
                    loss_supper = criterion(logit_supper, label_supper)
                    loss = loss_lunch + loss_supper

                # Back-propagation
                if phase == 'train':
                    scaler.scale(loss).backward()
                    scaler.unscale_(optimizer)
                    clip_grad_norm_(model.parameters(), args.clip_grad_norm)
                    scaler.step(optimizer)
                    scaler.update()

                    # Scheduler setting
                    if args.scheduler in ['constant', 'warmup']:
                        scheduler.step()
                    if args.scheduler == 'reduce_train':
                        scheduler.step(loss)

                # Print loss value
                rmse_loss = torch.sqrt(loss)
                if phase == 'train':
                    if i == 0 or freq == args.print_freq or i == len(
                            dataloader_dict['train']):
                        batch_log = "[Epoch:%d][%d/%d] train_MSE_loss:%2.3f  | train_RMSE_loss:%2.3f | learning_rate:%3.6f | spend_time:%3.2fmin" \
                                % (epoch+1, i, len(dataloader_dict['train']),
                                loss.item(), rmse_loss.item(), optimizer.param_groups[0]['lr'],
                                (time.time() - train_start_time) / 60)
                        write_log(logger, batch_log)
                        freq = 0
                    freq += 1
                elif phase == 'valid':
                    val_loss += loss.item()
                    val_rmse += rmse_loss.item()

        if phase == 'valid':
            val_loss /= len(dataloader_dict['valid'])
            val_rmse /= len(dataloader_dict['valid'])
            write_log(logger, 'Validation Loss: %3.3f' % val_loss)
            write_log(logger, 'Validation RMSE: %3.3f' % val_rmse)

            if val_rmse < best_val_rmse:
                write_log(logger, 'Checkpoint saving...')
                if not os.path.exists(args.save_path):
                    os.mkdir(args.save_path)
                torch.save(
                    {
                        'epoch': epoch,
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict(),
                        'scaler': scaler.state_dict()
                    }, os.path.join(args.save_path, f'checkpoint_cap.pth.tar'))
                best_val_rmse = val_rmse
                best_epoch = epoch
            else:
                else_log = f'Still {best_epoch} epoch RMSE({round(best_val_rmse, 3)}) is better...'
                write_log(logger, else_log)

    # 3)
    write_log(logger, f'Best Epoch: {best_epoch+1}')
    write_log(logger, f'Best Accuracy: {round(best_val_rmse, 3)}')
  torch.backends.cudnn.deterministic = True

  model = Transformer(
    source_vocab_size=SOURCE_VOCAB_SIZE,
    target_vocab_size=TARGET_VOCAB_SIZE,
    source_padding_index=SRC_PAD_IDX,
    target_padding_index=TRG_PAD_IDX,
    embedding_size=const.EMBEDDING_SIZE,
    number_of_layers=const.NUMBER_OF_LAYERS,
    number_of_heads=const.NUMBER_OF_HEADS,
    forward_expansion=const.FORWARD_EXPANSION,
    device=device,
  ).to(device)

  model.apply(model_utils.initialize_weights)
  optimizer = torch.optim.Adam(model.parameters(), lr=const.LEARNING_RATE)
  cross_entropy = nn.CrossEntropyLoss(ignore_index=TRG_PAD_IDX)

  print(f'The model has {model_utils.count_parameters(model):,} trainable parameters')

  trainer = Trainer(
    const=const,
    optimizer=optimizer,
    criterion=cross_entropy,
    device=device,
  )

  trainer.train(
    model=model,
    train_iterator=train_iterator,
    valid_iterator=valid_iterator,
Exemple #5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--problem', required=True)
    parser.add_argument('--train_step', type=int, default=200)
    parser.add_argument('--batch_size', type=int, default=4096)
    parser.add_argument('--max_length', type=int, default=100)
    parser.add_argument('--n_layers', type=int, default=6)
    parser.add_argument('--hidden_size', type=int, default=512)
    parser.add_argument('--filter_size', type=int, default=2048)
    parser.add_argument('--warmup', type=int, default=16000)
    parser.add_argument('--dropout', type=float, default=0.1)
    parser.add_argument('--label_smoothing', type=float, default=0.1)
    parser.add_argument('--val_every', type=int, default=5)
    parser.add_argument('--output_dir', type=str, default='./output')
    parser.add_argument('--data_dir', type=str, default='./data')
    parser.add_argument('--no_cuda', action='store_true')
    parser.add_argument('--summary_grad', action='store_true')
    opt = parser.parse_args()

    device = torch.device('cpu' if opt.no_cuda else 'cuda')

    if not os.path.exists(opt.output_dir + '/last/models'):
        os.makedirs(opt.output_dir + '/last/models')
    if not os.path.exists(opt.data_dir):
        os.makedirs(opt.data_dir)

    train_data, validation_data, i_vocab_size, t_vocab_size, opt = \
        problem.prepare(opt.problem, opt.data_dir, opt.max_length,
                        opt.batch_size, device, opt)
    if i_vocab_size is not None:
        print("# of vocabs (input):", i_vocab_size)
    print("# of vocabs (target):", t_vocab_size)

    if os.path.exists(opt.output_dir + '/last/models/last_model.pt'):
        print("Load a checkpoint...")
        last_model_path = opt.output_dir + '/last/models'
        model, global_step = utils.load_checkpoint(last_model_path, device,
                                                   is_eval=False)
    else:
        model = Transformer(i_vocab_size, t_vocab_size,
                            n_layers=opt.n_layers,
                            hidden_size=opt.hidden_size,
                            filter_size=opt.filter_size,
                            dropout_rate=opt.dropout,
                            share_target_embedding=opt.share_target_embedding,
                            has_inputs=opt.has_inputs)
        model = model.to(device=device)
        global_step = 0

    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("# of parameters: {}".format(num_params))

    optimizer = LRScheduler(
        filter(lambda x: x.requires_grad, model.parameters()),
        opt.hidden_size, opt.warmup, step=global_step)

    writer = SummaryWriter(opt.output_dir + '/last')
    val_writer = SummaryWriter(opt.output_dir + '/last/val')
    best_val_loss = float('inf')

    for t_step in range(opt.train_step):
        print("Epoch", t_step)
        start_epoch_time = time.time()
        global_step = train(train_data, model, opt, global_step,
                            optimizer, t_vocab_size, opt.label_smoothing,
                            writer)
        print("Epoch Time: {:.2f} sec".format(time.time() - start_epoch_time))

        if t_step % opt.val_every != 0:
            continue

        val_loss = validation(validation_data, model, global_step,
                              t_vocab_size, val_writer, opt)
        utils.save_checkpoint(model, opt.output_dir + '/last/models',
                              global_step, val_loss < best_val_loss)
        best_val_loss = min(val_loss, best_val_loss)
save_folder = 'weights'
save_file = "triu_mask_model.pkl"
step = 1000
total_loss = -1.
src_sequence_size = 8
tgt_sequence_size = 8

if __name__ == "__main__":
    dataset = Dataset(bd.en_dict, bd.cn_dict, bd.sentence_pair_demo,
                      src_sequence_size, tgt_sequence_size)
    model = Transformer(src_vocab_size=len(bd.en_dict),
                        tgt_vocab_size=len(bd.cn_dict),
                        word_emb_dim=8,
                        tgt_sequence_size=8)
    loss_f = torch.nn.NLLLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    model.train()
    upper_tri = get_upper_triangular(8)
    for i in range(step):
        optimizer.zero_grad()
        src, tgt_in, tgt_out, _, _ = dataset.get_batch(batch_size=1)
        output = model(src, tgt_in, tgt_mask=upper_tri)
        loss = loss_f(torch.log(output), tgt_out)
        if total_loss < 0:
            total_loss = loss.detach().numpy()
        else:
            total_loss = total_loss * 0.95 + loss.detach().numpy() * 0.05
        loss.backward()
        optimizer.step()
        if (i + 1) % 100 == 0:
            print("step: ", i + 1, "loss:", total_loss)
Exemple #7
0
def main():
    ''' 
    Usage:
    python train.py -data_pkl m30k_deen_shr.pkl -log m30k_deen_shr -embs_share_weight -proj_share_weight -label_smoothing -save_model trained -b 256 -warmup 128000
    '''

    parser = argparse.ArgumentParser()

    parser.add_argument('-data_pkl',
                        default=None)  # all-in-1 data pickle or bpe field

    parser.add_argument('-train_path', default=None)  # bpe encoded data
    parser.add_argument('-val_path', default=None)  # bpe encoded data

    parser.add_argument('-epoch', type=int, default=10)
    parser.add_argument('-b', '--batch_size', type=int, default=2048)

    parser.add_argument('-d_model', type=int, default=512)
    parser.add_argument('-d_inner_hid', type=int, default=2048)
    parser.add_argument('-d_k', type=int, default=64)
    parser.add_argument('-d_v', type=int, default=64)

    parser.add_argument('-n_head', type=int, default=8)
    parser.add_argument('-n_layers', type=int, default=6)
    parser.add_argument('-warmup', '--n_warmup_steps', type=int, default=4000)

    parser.add_argument('-dropout', type=float, default=0.1)
    parser.add_argument('-embs_share_weight', action='store_true')
    parser.add_argument('-proj_share_weight', action='store_true')

    parser.add_argument('-log', default=None)
    parser.add_argument('-save_model', default=None)
    parser.add_argument('-save_mode',
                        type=str,
                        choices=['all', 'best'],
                        default='best')

    parser.add_argument('-no_cuda', action='store_true')
    parser.add_argument('-label_smoothing', action='store_true')

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda
    opt.d_word_vec = opt.d_model

    if not opt.log and not opt.save_model:
        print('No experiment result will be saved.')
        raise

    if opt.batch_size < 2048 and opt.n_warmup_steps <= 4000:
        print('[Warning] The warmup steps may be not enough.\n' \
              '(sz_b, warmup) = (2048, 4000) is the official setting.\n' \
              'Using smaller batch w/o longer warmup may cause ' \
              'the warmup stage ends with only little data trained.')

    device = torch.device('cuda' if opt.cuda else 'cpu')

    # ========= Loading Dataset =========#

    if all((opt.train_path, opt.val_path)):
        training_data, validation_data = prepare_dataloaders_from_bpe_files(
            opt, device)
    elif opt.data_pkl:
        training_data, validation_data = prepare_dataloaders(opt, device)
    else:
        raise

    print(opt)

    transformer = Transformer(opt.src_vocab_size,
                              opt.trg_vocab_size,
                              src_pad_idx=opt.src_pad_idx,
                              trg_pad_idx=opt.trg_pad_idx,
                              trg_emb_prj_weight_sharing=opt.proj_share_weight,
                              src_emb_prj_weight_sharing=opt.embs_share_weight,
                              d_k=opt.d_k,
                              d_v=opt.d_v,
                              d_model=opt.d_model,
                              d_word_vec=opt.d_word_vec,
                              d_inner=opt.d_inner_hid,
                              n_layers=opt.n_layers,
                              n_head=opt.n_head,
                              dropout=opt.dropout).to(device)
    model_path = 'checkpoints/pretrained.chkpt'
    checkpoint = torch.load(model_path, map_location=device)
    transformer.load_state_dict(checkpoint['model'])
    optimizer = ScheduledOptim(
        optim.Adam(transformer.parameters(), betas=(0.9, 0.98), eps=1e-09),
        2.0, opt.d_model, opt.n_warmup_steps)

    train(transformer, training_data, validation_data, optimizer, device, opt)
    warmup_steps = configuration['hyper_params']['warmup_steps']

    # dataset params
    batch_size = configuration['train_dataset_params']['loader_params']['batch_size']
    vocab_size = configuration['train_dataset_params']['loader_params']['vocab_size']
    data_pct = configuration['train_dataset_params']['loader_params']['pct']
    num_workers = configuration['train_dataset_params']['loader_params']['num_workers']

    print(f"Setting up vocabulary...")
    train_data_loader, valid_data_loader = create_data_loaders(data_pct=data_pct, batch_size, vocab_size, num_workers=4)
    
    print(f"Training model with \
        epochs={epochs} \
        batch_size={batch_size} \
        vocab_size={vocab_size} \
        warmup_steps={warmup_steps} \
        training_examples={len(train_loader) * batch_size} \
        on device={device}")

    model = Transformer(vocab_size, vocab_size, d_model, d_hidden, n_heads, N)
    optimizer = NoamOptimizer(torch.optim.Adam(model.parameters(), betas=(.9,.98), eps=1e-9, lr=0.), d_model, warmup_steps)
    criterion = LabelSmoothingCrossEntropy().to(device)

    train_losses, valid_losses = train(model, epochs, criterion, train_data_loader, valid_data_loader, device)

    print("Model finished training, plotting losses...")
    plot_losses(train_losses, valid_losses)


    
Exemple #9
0
                               num_hidden_layers=6,
                               num_attn_head=8,
                               hidden_act='gelu',
                               device=device,
                               feed_forward_size=2048,
                               padding_idx=0,
                               share_embeddings=True,
                               enc_max_seq_length=128,
                               dec_max_seq_length=128)

    model = Transformer(config).to(config.device)

    dataset = CustomDataset(src_lines, trg_lines, tokenizer, config)
    data_loader = DataLoader(dataset, batch_size=16, shuffle=True)
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
                                                     mode='min',
                                                     patience=2)

    train_continue = False
    plus_epoch = 30
    if train_continue:
        weights = glob.glob('./weight/transformer_*')
        last_epoch = int(weights[-1].split('_')[-1])
        weight_path = weights[-1].replace('\\', '/')
        print('weight info of last epoch', weight_path)
        model.load_state_dict(torch.load(weight_path))
        total_epoch = last_epoch + plus_epoch
    else:
        last_epoch = 0
Exemple #10
0
def make_transformer(model_config, vocab):
    attention = MultiHeadedAttention(
        head_num=model_config['head_num'],
        feature_size=model_config['feature_size'],
        dropout=model_config['dropout_rate']
    )
    attention_with_cache = MultiHeadedAttentionWithCache(
        head_num=model_config['head_num'],
        feature_size=model_config['feature_size'],
        dropout=model_config['dropout_rate']
    )
    feed_forward = PositionWiseFeedForward(
        input_dim=model_config['feature_size'],
        ff_dim=model_config['feedforward_dim'],
        dropout=model_config['dropout_rate']
    )

    model = Transformer(
        src_embedding_layer=Embeddings(
            vocab_size=len(vocab['src']),
            emb_size=model_config['feature_size'],
            dropout=model_config['dropout_rate'],
            max_len=5000
        ),
        trg_embedding_layer=Embeddings(
            vocab_size=len(vocab['trg']),
            emb_size=model_config['feature_size'],
            dropout=model_config['dropout_rate'],
            max_len=5000
        ),
        encoder=TransformerEncoder(
            layer=TransformerEncoderLayer(feature_size=model_config['feature_size'],
                                          self_attention_layer=copy.deepcopy(attention),
                                          feed_forward_layer=copy.deepcopy(feed_forward),
                                          dropout_rate=model_config['dropout_rate'],
                                          layer_norm_rescale=model_config['layer_norm_rescale']),
            feature_size=model_config['feature_size'],
            num_layers=model_config['num_layers'],
            layer_norm_rescale=model_config['layer_norm_rescale'],
        ),
        decoder=TransformerDecoder(
            layer=TransformerDecoderLayer(feature_size=model_config['feature_size'],
                                          self_attention_layer=copy.deepcopy(attention_with_cache),
                                          cross_attention_layer=copy.deepcopy(attention_with_cache),
                                          feed_forward_layer=copy.deepcopy(feed_forward),
                                          dropout_rate=model_config['dropout_rate'],
                                          layer_norm_rescale=model_config['layer_norm_rescale'], ),
            num_layers=model_config['num_layers'],
            feature_size=model_config['feature_size'],
            layer_norm_rescale=model_config['layer_norm_rescale'],
        ),
        generator=SimpleGenerator(feature_size=model_config['feature_size'],
                                  vocab_size=len(vocab['trg']),
                                  bias=model_config['generator_bias']),
        vocab=vocab,
        share_decoder_embedding=model_config['share_decoder_embedding'],
        share_enc_dec_embedding=model_config['share_enc_dec_embedding'],
    )

    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return model
def train_model(train_iterator, val_iterator, test_iterator):
    batch_size = 32
    vocab_size = len(train_iterator.word2index)
    dmodel = 64
    output_size = 2
    padding_idx = train_iterator.word2index['<PAD>']
    n_layers = 4
    ffnn_hidden_size = dmodel * 2
    heads = 8
    pooling = 'max'
    dropout = 0.5
    label_smoothing = 0.1
    learning_rate = 0.001
    epochs = 30
    CUDA = torch.cuda.is_available()
    max_len = 0
    for batches in train_iterator:
        x_lengths = batches['x_lengths']
        if max(x_lengths) > max_len:
            max_len = int(max(x_lengths))
    model = Transformer(vocab_size, dmodel, output_size, max_len, padding_idx, n_layers, \
                        ffnn_hidden_size, heads, pooling, dropout)
    if CUDA:
        model.cuda()

    if label_smoothing:
        loss_fn = LabelSmoothingLoss(output_size, label_smoothing)
    else:
        loss_fn = nn.NLLLoss()
    model.add_loss_fn(loss_fn)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    model.add_optimizer(optimizer)

    device = torch.device('cuda' if CUDA else 'cpu')

    model.add_device(device)
    params = {
        'batch_size': batch_size,
        'dmodel': dmodel,
        'n_layers': n_layers,
        'ffnn_hidden_size': ffnn_hidden_size,
        'heads': heads,
        'pooling': pooling,
        'dropout': dropout,
        'label_smoothing': label_smoothing,
        'learning_rate': learning_rate
    }

    train_writer = SummaryWriter('runs/transformer_train')

    val_writer = SummaryWriter('runs/transformer_val')

    early_stop = EarlyStopping(wait_epochs=3)

    train_losses_list, train_avg_loss_list, train_accuracy_list = [], [], []
    eval_avg_loss_list, eval_accuracy_list, conf_matrix_list = [], [], []
    for epoch in range(epochs):

        try:
            print('\nStart epoch [{}/{}]'.format(epoch + 1, epochs))

            train_losses, train_avg_loss, train_accuracy = model.train_model(
                train_iterator)

            train_losses_list.append(train_losses)
            train_avg_loss_list.append(train_avg_loss)
            train_accuracy_list.append(train_accuracy)

            _, eval_avg_loss, eval_accuracy, conf_matrix = model.evaluate_model(
                val_iterator)

            eval_avg_loss_list.append(eval_avg_loss)
            eval_accuracy_list.append(eval_accuracy)
            conf_matrix_list.append(conf_matrix)

            print(
                '\nEpoch [{}/{}]: Train accuracy: {:.3f}. Train loss: {:.4f}. Evaluation accuracy: {:.3f}. Evaluation loss: {:.4f}' \
                .format(epoch + 1, epochs, train_accuracy, train_avg_loss, eval_accuracy, eval_avg_loss))

            train_writer.add_scalar('Training loss', train_avg_loss, epoch)
            val_writer.add_scalar('Validation loss', eval_avg_loss, epoch)

            if early_stop.stop(eval_avg_loss, model, delta=0.003):
                break

        finally:
            train_writer.close()
            val_writer.close()

    _, test_avg_loss, test_accuracy, test_conf_matrix = model.evaluate_model(
        test_iterator)
    print('Test accuracy: {:.3f}. Test error: {:.3f}'.format(
        test_accuracy, test_avg_loss))
Exemple #12
0
    return np.mean(l), np.mean(p)


#=================================main=================================

p = Personas()
writer = SummaryWriter(log_dir=config.save_path)
# Build model, optimizer, and set states
if not (config.load_frompretrain == 'None'):
    meta_net = Transformer(p.vocab,
                           model_file_path=config.load_frompretrain,
                           is_eval=False)
else:
    meta_net = Transformer(p.vocab)
if config.meta_optimizer == 'sgd':
    meta_optimizer = torch.optim.SGD(meta_net.parameters(), lr=config.meta_lr)
elif config.meta_optimizer == 'adam':
    meta_optimizer = torch.optim.Adam(meta_net.parameters(), lr=config.meta_lr)
elif config.meta_optimizer == 'noam':
    meta_optimizer = NoamOpt(
        config.hidden_dim, 1, 4000,
        torch.optim.Adam(meta_net.parameters(),
                         lr=0,
                         betas=(0.9, 0.98),
                         eps=1e-9))
else:
    raise ValueError

meta_batch_size = config.meta_batch_size
tasks = p.get_personas('train')
#tasks_loader = {t: p.get_data_loader(persona=t,batch_size=config.batch_size, split='train') for t in tasks}
Exemple #13
0
model = Transformer(src_pad_idx=src_pad_idx,
                    trg_pad_idx=trg_pad_idx,
                    trg_sos_idx=trg_sos_idx,
                    d_model=d_model,
                    enc_voc_size=enc_voc_size,
                    dec_voc_size=dec_voc_size,
                    max_len=max_len,
                    ffn_hidden=ffn_hidden,
                    n_head=n_head,
                    n_layers=n_layers,
                    drop_prob=drop_prob,
                    device=device).to(device)

model.apply(initialize_weights)
optimizer = Adam(params=model.parameters(),
                 lr=init_lr,
                 weight_decay=weight_decay,
                 eps=adam_eps)

criterion = nn.CrossEntropyLoss(ignore_index=src_pad_idx)


def train(model, iterator, optimizer, criterion, clip):
    model.train()
    epoch_loss = 0
    for i, batch in enumerate(iterator):
        src = batch.src
        trg = batch.trg

        optimizer.zero_grad()
Exemple #14
0
class Trainer:
    def __init__(self,
                 params,
                 mode,
                 train_iter=None,
                 valid_iter=None,
                 test_iter=None):
        self.params = params

        # Train mode
        if mode == 'train':
            self.train_iter = train_iter
            self.valid_iter = valid_iter

        # Test mode
        else:
            self.test_iter = test_iter

        self.model = Transformer(self.params)
        self.model.to(self.params.device)

        # Scheduling Optimzer
        self.optimizer = ScheduledAdam(optim.Adam(self.model.parameters(),
                                                  betas=(0.9, 0.98),
                                                  eps=1e-9),
                                       hidden_dim=params.hidden_dim,
                                       warm_steps=params.warm_steps)

        self.criterion = nn.CrossEntropyLoss(ignore_index=self.params.pad_idx)
        self.criterion.to(self.params.device)

    def train(self):
        print(self.model)
        print(
            f'The model has {self.model.count_params():,} trainable parameters'
        )
        best_valid_loss = float('inf')

        for epoch in range(self.params.num_epoch):
            self.model.train()
            epoch_loss = 0
            start_time = time.time()

            for batch in self.train_iter:
                # For each batch, first zero the gradients
                self.optimizer.zero_grad()
                source = batch.kor
                target = batch.eng

                # target sentence consists of <sos> and following tokens (except the <eos> token)
                output = self.model(source, target[:, :-1])[0]

                # ground truth sentence consists of tokens and <eos> token (except the <sos> token)
                output = output.contiguous().view(-1, output.shape[-1])
                target = target[:, 1:].contiguous().view(-1)
                # output = [(batch size * target length - 1), output dim]
                # target = [(batch size * target length - 1)]
                loss = self.criterion(output, target)
                loss.backward()

                # clip the gradients to prevent the model from exploding gradient
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               self.params.clip)

                self.optimizer.step()

                # 'item' method is used to extract a scalar from a tensor which only contains a single value.
                epoch_loss += loss.item()

            train_loss = epoch_loss / len(self.train_iter)
            valid_loss = self.evaluate()

            end_time = time.time()
            epoch_mins, epoch_secs = epoch_time(start_time, end_time)

            if valid_loss < best_valid_loss:
                best_valid_loss = valid_loss
                torch.save(self.model.state_dict(), self.params.save_model)

            print(
                f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s'
            )
            print(
                f'\tTrain Loss: {train_loss:.3f} | Val. Loss: {valid_loss:.3f}'
            )

    def evaluate(self):
        self.model.eval()
        epoch_loss = 0

        with torch.no_grad():
            for batch in self.valid_iter:
                source = batch.kor
                target = batch.eng

                output = self.model(source, target[:, :-1])[0]

                output = output.contiguous().view(-1, output.shape[-1])
                target = target[:, 1:].contiguous().view(-1)

                loss = self.criterion(output, target)

                epoch_loss += loss.item()

        return epoch_loss / len(self.valid_iter)

    def inference(self):
        self.model.load_state_dict(torch.load(self.params.save_model))
        self.model.eval()
        epoch_loss = 0

        with torch.no_grad():
            for batch in self.test_iter:
                source = batch.kor
                target = batch.eng

                output = self.model(source, target[:, :-1])[0]

                output = output.contiguous().view(-1, output.shape[-1])
                target = target[:, 1:].contiguous().view(-1)

                loss = self.criterion(output, target)

                epoch_loss += loss.item()

        test_loss = epoch_loss / len(self.test_iter)
        print(f'Test Loss: {test_loss:.3f}')
Exemple #15
0
    p, l = [],[]
    for batch in test_iter:
        loss, ppl, _ = model.train_one_batch(batch, train=False)
        l.append(loss)
        p.append(ppl)
    return np.mean(l), np.mean(p)

#=================================main=================================

p = Personas()
writer = SummaryWriter(log_dir=config.save_path)
# Build model, optimizer, and set states
if not (config.load_frompretrain=='None'): meta_net = Transformer(p.vocab,model_file_path=config.load_frompretrain,is_eval=False)
else: meta_net = Transformer(p.vocab)
if config.meta_optimizer=='sgd':
    meta_optimizer = torch.optim.SGD(meta_net.parameters(), lr=config.meta_lr)
elif config.meta_optimizer=='adam':
    meta_optimizer = torch.optim.Adam(meta_net.parameters(), lr=config.meta_lr)
elif config.meta_optimizer=='noam':
    meta_optimizer = NoamOpt(config.hidden_dim, 1, 4000, torch.optim.Adam(meta_net.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
else:
    raise ValueError

meta_batch_size = config.meta_batch_size
tasks = p.get_personas('train')
#tasks_loader = {t: p.get_data_loader(persona=t,batch_size=config.batch_size, split='train') for t in tasks}
tasks_iter = make_infinite_list(tasks)


# meta early stop
patience = 50
def main(args):
    comm = MPI.COMM_WORLD
    world_size = comm.Get_size()
    rank = comm.Get_rank()
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = str(args.master_port)
    torch.cuda.set_device(rank)
    dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)
    device = torch.device("cuda")

    logger = None
    tb_logger = None
    if rank == 0:
        if not os.path.exists(args.save_path):
            os.mkdir(args.save_path)
        if not os.path.exists(args.tensorboard_log_dir):
            os.mkdir(args.tensorboard_log_dir)
        tb_logger = SummaryWriter(
            f"{args.tensorboard_log_dir}/{args.model_name}")

        logger = logging.getLogger(__name__)
        logger.setLevel(logging.DEBUG)
        handler = TqdmLoggingHandler()
        handler.setFormatter(logging.Formatter(" %(asctime)s - %(message)s"))
        logger.addHandler(handler)
        logger.propagate = False

    write_log(logger, "Load data")

    def load_data(args):
        gc.disable()
        with open(f"{args.preprocessed_data_path}/hanja_korean_word2id.pkl",
                  "rb") as f:
            data = pickle.load(f)
            hanja_word2id = data['hanja_word2id']
            korean_word2id = data['korean_word2id']

        with open(f"{args.preprocessed_data_path}/preprocessed_train.pkl",
                  "rb") as f:
            data = pickle.load(f)
            train_hanja_indices = data['hanja_indices']
            train_korean_indices = data['korean_indices']
            train_additional_hanja_indices = data['additional_hanja_indices']

        with open(f"{args.preprocessed_data_path}/preprocessed_valid.pkl",
                  "rb") as f:
            data = pickle.load(f)
            valid_hanja_indices = data['hanja_indices']
            valid_korean_indices = data['korean_indices']
            valid_additional_hanja_indices = data['additional_hanja_indices']

        gc.enable()
        write_log(logger, "Finished loading data!")
        return (hanja_word2id, korean_word2id, train_hanja_indices,
                train_korean_indices, train_additional_hanja_indices,
                valid_hanja_indices, valid_korean_indices,
                valid_additional_hanja_indices)

    # load data
    (hanja_word2id, korean_word2id, train_hanja_indices, train_korean_indices,
     train_additional_hanja_indices, valid_hanja_indices, valid_korean_indices,
     valid_additional_hanja_indices) = load_data(args)
    hanja_vocab_num = len(hanja_word2id)
    korean_vocab_num = len(korean_word2id)

    hk_dataset = HanjaKoreanDataset(train_hanja_indices,
                                    train_korean_indices,
                                    min_len=args.min_len,
                                    src_max_len=args.src_max_len,
                                    trg_max_len=args.trg_max_len)
    hk_sampler = DistributedSampler(hk_dataset,
                                    num_replicas=world_size,
                                    rank=rank)
    hk_loader = DataLoader(hk_dataset,
                           drop_last=True,
                           batch_size=args.hk_batch_size,
                           sampler=hk_sampler,
                           num_workers=args.num_workers,
                           prefetch_factor=4,
                           pin_memory=True)
    write_log(logger, f"hanja-korean: {len(hk_dataset)}, {len(hk_loader)}")

    h_dataset = HanjaDataset(train_hanja_indices,
                             train_additional_hanja_indices,
                             hanja_word2id,
                             min_len=args.min_len,
                             src_max_len=args.src_max_len)
    h_sampler = DistributedSampler(h_dataset,
                                   num_replicas=world_size,
                                   rank=rank)
    h_loader = DataLoader(h_dataset,
                          drop_last=True,
                          batch_size=args.h_batch_size,
                          sampler=h_sampler,
                          num_workers=args.num_workers,
                          prefetch_factor=4,
                          pin_memory=True)
    write_log(logger, f"hanja: {len(h_dataset)}, {len(h_loader)}")

    hk_valid_dataset = HanjaKoreanDataset(valid_hanja_indices,
                                          valid_korean_indices,
                                          min_len=args.min_len,
                                          src_max_len=args.src_max_len,
                                          trg_max_len=args.trg_max_len)
    hk_valid_sampler = DistributedSampler(hk_valid_dataset,
                                          num_replicas=world_size,
                                          rank=rank)
    hk_valid_loader = DataLoader(hk_valid_dataset,
                                 drop_last=True,
                                 batch_size=args.hk_batch_size,
                                 sampler=hk_valid_sampler)
    write_log(
        logger,
        f"hanja-korean-valid: {len(hk_valid_dataset)}, {len(hk_valid_loader)}")

    h_valid_dataset = HanjaDataset(valid_hanja_indices,
                                   valid_additional_hanja_indices,
                                   hanja_word2id,
                                   min_len=args.min_len,
                                   src_max_len=args.src_max_len)
    h_valid_sampler = DistributedSampler(h_valid_dataset,
                                         num_replicas=world_size,
                                         rank=rank)
    h_valid_loader = DataLoader(h_valid_dataset,
                                drop_last=True,
                                batch_size=args.h_batch_size,
                                sampler=h_valid_sampler)
    write_log(logger, f"hanja: {len(h_valid_dataset)}, {len(h_valid_loader)}")

    del (train_hanja_indices, train_korean_indices,
         train_additional_hanja_indices, valid_hanja_indices,
         valid_korean_indices, valid_additional_hanja_indices)

    write_log(logger, "Build model")
    model = Transformer(hanja_vocab_num,
                        korean_vocab_num,
                        pad_idx=args.pad_idx,
                        bos_idx=args.bos_idx,
                        eos_idx=args.eos_idx,
                        src_max_len=args.src_max_len,
                        trg_max_len=args.trg_max_len,
                        d_model=args.d_model,
                        d_embedding=args.d_embedding,
                        n_head=args.n_head,
                        dropout=args.dropout,
                        dim_feedforward=args.dim_feedforward,
                        num_encoder_layer=args.num_encoder_layer,
                        num_decoder_layer=args.num_decoder_layer,
                        num_mask_layer=args.num_mask_layer).to(device)
    model = nn.parallel.DistributedDataParallel(model,
                                                device_ids=[device],
                                                find_unused_parameters=True)
    for param in model.parameters():
        dist.broadcast(param.data, 0)

    dist.barrier()
    write_log(
        logger,
        f"Total Parameters: {sum([p.nelement() for p in model.parameters()])}")

    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = Ralamb(params=optimizer_grouped_parameters, lr=args.lr)

    total_iters = round(
        len(hk_loader) / args.num_grad_accumulate * args.epochs)
    scheduler = get_cosine_schedule_with_warmup(
        optimizer, round(total_iters * args.warmup_ratio), total_iters)
    scaler = GradScaler()

    start_epoch = 0
    if args.resume:

        def load_states():
            checkpoint = torch.load(
                f'{args.save_path}/{args.model_name}_ckpt.pt',
                map_location='cpu')
            start_epoch = checkpoint['epoch'] + 1
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            scaler.load_state_dict(checkpoint['scaler'])
            return start_epoch

        start_epoch = load_states()

    write_log(logger, f"Training start - Total iter: {total_iters}\n")
    iter_num = round(len(hk_loader) / args.num_grad_accumulate)
    global_step = start_epoch * iter_num
    hk_iter = iter(hk_loader)
    h_iter = iter(h_loader)
    model.train()
    tgt_mask = Transformer.generate_square_subsequent_mask(
        args.trg_max_len - 1, device)

    # validation
    validate(model, tgt_mask, h_valid_loader, hk_valid_loader, rank, logger,
             tb_logger, 0, device)

    for epoch in range(start_epoch + 1, args.epochs + 1):
        while True:
            start = time.time()
            finish_epoch = False
            trans_top5, trans_loss, mask_top5, mask_loss = 0.0, 0.0, 0.0, 0.0

            if args.train_reconstruct:
                optimizer.zero_grad(set_to_none=True)
                for _ in range(args.num_grad_accumulate):
                    try:
                        src_sequences, trg_sequences = next(h_iter)
                    except StopIteration:
                        h_sampler.set_epoch(epoch)
                        h_iter = iter(h_loader)
                        src_sequences, trg_sequences = next(h_iter)

                    trg_sequences = trg_sequences.to(device)
                    src_sequences = src_sequences.to(device)
                    non_pad = trg_sequences != args.pad_idx
                    trg_sequences = trg_sequences[non_pad].contiguous().view(
                        -1)

                    with autocast():
                        predicted = model.module.reconstruct_predict(
                            src_sequences, masked_position=non_pad)
                        predicted = predicted.view(-1, predicted.size(-1))
                        loss = label_smoothing_loss(
                            predicted,
                            trg_sequences) / args.num_grad_accumulate

                    scaler.scale(loss).backward()

                    if global_step % args.print_freq == 0:
                        mask_top5 += accuracy(predicted, trg_sequences,
                                              5) / args.num_grad_accumulate
                        mask_loss += loss.detach().item()

                for param in model.parameters():
                    if param.grad is not None:
                        dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
                        param.grad.data = param.grad.data / world_size

                scaler.step(optimizer)
                scaler.update()

            if args.train_translate:
                optimizer.zero_grad(set_to_none=True)
                for _ in range(args.num_grad_accumulate):
                    try:
                        src_sequences, trg_sequences = next(hk_iter)
                    except StopIteration:
                        hk_sampler.set_epoch(epoch)
                        hk_iter = iter(hk_loader)
                        src_sequences, trg_sequences = next(hk_iter)
                        finish_epoch = True

                    trg_sequences = trg_sequences.to(device)
                    trg_sequences_target = trg_sequences[:, 1:]
                    src_sequences = src_sequences.to(device)
                    non_pad = trg_sequences_target != args.pad_idx
                    trg_sequences_target = trg_sequences_target[
                        non_pad].contiguous().view(-1)

                    with autocast():
                        predicted = model(src_sequences,
                                          trg_sequences[:, :-1],
                                          tgt_mask,
                                          non_pad_position=non_pad)
                        predicted = predicted.view(-1, predicted.size(-1))
                        loss = label_smoothing_loss(
                            predicted,
                            trg_sequences_target) / args.num_grad_accumulate

                    scaler.scale(loss).backward()

                    if global_step % args.print_freq == 0:
                        trans_top5 += accuracy(predicted, trg_sequences_target,
                                               5) / args.num_grad_accumulate
                        trans_loss += loss.detach().item()

                for param in model.parameters():
                    if param.grad is not None:
                        dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
                        param.grad.data = param.grad.data / world_size

                scaler.step(optimizer)
                scaler.update()

            scheduler.step()

            # Print status
            if global_step % args.print_freq == 0:
                if args.train_reconstruct:
                    mask_top5 = torch.cuda.FloatTensor([mask_top5])
                    mask_loss = torch.cuda.FloatTensor([mask_loss])
                    dist.all_reduce(mask_top5, op=dist.ReduceOp.SUM)
                    dist.all_reduce(mask_loss, op=dist.ReduceOp.SUM)
                    mask_top5 = (mask_top5 / world_size).item()
                    mask_loss = (mask_loss / world_size).item()

                if args.train_translate:
                    trans_top5 = torch.cuda.FloatTensor([trans_top5])
                    trans_loss = torch.cuda.FloatTensor([trans_loss])
                    dist.all_reduce(trans_top5, op=dist.ReduceOp.SUM)
                    dist.all_reduce(trans_loss, op=dist.ReduceOp.SUM)
                    trans_top5 = (trans_top5 / world_size).item()
                    trans_loss = (trans_loss / world_size).item()

                if rank == 0:
                    batch_time = time.time() - start
                    write_log(
                        logger,
                        f'[{global_step}/{total_iters}, {epoch}]\tIter time: {batch_time:.3f}\t'
                        f'Trans loss: {trans_loss:.3f}\tMask_loss: {mask_loss:.3f}\t'
                        f'Trans@5: {trans_top5:.3f}\tMask@5: {mask_top5:.3f}')

                    tb_logger.add_scalar('loss/translate', trans_loss,
                                         global_step)
                    tb_logger.add_scalar('loss/mask', mask_loss, global_step)
                    tb_logger.add_scalar('top5/translate', trans_top5,
                                         global_step)
                    tb_logger.add_scalar('top5/mask', mask_top5, global_step)
                    tb_logger.add_scalar('batch/time', batch_time, global_step)
                    tb_logger.add_scalar('batch/lr',
                                         optimizer.param_groups[0]['lr'],
                                         global_step)

            global_step += 1
            if finish_epoch:
                break

        # validation
        validate(model, tgt_mask, h_valid_loader, hk_valid_loader, rank,
                 logger, tb_logger, epoch, device)
        # save model
        if rank == 0:
            torch.save(
                {
                    'epoch': epoch,
                    'model': model.module.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'scaler': scaler.state_dict()
                }, f'{args.save_path}/{args.model_name}_ckpt.pt')
            write_log(logger, f"***** {epoch}th model updated! *****")