Пример #1
0
 def get_query_indicies(self):
     if len(self.query_indicies) < self.MAX_QUERIES:
         LOG.warning(
             f'Called get_query_indicies but number of queries '
             f'({self.query_indicies}) is less than maximum number '
             f'({self.MAX_QUERIES})')
     return self.query_indicies
Пример #2
0
    def step(self, action):
        reward = 0

        predicted_label = action
        true_label = np.argmax(self.storage.get_y(self._counter))

        LOG.debug(f'Predicted label: {predicted_label},'
                  f'true_label: {true_label}')

        if predicted_label == true_label:
            self.stats['good'] += 1
            reward += 1
        else:
            self.stats['wrong'] += 1

        self._counter += 1
        info = {}
        try:
            self._state = self.storage.get_x(self._counter)\
                              .view(*self.correct_shape).numpy()
            done = False
        except IndexError:
            self._state = np.zeros(self.correct_shape,
                                   dtype=np.uint8)
            done = True
            info = self.stats
            LOG.info(f'End of epoch, results: {self.stats}')

        return self._state, reward, done, info
Пример #3
0
    def save(self,
             save_path=None,
             force=False,
             as_png=False,
             as_tensor=False,
             filename='{}.npy'):
        """Save loaded images to the specified in config path (by default as
        numpy array). If path does not exist and force is set to false then
        exception is raised. Otherwise create required directories and then
        save the data.

        Args:
            save_path (str, optional): Path to save loaded data, if unspecified
            then use path saved in config file

            force (bool, optional): If set then create required directories
            if they do not exis. Defaults to False.

            as_png (bool, optional): If set to True then save images as png
            instead of numpy array

        Raises:
            FileNotFoundError: If force is False and any of the directories
            does not exist then this expetion will be raised.
        """
        if self.cm.do_shuffle():
            self.shuffle()

        self._clean()

        name = self.cm.get_config_name()
        path = save_path if save_path is not None else self.cm.get_save_path()
        path = os.path.join(path, name)
        if not force and not os.path.isdir(path):
            raise FileNotFoundError(
                f'Path ({path}) was not found and force was set to False')
        # create directories if they do not exist
        Path(path).mkdir(parents=True, exist_ok=True)

        if as_png:
            for counter, (img, label) in enumerate(zip(self._x, self._y)):
                filename = f'{counter}_{label}.png'
                img_path = os.path.join(path, filename)
                cv2.imwrite(img_path, img)
                LOG.debug(f'save image: {img_path}')
        elif as_tensor:
            color_channels = self._get_channels_number()

            torch.save(
                torch.Tensor(self._x).view(-1, color_channels, self.IMG_SIZE,
                                           self.IMG_SIZE),
                os.path.join(path, filename.format('x')))
            torch.save(torch.Tensor(self._y),
                       os.path.join(path, filename.format('y')))
        else:
            np.save(os.path.join(path, filename.format('x'), self._x))
            np.save(os.path.join(path, filename.format('y'), self._y))
        LOG.info(f'Data was saved in directory {path}')
Пример #4
0
 def shuffle(self):
     """ Shuffle all of the loaded images. Does not check if
     images were loaded.
     """
     assert len(self._x) == len(self._y)
     p = np.random.permutation(len(self._x))
     self._x = self._x[p]
     self._y = self._y[p]
     LOG.info('Data was shuffled')
Пример #5
0
 def _clean(self):
     """Remove trailing empty elements in _x and _y arrays. They will
     occur if any image will not be loaded properly. Number of exceptions
     durning loading is tracking by n_exceptions_while_loading variable.
     """
     LOG.info(f'Remove last {self.n_exceptions_while_loading} images...')
     if not self.n_exceptions_while_loading:
         return
     self._x = self._x[:-self.n_exceptions_while_loading]
     self._y = self._y[:-self.n_exceptions_while_loading]
Пример #6
0
def remove_corrupted_images(path):
    removed_files = []
    for f in os.listdir(path):
        full_path = os.path.join(path, f)
        try:
            img = cv2.imread(full_path)
            if img is None:
                raise OSError
        except OSError:
            os.remove(full_path)
            removed_files.append(full_path)
    LOG.info(f'Removed files: {removed_files}')
Пример #7
0
    def reset(self):
        """Reset all environment variables. Should be called at
        the beginning of the epoch.

        Returns:
            list: initial observation
        """
        self._queries = 0
        self._counter = 0
        self.entropy = 0
        self.query_indicies = []
        self.reward_arr = []
        self._state = tensor_to_numpy_1d(self._get_state_vector())
        LOG.info('reset environment')
        return self._state
Пример #8
0
    def __init__(self, dm, y_oracle):
        super(ClassificationEnv, self).__init__()
        self.dm = dm
        self.storage = self.dm.train
        self.img_size = CONFIG['img_size']
        self.correct_shape = (self.img_size, self.img_size, 1)
        self._counter = 0
        self.stats = {'good': 0, 'wrong': 0}

        self.action_space = spaces.Discrete(3)
        self.observation_space = spaces.Box(
            shape=(self.img_size, self.img_size, 1),
            low=0, high=255, dtype=np.uint8)

        LOG.info('Environment initialized')
Пример #9
0
def create_csv_file(target_file, data_dir, skips=None):
    if skips is None:
        skips = []

    with open(target_file, 'w+') as tf:
        label = 0
        for fdir in sorted(os.listdir(data_dir)):
            if fdir in skips:
                continue

            dirpath = os.path.join(data_dir, fdir)
            if os.path.isdir(dirpath):
                LOG.info(f'Start loading from {dirpath}...')
                for f in tqdm(os.listdir(dirpath)):
                    feature_rpath = os.path.join(fdir, f)
                    tf.write(f'{feature_rpath},{label}\n')
                label += 1
    LOG.info(f'CSV file {target_file} saved')
Пример #10
0
    def step(self, action):
        obs, reward, done, info = self.env.step(action)

        self.reward_arr.append(self.Reward(
            action, round_to3(reward)))

        self.observations.append(obs)

        if done:
            extras = {
                'episode_length': self.env.get_counter(),
                'total_reward': self._calculate_total_reward(),
                'rewards': self.reward_arr,
                'observations': self.observations,
            }
            info.update(extras)

            if self.autolog:
                LOG.info(f"episode length: {info['episode_length']}, "
                         f"total reward: {info['total_reward']}")

        return obs, reward, done, info
Пример #11
0
    def __init__(self, dm, model, config):
        super(QueryEnv, self).__init__()
        self.dm = dm
        self.model = model
        self._queries = 0
        self._counter = 0
        self.query_indicies = []

        self.IMG_SIZE = config['img_size']
        self.MAX_QUERIES = config['max_queries']
        self.REWARD_MULT = config['reward_multiplier']
        self.REWARD_THR = 0  # calculate_reward needs it
        self.MAX_REWARD = self._get_max_reward(len(CONFIG['labels']))
        self.REWARD_THR = config['reward_treshold'] * self.MAX_REWARD
        self.QUERY_PUNISH = config['query_punishment']
        self.LEFT_QUERIES_PUNISH = config['left_queries_punishment']

        LOG.info(f'max reward: {self.MAX_REWARD}, '
                 f'reward threshold: {self.REWARD_THR}')

        n_action = 2
        self.action_space = spaces.Discrete(n_action)
        self.observation_space = spaces.Box(    # softmax output
            low=0, high=1, shape=(len(CONFIG['labels']),), dtype=np.float32)
Пример #12
0
    def load_raw(self):  #!TODO prepare 2 arrays processed and only reshaped
        """ Load images and apply preprocessing on them based on config file
        """
        paths = list(self.cm.get_imgs_path_to_label().keys())
        labels = list(self.cm.get_imgs_path_to_label().values())
        self._fail_if_path_is_not_dir(paths)
        self._reset()
        self.MAX_IMGS_PER_CLASS = self.MAX_IMGS // len(paths)

        k = 0
        for ctr, (class_path, label) in enumerate(zip(paths, labels)):
            LOG.info(f'Start loading images from path: {class_path}...')

            for f in tqdm(os.listdir(class_path)):
                if k >= self.MAX_IMGS_PER_CLASS * (ctr + 1):
                    LOG.info(f'Maximum number of load attemps is reached ({k})'
                             f' for class {label}')
                    break
                try:
                    path = os.path.join(class_path, f)

                    color_mode = cv2.IMREAD_GRAYSCALE \
                        if self.cm.do_grayscale() else cv2.IMREAD_COLOR
                    img = cv2.imread(path, flags=color_mode)  # reads as bgr

                    if img is None:
                        raise OSError('Failed to read an image')
                    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                    img = self._rescale(img)

                    if self.cm.do_normalization():
                        img = cv2.normalize(img,
                                            img,
                                            0,
                                            1,
                                            cv2.NORM_MINMAX,
                                            dtype=cv2.CV_32F)

                    if self.cm.do_centering():
                        img -= img.mean()

                    if self.cm.do_standarization():
                        img /= img.std()

                    self._x[k] = img
                    self._y[k] = label

                    self.balance_counter[self.cm.get_label_name(ctr)] += 1
                    k += 1

                except OSError as e:
                    LOG.warning(
                        f'Error while loading image from path {path}: {e}')
                    self.n_exceptions_while_loading += 1
Пример #13
0
def save_img(img, path):
    cv2.imwrite(path, img)
    LOG.debug(f'save image: {path}')