예제 #1
0
파일: metrics.py 프로젝트: MVPTylerE/GeoSR
def ssim(y_true, y_pred):
    """
    https://github.com/jorge-pessoa/pytorch-msssim
    """
    m = pytorch_msssim.MSSSIM()
    ssim_tensor = m(y_true, y_pred)
    return np.float(ssim_tensor)
예제 #2
0
def train(args):

    print('Number of GPUs available: ' + str(torch.cuda.device_count()))
    model = nn.DataParallel(CAEP(num_resblocks).cuda())
    print('Done Setup Model.')

    dataset = BSDS500Crop128(args.dataset_path)
    dataloader = DataLoader(dataset,
                            batch_size=args.batch_size,
                            shuffle=args.shuffle,
                            num_workers=args.num_workers)
    testset = Kodak(args.testset_path)
    testloader = DataLoader(testset,
                            batch_size=testset.__len__(),
                            num_workers=args.num_workers)
    print(
        f"Done Setup Training DataLoader: {len(dataloader)} batches of size {args.batch_size}"
    )
    print(f"Done Setup Testing DataLoader: {len(testset)} Images")

    MSE = nn.MSELoss()
    SSIM = pytorch_msssim.SSIM().cuda()
    MSSSIM = pytorch_msssim.MSSSIM().cuda()

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.learning_rate,
                                 weight_decay=1e-10)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.5,
        patience=10,
        verbose=True,
    )

    writer = SummaryWriter(log_dir=f'TBXLog/{args.exp_name}')

    # ADMM variables
    Z = torch.zeros(16, 32, 32).cuda()
    U = torch.zeros(16, 32, 32).cuda()
    Z.requires_grad = False
    U.requires_grad = False

    if args.load != '':
        pretrained_state_dict = torch.load(f"./chkpt/{args.load}/model.state")
        current_state_dict = model.state_dict()
        current_state_dict.update(pretrained_state_dict)
        model.load_state_dict(current_state_dict)
        # Z = torch.load(f"./chkpt/{args.load}/Z.state")
        # U = torch.load(f"./chkpt/{args.load}/U.state")
        if args.load == args.exp_name:
            optimizer.load_state_dict(
                torch.load(f"./chkpt/{args.load}/opt.state"))
            scheduler.load_state_dict(
                torch.load(f"./chkpt/{args.load}/lr.state"))
        print('Model Params Loaded.')

    model.train()

    for ei in range(args.res_epoch + 1, args.res_epoch + args.num_epochs + 1):
        # train
        train_loss = 0
        train_ssim = 0
        train_msssim = 0
        train_psnr = 0
        train_peanalty = 0
        train_bpp = 0
        avg_c = torch.zeros(16, 32, 32).cuda()
        avg_c.requires_grad = False

        for bi, crop in enumerate(dataloader):
            x = crop.cuda()
            y, c = model(x)

            psnr = compute_psnr(x, y)
            mse = MSE(y, x)
            ssim = SSIM(x, y)
            msssim = MSSSIM(x, y)

            mix = 1000 * (1 - msssim) + 1000 * (1 - ssim) + 1e4 * mse + (45 -
                                                                         psnr)
            # ADMM Step 1
            peanalty = rho / 2 * torch.norm(c - Z + U, 2)
            bpp = compute_bpp(c, x.shape[0], 'crop', save=False)

            avg_c += torch.mean(c.detach() /
                                (len(dataloader) * args.admm_every),
                                dim=0)

            loss = mix + peanalty
            if ei == 1 and args.load != args.exp_name:
                loss = 1e5 * mse  # warm up

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print(
                '[%3d/%3d][%5d/%5d] Loss: %f, SSIM: %f, MSSSIM: %f, PSNR: %f, Norm of Code: %f, BPP: %2f'
                % (ei, args.num_epochs + args.res_epoch, bi, len(dataloader),
                   loss, ssim, msssim, psnr, peanalty, bpp))
            writer.add_scalar('batch_train/loss', loss,
                              ei * len(dataloader) + bi)
            writer.add_scalar('batch_train/ssim', ssim,
                              ei * len(dataloader) + bi)
            writer.add_scalar('batch_train/msssim', msssim,
                              ei * len(dataloader) + bi)
            writer.add_scalar('batch_train/psnr', psnr,
                              ei * len(dataloader) + bi)
            writer.add_scalar('batch_train/norm', peanalty,
                              ei * len(dataloader) + bi)
            writer.add_scalar('batch_train/bpp', bpp,
                              ei * len(dataloader) + bi)

            train_loss += loss.item() / len(dataloader)
            train_ssim += ssim.item() / len(dataloader)
            train_msssim += msssim.item() / len(dataloader)
            train_psnr += psnr.item() / len(dataloader)
            train_peanalty += peanalty.item() / len(dataloader)
            train_bpp += bpp / len(dataloader)

        writer.add_scalar('epoch_train/loss', train_loss, ei)
        writer.add_scalar('epoch_train/ssim', train_ssim, ei)
        writer.add_scalar('epoch_train/msssim', train_msssim, ei)
        writer.add_scalar('epoch_train/psnr', train_psnr, ei)
        writer.add_scalar('epoch_train/norm', train_peanalty, ei)
        writer.add_scalar('epoch_train/bpp', train_bpp, ei)

        if ei % args.admm_every == args.admm_every - 1:
            # ADMM Step 2
            Z = (avg_c + U).masked_fill_((torch.Tensor(
                np.argsort((avg_c + U).data.cpu().numpy(), axis=None)) >= int(
                    (1 - pruning_ratio) * 16 * 32 * 32)).view(16, 32,
                                                              32).cuda(),
                                         value=0)
            # ADMM Step 3
            U += avg_c - Z

        # test
        model.eval()
        val_loss = 0
        val_ssim = 0
        val_msssim = 0
        val_psnr = 0
        val_peanalty = 0
        val_bpp = 0
        for bi, (img, patches, _) in enumerate(testloader):
            avg_loss = 0
            avg_ssim = 0
            avg_msssim = 0
            avg_psnr = 0
            avg_peanalty = 0
            avg_bpp = 0
            for i in range(6):
                for j in range(4):
                    x = torch.Tensor(patches[:, i, j, :, :, :]).cuda()
                    y, c = model(x)

                    psnr = compute_psnr(x, y)
                    mse = MSE(y, x)
                    ssim = SSIM(x, y)
                    msssim = MSSSIM(x, y)

                    mix = 1000 * (1 - msssim) + 1000 * (
                        1 - ssim) + 1e4 * mse + (45 - psnr)

                    peanalty = rho / 2 * torch.norm(c - Z + U, 2)
                    bpp = compute_bpp(c,
                                      x.shape[0],
                                      f'Kodak_patches_{i}_{j}',
                                      save=True)
                    loss = mix + peanalty

                    avg_loss += loss.item() / 24
                    avg_ssim += ssim.item() / 24
                    avg_msssim += msssim.item() / 24
                    avg_psnr += psnr.item() / 24
                    avg_peanalty += peanalty.item() / 24
                    avg_bpp += bpp / 24

            save_kodak_img(model, img, 0, patches, writer, ei)
            save_kodak_img(model, img, 10, patches, writer, ei)
            save_kodak_img(model, img, 20, patches, writer, ei)

            val_loss += avg_loss
            val_ssim += avg_ssim
            val_msssim += avg_msssim
            val_psnr += avg_psnr
            val_peanalty += avg_peanalty
            val_bpp += avg_bpp
        print(
            '*Kodak: [%3d/%3d] Loss: %f, SSIM: %f, MSSSIM: %f, Norm of Code: %f, BPP: %.2f'
            % (ei, args.num_epochs + args.res_epoch, val_loss, val_ssim,
               val_msssim, val_peanalty, val_bpp))

        # bz = call('tar -jcvf ./code/code.tar.bz ./code', shell=True)
        # total_code_size = os.stat('./code/code.tar.bz').st_size
        # total_bpp = total_code_size * 8 / 24 / 768 / 512

        writer.add_scalar('test/loss', val_loss, ei)
        writer.add_scalar('test/ssim', val_ssim, ei)
        writer.add_scalar('test/msssim', val_msssim, ei)
        writer.add_scalar('test/psnr', val_psnr, ei)
        writer.add_scalar('test/norm', val_peanalty, ei)
        writer.add_scalar('test/bpp', val_bpp, ei)
        # writer.add_scalar('test/total_bpp', total_bpp, ei)
        model.train()

        scheduler.step(train_loss)

        # save model
        if ei % args.save_every == args.save_every - 1:
            torch.save(model.state_dict(),
                       f"./chkpt/{args.exp_name}/model.state")
            torch.save(optimizer.state_dict(),
                       f"./chkpt/{args.exp_name}/opt.state")
            torch.save(scheduler.state_dict(),
                       f"./chkpt/{args.exp_name}/lr.state")
            torch.save(Z, f"./chkpt/{args.exp_name}/Z.state")
            torch.save(U, f"./chkpt/{args.exp_name}/U.state")

    writer.close()
예제 #3
0
dataloader_secret = torch.utils.data.DataLoader(dataset_secret,
                                                batch_size=batch_size,
                                                shuffle=False,
                                                num_workers=1)

# initialize the model and load the params

encoder = model.Encoder()
encoder = encoder.to(device)

# decoder (discriminator)
decoder = model.Decoder()
decoder = decoder.to(device)

ssim_loss = pytorch_ssim.SSIM()
mssim_loss = pytorch_msssim.MSSSIM()
mse_loss = nn.MSELoss()
# dis_loss=nn.BCELoss()

print('loading params')

path = model_dir + '/' + str(epoch) + '.pth.tar'  # load theepoch params

checkpoint = torch.load(path, map_location='cpu')
encoder.load_state_dict(checkpoint['encoder_state_dict'])
decoder.load_state_dict(checkpoint['deocoder_state_dict'])

cover_ssmi_train = checkpoint['cover_ssmi']
secret_ssmi_train = checkpoint['secret_ssmi']
network_loss = checkpoint['net_loss']
encoder.eval()
예제 #4
0
def main():
    global opt, model
    opt = parser.parse_args()
    print(opt)

    save_path = os.path.join('.', "model", "{}_{}".format(opt.model, opt.ID))
    log_dir = './records/{}_{}/'.format(opt.model, opt.ID)
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    cuda = opt.cuda
    if cuda and not torch.cuda.is_available():
        raise Exception("No GPU found, please run without --cuda")

    opt.seed = random.randint(1, 10000)
    # opt.seed = 4222
    coeff_mse = opt.coeff_totalloss
    coeff_J = opt.coeff_J
    print("Random Seed: ", opt.seed)

    cudnn.benchmark = True

    print("===> Loading datasets")
    train_set = DatasetFromHdf5(opt.traindata, opt.patchSize, opt.aug)
    training_data_loader = DataLoader(dataset=train_set,
                                      num_workers=opt.threads,
                                      batch_size=opt.batchSize,
                                      shuffle=True)

    print("===> Building model")
    if opt.model == 'dense':
        model = Dense()
    else:
        raise ValueError("no known model of {}".format(opt.model))
    criterion = nn.MSELoss()
    Absloss = nn.L1Loss()
    ssim_loss = pytorch_msssim.MSSSIM()

    #loss_var = torch.std()
    if opt.freeze:
        model.freeze_pretrained()

    print("===> Setting GPU")
    if cuda:
        model = torch.nn.DataParallel(model).cuda()
        criterion = criterion.cuda()
        Absloss = Absloss.cuda()
        ssim_loss = ssim_loss.cuda()
        #loss_var = loss_var.cuda()
        vgg = Vgg16(requires_grad=False).cuda()

    # optionally resume from a checkpoint
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("===> loading checkpoint: {}".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            opt.start_epoch = checkpoint["epoch"] + 1
            model.load_state_dict(checkpoint["model"].state_dict())
        else:
            print("===> no checkpoint found at {}".format(opt.resume))

    # optionally copy weights from a checkpoint
    if opt.pretrained:
        if os.path.isfile(opt.pretrained):
            pretrained_dict = torch.load(opt.pretrained)['model'].state_dict()
            print("===> load model {}".format(opt.pretrained))
            model_dict = model.state_dict()
            # filter out unnecessary keys
            pretrained_dict = {
                k: v
                for k, v in pretrained_dict.items() if k in model_dict
            }
            print("\t...loaded parameters:")
            for k, v in pretrained_dict.items():
                print("\t\t+{}".format(k))
            model_dict.update(pretrained_dict)
            model.load_state_dict(model_dict)
            # weights = torch.load(opt.pretrained)
            # model.load_state_dict(weights['model'].state_dict())
        else:
            print("===> no model found at {}".format(opt.pretrained))

    print("===> Setting Optimizer")
    optimizer = optim.Adam(
        model.parameters(), lr=opt.lr,
        weight_decay=opt.weight_decay)  #weight_decay=opt.weight_decay

    print("===> Training")
    for epoch in range(opt.start_epoch, opt.nEpochs + 1):

        # Evaluate validation dataset and save images
        if epoch % 1 == 0:
            save_val_path = os.path.join('test', opt.model + '_' + opt.ID)
            checkdirctexist(save_val_path)
            image_list = glob.glob(os.path.join(opt.valdataset, '*.png'))
            for image_name in image_list:
                print("Processing ", image_name)
                img = cv2.imread(image_name)
                img = img.astype(np.float32)
                H, W, C = img.shape
                P = 512
                print("\t\tBreak image into patches of {}x{}".format(P, P))

                Wk = W
                Hk = H
                if W % 32:
                    Wk = W + (32 - W % 32)
                if H % 32:
                    Hk = H + (32 - H % 32)
                    img = np.pad(img, ((0, Hk - H), (0, Wk - W), (0, 0)),
                                 'reflect')
                    im_input = img / 255.0
                    im_input = np.expand_dims(np.rollaxis(im_input, 2), axis=0)
                    im_input_rollback = np.rollaxis(im_input[0], 0, 3)
                    with torch.no_grad():
                        im_input = Variable(torch.from_numpy(im_input).float())
                        im_input = im_input.cuda()
                        model.eval()

                        J, J1, J2, J3, w1, w2, w3 = model(im_input, opt)
                        im_output = J

                    im_output = im_output.cpu()
                    im_output_forsave = get_image_for_save(im_output)
                    J1_output = J1.cpu()
                    J1_output_forsave = get_image_for_save(J1_output)
                    J2_output = J2.cpu()
                    J2_output_forsave = get_image_for_save(J2_output)
                    J3_output = J3.cpu()
                    J3_output_forsave = get_image_for_save(J3_output)
                    W1_output = w1.cpu()
                    W1_output_forsave = get_image_for_save(W1_output)
                    W2_output = w2.cpu()
                    W2_output_forsave = get_image_for_save(W2_output)
                    W3_output = w3.cpu()
                    W3_output_forsave = get_image_for_save(W3_output)

                    path, filename = os.path.split(image_name)

                    im_output_forsave = im_output_forsave[0:H, 0:W, :]
                    J1_output_forsave = J1_output_forsave[0:H, 0:W, :]
                    J2_output_forsave = J2_output_forsave[0:H, 0:W, :]
                    J3_output_forsave = J3_output_forsave[0:H, 0:W, :]
                    W1_output_forsave = W1_output_forsave[0:H, 0:W, :]
                    W2_output_forsave = W2_output_forsave[0:H, 0:W, :]
                    W3_output_forsave = W3_output_forsave[0:H, 0:W, :]

                    cv2.imwrite(
                        os.path.join(save_val_path,
                                     "{}_IM_{}".format(epoch - 1, filename)),
                        im_output_forsave)
                    cv2.imwrite(
                        os.path.join(save_val_path,
                                     "{}_J1_{}".format(epoch - 1, filename)),
                        J1_output_forsave)
                    cv2.imwrite(
                        os.path.join(save_val_path,
                                     "{}_J2_{}".format(epoch - 1, filename)),
                        J2_output_forsave)
                    cv2.imwrite(
                        os.path.join(save_val_path,
                                     "{}_J3_{}".format(epoch - 1, filename)),
                        J3_output_forsave)
                    cv2.imwrite(
                        os.path.join(save_val_path,
                                     "{}_W1_{}".format(epoch - 1, filename)),
                        W1_output_forsave)
                    cv2.imwrite(
                        os.path.join(save_val_path,
                                     "{}_W2_{}".format(epoch - 1, filename)),
                        W2_output_forsave)
                    cv2.imwrite(
                        os.path.join(save_val_path,
                                     "{}_W3_{}".format(epoch - 1, filename)),
                        W3_output_forsave)
        train(training_data_loader, optimizer, model, criterion, Absloss,
              ssim_loss, epoch, vgg)
        save_checkpoint(model, epoch, save_path)