Example #1
0
    def run(self, trial=None):
        params = self.params
        is_ps = params["params_search"]
        term_size = shutil.get_terminal_size().columns

        # Show Hyper Params
        if trial is not None:
            sequence_length = params["sequence_length"]
            hyper_params = get_hyper_params(trial, sequence_length)
            self.params.update(hyper_params)
            0 < trial.number and print("\n")
            out_str = " Trial: {} ".format(trial.number + 1)
            print(out_str.center(term_size, "="))
            print("\n" + " Current Hyper Params ".center(term_size, "-"))
            print([i for i in sorted(hyper_params.items())])
            print("-" * shutil.get_terminal_size().columns + "\n")

        # Generate Batch Iterators
        train_loader = data.Iterator(
            self.train,
            batch_size=params["batch_size"],
            device=params["device"],
            train=True,
        )

        valid_loader = data.Iterator(
            self.valid,
            batch_size=params["batch_size"],
            device=params["device"],
            train=False,
            sort=False,
        )

        if not is_ps:
            test_loader = data.Iterator(
                self.test,
                batch_size=params["batch_size"],
                device=params["device"],
                train=False,
                sort=False,
            )

        # Calc Batch Size
        params["train_batch_total"] = math.ceil(
            len(self.train) / params["batch_size"])

        params["valid_batch_total"] = math.ceil(
            len(self.valid) / params["batch_size"])

        if not is_ps:
            params["test_batch_total"] = math.ceil(
                len(self.test) / params["batch_size"])

        # Define xml-cnn model
        model = xml_cnn(params, self.TEXT.vocab.vectors)
        model = model.to(params["device"])
        epochs = params["epochs"]
        learning_rate = params["learning_rate"]

        # Define Optimizer
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)

        if not is_ps:
            ms = [int(epochs * 0.5), int(epochs * 0.75)]
            scheduler = MultiStepLR(optimizer, milestones=ms, gamma=0.1)

        best_epoch = 1
        num_of_unchanged = 1

        measure = params["measure"]
        measure = "f1" in measure and measure[:-3] or measure
        if not is_ps:
            save_best_model_path = params["model_cache_path"] + "best_model.pkl"
        # 学習
        for epoch in range(1, epochs + 1):
            if self.params["params_search"]:
                out_str = " Epoch: {} ".format(epoch)
            else:
                lr = scheduler.get_last_lr()[0]
                term_size = shutil.get_terminal_size().columns
                out_str = " Epoch: {} (lr={:.20f}) ".format(epoch, lr)
            # out_str = " Epoch: {} ".format(epoch)
            print(out_str.center(term_size, "-"))

            # 学習
            training(params, model, train_loader, optimizer)

            # 検証
            val_measure_epoch_i = validating_testing(params, model,
                                                     valid_loader)

            # 最良モデルの記録と保存
            if epoch < 2:
                best_val_measure = val_measure_epoch_i
                (not is_ps) and torch.save(model, save_best_model_path)
            elif best_val_measure < val_measure_epoch_i:
                best_epoch = epoch
                best_val_measure = val_measure_epoch_i
                num_of_unchanged = 1
                (not is_ps) and torch.save(model, save_best_model_path)
            else:
                num_of_unchanged += 1

            # Show Best Epoch
            out_str = " Best Epoch: {} (" + measure + ": {:.10f}, "
            out_str = out_str.format(best_epoch, best_val_measure)
            if bool(params["early_stopping"]):
                remaining = params["early_stopping"] - num_of_unchanged
                out_str += "ES Remaining: {}) "
                out_str = out_str.format(remaining)
            else:
                out_str += "ES: False) "
            print("\n" + out_str.center(term_size, "-") + "\n")

            # Early Stopping
            if early_stopping(num_of_unchanged, params["early_stopping"]):
                break

            (not is_ps) and scheduler.step()

        if is_ps:
            # Show Best Trials
            if self.best_trial_measure < best_val_measure:
                self.best_trial_measure = best_val_measure
                self.num_of_trial = trial.number + 1
            out_str = " Best Trial: {} (" + measure + ": {:.20f}) "
            out_str = out_str.format(self.num_of_trial,
                                     self.best_trial_measure)
            print(out_str.center(term_size, "="))
        else:
            # Testing on Best Epoch Model
            model = torch.load(save_best_model_path)
            test_measure = validating_testing(params,
                                              model,
                                              test_loader,
                                              is_valid=False)
            out_str = " Finished "
            print("\n\n" + out_str.center(term_size, "=") + "\n")

            out_str = " Best Epoch: {} (" + measure + ": {:.20f}) "
            out_str = out_str.format(best_epoch, test_measure)
            print("\n" + out_str.center(term_size, "-") + "\n")

        return 1 - best_val_measure
Example #2
0
def train_model_multistage_lowlight():

    device = DEVICE
    #准备数据
    train_set = HsiCubicTrainDataset('./data/train_lowlight_patchsize32/')
    #print('trainset32 training example:', len(train_set32))

    #train_set_64 = HsiCubicTrainDataset('./data/train_lowlight_patchsize64/')

    #train_set_list = [train_set32, train_set_64]
    #train_set = ConcatDataset(train_set_list) #里面的样本大小必须是一致的,否则会连接失败
    print('total training example:', len(train_set))

    train_loader = DataLoader(dataset=train_set,
                              batch_size=BATCH_SIZE,
                              shuffle=True)

    #加载测试label数据
    mat_src_path = './data/test_lowlight/origin/soup_bigcorn_orange_1ms.mat'
    test_label_hsi = scio.loadmat(mat_src_path)['label']

    #加载测试数据
    batch_size = 1
    test_data_dir = './data/test_lowlight/cubic/'
    test_set = HsiCubicLowlightTestDataset(test_data_dir)
    test_dataloader = DataLoader(dataset=test_set,
                                 batch_size=batch_size,
                                 shuffle=False)

    batch_size, channel, width, height = next(iter(test_dataloader))[0].shape

    band_num = len(test_dataloader)
    denoised_hsi = np.zeros((width, height, band_num))

    #创建模型
    net = MultiStageHSID(K)
    init_params(net)
    #net = nn.DataParallel(net).to(device)
    net = net.to(device)

    #创建优化器
    #hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE, betas=(0.9, 0,999))
    hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE)
    scheduler = MultiStepLR(hsid_optimizer, milestones=[40, 60, 80], gamma=0.1)

    #定义loss 函数
    #criterion = nn.MSELoss()

    global tb_writer
    tb_writer = get_summary_writer(log_dir='logs')

    gen_epoch_loss_list = []

    cur_step = 0

    first_batch = next(iter(train_loader))

    best_psnr = 0
    best_epoch = 0
    best_iter = 0
    start_epoch = 1
    num_epoch = 100

    for epoch in range(start_epoch, num_epoch + 1):
        epoch_start_time = time.time()
        scheduler.step()
        print(epoch, 'lr={:.6f}'.format(scheduler.get_last_lr()[0]))
        print(scheduler.get_lr())
        gen_epoch_loss = 0

        net.train()
        #for batch_idx, (noisy, label) in enumerate([first_batch] * 300):
        for batch_idx, (noisy, cubic, label) in enumerate(train_loader):
            #print('batch_idx=', batch_idx)
            noisy = noisy.to(device)
            label = label.to(device)
            cubic = cubic.to(device)

            hsid_optimizer.zero_grad()
            #denoised_img = net(noisy, cubic)
            #loss = loss_fuction(denoised_img, label)

            residual = net(noisy, cubic)
            #loss = loss_fuction(residual, label-noisy)
            loss = np.sum([
                loss_fuction(residual[j], label) for j in range(len(residual))
            ])
            loss.backward()  # calcu gradient
            hsid_optimizer.step()  # update parameter

            gen_epoch_loss += loss.item()

            if cur_step % display_step == 0:
                if cur_step > 0:
                    print(
                        f"Epoch {epoch}: Step {cur_step}: Batch_idx {batch_idx}: MSE loss: {loss.item()}"
                    )
                else:
                    print("Pretrained initial state")

            tb_writer.add_scalar("MSE loss", loss.item(), cur_step)

            #step ++,每一次循环,每一个batch的处理,叫做一个step
            cur_step += 1

        gen_epoch_loss_list.append(gen_epoch_loss)
        tb_writer.add_scalar("mse epoch loss", gen_epoch_loss, epoch)

        #scheduler.step()
        #print("Decaying learning rate to %g" % scheduler.get_last_lr()[0])

        torch.save(
            {
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict(),
            }, f"checkpoints/hsid_multistage_patchsize64_{epoch}.pth")

        #测试代码
        net.eval()
        for batch_idx, (noisy_test, cubic_test,
                        label_test) in enumerate(test_dataloader):
            noisy_test = noisy_test.type(torch.FloatTensor)
            label_test = label_test.type(torch.FloatTensor)
            cubic_test = cubic_test.type(torch.FloatTensor)

            noisy_test = noisy_test.to(DEVICE)
            label_test = label_test.to(DEVICE)
            cubic_test = cubic_test.to(DEVICE)

            with torch.no_grad():

                residual = net(noisy_test, cubic_test)
                denoised_band = noisy_test + residual[0]

                denoised_band_numpy = denoised_band.cpu().numpy().astype(
                    np.float32)
                denoised_band_numpy = np.squeeze(denoised_band_numpy)

                denoised_hsi[:, :, batch_idx] = denoised_band_numpy

                if batch_idx == 49:
                    residual_squeezed = torch.squeeze(residual[0], axis=0)
                    denoised_band_squeezed = torch.squeeze(denoised_band,
                                                           axis=0)
                    label_test_squeezed = torch.squeeze(label_test, axis=0)
                    noisy_test_squeezed = torch.squeeze(noisy_test, axis=0)
                    tb_writer.add_image(f"images/{epoch}_restored",
                                        denoised_band_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_residual",
                                        residual_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_label",
                                        label_test_squeezed,
                                        1,
                                        dataformats='CHW')
                    tb_writer.add_image(f"images/{epoch}_noisy",
                                        noisy_test_squeezed,
                                        1,
                                        dataformats='CHW')

        psnr = PSNR(denoised_hsi, test_label_hsi)
        ssim = SSIM(denoised_hsi, test_label_hsi)
        sam = SAM(denoised_hsi, test_label_hsi)

        #计算pnsr和ssim
        print("=====averPSNR:{:.3f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".
              format(psnr, ssim, sam))
        tb_writer.add_scalars("validation metrics", {
            'average PSNR': psnr,
            'average SSIM': ssim,
            'avarage SAM': sam
        }, epoch)  #通过这个我就可以看到,那个epoch的性能是最好的

        #保存best模型
        if psnr > best_psnr:
            best_psnr = psnr
            best_epoch = epoch
            best_iter = cur_step
            torch.save(
                {
                    'epoch': epoch,
                    'gen': net.state_dict(),
                    'gen_opt': hsid_optimizer.state_dict(),
                }, f"checkpoints/hsid_multistage_patchsize64_best.pth")

        print(
            "[epoch %d it %d PSNR: %.4f --- best_epoch %d best_iter %d Best_PSNR %.4f]"
            % (epoch, cur_step, psnr, best_epoch, best_iter, best_psnr))

        print(
            "------------------------------------------------------------------"
        )
        print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".
              format(epoch,
                     time.time() - epoch_start_time, gen_epoch_loss,
                     scheduler.get_lr()[0]))
        print(
            "------------------------------------------------------------------"
        )

        #保存当前模型
        torch.save(
            {
                'epoch': epoch,
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict()
            }, os.path.join('./checkpoints', "model_latest.pth"))
    tb_writer.close()
class RunnerFSL(object):
    def __init__(self, data_train, classes):
        # all data
        self.data_train = data_train
        self.task_train = FSLDataset(self.data_train, classes,
                                     Config.fsl_num_way, Config.fsl_num_shot)
        self.task_train_loader = DataLoader(self.task_train,
                                            Config.fsl_batch_size,
                                            True,
                                            num_workers=Config.num_workers)

        # model
        self.matching_net = RunnerTool.to_cuda(Config.fsl_matching_net)
        self.matching_net = RunnerTool.to_cuda(
            nn.DataParallel(self.matching_net))
        cudnn.benchmark = True
        RunnerTool.to_cuda(self.matching_net.apply(RunnerTool.weights_init))
        self.norm = Normalize(2)

        # optim
        self.matching_net_optim = torch.optim.Adam(
            self.matching_net.parameters(), lr=Config.fsl_learning_rate)
        self.matching_net_scheduler = MultiStepLR(self.matching_net_optim,
                                                  Config.fsl_lr_schedule,
                                                  gamma=1 / 3)

        # loss
        self.fsl_loss = RunnerTool.to_cuda(nn.MSELoss())

        # Eval
        self.test_tool_fsl = FSLTestTool(
            self.matching_test,
            data_root=Config.data_root,
            num_way=Config.fsl_num_way,
            num_shot=Config.fsl_num_shot,
            episode_size=Config.fsl_episode_size,
            test_episode=Config.fsl_test_episode,
            transform=self.task_train.transform_test)
        pass

    def load_model(self):
        if os.path.exists(Config.mn_dir):
            self.matching_net.load_state_dict(torch.load(Config.mn_dir))
            Tools.print("load matching net success from {}".format(
                Config.mn_dir))
        pass

    def matching(self, task_data):
        data_batch_size, data_image_num, data_num_channel, data_width, data_weight = task_data.shape
        data_x = task_data.view(-1, data_num_channel, data_width, data_weight)
        net_out = self.matching_net(data_x)
        z = net_out.view(data_batch_size, data_image_num, -1)

        # 特征
        z_support, z_query = z.split(Config.fsl_num_shot * Config.fsl_num_way,
                                     dim=1)
        z_batch_size, z_num, z_dim = z_support.shape
        z_support = z_support.view(z_batch_size,
                                   Config.fsl_num_way * Config.fsl_num_shot,
                                   z_dim)
        z_query_expand = z_query.expand(
            z_batch_size, Config.fsl_num_way * Config.fsl_num_shot, z_dim)

        # 相似性
        z_support = self.norm(z_support)
        similarities = torch.sum(z_support * z_query_expand, -1)
        similarities = torch.softmax(similarities, dim=1)
        similarities = similarities.view(z_batch_size, Config.fsl_num_way,
                                         Config.fsl_num_shot)
        predicts = torch.mean(similarities, dim=-1)
        return predicts

    def matching_test(self, samples, batches):
        batch_num, _, _, _ = batches.shape

        sample_z = self.matching_net(samples)  # 5x64*5*5
        batch_z = self.matching_net(batches)  # 75x64*5*5
        z_support = sample_z.view(Config.fsl_num_way * Config.fsl_num_shot, -1)
        z_query = batch_z.view(batch_num, -1)
        _, z_dim = z_query.shape

        z_support_expand = z_support.unsqueeze(0).expand(
            batch_num, Config.fsl_num_way * Config.fsl_num_shot, z_dim)
        z_query_expand = z_query.unsqueeze(1).expand(
            batch_num, Config.fsl_num_way * Config.fsl_num_shot, z_dim)

        # 相似性
        z_support_expand = self.norm(z_support_expand)
        similarities = torch.sum(z_support_expand * z_query_expand, -1)
        similarities = torch.softmax(similarities, dim=1)
        similarities = similarities.view(batch_num, Config.fsl_num_way,
                                         Config.fsl_num_shot)
        predicts = torch.mean(similarities, dim=-1)
        return predicts

    def train(self):
        Tools.print()
        Tools.print("Training...")
        best_accuracy = 0.0

        for epoch in range(1, 1 + Config.fsl_train_epoch):
            self.matching_net.train()

            Tools.print()
            all_loss, is_ok_total, is_ok_acc = 0.0, 0, 0
            for task_data, task_labels, task_index, task_ok in tqdm(
                    self.task_train_loader):
                task_data, task_labels = RunnerTool.to_cuda(
                    task_data), RunnerTool.to_cuda(task_labels)

                ###########################################################################
                # 1 calculate features
                relations = self.matching(task_data)

                # 3 loss
                loss = self.fsl_loss(relations, task_labels)
                all_loss += loss.item()

                # 4 backward
                self.matching_net.zero_grad()
                loss.backward()
                self.matching_net_optim.step()

                # is ok
                is_ok_acc += torch.sum(torch.cat(task_ok))
                is_ok_total += torch.prod(
                    torch.tensor(torch.cat(task_ok).shape))
                ###########################################################################
                pass

            ###########################################################################
            # print
            Tools.print("{:6} loss:{:.3f} ok:{:.3f}({}/{}) lr:{}".format(
                epoch, all_loss / len(self.task_train_loader),
                int(is_ok_acc) / int(is_ok_total), is_ok_acc, is_ok_total,
                self.matching_net_scheduler.get_last_lr()))
            self.matching_net_scheduler.step()
            ###########################################################################

            ###########################################################################
            # Val
            if epoch % Config.fsl_val_freq == 0:
                self.matching_net.eval()

                val_accuracy = self.test_tool_fsl.val(episode=epoch,
                                                      is_print=True,
                                                      has_test=False)

                if val_accuracy > best_accuracy:
                    best_accuracy = val_accuracy
                    torch.save(self.matching_net.state_dict(), Config.mn_dir)
                    Tools.print("Save networks for epoch: {}".format(epoch))
                    pass
                pass
            ###########################################################################
            pass

        pass

    pass
Example #4
0
def main_train(args):
    # 获取命令参数
    if args.resume_training is not None:
        if not os.path.isfile(args.resume_training):
            print(f"{args.resume_training} 不是一个合法的文件!")
            return
        else:
            print(f"加载检查点:{args.resume_training}")
    cuda = args.cuda
    resume = args.resume_training
    batch_size = args.batch_size
    milestones = args.milestones
    lr = args.lr
    total_epoch = args.epochs
    resume_checkpoint_filename = args.resume_training
    best_model_name = args.best_model_name
    checkpoint_name = args.best_model_name
    data_path = args.data_path
    start_epoch = 1

    print("加载数据....")
    dataset = ISONetData(data_path=data_path)
    dataset_test = ISONetData(data_path=data_path, train=False)
    data_loader = DataLoader(dataset=dataset,
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=6,
                             pin_memory=True)
    data_loader_test = DataLoader(dataset=dataset_test,
                                  batch_size=batch_size,
                                  shuffle=False)
    print("成功加载数据...")
    print(f"训练集数量: {len(dataset)}")
    print(f"验证集数量: {len(dataset_test)}")

    model_path = Path("models")
    checkpoint_path = model_path.joinpath("checkpoint")

    if not model_path.exists():
        model_path.mkdir()
    if not checkpoint_path.exists():
        checkpoint_path.mkdir()

    if torch.cuda.is_available():
        device = torch.cuda.current_device()
    else:
        print("cuda 无效!")
        cuda = False

    net = ISONet()
    criterion = nn.MSELoss(reduction="mean")
    optimizer = optim.Adam(net.parameters(), lr=lr)

    if cuda:
        net = net.to(device=device)
        criterion = criterion.to(device=device)

    scheduler = MultiStepLR(optimizer=optimizer,
                            milestones=milestones,
                            gamma=0.1)
    writer = SummaryWriter()

    # 恢复训练
    if resume:
        print("恢复训练中...")
        checkpoint = torch.load(
            checkpoint_path.joinpath(resume_checkpoint_filename))
        net.load_state_dict(checkpoint["net"])
        optimizer.load_state_dict((checkpoint["optimizer"]))
        scheduler.load_state_dict(checkpoint["scheduler"])
        resume_epoch = checkpoint["epoch"]
        best_test_loss = checkpoint["best_test_loss"]

        start_epoch = resume_epoch + 1
        print(f"从第[{start_epoch}]轮开始训练...")
        print(f"上一次的损失为: [{best_test_loss}]...")
    else:
        # 初始化权重
        for m in net.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    if not locals().get("best_test_loss"):
        best_test_loss = 0

    record = 0
    for epoch in range(start_epoch, total_epoch):
        print(f"开始第 [{epoch}] 轮训练...")
        net.train()
        writer.add_scalar("Train/Learning Rate",
                          scheduler.get_last_lr()[0], epoch)
        for i, (data, label) in enumerate(data_loader, 0):
            if i == 0:
                start_time = int(time.time())
            if cuda:
                data = data.to(device=device)
                label = label.to(device=device)
            label = label.unsqueeze(1)

            optimizer.zero_grad()

            output = net(data)

            loss = criterion(output, label)

            loss.backward()

            optimizer.step()
            if i % 500 == 499:
                end_time = int(time.time())
                use_time = end_time - start_time

                print(
                    f">>> epoch[{epoch}] loss[{loss:.4f}]  {i * batch_size}/{len(dataset)} lr{scheduler.get_last_lr()} ",
                    end="")
                left_time = ((len(dataset) - i * batch_size) / 500 /
                             batch_size) * (end_time - start_time)
                print(
                    f"耗费时间:[{end_time - start_time:.2f}]秒,估计剩余时间: [{left_time:.2f}]秒"
                )
                start_time = end_time
            # 记录到 tensorboard
            if i % 128 == 127:
                writer.add_scalar("Train/loss", loss, record)
                record += 1

        # validate
        print("测试模型...")
        net.eval()

        test_loss = 0
        with torch.no_grad():
            loss_t = nn.MSELoss(reduction="mean")
            if cuda:
                loss_t = loss_t.to(device)
            for data, label in data_loader_test:
                if cuda:
                    data = data.to(device)
                    label = label.to(device)
                # expand dim
                label = label.unsqueeze_(1)
                predict = net(data)
                # sum up batch loss
                test_loss += loss_t(predict, label).item()

        test_loss /= len(dataset_test)
        test_loss *= batch_size
        print(
            f'\nTest Data: Average batch[{batch_size}] loss: {test_loss:.4f}\n'
        )
        scheduler.step()

        writer.add_scalar("Test/Loss", test_loss, epoch)

        checkpoint = {
            "net": net.state_dict(),
            "optimizer": optimizer.state_dict(),
            "epoch": epoch,
            "scheduler": scheduler.state_dict(),
            "best_test_loss": best_test_loss
        }

        if best_test_loss == 0:
            print("保存模型中...")
            torch.save(net.state_dict(), model_path.joinpath(best_model_name))
            best_test_loss = test_loss
        else:
            # 保存更好的模型
            if test_loss < best_test_loss:
                print("获取到更好的模型,保存中...")
                torch.save(net.state_dict(),
                           model_path.joinpath(best_model_name))
                best_test_loss = test_loss
        # 保存检查点
        if epoch % args.save_every_epochs == 0:
            c_time = time2str()
            torch.save(
                checkpoint,
                checkpoint_path.joinpath(
                    f"{checkpoint_name}_{epoch}_{c_time}.cpth"))
            print(f"保存检查点: [{checkpoint_name}_{epoch}_{c_time}.cpth]...\n")
class Runner(object):
    def __init__(self):
        self.best_accuracy = 0.0

        # all data
        self.data_train = CIFARDataset.get_data_all(Config.data_root)
        self.task_train = CIFARDataset(self.data_train, Config.num_way,
                                       Config.num_shot)
        self.task_train_loader = DataLoader(self.task_train,
                                            Config.batch_size,
                                            True,
                                            num_workers=Config.num_workers)

        # model
        self.matching_net = RunnerTool.to_cuda(Config.matching_net)
        RunnerTool.to_cuda(self.matching_net.apply(RunnerTool.weights_init))
        self.norm = Normalize(2)

        # loss
        self.loss = RunnerTool.to_cuda(nn.MSELoss())

        # optim
        self.matching_net_optim = torch.optim.Adam(
            self.matching_net.parameters(), lr=Config.learning_rate)
        self.matching_net_scheduler = MultiStepLR(self.matching_net_optim,
                                                  Config.train_epoch_lr,
                                                  gamma=0.5)

        self.test_tool = FSLTestTool(self.matching_test,
                                     data_root=Config.data_root,
                                     num_way=Config.num_way_test,
                                     num_shot=Config.num_shot,
                                     episode_size=Config.episode_size,
                                     test_episode=Config.test_episode,
                                     transform=self.task_train.transform_test)
        pass

    def load_model(self):
        if os.path.exists(Config.mn_dir):
            self.matching_net.load_state_dict(torch.load(Config.mn_dir))
            Tools.print("load proto net success from {}".format(Config.mn_dir),
                        txt_path=Config.log_file)
        pass

    def matching(self, task_data):
        data_batch_size, data_image_num, data_num_channel, data_width, data_weight = task_data.shape
        data_x = task_data.view(-1, data_num_channel, data_width, data_weight)
        net_out = self.matching_net(data_x)
        z = net_out.view(data_batch_size, data_image_num, -1)

        # 特征
        z_support, z_query = z.split(Config.num_shot * Config.num_way, dim=1)
        z_batch_size, z_num, z_dim = z_support.shape
        z_support = z_support.view(z_batch_size,
                                   Config.num_way * Config.num_shot, z_dim)
        z_query_expand = z_query.expand(z_batch_size,
                                        Config.num_way * Config.num_shot,
                                        z_dim)

        # 相似性
        z_support = self.norm(z_support)
        similarities = torch.sum(z_support * z_query_expand, -1)
        similarities = torch.softmax(similarities, dim=1)
        similarities = similarities.view(z_batch_size, Config.num_way,
                                         Config.num_shot)
        predicts = torch.mean(similarities, dim=-1)
        return predicts

    def matching_test(self, samples, batches):
        batch_num, _, _, _ = batches.shape

        sample_z = self.matching_net(samples)  # 5x64*5*5
        batch_z = self.matching_net(batches)  # 75x64*5*5
        z_support = sample_z.view(Config.num_way_test * Config.num_shot, -1)
        z_query = batch_z.view(batch_num, -1)
        _, z_dim = z_query.shape

        z_support_expand = z_support.unsqueeze(0).expand(
            batch_num, Config.num_way_test * Config.num_shot, z_dim)
        z_query_expand = z_query.unsqueeze(1).expand(
            batch_num, Config.num_way_test * Config.num_shot, z_dim)

        # 相似性
        z_support_expand = self.norm(z_support_expand)
        similarities = torch.sum(z_support_expand * z_query_expand, -1)
        similarities = torch.softmax(similarities, dim=1)
        similarities = similarities.view(batch_num, Config.num_way_test,
                                         Config.num_shot)
        predicts = torch.mean(similarities, dim=-1)
        return predicts

    def train(self):
        Tools.print()
        Tools.print("Training...", txt_path=Config.log_file)

        for epoch in range(1, 1 + Config.train_epoch):
            self.matching_net.train()

            Tools.print()
            all_loss = 0.0
            for task_data, task_labels, task_index in tqdm(
                    self.task_train_loader):
                task_data, task_labels = RunnerTool.to_cuda(
                    task_data), RunnerTool.to_cuda(task_labels)

                # 1 calculate features
                predicts = self.matching(task_data)

                # 2 loss
                loss = self.loss(predicts, task_labels)
                all_loss += loss.item()

                # 3 backward
                self.matching_net.zero_grad()
                loss.backward()
                self.matching_net_optim.step()
                ###########################################################################
                pass

            ###########################################################################
            # print
            Tools.print("{:6} loss:{:.3f} lr:{}".format(
                epoch, all_loss / len(self.task_train_loader),
                self.matching_net_scheduler.get_last_lr()),
                        txt_path=Config.log_file)

            self.matching_net_scheduler.step()
            ###########################################################################

            ###########################################################################
            # Val
            if epoch % Config.val_freq == 0:
                Tools.print()
                Tools.print("Test {} {} .......".format(
                    epoch, Config.model_name),
                            txt_path=Config.log_file)
                self.matching_net.eval()

                val_accuracy = self.test_tool.val(episode=epoch,
                                                  is_print=True,
                                                  has_test=False)
                if val_accuracy > self.best_accuracy:
                    self.best_accuracy = val_accuracy
                    torch.save(self.matching_net.state_dict(), Config.mn_dir)
                    Tools.print("Save networks for epoch: {}".format(epoch),
                                txt_path=Config.log_file)
                    pass
                pass
            ###########################################################################
            pass

        pass

    pass
Example #6
0
class Trainer:
    def __init__(self, cfg):
        self.cfg = cfg
        self.paths = cfg['paths']
        self.net_params = cfg['net']
        self.train_params = cfg['train']
        self.trans_params = cfg['train']['transforms']

        self.checkpoints = self.paths['checkpoints']
        Path(self.checkpoints).mkdir(parents=True, exist_ok=True)
        shutil.copyfile('config.yaml', f'{self.checkpoints}/config.yaml')

        self.update_interval = self.paths['update_interval']

        # amp training
        self.use_amp = self.train_params['mixed_precision']
        self.scaler = GradScaler() if self.use_amp else None

        # data setup
        dataset_name = self.train_params['dataset']
        self.use_multi = dataset_name == 'multi'
        print(f'Using dataset: {dataset_name}')
        self.train_dataset = get_pedestrian_dataset(
            dataset_name,
            self.paths,
            augment=get_train_transforms(self.trans_params),
            mode='train',
            multi_datasets=self.train_params['multi_datasets']
            if self.use_multi else None)
        print(f'Train dataset: {len(self.train_dataset)} samples')

        self.val_dataset = get_pedestrian_dataset(
            dataset_name,
            self.paths,
            augment=get_val_transforms(self.trans_params),
            mode='val',
            multi_datasets=self.train_params['multi_datasets']
            if self.use_multi else None)
        print(f'Val dataset: {len(self.val_dataset)} samples')

        tests_data = self.train_params['test_datasets']
        self.test_datasets = [
            get_pedestrian_dataset(d_name,
                                   self.paths,
                                   augment=get_test_transforms(
                                       self.trans_params),
                                   mode='test') for d_name in tests_data
        ]

        self.criterion = AnchorFreeLoss(self.train_params)

        self.writer = Writer(self.paths['log_dir'])
        print('Tensorboard logs are saved to: {}'.format(
            self.paths['log_dir']))

        self.sched_type = self.train_params['scheduler']
        self.scheduler = None
        self.optimizer = None

    def save_checkpoints(self, epoch, net):
        path = osp.join(self.checkpoints, f'Epoch_{epoch}.pth')
        torch.save(
            {
                'epoch': epoch,
                'net_state_dict': net.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict()
            }, path)

    def train(self):
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = True

        batch_size = self.train_params['batch_size']
        self.batch_size = batch_size
        num_workers = self.train_params['num_workers']
        pin_memory = self.train_params['pin_memory']
        print('Batch-size = {}'.format(batch_size))

        train_loader = DataLoader(self.train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=num_workers,
                                  pin_memory=pin_memory,
                                  drop_last=True)
        val_loader = DataLoader(self.val_dataset,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=num_workers,
                                pin_memory=pin_memory,
                                drop_last=False)

        # net setup
        print('Preparing net: ')
        net = get_fpn_net(self.net_params)
        # train setup
        lr = self.train_params['lr']
        epochs = self.train_params['epochs']
        weight_decay = self.train_params['weight_decay']

        self.optimizer = optim.Adam(net.parameters(),
                                    lr=lr,
                                    weight_decay=weight_decay,
                                    eps=1e-4)
        if self.net_params['pretrained']:
            checkpoint = torch.load(self.net_params['pretrained_model'],
                                    map_location="cuda")
            net.load_state_dict(checkpoint['net_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            for p in self.optimizer.param_groups:
                p['lr'] = lr
            for state in self.optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = v.cuda()
            print('CHECKPOINT LOADED')
        net.cuda()

        first_epoch = 0
        # scheduler
        if self.sched_type == 'ocp':
            last_epoch = -1 if first_epoch == 0 else first_epoch * len(
                train_loader)
            self.scheduler = OneCycleLR(
                self.optimizer,
                max_lr=lr,
                epochs=epochs,
                last_epoch=last_epoch,
                steps_per_epoch=len(train_loader),
                pct_start=self.train_params['ocp_params']['max_lr_pct'])
        elif self.sched_type == 'multi_step':
            last_epoch = -1 if first_epoch == 0 else first_epoch
            self.scheduler = MultiStepLR(
                self.optimizer,
                milestones=self.train_params['multi_params']['milestones'],
                gamma=self.train_params['multi_params']['gamma'],
                last_epoch=last_epoch)

        #start training

        net.train()
        val_rate = self.train_params['val_rate']
        test_rate = self.train_params['test_rate']
        for epoch in range(first_epoch, epochs):
            self.train_epoch(net, train_loader, epoch)

            if self.sched_type != 'ocp':
                self.writer.log_lr(epoch, self.scheduler.get_last_lr()[0])
                self.scheduler.step()

            if (epoch + 1) % val_rate == 0 or epoch == epochs - 1:
                self.eval(net, val_loader, epoch * len(train_loader))
            if (epoch + 1) % (val_rate *
                              test_rate) == 0 or epoch == epochs - 1:
                self.test_ap(net, epoch)
                self.save_checkpoints(epoch, net)

    def train_epoch(self, net, loader, epoch):
        net.train()
        loss_metric = LossMetric(self.cfg)
        probs = ProbsAverageMeter()

        for mini_batch_i, read_mini_batch in tqdm(enumerate(loader),
                                                  desc=f'Epoch {epoch}:',
                                                  ascii=True,
                                                  total=len(loader)):
            data, labels = read_mini_batch
            data = data.cuda()
            labels = [label.cuda() for label in labels]

            with amp.autocast():
                out = net(data)
                loss_dict, hm_probs = self.criterion(out, labels)
                loss = loss_metric.calculate_loss(loss_dict)
            self.optimizer.zero_grad()
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()

            probs.update(hm_probs)

            if self.sched_type == 'ocp':
                self.scheduler.step()

            loss_metric.add_sample(loss_dict)

            if mini_batch_i % self.update_interval == 0:
                if self.sched_type == 'ocp':
                    # TODO write average lr
                    self.writer.log_lr(epoch * len(loader) + mini_batch_i,
                                       self.scheduler.get_last_lr()[0])
                self.writer.log_training(epoch * len(loader) + mini_batch_i,
                                         loss_metric)
        self.writer.log_probs(epoch, probs.get_average())

    def eval(self, net, loader, step):
        net.eval()
        loss_metric = LossMetric(self.cfg)
        with torch.no_grad():
            for _, read_mini_batch in tqdm(enumerate(loader),
                                           desc=f'Val:',
                                           ascii=True,
                                           total=len(loader)):
                data, labels = read_mini_batch
                data = data.cuda()
                labels = [label.cuda() for label in labels]
                with amp.autocast():
                    out = net(data)
                    loss_dict, _ = self.criterion(out, labels)

                loss_metric.add_sample(loss_dict)

            self.writer.log_eval(step, loss_metric)

    def test_ap(self, net, epoch):
        for dataset in self.test_datasets:
            ap, _ = test(net, dataset, batch_size=self.batch_size)
            self.writer.log_ap(epoch, ap, dataset.name())
def train_model(config):
    # Define hyper-parameters.
    depth = int(config["DnCNN"]["depth"])
    n_channels = int(config["DnCNN"]["n_channels"])
    img_channel = int(config["DnCNN"]["img_channel"])
    kernel_size = int(config["DnCNN"]["kernel_size"])
    use_bnorm = config.getboolean("DnCNN", "use_bnorm")
    epochs = int(config["DnCNN"]["epoch"])
    batch_size = int(config["DnCNN"]["batch_size"])
    train_data_dir = config["DnCNN"]["train_data_dir"]
    test_data_dir = config["DnCNN"]["test_data_dir"]
    eta_min = float(config["DnCNN"]["eta_min"])
    eta_max = float(config["DnCNN"]["eta_max"])
    dose = float(config["DnCNN"]["dose"])
    model_save_dir = config["DnCNN"]["model_save_dir"]
    train_mode = config["DnCNN"]["train_mode"]
    log_file_name = config["DnCNN"]["log_file_name"]

    # Save logs to txt file.
    log_dir = config["DnCNN"]["log_dir"]
    if train_mode == "residual":
        log_dir = Path(log_dir) / "res_learning_smaller_lr_dose{}".format(str(int(dose * 100)))
    elif train_mode == "direct":
        log_dir = Path(log_dir) / "direct_predict_smaller_lr_dose{}".format(str(int(dose * 100)))
    log_file = log_dir / log_file_name

    # Define device.
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Initiate a DnCNN instance.
    # Load the model to device and set the model to training.
    model = DnCNN(depth=depth, n_channels=n_channels,
                  img_channel=img_channel,
                  use_bnorm=use_bnorm,
                  kernel_size=kernel_size)

    model = model.to(device)
    model.train()

    # Define loss criterion and optimizer
    optimizer = optim.Adam(model.parameters(), lr=5e-5)
    scheduler = MultiStepLR(optimizer, milestones=[20, 40, 60], gamma=0.2)
    criterion = LossFunc(reduction="mean", weight_mse=1, weight_nll=0, total_dose=dose * 100)

    # Get a validation test set and corrupt with noise for validation performance.
    # For every epoch, use this pre-determined noisy images.
    test_file_list = glob.glob(test_data_dir + "/*.png")
    xs_test = []
    # Can't directly convert the xs_test from list to ndarray because some images are 512*512
    # while the rest are 256*256.
    for i in range(len(test_file_list)):
        img = cv2.imread(test_file_list[i], 0)
        img = np.array(img, dtype="float32") / 255.0
        img = np.expand_dims(img, axis=0)
        img_noisy, _ = nm(img, eta_min, eta_max, dose, t=100)
        xs_test.append((img_noisy, img))
    
    # Get a validation train set and corrupt with noise.
    # For every epoch, use this pre-determined noisy images to see the training performance.
    train_file_list = glob.glob(train_data_dir + "/*png")
    xs_train = []
    for i in range(len(train_file_list)):
        img = cv2.imread(train_file_list[i], 0)
        img = np.array(img, dtype="float32") / 255.0
        img = np.expand_dims(img, axis=0)
        img_noisy, _ = nm(img, eta_min, eta_max, dose, t=100)
        xs_train.append((img_noisy, img))
    
    # Train the model.
    loss_store = []
    epoch_loss_store = []
    psnr_store = []
    ssim_store = []

    psnr_tr_store = []
    ssim_tr_store = []
    for epoch in range(epochs):
        # For each epoch, generate clean augmented patches from the training directory.
        # Convert the data from uint8 to float32 then scale them to make it in [0, 1].
        # Then make the patches to be of shape [N, C, H, W],
        # where N is the batch size, C is the number of color channels.
        # H and W are height and width of image patches.
        xs = dg.datagenerator(data_dir=train_data_dir)
        xs = xs.astype("float32") / 255.0
        xs = torch.from_numpy(xs.transpose((0, 3, 1, 2)))

        train_set = dg.DenoisingDatatset(xs, eta_min, eta_max, dose)
        train_loader = DataLoader(dataset=train_set, num_workers=4,
                                  drop_last=True, batch_size=batch_size,
                                  shuffle=True)  # TODO: if drop_last=True, the dropping in the
                                                 # TODO: data_generator is not necessary?

        # train_loader_test = next(iter(train_loader))

        t_start = timer()
        epoch_loss = 0
        for idx, data in enumerate(train_loader):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs, mode=train_mode)

            loss = criterion(outputs, labels)

            loss_store.append(loss.item())
            epoch_loss += loss.item()

            loss.backward()

            optimizer.step()

            if idx % 100 == 0:
                print("Epoch [{} / {}], step [{} / {}], loss = {:.5f}, lr = {:.8f}, elapsed time = {:.2f}s".format(
                    epoch + 1, epochs, idx, len(train_loader), loss.item(), *scheduler.get_last_lr(), timer()-t_start))

        epoch_loss_store.append(epoch_loss / len(train_loader))

        # At each epoch validate the result.
        model = model.eval()

        # Firstly validate on training sets. This takes a long time so I commented.
        tr_psnr = []
        tr_ssim = []
        # t_start = timer()
        with torch.no_grad():
            for idx, test_data in enumerate(xs_train):
                inputs, labels = test_data
                inputs = np.expand_dims(inputs, axis=0)
                inputs = torch.from_numpy(inputs).to(device)
                labels = labels.squeeze()

                outputs = model(inputs, mode=train_mode)
                outputs = outputs.squeeze().cpu().detach().numpy()

                tr_psnr.append(peak_signal_noise_ratio(labels, outputs))
                tr_ssim.append(structural_similarity(outputs, labels))
        psnr_tr_store.append(sum(tr_psnr) / len(tr_psnr))
        ssim_tr_store.append(sum(tr_ssim) / len(tr_ssim))
        # print("Elapsed time = {}".format(timer() - t_start))

        print("Validation on train set: epoch [{} / {}], aver PSNR = {:.2f}, aver SSIM = {:.4f}".format(
            epoch + 1, epochs, psnr_tr_store[-1], ssim_tr_store[-1]))

        # Validate on test set
        val_psnr = []
        val_ssim = []
        with torch.no_grad():
            for idx, test_data in enumerate(xs_test):
                inputs, labels = test_data
                inputs = np.expand_dims(inputs, axis=0)
                inputs = torch.from_numpy(inputs).to(device)
                labels = labels.squeeze()

                outputs = model(inputs, mode=train_mode)
                outputs = outputs.squeeze().cpu().detach().numpy()

                val_psnr.append(peak_signal_noise_ratio(labels, outputs))
                val_ssim.append(structural_similarity(outputs, labels))

        psnr_store.append(sum(val_psnr) / len(val_psnr))
        ssim_store.append(sum(val_ssim) / len(val_ssim))

        print("Validation on test set: epoch [{} / {}], aver PSNR = {:.2f}, aver SSIM = {:.4f}".format(
            epoch + 1, epochs, psnr_store[-1], ssim_store[-1]))

        # Set model to train mode again.
        model = model.train()

        scheduler.step()

        # Save model
        save_model(model, model_save_dir, epoch, dose * 100, train_mode)

        # Save the loss and validation PSNR, SSIM.

        if not log_dir.exists():
            Path.mkdir(log_dir)
        with open(log_file, "a+") as fh:
            fh.write("{} Epoch [{} / {}], loss = {:.6f}, train PSNR = {:.2f}dB, train SSIM = {:.4f}, "
                     "validation PSNR = {:.2f}dB, validation SSIM = {:.4f}\n".format(
                     datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"),
                     epoch + 1, epochs, epoch_loss_store[-1],
                     psnr_tr_store[-1], ssim_tr_store[-1],
                     psnr_store[-1], ssim_store[-1]))
Example #8
0
class ParameterServer(object):
    def __init__(self, i, args, model, bounded_delay):
        self.ps_index = i
        self.net = model.cuda()
        self.optimizer = optim.SGD(self.net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
        self.sched = MultiStepLR(self.optimizer, [70], gamma=0.1)  # TODO
        self.criterion = nn.CrossEntropyLoss()
        self.args = args
        self.num_of_workers = args.world_size
        self.num_of_groups = 4
        self.worker_per_group = 4
        self.push_num = [0] * self.worker_per_group
        self.pss = []
        self.bounded_delay = bounded_delay
        self.local_iter_list = self.worker_per_group * [0]
        self.sum_gradients = []
        sys.stdout = open(f'{self.args.stdout}/ps{self.ps_index:02}_stdout.log', 'a+', 1)
        sys.stderr = open(f'{self.args.stdout}/ps{self.ps_index:02}_stdout.log', 'a+', 1)
        self.epoch_timelist = [[-1] * self.num_of_groups for i in
                               range(self.worker_per_group)]  # timelist of all workers
        self.epoch_grouplist = [-1] * self.num_of_groups  # mean iteration time of the group
        self.initial_ratio = 0.4
        self.comp_ratio = [0.4, 0.4, 0.4, 0, 4]

    def init_pss(self, ps0, ps1, ps2):
        self.pss = [ps0, ps1, ps2]

    def desparsify(self, values, indices, ctx):
        # values, indices = tensors
        numel, shape = ctx
        tensor_decompressed = torch.zeros(numel, dtype=values.dtype, layout=values.layout, device=values.device)
        tensor_decompressed.scatter_(0, indices.long(), values)
        # print("after shape,",indices.long().dtype())
        return tensor_decompressed.view(shape)

    def apply_gradients(self, *gradients):
        total = 0
        for ele in range(0, len(self.push_num)):
            total = total + self.push_num[ele]
        degradients = []
        for tensors in gradients:
            for values, indices, ctx in tensors:
                afterdesparsify = self.desparsify(values, indices, ctx)
                degradients.append(afterdesparsify)
        summed_gradients = [
            torch.stack(gradient_zip).sum(dim=0) / self.args.world_size
            for gradient_zip in zip(degradients)
        ]
        self.optimizer.zero_grad()
        self.net.set_gradients(summed_gradients)
        self.optimizer.step()
        if total == 0:
            self.sum_gradients = summed_gradients
        if total % 4 == 0 and total != 0:
            self.sum_gradients = [
                torch.stack(gradient_zip).sum(dim=0)
                for gradient_zip in zip(self.sum_gradients, summed_gradients)
            ]
            for i in range(3):
                self.pss[i].ps_syn.remote(*(self.sum_gradients))
            self.sum_gradients.clear()
        elif total % 4 == 1:
            # self.sum_gradients.append(gradients)
            self.sum_gradients = summed_gradients
            # print("sumgradi",summed_gradients)
        else:
            self.sum_gradients = [
                torch.stack(gradient_zip).sum(dim=0)
                for gradient_zip in zip(self.sum_gradients, summed_gradients)
            ]

    def apply_allgradients(self, *gradients):
        summed_gradients = [
            torch.stack(gradient_zip).sum(dim=0) / self.args.world_size
            for gradient_zip in zip(*gradients)
        ]
        self.optimizer.zero_grad()
        self.net.set_gradients(summed_gradients)
        self.optimizer.step()
        for i in range(3):
            self.pss[i].ps_allsyn.remote(*gradients)

    def ps_allsyn(self, *gradients):
        summed_gradients = [
            torch.stack(gradient_zip).sum(dim=0) / self.args.world_size
            for gradient_zip in zip(*gradients)
        ]
        self.optimizer.zero_grad()
        self.net.set_gradients(summed_gradients)
        self.optimizer.step()
        # return self.net.get_weights()

    def ps_syn(self, summed_gradients):
        self.optimizer.zero_grad()
        self.net.set_gradients(summed_gradients)
        self.optimizer.step()

    def get_weights(self):
        return {k: v.cuda() for k, v in self.net.state_dict().items() if 'weight' in k or 'bias' in k}

    def lr_sched(self):
        self.sched.step()
        print(self.sched.get_last_lr()[0])

    def blocked(self, worker_index, local_iter):
        self.local_iter_list[int(worker_index / 4)] = local_iter
        min_iter = min(self.local_iter_list)
        return local_iter > min_iter + self.bounded_delay

    def get_time(self, worker_index, epoch_time):  # decide the fastest or the slowest
        self.epoch_timelist[worker_index % 4][int(worker_index / 4)] = epoch_time
        workercount = 0
        sumtime = 0
        for i in range(4):
            # some workers haven't finish 1 epoch
            if self.epoch_timelist[i].count(-1) != 0:
                return -1
        for i in range(4):
            workercount += 1
            sumtime += self.epoch_timelist[worker_index % 4][i]
        self.epoch_grouplist[worker_index % 4] = sumtime / workercount

        if self.epoch_grouplist[worker_index % 4] == max(self.epoch_grouplist):
            return 1
        elif self.epoch_grouplist[worker_index % 4] == min(self.epoch_grouplist):
            return 0
        else:
            return 2
Example #9
0
def train_model(config):
    # Define hyper-parameters.
    depth = int(config["DnCNN"]["depth"])
    n_channels = int(config["DnCNN"]["n_channels"])
    img_channel = int(config["DnCNN"]["img_channel"])
    kernel_size = int(config["DnCNN"]["kernel_size"])
    use_bnorm = config.getboolean("DnCNN", "use_bnorm")
    epochs = int(config["DnCNN"]["epoch"])
    batch_size = int(config["DnCNN"]["batch_size"])
    train_data_dir = config["DnCNN"]["train_data_dir"]
    test_data_dir = config["DnCNN"]["test_data_dir"]
    eta_min = float(config["DnCNN"]["eta_min"])
    eta_max = float(config["DnCNN"]["eta_max"])
    dose = float(config["DnCNN"]["dose"])
    model_save_dir = config["DnCNN"]["model_save_dir"]

    # Save logs to txt file.
    log_dir = config["DnCNN"]["log_dir"]
    log_dir = Path(log_dir) / "dose{}".format(str(int(dose * 100)))
    log_file = log_dir / "train_result.txt"

    # Define device.
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Initiate a DnCNN instance.
    # Load the model to device and set the model to training.
    model = DnCNN(depth=depth, n_channels=n_channels,
                  img_channel=img_channel,
                  use_bnorm=use_bnorm,
                  kernel_size=kernel_size)

    model = model.to(device)
    model.train()

    # Define loss criterion and optimizer
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    scheduler = MultiStepLR(optimizer, milestones=[30, 60, 90], gamma=0.2)
    criterion = LossFunc(reduction="mean")

    # Get a validation test set and corrupt with noise for validation performance.
    # For every epoch, use this pre-determined noisy images.
    test_file_list = glob.glob(test_data_dir + "/*.png")
    xs_test = []
    # Can't directly convert the xs_test from list to ndarray because some images are 512*512
    # while the rest are 256*256.
    for i in range(len(test_file_list)):
        img = cv2.imread(test_file_list[i], 0)
        img = np.array(img, dtype="float32") / 255.0
        img = np.expand_dims(img, axis=0)
        img_noisy, _ = nm(img, eta_min, eta_max, dose, t=100)
        xs_test.append((img_noisy, img))

    # Train the model.
    loss_store = []
    epoch_loss_store = []
    psnr_store = []
    ssim_store = []

    psnr_tr_store = []
    ssim_tr_store = []
    
    loss_mse = torch.nn.MSELoss()

    dtype = torch.cuda.FloatTensor
    # load vgg network
    vgg = Vgg16().type(dtype)
    
    
    for epoch in range(epochs):
        # For each epoch, generate clean augmented patches from the training directory.
        # Convert the data from uint8 to float32 then scale them to make it in [0, 1].
        # Then make the patches to be of shape [N, C, H, W],
        # where N is the batch size, C is the number of color channels.
        # H and W are height and width of image patches.
        xs = dg.datagenerator(data_dir=train_data_dir)
        xs = xs.astype("float32") / 255.0
        xs = torch.from_numpy(xs.transpose((0, 3, 1, 2)))

        train_set = dg.DenoisingDatatset(xs, eta_min, eta_max, dose)
        train_loader = DataLoader(dataset=train_set, num_workers=4,
                                  drop_last=True, batch_size=batch_size,
                                  shuffle=True)  # TODO: if drop_last=True, the dropping in the
                                                 # TODO: data_generator is not necessary?

        # train_loader_test = next(iter(train_loader))

        t_start = timer()
        epoch_loss = 0
        for idx, data in enumerate(train_loader):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            img_batch_read = len(inputs)

            optimizer.zero_grad()

            outputs = model(inputs)
            
            # We can use labels for both style and content image
            
                # style image
#             style_transform = transforms.Compose([
#             normalize_tensor_transform()      # normalize with ImageNet values
#             ])
            
#             labels_t = style_transform(labels)
                        
            labels_t = labels.repeat(1, 3, 1, 1)
            outputs_t = outputs.repeat(1, 3, 1, 1)            
            
            y_c_features = vgg(labels_t)
            style_gram = [gram(fmap) for fmap in y_c_features]
            
            y_hat_features = vgg(outputs_t)
            y_hat_gram = [gram(fmap) for fmap in y_hat_features]            
            
            # calculate style loss
            style_loss = 0.0
            for j in range(4):
                style_loss += loss_mse(y_hat_gram[j], style_gram[j][:img_batch_read])
            style_loss = STYLE_WEIGHT*style_loss
            aggregate_style_loss = style_loss

            # calculate content loss (h_relu_2_2)
            recon = y_c_features[1]      
            recon_hat = y_hat_features[1]
            content_loss = CONTENT_WEIGHT*loss_mse(recon_hat, recon)
            aggregate_content_loss = content_loss
            
            loss = aggregate_content_loss + aggregate_style_loss
#             loss = criterion(outputs, labels)
            
            loss_store.append(loss.item())
            epoch_loss += loss.item()

            loss.backward()

            optimizer.step()

            if idx % 100 == 0:
                print("Epoch [{} / {}], step [{} / {}], loss = {:.5f}, lr = {:.6f}, elapsed time = {:.2f}s".format(
                    epoch + 1, epochs, idx, len(train_loader), loss.item(), *scheduler.get_last_lr(), timer()-t_start))

        epoch_loss_store.append(epoch_loss / len(train_loader))

        # At each epoch validate the result.
        model = model.eval()

        # # Firstly validate on training sets. This takes a long time so I commented.
        # tr_psnr = []
        # tr_ssim = []
        # # t_start = timer()
        # with torch.no_grad():
        #     for idx, train_data in enumerate(train_loader):
        #         inputs, labels = train_data
        #         # print(inputs.shape)
        #         # inputs = np.expand_dims(inputs, axis=0)
        #         # inputs = torch.from_numpy(inputs).to(device)
        #         inputs = inputs.to(device)
        #         labels = labels.squeeze().numpy()
        #
        #         outputs = model(inputs)
        #         outputs = outputs.squeeze().cpu().detach().numpy()
        #
        #         tr_psnr.append(peak_signal_noise_ratio(labels, outputs))
        #         tr_ssim.append(structural_similarity(outputs, labels))
        # psnr_tr_store.append(sum(tr_psnr) / len(tr_psnr))
        # ssim_tr_store.append(sum(tr_ssim) / len(tr_ssim))
        # # print("Elapsed time = {}".format(timer() - t_start))
        #
        # print("Validation on train set: epoch [{} / {}], aver PSNR = {:.2f}, aver SSIM = {:.4f}".format(
        #     epoch + 1, epochs, psnr_tr_store[-1], ssim_tr_store[-1]))

        # Validate on test set
        val_psnr = []
        val_ssim = []
        with torch.no_grad():
            for idx, test_data in enumerate(xs_test):
                inputs, labels = test_data
                inputs = np.expand_dims(inputs, axis=0)
                inputs = torch.from_numpy(inputs).to(device)
                labels = labels.squeeze()

                outputs = model(inputs)
                outputs = outputs.squeeze().cpu().detach().numpy()

                val_psnr.append(peak_signal_noise_ratio(labels, outputs))
                val_ssim.append(structural_similarity(outputs, labels))

        psnr_store.append(sum(val_psnr) / len(val_psnr))
        ssim_store.append(sum(val_ssim) / len(val_ssim))

        print("Validation on test set: epoch [{} / {}], aver PSNR = {:.2f}, aver SSIM = {:.4f}".format(
            epoch + 1, epochs, psnr_store[-1], ssim_store[-1]))

        # Set model to train mode again.
        model = model.train()

        scheduler.step()

        # Save model
        save_model(model, model_save_dir, epoch, dose * 100)

        # Save the loss and validation PSNR, SSIM.

        if not log_dir.exists():
            Path.mkdir(log_dir)
        with open(log_file, "a+") as fh:
            # fh.write("{} Epoch [{} / {}], loss = {:.6f}, train PSNR = {:.2f}dB, train SSIM = {:.4f}, "
            #          "validation PSNR = {:.2f}dB, validation SSIM = {:.4f}".format(
            #          datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"),
            #          epoch + 1, epochs, epoch_loss_store[-1],
            #          psnr_tr_store[-1], ssim_tr_store[-1],
            #          psnr_store[-1], ssim_store[-1]))
            fh.write("{} Epoch [{} / {}], loss = {:.6f}, "
                     "validation PSNR = {:.2f}dB, validation SSIM = {:.4f}\n".format(
                     datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"),
                     epoch + 1, epochs, epoch_loss_store[-1],
                     psnr_store[-1], ssim_store[-1]))

        # np.savetxt(log_file, np.hstack((epoch + 1, epoch_loss_store[-1], psnr_store[-1], ssim_store[-1])), fmt="%.6f", delimiter=",  ")

        fig, ax = plt.subplots()
        ax.plot(loss_store[-len(train_loader):])
        ax.set_title("Last 1862 losses")
        ax.set_xlabel("iteration")
        fig.show()
def train_model_residual_lowlight_twostage_unet():

    learning_rate = INIT_LEARNING_RATE * 0.5
    device = DEVICE
    #准备数据
    train_set = HsiCubicTrainDataset('./data/train_lowlight/')
    print('total training example:', len(train_set))

    train_loader = DataLoader(dataset=train_set,
                              batch_size=BATCH_SIZE,
                              shuffle=True)

    #加载测试label数据
    mat_src_path = './data/test_lowlight/origin/soup_bigcorn_orange_1ms.mat'
    test_label_hsi = scio.loadmat(mat_src_path)['label']

    #加载测试数据
    batch_size = 1
    test_data_dir = './data/test_lowlight/cubic/'
    test_set = HsiCubicLowlightTestDataset(test_data_dir)
    test_dataloader = DataLoader(dataset=test_set,
                                 batch_size=batch_size,
                                 shuffle=False)

    batch_size, channel, width, height = next(iter(test_dataloader))[0].shape

    band_num = len(test_dataloader)
    denoised_hsi = np.zeros((width, height, band_num))

    #创建模型
    net = TwoStageHSIDWithUNet(K)
    init_params(net)
    #net = nn.DataParallel(net).to(device)
    net = net.to(device)

    #创建优化器
    #hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE, betas=(0.9, 0,999))
    hsid_optimizer = optim.Adam(net.parameters(), lr=learning_rate)
    scheduler = MultiStepLR(hsid_optimizer, milestones=[40, 60, 80], gamma=0.1)

    #定义loss 函数
    #criterion = nn.MSELoss()

    global tb_writer
    tb_writer = get_summary_writer(log_dir='logs')

    gen_epoch_loss_list = []

    cur_step = 0

    first_batch = next(iter(train_loader))

    for epoch in range(NUM_EPOCHS):
        scheduler.step()
        print(epoch, 'lr={:.6f}'.format(scheduler.get_last_lr()[0]))
        gen_epoch_loss = 0

        net.train()
        #for batch_idx, (noisy, label) in enumerate([first_batch] * 300):
        for batch_idx, (noisy, cubic, label) in enumerate(train_loader):
            #print('batch_idx=', batch_idx)
            noisy = noisy.to(device)
            label = label.to(device)
            cubic = cubic.to(device)

            hsid_optimizer.zero_grad()
            #denoised_img = net(noisy, cubic)
            #loss = loss_fuction(denoised_img, label)

            residual, residual_stage2 = net(noisy, cubic)
            loss = loss_function_with_tvloss(
                residual, label - noisy) + loss_function_with_tvloss(
                    residual_stage2, label - noisy)

            loss.backward()  # calcu gradient
            hsid_optimizer.step()  # update parameter

            gen_epoch_loss += loss.item()

            if cur_step % display_step == 0:
                if cur_step > 0:
                    print(
                        f"Epoch {epoch}: Step {cur_step}: Batch_idx {batch_idx}: MSE loss: {loss.item()}"
                    )
                else:
                    print("Pretrained initial state")

            tb_writer.add_scalar("MSE loss", loss.item(), cur_step)

            #step ++,每一次循环,每一个batch的处理,叫做一个step
            cur_step += 1

        gen_epoch_loss_list.append(gen_epoch_loss)
        tb_writer.add_scalar("mse epoch loss", gen_epoch_loss, epoch)

        #scheduler.step()
        #print("Decaying learning rate to %g" % scheduler.get_last_lr()[0])

        torch.save(
            {
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict(),
            }, f"checkpoints/two_stage_unet_hsid_{epoch}.pth")

        #测试代码
        net.eval()
        for batch_idx, (noisy_test, cubic_test,
                        label_test) in enumerate(test_dataloader):
            noisy_test = noisy_test.type(torch.FloatTensor)
            label_test = label_test.type(torch.FloatTensor)
            cubic_test = cubic_test.type(torch.FloatTensor)

            noisy_test = noisy_test.to(DEVICE)
            label_test = label_test.to(DEVICE)
            cubic_test = cubic_test.to(DEVICE)

            with torch.no_grad():

                residual, residual_stage2 = net(noisy_test, cubic_test)
                denoised_band = noisy_test + residual_stage2

                denoised_band_numpy = denoised_band.cpu().numpy().astype(
                    np.float32)
                denoised_band_numpy = np.squeeze(denoised_band_numpy)

                denoised_hsi[:, :, batch_idx] = denoised_band_numpy

        psnr = PSNR(denoised_hsi, test_label_hsi)
        ssim = SSIM(denoised_hsi, test_label_hsi)
        sam = SAM(denoised_hsi, test_label_hsi)

        #计算pnsr和ssim
        print("=====averPSNR:{:.3f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".
              format(psnr, ssim, sam))
        tb_writer.add_scalars("validation metrics", {
            'average PSNR': psnr,
            'average SSIM': ssim,
            'avarage SAM': sam
        }, epoch)  #通过这个我就可以看到,那个epoch的性能是最好的

    tb_writer.close()
def train_model():

    device = DEVICE
    #准备数据
    train_set = HsiCubicTrainDataset('./data/train_cubic/')
    train_loader = DataLoader(dataset=train_set,
                              batch_size=BATCH_SIZE,
                              shuffle=True)

    #加载测试label数据
    test_label_hsi = np.load('./data/origin/test_washington.npy')

    #加载测试数据
    test_data_dir = './data/test_level25/'
    test_set = HsiTrainDataset(test_data_dir)

    test_dataloader = DataLoader(test_set, batch_size=1, shuffle=False)

    #创建模型
    net = HSID_1x3(K)
    init_params(net)
    net = nn.DataParallel(net).to(device)

    #创建优化器
    #hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE, betas=(0.9, 0,999))
    hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE)
    scheduler = MultiStepLR(hsid_optimizer,
                            milestones=[15, 30, 45],
                            gamma=0.25)

    #定义loss 函数
    #criterion = nn.MSELoss()

    global tb_writer
    tb_writer = get_summary_writer(log_dir='logs')

    gen_epoch_loss_list = []

    cur_step = 0

    first_batch = next(iter(train_loader))

    for epoch in range(NUM_EPOCHS):

        gen_epoch_loss = 0

        net.train()
        #for batch_idx, (noisy, label) in enumerate([first_batch] * 300):
        for batch_idx, (noisy, cubic, label) in enumerate(train_loader):

            noisy = noisy.to(device)
            label = label.to(device)
            cubic = cubic.to(device)

            hsid_optimizer.zero_grad()
            denoised_img = net(noisy, cubic)
            loss = loss_fuction(denoised_img, label)

            loss.backward()  # calcu gradient
            hsid_optimizer.step()  # update parameter

            gen_epoch_loss += loss.item()

            if cur_step % display_step == 0:
                if cur_step > 0:
                    print(
                        f"Epoch {epoch}: Step {cur_step}: MSE loss: {loss.item()}"
                    )
                else:
                    print("Pretrained initial state")

            tb_writer.add_scalar("MSE loss", loss.item(), cur_step)

            #step ++,每一次循环,每一个batch的处理,叫做一个step
            cur_step += 1

        gen_epoch_loss_list.append(gen_epoch_loss)
        tb_writer.add_scalar("mse epoch loss", gen_epoch_loss, epoch)

        scheduler.step()
        print("Decaying learning rate to %g" % scheduler.get_last_lr()[0])

        torch.save(
            {
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict(),
            }, f"checkpoints/hsid_{epoch}.pth")

        #预测代码
        net.eval()
        for batch_idx, (noisy, label) in enumerate(test_dataloader):
            noisy = noisy.type(torch.FloatTensor)
            label = label.type(torch.FloatTensor)

            batch_size, width, height, band_num = noisy.shape
            denoised_hsi = np.zeros((width, height, band_num))

            noisy = noisy.to(DEVICE)
            label = label.to(DEVICE)

            with torch.no_grad():
                for i in range(band_num):  #遍历每个band去处理
                    current_noisy_band = noisy[:, :, :, i]
                    current_noisy_band = current_noisy_band[:, None]

                    adj_spectral_bands = get_adjacent_spectral_bands(
                        noisy, K,
                        i)  # shape: batch_size, width, height, band_num
                    adj_spectral_bands = torch.transpose(
                        adj_spectral_bands, 3, 1
                    )  #交换第一维和第三维 ,shape: batch_size, band_num, height, width
                    denoised_band = net(current_noisy_band, adj_spectral_bands)

                    denoised_band_numpy = denoised_band.cpu().numpy().astype(
                        np.float32)
                    denoised_band_numpy = np.squeeze(denoised_band_numpy)

                    denoised_hsi[:, :, i] = denoised_band_numpy

        psnr = PSNR(denoised_hsi, test_label_hsi)
        ssim = SSIM(denoised_hsi, test_label_hsi)
        sam = SAM(denoised_hsi, test_label_hsi)

        #计算pnsr和ssim
        print("=====averPSNR:{:.3f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".
              format(psnr, ssim, sam))

    tb_writer.close()
Example #12
0
class ParameterServer(object):
    def __init__(self, i, args, model, bounded_delay):
        self.ps_index = i
        self.net = model.cuda()
        self.optimizer = optim.SGD(self.net.parameters(),
                                   lr=args.lr,
                                   momentum=0.9,
                                   weight_decay=5e-4)
        if args.dataset == "cifar10" and args.model == "resnet18":
            self.sched = MultiStepLR(self.optimizer, [50, 70],
                                     gamma=0.1)  # TODO
        if args.dataset == "cifar10" and args.model == "vgg19":
            self.sched = MultiStepLR(self.optimizer, [70, 100],
                                     gamma=0.1)  # TODO
        self.criterion = nn.CrossEntropyLoss()
        self.args = args
        self.num_of_workers = args.world_size
        self.num_of_groups = 4
        self.worker_per_group = 4
        self.push_num = [0] * self.worker_per_group
        self.pss = []
        self.bounded_delay = bounded_delay
        self.local_iter_list = self.worker_per_group * [0]
        self.sum_gradients = []
        sys.stdout = open(
            f'{self.args.stdout}/ps{self.ps_index:02}_stdout.log', 'a+', 1)
        sys.stderr = open(
            f'{self.args.stdout}/ps{self.ps_index:02}_stdout.log', 'a+', 1)

    def init_pss(self, ps0, ps1, ps2):
        self.pss = [ps0, ps1, ps2]

    def apply_gradients(self, worker_index, *gradients):
        total = 0
        for ele in range(0, len(self.push_num)):
            total = total + self.push_num[ele]
        itr = self.push_num[int((worker_index - self.ps_index) / 4)]
        summed_gradients = [
            torch.stack(gradient_zip).sum(dim=0) / self.args.world_size
            for gradient_zip in zip(*gradients)
        ]
        self.optimizer.zero_grad()
        self.net.set_gradients(summed_gradients)
        self.optimizer.step()
        if total == 0:
            self.sum_gradients = summed_gradients
        if total % 4 == 0 and total != 0:
            self.sum_gradients = [
                torch.stack(gradient_zip).sum(dim=0)
                for gradient_zip in zip(self.sum_gradients, summed_gradients)
            ]
            for i in range(3):
                self.pss[i].ps_syn.remote(i, worker_index, itr, time.time(),
                                          *(self.sum_gradients))
            self.sum_gradients.clear()
        elif total % 4 == 1:
            # self.sum_gradients.append(gradients)
            self.sum_gradients = summed_gradients
            # print("sumgradi",summed_gradients)
        else:
            self.sum_gradients = [
                torch.stack(gradient_zip).sum(dim=0)
                for gradient_zip in zip(self.sum_gradients, summed_gradients)
            ]

            # return self.net.get_weights()

    def ps_syn(self, i, worker_index, summed_gradients):
        self.optimizer.zero_grad()
        self.net.set_gradients(summed_gradients)
        self.optimizer.step()

    def lr_sched(self):
        self.sched.step()
        print("ps_index,", self.ps_index, self.sched.get_last_lr()[0])

    def get_weights(self):
        return {
            k: v.cuda()
            for k, v in self.net.state_dict().items()
            if 'weight' in k or 'bias' in k
        }

    def blocked(self, worker_index, local_iter):
        self.local_iter_list[int(worker_index / 4)] = local_iter
        min_iter = min(self.local_iter_list)
        return local_iter > min_iter + self.bounded_delay
Example #13
0
def train_model_residual_lowlight():

    device = DEVICE
    #准备数据
    train_set = HsiCubicTrainDataset('../HSID/data/train_lowlight/')
    print('total training example:', len(train_set))

    train_loader = DataLoader(dataset=train_set,
                              batch_size=BATCH_SIZE,
                              shuffle=True)

    #加载测试label数据
    mat_src_path = '../HSID/data/test_lowlight/origin/soup_bigcorn_orange_1ms.mat'
    test_label_hsi = scio.loadmat(mat_src_path)['label']

    #加载测试数据
    batch_size = 1
    test_data_dir = '../HSID/data/test_lowlight/cubic/'
    test_set = HsiCubicLowlightTestDataset(test_data_dir)
    test_dataloader = DataLoader(dataset=test_set,
                                 batch_size=batch_size,
                                 shuffle=False)

    batch_size, channel, width, height = next(iter(test_dataloader))[0].shape

    band_num = len(test_dataloader)
    denoised_hsi = np.zeros((width, height, band_num))

    #创建模型
    net = ENCAM()
    #init_params(net) #创建encam时,已经通过self._initialize_weights()进行了初始化
    net = net.to(device)
    #net = nn.DataParallel(net)
    #net = net.to(device)

    #创建优化器
    #hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE, betas=(0.9, 0,999))
    hsid_optimizer = optim.Adam(net.parameters(), lr=INIT_LEARNING_RATE)
    scheduler = MultiStepLR(hsid_optimizer, milestones=[15, 30, 45], gamma=0.1)

    #定义loss 函数
    #criterion = nn.MSELoss()

    global tb_writer
    tb_writer = get_summary_writer(log_dir='logs')

    gen_epoch_loss_list = []

    cur_step = 0

    first_batch = next(iter(train_loader))

    best_psnr = 0
    best_epoch = 0
    best_iter = 0

    for epoch in range(NUM_EPOCHS):
        scheduler.step()
        print(epoch, 'lr={:.6f}'.format(scheduler.get_last_lr()[0]))
        gen_epoch_loss = 0

        net.train()
        #for batch_idx, (noisy, cubic, label) in enumerate([first_batch] * 300):
        for batch_idx, (noisy, cubic, label) in enumerate(train_loader):
            #noisy, cubic, label = next(iter(train_loader)) #从dataloader中取出一个batch
            #print('batch_idx=', batch_idx)
            noisy = noisy.to(device)
            label = label.to(device)
            cubic = cubic.to(device)

            hsid_optimizer.zero_grad()
            #denoised_img = net(noisy, cubic)
            #loss = loss_fuction(denoised_img, label)

            residual = net(noisy, cubic)
            loss = loss_fuction(residual, label - noisy)

            loss.backward()  # calcu gradient
            hsid_optimizer.step()  # update parameter

            gen_epoch_loss += loss.item()

            if cur_step % display_step == 0:
                if cur_step > 0:
                    print(
                        f"Epoch {epoch}: Step {cur_step}: Batch_idx {batch_idx}: MSE loss: {loss.item()}"
                    )
                else:
                    print("Pretrained initial state")

            tb_writer.add_scalar("MSE loss", loss.item(), cur_step)

            #step ++,每一次循环,每一个batch的处理,叫做一个step
            cur_step += 1

        gen_epoch_loss_list.append(gen_epoch_loss)
        tb_writer.add_scalar("mse epoch loss", gen_epoch_loss, epoch)

        #scheduler.step()
        #print("Decaying learning rate to %g" % scheduler.get_last_lr()[0])

        torch.save(
            {
                'gen': net.state_dict(),
                'gen_opt': hsid_optimizer.state_dict(),
            }, f"checkpoints/encam_{epoch}.pth")

        #测试代码

        net.eval()
        for batch_idx, (noisy_test, cubic_test,
                        label_test) in enumerate(test_dataloader):
            noisy_test = noisy_test.type(torch.FloatTensor)
            label_test = label_test.type(torch.FloatTensor)
            cubic_test = cubic_test.type(torch.FloatTensor)

            noisy_test = noisy_test.to(DEVICE)
            label_test = label_test.to(DEVICE)
            cubic_test = cubic_test.to(DEVICE)

            with torch.no_grad():

                #对图像下采样
                #noisy_permute = noisy.permute(0, 3,1,2)#交换第一维和第三维 ,shape: batch_size, band_num, height, width
                #label_permute = label.permute(0, 3, 1, 2)
                noisy_test_down = F.interpolate(noisy_test,
                                                scale_factor=0.5,
                                                mode='bilinear')
                cubic_test_squeeze = torch.squeeze(cubic_test, 0)
                cubic_test_down = F.interpolate(cubic_test_squeeze,
                                                scale_factor=0.5,
                                                mode='bilinear')
                cubic_test_down_unsqueeze = torch.unsqueeze(cubic_test_down, 0)
                residual = net(noisy_test_down, cubic_test_down_unsqueeze)
                denoised_band = noisy_test_down + residual

                #图像上采样
                denoised_band = F.interpolate(denoised_band,
                                              scale_factor=2,
                                              mode='bilinear')

                denoised_band_numpy = denoised_band.cpu().numpy().astype(
                    np.float32)
                denoised_band_numpy = np.squeeze(denoised_band_numpy)

                denoised_hsi[:, :, batch_idx] = denoised_band_numpy

        psnr = PSNR(denoised_hsi, test_label_hsi)
        ssim = SSIM(denoised_hsi, test_label_hsi)
        sam = SAM(denoised_hsi, test_label_hsi)

        #计算pnsr和ssim
        print("=====averPSNR:{:.3f}=====averSSIM:{:.4f}=====averSAM:{:.3f}".
              format(psnr, ssim, sam))
        tb_writer.add_scalars("validation metrics", {
            'average PSNR': psnr,
            'average SSIM': ssim,
            'avarage SAM': sam
        }, epoch)  #通过这个我就可以看到,那个epoch的性能是最好的

        #保存best模型
        if psnr > best_psnr:
            best_psnr = psnr
            best_epoch = epoch
            best_iter = cur_step
            torch.save(
                {
                    'gen': net.state_dict(),
                    'gen_opt': hsid_optimizer.state_dict(),
                }, f"checkpoints/encam_best.pth")

        print(
            "[epoch %d it %d PSNR: %.4f --- best_epoch %d best_iter %d Best_PSNR %.4f]"
            % (epoch, cur_step, psnr, best_epoch, best_iter, best_psnr))

    tb_writer.close()