Esempio n. 1
0
def test(cfg: Namespace) -> None:
    assert cfg.checkpoint not in [None, ""]
    assert cfg.device == "cpu" or (cfg.device == "cuda"
                                   and T.cuda.is_available())

    exp_dir = ROOT_EXP_DIR / cfg.exp_name
    os.makedirs(exp_dir / "out", exist_ok=True)
    cfg.to_file(exp_dir / "test_config.json")
    logger.info(f"[exp dir={exp_dir}]")

    model = CAE()
    model.load_state_dict(T.load(cfg.checkpoint))
    model.eval()
    if cfg.device == "cuda":
        model.cuda()
    logger.info(f"[model={cfg.checkpoint}] on {cfg.device}")

    dataloader = DataLoader(dataset=ImageFolder720p(cfg.dataset_path),
                            batch_size=1,
                            shuffle=cfg.shuffle)
    logger.info(f"[dataset={cfg.dataset_path}]")

    loss_criterion = nn.MSELoss()

    for batch_idx, data in enumerate(dataloader, start=1):
        img, patches, _ = data
        if cfg.device == "cuda":
            patches = patches.cuda()

        if batch_idx % cfg.batch_every == 0:
            pass

        out = T.zeros(6, 10, 3, 128, 128)
        avg_loss = 0

        for i in range(6):
            for j in range(10):
                x = patches[:, :, i, j, :, :].cuda()
                y = model(x)
                out[i, j] = y.data

                loss = loss_criterion(y, x)
                avg_loss += (1 / 60) * loss.item()

        logger.debug("[%5d/%5d] avg_loss: %f", batch_idx, len(dataloader),
                     avg_loss)

        # save output
        out = np.transpose(out, (0, 3, 1, 4, 2))
        out = np.reshape(out, (768, 1280, 3))
        out = np.transpose(out, (2, 0, 1))

        y = T.cat((img[0], out), dim=2)
        save_imgs(
            imgs=y.unsqueeze(0),
            to_size=(3, 768, 2 * 1280),
            name=exp_dir / f"out/test_{batch_idx}.png",
        )
Esempio n. 2
0
File: test.py Progetto: zy23456/cae
def test(cfg: Namespace) -> None:
    logger.info("=== Testing ===")

    # initial setup
    prologue(cfg)

    model = CAE()
    model.load_state_dict(torch.load(cfg.chkpt))
    model.eval()
    if cfg.device == "cuda":
        model.cuda()

    logger.info("Loaded model")

    dataset = ImageFolder720p(cfg.dataset_path)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=cfg.shuffle)

    logger.info("Loaded data")

    loss_criterion = nn.MSELoss()

    for batch_idx, data in enumerate(dataloader, start=1):
        img, patches, _ = data
        if cfg.device == 'cuda':
            patches = patches.cuda()

        if batch_idx % cfg.batch_every == 0:
            pass

        out = torch.zeros(6, 10, 3, 128, 128)
        # enc = torch.zeros(6, 10, 16, 8, 8)
        avg_loss = 0

        for i in range(6):
            for j in range(10):
                x = Variable(patches[:, :, i, j, :, :]).cuda()
                y = model(x)
                out[i, j] = y.data

                loss = loss_criterion(y, x)
                avg_loss += (1 / 60) * loss.item()

        logger.debug('[%5d/%5d] avg_loss: %f' %
                     (batch_idx, len(dataloader), avg_loss))

        # save output
        out = np.transpose(out, (0, 3, 1, 4, 2))
        out = np.reshape(out, (768, 1280, 3))
        out = np.transpose(out, (2, 0, 1))

        y = torch.cat((img[0], out), dim=2)
        save_imgs(
            imgs=y.unsqueeze(0),
            to_size=(3, 768, 2 * 1280),
            name=f"../experiments/{cfg.exp_name}/out/test_{batch_idx}.png")

    # final setup
    epilogue(cfg)
Esempio n. 3
0
def test(cfg: Namespace) -> None:
    assert cfg.checkpoint not in [None, ""]
    assert cfg.device == "cpu" or (cfg.device == "cuda"
                                   and T.cuda.is_available())

    exp_dir = ROOT_EXP_DIR / cfg.exp_name
    os.makedirs(exp_dir / "out", exist_ok=True)
    cfg.to_file(exp_dir / "test_config.json")
    logger.info(f"[exp dir={exp_dir}]")

    model = CAE()
    model.load_state_dict(T.load(cfg.checkpoint))
    model.eval()
    if cfg.device == "cuda":
        model.cuda()
    logger.info(f"[model={cfg.checkpoint}] on {cfg.device}")

    dataloader = DataLoader(dataset=ImageFolder720p(cfg.dataset_path),
                            batch_size=1,
                            shuffle=cfg.shuffle)
    logger.info(f"[dataset={cfg.dataset_path}]")

    loss_criterion = nn.MSELoss()

    for batch_idx, data in enumerate(dataloader, start=2):
        img, patches, _ = data
        print('the patches shape is:', patches.shape)
        # print(_)
        # plt.imshow(patches[0,:,3,1,:,:].permute(1,2,0))
        # # plt.imshow(patches[0].permute(1,2,0))
        # plt.show()
        if cfg.device == "cuda":
            patches = patches.cuda()

        if batch_idx % cfg.batch_every == 0:
            pass

        out = T.zeros(6, 10, 3, 128, 128)
        avg_loss = 0
        foo = []
        for i in range(6):
            for j in range(10):
                x = patches[:, :, i, j, :, :].cuda()
                print('the x shape is:', x.shape)
                y = model(x)
                print('hellyy', y.shape)
Esempio n. 4
0
def train(cfg: Namespace) -> None:
    logger.info("=== Training ===")

    # initial setup
    writer = prologue(cfg)

    # train-related code
    model = CAE()
    model.train()
    if cfg.device == "cuda":
        model.cuda()
    logger.debug(f"Model loaded on {cfg.device}")

    dataset = ImageFolder720p(cfg.dataset_path)
    dataloader = DataLoader(dataset,
                            batch_size=cfg.batch_size,
                            shuffle=cfg.shuffle,
                            num_workers=cfg.num_workers)
    logger.debug("Data loaded")

    optimizer = optim.Adam(model.parameters(),
                           lr=cfg.learning_rate,
                           weight_decay=1e-5)
    loss_criterion = nn.MSELoss()
    # scheduler = ...

    avg_loss, epoch_avg = 0.0, 0.0
    ts = 0

    # train-loop
    for epoch_idx in range(cfg.start_epoch, cfg.num_epochs + 1):

        # scheduler.step()

        for batch_idx, data in enumerate(dataloader, start=1):
            img, y, patches, _ = data

            if cfg.device == "cuda":
                patches = patches.cuda()

            avg_loss_per_image = 0.0  # 初始化单张图片的损失
            for i in range(1):
                for j in range(1):
                    optimizer.zero_grad()

                    x = Variable(patches[:, :, i, j, :, :])
                    y = model(x)
                    loss = loss_criterion(y, x)

                    avg_loss_per_image += (1 / 60) * loss.item()
                    avg_loss_per_image += (1 / 1) * loss.item()

                    loss.backward()
                    optimizer.step()

            avg_loss += avg_loss_per_image
            epoch_avg += avg_loss_per_image

            if batch_idx % cfg.batch_every == 0:
                writer.add_scalar("train/avg_loss", avg_loss / cfg.batch_every,
                                  ts)

                for name, param in model.named_parameters():
                    writer.add_histogram(name, param, ts)

                logger.debug('[%3d/%3d][%5d/%5d] avg_loss: %.8f' %
                             (epoch_idx, cfg.num_epochs, batch_idx,
                              len(dataloader), avg_loss / cfg.batch_every))
                avg_loss = 0.0
                ts += 1

            if batch_idx % cfg.save_every == 0:
                # out = torch.zeros(6, 10, 3, 128, 128)
                out = torch.zeros(1, 1, 3, 256, 256)
                for i in range(1):
                    for j in range(1):
                        x = Variable(patches[0, :, i,
                                             j, :, :].unsqueeze(0)).cuda()
                        out[i, j] = model(x).cpu().data

                out = np.transpose(out, (0, 3, 1, 4, 2))
                # out = np.reshape(out, (768, 1280, 3))
                out = np.reshape(out, (256, 256, 3))
                out = np.transpose(out, (2, 0, 1))

                y = torch.cat((img[0], out), dim=2).unsqueeze(0)
                # save_imgs(imgs=y, to_size=(3, 768, 2 * 1280), name=f"/data2/TDL/paper_fabric/workdir/4_30_short_eassy/{cfg.exp_name}/out/out_{epoch_idx}_{batch_idx}.png")
                save_imgs(
                    imgs=y,
                    to_size=(3, 256, 2 * 256),
                    name=
                    f"/data2/TDL/paper_fabric/workdir/4_30_short_eassy/{cfg.exp_name}/out/out_{epoch_idx}_{batch_idx}.png"
                )

        # -- batch-loop

        if epoch_idx % cfg.epoch_every == 0:
            epoch_avg /= (len(dataloader) * cfg.epoch_every)

            writer.add_scalar("train/epoch_avg_loss",
                              avg_loss / cfg.batch_every,
                              epoch_idx // cfg.epoch_every)

            logger.info("Epoch avg = %.8f" % epoch_avg)
            epoch_avg = 0.0
            torch.save(
                model.state_dict(),
                f"/data2/TDL/paper_fabric/workdir/4_30_short_eassy/{cfg.exp_name}/chkpt/model_{epoch_idx}.pth"
            )

    # -- train-loop

    # save final model
    torch.save(
        model.state_dict(),
        f"/data2/TDL/paper_fabric/workdir/4_30_short_eassy/{cfg.exp_name}/model_final.pth"
    )

    # final setup
    epilogue(cfg, writer)
Esempio n. 5
0
def train(cfg: Namespace) -> None:
    print(cfg.device)
    assert cfg.device == 'cpu' or (cfg.device == 'cuda'
                                   and T.cuda.is_available())

    logger.info('training: experiment %s' % (cfg.exp_name))

    # make dir-tree
    exp_dir = ROOT_DIR / 'experiments' / cfg.exp_name

    for d in ['out', 'checkpoint', 'logs']:
        os.makedirs(exp_dir / d, exist_ok=True)

    cfg.to_file(exp_dir / 'train_config.txt')

    # tb writer
    writer = SummaryWriter(exp_dir / 'logs')

    model = CAE()
    model.train()
    if cfg.device == 'cuda':
        model.cuda()
    logger.info(f'loaded model on {cfg.device}')

    dataset = ImageFolder720p(cfg.dataset_path)
    dataloader = DataLoader(dataset,
                            batch_size=cfg.batch_size,
                            shuffle=cfg.shuffle,
                            num_workers=cfg.num_workers)
    logger.info('loaded dataset')

    optimizer = optim.Adam(model.parameters(),
                           lr=cfg.learning_rate,
                           weight_decay=1e-5)
    loss_criterion = nn.MSELoss()

    avg_loss, epoch_avg = 0.0, 0.0
    ts = 0

    # EPOCHS
    for epoch_idx in range(cfg.start_epoch, cfg.num_epochs + 1):
        # BATCHES
        for batch_idx, data in enumerate(dataloader, start=1):
            img, patches, _ = data

            if cfg.device == 'cuda':
                patches = patches.cuda()

            avg_loss_per_image = 0.0
            for i in range(6):
                for j in range(10):
                    optimizer.zero_grad()

                    x = patches[:, :, i, j, :, :]
                    y = model(x)
                    loss = loss_criterion(y, x)

                    avg_loss_per_image += (1 / 60) * loss.item()

                    loss.backward()
                    optimizer.step()

            avg_loss += avg_loss_per_image
            epoch_avg += avg_loss_per_image

            if batch_idx % cfg.batch_every == 0:
                writer.add_scalar('train/avg_loss', avg_loss / cfg.batch_every,
                                  ts)

                for name, param in model.named_parameters():
                    writer.add_histogram(name, param, ts)

                logger.debug('[%3d/%3d][%5d/%5d] avg_loss: %.8f' %
                             (epoch_idx, cfg.num_epochs, batch_idx,
                              len(dataloader), avg_loss / cfg.batch_every))

                avg_loss = 0.0
                ts += 1
            # -- end batch every

            if batch_idx % cfg.save_every == 0:
                out = T.zeros(6, 10, 3, 128, 128)
                for i in range(6):
                    for j in range(10):
                        x = patches[0, :, i, j, :, :].unsqueeze(0).cuda()
                        out[i, j] = model(x).cpu().data

                out = np.transpose(out, (0, 3, 1, 4, 2))
                out = np.reshape(out, (768, 1280, 3))
                out = np.transpose(out, (2, 0, 1))

                y = T.cat((img[0], out), dim=2).unsqueeze(0)
                save_imgs(imgs=y,
                          to_size=(3, 768, 2 * 1280),
                          name=exp_dir / f'out/{epoch_idx}_{batch_idx}.png')
            # -- end save every
        # -- end batches

        if epoch_idx % cfg.epoch_every == 0:
            epoch_avg /= (len(dataloader) * cfg.epoch_every)

            writer.add_scalar('train/epoch_avg_loss',
                              avg_loss / cfg.batch_every,
                              epoch_idx // cfg.epoch_every)

            logger.info('Epoch avg = %.8f' % epoch_avg)
            epoch_avg = 0.0

            T.save(model.state_dict(),
                   exp_dir / f'checkpoint/model_{epoch_idx}.state')
        # -- end epoch every

# -- end epoch

# save final model
    T.save(model.state_dict(), exp_dir / 'model_final.state')

    # cleaning
    writer.close()
Esempio n. 6
0
def train(cfg: Namespace) -> None:
    assert cfg.device == "cpu" or (cfg.device == "cuda"
                                   and T.cuda.is_available())

    root_dir = Path(__file__).resolve().parents[1]

    logger.info("training: experiment %s" % (cfg.exp_name))

    # make dir-tree
    exp_dir = root_dir / "experiments" / cfg.exp_name

    for d in ["out", "checkpoint", "logs"]:
        os.makedirs(exp_dir / d, exist_ok=True)

    cfg.to_file(exp_dir / "train_config.json")

    # tb tb_writer
    tb_writer = SummaryWriter(exp_dir / "logs")
    logger.info("started tensorboard writer")

    model = CAE()
    model.train()
    if cfg.device == "cuda":
        model.cuda()
    logger.info(f"loaded model on {cfg.device}")

    dataloader = DataLoader(
        dataset=ImageFolder720p(cfg.dataset_path),
        batch_size=cfg.batch_size,
        shuffle=cfg.shuffle,
        num_workers=cfg.num_workers,
    )
    logger.info(f"loaded dataset from {cfg.dataset_path}")

    optimizer = optim.Adam(model.parameters(),
                           lr=cfg.learning_rate,
                           weight_decay=1e-5)
    loss_criterion = nn.MSELoss()

    avg_loss, epoch_avg = 0.0, 0.0
    ts = 0

    # EPOCHS
    for epoch_idx in range(cfg.start_epoch, cfg.num_epochs + 1):
        # BATCHES
        for batch_idx, data in enumerate(dataloader, start=1):
            img, patches, _ = data

            if cfg.device == "cuda":
                patches = patches.cuda()

            avg_loss_per_image = 0.0
            for i in range(6):
                for j in range(10):
                    optimizer.zero_grad()

                    x = patches[:, :, i, j, :, :]
                    y = model(x)
                    loss = loss_criterion(y, x)

                    avg_loss_per_image += (1 / 60) * loss.item()

                    loss.backward()
                    optimizer.step()

            avg_loss += avg_loss_per_image
            epoch_avg += avg_loss_per_image

            if batch_idx % cfg.batch_every == 0:
                tb_writer.add_scalar("train/avg_loss",
                                     avg_loss / cfg.batch_every, ts)

                for name, param in model.named_parameters():
                    tb_writer.add_histogram(name, param, ts)

                logger.debug("[%3d/%3d][%5d/%5d] avg_loss: %.8f" % (
                    epoch_idx,
                    cfg.num_epochs,
                    batch_idx,
                    len(dataloader),
                    avg_loss / cfg.batch_every,
                ))

                avg_loss = 0.0
                ts += 1
            # -- end batch every

            if batch_idx % cfg.save_every == 0:
                out = T.zeros(6, 10, 3, 128, 128)
                for i in range(6):
                    for j in range(10):
                        x = patches[0, :, i, j, :, :].unsqueeze(0).cuda()
                        out[i, j] = model(x).cpu().data

                out = np.transpose(out, (0, 3, 1, 4, 2))
                out = np.reshape(out, (768, 1280, 3))
                out = np.transpose(out, (2, 0, 1))

                y = T.cat((img[0], out), dim=2).unsqueeze(0)
                save_imgs(
                    imgs=y,
                    to_size=(3, 768, 2 * 1280),
                    name=exp_dir / f"out/{epoch_idx}_{batch_idx}.png",
                )
            # -- end save every
        # -- end batches

        if epoch_idx % cfg.epoch_every == 0:
            epoch_avg /= len(dataloader) * cfg.epoch_every

            tb_writer.add_scalar(
                "train/epoch_avg_loss",
                avg_loss / cfg.batch_every,
                epoch_idx // cfg.epoch_every,
            )

            logger.info("Epoch avg = %.8f" % epoch_avg)
            epoch_avg = 0.0

            T.save(model.state_dict(),
                   exp_dir / f"checkpoint/model_{epoch_idx}.pth")
        # -- end epoch every
    # -- end epoch

    # save final model
    T.save(model.state_dict(), exp_dir / "model_final.pth")

    # cleaning
    tb_writer.close()
Esempio n. 7
0
def test(cfg: Namespace) -> None:
    logger.info("=== Testing ===")

    if cfg.alg == "32":
        from autoencoder2bpp import AutoEncoder
    elif cfg.alg == "16":
        from autoencoder025bpp import AutoEncoder
    elif cfg.alg == "8":
        from autoencoder006bpp import AutoEncoder

    # initial setup
    prologue(cfg)

    model = CAE()
    model.load_state_dict(torch.load(cfg.chkpt))
    model.eval()
    if cfg.device == "cuda":
        model.cuda()

    logger.info("Loaded model")

    dataset = ImageFolder720p(cfg.dataset_path)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=cfg.shuffle)

    logger.info("Loaded data")

    loss_criterion = nn.MSELoss()
    if cfg.weigh_path is not None:
        loss_criterion = weighted_loss_function

    for batch_idx, data in enumerate(dataloader, start=1):
        img, patches, _, weights = data
        if cfg.device == 'cuda':
            patches = patches.cuda()

        out = torch.zeros(3, 256, 256)
        avg_loss = 0

        x = Variable(patches[:, :, :, :]).cpu()
        y = model(x).cpu()
        out = y.data

        if cfg.weigh_path:
            w = Variable(weights)
            w = w[:, None, :, :]
            w = torch.cat((w, w, w), dim=1)
            if cfg.device == "cuda":
                w = w.cuda()
            loss = loss_criterion(y, x, w)
        else:
            loss = loss_criterion(y, x)
        avg_loss += loss.item()

        logger.debug('[%5d/%5d] avg_loss: %f' %
                     (batch_idx, len(dataloader), avg_loss))

        # save output
        out = np.reshape(out, (3, 256, 256))

        #print(model.encoded)
        concat = torch.cat((img[0], out), dim=2).unsqueeze(0)
        save_imgs(
            imgs=concat,
            to_size=(3, 256, 2 * 256),
            name=f"../experiments/{cfg.exp_name}/out_test/test_{batch_idx}.png"
        )

    # final setup
    epilogue(cfg)
Esempio n. 8
0
def train(cfg: Namespace) -> None:
    logger.info("=== Training ===")

    # initial setup
    writer = prologue(cfg)

    # train-related code
    model = CAE_Classify()
    model.load_state_dict(torch.load(cfg.load_from), strict=False)     # 加载预训练模型
    model.train()
    if cfg.device == "cuda":
        model.cuda()
    logger.debug(f"Model loaded on {cfg.device}")

    dataset = ImageFolder720p(cfg.dataset_path)
    test_dataset = ImageFolder720p(cfg.test_dataset_path)
    test_dataloader = DataLoader(test_dataset, batch_size=50, shuffle=cfg.shuffle, num_workers=cfg.num_workers)
    dataloader = DataLoader(dataset, batch_size=cfg.batch_size, shuffle=cfg.shuffle, num_workers=cfg.num_workers)
    logger.debug("Data loaded")

    for index, data in enumerate(test_dataloader):
        if index > 0:
            break
        _, test_y, test_pathch, _ = data



    optimizer = optim.Adam(filter(lambda p:p.requires_grad, model.parameters()), lr=cfg.learning_rate, weight_decay=1e-5)
    loss_criterion = torch.nn.CrossEntropyLoss()
    # scheduler = ...

    avg_loss, epoch_avg = 0.0, 0.0
    ts = 0

    # train-loop
    for epoch_idx in range(cfg.start_epoch, cfg.num_epochs + 1):

        # scheduler.step()

        for batch_idx, data in enumerate(dataloader, start=1):
            img, y, patches, _ = data

            if cfg.device == "cuda":
                patches = patches.cuda()

            avg_loss_per_image = 0.0        # 初始化单张图片的损失
            optimizer.zero_grad()

            x = Variable(patches[:, :, 0, 0, :, :])
            pred_y = model(x)

            # loss = F.nll_loss(pred_y, y.cuda())
            loss = loss_criterion(pred_y, y.cuda())

            avg_loss_per_image += loss.item()

            loss.backward()
            optimizer.step()

            avg_loss += avg_loss_per_image
            epoch_avg += avg_loss_per_image

            if batch_idx % cfg.batch_every == 0:
                writer.add_scalar("train/avg_loss", avg_loss / cfg.batch_every, ts)

                x = Variable(test_pathch[:, :, 0, 0, :, :]).cuda()
                test_output = model(x).cpu()
                pred_y = torch.max(test_output, 1)[1].cpu().numpy()
                accuracy = float((pred_y == test_y.cpu().numpy()).astype(int).sum()) / float(test_y.size(0))


                for name, param in model.named_parameters():
                    writer.add_histogram(name, param, ts)

                logger.debug(
                    '[%3d/%3d][%5d/%5d] avg_loss: %.8f accuracy: %.8f'%
                    (epoch_idx, cfg.num_epochs, batch_idx, len(dataloader), avg_loss / cfg.batch_every, accuracy)
                )
                avg_loss = 0.0
                ts += 1


        # -- batch-loop

        if epoch_idx % cfg.epoch_every == 0:
            epoch_avg /= (len(dataloader) * cfg.epoch_every)
            accuracy = 0
            for index, data in enumerate(test_dataloader):
                _, test_y, patches, _ = data
                x = Variable(patches[:, :, 0, 0, :, :]).cuda()
                test_output = model(x)
                pred_y = torch.max(test_output, 1)[1]
                accuracy += float((test_y.cpu().numpy() == pred_y).astype().sum()) / float(test_y.size(0))
            accuracy /= (index - 1)

            writer.add_scalar("train/epoch_avg_loss", avg_loss / cfg.batch_every, epoch_idx // cfg.epoch_every)

            logger.info("Epoch avg = %.8f  Accuracy = %.8f" % epoch_avg, accuracy)
            epoch_avg = 0.0
            torch.save(model.state_dict(), f"/data2/TDL/paper_fabric/workdir/4_30_short_eassy/{cfg.exp_name}/chkpt/model_{epoch_idx}.pth")

    # -- train-loop

    # save final model
    torch.save(model.state_dict(), f"/data2/TDL/paper_fabric/workdir/4_30_short_eassy/{cfg.exp_name}/model_final.pth")

    # final setup
    epilogue(cfg, writer)
def train(cfg: Namespace) -> None:

    # initial setup
    writer = prologue(cfg)
    logger.info("=== Training ===")

    if cfg.alg == "32":
        from autoencoder2bpp import AutoEncoder
    elif cfg.alg == "16":
        from autoencoder025bpp import AutoEncoder
    elif cfg.alg == "8":
        from autoencoder006bpp import AutoEncoder

    # train-related code
    if cfg.device == "cuda":
        model = AutoEncoder(cudaD=True)
    else:
        model = AutoEncoder(cudaD=False)

    if cfg.chkpt:
        model.load_state_dict(torch.load(cfg.chkpt))
    model.train()
    if cfg.device == "cuda":
        model.cuda()

    logger.debug("Model loaded")

    dataset = ImageFolder720p(cfg.dataset_path, cfg.weigh_path)
    dataloader = DataLoader(dataset,
                            batch_size=cfg.batch_size,
                            shuffle=cfg.shuffle,
                            num_workers=cfg.num_workers)

    logger.debug("Data loaded")

    optimizer = optim.Adam(model.parameters(),
                           lr=cfg.learning_rate,
                           weight_decay=1e-5)
    loss_criterion = nn.MSELoss()
    if cfg.weigh_path is not None:
        loss_criterion = weighted_loss_function

    avg_loss, epoch_avg = 0.0, 0.0
    ts = 0

    # train-loop
    for epoch_idx in range(cfg.start_epoch, cfg.num_epochs + 1):

        # scheduler.step()

        for batch_idx, data in enumerate(dataloader, start=1):
            img, patches, _, weights = data

            if cfg.device == "cuda":
                patches = patches.cuda()

            avg_loss_per_image = 0.0

            optimizer.zero_grad()

            x = Variable(patches)
            y = model(x)

            if cfg.weigh_path:
                w = Variable(weights)
                w = w[:, None, :, :]
                w = torch.cat((w, w, w), dim=1)
                if cfg.device == "cuda":
                    w = w.cuda()
                loss = loss_criterion(y, x, w)
            else:
                loss = loss_criterion(y, x)

            avg_loss_per_image += loss.item()

            loss.backward()
            optimizer.step()

            avg_loss += avg_loss_per_image
            epoch_avg += avg_loss_per_image

            if batch_idx % cfg.batch_every == 0:
                writer.add_scalar("train/avg_loss", avg_loss / cfg.batch_every,
                                  ts)

                for name, param in model.named_parameters():
                    writer.add_histogram(name, param, ts)

                logger.debug('[%3d/%3d][%5d/%5d] avg_loss: %.8f' %
                             (epoch_idx, cfg.num_epochs, batch_idx,
                              len(dataloader), avg_loss / cfg.batch_every))
                avg_loss = 0.0
                ts += 1

            if batch_idx % cfg.save_every == 0:
                out = torch.zeros(3, 256, 256)

                if cfg.device == "cuda":
                    x = Variable(patches[0, :, :, :].unsqueeze(0)).cuda()
                else:
                    x = Variable(patches[0, :, :, :].unsqueeze(0)).cpu()

                out = model(x).cpu().data
                out = np.reshape(out, (3, 256, 256))

                y = torch.cat((img[0], out), dim=2).unsqueeze(0)
                save_imgs(
                    imgs=y,
                    to_size=(3, 256, 2 * 256),
                    name=
                    f"../experiments/{cfg.exp_name}/out/out_{epoch_idx}_{batch_idx}.png"
                )

        # -- batch-loop

        if epoch_idx % cfg.epoch_every == 0:
            epoch_avg /= (len(dataloader) * cfg.epoch_every)

            writer.add_scalar("train/epoch_avg_loss",
                              avg_loss / cfg.batch_every,
                              epoch_idx // cfg.epoch_every)

            logger.info("Epoch avg = %.8f" % epoch_avg)
            epoch_avg = 0.0
            torch.save(
                model.state_dict(),
                f"../experiments/{cfg.exp_name}/chkpt/model_{epoch_idx}.pth")

    # -- train-loop

    # save final model
    torch.save(model.state_dict(),
               f"../experiments/{cfg.exp_name}/model_final.pth")

    # final setup
    epilogue(cfg, writer)