def main(args):
    dataset_config = Config(args.dataset_config)
    model_config = Config(args.model_config)
    exp_dir = Path("experiments") / model_config.type
    exp_dir = exp_dir.joinpath(
        f"epochs_{args.epochs}_batch_size_{args.batch_size}_learning_rate_{args.learning_rate}"
    )

    tokenizer = get_tokenizer(dataset_config, model_config)

    checkpoint_manager = CheckpointManager(exp_dir)
    checkpoint = checkpoint_manager.load_checkpoint("best.tar")
    model = CharCNN(num_classes=model_config.num_classes,
                    embedding_dim=model_config.embedding_dim,
                    vocab=tokenizer.vocab)
    model.load_state_dict(checkpoint["model_state_dict"])

    summary_manager = SummaryManager(exp_dir)
    filepath = getattr(dataset_config, args.data)
    ds = Corpus(filepath, tokenizer.split_and_transform)
    dl = DataLoader(ds, batch_size=args.batch_size, num_workers=4)

    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)

    summary = evaluate(model, dl, {
        "loss": nn.CrossEntropyLoss(),
        "acc": acc
    }, device)

    summary_manager.load("summary.json")
    summary_manager.update({f"{args.data}": summary})
    summary_manager.save("summary.json")
    print(f"loss: {summary['loss']:.3f}, acc: {summary['acc']:.2%}")
Exemple #2
0
def main(args):
    dataset_config = Config(args.dataset_config)
    model_config = Config(args.model_config)

    exp_dir = Path("experiments") / model_config.type
    exp_dir = exp_dir.joinpath(
        f"epochs_{args.epochs}_batch_size_{args.batch_size}_learning_rate_{args.learning_rate}"
    )

    preprocessor = get_preprocessor(dataset_config, coarse_split_fn=split_morphs, fine_split_fn=split_jamos)

    # model (restore)
    checkpoint_manager = CheckpointManager(exp_dir)
    checkpoint = checkpoint_manager.load_checkpoint("best.tar")
    model = SAN(model_config.num_classes, preprocessor.coarse_vocab, preprocessor.fine_vocab,
                model_config.fine_embedding_dim, model_config.hidden_dim, model_config.multi_step,
                model_config.prediction_drop_ratio)
    model.load_state_dict(checkpoint["model_state_dict"])

    # evaluation
    filepath = getattr(dataset_config, args.data)
    ds = Corpus(filepath, preprocessor.preprocess)
    dl = DataLoader(ds, batch_size=args.batch_size, num_workers=4, collate_fn=batchify)

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)

    summary_manager = SummaryManager(exp_dir)
    summary = evaluate(model, dl, {"loss": log_loss, "acc": acc}, device)

    summary_manager.load("summary.json")
    summary_manager.update({f"{args.data}": summary})
    summary_manager.save("summary.json")

    print(f"loss: {summary['loss']:.3f}, acc: {summary['acc']:.2%}")
Exemple #3
0
def main(args):
    dataset_config = Config(args.dataset_config)
    model_config = Config(args.model_config)

    exp_dir = Path("experiments") / model_config.type
    exp_dir = exp_dir.joinpath(
        f"epochs_{args.epochs}_batch_size_{args.batch_size}_learning_rate_{args.learning_rate}"
        f"_teacher_forcing_ratio_{args.teacher_forcing_ratio}")

    src_processor, tgt_processor = get_processor(dataset_config)

    # model (restore)
    encoder = BidiEncoder(src_processor.vocab, model_config.encoder_hidden_dim,
                          model_config.drop_ratio)
    decoder = AttnDecoder(
        tgt_processor.vocab,
        model_config.method,
        model_config.encoder_hidden_dim * 2,
        model_config.decoder_hidden_dim,
        model_config.drop_ratio,
    )

    checkpoint_manager = CheckpointManager(exp_dir)
    checkpoint = checkpoint_manager.load_checkpoint("best.tar")
    encoder.load_state_dict(checkpoint["encoder_state_dict"])
    decoder.load_state_dict(checkpoint["decoder_state_dict"])

    encoder.eval()
    decoder.eval()

    # evaluation
    summary_manager = SummaryManager(exp_dir)
    filepath = getattr(dataset_config, args.data)
    ds = NMTCorpus(filepath, src_processor.process, tgt_processor.process)
    dl = DataLoader(
        ds,
        args.batch_size,
        shuffle=False,
        num_workers=4,
        collate_fn=batchify,
        drop_last=False,
    )

    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")
    encoder.to(device)
    decoder.to(device)

    loss = evaluate(encoder, decoder, tgt_processor.vocab, dl, device)
    summary = {"perplexity": np.exp(loss)}
    summary_manager.load("summary.json")
    summary_manager.update({"{}".format(args.data): summary})
    summary_manager.save("summary.json")
    print("perplexity: {:.3f}".format(np.exp(loss)))
def main(args):
    dataset_config = Config(args.dataset_config)
    model_config = Config(args.model_config)

    exp_dir = Path("experiments") / model_config.type
    exp_dir = exp_dir.joinpath(
        f"epochs_{args.epochs}_batch_size_{args.batch_size}_learning_rate_{args.learning_rate}"
    )

    tokenizer = get_tokenizer(dataset_config)

    # model (restore)
    checkpoint_manager = CheckpointManager(exp_dir)
    checkpoint = checkpoint_manager.load_checkpoint("best.tar")
    model = SAN(num_classes=model_config.num_classes,
                lstm_hidden_dim=model_config.lstm_hidden_dim,
                da=model_config.da,
                r=model_config.r,
                hidden_dim=model_config.hidden_dim,
                vocab=tokenizer.vocab)
    model.load_state_dict(checkpoint["model_state_dict"])

    # evaluation
    filepath = getattr(dataset_config, args.data)
    ds = Corpus(filepath, tokenizer.split_and_transform)
    dl = DataLoader(ds,
                    batch_size=args.batch_size,
                    num_workers=4,
                    collate_fn=batchify)

    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)

    summary_manager = SummaryManager(exp_dir)
    summary = evaluate(model, dl, {
        "loss": nn.CrossEntropyLoss(),
        "acc": acc
    }, device)

    summary_manager.load("summary.json")
    summary_manager.update({f"{args.data}": summary})
    summary_manager.save("summary.json")

    print("loss: {:.3f}, acc: {:.2%}".format(summary["loss"], summary["acc"]))
def main(args):
    dataset_config = Config(args.dataset_config)
    model_config = Config(args.model_config)
    ptr_config_info = Config(f"conf/pretrained/{model_config.type}.json")

    exp_dir = Path("experiments") / model_config.type
    exp_dir = exp_dir.joinpath(
        f"epochs_{args.epochs}_batch_size_{args.batch_size}_learning_rate_{args.learning_rate}"
        f"_weight_decay_{args.weight_decay}")

    preprocessor = get_preprocessor(ptr_config_info, model_config)

    with open(ptr_config_info.config, mode="r") as io:
        ptr_config = json.load(io)

    # model (restore)
    checkpoint_manager = CheckpointManager(exp_dir)
    checkpoint = checkpoint_manager.load_checkpoint('best.tar')
    config = BertConfig()
    config.update(ptr_config)
    model = SentenceClassifier(config,
                               num_classes=model_config.num_classes,
                               vocab=preprocessor.vocab)
    model.load_state_dict(checkpoint['model_state_dict'])

    # evaluation
    filepath = getattr(dataset_config, args.data)
    ds = Corpus(filepath, preprocessor.preprocess)
    dl = DataLoader(ds, batch_size=args.batch_size, num_workers=4)
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)

    summary_manager = SummaryManager(exp_dir)
    summary = evaluate(model, dl, {
        'loss': nn.CrossEntropyLoss(),
        'acc': acc
    }, device)

    summary_manager.load('summary.json')
    summary_manager.update({'{}'.format(args.data): summary})
    summary_manager.save('summary.json')

    print('loss: {:.3f}, acc: {:.2%}'.format(summary['loss'], summary['acc']))
Exemple #6
0
    model = MaLSTM(
        num_classes=model_config.num_classes,
        hidden_dim=model_config.hidden_dim,
        vocab=tokenizer.vocab,
    )
    model.load_state_dict(checkpoint["model_state_dict"])

    # evaluation
    filepath = getattr(data_config, args.dataset)
    ds = Corpus(filepath, tokenizer.split_and_transform)
    dl = DataLoader(ds,
                    batch_size=model_config.batch_size,
                    num_workers=4,
                    collate_fn=batchify)

    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)

    summary_manager = SummaryManager(model_dir)
    summary = evaluate(model, dl, {
        "loss": nn.CrossEntropyLoss(),
        "acc": acc
    }, device)

    summary_manager.load("summary.json")
    summary_manager.update({"{}".format(args.dataset): summary})
    summary_manager.save("summary.json")

    print("loss: {:.3f}, acc: {:.2%}".format(summary["loss"], summary["acc"]))
Exemple #7
0
def main(parser):

    args = parser.parse_args()

    if args.fp16 == True:
        from apex import amp

    data_dir = Path(args.data_dir)
    model_dir = Path(args.model_dir)
    model_config = Config(json_path=model_dir / 'config.json')
    model_config.learning_rate = args.lr
    model_config.batch_size = args.batch_size

    model_config.vocab_size = len(vocab)
    print("vocabulary length: ", len(vocab))

    # Train & Val Datasets
    train_data_dir = "../data/NER-master/말뭉치 - 형태소_개체명"
    tr_ds = NamedEntityRecognitionDataset(train_data_dir=train_data_dir, vocab=vocab, \
                                          tokenizer=bert_tokenizer, maxlen=model_config.maxlen, model_dir=model_dir)
    tr_dl = DataLoader(tr_ds,
                       batch_size=model_config.batch_size,
                       shuffle=True,
                       num_workers=2,
                       drop_last=False)

    val_data_dir = "../data/NER-master/validation_set"
    val_ds = NamedEntityRecognitionDataset(train_data_dir=val_data_dir, vocab=vocab, \
                                           tokenizer=bert_tokenizer, maxlen=model_config.maxlen, model_dir=model_dir)
    val_dl = DataLoader(val_ds,
                        batch_size=model_config.batch_size,
                        shuffle=True,
                        num_workers=2,
                        drop_last=False)

    # Model
    model = BertMulti_CRF(config=model_config,
                          num_classes=len(tr_ds.ner_to_index),
                          vocab=vocab)
    #model = BertMulti_Only(config=model_config, num_classes=len(tr_ds.ner_to_index), vocab=vocab)
    #model = BiLSTM(config=model_config, num_classes=len(tr_ds.ner_to_index), vocab=vocab)
    #model = BiLSTM_CRF(config=model_config, num_classes=len(tr_ds.ner_to_index))
    model.train()

    # optim
    train_examples_len = len(tr_ds)
    val_examples_len = len(val_ds)
    print("num of train: {}, num of val: {}".format(train_examples_len,
                                                    val_examples_len))

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    # num_train_optimization_steps = int(train_examples_len / model_config.batch_size / model_config.gradient_accumulation_steps) * model_config.epochs
    t_total = len(
        tr_dl
    ) // model_config.gradient_accumulation_steps * model_config.epochs
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=model_config.learning_rate,
                      eps=model_config.adam_epsilon)
    #optimizer = torch.optim.Adam(model.parameters(), model_config.learning_rate)
    if args.lr_schedule:
        scheduler = WarmupLinearSchedule(
            optimizer, warmup_steps=model_config.warmup_steps, t_total=t_total)
        #lmbda = lambda epoch: 0.5
        #scheduler = LambdaLR(optimizer, lr_lambda=lmbda)

    #Create model output directory
    output_dir = os.path.join(
        model_dir,
        '{}-lr{}-bs{}'.format(model.name, model_config.learning_rate,
                              model_config.batch_size))
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    #checkpoint_manager = CheckpointManager(model_dir)
    summary_manager = SummaryManager(output_dir)

    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    '''
    n_gpu = torch.cuda.device_count()
    if n_gpu > 1:
        model = torch.nn.DataParallel(model)
    '''
    model.to(device)

    if args.fp16:
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
    if args.continue_train:
        revert_to_best(model, optimizer, output_dir)
        logging.info("==== continue training: %s ====", '{}-lr{}-bs{}' \
                    .format(model.name, model_config.learning_rate, model_config.batch_size))

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(tr_ds))
    logger.info("  Num Epochs = %d", model_config.epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                model_config.batch_size)
    logger.info("  Gradient Accumulation steps = %d",
                model_config.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    log_file = open('{}/log.tsv'.format(output_dir), 'at')
    print('{}\t{}\t{}\t{}\t{}\t{}\t{}'.format('epoch', 'train loss', 'eval_loss', 'eval global accuracy', \
                                              'micro_f1_score', 'macro_f1_score', 'learning_rate'), file=log_file)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    best_dev_acc, best_dev_loss = 0.0, 99999999999.0
    best_epoch = 0
    best_steps = 0
    patience = args.patience
    f_scores = []
    model.zero_grad()
    set_seed()
    criterion = nn.CrossEntropyLoss()

    train_begin = datetime.now()
    '''
    train_iterator = trange(int(model_config.epochs), desc="Epoch")  
    for _epoch, _ in enumerate(train_iterator):
    '''
    for _epoch in range(model_config.epochs):
        #epoch_iterator = tqdm(tr_dl, desc="Iteration")
        epoch_iterator = tr_dl
        epoch = _epoch

        for step, batch in enumerate(epoch_iterator):

            model.train()
            #print(batch)

            x_input, token_type_ids, y_real = map(lambda elm: elm.to(device),
                                                  batch)
            #print(x_input.size(), token_type_ids.size(), y_real.size()) #都是batch_size*max_len
            #print(y_real)
            if model.name == "BertMulti_Only":
                y_out = model(x_input, token_type_ids, y_real)
                y_out.requires_grad_()
                y_out.contiguous()
                y_real.contiguous()
                y_real_ = y_real.view(-1)
                y_out_ = y_out.view(-1, len(tr_ds.ner_to_index))
                loss = criterion(y_out_, y_real_)
                _, sequence_of_tags = F.softmax(y_out, dim=2).max(2)
            elif model.name == "BiLSTM":
                y_out = model(x_input, token_type_ids, y_real)
                y_out.requires_grad_()
                y_out.contiguous()
                y_real.contiguous()

                y_out1 = F.log_softmax(y_out, dim=2)
                y_out1 = y_out1.view(-1, len(tr_ds.ner_to_index))

                y_real_ = y_real.view(-1)
                mask = (y_real_ != 1).float()
                #print(len(mask))
                original_len = int(torch.sum(mask))
                #print(x_input[0], y_real[0], original_len, '\n')
                y_out1 = y_out1[range(y_out1.shape[0]), y_real_] * mask
                loss = -torch.sum(y_out1) / original_len

                _, sequence_of_tags = F.softmax(y_out, dim=2).max(2)
            else:
                log_likelihood, sequence_of_tags = model(
                    x_input, token_type_ids, y_real)
                loss = -1 * log_likelihood

            if model_config.gradient_accumulation_steps > 1:
                loss = loss / model_config.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           model_config.max_grad_norm)
            tr_loss += loss.item()

            if (step + 1) % model_config.gradient_accumulation_steps == 0:
                optimizer.step()
                if args.lr_schedule:
                    scheduler.step()  # Update learning rate schedule
                    #print(scheduler.state_dict())
                model.zero_grad()
                global_step += 1

                with torch.no_grad():
                    sequence_of_tags = torch.tensor(sequence_of_tags).to(
                        device)
                    #print(sequence_of_tags.size(), y_real.size())
                    mb_acc = (sequence_of_tags == y_real
                              ).float()[y_real != vocab['[PAD]']].mean()

                tr_acc = mb_acc.item()
                tr_loss_avg = tr_loss / global_step
                tr_summary = {'loss': tr_loss_avg, 'acc': tr_acc}

                if (step + 1) % 20 == 0:
                    logging.info('epoch : {}, global_step : {}, tr_loss: {:.3f}, tr_acc: {:.2%}' \
                                 .format(epoch + 1, global_step, tr_summary['loss'], tr_summary['acc']))

                # evaluation and save model
                if model_config.logging_steps > 0 and global_step % model_config.logging_steps == 0:

                    eval_summary = evaluate(model, val_dl)

                    f_scores.append(eval_summary['macro_f1_score'])

                    # Save model checkpoint
                    summary = {'train': tr_summary, 'eval': eval_summary}
                    summary_manager.update(summary)
                    summary_manager.save('summary.json')

                    # Save
                    is_best = eval_summary[
                        "macro_f1_score"] >= best_dev_acc  # acc 기준 (원래는 train_acc가 아니라 val_acc로 해야)
                    is_best_str = 'BEST' if is_best else '< {:.4f}'.format(
                        max(f_scores))
                    logging.info(
                        '[Los trn]  [Los dev]  [global acc]  [micro f1]  [macro f1]     [global step]    [LR]'
                    )
                    logging.info('{:8.2f}  {:9.2f}  {:9.2f}  {:11.4f}  {:9.4f} {:4}  {:9}  {:14.8f}' \
                                 .format((tr_loss - logging_loss) / model_config.logging_steps, eval_summary['eval_loss'], \
                                         eval_summary['eval_global_acc'], eval_summary['micro_f1_score'], \
                                         eval_summary['macro_f1_score'], is_best_str, global_step, model_config.learning_rate))
                    print('{}\t{}\t{}\t{}\t{}\t{}\t{}'.format(epoch, tr_loss, \
                                                              eval_summary['eval_loss'], eval_summary['eval_global_acc'], \
                                                              eval_summary['micro_f1_score'], eval_summary['macro_f1_score'], \
                                                              model_config.learning_rate), file=log_file)
                    log_file.flush()

                    logging_loss = tr_loss

                    if is_best:
                        best_dev_acc = eval_summary["macro_f1_score"]
                        best_dev_loss = eval_summary["eval_loss"]
                        best_steps = global_step
                        best_epoch = epoch
                        #checkpoint_manager.save_checkpoint(state, 'best-epoch-{}-step-{}-acc-{:.3f}.bin'.format(epoch + 1, global_step, best_dev_acc))
                        #logging.info("Saving model checkpoint as best-epoch-{}-step-{}-acc-{:.3f}.bin".format(epoch + 1, global_step, best_dev_acc))
                        logging.info(
                            "Saving model at epoch{}, step{} in {}".format(
                                epoch, global_step, output_dir))
                        torch.save(model.state_dict(),
                                   '{}/model.state'.format(output_dir))
                        torch.save(optimizer.state_dict(),
                                   '{}/optim.state'.format(output_dir))
                        patience = args.patience

                    else:
                        revert_to_best(model, optimizer, output_dir)
                        patience -= 1
                        logging.info("==== revert to epoch[%d], step%d. F1 score: %.4f, patience: %d ====", \
                                     best_epoch, best_steps, max(f_scores), patience)

                        if patience == 0:
                            break

        else:

            continue

        break

    #print("global_step = {}, average loss = {}".format(global_step, tr_loss / global_step))

    train_end = datetime.now()
    train_elapsed = elapsed(train_end - train_begin)
    logging.info('==== training time elapsed: %s, epoch: %s ====',
                 train_elapsed, epoch)

    return global_step, tr_loss / global_step, best_steps
Exemple #8
0
def main(args):
    dataset_config = Config(args.dataset_config)
    model_config = Config(args.model_config)

    exp_dir = Path("experiments") / model_config.type
    exp_dir = exp_dir.joinpath(
        f"epochs_{args.epochs}_batch_size_{args.batch_size}_learning_rate_{args.learning_rate}"
    )

    if not exp_dir.exists():
        exp_dir.mkdir(parents=True)

    if args.fix_seed:
        torch.manual_seed(777)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    tokenizer = get_tokenizer(dataset_config)
    tr_dl, val_dl = get_data_loaders(dataset_config, model_config, tokenizer,
                                     args.batch_size)

    # model
    model = ConvRec(num_classes=model_config.num_classes,
                    embedding_dim=model_config.embedding_dim,
                    hidden_dim=model_config.hidden_dim,
                    vocab=tokenizer.vocab)

    loss_fn = nn.CrossEntropyLoss()
    opt = optim.Adam(params=model.parameters(), lr=args.learning_rate)
    scheduler = ReduceLROnPlateau(opt, patience=5)
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)

    writer = SummaryWriter('{}/runs'.format(exp_dir))
    checkpoint_manager = CheckpointManager(exp_dir)
    summary_manager = SummaryManager(exp_dir)
    best_val_loss = 1e+10

    for epoch in tqdm(range(args.epochs), desc='epochs'):

        tr_loss = 0
        tr_acc = 0

        model.train()
        for step, mb in tqdm(enumerate(tr_dl), desc='steps', total=len(tr_dl)):
            x_mb, y_mb = map(lambda elm: elm.to(device), mb)

            opt.zero_grad()
            y_hat_mb = model(x_mb)
            mb_loss = loss_fn(y_hat_mb, y_mb)
            mb_loss.backward()
            opt.step()

            with torch.no_grad():
                mb_acc = acc(y_hat_mb, y_mb)

            tr_loss += mb_loss.item()
            tr_acc += mb_acc.item()

            if (epoch * len(tr_dl) + step) % args.summary_step == 0:
                val_loss = evaluate(model, val_dl, {'loss': loss_fn},
                                    device)['loss']
                writer.add_scalars('loss', {
                    'train': tr_loss / (step + 1),
                    'val': val_loss
                },
                                   epoch * len(tr_dl) + step)
                model.train()
        else:
            tr_loss /= (step + 1)
            tr_acc /= (step + 1)

            tr_summary = {'loss': tr_loss, 'acc': tr_acc}
            val_summary = evaluate(model, val_dl, {
                'loss': loss_fn,
                'acc': acc
            }, device)
            scheduler.step(val_summary['loss'])
            tqdm.write('epoch : {}, tr_loss: {:.3f}, val_loss: '
                       '{:.3f}, tr_acc: {:.2%}, val_acc: {:.2%}'.format(
                           epoch + 1, tr_summary['loss'], val_summary['loss'],
                           tr_summary['acc'], val_summary['acc']))

            val_loss = val_summary['loss']
            is_best = val_loss < best_val_loss

            if is_best:
                state = {
                    'epoch': epoch + 1,
                    'model_state_dict': model.state_dict(),
                    'opt_state_dict': opt.state_dict()
                }
                summary = {'train': tr_summary, 'validation': val_summary}

                summary_manager.update(summary)
                summary_manager.save('summary.json')
                checkpoint_manager.save_checkpoint(state, 'best.tar')

                best_val_loss = val_loss
Exemple #9
0
                    min_length=model_config.min_length,
                    pad_val=tokenizer.vocab.to_indices(' '))
    val_dl = DataLoader(val_ds,
                        batch_size=model_config.batch_size,
                        collate_fn=batchify)

    loss_fn = nn.CrossEntropyLoss()
    opt = optim.Adam(params=model.parameters(), lr=model_config.learning_rate)
    scheduler = ReduceLROnPlateau(opt, patience=5)
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)

    writer = SummaryWriter('{}/runs'.format(model_dir))
    checkpoint_manager = CheckpointManager(model_dir)
    summary_manager = SummaryManager(model_dir)
    best_val_loss = 1e+10

    for epoch in tqdm(range(model_config.epochs), desc='epochs'):

        tr_loss = 0
        tr_acc = 0

        model.train()
        for step, mb in tqdm(enumerate(tr_dl), desc='steps', total=len(tr_dl)):
            x_mb, y_mb = map(lambda elm: elm.to(device), mb)

            opt.zero_grad()
            y_hat_mb = model(x_mb)
            mb_loss = loss_fn(y_hat_mb, y_mb)
            mb_loss.backward()
Exemple #10
0
        print('gpu is available')
        torch.cuda.empty_cache()

    if torch.cuda.device_count() > 1:
        print('multiple gpus are available')
        if args.gpu is not None:
            device_ids = [int(x) for x in args.gpu]
            model = DataParallel(model, device_ids=device_ids)
        else:
            model = DataParallel(model)
    model.to(device)
    criterion.to(device)

    writer = SummaryWriter(save_dir / f'runs_{args.model}')
    checkpoint_manager = CheckpointManager(save_dir)
    summary_manager = SummaryManager(save_dir)
    summary_manager.update(experiment_summary)

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.learning_rate,
                                 weight_decay=1e-5)
    sampler = BernoulliNegativeSampler(kg_train)
    tr_dl = DataLoader(kg_train, batch_size=args.batch_size)
    val_dl = DataLoader(kg_valid, batch_size=args.batch_size)

    best_val_loss = 1e+10
    for epoch in tqdm(range(args.epochs), desc='epochs'):
        tr_loss = 0
        model.train()

        for step, batch in enumerate(tr_dl):
Exemple #11
0
    checkpoint_manager = CheckpointManager(model_dir)
    checkpoint = checkpoint_manager.load_checkpoint('best_snu_{}.tar'.format(
        args.pretrained_config))

    config = BertConfig(ptr_config.config)
    model = SentenceClassifier(config,
                               num_classes=model_config.num_classes,
                               vocab=preprocessor.vocab)
    model.load_state_dict(checkpoint['model_state_dict'])

    # evaluation
    filepath = getattr(data_config, args.dataset)
    ds = Corpus(filepath, preprocessor.preprocess)
    dl = DataLoader(ds, batch_size=model_config.batch_size, num_workers=4)

    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)

    summary_manager = SummaryManager(model_dir)
    summary = evaluate(model, dl, {
        'loss': nn.CrossEntropyLoss(),
        'acc': acc
    }, device)

    summary_manager.load('summary_snu_{}.json'.format(args.pretrained_config))
    summary_manager.update({'{}'.format(args.dataset): summary})
    summary_manager.save('summary_snu_{}.json'.format(args.pretrained_config))

    print('loss: {:.3f}, acc: {:.2%}'.format(summary['loss'], summary['acc']))
Exemple #12
0
def main(args):
    dataset_config = Config(args.dataset_config)
    model_config = Config(args.model_config)

    exp_dir = Path("experiments") / model_config.type
    exp_dir = exp_dir.joinpath(
        f"epochs_{args.epochs}_batch_size_{args.batch_size}_learning_rate_{args.learning_rate}"
    )

    if not exp_dir.exists():
        exp_dir.mkdir(parents=True)

    if args.fix_seed:
        torch.manual_seed(777)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    preprocessor = get_preprocessor(dataset_config,
                                    coarse_split_fn=split_morphs,
                                    fine_split_fn=split_jamos)
    tr_dl, val_dl = get_data_loaders(dataset_config,
                                     preprocessor,
                                     args.batch_size,
                                     collate_fn=batchify)

    # model
    model = SAN(model_config.num_classes, preprocessor.coarse_vocab,
                preprocessor.fine_vocab, model_config.fine_embedding_dim,
                model_config.hidden_dim, model_config.multi_step,
                model_config.prediction_drop_ratio)

    opt = optim.Adam(model.parameters(), lr=args.learning_rate)
    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)

    writer = SummaryWriter(f"{exp_dir}/runs")
    checkpoint_manager = CheckpointManager(exp_dir)
    summary_manager = SummaryManager(exp_dir)
    best_val_loss = 1e10

    for epoch in tqdm(range(args.epochs), desc="epochs"):

        tr_loss = 0
        tr_acc = 0

        model.train()
        for step, mb in tqdm(enumerate(tr_dl), desc="steps", total=len(tr_dl)):
            qa_mb, qb_mb, y_mb = map(
                lambda elm: (el.to(device) for el in elm)
                if isinstance(elm, tuple) else elm.to(device), mb)
            opt.zero_grad()
            y_hat_mb = model((qa_mb, qb_mb))
            mb_loss = log_loss(y_hat_mb, y_mb)
            mb_loss.backward()
            opt.step()

            with torch.no_grad():
                mb_acc = acc(y_hat_mb, y_mb)

            tr_loss += mb_loss.item()
            tr_acc += mb_acc.item()

            if (epoch * len(tr_dl) + step) % args.summary_step == 0:
                val_loss = evaluate(model, val_dl, {"loss": log_loss},
                                    device)["loss"]
                writer.add_scalars("loss", {
                    "train": tr_loss / (step + 1),
                    "val": val_loss
                },
                                   epoch * len(tr_dl) + step)
                model.train()
        else:
            tr_loss /= step + 1
            tr_acc /= step + 1

            tr_summary = {"loss": tr_loss, "acc": tr_acc}
            val_summary = evaluate(model, val_dl, {
                "loss": log_loss,
                "acc": acc
            }, device)
            tqdm.write(
                f"epoch: {epoch+1}\n"
                f"tr_loss: {tr_summary['loss']:.3f}, val_loss: {val_summary['loss']:.3f}\n"
                f"tr_acc: {tr_summary['acc']:.2%}, val_acc: {val_summary['acc']:.2%}"
            )

            val_loss = val_summary["loss"]
            is_best = val_loss < best_val_loss

            if is_best:
                state = {
                    "epoch": epoch + 1,
                    "model_state_dict": model.state_dict(),
                    "opt_state_dict": opt.state_dict(),
                }
                summary = {"train": tr_summary, "validation": val_summary}

                summary_manager.update(summary)
                summary_manager.save("summary.json")
                checkpoint_manager.save_checkpoint(state, "best.tar")

                best_val_loss = val_loss
Exemple #13
0
    # tokenizer
    with open(data_config.token_vocab, mode="rb") as io:
        token_vocab = pickle.load(io)
    with open(data_config.label_vocab, mode="rb") as io:
        label_vocab = pickle.load(io)
    token_tokenizer = Tokenizer(token_vocab, split_to_self)
    label_tokenizer = Tokenizer(label_vocab, split_to_self)

    # model (restore)
    checkpoint_manager = CheckpointManager(model_dir)
    checkpoint = checkpoint_manager.load_checkpoint(args.restore_file + ".tar")
    model = BilstmCRF(label_vocab, token_vocab, model_config.lstm_hidden_dim)
    model.load_state_dict(checkpoint["model_state_dict"])

    # evaluation
    summary_manager = SummaryManager(model_dir)
    filepath = getattr(data_config, args.data_name)
    ds = Corpus(
        filepath,
        token_tokenizer.split_and_transform,
        label_tokenizer.split_and_transform,
    )
    dl = DataLoader(ds,
                    batch_size=model_config.batch_size,
                    num_workers=4,
                    collate_fn=batchify)
    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)

    f1_score = get_f1_score(model, dl, device)
Exemple #14
0
def main(args):
    dataset_config = Config(args.dataset_config)
    model_config = Config(args.model_config)
    ptr_config_info = Config(f"conf/pretrained/{model_config.type}.json")

    exp_dir = Path("experiments") / model_config.type
    exp_dir = exp_dir.joinpath(
        f"epochs_{args.epochs}_batch_size_{args.batch_size}_learning_rate_{args.learning_rate}"
        f"_weight_decay_{args.weight_decay}")

    if not exp_dir.exists():
        exp_dir.mkdir(parents=True)

    if args.fix_seed:
        torch.manual_seed(777)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    preprocessor = get_preprocessor(ptr_config_info, model_config)

    with open(ptr_config_info.config, mode="r") as io:
        ptr_config = json.load(io)

    # model
    config = BertConfig()
    config.update(ptr_config)
    model = PairwiseClassifier(config,
                               num_classes=model_config.num_classes,
                               vocab=preprocessor.vocab)
    bert_pretrained = torch.load(ptr_config_info.bert)
    model.load_state_dict(bert_pretrained, strict=False)

    tr_dl, val_dl = get_data_loaders(dataset_config, preprocessor,
                                     args.batch_size)

    loss_fn = nn.CrossEntropyLoss()
    opt = optim.Adam([
        {
            "params": model.bert.parameters(),
            "lr": args.learning_rate / 100
        },
        {
            "params": model.classifier.parameters(),
            "lr": args.learning_rate
        },
    ],
                     weight_decay=args.weight_decay)

    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)

    writer = SummaryWriter(f'{exp_dir}/runs')
    checkpoint_manager = CheckpointManager(exp_dir)
    summary_manager = SummaryManager(exp_dir)
    best_val_loss = 1e+10

    for epoch in tqdm(range(args.epochs), desc='epochs'):

        tr_loss = 0
        tr_acc = 0

        model.train()
        for step, mb in tqdm(enumerate(tr_dl), desc='steps', total=len(tr_dl)):
            x_mb, x_types_mb, y_mb = map(lambda elm: elm.to(device), mb)
            opt.zero_grad()
            y_hat_mb = model(x_mb, x_types_mb)
            mb_loss = loss_fn(y_hat_mb, y_mb)
            mb_loss.backward()
            opt.step()

            with torch.no_grad():
                mb_acc = acc(y_hat_mb, y_mb)

            tr_loss += mb_loss.item()
            tr_acc += mb_acc.item()

            if (epoch * len(tr_dl) + step) % args.summary_step == 0:
                val_loss = evaluate(model, val_dl, {'loss': loss_fn},
                                    device)['loss']
                writer.add_scalars('loss', {
                    'train': tr_loss / (step + 1),
                    'val': val_loss
                },
                                   epoch * len(tr_dl) + step)
                model.train()
        else:
            tr_loss /= (step + 1)
            tr_acc /= (step + 1)

            tr_summary = {'loss': tr_loss, 'acc': tr_acc}
            val_summary = evaluate(model, val_dl, {
                'loss': loss_fn,
                'acc': acc
            }, device)
            tqdm.write(
                f"epoch: {epoch+1}\n"
                f"tr_loss: {tr_summary['loss']:.3f}, val_loss: {val_summary['loss']:.3f}\n"
                f"tr_acc: {tr_summary['acc']:.2%}, val_acc: {val_summary['acc']:.2%}"
            )

            val_loss = val_summary['loss']
            is_best = val_loss < best_val_loss

            if is_best:
                state = {
                    'epoch': epoch + 1,
                    'model_state_dict': model.state_dict(),
                    'opt_state_dict': opt.state_dict()
                }
                summary = {'train': tr_summary, 'validation': val_summary}

                summary_manager.update(summary)
                summary_manager.save('summary.json')
                checkpoint_manager.save_checkpoint(state, 'best.tar')

                best_val_loss = val_loss
    pad_sequence = PadSequence(length=model_config.length,
                               pad_val=vocab.to_indices(vocab.padding_token))
    tokenizer = Tokenizer(vocab=vocab,
                          split_fn=split_to_jamo,
                          pad_fn=pad_sequence)

    # model (restore)
    checkpoint_manager = CheckpointManager(model_dir)
    checkpoint = checkpoint_manager.load_checkpoint(args.restore_file + '.tar')
    model = CharCNN(num_classes=model_config.num_classes,
                    embedding_dim=model_config.embedding_dim,
                    vocab=tokenizer.vocab)
    model.load_state_dict(checkpoint['model_state_dict'])

    # evaluation
    summary_manager = SummaryManager(model_dir)
    filepath = getattr(data_config, args.data_name)
    ds = Corpus(filepath, tokenizer.split_and_transform)
    dl = DataLoader(ds, batch_size=model_config.batch_size, num_workers=4)

    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)

    summary = evaluate(model, dl, {
        'loss': nn.CrossEntropyLoss(),
        'acc': acc
    }, device)

    summary_manager.load('summary.json')
    summary_manager.update({'{}'.format(args.data_name): summary})
Exemple #16
0
def main(args):
    dataset_config = Config(args.dataset_config)
    model_config = Config(args.model_config)

    exp_dir = Path("experiments") / model_config.type
    exp_dir = exp_dir.joinpath(
        f"epochs_{args.epochs}_batch_size_{args.batch_size}_learning_rate_{args.learning_rate}"
    )

    if not exp_dir.exists():
        exp_dir.mkdir(parents=True)

    if args.fix_seed:
        torch.manual_seed(777)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    tokenizer = get_tokenizer(dataset_config, model_config)
    tr_dl, val_dl = get_data_loaders(dataset_config, tokenizer, args.batch_size)
    model = SenCNN(num_classes=model_config.num_classes, vocab=tokenizer.vocab)

    loss_fn = nn.CrossEntropyLoss()
    opt = optim.Adam(params=model.parameters(), lr=args.learning_rate)
    scheduler = ReduceLROnPlateau(opt, patience=5)
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)

    writer = SummaryWriter(f"{exp_dir}/runs")
    checkpoint_manager = CheckpointManager(exp_dir)
    summary_manager = SummaryManager(exp_dir)
    best_val_loss = 1e10

    for epoch in tqdm(range(args.epochs), desc="epochs"):

        tr_loss = 0
        tr_acc = 0

        model.train()
        for step, mb in tqdm(enumerate(tr_dl), desc="steps", total=len(tr_dl)):
            x_mb, y_mb = map(lambda elm: elm.to(device), mb)

            opt.zero_grad()
            y_hat_mb = model(x_mb)
            mb_loss = loss_fn(y_hat_mb, y_mb)
            mb_loss.backward()
            clip_grad_norm_(model._fc.weight, 5)
            opt.step()

            with torch.no_grad():
                mb_acc = acc(y_hat_mb, y_mb)

            tr_loss += mb_loss.item()
            tr_acc += mb_acc.item()

            if (epoch * len(tr_dl) + step) % args.summary_step == 0:
                val_loss = evaluate(model, val_dl, {"loss": loss_fn}, device)["loss"]
                writer.add_scalars("loss", {"train": tr_loss / (step + 1), "validation": val_loss},
                                   epoch * len(tr_dl) + step)
                model.train()
        else:
            tr_loss /= step + 1
            tr_acc /= step + 1

            tr_summary = {"loss": tr_loss, "acc": tr_acc}
            val_summary = evaluate(model, val_dl, {"loss": loss_fn, "acc": acc}, device)
            scheduler.step(val_summary["loss"])
            tqdm.write(f"epoch: {epoch+1}\n"
                       f"tr_loss: {tr_summary['loss']:.3f}, val_loss: {val_summary['loss']:.3f}\n"
                       f"tr_acc: {tr_summary['acc']:.2%}, val_acc: {val_summary['acc']:.2%}")

            val_loss = val_summary["loss"]
            is_best = val_loss < best_val_loss

            if is_best:
                state = {
                    "epoch": epoch + 1,
                    "model_state_dict": model.state_dict(),
                    "opt_state_dict": opt.state_dict(),
                }
                summary = {
                    "epoch": epoch + 1,
                    "train": tr_summary,
                    "validation": val_summary,
                }

                summary_manager.update(summary)
                summary_manager.save("summary.json")
                checkpoint_manager.save_checkpoint(state, "best.tar")

                best_val_loss = val_loss
                                   type=list,
                                   help='Set GPU for training')

if __name__ == '__main__':
    args = parser.parse_args()
    data_dir = Path(args.data_dir) / args.data
    restore_dir = Path(args.restore_dir) / args.data

    assert args.data in ['wikidatasets',
                         'fb15k'], "Invalid knowledge graph dataset"
    if args.data == 'wikidatasets':
        data_dir = data_dir / args.which
        restore_dir = restore_dir / args.which
    restore_dir = restore_dir / args.model

    summary_manager = SummaryManager(restore_dir)
    summary_manager.load(f'summary_{args.model}.json')
    previous_summary = summary_manager.summary
    ent_dim = previous_summary['Experiment Summary']['entity dimension']
    rel_dim = previous_summary['Experiment Summary']['relation dimension']
    limit = previous_summary['Experiment Summary']['limit']
    margin = previous_summary['Experiment Summary']['margin']

    with open(data_dir / 'kg_test.pkl', mode='rb') as io:
        kg_test = pickle.load(io)
    with open(data_dir / 'kg_valid.pkl', mode='rb') as io:
        kg_valid = pickle.load(io)

    # restore model
    assert args.model in ['TransE', 'TransR', 'DistMult',
                          'TransD'], "Invalid Knowledge Graph Embedding Model"
Exemple #18
0
    val_dl = DataLoader(val_ds, batch_size=model_config.batch_size, num_workers=4)

    loss_fn = nn.CrossEntropyLoss()
    opt = optim.Adam(
        [
            {"params": model.bert.parameters(), "lr": model_config.learning_rate / 100},
            {"params": model.classifier.parameters(), "lr": model_config.learning_rate},

        ], weight_decay=5e-4)

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)

    writer = SummaryWriter('{}/runs_{}'.format(model_dir, args.type))
    checkpoint_manager = CheckpointManager(model_dir)
    summary_manager = SummaryManager(model_dir)
    best_val_loss = 1e+10

    for epoch in tqdm(range(model_config.epochs), desc='epochs'):

        tr_loss = 0
        tr_acc = 0

        model.train()
        for step, mb in tqdm(enumerate(tr_dl), desc='steps', total=len(tr_dl)):
            x_mb, y_mb = map(lambda elm: elm.to(device), mb)
            opt.zero_grad()
            y_hat_mb = model(x_mb)
            mb_loss = loss_fn(y_hat_mb, y_mb)
            mb_loss.backward()
            opt.step()
Exemple #19
0
        {
            "params": model.classifier.parameters(),
            "lr": model_config.learning_rate
        },
    ],
                     weight_decay=5e-4)

    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)

    writer = SummaryWriter('{}/runs_{}'.format(model_dir,
                                               args.pretrained_config))

    checkpoint_manager = CheckpointManager(model_dir)
    summary_manager = SummaryManager(model_dir)
    best_val_loss = 1e+10

    for epoch in tqdm(range(model_config.epochs), desc='epochs'):

        tr_loss = 0
        tr_acc = 0

        model.train()
        for step, mb in tqdm(enumerate(tr_dl), desc='steps', total=len(tr_dl)):
            x_mb, y_mb = map(lambda elm: elm.to(device), mb)
            opt.zero_grad()
            y_hat_mb = model(x_mb)
            mb_loss = loss_fn(y_hat_mb, y_mb)
            mb_loss.backward()
            opt.step()
    # model (restore)
    checkpoint_manager = CheckpointManager(model_dir)
    checkpoint = checkpoint_manager.load_checkpoint('best_{}.tar'.format(
        args.type))
    config = BertConfig(ptr_config.config)
    model = PairwiseClassifier(config,
                               num_classes=model_config.num_classes,
                               vocab=preprocessor.vocab)
    model.load_state_dict(checkpoint['model_state_dict'])

    # evaluation
    filepath = getattr(data_config, args.dataset)
    ds = Corpus(filepath, preprocessor.preprocess)
    dl = DataLoader(ds, batch_size=model_config.batch_size, num_workers=4)

    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)

    summary_manager = SummaryManager(model_dir)
    summary = evaluate(model, dl, {
        'loss': nn.CrossEntropyLoss(),
        'acc': acc
    }, device)

    summary_manager.load('summary_{}.json'.format(args.type))
    summary_manager.update({'{}'.format(args.dataset): summary})
    summary_manager.save('summary_{}.json'.format(args.type))

    print('loss: {:.3f}, acc: {:.2%}'.format(summary['loss'], summary['acc']))
Exemple #21
0
def main(args):
    dataset_config = Config(args.dataset_config)
    model_config = Config(args.model_config)

    exp_dir = Path("experiments") / model_config.type
    exp_dir = exp_dir.joinpath(
        f"epochs_{args.epochs}_batch_size_{args.batch_size}_learning_rate_{args.learning_rate}"
        f"_teacher_forcing_ratio_{args.teacher_forcing_ratio}")

    if not exp_dir.exists():
        exp_dir.mkdir(parents=True)

    if args.fix_seed:
        torch.manual_seed(777)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # processor
    src_processor, tgt_processor = get_processor(dataset_config)

    # data_loaders
    tr_dl, val_dl = get_data_loaders(dataset_config, src_processor,
                                     tgt_processor, args.batch_size)

    # model
    encoder = BidiEncoder(src_processor.vocab, model_config.encoder_hidden_dim,
                          model_config.drop_ratio)
    decoder = AttnDecoder(
        tgt_processor.vocab,
        model_config.method,
        model_config.encoder_hidden_dim * 2,
        model_config.decoder_hidden_dim,
        model_config.drop_ratio,
    )

    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")
    encoder.to(device)
    decoder.to(device)

    writer = SummaryWriter("{}/runs".format(exp_dir))
    checkpoint_manager = CheckpointManager(exp_dir)
    summary_manager = SummaryManager(exp_dir)
    best_val_loss = 1e10

    opt = optim.RMSprop(
        [{
            "params": encoder.parameters()
        }, {
            "params": decoder.parameters()
        }],
        lr=args.learning_rate,
    )
    scheduler = ReduceLROnPlateau(opt, patience=5)

    for epoch in tqdm(range(args.epochs), desc="epochs"):
        tr_loss = 0

        encoder.train()
        decoder.train()

        for step, mb in tqdm(enumerate(tr_dl), desc="steps", total=len(tr_dl)):
            mb_loss = 0
            src_mb, tgt_mb = map(lambda elm: elm.to(device), mb)
            opt.zero_grad()

            # encoder
            enc_outputs_mb, src_length_mb, enc_hc_mb = encoder(src_mb)

            # decoder
            dec_input_mb = torch.ones((tgt_mb.size()[0], 1),
                                      device=device).long()
            dec_input_mb *= tgt_processor.vocab.to_indices(
                tgt_processor.vocab.bos_token)
            dec_hc_mb = None
            tgt_length_mb = tgt_mb.ne(
                tgt_processor.vocab.to_indices(
                    tgt_processor.vocab.padding_token)).sum(dim=1)
            tgt_mask_mb = sequence_mask(tgt_length_mb, tgt_length_mb.max())

            use_teacher_forcing = (True if
                                   random.random() > args.teacher_forcing_ratio
                                   else False)

            if use_teacher_forcing:
                for t in range(tgt_length_mb.max()):
                    dec_output_mb, dec_hc_mb = decoder(dec_input_mb, dec_hc_mb,
                                                       enc_outputs_mb,
                                                       src_length_mb)
                    sequence_loss = mask_nll_loss(dec_output_mb, tgt_mb[:,
                                                                        [t]],
                                                  tgt_mask_mb[:, [t]])
                    mb_loss += sequence_loss
                    dec_input_mb = tgt_mb[:,
                                          [t]]  # next input is current target
                else:
                    mb_loss /= tgt_length_mb.max()
            else:
                for t in range(tgt_length_mb.max()):
                    dec_output_mb, dec_hc_mb = decoder(dec_input_mb, dec_hc_mb,
                                                       enc_outputs_mb,
                                                       src_length_mb)
                    sequence_loss = mask_nll_loss(dec_output_mb, tgt_mb[:,
                                                                        [t]],
                                                  tgt_mask_mb[:, [t]])
                    mb_loss += sequence_loss
                    dec_input_mb = dec_output_mb.topk(1).indices
                else:
                    mb_loss /= tgt_length_mb.max()

            # update params
            mb_loss.backward()
            nn.utils.clip_grad_norm_(encoder.parameters(), args.clip_norm)
            nn.utils.clip_grad_norm_(decoder.parameters(), args.clip_norm)
            opt.step()

            tr_loss += mb_loss.item()

            if (epoch * len(tr_dl) + step) % args.summary_step == 0:
                val_loss = evaluate(encoder, decoder, tgt_processor.vocab,
                                    val_dl, device)
                writer.add_scalars(
                    "perplexity",
                    {
                        "train": np.exp(tr_loss / (step + 1)),
                        "validation": np.exp(val_loss)
                    },
                    epoch * len(tr_dl) + step,
                )
                encoder.train()
                decoder.train()

        else:
            tr_loss /= step + 1

            tr_summary = {"perplexity": np.exp(tr_loss)}
            val_loss = evaluate(encoder, decoder, tgt_processor.vocab, val_dl,
                                device)
            val_summary = {"perplexity": np.exp(val_loss)}
            scheduler.step(np.exp(val_loss))

            tqdm.write("epoch : {}, tr_ppl: {:.3f}, val_ppl: "
                       "{:.3f}".format(epoch + 1, tr_summary["perplexity"],
                                       val_summary["perplexity"]))

            is_best = val_loss < best_val_loss

            if is_best:
                state = {
                    "epoch": epoch + 1,
                    "encoder_state_dict": encoder.state_dict(),
                    "decoder_state_dict": decoder.state_dict(),
                    "opt_state_dict": opt.state_dict(),
                }
                summary = {
                    "epoch": epoch + 1,
                    "train": tr_summary,
                    "validation": val_summary,
                }

                summary_manager.update(summary)
                summary_manager.save("summary.json")
                checkpoint_manager.save_checkpoint(state, "best.tar")

                best_val_loss = val_loss