Example #1
0
def main(args):
    #生成日志目录与模型目录
    os.makedirs(args.log_dir, exist_ok=True)
    os.makedirs(args.model_dir, exist_ok=True)
    #得到分批次的训练与验证数据集
    data_loader = get_loader(
        input_dir=args.input_dir,
        input_vqa_train='train.npy',
        input_vqa_valid='valid.npy',
        max_qst_length=args.max_qst_length,
        max_num_ans=args.max_num_ans,
        batch_size=args.batch_size,
        num_workers=args.num_workers)
    #问题词典的长度                    (VqaDataset-VocabDict)
    qst_vocab_size = data_loader['train'].dataset.qst_vocab.vocab_size
    #有效的答案的总长度
    ans_vocab_size = data_loader['train'].dataset.ans_vocab.vocab_size
    #有效答案中未知词的索引
    ans_unk_idx = data_loader['train'].dataset.ans_vocab.unk2idx

    #导入模型
    model = VqaModel(
        embed_size=args.embed_size,
        qst_vocab_size=qst_vocab_size,
        ans_vocab_size=ans_vocab_size,
        word_embed_size=args.word_embed_size,
        num_layers=args.num_layers,
        hidden_size=args.hidden_size).to(device)

    criterion = nn.CrossEntropyLoss()#多分类问题使用交叉熵损失
    #罗列全部需要训练的参数
    params = list(model.img_encoder.fc.parameters()) \
        + list(model.qst_encoder.parameters()) \
        + list(model.fc1.parameters()) \
        + list(model.fc2.parameters())
    #设定用于更新参数的优化器
    optimizer = optim.Adam(params, lr=args.learning_rate)
    ##设置调整学习率的机制##
    scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)

    for epoch in range(args.num_epochs):#对于每次迭代
        for phase in ['train', 'valid']: #分别运算训练样本与验证样本
            running_loss = 0.0 #统计本次迭代中的交叉熵损失和
            running_corr_exp1 = 0 #统计本次迭代中,预测值命中的有效答案数目
            running_corr_exp2 = 0 #不计预测值命中<unk>的情况
            #总共的batch数目
            batch_step_size = len(data_loader[phase].dataset) / args.batch_size

            if phase == 'train':#训练集的话,调整学习率、训练模型
                scheduler.step()
                model.train()
            else:
                model.eval()#验证集用来评估

            #对于每个batch
            for batch_idx, batch_sample in enumerate(data_loader[phase]):
                image = batch_sample['image'].to(device) #4维数组
                question = batch_sample['question'].to(device)#2维数组
                label = batch_sample['answer_label'].to(device) #batch_size*单标签
                multi_choice = batch_sample['answer_multi_choice']  # not tensor, list.

                optimizer.zero_grad()#先将梯度置0
                #只在训练时对梯度信息进行记录
                with torch.set_grad_enabled(phase == 'train'):
                    #代入数据得到输出值
                    output = model(image, question)      # [batch_size, ans_vocab_size=1000]
                    #得到最大值所在的索引,即答案标签
                    _, pred_exp1 = torch.max(output, 1)  # [batch_size]
                    _, pred_exp2 = torch.max(output, 1)  # [batch_size]
                    loss = criterion(output, label)#计算损失

                    if phase == 'train':#训练集的话,根据损失更新参数
                        loss.backward()
                        optimizer.step()

                # Evaluation metric of 'multiple choice'
                # Exp1: our model prediction to '<unk>' IS accepted as the answer.
                # Exp2: our model prediction to '<unk>' is NOT accepted as the answer.
                #将预测为<unk>的标签值(0)设置为-9999
                pred_exp2[pred_exp2 == ans_unk_idx] = -9999
                running_loss += loss.item()#累加每个batch的损失
                #串联各样本结果[batch_size,10].求和看这些样本的有效答案(单个样本的有效答案是有重复的)中出现了多少个预测答案
                running_corr_exp1 += torch.stack([(ans == pred_exp1.cpu()) for ans in multi_choice]).any(dim=0).sum()
                #<unk> 命中不算
                running_corr_exp2 += torch.stack([(ans == pred_exp2.cpu()) for ans in multi_choice]).any(dim=0).sum()

                # Print the average loss in a mini-batch.
                #打印batch中的损失
                if batch_idx % 100 == 0:
                    print('| {} SET | Epoch [{:02d}/{:02d}], Step [{:04d}/{:04d}], Loss: {:.4f}'
                          .format(phase.upper(), epoch+1, args.num_epochs, batch_idx, int(batch_step_size), loss.item()))

            # Print the average loss and accuracy in an epoch.
            #打印在本次迭代中每个batch的平均损失
            epoch_loss = running_loss / batch_step_size
            #打印两种精度。(分母表示全部样本数,其实这个比例有点问题)
            epoch_acc_exp1 = running_corr_exp1.double() / len(data_loader[phase].dataset)      # multiple choice
            epoch_acc_exp2 = running_corr_exp2.double() / len(data_loader[phase].dataset)      # multiple choice
            print('| {} SET | Epoch [{:02d}/{:02d}], Loss: {:.4f}, Acc(Exp1): {:.4f}, Acc(Exp2): {:.4f} \n'
                  .format(phase.upper(), epoch+1, args.num_epochs, epoch_loss, epoch_acc_exp1, epoch_acc_exp2))
            # Log the loss and accuracy in an epoch.
            #保存本次迭代的batch平均损失、2种精度。(.item()用于取元素)
            with open(os.path.join(args.log_dir, '{}-log-epoch-{:02}.txt')
                      .format(phase, epoch+1), 'w') as f:
                f.write(str(epoch+1) + '\t'
                        + str(epoch_loss) + '\t'
                        + str(epoch_acc_exp1.item()) + '\t'
                        + str(epoch_acc_exp2.item()))

        # Save the model check points.
        #训练集、验证集结束后,若迭代达到保存步,保存模型状态参数
        if (epoch+1) % args.save_step == 0:
            torch.save({'epoch': epoch+1, 'state_dict': model.state_dict()},
                       os.path.join(args.model_dir, 'model-epoch-{:02d}.ckpt'.format(epoch+1)))
Example #2
0
def main(args):

    os.makedirs(args.log_dir, exist_ok=True)
    os.makedirs(args.model_dir, exist_ok=True)

    data_loader = get_loader(input_dir=args.input_dir,
                             input_vqa_train='train.npy',
                             input_vqa_valid='valid.npy',
                             max_qst_length=args.max_qst_length,
                             max_num_ans=args.max_num_ans,
                             batch_size=args.batch_size,
                             num_workers=args.num_workers)

    qst_vocab_size = data_loader['train'].dataset.qst_vocab.vocab_size
    ans_vocab_size = data_loader['train'].dataset.ans_vocab.vocab_size
    ans_unk_idx = data_loader['train'].dataset.ans_vocab.unk2idx

    model = VqaModel(embed_size=args.embed_size,
                     qst_vocab_size=qst_vocab_size,
                     ans_vocab_size=ans_vocab_size,
                     word_embed_size=args.word_embed_size,
                     num_layers=args.num_layers,
                     hidden_size=args.hidden_size).to(device)

    criterion = nn.CrossEntropyLoss()

    params = list(model.img_encoder.fc.parameters()) \
        + list(model.qst_encoder.parameters()) \
        + list(model.fc1.parameters()) \
        + list(model.fc2.parameters())

    optimizer = optim.Adam(params, lr=args.learning_rate)
    scheduler = lr_scheduler.StepLR(optimizer,
                                    step_size=args.step_size,
                                    gamma=args.gamma)

    for epoch in range(args.num_epochs):

        for phase in ['train', 'valid']:

            running_loss = 0.0
            running_corr_exp1 = 0
            running_corr_exp2 = 0
            batch_step_size = len(data_loader[phase].dataset) / args.batch_size

            if phase == 'train':
                scheduler.step()
                model.train()
            else:
                model.eval()

            for batch_idx, batch_sample in enumerate(data_loader[phase]):

                image = batch_sample['image'].to(device)
                question = batch_sample['question'].to(device)
                label = batch_sample['answer_label'].to(device)
                multi_choice = batch_sample[
                    'answer_multi_choice']  # not tensor, list.

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):

                    output = model(
                        image, question)  # [batch_size, ans_vocab_size=1000]
                    _, pred_exp1 = torch.max(output, 1)  # [batch_size]
                    _, pred_exp2 = torch.max(output, 1)  # [batch_size]
                    loss = criterion(output, label)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # Evaluation metric of 'multiple choice'
                # Exp1: our model prediction to '<unk>' IS accepted as the answer.
                # Exp2: our model prediction to '<unk>' is NOT accepted as the answer.
                pred_exp2[pred_exp2 == ans_unk_idx] = -9999
                running_loss += loss.item()
                running_corr_exp1 += torch.stack([
                    (ans == pred_exp1.cpu()) for ans in multi_choice
                ]).any(dim=0).sum()
                running_corr_exp2 += torch.stack([
                    (ans == pred_exp2.cpu()) for ans in multi_choice
                ]).any(dim=0).sum()

                # Print the average loss in a mini-batch.
                if batch_idx % 100 == 0:
                    print(
                        '| {} SET | Epoch [{:02d}/{:02d}], Step [{:04d}/{:04d}], Loss: {:.4f}'
                        .format(phase.upper(), epoch + 1, args.num_epochs,
                                batch_idx, int(batch_step_size), loss.item()))

            # Print the average loss and accuracy in an epoch.
            epoch_loss = running_loss / batch_step_size
            epoch_acc_exp1 = running_corr_exp1.double() / len(
                data_loader[phase].dataset)  # multiple choice
            epoch_acc_exp2 = running_corr_exp2.double() / len(
                data_loader[phase].dataset)  # multiple choice

            print(
                '| {} SET | Epoch [{:02d}/{:02d}], Loss: {:.4f}, Acc(Exp1): {:.4f}, Acc(Exp2): {:.4f} \n'
                .format(phase.upper(), epoch + 1, args.num_epochs, epoch_loss,
                        epoch_acc_exp1, epoch_acc_exp2))

            # Log the loss and accuracy in an epoch.
            with open(
                    os.path.join(args.log_dir,
                                 '{}-log-epoch-{:02}.txt').format(
                                     phase, epoch + 1), 'w') as f:
                f.write(
                    str(epoch + 1) + '\t' + str(epoch_loss) + '\t' +
                    str(epoch_acc_exp1.item()) + '\t' +
                    str(epoch_acc_exp2.item()))

        # Save the model check points.
        if (epoch + 1) % args.save_step == 0:
            torch.save({
                'epoch': epoch + 1,
                'state_dict': model.state_dict()
            },
                       os.path.join(
                           args.model_dir,
                           'model-epoch-{:02d}.ckpt'.format(epoch + 1)))
Example #3
0
def main(cfg):
    gpu_id = cfg["hyperparameters"]["gpu_id"]
    # Use GPU if available
    if gpu_id >= 0:
        assert torch.cuda.is_available()
        device = torch.device("cuda:" + str(gpu_id))
        print("Using GPU {} | {}".format(gpu_id,
                                         torch.cuda.get_device_name(gpu_id)))
    elif gpu_id == -1:
        device = torch.device("cpu")
        print("Using the CPU")
    else:
        raise NotImplementedError(
            "Device ID {} not recognized. gpu_id = 0, 1, 2 etc. Use -1 for CPU"
            .format(gpu_id))

    data_loader = get_loader(
        input_dir=cfg["paths"]["input"],
        input_vqa_train="train.npy",
        input_vqa_valid="valid.npy",
        max_qst_length=cfg["hyperparameters"]["max_input_length"],
        max_num_ans=cfg["hyperparameters"]["max_num_answers"],
        batch_size=cfg["hyperparameters"]["batch_size"],
        num_workers=6)

    qst_vocab_size = data_loader['train'].dataset.qst_vocab.vocab_size
    ans_vocab_size = data_loader['train'].dataset.ans_vocab.vocab_size
    ans_list = data_loader['train'].dataset.ans_vocab.word_list
    ans_unk_idx = data_loader['train'].dataset.ans_vocab.unk2idx
    cfg["hyperparameters"]["qst_vocab_size"] = qst_vocab_size
    cfg["hyperparameters"]["ans_vocab_size"] = ans_vocab_size

    assert not (cfg["hyperparameters"]["use_dnc_c"]
                and cfg["hyperparameters"]["use_dnc_q"])
    _set_seed(cfg["hyperparameters"]["seed"])
    if cfg["hyperparameters"]["use_dnc_c"]:
        model = VqaModelDncC(cfg).to(device)
        net_name = "dnc_C_" + str(cfg["dnc"]["number"])

    elif cfg["hyperparameters"]["use_dnc_q"]:
        model = VqaModelDncQ(cfg).to(device)
        net_name = "dnc_Q"

    else:
        model = VqaModel(cfg).to(device)
        net_name = "Baseline"

        # embed_size=cfg["hyperparameters"]["commun_embed_size"],
        # qst_vocab_size=qst_vocab_size,
        # ans_vocab_size=ans_vocab_size,
        # word_embed_size=cfg["hyperparameters"]["embedding_dim"],
        # num_layers=args.num_layers,
        # hidden_size=args.hidden_size).to(device)

    criterion = nn.CrossEntropyLoss()
    if cfg["hyperparameters"]["use_dnc_c"]:
        dnc_params = {
            "params": model.dnc.parameters(),
            "lr": cfg["dnc_c"]["lr"]
        }
        img_encoder_params = {"params": model.img_encoder.fc.parameters()}
        qst_encoder_params = {"params": model.qst_encoder.fc.parameters()}
        if cfg["hyperparameters"]["optimizer"] == "adam":
            optimizer = optim.Adam(
                [dnc_params, img_encoder_params, qst_encoder_params],
                lr=cfg["hyperparameters"]["lr"],
                weight_decay=cfg["hyperparameters"]["weight_decay"])
        elif cfg["hyperparameters"]["optimizer"] == "sgd":
            optimizer = optim.SGD(
                [dnc_params, img_encoder_params, qst_encoder_params],
                lr=cfg["hyperparameters"]["lr"],
                weight_decay=cfg["hyperparameters"]["weight_decay"])
    elif cfg["hyperparameters"]["use_dnc_q"]:
        dnc_params = {
            "params": model.qst_encoder.dnc_q.parameters(),
            "lr": cfg["dnc_q"]["lr"]
        }
        embed_params = {"params": model.qst_encoder.word2vec.parameters()}
        img_encoder_params = {"params": model.img_encoder.fc.parameters()}
        #qst_encoder_params = {"params": model.qst_encoder.fc.parameters()}
        fc1_params = {"params": model.fc1.parameters()}
        fc2_params = {"params": model.fc2.parameters()}

        if cfg["hyperparameters"]["optimizer"] == "adam":
            optimizer = optim.Adam(
                [
                    dnc_params, embed_params, img_encoder_params, fc1_params,
                    fc2_params
                ],
                lr=cfg["hyperparameters"]["lr"],
                weight_decay=cfg["hyperparameters"]["weight_decay"])
        elif cfg["hyperparameters"]["optimizer"] == "sgd":
            optimizer = optim.SGD(
                [
                    dnc_params, embed_params, img_encoder_params, fc1_params,
                    fc2_params
                ],
                lr=cfg["hyperparameters"]["lr"],
                weight_decay=cfg["hyperparameters"]["weight_decay"])
    else:
        params = list(model.img_encoder.fc.parameters()) \
            + list(model.qst_encoder.parameters()) \
            + list(model.fc1.parameters()) \
            + list(model.fc2.parameters())
        optimizer = optim.Adam(params, lr=cfg["hyperparameters"]["lr"])
    print("Training " + net_name)
    scheduler = lr_scheduler.StepLR(
        optimizer,
        step_size=cfg["hyperparameters"]["lr_reduce_after"],
        gamma=cfg["hyperparameters"]["lr_decay_rate"])
    summary_writer = SummaryWriter(logdir=cfg["logging"]["tensorboard_dir"])
    tr_iter = 0
    val_iter = 0
    lr = 0
    lr_dnc = 0

    for epoch in range(cfg["hyperparameters"]["num_epochs"]):

        for phase in ['train', 'valid']:
            if cfg["hyperparameters"]["use_dnc_c"]:
                if cfg["dnc"]["number"] == 1:
                    model.dnc.update_batch_size(
                        cfg["hyperparameters"]["batch_size"])
                    h, mem = model.dnc.reset()
                elif cfg["dnc"]["number"] == 0:
                    (mem, rv) = model.dnc.init_hidden(
                        None, cfg["hyperparameters"]["batch_size"], True)
                else:
                    raise ValueError("No dnc number " + cfg["dnc"]["number"])
            if cfg["hyperparameters"]["use_dnc_q"]:
                (chx, mhx, rv) = (None, None, None)

            running_loss = 0.0
            dataloader = data_loader[phase]
            batch_step_size = len(
                dataloader.dataset) / cfg["hyperparameters"]["batch_size"]
            if phase == 'train':
                model.train()
            else:
                model.eval()
                val_predictions = []
            pbar = tqdm(dataloader)
            pbar.set_description("{} | Epcoh {}/{}".format(
                phase, epoch, cfg["hyperparameters"]["num_epochs"]))
            for batch_idx, batch_sample in enumerate(pbar):

                image = batch_sample['image'].to(device)
                question = batch_sample['question'].to(device)
                label = batch_sample['answer_label'].to(device)
                multi_choice = batch_sample[
                    'answer_multi_choice']  # not tensor, list.
                if image.size(0) != cfg["hyperparameters"]["batch_size"]:
                    if cfg["hyperparameters"]["use_dnc_c"]:
                        if cfg["dnc"]["number"] == 1:
                            model.dnc.update_batch_size(image.size(0))
                            h, mem = model.dnc.reset()
                        elif cfg["dnc"]["number"] == 0:
                            (mem, rv) = model.dnc.init_hidden(
                                None, image.size(0), False)
                        else:
                            raise ValueError("No dnc number " +
                                             cfg["dnc"]["number"])
                    if cfg["hyperparameters"]["use_dnc_q"]:
                        (chx, mhx, rv) = (None, None, None)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    if cfg["hyperparameters"]["use_dnc_c"]:
                        if cfg["dnc"]["number"] == 1:
                            output, h, mem = model(image,
                                                   question,
                                                   h=h,
                                                   mem=mem)
                        elif cfg["dnc"]["number"] == 0:
                            output, (mem, rv), v = model(image,
                                                         question,
                                                         mem=mem,
                                                         rv=rv)

                    elif cfg["hyperparameters"]["use_dnc_q"]:
                        output, (chx, mhx, rv), v = model(image,
                                                          question,
                                                          chx=chx,
                                                          mhx=mhx,
                                                          rv=rv)
                    else:
                        output = model(
                            image,
                            question)  # [batch_size, ans_vocab_size=1000]
                    _, pred = torch.max(output, 1)  # [batch_size]
                    # _, pred_exp2 = torch.max(output, 1)  # [batch_size]
                    loss = criterion(output, label)
                    if phase == 'train':
                        loss.backward()
                        # if iter % cfg["hyperparameters"]["grad_flow_interval"] == 0:
                        #     plot_grad_flow(model.named_parameters(), cfg["hyperparameters"]["grad_flow_dir"], str(tr_iter))
                        if cfg["hyperparameters"]["use_clip_grad"]:
                            nn.utils.clip_grad_norm_(
                                model.parameters(),
                                cfg["hyperparameters"]["clip_value"])
                        optimizer.step()

                        if cfg["hyperparameters"]["use_dnc_c"]:
                            lr_dnc = optimizer.param_groups[0]["lr"]
                            lr = optimizer.param_groups[1]["lr"]
                            dict_lr = {"DNC": lr_dnc, "Rest": lr}
                            summary_writer.add_scalars("lr",
                                                       dict_lr,
                                                       global_step=tr_iter)
                        elif cfg["hyperparameters"]["use_dnc_q"]:
                            lr_dnc = optimizer.param_groups[0]["lr"]
                            lr = optimizer.param_groups[1]["lr"]
                            dict_lr = {"DNC": lr_dnc, "Rest": lr}
                            summary_writer.add_scalars("lr",
                                                       dict_lr,
                                                       global_step=tr_iter)
                        else:
                            lr = optimizer.param_groups[0]["lr"]
                            summary_writer.add_scalar("lr",
                                                      lr,
                                                      global_step=tr_iter)

                    else:
                        question_ids = batch_sample["question_id"].tolist()
                        pred = pred.tolist()
                        pred = [ans_list[i] for i in pred]
                        for id_, ans in zip(question_ids, pred):
                            val_predictions.append({
                                "question_id": id_,
                                "answer": ans
                            })
                    if cfg["hyperparameters"]["use_dnc_c"]:
                        if cfg["dnc"]["number"] == 1:
                            mem = repackage_hidden(mem)
                        elif cfg["dnc"]["number"] == 0:
                            mem = {
                                k: (v.detach() if isinstance(v, var) else v)
                                for k, v in mem.items()
                            }
                            rv = rv.detach()
                    elif cfg["hyperparameters"]["use_dnc_q"]:
                        mhx = {
                            k: (v.detach() if isinstance(v, var) else v)
                            for k, v in mhx.items()
                        }

                # Evaluation metric of 'multiple choice'
                # Exp1: our model prediction to '<unk>' IS accepted as the answer.
                # Exp2: our model prediction to '<unk>' is NOT accepted as the answer.
                # pred_exp2[pred_exp2 == ans_unk_idx] = -9999
                running_loss += loss.item()
                summary_writer.add_scalar(
                    "Loss/" + phase + "_Batch",
                    loss.item(),
                    global_step=tr_iter if phase == "train" else val_iter)
                # running_corr_exp1 += torch.stack([(ans == pred_exp1.cpu()) for ans in multi_choice]).any(dim=0).sum()
                # running_corr_exp2 += torch.stack([(ans == pred_exp2.cpu()) for ans in multi_choice]).any(dim=0).sum()

                # Print the average loss in a mini-batch.
                # if batch_idx % 10 == 0:
                #     print('| {} SET | Epoch [{:02d}/{:02d}], Step [{:04d}/{:04d}], Loss: {:.4f}'
                #           .format(phase.upper(), epoch+1, args.num_epochs, batch_idx, int(batch_step_size), loss.item()))

                if phase == "train":
                    tr_iter += 1
                else:
                    val_iter += 1
            if phase == "train":
                scheduler.step()

            # Print the average loss and accuracy in an epoch.
            epoch_loss = running_loss / batch_step_size
            summary_writer.add_scalar("Loss/" + phase + "_Epoch",
                                      epoch_loss,
                                      global_step=epoch)
            if phase == "valid":
                valFile = os.path.join(cfg["logging"]["results_dir"],
                                       "val_res.json")
                with open(valFile, 'w') as f:
                    json.dump(val_predictions, f)
                annFile = cfg["paths"]["json_a_path_val"]
                quesFile = cfg["paths"]["json_q_path_val"]
                vqa = VQA(annFile, quesFile)
                vqaRes = vqa.loadRes(valFile, quesFile)
                vqaEval = VQAEval(vqa, vqaRes, n=2)
                vqaEval.evaluate()
                acc_overall = vqaEval.accuracy['overall']
                # acc_perQuestionType = vqaEval.accuracy['perQuestionType']
                # acc_perAnswerType = vqaEval.accuracy['perAnswerType']
                summary_writer.add_scalar("Acc/overall_" + phase + "_Epoch",
                                          acc_overall,
                                          global_step=epoch)
                # summary_writer.add_scalar("Acc/perQues" + phase + "_Epoch", epoch_loss, global_step=epoch)
                # summary_writer.add_scalar("Acc/" + phase + "_Epoch", epoch_loss, global_step=epoch)

            # epoch_acc_exp1 = running_corr_exp1.double() / len(data_loader[phase].dataset)      # multiple choice
            # epoch_acc_exp2 = running_corr_exp2.double() / len(data_loader[phase].dataset)      # multiple choice

            # print('| {} SET | Epoch [{:02d}/{:02d}], Loss: {:.4f}, Acc(Exp1): {:.4f}, Acc(Exp2): {:.4f} \n'
            #       .format(phase.upper(), epoch+1, args.num_epochs, epoch_loss, epoch_acc_exp1, epoch_acc_exp2))

            # Log the loss and accuracy in an epoch.
            # with open(os.path.join(args.log_dir, '{}-log-epoch-{:02}.txt')
            #           .format(phase, epoch+1), 'w') as f:
            #     f.write(str(epoch+1) + '\t'
            #             + str(epoch_loss) + '\t'
            #             + str(epoch_acc_exp1.item()) + '\t'
            #             + str(epoch_acc_exp2.item()))

        # Save the model check points.

        _save_checkpoint(net_name, model, optimizer, epoch, tr_iter, val_iter,
                         lr, cfg)