Ejemplo n.º 1
0
    def act(self, action):
        self.psnr_pre = self.psnr
        if action == self.action_size - 1:  # stop
            self.terminal = True
        else:
            feed_dict = {self.inputs[action]: self.img}
            with self.graphs[action].as_default():
                with self.sessions[action].as_default():
                    im_out = self.sessions[action].run(self.outputs[action],
                                                       feed_dict=feed_dict)
            self.img = im_out
        self.psnr = psnr_cal(self.img, self.img_gt)

        # max step
        if self.count >= self.stop_step - 1:
            self.terminal = True

        # stop if too bad
        if self.psnr < self.psnr_init:
            self.terminal = True

        # calculate reward
        self.reward = self.reward_function(self.psnr, self.psnr_pre)
        self.count += 1

        return self.img, self.reward, self.terminal
Ejemplo n.º 2
0
    def new_image(self):
        self.terminal = False
        while self.data_index < self.data_len:
            self.img = self.data[self.data_index:self.data_index + 1, ...]
            self.img_gt = self.label[self.data_index:self.data_index + 1, ...]
            self.psnr = psnr_cal(self.img, self.img_gt)
            if self.psnr > 50:  # ignore too smooth samples and rule out 'inf'
                self.data_index += 1
            else:
                break

        # update training file
        if self.data_index >= self.data_len:
            if self.train_max > 1:
                self.train_cur += 1
                if self.train_cur >= self.train_max:
                    self.train_cur = 0

                # load new file
                print 'loading file No.%d' % (self.train_cur + 1)
                f = h5py.File(self.train_list[self.train_cur], 'r')
                self.data = f['data'].value
                self.label = f['label'].value
                self.data_len = len(self.data)
                f.close()

            # start from beginning
            self.data_index = 0
            while True:
                self.img = self.data[self.data_index:self.data_index + 1, ...]
                self.img_gt = self.label[self.data_index:self.data_index + 1,
                                         ...]
                self.psnr = psnr_cal(self.img, self.img_gt)
                if self.psnr > 50:  # ignore too smooth samples and rule out 'inf'
                    self.data_index += 1
                else:
                    break

        self.reward = 0
        self.count = 0
        self.psnr_init = self.psnr
        self.data_index += 1
        return self.img, self.reward, 0, self.terminal
Ejemplo n.º 3
0
    def act_test(self, action, step=0):
        reward_all = np.zeros(action.shape)
        psnr_all = np.zeros(action.shape)
        if step == 0:
            self.test_imgs = self.data_test.copy()
            self.test_temp_imgs = self.data_test.copy()
            self.test_pre_imgs = self.data_test.copy()
            self.test_steps = np.zeros(len(action), dtype=int)
        for k in range(len(action)):
            img_in = self.data_test[
                k:k + 1,
                ...].copy() if step == 0 else self.test_imgs[k:k + 1,
                                                             ...].copy()
            img_label = self.label_test[k:k + 1, ...].copy()
            self.test_temp_imgs[k:k + 1, ...] = img_in.copy()
            psnr_pre = psnr_cal(img_in, img_label)
            if action[k] == self.action_size - 1 or self.test_steps[
                    k] == self.stop_step:  # stop action or already stop
                img_out = img_in.copy()
                self.test_steps[k] = self.stop_step  # terminal flag
            else:
                feed_dict = {self.inputs[action[k]]: img_in}
                with self.graphs[action[k]].as_default():
                    with self.sessions[action[k]].as_default():
                        with tf.device('/gpu:0'):
                            img_out = self.sessions[action[k]].run(
                                self.outputs[action[k]], feed_dict=feed_dict)
                self.test_steps[k] += 1
            self.test_pre_imgs[k:k + 1, ...] = self.test_temp_imgs[k:k + 1,
                                                                   ...].copy()
            self.test_imgs[k:k + 1,
                           ...] = img_out.copy()  # keep intermediate results
            psnr = psnr_cal(img_out, img_label)
            reward = self.reward_function(psnr, psnr_pre=psnr_pre)
            psnr_all[k] = psnr
            reward_all[k] = reward

        if self.is_train:
            return reward_all.mean(), psnr_all.mean(), self.base_psnr
        else:
            return reward_all, psnr_all, self.base_psnr
Ejemplo n.º 4
0
    def update_test_data(self):
        self.test_cur = self.test_cur + len(self.data_test)
        test_end = min(self.test_total, self.test_cur + self.test_batch)
        if self.test_cur >= test_end:
            return False  #failed
        else:
            self.data_test = self.data_all[self.test_cur:test_end, ...]
            self.label_test = self.label_all[self.test_cur:test_end, ...]

            # update base psnr
            self.base_psnr = 0.
            for k in range(len(self.data_test)):
                self.base_psnr += psnr_cal(self.data_test[k, ...],
                                           self.label_test[k, ...])
            self.base_psnr /= len(self.data_test)
            return True  #successful
Ejemplo n.º 5
0
 def update_test_data(self):
     self.test_cur = self.test_cur + len(self.data_test)
     test_end = min(self.test_total, self.test_cur + self.test_batch)
     if self.test_cur >= test_end:
         return False  #failed
     else:
         self.data_test = self.data_all[self.test_cur:test_end, ...]
         self.label_test = self.label_all[self.test_cur:test_end, ...]
         # swap axes if shape is not right (for mixed data)
         if self.data_test.shape[-1] > 3:
             self.data_test = np.swapaxes(self.data_test, 1, 2)
             self.data_test = np.swapaxes(self.data_test, 2, 3)
             self.label_test = np.swapaxes(self.label_test, 1, 2)
             self.label_test = np.swapaxes(self.label_test, 2, 3)
         # update base psnr
         self.base_psnr = 0.
         for k in range(len(self.data_test)):
             self.base_psnr += psnr_cal(self.data_test[k, ...],
                                        self.label_test[k, ...])
         self.base_psnr /= len(self.data_test)
         return True  #successful
Ejemplo n.º 6
0
    def __init__(self, config):
        self.reward = 0
        self.terminal = True
        self.stop_step = config.stop_step
        self.reward_func = config.reward_func
        self.is_train = config.is_train
        self.count = 0  # count restoration step
        self.psnr, self.psnr_pre, self.psnr_init = 0., 0., 0.

        if self.is_train:
            # training data
            self.train_list = [
                config.train_dir + file
                for file in os.listdir(config.train_dir)
                if file.endswith('.h5')
            ]
            self.train_cur = 0
            self.train_max = len(self.train_list)
            f = h5py.File(self.train_list[self.train_cur], 'r')
            self.data = f['data'].value
            self.label = f['label'].value
            f.close()
            self.data_index = 0
            self.data_len = len(self.data)

            # validation data
            f = h5py.File(config.val_dir + os.listdir(config.val_dir)[0], 'r')
            self.data_test = f['data'].value
            self.label_test = f['label'].value
            f.close()
            self.data_all = self.data_test
            self.label_all = self.label_test
        else:
            # test data
            self.test_batch = config.test_batch
            self.test_in = config.test_dir + config.dataset + '_in/'
            self.test_gt = config.test_dir + config.dataset + '_gt/'
            list_in = [
                self.test_in + name for name in os.listdir(self.test_in)
            ]
            list_in.sort()
            list_gt = [
                self.test_gt + name for name in os.listdir(self.test_gt)
            ]
            list_gt.sort()
            self.name_list = [
                os.path.splitext(os.path.basename(file))[0] for file in list_in
            ]
            self.data_all, self.label_all = load_imgs(list_in, list_gt)
            self.test_total = len(list_in)
            self.test_cur = 0

            # data reformat, because the data for tools training are in a different format
            self.data_all = data_reformat(self.data_all)
            self.label_all = data_reformat(self.label_all)
            self.data_test = self.data_all[
                0:min(self.test_batch, self.test_total), ...]
            self.label_test = self.label_all[
                0:min(self.test_batch, self.test_total), ...]

        # input PSNR
        self.base_psnr = 0.
        for k in range(len(self.data_all)):
            self.base_psnr += psnr_cal(self.data_all[k, ...],
                                       self.label_all[k, ...])
        self.base_psnr /= len(self.data_all)

        # reward functions
        self.rewards = {'step_psnr_reward': step_psnr_reward}
        self.reward_function = self.rewards[self.reward_func]

        # build toolbox
        self.action_size = 12 + 1
        toolbox_path = 'toolbox/'
        self.graphs = []
        self.sessions = []
        self.inputs = []
        self.outputs = []
        for idx in range(12):
            g = tf.Graph()
            with g.as_default():
                # load graph
                saver = tf.train.import_meta_graph(toolbox_path + 'tool%02d' %
                                                   (idx + 1) + '.meta')
                # input data
                input_data = g.get_tensor_by_name('Placeholder:0')
                self.inputs.append(input_data)
                # get the output
                output_data = g.get_tensor_by_name('sum:0')
                self.outputs.append(output_data)
                # save graph
                self.graphs.append(g)
            sess = tf.Session(graph=g,
                              config=tf.ConfigProto(log_device_placement=True))
            with g.as_default():
                with sess.as_default():
                    saver.restore(sess, toolbox_path + 'tool%02d' % (idx + 1))
                    self.sessions.append(sess)
Ejemplo n.º 7
0
    def __init__(self, config):

        screen_width, screen_height = config.screen_width, config.screen_height
        self.dims = (screen_width, screen_height)
        self.test_batch = config.test_batch
        self.test_in = 'test_images/' + config.dataset + '_in/'
        self.test_gt = 'test_images/' + config.dataset + '_gt/'
        self._screen = None
        self.reward = 0
        self.terminal = True
        self.stop_step = config.stop_step
        self.reward_func = config.reward_func

        # test data
        list_in = [self.test_in + name for name in os.listdir(self.test_in)]
        list_in.sort()
        list_gt = [self.test_gt + name for name in os.listdir(self.test_gt)]
        list_gt.sort()
        self.data_all, self.label_all = load_imgs(list_in, list_gt)
        self.test_total = len(list_in)
        self.test_cur = 0

        # BGR --> RGB, swap H and W
        # This is because the data for tools training are in a different format
        # You don't need to do so with your own tools
        temp = self.data_all.copy()
        self.data_all[:, :, :, 0] = temp[:, :, :, 2]
        self.data_all[:, :, :, 2] = temp[:, :, :, 0]
        self.data_all = np.swapaxes(self.data_all, 1, 2)
        temp = self.label_all.copy()
        self.label_all[:, :, :, 0] = temp[:, :, :, 2]
        self.label_all[:, :, :, 2] = temp[:, :, :, 0]
        self.label_all = np.swapaxes(self.label_all, 1, 2)

        self.data_test = self.data_all[0:min(self.test_batch, self.test_total),
                                       ...]
        self.label_test = self.label_all[
            0:min(self.test_batch, self.test_total), ...]

        # reward functions
        self.rewards = {'step_psnr_reward': step_psnr_reward}
        self.reward_function = self.rewards[self.reward_func]

        # base_psnr (input psnr)
        self.base_psnr = 0.
        for k in range(len(self.data_all)):
            self.base_psnr += psnr_cal(self.data_all[k, ...],
                                       self.label_all[k, ...])
        self.base_psnr /= len(self.data_all)

        self.data = np.array([[[[0]]]])
        self._data_index = 0
        self._data_len = len(self.data)

        # build toolbox
        self.action_size = 12 + 1
        toolbox_path = 'toolbox/'
        self.graphs = []
        self.sessions = []
        self.inputs = []
        self.outputs = []
        for idx in range(12):
            g = tf.Graph()
            with g.as_default():
                # load graph
                saver = tf.train.import_meta_graph(toolbox_path + 'tool%02d' %
                                                   (idx + 1) + '.meta')
                # input data
                input_data = g.get_tensor_by_name('Placeholder:0')
                self.inputs.append(input_data)
                # get the output
                output_data = g.get_tensor_by_name('sum:0')
                self.outputs.append(output_data)
                # save graph
                self.graphs.append(g)
            sess = tf.Session(graph=g,
                              config=tf.ConfigProto(log_device_placement=True))
            with g.as_default():
                with sess.as_default():
                    saver.restore(sess, toolbox_path + 'tool%02d' % (idx + 1))
                    self.sessions.append(sess)
Ejemplo n.º 8
0
def main(arg):
    print("===> Loading datasets")
    lr_list = glob.glob(os.path.join(args.data_lr, '*'))
    hr_list = glob.glob(os.path.join(args.data_hr, '*'))
    data_set = DatasetLoader(lr_list, hr_list, arg.patch_size, arg.scale)
    train_loader = DataLoader(data_set,
                              batch_size=arg.batch_size,
                              num_workers=arg.workers,
                              shuffle=True,
                              pin_memory=True,
                              drop_last=True)

    print("===> Building model")
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    cudnn.benchmark = True

    device_ids = list(range(args.gpus))
    model = RCAN(arg)
    criterion = nn.L1Loss(reduction='sum')

    print("===> Setting GPU")
    model = nn.DataParallel(model, device_ids=device_ids)
    model = model.cuda()
    criterion = criterion.cuda()

    # optionally resume from a checkpoint
    if arg.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(arg.resume))
            checkpoint = torch.load(arg.resume)
            new_state_dict = OrderedDict()
            for k, v in checkpoint.items():
                namekey = 'module.' + k  # remove `module.`
                new_state_dict[namekey] = v
            model.load_state_dict(new_state_dict)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    print("===> Setting Optimizer")
    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=arg.lr,
                           weight_decay=arg.weight_decay,
                           betas=(0.9, 0.999),
                           eps=1e-08)

    print("===> Training")
    for epoch in range(args.start_epoch, args.epochs):
        adjust_lr(optimizer, epoch)
        model.train()
        losses = AverageMeter()
        psnrs = AverageMeter()
        with tqdm(total=(len(data_set) -
                         len(data_set) % args.batch_size)) as t:
            t.set_description('epoch:{}/{} lr={}'.format(
                epoch, args.epochs - 1, optimizer.param_groups[0]["lr"]))

            for data in train_loader:
                data_x, data_y = Variable(data[0]), Variable(
                    data[1], requires_grad=False)

                data_x = data_x.type(torch.FloatTensor)
                data_y = data_y.type(torch.FloatTensor)

                data_x = data_x.cuda()
                data_y = data_y.cuda()

                pred = model(data_x)
                # pix loss
                loss = criterion(pred, data_y)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                pred = pred.cpu()
                pred = pred.detach().numpy().astype(np.float32)

                data_y = data_y.cpu()
                data_y = data_y.numpy().astype(np.float32)

                psnr = psnr_cal(pred, data_y)
                mean_loss = loss.item() / (args.batch_size * args.n_colors *
                                           ((args.patch_size * args.scale)**2))
                losses.update(mean_loss)
                psnrs.update(psnr)

                t.set_postfix(loss='Loss: {losses.val:.3f} ({losses.avg:.3f})'
                              ' PNSR: {psnrs.val:.3f} ({psnrs.avg:.3f})'.
                              format(losses=losses, psnrs=psnrs))

                t.update(len(data[0]))

        # save model
        model_out_path = os.path.join(args.checkpoint,
                                      "model_epoch_{}_rcan.pth".format(epoch))
        if not os.path.exists(args.checkpoint):
            os.makedirs(args.checkpoint)
        torch.save(model.module.state_dict(), model_out_path)