def _read_img(self, img_p):
     img = np.array(Image.open(img_p)).transpose(2, 0, 1)  # Turn into CHW
     C, H, W = img.shape
     # check number of channels
     if C == 4:
         print('*** WARN: Will discard 4th (alpha) channel.')
         img = img[:3, ...]
     elif C != 3:
         raise EncodeError(f'Image has {C} channels, expected 3 or 4.')
     # Convert to 1CHW torch tensor
     img = torch.from_numpy(img).unsqueeze(0).long()
     # Check padding
     padding = self._padding_fac()
     if H % padding != 0 or W % padding != 0:
         print(
             f'*** WARN: image shape ({H}X{W}) not divisible by {padding}. Will pad...'
         )
         img = MultiscaleBlueprint.pad(img, fac=padding)
     return img
    def __init__(self,
                 ms_config_p, dl_config_p,
                 log_dir_root, log_config: LogConfig,
                 num_workers,
                 saver: Saver, restorer: TrainRestorer=None,
                 sw_cls=vis.safe_summary_writer.SafeSummaryWriter):
        """
        :param ms_config_p: Path to the multiscale config file, see README
        :param dl_config_p: Path to the dataloader config file, see README
        :param log_dir_root: All outputs (checkpoints, tensorboard) will be saved here.
        :param log_config: Instance of train.trainer.LogConfig, contains intervals.
        :param num_workers: Number of workers to use for DataLoading, see train.py
        :param saver: Saver instance to use.
        :param restorer: Instance of TrainRestorer, if we need to restore
        """

        # Read configs
        # config_ms = config for the network (ms = multiscale)
        # config_dl = config for data loading
        (self.config_ms, self.config_dl), rel_paths = ft.unzip(map(config_parser.parse, [ms_config_p, dl_config_p]))
        # Update config_ms depending on global_config
        global_config.update_config(self.config_ms)
        # Create data loaders
        dl_train, dl_val = self._get_dataloaders(num_workers)
        # Create blueprint. A blueprint collects the network as well as the losses in one class, for easy reuse
        # during testing.
        self.blueprint = MultiscaleBlueprint(self.config_ms)
        print('Network:', self.blueprint.net)
        # Setup optimizer
        optim_cls = {'RMSprop': optim.RMSprop,
                     'Adam': optim.Adam,
                     'SGD': optim.SGD,
                     }[self.config_ms.optim]
        net = self.blueprint.net
        self.optim = optim_cls(net.parameters(), self.config_ms.lr.initial,
                               weight_decay=self.config_ms.weight_decay)
        # Calculate a rough estimate for time per batch (does not take into account that CUDA is async,
        # but good enought to get a feeling during training).
        self.time_accumulator = timer.TimeAccumulator()
        # Restore network if requested
        skip_to_itr = self.maybe_restore(restorer)
        if skip_to_itr is not None:  # i.e., we have a restorer
            print('Skipping to {}...'.format(skip_to_itr))
        # Create LR schedule to update parameters
        self.lr_schedule = lr_schedule.from_spec(
                self.config_ms.lr.schedule, self.config_ms.lr.initial, [self.optim], epoch_len=len(dl_train))

        # --- All nn.Modules are setup ---
        print('-' * 80)

        # create log dir and summary writer
        self.log_dir = Trainer.get_log_dir(log_dir_root, rel_paths, restorer)
        self.log_date = logdir_helpers.log_date_from_log_dir(self.log_dir)
        self.ckpt_dir = os.path.join(self.log_dir, CKPTS_DIR_NAME)
        print(f'Checkpoints will be saved to {self.ckpt_dir}')
        saver.set_out_dir(self.ckpt_dir)


        # Create summary writer
        sw = sw_cls(self.log_dir)
        self.summarizer = vis.summarizable_module.Summarizer(sw)
        net.register_summarizer(self.summarizer)
        self.blueprint.register_summarizer(self.summarizer)
        # superclass setup
        super(MultiscaleTrainer, self).__init__(dl_train, dl_val, [self.optim], net, sw,
                                                max_epochs=self.config_dl.max_epochs,
                                                log_config=log_config, saver=saver, skip_to_itr=skip_to_itr)
class MultiscaleTrainer(Trainer):
    def __init__(self,
                 ms_config_p, dl_config_p,
                 log_dir_root, log_config: LogConfig,
                 num_workers,
                 saver: Saver, restorer: TrainRestorer=None,
                 sw_cls=vis.safe_summary_writer.SafeSummaryWriter):
        """
        :param ms_config_p: Path to the multiscale config file, see README
        :param dl_config_p: Path to the dataloader config file, see README
        :param log_dir_root: All outputs (checkpoints, tensorboard) will be saved here.
        :param log_config: Instance of train.trainer.LogConfig, contains intervals.
        :param num_workers: Number of workers to use for DataLoading, see train.py
        :param saver: Saver instance to use.
        :param restorer: Instance of TrainRestorer, if we need to restore
        """

        # Read configs
        # config_ms = config for the network (ms = multiscale)
        # config_dl = config for data loading
        (self.config_ms, self.config_dl), rel_paths = ft.unzip(map(config_parser.parse, [ms_config_p, dl_config_p]))
        # Update config_ms depending on global_config
        global_config.update_config(self.config_ms)
        # Create data loaders
        dl_train, dl_val = self._get_dataloaders(num_workers)
        # Create blueprint. A blueprint collects the network as well as the losses in one class, for easy reuse
        # during testing.
        self.blueprint = MultiscaleBlueprint(self.config_ms)
        print('Network:', self.blueprint.net)
        # Setup optimizer
        optim_cls = {'RMSprop': optim.RMSprop,
                     'Adam': optim.Adam,
                     'SGD': optim.SGD,
                     }[self.config_ms.optim]
        net = self.blueprint.net
        self.optim = optim_cls(net.parameters(), self.config_ms.lr.initial,
                               weight_decay=self.config_ms.weight_decay)
        # Calculate a rough estimate for time per batch (does not take into account that CUDA is async,
        # but good enought to get a feeling during training).
        self.time_accumulator = timer.TimeAccumulator()
        # Restore network if requested
        skip_to_itr = self.maybe_restore(restorer)
        if skip_to_itr is not None:  # i.e., we have a restorer
            print('Skipping to {}...'.format(skip_to_itr))
        # Create LR schedule to update parameters
        self.lr_schedule = lr_schedule.from_spec(
                self.config_ms.lr.schedule, self.config_ms.lr.initial, [self.optim], epoch_len=len(dl_train))

        # --- All nn.Modules are setup ---
        print('-' * 80)

        # create log dir and summary writer
        self.log_dir = Trainer.get_log_dir(log_dir_root, rel_paths, restorer)
        self.log_date = logdir_helpers.log_date_from_log_dir(self.log_dir)
        self.ckpt_dir = os.path.join(self.log_dir, CKPTS_DIR_NAME)
        print(f'Checkpoints will be saved to {self.ckpt_dir}')
        saver.set_out_dir(self.ckpt_dir)


        # Create summary writer
        sw = sw_cls(self.log_dir)
        self.summarizer = vis.summarizable_module.Summarizer(sw)
        net.register_summarizer(self.summarizer)
        self.blueprint.register_summarizer(self.summarizer)
        # superclass setup
        super(MultiscaleTrainer, self).__init__(dl_train, dl_val, [self.optim], net, sw,
                                                max_epochs=self.config_dl.max_epochs,
                                                log_config=log_config, saver=saver, skip_to_itr=skip_to_itr)

    def modules_to_save(self):
        return {'net': self.blueprint.net,
                'optim': self.optim}

    def _get_dataloaders(self, num_workers, shuffle_train=True):
        assert self.config_dl.train_imgs_glob is not None
        print('Cropping to {}'.format(self.config_dl.crop_size))
        to_tensor_transform = transforms.Compose(
                [transforms.RandomCrop(self.config_dl.crop_size),
                 transforms.RandomHorizontalFlip(),
                 images_loader.IndexImagesDataset.to_tensor_uint8_transform()])
        # NOTE: if there are images in your training set with dimensions <128, training will abort at some point,
        # because the cropper failes. See REAME, section about data preparation.
        min_size = self.config_dl.crop_size
        if min_size <= 128:
            min_size = None
        ds_train = images_loader.IndexImagesDataset(
                images=images_loader.ImagesCached(
                        self.config_dl.train_imgs_glob,
                        self.config_dl.image_cache_pkl,
                        min_size=min_size),
                to_tensor_transform=to_tensor_transform)

        dl_train = DataLoader(ds_train, self.config_dl.batchsize_train, shuffle=shuffle_train,
                              num_workers=num_workers)
        print('Created DataLoader [train] {} batches -> {} imgs'.format(
                len(dl_train), self.config_dl.batchsize_train * len(dl_train)))

        ds_val = self._get_ds_val(
                self.config_dl.val_glob,
                crop=self.config_dl.crop_size,
                truncate=self.config_dl.num_val_batches * self.config_dl.batchsize_val)
        dl_val = DataLoader(
                ds_val, self.config_dl.batchsize_val, shuffle=False,
                num_workers=num_workers, drop_last=True)
        print('Created DataLoader [val] {} batches -> {} imgs'.format(
                len(dl_val), self.config_dl.batchsize_train * len(dl_val)))

        return dl_train, dl_val

    def _get_ds_val(self, images_spec, crop=False, truncate=False):
        img_to_tensor_t = [images_loader.IndexImagesDataset.to_tensor_uint8_transform()]
        if crop:
            img_to_tensor_t.insert(0, transforms.CenterCrop(crop))
        img_to_tensor_t = transforms.Compose(img_to_tensor_t)

        fixed_first = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'fixedimg.jpg')
        if not os.path.isfile(fixed_first):
            print(f'INFO: No file found at {fixed_first}')
            fixed_first = None

        ds = images_loader.IndexImagesDataset(
                images=images_loader.ImagesCached(
                        images_spec, self.config_dl.image_cache_pkl,
                        min_size=self.config_dl.val_glob_min_size),
                to_tensor_transform=img_to_tensor_t,
                fixed_first=fixed_first)  # fix a first image to have consistency in tensor board

        if truncate:
            ds = pe.TruncatedDataset(ds, num_elemens=truncate)

        return ds

    def train_step(self, i, batch, log, log_heavy, load_time=None):
        """
        :param i: current step
        :param batch: dict with 'idx', 'raw'
        """
        self.lr_schedule.update(i)
        self.net.zero_grad()

        values = Values('{:.3e}', ' | ')

        with self.time_accumulator.execute():
            idxs, img_batch, s = self.blueprint.unpack(batch)

            with self.summarizer.maybe_enable(prefix='train', flag=log, global_step=i):
                out = self.blueprint.forward(img_batch)

            with self.summarizer.maybe_enable(prefix='train', flag=log_heavy, global_step=i):
                loss_pc, nonrecursive_bpsps, _ = self.blueprint.get_loss(out)

            total_loss = loss_pc
            total_loss.backward()
            self.optim.step()

            values['loss'] = loss_pc
            values['bpsp'] = sum(nonrecursive_bpsps)

        if not log:
            return

        mean_time_per_batch = self.time_accumulator.mean_time_spent()
        imgs_per_second = self.config_dl.batchsize_train / mean_time_per_batch

        print('{} {: 6d}: {} // {:.3f} img/s '.format(
                self.log_date, i, values.get_str(), imgs_per_second) + (load_time or ''))

        values.write(self.sw, i)

        # Gradients
        params = [('all', self.net.parameters())]
        for name, ps in params:
            tot = pe.get_total_grad_norm(ps)
            self.sw.add_scalar('grads/{}/total'.format(name), tot, i)

        # log LR
        lrs = list(self.get_lrs())
        assert len(lrs) == 1
        self.sw.add_scalar('train/lr', lrs[0], i)

        if not log_heavy:
            return

        self.blueprint.add_image_summaries(self.sw, out, i, 'train')

    def validation_loop(self, i):
        bs = pe.BatchSummarizer(self.sw, i)
        val_start = time.time()
        for j, batch in enumerate(self.dl_val):
            idxs, img_batch, s = self.blueprint.unpack(batch)

            # Only log TB summaries for first batch
            with self.summarizer.maybe_enable(prefix='val', flag=j == 0, global_step=i):
                out = self.blueprint.forward(img_batch)
                loss_pc, nonrecursive_bpsps, _ = self.blueprint.get_loss(out)

            bs.append('val/bpsp', sum(nonrecursive_bpsps))

            if j > 0:
                continue

            self.blueprint.add_image_summaries(self.sw, out, i, 'val')

        val_duration = time.time() - val_start
        num_imgs = len(self.dl_val.dataset)
        time_per_img = val_duration/num_imgs

        output_strs = bs.output_summaries()
        output_strs = ['{: 6d}'.format(i)] + output_strs + ['({:.3f} s/img)'.format(time_per_img)]
        output_str = ' | '.join(output_strs)
        sep = '-' * len(output_str)
        print('\n'.join([sep, output_str, sep]))
Esempio n. 4
0
    def encode(self, img, pout):
        """
        Encode image to disk at path `p`.
        :param img: uint8 tensor of shape CHW or 1CHW
        :param pout: path
        :return actual_bpsp
        """
        assert not os.path.isfile(pout)
        if len(img.shape) == 3:
            img = img.unsqueeze(0)  # 1CHW
        assert len(
            img.shape
        ) == 4 and img.shape[0] == 1 and img.shape[1] == 3, img.shape
        assert img.dtype == torch.int64, img.dtype

        if auto_crop.needs_crop(img):
            print('Need to encode individual crops!')

            c = auto_crop.CropLossCombinator()
            for i, img_crop in enumerate(auto_crop.iter_crops(img)):
                bpsp_crop = self.encode(
                    img_crop, pout + part_suffix_helper.make_part_suffix(i))
                c.add(bpsp_crop, np.prod(img_crop.shape[-2:]))
            return c.get_bpsp()

        # TODO: Note that recursive is not supported.
        padding = 2**self.blueprint.net.config_ms.num_scales
        _, _, H, W = img.shape
        if H % padding != 0 or W % padding != 0:
            print(
                f'*** INFO: image shape ({H}X{W}) not divisible by {padding}, will pad.'
            )
            img, padding_tuple = pad.pad(
                img, fac=padding, mode=MultiscaleBlueprint.get_padding_mode())
        else:
            padding_tuple = (0, 0, 0, 0)

        img = img.float()

        with self.times.run('[-] encode forwardpass'):
            out = self.blueprint.net(img)

        if self.compare_with_theory:
            with self.times.run('[-] get loss'):
                loss_out = self.blueprint.get_loss(out)

        self.blueprint.net.zero_grad()

        entropy_coding_bytes = []  # bytes used by different scales

        with open(pout, 'wb') as fout:
            write_padding_tuple(padding_tuple, fout)
            for scale, dmll, uniform in self.iter_scale_dmll():
                with self.times.prefix_scope(f'[{scale}]'):
                    if uniform:
                        entropy_coding_bytes.append(
                            self.encode_uniform(dmll, out.S[scale], fout))
                    else:
                        entropy_coding_bytes.append(
                            self.encode_scale(scale, dmll, out, img, fout))
                    fout.write(_MAGIC_VALUE_SEP)

        num_subpixels = np.prod(img.shape)
        actual_num_bytes = os.path.getsize(pout)
        actual_bpsp = actual_num_bytes * 8 / num_subpixels

        if self.compare_with_theory:
            assumed_bpsps = [
                b * 8 / num_subpixels for b in entropy_coding_bytes
            ]
            tostr = lambda l: ' | '.join(map('{:.3f}'.format, l)
                                         ) + f' => {sum(l):.3f}'
            overhead = (sum(assumed_bpsps) / sum(loss_out.nonrecursive_bpsps) -
                        1) * 100
            info = f'Bitrates:\n' \
                f'theory:  {tostr(loss_out.nonrecursive_bpsps)}\n' \
                f'assumed: {tostr(list(reversed(assumed_bpsps)))} [{overhead:.2f}%]\n' \
                f'actual:                                => {actual_bpsp:.3f} [{actual_num_bytes} bytes]'
            print(info)
            return actual_bpsp
        else:
            return actual_bpsp
    def __init__(self, log_date, flags, restore_itr, l3c=False):
        """
        :param flags:
            log_dir
            img
            filter_filenames
            max_imgs_per_folder
            # out_dir
            crop
            recursive
            sample
            write_to_files
            compare_theory
            time_report
            overwrite_cache
        """
        self.flags = flags

        test_log_dir_root = self.flags.log_dir.rstrip(os.path.sep) + '_test'
        global_config.reset()

        config_ps, experiment_dir = MultiscaleTester.get_configs_experiment_dir(
            'ms', self.flags.log_dir, log_date)
        self.log_date = logdir_helpers.log_date_from_log_dir(experiment_dir)
        (self.config_ms, _), _ = ft.unzip(map(config_parser.parse, config_ps))
        global_config.update_config(self.config_ms)

        self.recursive = _parse_recursive_flag(self.flags.recursive,
                                               config_ms=self.config_ms)
        if self.flags.write_to_files and self.recursive:
            raise NotImplementedError(
                '--write_to_file not implemented for --recursive')

        if self.recursive:
            print(f'--recursive={self.recursive}')

        blueprint = MultiscaleBlueprint(self.config_ms)
        blueprint.set_eval()
        self.blueprint = blueprint

        self.restorer = saver.Restorer(paths.get_ckpts_dir(experiment_dir))
        self.restore_itr, ckpt_p = self.restorer.get_ckpt_for_itr(restore_itr)
        self.restorer.restore({'net': self.blueprint.net}, ckpt_p, strict=True)

        # test_log_dir/0311_1057 cr oi_012
        self.test_log_dir = os.path.join(test_log_dir_root,
                                         os.path.basename(experiment_dir))
        if self.flags.reset_entire_cache and os.path.isdir(self.test_log_dir):
            print(f'Removing test_log_dir={self.test_log_dir}...')
            time.sleep(1)
            shutil.rmtree(self.test_log_dir)
        os.makedirs(self.test_log_dir, exist_ok=True)
        self.test_output_cache = TestOutputCache(self.test_log_dir)

        self.times = cuda_timer.StackTimeLogger(
        ) if self.flags.write_to_files else None

        # Import only if needed, as it imports torchac
        if self.flags.write_to_files:
            check_correct_torchac_backend_available()
            from bitcoding.bitcoding import Bitcoding
            self.bc = Bitcoding(self.blueprint,
                                times=self.times,
                                compare_with_theory=self.flags.compare_theory)
        elif l3c:  # Called from l3c.py
            from bitcoding.bitcoding import Bitcoding
            self.bc = Bitcoding(self.blueprint, times=no_op.NoOp)