示例#1
0
def train(**kwargs):
    """
    训练
    训练的主要步骤如下:
    - 定义网络
    - 定义数据
    - 定义损失函数和优化器
    - 计算重要指标
    - 开始训练
      - 训练网络
      - 可视化各种指标
      - 计算在验证集上的指标
    :param kwargs:
    :return:
    """
    # 根据命令行更新参数
    opt.parse(kwargs)
    vis = Visualizer(opt.env)

    # Step 1 定义网络
    model = getattr(models, opt.model)()
    # model = models.ResNet34()
    if opt.load_model_path:
        model.load(opt.load_model_path)
    if opt.use_gpu:
        model.cuda()

    # step2: 数据
    train_data = DogCat(opt.train_data_root, train=True)
    val_data = DogCat(opt.train_data_root, train=False)
    train_dataloader = DataLoader(train_data,
                                  opt.batch_size,
                                  shuffle=True,
                                  num_workers=opt.num_workers)
    val_dataloader = DataLoader(val_data,
                                opt.batch_size,
                                shuffle=False,
                                num_workers=opt.num_workers)

    # step3: 目标函数和优化器
    criterion = t.nn.CrossEntropyLoss()
    lr = opt.lr
    optimizer = t.optim.Adam(model.parameters(),
                             lr=lr,
                             weight_decay=opt.weight_decay)

    # step4: 统计指标:平滑处理之后的损失,还有混淆矩阵
    loss_meter = meter.AverageValueMeter()
    confusion_matrix = meter.ConfusionMeter(2)
    previous_loss = 0.0

    # 训练
    for epoch in range(opt.max_epoch):
        loss_meter.reset()
        confusion_matrix.reset()
        ii = 0
        for data, label in train_dataloader:
            if opt.use_gpu:
                data = data.cuda()
                label = label.cuda()
            optimizer.zero_grad()
            score = model(data)
            loss = criterion(score, label.long())
            loss.backward()
            optimizer.step()

            # 更新统计指标和可视化
            num = loss.data.item()
            loss_meter.add(num)
            confusion_matrix.add(score.data, label.data)
            if ii % opt.print_feq == opt.print_feq - 1:
                vis.plot("loss", loss_meter.value()[0])
            ii += 1

        model.save()

        # 计算验证集上的指标及可视化
        val_cm, val_accuracy = val(model, val_dataloader)
        vis.plot('val_accuracy', val_accuracy)
        vis.log(
            "epoch:{epoch},lr:{lr},loss:{loss},train_cm:{train_cm},val_cm:{val_cm}"
            .format(epoch=epoch,
                    loss=loss_meter.value()[0],
                    val_cm=str(val_cm.value()),
                    train_cm=str(confusion_matrix.value()),
                    lr=lr))

        # 如果损失不再下降,则降低学习率
        if loss_meter.value()[0] > previous_loss:
            lr = lr * opt.lr_decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

        previous_loss = loss_meter.value()[0]
示例#2
0
        data_loader = load_image_datasets(batch_size=batch_size)

    dataset_size = len(data_loader)
    net.set_input()
    for batch_id, (x, _) in enumerate(data_loader):
        total_iter += 1

        net.forward(x)
        net.optimize_parameters()

        # Print and plot training infomation
        if total_iter % opt.print_freq == 0:  # print training losses and save logging information to the disk
            losses.append(net.get_loss())
            loss_dict = {'total_loss': losses[-1]}

            visualizer.plot(total_iter, losses, names=['total_loss'])
            visualizer.print(epoch, loss_dict, time.time() - iter_time)
            iter_time = time.time()

        # Save the checkpoint
        if total_iter % opt.save_epoch_freq == 0:
            net.save_networks('latest', checkpoint_dir)
            net.save_networks(epoch + 1, checkpoint_dir)

        # Display the result image
        if total_iter % opt.display_freq == 0:
            out_img_path = os.path.join(
                images_dir, "{}_{}.png".format(opt.model, epoch + 1))
            torchvision.utils.save_image(net.get_image(), out_img_path)
            visualizer.display_image(out_img_path)
示例#3
0
def main():
    opt = BaseOptions().parse()
    device = torch.device(
        "cuda:{}".format(opt.gpu_ids[0]) if len(opt.gpu_ids) > 0 else "cpu")

    # create save dir
    save_root = os.path.join('checkpoints', opt.name)
    if not os.path.isdir(save_root):
        os.makedirs(save_root)

    # get the data
    climate_data = ClimateDataset(opt=opt, phase='train')
    climate_data_loader = DataLoader(climate_data,
                                     batch_size=opt.batch_size,
                                     shuffle=True,
                                     num_workers=int(opt.n_threads))
    val_data = ClimateDataset(opt=opt, phase='val')
    val_data_loader = DataLoader(val_data,
                                 batch_size=opt.batch_size,
                                 shuffle=True,
                                 num_workers=int(opt.n_threads))

    # load the model
    model = SDVAE(opt=opt, device=device).to(device)

    initial_epoch = 0
    if opt.load_epoch >= 0:
        save_name = "epoch_{}.pth".format(opt.load_epoch)
        save_dir = os.path.join(save_root, save_name)
        model.load_state_dict(torch.load(save_dir))
        initial_epoch = opt.load_epoch + 1

    if opt.phase == 'train':
        # get optimizer
        model.train()
        optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)
        # todo try amsgrad=True
        viz = Visualizer(opt,
                         n_images=5,
                         training_size=len(climate_data_loader.dataset),
                         n_batches=len(climate_data_loader))

        for epoch_idx in range(opt.n_epochs):
            mse_list = []
            kld_list = []
            cycle_loss_list = []
            loss_list = []

            epoch = initial_epoch + epoch_idx
            img_id = 0

            # timing
            epoch_start_time = time.time()
            iter_data_start_time = time.time()
            iter_data_time = 0
            iter_time = 0

            for batch_idx, data in enumerate(climate_data_loader, 0):
                iter_start_time = time.time()

                optimizer.zero_grad()

                recon_pr, mu, log_var = model(
                    fine_pr=data['fine_pr'].to(device),
                    coarse_pr=data['coarse_pr'].to(device),
                    orog=data['orog'].to(device),
                    coarse_uas=data['coarse_uas'].to(device),
                    coarse_vas=data['coarse_vas'].to(device),
                    coarse_psl=data['coarse_psl'].to(device))

                mse, kld, cycle_loss, loss = model.loss_function(
                    recon_pr, data['fine_pr'].to(device), mu, log_var,
                    data['coarse_pr'].to(device))

                if loss.item() < float('inf'):
                    loss.backward()

                    mse_list += [mse.item()]
                    kld_list += [kld.item()]
                    cycle_loss_list += [cycle_loss.item()]
                    loss_list += [loss.item()]

                    optimizer.step()
                else:
                    print("inf loss")

                # timing
                iter_time += time.time() - iter_start_time
                iter_data_time += iter_start_time - iter_data_start_time

                if batch_idx % opt.log_interval == 0 and batch_idx > 0:
                    viz.print(epoch, batch_idx,
                              np.mean(mse_list[-opt.log_interval:]),
                              np.mean(kld_list[-opt.log_interval:]),
                              np.mean(cycle_loss_list[-opt.log_interval:]),
                              np.mean(loss_list[-opt.log_interval:]),
                              iter_time, iter_data_time,
                              sum(data['time']).item())

                    iter_data_time = 0
                    iter_time = 0

                if batch_idx % opt.plot_interval == 0 and batch_idx > 0:
                    img_id += 1
                    image_name = "Epoch{}_Image{}.jpg".format(epoch, img_id)
                    if opt.model == "mse_vae":
                        viz.plot(fine_pr=data['fine_pr'].to(device),
                                 recon_pr=recon_pr,
                                 image_name=image_name)
                    elif opt.model == "gamma_vae":
                        viz.plot(fine_pr=data['fine_pr'].to(device),
                                 recon_pr=recon_pr['p'] * recon_pr['alpha'] *
                                 recon_pr['beta'],
                                 image_name=image_name)
                    else:
                        raise ValueError("model {} is not implemented".format(
                            opt.model))

                if batch_idx % opt.save_latest_interval == 0 and batch_idx > 0:
                    save('latest', save_root, opt.gpu_ids, model)
                    print('saved latest epoch after {} iterations'.format(
                        batch_idx))

                if batch_idx % opt.eval_val_loss == 0 and batch_idx > 0:
                    # switch model to evaluation mode
                    model.eval()
                    # calculate val loss
                    val_loss_sum = np.zeros(
                        4)  # val_mse, val_kld, val_cycle_loss, val_loss
                    inf_losses = 0  # nr of sets where loss was inf
                    for batch_idx, data in enumerate(val_data_loader, 0):
                        recon_pr, mu, log_var = model(
                            fine_pr=data['fine_pr'].to(device),
                            coarse_pr=data['coarse_pr'].to(device),
                            orog=data['orog'].to(device),
                            coarse_uas=data['coarse_uas'].to(device),
                            coarse_vas=data['coarse_vas'].to(device),
                            coarse_psl=data['coarse_psl'].to(device))
                        val_loss = model.loss_function(
                            recon_pr, data['fine_pr'].to(device), mu, log_var,
                            data['coarse_pr'].to(device))
                        val_loss = [l.item() for l in val_loss]
                        if val_loss[-1] < float('inf'):
                            val_loss_sum += val_loss
                        else:
                            inf_losses += 1
                        if batch_idx >= opt.eval_val_loss:
                            break

                    n_val = opt.eval_val_loss - inf_losses
                    viz.print_eval(
                        epoch=epoch,
                        val_mse=val_loss_sum[0] / n_val,
                        val_kld=val_loss_sum[1] / n_val,
                        val_cycle_loss=val_loss_sum[2] / n_val,
                        val_loss=val_loss_sum[3] / n_val,
                        inf_losses=inf_losses,
                        train_mse=np.mean(mse_list[-opt.eval_val_loss:]),
                        train_kld=np.mean(kld_list[-opt.eval_val_loss:]),
                        train_cycle_loss=np.mean(
                            cycle_loss_list[-opt.eval_val_loss:]),
                        train_loss=np.mean(loss_list[-opt.eval_val_loss:]))

                    model.train()

                iter_data_start_time = time.time()

            if epoch % opt.save_interval == 0:
                save(epoch, save_root, opt.gpu_ids, model)
            else:
                print('val inf loss')
            epoch_time = time.time() - epoch_start_time
            viz.print_epoch(epoch=epoch,
                            epoch_mse=np.mean(mse_list),
                            epoch_kld=np.mean(kld_list),
                            epoch_cycle_loss=np.mean(cycle_loss_list),
                            epoch_loss=np.mean(loss_list),
                            epoch_time=epoch_time)
示例#4
0
def main():
    # load data
    train_loader = torch.utils.data.DataLoader(NYUDepthDataset(
        cfg.trainval_data_root,
        'train',
        sample_num=cfg.sample_num,
        superpixel=False,
        relative=False,
        transform=True),
                                               batch_size=cfg.batch_size,
                                               shuffle=True,
                                               num_workers=cfg.num_workers,
                                               drop_last=True)
    print('Train Batches:', len(train_loader))

    # val_loader = torch.utils.data.DataLoader(NYUDepthDataset(cfg.trainval_data_root, 'val', transform=True),
    #                                          batch_size=cfg.batch_size, shuffle=True,
    #                                          num_workers=cfg.num_workers, drop_last=True)
    # print('Validation Batches:', len(val_loader))

    test_set = NyuDepthMat(
        cfg.test_data_root,
        '/home/ans/PycharmProjects/SDFCN/data/testIdxs.txt')
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=cfg.batch_size,
                                              shuffle=True,
                                              drop_last=True)

    # train_set = NyuDepthMat(cfg.test_data_root, '/home/ans/PycharmProjects/SDFCN/data/trainIdxs.txt')
    # train_loader = torch.utils.data.DataLoader(train_set,
    #                                           batch_size=cfg.batch_size,
    #                                           shuffle=True, drop_last=True)
    # train_loader = test_loader
    #
    val_loader = test_loader
    # load model and weight
    # model = FCRN(cfg.batch_size)
    model = ResDUCNet(model=torchvision.models.resnet50(pretrained=False))
    init_upsample = False
    # print(model)

    loss_fn = berHu()

    if cfg.use_gpu:
        print('Use CUDA')
        model = model.cuda()
        # loss_fn = berHu().cuda()
        # loss_fn = torch.nn.MSELoss().cuda()
        loss_fn = torch.nn.L1Loss().cuda()

    start_epoch = 0
    best_val_err = 10e3

    if cfg.resume_from_file:
        if os.path.isfile(cfg.resume_file):
            print("=> loading checkpoint '{}'".format(cfg.resume_file))
            checkpoint = torch.load(cfg.resume_file)
            # start_epoch = checkpoint['epoch']
            start_epoch = 0
            # model.load_state_dict(checkpoint['state_dict'])
            model.load_state_dict(checkpoint['model_state'])
            # print("=> loaded checkpoint '{}' (epoch {})"
            #       .format(cfg.resume_file, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(cfg.resume_file))
    # else:
    #     if init_upsample:
    #         print('Loading weights from ', cfg.weights_file)
    #         # bone_state_dict = load_weights(model, cfg.weights_file, dtype)
    #         model.load_state_dict(load_weights(model, cfg.weights_file, dtype))
    #     else:
    #         print('Loading weights from ', cfg.resnet50_file)
    #         pretrained_dict = torch.load(cfg.resnet50_file)
    #         model_dict = model.state_dict()
    #         pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    #         model_dict.update(pretrained_dict)
    #         model.load_state_dict(model_dict)
    #     print('Weights loaded.')

    # val_error, val_rmse = validate(val_loader, model, loss_fn)
    # print('before train: val_error %f, rmse: %f' % (val_error, val_rmse))

    vis = Visualizer(cfg.env)
    # 4.Optim
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
    print("optimizer set.")
    scheduler = lr_scheduler.StepLR(optimizer,
                                    step_size=cfg.step,
                                    gamma=cfg.lr_decay)

    for epoch in range(cfg.num_epochs):

        scheduler.step()
        # print(optimizer.state_dict()['param_groups'][0]['lr'])
        print('Starting train epoch %d / %d, lr=%f' %
              (start_epoch + epoch + 1, cfg.num_epochs,
               optimizer.state_dict()['param_groups'][0]['lr']))

        model.train()
        running_loss = 0
        count = 0
        epoch_loss = 0

        for i_batch, sample_batched in enumerate(train_loader):
            input_var = Variable(sample_batched['rgb'].type(dtype))
            depth_var = Variable(sample_batched['depth'].type(dtype))

            optimizer.zero_grad()
            output = model(input_var)
            loss = loss_fn(output, depth_var)

            if i_batch % cfg.print_freq == cfg.print_freq - 1:
                print('{0} batches, loss:{1}'.format(i_batch + 1,
                                                     loss.data.cpu().item()))
                vis.plot('loss', loss.data.cpu().item())

            if i_batch % (cfg.print_freq * 10) == (cfg.print_freq * 10) - 1:
                vis.depth('pred', output)
                # vis.imshow('img', sample_batched['rgb'].type(dtype))
                vis.depth('depth', sample_batched['depth'].type(dtype))

            count += 1
            running_loss += loss.data.cpu().numpy()

            loss.backward()
            optimizer.step()

        epoch_loss = running_loss / count
        print('epoch loss:', epoch_loss)

        val_error, val_rmse = validate(val_loader, model, loss_fn, vis=vis)
        vis.plot('val_error', val_error)
        vis.plot('val_rmse', val_rmse)
        vis.log('epoch:{epoch},lr={lr},epoch_loss:{loss},val_error:{val_cm}'.
                format(epoch=start_epoch + epoch + 1,
                       loss=epoch_loss,
                       val_cm=val_error,
                       lr=optimizer.state_dict()['param_groups'][0]['lr']))

        if val_error < best_val_err:
            best_val_err = val_error
            if not os.path.exists(cfg.checkpoint_dir):
                os.mkdir(cfg.checkpoint_dir)

            torch.save(
                {
                    'epoch': start_epoch + epoch + 1,
                    'state_dict': model.state_dict(),
                    # 'optimitezer': optimizer.state_dict(),
                },
                os.path.join(
                    cfg.checkpoint_dir,
                    '{}_{}_epoch_{}_{}'.format(cfg.checkpoint, cfg.env,
                                               start_epoch + epoch + 1,
                                               cfg.checkpoint_postfix)))

    torch.save(
        {
            'epoch': start_epoch + epoch + 1,
            'state_dict': model.state_dict(),
            # 'optimitezer': optimizer.state_dict(),
        },
        os.path.join(
            cfg.checkpoint_dir,
            '{}_{}_epoch_{}_{}'.format(cfg.checkpoint, cfg.env,
                                       start_epoch + epoch + 1,
                                       cfg.checkpoint_postfix)))
示例#5
0
def main():
    # load data
    train_loader = torch.utils.data.DataLoader(NYUDepthDataset(
        cfg.trainval_data_root,
        'train',
        sample_num=cfg.sample_num,
        superpixel=False,
        relative=True,
        transform=True),
                                               batch_size=cfg.batch_size,
                                               shuffle=True,
                                               num_workers=cfg.num_workers,
                                               drop_last=True)
    print('Train Batches:', len(train_loader))

    # val_loader = torch.utils.data.DataLoader(NYUDepthDataset(cfg.trainval_data_root, 'val', transform=True),
    #                                          batch_size=cfg.batch_size, shuffle=True,
    #                                          num_workers=cfg.num_workers, drop_last=True)
    # print('Validation Batches:', len(val_loader))

    test_set = NyuDepthMat(
        cfg.test_data_root,
        '/home/ans/PycharmProjects/SDFCN/data/testIdxs.txt')
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=cfg.batch_size,
                                              shuffle=True,
                                              drop_last=True)

    # train_set = NyuDepthMat(cfg.test_data_root, '/home/ans/PycharmProjects/SDFCN/data/trainIdxs.txt')
    # train_loader = torch.utils.data.DataLoader(train_set,
    #                                           batch_size=cfg.batch_size,
    #                                           shuffle=True, drop_last=True)
    # train_loader = test_loader
    #
    val_loader = test_loader
    # load model and weight
    # model = FCRN(cfg.batch_size)
    model = DUCNet(model=torchvision.models.resnet50(pretrained=True))
    init_upsample = False
    # print(model)

    # loss_fn = berHu()

    if cfg.use_gpu:
        print('Use CUDA')
        model = model.cuda()
        berhu_loss = berHu().cuda()
        rela_loss = relativeloss().cuda()
        loss_fn = torch.nn.MSELoss().cuda()
    else:
        exit(0)

    start_epoch = 0
    # resume_from_file = False
    best_val_err = 10e3

    vis = Visualizer(cfg.env)
    print('Created visdom environment:', cfg.env)
    # 4.Optim
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
    print("optimizer set.")
    scheduler = lr_scheduler.StepLR(optimizer, step_size=cfg.step, gamma=0.1)

    for epoch in range(cfg.num_epochs):

        scheduler.step()
        print('Starting train epoch %d / %d, lr=%f' %
              (start_epoch + epoch + 1, cfg.num_epochs,
               optimizer.state_dict()['param_groups'][0]['lr']))

        model.train()
        running_loss = 0
        count = 0
        epoch_loss = 0

        for i_batch, sample_batched in enumerate(train_loader):
            input_var = Variable(sample_batched['rgb'].type(dtype))
            depth_var = Variable(sample_batched['depth'].type(dtype))

            optimizer.zero_grad()
            output = model(input_var)
            # loss = loss_fn(output, depth_var)
            loss1 = loss_fn(output, depth_var)
            Ah, Aw, Bh, Bw = generate_relative_pos(sample_batched['center'])

            loss2 = rela_loss(output[..., 0, Ah, Aw], output[..., 0, Bh, Bw],
                              sample_batched['ord'])
            loss = loss1 + loss2

            if i_batch % cfg.print_freq == cfg.print_freq - 1:
                print('{0} batches, loss:{1}, berhu:{2}, relative:{3}'.format(
                    i_batch + 1,
                    loss.data.cpu().item(),
                    loss1.data.cpu().item(),
                    loss2.data.cpu().item()))
                vis.plot('loss', loss.data.cpu().item())

            if i_batch % (cfg.print_freq * 10) == (cfg.print_freq * 10) - 1:
                vis.depth('pred', output)
                # vis.imshow('img', sample_batched['rgb'].type(dtype))
                vis.depth('depth', sample_batched['depth'].type(dtype))

            count += 1
            running_loss += loss.data.cpu().numpy()

            loss.backward()
            optimizer.step()

        epoch_loss = running_loss / count
        print('epoch loss:', epoch_loss)

        val_error, val_rmse = validate(val_loader, model, loss_fn, vis=vis)
        vis.plot('val_error', val_error)
        vis.plot('val_rmse', val_rmse)
        vis.log('epoch:{epoch},lr={lr},epoch_loss:{loss},val_error:{val_cm}'.
                format(epoch=start_epoch + epoch + 1,
                       loss=epoch_loss,
                       val_cm=val_error,
                       lr=optimizer.state_dict()['param_groups'][0]['lr']))

        if val_error < best_val_err:
            best_val_err = val_error
            if not os.path.exists(cfg.checkpoint_dir):
                os.mkdir(cfg.checkpoint_dir)

            torch.save(
                {
                    'epoch': start_epoch + epoch + 1,
                    'state_dict': model.state_dict(),
                    # 'optimitezer': optimizer.state_dict(),
                },
                os.path.join(
                    cfg.checkpoint_dir,
                    '{}_{}_epoch_{}_{}'.format(cfg.checkpoint, cfg.env,
                                               start_epoch + epoch + 1,
                                               cfg.checkpoint_postfix)))

    torch.save(
        {
            'epoch': start_epoch + epoch + 1,
            'state_dict': model.state_dict(),
            # 'optimitezer': optimizer.state_dict(),
        },
        os.path.join(
            cfg.checkpoint_dir,
            '{}_{}_epoch_{}_{}'.format(cfg.checkpoint, cfg.env,
                                       start_epoch + epoch + 1,
                                       cfg.checkpoint_postfix)))
        raise RuntimeError("Cannot find launch file {}".format(launch_file))

    input_arguments = {}
    for argument in args.launch_info[1::]:
        if ":=" not in argument:
            raise RuntimeError(
                "Input arguments must follow format 'ARG:VALUE'; got '{}']".
                format(argument))
        arg, value = argument.split(":=")
        input_arguments[arg] = value

    if args.verbose:
        level = logging.DEBUG
    elif args.quiet:
        level = logging.WARNING
    else:
        level = logging.INFO
    logging.basicConfig(level=level)
    logger = logging.getLogger(__name__)

    # parse launch file
    logger.info("Analyzing {} with arguments {}".format(
        launch_file, input_arguments))
    graph = utils.parser.build_graph(launch_file, input_arguments,
                                     args.verbose)

    # construct visualizer and plot
    visualizer = Visualizer(launch_file, graph)
    if not args.noplot:
        visualizer.plot()
示例#7
0
def train(opt):
    # 更新配置
    vis = Visualizer(opt.env)

    # step1: 加载模型
    model = getattr(models, opt.model)()
    if opt.load_model_path:
        model.load(opt.load_model_path)
    if opt.use_gpu: model.cuda()

    # step2: 数据
    train_data = DogCat(opt.train_data_root, train=True)
    val_data = DogCat(opt.train_data_root, train=False)
    train_dataloader = DataLoader(train_data,
                                  opt.batch_size,
                                  shuffle=True,
                                  num_workers=opt.num_workers)
    # 验证集 data 不做变换
    val_dataloader = DataLoader(val_data,
                                opt.batch_size,
                                shuffle=False,
                                num_workers=opt.num_workers)

    # step3: 目标函数和优化器
    # 交叉熵损失
    criterion = t.nn.CrossEntropyLoss()
    # 学习率
    lr = opt.lr
    # Adam 优化器
    optimizer = t.optim.Adam(model.parameters(),
                             lr=lr,
                             weight_decay=opt.weight_decay)

    # step4: 统计指标:平滑处理之后的损失,还有混淆矩阵
    # 计算所有数的平均值和标准差,用来统计一个 epoch 中损失的平均值
    loss_meter = meter.AverageValueMeter()
    # 统计分类问题中的分类情况,错误矩阵
    confusion_matrix = meter.ConfusionMeter(2)
    previous_loss = 1e100

    # 训练
    for epoch in range(opt.max_epoch):

        loss_meter.reset()
        confusion_matrix.reset()

        # index,(data, label)
        for ii, (data, label) in enumerate(train_dataloader):

            # 训练模型
            input = data
            target = label
            if opt.use_gpu:
                input = input.cuda()
                target = target.cuda()
            # 梯度清零
            optimizer.zero_grad()
            score = model(input)
            # 损失
            loss = criterion(score, target)
            loss.backward()
            # 优化步骤
            optimizer.step()

            # 更新统计指标以及可视化
            loss_meter.add(loss.item())
            confusion_matrix.add(score.data, target.data)

            if ii % opt.print_freq == opt.print_freq - 1:
                vis.plot('loss', loss_meter.value()[0])

        # checkpoint
        model.save()

        # 计算验证集上的指标及可视化
        val_cm, val_accuracy = val(model, val_dataloader)
        vis.plot('val_accuracy', val_accuracy)
        vis.log(
            "epoch:{epoch},lr:{lr},loss:{loss},train_cm:{train_cm},val_cm:{val_cm}"
            .format(epoch=epoch,
                    loss=loss_meter.value()[0],
                    val_cm=str(val_cm.value()),
                    train_cm=str(confusion_matrix.value()),
                    lr=lr))

        # 如果损失不再下降,则降低学习率
        if loss_meter.value()[0] > previous_loss:
            lr = lr * opt.lr_decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

        previous_loss = loss_meter.value()[0]