def train(model,
          train_loader,
          test_loader,
          mode='EDSR_Baseline',
          save_image_every=50,
          save_model_every=10,
          test_model_every=1,
          epoch_start=0,
          num_epochs=1000,
          device=None,
          refresh=True):

    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'

    today = datetime.datetime.now().strftime('%Y.%m.%d')

    result_dir = f'./results/{today}/{mode}'
    weight_dir = f'./weights/{today}/{mode}'
    logger_dir = f'./logger/{today}_{mode}'
    csv = f'./hist_{today}_{mode}.csv'
    if refresh:
        try:
            shutil.rmtree(result_dir)
            shutil.rmtree(weight_dir)
            shutil.rmtree(logger_dir)
        except FileNotFoundError:
            pass
    os.makedirs(result_dir, exist_ok=True)
    os.makedirs(weight_dir, exist_ok=True)
    os.makedirs(logger_dir, exist_ok=True)
    logger = SummaryWriter(log_dir=logger_dir, flush_secs=2)
    model = model.to(device)

    params = list(model.parameters())
    optim = torch.optim.Adam(params, lr=1e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optim,
                                                step_size=1000,
                                                gamma=0.99)
    criterion = torch.nn.L1Loss()
    GMSD = GMSD_quality().to(device)
    mshf = MSHF(3, 3).to(device)

    start_time = time.time()
    print(f'Training Start || Mode: {mode}')

    step = 0
    pfix = OrderedDict()
    pfix_test = OrderedDict()

    hist = dict()
    hist['mode'] = f'{today}_{mode}'
    for key in ['epoch', 'psnr', 'ssim', 'ms-ssim']:
        hist[key] = []

    blurs = {}
    for ksize in [3, 5, 7]:
        blurs[ksize] = {}
        for sigma in [0.4, 0.8, 1.0, 1.2, 1.6, 2.0]:
            blurs[ksize][sigma] = Blur(ksize=ksize, sigma=sigma).to(device)
    noise_sigma = 0.3

    for epoch in range(epoch_start, epoch_start + num_epochs):

        if epoch == 0:
            torch.save(model.state_dict(),
                       f'{weight_dir}/epoch_{epoch+1:04d}.pth')

        if epoch == 0:
            with torch.no_grad():
                with tqdm(
                        test_loader,
                        desc=
                        f'{mode} || Warming Up || Test Epoch {epoch}/{num_epochs}',
                        position=0,
                        leave=True) as pbar_test:
                    psnrs = []
                    ssims = []
                    msssims = []
                    for lr, hr, fname in pbar_test:
                        lr = lr.to(device)
                        # hr = hr.to(device)

                        blur = blurs[7][2.0]

                        lr_input = blur(lr)
                        lr_input = lr_input + torch.rand_like(
                            lr, device=lr.device) * noise_sigma

                        _, features = model(lr_input)
                        dr = features[0]
                        # sr = quantize(sr)

                        psnr, ssim, msssim = evaluate(lr, dr)

                        psnrs.append(psnr)
                        ssims.append(ssim)
                        msssims.append(msssim)

                        psnr_mean = np.array(psnrs).mean()
                        ssim_mean = np.array(ssims).mean()
                        msssim_mean = np.array(msssims).mean()

                        pfix_test['psnr'] = f'{psnr:.4f}'
                        pfix_test['ssim'] = f'{ssim:.4f}'
                        pfix_test['msssim'] = f'{msssim:.4f}'
                        pfix_test['psnr_mean'] = f'{psnr_mean:.4f}'
                        pfix_test['ssim_mean'] = f'{ssim_mean:.4f}'
                        pfix_test['msssim_mean'] = f'{msssim_mean:.4f}'

                        pbar_test.set_postfix(pfix_test)
                        if len(psnrs) > 1: break

        with tqdm(train_loader,
                  desc=f'{mode} || Epoch {epoch+1}/{num_epochs}',
                  position=0,
                  leave=True) as pbar:
            psnrs = []
            ssims = []
            msssims = []
            losses = []
            for lr, hr, _ in pbar:
                lr = lr.to(device)
                # hr = hr.to(device)

                # prediction
                ksize_ = random.choice([3, 5, 7])
                sigma_ = random.choice([0.4, 0.8, 1.0, 1.2, 1.6, 2.0])
                blur = blurs[ksize_][sigma_]

                dnd = random.choice(['blur', 'noise', 'blur_and_noise'])
                if dnd == 'blur':
                    lr_input = blur(lr)
                elif dnd == 'noise':
                    lr_input = lr + torch.rand_like(
                        lr, device=lr.device) * noise_sigma
                else:
                    lr_input = blur(lr)
                    lr_input = lr_input + torch.rand_like(
                        lr, device=lr.device) * noise_sigma

                _, features = model(lr_input)
                dr = features[0]

                gmsd = GMSD(lr, dr)

                # training
                loss = criterion(lr, dr)
                loss_tot = loss
                optim.zero_grad()
                loss_tot.backward()
                optim.step()
                scheduler.step()

                # training history
                elapsed_time = time.time() - start_time
                elapsed = sec2time(elapsed_time)
                pfix['Step'] = f'{step+1}'
                pfix['Loss'] = f'{loss.item():.4f}'

                psnr, ssim, msssim = evaluate(lr, dr)

                psnrs.append(psnr)
                ssims.append(ssim)
                msssims.append(msssim)

                psnr_mean = np.array(psnrs).mean()
                ssim_mean = np.array(ssims).mean()
                msssim_mean = np.array(msssims).mean()

                pfix['PSNR'] = f'{psnr:.2f}'
                pfix['SSIM'] = f'{ssim:.4f}'
                # pfix['MSSSIM'] = f'{msssim:.4f}'
                pfix['PSNR_mean'] = f'{psnr_mean:.2f}'
                pfix['SSIM_mean'] = f'{ssim_mean:.4f}'
                # pfix['MSSSIM_mean'] = f'{msssim_mean:.4f}'

                free_gpu = get_gpu_memory()[0]

                pfix['free GPU'] = f'{free_gpu}MiB'
                pfix['Elapsed'] = f'{elapsed}'

                pbar.set_postfix(pfix)
                losses.append(loss.item())

                if step % save_image_every == 0:

                    imsave([lr_input[0], dr[0], lr[0], gmsd[0]],
                           f'{result_dir}/epoch_{epoch+1}_iter_{step:05d}.jpg')

                step += 1

            logger.add_scalar("Loss/train", np.array(losses).mean(), epoch + 1)
            logger.add_scalar("PSNR/train", psnr_mean, epoch + 1)
            logger.add_scalar("SSIM/train", ssim_mean, epoch + 1)

            if (epoch + 1) % save_model_every == 0:
                torch.save(model.state_dict(),
                           f'{weight_dir}/epoch_{epoch+1:04d}.pth')

            if (epoch + 1) % test_model_every == 0:

                with torch.no_grad():
                    with tqdm(
                            test_loader,
                            desc=f'{mode} || Test Epoch {epoch+1}/{num_epochs}',
                            position=0,
                            leave=True) as pbar_test:
                        psnrs = []
                        ssims = []
                        msssims = []
                        for lr, hr, fname in pbar_test:

                            fname = fname[0].split('/')[-1].split('.pt')[0]

                            lr = lr.to(device)
                            # hr = hr.to(device)

                            blur = blurs[7][2.0]
                            lr_input = blur(lr)
                            lr_input = lr_input + torch.rand_like(
                                lr, device=lr.device) * noise_sigma

                            _, features = model(lr_input)
                            dr = features[0]

                            mshf_lr = mshf(lr)
                            mshf_dr = mshf(dr)

                            gmsd = GMSD(lr, dr)

                            psnr, ssim, msssim = evaluate(lr, dr)

                            psnrs.append(psnr)
                            ssims.append(ssim)
                            msssims.append(msssim)

                            psnr_mean = np.array(psnrs).mean()
                            ssim_mean = np.array(ssims).mean()
                            msssim_mean = np.array(msssims).mean()

                            pfix_test['psnr'] = f'{psnr:.4f}'
                            pfix_test['ssim'] = f'{ssim:.4f}'
                            pfix_test['msssim'] = f'{msssim:.4f}'
                            pfix_test['psnr_mean'] = f'{psnr_mean:.4f}'
                            pfix_test['ssim_mean'] = f'{ssim_mean:.4f}'
                            pfix_test['msssim_mean'] = f'{msssim_mean:.4f}'

                            pbar_test.set_postfix(pfix_test)

                            imsave([lr_input[0], dr[0], lr[0], gmsd[0]],
                                   f'{result_dir}/{fname}.jpg')

                            mshf_vis = torch.cat(
                                (torch.cat([
                                    mshf_dr[:, i, :, :]
                                    for i in range(mshf_dr.shape[1])
                                ],
                                           dim=-1),
                                 torch.cat([
                                     mshf_lr[:, i, :, :]
                                     for i in range(mshf_lr.shape[1])
                                 ],
                                           dim=-1)),
                                dim=-2)

                            imsave(mshf_vis, f'{result_dir}/MSHF_{fname}.jpg')

                        hist['epoch'].append(epoch + 1)
                        hist['psnr'].append(psnr_mean)
                        hist['ssim'].append(ssim_mean)
                        hist['ms-ssim'].append(msssim_mean)

                        logger.add_scalar("PSNR/test", psnr_mean, epoch + 1)
                        logger.add_scalar("SSIM/test", ssim_mean, epoch + 1)
                        logger.add_scalar("MS-SSIM/test", msssim_mean,
                                          epoch + 1)

                        df = pd.DataFrame(hist)
                        df.to_csv(csv)
Ejemplo n.º 2
0
    if args.n_epoch > 2:
        argmax = t_list.argmax()
        argmin = t_list.argmin()
        t_list[argmax] = 0
        t_list[argmin] = 0
        t_avg = np.sum(t_list) / (args.n_epoch - 2)
    else:
        t_avg = np.sum(t_list) / args.n_epoch
    print(
        'Finish %d images for %d times in %.4fs, speed = %.4f image/s (%.4f ms/image)'
        % (args.n_sample, args.n_epoch, t_end - t_start, args.n_sample / t_avg,
           t_avg * 1000.0 / args.n_sample))

    print('===================== benchmark finished =====================')

    from utils import get_gpu_memory
    gpu_mem = get_gpu_memory()

    #  save results
    res_dir = 'cache/results'
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)

    res_file_path = os.path.join(
        res_dir,
        '%s_%s_%s_%d.txt' % (DLLIB, args.network, args.dtype, args.batch_size))
    with open(res_file_path, 'w') as fd:
        fd.write('%s %s %s %d %f %d\n' %
                 (DLLIB, args.network, args.dtype, args.batch_size,
                  args.n_sample / t_avg, gpu_mem))
Ejemplo n.º 3
0
def train(model,
          train_loader,
          test_loader,
          mode='EDSR_Baseline',
          save_image_every=50,
          save_model_every=10,
          test_model_every=1,
          num_epochs=1000,
          device=None,
          refresh=True):

    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'

    today = datetime.datetime.now().strftime('%Y.%m.%d')

    result_dir = f'./results/{today}/{mode}'
    weight_dir = f'./weights/{today}/{mode}'
    logger_dir = f'./logger/{today}_{mode}'
    csv = f'./hist_{today}_{mode}.csv'
    if refresh:
        try:
            shutil.rmtree(result_dir)
            shutil.rmtree(weight_dir)
            shutil.rmtree(logger_dir)
        except FileNotFoundError:
            pass
    os.makedirs(result_dir, exist_ok=True)
    os.makedirs(weight_dir, exist_ok=True)
    os.makedirs(logger_dir, exist_ok=True)
    logger = SummaryWriter(log_dir=logger_dir, flush_secs=2)
    model = model.to(device)

    params = list(model.parameters())
    optim = torch.optim.Adam(params, lr=1e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optim,
                                                step_size=1000,
                                                gamma=0.99)
    criterion = torch.nn.L1Loss()

    ######

    ED = Edge().to(device)

    ######

    start_time = time.time()
    print(f'Training Start || Mode: {mode}')

    step = 0
    pfix = OrderedDict()
    pfix_test = OrderedDict()

    hist = dict()
    hist['mode'] = f'{today}_{mode}'
    for key in ['epoch', 'psnr', 'ssim', 'ms-ssim']:
        hist[key] = []

    for epoch in range(num_epochs):

        if epoch == 0:
            torch.save(model.state_dict(),
                       f'{weight_dir}/epoch_{epoch+1:04d}.pth')

        if epoch == 0:
            with torch.no_grad():
                with tqdm(
                        test_loader,
                        desc=
                        f'Mode: {mode} || Warming Up || Test Epoch {epoch}/{num_epochs}',
                        position=0,
                        leave=True) as pbar_test:
                    psnrs = []
                    ssims = []
                    msssims = []
                    for lr, hr, fname in pbar_test:
                        lr = lr.to(device)
                        hr = hr.to(device)

                        sr, _, features = model(lr)
                        sr = quantize(sr)

                        psnr, ssim, msssim = evaluate(hr, sr)

                        psnrs.append(psnr)
                        ssims.append(ssim)
                        msssims.append(msssim)

                        psnr_mean = np.array(psnrs).mean()
                        ssim_mean = np.array(ssims).mean()
                        msssim_mean = np.array(msssims).mean()

                        pfix_test['psnr'] = f'{psnr:.4f}'
                        pfix_test['ssim'] = f'{ssim:.4f}'
                        pfix_test['msssim'] = f'{msssim:.4f}'
                        pfix_test['psnr_mean'] = f'{psnr_mean:.4f}'
                        pfix_test['ssim_mean'] = f'{ssim_mean:.4f}'
                        pfix_test['msssim_mean'] = f'{msssim_mean:.4f}'

                        pbar_test.set_postfix(pfix_test)
                        if len(psnrs) > 1: break

        with tqdm(train_loader,
                  desc=f'Mode: {mode} || Epoch {epoch+1}/{num_epochs}',
                  position=0,
                  leave=True) as pbar:
            psnrs = []
            ssims = []
            msssims = []
            losses = []
            for lr, hr, _ in pbar:
                lr = lr.to(device)
                hr = hr.to(device)

                # prediction
                sr, lr_edge, features = model(lr)

                ####
                hr_edge = ED(hr)
                loss_edge = criterion(lr_edge, hr_edge)
                #####

                # training
                loss = criterion(hr, sr)
                loss_tot = loss + 0.1 * loss_edge
                optim.zero_grad()
                loss_tot.backward()
                optim.step()
                scheduler.step()

                # training history
                elapsed_time = time.time() - start_time
                elapsed = sec2time(elapsed_time)
                pfix['Step'] = f'{step+1}'
                pfix['Loss'] = f'{loss.item():.4f}'
                pfix['Loss Edge'] = f'{loss_edge.item():.4f}'

                sr = quantize(sr)
                psnr, ssim, msssim = evaluate(hr, sr)

                psnrs.append(psnr)
                ssims.append(ssim)
                msssims.append(msssim)

                psnr_mean = np.array(psnrs).mean()
                ssim_mean = np.array(ssims).mean()
                msssim_mean = np.array(msssims).mean()

                pfix['PSNR'] = f'{psnr:.2f}'
                pfix['SSIM'] = f'{ssim:.4f}'
                # pfix['MSSSIM'] = f'{msssim:.4f}'
                pfix['PSNR_mean'] = f'{psnr_mean:.2f}'
                pfix['SSIM_mean'] = f'{ssim_mean:.4f}'
                # pfix['MSSSIM_mean'] = f'{msssim_mean:.4f}'

                free_gpu = get_gpu_memory()[0]

                pfix['free GPU'] = f'{free_gpu}MiB'
                pfix['Elapsed'] = f'{elapsed}'

                pbar.set_postfix(pfix)
                losses.append(loss.item())

                if step % save_image_every == 0:

                    z = torch.zeros_like(lr[0])
                    _, _, llr, _ = lr.shape
                    _, _, hlr, _ = hr.shape
                    if hlr // 2 == llr:
                        xz = torch.cat((lr[0], z), dim=-2)
                    elif hlr // 4 == llr:
                        xz = torch.cat((lr[0], z, z, z), dim=-2)
                    imsave([xz, sr[0], hr[0]],
                           f'{result_dir}/epoch_{epoch+1}_iter_{step:05d}.jpg')

                step += 1

            logger.add_scalar("Loss/train", np.array(losses).mean(), epoch + 1)
            logger.add_scalar("PSNR/train", psnr_mean, epoch + 1)
            logger.add_scalar("SSIM/train", ssim_mean, epoch + 1)

            if (epoch + 1) % save_model_every == 0:
                torch.save(model.state_dict(),
                           f'{weight_dir}/epoch_{epoch+1:04d}.pth')

            if (epoch + 1) % test_model_every == 0:

                with torch.no_grad():
                    with tqdm(
                            test_loader,
                            desc=
                            f'Mode: {mode} || Test Epoch {epoch+1}/{num_epochs}',
                            position=0,
                            leave=True) as pbar_test:
                        psnrs = []
                        ssims = []
                        msssims = []
                        for lr, hr, fname in pbar_test:

                            fname = fname[0].split('/')[-1].split('.pt')[0]

                            lr = lr.to(device)
                            hr = hr.to(device)

                            sr, _, features = model(lr)
                            sr = quantize(sr)

                            psnr, ssim, msssim = evaluate(hr, sr)

                            psnrs.append(psnr)
                            ssims.append(ssim)
                            msssims.append(msssim)

                            psnr_mean = np.array(psnrs).mean()
                            ssim_mean = np.array(ssims).mean()
                            msssim_mean = np.array(msssims).mean()

                            pfix_test['psnr'] = f'{psnr:.4f}'
                            pfix_test['ssim'] = f'{ssim:.4f}'
                            pfix_test['msssim'] = f'{msssim:.4f}'
                            pfix_test['psnr_mean'] = f'{psnr_mean:.4f}'
                            pfix_test['ssim_mean'] = f'{ssim_mean:.4f}'
                            pfix_test['msssim_mean'] = f'{msssim_mean:.4f}'

                            pbar_test.set_postfix(pfix_test)

                            z = torch.zeros_like(lr[0])
                            _, _, llr, _ = lr.shape
                            _, _, hlr, _ = hr.shape
                            if hlr // 2 == llr:
                                xz = torch.cat((lr[0], z), dim=-2)
                            elif hlr // 4 == llr:
                                xz = torch.cat((lr[0], z, z, z), dim=-2)
                            imsave([xz, sr[0], hr[0]],
                                   f'{result_dir}/{fname}.jpg')

                        hist['epoch'].append(epoch + 1)
                        hist['psnr'].append(psnr_mean)
                        hist['ssim'].append(ssim_mean)
                        hist['ms-ssim'].append(msssim_mean)

                        logger.add_scalar("PSNR/test", psnr_mean, epoch + 1)
                        logger.add_scalar("SSIM/test", ssim_mean, epoch + 1)
                        logger.add_scalar("MS-SSIM/test", msssim_mean,
                                          epoch + 1)

                        df = pd.DataFrame(hist)
                        df.to_csv(csv)
Ejemplo n.º 4
0
def train(model,
          train_loader,
          test_loader,
          mode='EDSR_Baseline',
          save_image_every=50,
          save_model_every=10,
          test_model_every=1,
          epoch_start=0,
          num_epochs=1000,
          device=None,
          refresh=True,
          scale=2,
          today=None):

    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if today is None:
        today = datetime.datetime.now().strftime('%Y.%m.%d')

    result_dir = f'./results/{today}/{mode}'
    weight_dir = f'./weights/{today}/{mode}'
    logger_dir = f'./logger/{today}_{mode}'
    csv = f'./hist_{today}_{mode}.csv'
    if refresh:
        try:
            shutil.rmtree(result_dir)
            shutil.rmtree(weight_dir)
            shutil.rmtree(logger_dir)
        except FileNotFoundError:
            pass
    os.makedirs(result_dir, exist_ok=True)
    os.makedirs(weight_dir, exist_ok=True)
    os.makedirs(logger_dir, exist_ok=True)
    logger = SummaryWriter(log_dir=logger_dir, flush_secs=2)
    model = model.to(device)

    params = list(model.parameters())
    optim = torch.optim.Adam(params, lr=1e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optim,
                                                step_size=1000,
                                                gamma=0.99)
    criterion = torch.nn.L1Loss()
    GMSD = GMSD_quality().to(device)
    opening = Opening().to(device)
    blur = Blur().to(device)
    mshf = MSHF(3, 3).to(device)

    downx2_bicubic = nn.Upsample(scale_factor=1 / 2,
                                 mode='bicubic',
                                 align_corners=False)
    downx4_bicubic = nn.Upsample(scale_factor=1 / 4,
                                 mode='bicubic',
                                 align_corners=False)

    start_time = time.time()
    print(f'Training Start || Mode: {mode}')

    step = 0
    pfix = OrderedDict()
    pfix_test = OrderedDict()

    hist = dict()
    hist['mode'] = f'{today}_{mode}'
    for key in ['epoch', 'psnr', 'ssim', 'ms-ssim']:
        hist[key] = []

    soft_mask = False

    # hf_kernel = get_hf_kernel(mode='high')

    for epoch in range(epoch_start, epoch_start + num_epochs):

        if epoch == 0:
            torch.save(model.state_dict(),
                       f'{weight_dir}/epoch_{epoch+1:04d}.pth')

        if epoch == 0:
            with torch.no_grad():
                with tqdm(
                        test_loader,
                        desc=
                        f'{mode} || Warming Up || Test Epoch {epoch}/{num_epochs}',
                        position=0,
                        leave=True) as pbar_test:
                    psnrs = []
                    ssims = []
                    msssims = []
                    for lr, hr, fname in pbar_test:
                        lr = lr.to(device)
                        hr = hr.to(device)

                        sr, srx2, srx1 = model(lr)

                        sr = quantize(sr)

                        psnr, ssim, msssim = evaluate(hr, sr)

                        psnrs.append(psnr)
                        ssims.append(ssim)
                        msssims.append(msssim)

                        psnr_mean = np.array(psnrs).mean()
                        ssim_mean = np.array(ssims).mean()
                        msssim_mean = np.array(msssims).mean()

                        pfix_test['psnr_mean'] = f'{psnr_mean:.4f}'
                        pfix_test['ssim_mean'] = f'{ssim_mean:.4f}'
                        pfix_test['msssim_mean'] = f'{msssim_mean:.4f}'

                        pbar_test.set_postfix(pfix_test)
                        if len(psnrs) > 1: break

        with tqdm(train_loader,
                  desc=f'{mode} || Epoch {epoch+1}/{num_epochs}',
                  position=0,
                  leave=True) as pbar:
            psnrs = []
            ssims = []
            msssims = []
            losses = []
            for lr, hr, _ in pbar:
                lr = lr.to(device)
                hr = hr.to(device)

                hrx1 = downx4_bicubic(hr)
                hrx2 = downx2_bicubic(hr)

                # prediction
                sr, srx2, srx1 = model(lr)

                gmsd = GMSD(hr, sr)

                sr_ = quantize(sr)
                psnr, ssim, msssim = evaluate(hr, sr_)

                if psnr >= 40 - 2 * scale:
                    soft_mask = True
                else:
                    soft_mask = False

                if soft_mask:
                    # with torch.no_grad():
                    #     for _ in range(10): gmsd = opening(gmsd)
                    gmask = gmsd / gmsd.max()
                    gmask = (gmask > 0.2) * 1.0
                    gmask = blur(gmask)
                    gmask = (gmask - gmask.min()) / (gmask.max() -
                                                     gmask.min() + 1e-7)
                    gmask = (gmask + 0.25) / 1.25
                    gmask = gmask.detach()

                    gmaskx2 = downx2_bicubic(gmask)
                    gmaskx1 = downx4_bicubic(gmask)

                    # training
                    loss = criterion(sr * gmask, hr * gmask)
                    lossx2 = criterion(srx2 * gmaskx2, hrx2 * gmaskx2)
                    lossx1 = criterion(srx1 * gmaskx1, hrx1 * gmaskx1)
                else:
                    loss = criterion(sr, hr)
                    lossx2 = criterion(srx2, hrx2)
                    lossx1 = criterion(srx1, hrx1)

                # training
                loss_tot = loss + 0.25 * lossx2 + 0.125 * lossx1
                optim.zero_grad()
                loss_tot.backward()
                optim.step()
                scheduler.step()

                # training history
                elapsed_time = time.time() - start_time
                elapsed = sec2time(elapsed_time)
                pfix['Loss'] = f'{loss.item():.4f}'
                pfix['x2'] = f'{lossx2.item():.4f}'
                pfix['x1'] = f'{lossx1.item():.4f}'

                psnrs.append(psnr)
                ssims.append(ssim)
                msssims.append(msssim)

                psnr_mean = np.array(psnrs).mean()
                ssim_mean = np.array(ssims).mean()
                msssim_mean = np.array(msssims).mean()

                pfix['PSNR_mean'] = f'{psnr_mean:.2f}'
                pfix['SSIM_mean'] = f'{ssim_mean:.4f}'

                free_gpu = get_gpu_memory()[0]

                pfix['Elapsed'] = f'{elapsed}'
                pfix['free GPU'] = f'{free_gpu}MiB'

                pbar.set_postfix(pfix)
                losses.append(loss.item())

                if step % save_image_every == 0:

                    z = torch.zeros_like(lr[0])
                    _, _, llr, _ = lr.shape
                    _, _, hlr, _ = hr.shape
                    if hlr // 2 == llr:
                        xz = torch.cat((lr[0], z), dim=-2)
                    elif hlr // 4 == llr:
                        xz = torch.cat((lr[0], z, z, z), dim=-2)
                    imsave([xz, sr[0], hr[0], gmsd[0]],
                           f'{result_dir}/epoch_{epoch+1}_iter_{step:05d}.jpg')

                step += 1

            logger.add_scalar("Loss/train", np.array(losses).mean(), epoch + 1)
            logger.add_scalar("PSNR/train", psnr_mean, epoch + 1)
            logger.add_scalar("SSIM/train", ssim_mean, epoch + 1)

            if (epoch + 1) % save_model_every == 0:
                torch.save(model.state_dict(),
                           f'{weight_dir}/epoch_{epoch+1:04d}.pth')

            if (epoch + 1) % test_model_every == 0:

                with torch.no_grad():
                    with tqdm(
                            test_loader,
                            desc=f'{mode} || Test Epoch {epoch+1}/{num_epochs}',
                            position=0,
                            leave=True) as pbar_test:
                        psnrs = []
                        ssims = []
                        msssims = []
                        for lr, hr, fname in pbar_test:

                            fname = fname[0].split('/')[-1].split('.pt')[0]

                            lr = lr.to(device)
                            hr = hr.to(device)

                            sr, _, _ = model(lr)

                            mshf_hr = mshf(hr)
                            mshf_sr = mshf(sr)

                            gmsd = GMSD(hr, sr)

                            sr = quantize(sr)

                            psnr, ssim, msssim = evaluate(hr, sr)

                            psnrs.append(psnr)
                            ssims.append(ssim)
                            msssims.append(msssim)

                            psnr_mean = np.array(psnrs).mean()
                            ssim_mean = np.array(ssims).mean()
                            msssim_mean = np.array(msssims).mean()

                            pfix_test['psnr_mean'] = f'{psnr_mean:.4f}'
                            pfix_test['ssim_mean'] = f'{ssim_mean:.4f}'
                            pfix_test['msssim_mean'] = f'{msssim_mean:.4f}'

                            pbar_test.set_postfix(pfix_test)

                            z = torch.zeros_like(lr[0])
                            _, _, llr, _ = lr.shape
                            _, _, hlr, _ = hr.shape
                            if hlr // 2 == llr:
                                xz = torch.cat((lr[0], z), dim=-2)
                            elif hlr // 4 == llr:
                                xz = torch.cat((lr[0], z, z, z), dim=-2)
                            imsave([xz, sr[0], hr[0], gmsd[0]],
                                   f'{result_dir}/{fname}.jpg')

                            mshf_vis = torch.cat(
                                (torch.cat([
                                    mshf_sr[:, i, :, :]
                                    for i in range(mshf_sr.shape[1])
                                ],
                                           dim=-1),
                                 torch.cat([
                                     mshf_hr[:, i, :, :]
                                     for i in range(mshf_hr.shape[1])
                                 ],
                                           dim=-1)),
                                dim=-2)

                            imsave(mshf_vis, f'{result_dir}/MSHF_{fname}.jpg')

                        hist['epoch'].append(epoch + 1)
                        hist['psnr'].append(psnr_mean)
                        hist['ssim'].append(ssim_mean)
                        hist['ms-ssim'].append(msssim_mean)

                        logger.add_scalar("PSNR/test", psnr_mean, epoch + 1)
                        logger.add_scalar("SSIM/test", ssim_mean, epoch + 1)
                        logger.add_scalar("MS-SSIM/test", msssim_mean,
                                          epoch + 1)

                        df = pd.DataFrame(hist)
                        df.to_csv(csv)
    return model
Ejemplo n.º 5
0
              position=0,
              leave=True) as pbar:
        for lr, hr, _ in pbar:

            # prediction
            pred = model(lr)

            # training
            loss = criterion(hr, pred)
            optim.zero_grad()
            loss.backward()
            optim.step()
            scheduler.step()

            # training history
            free_gpu = get_gpu_memory()[0]
            elapsed_time = time.time() - start_time
            elapsed = sec2time(elapsed_time)
            pfix['Step'] = f'{step+1}'
            pfix['Loss'] = f'{loss.item():.4f}'
            pfix['free GPU'] = f'{free_gpu}MiB'
            pfix['Elapsed'] = f'{elapsed}'
            hist['Iter'].append(step)
            hist['Loss'].append(loss.item())
            pbar.set_postfix(pfix)

            if step % save_image_every == 0:

                z = torch.zeros_like(lr[0])
                xz = torch.cat((lr[0], z), dim=-2)
                img = torch.cat((xz, pred[0], hr[0]), dim=-1)
Ejemplo n.º 6
0
parser.add_argument("-e","--epochs", type=int, default=25)
parser.add_argument("-f","--finetune", type=int, default=0)
args = parser.parse_args()
print("Arguments: {}".format(args))

model_names = available_models()
if args.model_name not in model_names:
    raise ValueError("Wrong model name, please select one among {}".format(model_names))
        
print("OS:", sys.platform)
print("Python:", sys.version)
print("PyTorch:", torch.__version__)
print("Numpy:", np.__version__)
print("Number of CPU processors:", get_number_processors())
print("GPU:", get_gpu_name())
print("GPU memory: {}".format(get_gpu_memory()))
print("CUDA version:", get_cuda_version())
print("CUDNN version:", torch.backends.cudnn.version())

torch.backends.cudnn.benchmark=True # enables cudnn's auto-tuner


# Datasets
dataset = create_dataset(args.dataset_path, batch_size=args.batch_size)


# Training
if args.finetune:
    model, metrics = finetune(dataset, args.model_name, SETS, args.epochs, args.gpus, 
                              args.learning_rate, args.momentum, args.learning_rate_step, 
                              args.learning_rate_epochs, verbose=True)
Ejemplo n.º 7
0
    for content_id in content_ids:
        new_result['content_ids'].append(content_id)
        if content_id in content_ids_map and content_ids_map[
                content_id] in result:
            new_result['ratings'].append(
                float(result[content_ids_map[content_id]]))
        else:
            new_result['ratings'].append(999.0)  # TODO Fake data
    return make_response(jsonify(new_result), STATUS_OK)


print("OS: ", sys.platform)
print("Python: ", sys.version)
print("PyTorch: ", torch.__version__)
print("Numpy: ", np.__version__)
print("Number of CPU processors: ", get_number_processors())
print("GPU: ", get_gpu_name())
print("GPU memory: ", get_gpu_memory())
print("CUDA: ", get_cuda_version())
print("USE_GPU: ", USE_GPU)

#Load data and model as global variables
data_layer, inv_userIdMap, inv_itemIdMap = load_train_data(TRAIN)
rencoder_api = load_recommender(data_layer.vector_dim, HIDDEN, ACTIVATION,
                                DROPOUT, MODEL_PATH)
muids_map, content_ids_map = load_train_muid_and_content_id(
    TRAIN_MUID, TRAIN_CONTENT_ID)

if __name__ == "__main__":
    run_server()
def train(model, model_sr, train_loader, test_loader, mode='EDSR_Baseline', save_image_every=50, save_model_every=10, test_model_every=1, epoch_start=0, num_epochs=1000, device=None, refresh=True):

    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'

    today = datetime.datetime.now().strftime('%Y.%m.%d')
    
    result_dir = f'./results/{today}/{mode}'
    weight_dir = f'./weights/{today}/{mode}'
    logger_dir = f'./logger/{today}_{mode}'
    csv = f'./hist_{today}_{mode}.csv'
    if refresh:
        try:
            shutil.rmtree(result_dir)
            shutil.rmtree(weight_dir)
            shutil.rmtree(logger_dir)
        except FileNotFoundError:
            pass
    os.makedirs(result_dir, exist_ok=True)
    os.makedirs(weight_dir, exist_ok=True)
    os.makedirs(logger_dir, exist_ok=True)
    logger = SummaryWriter(log_dir=logger_dir, flush_secs=2)
    model = model.to(device)
    model_sr = model_sr.to(device)

    params = list(model.parameters())
    optim = torch.optim.Adam(params, lr=1e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=1000, gamma= 0.99)
    criterion = torch.nn.L1Loss()
    GMSD = GMSD_quality().to(device)
    mshf = MSHF(3, 3).to(device)

    start_time = time.time()
    print(f'Training Start || Mode: {mode}')

    step = 0
    pfix = OrderedDict()
    pfix_test = OrderedDict()

    hist = dict()
    hist['mode'] = f'{today}_{mode}'
    for key in ['epoch', 'psnr', 'ssim', 'ms-ssim']:
        hist[key] = []

    for epoch in range(epoch_start, epoch_start+num_epochs):

        if epoch == 0:
            torch.save(model.state_dict(), f'{weight_dir}/epoch_{epoch+1:04d}.pth')
            
        if epoch == 0:
            with torch.no_grad():
                with tqdm(test_loader, desc=f'{mode} || Warming Up || Test Epoch {epoch}/{num_epochs}', position=0, leave=True) as pbar_test:
                    psnrs = []
                    ssims = []
                    msssims = []
                    for lr, hr, fname in pbar_test:
                        lr = lr.to(device)
                        hr = hr.to(device)
                                                
                        sr, deep = model_sr(lr)
                        
                        fake = model(sr)
                        
                        sr = quantize(sr)
                        
                        psnr, ssim, msssim = evaluate(hr, sr)
                        
                        psnrs.append(psnr)
                        ssims.append(ssim)
                        msssims.append(msssim)
                        
                        psnr_mean = np.array(psnrs).mean()
                        ssim_mean = np.array(ssims).mean()
                        msssim_mean = np.array(msssims).mean()

                        pfix_test['psnr'] = f'{psnr:.4f}'
                        pfix_test['ssim'] = f'{ssim:.4f}'
                        pfix_test['msssim'] = f'{msssim:.4f}'
                        pfix_test['psnr_mean'] = f'{psnr_mean:.4f}'
                        pfix_test['ssim_mean'] = f'{ssim_mean:.4f}'
                        pfix_test['msssim_mean'] = f'{msssim_mean:.4f}'

                        pbar_test.set_postfix(pfix_test)
                        if len(psnrs) > 1: break
                        

        with tqdm(train_loader, desc=f'{mode} || Epoch {epoch+1}/{num_epochs}', position=0, leave=True) as pbar:
            psnrs = []
            ssims = []
            msssims = []
            losses = []
            for lr, hr, _ in pbar:
                lr = lr.to(device)
                hr = hr.to(device)
                                
                # prediction
                sr, deep = model_sr(lr)
                
                fake = model(sr)
                loss_fake = criterion(fake, torch.zeros_like(fake, device=fake.device))
                
                real = model(hr)
                loss_real = criterion(real, torch.ones_like(real, device=real.device))
                
                # training
                loss_tot = loss_fake + loss_real
                optim.zero_grad()
                loss_tot.backward()
                optim.step()
                scheduler.step()
                
                # training history 
                elapsed_time = time.time() - start_time
                elapsed = sec2time(elapsed_time)            
                pfix['Step'] = f'{step+1}'
                pfix['Loss real'] = f'{loss_real.item():.4f}'
                pfix['Loss fake'] = f'{loss_fake.item():.4f}'
                
                free_gpu = get_gpu_memory()[0]
                
                pbar.set_postfix(pfix)
                step += 1
                
            if (epoch+1) % save_model_every == 0:
                torch.save(model.state_dict(), f'{weight_dir}/epoch_{epoch+1:04d}.pth')
                
            if (epoch+1) % test_model_every == 0:
                
                with torch.no_grad():
                    with tqdm(test_loader, desc=f'{mode} || Test Epoch {epoch+1}/{num_epochs}', position=0, leave=True) as pbar_test:
                        psnrs = []
                        ssims = []
                        msssims = []
                        for lr, hr, fname in pbar_test:
                                        
                            lr = lr.to(device)
                            hr = hr.to(device)
                                            
                            # prediction
                            sr, deep = model_sr(lr)
                            
                            fake = model(sr)
                            loss_fake = criterion(fake, torch.zeros_like(fake, device=fake.device))
                            
                            real = model(hr)
                            loss_real = criterion(real, torch.ones_like(real, device=real.device))
                            
                            # training history 
                            elapsed_time = time.time() - start_time
                            elapsed = sec2time(elapsed_time)            
                            pfix_test['Step'] = f'{step+1}'
                            pfix_test['Loss real'] = f'{loss_real.item():.4f}'
                            pfix_test['Loss fake'] = f'{loss_fake.item():.4f}'
                            
                            pbar_test.set_postfix(pfix_test)
Ejemplo n.º 9
0
def train(model, train_loader, test_loader, mode='EDSR_Baseline', save_image_every=50, save_model_every=10, test_model_every=1, epoch_start=0, num_epochs=1000, device=None, refresh=True, scale=2, today=None):

    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if today is None:
        today = datetime.datetime.now().strftime('%Y.%m.%d')
    
    result_dir = f'./results/{today}/{mode}'
    weight_dir = f'./weights/{today}/{mode}'
    logger_dir = f'./logger/{today}_{mode}'
    csv = f'./hist_{today}_{mode}.csv'
    if refresh:
        try:
            shutil.rmtree(result_dir)
            shutil.rmtree(weight_dir)
            shutil.rmtree(logger_dir)
        except FileNotFoundError:
            pass
    os.makedirs(result_dir, exist_ok=True)
    os.makedirs(weight_dir, exist_ok=True)
    os.makedirs(logger_dir, exist_ok=True)
    logger = SummaryWriter(log_dir=logger_dir, flush_secs=2)
    model = model.to(device)

    params = list(model.parameters())
    optim = torch.optim.Adam(params, lr=1e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=1000, gamma= 0.99)
    criterion = torch.nn.L1Loss()
    
    start_time = time.time()
    print(f'Training Start || Mode: {mode}')

    step = 0
    pfix = OrderedDict()
    pfix_test = OrderedDict()

    hist = dict()
    hist['mode'] = f'{today}_{mode}'
    for key in ['epoch', 'psnr', 'ssim', 'ms-ssim']:
        hist[key] = []

    soft_mask = False
    
    sigmas = [10/255, 30/255, 50/255]
    
    # hf_kernel = get_hf_kernel(mode='high')

    for epoch in range(epoch_start, epoch_start+num_epochs):
        sigma = 0.0004 * (epoch+1)
        if epoch == 0:
            torch.save(model.state_dict(), f'{weight_dir}/epoch_{epoch+1:04d}.pth')
            
        if epoch == 0:
            with torch.no_grad():
                with tqdm(test_loader, desc=f'{mode} || Warming Up || Test Epoch {epoch}/{num_epochs}', position=0, leave=True) as pbar_test:
                    psnrs = []
                    ssims = []
                    msssims = []
                    for lr, hr, fname in pbar_test:
                        lr = lr.to(device)                        
                        hr = hr.to(device)
                        
                        sigma = np.random.choice(sigmas)
                        hr_input = hr + torch.randn_like(hr, device=hr.device)*sigma
                        hr_input = torch.clamp(hr_input, 0, 1)
                        
                        sr = model(hr_input)
                        
                        sr = quantize(sr)
                        
                        psnr, ssim, msssim = evaluate(hr, sr)
                        
                        psnrs.append(psnr)
                        ssims.append(ssim)
                        msssims.append(msssim)
                        
                        psnr_mean = np.array(psnrs).mean()
                        ssim_mean = np.array(ssims).mean()
                        msssim_mean = np.array(msssims).mean()

                        pfix_test['psnr_mean'] = f'{psnr_mean:.4f}'
                        pfix_test['ssim_mean'] = f'{ssim_mean:.4f}'
                        pfix_test['msssim_mean'] = f'{msssim_mean:.4f}'

                        pbar_test.set_postfix(pfix_test)
                        if len(psnrs) > 1: break
                        

        with tqdm(train_loader, desc=f'{mode} || Epoch {epoch+1}/{num_epochs}', position=0, leave=True) as pbar:
            psnrs = []
            ssims = []
            msssims = []
            losses = []
            for lr, hr, _ in pbar:
                lr = lr.to(device)
                hr = hr.to(device)

                sigma = np.random.choice(sigmas)
                hr_input = hr + torch.randn_like(hr, device=hr.device)*sigma
                hr_input = torch.clamp(hr_input, 0, 1)
                
                sr = model(hr_input)
                
                sr_ = quantize(sr)      
                psnr, ssim, msssim = evaluate(hr, sr_)
                
                loss = criterion(sr, hr)
                # training
                loss_tot = loss
                optim.zero_grad()
                loss_tot.backward()
                optim.step()
                scheduler.step()
                
                # training history 
                elapsed_time = time.time() - start_time
                elapsed = sec2time(elapsed_time)            
                pfix['Step'] = f'{step+1}'
                pfix['Loss'] = f'{loss.item():.4f}'
                
                psnrs.append(psnr)
                ssims.append(ssim)
                msssims.append(msssim)

                psnr_mean = np.array(psnrs).mean()
                ssim_mean = np.array(ssims).mean()
                msssim_mean = np.array(msssims).mean()

                pfix['PSNR_mean'] = f'{psnr_mean:.2f}'
                pfix['SSIM_mean'] = f'{ssim_mean:.4f}'
                           
                free_gpu = get_gpu_memory()[0]
                
                pfix['free GPU'] = f'{free_gpu}MiB'
                pfix['Elapsed'] = f'{elapsed}'
                
                pbar.set_postfix(pfix)
                losses.append(loss.item())
                
                step += 1
                
            logger.add_scalar("Loss/train", np.array(losses).mean(), epoch+1)
            logger.add_scalar("PSNR/train", psnr_mean, epoch+1)
            logger.add_scalar("SSIM/train", ssim_mean, epoch+1)
            
            if (epoch+1) % save_model_every == 0:
                torch.save(model.state_dict(), f'{weight_dir}/epoch_{epoch+1:04d}.pth')
                
            if (epoch+1) % test_model_every == 0:
                
                with torch.no_grad():
                    with tqdm(test_loader, desc=f'{mode} || Test Epoch {epoch+1}/{num_epochs}', position=0, leave=True) as pbar_test:
                        psnrs = []
                        ssims = []
                        msssims = []
                        for lr, hr, fname in pbar_test:
                            
                            fname = fname[0].split('/')[-1].split('.pt')[0]
                            
                            # lr = lr.to(device)
                            hr = hr.to(device)
                            
                            sigma = np.random.choice(sigmas)
                            hr_input = hr + torch.randn_like(hr, device=hr.device)*sigma
                            hr_input = torch.clamp(hr_input, 0, 1)
                            
                            sr = model(hr_input)
                        
                            sr = quantize(sr)

                            psnr, ssim, msssim = evaluate(hr, sr)

                            psnrs.append(psnr)
                            ssims.append(ssim)
                            msssims.append(msssim)

                            psnr_mean = np.array(psnrs).mean()
                            ssim_mean = np.array(ssims).mean()
                            msssim_mean = np.array(msssims).mean()

                            pfix_test['psnr_mean'] = f'{psnr_mean:.4f}'
                            pfix_test['ssim_mean'] = f'{ssim_mean:.4f}'
                            pfix_test['msssim_mean'] = f'{msssim_mean:.4f}'
                            
                            pbar_test.set_postfix(pfix_test)
                            
                            
                            imsave([hr_input[0], sr[0], hr[0]], f'{result_dir}/{fname}.jpg')
                            
                        hist['epoch'].append(epoch+1)
                        hist['psnr'].append(psnr_mean)
                        hist['ssim'].append(ssim_mean)
                        hist['ms-ssim'].append(msssim_mean)
                        
                        logger.add_scalar("PSNR/test", psnr_mean, epoch+1)
                        logger.add_scalar("SSIM/test", ssim_mean, epoch+1)
                        logger.add_scalar("MS-SSIM/test", msssim_mean, epoch+1)
                        
                        df = pd.DataFrame(hist)
                        df.to_csv(csv)