def ops_saved_summery(net_name=cfg.NET.__name__,
                      dataset_name=cfg.DATA.name(),
                      mode=Mode.ALL_MODES,
                      ps='*',
                      ones_range=('*', '*'),
                      acc_loss='*',
                      gran_thresh='*',
                      init_acc='*',
                      batch_size=cfg.BATCH_SIZE,
                      max_samples=cfg.TEST_SET_SIZE):
    rec_finder = RecordFinder(net_name, dataset_name, ps, ones_range,
                              gran_thresh, acc_loss, init_acc)
    final_rec_fn = rec_finder.find_rec_filename(mode,
                                                RecordType.FINAL_RESULT_REC)
    if final_rec_fn is None:
        print('No Record found')
        return
    rec = load_from_file(final_rec_fn, '')
    print(rec)

    base_fn = 'ops_summery_' + rec.filename
    summery_fn_pkl = os.path.join(cfg.RESULTS_DIR, base_fn + '.pkl')
    if os.path.exists(summery_fn_pkl):
        arr = load_from_file(summery_fn_pkl, path='')
    else:
        nn = NeuralNet()
        data = Datasets.get(dataset_name, cfg.DATASET_DIR)
        nn.net.initialize_spatial_layers(data.shape(), cfg.BATCH_SIZE,
                                         rec.patch_size)
        test_gen, _ = data.testset(batch_size=batch_size,
                                   max_samples=max_samples)

        arr = [None] * len(rec.mask)
        for idx, layer in enumerate(rec.mask):
            nn.net.reset_spatial()
            print(
                f"----------------------------------------------------------------"
            )

            nn.net.strict_mask_update(update_ids=[idx], masks=[layer])
            _, test_acc, _ = nn.test(test_gen)
            ops_saved, ops_total = nn.net.num_ops()

            arr[idx] = (ops_saved, ops_total, test_acc)
            nn.net.print_ops_summary()

        print(
            f"----------------------------------------------------------------"
        )
        nn.net.reset_spatial()
        save_to_file(arr, use_default=False, path='', filename=summery_fn_pkl)

    out_path = os.path.join(cfg.RESULTS_DIR, base_fn + ".csv")
    with open(out_path, 'w', newline='') as f:
        csv.writer(f).writerow(['layer', 'ops_saved', 'ops_total'])
        for idx, r in enumerate(arr):
            csv.writer(f).writerow([idx, r[0], r[1]])

    return arr
示例#2
0
 def _quantizier_main(self, rec_type, in_rec):
     if rec_type == RecordType.lQ_RESUME:
         resume_param_path = self.record_finder.find_rec_filename(
             in_rec.mode, RecordType.lQ_RESUME)
         quantizier = LayerQuantizier(in_rec, self.init_acc,
                                      self.max_acc_loss, self.ps,
                                      self.ones_range, self.get_total_ops(),
                                      resume_param_path)
     else:
         q_rec_fn = self.record_finder.find_rec_filename(
             in_rec.mode, rec_type)
         Quantizier = PatchQuantizier if rec_type == RecordType.pQ_REC else ChannelQuantizier
         if q_rec_fn is None:
             quantizier = Quantizier(in_rec, self.init_acc,
                                     self.max_acc_loss, self.ps)
         else:
             quantizier = Quantizier(in_rec, self.init_acc,
                                     self.max_acc_loss, self.ps,
                                     load_from_file(q_rec_fn, ''))
     if not quantizier.is_finised():
         self._init_nn()
         quantizier.simulate(self.nn, self.test_gen)
     if RecordType.lQ_RESUME == rec_type:
         return
     return quantizier.output_rec
示例#3
0
    def gen_first_lvl_results(self, mode):
        rec_filename = self.record_finder.find_rec_filename(
            mode, RecordType.FIRST_LVL_REC)
        if rec_filename is not None:
            rcs = load_from_file(rec_filename, path='')
            st_point = rcs.find_resume_point()
            if st_point is None:
                return rcs

        layers_layout = self.nn.net.generate_spatial_sizes(dat.shape())
        self._init_nn()

        if rec_filename is None:
            if self.input_patterns is None:
                rcs = Record(layers_layout, self.gran_thresh, True, mode,
                             self.init_acc, self.ps, self.ones_range)
            else:
                rcs = Record(layers_layout, self.gran_thresh, False, mode,
                             self.init_acc, self.input_patterns,
                             self.ones_range)
            st_point = [0] * 4

            if INNAS_COMP:
                rcs.filename = 'DEBUG_' + rcs.filename

        print('==> Result will be saved to ' +
              os.path.join(cfg.RESULTS_DIR, rcs.filename))
        save_counter = 0
        for layer, channel, patch, pattern_idx, mask in tqdm(
                mf.gen_masks_with_resume(self.ps,
                                         rcs.all_patterns,
                                         rcs.mode,
                                         rcs.gran_thresh,
                                         layers_layout,
                                         resume_params=st_point)):
            self.nn.net.strict_mask_update(update_ids=[layer],
                                           masks=[torch.from_numpy(mask)])

            if INNAS_COMP:
                test_acc = 100
                ops_saved = 100
                ops_total = 100
            else:
                _, test_acc, _ = self.nn.test(self.test_gen)
                ops_saved, ops_total = self.nn.net.num_ops()
                self.nn.net.reset_spatial()
            rcs.addRecord(ops_saved, ops_total, test_acc, layer, channel,
                          patch, pattern_idx)

            save_counter += 1
            if save_counter > cfg.SAVE_INTERVAL:
                save_to_file(rcs, True, cfg.RESULTS_DIR)
                save_counter = 0

        save_to_file(rcs, True, cfg.RESULTS_DIR)
        print('==> Result saved to ' +
              os.path.join(cfg.RESULTS_DIR, rcs.filename))
        return rcs
def plot_ops_saved_vs_max_acc_loss(net_name,
                                   dataset_name,
                                   ps,
                                   ones_range,
                                   gran_thresh,
                                   acc_loss_opts,
                                   init_acc,
                                   modes=None,
                                   title=None):
    bs_line_rec = get_baseline_rec(net_name, dataset_name, ps, init_acc)
    plt.figure()
    if bs_line_rec is not None:
        plt.plot(acc_loss_opts,
                 [round(bs_line_rec.ops_saved / bs_line_rec.total_ops, 3)] *
                 len(acc_loss_opts),
                 '--',
                 label=f'baseline')
        plt.axvline(x=bs_line_rec.init_acc - bs_line_rec.baseline_acc,
                    linestyle='--',
                    label='baseline')

    modes = get_modes(modes)
    for mode in modes:
        ops_saved = [None] * len(acc_loss_opts)
        for idx, acc_loss in enumerate(acc_loss_opts):
            rec_finder = RecordFinder(net_name, dataset_name, ps, ones_range,
                                      gran_thresh, acc_loss, init_acc)
            fn = rec_finder.find_rec_filename(mode,
                                              RecordType.FINAL_RESULT_REC)
            if fn is not None:
                final_rec = load_from_file(fn, '')
                ops_saved[idx] = round(
                    final_rec.ops_saved / final_rec.total_ops, 3)
        if (ops_saved != [None] * len(acc_loss_opts)):
            plt.plot(acc_loss_opts, ops_saved, 'o--', label=gran_dict[mode])

    plt.xlabel('max acc loss [%]')
    plt.ylabel('operations saved [%]')

    if title is None:
        title = ''
    plt.title(
        f'Operations Saved vs Maximun Allowed Accuracy Loss {title}\n'
        f'{net_name}, {dataset_name}, INITIAL ACC:{init_acc} \n'
        f'PATCH SIZE:{ps}, ONES:{ones_range[0]}-{ones_range[1]-1}, GRANULARITY:{gran_thresh}\n'
        f'LQ{cfg.LQ_OPTION}, CQ{cfg.CQ_OPTION}r{cfg.CHANNELQ_UPDATE_RATIO}, PQ{cfg.PQ_OPTION}r{cfg.PATCHQ_UPDATE_RATIO}'
    )

    plt.legend()
    # plt.show()
    plt.savefig(
        f'{cfg.RESULTS_DIR}/ops_saved_vs_max_acc_loss_{net_name}_{dataset_name}_acc{init_acc}_'
        +
        f'LQ{cfg.LQ_OPTION}_CQ{cfg.CQ_OPTION}r{cfg.CHANNELQ_UPDATE_RATIO}_PQ{cfg.PQ_OPTION}r{cfg.PATCHQ_UPDATE_RATIO}_'
        + f'ps{ps}_ones{ones_range[0]}x{ones_range[1]}_mg{gran_thresh}.pdf')
示例#5
0
def debug_layer_q(acc_loss, regexes=None):
    print('debuging...')
    if regexes is None:
        regexes = [
            f'LayerQ{cfg.LQ_OPTION}_' +
            f'ma{acc_loss}_PatchQ_ma{acc_loss}_ResNet18Spatial_CIFAR10_acc93.5_uniform_filters_ps2_ones1x3_mg10_',
            f'LayerQ{cfg.LQ_OPTION}_' +
            f'ma{acc_loss}_ChannelQ_ma{acc_loss}_ResNet18Spatial_CIFAR10_acc93.5_uniform_patch_ps2_ones1x3_mg10_',
            f'LayerQ{cfg.LQ_OPTION}_' +
            f'ma{acc_loss}_ResNet18Spatial_CIFAR10_acc93.5_uniform_layer_ps2_ones1x3_mg10_'
        ]
    for regex in regexes:
        print(f'{cfg.RESULTS_DIR}/{regex}*pkl')
        fn = glob.glob(f'{cfg.RESULTS_DIR}/{regex}*pkl')[0]
        rec = load_from_file(fn, '')
        rec.debug(os.path.basename(fn))
def get_baseline_rec(net_name, dataset_name, ps, init_acc):
    rec_finder = RecordFinder(net_name, dataset_name, ps, ('*', '*'), '*', '*',
                              init_acc)
    bs_line_fn = rec_finder.find_rec_filename(None, RecordType.BASELINE_REC)
    if bs_line_fn is None:
        optim = Optimizer(ps, (None, None), None, None)
        optim.base_line_result()
        bs_line_fn = rec_finder.find_rec_filename(None,
                                                  RecordType.BASELINE_REC)
    if bs_line_fn is None:
        print(
            f' !!! Was not able to get baseline result for initial accuracy of {init_acc} !!!'
        )
        print(f' !!! Adjust TEST_SET_SIZE in Config.py !!!')
        return bs_line_fn
    return load_from_file(bs_line_fn, '')
def plot_ops_saved_vs_ones(net_name,
                           dataset_name,
                           ps,
                           ones_possibilities,
                           gran_thresh,
                           acc_loss,
                           init_acc,
                           modes=None):
    #    bs_line_rec = get_baseline_rec(net_name, dataset_name, ps, init_acc)
    plt.figure()
    #    if bs_line_rec is not None:
    #        plt.plot(ones_possibilities, [bs_line_rec.ops_saved/bs_line_rec.total_ops]*len(ones_possibilities),
    #                                      '--', label=f'baseline, {round(bs_line_rec.init_acc-bs_line_rec.baseline_acc, 2)}% loss')
    modes = get_modes(modes)
    for mode in modes:
        ops_saved = [None] * len(ones_possibilities)
        has_results = False
        for idx, ones in enumerate(ones_possibilities):
            rec_finder = RecordFinder(net_name, dataset_name, ps,
                                      (ones, ones + 1), gran_thresh, acc_loss,
                                      init_acc)
            fn = rec_finder.find_rec_filename(mode,
                                              RecordType.FINAL_RESULT_REC)
            if fn is not None:
                rec = load_from_file(fn, '')
                ops_saved[idx] = round(rec.ops_saved / rec.total_ops, 3)
                has_results = True
        if has_results:
            plt.plot(ones_possibilities,
                     ops_saved,
                     'o--',
                     label=gran_dict[mode])
    plt.xlabel('number of ones')
    plt.ylabel('operations saved [%]')
    plt.title(
        f'Operations Saved vs Number of Ones \n'
        f'{net_name}, {dataset_name}, INITIAL ACC:{init_acc} \n'
        f'PATCH SIZE:{ps}, MAX ACC LOSS:{acc_loss}, GRANULARITY:{gran_thresh}')
    plt.legend()
    plt.savefig(
        f'{cfg.RESULTS_DIR}/ops_saved_vs_number_of_ones_{net_name}_{dataset_name}'
        + f'acc{init_acc}_ps{ps}_ma{acc_loss}_mg{gran_thresh}.pdf')
示例#8
0
    def __init__(self,
                 patch_size,
                 ones_range,
                 gran_thresh,
                 max_acc_loss,
                 init_acc=None,
                 test_size=cfg.TEST_SET_SIZE,
                 patterns_idx=None):
        self.ps = patch_size
        self.max_acc_loss = max_acc_loss
        self.gran_thresh = gran_thresh

        if patterns_idx is None:
            self.ones_range = ones_range
            self.input_patterns = None
        else:
            patterns_rec = load_from_file(
                f'all_patterns_ps{self.ps}_cluster{patterns_idx}.pkl',
                path=cfg.RESULTS_DIR)
            self.ones_range = (patterns_rec[1], patterns_rec[1] + 1)
            self.input_patterns = patterns_rec[2]

        self.full_net_run_time = None
        self.total_ops = None

        self.nn = NeuralNet()
        self.nn.net.initialize_spatial_layers(dat.shape(), cfg.BATCH_SIZE,
                                              self.ps)
        self.test_gen, _ = dat.testset(batch_size=cfg.BATCH_SIZE,
                                       max_samples=cfg.TEST_SET_SIZE)
        self.test_set_size = cfg.TEST_SET_SIZE
        if INNAS_COMP:
            init_acc = DEBUG_INIT_ACC
        if init_acc is None:
            _, test_acc, correct = self.nn.test(self.test_gen)
            print(f'==> Asserted test-acc of: {test_acc} [{correct}]\n ')
            self.init_acc = test_acc  # TODO - Fix initialize bug
        else:
            self.init_acc = init_acc
        self.record_finder = RecordFinder(cfg.NET.__name__, dat.name(),
                                          patch_size, ones_range, gran_thresh,
                                          max_acc_loss, self.init_acc)
示例#9
0
    def __init__(self, rec, init_acc, max_acc_loss, patch_size, ones_range, total_ops, resume_param_path=None, default_in_pattern=None):
        self.patch_size = patch_size
        self.input_patterns = rec.all_patterns
        self.min_acc = init_acc - max_acc_loss
        self.mode = rec.mode
        self.layers_layout = rec.layers_layout
        self.max_acc_loss = max_acc_loss
        self.ones_range = ones_range
        self.total_ops = total_ops
        
        self.input = rec.gen_pattern_lists(self.min_acc)
        self.input = [self.input[l][0][0] for l in range(len(self.input))]
        self.no_of_patterns = 1
        for l_idx in range(len(self.input)):
            if cfg.LQ_OPTION == cfg.LQ_modes.DEFAULT:
                self.no_of_patterns += len(self.input[l_idx])
            elif cfg.LQ_OPTION == cfg.LQ_modes.PRODUCT:
                self.no_of_patterns *= len(self.input[l_idx])
        self._clean_input()
        self.product_iter = None
        
        if default_in_pattern is not None:
            self.default_in_pattern = default_in_pattern
        elif rec.mode == Mode.UNIFORM_LAYER:
            self.default_in_pattern = np.ones((self.input_patterns.shape[0], self.input_patterns.shape[0]),
                                              dtype=self.input_patterns.dtype)
        elif rec.mode == Mode.UNIFORM_FILTERS:
            self.default_in_pattern = np.ones((self.patch_size, self.patch_size), dtype=self.input_patterns[0][0][0].dtype)
        else:
            self.default_in_pattern = np.ones((self.patch_size, self.patch_size), dtype=self.input_patterns[0][0].dtype)

        self.resume_param_filename = f'LayerQ{cfg.LQ_OPTION.value}_ma' + str(max_acc_loss) + '_' + rec.filename
        if LQ_DEBUG:
            self.resume_param_filename = 'DEBUG_' + self.resume_param_filename

        if resume_param_path is None:
            input_length = [len(self.input[layer]) for layer in range(len(self.input))]
            self.resume_rec = LayerQuantResumeRec(len(self.input), input_length, max_acc_loss, self.input)
        else:
            self.resume_rec = load_from_file(resume_param_path, path='')
示例#10
0
 def create_FR_with_different_acc_loss(self, mode, acc_loss):
     self.record_finder.max_acc_loss = '*'
     best_FR = None
     for lq_rec_fn in self.record_finder.find_all_recs_fns(
             mode, RecordType.lQ_RESUME):
         in_rec_fn = '_'.join(os.path.basename(lq_rec_fn).split('_')[2:])
         lq = LayerQuantizier(
             load_from_file(in_rec_fn, path=cfg.RESULTS_DIR), self.init_acc,
             self.max_acc_loss, self.ps, self.ones_range,
             self.get_total_ops(), lq_rec_fn)
         final_rec = lq.find_final_mask(acc_loss,
                                        nn=self.nn,
                                        test_gen=self.test_gen)
         if best_FR is None:
             best_FR = final_rec
         elif best_FR.ops_saved < final_rec.ops_saved:
             best_FR = final_rec
     print(best_FR)
     save_to_file(best_FR, True, cfg.RESULTS_DIR)
     print('==> result saved to ' + best_FR.filename)
     self.record_finder.max_acc_loss = self.max_acc_loss
     return best_FR
示例#11
0
def show_final_mask_grid_resnet18(channel=0,
                                  net_name='*',
                                  dataset_name='*',
                                  mode=Mode.ALL_MODES,
                                  ps='*',
                                  ones_range=('*', '*'),
                                  acc_loss='*',
                                  gran_thresh='*',
                                  init_acc='*',
                                  filename=None,
                                  font_size=22,
                                  show_patch_grid=False):
    rec_finder = RecordFinder(net_name, dataset_name, ps, ones_range,
                              gran_thresh, acc_loss, init_acc)
    final_rec_fn = rec_finder.find_rec_filename(mode,
                                                RecordType.FINAL_RESULT_REC)
    if final_rec_fn is None:
        print('No Record found')
        return
    rec = load_from_file(final_rec_fn, '')

    shift = 4
    grid = (18, 24)
    st = ((0, 0), (0, 8), (0, 16), (8, 0), (8, 8), (8, 16), (8, 20), (12, 16),
          (12, 20), (16, 0 + shift), (16, 2 + shift), (16, 4 + shift),
          (16, 6 + shift), (16, 8 + shift), (16, 10 + shift), (16, 12 + shift),
          (16, 14 + shift))
    span = (8, 8, 8, 8, 8, 4, 4, 4, 4, 2, 2, 2, 2, 2, 2, 2, 2)

    fig = plt.figure()
    plt.tight_layout()
    fig.set_figheight(10)
    fig.set_figwidth(13)
    plt.rcParams.update({'font.size': font_size})
    for l_to_plot_idx, l_to_plot in enumerate(rec.mask):
        l_to_plot = l_to_plot.numpy()

        plt.subplot2grid(grid,
                         st[l_to_plot_idx],
                         colspan=span[l_to_plot_idx],
                         rowspan=span[l_to_plot_idx])
        plt.imshow(l_to_plot[channel], cmap=plt.cm.gray)  #(0:black, 1:white)
        plt.title(f'{l_to_plot_idx}')
        ax = plt.gca()
        # Minor ticks
        ax.set_xticks(np.arange(-.5, l_to_plot[channel].shape[0] - 1,
                                rec.patch_size),
                      minor=True)
        ax.set_yticks(np.arange(-.5, l_to_plot[channel].shape[1] - 1,
                                rec.patch_size),
                      minor=True)
        # Gridlines based on minor ticks
        if show_patch_grid:
            ax.grid(which='minor', color='r', linestyle='-', linewidth=2)
        plt.tick_params(axis='both',
                        which='major',
                        bottom=False,
                        top=False,
                        left=False,
                        right=False,
                        labelbottom=False,
                        labelleft=False)
    plt.tight_layout()
    plt.subplots_adjust(top=1)
    if filename is None:
        plt.show()
    else:
        plt.savefig(filename)

    return rec
示例#12
0
def show_channel_grid(layer=0,
                      net_name='*',
                      dataset_name='*',
                      mode=Mode.ALL_MODES,
                      ps='*',
                      ones_range=('*', '*'),
                      acc_loss='*',
                      gran_thresh='*',
                      init_acc='*',
                      filename=None,
                      font_size=22,
                      show_patch_grid=False):
    rec_finder = RecordFinder(net_name, dataset_name, ps, ones_range,
                              gran_thresh, acc_loss, init_acc)
    final_rec_fn = rec_finder.find_rec_filename(mode,
                                                RecordType.FINAL_RESULT_REC)
    if final_rec_fn is None:
        print('No Record found')
        return
    rec = load_from_file(final_rec_fn, '')

    layer_mask = rec.mask[layer].numpy()

    no_of_channels = layer_mask.shape[0]
    rows = math.ceil(math.sqrt(no_of_channels))
    fig, axs = plt.subplots(nrows=rows, ncols=rows)
    fig.set_figheight(30)
    fig.set_figwidth(30)
    plt.rcParams.update({'font.size': font_size})

    for c in range(layer_mask.shape[0]):
        row_idx = math.floor(c / rows)
        col_idx = c - row_idx * rows
        axs[row_idx][col_idx].imshow(layer_mask[c],
                                     cmap=plt.cm.gray)  #(0:black, 1:white)
        axs[row_idx][col_idx].set_title(f'{c}')

        # Minor ticks
        axs[row_idx][col_idx].set_xticks(np.arange(-.5,
                                                   layer_mask[c].shape[0] - 1,
                                                   rec.patch_size),
                                         minor=True)
        axs[row_idx][col_idx].set_yticks(np.arange(-.5,
                                                   layer_mask[c].shape[1] - 1,
                                                   rec.patch_size),
                                         minor=True)
        # Gridlines based on minor ticks
        if show_patch_grid:
            axs[row_idx][col_idx].grid(which='minor',
                                       color='r',
                                       linestyle='-',
                                       linewidth=2)
        axs[row_idx][col_idx].tick_params(axis='both',
                                          which='major',
                                          bottom=False,
                                          top=False,
                                          left=False,
                                          right=False,
                                          labelbottom=False,
                                          labelleft=False)
        #axs[row_idx][col_idx].colorbar()
    plt.tight_layout()
    if filename is None:
        plt.show()
    else:
        plt.savefig(filename)

    return rec
示例#13
0
def show_final_mask(show_all_layers=False,
                    layers_to_show=None,
                    show_all_channels=False,
                    channels_to_show=None,
                    plot_3D=False,
                    net_name='*',
                    dataset_name='*',
                    mode=Mode.ALL_MODES,
                    ps='*',
                    ones_range=('*', '*'),
                    acc_loss='*',
                    gran_thresh='*',
                    init_acc='*'):
    rec_finder = RecordFinder(net_name, dataset_name, ps, ones_range,
                              gran_thresh, acc_loss, init_acc)
    final_rec_fn = rec_finder.find_rec_filename(mode,
                                                RecordType.FINAL_RESULT_REC)
    if final_rec_fn is None:
        print('No Record found')
        return
    rec = load_from_file(final_rec_fn, '')
    print(rec)
    mask_size = [
        rec.layers_layout[l][0] * rec.layers_layout[l][1] *
        rec.layers_layout[l][2] for l in range(len(rec.mask))
    ]
    zeros_in_each_layer = [
        np.count_nonzero(rec.mask[l].numpy() == 0) / mask_size[l]
        for l in range(len(rec.mask))
    ]

    plt.figure()
    tick_label = [str(l) for l in range(len(rec.mask))]
    plt.bar(list(range(len(rec.mask))),
            zeros_in_each_layer,
            tick_label=tick_label)
    plt.xlabel('layer index')
    plt.ylabel('zeros [%]')
    plt.title('[%] of Zeros in each Prediction Layer for the Chosen Mask')
    plt.show()

    if show_all_layers:
        layers_to_show = range(len(rec.mask))
    elif layers_to_show is None:
        layers_to_show = [
            max(range(len(zeros_in_each_layer)),
                key=zeros_in_each_layer.__getitem__)
        ]

    for idx, l_to_plot_idx in enumerate(layers_to_show):
        l_to_plot = rec.mask[l_to_plot_idx].numpy()
        if rec.mode == Mode.UNIFORM_FILTERS or rec.mode == Mode.UNIFORM_LAYER:  # all channels in layer are the same
            show_channel(l_to_plot_idx, 0, rec.layers_layout[l_to_plot_idx],
                         l_to_plot[0], rec.patch_size)
        else:
            if plot_3D:
                show_layer(l_to_plot_idx, rec.layers_layout[l_to_plot_idx],
                           l_to_plot)
            if show_all_channels:
                channels = range(rec.layers_layout[l_to_plot_idx][0])
            elif channels_to_show is None:
                channels = [
                    0,
                    round(rec.layers_layout[l_to_plot_idx][0] / 2),
                    rec.layers_layout[l_to_plot_idx][0] - 1
                ]
            elif type(channels_to_show) is list and type(
                    channels_to_show[0]) is list:
                channels = channels_to_show[idx]
            elif type(channels_to_show) is not list:
                channels = [channels_to_show]
            else:
                channels = channels_to_show
            for channel in channels:
                show_channel(l_to_plot_idx, channel,
                             rec.layers_layout[l_to_plot_idx],
                             l_to_plot[channel], rec.patch_size)

    return rec