def cached_listdir_imgs(p, min_size=None, discard_shitty=True) -> Images: if isinstance(p, list): return _joined( cached_listdir_imgs(p_, min_size, discard_shitty) for p_ in p) if isinstance(p, tuple): p, resample = p if isinstance(resample, int): return cached_listdir_imgs(p, min_size, discard_shitty).repeat(resample) if isinstance(resample, float): assert 0 <= resample < 1, resample images = cached_listdir_imgs(p, min_size, discard_shitty) subsample = int(resample * len(images)) return images.subsample(subsample) raise ValueError('Invalid type for resample:', resample) if not os.path.isdir(p): raise NotADirectoryError(p) ps = [] with timer.execute( f'>>> filter [min_size={min_size}; discard_s={discard_shitty}]'): for img in _iter_imgs(p): if min_size and img.smallest_size < min_size: continue if discard_shitty and img.shitty: continue ps.append(img.full_p) return Images( ps, id= f'{os.path.basename(p.rstrip(os.path.sep))}_{min_size}_dS={discard_shitty}' )
def validation_step(self, i, kind): if kind == 'fixed_first': with self.summarizer.enable(prefix='val', global_step=i): x_n, q = self.blueprint.unpack_batch_pad(self.fixed_first_val) out: ClassifierOut = self.blueprint.forward(x_n) elif kind == 'validation_set': with timer.execute(f'>>> Running on {len(self.ds_val)} images of validation set [s]'): test_id = TestID(self.ds_val.id, i) test_results = TestResults() for idx, img in enumerate(self.ds_val): filename = self.ds_val.get_filename(idx) x_n, q = self.blueprint.unpack_batch_pad(img) out: ClassifierOut = self.blueprint.forward(x_n) loss = self.blueprint.loss(out.q_logits, q) test_results.set_from_loss(loss, filename) test_results.set(filename, 'acc', self.blueprint.get_accuracy(out.q_logits, q)) _, test_output_cache = multiscale_tester.get_test_log_dir_and_cache( self.log_dir_root, os.path.basename(self.log_dir)) test_output_cache[test_id] = test_results print(f'VALIDATE {i: 6d}: {test_results.means_str()}') for key, value in test_results.means_dict().items(): self.sw.add_scalar(f'val_set/{self.ds_val.id}/{key}', value, i) else: raise ValueError('Invalid kind', kind)
def cached_listdir_imgs_max(p, max_size=None, discard_shitty=True): ps = [] filtered_max = 0 with timer.execute( f'>>> filter [max_size={max_size}; discard_s={discard_shitty}]'): for img in _iter_imgs(p): if max_size and img.smallest_size >= max_size: filtered_max += 1 continue if discard_shitty and img.shitty: continue ps.append(img.full_p) print('Filtered', filtered_max, 'imgs!') return Images( ps, id= f'{os.path.basename(p.rstrip(os.path.sep))}_{max_size}_dS={discard_shitty}' )
def encode_decode_to_file_ctx(syms, prediction_net: probclass.PredictionNetwork, syms_format='HWC', verbose=False): """ Encode symbols with arithmetic coding to disk. :param syms: HWC or CHW depending on syms_format, symbols of one image. Or BHWC, BCHW, in which case the number of bits needed for all batches is returned. :param prediction_net: arithmetic coding to be correct). :return: number of bits to encode all symbols in `syms` """ _print = print if verbose else no_op.NoOp() if len(syms.shape) == 4: num_batches = syms.shape[0] return np.sum([ encode_decode_to_file_ctx(syms[b, ...], prediction_net, syms_format, verbose) for b in range(num_batches) ]) assert len(syms.shape) == 3, 'Expected HWC or CHW' assert syms_format in ('HWC', 'CHW') if syms_format == 'HWC': _print('Transposing symbols for encoding...') syms = np.transpose(syms, (2, 0, 1)) # --- _print('Preparing encode...') foutid, fout_p = tempfile.mkstemp() ctx_shape = prediction_net.input_ctx_shape get_freqs = ft.compose(ac.SimpleFrequencyTable, prediction_net.get_freqs) get_pr = prediction_net.get_pr # encode with timer.execute('Encoding time [s]'): _print( 'Encoding symbols of shape {} ({} symbols) with context shape {}...' .format(syms.shape, np.prod(syms.shape), ctx_shape)) syms_padded = prediction_net.pad_symbols_volume(syms) virtual_num_bits, first_sym, theoretical_bit_cost = _encode( foutid, syms_padded, ctx_shape, get_freqs, get_pr, _print) assert abs(virtual_num_bits - theoretical_bit_cost ) < 50, 'Virtual: {} -- Theoretical: {}'.format( virtual_num_bits, theoretical_bit_cost) # bit count actual_num_bits = os.path.getsize(fout_p) * 8 assert actual_num_bits == virtual_num_bits, '{} != {}'.format( actual_num_bits, virtual_num_bits) # decode with timer.execute('Decoding time [s]'): _print('Decoding symbols to shape {}, first_sym={}...'.format( syms_padded.shape, first_sym)) syms_dec_padded = _decode(fout_p, syms_padded.shape, ctx_shape, first_sym, get_freqs, _print) syms_dec = prediction_net.undo_pad_symbols_volume(syms_dec_padded) # checkin' (takes no time) np.testing.assert_array_equal(syms, syms_dec) _print('Decoded symbols match input!') # cleanup os.remove(fout_p) return actual_num_bits