Ejemplo n.º 1
0
                                            'val': val_loss}, epoch * len(tr_dl) + step)
                tqdm.write('global_step: {:3}, tr_loss: {:.3f}, val_loss: {:.3f}'.format(epoch * len(tr_dl) + step,
                                                                                         tr_loss / (step + 1),
                                                                                         val_loss))
                model.train()
        else:
            tr_loss /= (step + 1)
            tr_acc /= (step + 1)

            tr_summary = {'loss': tr_loss, 'acc': tr_acc}
            val_summary = evaluate(model, val_dl, {'loss': loss_fn, 'acc': acc}, device)
            tqdm.write('epoch : {}, tr_loss: {:.3f}, val_loss: '
                       '{:.3f}, tr_acc: {:.2%}, val_acc: {:.2%}'.format(epoch + 1, tr_summary['loss'],
                                                                        val_summary['loss'], tr_summary['acc'],
                                                                        val_summary['acc']))

            val_loss = val_summary['loss']
            is_best = val_loss < best_val_loss

            if is_best:
                state = {'epoch': epoch + 1,
                         'model_state_dict': model.state_dict(),
                         'opt_state_dict': opt.state_dict()}
                summary = {'train': tr_summary, 'validation': val_summary}

                summary_manager.update(summary)
                summary_manager.save('summary_{}.json'.format(args.type))
                checkpoint_manager.save_checkpoint(state, 'best_{}.tar'.format(args.type))

                best_val_loss = val_loss
Ejemplo n.º 2
0
def main(args):
    dataset_config = Config(args.dataset_config)
    model_config = Config(args.model_config)

    exp_dir = Path("experiments") / model_config.type
    exp_dir = exp_dir.joinpath(
        f"epochs_{args.epochs}_batch_size_{args.batch_size}_learning_rate_{args.learning_rate}"
    )

    if not exp_dir.exists():
        exp_dir.mkdir(parents=True)

    if args.fix_seed:
        torch.manual_seed(777)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    tokenizer = get_tokenizer(dataset_config)
    tr_dl, val_dl = get_data_loaders(dataset_config, model_config, tokenizer,
                                     args.batch_size)

    # model
    model = ConvRec(num_classes=model_config.num_classes,
                    embedding_dim=model_config.embedding_dim,
                    hidden_dim=model_config.hidden_dim,
                    vocab=tokenizer.vocab)

    loss_fn = nn.CrossEntropyLoss()
    opt = optim.Adam(params=model.parameters(), lr=args.learning_rate)
    scheduler = ReduceLROnPlateau(opt, patience=5)
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)

    writer = SummaryWriter('{}/runs'.format(exp_dir))
    checkpoint_manager = CheckpointManager(exp_dir)
    summary_manager = SummaryManager(exp_dir)
    best_val_loss = 1e+10

    for epoch in tqdm(range(args.epochs), desc='epochs'):

        tr_loss = 0
        tr_acc = 0

        model.train()
        for step, mb in tqdm(enumerate(tr_dl), desc='steps', total=len(tr_dl)):
            x_mb, y_mb = map(lambda elm: elm.to(device), mb)

            opt.zero_grad()
            y_hat_mb = model(x_mb)
            mb_loss = loss_fn(y_hat_mb, y_mb)
            mb_loss.backward()
            opt.step()

            with torch.no_grad():
                mb_acc = acc(y_hat_mb, y_mb)

            tr_loss += mb_loss.item()
            tr_acc += mb_acc.item()

            if (epoch * len(tr_dl) + step) % args.summary_step == 0:
                val_loss = evaluate(model, val_dl, {'loss': loss_fn},
                                    device)['loss']
                writer.add_scalars('loss', {
                    'train': tr_loss / (step + 1),
                    'val': val_loss
                },
                                   epoch * len(tr_dl) + step)
                model.train()
        else:
            tr_loss /= (step + 1)
            tr_acc /= (step + 1)

            tr_summary = {'loss': tr_loss, 'acc': tr_acc}
            val_summary = evaluate(model, val_dl, {
                'loss': loss_fn,
                'acc': acc
            }, device)
            scheduler.step(val_summary['loss'])
            tqdm.write('epoch : {}, tr_loss: {:.3f}, val_loss: '
                       '{:.3f}, tr_acc: {:.2%}, val_acc: {:.2%}'.format(
                           epoch + 1, tr_summary['loss'], val_summary['loss'],
                           tr_summary['acc'], val_summary['acc']))

            val_loss = val_summary['loss']
            is_best = val_loss < best_val_loss

            if is_best:
                state = {
                    'epoch': epoch + 1,
                    'model_state_dict': model.state_dict(),
                    'opt_state_dict': opt.state_dict()
                }
                summary = {'train': tr_summary, 'validation': val_summary}

                summary_manager.update(summary)
                summary_manager.save('summary.json')
                checkpoint_manager.save_checkpoint(state, 'best.tar')

                best_val_loss = val_loss
Ejemplo n.º 3
0
def main(args):
    dataset_config = Config(args.dataset_config)
    model_config = Config(args.model_config)

    exp_dir = Path("experiments") / model_config.type
    exp_dir = exp_dir.joinpath(
        f"epochs_{args.epochs}_batch_size_{args.batch_size}_learning_rate_{args.learning_rate}"
    )

    if not exp_dir.exists():
        exp_dir.mkdir(parents=True)

    if args.fix_seed:
        torch.manual_seed(777)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    tokenizer = get_tokenizer(dataset_config, model_config)
    tr_dl, val_dl = get_data_loaders(dataset_config, tokenizer, args.batch_size)
    model = SenCNN(num_classes=model_config.num_classes, vocab=tokenizer.vocab)

    loss_fn = nn.CrossEntropyLoss()
    opt = optim.Adam(params=model.parameters(), lr=args.learning_rate)
    scheduler = ReduceLROnPlateau(opt, patience=5)
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)

    writer = SummaryWriter(f"{exp_dir}/runs")
    checkpoint_manager = CheckpointManager(exp_dir)
    summary_manager = SummaryManager(exp_dir)
    best_val_loss = 1e10

    for epoch in tqdm(range(args.epochs), desc="epochs"):

        tr_loss = 0
        tr_acc = 0

        model.train()
        for step, mb in tqdm(enumerate(tr_dl), desc="steps", total=len(tr_dl)):
            x_mb, y_mb = map(lambda elm: elm.to(device), mb)

            opt.zero_grad()
            y_hat_mb = model(x_mb)
            mb_loss = loss_fn(y_hat_mb, y_mb)
            mb_loss.backward()
            clip_grad_norm_(model._fc.weight, 5)
            opt.step()

            with torch.no_grad():
                mb_acc = acc(y_hat_mb, y_mb)

            tr_loss += mb_loss.item()
            tr_acc += mb_acc.item()

            if (epoch * len(tr_dl) + step) % args.summary_step == 0:
                val_loss = evaluate(model, val_dl, {"loss": loss_fn}, device)["loss"]
                writer.add_scalars("loss", {"train": tr_loss / (step + 1), "validation": val_loss},
                                   epoch * len(tr_dl) + step)
                model.train()
        else:
            tr_loss /= step + 1
            tr_acc /= step + 1

            tr_summary = {"loss": tr_loss, "acc": tr_acc}
            val_summary = evaluate(model, val_dl, {"loss": loss_fn, "acc": acc}, device)
            scheduler.step(val_summary["loss"])
            tqdm.write(f"epoch: {epoch+1}\n"
                       f"tr_loss: {tr_summary['loss']:.3f}, val_loss: {val_summary['loss']:.3f}\n"
                       f"tr_acc: {tr_summary['acc']:.2%}, val_acc: {val_summary['acc']:.2%}")

            val_loss = val_summary["loss"]
            is_best = val_loss < best_val_loss

            if is_best:
                state = {
                    "epoch": epoch + 1,
                    "model_state_dict": model.state_dict(),
                    "opt_state_dict": opt.state_dict(),
                }
                summary = {
                    "epoch": epoch + 1,
                    "train": tr_summary,
                    "validation": val_summary,
                }

                summary_manager.update(summary)
                summary_manager.save("summary.json")
                checkpoint_manager.save_checkpoint(state, "best.tar")

                best_val_loss = val_loss
Ejemplo n.º 4
0
            tr_loss /= (step + 1)
            tr_acc /= (step + 1)

            tr_summary = {'loss': tr_loss, 'acc': tr_acc}
            val_summary = evaluate(model, val_dl, {
                'loss': loss_fn,
                'acc': acc
            }, device)
            scheduler.step(val_summary['loss'])
            tqdm.write('epoch : {}, tr_loss: {:.3f}, val_loss: '
                       '{:.3f}, tr_acc: {:.2%}, val_acc: {:.2%}'.format(
                           epoch + 1, tr_summary['loss'], val_summary['loss'],
                           tr_summary['acc'], val_summary['acc']))

            val_loss = val_summary['loss']
            is_best = val_loss < best_val_loss

            if is_best:
                state = {
                    'epoch': epoch + 1,
                    'model_state_dict': model.state_dict(),
                    'opt_state_dict': opt.state_dict()
                }
                summary = {'train': tr_summary, 'validation': val_summary}

                summary_manager.update(summary)
                summary_manager.save('summary.json')
                checkpoint_manager.save_checkpoint(state, 'best.tar')

                best_val_loss = val_loss
Ejemplo n.º 5
0
def main(args):
    dataset_config = Config(args.dataset_config)
    model_config = Config(args.model_config)

    exp_dir = Path("experiments") / model_config.type
    exp_dir = exp_dir.joinpath(
        f"epochs_{args.epochs}_batch_size_{args.batch_size}_learning_rate_{args.learning_rate}"
        f"_teacher_forcing_ratio_{args.teacher_forcing_ratio}")

    if not exp_dir.exists():
        exp_dir.mkdir(parents=True)

    if args.fix_seed:
        torch.manual_seed(777)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # processor
    src_processor, tgt_processor = get_processor(dataset_config)

    # data_loaders
    tr_dl, val_dl = get_data_loaders(dataset_config, src_processor,
                                     tgt_processor, args.batch_size)

    # model
    encoder = BidiEncoder(src_processor.vocab, model_config.encoder_hidden_dim,
                          model_config.drop_ratio)
    decoder = AttnDecoder(
        tgt_processor.vocab,
        model_config.method,
        model_config.encoder_hidden_dim * 2,
        model_config.decoder_hidden_dim,
        model_config.drop_ratio,
    )

    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")
    encoder.to(device)
    decoder.to(device)

    writer = SummaryWriter("{}/runs".format(exp_dir))
    checkpoint_manager = CheckpointManager(exp_dir)
    summary_manager = SummaryManager(exp_dir)
    best_val_loss = 1e10

    opt = optim.RMSprop(
        [{
            "params": encoder.parameters()
        }, {
            "params": decoder.parameters()
        }],
        lr=args.learning_rate,
    )
    scheduler = ReduceLROnPlateau(opt, patience=5)

    for epoch in tqdm(range(args.epochs), desc="epochs"):
        tr_loss = 0

        encoder.train()
        decoder.train()

        for step, mb in tqdm(enumerate(tr_dl), desc="steps", total=len(tr_dl)):
            mb_loss = 0
            src_mb, tgt_mb = map(lambda elm: elm.to(device), mb)
            opt.zero_grad()

            # encoder
            enc_outputs_mb, src_length_mb, enc_hc_mb = encoder(src_mb)

            # decoder
            dec_input_mb = torch.ones((tgt_mb.size()[0], 1),
                                      device=device).long()
            dec_input_mb *= tgt_processor.vocab.to_indices(
                tgt_processor.vocab.bos_token)
            dec_hc_mb = None
            tgt_length_mb = tgt_mb.ne(
                tgt_processor.vocab.to_indices(
                    tgt_processor.vocab.padding_token)).sum(dim=1)
            tgt_mask_mb = sequence_mask(tgt_length_mb, tgt_length_mb.max())

            use_teacher_forcing = (True if
                                   random.random() > args.teacher_forcing_ratio
                                   else False)

            if use_teacher_forcing:
                for t in range(tgt_length_mb.max()):
                    dec_output_mb, dec_hc_mb = decoder(dec_input_mb, dec_hc_mb,
                                                       enc_outputs_mb,
                                                       src_length_mb)
                    sequence_loss = mask_nll_loss(dec_output_mb, tgt_mb[:,
                                                                        [t]],
                                                  tgt_mask_mb[:, [t]])
                    mb_loss += sequence_loss
                    dec_input_mb = tgt_mb[:,
                                          [t]]  # next input is current target
                else:
                    mb_loss /= tgt_length_mb.max()
            else:
                for t in range(tgt_length_mb.max()):
                    dec_output_mb, dec_hc_mb = decoder(dec_input_mb, dec_hc_mb,
                                                       enc_outputs_mb,
                                                       src_length_mb)
                    sequence_loss = mask_nll_loss(dec_output_mb, tgt_mb[:,
                                                                        [t]],
                                                  tgt_mask_mb[:, [t]])
                    mb_loss += sequence_loss
                    dec_input_mb = dec_output_mb.topk(1).indices
                else:
                    mb_loss /= tgt_length_mb.max()

            # update params
            mb_loss.backward()
            nn.utils.clip_grad_norm_(encoder.parameters(), args.clip_norm)
            nn.utils.clip_grad_norm_(decoder.parameters(), args.clip_norm)
            opt.step()

            tr_loss += mb_loss.item()

            if (epoch * len(tr_dl) + step) % args.summary_step == 0:
                val_loss = evaluate(encoder, decoder, tgt_processor.vocab,
                                    val_dl, device)
                writer.add_scalars(
                    "perplexity",
                    {
                        "train": np.exp(tr_loss / (step + 1)),
                        "validation": np.exp(val_loss)
                    },
                    epoch * len(tr_dl) + step,
                )
                encoder.train()
                decoder.train()

        else:
            tr_loss /= step + 1

            tr_summary = {"perplexity": np.exp(tr_loss)}
            val_loss = evaluate(encoder, decoder, tgt_processor.vocab, val_dl,
                                device)
            val_summary = {"perplexity": np.exp(val_loss)}
            scheduler.step(np.exp(val_loss))

            tqdm.write("epoch : {}, tr_ppl: {:.3f}, val_ppl: "
                       "{:.3f}".format(epoch + 1, tr_summary["perplexity"],
                                       val_summary["perplexity"]))

            is_best = val_loss < best_val_loss

            if is_best:
                state = {
                    "epoch": epoch + 1,
                    "encoder_state_dict": encoder.state_dict(),
                    "decoder_state_dict": decoder.state_dict(),
                    "opt_state_dict": opt.state_dict(),
                }
                summary = {
                    "epoch": epoch + 1,
                    "train": tr_summary,
                    "validation": val_summary,
                }

                summary_manager.update(summary)
                summary_manager.save("summary.json")
                checkpoint_manager.save_checkpoint(state, "best.tar")

                best_val_loss = val_loss
Ejemplo n.º 6
0
            model.train()
        else:
            avg_tr_loss /= (step + 1)
            tr_acc /= tr_num_y

            avg_val_loss, val_acc = evaluate(model, loss_fn, val_dl, dev)

            tr_summary = {"loss": avg_tr_loss, "acc": tr_acc}
            val_summary = {"loss": avg_val_loss, "acc": val_acc}

            tqdm.write(
                "epoch: {}, tr_loss: {:.3f}, val_loss: {:.3f}, tr_acc: {:.3f}, val_acc: {:.3f}"
                .format(epoch + 1, tr_summary["loss"], val_summary["loss"],
                        tr_summary["acc"], val_summary["acc"]))

            is_best = avg_val_loss < best_val_loss

            if is_best:
                state = {
                    "epoch": epoch + 1,
                    "model_state_dict": model.state_dict(),
                    "opt_state_dict": opt.state_dict()
                }
                summary = {"train": tr_summary, "validation": val_summary}

                summary_manager.update(summary)
                summary_manager.save("summary.json")
                checkpoint_manager.save_checkpoint(state, "best.tar")

                best_val_loss = avg_val_loss
Ejemplo n.º 7
0
            tr_acc /= (step + 1)

            tr_summary = {'loss': tr_loss, 'acc': tr_acc}
            val_summary = evaluate(model, val_dl, {
                'loss': loss_fn,
                'acc': acc
            }, device)
            tqdm.write('epoch : {}, tr_loss: {:.3f}, val_loss: '
                       '{:.3f}, tr_acc: {:.2%}, val_acc: {:.2%}'.format(
                           epoch + 1, tr_summary['loss'], val_summary['loss'],
                           tr_summary['acc'], val_summary['acc']))

            val_loss = val_summary['loss']
            is_best = val_loss < best_val_loss

            if is_best:
                state = {
                    'epoch': epoch + 1,
                    'model_state_dict': model.state_dict(),
                    'opt_state_dict': opt.state_dict()
                }
                summary = {'train': tr_summary, 'validation': val_summary}

                summary_manager.update(summary)
                summary_manager.save('summary_snu_{}.json'.format(
                    args.pretrained_config))
                checkpoint_manager.save_checkpoint(
                    state, 'best_snu_{}.tar'.format(args.pretrained_config))

                best_val_loss = val_loss
Ejemplo n.º 8
0
            n_h, n_t = sampler.corrupt_batch(h, t, r)
            with torch.no_grad():
                pos, neg = model(h, t, n_h, n_t, r)
                loss = criterion(pos, neg)
                val_loss += loss.item()
        val_loss /= (step + 1)
        writer.add_scalars('loss', {'train': tr_loss, 'val': val_loss}, epoch)
        if (epoch + 1) % args.summary_step == 0:
            tqdm.write(
                'Epoch {} | train loss: {:.5f}, valid loss: {:.5f}'.format(
                    epoch + 1, tr_loss, val_loss))
        model.normalize_parameters()
        is_best = val_loss < best_val_loss
        if is_best:
            state = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            summary = {
                'training loss': round(tr_loss, 4),
                'validation loss': round(val_loss, 4)
            }
            summary = dict(**summary)
            summary = {'Training Summary': summary}

            summary_manager.update(summary)
            summary_manager.save(f'summary_{args.model}.json')
            checkpoint_manager.save_checkpoint(state, f'best_{args.model}.tar')
            best_val_loss = val_loss
Ejemplo n.º 9
0
def main(args):
    dataset_config = Config(args.dataset_config)
    model_config = Config(args.model_config)

    exp_dir = Path("experiments") / model_config.type
    exp_dir = exp_dir.joinpath(
        f"epochs_{args.epochs}_batch_size_{args.batch_size}_learning_rate_{args.learning_rate}"
    )

    if not exp_dir.exists():
        exp_dir.mkdir(parents=True)

    if args.fix_seed:
        torch.manual_seed(777)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    preprocessor = get_preprocessor(dataset_config,
                                    coarse_split_fn=split_morphs,
                                    fine_split_fn=split_jamos)
    tr_dl, val_dl = get_data_loaders(dataset_config,
                                     preprocessor,
                                     args.batch_size,
                                     collate_fn=batchify)

    # model
    model = SAN(model_config.num_classes, preprocessor.coarse_vocab,
                preprocessor.fine_vocab, model_config.fine_embedding_dim,
                model_config.hidden_dim, model_config.multi_step,
                model_config.prediction_drop_ratio)

    opt = optim.Adam(model.parameters(), lr=args.learning_rate)
    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)

    writer = SummaryWriter(f"{exp_dir}/runs")
    checkpoint_manager = CheckpointManager(exp_dir)
    summary_manager = SummaryManager(exp_dir)
    best_val_loss = 1e10

    for epoch in tqdm(range(args.epochs), desc="epochs"):

        tr_loss = 0
        tr_acc = 0

        model.train()
        for step, mb in tqdm(enumerate(tr_dl), desc="steps", total=len(tr_dl)):
            qa_mb, qb_mb, y_mb = map(
                lambda elm: (el.to(device) for el in elm)
                if isinstance(elm, tuple) else elm.to(device), mb)
            opt.zero_grad()
            y_hat_mb = model((qa_mb, qb_mb))
            mb_loss = log_loss(y_hat_mb, y_mb)
            mb_loss.backward()
            opt.step()

            with torch.no_grad():
                mb_acc = acc(y_hat_mb, y_mb)

            tr_loss += mb_loss.item()
            tr_acc += mb_acc.item()

            if (epoch * len(tr_dl) + step) % args.summary_step == 0:
                val_loss = evaluate(model, val_dl, {"loss": log_loss},
                                    device)["loss"]
                writer.add_scalars("loss", {
                    "train": tr_loss / (step + 1),
                    "val": val_loss
                },
                                   epoch * len(tr_dl) + step)
                model.train()
        else:
            tr_loss /= step + 1
            tr_acc /= step + 1

            tr_summary = {"loss": tr_loss, "acc": tr_acc}
            val_summary = evaluate(model, val_dl, {
                "loss": log_loss,
                "acc": acc
            }, device)
            tqdm.write(
                f"epoch: {epoch+1}\n"
                f"tr_loss: {tr_summary['loss']:.3f}, val_loss: {val_summary['loss']:.3f}\n"
                f"tr_acc: {tr_summary['acc']:.2%}, val_acc: {val_summary['acc']:.2%}"
            )

            val_loss = val_summary["loss"]
            is_best = val_loss < best_val_loss

            if is_best:
                state = {
                    "epoch": epoch + 1,
                    "model_state_dict": model.state_dict(),
                    "opt_state_dict": opt.state_dict(),
                }
                summary = {"train": tr_summary, "validation": val_summary}

                summary_manager.update(summary)
                summary_manager.save("summary.json")
                checkpoint_manager.save_checkpoint(state, "best.tar")

                best_val_loss = val_loss
Ejemplo n.º 10
0
        'global_step': e + 1,
        'model_state_dict': model.state_dict(),
        'opt_state_dict': optimizer.state_dict()
    }

    summary = {'train': tr_summary, 'eval': eval_summary}
    summary_manager.update(summary)
    print("summary: ", summary)
    summary_manager.save('summary.json')

    # save
    is_best = eval_summary['f1'] >= best_dev_f1

    if is_best:
        best_dev_f1 = eval_summary['f1']
        checkpoint_manager.save_checkpoint(
            state, 'best-epoch-{}-f1-{:.3f}.bin'.format(e + 1, best_dev_f1))
        print("model checkpoint has been saved: best-epoch-{}-f1-{:.3f}.bin".
              format(e + 1, best_dev_f1))

        ## print classification report and save confusion matrix
        # cr_save_path = model_dir / 'best-epoch-{}-f1-{:.3f}-cr.csv'.format(e + 1, best_dev_f1)
        # cm_save_path = model_dir / 'best-epoch-{}-f1-{:.3f}-cm.png'.format(e + 1, best_dev_f1)
    else:
        torch.save(
            state,
            os.path.join(
                output_dir,
                'model-epoch-{}-f1-{:.3f}.bin'.format(e + 1,
                                                      eval_summary["f1"])))
        print("model checkpoint has been saved: best-epoch-{}-f1-{:.3f}.bin".
              format(e + 1, eval_summary['f1']))
Ejemplo n.º 11
0
def main(args):
    dataset_config = Config(args.dataset_config)
    model_config = Config(args.model_config)
    ptr_config_info = Config(f"conf/pretrained/{model_config.type}.json")

    exp_dir = Path("experiments") / model_config.type
    exp_dir = exp_dir.joinpath(
        f"epochs_{args.epochs}_batch_size_{args.batch_size}_learning_rate_{args.learning_rate}"
        f"_weight_decay_{args.weight_decay}")

    if not exp_dir.exists():
        exp_dir.mkdir(parents=True)

    if args.fix_seed:
        torch.manual_seed(777)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    preprocessor = get_preprocessor(ptr_config_info, model_config)

    with open(ptr_config_info.config, mode="r") as io:
        ptr_config = json.load(io)

    # model
    config = BertConfig()
    config.update(ptr_config)
    model = PairwiseClassifier(config,
                               num_classes=model_config.num_classes,
                               vocab=preprocessor.vocab)
    bert_pretrained = torch.load(ptr_config_info.bert)
    model.load_state_dict(bert_pretrained, strict=False)

    tr_dl, val_dl = get_data_loaders(dataset_config, preprocessor,
                                     args.batch_size)

    loss_fn = nn.CrossEntropyLoss()
    opt = optim.Adam([
        {
            "params": model.bert.parameters(),
            "lr": args.learning_rate / 100
        },
        {
            "params": model.classifier.parameters(),
            "lr": args.learning_rate
        },
    ],
                     weight_decay=args.weight_decay)

    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)

    writer = SummaryWriter(f'{exp_dir}/runs')
    checkpoint_manager = CheckpointManager(exp_dir)
    summary_manager = SummaryManager(exp_dir)
    best_val_loss = 1e+10

    for epoch in tqdm(range(args.epochs), desc='epochs'):

        tr_loss = 0
        tr_acc = 0

        model.train()
        for step, mb in tqdm(enumerate(tr_dl), desc='steps', total=len(tr_dl)):
            x_mb, x_types_mb, y_mb = map(lambda elm: elm.to(device), mb)
            opt.zero_grad()
            y_hat_mb = model(x_mb, x_types_mb)
            mb_loss = loss_fn(y_hat_mb, y_mb)
            mb_loss.backward()
            opt.step()

            with torch.no_grad():
                mb_acc = acc(y_hat_mb, y_mb)

            tr_loss += mb_loss.item()
            tr_acc += mb_acc.item()

            if (epoch * len(tr_dl) + step) % args.summary_step == 0:
                val_loss = evaluate(model, val_dl, {'loss': loss_fn},
                                    device)['loss']
                writer.add_scalars('loss', {
                    'train': tr_loss / (step + 1),
                    'val': val_loss
                },
                                   epoch * len(tr_dl) + step)
                model.train()
        else:
            tr_loss /= (step + 1)
            tr_acc /= (step + 1)

            tr_summary = {'loss': tr_loss, 'acc': tr_acc}
            val_summary = evaluate(model, val_dl, {
                'loss': loss_fn,
                'acc': acc
            }, device)
            tqdm.write(
                f"epoch: {epoch+1}\n"
                f"tr_loss: {tr_summary['loss']:.3f}, val_loss: {val_summary['loss']:.3f}\n"
                f"tr_acc: {tr_summary['acc']:.2%}, val_acc: {val_summary['acc']:.2%}"
            )

            val_loss = val_summary['loss']
            is_best = val_loss < best_val_loss

            if is_best:
                state = {
                    'epoch': epoch + 1,
                    'model_state_dict': model.state_dict(),
                    'opt_state_dict': opt.state_dict()
                }
                summary = {'train': tr_summary, 'validation': val_summary}

                summary_manager.update(summary)
                summary_manager.save('summary.json')
                checkpoint_manager.save_checkpoint(state, 'best.tar')

                best_val_loss = val_loss