예제 #1
0
    def __init__(self,
                 train_file,
                 lr_root,
                 hr_root,
                 crop_size=[256, 256],
                 flip=True,
                 extension="bmp",
                 is_train=True):
        system_log.WriteLine("reading dataset")
        self.is_train = is_train
        self.flip = flip
        self.crop_size = crop_size
        self.lr_seq = []
        self.hr_seq = []
        with open(train_file, 'r') as f:
            lines = f.readlines()
            for line in lines:
                element = line.strip('\n').split(' ')
                lr_video_name = element[0]
                hr_video_name = element[1]
                lr_video_folder = os.path.join(lr_root, lr_video_name)
                hr_video_folder = os.path.join(hr_root, hr_video_name)

                max_items = len([name for name in os.listdir(hr_video_folder)])
                for i in range(1, max_items + 1):
                    lr_path = os.path.join(
                        lr_video_folder,
                        "%s_%03d.%s" % (lr_video_name, i, extension))
                    hr_path = os.path.join(
                        hr_video_folder,
                        "%s_%03d.%s" % (hr_video_name, i, extension))
                    self.lr_seq.append(lr_path)
                    self.hr_seq.append(hr_path)

        system_log.WriteLine(f"total frame {len(self.hr_seq)}")
예제 #2
0
    def __init__(self,
                 train_file,
                 lr_root,
                 hr_root,
                 depth,
                 crop_size=[256, 256],
                 flip=True,
                 extension="bmp",
                 is_train=True):
        system_log.WriteLine("reading dataset")
        self.is_train = is_train
        self.flip = flip
        self.crop_size = crop_size
        self.lr_seq = []
        self.hr_seq = []
        with open(train_file, 'r') as f:
            lines = f.readlines()
            for line in lines:
                element = line.strip('\n').split(' ')
                lr_video_name = element[0]
                hr_video_name = element[1]
                lr_video_folder = os.path.join(lr_root, lr_video_name)
                hr_video_folder = os.path.join(hr_root, hr_video_name)

                name_list = list(
                    filter(lambda x: f".{extension}" in x,
                           [name for name in os.listdir(hr_video_folder)]))
                max_items = len(name_list)
                # if hr_video_name == "10091373":
                #     print(f"max_items = {max_items}")
                #     for name in os.listdir(hr_video_folder):
                #         print(f"name = {name}")
                # print(f"hr_video_name = {hr_video_name},  max_items = {max_items}")

                for i in range((depth // 2) + 1, max_items - (depth // 2) + 1):
                    lr_sub_seq = []
                    for j in range(-(depth // 2), (depth // 2) + 1):
                        lr_index = i + j
                        lr_path = os.path.join(
                            lr_video_folder, "%s_%03d.%s" %
                            (lr_video_name, lr_index, extension))
                        lr_sub_seq.append(lr_path)
                    self.lr_seq.append(lr_sub_seq)
                    hr_path = os.path.join(
                        hr_video_folder,
                        "%s_%03d.%s" % (hr_video_name, i, extension))
                    self.hr_seq.append(hr_path)

        system_log.WriteLine(f"total frame {len(self.hr_seq)}")
예제 #3
0
    def __init__(self, train_file, lr_root, hr_root, depth, crop_size=[256,256], flip=True, extension="bmp", is_train=True):
        system_log.WriteLine("reading dataset")
        self.is_train = is_train
        self.flip = flip
        self.crop_size = crop_size
        self.lr_seq = []
        self.hr_seq = []
        with open(train_file, 'r') as f:
            lines = f.readlines()
            for line in lines:
                element = line.strip('\n').split(' ')
                lr_video_name = element[0]
                hr_video_name = element[1]
                lr_video_folder = os.path.join(lr_root, lr_video_name)
                hr_video_folder = os.path.join(hr_root, hr_video_name)

                images_name = [name for name in os.listdir(hr_video_folder)]    # Youku_00250_h_GT_001.bmp   Youku_00250_l_001.bmp
                images_name.sort(key = lambda x: int(x[-7:-4]))
                first_image_index = int(images_name[0][-7:-4])
                end_image_index = int(images_name[-1][-7:-4])
                _lr_video_name = images_name[0][:-12]    # Youku_00250_
                _lr_video_name = f"{_lr_video_name}l"
                _hr_video_name = _lr_video_name[:-1]    # Youku_00250_
                _hr_video_name = f"{_hr_video_name}h_GT"    # Youku_00250_h_GT

                for i in range(first_image_index, end_image_index+1):
                    lr_sub_seq = []
                    first_lr_path = os.path.join(lr_video_folder, "%s_%03d.%s"%(_lr_video_name,first_image_index,extension))
                    end_lr_path = os.path.join(lr_video_folder, "%s_%03d.%s"%(_lr_video_name,end_image_index,extension))
                    for j in range(-(depth//2),(depth//2)+1):
                        lr_index = i + j
                        if lr_index > first_image_index-1 and lr_index < end_image_index+1:
                            lr_path = os.path.join(lr_video_folder, "%s_%03d.%s"%(_lr_video_name,lr_index,extension))
                            lr_sub_seq.append(lr_path)
                        elif lr_index < first_image_index:
                            lr_sub_seq.append(first_lr_path)
                        else:
                            lr_sub_seq.append(end_lr_path)
                    self.lr_seq.append(lr_sub_seq)
                    hr_path = os.path.join(hr_video_folder, "%s_%03d.%s"%(_hr_video_name,i,extension))
                    self.hr_seq.append(hr_path)

        system_log.WriteLine(f"total frame {len(self.hr_seq)}")
예제 #4
0
def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate"""
    lr = system_config.lr
    for count, epoch_idx in enumerate(system_config.multi_step):
        if epoch == epoch_idx:
            for _ in range(count + 1):
                lr = lr * 0.1

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    system_log.WriteLine(f"update learning rate to {lr}")
예제 #5
0
def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate"""
    learning_mode = system_config.learning_mode
    lr = system_config.lr
    if learning_mode == "cosin":
        c = (np.cos((epoch % 100) * 0.02 * math.pi) + 1) / 2
        lr = (lr * c) + 0.000001
    elif learning_mode == "normal":
        multi_step = system_config.multi_step
        for i in range(len(multi_step)):
            if i == 0:
                if epoch <= multi_step[0]:
                    lr = lr
            else:
                if epoch > multi_step[i - 1] and epoch <= multi_step[i]:
                    lr = lr * 0.5**i

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    system_log.WriteLine(f"update learning rate to {lr}")
예제 #6
0
                # flip W and H
                out = image_forward(net, torch.flip(tensor_in, (-2, -1)))
                out = torch.flip(out, (-2, -1))
                out = out.detach().cpu().numpy()[0]
                output_f = output_f + out

                output_f = output_f / 4
                output_f = output_f * 255
                output_f = output_f.transpose(1, 2, 0)

                image_index = int(frame_list[frame_idx][-7:-4])
                output_path = os.path.join(
                    output_img_folder,
                    "%s_%03d.%s" % (folder_name[:-2], image_index, extension))
                cv.imwrite(output_path, output_f)
                system_log.WriteLine(f"Write Image to {output_path}")

        # sub frame
        calc_index = [1, 26, 51, 76]
        for folder_name in sub_frame_folder_list:
            folder_path = os.path.join(input_folder, folder_name)
            frame_list = []
            for frame_path in os.listdir(folder_path):
                frame_list.append(frame_path)  # Youku_00850_l_001.bmp

            frame_list.sort(key=lambda x: int(x[-7:-4]))
            first_index = int(frame_list[0][-7:-4])
            last_index = int(frame_list[-1][-7:-4])
            frame_name = frame_list[0][:13]
            output_img_folder = os.path.join(img_folder, frame_name)
            if not os.path.exists(output_img_folder):
예제 #7
0
                out = image_forward(net, torch.flip(tensor_in, (-2, -1)))
                out = torch.flip(out, (-2, -1))
                out = out.detach().cpu().numpy()
                output_f = output_f + out

                output_f = output_f / 4
                output_f = output_f * 255

                for i in range(batch_size):
                    cur_output_index = frame_index - batch_size + i
                    cv_img = output_f[i]
                    cv_img = cv_img.transpose(1, 2, 0)
                    output_path = os.path.join(output_img_folder,
                                               frames_list[cur_output_index])
                    cv.imwrite(output_path, cv_img)
                    system_log.WriteLine(f"write image to {output_path}")

    # # bmp 2 mp4
    # for video_name in video_list:
    #     hr_name = f"{video_name}.mp4"
    #     hr_name = os.path.join(final_folder, hr_name)
    #     output_img_folder = os.path.join(img_folder, video_name)
    #     shell_merge = f"ffmpeg -i {output_img_folder}/{video_name}_%03d.{extension}  -pix_fmt yuv420p  -vsync 0 {hr_name} -y"
    #     os.system(shell_merge)

    # # zip
    # system_log.WriteLine(f"zip...")
    # zipDir(f"{final_folder}",f"{output_folder}/result.zip")

    system_log.WriteLine(f"all done")
예제 #8
0
def train_vsr(cfg, model_path=None):
    cfg_file = cfg
    with open(cfg_file, 'r') as f:
        system_config.update_config(json.load(f))
    system_log.set_filepath(system_config.log_path)

    lr_root = "dataset/train/540p/"
    hr_root = "dataset/train/4K/"

    if system_config.Stage2:
        if system_config.MiniTest:
            train_file = "dataset/train/miniTest.txt"
            validation_file = "dataset/train/miniTest.txt"
        else:
            train_file = "dataset/train/train.txt"
            validation_file = "dataset/train/validation.txt"

        lr_root = "dataset/train/540p/"
        hr_root = "dataset/train/4K/"

    else:
        if system_config.MiniTest:
            train_file = "dataset/train/miniTest.txt"
            validation_file = "dataset/train/miniTest.txt"
        else:
            train_file = "dataset/train/train"
            validation_file = "dataset/train/validation"
            if system_config.seg_frame:
                train_file = f"{train_file}_seg"
                validation_file = f"{validation_file}_seg"
                lr_root = f"{lr_root}_seg"
                hr_root = f"{hr_root}_seg"

            train_file = f"{train_file}.txt"
            validation_file = f"{validation_file}.txt"

    if system_config.Stage2:
        train_loader = torch.utils.data.DataLoader(dataset_loader_vsr_stage2(train_file, lr_root, hr_root, system_config.depth, crop_size=system_config.input_size, flip=system_config.flip, extension=system_config.extension, is_train=True), \
            batch_size=system_config.batch_size, shuffle = True, num_workers=10, pin_memory=True)
        validation_loader = torch.utils.data.DataLoader(dataset_loader_vsr_stage2(validation_file, lr_root, hr_root, system_config.depth, extension=system_config.extension, is_train=False), \
            batch_size=system_config.validation_batch_size, shuffle = True, num_workers=10, pin_memory=True)
    else:
        train_loader = torch.utils.data.DataLoader(dataset_loader_vsr(train_file, lr_root, hr_root, system_config.depth, crop_size=system_config.input_size, flip=system_config.flip, extension=system_config.extension, is_train=True), \
            batch_size=system_config.batch_size, shuffle = True, num_workers=10, pin_memory=True)
        validation_loader = torch.utils.data.DataLoader(dataset_loader_vsr(validation_file, lr_root, hr_root, system_config.depth, extension=system_config.extension, is_train=False), \
            batch_size=system_config.validation_batch_size, shuffle = True, num_workers=10, pin_memory=True)

    if system_config.net == "EDVR":
        net = EDVR_arch.EDVR(128,
                             system_config.depth,
                             8,
                             5,
                             40,
                             predeblur=False,
                             HR_in=False)
    elif system_config.net == "EDVR_CBAM":
        net = EDVR_arch.EDVR_CBAM(128,
                                  system_config.depth,
                                  8,
                                  5,
                                  40,
                                  predeblur=False,
                                  HR_in=False)
    elif system_config.net == "EDVR_CBAM_Stage2":
        net = EDVR_arch.EDVR_CBAM(128,
                                  system_config.depth,
                                  8,
                                  5,
                                  40,
                                  predeblur=False,
                                  HR_in=True)
    elif system_config.net == "EDVR_DUF":
        net = EDVR_arch.EDVR_DUF(128,
                                 system_config.depth,
                                 8,
                                 5,
                                 40,
                                 predeblur=False,
                                 HR_in=False)
    elif system_config.net == "EDVR_DUF_V2":
        net = EDVR_arch.EDVR_DUF_v2(128,
                                    system_config.depth,
                                    8,
                                    5,
                                    40,
                                    predeblur=False,
                                    HR_in=False)
    elif system_config.net == "EDVR_V2":
        net = EDVR_arch.EDVR_v2(128,
                                system_config.depth,
                                8,
                                10,
                                40,
                                predeblur=False,
                                HR_in=False)
    elif system_config.net == "EDVR_FUSION":
        net = EDVR_arch.EDVR_Fusion(128,
                                    system_config.depth,
                                    8,
                                    5,
                                    40,
                                    predeblur=False,
                                    HR_in=False)
    elif system_config.net == "EDVR_FUSION_CBAM":
        net = EDVR_arch.EDVR_Fusion_CBAM(128,
                                         system_config.depth,
                                         8,
                                         5,
                                         40,
                                         predeblur=False,
                                         HR_in=False)
    elif system_config.net == "EDVR_FUSION_WD":
        net = EDVR_arch.EDVR_Fusion_WD(128,
                                       system_config.depth,
                                       8,
                                       5,
                                       40,
                                       predeblur=False,
                                       HR_in=False)
    elif system_config.net == "EDVR_Denoise":
        net = EDVR_arch.EDVR_Denoise(128,
                                     system_config.depth,
                                     8,
                                     5,
                                     5,
                                     5,
                                     40,
                                     predeblur=False,
                                     HR_in=False)
    elif system_config.net == "EDVR_CBAM_Nonlocal":
        net = EDVR_arch.EDVR_CBAM_Nonlocal(128,
                                           system_config.depth,
                                           8,
                                           3,
                                           system_config.non_local[0],
                                           2,
                                           25,
                                           system_config.non_local[1],
                                           10,
                                           system_config.non_local[2],
                                           5,
                                           predeblur=False,
                                           HR_in=False)
    elif system_config.net == "EDVR_CBAM_Denoise_Nonlocal":
        net = EDVR_arch.EDVR_Denoise_Nonlocal(128,
                                              system_config.depth,
                                              8,
                                              5,
                                              5,
                                              3,
                                              system_config.non_local[0],
                                              2,
                                              25,
                                              system_config.non_local[1],
                                              10,
                                              system_config.non_local[2],
                                              5,
                                              predeblur=False,
                                              HR_in=False)
    elif system_config.net == "EDVR_CBAM_Denoise":
        net = EDVR_arch.EDVR_Denoise(128,
                                     system_config.depth,
                                     8,
                                     5,
                                     5,
                                     5,
                                     40,
                                     predeblur=False,
                                     HR_in=False)

    if not model_path == None:
        net.load_state_dict(torch.load(model_path))
        system_log.WriteLine(f"loading model from {model_path}")

    net = net.cuda()
    net = torch.nn.DataParallel(net)

    train_loss_iter = AverageMeter()
    train_loss_total = AverageMeter()
    loss_fun = CharbonnierLoss()
    mse_fun = nn.MSELoss()
    optimizer = optim.Adam(net.parameters(), lr=system_config.lr)
    min_loss = np.inf
    max_psnr = 0

    system_log.WriteLine(f"config: {system_config.config_all()}")

    for epoch_idx in range(1, system_config.max_epochs + 1):
        train_loss_iter.reset()
        net.train()

        adjust_learning_rate(optimizer, epoch_idx)
        # if epoch_idx in system_config.multi_step:
        #     adjust_learning_rate(optimizer, epoch_idx)

        start = time.time()
        for lr_seq, lr_seq_reverse, hr_seq in train_loader:
            iter_start = time.time()
            tensor_lr = torch.Tensor(lr_seq / 255).cuda()
            tensor_hr = torch.Tensor(hr_seq / 255).cuda()

            out = net(tensor_lr)
            loss = loss_fun(out, tensor_hr)

            if system_config.PP_loss:
                tensor_lr_reverse = torch.Tensor(lr_seq_reverse / 255).cuda()
                out_reverse = net(tensor_lr_reverse)
                pp_loss = mse_fun(out, out_reverse)
                loss += pp_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss_iter.update(float(loss))
            iter_end = time.time()

            system_log.WriteLine(
                f"Epoch[{epoch_idx}/{system_config.max_epochs}]:train_loss_epoch:{train_loss_iter.avg:.8f}, loss_iter: {loss:.8f}, cost time: {(iter_end-iter_start):.8f}sec!"
            )

        end = time.time()
        train_loss_total.update(train_loss_iter.avg)
        system_log.WriteLine(
            f"Epoch[{epoch_idx}/{system_config.max_epochs}]: train_loss_total:{train_loss_total.avg:.8f}, train_loss_iter:{train_loss_iter.avg:.8f}, cost time: {(end-start):.8f}sec!"
        )

        # min loss
        if min_loss > train_loss_iter.avg:
            saved_model = net.module
            torch.save(saved_model.state_dict(),
                       system_config.min_loss_model_path)
            system_log.WriteLine(
                f"min loss update from {min_loss} to {train_loss_iter.avg}, save model to {system_config.min_loss_model_path}"
            )
            min_loss = train_loss_iter.avg

        # save ckpt
        saved_model = net.module
        torch.save(saved_model.state_dict(),
                   system_config.ckpt_path.format(epoch_idx))

        # validation
        if epoch_idx % system_config.validation_per_epochs == 0:
            val_start = time.time()
            psnr = validation(net, validation_loader)
            val_end = time.time()

            system_log.WriteLine(
                f"Validation: psnr:{psnr:.8f}, cost time: {(val_end-val_start):8f}sec!"
            )
            if max_psnr < psnr:
                # save model
                saved_model = net.module
                torch.save(saved_model.state_dict(),
                           system_config.best_model_path)
                system_log.WriteLine(
                    f"psnr update from {max_psnr} to {psnr}, save model to {system_config.best_model_path}"
                )
                max_psnr = psnr

    system_log.WriteLine(f"train done!")
    system_log.WriteLine(f"min_loss: {min_loss}, max_psnr: {max_psnr}")
예제 #9
0
def train_sisr(cfg):
    cfg_file = cfg
    with open(cfg_file, 'r') as f:
        system_config.update_config(json.load(f))
    system_log.set_filepath(system_config.log_path)

    lr_root = "/dfsdata2/share-group/aisz_group/tianchi/round2/LR/images/" + system_config.extension
    hr_root = "/dfsdata2/share-group/aisz_group/tianchi/round2/HR/images/" + system_config.extension

    if not system_config.MiniTest:
        train_file = "/dfsdata2/share-group/aisz_group/tianchi/round2/train/train"
        validation_file = "/dfsdata2/share-group/aisz_group/tianchi/round2/validation/val"
        if system_config.seg_frame:
            train_file = f"{train_file}_seg"
            validation_file = f"{validation_file}_seg"
            lr_root = f"{lr_root}_seg"
            hr_root = f"{hr_root}_seg"

        train_file = f"{train_file}.txt"
        validation_file = f"{validation_file}.txt"

    else:
        train_file = "/dfsdata2/liuwei79_data/ImageDatabase/tianchi/round2/miniTest.txt"
        validation_file = "/dfsdata2/liuwei79_data/ImageDatabase/tianchi/round2/miniTest.txt"

    train_loader = torch.utils.data.DataLoader(dataset_loader_sisr(train_file, lr_root, hr_root, crop_size=system_config.input_size, flip=system_config.flip, extension=system_config.extension, is_train=True), \
        batch_size=system_config.batch_size, shuffle = True, num_workers=10, pin_memory=True)
    validation_loader = torch.utils.data.DataLoader(dataset_loader_sisr(validation_file, lr_root, hr_root, extension=system_config.extension, is_train=False), \
        batch_size=system_config.validation_batch_size, shuffle = True, num_workers=10, pin_memory=True)

    if system_config.net == "WDSR_A":
        net = WDSR_A(4, system_config.n_resblocks, 64, 192).cuda()

    net = torch.nn.DataParallel(net)

    train_loss_iter = AverageMeter()
    train_loss_total = AverageMeter()
    loss_fun = CharbonnierLoss()
    optimizer = optim.Adam(net.parameters(), lr=system_config.lr)
    min_loss = np.inf
    max_psnr = 0

    system_log.WriteLine(f"config: {system_config.config_all()}")

    for epoch_idx in range(1, system_config.max_epochs + 1):
        train_loss_iter.reset()
        net.train()

        if epoch_idx in system_config.multi_step:
            adjust_learning_rate(optimizer, epoch_idx)

        start = time.time()
        for lr_seq, hr_seq in train_loader:
            iter_start = time.time()
            tensor_lr = torch.Tensor(lr_seq / 255).cuda()
            tensor_hr = torch.Tensor(hr_seq / 255).cuda()

            out = net(tensor_lr)
            loss = loss_fun(out, tensor_hr)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss_iter.update(float(loss))
            iter_end = time.time()

            system_log.WriteLine(
                f"Epoch[{epoch_idx}/{system_config.max_epochs}]:train_loss_epoch:{train_loss_iter.avg:.8f}, loss_iter: {loss:.8f}, cost time: {(iter_end-iter_start):.8f}sec!"
            )

        end = time.time()
        train_loss_total.update(train_loss_iter.avg)
        system_log.WriteLine(
            f"Epoch[{epoch_idx}/{system_config.max_epochs}]: train_loss_total:{train_loss_total.avg:.8f}, train_loss_iter:{train_loss_iter.avg:.8f}, cost time: {(end-start):.8f}sec!"
        )

        # min loss
        if min_loss > train_loss_iter.avg:
            saved_model = net.module
            torch.save(saved_model.state_dict(),
                       system_config.min_loss_model_path)
            system_log.WriteLine(
                f"min loss update from {min_loss} to {train_loss_iter.avg}, save model to {system_config.min_loss_model_path}"
            )
            min_loss = train_loss_iter.avg

        # save ckpt
        saved_model = net.module
        torch.save(saved_model.state_dict(),
                   system_config.ckpt_path.format(epoch_idx))

        # validation
        if epoch_idx % system_config.validation_per_epochs == 0:
            val_start = time.time()
            psnr = validation(net, validation_loader)
            val_end = time.time()

            system_log.WriteLine(
                f"Validation: psnr:{psnr:.8f}, cost time: {(val_end-val_start):8f}sec!"
            )
            if max_psnr < psnr:
                # save model
                saved_model = net.module
                torch.save(saved_model.state_dict(),
                           system_config.best_model_path)
                system_log.WriteLine(
                    f"psnr update from {max_psnr} to {psnr}, save model to {system_config.best_model_path}"
                )
                max_psnr = psnr

    system_log.WriteLine(f"train done!")
    system_log.WriteLine(f"min_loss: {min_loss}, max_psnr: {max_psnr}")