Пример #1
0
    def __init__(self,
                 optims,
                 initial,
                 decay_fac,
                 decay_interval_itr=None,
                 decay_interval_epoch=None,
                 epoch_len=None,
                 warm_restart=None,
                 warm_restart_schedule=None):
        super(ExponentialDecayLRSchedule, self).__init__(optims)
        assert_exc((decay_interval_itr is not None) ^
                   (decay_interval_epoch is not None),
                   'Need either iter or epoch')
        if decay_interval_epoch:
            assert epoch_len is not None
            decay_interval_itr = int(decay_interval_epoch * epoch_len)
            if warm_restart:
                warm_restart = int(warm_restart * epoch_len)
        self.initial = initial
        self.decay_fac = decay_fac
        self.decay_every_itr = decay_interval_itr

        self.warm_restart_itr = warm_restart
        self.warm_restart_schedule = warm_restart_schedule

        self.last_warm_restart = 0
Пример #2
0
 def _enable(self, prefix, global_step):
     """ Enable logging of prefix """
     assert_exc(isinstance(prefix, str),
                'prefix must be str, got {}'.format(prefix))
     assert_exc(prefix[-1] != '/')
     self.enabled_prefix = prefix
     self.global_step = global_step
Пример #3
0
 def get_ckpt_for_itr(self, itr, before_time=None):
     """
     Gets ckpt_itrc where itrc <= itr, i.e., the latest ckpt before `itr`.
     If `before_time` is given and itr == -1, the latest ckpt before `before_time` is used
     Special values: itr == -1 -> newest ckpt
     """
     # sorted list of (itr, ckpt_p)
     ckpts = list(self.itr_ckpt())
     if before_time is not None:
         ckpts_before_time = [(i, p) for i, p in ckpts
                              if os.path.getmtime(p) <= before_time]
         print(
             f'*** Ignoring {len(ckpts) - len(ckpts_before_time)} ckpts after {before_time}'
         )
         ckpts = ckpts_before_time
     assert_exc(
         len(ckpts) > 0, 'No ckpts found in {}'.format(self._out_dir),
         CkeckpointLoadingException)
     if itr == -1:
         return ckpts[-1]
     first_itrc, _ = ckpts[0]
     assert_exc(first_itrc <= itr,
                'Earliest ckpt {} is after {}'.format(first_itrc, itr),
                CkeckpointLoadingException)
     for itrc, ckpt_p in reversed(ckpts):
         if itrc <= itr:
             return itrc, ckpt_p
     raise ValueError('Unexpected, {}, {}'.format(itr, ckpts))
Пример #4
0
 def _parse_cos_spec():
     lrmax, lrmin, T = s.split(SPEC_SEP)
     kind, T = T[0], T[1:]
     assert_exc(kind in ('i', 'e'), 'Invalid spec: {}'.format(s))
     T_itr = int(T) if kind == 'i' else None
     T_epoch = float(T) if kind == 'e' else None
     return CosineDecayLRSchedule(optims, float(lrmax), float(lrmin), T_itr,
                                  T_epoch, epoch_len)
Пример #5
0
 def add(self, v):
     v = pe.tensor_to_np(v)
     num_values = np.prod(v.shape)
     if self._buffer is None:
         print(f'Creating {v.dtype} buffer for {self._name}: {self._buffer_size}x{num_values}')
         self._buffer = np.zeros((self._buffer_size, num_values), dtype=v.dtype)
     assert_exc(self._buffer.shape[1] == num_values, (self._buffer.shape, v.shape, num_values), BufferSizeMismatch)
     self._buffer[self._idx, :] = v.flatten()
     self._idx = (self._idx + 1) % self._buffer_size
     self._filled_idx = min(self._filled_idx + 1, self._buffer_size)
Пример #6
0
 def __getitem__(self, items):
     """
     Allows usage as
     lr, L, num_layers = config['lr', 'L', 'num_layer']
     """
     if not isinstance(items, tuple):
         items = [items]
     for item in items:
         assert_exc(item in self.__dict__,
                    'Invalid parameter: {}'.format(item), AttributeError)
         yield self.__dict__[item]
Пример #7
0
 def get_ckpt_for_itr(self, itr):
     """
     Gets ckpt_itrc where itrc <= itr, i.e., the latest ckpt before `itr`.
     Special values: itr == -1 -> newest ckpt
     """
     ckpts = list(self.itr_ckpt())
     assert_exc(len(ckpts) > 0, 'No ckpts found in {}'.format(self._out_dir))
     if itr == -1:
         return ckpts[-1]
     first_itrc, _ = ckpts[0]
     assert_exc(first_itrc <= itr, 'Earliest ckpt {} is after {}'.format(first_itrc, itr))
     for itrc, ckpt_p in reversed(ckpts):
         if itrc <= itr:
             return itrc, ckpt_p
     raise ValueError('Unexpected, {}, {}'.format(itr, ckpts))
Пример #8
0
    def decode(self, pin, png_out_p):
        """
        Decode L3C-encoded file at `pin` to a PNG at `png_out_p`.
        """
        pout_dir = os.path.dirname(os.path.abspath(png_out_p))
        assert_exc(os.path.isdir(pout_dir),
                   f'png_out_p directory ({pout_dir}) does not exists!',
                   DecodeError)
        assert_exc(png_out_p.endswith('.png'),
                   f'png_out_p must end in .png, got {png_out_p}', DecodeError)

        decoded = self.bc.decode(pin)

        self._write_img(decoded, png_out_p)
        print(f'---\nDecoded: {png_out_p}')
Пример #9
0
    def encode(self, img_p, pout, overwrite=False):
        pout_dir = os.path.dirname(os.path.abspath(pout))
        assert_exc(os.path.isdir(pout_dir),
                   f'pout directory ({pout_dir}) does not exists!',
                   EncodeError)
        if overwrite and os.path.isfile(pout):
            print(f'Removing {pout}...')
            os.remove(pout)
        assert_exc(not os.path.isfile(pout),
                   f'{pout} exists. Consider --overwrite', EncodeError)

        img = self._read_img(img_p)
        img = img.to(pe.DEVICE)

        self.bc.encode(img, pout=pout)
        print('---\nSaved:', pout)
Пример #10
0
    def __init__(self,
                 root_dir_or_img,
                 max_imgs=None,
                 skip_hidden=False,
                 append_id=None):
        """
        :param root_dir_or_img: Either a directory with images or the path of a single image.
        :param max_imgs: If given, subsample deterministically to only contain max_imgs
        :param skip_hidden: If given, skip images starting with '.'
        :param append_id: If given, append `append_id` to self.id
        :raises ValueError if root_dir is not a directory or does not contain images
        """
        self.root_dir_or_img = root_dir_or_img

        if os.path.isdir(root_dir_or_img):
            root_dir = root_dir_or_img
            self.name = os.path.basename(root_dir.rstrip('/'))
            self.ps = sorted(p for p in os_ext.listdir_paths(root_dir)
                             if has_image_ext(p))
            if skip_hidden:
                self.ps = self._filter_hidden(self.ps)
            if max_imgs and max_imgs < len(self.ps):
                print('Subsampling to use {} imgs of {}...'.format(
                    max_imgs, self.name))
                idxs = np.linspace(0, len(self.ps) - 1, max_imgs, dtype=np.int)
                self.ps = np.array(self.ps)[idxs].tolist()
                assert len(self.ps) == max_imgs
            assert_exc(
                len(self.ps) > 0, 'No images found in {}'.format(root_dir),
                ValueError)
            self.id = '{}_{}'.format(self.name, len(self.ps))
            self._str = 'Testset({}): in {}, {} images'.format(
                self.name, root_dir, len(self.ps))
        else:
            img = root_dir_or_img
            assert_exc(os.path.isfile(img), 'Does not exist: {}'.format(img),
                       FileNotFoundError)
            self.name = os.path.basename(img)
            self.ps = [img]
            self.id = img
            self._str = 'Testset([{}]): 1 image'.format(self.name)
        if append_id:
            self.id += append_id
Пример #11
0
def interpolator(measures_per_image_iter, grid, interp_mode='linear'):
    accumulated_values = np.zeros_like(grid, np.float64)
    # Count values per bin
    N = np.zeros_like(grid, np.int64)
    num_imgs = 0
    num_errors = 0

    for img_description, (bpps, values) in measures_per_image_iter:
        assert_exc(
            len(bpps) >= 2, 'Missing values for {}'.format(img_description),
            OtherCodecsReadException)
        assert_exc(bpps[0] >= bpps[-1],
                   f'First bpp < last: {bpps[0]} < {bpps[-1]}',
                   OtherCodecsReadException)

        num_imgs += 1

        # create interpolation function
        try:
            fq = scipy.interpolate.interp1d(bpps, values, interp_mode)
        except ValueError as e:
            raise OtherCodecsReadException('ValueError while creating fq:', e)

        for i, bpp in enumerate(grid):
            try:
                accumulated_values[i] += fq(bpp)
                N[i] += 1
            except ValueError as e:
                num_errors += 1
                continue
    try:
        grid, values = ft.unzip(
            (bpp, v / n) for bpp, v, n in zip(grid, accumulated_values, N)
            if n > _REQUIRED_BINS * num_imgs)
    except ValueError as e:
        raise ValueError(
            grid, accumulated_values, N,
            list(
                ft.unzip((bpp, v / n)
                         for bpp, v, n in zip(grid, accumulated_values, N)
                         if n > _REQUIRED_BINS * num_imgs)))
    return grid, values
Пример #12
0
def to_image(t):
    """
    :param t: tensor or np.ndarray, may be of shape NCHW / CHW with C=1 or 3 / HW, dtype float32 or uint8. If float32:
    must be in [0, 1]
    :return: HW3 uint8 np.ndarray
    """
    if not isinstance(t, np.ndarray):
        t = pe.tensor_to_np(t)
    # - t is numpy array
    if t.ndim == 4:
        # - t has batch dimension, only use first
        t = t[0, ...]
    elif t.ndim == 2:
        t = np.expand_dims(t, 0)  # Now 1HW
    assert_exc(t.ndim == 3, 'Invalid shape: {}'.format(t.shape))
    # - t is 3 dimensional CHW numpy array
    if t.dtype != np.uint8:
        assert_exc(t.dtype == np.float32, 'Expected either uint8 or float32, got {}'.format(t.dtype))
        _check_range(t, 0, 1)
        t = (t * 255.).astype(np.uint8)
    # - t is uint8 numpy array
    num_channels = t.shape[0]
    if num_channels == 3:
        t = np.transpose(t, (1, 2, 0))
    elif num_channels == 1:
        t = np.stack([t[0, :, :] for _ in range(3)], -1)
    else:
        raise ValueError('Expected CHW, got {}'.format(t.shape))
    assert_exc(t.ndim == 3 and t.shape[2] == 3, str(t.shape))
    # - t is uint8 numpy array of shape HW3
    return t
Пример #13
0
def create_curves_for_images(root_dir, out_dir, grid, mode):
    times = []
    # make sure we exclude here! otherwise, task_array.job_enumerate sees different number of files in every job!!
    all_img_ps = _get_image_paths(root_dir, exclude={'tmp'})
    assert_exc(len(all_img_ps) > 0, 'No images found', ValueError)

    non_pngs = [p for p in all_img_ps if not p.endswith('.png')]
    assert_exc(
        len(non_pngs) == 0,
        f'Only .pngs are supported by this code! Found {len(non_pngs)} others.'
    )

    measure_over_interval = {
        'bpg': bpg_measure_over_interval,
        'balle': balle_measure_over_interval,
        # 'bpgslow':  bpg_measure_over_interval_slow,
        'jp2k': jp2k_measure_over_interval,
        'jp': jp_measure_over_interval,
        'webp': webp_measure_over_interval
    }[mode]

    for i, img_p in task_array.job_enumerate(all_img_ps):
        print('>>>', task_array.TASK_ID, 'compresses', os.path.basename(img_p))
        img_name = os.path.splitext(os.path.basename(img_p))[0]
        s = time.time()
        mf = measures_file_p(out_dir, img_name)
        if complete_measures_file_exists(mf, num_ops=len(grid)):
            print(f'Found output for {img_name}, skipping...')
            continue
        # need to create measures file
        with open(mf, 'w+') as f:
            measure_over_interval(img_p, f, grid)
        times.append(time.time() - s)
        avg_time = np.mean(times[-15:])
        print('Time left: {:.2f}min'.format(avg_time * (len(all_img_ps) - i) /
                                            60))
Пример #14
0
def _parse_exp_spec(s, optims, initial_lr, epoch_len):
    if s.count(SPEC_SEP) > 2:
        fac, interval, warm, warm_start, warm_fac, warm_interval = s.split(
            SPEC_SEP)
        assert warm == 'warm'
        warm_start = int(warm_start)
        warm_schedule = _parse_exp_spec(
            SPEC_SEP.join([warm_fac, warm_interval]), optims, initial_lr,
            epoch_len)
    else:
        fac, interval = s.split(SPEC_SEP)
        warm_start, warm_schedule = None, None
    kind, interval = interval[0], interval[1:]
    assert_exc(kind in ('i', 'e'), 'Invalid spec: {}'.format(s))
    decay_interval_itr = int(interval) if kind == 'i' else None
    decay_interval_epoch = float(interval) if kind == 'e' else None
    return ExponentialDecayLRSchedule(optims,
                                      initial_lr,
                                      float(fac),
                                      decay_interval_itr,
                                      decay_interval_epoch,
                                      epoch_len,
                                      warm_restart=warm_start,
                                      warm_restart_schedule=warm_schedule)
Пример #15
0
def get_experiment_dir(log_dir, experiment_spec):
    """
    experiment_spec: if is a logdate, find correct full path in log_dir, otherwise assume logdir/experiment_spec exists
    :return experiment dir, no slash at the end. containing /ckpts
    """
    if logdir_helpers.is_log_date(experiment_spec):  # assume that log_dir/restore* matches
        assert_exc(log_dir is not None, 'Can only infer experiment_dir from log_date if log_dir is not None')
        restore_dir_glob = os.path.join(log_dir, experiment_spec + '*')
        restore_dir_possible = glob.glob(restore_dir_glob)
        assert_exc(len(restore_dir_possible) == 1, 'Expected one match for {}, got {}'.format(
                restore_dir_glob, restore_dir_possible))
        experiment_spec = restore_dir_possible[0]
    else:
        experiment_spec = os.path.join(log_dir, experiment_spec)
    experiment_dir = experiment_spec.rstrip(os.path.sep)
    assert_exc(os.path.isdir(experiment_dir), 'Invalid experiment_dir: {}'.format(experiment_dir))
    return experiment_dir
Пример #16
0
def get_ckpts_dir(experiment_dir, ensure_exists=True):
    ckpts_p = os.path.join(experiment_dir, CKPTS_DIR_NAME)
    if ensure_exists:
        assert_exc(os.path.isdir(ckpts_p), 'Not found: {}'.format(ckpts_p))
    return ckpts_p
Пример #17
0
def _assert_dir_exists_in_cwd(p):
    assert_exc(os.path.isdir(p), '{} not in {}'.format(p, os.getcwd()))
Пример #18
0
 def filter_filenames(self, filter_filenames):
     filename = lambda p: os.path.splitext(os.path.basename(p))[0]
     self.ps = [p for p in self.ps if filename(p) in filter_filenames]
     assert_exc(
         len(self.ps) > 0,
         'No files after filtering for {}'.format(filter_filenames))
Пример #19
0
def _check_range(a, lo, hi):
    a_lo, a_hi = np.min(a), np.max(a)
    assert_exc(a_lo >= lo and a_hi <= hi, 'Invalid range: [{}, {}]. Expected: [{}, {}]'.format(a_lo, a_hi, lo, hi))