コード例 #1
0
def cached_listdir_imgs(p, min_size=None, discard_shitty=True) -> Images:
    if isinstance(p, list):
        return _joined(
            cached_listdir_imgs(p_, min_size, discard_shitty) for p_ in p)
    if isinstance(p, tuple):
        p, resample = p
        if isinstance(resample, int):
            return cached_listdir_imgs(p, min_size,
                                       discard_shitty).repeat(resample)
        if isinstance(resample, float):
            assert 0 <= resample < 1, resample
            images = cached_listdir_imgs(p, min_size, discard_shitty)
            subsample = int(resample * len(images))
            return images.subsample(subsample)
        raise ValueError('Invalid type for resample:', resample)

    if not os.path.isdir(p):
        raise NotADirectoryError(p)

    ps = []
    with timer.execute(
            f'>>> filter [min_size={min_size}; discard_s={discard_shitty}]'):
        for img in _iter_imgs(p):
            if min_size and img.smallest_size < min_size:
                continue
            if discard_shitty and img.shitty:
                continue
            ps.append(img.full_p)
    return Images(
        ps,
        id=
        f'{os.path.basename(p.rstrip(os.path.sep))}_{min_size}_dS={discard_shitty}'
    )
コード例 #2
0
    def validation_step(self, i, kind):
        if kind == 'fixed_first':
            with self.summarizer.enable(prefix='val', global_step=i):
                x_n, q = self.blueprint.unpack_batch_pad(self.fixed_first_val)
                out: ClassifierOut = self.blueprint.forward(x_n)
        elif kind == 'validation_set':
            with timer.execute(f'>>> Running on {len(self.ds_val)} images of validation set [s]'):
                test_id = TestID(self.ds_val.id, i)
                test_results = TestResults()

                for idx, img in enumerate(self.ds_val):
                    filename = self.ds_val.get_filename(idx)
                    x_n, q = self.blueprint.unpack_batch_pad(img)
                    out: ClassifierOut = self.blueprint.forward(x_n)
                    loss = self.blueprint.loss(out.q_logits, q)
                    test_results.set_from_loss(loss, filename)
                    test_results.set(filename, 'acc', self.blueprint.get_accuracy(out.q_logits, q))

                _, test_output_cache = multiscale_tester.get_test_log_dir_and_cache(
                        self.log_dir_root, os.path.basename(self.log_dir))
                test_output_cache[test_id] = test_results
                print(f'VALIDATE  {i: 6d}: {test_results.means_str()}')
                for key, value in test_results.means_dict().items():
                    self.sw.add_scalar(f'val_set/{self.ds_val.id}/{key}',
                                       value, i)
        else:
            raise ValueError('Invalid kind', kind)
コード例 #3
0
def cached_listdir_imgs_max(p, max_size=None, discard_shitty=True):
    ps = []
    filtered_max = 0
    with timer.execute(
            f'>>> filter [max_size={max_size}; discard_s={discard_shitty}]'):
        for img in _iter_imgs(p):
            if max_size and img.smallest_size >= max_size:
                filtered_max += 1
                continue
            if discard_shitty and img.shitty:
                continue
            ps.append(img.full_p)
    print('Filtered', filtered_max, 'imgs!')
    return Images(
        ps,
        id=
        f'{os.path.basename(p.rstrip(os.path.sep))}_{max_size}_dS={discard_shitty}'
    )
コード例 #4
0
ファイル: bit_counter.py プロジェクト: cmatija/DL_project
def encode_decode_to_file_ctx(syms,
                              prediction_net: probclass.PredictionNetwork,
                              syms_format='HWC',
                              verbose=False):
    """
    Encode symbols with arithmetic coding to disk.
    :param syms: HWC or CHW depending on syms_format, symbols of one image. Or BHWC, BCHW, in which case the number
    of bits needed for all batches is returned.
    :param prediction_net:
    arithmetic coding to be correct).
    :return: number of bits to encode all symbols in `syms`
    """
    _print = print if verbose else no_op.NoOp()

    if len(syms.shape) == 4:
        num_batches = syms.shape[0]
        return np.sum([
            encode_decode_to_file_ctx(syms[b, ...], prediction_net,
                                      syms_format, verbose)
            for b in range(num_batches)
        ])

    assert len(syms.shape) == 3, 'Expected HWC or CHW'
    assert syms_format in ('HWC', 'CHW')
    if syms_format == 'HWC':
        _print('Transposing symbols for encoding...')
        syms = np.transpose(syms, (2, 0, 1))

    # ---
    _print('Preparing encode...')

    foutid, fout_p = tempfile.mkstemp()
    ctx_shape = prediction_net.input_ctx_shape
    get_freqs = ft.compose(ac.SimpleFrequencyTable, prediction_net.get_freqs)
    get_pr = prediction_net.get_pr

    # encode
    with timer.execute('Encoding time [s]'):
        _print(
            'Encoding symbols of shape {} ({} symbols) with context shape {}...'
            .format(syms.shape, np.prod(syms.shape), ctx_shape))

        syms_padded = prediction_net.pad_symbols_volume(syms)
        virtual_num_bits, first_sym, theoretical_bit_cost = _encode(
            foutid, syms_padded, ctx_shape, get_freqs, get_pr, _print)
        assert abs(virtual_num_bits - theoretical_bit_cost
                   ) < 50, 'Virtual: {} -- Theoretical: {}'.format(
                       virtual_num_bits, theoretical_bit_cost)

    # bit count
    actual_num_bits = os.path.getsize(fout_p) * 8
    assert actual_num_bits == virtual_num_bits, '{} != {}'.format(
        actual_num_bits, virtual_num_bits)

    # decode
    with timer.execute('Decoding time [s]'):
        _print('Decoding symbols to shape {}, first_sym={}...'.format(
            syms_padded.shape, first_sym))

        syms_dec_padded = _decode(fout_p, syms_padded.shape, ctx_shape,
                                  first_sym, get_freqs, _print)
        syms_dec = prediction_net.undo_pad_symbols_volume(syms_dec_padded)

    # checkin' (takes no time)
    np.testing.assert_array_equal(syms, syms_dec)
    _print('Decoded symbols match input!')

    # cleanup
    os.remove(fout_p)

    return actual_num_bits