Exemple #1
0
def interpolator(measures_per_image_iter, grid, interp_mode='linear'):
    accumulated_values = np.zeros_like(grid, np.float64)
    # Count values per bin
    N = np.zeros_like(grid, np.int64)
    num_imgs = 0
    num_errors = 0
    for img_description, (bpps, values) in measures_per_image_iter:
        assert len(bpps) >= 2, 'Missing values for {}'.format(img_description)
        assert bpps[0] >= bpps[-1]

        num_imgs += 1
        # interpolation function
        try:
            fq = scipy.interpolate.interp1d(bpps, values,
                                            interp_mode)  # key code
        except ValueError as e:
            print(bpps, values)
            print(e)
            exit(1)
        for i, bpp in enumerate(grid):
            try:
                accumulated_values[i] += fq(bpp)
                N[i] += 1
            except ValueError as e:
                num_errors += 1
                continue
    try:
        grid, values = ft.unzip(
            (bpp, m / n) for bpp, m, n in zip(grid, accumulated_values, N)
            if n > _REQUIRED_BINS * num_imgs)
    except ValueError as e:
        raise e
    return grid, values
Exemple #2
0
def get_interpolated_values_bpg_jp2k(bpg_or_jp2k_dir, grid, metric):
    """ :returns grid, values"""
    ps = other_codecs.all_measures_file_ps(bpg_or_jp2k_dir)
    if len(ps) == 0:
        raise CodecDistanceReadException('No matches in {}'.format(bpg_or_jp2k_dir))
    measures_per_image_iter = ((p, ft.unzip(sorted(other_codecs.read_measures(p, metric), reverse=True))) for p in ps)
    return interpolator(measures_per_image_iter, grid, interp_mode='linear')
Exemple #3
0
def interpolator(measures_per_image_iter, grid, interp_mode='linear'):
    accumulated_values = np.zeros_like(grid, np.float64)
    # Count values per bin
    N = np.zeros_like(grid, np.int64)
    num_imgs = 0
    num_errors = 0

    for img_description, (bpps, values) in measures_per_image_iter:
        assert_exc(
            len(bpps) >= 2, 'Missing values for {}'.format(img_description),
            OtherCodecsReadException)
        assert_exc(bpps[0] >= bpps[-1],
                   f'First bpp < last: {bpps[0]} < {bpps[-1]}',
                   OtherCodecsReadException)

        num_imgs += 1

        # create interpolation function
        try:
            fq = scipy.interpolate.interp1d(bpps, values, interp_mode)
        except ValueError as e:
            raise OtherCodecsReadException('ValueError while creating fq:', e)

        for i, bpp in enumerate(grid):
            try:
                accumulated_values[i] += fq(bpp)
                N[i] += 1
            except ValueError as e:
                num_errors += 1
                continue
    try:
        grid, values = ft.unzip(
            (bpp, v / n) for bpp, v, n in zip(grid, accumulated_values, N)
            if n > _REQUIRED_BINS * num_imgs)
    except ValueError as e:
        raise ValueError(
            grid, accumulated_values, N,
            list(
                ft.unzip((bpp, v / n)
                         for bpp, v, n in zip(grid, accumulated_values, N)
                         if n > _REQUIRED_BINS * num_imgs)))
    return grid, values
Exemple #4
0
def interpolate_ours(measures_readers, grid, interp_mode, metric):
    measures_per_image = defaultdict(list)
    for measures_reader in measures_readers:
        for img_name, bpp, value in measures_reader.iter_metric(metric):
            measures_per_image[img_name].append((bpp, value))

    # Make sure every job has a value for every image
    for img_name, values in measures_per_image.items():
        assert len(values) == len(measures_readers), '{}: {}'.format(img_name, len(values))

    return interpolator(
            ((img_name, ft.unzip(sorted(bpps_values, reverse=True)))
             for img_name, bpps_values in measures_per_image.items()),
            grid, interp_mode)
    def __init__(self,
                 ms_config_p, dl_config_p,
                 log_dir_root, log_config: LogConfig,
                 num_workers,
                 saver: Saver, restorer: TrainRestorer=None,
                 sw_cls=vis.safe_summary_writer.SafeSummaryWriter):
        """
        :param ms_config_p: Path to the multiscale config file, see README
        :param dl_config_p: Path to the dataloader config file, see README
        :param log_dir_root: All outputs (checkpoints, tensorboard) will be saved here.
        :param log_config: Instance of train.trainer.LogConfig, contains intervals.
        :param num_workers: Number of workers to use for DataLoading, see train.py
        :param saver: Saver instance to use.
        :param restorer: Instance of TrainRestorer, if we need to restore
        """

        # Read configs
        # config_ms = config for the network (ms = multiscale)
        # config_dl = config for data loading
        (self.config_ms, self.config_dl), rel_paths = ft.unzip(map(config_parser.parse, [ms_config_p, dl_config_p]))
        # Update config_ms depending on global_config
        global_config.update_config(self.config_ms)
        # Create data loaders
        dl_train, dl_val = self._get_dataloaders(num_workers)
        # Create blueprint. A blueprint collects the network as well as the losses in one class, for easy reuse
        # during testing.
        self.blueprint = MultiscaleBlueprint(self.config_ms)
        print('Network:', self.blueprint.net)
        # Setup optimizer
        optim_cls = {'RMSprop': optim.RMSprop,
                     'Adam': optim.Adam,
                     'SGD': optim.SGD,
                     }[self.config_ms.optim]
        net = self.blueprint.net
        self.optim = optim_cls(net.parameters(), self.config_ms.lr.initial,
                               weight_decay=self.config_ms.weight_decay)
        # Calculate a rough estimate for time per batch (does not take into account that CUDA is async,
        # but good enought to get a feeling during training).
        self.time_accumulator = timer.TimeAccumulator()
        # Restore network if requested
        skip_to_itr = self.maybe_restore(restorer)
        if skip_to_itr is not None:  # i.e., we have a restorer
            print('Skipping to {}...'.format(skip_to_itr))
        # Create LR schedule to update parameters
        self.lr_schedule = lr_schedule.from_spec(
                self.config_ms.lr.schedule, self.config_ms.lr.initial, [self.optim], epoch_len=len(dl_train))

        # --- All nn.Modules are setup ---
        print('-' * 80)

        # create log dir and summary writer
        self.log_dir = Trainer.get_log_dir(log_dir_root, rel_paths, restorer)
        self.log_date = logdir_helpers.log_date_from_log_dir(self.log_dir)
        self.ckpt_dir = os.path.join(self.log_dir, CKPTS_DIR_NAME)
        print(f'Checkpoints will be saved to {self.ckpt_dir}')
        saver.set_out_dir(self.ckpt_dir)


        # Create summary writer
        sw = sw_cls(self.log_dir)
        self.summarizer = vis.summarizable_module.Summarizer(sw)
        net.register_summarizer(self.summarizer)
        self.blueprint.register_summarizer(self.summarizer)
        # superclass setup
        super(MultiscaleTrainer, self).__init__(dl_train, dl_val, [self.optim], net, sw,
                                                max_epochs=self.config_dl.max_epochs,
                                                log_config=log_config, saver=saver, skip_to_itr=skip_to_itr)
Exemple #6
0
def interpolated_curve(log_dir_root, job_ids, dataset, grid, interp_mode,
                       plot_interp_of_ours, plot_mean_of_ours,
                       plot_ids_of_ours, metric, x_range, y_range, use_latex,
                       output_path, paper_plot):
    if not output_path:
        output_path = 'plot_{}.png'.format(TITLES[dataset])

    cmap = plt.cm.get_cmap('cool')

    style = {
        LABEL_OURS: ('0', '-', 3),
        LABEL_RB: (cmap(0.9), '-', 1.5),
        LABEL_BPG: (cmap(0.7), '-', 1.5),
        LABEL_JP2K: (cmap(0.45), '-', 1.5),
        LABEL_JP: (cmap(0.2), '-', 1.5),
        LABEL_WEBP: (cmap(0.1), '-', 1.5),
        LABEL_JOHNSTON: (cmap(0.7), '--', 1.5),
        LABEL_BALLE: (cmap(0.45), '--', 1.5),
        LABEL_THEIS: (cmap(0.2), '--', 1.5),
    }

    pos = {
        LABEL_OURS: 10,
        LABEL_RB: 9,
        LABEL_JOHNSTON: 8,
        LABEL_BPG: 7,
        LABEL_BALLE: 6,
        LABEL_JP2K: 5,
        LABEL_THEIS: 4,
        LABEL_JP: 3,
        LABEL_WEBP: 2,
        'Fig. 1': 11
    }

    plt.figure(figsize=(6, 6))
    if use_latex:
        plt.rc('text', usetex=True)
        plt.rc('font', family='serif', serif=['Computer Modern Roman'])

    for codec_short_name, measures_dir in CODECS[dataset].items():
        measures_dir = os.path.join(constants.OTHER_CODECS_ROOT, measures_dir)
        label = get_label_from_codec_short_name(codec_short_name)
        col, line_style, line_width = style[label]
        assert os.path.exists(measures_dir), measures_dir
        this_grid, this_msssims = get_interpolated_values_bpg_jp2k(
            measures_dir, grid, metric)
        dashes = (5, 1) if line_style == '--' else []
        plt.plot(this_grid,
                 this_msssims,
                 label=label,
                 linewidth=line_width,
                 color=col,
                 dashes=dashes)

    if dataset == 'kodak':
        for name, data in [(LABEL_RB, _RIPPEL_KODAK)]:
            print('hi')
            col, line_style, line_width = style[name]
            dashes = (5, 1) if line_style == '--' else []
            plt.plot(*ft.unzip(data),
                     label=name,
                     color=col,
                     linewidth=line_width,
                     dashes=dashes)

    for job_ids in job_ids.split(';'):
        measures_readers = get_measures_readers(log_dir_root, job_ids, dataset)
        print('\n'.join(m.p for m in measures_readers))

        if measures_readers:  # may be empty if no job_ids are passed
            col, line_style, line_width = style['Ours']
            if plot_interp_of_ours:
                ours_grid, ours_msssim = interpolate_ours(
                    measures_readers, grid, interp_mode, metric)
                dashes = (5, 1) if line_style == '--' else []
                plt.plot(ours_grid,
                         ours_msssim,
                         label='Ours',
                         color=col,
                         linewidth=line_width,
                         dashes=dashes)
            if plot_mean_of_ours:
                plot_ours_mean(measures_readers, metric, col, plot_ids_of_ours)

    if paper_plot:
        col, line_style, line_width = style['Ours']
        dashes = (5, 1) if line_style == '--' else []
        plt.plot(*ft.unzip(CVPR_FIG1),
                 label='Fig. 1',
                 color=col,
                 linewidth=line_width,
                 dashes=dashes)

    plt.title('{} on {}'.format(metric.upper(), TITLES[dataset]))
    plt.xlabel('bpp', labelpad=-5)
    plt.grid()

    ax = plt.gca()
    handles, labels = ax.get_legend_handles_labels()
    labels, handles = zip(
        *sorted(zip(labels, handles), reverse=True, key=lambda t: pos[t[0]]))
    ax.legend(handles,
              labels,
              loc=4,
              prop={'size': 12},
              fancybox=True,
              framealpha=0.7)

    ax.yaxis.grid(b=True, which='both', color='0.8', linestyle='-')
    ax.xaxis.grid(b=True, which='major', color='0.8', linestyle='-')
    ax.set_axisbelow(True)

    ax.minorticks_on()
    ax.yaxis.set_minor_locator(mpl.ticker.AutoMinorLocator(2))

    plt.xlim(x_range)
    plt.ylim(y_range)
    print('Saving {}...'.format(output_path))
    plt.savefig(output_path, bbox_inches='tight')
    plt.close()
Exemple #7
0
def parse_configs(*configs):
    """ Parse multiple configs """
    return ft.unzip(map(parse, configs))
    def __init__(self, log_date, flags, restore_itr, l3c=False):
        """
        :param flags:
            log_dir
            img
            filter_filenames
            max_imgs_per_folder
            # out_dir
            crop
            recursive
            sample
            write_to_files
            compare_theory
            time_report
            overwrite_cache
        """
        self.flags = flags

        test_log_dir_root = self.flags.log_dir.rstrip(os.path.sep) + '_test'
        global_config.reset()

        config_ps, experiment_dir = MultiscaleTester.get_configs_experiment_dir(
            'ms', self.flags.log_dir, log_date)
        self.log_date = logdir_helpers.log_date_from_log_dir(experiment_dir)
        (self.config_ms, _), _ = ft.unzip(map(config_parser.parse, config_ps))
        global_config.update_config(self.config_ms)

        self.recursive = _parse_recursive_flag(self.flags.recursive,
                                               config_ms=self.config_ms)
        if self.flags.write_to_files and self.recursive:
            raise NotImplementedError(
                '--write_to_file not implemented for --recursive')

        if self.recursive:
            print(f'--recursive={self.recursive}')

        blueprint = MultiscaleBlueprint(self.config_ms)
        blueprint.set_eval()
        self.blueprint = blueprint

        self.restorer = saver.Restorer(paths.get_ckpts_dir(experiment_dir))
        self.restore_itr, ckpt_p = self.restorer.get_ckpt_for_itr(restore_itr)
        self.restorer.restore({'net': self.blueprint.net}, ckpt_p, strict=True)

        # test_log_dir/0311_1057 cr oi_012
        self.test_log_dir = os.path.join(test_log_dir_root,
                                         os.path.basename(experiment_dir))
        if self.flags.reset_entire_cache and os.path.isdir(self.test_log_dir):
            print(f'Removing test_log_dir={self.test_log_dir}...')
            time.sleep(1)
            shutil.rmtree(self.test_log_dir)
        os.makedirs(self.test_log_dir, exist_ok=True)
        self.test_output_cache = TestOutputCache(self.test_log_dir)

        self.times = cuda_timer.StackTimeLogger(
        ) if self.flags.write_to_files else None

        # Import only if needed, as it imports torchac
        if self.flags.write_to_files:
            check_correct_torchac_backend_available()
            from bitcoding.bitcoding import Bitcoding
            self.bc = Bitcoding(self.blueprint,
                                times=self.times,
                                compare_with_theory=self.flags.compare_theory)
        elif l3c:  # Called from l3c.py
            from bitcoding.bitcoding import Bitcoding
            self.bc = Bitcoding(self.blueprint, times=no_op.NoOp)
Exemple #9
0
def plot_measured_dataset(
        out_dir,
        points_ours: Optional[Dict[str, List[Tuple[str, float,
                                                   float]]]] = None):
    bpp_grid = np.linspace(0, 3, 100)
    ylims = {'psnr': 26}
    for metric in SUPPORTED_METRICS:
        title = metric
        out_p = os.path.join(out_dir, title + '.pdf')
        print(f'Creating plot {out_p}...')
        plt.figure(figsize=(10, 5))
        plt.title(title)
        xlim = bpp_grid[-1]
        if points_ours and metric in points_ours:
            xlim = max(xlim, max(bpp for _, bpp, _ in points_ours[metric]))
        plt.xlim((0, xlim))
        if metric in ylims and metric in points_ours:
            plt.ylim(
                (ylims[metric],
                 np.ceil(max(psnr for _, _, psnr in points_ours[metric]))))

        for codec_dir in os_ext.listdir_paths(out_dir):
            ps = all_measures_file_ps(codec_dir)
            if len(ps) == 0:
                # print(f'Nothing in {codec_dir}, skipping...')
                continue

            if any(count_num_measures(p) == 0 for p in ps):
                print(
                    f'Found incomplete measures file, skipping {os.path.basename(codec_dir)}'
                )
                continue

            measures_per_image_iter = ((p,
                                        ft.unzip(
                                            sorted(read_measures(p, metric),
                                                   reverse=True))) for p in ps)

            try:
                grid, values = interpolator(measures_per_image_iter,
                                            bpp_grid,
                                            interp_mode='linear')
            except OtherCodecsReadException as e:
                print(f'*** Skipping {os.path.basename(codec_dir)}: {e}')
                continue
            except ValueError as e:
                print(e)
                print(codec_dir)
                continue

            codec_name = os.path.basename(codec_dir)
            plt.plot(grid, values, label=codec_name)
            # plt.ylim((0.4, 1))

        if points_ours and metric in points_ours:
            ours_for_metric = points_ours[metric]
            for name, bpp, value in ours_for_metric:
                plt.scatter(bpp, value, label=name)

        plt.legend()
        plt.savefig(out_p, bbox_inches='tight')
        plt.close()
    def __init__(self,
                 config_p,
                 dl_config_p,
                 log_dir_root,
                 log_config: LogConfig,
                 num_workers,
                 saver: Saver,
                 restorer: TrainRestorer = None,
                 sw_cls=vis.safe_summary_writer.SafeSummaryWriter):
        """
        :param config_p: Path to the network config file, see README
        :param dl_config_p: Path to the dataloader config file, see README
        :param log_dir_root: All outputs (checkpoints, tensorboard) will be saved here.
        :param log_config: Instance of train.trainer.LogConfig, contains intervals.
        :param num_workers: Number of workers to use for DataLoading, see train.py
        :param saver: Saver instance to use.
        :param restorer: Instance of TrainRestorer, if we need to restore
        """
        self.style = MultiscaleTrainer.get_style_from_config(config_p)
        self.blueprint_cls = {
            'enhancement': EnhancementBlueprint,
            'classifier': ClassifierBlueprint
        }[self.style]

        global_config.declare_used('filter_imgs')

        # Read configs
        # config = config for the network
        # config_dl = config for data loading
        (self.config, self.config_dl), rel_paths = ft.unzip(
            map(config_parser.parse, [config_p, dl_config_p]))
        # TODO only read by enhancement classes
        self.config.is_residual = self.config_dl.is_residual_dataset

        # Update global_config given config.global_config
        global_config_config_keys = global_config.add_from_str_without_overwriting(
            self.config.global_config)
        # Update config_ms depending on global_config
        global_config.update_config(self.config)

        if self.style == 'enhancement':
            EnhancementBlueprint.read_evenly_spaced_bins(self.config_dl)

        self._custom_init()

        # Create data loaders
        dl_train, self.ds_val, self.fixed_first_val = self._get_dataloaders(
            num_workers)
        # Create blueprint. A blueprint collects the network as well as the losses in one class, for easy reuse
        # during testing.
        self.blueprint = self.blueprint_cls(self.config)
        print('Network:', self.blueprint.net)
        # Setup optimizer
        optim_cls = {
            'RMSprop': optim.RMSprop,
            'Adam': optim.Adam,
            'SGD': optim.SGD,
        }[self.config.optim]
        net = self.blueprint.net
        self.optim = optim_cls(net.parameters(),
                               self.config.lr.initial,
                               weight_decay=self.config.weight_decay)
        # Calculate a rough estimate for time per batch (does not take into account that CUDA is async,
        # but good enought to get a feeling during training).
        self.time_accumulator = timer.TimeAccumulator()
        # Restore network if requested
        skip_to_itr = self.maybe_restore(restorer)
        if skip_to_itr is not None:  # i.e., we have a restorer
            print('Skipping to {}...'.format(skip_to_itr))
        # Create LR schedule to update parameters
        self.lr_schedule = lr_schedule.from_spec(self.config.lr.schedule,
                                                 self.config.lr.initial,
                                                 [self.optim],
                                                 epoch_len=len(dl_train))

        # --- All nn.Modules are setup ---
        print('-' * 80)

        # create log dir and summary writer
        self.log_dir_root = log_dir_root
        global_config_values = global_config.values(
            ignore=global_config_config_keys)
        self.log_dir = Trainer.get_log_dir(
            log_dir_root,
            rel_paths,
            restorer,
            global_config_values=global_config_values)
        self.log_date = logdir_helpers.log_date_from_log_dir(self.log_dir)
        self.ckpt_dir = os.path.join(self.log_dir, CKPTS_DIR_NAME)
        print(f'Checkpoints will be saved to {self.ckpt_dir}')
        saver.set_out_dir(self.ckpt_dir)

        if global_config.get('ds_syn', None):
            underlying = dl_train.dataset
            while not isinstance(underlying, _CheckerboardDataset):
                underlying = underlying.ds
            underlying.save_all(self.log_dir)

        # Create summary writer
        sw = sw_cls(self.log_dir)
        self.summarizer = vis.summarizable_module.Summarizer(sw)
        net.register_summarizer(self.summarizer)
        self.blueprint.register_summarizer(self.summarizer)

        # Try to write filenames somewhere
        try:
            dl_train.dataset.write_file_names_to_txt(self.log_dir)
        except AttributeError:
            raise AttributeError(
                f'dl_train.dataset of type {type(dl_train.dataset)} does not support '
                f'write_file_names_to_txt(log_dir)!')

        # superclass setup
        super(MultiscaleTrainer,
              self).__init__(dl_train, [self.optim],
                             net,
                             sw,
                             max_epochs=self.config_dl.max_epochs,
                             log_config=log_config,
                             saver=saver,
                             skip_to_itr=skip_to_itr)