Ejemplo n.º 1
0
def prediction(text):
    params = Params('config/params.json')

    # load tokenizer and torchtext Fields
    pickle_tokenizer = open('pickles/tokenizer.pickle', 'rb')
    cohesion_scores = pickle.load(pickle_tokenizer)
    tokenizer = LTokenizer(scores=cohesion_scores)

    pickle_kor = open('pickles/kor.pickle', 'rb')
    kor = pickle.load(pickle_kor)
    pickle_eng = open('pickles/eng.pickle', 'rb')
    eng = pickle.load(pickle_eng)
    eos_idx = eng.vocab.stoi['<eos>']

    # select model and load trained model
    model = Transformer(params)
    model.load_state_dict(torch.load(params.save_model))
    model.to(params.device)
    model.eval()

    # convert input into tensor and forward it through selected model
    tokenized = tokenizer.tokenize(text)
    indexed = [kor.vocab.stoi[token] for token in tokenized]


    source = torch.LongTensor(indexed).unsqueeze(0).to(params.device)  # [1, source_len]: unsqueeze to add batch size
    target = torch.zeros(1, params.max_len).type_as(source.data)       # [1, max_len]

    encoder_output = model.encoder(source)
    next_symbol = eng.vocab.stoi['<sos>']

    for i in range(0, params.max_len):
        if next_symbol == eos_idx:
            break
        target[0][i] = next_symbol
        decoder_output, _ = model.decoder(target, source, encoder_output)  # [1, target length, output dim]
        prob = decoder_output.squeeze(0).max(dim=-1, keepdim=False)[1]
        next_word = prob.data[i]
        next_symbol = next_word.item()

    #eos_idx = torch.where(target[0] == eos_idx)[0][0]
    #eos_idx = eos_idx.item()
    eos_index = 34
    print(eos_idx)
    target = target[0][:eos_idx].unsqueeze(0)

    # translation_tensor = [target length] filed with word indices
    target, attention_map = model(source, target)
    target = target.squeeze(0).max(dim=-1)[1]

    reply_token = [eng.vocab.itos[token] for token in target if token != 3]
    print(reply_token)
    #translation = translated_token[:translated_token.index('<eos>')]
    #translation = ''.join(translation)
    reply = ' '.join(reply_token)
    #print(reply)

    #display_attention(tokenized, reply_token, attention_map[4].squeeze(0)[:-1])
    return reply 
Ejemplo n.º 2
0
def predict(config):
    params = Params('config/params.json')

    # load tokenizer and torchtext Fields
    pickle_tokenizer = open('pickles/tokenizer.pickle', 'rb')
    cohesion_scores = pickle.load(pickle_tokenizer)
    tokenizer = LTokenizer(scores=cohesion_scores)

    pickle_kor = open('pickles/kor.pickle', 'rb')
    kor = pickle.load(pickle_kor)

    pickle_eng = open('pickles/eng.pickle', 'rb')
    eng = pickle.load(pickle_eng)

    # select model and load trained model
    model = Transformer(params)

    model.load_state_dict(torch.load(params.save_model))
    model.to(params.device)
    model.eval()

    input = clean_text(config.input)

    # convert input into tensor and forward it through selected model
    tokenized = tokenizer.tokenize(input)
    indexed = [kor.vocab.stoi[token] for token in tokenized]

    source = torch.LongTensor(indexed).unsqueeze(0).to(
        params.device)  # [1, source length]: unsqueeze to add batch size
    target = torch.zeros(1, params.max_len).type_as(source.data)

    encoder_output = model.encoder(source)
    next_symbol = eng.vocab.stoi['<sos>']

    for i in range(0, params.max_len):
        target[0][i] = next_symbol
        dec_output = model.decoder(target, source, encoder_output)
        # dec_output = [1, target length, output dim]
        prob = dec_output.squeeze(0).max(dim=-1, keepdim=False)[1]
        next_word = prob.data[i]
        next_symbol = next_word.item()

    # translation_tensor = [target length] filed with word indices
    target = model(source, target)
    target = torch.argmax(target.squeeze(0), -1)
    # target = target.squeeze(0).max(dim=-1, keepdim=False)
    translation = [eng.vocab.itos[token] for token in target][1:]

    translation = ' '.join(translation)
    print(f'kor> {config.input}')
    print(f'eng> {translation.capitalize()}')
Ejemplo n.º 3
0
def make_model(src_vocab,
               tgt_vocab,
               N=6,
               d_model=512,
               d_ff=2048,
               h=8,
               dropout=0.1):
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model).to(args.device)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout).to(args.device)
    position = PositionalEncoding(d_model, dropout).to(args.device)
    model = Transformer(
        Encoder(
            EncoderLayer(d_model, c(attn), c(ff), dropout).to(args.device),
            N).to(args.device),
        Decoder(
            DecoderLayer(d_model, c(attn), c(attn), c(ff),
                         dropout).to(args.device), N).to(args.device),
        nn.Sequential(
            Embeddings(d_model, src_vocab).to(args.device), c(position)),
        nn.Sequential(
            Embeddings(d_model, tgt_vocab).to(args.device), c(position)),
        Generator(d_model, tgt_vocab)).to(args.device)

    # This was important from their code.
    # Initialize parameters with Glorot / fan_avg.
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    return model.to(args.device)
Ejemplo n.º 4
0
def training(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    #===================================#
    #==============Logging==============#
    #===================================#

    logger = logging.getLogger(__name__)
    logger.setLevel(logging.DEBUG)
    handler = TqdmLoggingHandler()
    handler.setFormatter(
        logging.Formatter(" %(asctime)s - %(message)s", "%Y-%m-%d %H:%M:%S"))
    logger.addHandler(handler)
    logger.propagate = False

    #===================================#
    #============Data Load==============#
    #===================================#

    # 1) Dataloader setting
    write_log(logger, "Load data...")
    gc.disable()
    dataset_dict = {
        'train': CustomDataset(data_path=args.preprocessed_path,
                               phase='train'),
        'valid': CustomDataset(data_path=args.preprocessed_path,
                               phase='valid'),
        'test': CustomDataset(data_path=args.preprocessed_path, phase='test')
    }
    unique_menu_count = dataset_dict['train'].unique_count()
    dataloader_dict = {
        'train':
        DataLoader(dataset_dict['train'],
                   drop_last=True,
                   batch_size=args.batch_size,
                   shuffle=True,
                   pin_memory=True,
                   num_workers=args.num_workers,
                   collate_fn=PadCollate()),
        'valid':
        DataLoader(dataset_dict['valid'],
                   drop_last=False,
                   batch_size=args.batch_size,
                   shuffle=False,
                   pin_memory=True,
                   num_workers=args.num_workers,
                   collate_fn=PadCollate()),
        'test':
        DataLoader(dataset_dict['test'],
                   drop_last=False,
                   batch_size=args.batch_size,
                   shuffle=False,
                   pin_memory=True,
                   num_workers=args.num_workers,
                   collate_fn=PadCollate())
    }
    gc.enable()
    write_log(
        logger,
        f"Total number of trainingsets  iterations - {len(dataset_dict['train'])}, {len(dataloader_dict['train'])}"
    )

    #===================================#
    #===========Model setting===========#
    #===================================#

    # 1) Model initiating
    write_log(logger, "Instantiating models...")
    model = Transformer(model_type=args.model_type,
                        input_size=unique_menu_count,
                        d_model=args.d_model,
                        d_embedding=args.d_embedding,
                        n_head=args.n_head,
                        dim_feedforward=args.dim_feedforward,
                        num_encoder_layer=args.num_encoder_layer,
                        dropout=args.dropout)
    model = model.train()
    model = model.to(device)

    # 2) Optimizer setting
    optimizer = optimizer_select(model, args)
    scheduler = shceduler_select(optimizer, dataloader_dict, args)
    criterion = nn.MSELoss()
    scaler = GradScaler(enabled=True)

    model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

    # 2) Model resume
    start_epoch = 0
    if args.resume:
        checkpoint = torch.load(os.path.join(args.model_path,
                                             'checkpoint.pth.tar'),
                                map_location='cpu')
        start_epoch = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        model = model.train()
        model = model.to(device)
        del checkpoint

    #===================================#
    #=========Model Train Start=========#
    #===================================#

    best_val_rmse = 9999999

    write_log(logger, 'Train start!')

    for epoch in range(start_epoch, args.num_epochs):
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()
                train_start_time = time.time()
                freq = 0
            elif phase == 'valid':
                model.eval()
                val_loss = 0
                val_rmse = 0

            for i, (src_menu, label_lunch,
                    label_supper) in enumerate(dataloader_dict[phase]):

                # Optimizer setting
                optimizer.zero_grad()

                # Input, output setting
                src_menu = src_menu.to(device, non_blocking=True)
                label_lunch = label_lunch.float().to(device, non_blocking=True)
                label_supper = label_supper.float().to(device,
                                                       non_blocking=True)

                # Model
                with torch.set_grad_enabled(phase == 'train'):
                    with autocast(enabled=True):
                        if args.model_type == 'sep':
                            logit = model(src_menu)
                            logit_lunch = logit[:, 0]
                            logit_supper = logit[:, 0]
                        elif args.model_type == 'total':
                            logit = model(src_menu)
                            logit_lunch = logit[:, 0]
                            logit_supper = logit[:, 1]

                    # Loss calculate
                    loss_lunch = criterion(logit_lunch, label_lunch)
                    loss_supper = criterion(logit_supper, label_supper)
                    loss = loss_lunch + loss_supper

                # Back-propagation
                if phase == 'train':
                    scaler.scale(loss).backward()
                    scaler.unscale_(optimizer)
                    clip_grad_norm_(model.parameters(), args.clip_grad_norm)
                    scaler.step(optimizer)
                    scaler.update()

                    # Scheduler setting
                    if args.scheduler in ['constant', 'warmup']:
                        scheduler.step()
                    if args.scheduler == 'reduce_train':
                        scheduler.step(loss)

                # Print loss value
                rmse_loss = torch.sqrt(loss)
                if phase == 'train':
                    if i == 0 or freq == args.print_freq or i == len(
                            dataloader_dict['train']):
                        batch_log = "[Epoch:%d][%d/%d] train_MSE_loss:%2.3f  | train_RMSE_loss:%2.3f | learning_rate:%3.6f | spend_time:%3.2fmin" \
                                % (epoch+1, i, len(dataloader_dict['train']),
                                loss.item(), rmse_loss.item(), optimizer.param_groups[0]['lr'],
                                (time.time() - train_start_time) / 60)
                        write_log(logger, batch_log)
                        freq = 0
                    freq += 1
                elif phase == 'valid':
                    val_loss += loss.item()
                    val_rmse += rmse_loss.item()

        if phase == 'valid':
            val_loss /= len(dataloader_dict['valid'])
            val_rmse /= len(dataloader_dict['valid'])
            write_log(logger, 'Validation Loss: %3.3f' % val_loss)
            write_log(logger, 'Validation RMSE: %3.3f' % val_rmse)

            if val_rmse < best_val_rmse:
                write_log(logger, 'Checkpoint saving...')
                if not os.path.exists(args.save_path):
                    os.mkdir(args.save_path)
                torch.save(
                    {
                        'epoch': epoch,
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict(),
                        'scaler': scaler.state_dict()
                    }, os.path.join(args.save_path, f'checkpoint_cap.pth.tar'))
                best_val_rmse = val_rmse
                best_epoch = epoch
            else:
                else_log = f'Still {best_epoch} epoch RMSE({round(best_val_rmse, 3)}) is better...'
                write_log(logger, else_log)

    # 3)
    write_log(logger, f'Best Epoch: {best_epoch+1}')
    write_log(logger, f'Best Accuracy: {round(best_val_rmse, 3)}')
Ejemplo n.º 5
0
def training(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    #===================================#
    #==============Logging==============#
    #===================================#

    logger = logging.getLogger(__name__)
    logger.setLevel(logging.DEBUG)
    handler = TqdmLoggingHandler()
    handler.setFormatter(
        logging.Formatter(" %(asctime)s - %(message)s", "%Y-%m-%d %H:%M:%S"))
    logger.addHandler(handler)
    logger.propagate = False

    #===================================#
    #============Data Load==============#
    #===================================#

    # 1) Data open
    write_log(logger, "Load data...")
    gc.disable()
    with open(os.path.join(args.preprocess_path, 'processed.pkl'), 'rb') as f:
        data_ = pickle.load(f)
        train_src_indices = data_['train_src_indices']
        valid_src_indices = data_['valid_src_indices']
        train_trg_indices = data_['train_trg_indices']
        valid_trg_indices = data_['valid_trg_indices']
        src_word2id = data_['src_word2id']
        trg_word2id = data_['trg_word2id']
        src_vocab_num = len(src_word2id)
        trg_vocab_num = len(trg_word2id)
        del data_
    gc.enable()
    write_log(logger, "Finished loading data!")

    # 2) Dataloader setting
    dataset_dict = {
        'train':
        CustomDataset(train_src_indices,
                      train_trg_indices,
                      min_len=args.min_len,
                      src_max_len=args.src_max_len,
                      trg_max_len=args.trg_max_len),
        'valid':
        CustomDataset(valid_src_indices,
                      valid_trg_indices,
                      min_len=args.min_len,
                      src_max_len=args.src_max_len,
                      trg_max_len=args.trg_max_len),
    }
    dataloader_dict = {
        'train':
        DataLoader(dataset_dict['train'],
                   drop_last=True,
                   batch_size=args.batch_size,
                   shuffle=True,
                   pin_memory=True,
                   num_workers=args.num_workers),
        'valid':
        DataLoader(dataset_dict['valid'],
                   drop_last=False,
                   batch_size=args.batch_size,
                   shuffle=False,
                   pin_memory=True,
                   num_workers=args.num_workers)
    }
    write_log(
        logger,
        f"Total number of trainingsets  iterations - {len(dataset_dict['train'])}, {len(dataloader_dict['train'])}"
    )

    #===================================#
    #===========Train setting===========#
    #===================================#

    # 1) Model initiating
    write_log(logger, 'Instantiating model...')
    model = Transformer(
        src_vocab_num=src_vocab_num,
        trg_vocab_num=trg_vocab_num,
        pad_idx=args.pad_id,
        bos_idx=args.bos_id,
        eos_idx=args.eos_id,
        d_model=args.d_model,
        d_embedding=args.d_embedding,
        n_head=args.n_head,
        dim_feedforward=args.dim_feedforward,
        num_common_layer=args.num_common_layer,
        num_encoder_layer=args.num_encoder_layer,
        num_decoder_layer=args.num_decoder_layer,
        src_max_len=args.src_max_len,
        trg_max_len=args.trg_max_len,
        dropout=args.dropout,
        embedding_dropout=args.embedding_dropout,
        trg_emb_prj_weight_sharing=args.trg_emb_prj_weight_sharing,
        emb_src_trg_weight_sharing=args.emb_src_trg_weight_sharing,
        parallel=args.parallel)
    model.train()
    model = model.to(device)
    tgt_mask = model.generate_square_subsequent_mask(args.trg_max_len - 1,
                                                     device)

    # 2) Optimizer & Learning rate scheduler setting
    optimizer = optimizer_select(model, args)
    scheduler = shceduler_select(optimizer, dataloader_dict, args)
    scaler = GradScaler()

    # 3) Model resume
    start_epoch = 0
    if args.resume:
        write_log(logger, 'Resume model...')
        checkpoint = torch.load(
            os.path.join(args.save_path, 'checkpoint.pth.tar'))
        start_epoch = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        scaler.load_state_dict(checkpoint['scaler'])
        del checkpoint

    #===================================#
    #=========Model Train Start=========#
    #===================================#

    best_val_acc = 0

    write_log(logger, 'Traing start!')

    for epoch in range(start_epoch + 1, args.num_epochs + 1):
        start_time_e = time()
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()
            if phase == 'valid':
                write_log(logger, 'Validation start...')
                val_loss = 0
                val_acc = 0
                model.eval()
            for i, (src, trg) in enumerate(
                    tqdm(dataloader_dict[phase],
                         bar_format='{l_bar}{bar:30}{r_bar}{bar:-2b}')):

                # Optimizer setting
                optimizer.zero_grad(set_to_none=True)

                # Input, output setting
                src = src.to(device, non_blocking=True)
                trg = trg.to(device, non_blocking=True)

                trg_sequences_target = trg[:, 1:]
                non_pad = trg_sequences_target != args.pad_id
                trg_sequences_target = trg_sequences_target[
                    non_pad].contiguous().view(-1)

                # Train
                if phase == 'train':

                    # Loss calculate
                    with autocast():
                        predicted = model(src,
                                          trg[:, :-1],
                                          tgt_mask,
                                          non_pad_position=non_pad)
                        predicted = predicted.view(-1, predicted.size(-1))
                        loss = label_smoothing_loss(predicted,
                                                    trg_sequences_target,
                                                    args.pad_id)

                    scaler.scale(loss).backward()
                    scaler.unscale_(optimizer)
                    clip_grad_norm_(model.parameters(), args.clip_grad_norm)
                    scaler.step(optimizer)
                    scaler.update()

                    if args.scheduler in ['constant', 'warmup']:
                        scheduler.step()
                    if args.scheduler == 'reduce_train':
                        scheduler.step(loss)

                    # Print loss value only training
                    if i == 0 or freq == args.print_freq or i == len(
                            dataloader_dict['train']):
                        acc = (predicted.max(dim=1)[1] == trg_sequences_target
                               ).sum() / len(trg_sequences_target)
                        iter_log = "[Epoch:%03d][%03d/%03d] train_loss:%03.3f | train_acc:%03.2f%% | learning_rate:%1.6f | spend_time:%02.2fmin" % \
                            (epoch, i, len(dataloader_dict['train']),
                            loss.item(), acc*100, optimizer.param_groups[0]['lr'],
                            (time() - start_time_e) / 60)
                        write_log(logger, iter_log)
                        freq = 0
                    freq += 1

                # Validation
                if phase == 'valid':
                    with torch.no_grad():
                        predicted = model(src,
                                          trg[:, :-1],
                                          tgt_mask,
                                          non_pad_position=non_pad)
                        loss = F.cross_entropy(predicted, trg_sequences_target)
                    val_loss += loss.item()
                    val_acc += (predicted.max(dim=1)[1] == trg_sequences_target
                                ).sum() / len(trg_sequences_target)
                    if args.scheduler == 'reduce_valid':
                        scheduler.step(val_loss)
                    if args.scheduler == 'lambda':
                        scheduler.step()

            if phase == 'valid':
                val_loss /= len(dataloader_dict[phase])
                val_acc /= len(dataloader_dict[phase])
                write_log(logger, 'Validation Loss: %3.3f' % val_loss)
                write_log(logger,
                          'Validation Accuracy: %3.2f%%' % (val_acc * 100))
                if val_acc > best_val_acc:
                    write_log(logger, 'Checkpoint saving...')
                    torch.save(
                        {
                            'epoch': epoch,
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            'scheduler': scheduler.state_dict(),
                            'scaler': scaler.state_dict()
                        }, f'checkpoint_{args.parallel}.pth.tar')
                    best_val_acc = val_acc
                    best_epoch = epoch
                else:
                    else_log = f'Still {best_epoch} epoch accuracy({round(best_val_acc.item()*100, 2)})% is better...'
                    write_log(logger, else_log)

    # 3) Print results
    print(f'Best Epoch: {best_epoch}')
    print(f'Best Accuracy: {round(best_val_acc.item(), 2)}')
Ejemplo n.º 6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--problem', required=True)
    parser.add_argument('--train_step', type=int, default=200)
    parser.add_argument('--batch_size', type=int, default=4096)
    parser.add_argument('--max_length', type=int, default=100)
    parser.add_argument('--n_layers', type=int, default=6)
    parser.add_argument('--hidden_size', type=int, default=512)
    parser.add_argument('--filter_size', type=int, default=2048)
    parser.add_argument('--warmup', type=int, default=16000)
    parser.add_argument('--dropout', type=float, default=0.1)
    parser.add_argument('--label_smoothing', type=float, default=0.1)
    parser.add_argument('--val_every', type=int, default=5)
    parser.add_argument('--output_dir', type=str, default='./output')
    parser.add_argument('--data_dir', type=str, default='./data')
    parser.add_argument('--no_cuda', action='store_true')
    parser.add_argument('--summary_grad', action='store_true')
    opt = parser.parse_args()

    device = torch.device('cpu' if opt.no_cuda else 'cuda')

    if not os.path.exists(opt.output_dir + '/last/models'):
        os.makedirs(opt.output_dir + '/last/models')
    if not os.path.exists(opt.data_dir):
        os.makedirs(opt.data_dir)

    train_data, validation_data, i_vocab_size, t_vocab_size, opt = \
        problem.prepare(opt.problem, opt.data_dir, opt.max_length,
                        opt.batch_size, device, opt)
    if i_vocab_size is not None:
        print("# of vocabs (input):", i_vocab_size)
    print("# of vocabs (target):", t_vocab_size)

    if os.path.exists(opt.output_dir + '/last/models/last_model.pt'):
        print("Load a checkpoint...")
        last_model_path = opt.output_dir + '/last/models'
        model, global_step = utils.load_checkpoint(last_model_path, device,
                                                   is_eval=False)
    else:
        model = Transformer(i_vocab_size, t_vocab_size,
                            n_layers=opt.n_layers,
                            hidden_size=opt.hidden_size,
                            filter_size=opt.filter_size,
                            dropout_rate=opt.dropout,
                            share_target_embedding=opt.share_target_embedding,
                            has_inputs=opt.has_inputs)
        model = model.to(device=device)
        global_step = 0

    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("# of parameters: {}".format(num_params))

    optimizer = LRScheduler(
        filter(lambda x: x.requires_grad, model.parameters()),
        opt.hidden_size, opt.warmup, step=global_step)

    writer = SummaryWriter(opt.output_dir + '/last')
    val_writer = SummaryWriter(opt.output_dir + '/last/val')
    best_val_loss = float('inf')

    for t_step in range(opt.train_step):
        print("Epoch", t_step)
        start_epoch_time = time.time()
        global_step = train(train_data, model, opt, global_step,
                            optimizer, t_vocab_size, opt.label_smoothing,
                            writer)
        print("Epoch Time: {:.2f} sec".format(time.time() - start_epoch_time))

        if t_step % opt.val_every != 0:
            continue

        val_loss = validation(validation_data, model, global_step,
                              t_vocab_size, val_writer, opt)
        utils.save_checkpoint(model, opt.output_dir + '/last/models',
                              global_step, val_loss < best_val_loss)
        best_val_loss = min(val_loss, best_val_loss)
def main(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    logger = logging.getLogger(__name__)
    logger.setLevel(logging.DEBUG)
    handler = TqdmLoggingHandler()
    handler.setFormatter(logging.Formatter(" %(asctime)s - %(message)s"))
    logger.addHandler(handler)
    logger.propagate = False

    write_log(logger, "Load data")

    def load_data(args):
        gc.disable()
        with open(f"{args.preprocessed_data_path}/hanja_korean_word2id.pkl",
                  "rb") as f:
            data = pickle.load(f)
            hanja_word2id = data['hanja_word2id']
            korean_word2id = data['korean_word2id']

        with open(f"{args.preprocessed_data_path}/preprocessed_test.pkl",
                  "rb") as f:
            data = pickle.load(f)
            test_hanja_indices = data['hanja_indices']
            test_korean_indices = data['korean_indices']

        gc.enable()
        write_log(logger, "Finished loading data!")
        return hanja_word2id, korean_word2id, test_hanja_indices, test_korean_indices

    hanja_word2id, korean_word2id, test_hanja_indices, test_korean_indices = load_data(
        args)
    hanja_vocab_num = len(hanja_word2id)
    korean_vocab_num = len(korean_word2id)

    hk_dataset = HanjaKoreanDataset(test_hanja_indices,
                                    test_korean_indices,
                                    min_len=args.min_len,
                                    src_max_len=args.src_max_len,
                                    trg_max_len=args.trg_max_len)
    hk_loader = DataLoader(hk_dataset,
                           drop_last=True,
                           batch_size=args.hk_batch_size,
                           num_workers=4,
                           prefetch_factor=4,
                           pin_memory=True)
    write_log(logger, f"hanja-korean: {len(hk_dataset)}, {len(hk_loader)}")
    del test_hanja_indices, test_korean_indices

    write_log(logger, "Build model")
    model = Transformer(hanja_vocab_num,
                        korean_vocab_num,
                        pad_idx=args.pad_idx,
                        bos_idx=args.bos_idx,
                        eos_idx=args.eos_idx,
                        src_max_len=args.src_max_len,
                        trg_max_len=args.trg_max_len,
                        d_model=args.d_model,
                        d_embedding=args.d_embedding,
                        n_head=args.n_head,
                        dim_feedforward=args.dim_feedforward,
                        num_encoder_layer=args.num_encoder_layer,
                        num_decoder_layer=args.num_decoder_layer,
                        num_mask_layer=args.num_mask_layer)

    model.load_state_dict(
        torch.load(args.checkpoint_path, map_location=device)['model'])
    model.src_output_linear = None
    model.src_output_linear2 = None
    model.src_output_norm = None
    model.mask_encoders = None
    model = model.to(device)
    model.eval()

    write_log(logger, "Load SentencePiece model")
    parser = spm.SentencePieceProcessor()
    parser.Load(os.path.join(args.preprocessed_data_path, 'm_korean.model'))

    predicted_list = list()
    label_list = list()
    every_batch = torch.arange(0,
                               args.beam_size * args.hk_batch_size,
                               args.beam_size,
                               device=device)
    tgt_masks = {
        l: model.generate_square_subsequent_mask(l, device)
        for l in range(1, args.trg_max_len + 1)
    }

    with torch.no_grad():
        for src_sequences, trg_sequences in tqdm(hk_loader):
            src_sequences = src_sequences.to(device)
            label_list.extend(trg_sequences.tolist())

            # Encoding
            # encoder_out: (src_seq, batch_size, d_model)
            # src_key_padding_mask: (batch_size, src_seq)
            encoder_out = model.src_embedding(src_sequences).transpose(0, 1)
            src_key_padding_mask = (src_sequences == model.pad_idx)
            for encoder in model.encoders:
                encoder_out = encoder(
                    encoder_out, src_key_padding_mask=src_key_padding_mask)

            # Expanding
            # encoder_out: (src_seq, batch_size * k, d_model)
            # src_key_padding_mask: (batch_size * k, src_seq)
            src_seq_size = encoder_out.size(0)
            src_key_padding_mask = src_key_padding_mask.view(
                args.hk_batch_size, 1, -1).repeat(1, args.beam_size, 1)
            src_key_padding_mask = src_key_padding_mask.view(-1, src_seq_size)
            encoder_out = encoder_out.view(-1, args.hk_batch_size, 1,
                                           args.d_model).repeat(
                                               1, 1, args.beam_size, 1)
            encoder_out = encoder_out.view(src_seq_size, -1, args.d_model)

            # Scores save vector & decoding list setting
            scores_save = torch.zeros(args.beam_size * args.hk_batch_size,
                                      1,
                                      device=device)
            top_k_scores = torch.zeros(args.beam_size * args.hk_batch_size,
                                       1,
                                       device=device)
            complete_seqs = dict()
            complete_ind = set()

            # Decoding start token setting
            seqs = torch.tensor([[model.bos_idx]],
                                dtype=torch.long,
                                device=device)
            seqs = seqs.repeat(args.beam_size * args.hk_batch_size,
                               1).contiguous()

            for step in range(model.trg_max_len):
                # Decoder setting
                # tgt_mask: (out_seq)
                # tgt_key_padding_mask: (batch_size * k, out_seq)
                tgt_mask = tgt_masks[seqs.size(1)]
                tgt_key_padding_mask = (seqs == model.pad_idx)

                # Decoding sentence
                # decoder_out: (out_seq, batch_size * k, d_model)
                decoder_out = model.trg_embedding(seqs).transpose(0, 1)
                for decoder in model.decoders:
                    decoder_out = decoder(
                        decoder_out,
                        encoder_out,
                        tgt_mask=tgt_mask,
                        memory_key_padding_mask=src_key_padding_mask,
                        tgt_key_padding_mask=tgt_key_padding_mask)

                # Score calculate
                # scores: (batch_size * k, vocab_num)
                scores = F.gelu(model.trg_output_linear(decoder_out[-1]))
                scores = model.trg_output_linear2(
                    model.trg_output_norm(scores))
                scores = F.log_softmax(scores, dim=1)

                # Repetition Penalty
                if step > 0 and args.repetition_penalty > 0:
                    prev_ix = next_word_inds.view(-1)
                    for index, prev_token_id in enumerate(prev_ix):
                        scores[index][prev_token_id] *= args.repetition_penalty

                # Add score
                scores = top_k_scores.expand_as(scores) + scores
                if step == 0:
                    # scores: (batch_size, vocab_num)
                    # top_k_scores: (batch_size, k)
                    scores = scores[::args.beam_size]
                    scores[:, model.eos_idx] = float(
                        '-inf')  # set eos token probability zero in first step
                    top_k_scores, top_k_words = scores.topk(
                        args.beam_size, 1, True, True)
                else:
                    # top_k_scores: (batch_size * k, out_seq)
                    top_k_scores, top_k_words = scores.view(
                        args.hk_batch_size, -1).topk(args.beam_size, 1, True,
                                                     True)

                # Previous and Next word extract
                # seqs: (batch_size * k, out_seq + 1)
                prev_word_inds = top_k_words // korean_vocab_num
                next_word_inds = top_k_words % korean_vocab_num
                top_k_scores = top_k_scores.view(
                    args.hk_batch_size * args.beam_size, -1)
                top_k_words = top_k_words.view(
                    args.hk_batch_size * args.beam_size, -1)
                seqs = seqs[prev_word_inds.view(-1) + every_batch.unsqueeze(
                    1).repeat(1, args.beam_size).view(-1)]
                seqs = torch.cat([
                    seqs,
                    next_word_inds.view(args.beam_size * args.hk_batch_size,
                                        -1)
                ],
                                 dim=1)

                # Find and Save Complete Sequences Score
                eos_ind = torch.where(
                    next_word_inds.view(-1) == model.eos_idx)[0]
                if len(eos_ind) > 0:
                    eos_ind = eos_ind.tolist()
                    complete_ind_add = set(eos_ind) - complete_ind
                    complete_ind_add = list(complete_ind_add)
                    complete_ind.update(eos_ind)
                    if len(complete_ind_add) > 0:
                        scores_save[complete_ind_add] = top_k_scores[
                            complete_ind_add]
                        for ix in complete_ind_add:
                            complete_seqs[ix] = seqs[ix].tolist()

            # If eos token doesn't exist in sequence
            score_save_pos = torch.where(scores_save == 0)
            if len(score_save_pos[0]) > 0:
                for ix in score_save_pos[0].tolist():
                    complete_seqs[ix] = seqs[ix].tolist()
                scores_save[score_save_pos] = top_k_scores[score_save_pos]

            # Beam Length Normalization
            lp = torch.tensor([
                len(complete_seqs[i])
                for i in range(args.hk_batch_size * args.beam_size)
            ],
                              device=device)
            lp = (((lp + args.beam_size)**args.beam_alpha) /
                  ((args.beam_size + 1)**args.beam_alpha))
            scores_save = scores_save / lp.unsqueeze(1)

            # Predicted and Label processing
            ind = scores_save.view(args.hk_batch_size, args.beam_size,
                                   -1).argmax(dim=1)
            ind_expand = ind.view(-1) + every_batch
            predicted_list.extend(
                [complete_seqs[i] for i in ind_expand.tolist()])

    with open(
            f'./results_beam_{args.beam_size}_{args.beam_alpha}_{args.repetition_penalty}.pkl',
            'wb') as f:
        pickle.dump(
            {
                'prediction':
                predicted_list,
                'label':
                label_list,
                'prediction_decode':
                [parser.DecodeIds(pred) for pred in predicted_list],
                'label_decode':
                [parser.DecodeIds(label) for label in label_list]
            }, f)
Ejemplo n.º 8
0
def main(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def load_data(args):
        gc.disable()
        with open(f"{args.preprocessed_data_path}/hanja_korean_word2id.pkl",
                  "rb") as f:
            data = pickle.load(f)
            hanja_word2id = data['hanja_word2id']
            korean_word2id = data['korean_word2id']

        with open(f"{args.preprocessed_data_path}/preprocessed_test.pkl",
                  "rb") as f:
            data = pickle.load(f)
            test_hanja_indices = data['hanja_indices']
            test_additional_hanja_indices = data['additional_hanja_indices']

        gc.enable()
        return hanja_word2id, korean_word2id, test_hanja_indices, test_additional_hanja_indices

    hanja_word2id, korean_word2id, hanja_indices, additional_hanja_indices = load_data(
        args)
    hanja_vocab_num = len(hanja_word2id)
    korean_vocab_num = len(korean_word2id)

    print('Loader and Model Setting...')
    h_dataset = HanjaDataset(hanja_indices,
                             additional_hanja_indices,
                             hanja_word2id,
                             min_len=args.min_len,
                             src_max_len=args.src_max_len)
    h_loader = DataLoader(h_dataset,
                          drop_last=True,
                          batch_size=args.batch_size,
                          num_workers=4,
                          prefetch_factor=4)

    model = Transformer(hanja_vocab_num,
                        korean_vocab_num,
                        pad_idx=args.pad_idx,
                        bos_idx=args.bos_idx,
                        eos_idx=args.eos_idx,
                        src_max_len=args.src_max_len,
                        trg_max_len=args.trg_max_len,
                        d_model=args.d_model,
                        d_embedding=args.d_embedding,
                        n_head=args.n_head,
                        dim_feedforward=args.dim_feedforward,
                        num_encoder_layer=args.num_encoder_layer,
                        num_decoder_layer=args.num_decoder_layer,
                        num_mask_layer=args.num_mask_layer)

    model.load_state_dict(
        torch.load(args.checkpoint_path, map_location='cpu')['model'])
    model.decoders = None
    model.trg_embedding = None
    model.trg_output_linear = None
    model.trg_output_linear2 = None
    model.trg_output_norm = None
    model = model.to(device)
    model.eval()

    masking_acc = defaultdict(float)

    with torch.no_grad():
        for inputs, labels in h_loader:
            # Setting
            inputs = inputs.to(device)
            labels = labels.to(device)
            masked_position = labels != args.pad_idx
            masked_labels = labels[masked_position].contiguous().view(
                -1).unsqueeze(1)
            total_mask_count = masked_labels.size(0)

            # Prediction, output: Batch * Length * Vocab
            pred = model.reconstruct_predict(inputs,
                                             masked_position=masked_position)
            _, pred = pred.topk(10, 1, True, True)

            # Top1, 5, 10
            masking_acc[1] += (torch.sum(
                masked_labels == pred[:, :1]).item()) / total_mask_count
            masking_acc[5] += (torch.sum(
                masked_labels == pred[:, :5]).item()) / total_mask_count
            masking_acc[10] += (torch.sum(
                masked_labels == pred).item()) / total_mask_count

    for key in masking_acc.keys():
        masking_acc[key] /= len(h_loader)

    for key, value in masking_acc.items():
        print(f'Top {key} Accuracy: {value:.4f}')

    with open('./mask_result.pkl', 'wb') as f:
        pickle.dump(masking_acc, f)
Ejemplo n.º 9
0
Archivo: train.py Proyecto: Jwoo5/temp
    t_total = len(train_loader) * args.epoch
    
    lr = args.lr
    params = []
    for key, value in dict(model.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                params += [{'params' : [value], 'lr' : lr * 2, \
                            'weight_decay' : 0 }]
        else:
            params += [{'params':[value],'lr':lr, 'weight_decay': 0.0005}]
    
    lr = lr * 0.1
    optimizer = torch.optim.Adam(params)
    model.to(config.device)
    wandb.watch(model)    

    inputs = torch.FloatTensor(1)
    label = torch.LongTensor(1)
    inputs = inputs.to(config.device)
    label = label.to(config.device)
    inputs = Variable(inputs)
    label = Variable(label)

    best_epoch, best_loss, best_score = 0, 0, 0
    if os.path.isfile(args.save):
        best_epoch, best_loss, best_score = model.load(args.save)
        print(f"rank: {config.device} load state dict from: {args.save}")

    offset = best_epoch
Ejemplo n.º 10
0
def main(proc_id, args):
    trg_sp = spm.SentencePieceProcessor()
    trg_sp.Load(args.spm_trg_path)
    trg_vocab_num = trg_sp.piece_size()
    bos_id = trg_sp.bos_id()
    eos_id = trg_sp.eos_id()
    pad_id = trg_sp.pad_id()
    src_vocab = requests.get(f'{args.api_url}/getMetaData').json()['src_vocab']
    unk_id = src_vocab['<unk>']

    device = torch.device(f"cuda:{proc_id}")
    model = Transformer(len(src_vocab),
                        trg_vocab_num,
                        pad_idx=pad_id,
                        bos_idx=bos_id,
                        eos_idx=eos_id,
                        src_max_len=args.src_max_len,
                        trg_max_len=args.trg_max_len,
                        d_model=args.d_model,
                        d_embedding=args.d_embedding,
                        n_head=args.n_head,
                        dim_feedforward=args.dim_feedforward,
                        num_encoder_layer=args.num_encoder_layer,
                        num_decoder_layer=args.num_decoder_layer,
                        num_mask_layer=args.num_mask_layer)

    model.load_state_dict(
        torch.load(args.checkpoint_path, map_location=device)['model'])
    model.src_output_linear = None
    model.src_output_linear2 = None
    model.src_output_norm = None
    model.mask_encoders = None
    model = model.to(device)
    model = model.eval()

    tgt_masks = {
        l: model.generate_square_subsequent_mask(l, device)
        for l in range(1, args.trg_max_len + 1)
    }

    while True:
        data = requests.get(f'{args.api_url}/getData').json()
        pred_data = {'file': data['file'], 'content': []}
        parsed_ids = []
        for d in data['content']:
            parsed_id = [src_vocab.get(c, unk_id) for c in d['hanja']]
            if args.min_len <= len(parsed_id) <= args.src_max_len:
                input_id = np.zeros(args.src_max_len, dtype=np.int64)
                input_id[:len(parsed_id)] = parsed_id
                parsed_ids.append(input_id)
                pred_data['content'].append(d)

        num_iter = ceil(len(parsed_ids) / args.batch_size)
        batch_size_ = args.batch_size
        predicted_num = 0

        with torch.no_grad():
            batch_indices = torch.arange(0,
                                         args.beam_size * args.batch_size,
                                         args.beam_size,
                                         device=device)
            for iter_ in range(num_iter):
                iter_time = time()
                src_sequences = parsed_ids[iter_ *
                                           args.batch_size:(iter_ + 1) *
                                           args.batch_size]

                scores_save = torch.zeros(args.beam_size * args.batch_size,
                                          1,
                                          device=device)
                top_k_scores = torch.zeros(args.beam_size * args.batch_size,
                                           1,
                                           device=device)
                complete_seqs = dict()
                complete_ind = set()
                if len(src_sequences) < args.batch_size:
                    batch_size_ = len(src_sequences)
                    batch_indices = torch.arange(0,
                                                 args.beam_size * batch_size_,
                                                 args.beam_size,
                                                 device=device)
                    scores_save = torch.zeros(args.beam_size * batch_size_,
                                              1,
                                              device=device)
                    top_k_scores = torch.zeros(args.beam_size * batch_size_,
                                               1,
                                               device=device)

                src_sequences = torch.cat([
                    torch.cuda.LongTensor(seq, device=device)
                    for seq in src_sequences
                ])
                src_sequences = src_sequences.view(batch_size_,
                                                   args.src_max_len)

                # Encoding
                # encoder_out: (src_seq, batch_size, d_model), src_key_padding_mask: (batch_size, src_seq)
                encoder_out = model.src_embedding(src_sequences).transpose(
                    0, 1)
                src_key_padding_mask = (src_sequences == pad_id)
                for encoder in model.encoders:
                    encoder_out = encoder(
                        encoder_out, src_key_padding_mask=src_key_padding_mask)

                # Expanding
                # encoder_out: (src_seq, batch_size*k, d_model), src_key_padding_mask: (batch_size*k, src_seq)
                src_seq_size = encoder_out.size(0)
                src_key_padding_mask = src_key_padding_mask.view(
                    batch_size_, 1, -1).repeat(1, args.beam_size, 1)
                src_key_padding_mask = src_key_padding_mask.view(
                    -1, src_seq_size)
                encoder_out = encoder_out.view(-1, batch_size_, 1,
                                               args.d_model).repeat(
                                                   1, 1, args.beam_size, 1)
                encoder_out = encoder_out.view(src_seq_size, -1, args.d_model)

                # Decoding start token setting
                seqs = torch.tensor([[bos_id]],
                                    dtype=torch.long,
                                    device=device)
                seqs = seqs.repeat(args.beam_size * batch_size_,
                                   1).contiguous()

                for step in range(model.trg_max_len):
                    # Decoder setting
                    # tgt_mask: (out_seq), tgt_key_padding_mask: (batch_size * k, out_seq)
                    tgt_mask = tgt_masks[seqs.size(1)]
                    tgt_key_padding_mask = (seqs == pad_id)

                    # Decoding sentence
                    # decoder_out: (out_seq, batch_size * k, d_model)
                    decoder_out = model.trg_embedding(seqs).transpose(0, 1)
                    for decoder in model.decoders:
                        decoder_out = decoder(
                            decoder_out,
                            encoder_out,
                            tgt_mask=tgt_mask,
                            memory_key_padding_mask=src_key_padding_mask,
                            tgt_key_padding_mask=tgt_key_padding_mask)

                    # Score calculate
                    # scores: (batch_size * k, vocab_num)
                    scores = F.gelu(model.trg_output_linear(decoder_out[-1]))
                    scores = model.trg_output_linear2(
                        model.trg_output_norm(scores))
                    scores = F.log_softmax(scores, dim=1)

                    # Repetition Penalty
                    if step > 0 and args.repetition_penalty > 0:
                        prev_ix = next_word_inds.view(-1)
                        for index, prev_token_id in enumerate(prev_ix):
                            scores[index][
                                prev_token_id] *= args.repetition_penalty

                    # Add score
                    scores = top_k_scores.expand_as(scores) + scores
                    if step == 0:
                        # scores: (batch_size, vocab_num)
                        # top_k_scores: (batch_size, k)
                        scores = scores[::args.beam_size]
                        # set eos token probability zero in first step
                        scores[:, eos_id] = float('-inf')
                        top_k_scores, top_k_words = scores.topk(
                            args.beam_size, 1, True, True)
                    else:
                        # top_k_scores: (batch_size * k, out_seq)
                        top_k_scores, top_k_words = scores.view(
                            batch_size_, -1).topk(args.beam_size, 1, True,
                                                  True)

                    # Previous and Next word extract
                    # seqs: (batch_size * k, out_seq + 1)
                    prev_word_inds = top_k_words // trg_vocab_num
                    next_word_inds = top_k_words % trg_vocab_num
                    top_k_scores = top_k_scores.view(
                        batch_size_ * args.beam_size, -1)
                    top_k_words = top_k_words.view(
                        batch_size_ * args.beam_size, -1)
                    seqs = seqs[prev_word_inds.view(-1) +
                                batch_indices.unsqueeze(1).repeat(
                                    1, args.beam_size).view(-1)]
                    seqs = torch.cat([
                        seqs,
                        next_word_inds.view(args.beam_size * batch_size_, -1)
                    ],
                                     dim=1)

                    # Find and Save Complete Sequences Score
                    eos_ind = torch.where(next_word_inds.view(-1) == eos_id)[0]
                    if len(eos_ind) > 0:
                        eos_ind = eos_ind.tolist()
                        complete_ind_add = set(eos_ind) - complete_ind
                        complete_ind_add = list(complete_ind_add)
                        complete_ind.update(eos_ind)
                        if len(complete_ind_add) > 0:
                            scores_save[complete_ind_add] = top_k_scores[
                                complete_ind_add]
                            for ix in complete_ind_add:
                                complete_seqs[ix] = seqs[ix].tolist()

                # If eos token doesn't exist in sequence
                score_save_pos = torch.where(scores_save == 0)
                if len(score_save_pos[0]) > 0:
                    for ix in score_save_pos[0].tolist():
                        complete_seqs[ix] = seqs[ix].tolist()
                    scores_save[score_save_pos] = top_k_scores[score_save_pos]

                # Beam Length Normalization
                lp = torch.tensor([
                    len(complete_seqs[i])
                    for i in range(batch_size_ * args.beam_size)
                ],
                                  device=device)
                lp = (((lp + args.beam_size)**args.beam_alpha) /
                      ((args.beam_size + 1)**args.beam_alpha))
                scores_save = scores_save / lp.unsqueeze(1)

                # Predicted and Label processing
                ind = scores_save.view(batch_size_, args.beam_size,
                                       -1).argmax(dim=1)
                ind = (ind.view(-1) + batch_indices).tolist()
                for i in ind:
                    predicted_sequence = trg_sp.decode_ids(complete_seqs[i])
                    pred_data['content'][predicted_num][
                        'predicted_sequence'] = predicted_sequence
                    predicted_num += 1

                iter_time = time() - iter_time
                print(
                    f"{proc_id} - iter: {iter_ + 1}/{num_iter}, {iter_time:.2f}"
                )

        res = requests.post(f'{args.api_url}/commitData',
                            json=pred_data).json()
        print(f"{proc_id} - Progress: {res['progress']}, {pred_data['file']}")
        if res['progress'] == 'finish':
            return
Ejemplo n.º 11
0
def build_model(
    opts,
    source_vocab_size: int,
    target_vocab_size: int,
    source_pad_id: int,
    target_sos_id: int,
    target_eos_id: int,
    target_pad_id: int,
    device: torch.device,
):
    """ Builds our Transformer model

        Parameters
        ----------
        opts
            The options from the command line
        source_vocab_size : int
            The vocab size in the source language
        target_vocab_size : int
            The vocab size in the target language
        source_pad_id : int
            The ID of a padding token in the source language
        target_sos_id : int
            The ID of a start-of-sequence token in the target language
        target_eos_id : int
            The ID of an end-of-sequence token in the target language
        target_pad_id : int
            The ID of a padding token in the target language
        device : torch.device
            The device to run the model on
            
        Returns
        -------
        model : Transformer
            The model
    """

    encoder = Encoder(
        source_vocab_size,
        opts.source_word_embedding_size,
        opts.encoder_num_layers,
        opts.encoder_num_attention_heads,
        opts.encoder_pf_size,
        opts.encoder_dropout,
        device,
    )

    decoder = Decoder(
        target_vocab_size,
        opts.target_word_embedding_size,
        opts.decoder_num_layers,
        opts.decoder_num_attention_heads,
        opts.decoder_pf_size,
        opts.decoder_dropout,
        device,
    )

    model = Transformer(
        encoder,
        decoder,
        source_pad_id,
        target_sos_id,
        target_eos_id,
        target_pad_id,
        device,
    )
    model.to(device)

    return model
Ejemplo n.º 12
0
class Trainer:
    def __init__(self,
                 params,
                 mode,
                 train_iter=None,
                 valid_iter=None,
                 test_iter=None):
        self.params = params

        # Train mode
        if mode == 'train':
            self.train_iter = train_iter
            self.valid_iter = valid_iter

        # Test mode
        else:
            self.test_iter = test_iter

        self.model = Transformer(self.params)
        self.model.to(self.params.device)

        # Scheduling Optimzer
        self.optimizer = ScheduledAdam(optim.Adam(self.model.parameters(),
                                                  betas=(0.9, 0.98),
                                                  eps=1e-9),
                                       hidden_dim=params.hidden_dim,
                                       warm_steps=params.warm_steps)

        self.criterion = nn.CrossEntropyLoss(ignore_index=self.params.pad_idx)
        self.criterion.to(self.params.device)

    def train(self):
        print(self.model)
        print(
            f'The model has {self.model.count_params():,} trainable parameters'
        )
        best_valid_loss = float('inf')

        for epoch in range(self.params.num_epoch):
            self.model.train()
            epoch_loss = 0
            start_time = time.time()

            for batch in self.train_iter:
                # For each batch, first zero the gradients
                self.optimizer.zero_grad()
                source = batch.kor
                target = batch.eng

                # target sentence consists of <sos> and following tokens (except the <eos> token)
                output = self.model(source, target[:, :-1])[0]

                # ground truth sentence consists of tokens and <eos> token (except the <sos> token)
                output = output.contiguous().view(-1, output.shape[-1])
                target = target[:, 1:].contiguous().view(-1)
                # output = [(batch size * target length - 1), output dim]
                # target = [(batch size * target length - 1)]
                loss = self.criterion(output, target)
                loss.backward()

                # clip the gradients to prevent the model from exploding gradient
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               self.params.clip)

                self.optimizer.step()

                # 'item' method is used to extract a scalar from a tensor which only contains a single value.
                epoch_loss += loss.item()

            train_loss = epoch_loss / len(self.train_iter)
            valid_loss = self.evaluate()

            end_time = time.time()
            epoch_mins, epoch_secs = epoch_time(start_time, end_time)

            if valid_loss < best_valid_loss:
                best_valid_loss = valid_loss
                torch.save(self.model.state_dict(), self.params.save_model)

            print(
                f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s'
            )
            print(
                f'\tTrain Loss: {train_loss:.3f} | Val. Loss: {valid_loss:.3f}'
            )

    def evaluate(self):
        self.model.eval()
        epoch_loss = 0

        with torch.no_grad():
            for batch in self.valid_iter:
                source = batch.kor
                target = batch.eng

                output = self.model(source, target[:, :-1])[0]

                output = output.contiguous().view(-1, output.shape[-1])
                target = target[:, 1:].contiguous().view(-1)

                loss = self.criterion(output, target)

                epoch_loss += loss.item()

        return epoch_loss / len(self.valid_iter)

    def inference(self):
        self.model.load_state_dict(torch.load(self.params.save_model))
        self.model.eval()
        epoch_loss = 0

        with torch.no_grad():
            for batch in self.test_iter:
                source = batch.kor
                target = batch.eng

                output = self.model(source, target[:, :-1])[0]

                output = output.contiguous().view(-1, output.shape[-1])
                target = target[:, 1:].contiguous().view(-1)

                loss = self.criterion(output, target)

                epoch_loss += loss.item()

        test_loss = epoch_loss / len(self.test_iter)
        print(f'Test Loss: {test_loss:.3f}')