Example #1
0
def main(lr, batch_size, epoch, gpu, train_set, valid_set):

    # ------------- Part for tensorboard --------------
    writer = SummaryWriter(comment="_equal_CZAR")
    # ------------- Part for tensorboard --------------

    # -------------- Some prepare ---------------------
    torch.backends.cudnn.enabled = True
    torch.cuda.set_device(gpu)
    # torch.set_default_tensor_type('torch.cuda.FloatTensor')
    # -------------- Some prepare ---------------------

    BATCH_SIZE = batch_size
    EPOCH = epoch

    LEARNING_RATE = lr
    belta1 = 0.9
    belta2 = 0.999

    trainset = mydataset(train_set, transform_train)
    valset = mydataset(valid_set)
    trainLoader = torch.utils.data.DataLoader(trainset,
                                              batch_size=BATCH_SIZE,
                                              shuffle=True)
    valLoader = torch.utils.data.DataLoader(valset,
                                            batch_size=1,
                                            shuffle=False)

    opter = Opter(128, 128, batch_size)

    SepConvNet = Network(opter).cuda()
    SepConvNet.apply(weights_init)
    # SepConvNet.load_state_dict(torch.load('/mnt/hdd/xiasifeng/sepconv/sepconv_mutiscale_LD/SepConv_iter33-ltype_fSATD_fs-lr_0.001-trainloss_0.1497-evalloss_0.1357-evalpsnr_29.6497.pkl'))

    # SepConvNet_cost = nn.MSELoss().cuda()
    # SepConvNet_cost = nn.L1Loss().cuda()
    SepConvNet_cost = sepconv.SATDLoss().cuda()
    SepConvNet_optimizer = optim.Adamax(SepConvNet.parameters(),
                                        lr=LEARNING_RATE,
                                        betas=(belta1, belta2))
    SepConvNet_schedule = optim.lr_scheduler.ReduceLROnPlateau(
        SepConvNet_optimizer,
        factor=0.1,
        patience=3,
        verbose=True,
        min_lr=1e-5)

    # ----------------  Time part -------------------
    start_time = time.time()
    global_step = 0
    # ----------------  Time part -------------------

    for epoch in range(0, EPOCH):
        SepConvNet.train().cuda()
        cnt = 0
        sumloss = 0.0  # The sumloss is for the whole training_set
        tsumloss = 0.0  # The tsumloss is for the printinterval
        printinterval = 300
        print("---------------[Epoch%3d]---------------" % (epoch + 1))
        for imgL, imgR, label in trainLoader:
            global_step = global_step + 1
            cnt = cnt + 1
            SepConvNet_optimizer.zero_grad()

            imgL = var(imgL).cuda()
            imgR = var(imgR).cuda()
            label = var(label).cuda()
            with torch.no_grad():
                # Remember here we need the back-forward flow
                diff = opter.calcOpt(imgR, imgL)

            warped, output = SepConvNet(diff, imgL, imgR)
            loss_out = SepConvNet_cost(output, label)
            loss_warp = SepConvNet_cost(warped, label)
            loss = 0.5 * loss_out + 0.5 * loss_warp
            loss.backward()
            SepConvNet_optimizer.step()
            sumloss = sumloss + loss_out.data.item()
            tsumloss = tsumloss + loss_out.data.item()

            if cnt % printinterval == 0:
                writer.add_image("Target image", label[0], cnt)
                writer.add_image("Warped image", warped[0], cnt)
                writer.add_image("Final image", output[0], cnt)
                writer.add_scalar('Train Batch SATD loss',
                                  loss_out.data.item(),
                                  int(global_step / printinterval))
                writer.add_scalar('Train Interval SATD loss',
                                  tsumloss / printinterval,
                                  int(global_step / printinterval))
                print(
                    'Epoch [%d/%d], Iter [%d/%d], Time [%4.4f], Batch loss [%.6f], Interval loss [%.6f]'
                    % (epoch + 1, EPOCH, cnt, len(trainset) // BATCH_SIZE,
                       time.time() - start_time, loss_out.data.item(),
                       tsumloss / printinterval))
                tsumloss = 0.0
        print('Epoch [%d/%d], iter: %d, Time [%4.4f], Avg Loss [%.6f]' %
              (epoch + 1, EPOCH, cnt, time.time() - start_time, sumloss / cnt))

        # ---------------- Part for validation ----------------
        trainloss = sumloss / cnt
        SepConvNet.eval().cuda()
        evalcnt = 0
        pos = 0.0
        sumloss = 0.0
        psnr = 0.0
        for imgL, imgR, label in valLoader:
            imgL = var(imgL).cuda()
            imgR = var(imgR).cuda()
            label = var(label).cuda()
            with torch.no_grad():
                # Remember here we need the back-forward flow
                diff = opter.calcOpt(imgR, imgL)
            with torch.no_grad():
                warped, output = SepConvNet(diff, imgL, imgR)
                loss_out = SepConvNet_cost(output, label)
                loss_warp = SepConvNet_cost(warped, label)
                loss = 0.5 * loss_out + 0.5 * loss_warp
                sumloss = sumloss + loss_out.data.item()
                psnr = psnr + calcPSNR.calcPSNR(output.cpu().data.numpy(),
                                                label.cpu().data.numpy())
                evalcnt = evalcnt + 1
        # ------------- Tensorboard part -------------
        writer.add_scalar("Valid SATD loss", sumloss / evalcnt, epoch)
        writer.add_scalar("Valid PSNR", psnr / valset.__len__(), epoch)
        # ------------- Tensorboard part -------------
        print('Validation loss [%.6f],  Average PSNR [%.4f]' %
              (sumloss / evalcnt, psnr / valset.__len__()))
        SepConvNet_schedule.step(psnr / valset.__len__())
        torch.save(
            SepConvNet.state_dict(),
            os.path.join(
                '.', 'equal_CZAR_iter' + str(epoch + 1) + '-ltype_fSATD_fs' +
                '-lr_' + str(LEARNING_RATE) + '-trainloss_' +
                str(round(trainloss, 4)) + '-evalloss_' +
                str(round(sumloss / evalcnt, 4)) + '-evalpsnr_' +
                str(round(psnr / valset.__len__(), 4)) + '.pkl'))
    writer.close()
Example #2
0
def main(lr, batch_size, epoch, gpu, train_set, valid_set):
    # ------------- Part for tensorboard --------------
    writer = SummaryWriter(log_dir='tb/ft1_baseline_mask')
    # ------------- Part for tensorboard --------------
    torch.backends.cudnn.enabled = True
    torch.cuda.set_device(gpu)

    BATCH_SIZE = batch_size
    EPOCH = epoch

    LEARNING_RATE = lr
    belta1 = 0.9
    belta2 = 0.999

    trainset = mydataset(train_set, transform_train)
    valset = mydataset(valid_set)
    trainLoader = torch.utils.data.DataLoader(trainset,
                                              batch_size=BATCH_SIZE,
                                              shuffle=True)
    valLoader = torch.utils.data.DataLoader(valset,
                                            batch_size=1,
                                            shuffle=False)

    SepConvNet = Network().cuda()
    # SepConvNet.apply(weights_init)
    SepConvNet.load_state_dict(
        torch.load(
            '/mnt/hdd/iku/ISCAS/train/mask_baseline_iter52-ltype_fSATD_fs-lr_0.001-trainloss_0.1279-evalloss_0.1181-evalpsnr_29.6526.pkl'
        ))

    # MSE_cost = nn.MSELoss().cuda()
    # SepConvNet_cost = nn.L1Loss().cuda()
    SepConvNet_cost = sepconv.SATDLoss().cuda()
    SepConvNet_optimizer = optim.Adamax(SepConvNet.parameters(),
                                        lr=LEARNING_RATE,
                                        betas=(belta1, belta2))
    SepConvNet_schedule = optim.lr_scheduler.ReduceLROnPlateau(
        SepConvNet_optimizer,
        factor=0.1,
        patience=3,
        verbose=True,
        min_lr=1e-7)

    # ----------------  Time part -------------------
    start_time = time.time()
    global_step = 0
    # ----------------  Time part -------------------

    # ---------------- Opt part -----------------------
    opter = Opter(gpu)
    # -------------------------------------------------

    for epoch in range(0, EPOCH):
        SepConvNet.train().cuda()
        cnt = 0
        sumloss = 0.0  # The sumloss is for the whole training_set
        tsumloss = 0.0  # The tsumloss is for the printinterval
        printinterval = 300
        print("---------------[Epoch%3d]---------------" % (epoch + 1))
        for imgL, imgR, label in trainLoader:
            global_step = global_step + 1
            cnt = cnt + 1
            SepConvNet_optimizer.zero_grad()
            imgL = var(imgL).cuda()
            imgR = var(imgR).cuda()
            label = var(label).cuda()

            output = SepConvNet(imgL, imgR)
            loss = SepConvNet_cost(output, label)
            loss.backward()
            SepConvNet_optimizer.step()

            sumloss = sumloss + loss.data.item()
            tsumloss = tsumloss + loss.data.item()

            if cnt % printinterval == 0:
                writer.add_image("Ref image", imgR[0], cnt)
                writer.add_image("Pred image", output[0], cnt)
                writer.add_image("Target image", label[0], cnt)
                writer.add_scalar('Train Batch SATD loss', loss.data.item(),
                                  int(global_step / printinterval))
                writer.add_scalar('Train Interval SATD loss',
                                  tsumloss / printinterval,
                                  int(global_step / printinterval))
                print(
                    'Epoch [%d/%d], Iter [%d/%d], Time [%4.4f], Batch loss [%.6f], Interval loss [%.6f]'
                    % (epoch + 1, EPOCH, cnt, len(trainset) // BATCH_SIZE,
                       time.time() - start_time, loss.data.item(),
                       tsumloss / printinterval))
                tsumloss = 0.0
        print('Epoch [%d/%d], iter: %d, Time [%4.4f], Avg Loss [%.6f]' %
              (epoch + 1, EPOCH, cnt, time.time() - start_time, sumloss / cnt))

        # ---------------- Part for validation ----------------
        trainloss = sumloss / cnt
        SepConvNet.eval().cuda()
        evalcnt = 0
        pos = 0.0
        sumloss = 0.0
        psnr = 0.0
        for imgL, imgR, label in valLoader:
            imgL = var(imgL).cuda()
            imgR = var(imgR).cuda()
            label = var(label).cuda()
            with torch.no_grad():

                output = SepConvNet(imgL, imgR)
                loss = SepConvNet_cost(output, label)

                sumloss = sumloss + loss.data.item()
                psnr = psnr + calcPSNR.calcPSNR(output.cpu().data.numpy(),
                                                label.cpu().data.numpy())
                evalcnt = evalcnt + 1
        # ------------- Tensorboard part -------------
        writer.add_scalar("Valid SATD loss", sumloss / evalcnt, epoch)
        writer.add_scalar("Valid PSNR", psnr / valset.__len__(), epoch)
        # ------------- Tensorboard part -------------
        print('Validation loss [%.6f],  Average PSNR [%.4f]' %
              (sumloss / evalcnt, psnr / valset.__len__()))
        SepConvNet_schedule.step(psnr / valset.__len__())
        torch.save(
            SepConvNet.state_dict(),
            os.path.join(
                '.', 'ft1_mask_baseline_iter' + str(epoch + 1) +
                '-ltype_fSATD_fs' + '-lr_' + str(LEARNING_RATE) +
                '-trainloss_' + str(round(trainloss, 4)) + '-evalloss_' +
                str(round(sumloss / evalcnt, 4)) + '-evalpsnr_' +
                str(round(psnr / valset.__len__(), 4)) + '.pkl'))
    writer.close()
def main(lr, batch_size, epoch, gpu, train_set, valid_set):
    # ------------- Part for tensorboard --------------
    # writer = SummaryWriter(log_dir='tb/LSTM_ft1')
    # ------------- Part for tensorboard --------------
    torch.backends.cudnn.enabled = True
    torch.cuda.set_device(gpu)

    BATCH_SIZE=batch_size
    EPOCH=epoch

    LEARNING_RATE = lr
    belta1 = 0.9
    belta2 = 0.999

    trainset = vimeodataset(train_set, 'filelist.txt',transform_train)
    valset = vimeodataset(valid_set, 'test.txt')    
    trainLoader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
    valLoader = torch.utils.data.DataLoader(valset, batch_size=BATCH_SIZE, shuffle=False)
    assert(len(valset) % BATCH_SIZE == 0)


    SepConvNet = Network().cuda()
    # SepConvNet.apply(weights_init)
    SepConvNet.load_my_state_dict(torch.load('tail_LSTM_iter15-ltype_fSATD_fs-lr_0.001-trainloss_0.6045-evalloss_0.1127-evalpsnr_30.2671.pkl', map_location='cuda:%d'%(gpu)))
    # SepConvNet.load_state_dict(torch.load('beta_LSTM_iter8-ltype_fSATD_fs-lr_0.001-trainloss_0.557-evalloss_0.1165-evalpsnr_29.8361.pkl'))
            
    # @@@ Test result: child from 0-27 is the raw model~
    grad_list = [18,19,20,21,  25,26,27]
    child_cnt = 0
    for child in SepConvNet.children():
        if child_cnt in grad_list:
            child_cnt += 1
            continue
        child_cnt += 1
        for param in child.parameters():
            param.requires_grad = False

    # cs = list(SepConvNet.children())
    # ps = list(cs[17].parameters())
    # IPython.embed()
    # exit()
    # MSE_cost = nn.MSELoss().cuda()
    # SepConvNet_cost = nn.L1Loss().cuda()
    SepConvNet_cost = sepconv.SATDLoss().cuda()
    # SepConvNet_optimizer = optim.Adamax(SepConvNet.parameters(),lr=LEARNING_RATE, betas=(belta1,belta2))
    SepConvNet_optimizer = optim.Adamax(filter(lambda p: p.requires_grad, SepConvNet.parameters()),lr=LEARNING_RATE, betas=(belta1,belta2))
    SepConvNet_schedule = optim.lr_scheduler.ReduceLROnPlateau(SepConvNet_optimizer, factor=0.1, patience = 3, verbose=True)

    # ----------------  Time part -------------------
    start_time = time.time()
    global_step = 0
    # ----------------  Time part -------------------


    # ---------------- Opt part -----------------------
    # opter = Opter(gpu)
    # -------------------------------------------------
    # print('[!] Ready to train!')
    # IPython.embed()

    for epoch in range(0,EPOCH):
        SepConvNet.train().cuda()
        cnt = 0
        sumloss = 0.0 # The sumloss is for the whole training_set
        tsumloss = 0.0 # The tsumloss is for the printinterval

        sumloss_b = 0.0 # The sumloss is for the whole training_set
        tsumloss_b = 0.0 # The tsumloss is for the printinterval

        printinterval = 500
        print("---------------[Epoch%3d]---------------"%(epoch + 1))
        for label_list in trainLoader:
            bad_list = label_list[7:]
            label_list = label_list[:7]
            # IPython.embed()
            # exit()
            global_step = global_step + 1
            cnt = cnt + 1

            for i in range(5):
                imgL = var(bad_list[i]).cuda()
                imgR = var(bad_list[i+1]).cuda()
                poor_label = var(bad_list[i+2]).cuda()

                label = var(label_list[i+2]).cuda()
                label_L = var(label_list[i]).cuda()

                # ----------- Forward prediction -----------
                SepConvNet_optimizer.zero_grad()
                if i == 0:
                    output_f, stat = SepConvNet(imgL, imgR, 0)
                else:
                    output_f, stat = SepConvNet(imgL, imgR, 0, res_c, stat)

                loss = SepConvNet_cost(output_f, label)

                loss.backward(retain_graph=True)
               

                sumloss = sumloss + loss.data.item()
                tsumloss = tsumloss + loss.data.item()


                # ----------- Backward prediction -----------
                SepConvNet_optimizer.zero_grad()
                
                output_b, stat = SepConvNet(output_f, imgR, 1, tensorHidden=stat)

                loss = SepConvNet_cost(output_b, label_L)
                if i < 4:
                    loss.backward(retain_graph=True)
                else:
                    loss.backward(retain_graph=False)

                sumloss_b = sumloss_b + loss.data.item()
                tsumloss_b = tsumloss_b + loss.data.item()

                res_f = poor_label - output_f
                res_b = imgL - output_b
                res_c = torch.cat([res_f, res_b], 1)
            
            
            if cnt % printinterval == 0:
                print('Epoch [%d/%d], Iter [%d/%d], Time [%4.4f], Back loss[%.6f], Interval loss [%.6f]' %
                    (epoch + 1, EPOCH, cnt, len(trainset) // BATCH_SIZE, time.time() - start_time, tsumloss_b / printinterval / 5, tsumloss / printinterval / 5))
                tsumloss = 0.0
                tsumloss_b = 0.0
        print('Epoch [%d/%d], iter: %d, Time [%4.4f], Avg Loss [%.6f]' %
            (epoch + 1, EPOCH, cnt, time.time() - start_time, sumloss / cnt / 5))


        # ---------------- Part for validation ----------------
        trainloss = sumloss / cnt
        SepConvNet.eval().cuda()
        evalcnt = 0
        pos = 0.0
        sumloss = 0.0
        sumloss_b = 0.0
        psnr = 0.0
        psnr_b = 0.0
        for label_list in valLoader:
            
            bad_list = label_list[7:]
            label_list = label_list[:7]
            loss_s = []
            with torch.no_grad():
                for i in range(5):
                    imgL = var(bad_list[i]).cuda()
                    imgR = var(bad_list[i+1]).cuda()
                    poor_label = var(bad_list[i+2]).cuda()

                    label = var(label_list[i+2]).cuda()
                    label_L = var(label_list[i]).cuda()

                    # ----------- Forward prediction -----------
                    if i == 0:
                        output_f, stat = SepConvNet(imgL, imgR, 0)
                    else:
                        output_f, stat = SepConvNet(imgL, imgR, 0, res_c, stat)
                    loss = SepConvNet_cost(output_f, label)

                    psnr = psnr + calcPSNR.calcPSNR(output_f.cpu().data.numpy(), label.cpu().data.numpy())
                    sumloss = sumloss + loss.data.item()
                    # sumloss = sumloss + loss.data.item()
                    # tsumloss = tsumloss + loss.data.item()


                    # ----------- Backward prediction -----------
                    output_b, stat = SepConvNet(output_f, imgR, 1, tensorHidden=stat)

                    loss_b = SepConvNet_cost(output_b, label_L)
                    psnr_b = psnr + calcPSNR.calcPSNR(output_b.cpu().data.numpy(), label_L.cpu().data.numpy())
                    sumloss_b = sumloss_b + loss_b.data.item()
                    # sumloss_b = sumloss_b + loss.data.item()
                    # tsumloss_b = tsumloss_b + loss.data.item()

                    res_f = poor_label - output_f
                    res_b = imgL - output_b
                    res_c = torch.cat([res_f, res_b], 1)


                evalcnt = evalcnt + 5

        # ------------- Tensorboard part -------------
        # writer.add_scalar("Valid SATD loss", sumloss / evalcnt, epoch)
        # writer.add_scalar("Valid PSNR", psnr / valset.__len__(), epoch)
        # ------------- Tensorboard part -------------
        print('Validation loss [%.6f],  Average PSNR [%.4f], [!] Backward loss [%.6f] PSNR[%.4f]' % (
        sumloss / evalcnt, psnr / evalcnt, sumloss_b / evalcnt, psnr_b / evalcnt))
        SepConvNet_schedule.step(psnr / evalcnt)
        torch.save(SepConvNet.state_dict(),
                os.path.join('.', 'test_share_dual_LSTM_iter' + str(epoch + 1)
                                + '-ltype_fSATD_fs'
                                + '-lr_' + str(LEARNING_RATE)
                                + '-trainloss_' + str(round(trainloss, 4))
                                + '-evalloss_' + str(round(sumloss / evalcnt, 4))
                                + '-evalpsnr_' + str(round(psnr / evalcnt, 4)) + '.pkl'))
def main(lr, batch_size, epoch, gpu, train_set, valid_set):
    # ------------- Part for tensorboard --------------
    # writer = SummaryWriter(log_dir='tb/LSTM_ft1')
    # ------------- Part for tensorboard --------------
    torch.backends.cudnn.enabled = True
    torch.cuda.set_device(gpu)

    BATCH_SIZE = batch_size
    EPOCH = epoch

    LEARNING_RATE = lr
    belta1 = 0.9
    belta2 = 0.999

    trainset = vimeodataset(train_set, 'filelist.txt', transform_train)
    valset = vimeodataset(valid_set, 'test.txt')
    trainLoader = torch.utils.data.DataLoader(trainset,
                                              batch_size=BATCH_SIZE,
                                              shuffle=True)
    valLoader = torch.utils.data.DataLoader(valset,
                                            batch_size=BATCH_SIZE,
                                            shuffle=False)
    assert (len(valset) % BATCH_SIZE == 0)

    # SepConvNet.apply(weights_init)
    # SepConvNet.load_state_dict(torch.load('beta_LSTM_iter8-ltype_fSATD_fs-lr_0.001-trainloss_0.557-evalloss_0.1165-evalpsnr_29.8361.pkl'))

    SepConvNet = Network().cuda()
    SepConvNet.load_my_state_dict(
        torch.load(
            'ft2_baseline_iter86-ltype_fSATD_fs-lr_0.001-trainloss_0.1249-evalloss_0.1155-evalpsnr_29.9327.pkl',
            map_location='cuda:%d' % (gpu)))

    # MSE_cost = nn.MSELoss().cuda()
    # SepConvNet_cost = nn.L1Loss().cuda()
    SepConvNet_cost = sepconv.SATDLoss().cuda()
    SepConvNet_optimizer = optim.Adamax(SepConvNet.parameters(),
                                        lr=LEARNING_RATE,
                                        betas=(belta1, belta2))
    SepConvNet_schedule = optim.lr_scheduler.ReduceLROnPlateau(
        SepConvNet_optimizer,
        factor=0.1,
        patience=3,
        verbose=True,
        min_lr=1e-6)

    # ----------------  Time part -------------------
    start_time = time.time()
    global_step = 0
    # ----------------  Time part -------------------

    # ---------------- Opt part -----------------------
    # opter = Opter(gpu)
    # -------------------------------------------------

    for epoch in range(0, EPOCH):
        SepConvNet.train().cuda()
        cnt = 0
        sumloss = 0.0  # The sumloss is for the whole training_set
        tsumloss = 0.0  # The tsumloss is for the printinterval
        printinterval = 100
        print("---------------[Epoch%3d]---------------" % (epoch + 1))
        for label_list in trainLoader:
            bad_list = label_list[7:]
            label_list = label_list[:7]
            # IPython.embed()
            # exit()
            global_step = global_step + 1
            cnt = cnt + 1
            loss_s = []
            for i in range(5):
                imgL = var(bad_list[i]).cuda()
                imgR = var(bad_list[i + 1]).cuda()
                label = var(label_list[i + 2]).cuda()
                poor_label = var(bad_list[i + 2]).cuda()
                if i == 0:
                    SepConvNet_optimizer.zero_grad()

                    output, stat = SepConvNet(imgL, imgR)
                    res = poor_label - output
                    loss = SepConvNet_cost(output, label)

                    loss.backward(retain_graph=True)
                    SepConvNet_optimizer.step()

                    sumloss = sumloss + loss.data.item()
                    tsumloss = tsumloss + loss.data.item()

                elif i < 4:
                    SepConvNet_optimizer.zero_grad()

                    output, stat = SepConvNet(imgL, imgR, res, stat)
                    res = poor_label - output
                    loss = SepConvNet_cost(output, label)

                    loss.backward(retain_graph=True)
                    SepConvNet_optimizer.step()

                    sumloss = sumloss + loss.data.item()
                    tsumloss = tsumloss + loss.data.item()
                else:
                    SepConvNet_optimizer.zero_grad()

                    output, stat = SepConvNet(imgL, imgR, res, stat)
                    res = poor_label - output
                    loss = SepConvNet_cost(output, label)

                    loss.backward()
                    SepConvNet_optimizer.step()

                    sumloss = sumloss + loss.data.item()
                    tsumloss = tsumloss + loss.data.item()

            if cnt % printinterval == 0:
                print(
                    'Epoch [%d/%d], Iter [%d/%d], Time [%4.4f], Batch loss [%.6f], Interval loss [%.6f]'
                    % (epoch + 1, EPOCH, cnt, len(trainset) // BATCH_SIZE,
                       time.time() - start_time, loss.data.item(),
                       tsumloss / printinterval / 5))
                tsumloss = 0.0
        print('Epoch [%d/%d], iter: %d, Time [%4.4f], Avg Loss [%.6f]' %
              (epoch + 1, EPOCH, cnt, time.time() - start_time,
               sumloss / cnt / 5))

        # ---------------- Part for validation ----------------
        trainloss = sumloss / cnt
        SepConvNet.eval().cuda()
        evalcnt = 0
        pos = 0.0
        sumloss = 0.0
        psnr = 0.0
        for label_list in valLoader:

            bad_list = label_list[7:]
            label_list = label_list[:7]
            loss_s = []
            with torch.no_grad():
                for i in range(5):

                    imgL = var(bad_list[i]).cuda()
                    imgR = var(bad_list[i + 1]).cuda()
                    label = var(label_list[i + 2]).cuda()
                    poor_label = var(bad_list[i + 2]).cuda()

                    if i == 0:
                        output, stat = SepConvNet(imgL, imgR)
                        psnr = psnr + calcPSNR.calcPSNR(
                            output.cpu().data.numpy(),
                            label.cpu().data.numpy())
                        res = poor_label - output

                        loss = SepConvNet_cost(output, label)
                        sumloss = sumloss + loss.data.item()

                    else:
                        output, stat = SepConvNet(imgL, imgR, res, stat)
                        psnr = psnr + calcPSNR.calcPSNR(
                            output.cpu().data.numpy(),
                            label.cpu().data.numpy())
                        res = poor_label - output

                        loss = SepConvNet_cost(output, label)
                        sumloss = sumloss + loss.data.item()

                evalcnt = evalcnt + 5

        # ------------- Tensorboard part -------------
        # writer.add_scalar("Valid SATD loss", sumloss / evalcnt, epoch)
        # writer.add_scalar("Valid PSNR", psnr / valset.__len__(), epoch)
        # ------------- Tensorboard part -------------
        print('Validation loss [%.6f],  Average PSNR [%.4f]' %
              (sumloss / evalcnt, psnr / evalcnt))
        SepConvNet_schedule.step(psnr / evalcnt)
        torch.save(
            SepConvNet.state_dict(),
            os.path.join(
                '.', 'tail2_LSTM_iter' + str(epoch + 1) + '-ltype_fSATD_fs' +
                '-lr_' + str(LEARNING_RATE) + '-trainloss_' +
                str(round(trainloss, 4)) + '-evalloss_' +
                str(round(sumloss / evalcnt, 4)) + '-evalpsnr_' +
                str(round(psnr / evalcnt, 4)) + '.pkl'))
def main(lr, batch_size, epoch, gpu, train_set, valid_set):
    # ------------- Part for tensorboard --------------
    # writer = SummaryWriter(log_dir='tb/LSTM_ft1')
    # ------------- Part for tensorboard --------------
    torch.backends.cudnn.enabled = True
    torch.cuda.set_device(gpu)

    BATCH_SIZE = batch_size
    EPOCH = epoch

    LEARNING_RATE = lr
    belta1 = 0.9
    belta2 = 0.999

    trainset = vimeodataset(train_set, 'filelist.txt', transform_train)
    valset = vimeodataset(valid_set, 'test.txt')
    trainLoader = torch.utils.data.DataLoader(trainset,
                                              batch_size=BATCH_SIZE,
                                              shuffle=True)
    valLoader = torch.utils.data.DataLoader(valset,
                                            batch_size=BATCH_SIZE * 2,
                                            shuffle=False)
    assert (len(valset) % BATCH_SIZE == 0)

    SepConvNet = Network().cuda()
    # SepConvNet.apply(weights_init)
    # SepConvNet.load_my_state_dict(torch.load('SepConv_iter95-ltype_fSATD_fs-lr_0.0001-trainloss_0.1441-evalloss_0.1324-evalpsnr_29.9585.pkl', map_location="cuda:%d"%(gpu)))
    SepConvNet.load_my_state_dict(
        torch.load(
            'SepConv_iter95-ltype_fSATD_fs-lr_0.0001-trainloss_0.1441-evalloss_0.1324-evalpsnr_29.9585.pkl',
            map_location="cuda:%d" % (gpu)))

    # MSE_cost = nn.MSELoss().cuda()
    # SepConvNet_cost = nn.L1Loss().cuda()

    child_cnt = 0

    skip_childs = list(
        set(range(33)) - set([14, 15, 16, 17, 20, 21, 22, 23, 26, 27, 28, 29]))
    for child in SepConvNet.children():
        # print('-----------  Children:%d ----------------'%(child_cnt))
        # print(child)
        param_cnt = 0
        if not child_cnt in skip_childs:
            for param in child.parameters():
                # print("Param: %d in child: %d is frozen"%(param_cnt, child_cnt))
                param.requires_grad = False
                param_cnt += 1
        child_cnt += 1

    SepConvNet_cost = sepconv.SATDLoss().cuda()
    # SepConvNet_optimizer = optim.Adamax(SepConvNet.parameters(),lr=LEARNING_RATE, betas=(belta1,belta2))
    SepConvNet_optimizer = optim.Adamax(filter(lambda p: p.requires_grad,
                                               SepConvNet.parameters()),
                                        lr=LEARNING_RATE,
                                        betas=(belta1, belta2))
    SepConvNet_schedule = optim.lr_scheduler.ReduceLROnPlateau(
        SepConvNet_optimizer,
        factor=0.1,
        patience=3,
        verbose=True,
        min_lr=1e-6)
    # IPython.embed()
    # exit()
    # ----------------  Time part -------------------
    start_time = time.time()
    global_step = 0
    # ----------------  Time part -------------------

    # ---------------- Opt part -----------------------
    # opter = Opter(gpu)
    # -------------------------------------------------

    for epoch in range(0, EPOCH):
        SepConvNet.train().cuda()
        cnt = 0
        sumloss = 0.0  # The sumloss is for the whole training_set
        tsumloss = 0.0  # The tsumloss is for the printinterval
        printinterval = 300
        print("---------------[Epoch%3d]---------------" % (epoch + 1))
        for label_list in trainLoader:
            bad_list = label_list[7:]
            label_list = label_list[:7]
            # IPython.embed()
            # exit()
            global_step = global_step + 1
            cnt = cnt + 1
            loss_s = []
            for i in range(5):
                imgL = var(bad_list[i]).cuda()
                imgR = var(bad_list[i + 1]).cuda()
                label = var(label_list[i + 2]).cuda()
                poor_label = var(bad_list[i + 2]).cuda()
                if i == 0:
                    SepConvNet_optimizer.zero_grad()

                    output, output_a, output_b, stat = SepConvNet(imgL, imgR)
                    res = poor_label - output
                    # loss = SepConvNet_cost(output, label)
                    loss = 0.5*SepConvNet_cost(output, label) + \
                            0.2*SepConvNet_cost(output_a,func.upsample(label, size=(label.shape[2] // 4, label.shape[3] // 4), mode='bilinear',align_corners=True)) + \
                            0.3*SepConvNet_cost(output_b,func.upsample(label, size=(label.shape[2] // 2, label.shape[3] // 2), mode='bilinear',align_corners=True))

                    loss.backward(retain_graph=True)
                    SepConvNet_optimizer.step()

                    sumloss = sumloss + loss.data.item()
                    tsumloss = tsumloss + loss.data.item()

                elif i < 4:
                    SepConvNet_optimizer.zero_grad()

                    output, output_a, output_b, stat = SepConvNet(
                        imgL, imgR, res, stat)
                    res = poor_label - output
                    # loss = SepConvNet_cost(output, label)
                    loss = 0.5*SepConvNet_cost(output, label) + \
                            0.2*SepConvNet_cost(output_a,func.upsample(label, size=(label.shape[2] // 4, label.shape[3] // 4), mode='bilinear',align_corners=True)) + \
                            0.3*SepConvNet_cost(output_b,func.upsample(label, size=(label.shape[2] // 2, label.shape[3] // 2), mode='bilinear',align_corners=True))

                    loss.backward(retain_graph=True)
                    SepConvNet_optimizer.step()

                    sumloss = sumloss + loss.data.item()
                    tsumloss = tsumloss + loss.data.item()
                else:
                    SepConvNet_optimizer.zero_grad()

                    output, output_a, output_b, stat = SepConvNet(
                        imgL, imgR, res, stat)
                    res = poor_label - output
                    # loss = SepConvNet_cost(output, label)
                    loss = 0.5*SepConvNet_cost(output, label) + \
                            0.2*SepConvNet_cost(output_a,func.upsample(label, size=(label.shape[2] // 4, label.shape[3] // 4), mode='bilinear',align_corners=True)) + \
                            0.3*SepConvNet_cost(output_b,func.upsample(label, size=(label.shape[2] // 2, label.shape[3] // 2), mode='bilinear',align_corners=True))

                    loss.backward()
                    SepConvNet_optimizer.step()

                    sumloss = sumloss + loss.data.item()
                    tsumloss = tsumloss + loss.data.item()

            if cnt % printinterval == 0:
                print(
                    'Epoch [%d/%d], Iter [%d/%d], Time [%4.4f], Batch loss [%.6f], Interval loss [%.6f]'
                    % (epoch + 1, EPOCH, cnt, len(trainset) // BATCH_SIZE,
                       time.time() - start_time, loss.data.item(),
                       tsumloss / printinterval / 5))
                tsumloss = 0.0
        print('Epoch [%d/%d], iter: %d, Time [%4.4f], Avg Loss [%.6f]' %
              (epoch + 1, EPOCH, cnt, time.time() - start_time,
               sumloss / cnt / 5))

        # ---------------- Part for validation ----------------
        trainloss = sumloss / cnt
        SepConvNet.eval().cuda()
        evalcnt = 0
        pos = 0.0
        sumloss = 0.0
        psnr = 0.0
        for label_list in valLoader:

            bad_list = label_list[7:]
            label_list = label_list[:7]
            loss_s = []
            with torch.no_grad():
                for i in range(5):

                    imgL = var(bad_list[i]).cuda()
                    imgR = var(bad_list[i + 1]).cuda()
                    label = var(label_list[i + 2]).cuda()
                    poor_label = var(bad_list[i + 2]).cuda()

                    if i == 0:
                        output, output_a, output_b, stat = SepConvNet(
                            imgL, imgR)
                        psnr = psnr + calcPSNR.calcPSNR(
                            output.cpu().data.numpy(),
                            label.cpu().data.numpy())
                        res = poor_label - output

                        loss = SepConvNet_cost(output, label)
                        sumloss = sumloss + loss.data.item()

                    else:
                        output, output_a, output_b, stat = SepConvNet(
                            imgL, imgR, res, stat)
                        psnr = psnr + calcPSNR.calcPSNR(
                            output.cpu().data.numpy(),
                            label.cpu().data.numpy())
                        res = poor_label - output

                        loss = SepConvNet_cost(output, label)
                        sumloss = sumloss + loss.data.item()

                evalcnt = evalcnt + 5

        # ------------- Tensorboard part -------------
        # writer.add_scalar("Valid SATD loss", sumloss / evalcnt, epoch)
        # writer.add_scalar("Valid PSNR", psnr / valset.__len__(), epoch)
        # ------------- Tensorboard part -------------
        print('Validation loss [%.6f],  Average PSNR [%.4f]' %
              (sumloss / evalcnt, psnr / evalcnt))
        SepConvNet_schedule.step(psnr / evalcnt)
        torch.save(
            SepConvNet.state_dict(),
            os.path.join(
                '.', 'multiscale_test_LSTM_iter' + str(epoch + 1) +
                '-ltype_fSATD_fs' + '-lr_' + str(LEARNING_RATE) +
                '-trainloss_' + str(round(trainloss, 4)) + '-evalloss_' +
                str(round(sumloss / evalcnt, 4)) + '-evalpsnr_' +
                str(round(psnr / evalcnt, 4)) + '.pkl'))