예제 #1
0
class ElementPlot(PlotlyPlot, GenericElementPlot):

    aspect = param.Parameter(default='cube',
                             doc="""
        The aspect ratio mode of the plot. By default, a plot may
        select its own appropriate aspect ratio but sometimes it may
        be necessary to force a square aspect ratio (e.g. to display
        the plot as an element of a grid). The modes 'auto' and
        'equal' correspond to the axis modes of the same name in
        matplotlib, a numeric value may also be passed.""")

    bgcolor = param.ClassSelector(class_=(str, tuple),
                                  default=None,
                                  doc="""
        If set bgcolor overrides the background color of the axis.""")

    invert_axes = param.ObjectSelector(default=False,
                                       doc="""
        Inverts the axes of the plot. Note that this parameter may not
        always be respected by all plots but should be respected by
        adjoined plots when appropriate.""")

    invert_xaxis = param.Boolean(default=False,
                                 doc="""
        Whether to invert the plot x-axis.""")

    invert_yaxis = param.Boolean(default=False,
                                 doc="""
        Whether to invert the plot y-axis.""")

    invert_zaxis = param.Boolean(default=False,
                                 doc="""
        Whether to invert the plot z-axis.""")

    labelled = param.List(default=['x', 'y'],
                          doc="""
        Whether to plot the 'x' and 'y' labels.""")

    logx = param.Boolean(default=False,
                         doc="""
         Whether to apply log scaling to the x-axis of the Chart.""")

    logy = param.Boolean(default=False,
                         doc="""
         Whether to apply log scaling to the y-axis of the Chart.""")

    logz = param.Boolean(default=False,
                         doc="""
         Whether to apply log scaling to the y-axis of the Chart.""")

    margins = param.NumericTuple(default=(50, 50, 50, 50),
                                 doc="""
         Margins in pixel values specified as a tuple of the form
         (left, bottom, right, top).""")

    show_legend = param.Boolean(default=False,
                                doc="""
        Whether to show legend for the plot.""")

    xaxis = param.ObjectSelector(
        default='bottom',
        objects=['top', 'bottom', 'bare', 'top-bare', 'bottom-bare', None],
        doc="""
        Whether and where to display the xaxis, bare options allow suppressing
        all axis labels including ticks and xlabel. Valid options are 'top',
        'bottom', 'bare', 'top-bare' and 'bottom-bare'.""")

    xticks = param.Parameter(default=None,
                             doc="""
        Ticks along x-axis specified as an integer, explicit list of
        tick locations, list of tuples containing the locations and
        labels or a matplotlib tick locator object. If set to None
        default matplotlib ticking behavior is applied.""")

    yaxis = param.ObjectSelector(
        default='left',
        objects=['left', 'right', 'bare', 'left-bare', 'right-bare', None],
        doc="""
        Whether and where to display the yaxis, bare options allow suppressing
        all axis labels including ticks and ylabel. Valid options are 'left',
        'right', 'bare' 'left-bare' and 'right-bare'.""")

    yticks = param.Parameter(default=None,
                             doc="""
        Ticks along y-axis specified as an integer, explicit list of
        tick locations, list of tuples containing the locations and
        labels or a matplotlib tick locator object. If set to None
        default matplotlib ticking behavior is applied.""")

    zlabel = param.String(default=None,
                          doc="""
        An explicit override of the z-axis label, if set takes precedence
        over the dimension label.""")

    trace_type = None

    def initialize_plot(self, ranges=None):
        """
        Initializes a new plot object with the last available frame.
        """
        # Get element key and ranges for frame
        fig = self.generate_plot(self.keys[-1], ranges)
        self.drawn = True
        return fig

    def generate_plot(self, key, ranges):
        element = self._get_frame(key)
        if element is None:
            return self.handles['fig']
        plot_opts = self.lookup_options(element, 'plot').options
        self.set_param(
            **{k: v
               for k, v in plot_opts.items() if k in self.params()})
        self.style = self.lookup_options(element, 'style')

        ranges = self.compute_ranges(self.hmap, key, ranges)
        ranges = util.match_spec(element, ranges)

        data_args, data_kwargs = self.get_data(element, ranges)
        opts = self.graph_options(element, ranges)
        graph = self.init_graph(data_args, dict(opts, **data_kwargs))
        self.handles['graph'] = graph

        layout = self.init_layout(key, element, ranges)
        self.handles['layout'] = layout

        if isinstance(graph, dict) and 'data' in graph:
            merge_figure(graph, {'layout': layout})
            self.handles['fig'] = graph
            return self.handles['fig']
        else:
            if not isinstance(graph, list):
                graph = [graph]
            fig = dict(data=graph, layout=layout)
            self.handles['fig'] = fig
            return fig

    def graph_options(self, element, ranges):
        if self.overlay_dims:
            legend = ', '.join([
                d.pprint_value_string(v) for d, v in self.overlay_dims.items()
            ])
        else:
            legend = element.label

        opts = dict(showlegend=self.show_legend,
                    legendgroup=element.group,
                    name=legend)

        return opts

    def init_graph(self, plot_args, plot_kwargs):
        plot_kwargs['type'] = self.trace_type
        return dict(*plot_args, **plot_kwargs)

    def get_data(self, element, ranges):
        return {}

    def get_aspect(self, xspan, yspan):
        """
        Computes the aspect ratio of the plot
        """
        return self.width / self.height

    def init_layout(self, key, element, ranges, xdim=None, ydim=None):
        l, b, r, t = self.get_extents(element, ranges)

        options = {}

        xdim = element.get_dimension(0) if xdim is None else xdim
        ydim = element.get_dimension(1) if ydim is None else ydim
        xlabel, ylabel, zlabel = self._get_axis_labels([xdim, ydim])

        if self.invert_axes:
            xlabel, ylabel = ylabel, xlabel
            l, b, r, t = b, l, t, r

        if 'x' not in self.labelled:
            xlabel = ''
        if 'y' not in self.labelled:
            ylabel = ''

        if xdim:
            xaxis = dict(range=[l, r], title=xlabel)
            if self.logx:
                xaxis['type'] = 'log'
            options['xaxis'] = xaxis

        if ydim:
            yaxis = dict(range=[b, t], title=ylabel)
            if self.logy:
                yaxis['type'] = 'log'
            options['yaxis'] = yaxis

        l, b, r, t = self.margins
        margin = dict(l=l, r=r, b=b, t=t, pad=4)
        return dict(width=self.width,
                    height=self.height,
                    title=self._format_title(key, separator=' '),
                    plot_bgcolor=self.bgcolor,
                    margin=margin,
                    **options)

    def update_frame(self, key, ranges=None):
        """
        Updates an existing plot with data corresponding
        to the key.
        """
        self.generate_plot(key, ranges)
예제 #2
0
class aggregate(ElementOperation):
    """
    aggregate implements 2D binning for any valid HoloViews Element
    type using datashader. I.e., this operation turns a HoloViews
    Element or overlay of Elements into an hv.Image or an overlay of
    hv.Images by rasterizing it, which provides a fixed-sized
    representation independent of the original dataset size.

    By default it will simply count the number of values in each bin
    but other aggregators can be supplied implementing mean, max, min
    and other reduction operations.

    The bins of the aggregate are defined by the width and height and
    the x_range and y_range. If x_sampling or y_sampling are supplied
    the operation will ensure that a bin is no smaller than theminimum
    sampling distance by reducing the width and height when the zoomed
    in beyond the minimum sampling distance.
    """

    aggregator = param.ClassSelector(class_=ds.reductions.Reduction,
                                     default=ds.count())

    dynamic = param.Boolean(default=True,
                            doc="""
       Enables dynamic processing by default.""")

    height = param.Integer(default=400,
                           doc="""
       The height of the aggregated image in pixels.""")

    width = param.Integer(default=400,
                          doc="""
       The width of the aggregated image in pixels.""")

    x_range = param.NumericTuple(default=None,
                                 length=2,
                                 doc="""
       The x_range as a tuple of min and max x-value. Auto-ranges
       if set to None.""")

    y_range = param.NumericTuple(default=None,
                                 length=2,
                                 doc="""
       The x_range as a tuple of min and max y-value. Auto-ranges
       if set to None.""")

    x_sampling = param.Number(default=None,
                              doc="""
        Specifies the smallest allowed sampling interval along the y-axis.""")

    y_sampling = param.Number(default=None,
                              doc="""
        Specifies the smallest allowed sampling interval along the y-axis.""")

    streams = param.List(default=[RangeXY],
                         doc="""
        List of streams that are applied if dynamic=True, allowing
        for dynamic interaction with the plot.""")

    element_type = param.ClassSelector(class_=(Dataset, ),
                                       instantiate=False,
                                       is_instance=False,
                                       default=GridImage,
                                       doc="""
        The type of the returned Elements, must be a 2D Dataset type.""")

    @classmethod
    def get_agg_data(cls, obj, category=None):
        """
        Reduces any Overlay or NdOverlay of Elements into a single
        xarray Dataset that can be aggregated.
        """
        paths = []
        kdims = obj.kdims
        vdims = obj.vdims
        x, y = obj.dimensions(label=True)[:2]
        is_df = lambda x: isinstance(x, Dataset
                                     ) and x.interface in DF_INTERFACES
        if isinstance(obj, Path):
            glyph = 'line'
            for p in obj.data:
                df = pd.DataFrame(p, columns=obj.dimensions('key', True))
                if isinstance(obj, Contours) and obj.vdims and obj.level:
                    df[obj.vdims[0].name] = p.level
                paths.append(df)
        elif isinstance(obj, CompositeOverlay):
            for key, el in obj.data.items():
                x, y, element, glyph = cls.get_agg_data(el)
                df = element.data if is_df(element) else element.dframe()
                if isinstance(obj, NdOverlay):
                    df = df.assign(
                        **dict(zip(obj.dimensions('key', True), key)))
                paths.append(df)
            kdims += element.kdims
            vdims = element.vdims
        elif isinstance(obj, Element):
            glyph = 'line' if isinstance(obj, Curve) else 'points'
            paths.append(obj.data if is_df(obj) else obj.dframe())
        if len(paths) > 1:
            if glyph == 'line':
                path = paths[0][:1]
                if isinstance(path, dd.DataFrame):
                    path = path.compute()
                empty = path.copy()
                empty.iloc[0, :] = (np.NaN, ) * empty.shape[1]
                paths = [elem for path in paths for elem in (path, empty)][:-1]
            if all(isinstance(path, dd.DataFrame) for path in paths):
                df = dd.concat(paths)
            else:
                paths = [
                    path.compute() if isinstance(path, dd.DataFrame) else path
                    for path in paths
                ]
                df = pd.concat(paths)
        else:
            df = paths[0]
        if category and df[category].dtype.name != 'category':
            df[category] = df[category].astype('category')
        return x, y, Dataset(df, kdims=kdims, vdims=vdims), glyph

    def _process(self, element, key=None):
        agg_fn = self.p.aggregator
        category = agg_fn.column if isinstance(agg_fn, ds.count_cat) else None
        x, y, data, glyph = self.get_agg_data(element, category)

        xstart, xend = self.p.x_range if self.p.x_range else data.range(x)
        ystart, yend = self.p.y_range if self.p.y_range else data.range(y)

        # Compute highest allowed sampling density
        width, height = self.p.width, self.p.height
        if self.p.x_sampling:
            x_range = xend - xstart
            width = int(min([(x_range / self.p.x_sampling), width]))
        if self.p.y_sampling:
            y_range = yend - ystart
            height = int(min([(y_range / self.p.y_sampling), height]))

        cvs = ds.Canvas(plot_width=width,
                        plot_height=height,
                        x_range=(xstart, xend),
                        y_range=(ystart, yend))

        column = agg_fn.column
        if column and isinstance(agg_fn, ds.count_cat):
            name = '%s Count' % agg_fn.column
        else:
            name = column
        vdims = [
            element.get_dimension(column)(name)
            if column else Dimension('Count')
        ]
        params = dict(get_param_values(element),
                      kdims=element.dimensions()[:2],
                      datatype=['xarray'],
                      vdims=vdims)

        agg = getattr(cvs, glyph)(data, x, y, self.p.aggregator)
        if agg.ndim == 2:
            return self.p.element_type(agg, **params)
        else:
            return NdOverlay(
                {
                    c: self.p.element_type(agg.sel(**{column: c}), **params)
                    for c in agg.coords[column].data
                },
                kdims=[data.get_dimension(column)])
예제 #3
0
class shade(Operation):
    """
    shade applies a normalization function followed by colormapping to
    an Image or NdOverlay of Images, returning an RGB Element.
    The data must be in the form of a 2D or 3D DataArray, but NdOverlays
    of 2D Images will be automatically converted to a 3D array.

    In the 2D case data is normalized and colormapped, while a 3D
    array representing categorical aggregates will be supplied a color
    key for each category. The colormap (cmap) may be supplied as an
    Iterable or a Callable.
    """

    cmap = param.ClassSelector(default=fire,
                               class_=(Iterable, Callable, dict),
                               doc="""
        Iterable or callable which returns colors as hex colors.
        Callable type must allow mapping colors between 0 and 1.""")

    normalization = param.ClassSelector(default='eq_hist',
                                        class_=(basestring, Callable),
                                        doc="""
        The normalization operation applied before colormapping.
        Valid options include 'linear', 'log', 'eq_hist', 'cbrt',
        and any valid transfer function that accepts data, mask, nbins
        arguments.""")

    clims = param.NumericTuple(default=None,
                               length=2,
                               doc="""
        Min and max data values to use for colormap interpolation, when
        wishing to override autoranging.
        """)

    link_inputs = param.Boolean(default=True,
                                doc="""
        By default, the link_inputs parameter is set to True so that
        when applying shade, backends that support linked streams
        update RangeXY streams on the inputs of the shade operation.
        Disable when you do not want the resulting plot to be interactive,
        e.g. when trying to display an interactive plot a second time.""")

    @classmethod
    def concatenate(cls, overlay):
        """
        Concatenates an NdOverlay of Image types into a single 3D
        xarray Dataset.
        """
        if not isinstance(overlay, NdOverlay):
            raise ValueError('Only NdOverlays can be concatenated')
        xarr = xr.concat([v.data.T for v in overlay.values()],
                         pd.Index(overlay.keys(), name=overlay.kdims[0].name))
        params = dict(get_param_values(overlay.last),
                      vdims=overlay.last.vdims,
                      kdims=overlay.kdims + overlay.last.kdims)
        return Dataset(xarr.T, datatype=['xarray'], **params)

    @classmethod
    def uint32_to_uint8(cls, img):
        """
        Cast uint32 RGB image to 4 uint8 channels.
        """
        return np.flipud(img.view(dtype=np.uint8).reshape(img.shape + (4, )))

    @classmethod
    def rgb2hex(cls, rgb):
        """
        Convert RGB(A) tuple to hex.
        """
        if len(rgb) > 3:
            rgb = rgb[:-1]
        return "#{0:02x}{1:02x}{2:02x}".format(*(int(v * 255) for v in rgb))

    def _process(self, element, key=None):
        if isinstance(element, NdOverlay):
            bounds = element.last.bounds
            element = self.concatenate(element)
        else:
            bounds = element.bounds

        vdim = element.vdims[0].name
        array = element.data[vdim]
        kdims = element.kdims

        # Compute shading options depending on whether
        # it is a categorical or regular aggregate
        shade_opts = dict(how=self.p.normalization)
        if element.ndims > 2:
            kdims = element.kdims[1:]
            categories = array.shape[-1]
            if not self.p.cmap:
                pass
            elif isinstance(self.p.cmap, dict):
                shade_opts['color_key'] = self.p.cmap
            elif isinstance(self.p.cmap, Iterable):
                shade_opts['color_key'] = [
                    c for i, c in zip(range(categories), self.p.cmap)
                ]
            else:
                colors = [
                    self.p.cmap(s) for s in np.linspace(0, 1, categories)
                ]
                shade_opts['color_key'] = map(self.rgb2hex, colors)
        elif not self.p.cmap:
            pass
        elif isinstance(self.p.cmap, Callable):
            colors = [self.p.cmap(s) for s in np.linspace(0, 1, 256)]
            shade_opts['cmap'] = map(self.rgb2hex, colors)
        else:
            shade_opts['cmap'] = self.p.cmap

        if self.p.clims:
            shade_opts['span'] = self.p.clims
        elif ds_version > '0.5.0' and self.p.normalization != 'eq_hist':
            shade_opts['span'] = element.range(vdim)

        with warnings.catch_warnings():
            warnings.filterwarnings(
                'ignore', r'invalid value encountered in true_divide')
            if np.isnan(array.data).all():
                arr = np.zeros(array.data.shape, dtype=np.uint32)
                img = array.copy()
                img.data = arr
            else:
                img = tf.shade(array, **shade_opts)
        params = dict(get_param_values(element),
                      kdims=kdims,
                      bounds=bounds,
                      vdims=RGB.vdims[:])
        return RGB(self.uint32_to_uint8(img.data), **params)
예제 #4
0
class Plot(param.Parameterized):
    """
    A Plot object returns either a matplotlib figure object (when
    called or indexed) or a matplotlib animation object as
    appropriate. Plots take element objects such as Image,
    Contours or Points as inputs and plots them in the
    appropriate format. As views may vary over time, all plots support
    animation via the anim() method.
    """

    figure_alpha = param.Number(default=1.0, bounds=(0, 1), doc="""
        Alpha of the overall figure background.""")

    figure_bounds = param.NumericTuple(default=(0.15, 0.15, 0.85, 0.85),
                                       doc="""
        The bounds of the overall figure as a 4-tuple of the form
        (left, bottom, right, top), defining the size of the border
        around the subplots.""")

    figure_inches = param.NumericTuple(default=(4, 4), doc="""
        The overall matplotlib figure size in inches.""")

    figure_latex = param.Boolean(default=False, doc="""
        Whether to use LaTeX text in the overall figure.""")

    figure_rcparams = param.Dict(default={}, doc="""
        matplotlib rc parameters to apply to the overall figure.""")

    figure_size = param.Integer(default=100, bounds=(1, None), doc="""
        Size relative to the supplied overall figure_inches in percent.""")

    finalize_hooks = param.HookList(default=[], doc="""
        Optional list of hooks called when finalizing an axis.
        The hook is passed the full set of plot handles and the
        displayed object.""")

    sublabel_format = param.String(default=None, allow_None=True, doc="""
        Allows labeling the subaxes in each plot with various formatters
        including {Alpha}, {alpha}, {numeric} and {roman}.""")

    sublabel_position = param.NumericTuple(default=(-0.35, 0.85), doc="""
         Position relative to the plot for placing the optional subfigure label.""")

    sublabel_size = param.Number(default=18, doc="""
         Size of optional subfigure label.""")

    normalize = param.Boolean(default=True, doc="""
        Whether to compute ranges across all Elements at this level
        of plotting. Allows selecting normalization at different levels
        for nested data containers.""")

    projection = param.ObjectSelector(default=None,
                                      objects=['3d', 'polar', None], doc="""
        The projection of the plot axis, default of None is equivalent to
        2D plot, 3D and polar plots are also supported.""")

    show_frame = param.Boolean(default=True, doc="""
        Whether or not to show a complete frame around the plot.""")

    show_title = param.Boolean(default=True, doc="""
        Whether to display the plot title.""")

    title_format = param.String(default="{label} {group}", doc="""
        The formatting string for the title of this plot.""")

    # A list of matplotlib keyword arguments that may be supplied via a
    # style options object. Each subclass should override this
    # parameter to list every option that works correctly.
    style_opts = []

    # A mapping from ViewableElement types to their corresponding side plot types
    sideplots = {}


    def __init__(self, figure=None, axis=None, dimensions=None, subplots=None,
                 layout_dimensions=None, uniform=True, keys=None, subplot=False,
                 adjoined=None, layout_num=0, **params):
        self.adjoined = adjoined
        self.subplots = subplots
        self.subplot = figure is not None or subplot
        self.dimensions = dimensions
        self.layout_num = layout_num
        self.layout_dimensions = layout_dimensions
        self.keys = keys
        self.uniform = uniform

        self._create_fig = True
        self.drawn = False
        # List of handles to matplotlib objects for animation update
        self.handles = {} if figure is None else {'fig': figure}

        super(Plot, self).__init__(**params)
        size_scale = self.figure_size / 100.
        self.figure_inches = (self.figure_inches[0] * size_scale,
                              self.figure_inches[1] * size_scale)
        self.handles['axis'] = self._init_axis(axis)


    def compute_ranges(self, obj, key, ranges):
        """
        Given an object, a specific key and the normalization options
        this method will find the specified normalization options on
        the appropriate OptionTree, group the elements according to
        the selected normalization option (i.e. either per frame or
        over the whole animation) and finally compute the dimension
        ranges in each group. The new set of ranges is returned.
        """
        all_table = all(isinstance(el, Table) for el in obj.traverse(lambda x: x, [Element]))
        if obj is None or not self.normalize or all_table:
            return OrderedDict()
        # Get inherited ranges
        ranges = {} if ranges is None or self.adjoined else dict(ranges)

        # Get element identifiers from current object and resolve
        # with selected normalization options
        norm_opts = self._get_norm_opts(obj)

        # Traverse displayed object if normalization applies
        # at this level, and ranges for the group have not
        # been supplied from a composite plot
        elements = []
        return_fn = lambda x: x if isinstance(x, Element) else None
        for group, (axiswise, framewise) in norm_opts.items():
            if group in ranges:
                continue # Skip if ranges are already computed
            elif not framewise and not self.adjoined: # Traverse to get all elements
                elements = obj.traverse(return_fn, [group])
            elif key is not None: # Traverse to get elements for each frame
                elements = self._get_frame(key).traverse(return_fn, [group])
            if not axiswise or (not framewise and isinstance(obj, HoloMap)): # Compute new ranges
                self._compute_group_range(group, elements, ranges)
        return ranges


    def _get_norm_opts(self, obj):
        """
        Gets the normalization options for a LabelledData object by
        traversing the object for to find elements and their ids.
        The id is then used to select the appropriate OptionsTree,
        accumulating the normalization options into a dictionary.
        Returns a dictionary of normalization options for each
        element in the tree.
        """
        norm_opts = {}

        # Get all elements' type.group.label specs and ids
        type_val_fn = lambda x: (x.id, (type(x).__name__, sanitize_identifier(x.group, escape=False),
                                        sanitize_identifier(x.label, escape=False))) \
            if isinstance(x, Element) else None
        element_specs = {(idspec[0], idspec[1]) for idspec in obj.traverse(type_val_fn)
                         if idspec is not None}

        # Group elements specs by ID and override normalization
        # options sequentially
        key_fn = lambda x: -1 if x[0] is None else x[0]
        id_groups = groupby(sorted(element_specs, key=key_fn), key_fn)
        for gid, element_spec_group in id_groups:
            gid = None if gid == -1 else gid
            group_specs = [el for _, el in element_spec_group]
            optstree = Store.custom_options.get(gid, Store.options)
            # Get the normalization options for the current id
            # and match against customizable elements
            for opts in optstree:
                path = tuple(opts.path.split('.')[1:])
                applies = any(path == spec[:i] for spec in group_specs
                              for i in range(1, 4))
                if applies and 'norm' in opts.groups:
                    nopts = opts['norm'].options
                    if 'axiswise' in nopts or 'framewise' in nopts:
                        norm_opts.update({path: (opts['norm'].options.get('axiswise', False),
                                                 opts['norm'].options.get('framewise', False))})
        element_specs = [spec for eid, spec in element_specs]
        norm_opts.update({spec: (False, False) for spec in element_specs
                          if not any(spec[1:i] in norm_opts.keys() for i in range(1, 3))})
        return norm_opts


    @staticmethod
    def _compute_group_range(group, elements, ranges):
        # Iterate over all elements in a normalization group
        # and accumulate their ranges into the supplied dictionary.
        elements = [el for el in elements if el is not None]
        for el in elements:
            for dim in el.dimensions(label=True):
                dim_range = el.range(dim)
                if group not in ranges: ranges[group] = OrderedDict()
                if dim in ranges[group]:
                    ranges[group][dim] = find_minmax(ranges[group][dim], dim_range)
                else:
                    ranges[group][dim] = dim_range


    def _get_frame(self, key):
        """
        Required on each Plot type to get the data corresponding
        just to the current frame out from the object.
        """
        pass


    def _frame_title(self, key, group_size=2):
        """
        Returns the formatted dimension group strings
        for a particular frame.
        """
        if self.layout_dimensions is not None:
            dimensions, key = zip(*self.layout_dimensions.items())
        elif not self.uniform or len(self) == 1 or self.layout_num:
            return ''
        else:
            key = key if isinstance(key, tuple) else (key,)
            dimensions = self.dimensions
        dimension_labels = [dim.pprint_value_string(k) for dim, k in
                            zip(dimensions, key)]
        groups = [', '.join(dimension_labels[i*group_size:(i+1)*group_size])
                  for i in range(len(dimension_labels))]
        return '\n '.join(g for g in groups if g)


    def _init_axis(self, axis):
        """
        Return an axis which may need to be initialized from
        a new figure.
        """
        if not self.subplot and self._create_fig:
            rc_params = self.figure_rcparams
            if self.figure_latex:
                rc_params['text.usetex'] = True
            with matplotlib.rc_context(rc=rc_params):
                fig = plt.figure()
                self.handles['fig'] = fig
                l, b, r, t = self.figure_bounds
                fig.subplots_adjust(left=l, bottom=b, right=r, top=t)
                fig.patch.set_alpha(self.figure_alpha)
                fig.set_size_inches(list(self.figure_inches))
                axis = fig.add_subplot(111, projection=self.projection)
                axis.set_aspect('auto')

        return axis


    def _subplot_label(self, axis):
        layout_num = self.layout_num if self.subplot else 1
        if self.sublabel_format and not self.adjoined and layout_num > 0:
            from mpl_toolkits.axes_grid1.anchored_artists import AnchoredText
            labels = {}
            if '{Alpha}' in self.sublabel_format:
                labels['Alpha'] = str(chr(layout_num+64))
            elif '{alpha}' in self.sublabel_format:
                labels['alpha'] = str(chr(layout_num+96))
            elif '{numeric}' in self.sublabel_format:
                labels['numeric'] = self.layout_num
            elif '{Roman}' in self.sublabel_format:
                labels['Roman'] = int_to_roman(layout_num)
            elif '{roman}' in self.sublabel_format:
                labels['roman'] = int_to_roman(layout_num).lower()
            at = AnchoredText(self.sublabel_format.format(**labels), loc=3,
                              bbox_to_anchor=self.sublabel_position, frameon=False,
                              prop=dict(size=self.sublabel_size, weight='bold'),
                              bbox_transform=axis.transAxes)
            at.patch.set_visible(False)
            axis.add_artist(at)


    def _finalize_axis(self, key):
        """
        General method to finalize the axis and plot.
        """
        if 'title' in self.handles:
            self.handles['title'].set_visible(self.show_title)

        self.drawn = True
        if self.subplot:
            return self.handles['axis']
        else:
            plt.draw()
            fig = self.handles['fig']
            plt.close(fig)
            return fig


    def __getitem__(self, frame):
        """
        Get the matplotlib figure at the given frame number.
        """
        if frame > len(self):
            self.warning("Showing last frame available: %d" % len(self))
        if not self.drawn: self.handles['fig'] = self()
        self.update_frame(self.keys[frame])
        return self.handles['fig']


    def anim(self, start=0, stop=None, fps=30):
        """
        Method to return a matplotlib animation. The start and stop
        frames may be specified as well as the fps.
        """
        figure = self()
        anim = animation.FuncAnimation(figure, self.update_frame,
                                       frames=self.keys,
                                       interval = 1000.0/fps)
        # Close the figure handle
        plt.close(figure)
        return anim

    def __len__(self):
        """
        Returns the total number of available frames.
        """
        return len(self.keys)


    def __call__(self, ranges=None):
        """
        Return a matplotlib figure.
        """
        raise NotImplementedError


    def update_frame(self, key, ranges=None):
        """
        Updates the current frame of the plot.
        """
        raise NotImplementedError


    def update_handles(self, axis, view, key, ranges=None):
        """
        Should be called by the update_frame class to update
        any handles on the plot.
        """
        pass
예제 #5
0
class SegmentationModelBase(ModelConfigBase):
    """
    A class that holds all settings that are specific to segmentation models.
    """

    #: The segmentation model architecture to use.
    #: Valid options are defined at :class:`ModelArchitectureConfig`: 'Basic (DeepMedic)', 'UNet3D', 'UNet2D'
    architecture: str = param.String(
        "Basic",
        doc="The model architecture (for example, UNet). Valid options are"
        "UNet3D, UNet2D, Basic (DeepMedic)")

    #: The loss type to use during training.
    #: Valid options are defined at :class:`SegmentationLoss`: "SoftDice", "CrossEntropy", "Focal", "Mixture"
    loss_type: SegmentationLoss = param.ClassSelector(
        default=SegmentationLoss.SoftDice,
        class_=SegmentationLoss,
        instantiate=False,
        doc="The loss_type to use")

    #: List of pairs of weights, loss types and class-weight-power values for use when loss_type is
    #: :attr:`SegmentationLoss.MixtureLoss`".
    mixture_loss_components: Optional[List[MixtureLossComponent]] = param.List(
        None,
        class_=MixtureLossComponent,
        instantiate=False,
        doc=
        "List of pairs of weights, loss types and class-weight-power values for use when loss_type is MixtureLoss"
    )

    #: For weighted loss, power to which to raise the weights per class. If this is None, loss is not weighted.
    loss_class_weight_power: Optional[float] = param.Number(
        None,
        allow_None=True,
        doc="Power to which to raise class weights for loss "
        "function; default value will depend on loss_type")

    #: Gamma value for focal loss: weight for each pixel is posterior likelihood to the power -focal_loss_gamma.
    focal_loss_gamma: float = param.Number(
        1.0,
        doc="Gamma value for focal loss: weight for each pixel is "
        "posterior likelihood to the power -focal_loss_gamma.")

    #: The spacing X, Y, Z expected for all images in the dataset
    dataset_expected_spacing_xyz: Optional[TupleFloat3] = param.NumericTuple(
        None,
        length=3,
        allow_None=True,
        doc="The spacing X, Y, Z expected for all images in the dataset")

    #: The number of feature channels at different stages of the model.
    feature_channels: List[int] = param.List(
        None,
        class_=int,
        bounds=(1, None),
        instantiate=False,
        doc="The number of feature channels at different stages of the model.")

    #: The size of the convolution kernels.
    kernel_size: int = param.Integer(
        3, bounds=(1, None), doc="The size of the convolution kernels.")

    #: The number of image levels used in Unet (in encoding and decoding paths).
    num_downsampling_paths: int = param.Integer(
        4,
        bounds=(1, None),
        instantiate=False,
        doc=
        "The number of levels used in a UNet architecture in encoding and decoding paths."
    )

    #: The size of the random crops that will be drawn from the input images during training. This is also the
    #: input size of the model.
    crop_size: TupleInt3 = IntTuple(
        (1, 1, 1),
        length=3,
        doc="The size of the random crops that will be "
        "drawn from the input images. This is also the "
        "input size of the model.")

    #: The names of the image input channels that the model consumes. These channels must be present in the
    #: dataset.csv file.
    image_channels: List[str] = param.List(
        None,
        class_=str,
        bounds=(1, None),
        instantiate=False,
        doc="The names of the image input channels that the model consumes. "
        "These channels must be present in the dataset.csv file")

    #: The names of the ground truth channels that the model consumes. These channels must be present in the
    #: dataset.csv file
    ground_truth_ids: List[str] = param.List(
        None,
        class_=str,
        bounds=(1, None),
        instantiate=False,
        doc="The names of the ground truth channels that the model consumes. "
        "These channels must be present in the dataset.csv file")

    #: The name of the channel that contains the `inside/outside body` information (to mask out the background).
    #: This channel must be present in the dataset
    mask_id: Optional[str] = param.String(
        None,
        allow_None=True,
        doc="The name of the channel that contains the "
        "`inside/outside body` information."
        "This channel must be present in the dataset")

    #: The type of image normalization that should be applied. Must be None, or of type
    # :attr:`PhotometricNormalizationMethod`: Unchanged, SimpleNorm, MriWindow , CtWindow, TrimmedNorm
    norm_method: PhotometricNormalizationMethod = \
        param.ClassSelector(default=PhotometricNormalizationMethod.CtWindow,
                            class_=PhotometricNormalizationMethod,
                            instantiate=False,
                            doc="The type of image normalization that should be applied. Must be one of None, "
                                "Unchanged, SimpleNorm, MriWindow , CtWindow, TrimmedNorm")

    #: The Window setting for the :attr:`PhotometricNormalizationMethod.CtWindow` normalization.
    window: int = param.Integer(
        600,
        bounds=(0, None),
        doc="The Window setting for the 'CtWindow' normalization.")

    #: The level setting for the :attr:`PhotometricNormalizationMethod.CtWindow` normalization.
    level: int = param.Integer(
        50, doc="The level setting for the 'CtWindow' normalization.")

    #: The value range that image normalization should produce. This is the input range to the network.
    output_range: TupleFloat2 = param.NumericTuple(
        (-1.0, 1.0),
        length=2,
        doc="The value range that image normalization should produce. "
        "This is the input range to the network.")

    #: If true, create additional plots during image normalization.
    debug_mode: bool = param.Boolean(
        False,
        doc="If true, create additional plots during image normalization.")

    #: Tail parameter allows window range to be extended to right, used in
    #: :attr:`PhotometricNormalizationMethod.MriWindow`. The value must be a list with one entry per input channel
    #: if the model has multiple input channels
    tail: List[float] = param.List(
        None,
        class_=float,
        doc=
        "Tail parameter allows window range to be extended to right, Used in MriWindow."
        " The value must be a list with one entry per input channel "
        "if the model has multiple input channels.")

    #: Sharpen parameter specifies number of standard deviations from mean to be included in window range.
    #: Used in :attr:`PhotometricNormalizationMethod.MriWindow`
    sharpen: float = param.Number(
        0.9,
        doc="Sharpen parameter specifies number of standard deviations "
        "from mean to be included in window range. Used in MriWindow")

    #: Percentile at which to trim input distribution prior to normalization. Used in
    #: :attr:`PhotometricNormalizationMethod.TrimmedNorm`
    trim_percentiles: TupleFloat2 = param.NumericTuple(
        (1.0, 99.0),
        length=2,
        doc="Percentile at which to trim input distribution prior "
        "to normalization. Used in TrimmedNorm")

    #: Padding mode to use for training and inference. See :attr:`PaddingMode` for valid options.
    padding_mode: PaddingMode = param.ClassSelector(
        default=PaddingMode.Edge,
        class_=PaddingMode,
        instantiate=False,
        doc="Padding mode to use for training and inference")

    #: The batch size to use for inference forward pass.
    inference_batch_size: int = param.Integer(
        8,
        bounds=(1, None),
        doc="The batch size to use for inference forward pass")

    #: The crop size to use for model testing. If nothing is specified, crop_size parameter is used instead,
    #: i.e. training and testing crop size will be the same.
    test_crop_size: Optional[TupleInt3] = IntTuple(
        None,
        length=3,
        allow_None=True,
        doc="The crop size to use for model testing. "
        "If nothing is specified, "
        "crop_size parameter is used instead, "
        "i.e. training and testing crop size "
        "will be the same.")

    #: The per-class probabilities for picking a center point of a crop.
    class_weights: Optional[List[float]] = param.List(
        None,
        class_=float,
        bounds=(1, None),
        allow_None=True,
        instantiate=False,
        doc="The per-class probabilities for picking a center point of "
        "a crop.")

    #: Layer name hierarchy (parent, child recursive) as by model definition. If None, no activation maps will be saved
    activation_map_layers: Optional[List[str]] = param.List(
        None,
        class_=str,
        allow_None=True,
        bounds=(1, None),
        instantiate=False,
        doc="Layer name hierarchy (parent, child "
        "recursive) as by model definition. If None, "
        "no activation maps will be saved")

    #: The aggregation method to use when testing ensemble models. See :attr: `EnsembleAggregationType` for options.
    ensemble_aggregation_type: EnsembleAggregationType = param.ClassSelector(
        default=EnsembleAggregationType.Average,
        class_=EnsembleAggregationType,
        instantiate=False,
        doc="The aggregation method to use when"
        "testing ensemble models.")

    #: The size of the smoothing kernel in mm to be used for smoothing posteriors before computing the final
    #: segmentations. No smoothing is performed if set to None.
    posterior_smoothing_mm: Optional[TupleInt3] = param.NumericTuple(
        None,
        length=3,
        allow_None=True,
        doc="The size of the smoothing kernel in mm to be "
        "used for smoothing posteriors before "
        "computing the final segmentations. No "
        "smoothing is performed if set to None")

    #: If True save image and segmentations for one image in a batch for each training epoch
    store_dataset_sample: bool = param.Boolean(
        False,
        doc="If True save image and segmentations for one image"
        "in a batch for each training epoch")

    #: List of (name, container) pairs, where name is a descriptive name and container is a Azure ML storage account
    #: container name to be used for statistical comparisons
    comparison_blob_storage_paths: List[Tuple[str, str]] = param.List(
        None,
        class_=tuple,
        allow_None=True,
        doc=
        "List of (name, container) pairs, where name is a descriptive name and container is a "
        "Azure ML storage account container name to be used for statistical comparisons"
    )

    #: List of rules for structures that should be prevented from sharing the same slice.
    #: These are not applied if :attr:`disable_extra_postprocessing` is True.
    #: Parameter should be a list of :attr:`SliceExclusionRule` objects.
    slice_exclusion_rules: List[SliceExclusionRule] = param.List(
        default=[],
        class_=SliceExclusionRule,
        allow_None=False,
        doc=
        "List of rules for structures that should be prevented from sharing the same slice; "
        "not applied if disable_extra_postprocessing is True.")

    #: List of rules for class pairs whose summed probability is used to create the segmentation map from predicted
    #: posterior probabilities.
    #: These are not applied if :attr:`disable_extra_postprocessing` is True.
    #: Parameter should be a list of :attr:`SummedProbabilityRule` objects.
    summed_probability_rules: List[SummedProbabilityRule] = param.List(
        default=[],
        class_=SummedProbabilityRule,
        allow_None=False,
        doc=
        "List of rules for class pairs whose summed probability is used to create the segmentation map from "
        "predicted posterior probabilities; not applied if disable_extra_postprocessing is True."
    )

    #: Whether to ignore :attr:`slice_exclusion_rules` and :attr:`summed_probability_rules` even if defined
    disable_extra_postprocessing: bool = param.Boolean(
        False,
        doc=
        "Whether to ignore slice_exclusion_rules and summed_probability_rules even if defined"
    )

    #: User friendly display names to be used for each of the predicted GT classes. Default is ground_truth_ids if
    #: None provided
    ground_truth_ids_display_names: List[str] = param.List(
        None,
        class_=str,
        bounds=(1, None),
        instantiate=False,
        allow_None=True,
        doc="User friendly display names to be used for each of "
        "the predicted GT classes. Default is ground_truth_ids "
        "if None provided")

    #: Colours in (R, G, B) for the structures, same order as in ground_truth_ids_display_names
    colours: List[TupleInt3] = param.List(
        None,
        class_=tuple,
        bounds=(1, None),
        instantiate=False,
        allow_None=True,
        doc="Colours in (R, G, B) for the structures, same order as in "
        "ground_truth_ids_display_names")

    #: List of bool specifiying if structures need filling holes. If True, the output of the model for that class
    #: will include postprocessing to fill holes, in the same order as in ground_truth_ids_display_names
    fill_holes: List[bool] = param.List(
        None,
        class_=bool,
        bounds=(1, None),
        instantiate=False,
        allow_None=True,
        doc="List of bool specifiying if structures need filling holes. If True "
        "output of the model for that class includes postprocessing to fill holes, "
        "in the same order as in ground_truth_ids_display_names")

    roi_interpreted_types: List[str] = param.List(
        None,
        class_=str,
        bounds=(1, None),
        instantiate=False,
        allow_None=True,
        doc="List of str with the ROI interpreted Types. Possible values "
        "(None, CTV, ORGAN, EXTERNAL)")

    interpreter: str = param.String(
        "Default_Interpreter",
        doc="The interpreter that created the DICOM-RT file")

    manufacturer: str = param.String(
        "Default_Manufacturer",
        doc="The manufacturer that created the DICOM-RT file")

    _inference_stride_size: Optional[TupleInt3] = IntTuple(
        None,
        length=3,
        allow_None=True,
        doc="The stride size in the inference pipeline. "
        "At most, this should be the output_size to "
        "avoid gaps in output posterior image. If it "
        "is not specified, its value is set to "
        "output size.")
    _center_size: Optional[TupleInt3] = IntTuple(None,
                                                 length=3,
                                                 allow_None=True)
    _train_output_size: Optional[TupleInt3] = IntTuple(None,
                                                       length=3,
                                                       allow_None=True)
    _test_output_size: Optional[TupleInt3] = IntTuple(None,
                                                      length=3,
                                                      allow_None=True)

    #: Dictionary of types to enforce for certain DataFrame columns, where key is column name and value is desired type.
    col_type_converters: Optional[Dict[str, Any]] = param.Dict(
        None,
        doc="Dictionary of types to enforce for certain "
        "DataFrame columns, where key is column name "
        "and value is desired type.",
        allow_None=True,
        instantiate=False)

    _largest_connected_component_foreground_classes: LARGEST_CC_TYPE = \
        param.List(None, class_=None, bounds=(1, None), instantiate=False, allow_None=True,
                   doc="The names of the ground truth channels for which to select the largest connected component in "
                       "the model predictions as an inference post-processing step. Alternatively, a member of the "
                       "list can be a tuple (name, threshold), where name is a channel name and threshold is a value "
                       "between 0 and 0.5 such that disconnected components will be kept if their volume (relative "
                       "to the whole structure) exceeds that value.")

    #: If true, various overview plots with results are generated during model evaluation. Set to False if you see
    #: non-deterministic pull request build failures.
    is_plotting_enabled: bool = param.Boolean(
        True,
        doc="If true, various overview plots with results are generated "
        "during model evaluation. Set to False if you see "
        "non-deterministic pull request build failures.")
    show_patch_sampling: int = param.Integer(
        1,
        bounds=(0, None),
        doc="Number of patients from the training set for which the effect of"
        "patch sampling will be shown. Nifti images and thumbnails for each"
        "of the first N subjects in the training set will be "
        "written to the outputs folder.")

    def __init__(self,
                 center_size: Optional[TupleInt3] = None,
                 inference_stride_size: Optional[TupleInt3] = None,
                 min_l_rate: float = 0,
                 largest_connected_component_foreground_classes:
                 LARGEST_CC_TYPE = None,
                 **params: Any):
        super().__init__(**params)
        self.test_crop_size = self.test_crop_size if self.test_crop_size is not None else self.crop_size
        self.inference_stride_size = inference_stride_size
        self.min_l_rate = min_l_rate
        self.largest_connected_component_foreground_classes = largest_connected_component_foreground_classes
        self._center_size = center_size
        self._model_category = ModelCategory.Segmentation

    def validate(self) -> None:
        """
        Validates the parameters stored in the present object.
        """
        super().validate()
        check_is_any_of("Architecture", self.architecture,
                        vars(ModelArchitectureConfig).keys())

        def len_or_zero(lst: Optional[List[Any]]) -> int:
            return 0 if lst is None else len(lst)

        if self.kernel_size % 2 == 0:
            raise ValueError(
                "The kernel size must be an odd number (kernel_size: {})".
                format(self.kernel_size))

        if self.architecture != ModelArchitectureConfig.UNet3D:
            if any_pairwise_larger(self.center_size, self.crop_size):
                raise ValueError(
                    "Each center_size should be less than or equal to the crop_size "
                    "(center_size: {}, crop_size: {}".format(
                        self.center_size, self.crop_size))
        else:
            if self.crop_size != self.center_size:
                raise ValueError(
                    "For UNet3D, the center size of each dimension should be equal to the crop size "
                    "(center_size: {}, crop_size: {}".format(
                        self.center_size, self.crop_size))

        self.validate_inference_stride_size(self.inference_stride_size,
                                            self.get_output_size())

        # check to make sure there is no overlap between image and ground-truth channels
        image_gt_intersect = np.intersect1d(self.image_channels,
                                            self.ground_truth_ids)
        if len(image_gt_intersect) != 0:
            raise ValueError(
                "Channels: {} were found in both image_channels, and ground_truth_ids"
                .format(image_gt_intersect))

        valid_norm_methods = [
            method.value for method in PhotometricNormalizationMethod
        ]
        check_is_any_of("norm_method", self.norm_method.value,
                        valid_norm_methods)

        if len(self.trim_percentiles
               ) < 2 or self.trim_percentiles[0] >= self.trim_percentiles[1]:
            raise ValueError(
                "Thresholds should contain lower and upper percentile thresholds, but got: {}"
                .format(self.trim_percentiles))

        if len_or_zero(self.class_weights) != (
                len_or_zero(self.ground_truth_ids) + 1):
            raise ValueError(
                "class_weights needs to be equal to number of ground_truth_ids + 1"
            )
        if self.class_weights is None:
            raise ValueError("class_weights must be set.")
        SegmentationModelBase.validate_class_weights(self.class_weights)
        if self.ground_truth_ids is None:
            raise ValueError("ground_truth_ids is None")
        if len(self.ground_truth_ids_display_names) != len(
                self.ground_truth_ids):
            raise ValueError(
                "len(ground_truth_ids_display_names)!=len(ground_truth_ids)")
        if len(self.ground_truth_ids_display_names) != len(self.colours):
            raise ValueError(
                "len(ground_truth_ids_display_names)!=len(colours)")
        if len(self.ground_truth_ids_display_names) != len(self.fill_holes):
            raise ValueError(
                "len(ground_truth_ids_display_names)!=len(fill_holes)")
        if self.mean_teacher_alpha is not None:
            raise ValueError(
                "Mean teacher model is currently only supported for ScalarModels."
                "Please reset mean_teacher_alpha to None.")
        if not self.disable_extra_postprocessing:
            if self.slice_exclusion_rules is not None:
                for rule in self.slice_exclusion_rules:
                    rule.validate(self.ground_truth_ids)
            if self.summed_probability_rules is not None:
                for rule in self.summed_probability_rules:
                    rule.validate(self.ground_truth_ids)

    @staticmethod
    def validate_class_weights(class_weights: List[float]) -> None:
        """
        Checks that the given list of class weights is valid: The weights must be positive and add up to 1.0.
        Raises a ValueError if that is not the case.
        """
        if not isclose(sum(class_weights), 1.0):
            raise ValueError(
                f'class_weights needs to add to 1 but it was: {sum(class_weights)}'
            )
        if np.any(np.array(class_weights) < 0):
            raise ValueError(
                "class_weights must have non-negative values only, found: {}".
                format(class_weights))

    @staticmethod
    def validate_inference_stride_size(
            inference_stride_size: Optional[TupleInt3],
            output_size: Optional[TupleInt3]) -> None:
        """
        Checks that patch stride size is positive and smaller than output patch size to ensure that posterior
        predictions are obtained for all pixels
        """
        if inference_stride_size is not None:
            if any_smaller_or_equal_than(inference_stride_size, 0):
                raise ValueError(
                    "inference_stride_size must be > 0 in all dimensions, found: {}"
                    .format(inference_stride_size))

            if output_size is not None:
                if any_pairwise_larger(inference_stride_size, output_size):
                    raise ValueError(
                        "inference_stride_size must be <= output_size in all dimensions"
                        "Found: output_size={}, inference_stride_size={}".
                        format(output_size, inference_stride_size))

    @property
    def number_of_image_channels(self) -> int:
        """
        Gets the number of image input channels that the model has (usually 1 CT channel, or multiple MR).
        """
        return 0 if self.image_channels is None else len(self.image_channels)

    @property
    def number_of_classes(self) -> int:
        """
        Returns the number of ground truth ids, including the background class.
        """
        return 1 if self.ground_truth_ids is None else len(
            self.ground_truth_ids) + 1

    @property
    def center_size(self) -> TupleInt3:
        """
        Gets the size of the center crop that the model predicts.
        """
        if self._center_size is None:
            return get_center_size(arch=self.architecture,
                                   crop_size=self.crop_size)
        Warning(
            "'center_size' argument will soon be deprecated. Output shapes are inferred from models on the fly."
        )
        return self._center_size

    @property
    def inference_stride_size(self) -> Optional[TupleInt3]:
        """
        Gets the stride size that should be used when stitching patches at inference time.
        """
        if self._inference_stride_size is None:
            return self.get_output_size(ModelExecutionMode.TEST)
        return self._inference_stride_size

    @inference_stride_size.setter
    def inference_stride_size(self, val: Optional[TupleInt3]) -> None:
        """
        Sets the inference stride size with given value. This setter is used if output shape needs to be
        determined dynamically at run time
        """
        self._inference_stride_size = val
        self.validate_inference_stride_size(inference_stride_size=val,
                                            output_size=self.get_output_size(
                                                ModelExecutionMode.TEST))

    @property
    def example_images_folder(self) -> Path:
        """
        Gets the full path in which the example images should be stored during training.
        """
        return self.outputs_folder / EXAMPLE_IMAGES_FOLDER

    @property
    def largest_connected_component_foreground_classes(
            self) -> LARGEST_CC_TYPE:
        """
        Gets the list of classes for which the largest connected components should be computed when predicting.
        """
        return self._largest_connected_component_foreground_classes

    @largest_connected_component_foreground_classes.setter
    def largest_connected_component_foreground_classes(
            self, value: LARGEST_CC_TYPE) -> None:
        """
        Sets the list of classes for which the largest connected components should be computed when predicting.
        """
        pairs: Optional[List[Tuple[str, Optional[float]]]] = None
        if value is not None:
            # Set all members to be tuples rather than just class names.
            pairs = [
                val if isinstance(val, tuple) else (val, None) for val in value
            ]
            class_names = set(pair[0] for pair in pairs)
            unknown_labels = class_names - set(self.ground_truth_ids)
            if unknown_labels:
                raise ValueError(
                    f"Found unknown labels {unknown_labels} in largest_connected_component_foreground_classes: "
                    f"labels must exist in [{self.ground_truth_ids}]")
            bad_thresholds = [
                pair[1] for pair in pairs
                if (pair[1] is not None) and (pair[1] <= 0.0 or pair[1] > 0.5)
            ]  # type: ignore
            if bad_thresholds:
                raise ValueError(
                    f"Found bad threshold(s) {bad_thresholds} in largest_connected_component_foreground_classes: "
                    "thresholds must be positive and at most 0.5.")

        self._largest_connected_component_foreground_classes = pairs

    def read_dataset_into_dataframe_and_pre_process(self) -> None:
        """
        Loads a dataset from the dataset_csv file, and stores it in the present object.
        """
        assert self.local_dataset is not None  # for mypy
        self.dataset_data_frame = pd.read_csv(
            self.local_dataset / self.dataset_csv,
            dtype=str,
            converters=self.col_type_converters,
            low_memory=False)
        self.pre_process_dataset_dataframe()

    def get_parameter_search_hyperdrive_config(
            self, run_config: ScriptRunConfig) -> HyperDriveConfig:
        """
        Turns the given AzureML estimator (settings for running a job in AzureML) into a configuration object
        for doing hyperparameter searches.

        :param run_config: The settings for running a single AzureML job.
        :return: A HyperDriveConfig object for running multiple AzureML jobs.
        """
        return super().get_parameter_search_hyperdrive_config(run_config)

    def get_model_train_test_dataset_splits(
            self, dataset_df: DataFrame) -> DatasetSplits:
        """
        Computes the training, validation and test splits for the model, from a dataframe that contains
        the full dataset.

        :param dataset_df: A dataframe that contains the full dataset that the model is using.
        :return: An instance of DatasetSplits with dataframes for training, validation and testing.
        """
        return super().get_model_train_test_dataset_splits(dataset_df)

    def get_output_size(
        self,
        execution_mode: ModelExecutionMode = ModelExecutionMode.TRAIN
    ) -> Optional[TupleInt3]:
        """
        Returns shape of model's output tensor for training, validation and testing inference modes
        """
        if (execution_mode
                == ModelExecutionMode.TRAIN) or (execution_mode
                                                 == ModelExecutionMode.VAL):
            return self._train_output_size
        elif execution_mode == ModelExecutionMode.TEST:
            return self._test_output_size
        raise ValueError(
            "Unknown execution mode '{}' for function 'get_output_size'".
            format(execution_mode))

    def set_derived_model_properties(self, model: Any) -> None:
        """
        Updates the model config parameters that depend on how the segmentation model is structured.
        In particular, this computes the model's output size for the training and the inference crops.
        If the inference stride size is not set, then set it to be equal to the size of the inference output patches.
        """
        logging.info(
            f"Computing model output size when fed with training crops of size {self.crop_size}"
        )
        self._train_output_size = model.get_output_shape(
            input_shape=self.crop_size)
        logging.info(
            f"Computing model output size when fed with inference crops of size {self.test_crop_size}"
        )
        self._test_output_size = model.get_output_shape(
            input_shape=self.test_crop_size)
        if self.inference_stride_size is None:
            self.inference_stride_size = self._test_output_size
        else:
            if any_pairwise_larger(self.inference_stride_size,
                                   self._test_output_size):
                raise ValueError(
                    f"The inference stride size {self.inference_stride_size} must be smaller than the "
                    f"model's output size {self._test_output_size} in each dimension."
                )

    def class_and_index_with_background(self) -> Dict[str, int]:
        """
        Returns a dict of class names to indices, including the background class.
        The class index assumes that background is class 0, foreground starts at 1.
        For example, if the ground_truth_ids are ["foo", "bar"], the result
        is {"background": 0, "foo": 1, "bar": 2}

        :return: A dict, one entry for each entry in ground_truth_ids + 1 for the background class.
        """
        classes = {BACKGROUND_CLASS_NAME: 0}
        classes.update({x: i + 1 for i, x in enumerate(self.ground_truth_ids)})
        return classes

    def create_and_set_torch_datasets(self,
                                      for_training: bool = True,
                                      for_inference: bool = True) -> None:
        """
        Creates torch datasets for all model execution modes, and stores them in the object.
        """
        from InnerEye.ML.dataset.cropping_dataset import CroppingDataset
        from InnerEye.ML.dataset.full_image_dataset import FullImageDataset

        dataset_splits = self.get_dataset_splits()
        crop_transforms = self.get_cropped_image_sample_transforms()
        full_image_transforms = self.get_full_image_sample_transforms()
        if for_training:
            self._datasets_for_training = {
                ModelExecutionMode.TRAIN:
                CroppingDataset(
                    self,
                    dataset_splits.train,
                    cropped_sample_transforms=crop_transforms.
                    train,  # type: ignore
                    full_image_sample_transforms=full_image_transforms.train
                ),  # type: ignore
                ModelExecutionMode.VAL:
                CroppingDataset(
                    self,
                    dataset_splits.val,
                    cropped_sample_transforms=crop_transforms.
                    val,  # type: ignore
                    full_image_sample_transforms=full_image_transforms.val
                ),  # type: ignore
            }
        if for_inference:
            self._datasets_for_inference = {
                mode: FullImageDataset(
                    self,
                    dataset_splits[mode],
                    full_image_sample_transforms=full_image_transforms.test
                )  # type: ignore
                for mode in ModelExecutionMode if len(dataset_splits[mode]) > 0
            }

    def create_model(self) -> Any:
        """
        Creates a PyTorch model from the settings stored in the present object.
        :return: The network model as a torch.nn.Module object
        """
        # Use a local import here to avoid reliance on pytorch too early.
        # Return type should be BaseModel, but that would also introduce reliance on pytorch.
        from InnerEye.ML.utils.model_util import build_net
        return build_net(self)

    def get_full_image_sample_transforms(
            self) -> ModelTransformsPerExecutionMode:
        """
        Get transforms to perform on full image samples for each model execution mode.
        By default only PhotometricNormalization is performed.
        """
        from InnerEye.ML.utils.transforms import Compose3D
        from InnerEye.ML.photometric_normalization import PhotometricNormalization

        photometric_transformation = Compose3D(
            transforms=[PhotometricNormalization(self, use_gpu=False)])
        return ModelTransformsPerExecutionMode(
            train=photometric_transformation,
            val=photometric_transformation,
            test=photometric_transformation)

    def get_cropped_image_sample_transforms(
            self) -> ModelTransformsPerExecutionMode:
        """
        Get transforms to perform on cropped samples for each model execution mode.
        By default no transformation is performed.
        """
        return ModelTransformsPerExecutionMode()
예제 #6
0
class ColorbarPlot(ElementPlot):

    clabel = param.String(default=None,
                          doc="""
        An explicit override of the color bar label, if set takes precedence
        over the title key in colorbar_opts.""")

    clim = param.NumericTuple(default=(np.nan, np.nan),
                              length=2,
                              doc="""
       User-specified colorbar axis range limits for the plot, as a tuple (low,high).
       If specified, takes precedence over data and dimension ranges.""")

    colorbar = param.Boolean(default=False,
                             doc="""
        Whether to draw a colorbar.""")

    color_levels = param.ClassSelector(default=None,
                                       class_=(int, list),
                                       doc="""
        Number of discrete colors to use when colormapping or a set of color
        intervals defining the range of values to map each color to.""")

    clipping_colors = param.Dict(default={},
                                 doc="""
        Dictionary to specify colors for clipped values, allows
        setting color for NaN values and for values above and below
        the min and max value. The min, max or NaN color may specify
        an RGB(A) color as a color hex string of the form #FFFFFF or
        #FFFFFFFF or a length 3 or length 4 tuple specifying values in
        the range 0-1 or a named HTML color.""")

    cbar_padding = param.Number(default=0.01,
                                doc="""
        Padding between colorbar and other plots.""")

    cbar_ticks = param.Parameter(default=None,
                                 doc="""
        Ticks along colorbar-axis specified as an integer, explicit
        list of tick locations, list of tuples containing the
        locations and labels or a matplotlib tick locator object. If
        set to None default matplotlib ticking behavior is
        applied.""")

    cbar_width = param.Number(default=0.05,
                              doc="""
        Width of the colorbar as a fraction of the main plot""")

    symmetric = param.Boolean(default=False,
                              doc="""
        Whether to make the colormap symmetric around zero.""")

    _colorbars = {}

    _default_nan = '#8b8b8b'

    def __init__(self, *args, **kwargs):
        super(ColorbarPlot, self).__init__(*args, **kwargs)
        self._cbar_extend = 'neither'

    def _adjust_cbar(self, cbar, label, dim):
        noalpha = math.floor(self.style[self.cyclic_index].get('alpha',
                                                               1)) == 1
        if (cbar.solids and noalpha):
            cbar.solids.set_edgecolor("face")
        cbar.set_label(label)
        if isinstance(self.cbar_ticks, ticker.Locator):
            cbar.ax.yaxis.set_major_locator(self.cbar_ticks)
        elif self.cbar_ticks == 0:
            cbar.set_ticks([])
        elif isinstance(self.cbar_ticks, int):
            locator = ticker.MaxNLocator(self.cbar_ticks)
            cbar.ax.yaxis.set_major_locator(locator)
        elif isinstance(self.cbar_ticks, list):
            if all(isinstance(t, tuple) for t in self.cbar_ticks):
                ticks, labels = zip(*self.cbar_ticks)
            else:
                ticks, labels = zip(*[(t, dim.pprint_value(t))
                                      for t in self.cbar_ticks])
            cbar.set_ticks(ticks)
            cbar.set_ticklabels(labels)

    def _finalize_artist(self, element):
        if self.colorbar:
            dims = [
                h for k, h in self.handles.items() if k.endswith('color_dim')
            ]
            for d in dims:
                self._draw_colorbar(element, d)

    def _draw_colorbar(self, element=None, dimension=None, redraw=True):
        if element is None:
            element = self.hmap.last
        artist = self.handles.get('artist', None)
        fig = self.handles['fig']
        axis = self.handles['axis']
        ax_colorbars, position = ColorbarPlot._colorbars.get(
            id(axis), ([], None))
        specs = [spec[:2] for _, _, spec, _ in ax_colorbars]
        spec = util.get_spec(element)

        if position is None or not redraw:
            if redraw:
                fig.canvas.draw()
            bbox = axis.get_position()
            l, b, w, h = bbox.x0, bbox.y0, bbox.width, bbox.height
        else:
            l, b, w, h = position

        # Get colorbar label
        if isinstance(dimension, dim):
            dimension = dimension.dimension
        dimension = element.get_dimension(dimension)
        if self.clabel:
            label = self.clabel
        elif dimension:
            label = dimension.pprint_label
        elif element.vdims:
            label = element.vdims[0].pprint_label
        elif dimension is None:
            label = ''

        padding = self.cbar_padding
        width = self.cbar_width
        if spec[:2] not in specs:
            offset = len(ax_colorbars)
            scaled_w = w * width
            cax = fig.add_axes([
                l + w + padding + (scaled_w + padding + w * 0.15) * offset, b,
                scaled_w, h
            ])
            cbar = fig.colorbar(artist,
                                cax=cax,
                                ax=axis,
                                extend=self._cbar_extend)
            self._adjust_cbar(cbar, label, dimension)
            self.handles['cax'] = cax
            self.handles['cbar'] = cbar
            ylabel = cax.yaxis.get_label()
            self.handles['bbox_extra_artists'] += [cax, ylabel]
            ax_colorbars.append((artist, cax, spec, label))

        for i, (artist, cax, spec, label) in enumerate(ax_colorbars):
            scaled_w = w * width
            cax.set_position([
                l + w + padding + (scaled_w + padding + w * 0.15) * i, b,
                scaled_w, h
            ])

        ColorbarPlot._colorbars[id(axis)] = (ax_colorbars, (l, b, w, h))

    def _norm_kwargs(self,
                     element,
                     ranges,
                     opts,
                     vdim,
                     values=None,
                     prefix=''):
        """
        Returns valid color normalization kwargs
        to be passed to matplotlib plot function.
        """
        dim_name = dim_range_key(vdim)
        if values is None:
            if isinstance(vdim, dim):
                values = vdim.apply(element, flat=True)
            else:
                expanded = not (
                    isinstance(element, Dataset) and element.interface.multi
                    and (getattr(element, 'level', None) is not None
                         or element.interface.isscalar(element, vdim.name)))
                values = np.asarray(
                    element.dimension_values(vdim, expanded=expanded))

        # Store dimension being colormapped for colorbars
        if prefix + 'color_dim' not in self.handles:
            self.handles[prefix + 'color_dim'] = vdim

        clim = opts.pop(prefix + 'clims', None)

        # check if there's an actual value (not np.nan)
        if clim is None and util.isfinite(self.clim).all():
            clim = self.clim

        if clim is None:
            if not len(values):
                clim = (0, 0)
                categorical = False
            elif values.dtype.kind in 'uif':
                if dim_name in ranges:
                    clim = ranges[dim_name]['combined']
                elif isinstance(vdim, dim):
                    if values.dtype.kind == 'M':
                        clim = values.min(), values.max()
                    elif len(values) == 0:
                        clim = np.NaN, np.NaN
                    else:
                        try:
                            with warnings.catch_warnings():
                                warnings.filterwarnings(
                                    'ignore',
                                    r'All-NaN (slice|axis) encountered')
                                clim = (np.nanmin(values), np.nanmax(values))
                        except:
                            clim = np.NaN, np.NaN
                else:
                    clim = element.range(vdim)
                if self.logz:
                    # Lower clim must be >0 when logz=True
                    # Choose the maximum between the lowest non-zero value
                    # and the overall range
                    if clim[0] == 0:
                        clim = (values[values != 0].min(), clim[1])
                if self.symmetric:
                    clim = -np.abs(clim).max(), np.abs(clim).max()
                categorical = False
            else:
                range_key = dim_range_key(vdim)
                if range_key in ranges and 'factors' in ranges[range_key]:
                    factors = ranges[range_key]['factors']
                else:
                    factors = util.unique_array(values)
                clim = (0, len(factors) - 1)
                categorical = True
        else:
            categorical = values.dtype.kind not in 'uif'

        if self.logz:
            if self.symmetric:
                norm = mpl_colors.SymLogNorm(vmin=clim[0],
                                             vmax=clim[1],
                                             linthresh=clim[1] / np.e)
            else:
                norm = mpl_colors.LogNorm(vmin=clim[0], vmax=clim[1])
            opts[prefix + 'norm'] = norm
        opts[prefix + 'vmin'] = clim[0]
        opts[prefix + 'vmax'] = clim[1]

        cmap = opts.get(prefix + 'cmap', opts.get('cmap', 'viridis'))
        if values.dtype.kind not in 'OSUM':
            ncolors = None
            if isinstance(self.color_levels, int):
                ncolors = self.color_levels
            elif isinstance(self.color_levels, list):
                ncolors = len(self.color_levels) - 1
                if isinstance(cmap, list) and len(cmap) != ncolors:
                    raise ValueError(
                        'The number of colors in the colormap '
                        'must match the intervals defined in the '
                        'color_levels, expected %d colors found %d.' %
                        (ncolors, len(cmap)))
            try:
                el_min, el_max = np.nanmin(values), np.nanmax(values)
            except ValueError:
                el_min, el_max = -np.inf, np.inf
        else:
            ncolors = clim[-1] + 1
            el_min, el_max = -np.inf, np.inf
        vmin = -np.inf if opts[prefix + 'vmin'] is None else opts[prefix +
                                                                  'vmin']
        vmax = np.inf if opts[prefix + 'vmax'] is None else opts[prefix +
                                                                 'vmax']
        if el_min < vmin and el_max > vmax:
            self._cbar_extend = 'both'
        elif el_min < vmin:
            self._cbar_extend = 'min'
        elif el_max > vmax:
            self._cbar_extend = 'max'

        # Define special out-of-range colors on colormap
        colors = {}
        for k, val in self.clipping_colors.items():
            if val == 'transparent':
                colors[k] = {'color': 'w', 'alpha': 0}
            elif isinstance(val, tuple):
                colors[k] = {
                    'color': val[:3],
                    'alpha': val[3] if len(val) > 3 else 1
                }
            elif isinstance(val, util.basestring):
                color = val
                alpha = 1
                if color.startswith('#') and len(color) == 9:
                    alpha = int(color[-2:], 16) / 255.
                    color = color[:-2]
                colors[k] = {'color': color, 'alpha': alpha}

        if not isinstance(cmap, mpl_colors.Colormap):
            if isinstance(cmap, dict):
                factors = util.unique_array(values)
                palette = [
                    cmap.get(
                        f,
                        colors.get('NaN',
                                   {'color': self._default_nan})['color'])
                    for f in factors
                ]
            else:
                palette = process_cmap(cmap, ncolors, categorical=categorical)
                if isinstance(self.color_levels, list):
                    palette, (vmin, vmax) = color_intervals(palette,
                                                            self.color_levels,
                                                            clip=(vmin, vmax))
            cmap = mpl_colors.ListedColormap(palette)
        if 'max' in colors: cmap.set_over(**colors['max'])
        if 'min' in colors: cmap.set_under(**colors['min'])
        if 'NaN' in colors: cmap.set_bad(**colors['NaN'])
        opts[prefix + 'cmap'] = cmap
예제 #7
0
파일: stats.py 프로젝트: zzwei1/holoviews
class univariate_kde(Operation):
    """
    Computes a 1D kernel density estimate (KDE) along the supplied
    dimension. Kernel density estimation is a non-parametric way to
    estimate the probability density function of a random variable.

    The KDE works by placing a Gaussian kernel at each sample with
    the supplied bandwidth. These kernels are then summed to produce
    the density estimate. By default a good bandwidth is determined
    using the bw_method but it may be overridden by an explicit value.
    """

    bw_method = param.ObjectSelector(default='scott',
                                     objects=['scott', 'silverman'],
                                     doc="""
        Method of automatically determining KDE bandwidth""")

    bandwidth = param.Number(default=None,
                             doc="""
        Allows supplying explicit bandwidth value rather than relying on scott or silverman method."""
                             )

    cut = param.Number(default=3,
                       doc="""
        Draw the estimate to cut * bw from the extreme data points.""")

    bin_range = param.NumericTuple(default=None,
                                   length=2,
                                   doc="""
        Specifies the range within which to compute the KDE.""")

    dimension = param.String(default=None,
                             doc="""
        Along which dimension of the Element to compute the KDE.""")

    filled = param.Boolean(default=True,
                           doc="""
        Controls whether to return filled or unfilled KDE.""")

    n_samples = param.Integer(default=100,
                              doc="""
        Number of samples to compute the KDE over.""")

    groupby = param.ClassSelector(default=None,
                                  class_=(basestring, Dimension),
                                  doc="""
      Defines a dimension to group the Histogram returning an NdOverlay of Histograms."""
                                  )

    def _process(self, element, key=None):
        if self.p.groupby:
            if not isinstance(element, Dataset):
                raise ValueError(
                    'Cannot use histogram groupby on non-Dataset Element')
            grouped = element.groupby(self.p.groupby,
                                      group_type=Dataset,
                                      container_type=NdOverlay)
            self.p.groupby = None
            return grouped.map(self._process, Dataset)

        try:
            from scipy import stats
        except ImportError:
            raise ImportError('%s operation requires SciPy to be installed.' %
                              type(self).__name__)

        params = {}
        if isinstance(element, Distribution):
            selected_dim = element.kdims[0]
            if element.group != type(element).__name__:
                params['group'] = element.group
            params['label'] = element.label
            vdim = element.vdims[0]
            vdim_name = '{}_density'.format(selected_dim.name)
            vdim_label = '{} Density'.format(selected_dim.label)
            vdims = [
                vdim(vdim_name, label=vdim_label)
                if vdim.name == 'Density' else vdim
            ]
        else:
            if self.p.dimension:
                selected_dim = element.get_dimension(self.p.dimension)
            else:
                dimensions = element.vdims + element.kdims
                if not dimensions:
                    raise ValueError(
                        "%s element does not declare any dimensions "
                        "to compute the kernel density estimate on." %
                        type(element).__name__)
                selected_dim = dimensions[0]
            vdim_name = '{}_density'.format(selected_dim.name)
            vdim_label = '{} Density'.format(selected_dim.label)
            vdims = [Dimension(vdim_name, label=vdim_label)]

        data = element.dimension_values(selected_dim)
        bin_range = self.p.bin_range or element.range(selected_dim)
        if bin_range == (0, 0) or any(not np.isfinite(r) for r in bin_range):
            bin_range = (0, 1)
        elif bin_range[0] == bin_range[1]:
            bin_range = (bin_range[0] - 0.5, bin_range[1] + 0.5)

        data = data[np.isfinite(data)] if len(data) else []
        if len(data) > 1:
            kde = stats.gaussian_kde(data)
            if self.p.bandwidth:
                kde.set_bandwidth(self.p.bandwidth)
            bw = kde.scotts_factor() * data.std(ddof=1)
            if self.p.bin_range:
                xs = np.linspace(bin_range[0], bin_range[1], self.p.n_samples)
            else:
                xs = _kde_support(bin_range, bw, self.p.n_samples, self.p.cut,
                                  selected_dim.range)
            ys = kde.evaluate(xs)
        else:
            xs = np.linspace(bin_range[0], bin_range[1], self.p.n_samples)
            ys = np.full_like(xs, 0)

        element_type = Area if self.p.filled else Curve
        return element_type((xs, ys),
                            kdims=[selected_dim],
                            vdims=vdims,
                            **params)
예제 #8
0
class ModelCGCAL(ModelGCAL):
    """
    A continuous version of GCAL, changing as few settings as
    necessary to obtain a continuous model of time.
    """

    GCAL_sequence = param.Boolean(default=True,
                                  doc="""
        Whether to match the GCAL training sequence or not (for
        Gaussians).

        Relies on bad, fragile hacks and is anly required for
        comparison purposes and debugging.""")

    continuous = param.Boolean(default=True,
                               doc="""
       Switch between Contunuous GCAL and regular GCAL for debugging
       purposes and to compare behaviour.""")

    timescale = param.Number(default=240.0,
                             constant=True,
                             doc="""
       Multiplicative factor between simulation time and milliseconds:

          milliseconds = topo.sim.time() * timescale.

       NOTE: This is a conversion factor and is NOT topo.sim.time()!""")

    saccade_duration = param.Integer(default=240,
                                     doc="""
        The *maximum* length of saccade i.e the largest possible
        multiple of timestep will be used that fits in this
        duration. This parameter is used to compute a concrete value
        of the saccade_duration property.

        NOTE: The notion of a 'saccade' is what links earlier models
        with activity resets (such as GCAL) to a continuous
        timebase. In short, this parameter corresponds to the temporal
        duration used to present each training stimulus.

        The notion of a saccade is appropriate for primate visual
        development but is not *necessary* in continuous models.""")

    timestep = param.Integer(default=12,
                             doc="""
        The simulation time (milliseconds) used to 'clock' the model.""")

    # Delays and time constants

    lgn_afferent_delay = param.Number(default=12.0,
                                      doc="""
       The afferent delay (milliseconds) from the retina to the LGN.""")

    v1_afferent_delay = param.Number(default=12.0,
                                     doc="""
       The afferent delay (simulation time) from the LGN to V1.""")

    lgn_hysteresis = param.Number(default=0.05,
                                  allow_None=True,
                                  doc="""
        The time constant for the LGN sheet (per millisecond) if supplied.""")

    v1_hysteresis = param.Number(default=None,
                                 allow_None=True,
                                 doc="""
        The time constant for the V1 sheet (per millisecond) if supplied.""")

    # Parameters affecting LGN PSTH profiles

    gain_control_delay = param.Number(default=12.0,
                                      doc="""
        Millisecond felay of lateral gain-control projections in the
        LGN. Primary parameter controlling shape of LGN PSTHs.""")

    gain_control_strength = param.Number(default=0.6,
                                         doc="""
        Strength of lateral gain-control projections in the
        LGN. Controls the overall strength of the gain control.""")

    # Inclusion of 'dr' in the list below is a hack - it is only there
    # to enable the 'period' parameter to the PatternCoordinator when
    # 'period' should be available regardless of the dimension used.
    dims = param.List(['xy', 'or', 'dr'],
                      doc="""
      The addition of 'dr' allows for direction map development (for
      non-zero speeds) and allows the reset period to be customized.""")

    speed = param.Number(default=0.0,
                         doc="""
      The speed of translation of the training patterns.""")

    # Existing parameters tweaked for tuning.

    lgn_aff_strength = param.Number(default=0.7,
                                    bounds=(0.0, None),
                                    doc="""
        Overall strength of the afferent projection to the LGN.

        Note: This parameter overrides the behaviour of the
        strength_factor parameter.""")

    aff_strength = param.Number(default=1.1,
                                doc="""
        Overall strength of the afferent projection to V1.

        Note: This parameter overrides the behaviour of the
        aff_strength parameter.""")

    learning_rate = param.Number(default=1.0,
                                 doc="""
       Overall scaling factor for projection learning rates where a
       value of 1.0 (default) doesn't modify the learning rates
       relative to GCAL.""")

    snapshot_learning = param.NumericTuple(default=None,
                                           allow_None=True,
                                           length=3,
                                           doc="""
                                           Three tuple e.g (240,130,0.051)""")

    kappa_bias = param.Number(default=None, allow_None=True)

    def __init__(self, **params):
        super(ModelCGCAL, self).__init__(**params)
        if (not self.continuous) and self.lgn_aff_strength == 0.7:
            self.lgn_aff_strength = 2.33
            self.warning(
                'Setting lgn_aff_strength from 0.7 to 2.33 as continuous=False'
            )
        if (not self.continuous) and self.lgn_aff_strength == 1.1:
            self.aff_strength = 1.5
            self.warning(
                'Setting aff_strength from 1.1 to 1.5 as continuous=False')

    def property_setup(self, properties):
        properties = super(ModelCGCAL, self).property_setup(properties)

        # In simulation time units
        properties[
            'period'] = self.timestep / self.timescale if self.continuous else 1.0
        properties['steps_per_saccade'] = (self.saccade_duration //
                                           self.timestep)
        properties['saccade_duration'] = (properties['steps_per_saccade'] *
                                          properties['period']
                                          if self.continuous else 1.0)

        properties['lags'] = [
            0
        ]  # Disables multiple lagged projections (dim='dr')
        return properties

    def training_pattern_setup(self, **overrides):
        assert not (self.GCAL_sequence and self.kappa_bias)
        if self.kappa_bias:
            OR_coordinator = VonMisesORCoordinator.instance(
                kappa=self.kappa_bias)
            feature_coordinators = OrderedDict([('xy',
                                                 [XCoordinator, YCoordinator]),
                                                ('or', OR_coordinator)])
            overrides['feature_coordinators'] = feature_coordinators

        return super(ModelCGCAL, self).training_pattern_setup(
            **dict(overrides, reset_period=self['saccade_duration']))

    #============================#
    # Temporal scaling equations #
    #============================#

    def projection_learning_rate(self, rate):
        """
        For afferent, excitatory V1 and inhibitory V1.
        """
        if not self.continuous: return rate
        if self.snapshot_learning: return rate
        return self.learning_rate * (rate / self['steps_per_saccade'])

    def homeostatic_learning_rate(self, rate=0.01):
        """
        A rate of 0.01 is default - can parameterize later if needed.

        Note: The learning rate only needs adjusted for the number of
        steps if homeostasis is continually applied.
        """
        if not self.homeostasis: return 0.0
        elif not self.continuous: return rate
        else: return rate / self['steps_per_saccade']

    def hysteresis_constant(self, constant):
        """
        Compute the *continuous* hysteresis constant for V1 and LGN.

        If constant is set to None, hysteresis is disabled.
        """
        if not self.continuous: return 1.0
        return 1.0 if constant is None else (constant * self.timestep)

    #===================#
    # Sheet definitions #
    #===================#

    @Model.Continuous
    def LGN(self, properties):
        """
        Filters parameters based on sheet type and applies hysteresis
        for continuous models.

        Note that there is a difference between the two orderings:

        * [hysteresis, rectify]: Applies hysteresis on the voltage or
                                 some other real biophysical variable

        * [rectify, hysteresis]: Applies hysteresis to the
                                 positive-valued 'spiking' variable.

        """
        filtered = ['tsettle', 'strict_tsettle'] if self.continuous else []
        params = {
            k: v
            for k, v in super(ModelCGCAL, self).LGN(properties).items()
            if k not in filtered
        }

        lgn_time_constant = (self.hysteresis_constant(self.lgn_hysteresis)
                             if self.continuous else 1.0)
        hysteresis = transferfn.Hysteresis(time_constant=lgn_time_constant)
        return dict(params,
                    output_fns=[transferfn.misc.HalfRectify(), hysteresis])

    @Model.Continuous
    def V1(self, properties):
        parameters = super(ModelCGCAL, self).V1(properties)
        parameters = {
            k: v
            for k, v in parameters.items()
            if k not in (['tsettle'] if self.continuous else [])
        }

        v1_time_constant = (self.hysteresis_constant(self.v1_hysteresis)
                            if self.continuous else 1.0)
        hysteresis = transferfn.Hysteresis(time_constant=v1_time_constant)

        homeostasis = transferfn.misc.HomeostaticResponse(
            period=0.0,
            t_init=self.t_init,
            target_activity=self.target_activity,
            learning_rate=self.homeostatic_learning_rate())

        if self.snapshot_learning:
            parameters['snapshot_learning'] = self.snapshot_learning

        return dict(parameters, output_fns=[homeostasis])

    #========================#
    # Projection definitions #
    #========================#

    @Model.SharedWeightCFProjection
    def afferent(self, src_properties, dest_properties):
        "Names set to avoid duplicate projection names."
        name = 'AfferentOn' if dest_properties[
            'polarity'] == 'On' else 'AfferentOff'
        lgn_aff = super(ModelCGCAL, self).afferent(src_properties,
                                                   dest_properties)
        return dict(lgn_aff,
                    name=name,
                    strength=self.lgn_aff_strength,
                    delay=self.lgn_afferent_delay / self.timescale)

    @Model.CFProjection
    def V1_afferent(self, src_properties, dest_properties):
        "Projection delay and learning rate modified"
        paramlist = super(ModelCGCAL,
                          self).V1_afferent(src_properties, dest_properties)[0]
        return dict(paramlist,
                    delay=self.v1_afferent_delay / self.timescale,
                    strength=self.aff_strength,
                    learning_rate=self.projection_learning_rate(self.aff_lr))

    @Model.SharedWeightCFProjection
    def lateral_gain_control(self, src_properties, dest_properties):
        "Projection delay modified"
        params = super(ModelCGCAL,
                       self).lateral_gain_control(src_properties,
                                                  dest_properties)
        return dict(params,
                    strength=self.gain_control_strength,
                    delay=self.gain_control_delay / self.timescale)

    @Model.CFProjection
    def lateral_excitatory(self, src_properties, dest_properties):
        "Projection delay and learning rate modified"
        params = super(ModelCGCAL,
                       self).lateral_excitatory(src_properties,
                                                dest_properties)
        return dict(params,
                    delay=self.timestep / self.timescale,
                    learning_rate=self.projection_learning_rate(self.exc_lr))

    @Model.CFProjection
    def lateral_inhibitory(self, src_properties, dest_properties):
        "Projection delay and learning rate modified"
        params = super(ModelCGCAL,
                       self).lateral_inhibitory(src_properties,
                                                dest_properties)
        return dict(params,
                    delay=self.timestep / self.timescale,
                    learning_rate=self.projection_learning_rate(self.inh_lr))

    def setup(self, *args, **params):
        spec = super(ModelCGCAL, self).setup(*args, **params)
        if self.GCAL_sequence and self.continuous:
            time_factor = int(self.saccade_duration / self.timescale)
            # By default, should be 240 for TCAL, 1 for GCAL
            apply_GCAL_training_sequence(spec, time_factor)
            return spec
        else:
            return spec
예제 #9
0
class histogram(Operation):
    """
    Returns a Histogram of the input element data, binned into
    num_bins over the bin_range (if specified) along the specified
    dimension.
    """

    bin_range = param.NumericTuple(default=None,
                                   length=2,
                                   doc="""
      Specifies the range within which to compute the bins.""")

    dimension = param.String(default=None,
                             doc="""
      Along which dimension of the Element to compute the histogram.""")

    frequency_label = param.String(default='{dim} Frequency',
                                   doc="""
      Format string defining the label of the frequency dimension of the Histogram."""
                                   )

    groupby = param.ClassSelector(default=None,
                                  class_=(basestring, Dimension),
                                  doc="""
      Defines a dimension to group the Histogram returning an NdOverlay of Histograms."""
                                  )

    individually = param.Boolean(default=True,
                                 doc="""
      Specifies whether the histogram will be rescaled for each Element in a UniformNdMapping."""
                                 )

    log = param.Boolean(default=False,
                        doc="""
      Whether to use base 10 logarithmic samples for the bin edges.""")

    mean_weighted = param.Boolean(default=False,
                                  doc="""
      Whether the weighted frequencies are averaged.""")

    normed = param.ObjectSelector(default=True,
                                  objects=[True, False, 'integral', 'height'],
                                  doc="""
      Controls normalization behavior.  If `True` or `'integral'`, then
      `density=True` is passed to np.histogram, and the distribution
      is normalized such that the integral is unity.  If `False`,
      then the frequencies will be raw counts. If `'height'`, then the
      frequencies are normalized such that the max bin height is unity.""")

    nonzero = param.Boolean(default=False,
                            doc="""
      Whether to use only nonzero values when computing the histogram""")

    num_bins = param.Integer(default=20,
                             doc="""
      Number of bins in the histogram .""")

    weight_dimension = param.String(default=None,
                                    doc="""
       Name of the dimension the weighting should be drawn from""")

    style_prefix = param.String(default=None,
                                allow_None=None,
                                doc="""
      Used for setting a common style for histograms in a HoloMap or AdjointLayout."""
                                )

    def _process(self, view, key=None):
        if self.p.groupby:
            if not isinstance(view, Dataset):
                raise ValueError(
                    'Cannot use histogram groupby on non-Dataset Element')
            grouped = view.groupby(self.p.groupby,
                                   group_type=Dataset,
                                   container_type=NdOverlay)
            self.p.groupby = None
            return grouped.map(self._process, Dataset)

        if self.p.dimension:
            selected_dim = self.p.dimension
        else:
            selected_dim = [d.name for d in view.vdims + view.kdims][0]
        data = np.array(view.dimension_values(selected_dim))
        if self.p.nonzero:
            mask = data > 0
            data = data[mask]
        if self.p.weight_dimension:
            weights = np.array(view.dimension_values(self.p.weight_dimension))
            if self.p.nonzero:
                weights = weights[mask]
        else:
            weights = None

        data = data[np.isfinite(data)]
        hist_range = self.p.bin_range or view.range(selected_dim)
        # Avoids range issues including zero bin range and empty bins
        if hist_range == (0, 0) or any(not np.isfinite(r) for r in hist_range):
            hist_range = (0, 1)
        if self.p.log:
            bin_min = max([abs(hist_range[0]), data[data > 0].min()])
            edges = np.logspace(np.log10(bin_min), np.log10(hist_range[1]),
                                self.p.num_bins + 1)
        else:
            edges = np.linspace(hist_range[0], hist_range[1],
                                self.p.num_bins + 1)
        normed = False if self.p.mean_weighted and self.p.weight_dimension else self.p.normed

        if len(data):
            if normed:
                # This covers True, 'height', 'integral'
                hist, edges = np.histogram(data,
                                           density=True,
                                           range=hist_range,
                                           weights=weights,
                                           bins=edges)
                if normed == 'height':
                    hist /= hist.max()
            else:
                hist, edges = np.histogram(data,
                                           normed=normed,
                                           range=hist_range,
                                           weights=weights,
                                           bins=edges)
                if self.p.weight_dimension and self.p.mean_weighted:
                    hist_mean, _ = np.histogram(data,
                                                density=False,
                                                range=hist_range,
                                                bins=self.p.num_bins)
                    hist /= hist_mean
        else:
            hist = np.zeros(self.p.num_bins)
        hist[np.isnan(hist)] = 0

        params = {}
        if self.p.weight_dimension:
            params['vdims'] = [view.get_dimension(self.p.weight_dimension)]
        else:
            label = self.p.frequency_label.format(dim=selected_dim)
            params['vdims'] = [
                Dimension('{}_frequency'.format(selected_dim), label=label)
            ]

        if view.group != view.__class__.__name__:
            params['group'] = view.group

        return Histogram((hist, edges),
                         kdims=[view.get_dimension(selected_dim)],
                         label=view.label,
                         **params)
예제 #10
0
class ElementPlot(PlotlyPlot, GenericElementPlot):

    aspect = param.Parameter(default='cube', doc="""
        The aspect ratio mode of the plot. By default, a plot may
        select its own appropriate aspect ratio but sometimes it may
        be necessary to force a square aspect ratio (e.g. to display
        the plot as an element of a grid). The modes 'auto' and
        'equal' correspond to the axis modes of the same name in
        matplotlib, a numeric value may also be passed.""")

    bgcolor = param.ClassSelector(class_=(str, tuple), default=None, doc="""
        If set bgcolor overrides the background color of the axis.""")

    invert_axes = param.ObjectSelector(default=False, doc="""
        Inverts the axes of the plot. Note that this parameter may not
        always be respected by all plots but should be respected by
        adjoined plots when appropriate.""")

    invert_xaxis = param.Boolean(default=False, doc="""
        Whether to invert the plot x-axis.""")

    invert_yaxis = param.Boolean(default=False, doc="""
        Whether to invert the plot y-axis.""")

    invert_zaxis = param.Boolean(default=False, doc="""
        Whether to invert the plot z-axis.""")

    labelled = param.List(default=['x', 'y', 'z'], doc="""
        Whether to label the 'x' and 'y' axes.""")

    logx = param.Boolean(default=False, doc="""
         Whether to apply log scaling to the x-axis of the Chart.""")

    logy  = param.Boolean(default=False, doc="""
         Whether to apply log scaling to the y-axis of the Chart.""")

    logz  = param.Boolean(default=False, doc="""
         Whether to apply log scaling to the y-axis of the Chart.""")

    margins = param.NumericTuple(default=(50, 50, 50, 50), doc="""
         Margins in pixel values specified as a tuple of the form
         (left, bottom, right, top).""")

    show_legend = param.Boolean(default=False, doc="""
        Whether to show legend for the plot.""")

    xaxis = param.ObjectSelector(default='bottom',
                                 objects=['top', 'bottom', 'bare', 'top-bare',
                                          'bottom-bare', None], doc="""
        Whether and where to display the xaxis, bare options allow suppressing
        all axis labels including ticks and xlabel. Valid options are 'top',
        'bottom', 'bare', 'top-bare' and 'bottom-bare'.""")

    xticks = param.Parameter(default=None, doc="""
        Ticks along x-axis specified as an integer, explicit list of
        tick locations, list of tuples containing the locations.""")

    yaxis = param.ObjectSelector(default='left',
                                      objects=['left', 'right', 'bare', 'left-bare',
                                               'right-bare', None], doc="""
        Whether and where to display the yaxis, bare options allow suppressing
        all axis labels including ticks and ylabel. Valid options are 'left',
        'right', 'bare' 'left-bare' and 'right-bare'.""")

    yticks = param.Parameter(default=None, doc="""
        Ticks along y-axis specified as an integer, explicit list of
        tick locations, list of tuples containing the locations.""")

    zlabel = param.String(default=None, doc="""
        An explicit override of the z-axis label, if set takes precedence
        over the dimension label.""")

    zticks = param.Parameter(default=None, doc="""
        Ticks along z-axis specified as an integer, explicit list of
        tick locations, list of tuples containing the locations.""")

    trace_kwargs = {}

    _style_key = None

    # Whether vectorized styles are applied per trace
    _per_trace = False

    # Declare which styles cannot be mapped to a non-scalar dimension
    _nonvectorized_styles = []

    def initialize_plot(self, ranges=None):
        """
        Initializes a new plot object with the last available frame.
        """
        # Get element key and ranges for frame
        fig = self.generate_plot(self.keys[-1], ranges)
        self.drawn = True
        return fig


    def generate_plot(self, key, ranges, element=None):
        if element is None:
            element = self._get_frame(key)

        if element is None:
            return self.handles['fig']

        # Set plot options
        plot_opts = self.lookup_options(element, 'plot').options
        self.param.set_param(**{k: v for k, v in plot_opts.items()
                                if k in self.params()})

        # Get ranges
        ranges = self.compute_ranges(self.hmap, key, ranges)
        ranges = util.match_spec(element, ranges)

        # Get style
        self.style = self.lookup_options(element, 'style')
        style = self.style[self.cyclic_index]

        # Get data and options and merge them
        data = self.get_data(element, ranges, style)
        opts = self.graph_options(element, ranges, style)
        graphs = []
        for i, d in enumerate(data):
            # Initialize traces
            traces = self.init_graph(d, opts, index=i)
            graphs.extend(traces)
        self.handles['graphs'] = graphs

        # Initialize layout
        layout = self.init_layout(key, element, ranges)
        self.handles['layout'] = layout

        # Create figure and return it
        self.drawn = True
        fig = dict(data=graphs, layout=layout)
        self.handles['fig'] = fig
        return fig


    def graph_options(self, element, ranges, style):
        if self.overlay_dims:
            legend = ', '.join([d.pprint_value_string(v) for d, v in
                                self.overlay_dims.items()])
        else:
            legend = element.label

        opts = dict(
            showlegend=self.show_legend, legendgroup=element.group,
            name=legend, **self.trace_kwargs)

        if self._style_key is not None:
            styles = self._apply_transforms(element, ranges, style)
            opts[self._style_key] = {STYLE_ALIASES.get(k, k): v
                                     for k, v in styles.items()}
        else:
            opts.update({STYLE_ALIASES.get(k, k): v
                         for k, v in style.items() if k != 'cmap'})

        return opts


    def init_graph(self, data, options, index=0):
        trace = dict(options)
        for k, v in data.items():
            if k in trace and isinstance(trace[k], dict):
                trace[k].update(v)
            else:
                trace[k] = v

        if self._style_key and self._per_trace:
            vectorized = {k: v for k, v in options[self._style_key].items()
                          if isinstance(v, np.ndarray)}
            trace[self._style_key] = dict(trace[self._style_key])
            for s, val in vectorized.items():
                trace[self._style_key][s] = val[index]
        return [trace]


    def get_data(self, element, ranges, style):
        return []


    def get_aspect(self, xspan, yspan):
        """
        Computes the aspect ratio of the plot
        """
        return self.width/self.height


    def _get_axis_dims(self, element):
        """Returns the dimensions corresponding to each axis.

        Should return a list of dimensions or list of lists of
        dimensions, which will be formatted to label the axis
        and to link axes.
        """
        dims = element.dimensions()[:3]
        pad = [None]*max(3-len(dims), 0)
        return dims + pad


    def _apply_transforms(self, element, ranges, style):
        new_style = dict(style)
        for k, v in dict(style).items():
            if isinstance(v, util.basestring):
                if k == 'marker' and v in 'xsdo':
                    continue
                elif v in element:
                    v = dim(v)
                elif any(d==v for d in self.overlay_dims):
                    v = dim([d for d in self.overlay_dims if d==v][0])

            if not isinstance(v, dim):
                continue
            elif (not v.applies(element) and v.dimension not in self.overlay_dims):
                new_style.pop(k)
                self.warning('Specified %s dim transform %r could not be applied, as not all '
                             'dimensions could be resolved.' % (k, v))
                continue

            if len(v.ops) == 0 and v.dimension in self.overlay_dims:
                val = self.overlay_dims[v.dimension]
            else:
                val = v.apply(element, ranges=ranges, flat=True)

            if (not util.isscalar(val) and len(util.unique_array(val)) == 1
                and not 'color' in k):
                val = val[0]

            if not util.isscalar(val):
                if k in self._nonvectorized_styles:
                    element = type(element).__name__
                    raise ValueError('Mapping a dimension to the "{style}" '
                                     'style option is not supported by the '
                                     '{element} element using the {backend} '
                                     'backend. To map the "{dim}" dimension '
                                     'to the {style} use a groupby operation '
                                     'to overlay your data along the dimension.'.format(
                                         style=k, dim=v.dimension, element=element,
                                         backend=self.renderer.backend))

            # If color is not valid colorspec add colormapper
            numeric = isinstance(val, np.ndarray) and val.dtype.kind in 'uifMm'
            if ('color' in k and isinstance(val, np.ndarray) and numeric):
                copts = self.get_color_opts(v, element, ranges, style)
                new_style.pop('cmap', None)
                new_style.update(copts)
            new_style[k] = val
        return new_style


    def init_layout(self, key, element, ranges):
        el = element.traverse(lambda x: x, [Element])
        el = el[0] if el else element

        extent = self.get_extents(element, ranges)

        if len(extent) == 4:
            l, b, r, t = extent
        else:
            l, b, z0, r, t, z1 = extent

        options = {}

        dims = self._get_axis_dims(el)
        if len(dims) > 2:
            xdim, ydim, zdim = dims
        else:
            xdim, ydim = dims
            zdim = None
        xlabel, ylabel, zlabel = self._get_axis_labels(dims)

        if self.invert_axes:
            xlabel, ylabel = ylabel, xlabel
            ydim, xdim = xdim, ydim
            l, b, r, t = b, l, t, r

        if 'x' not in self.labelled:
            xlabel = ''
        if 'y' not in self.labelled:
            ylabel = ''
        if 'z' not in self.labelled:
            zlabel = ''

        if xdim:
            xrange = [r, l] if self.invert_xaxis else [l, r]
            xaxis = dict(range=xrange, title=xlabel)
            if self.logx:
                xaxis['type'] = 'log'
            self._get_ticks(xaxis, self.xticks)
        else:
            xaxis = {}

        if ydim:
            yrange = [t, b] if self.invert_yaxis else [b, t]
            yaxis = dict(range=yrange, title=ylabel)
            if self.logy:
                yaxis['type'] = 'log'
            self._get_ticks(yaxis, self.yticks)
        else:
            yaxis = {}

        if self.projection == '3d':
            scene = dict(xaxis=xaxis, yaxis=yaxis)
            if zdim:
                zrange = [z1, z0] if self.invert_zaxis else [z0, z1]
                zaxis = dict(range=zrange, title=zlabel)
                if self.logz:
                    zaxis['type'] = 'log'
                self._get_ticks(zaxis, self.zticks)
                scene['zaxis'] = zaxis
            if self.aspect == 'cube':
                scene['aspectmode'] = 'cube'
            else:
                scene['aspectmode'] = 'manual'
                scene['aspectratio'] = self.aspect
            options['scene'] = scene
        else:
            l, b, r, t = self.margins
            options['xaxis'] = xaxis
            options['yaxis'] = yaxis
            options['margin'] = dict(l=l, r=r, b=b, t=t, pad=4)

        return dict(width=self.width, height=self.height,
                    title=self._format_title(key, separator=' '),
                    plot_bgcolor=self.bgcolor, **options)

    def _get_ticks(self, axis, ticker):
        axis_props = {}
        if isinstance(ticker, (tuple, list)):
            if all(isinstance(t, tuple) for t in ticker):
                ticks, labels = zip(*ticker)
                labels = [l if isinstance(l, util.basestring) else str(l)
                              for l in labels]
                axis_props['tickvals'] = ticks
                axis_props['ticktext'] = labels
            else:
                axis_props['tickvals'] = ticker
            axis.update(axis_props)

    def update_frame(self, key, ranges=None, element=None):
        """
        Updates an existing plot with data corresponding
        to the key.
        """
        self.generate_plot(key, ranges, element)
예제 #11
0
class ColorbarPlot(ElementPlot):

    clim = param.NumericTuple(default=(np.nan, np.nan), length=2, doc="""
       User-specified colorbar axis range limits for the plot, as a tuple (low,high).
       If specified, takes precedence over data and dimension ranges.""")

    colorbar = param.Boolean(default=False, doc="""
        Whether to display a colorbar.""")

    color_levels = param.ClassSelector(default=None, class_=(int, list), doc="""
        Number of discrete colors to use when colormapping or a set of color
        intervals defining the range of values to map each color to.""")

    colorbar_opts = param.Dict(default={}, doc="""
        Allows setting including borderwidth, showexponent, nticks,
        outlinecolor, thickness, bgcolor, outlinewidth, bordercolor,
        ticklen, xpad, ypad, tickangle...""")

    symmetric = param.Boolean(default=False, doc="""
        Whether to make the colormap symmetric around zero.""")

    def get_color_opts(self, eldim, element, ranges, style):
        opts = {}
        dim_name = dim_range_key(eldim)
        if self.colorbar:
            if isinstance(eldim, dim):
                title = str(eldim) if eldim.ops else str(eldim)[1:-1]
            else:
                title = eldim.pprint_label
            opts['colorbar'] = dict(title=title, **self.colorbar_opts)
            opts['showscale'] = True
        else:
            opts['showscale'] = False

        if eldim:
            auto = False
            if util.isfinite(self.clim).all():
                cmin, cmax = self.clim
            elif dim_name in ranges:
                cmin, cmax = ranges[dim_name]['combined']
            elif isinstance(eldim, dim):
                cmin, cmax = np.nan, np.nan
                auto = True
            else:
                cmin, cmax = element.range(dim_name)
            if self.symmetric:
                cabs = np.abs([cmin, cmax])
                cmin, cmax = -cabs.max(), cabs.max()
        else:
            auto = True
            cmin, cmax = None, None

        cmap = style.pop('cmap', 'viridis')
        colorscale = get_colorscale(cmap, self.color_levels, cmin, cmax)

        # Reduce colorscale length to <= 255 to work around
        # https://github.com/plotly/plotly.js/issues/3699. Plotly.js performs
        # colorscale interpolation internally so reducing the number of colors
        # here makes very little difference to the displayed colorscale.
        #
        # Note that we need to be careful to make sure the first and last
        # colorscale pairs, colorscale[0] and colorscale[-1], are preserved
        # as the first and last in the subsampled colorscale
        if isinstance(colorscale, list) and len(colorscale) > 255:
            last_clr_pair = colorscale[-1]
            step = int(np.ceil(len(colorscale) / 255))
            colorscale = colorscale[0::step]
            colorscale[-1] = last_clr_pair

        if cmin is not None:
            opts['cmin'] = cmin
        if cmax is not None:
            opts['cmax'] = cmax
        opts['cauto'] = auto
        opts['colorscale'] = colorscale
        return opts
예제 #12
0
class GeoAnnotator(param.Parameterized):
    """
    Provides support for drawing polygons and points on top of a map.
    """

    tile_url = param.String(
        default='http://c.tile.openstreetmap.org/{Z}/{X}/{Y}.png',
        doc="URL for the tile source",
        precedence=-1)

    extent = param.NumericTuple(default=(np.nan, ) * 4,
                                doc="""
         Initial extent if no data is provided.""",
                                precedence=-1)

    path_type = param.ClassSelector(default=Polygons,
                                    class_=Path,
                                    is_instance=False,
                                    doc="""
         The element type to draw into.""")

    polys = param.ClassSelector(class_=Path,
                                precedence=-1,
                                doc="""
         Polygon or Path element to annotate""")

    points = param.ClassSelector(class_=Points,
                                 precedence=-1,
                                 doc="""
         Point element to annotate""")

    num_points = param.Integer(default=None,
                               doc="""
         Maximum number of points to allow drawing (unlimited by default).""")

    num_polys = param.Integer(default=None,
                              doc="""
         Maximum number of polygons to allow drawing (unlimited by default)."""
                              )

    height = param.Integer(default=500,
                           doc="Height of the plot",
                           precedence=-1)

    width = param.Integer(default=900, doc="Width of the plot", precedence=-1)

    def __init__(self, polys=None, points=None, crs=None, **params):
        super(GeoAnnotator, self).__init__(**params)
        plot_opts = dict(height=self.height, width=self.width)
        self.tiles = WMTS(self.tile_url,
                          extents=self.extent,
                          crs=ccrs.PlateCarree()).opts(plot=plot_opts)
        polys = [] if polys is None else polys
        points = [] if points is None else points
        crs = ccrs.GOOGLE_MERCATOR if crs is None else crs
        self._tools = [CheckpointTool(), RestoreTool(), ClearTool()]
        if not isinstance(polys, Path):
            polys = self.path_type(polys, crs=crs)
        self._init_polys(polys)
        if not isinstance(points, Points):
            points = Points(points, self.polys.kdims, crs=crs)
        self._init_points(points)

    @param.depends('polys', watch=True)
    @preprocess
    def _init_polys(self, polys=None):
        opts = dict(tools=self._tools,
                    finalize_hooks=[initialize_tools],
                    color_index=None)
        polys = self.polys if polys is None else polys
        self.polys = polys.options(**opts)
        self.poly_stream = PolyDraw(source=self.polys,
                                    data={},
                                    show_vertices=True,
                                    num_objects=self.num_polys)
        self.vertex_stream = PolyEdit(source=self.polys,
                                      vertex_style={'nonselection_alpha': 0.5})

    @param.depends('points', watch=True)
    @preprocess
    def _init_points(self, points=None):
        opts = dict(tools=self._tools,
                    finalize_hooks=[initialize_tools],
                    color_index=None)
        points = self.points if points is None else points
        self.points = points.options(**opts)
        self.point_stream = PointDraw(source=self.points,
                                      drag=True,
                                      data={},
                                      num_objects=self.num_points)

    def pprint(self):
        params = dict(self.get_param_values())
        name = params.pop('name')
        string = '%s\n%s\n\n' % (name, '-' * len(name))
        for item in sorted(params.items()):
            string += '  %s: %s\n' % (item)
        print(string)

    @param.depends('points', 'polys')
    def map_view(self):
        return self.tiles * self.polys * self.points

    def panel(self):
        return pn.Row(self.map_view)
예제 #13
0
class LayoutPlot(GenericLayoutPlot, CompositePlot):
    """
    A LayoutPlot accepts either a Layout or a NdLayout and
    displays the elements in a cartesian grid in scanline order.
    """

    absolute_scaling = param.ObjectSelector(default=False,
                                            doc="""
      If aspect_weight is enabled absolute_scaling determines whether
      axes are scaled relative to the widest plot or whether the
      aspect scales the axes in absolute terms.""")

    aspect_weight = param.Number(default=0,
                                 doc="""
      Weighting of the individual aspects when computing the Layout
      grid aspects and overall figure size.""")

    fig_bounds = param.NumericTuple(default=(0.05, 0.05, 0.95, 0.95),
                                    doc="""
      The bounds of the figure as a 4-tuple of the form
      (left, bottom, right, top), defining the size of the border
      around the subplots.""")

    tight = param.Boolean(default=False,
                          doc="""
      Tightly fit the axes in the layout within the fig_bounds
      and tight_padding.""")

    tight_padding = param.Parameter(default=3,
                                    doc="""
      Integer or tuple specifying the padding in inches in a tight layout.""")

    hspace = param.Number(default=0.5,
                          doc="""
      Specifies the space between horizontally adjacent elements in the grid.
      Default value is set conservatively to avoid overlap of subplots.""")

    vspace = param.Number(default=0.1,
                          doc="""
      Specifies the space between vertically adjacent elements in the grid.
      Default value is set conservatively to avoid overlap of subplots.""")

    fontsize = param.Parameter(default={'title': 16}, allow_None=True)

    def __init__(self, layout, **params):
        super(LayoutPlot, self).__init__(layout=layout, **params)
        self.subplots, self.subaxes, self.layout = self._compute_gridspec(
            layout)

    def _compute_gridspec(self, layout):
        """
        Computes the tallest and widest cell for each row and column
        by examining the Layouts in the GridSpace. The GridSpec is then
        instantiated and the LayoutPlots are configured with the
        appropriate embedded layout_types. The first element of the
        returned tuple is a dictionary of all the LayoutPlots indexed
        by row and column. The second dictionary in the tuple supplies
        the grid indicies needed to instantiate the axes for each
        LayoutPlot.
        """
        layout_items = layout.grid_items()
        layout_dimensions = layout.kdims if isinstance(layout,
                                                       NdLayout) else None

        layouts = {}
        col_widthratios, row_heightratios = {}, {}
        for (r, c) in self.coords:
            # Get view at layout position and wrap in AdjointLayout
            _, view = layout_items.get((r, c), (None, None))
            layout_view = view if isinstance(
                view, AdjointLayout) else AdjointLayout([view])
            layouts[(r, c)] = layout_view

            # Compute shape of AdjointLayout element
            layout_lens = {1: 'Single', 2: 'Dual', 3: 'Triple'}
            layout_type = layout_lens[len(layout_view)]

            # Get aspects
            main = layout_view.main
            main = main.last if isinstance(main, HoloMap) else main
            main_options = self.lookup_options(main,
                                               'plot').options if main else {}
            if main and not isinstance(main_options.get('aspect', 1),
                                       basestring):
                main_aspect = np.nan if isinstance(
                    main, Empty) else main_options.get('aspect', 1)
                main_aspect = self.aspect_weight * main_aspect + 1 - self.aspect_weight
            else:
                main_aspect = np.nan

            if layout_type in ['Dual', 'Triple']:
                el = layout_view.get('right', None)
                eltype = type(el)
                if el and eltype in MPLPlot.sideplots:
                    plot_type = MPLPlot.sideplots[type(el)]
                    ratio = 0.6 * (plot_type.subplot_size +
                                   plot_type.border_size)
                    width_ratios = [4, 4 * ratio]
                else:
                    width_ratios = [4, 1]
            else:
                width_ratios = [4]

            inv_aspect = 1. / main_aspect if main_aspect else np.NaN
            if layout_type in ['Embedded Dual', 'Triple']:
                el = layout_view.get('top', None)
                eltype = type(el)
                if el and eltype in MPLPlot.sideplots:
                    plot_type = MPLPlot.sideplots[type(el)]
                    ratio = 0.6 * (plot_type.subplot_size +
                                   plot_type.border_size)
                    height_ratios = [4 * ratio, 4]
                else:
                    height_ratios = [1, 4]
            else:
                height_ratios = [4]

            if not isinstance(main_aspect, (basestring, type(None))):
                width_ratios = [
                    wratio * main_aspect for wratio in width_ratios
                ]
                height_ratios = [
                    hratio * inv_aspect for hratio in height_ratios
                ]
            layout_shape = (len(width_ratios), len(height_ratios))

            # For each row and column record the width and height ratios
            # of the LayoutPlot with the most horizontal or vertical splits
            # and largest aspect
            prev_heights = row_heightratios.get(r, (0, []))
            if layout_shape[1] > prev_heights[0]:
                row_heightratios[r] = [layout_shape[1], prev_heights[1]]
            row_heightratios[r][1].append(height_ratios)

            prev_widths = col_widthratios.get(c, (0, []))
            if layout_shape[0] > prev_widths[0]:
                col_widthratios[c] = (layout_shape[0], prev_widths[1])
            col_widthratios[c][1].append(width_ratios)

        col_splits = [v[0] for _c, v in sorted(col_widthratios.items())]
        row_splits = [v[0] for _r, v in sorted(row_heightratios.items())]

        widths = np.array([
            r for col in col_widthratios.values() for ratios in col[1]
            for r in ratios
        ]) / 4

        wr_unnormalized = compute_ratios(col_widthratios, False)
        hr_list = compute_ratios(row_heightratios)
        wr_list = compute_ratios(col_widthratios)

        # Compute the number of rows and cols
        cols, rows = len(wr_list), len(hr_list)

        wr_list = [r if np.isfinite(r) else 1 for r in wr_list]
        hr_list = [r if np.isfinite(r) else 1 for r in hr_list]

        width = sum([r if np.isfinite(r) else 1 for r in wr_list])
        yscale = width / sum([(1 / v) * 4 if np.isfinite(v) else 4
                              for v in wr_unnormalized])
        if self.absolute_scaling:
            width = width * np.nanmax(widths)

        xinches, yinches = None, None
        if not isinstance(self.fig_inches, (tuple, list)):
            xinches = self.fig_inches * width
            yinches = xinches / yscale
        elif self.fig_inches[0] is None:
            xinches = self.fig_inches[1] * yscale
            yinches = self.fig_inches[1]
        elif self.fig_inches[1] is None:
            xinches = self.fig_inches[0]
            yinches = self.fig_inches[0] / yscale
        if xinches and yinches:
            self.handles['fig'].set_size_inches([xinches, yinches])

        self.gs = gridspec.GridSpec(rows,
                                    cols,
                                    width_ratios=wr_list,
                                    height_ratios=hr_list,
                                    wspace=self.hspace,
                                    hspace=self.vspace)

        # Situate all the Layouts in the grid and compute the gridspec
        # indices for all the axes required by each LayoutPlot.
        gidx = 0
        layout_count = 0
        tight = self.tight
        collapsed_layout = layout.clone(shared_data=False, id=layout.id)
        frame_ranges = self.compute_ranges(layout, None, None)
        frame_ranges = OrderedDict([
            (key, self.compute_ranges(layout, key, frame_ranges))
            for key in self.keys
        ])
        layout_subplots, layout_axes = {}, {}
        for r, c in self.coords:
            # Compute the layout type from shape
            wsplits = col_splits[c]
            hsplits = row_splits[r]
            if (wsplits, hsplits) == (1, 1):
                layout_type = 'Single'
            elif (wsplits, hsplits) == (2, 1):
                layout_type = 'Dual'
            elif (wsplits, hsplits) == (1, 2):
                layout_type = 'Embedded Dual'
            elif (wsplits, hsplits) == (2, 2):
                layout_type = 'Triple'

            # Get the AdjoinLayout at the specified coordinate
            view = layouts[(r, c)]
            positions = AdjointLayoutPlot.layout_dict[layout_type]

            # Create temporary subplots to get projections types
            # to create the correct subaxes for all plots in the layout
            _, _, projs = self._create_subplots(layouts[(r, c)],
                                                positions,
                                                None,
                                                frame_ranges,
                                                create=False)
            gidx, gsinds = self.grid_situate(gidx, layout_type, cols)

            layout_key, _ = layout_items.get((r, c), (None, None))
            if isinstance(layout, NdLayout) and layout_key:
                layout_dimensions = OrderedDict(
                    zip(layout_dimensions, layout_key))

            # Generate the axes and create the subplots with the appropriate
            # axis objects, handling any Empty objects.
            obj = layouts[(r, c)]
            empty = isinstance(obj.main, Empty)
            if empty:
                obj = AdjointLayout([])
            else:
                layout_count += 1
            subaxes = [
                plt.subplot(self.gs[ind], projection=proj)
                for ind, proj in zip(gsinds, projs)
            ]
            subplot_data = self._create_subplots(
                obj,
                positions,
                layout_dimensions,
                frame_ranges,
                dict(zip(positions, subaxes)),
                num=0 if empty else layout_count)
            subplots, adjoint_layout, _ = subplot_data
            layout_axes[(r, c)] = subaxes

            # Generate the AdjointLayoutsPlot which will coordinate
            # plotting of AdjointLayouts in the larger grid
            plotopts = self.lookup_options(view, 'plot').options
            layout_plot = AdjointLayoutPlot(adjoint_layout,
                                            layout_type,
                                            subaxes,
                                            subplots,
                                            fig=self.handles['fig'],
                                            **plotopts)
            layout_subplots[(r, c)] = layout_plot
            tight = not any(
                type(p) is GridPlot
                for p in layout_plot.subplots.values()) and tight
            if layout_key:
                collapsed_layout[layout_key] = adjoint_layout

        # Apply tight layout if enabled and incompatible
        # GridPlot isn't present.
        if tight:
            if isinstance(self.tight_padding, (tuple, list)):
                wpad, hpad = self.tight_padding
                padding = dict(w_pad=wpad, h_pad=hpad)
            else:
                padding = dict(w_pad=self.tight_padding,
                               h_pad=self.tight_padding)
            self.gs.tight_layout(self.handles['fig'],
                                 rect=self.fig_bounds,
                                 **padding)

        return layout_subplots, layout_axes, collapsed_layout

    def grid_situate(self, current_idx, layout_type, subgrid_width):
        """
        Situate the current AdjointLayoutPlot in a LayoutPlot. The
        LayoutPlot specifies a layout_type into which the AdjointLayoutPlot
        must be embedded. This enclosing layout is guaranteed to have
        enough cells to display all the views.

        Based on this enforced layout format, a starting index
        supplied by LayoutPlot (indexing into a large gridspec
        arrangement) is updated to the appropriate embedded value. It
        will also return a list of gridspec indices associated with
        the all the required layout axes.
        """
        # Set the layout configuration as situated in a NdLayout

        if layout_type == 'Single':
            start, inds = current_idx + 1, [current_idx]
        elif layout_type == 'Dual':
            start, inds = current_idx + 2, [current_idx, current_idx + 1]

        bottom_idx = current_idx + subgrid_width
        if layout_type == 'Embedded Dual':
            bottom = ((current_idx + 1) % subgrid_width) == 0
            grid_idx = (bottom_idx if bottom else current_idx) + 1
            start, inds = grid_idx, [current_idx, bottom_idx]
        elif layout_type == 'Triple':
            bottom = ((current_idx + 2) % subgrid_width) == 0
            grid_idx = (bottom_idx if bottom else current_idx) + 2
            start, inds = grid_idx, [
                current_idx, current_idx + 1, bottom_idx, bottom_idx + 1
            ]

        return start, inds

    def _create_subplots(self,
                         layout,
                         positions,
                         layout_dimensions,
                         ranges,
                         axes={},
                         num=1,
                         create=True):
        """
        Plot all the views contained in the AdjointLayout Object using axes
        appropriate to the layout configuration. All the axes are
        supplied by LayoutPlot - the purpose of the call is to
        invoke subplots with correct options and styles and hide any
        empty axes as necessary.
        """
        subplots = {}
        projections = []
        adjoint_clone = layout.clone(shared_data=False, id=layout.id)
        subplot_opts = dict(show_title=False, adjoined=layout)
        for pos in positions:
            # Pos will be one of 'main', 'top' or 'right' or None
            view = layout.get(pos, None)
            ax = axes.get(pos, None)
            if view is None:
                projections.append(None)
                continue

            # Determine projection type for plot
            projections.append(self._get_projection(view))

            if not create:
                continue

            # Customize plotopts depending on position.
            plotopts = self.lookup_options(view, 'plot').options

            # Options common for any subplot
            override_opts = {}
            sublabel_opts = {}
            if pos == 'main':
                own_params = self.get_param_values(onlychanged=True)
                sublabel_opts = {
                    k: v
                    for k, v in own_params if 'sublabel_' in k
                }
                if not isinstance(view, GridSpace):
                    override_opts = dict(aspect='square')
            elif pos == 'right':
                right_opts = dict(invert_axes=True, xaxis=None)
                override_opts = dict(subplot_opts, **right_opts)
            elif pos == 'top':
                top_opts = dict(yaxis=None)
                override_opts = dict(subplot_opts, **top_opts)

            # Override the plotopts as required
            plotopts = dict(sublabel_opts, **plotopts)
            plotopts.update(override_opts, fig=self.handles['fig'])
            vtype = view.type if isinstance(view, HoloMap) else view.__class__
            if isinstance(view, GridSpace):
                plotopts['create_axes'] = ax is not None
            if pos == 'main':
                plot_type = Store.registry['matplotlib'][vtype]
            else:
                plot_type = MPLPlot.sideplots[vtype]
            num = num if len(self.coords) > 1 else 0
            subplots[pos] = plot_type(view,
                                      axis=ax,
                                      keys=self.keys,
                                      dimensions=self.dimensions,
                                      layout_dimensions=layout_dimensions,
                                      ranges=ranges,
                                      subplot=True,
                                      uniform=self.uniform,
                                      layout_num=num,
                                      **plotopts)
            if isinstance(view,
                          (Element, HoloMap, Collator, CompositeOverlay)):
                adjoint_clone[pos] = subplots[pos].hmap
            else:
                adjoint_clone[pos] = subplots[pos].layout
        return subplots, adjoint_clone, projections

    def initialize_plot(self):
        key = self.keys[-1]
        ranges = self.compute_ranges(self.layout, key, None)
        for subplot in self.subplots.values():
            subplot.initialize_plot(ranges=ranges)

        # Create title handle
        if self.show_title and len(self.coords) > 1:
            title = self._format_title(key)
            title = self.handles['fig'].suptitle(title,
                                                 **self._fontsize('title'))
            self.handles['title'] = title

        return self._finalize_axis(None)
예제 #14
0
class MPLPlot(DimensionedPlot):
    """
    An MPLPlot object draws a matplotlib figure object when called or
    indexed but can also return a matplotlib animation object as
    appropriate. MPLPlots take element objects such as Image, Contours
    or Points as inputs and plots them in the appropriate format using
    matplotlib. As HoloMaps are supported, all plots support animation
    via the anim() method.
    """

    renderer = MPLRenderer
    sideplots = {}

    fig_alpha = param.Number(default=1.0,
                             bounds=(0, 1),
                             doc="""
        Alpha of the overall figure background.""")

    fig_bounds = param.NumericTuple(default=(0.15, 0.15, 0.85, 0.85),
                                    doc="""
        The bounds of the overall figure as a 4-tuple of the form
        (left, bottom, right, top), defining the size of the border
        around the subplots.""")

    fig_inches = param.Parameter(default=4,
                                 doc="""
        The overall matplotlib figure size in inches.  May be set as
        an integer in which case it will be used to autocompute a
        size. Alternatively may be set with an explicit tuple or list,
        in which case it will be applied directly after being scaled
        by fig_size. If either the width or height is set to None,
        it will be computed automatically.""")

    fig_latex = param.Boolean(default=False,
                              doc="""
        Whether to use LaTeX text in the overall figure.""")

    fig_rcparams = param.Dict(default={},
                              doc="""
        matplotlib rc parameters to apply to the overall figure.""")

    fig_size = param.Integer(default=100,
                             bounds=(1, None),
                             doc="""
        Size relative to the supplied overall fig_inches in percent.""")

    initial_hooks = param.HookList(default=[],
                                   doc="""
        Optional list of hooks called before plotting the data onto
        the axis. The hook is passed the plot object and the displayed
        object, other plotting handles can be accessed via plot.handles.""")

    final_hooks = param.HookList(default=[],
                                 doc="""
        Optional list of hooks called when finalizing an axis.
        The hook is passed the plot object and the displayed
        object, other plotting handles can be accessed via plot.handles.""")

    finalize_hooks = param.HookList(default=[],
                                    doc="""
        Optional list of hooks called when finalizing an axis.
        The hook is passed the plot object and the displayed
        object, other plotting handles can be accessed via plot.handles.""")

    sublabel_format = param.String(default=None,
                                   allow_None=True,
                                   doc="""
        Allows labeling the subaxes in each plot with various formatters
        including {Alpha}, {alpha}, {numeric} and {roman}.""")

    sublabel_position = param.NumericTuple(default=(-0.35, 0.85),
                                           doc="""
         Position relative to the plot for placing the optional subfigure label."""
                                           )

    sublabel_size = param.Number(default=18,
                                 doc="""
         Size of optional subfigure label.""")

    projection = param.Parameter(default=None,
                                 doc="""
        The projection of the plot axis, default of None is equivalent to
        2D plot, '3d' and 'polar' are also supported by matplotlib by default.
        May also supply a custom projection that is either a matplotlib
        projection type or implements the `_as_mpl_axes` method.""")

    show_frame = param.Boolean(default=True,
                               doc="""
        Whether or not to show a complete frame around the plot.""")

    _close_figures = True

    def __init__(self, fig=None, axis=None, **params):
        self._create_fig = True
        super(MPLPlot, self).__init__(**params)
        # List of handles to matplotlib objects for animation update
        scale = self.fig_size / 100.
        if isinstance(self.fig_inches, (tuple, list)):
            self.fig_inches = [
                None if i is None else i * scale for i in self.fig_inches
            ]
        else:
            self.fig_inches *= scale
        fig, axis = self._init_axis(fig, axis)
        self.handles['fig'] = fig
        self.handles['axis'] = axis

        if self.final_hooks and self.finalize_hooks:
            self.warning('Set either final_hooks or deprecated '
                         'finalize_hooks, not both.')
        self.finalize_hooks = self.final_hooks

    def _init_axis(self, fig, axis):
        """
        Return an axis which may need to be initialized from
        a new figure.
        """
        if not fig and self._create_fig:
            rc_params = self.fig_rcparams
            if self.fig_latex:
                rc_params['text.usetex'] = True
            with mpl.rc_context(rc=rc_params):
                fig = plt.figure()
                l, b, r, t = self.fig_bounds
                inches = self.fig_inches
                fig.subplots_adjust(left=l, bottom=b, right=r, top=t)
                fig.patch.set_alpha(self.fig_alpha)
                if isinstance(inches, (tuple, list)):
                    inches = list(inches)
                    if inches[0] is None:
                        inches[0] = inches[1]
                    elif inches[1] is None:
                        inches[1] = inches[0]
                    fig.set_size_inches(list(inches))
                else:
                    fig.set_size_inches([inches, inches])
                axis = fig.add_subplot(111, projection=self.projection)
                axis.set_aspect('auto')

        return fig, axis

    def _subplot_label(self, axis):
        layout_num = self.layout_num if self.subplot else 1
        if self.sublabel_format and not self.adjoined and layout_num > 0:
            from mpl_toolkits.axes_grid1.anchored_artists import AnchoredText
            labels = {}
            if '{Alpha}' in self.sublabel_format:
                labels['Alpha'] = int_to_alpha(layout_num - 1)
            elif '{alpha}' in self.sublabel_format:
                labels['alpha'] = int_to_alpha(layout_num - 1, upper=False)
            elif '{numeric}' in self.sublabel_format:
                labels['numeric'] = self.layout_num
            elif '{Roman}' in self.sublabel_format:
                labels['Roman'] = int_to_roman(layout_num)
            elif '{roman}' in self.sublabel_format:
                labels['roman'] = int_to_roman(layout_num).lower()
            at = AnchoredText(self.sublabel_format.format(**labels),
                              loc=3,
                              bbox_to_anchor=self.sublabel_position,
                              frameon=False,
                              prop=dict(size=self.sublabel_size,
                                        weight='bold'),
                              bbox_transform=axis.transAxes)
            at.patch.set_visible(False)
            axis.add_artist(at)
            self.handles['sublabel'] = at.txt.get_children()[0]

    def _finalize_axis(self, key):
        """
        General method to finalize the axis and plot.
        """
        if 'title' in self.handles:
            self.handles['title'].set_visible(self.show_title)

        self.drawn = True
        if self.subplot:
            return self.handles['axis']
        else:
            fig = self.handles['fig']
            if not getattr(self, 'overlaid', False) and self._close_figures:
                plt.close(fig)
            return fig

    @property
    def state(self):
        return self.handles['fig']

    def anim(self, start=0, stop=None, fps=30):
        """
        Method to return a matplotlib animation. The start and stop
        frames may be specified as well as the fps.
        """
        figure = self.initialize_plot()
        anim = animation.FuncAnimation(figure,
                                       self.update_frame,
                                       frames=self.keys,
                                       interval=1000.0 / fps)
        # Close the figure handle
        if self._close_figures: plt.close(figure)
        return anim

    def update(self, key):
        rc_params = self.fig_rcparams
        if self.fig_latex:
            rc_params['text.usetex'] = True
        mpl.rcParams.update(rc_params)
        if len(self) == 1 and key == 0 and not self.drawn:
            return self.initialize_plot()
        return self.__getitem__(key)
예제 #15
0
class _layout_sankey(Operation):
    """
    Computes a Sankey diagram from a Graph element for internal use in
    the Sankey element constructor.

    Adapted from d3-sankey under BSD-3 license.
    """

    bounds = param.NumericTuple(default=(0, 0, 1000, 500))

    node_width = param.Number(default=15,
                              doc="""
        Width of the nodes.""")

    node_padding = param.Integer(default=10,
                                 doc="""
        Number of pixels of padding relative to the bounds.""")

    iterations = param.Integer(default=32,
                               doc="""
        Number of iterations to run the layout algorithm.""")

    def _process(self, element, key=None):
        nodes, edges, graph = self.layout(element, **self.p)
        params = get_param_values(element)
        return Sankey((element.data, nodes, edges), sankey=graph, **params)

    def layout(self, element, **params):
        self.p = param.ParamOverrides(self, params)
        graph = {'nodes': [], 'links': []}
        self.computeNodeLinks(element, graph)
        self.computeNodeValues(graph)
        self.computeNodeDepths(graph)
        self.computeNodeBreadths(graph)
        self.computeLinkBreadths(graph)
        paths = self.computePaths(graph)

        node_data = []
        for node in graph['nodes']:
            node_data.append((np.mean([node['x0'], node['x1']]),
                              np.mean([node['y0'], node['y1']]),
                              node['index']) + tuple(node['values']))
        if element.nodes.ndims == 3:
            kdims = element.nodes.kdims
        elif element.nodes.ndims:
            kdims = element.node_type.kdims[:2] + element.nodes.kdims[-1:]
        else:
            kdims = element.node_type.kdims
        nodes = element.node_type(node_data,
                                  kdims=kdims,
                                  vdims=element.nodes.vdims)
        edges = element.edge_type(paths)
        return nodes, edges, graph

    def computePaths(self, graph):
        paths = []
        for link in graph['links']:
            source, target = link['source'], link['target']
            x0, y0 = source['x1'], link['y0']
            x1, y1 = target['x0'], link['y1']
            start = np.array([(x0, link['width'] + y0), (x0, y0)])
            src = (x0, y0)
            ctr1 = ((x0 + x1) / 2., y0)
            ctr2 = ((x0 + x1) / 2., y1)
            tgt = (x1, y1)
            bottom = quadratic_bezier(src, tgt, ctr1, ctr2)
            mid = np.array([(x1, y1), (x1, y1 + link['width'])])

            xmid = (x0 + x1) / 2.
            y0 = y0 + link['width']
            y1 = y1 + link['width']
            src = (x1, y1)
            ctr1 = (xmid, y1)
            ctr2 = (xmid, y0)
            tgt = (x0, y0)
            top = quadratic_bezier(src, tgt, ctr1, ctr2)
            spline = np.concatenate([start, bottom, mid, top])
            paths.append(spline)
        return paths

    @classmethod
    def weightedSource(cls, link):
        return cls.nodeCenter(link['source']) * link['value']

    @classmethod
    def weightedTarget(cls, link):
        return cls.nodeCenter(link['target']) * link['value']

    @classmethod
    def nodeCenter(cls, node):
        return (node['y0'] + node['y1']) / 2

    @classmethod
    def ascendingBreadth(cls, a, b):
        return int(a['y0'] - b['y0'])

    @classmethod
    def ascendingSourceBreadth(cls, a, b):
        return cls.ascendingBreadth(a['source'],
                                    b['source']) | a['index'] - b['index']

    @classmethod
    def ascendingTargetBreadth(cls, a, b):
        return cls.ascendingBreadth(a['target'],
                                    b['target']) | a['index'] - b['index']

    @classmethod
    def computeNodeLinks(cls, element, graph):
        """
        Populate the sourceLinks and targetLinks for each node.
        Also, if the source and target are not objects, assume they are indices.
        """
        index = element.nodes.kdims[-1]
        node_map = {}
        if element.nodes.vdims:
            values = zip(*(element.nodes.dimension_values(d)
                           for d in element.nodes.vdims))
        else:
            values = cycle([tuple()])
        for index, vals in zip(element.nodes.dimension_values(index), values):
            node = {
                'index': index,
                'sourceLinks': [],
                'targetLinks': [],
                'values': vals
            }
            graph['nodes'].append(node)
            node_map[index] = node

        links = [element.dimension_values(d) for d in element.dimensions()[:3]]
        for i, (src, tgt, value) in enumerate(zip(*links)):
            source, target = node_map[src], node_map[tgt]
            link = dict(index=i, source=source, target=target, value=value)
            graph['links'].append(link)
            source['sourceLinks'].append(link)
            target['targetLinks'].append(link)

    @classmethod
    def computeNodeValues(cls, graph):
        """
        Compute the value (size) of each node by summing the associated links.
        """
        for node in graph['nodes']:
            source_val = np.sum([l['value'] for l in node['sourceLinks']])
            target_val = np.sum([l['value'] for l in node['targetLinks']])
            node['value'] = max([source_val, target_val])

    def computeNodeDepths(self, graph):
        """
        Iteratively assign the depth (x-position) for each node.
        Nodes are assigned the maximum depth of incoming neighbors plus one;
        nodes with no incoming links are assigned depth zero, while
        nodes with no outgoing links are assigned the maximum depth.
        """
        nodes = graph['nodes']
        depth = 0
        while nodes:
            next_nodes = []
            for node in nodes:
                node['depth'] = depth
                for link in node['sourceLinks']:
                    if not any(link['target'] is node for node in next_nodes):
                        next_nodes.append(link['target'])
            nodes = next_nodes
            depth += 1
            if depth > 10000:
                raise RecursionError(
                    'Sankey diagrams only support acyclic graphs.')

        nodes = graph['nodes']
        depth = 0
        while nodes:
            next_nodes = []
            for node in nodes:
                node['height'] = depth
                for link in node['targetLinks']:
                    if not any(link['source'] is node for node in next_nodes):
                        next_nodes.append(link['source'])
            nodes = next_nodes
            depth += 1
            if depth > 10000:
                raise RecursionError(
                    'Sankey diagrams only support acyclic graphs.')

        x0, _, x1, _ = self.p.bounds
        dx = self.p.node_width
        kx = (x1 - x0 - dx) / (depth - 1)
        for node in graph['nodes']:
            d = node['depth'] if node['sourceLinks'] else depth - 1
            node['x0'] = x0 + max([0, min([depth - 1, np.floor(d)]) * kx])
            node['x1'] = node['x0'] + dx

    def computeNodeBreadths(self, graph):
        node_map = OrderedDict()
        for n in graph['nodes']:
            if n['x0'] not in node_map:
                node_map[n['x0']] = []
            node_map[n['x0']].append(n)

        _, y0, _, y1 = self.p.bounds
        py = self.p.node_padding

        def initializeNodeBreadth():
            kys = []
            for nodes in node_map.values():
                nsum = np.sum([node['value'] for node in nodes])
                ky = (y1 - y0 - (len(nodes) - 1) * py) / nsum
                kys.append(ky)
            ky = np.min(kys)

            for nodes in node_map.values():
                for i, node in enumerate(nodes):
                    node['y0'] = i
                    node['y1'] = i + node['value'] * ky

            for link in graph['links']:
                link['width'] = link['value'] * ky

        def relaxLeftToRight(alpha):
            for nodes in node_map.values():
                for node in nodes:
                    if not node['targetLinks']:
                        continue
                    weighted = sum(
                        [self.weightedSource(l) for l in node['targetLinks']])
                    tsum = sum([l['value'] for l in node['targetLinks']])
                    center = self.nodeCenter(node)
                    dy = (weighted / tsum - center) * alpha
                    node['y0'] += dy
                    node['y1'] += dy

        def relaxRightToLeft(alpha):
            for nodes in list(node_map.values())[::-1]:
                for node in nodes:
                    if not node['sourceLinks']:
                        continue
                    weighted = sum(
                        [self.weightedTarget(l) for l in node['sourceLinks']])
                    tsum = sum([l['value'] for l in node['sourceLinks']])
                    center = self.nodeCenter(node)
                    dy = (weighted / tsum - center) * alpha
                    node['y0'] += dy
                    node['y1'] += dy

        def resolveCollisions():
            for nodes in node_map.values():
                y = y0
                nodes.sort(key=cmp_to_key(self.ascendingBreadth))
                for node in nodes:
                    dy = y - node['y0']
                    if dy > 0:
                        node['y0'] += dy
                        node['y1'] += dy
                    y = node['y1'] + py

                dy = y - py - y1
                if dy > 0:
                    node['y0'] -= dy
                    node['y1'] -= dy
                    y = node['y0']
                    for node in nodes[:-1][::-1]:
                        dy = node['y1'] + py - y
                        if dy > 0:
                            node['y0'] -= dy
                            node['y1'] -= dy
                        y = node['y0']

        initializeNodeBreadth()
        resolveCollisions()
        alpha = 1
        for _ in range(self.p.iterations):
            alpha = alpha * 0.99
            relaxRightToLeft(alpha)
            resolveCollisions()
            relaxLeftToRight(alpha)
            resolveCollisions()

    @classmethod
    def computeLinkBreadths(cls, graph):
        for node in graph['nodes']:
            node['sourceLinks'].sort(
                key=cmp_to_key(cls.ascendingTargetBreadth))
            node['targetLinks'].sort(
                key=cmp_to_key(cls.ascendingSourceBreadth))

        for node in graph['nodes']:
            y0 = y1 = node['y0']
            for link in node['sourceLinks']:
                link['y0'] = y0
                y0 += link['width']
            for link in node['targetLinks']:
                link['y1'] = y1
                y1 += link['width']
예제 #16
0
class decimate(Operation):
    """
    Decimates any column based Element to a specified number of random
    rows if the current view defined by the x_range and y_range
    contains more than max_samples. By default the operation returns a
    DynamicMap with a RangeXY stream allowing dynamic downsampling.
    """

    dynamic = param.Boolean(default=True,
                            doc="""
       Enables dynamic processing by default.""")

    link_inputs = param.Boolean(default=True,
                                doc="""
         By default, the link_inputs parameter is set to True so that
         when applying shade, backends that support linked streams
         update RangeXY streams on the inputs of the shade operation.""")

    max_samples = param.Integer(default=5000,
                                doc="""
        Maximum number of samples to display at the same time.""")

    random_seed = param.Integer(default=42,
                                doc="""
        Seed used to initialize randomization.""")

    streams = param.List(default=[RangeXY],
                         doc="""
        List of streams that are applied if dynamic=True, allowing
        for dynamic interaction with the plot.""")

    x_range = param.NumericTuple(default=None,
                                 length=2,
                                 doc="""
       The x_range as a tuple of min and max x-value. Auto-ranges
       if set to None.""")

    y_range = param.NumericTuple(default=None,
                                 length=2,
                                 doc="""
       The x_range as a tuple of min and max y-value. Auto-ranges
       if set to None.""")

    def _process_layer(self, element, key=None):
        if not isinstance(element, Dataset):
            raise ValueError("Cannot downsample non-Dataset types.")
        if element.interface not in column_interfaces:
            element = element.clone(tuple(element.columns().values()))

        xstart, xend = self.p.x_range if self.p.x_range else element.range(0)
        ystart, yend = self.p.y_range if self.p.y_range else element.range(1)

        # Slice element to current ranges
        xdim, ydim = element.dimensions(label=True)[0:2]
        sliced = element.select(**{xdim: (xstart, xend), ydim: (ystart, yend)})

        if len(sliced) > self.p.max_samples:
            prng = np.random.RandomState(self.p.random_seed)
            return element.iloc[prng.choice(len(sliced), self.p.max_samples,
                                            False)]
        return sliced

    def _process(self, element, key=None):
        return element.map(self._process_layer, Element)
class ScalarModelBase(ModelConfigBase):
    aggregation_type: AggregationType = param.ClassSelector(default=AggregationType.Average, class_=AggregationType,
                                                            doc="The type of global pooling aggregation to use between"
                                                                " the encoder and the classifier.")
    loss_type: ScalarLoss = param.ClassSelector(default=ScalarLoss.BinaryCrossEntropyWithLogits, class_=ScalarLoss,
                                                instantiate=False, doc="The loss_type to use")
    image_channels: List[str] = param.List(class_=str,
                                           doc="Identifies the rows of the dataset file that contain image file paths.")
    image_file_column: Optional[str] = param.String(default=None, allow_None=True,
                                                    doc="The column that contains the path to image files.")
    expected_column_values: List[Tuple[str, str]] = \
        param.List(default=None, doc="List of tuples with column name and expected value to filter rows in the "
        f"{DATASET_CSV_FILE_NAME}",
                   allow_None=True)
    label_channels: Optional[List[str]] = \
        param.List(default=None, allow_None=True,
                   doc="Identifies the row of a dataset file that contains the label value.")
    label_value_column: str = param.String(doc="The column in the dataset file that contains the label value.")
    non_image_feature_channels: Union[List[str], Dict[str, List[str]]] = \
        ListOrDictParam(doc="Specifies the rows of a dataset file from which additional feature values should be read."
                            "The channels can be specified as a List of channels to be used for all non imaging"
                            "features or a as Dict mapping features to specific channels. The helper function"
                            "`get_non_image_features_dict` is available to construct this dictionnary.")
    numerical_columns: List[str] = \
        param.List(class_=str,
                   default=[],
                   doc="Specifies the columns of a dataset file from which additional numerical "
                       "feature values should be read.")
    categorical_columns: List[str] = \
        param.List(class_=str,
                   default=[],
                   doc="Specifies the columns of a dataset file from which additional "
                       "catagorical feature values should be read.")

    subject_column: str = \
        param.String(default=CSV_SUBJECT_HEADER, allow_None=False,
                     doc="The name of the column that contains the patient/subject identifier. Default: 'subject'")
    channel_column: str = \
        param.String(default=CSV_CHANNEL_HEADER, allow_None=False,
                     doc="The name of the column that contains image channel information, for identifying multiple "
                         "rows belonging to the same subject. Default: 'channel'")

    add_differences_for_features: List[str] = \
        param.List(class_=str,
                   doc="When using sequence datasets, this specifies the columns in the dataset for which additional"
                       "features should be added. For all columns given here, the feature differences between index i"
                       "and index 0 (start of the sequence) are added as additional features.")
    traverse_dirs_when_loading: bool = \
        param.Boolean(doc="If true, image file names in datasets do no need to contain "
                          "the full path. Before loading, all files will be enumerated "
                          "recursively. If false, the image file name must be fully "
                          "given in the dataset file (relative to root path)")
    load_segmentation: bool = \
        param.Boolean(default=False, doc="If True the segmentations from hdf5 files will be loaded. If False, only"
                                         "the images will be loaded.")
    center_crop_size: Optional[TupleInt3] = \
        param.NumericTuple(default=None, allow_None=True, length=3,
                           doc="If given, the loaded images and segmentations will be cropped to the given size."
                               "Size is given in pixels. The crop will be taken from the center of the image. "
                               "Crop size should be in the form (crop_z, crop_y, crop_x)."
                               "If your dataset has 2D images, center crop should have singleton first dimension,"
                               "i.e. (1, ) + (crop_y, crop_x)")

    image_size: Optional[TupleInt3] = \
        param.NumericTuple(default=None, allow_None=True, length=3,
                           doc="If given, images will be resized to these dimensions immediately after loading from"
                               "file."
                               "Image size should be in the form (size_z, size_y, size_x)."
                               "If your dataset has 2D images, image size should have singleton first dimension,"
                               "i.e. (1, ) + (size_y, size_x)")

    categorical_feature_encoder: Optional[OneHotEncoderBase] = param.ClassSelector(OneHotEncoderBase,
                                                                                   allow_None=True,
                                                                                   instantiate=False,
                                                                                   doc="The one hot encoding scheme "
                                                                                       "for categorical data if "
                                                                                       "required")
    num_dataset_reader_workers: int = param.Integer(default=0, bounds=(-1, None),
                                                    doc="Number of workers (processes) to use for dataset "
                                                        "reading. Default is 0 which means only the main thread "
                                                        "will be used. Set to -1 for maximum parallelism level.")

    ensemble_aggregation_type: EnsembleAggregationType = param.ClassSelector(default=EnsembleAggregationType.Average,
                                                                             class_=EnsembleAggregationType,
                                                                             instantiate=False,
                                                                             doc="The aggregation method to use when"
                                                                                 "testing ensemble models.")
    number_of_cross_validation_splits_per_fold: int = param.Integer(0, bounds=(0, None),
                                                                    doc="Number of cross validation splits for k-fold "
                                                                        "cross validation within a fold.")

    cross_validation_sub_fold_split_index: int = param.Integer(DEFAULT_CROSS_VALIDATION_SPLIT_INDEX, bounds=(-1, None),
                                                               doc="The index of the cross validation fold this model "
                                                                   "is associated with when performing k-fold cross "
                                                                   "validation")
    dataset_stats_hook: Optional[Callable[[Dict[ModelExecutionMode, Any]], None]] = \
        param.Callable(default=None,
                       allow_None=True,
                       doc="A hook that is called with a dictionary that maps from train/val/test to the actual "
                           "dataset, to do customized statistics.")

    def __init__(self, num_dataset_reader_workers: int = 0, **params: Any) -> None:
        super().__init__(**params)
        self._model_category = ModelCategory.Regression \
            if self.is_regression_model else ModelCategory.Classification
        if not self.is_offline_run:
            self.num_dataset_reader_workers = 0
            logging.info("dataset reader parallelization is supported only locally, setting "
                         "num_dataset_reader_workers to 0 as this is an AML run.")
        else:
            self.num_dataset_reader_workers = num_dataset_reader_workers

    def validate(self) -> None:
        if not self.perform_cross_validation and self.perform_sub_fold_cross_validation:
            raise ValueError("Cannot perform sub fold cross validation if not running in cross validation mode"
                             " found, please set number_of_cross_validation_splits >= 2")
        if self.number_of_cross_validation_splits_per_fold == 1:
            raise ValueError("At least two sub folds must be required when performing sub fold cross validation,"
                             " but number_of_cross_validation_splits_per_fold was set to 1")

    @property
    def is_classification_model(self) -> bool:
        """
        Returns whether the model is a classification model
        """
        return self.loss_type.is_classification_loss()

    @property
    def is_regression_model(self) -> bool:
        """
        Returns whether the model is a regression model
        """
        return self.loss_type.is_regression_loss()

    @property
    def is_non_imaging_model(self) -> bool:
        """
        Returns whether the model uses non image features only
        """
        return len(self.image_channels) == 0

    @property
    def perform_sub_fold_cross_validation(self) -> bool:
        """
        True if sub fold cross validation will be be performed as part of the training procedure.
        :return:
        """
        return self.number_of_cross_validation_splits_per_fold > 1

    def get_total_number_of_non_imaging_features(self) -> int:
        """Returns the total number of non imaging features expected in the input"""
        return self.get_total_number_of_numerical_non_imaging_features() + \
               self.get_total_number_of_categorical_non_imaging_features()

    def get_total_number_of_numerical_non_imaging_features(self) -> int:
        """Returns the total number of numerical non imaging features expected in the input"""
        if len(self.numerical_columns) == 0:
            return 0
        else:
            features_channels_dict = self.get_non_image_feature_channels_dict()
            return sum([len(features_channels_dict[col]) for col in self.numerical_columns])

    def get_total_number_of_categorical_non_imaging_features(self) -> int:
        """
        Returns the total number of categorical non imaging features expected in the input eg for the
        categorical channels A and B the total number would be: 2 ( feature channels A and B) * 4
        (which is the number of bits required to one-hot encode a single channel)
        A| True, No => [1, 0, 0, 1]
        B| False, Yes => [0, 1, 1, 0]
        """
        if self.categorical_columns and not self.categorical_feature_encoder:
            raise ValueError(f"Found {len(self.categorical_columns)} categorical columns, but "
                             f"one_hot_encoder is None. Either set one_hot_encoder explicitly "
                             f"or make sure property is accessed after the dataset dataframe has been loaded.")
        elif not self.categorical_feature_encoder:
            return 0
        else:
            features_channels_dict = self.get_non_image_feature_channels_dict()
            if self.categorical_columns is None:
                return 0
            return sum([len(features_channels_dict[col]) * self.categorical_feature_encoder.get_feature_length(col)
                        for col in self.categorical_columns])

    def get_non_image_feature_channels_dict(self) -> Dict:
        """
        Convert the provided non_image_features_channels from List to Dictionary mapping each feature to its channels.
        As well as converting default key to each not defined features. Making it a property to avoid doing this
        conversion
        several time throughout the code.
        """
        if not self.non_image_feature_channels:
            return {}

        if isinstance(self.non_image_feature_channels, List):
            non_image_feature_channels_dict = {KEY_FOR_DEFAULT_CHANNEL: self.non_image_feature_channels}
        else:
            non_image_feature_channels_dict = self.non_image_feature_channels.copy()
        all_non_image_features = self.numerical_columns.copy()
        if self.categorical_columns:
            all_non_image_features.extend(self.categorical_columns)

        # Map each feature to its channels
        for column in all_non_image_features:
            if column not in self.non_image_feature_channels:
                try:
                    non_image_feature_channels_dict[column] = non_image_feature_channels_dict[KEY_FOR_DEFAULT_CHANNEL]
                except KeyError:
                    raise KeyError(f"The column {column} is not present in the non_image_features dictionary and the"
                                   f"default key {KEY_FOR_DEFAULT_CHANNEL} is missing.")
        # Delete default key
        non_image_feature_channels_dict.pop(KEY_FOR_DEFAULT_CHANNEL, None)
        return non_image_feature_channels_dict

    def filter_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Filter dataframes based on expected values on columns
        :param df: the input dataframe
        :return: the filtered dataframe
        """

        def _dataframe_stats(df: pd.DataFrame) -> str:
            """
            Creates a human readable string that contains the number of rows and the number of unique subjects.
            :return: A string like "12 rows, 5 unique subjects. "
            """
            total_rows = len(df)
            if self.subject_column in df:
                unique_subjects = len(df[self.subject_column].unique())
                message = f"{unique_subjects} unique subjects"
            else:
                message = f"subject column '{self.subject_column}' not present"
            return f"{total_rows} rows, {message}. "

        logging.info(f"Before filtering: {_dataframe_stats(df)}")
        if self.expected_column_values is not None:
            for column_name, expected_value in self.expected_column_values:
                df = df[df[column_name] == expected_value]
                logging.info(f"After filtering for 'column[{column_name}] == {expected_value}': {_dataframe_stats(df)}")
        logging.info(f"Final: {_dataframe_stats(df)}")
        return df

    def get_label_transform(self) -> Union[Callable, List[Callable]]:
        """Return a transformation or list of transformation
        to apply to the labels.
        """
        return LabelTransformation.identity

    def read_dataset_into_dataframe_and_pre_process(self) -> None:
        assert self.local_dataset is not None
        file_path = self.local_dataset / DATASET_CSV_FILE_NAME
        self.dataset_data_frame = pd.read_csv(file_path, dtype=str, low_memory=False)
        self.pre_process_dataset_dataframe()

    def pre_process_dataset_dataframe(self) -> None:
        # some empty values on numeric columns get converted to nan but we want ''
        assert self.dataset_data_frame is not None
        df = self.dataset_data_frame.fillna('')
        self.dataset_data_frame = self.filter_dataframe(df)
        # update the one-hot encoder based on this dataframe
        if self.categorical_columns:
            from InnerEye.ML.utils.dataset_util import CategoricalToOneHotEncoder
            self.categorical_feature_encoder = CategoricalToOneHotEncoder.create_from_dataframe(
                dataframe=self.dataset_data_frame,
                columns=self.categorical_columns
            )

    def create_torch_datasets(self, dataset_splits: DatasetSplits) -> Dict[ModelExecutionMode, Any]:
        from InnerEye.ML.dataset.scalar_dataset import ScalarDataset
        image_transforms = self.get_image_sample_transforms()
        train = ScalarDataset(args=self, data_frame=dataset_splits.train,
                              name="training", sample_transforms=image_transforms.train)  # type: ignore
        val = ScalarDataset(args=self, data_frame=dataset_splits.val, feature_statistics=train.feature_statistics,
                            name="validation", sample_transforms=image_transforms.val)  # type: ignore
        test = ScalarDataset(args=self, data_frame=dataset_splits.test, feature_statistics=train.feature_statistics,
                             name="test", sample_transforms=image_transforms.test)  # type: ignore

        return {
            ModelExecutionMode.TRAIN: train,
            ModelExecutionMode.VAL: val,
            ModelExecutionMode.TEST: test
        }

    def create_and_set_torch_datasets(self, for_training: bool = True, for_inference: bool = True) -> None:
        """
        Creates and sets torch datasets for all model execution modes, and stores them in the self._datasets field.
        It also calls the hook to compute statistics for the train/val/test datasets.
        """
        # For models other than segmentation models, it is easier to create both training and inference datasets
        # in one go, ignoring the arguments.
        if self._datasets_for_training is None and self._datasets_for_inference is None:
            datasets = self.create_torch_datasets(self.get_dataset_splits())
            self._datasets_for_training = {mode: datasets[mode]
                                           for mode in [ModelExecutionMode.TRAIN, ModelExecutionMode.VAL]}
            self._datasets_for_inference = datasets
            for split, dataset in datasets.items():
                logging.info(f"{split.value}: {len(dataset)} subjects. Detailed status: {dataset.status}")
            if self.dataset_stats_hook:
                try:
                    self.dataset_stats_hook(datasets)
                except Exception as ex:
                    print_exception(ex, message="Error while calling the hook for computing dataset statistics.")

    def get_training_class_counts(self) -> Dict:
        if self._datasets_for_training is None:
            self.create_and_set_torch_datasets(for_inference=False)
        assert self._datasets_for_training is not None  # for mypy
        return self._datasets_for_training[ModelExecutionMode.TRAIN].get_class_counts()

    def create_model(self) -> Any:
        pass

    def get_post_loss_logits_normalization_function(self) -> Callable:
        """
        Post loss normalization function to apply to the logits produced by the model.
        :return:
        """
        import torch
        if self.loss_type.is_classification_loss():
            return torch.nn.Sigmoid()
        elif self.loss_type.is_regression_loss():
            return torch.nn.Identity()  # type: ignore
        else:
            raise NotImplementedError("get_post_loss_logits_normalization_function not implemented for "
                                      f"loss type: {self.loss_type}")

    def get_parameter_search_hyperdrive_config(self, estimator: Estimator) -> HyperDriveConfig:
        return super().get_parameter_search_hyperdrive_config(estimator)

    def get_model_train_test_dataset_splits(self, dataset_df: pd.DataFrame) -> DatasetSplits:
        return super().get_model_train_test_dataset_splits(dataset_df)

    def get_image_sample_transforms(self) -> ModelTransformsPerExecutionMode:
        """
        Get transforms to perform on samples for each model execution mode.
        By default only no transformation is performed.
        For data augmentation, specify a Compose3D for the training execution mode.
        """
        return ModelTransformsPerExecutionMode()

    def get_cross_validation_dataset_splits(self, dataset_split: DatasetSplits) -> DatasetSplits:
        """
        When running cross validation, this method returns the dataset split that should be used for the
        currently executed cross validation split. If sub fold cross validation is required,
        then the training set corresponding to the currently executed cross validation split is further
        into a child fold, which has the same validation set as the parent fold.

        :param dataset_split: The full dataset, split into training, validation and test section.
        :return: The dataset split with training and validation sections shuffled according to the current
        cross validation index.
        """
        split_for_current_fold = super().get_cross_validation_dataset_splits(dataset_split)
        if self.perform_sub_fold_cross_validation:
            # create a sub fold based on the training set and set the validation set
            # as the validation set of the split.
            val_split = split_for_current_fold.val
            split_for_current_fold.val = pd.DataFrame()
            sub_fold_split = split_for_current_fold.get_k_fold_cross_validation_splits(
                self.number_of_cross_validation_splits_per_fold)[self.cross_validation_sub_fold_split_index]
            sub_fold_split.val = val_split
            return sub_fold_split
        else:
            return split_for_current_fold

    def get_effective_random_seed(self) -> int:
        seed = super().get_effective_random_seed()
        if self.perform_sub_fold_cross_validation:
            # offset the random seed based on the cross validation split index so each
            # sub fold with respect to the parent fold cross validation index so that
            # each sub fold has a different initial random state.
            seed += (self.cross_validation_split_index * self.number_of_cross_validation_splits_per_fold) \
                   + self.cross_validation_sub_fold_split_index
        return seed

    def get_total_number_of_cross_validation_runs(self) -> int:
        if self.perform_sub_fold_cross_validation:
            return self.number_of_cross_validation_splits * self.number_of_cross_validation_splits_per_fold
        else:
            return super().get_total_number_of_cross_validation_runs()

    def get_cross_validation_hyperdrive_sampler(self) -> GridParameterSampling:
        if self.perform_sub_fold_cross_validation:
            return GridParameterSampling(parameter_space={
                CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY: choice(list(range(self.number_of_cross_validation_splits))),
                CROSS_VALIDATION_SUB_FOLD_SPLIT_INDEX_TAG_KEY: choice(list(range(
                    self.number_of_cross_validation_splits_per_fold))),
            })
        else:
            return super().get_cross_validation_hyperdrive_sampler()

    def should_wait_for_other_cross_val_child_runs(self) -> bool:
        """
        Returns True if the current run is an online run and is the 0th cross validation split and
        0th sub fold split if sub fold cross validation is being performed.
        In this case, this will be the run that will wait for all other child runs to finish in order
        to aggregate their results.
        :return:
        """
        should_wait_child = True
        if self.perform_sub_fold_cross_validation and self.cross_validation_sub_fold_split_index != 0:
            should_wait_child = False
        return (not self.is_offline_run) and self.cross_validation_split_index == 0 and should_wait_child
예제 #18
0
class Sheet(EventProcessor,
            SheetCoordinateSystem):  # pylint: disable-msg=W0223
    """
    The generic base class for neural sheets.

    See SheetCoordinateSystem for how Sheet represents space, and
    EventProcessor for how Sheet handles time.

    output_fns are functions that take an activity matrix and produce
    an identically shaped output matrix. The default is having no
    output_fns.
    """
    __abstract = True

    nominal_bounds = BoundingRegionParameter(BoundingBox(radius=0.5),
                                             constant=True,
                                             doc="""
            User-specified BoundingBox of the Sheet coordinate area
            covered by this Sheet.  The left and right bounds--if
            specified--will always be observed, but the top and bottom
            bounds may be adjusted to ensure the density in the y
            direction is the same as the density in the x direction.
            In such a case, the top and bottom bounds are adjusted
            so that the center y point remains the same, and each
            bound is as close as possible to its specified value. The
            actual value of this Parameter is not adjusted, but the
            true bounds may be found from the 'bounds' attribute
            of this object.
            """)

    nominal_density = param.Number(default=10,
                                   constant=True,
                                   doc="""
            User-specified number of processing units per 1.0 distance
            horizontally or vertically in Sheet coordinates. The actual
            number may be different because of discretization; the matrix
            needs to tile the plane exactly, and for that to work the
            density might need to be adjusted.  For instance, an area of 3x2
            cannot have a density of 2 in each direction. The true density
            may be obtained from either the xdensity or ydensity attribute
            (since these are identical for a Sheet).
            """)

    plastic = param.Boolean(True,
                            doc="""
            Setting this to False tells the Sheet not to change its
            permanent state (e.g. any connection weights) based on
            incoming events.
            """)

    precedence = param.Number(default=0.1,
                              softbounds=(0.0, 1.0),
                              doc="""
            Allows a sorting order for Sheets, e.g. in the GUI.""")

    row_precedence = param.Number(default=0.5,
                                  softbounds=(0.0, 1.0),
                                  doc="""
            Allows grouping of Sheets before sorting precedence is
            applied, e.g. for two-dimensional plots in the GUI.""")

    layout_location = param.NumericTuple(default=(-1, -1),
                                         precedence=-1,
                                         doc="""
            Location for this Sheet in an arbitrary pixel-based space
            in which Sheets can be laid out for visualization.""")

    output_fns = param.HookList(
        default=[],
        class_=TransferFn,
        doc=
        "Output function(s) to apply (if apply_output_fns is true) to this Sheet's activity."
    )

    apply_output_fns = param.Boolean(
        default=True,
        doc="Whether to apply the output_fn after computing an Activity matrix."
    )

    properties = param.Dict(default={},
                            doc="""
       A dictionary of property values associated with the Sheet
       object.  For instance, the dictionary:

       {'polarity':'ON', 'eye':'Left'}

       could be used to indicate a left, LGN Sheet with ON-surround
       receptive fields.""")

    def _get_density(self):
        return self.xdensity

    density = property(_get_density,
                       doc="""The sheet's true density (i.e. the
        xdensity, which is equal to the ydensity for a Sheet.)""")

    def __init__(self, **params):
        """
        Initialize this object as an EventProcessor, then also as
        a SheetCoordinateSystem with equal xdensity and ydensity.

        views is a Layout, which stores associated measurements,
        i.e. representations of the sheet for use by analysis or plotting
        code.
        """
        EventProcessor.__init__(self, **params)

        # Initialize this object as a SheetCoordinateSystem, with
        # the same density along y as along x.
        SheetCoordinateSystem.__init__(self, self.nominal_bounds,
                                       self.nominal_density)

        n_units = round((self.lbrt[2] - self.lbrt[0]) * self.xdensity, 0)
        if n_units < 1:            raise ValueError(
                "Sheet bounds and density must be specified such that the "+ \
 "sheet has at least one unit in each direction; " \
 +self.name+ " does not.")

        # setup the activity matrix
        self.activity = zeros(self.shape, activity_type)

        # For non-plastic inputs
        self.__saved_activity = []
        self._plasticity_setting_stack = []

        self.views = Layout()
        self.views.Maps = Layout()
        self.views.Curves = Layout()

    ### JABALERT: This should be deleted now that sheet_views is public
    ### JC: shouldn't we keep that, or at least write a function in
    ### utils that deletes a value in a dictinnary without returning an
    ### error if the key is not in the dict?  I leave for the moment,
    ### and have to ask Jim to advise.
    def release_sheet_view(self, view_name):
        """
        Delete the dictionary entry with key entry 'view_name' to save
        memory.
        """
        if view_name in self.views.Maps:
            self.views.Maps[view_name] = None

    # CB: what to call this? sheetcoords()? sheetcoords_of_grid()? idxsheetcoords()?
    def sheetcoords_of_idx_grid(self):
        """
        Return an array of x-coordinates and an array of y-coordinates
        corresponding to the activity matrix of the sheet.
        """
        nrows, ncols = self.activity.shape

        C, R = meshgrid(arange(ncols), arange(nrows))

        X, Y = self.matrixidx2sheet(R, C)
        return X, Y

    # CB: check whether we need this function any more.
    def row_col_sheetcoords(self):
        """
        Return an array of Y-coordinates corresponding to the rows of
        the activity matrix of the sheet, and an array of
        X-coordinates corresponding to the columns.
        """
        # The row and column centers are returned in matrix (not
        # sheet) order (hence the reversals below).
        nrows, ncols = self.activity.shape
        return self.matrixidx2sheet(arange(nrows - 1, -1, -1),
                                    arange(ncols))[::-1]

    # CBALERT: to be removed once other code uses
    # row_col_sheetcoords() or sheetcoords_of_idx_grid().
    def sheet_rows(self):
        return self.row_col_sheetcoords()[0]

    def sheet_cols(self):
        return self.row_col_sheetcoords()[1]

    # CEBALERT: haven't really thought about what to put in this. The
    # way it is now, subclasses could make a super.activate() call to
    # avoid repeating some stuff.
    def activate(self):
        """
        Collect activity from each projection, combine it to calculate
        the activity for this sheet, and send the result out.

        Subclasses will need to override this method to whatever it
        means to calculate activity in that subclass.
        """
        if self.apply_output_fns:
            for of in self.output_fns:
                of(self.activity)

        self.send_output(src_port='Activity', data=self.activity)

    def state_push(self):
        """
        Save the current state of this sheet to an internal stack.

        This method is used by operations that need to test the
        response of the sheet without permanently altering its state,
        e.g. for measuring maps or probing the current behavior
        non-invasively.  By default, only the activity pattern of this
        sheet is saved, but subclasses should add saving for any
        additional state that they maintain, or strange bugs are
        likely to occur.  The state can be restored using state_pop().

        Note that Sheets that do learning need not save the
        values of all connection weights, if any, because
        plasticity can be turned off explicitly.  Thus this method
        is intended only for shorter-term state.
        """
        self.__saved_activity.append(array(self.activity))
        EventProcessor.state_push(self)
        for of in self.output_fns:
            if hasattr(of, 'state_push'):
                of.state_push()

    def state_pop(self):
        """
        Pop the most recently saved state off the stack.

        See state_push() for more details.
        """
        self.activity = self.__saved_activity.pop()
        EventProcessor.state_pop(self)
        for of in self.output_fns:
            if hasattr(of, 'state_pop'):
                of.state_pop()

    def activity_len(self):
        """Return the number of items that have been saved by state_push()."""
        return len(self.__saved_activity)

    def override_plasticity_state(self, new_plasticity_state):
        """
        Temporarily override plasticity of medium and long term internal state.

        This function should be implemented by all subclasses so that
        it preserves the ability of the Sheet to compute activity,
        i.e. to operate over a short time scale, while preventing any
        lasting changes to the state (if new_plasticity_state=False).

        Any operation that does not have any lasting state, such as
        those affecting only the current activity level, should not
        be affected by this call.

        By default, simply saves a copy of the plastic flag to an
        internal stack (so that it can be restored by
        restore_plasticity_state()), and then sets plastic to
        new_plasticity_state.
        """
        self._plasticity_setting_stack.append(self.plastic)
        self.plastic = new_plasticity_state

    def restore_plasticity_state(self):
        """
        Restores plasticity of medium and long term internal state after
        a override_plasticity_state call.

        This function should be implemented by all subclasses to
        remove the effect of the most recent override_plasticity_state call,
        i.e. to restore plasticity of any type that was overridden.
        """
        self.plastic = self._plasticity_setting_stack.pop()

    def n_bytes(self):
        """
        Return a lower bound for the memory taken by this sheet, in bytes.

        Typically, this number will include the activity array and any
        similar arrays, plus any other significant data owned (in some
        sense) by this Sheet.  It will not usually include memory
        taken by the Python dictionary or various "housekeeping"
        attributes, which usually contribute only a small amount to
        the memory requirements.

        Subclasses should reimplement this method if they store a
        significant amount of data other than in the activity array.
        """
        return self.activity.nbytes

    def __getitem__(self, coords):
        metadata = AttrDict(precedence=self.precedence,
                            row_precedence=self.row_precedence,
                            timestamp=self.simulation.time())

        image = Image(self.activity.copy(),
                      self.bounds,
                      label=self.name,
                      group='Activity')[coords]
        image.metadata = metadata
        return image
예제 #19
0
파일: stats.py 프로젝트: zzwei1/holoviews
class bivariate_kde(Operation):
    """
    Computes a 2D kernel density estimate (KDE) of the first two
    dimensions in the input data. Kernel density estimation is a
    non-parametric way to estimate the probability density function of
    a random variable.

    The KDE works by placing 2D Gaussian kernel at each sample with
    the supplied bandwidth. These kernels are then summed to produce
    the density estimate. By default a good bandwidth is determined
    using the bw_method but it may be overridden by an explicit value.
    """

    contours = param.Boolean(default=True,
                             doc="""
        Whether to compute contours from the KDE, determines whether to
        return an Image or Contours/Polygons.""")

    bw_method = param.ObjectSelector(default='scott',
                                     objects=['scott', 'silverman'],
                                     doc="""
        Method of automatically determining KDE bandwidth""")

    bandwidth = param.Number(default=None,
                             doc="""
        Allows supplying explicit bandwidth value rather than relying
        on scott or silverman method.""")

    cut = param.Number(default=3,
                       doc="""
        Draw the estimate to cut * bw from the extreme data points.""")

    filled = param.Boolean(default=False,
                           doc="""
        Controls whether to return filled or unfilled contours.""")

    n_samples = param.Integer(default=100,
                              doc="""
        Number of samples to compute the KDE over.""")

    x_range = param.NumericTuple(default=None,
                                 length=2,
                                 doc="""
       The x_range as a tuple of min and max x-value. Auto-ranges
       if set to None.""")

    y_range = param.NumericTuple(default=None,
                                 length=2,
                                 doc="""
       The x_range as a tuple of min and max y-value. Auto-ranges
       if set to None.""")

    def _process(self, element, key=None):
        try:
            from scipy import stats
        except ImportError:
            raise ImportError('%s operation requires SciPy to be installed.' %
                              type(self).__name__)

        if len(element.dimensions()) < 2:
            raise ValueError("bivariate_kde can only be computed on elements "
                             "declaring at least two dimensions.")
        xdim, ydim = element.dimensions()[:2]
        params = {}
        if isinstance(element, Bivariate):
            if element.group != type(element).__name__:
                params['group'] = element.group
            params['label'] = element.label
            vdim = element.vdims[0]
        else:
            vdim = 'Density'

        data = element.array([0, 1]).T
        xmin, xmax = self.p.x_range or element.range(0)
        ymin, ymax = self.p.y_range or element.range(1)
        if any(not np.isfinite(v) for v in (xmin, xmax)):
            xmin, xmax = -0.5, 0.5
        elif xmin == xmax:
            xmin, xmax = xmin - 0.5, xmax + 0.5
        if any(not np.isfinite(v) for v in (ymin, ymax)):
            ymin, ymax = -0.5, 0.5
        elif ymin == ymax:
            ymin, ymax = ymin - 0.5, ymax + 0.5

        data = data[:, np.isfinite(data).min(
            axis=0)] if data.shape[1] > 1 else np.empty((2, 0))
        if data.shape[1] > 1:
            kde = stats.gaussian_kde(data)
            if self.p.bandwidth:
                kde.set_bandwidth(self.p.bandwidth)
            bw = kde.scotts_factor() * data.std(ddof=1)
            if self.p.x_range:
                xs = np.linspace(xmin, xmax, self.p.n_samples)
            else:
                xs = _kde_support((xmin, xmax), bw, self.p.n_samples,
                                  self.p.cut, xdim.range)
            if self.p.y_range:
                ys = np.linspace(ymin, ymax, self.p.n_samples)
            else:
                ys = _kde_support((ymin, ymax), bw, self.p.n_samples,
                                  self.p.cut, ydim.range)
            xx, yy = cartesian_product([xs, ys], False)
            positions = np.vstack([xx.ravel(), yy.ravel()])
            f = np.reshape(kde(positions).T, xx.shape)
        elif self.p.contours:
            eltype = Polygons if self.p.filled else Contours
            return eltype([], kdims=[xdim, ydim], vdims=[vdim])
        else:
            xs = np.linspace(xmin, xmax, self.p.n_samples)
            ys = np.linspace(ymin, ymax, self.p.n_samples)
            f = np.zeros((self.p.n_samples, self.p.n_samples))

        img = Image((xs, ys, f.T),
                    kdims=element.dimensions()[:2],
                    vdims=[vdim],
                    **params)
        if self.p.contours:
            cntr = contours(img, filled=self.p.filled)
            return cntr.clone(cntr.data[1:], **params)
        return img
예제 #20
0
class shade(LinkableOperation):
    """
    shade applies a normalization function followed by colormapping to
    an Image or NdOverlay of Images, returning an RGB Element.
    The data must be in the form of a 2D or 3D DataArray, but NdOverlays
    of 2D Images will be automatically converted to a 3D array.

    In the 2D case data is normalized and colormapped, while a 3D
    array representing categorical aggregates will be supplied a color
    key for each category. The colormap (cmap) for the 2D case may be
    supplied as an Iterable or a Callable.
    """

    cmap = param.ClassSelector(class_=(Iterable, Callable, dict), doc="""
        Iterable or callable which returns colors as hex colors
        or web color names (as defined by datashader), to be used
        for the colormap of single-layer datashader output.
        Callable type must allow mapping colors between 0 and 1.
        The default value of None reverts to Datashader's default
        colormap.""")

    color_key = param.ClassSelector(class_=(Iterable, Callable, dict), doc="""
        Iterable or callable that returns colors as hex colors, to
        be used for the color key of categorical datashader output.
        Callable type must allow mapping colors for supplied values
        between 0 and 1.""")

    normalization = param.ClassSelector(default='eq_hist',
                                        class_=(basestring, Callable),
                                        doc="""
        The normalization operation applied before colormapping.
        Valid options include 'linear', 'log', 'eq_hist', 'cbrt',
        and any valid transfer function that accepts data, mask, nbins
        arguments.""")

    clims = param.NumericTuple(default=None, length=2, doc="""
        Min and max data values to use for colormap interpolation, when
        wishing to override autoranging.
        """)

    min_alpha = param.Number(default=40, doc="""
        The minimum alpha value to use for non-empty pixels when doing
        colormapping, in [0, 255].  Use a higher value to avoid
        undersaturation, i.e. poorly visible low-value datapoints, at
        the expense of the overall dynamic range..""")

    @classmethod
    def concatenate(cls, overlay):
        """
        Concatenates an NdOverlay of Image types into a single 3D
        xarray Dataset.
        """
        if not isinstance(overlay, NdOverlay):
            raise ValueError('Only NdOverlays can be concatenated')
        xarr = xr.concat([v.data.transpose() for v in overlay.values()],
                         pd.Index(overlay.keys(), name=overlay.kdims[0].name))
        params = dict(get_param_values(overlay.last),
                      vdims=overlay.last.vdims,
                      kdims=overlay.kdims+overlay.last.kdims)
        return Dataset(xarr.transpose(), datatype=['xarray'], **params)


    @classmethod
    def uint32_to_uint8(cls, img):
        """
        Cast uint32 RGB image to 4 uint8 channels.
        """
        return np.flipud(img.view(dtype=np.uint8).reshape(img.shape + (4,)))


    @classmethod
    def rgb2hex(cls, rgb):
        """
        Convert RGB(A) tuple to hex.
        """
        if len(rgb) > 3:
            rgb = rgb[:-1]
        return "#{0:02x}{1:02x}{2:02x}".format(*(int(v*255) for v in rgb))


    @classmethod
    def to_xarray(cls, element):
        if issubclass(element.interface, XArrayInterface):
            return element
        data = tuple(element.dimension_values(kd, expanded=False)
                     for kd in element.kdims)
        data += tuple(element.dimension_values(vd, flat=False)
                      for vd in element.vdims)
        dtypes = [dt for dt in element.datatype if dt != 'xarray']
        return element.clone(data, datatype=['xarray']+dtypes,
                             bounds=element.bounds,
                             xdensity=element.xdensity,
                             ydensity=element.ydensity)


    def _process(self, element, key=None):
        element = element.map(self.to_xarray, Image)
        if isinstance(element, NdOverlay):
            bounds = element.last.bounds
            xdensity = element.last.xdensity
            ydensity = element.last.ydensity
            element = self.concatenate(element)
        elif isinstance(element, Overlay):
            return element.map(self._process, [Element])
        else:
            xdensity = element.xdensity
            ydensity = element.ydensity
            bounds = element.bounds

        vdim = element.vdims[0].name
        array = element.data[vdim]
        kdims = element.kdims

        # Compute shading options depending on whether
        # it is a categorical or regular aggregate
        shade_opts = dict(how=self.p.normalization, min_alpha=self.p.min_alpha)
        if element.ndims > 2:
            kdims = element.kdims[1:]
            categories = array.shape[-1]
            if not self.p.color_key:
                pass
            elif isinstance(self.p.color_key, dict):
                shade_opts['color_key'] = self.p.color_key
            elif isinstance(self.p.color_key, Iterable):
                shade_opts['color_key'] = [c for i, c in
                                           zip(range(categories), self.p.color_key)]
            else:
                colors = [self.p.color_key(s) for s in np.linspace(0, 1, categories)]
                shade_opts['color_key'] = map(self.rgb2hex, colors)
        elif not self.p.cmap:
            pass
        elif isinstance(self.p.cmap, Callable):
            colors = [self.p.cmap(s) for s in np.linspace(0, 1, 256)]
            shade_opts['cmap'] = map(self.rgb2hex, colors)
        else:
            shade_opts['cmap'] = self.p.cmap

        if self.p.clims:
            shade_opts['span'] = self.p.clims
        elif ds_version > '0.5.0' and self.p.normalization != 'eq_hist':
            shade_opts['span'] = element.range(vdim)

        for d in kdims:
            if array[d.name].dtype.kind == 'M':
                array[d.name] = array[d.name].astype('datetime64[us]').astype('int64')

        with warnings.catch_warnings():
            warnings.filterwarnings('ignore', r'invalid value encountered in true_divide')
            if np.isnan(array.data).all():
                arr = np.zeros(array.data.shape, dtype=np.uint32)
                img = array.copy()
                img.data = arr
            else:
                img = tf.shade(array, **shade_opts)
        params = dict(get_param_values(element), kdims=kdims,
                      bounds=bounds, vdims=RGB.vdims[:],
                      xdensity=xdensity, ydensity=ydensity)
        return RGB(self.uint32_to_uint8(img.data), **params)
예제 #21
0
class ViolinPlot(BoxWhiskerPlot):

    bandwidth = param.Number(default=None,
                             doc="""
        Allows supplying explicit bandwidth value rather than relying
        on scott or silverman method.""")

    clip = param.NumericTuple(default=None,
                              length=2,
                              doc="""
        A tuple of a lower and upper bound to clip the violin at.""")

    cut = param.Number(default=5,
                       doc="""
        Draw the estimate to cut * bw from the extreme data points.""")

    inner = param.ObjectSelector(objects=['box', 'quartiles', 'stick', None],
                                 default='box',
                                 doc="""
        Inner visual indicator for distribution values:

          * box - A small box plot
          * stick - Lines indicating each sample value
          * quartiles - Indicates first, second and third quartiles
        """)

    violin_width = param.Number(default=0.8,
                                doc="""
       Relative width of the violin""")

    # Map each glyph to a style group
    _style_groups = {
        'patch': 'violin',
        'segment': 'stats',
        'vbar': 'box',
        'scatter': 'median',
        'hbar': 'box'
    }

    _draw_order = ['patch', 'segment', 'vbar', 'hbar', 'circle', 'scatter']

    style_opts = ([
        glyph + p for p in fill_properties + line_properties
        for glyph in ('violin_', 'box_')
    ] + ['stats_' + p for p in line_properties] + [
        '_'.join([glyph, p]) for p in ('color', 'alpha')
        for glyph in ('box', 'violin', 'stats', 'median')
    ])

    _stat_fns = [partial(np.percentile, q=q) for q in [25, 50, 75]]

    def _kde_data(self, el, key, **kwargs):
        vdim = el.vdims[0]
        values = el.dimension_values(vdim)
        if self.clip:
            vdim = vdim(range=self.clip)
            el = el.clone(vdims=[vdim])
        kde = univariate_kde(el, dimension=vdim, **kwargs)
        xs, ys = (kde.dimension_values(i) for i in range(2))
        ys = (ys / ys.max()) * (self.violin_width / 2.)
        ys = [
            key + (sign * y, ) for sign, vs in ((-1, ys), (1, ys[::-1]))
            for y in vs
        ]
        kde = {'x': np.concatenate([xs, xs[::-1]]), 'y': ys}

        bars, segments, scatter = defaultdict(list), defaultdict(list), {}
        values = el.dimension_values(vdim)
        values = values[np.isfinite(values)]
        if self.inner == 'quartiles':
            for stat_fn in self._stat_fns:
                stat = stat_fn(values)
                sidx = np.argmin(np.abs(xs - stat))
                sx, sy = xs[sidx], ys[sidx]
                segments['x'].append(sx)
                segments['y0'].append(key + (-sy[-1], ))
                segments['y1'].append(sy)
        elif self.inner == 'stick':
            for value in values:
                sidx = np.argmin(np.abs(xs - value))
                sx, sy = xs[sidx], ys[sidx]
                segments['x'].append(sx)
                segments['y0'].append(key + (-sy[-1], ))
                segments['y1'].append(sy)
        elif self.inner == 'box':
            xpos = key + (0, )
            q1, q2, q3 = (np.percentile(values, q=q)
                          for q in range(25, 100, 25))
            iqr = q3 - q1
            upper = min(q3 + 1.5 * iqr, np.nanmax(values))
            lower = max(q1 - 1.5 * iqr, np.nanmin(values))
            segments['x'].append(xpos)
            segments['y0'].append(lower)
            segments['y1'].append(upper)
            bars['x'].append(xpos)
            bars['bottom'].append(q1)
            bars['top'].append(q3)
            scatter['x'] = xpos
            scatter['y'] = q2
        return kde, segments, bars, scatter

    def get_data(self, element, ranges, style):
        if element.kdims:
            groups = element.groupby(element.kdims).data
        else:
            groups = dict([((element.label, ), element)])

        # Define glyph-data mapping
        if self.invert_axes:
            bar_map = {
                'y': 'x',
                'left': 'bottom',
                'right': 'top',
                'height': 0.1
            }
            kde_map = {'x': 'x', 'y': 'y'}
            if self.inner == 'box':
                seg_map = {'x0': 'y0', 'x1': 'y1', 'y0': 'x', 'y1': 'x'}
            else:
                seg_map = {'x0': 'x', 'x1': 'x', 'y0': 'y0', 'y1': 'y1'}
            scatter_map = {'x': 'y', 'y': 'x'}
            bar_glyph = 'hbar'
        else:
            bar_map = {
                'x': 'x',
                'bottom': 'bottom',
                'top': 'top',
                'width': 0.1
            }
            kde_map = {'x': 'y', 'y': 'x'}
            if self.inner == 'box':
                seg_map = {'x0': 'x', 'x1': 'x', 'y0': 'y0', 'y1': 'y1'}
            else:
                seg_map = {'y0': 'x', 'y1': 'x', 'x0': 'y0', 'x1': 'y1'}
            scatter_map = {'x': 'x', 'y': 'y'}
            bar_glyph = 'vbar'

        elstyle = self.lookup_options(element, 'style')
        kwargs = {'bandwidth': self.bandwidth, 'cut': self.cut}

        data, mapping = {}, {}
        seg_data, bar_data, scatter_data = (defaultdict(list)
                                            for i in range(3))
        for i, (key, g) in enumerate(groups.items()):
            key = decode_bytes(key)
            gkey = 'patch_%d' % i
            kde, segs, bars, scatter = self._kde_data(g, key, **kwargs)
            for k, v in segs.items():
                seg_data[k] += v
            for k, v in bars.items():
                bar_data[k] += v
            for k, v in scatter.items():
                scatter_data[k].append(v)
            data[gkey] = kde
            patch_style = {
                k[7:]: v
                for k, v in elstyle[i].items() if k.startswith('violin')
            }
            mapping[gkey] = dict(kde_map, **patch_style)

        if seg_data:
            data['segment_1'] = {
                k: v if isinstance(v[0], tuple) else np.array(v)
                for k, v in seg_data.items()
            }
            mapping['segment_1'] = seg_map
        if bar_data:
            data[bar_glyph + '_1'] = {
                k: v if isinstance(v[0], tuple) else np.array(v)
                for k, v in bar_data.items()
            }
            mapping[bar_glyph + '_1'] = bar_map
        if scatter_data:
            data['scatter_1'] = {
                k: v if isinstance(v[0], tuple) else np.array(v)
                for k, v in scatter_data.items()
            }
            mapping['scatter_1'] = scatter_map
        return data, mapping, style
예제 #22
0
class tuning_curve(PylabPlotCommand):
    """
    Plot a tuning curve for a feature, such as orientation, contrast, or size.

    The curve datapoints are collected from the curve_dict for
    the units at the specified coordinates in the specified sheet
    (where the units and sheet may be set by a GUI, using
    topo.analysis.featureresponses.UnitCurveCommand.sheet and
    topo.analysis.featureresponses.UnitCurveCommand.coords,
    or by hand).
    """

    center = param.Boolean(default=True, doc="""
        Centers the tuning curve around the maximally responding feature.""")

    coords = param.List(default=[(0 , 0)], doc="""
        List of coordinates of units to measure.""")

    group_by = param.List(default=['Contrast'], doc="""
        Feature dimensions for which curves are overlaid.""")

    legend = param.Boolean(default=True, doc="""
        Whether or not to include a legend in the plot.""")

    relative_labels = param.Boolean(default=False, doc="""
        Relabel the x-axis with values relative to the preferred.""")

    sheet = param.ObjectSelector(default=None, doc="""
        Name of the sheet to use in measurements.""")

    x_axis = param.String(default='', doc="""
        Feature to plot on the x axis of the tuning curve""")

    # Disable and hide parameters inherited from the base class
    coord = param.NumericTuple(constant=True,  precedence=-1)

    def __call__(self, **params):
        p = ParamOverrides(self, params, allow_extra_keywords=True)

        x_axis = p.x_axis.capitalize()
        stack = p.sheet.views.Curves[x_axis.capitalize()+"Tuning"]
        time = stack.dim_range('Time')[1]

        curves = []
        if stack.dimension_labels[0] == 'X':
            for coord in p.coords:
                x, y = coord
                current_stack = stack[x, y, time, :, :, :]
                curve_stack = current_stack.sample(X=x, Y=y).collate(p.x_axis.capitalize())
                curves.append(curve_stack.overlay_dimensions(p.group_by))
        else:
            current_stack = stack[time, :, :, :]
            curve_stack = current_stack.sample(coords=p.coords).collate(p.x_axis.capitalize())
            overlaid_curves = curve_stack.overlay_dimensions(p.group_by)
            if not isinstance(curves, GridLayout): curves = [overlaid_curves]

        figs = []
        for coord, curve in zip(p.coords,curves):
            fig = plt.figure()
            ax = plt.subplot(111)
            plot = DataPlot if isinstance(curve.last, DataOverlay) else CurvePlot
            plot(curve, center=p.center, relative_labels=p.relative_labels,
                 show_legend=p.legend)(ax)
            self._generate_figure(p, fig)
            figs.append((coord, fig))

        return figs


    def _generate_figure(self, p, fig):
        """
        Helper function to display a figure on screen or save to a file.

        p should be a ParamOverrides instance containing the current
        set of parameters.
        """

        plt.show._needmain=False
        if p.filename is not None:
            # JABALERT: need to reformat this as for other plots
            fullname=p.filename+p.filename_suffix+str(topo.sim.time())+"."+p.file_format
            fig.savefig(normalize_path(fullname), dpi=p.file_dpi)
        elif p.display_window:
            self._set_windowtitle(p.title)
            fig.show()
        else:
            fig.close()
예제 #23
0
class LayoutPlot(CompositePlot):
    """
    A LayoutPlot accepts either a Layout or a NdLayout and
    displays the elements in a cartesian grid in scanline order.
    """

    figure_bounds = param.NumericTuple(default=(0.05, 0.05, 0.95, 0.95),
                                       doc="""
        The bounds of the figure as a 4-tuple of the form
        (left, bottom, right, top), defining the size of the border
        around the subplots.""")

    horizontal_spacing = param.Number(default=0.5, doc="""
      Specifies the space between horizontally adjacent elements in the grid.
      Default value is set conservatively to avoid overlap of subplots.""")

    vertical_spacing = param.Number(default=0.2, doc="""
      Specifies the space between vertically adjacent elements in the grid.
      Default value is set conservatively to avoid overlap of subplots.""")

    def __init__(self, layout, **params):
        if not isinstance(layout, (NdLayout, Layout)):
            raise ValueError("LayoutPlot only accepts Layout objects.")
        if len(layout.values()) == 0:
            raise ValueError("Cannot display empty layout")

        self.layout = layout.map(Compositor.collapse_element, [CompositeOverlay])
        self.subplots = {}
        self.rows, self.cols = layout.shape
        self.coords = list(product(range(self.rows),
                                   range(self.cols)))
        dimensions, keys = traversal.unique_dimkeys(layout)
        plotopts = Store.lookup_options(layout, 'plot').options
        super(LayoutPlot, self).__init__(keys=keys, dimensions=dimensions,
                                         uniform=traversal.uniform(layout),
                                         **dict(plotopts, **params))
        self.subplots, self.subaxes, self.layout = self._compute_gridspec(layout)


    def _compute_gridspec(self, layout):
        """
        Computes the tallest and widest cell for each row and column
        by examining the Layouts in the GridSpace. The GridSpec is then
        instantiated and the LayoutPlots are configured with the
        appropriate embedded layout_types. The first element of the
        returned tuple is a dictionary of all the LayoutPlots indexed
        by row and column. The second dictionary in the tuple supplies
        the grid indicies needed to instantiate the axes for each
        LayoutPlot.
        """
        layout_items = layout.grid_items()
        layout_dimensions = layout.key_dimensions if isinstance(layout, NdLayout) else None

        layouts = {}
        row_heightratios, col_widthratios = {}, {}
        for (r, c) in self.coords:
            # Get view at layout position and wrap in AdjointLayout
            _, view = layout_items.get((r, c), (None, None))
            layout_view = view if isinstance(view, AdjointLayout) else AdjointLayout([view])
            layouts[(r, c)] = layout_view

            # Compute shape of AdjointLayout element
            layout_lens = {1:'Single', 2:'Dual', 3:'Triple'}
            layout_type = layout_lens[len(layout_view)]
            width_ratios = AdjointLayoutPlot.layout_dict[layout_type]['width_ratios']
            height_ratios = AdjointLayoutPlot.layout_dict[layout_type]['height_ratios']
            layout_shape = (len(width_ratios), len(height_ratios))

            # For each row and column record the width and height ratios
            # of the LayoutPlot with the most horizontal or vertical splits
            if layout_shape[0] > row_heightratios.get(r, (0, None))[0]:
                row_heightratios[r] = (layout_shape[1], height_ratios)
            if layout_shape[1] > col_widthratios.get(c, (0, None))[0]:
                col_widthratios[c] = (layout_shape[0], width_ratios)

        # In order of row/column collect the largest width and height ratios
        height_ratios = [v[1] for k, v in sorted(row_heightratios.items())]
        width_ratios = [v[1] for k, v in sorted(col_widthratios.items())]
        # Compute the number of rows and cols
        cols = np.sum([len(wr) for wr in width_ratios])
        rows = np.sum([len(hr) for hr in height_ratios])
        # Flatten the width and height ratio lists
        wr_list = [wr for wrs in width_ratios for wr in wrs]
        hr_list = [hr for hrs in height_ratios for hr in hrs]

        self.gs = gridspec.GridSpec(rows, cols,
                                    width_ratios=wr_list,
                                    height_ratios=hr_list,
                                    wspace=self.horizontal_spacing,
                                    hspace=self.vertical_spacing)

        # Situate all the Layouts in the grid and compute the gridspec
        # indices for all the axes required by each LayoutPlot.
        gidx = 0
        collapsed_layout = layout.clone(shared_data=False, id=layout.id)
        frame_ranges = self.compute_ranges(layout, None, None)
        frame_ranges = OrderedDict([(key, self.compute_ranges(layout, key, frame_ranges))
                                    for key in self.keys])
        layout_subplots, layout_axes = {}, {}
        for num, (r, c) in enumerate(self.coords):
            # Compute the layout type from shape
            wsplits = len(width_ratios[c])
            hsplits = len(height_ratios[r])
            if (wsplits, hsplits) == (1,1):
                layout_type = 'Single'
            elif (wsplits, hsplits) == (2,1):
                layout_type = 'Dual'
            elif (wsplits, hsplits) == (1,2):
                layout_type = 'Embedded Dual'
            elif (wsplits, hsplits) == (2,2):
                layout_type = 'Triple'

            # Get the AdjoinLayout at the specified coordinate
            view = layouts[(r, c)]
            positions = AdjointLayoutPlot.layout_dict[layout_type]['positions']

            # Create temporary subplots to get projections types
            # to create the correct subaxes for all plots in the layout
            temp_subplots, new_layout = self._create_subplots(layouts[(r, c)], positions,
                                                              None, frame_ranges)
            gidx, gsinds, projs = self.grid_situate(temp_subplots, gidx, layout_type, cols)

            layout_key, _ = layout_items.get((r, c), (None, None))
            if isinstance(layout, NdLayout) and layout_key:
                layout_dimensions = OrderedDict(zip(layout_dimensions, layout_key))

            # Generate the axes and create the subplots with the appropriate
            # axis objects
            subaxes = [plt.subplot(self.gs[ind], projection=proj)
                       for ind, proj in zip(gsinds, projs)]
            subplots, adjoint_layout = self._create_subplots(layouts[(r, c)], positions,
                                                             layout_dimensions, frame_ranges,
                                                             dict(zip(positions, subaxes)),
                                                             num=num+1)
            layout_axes[(r, c)] = subaxes

            # Generate the AdjointLayoutsPlot which will coordinate
            # plotting of AdjointLayouts in the larger grid
            plotopts = Store.lookup_options(view, 'plot').options
            layout_plot = AdjointLayoutPlot(adjoint_layout, layout_type, subaxes, subplots,
                                            figure=self.handles['fig'], **plotopts)
            layout_subplots[(r, c)] = layout_plot
            if layout_key:
                collapsed_layout[layout_key] = adjoint_layout

        if self.show_title and len(self.coords) > 1:
            self.handles['title'] = self.handles['fig'].suptitle('', fontsize=16)

        return layout_subplots, layout_axes, collapsed_layout


    def grid_situate(self, subplots, current_idx, layout_type, subgrid_width):
        """
        Situate the current AdjointLayoutPlot in a LayoutPlot. The
        LayoutPlot specifies a layout_type into which the AdjointLayoutPlot
        must be embedded. This enclosing layout is guaranteed to have
        enough cells to display all the views.

        Based on this enforced layout format, a starting index
        supplied by LayoutPlot (indexing into a large gridspec
        arrangement) is updated to the appropriate embedded value. It
        will also return a list of gridspec indices associated with
        the all the required layout axes.
        """
        # Set the layout configuration as situated in a NdLayout

        if layout_type == 'Single':
            positions = ['main']
            start, inds = current_idx+1, [current_idx]
        elif layout_type == 'Dual':
            positions = ['main', 'right']
            start, inds = current_idx+2, [current_idx, current_idx+1]

        bottom_idx = current_idx + subgrid_width
        if layout_type == 'Embedded Dual':
            positions = [None, None, 'main', 'right']
            bottom = ((current_idx+1) % subgrid_width) == 0
            grid_idx = (bottom_idx if bottom else current_idx)+1
            start, inds = grid_idx, [current_idx, bottom_idx]
        elif layout_type == 'Triple':
            positions = ['top', None, 'main', 'right']
            bottom = ((current_idx+2) % subgrid_width) == 0
            grid_idx = (bottom_idx if bottom else current_idx) + 2
            start, inds = grid_idx, [current_idx, current_idx+1,
                              bottom_idx, bottom_idx+1]
        projs = [subplots.get(pos, Plot).projection for pos in positions]

        return start, inds, projs


    def _create_subplots(self, layout, positions, layout_dimensions, ranges, axes={}, num=1):
        """
        Plot all the views contained in the AdjointLayout Object using axes
        appropriate to the layout configuration. All the axes are
        supplied by LayoutPlot - the purpose of the call is to
        invoke subplots with correct options and styles and hide any
        empty axes as necessary.
        """
        subplots = {}
        adjoint_clone = layout.clone(shared_data=False, id=layout.id)
        subplot_opts = dict(show_title=False, adjoined=layout)
        for pos in positions:
            # Pos will be one of 'main', 'top' or 'right' or None
            view = layout.get(pos, None)
            ax = axes.get(pos, None)
            if view is None:
                continue
            # Customize plotopts depending on position.
            plotopts = Store.lookup_options(view, 'plot').options
            # Options common for any subplot

            override_opts = {}
            if pos == 'main':
                own_params = self.get_param_values(onlychanged=True)
                sublabel_opts = {k: v for k, v in own_params
                                 if 'sublabel_' in k}
                override_opts = dict(aspect='square')
            elif pos == 'right':
                right_opts = dict(orientation='vertical',
                                  show_xaxis=None, show_yaxis='left')
                override_opts = dict(subplot_opts, **right_opts)
            elif pos == 'top':
                top_opts = dict(show_xaxis='bottom', show_yaxis=None)
                override_opts = dict(subplot_opts, **top_opts)

            # Override the plotopts as required
            plotopts = dict(sublabel_opts, **plotopts)
            plotopts.update(override_opts, figure=self.handles['fig'])
            vtype = view.type if isinstance(view, HoloMap) else view.__class__
            if isinstance(view, GridSpace):
                raster_fn = lambda x: True if isinstance(x, Raster) or \
                                  (not isinstance(x, Element)) else False
                all_raster = all(view.traverse(raster_fn))
                if all_raster:
                    from .raster import RasterGridPlot
                    plot_type = RasterGridPlot
                else:
                    plot_type = GridPlot
                plotopts['create_axes'] = ax is not None
            else:
                if pos == 'main':
                    plot_type = Store.registry[vtype]
                else:
                    plot_type = Plot.sideplots[vtype]
            num = num if len(self.coords) > 1 else 0
            subplots[pos] = plot_type(view, axis=ax, keys=self.keys,
                                      dimensions=self.dimensions,
                                      layout_dimensions=layout_dimensions,
                                      ranges=ranges, subplot=True,
                                      uniform=self.uniform, layout_num=num,
                                      **plotopts)
            if issubclass(plot_type, CompositePlot):
                adjoint_clone[pos] = subplots[pos].layout
            else:
                adjoint_clone[pos] = subplots[pos].map
        return subplots, adjoint_clone


    def update_handles(self, axis, view, key, ranges=None):
        """
        Should be called by the update_frame class to update
        any handles on the plot.
        """
        if self.show_title and 'title' in self.handles and len(self.coords) > 1:
            self.handles['title'].set_text(self._format_title(key))


    def __call__(self):
        axis = self.handles['axis']
        self.update_handles(axis, None, self.keys[-1])

        ranges = self.compute_ranges(self.layout, self.keys[-1], None)
        for subplot in self.subplots.values():
            subplot(ranges=ranges)

        return self._finalize_axis(None)
class ScalarModelBase(ModelConfigBase):
    class_names: List[str] = param.List(
        class_=str,
        default=[DEFAULT_KEY],
        bounds=(1, None),
        doc=
        "The label names for each label class in the dataset and model output "
        "in the case of binary and multi-label classification tasks."
        "The order of the names should match the order of label class indices "
        "in dataset.csv"
        "For multi-label classification, this field is required."
        "For binary classification, this field must be a list of size 1, and "
        "is by default ['Default'], but can optionally be set to a more "
        "descriptive "
        "name for the positive class.")
    target_names: List[str] = param.List(
        class_=str,
        default=None,
        bounds=(1, None),
        doc=
        "The label names for each output target, used for logging metrics and "
        "reporting results. If provided, the length of this list must match the "
        "number of model outputs (and of transformed labels, if defined; see "
        "get_posthoc_label_transform()). By default, this inherits the value of "
        "class_names at initialisation. This will be ignored in sequence models, "
        "as target_names are determined automatically based on"
        "sequence_target_positions")
    aggregation_type: AggregationType = param.ClassSelector(
        default=AggregationType.Average,
        class_=AggregationType,
        doc="The type of global pooling aggregation to use between"
        " the encoder and the classifier.")
    loss_type: ScalarLoss = param.ClassSelector(
        default=ScalarLoss.BinaryCrossEntropyWithLogits,
        class_=ScalarLoss,
        instantiate=False,
        doc="The loss_type to use")
    image_channels: List[str] = param.List(
        class_=str,
        doc=
        "Identifies the rows of the dataset file that contain image file paths."
    )
    image_file_column: Optional[str] = param.String(
        default=None,
        allow_None=True,
        doc="The column that contains the path to image files.")
    expected_column_values: List[Tuple[str, str]] = \
        param.List(default=None,
                   doc="List of tuples with column name and expected value to filter rows in the dataset csv file",
                   allow_None=True)
    label_channels: Optional[List[str]] = \
        param.List(default=None, allow_None=True,
                   doc="Identifies the row of a dataset file that contains the label value.")
    label_value_column: str = param.String(
        doc="The column in the dataset file that contains the label value.")
    non_image_feature_channels: Union[List[str], Dict[str, List[str]]] = \
        ListOrDictParam(doc="Specifies the rows of a dataset file from which additional feature values should be read."
                            "The channels can be specified as a List of channels to be used for all non imaging"
                            "features or a as Dict mapping features to specific channels. The helper function"
                            "`get_non_image_features_dict` is available to construct this dictionnary.")
    numerical_columns: List[str] = \
        param.List(class_=str,
                   default=[],
                   doc="Specifies the columns of a dataset file from which additional numerical "
                       "feature values should be read.")
    categorical_columns: List[str] = \
        param.List(class_=str,
                   default=[],
                   doc="Specifies the columns of a dataset file from which additional "
                       "catagorical feature values should be read.")

    subject_column: str = \
        param.String(default=CSV_SUBJECT_HEADER, allow_None=False,
                     doc="The name of the column that contains the patient/subject identifier. Default: 'subject'")
    channel_column: str = \
        param.String(default=CSV_CHANNEL_HEADER, allow_None=False,
                     doc="The name of the column that contains image channel information, for identifying multiple "
                         "rows belonging to the same subject. Default: 'channel'")

    add_differences_for_features: List[str] = \
        param.List(class_=str,
                   doc="When using sequence datasets, this specifies the columns in the dataset for which additional"
                       "features should be added. For all columns given here, the feature differences between index i"
                       "and index 0 (start of the sequence) are added as additional features.")
    traverse_dirs_when_loading: bool = \
        param.Boolean(doc="If true, image file names in datasets do no need to contain "
                          "the full path. Before loading, all files will be enumerated "
                          "recursively. If false, the image file name must be fully "
                          "given in the dataset file (relative to root path)")
    load_segmentation: bool = \
        param.Boolean(default=False, doc="If True the segmentations from hdf5 files will be loaded. If False, only"
                                         "the images will be loaded.")
    center_crop_size: Optional[TupleInt3] = \
        param.NumericTuple(default=None, allow_None=True, length=3,
                           doc="If given, the loaded images and segmentations will be cropped to the given size."
                               "Size is given in pixels. The crop will be taken from the center of the image. "
                               "Crop size should be in the form (crop_z, crop_y, crop_x)."
                               "If your dataset has 2D images, center crop should have singleton first dimension,"
                               "i.e. (1, ) + (crop_y, crop_x)")

    image_size: Optional[TupleInt3] = \
        param.NumericTuple(default=None, allow_None=True, length=3,
                           doc="If given, images will be resized to these dimensions immediately after loading from"
                               "file."
                               "Image size should be in the form (size_z, size_y, size_x)."
                               "If your dataset has 2D images, image size should have singleton first dimension,"
                               "i.e. (1, ) + (size_y, size_x)")

    categorical_feature_encoder: Optional[
        OneHotEncoderBase] = param.ClassSelector(
            OneHotEncoderBase,
            allow_None=True,
            instantiate=False,
            doc="The one hot encoding scheme "
            "for categorical data if "
            "required")
    num_dataset_reader_workers: int = param.Integer(
        default=0,
        bounds=(-1, None),
        doc="Number of workers (processes) to use for dataset "
        "reading. Default is 0 which means only the main thread "
        "will be used. Set to -1 for maximum parallelism level.")

    ensemble_aggregation_type: EnsembleAggregationType = param.ClassSelector(
        default=EnsembleAggregationType.Average,
        class_=EnsembleAggregationType,
        instantiate=False,
        doc="The aggregation method to use when"
        "testing ensemble models.")

    dataset_stats_hook: Optional[Callable[[Dict[ModelExecutionMode, Any]], None]] = \
        param.Callable(default=None,
                       allow_None=True,
                       doc="A hook that is called with a dictionary that maps from train/val/test to the actual "
                           "dataset, to do customized statistics.")

    def __init__(self,
                 num_dataset_reader_workers: int = 0,
                 **params: Any) -> None:
        super().__init__(**params)
        self._model_category = ModelCategory.Regression \
            if self.is_regression_model else ModelCategory.Classification
        if not self.is_offline_run:
            self.num_dataset_reader_workers = 0
            logging.info(
                "dataset reader parallelization is supported only locally, setting "
                "num_dataset_reader_workers to 0 as this is an AML run.")
        else:
            self.num_dataset_reader_workers = num_dataset_reader_workers
        if self.target_names is None:
            self.target_names = self.class_names
        # Report generation assumes that results for the test set are available when we do cross validation on
        # all ScalarModels.
        self.inference_on_test_set = True

    def validate(self) -> None:
        if len(self.class_names) > 1 and not self.is_classification_model:
            raise ValueError(
                "Multiple label classes supported only for classification tasks."
            )

    @property
    def is_classification_model(self) -> bool:
        """
        Returns whether the model is a classification model
        """
        return self.loss_type.is_classification_loss()

    @property
    def is_regression_model(self) -> bool:
        """
        Returns whether the model is a regression model
        """
        return self.loss_type.is_regression_loss()

    @property
    def is_non_imaging_model(self) -> bool:
        """
        Returns whether the model uses non image features only
        """
        return len(self.image_channels) == 0

    def should_generate_multilabel_report(self) -> bool:
        """Determines whether to produce a multilabel report. Override this to implement custom behaviour."""
        return len(self.class_names) > 1

    def get_total_number_of_non_imaging_features(self) -> int:
        """Returns the total number of non imaging features expected in the input"""
        return self.get_total_number_of_numerical_non_imaging_features() + \
               self.get_total_number_of_categorical_non_imaging_features()

    def get_total_number_of_numerical_non_imaging_features(self) -> int:
        """Returns the total number of numerical non imaging features expected in the input"""
        if len(self.numerical_columns) == 0:
            return 0
        else:
            features_channels_dict = self.get_non_image_feature_channels_dict()
            return sum([
                len(features_channels_dict[col])
                for col in self.numerical_columns
            ])

    def get_total_number_of_categorical_non_imaging_features(self) -> int:
        """
        Returns the total number of categorical non imaging features expected in the input eg for the
        categorical channels A and B the total number would be: 2 ( feature channels A and B) * 4
        (which is the number of bits required to one-hot encode a single channel)
        A| True, No => [1, 0, 0, 1]
        B| False, Yes => [0, 1, 1, 0]
        """
        if self.categorical_columns and not self.categorical_feature_encoder:
            raise ValueError(
                f"Found {len(self.categorical_columns)} categorical columns, but "
                f"one_hot_encoder is None. Either set one_hot_encoder explicitly "
                f"or make sure property is accessed after the dataset dataframe has been loaded."
            )
        elif not self.categorical_feature_encoder:
            return 0
        else:
            features_channels_dict = self.get_non_image_feature_channels_dict()
            if self.categorical_columns is None:
                return 0
            return sum([
                len(features_channels_dict[col]) *
                self.categorical_feature_encoder.get_feature_length(col)
                for col in self.categorical_columns
            ])

    def get_non_image_feature_channels_dict(self) -> Dict:
        """
        Convert the provided non_image_features_channels from List to Dictionary mapping each feature to its channels.
        As well as converting default key to each not defined features. Making it a property to avoid doing this
        conversion
        several time throughout the code.
        """
        if not self.non_image_feature_channels:
            return {}

        if isinstance(self.non_image_feature_channels, List):
            non_image_feature_channels_dict = {
                DEFAULT_KEY: self.non_image_feature_channels
            }
        else:
            non_image_feature_channels_dict = self.non_image_feature_channels.copy(
            )
        all_non_image_features = self.numerical_columns.copy()
        if self.categorical_columns:
            all_non_image_features.extend(self.categorical_columns)

        # Map each feature to its channels
        for column in all_non_image_features:
            if column not in self.non_image_feature_channels:
                try:
                    non_image_feature_channels_dict[
                        column] = non_image_feature_channels_dict[DEFAULT_KEY]
                except KeyError:
                    raise KeyError(
                        f"The column {column} is not present in the non_image_features dictionary and the"
                        f"default key {DEFAULT_KEY} is missing.")
        # Delete default key
        non_image_feature_channels_dict.pop(DEFAULT_KEY, None)
        return non_image_feature_channels_dict

    def filter_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Filter dataframes based on expected values on columns
        :param df: the input dataframe
        :return: the filtered dataframe
        """
        def _dataframe_stats(df: pd.DataFrame) -> str:
            """
            Creates a human readable string that contains the number of rows and the number of unique subjects.
            :return: A string like "12 rows, 5 unique subjects. "
            """
            total_rows = len(df)
            if self.subject_column in df:
                unique_subjects = len(df[self.subject_column].unique())
                message = f"{unique_subjects} unique subjects"
            else:
                message = f"subject column '{self.subject_column}' not present"
            return f"{total_rows} rows, {message}. "

        logging.info(f"Before filtering: {_dataframe_stats(df)}")
        if self.expected_column_values is not None:
            for column_name, expected_value in self.expected_column_values:
                df = df[df[column_name] == expected_value]
                logging.info(
                    f"After filtering for 'column[{column_name}] == {expected_value}': {_dataframe_stats(df)}"
                )
        logging.info(f"Final: {_dataframe_stats(df)}")
        return df

    def get_label_transform(self) -> Union[Callable, List[Callable]]:
        """Return a transformation or list of transformation
        to apply to the labels.
        """
        return LabelTransformation.identity

    def read_dataset_into_dataframe_and_pre_process(self) -> None:
        assert self.local_dataset is not None
        file_path = self.local_dataset / self.dataset_csv
        self.dataset_data_frame = pd.read_csv(file_path,
                                              dtype=str,
                                              low_memory=False)
        self.pre_process_dataset_dataframe()

    def pre_process_dataset_dataframe(self) -> None:
        # some empty values on numeric columns get converted to nan but we want ''
        assert self.dataset_data_frame is not None
        df = self.dataset_data_frame.fillna('')
        self.dataset_data_frame = self.filter_dataframe(df)
        # update the one-hot encoder based on this dataframe
        if self.categorical_columns:
            from InnerEye.ML.utils.dataset_util import CategoricalToOneHotEncoder
            self.categorical_feature_encoder = CategoricalToOneHotEncoder.create_from_dataframe(
                dataframe=self.dataset_data_frame,
                columns=self.categorical_columns)

    def create_torch_datasets(
            self,
            dataset_splits: DatasetSplits) -> Dict[ModelExecutionMode, Any]:
        from InnerEye.ML.dataset.scalar_dataset import ScalarDataset
        sample_transform = self.get_scalar_item_transform()
        assert sample_transform.train is not None  # for mypy
        assert sample_transform.val is not None  # for mypy
        assert sample_transform.test is not None  # for mypy
        train = ScalarDataset(args=self,
                              data_frame=dataset_splits.train,
                              name="training",
                              sample_transform=sample_transform.train)
        val = ScalarDataset(args=self,
                            data_frame=dataset_splits.val,
                            feature_statistics=train.feature_statistics,
                            name="validation",
                            sample_transform=sample_transform.val)
        test = ScalarDataset(args=self,
                             data_frame=dataset_splits.test,
                             feature_statistics=train.feature_statistics,
                             name="test",
                             sample_transform=sample_transform.test)

        return {
            ModelExecutionMode.TRAIN: train,
            ModelExecutionMode.VAL: val,
            ModelExecutionMode.TEST: test
        }

    def create_and_set_torch_datasets(self,
                                      for_training: bool = True,
                                      for_inference: bool = True) -> None:
        """
        Creates and sets torch datasets for all model execution modes, and stores them in the self._datasets field.
        It also calls the hook to compute statistics for the train/val/test datasets.
        """
        # For models other than segmentation models, it is easier to create both training and inference datasets
        # in one go, ignoring the arguments.
        if self._datasets_for_training is None and self._datasets_for_inference is None:
            datasets = self.create_torch_datasets(self.get_dataset_splits())
            self._datasets_for_training = {
                mode: datasets[mode]
                for mode in [ModelExecutionMode.TRAIN, ModelExecutionMode.VAL]
            }
            self._datasets_for_inference = datasets
            for split, dataset in datasets.items():
                logging.info(
                    f"{split.value}: {len(dataset)} subjects. Detailed status: {dataset.status}"
                )
            if self.dataset_stats_hook:
                try:
                    self.dataset_stats_hook(datasets)
                except Exception as ex:
                    print_exception(
                        ex,
                        message=
                        "Error while calling the hook for computing dataset statistics."
                    )

    def get_training_class_counts(self) -> Dict:
        if self._datasets_for_training is None:
            self.create_and_set_torch_datasets(for_inference=False)
        assert self._datasets_for_training is not None  # for mypy
        return self._datasets_for_training[
            ModelExecutionMode.TRAIN].get_class_counts()

    def get_total_number_of_training_samples(self) -> int:
        if self._datasets_for_training is None:
            self.create_and_set_torch_datasets(for_inference=False)
        assert self._datasets_for_training is not None  # for mypy
        return len(self._datasets_for_training[ModelExecutionMode.TRAIN])

    def create_model(self) -> Any:
        pass

    def get_loss_function(self) -> Callable:
        """Returns a custom loss function to be used with ScalarLoss.CustomClassification or CustomRegression."""
        assert self.loss_type in {ScalarLoss.CustomClassification, ScalarLoss.CustomRegression}, \
            f"get_loss_function() should be called only for custom loss types (received {self.loss_type})"
        raise NotImplementedError(
            f"get_loss_function() must be implemented for loss type {self.loss_type}"
        )

    def get_post_loss_logits_normalization_function(self) -> Callable:
        """
        Post loss normalization function to apply to the logits produced by the model.
        :return:
        """
        import torch
        if self.loss_type.is_classification_loss():
            return torch.nn.Sigmoid()
        elif self.loss_type.is_regression_loss():
            return torch.nn.Identity()  # type: ignore
        else:
            raise NotImplementedError(
                "get_post_loss_logits_normalization_function not implemented for "
                f"loss type: {self.loss_type}")

    def get_parameter_search_hyperdrive_config(
            self, run_config: ScriptRunConfig) -> HyperDriveConfig:
        return super().get_parameter_search_hyperdrive_config(run_config)

    def get_model_train_test_dataset_splits(
            self, dataset_df: pd.DataFrame) -> DatasetSplits:
        return super().get_model_train_test_dataset_splits(dataset_df)

    def get_image_transform(self) -> ModelTransformsPerExecutionMode:
        """
        Get transforms to apply to images for each model execution mode.
        By default only no transformation is performed.
        """
        return ModelTransformsPerExecutionMode()

    def get_segmentation_transform(self) -> ModelTransformsPerExecutionMode:
        """
        Get transforms to apply on segmentations maps inputs for each model execution mode.
        By default only no transformation is performed.
        """
        return ModelTransformsPerExecutionMode()

    def get_scalar_item_transform(self) -> ModelTransformsPerExecutionMode:
        from InnerEye.ML.dataset.scalar_dataset import ScalarItemAugmentation
        image_transform = self.get_image_transform()
        segmentation_transform = self.get_segmentation_transform()
        return ModelTransformsPerExecutionMode(
            train=ScalarItemAugmentation(image_transform.train,
                                         segmentation_transform.train),
            val=ScalarItemAugmentation(image_transform.val,
                                       segmentation_transform.val),
            test=ScalarItemAugmentation(image_transform.test,
                                        segmentation_transform.test))

    def create_metric_computers(self) -> ModuleDict:
        """
        Gets a set of objects that compute all the metrics for the type of model that is being trained,
        across all prediction targets (sequence positions when using a sequence model).
        :return: A dictionary mapping from names of prediction targets to a list of metric computers.
        """
        # The metric computers should be stored in an object that derives from torch.Module,
        # so that they are picked up when moving the whole LightningModule to GPU.
        # https://github.com/PyTorchLightning/pytorch-lightning/issues/4713
        return ModuleDict(
            {p: self._get_metrics_computers()
             for p in self.target_names})

    def _get_metrics_computers(self) -> ModuleList:
        """
        Gets the objects that compute metrics for the present kind of models, for a single prediction target.
        """
        if self.is_classification_model:
            return ModuleList([
                Accuracy05(),
                AccuracyAtOptimalThreshold(),
                OptimalThreshold(),
                FalsePositiveRateOptimalThreshold(),
                FalseNegativeRateOptimalThreshold(),
                AreaUnderRocCurve(),
                AreaUnderPrecisionRecallCurve(),
                BinaryCrossEntropyWithLogits()
            ])
        else:
            return ModuleList(
                [MeanAbsoluteError(),
                 MeanSquaredError(),
                 ExplainedVariance()])

    def compute_and_log_metrics(self, logits: torch.Tensor,
                                targets: torch.Tensor, subject_ids: List[str],
                                is_training: bool, metrics: ModuleDict,
                                logger: DataframeLogger, current_epoch: int,
                                data_split: ModelExecutionMode) -> None:
        """
        Computes all the metrics for a given (logits, labels) pair, and writes them to the loggers.
        :param logits: The model output before normalization.
        :param targets: The expected model outputs.
        :param subject_ids: The subject IDs for the present minibatch.
        :param is_training: If True, write the metrics as training metrics, otherwise as validation metrics.
        :param metrics: A dictionary mapping from names of prediction targets to a list of metric computers,
        as returned by create_metric_computers.
        :param logger: An object of type DataframeLogger which can be be used for logging within this function.
        :param current_epoch: Current epoch number.
        :param data_split: ModelExecutionMode object indicating if this is the train or validation split.
        :return:
        """
        per_subject_outputs: List[Tuple[str, str, torch.Tensor,
                                        torch.Tensor]] = []
        for i, (prediction_target, metric_list) in enumerate(metrics.items()):
            # mask the model outputs and labels if required
            masked = get_masked_model_outputs_and_labels(
                logits[:, i, ...], targets[:, i, ...], subject_ids)
            # compute metrics on valid masked tensors only
            if masked is not None:
                _logits = masked.model_outputs.data
                _posteriors = self.get_post_loss_logits_normalization_function(
                )(_logits)
                # Classification metrics expect labels as integers, but they are float throughout the rest of the code
                labels_dtype = torch.int if self.is_classification_model else _posteriors.dtype
                _labels = masked.labels.data.to(dtype=labels_dtype)
                _subject_ids = masked.subject_ids
                assert _subject_ids is not None
                for metric in metric_list:
                    if isinstance(
                            metric,
                            ScalarMetricsBase) and metric.compute_from_logits:
                        metric(_logits, _labels)
                    else:
                        metric(_posteriors, _labels)
                per_subject_outputs.extend(
                    zip(_subject_ids, [prediction_target] * len(_subject_ids),
                        _posteriors.tolist(), _labels.tolist()))
        # Write a full breakdown of per-subject predictions and labels to a file. These files are local to the current
        # rank in distributed training, and will be aggregated after training.
        for subject, prediction_target, model_output, label in per_subject_outputs:
            logger.add_record({
                LoggingColumns.Epoch.value: current_epoch,
                LoggingColumns.Patient.value: subject,
                LoggingColumns.Hue.value: prediction_target,
                LoggingColumns.ModelOutput.value: model_output,
                LoggingColumns.Label.value: label,
                LoggingColumns.DataSplit.value: data_split.value
            })
예제 #25
0
class DeepLearningConfig(GenericConfig, CudaAwareConfig):
    """
    A class that holds all settings that are shared across segmentation models and regression/classification models.
    """
    _model_category: ModelCategory = param.ClassSelector(
        class_=ModelCategory,
        doc="The high-level model category described by this config.")
    _model_name: str = param.String(
        None,
        doc="The human readable name of the model (for example, Liver). This is "
        "usually set from the class name.")

    random_seed: int = param.Integer(
        42, doc="The seed to use for all random number generators.")
    azure_dataset_id: str = param.String(
        doc=
        "If provided, the ID of the dataset to use. This dataset must exist as a "
        "folder of the same name in the 'datasets' "
        "container in the datasets storage account.")
    local_dataset: Optional[Path] = param.ClassSelector(
        class_=Path,
        default=None,
        allow_None=True,
        doc="The path of the dataset to use, when training is running "
        "outside Azure.")
    num_dataload_workers: int = param.Integer(
        8,
        bounds=(0, None),
        doc="The number of data loading workers (processes). When set to 0,"
        "data loading is running in the same process (no process startup "
        "cost, hence good for use in unit testing. However, it "
        "does not give the same result as running with 1 worker process)")
    shuffle: bool = param.Boolean(
        True,
        doc="If true, the dataset will be shuffled randomly during training.")
    num_epochs: int = param.Integer(100,
                                    bounds=(1, None),
                                    doc="Number of epochs to train.")
    start_epoch: int = param.Integer(
        0,
        bounds=(0, None),
        doc="The first epoch to train. Set to 0 to start a new "
        "training. Set to a value larger than zero for starting"
        " from a checkpoint.")

    l_rate: float = param.Number(1e-4,
                                 doc="The initial learning rate",
                                 bounds=(0, None))
    _min_l_rate: float = param.Number(
        0.0,
        doc=
        "The minimum learning rate for the Polynomial and Cosine schedulers.",
        bounds=(0.0, None))
    l_rate_scheduler: LRSchedulerType = param.ClassSelector(
        default=LRSchedulerType.Polynomial,
        class_=LRSchedulerType,
        instantiate=False,
        doc="Learning rate decay method (Cosine, Polynomial, "
        "Step, MultiStep or Exponential)")
    l_rate_exponential_gamma: float = param.Number(
        0.9,
        doc="Controls the rate of decay for the Exponential "
        "LR scheduler.")
    l_rate_step_gamma: float = param.Number(
        0.1, doc="Controls the rate of decay for the "
        "Step LR scheduler.")
    l_rate_step_step_size: int = param.Integer(
        50, bounds=(0, None), doc="The step size for Step LR scheduler")
    l_rate_multi_step_gamma: float = param.Number(
        0.1,
        doc="Controls the rate of decay for the "
        "MultiStep LR scheduler.")
    l_rate_multi_step_milestones: Optional[List[int]] = param.List(
        None,
        bounds=(1, None),
        allow_None=True,
        class_=int,
        doc="The milestones for MultiStep decay.")
    l_rate_polynomial_gamma: float = param.Number(
        1e-4,
        doc="Controls the rate of decay for the "
        "Polynomial LR scheduler.")
    l_rate_warmup: LRWarmUpType = param.ClassSelector(
        default=LRWarmUpType.NoWarmUp,
        class_=LRWarmUpType,
        instantiate=False,
        doc="The type of learning rate warm up to use. "
        "Can be NoWarmUp (default) or Linear.")
    l_rate_warmup_epochs: int = param.Integer(
        0,
        bounds=(0, None),
        doc="Number of warmup epochs (linear warmup) before the "
        "scheduler starts decaying the learning rate. "
        "For example, if you are using MultiStepLR with "
        "milestones [50, 100, 200] and warmup epochs = 100, warmup "
        "will last for 100 epochs and the first decay of LR "
        "will happen on epoch 150")
    optimizer_type: OptimizerType = param.ClassSelector(
        default=OptimizerType.Adam,
        class_=OptimizerType,
        instantiate=False,
        doc="The optimizer_type to use")
    opt_eps: float = param.Number(
        1e-4, doc="The epsilon parameter of RMSprop or Adam")
    rms_alpha: float = param.Number(0.9, doc="The alpha parameter of RMSprop")
    adam_betas: TupleFloat2 = param.NumericTuple(
        (0.9, 0.999),
        length=2,
        doc="The betas parameter of Adam, default is (0.9, 0.999)")
    momentum: float = param.Number(
        0.6, doc="The momentum parameter of the optimizers")
    weight_decay: float = param.Number(
        1e-4, doc="The weight decay used to control L2 regularization")

    save_start_epoch: int = param.Integer(
        100,
        bounds=(0, None),
        doc="Save epoch checkpoints only when epoch is "
        "larger or equal to this value.")
    save_step_epochs: int = param.Integer(
        50,
        bounds=(0, None),
        doc="Save epoch checkpoints when epoch number is a "
        "multiple of save_step_epochs")
    train_batch_size: int = param.Integer(
        4,
        bounds=(0, None),
        doc="The number of crops that make up one minibatch during training.")
    detect_anomaly: bool = param.Boolean(
        False,
        doc="If true, test gradients for anomalies (NaN or Inf) during "
        "training.")
    use_mixed_precision: bool = param.Boolean(
        False,
        doc="If true, mixed precision training is activated during "
        "training.")
    use_model_parallel: bool = param.Boolean(
        False,
        doc="If true, neural network model is partitioned across all "
        "available GPUs to fit in a large model. It shall not be used "
        "together with data parallel.")
    test_diff_epochs: Optional[int] = param.Integer(
        None,
        doc="Number of different epochs of the same model to test",
        allow_None=True)
    test_step_epochs: Optional[int] = param.Integer(
        None, doc="How many epochs to move for each test", allow_None=True)
    test_start_epoch: Optional[int] = param.Integer(
        None,
        doc="The first epoch on which testing should run.",
        allow_None=True)
    monitoring_interval_seconds: int = param.Integer(
        0,
        doc="Seconds delay between logging GPU/CPU resource "
        "statistics. If 0 or less, do not log any resource "
        "statistics.")
    number_of_cross_validation_splits: int = param.Integer(
        0,
        bounds=(0, None),
        doc="Number of cross validation splits for k-fold cross "
        "validation")
    cross_validation_split_index: int = param.Integer(
        DEFAULT_CROSS_VALIDATION_SPLIT_INDEX,
        bounds=(-1, None),
        doc="The index of the cross validation fold this model is "
        "associated with when performing k-fold cross validation")
    file_system_config: DeepLearningFileSystemConfig = param.ClassSelector(
        default=DeepLearningFileSystemConfig(),
        class_=DeepLearningFileSystemConfig,
        instantiate=False,
        doc="File system related configs")
    pin_memory: bool = param.Boolean(
        True, doc="Value of pin_memory argument to DataLoader")
    _overrides: Dict[str, Any] = param.Dict(
        instantiate=True,
        doc="Model config properties that were overridden from the commandline"
    )
    restrict_subjects: Optional[str] = \
        param.String(doc="Use at most this number of subjects for train, val, or test set (must be > 0 or None). "
                         "If None, do not modify the train, val, or test sets. If a string of the form 'i,j,k' where "
                         "i, j and k are integers, modify just the corresponding sets (i for train, j for val, k for "
                         "test). If any of i, j or j are missing or are negative, do not modify the corresponding "
                         "set. Thus a value of 20,,5 means limit training set to 20, keep validation set as is, and "
                         "limit test set to 5. If any of i,j,k is '+', discarded members of the other sets are added "
                         "to that set.",
                     allow_None=True)
    perform_training_set_inference: bool = \
        param.Boolean(False,
                      doc="If False (default), run full image inference on validation and test set after training. If "
                          "True, also run full image inference on the training set")
    perform_validation_and_test_set_inference: bool = \
        param.Boolean(True,
                      doc="If True (default), run full image inference on validation and test set after training.")
    _metrics_data_frame_loggers: MetricsDataframeLoggers = param.ClassSelector(
        default=None,
        class_=MetricsDataframeLoggers,
        instantiate=False,
        doc="Data frame loggers for this model "
        "config")
    _dataset_data_frame: Optional[DataFrame] = \
        param.DataFrame(default=None,
                        doc="The dataframe that contains the dataset for the model. This is usually read from disk "
                            "from dataset.csv")
    _use_gpu: Optional[bool] = param.Boolean(
        None,
        doc="If true, a CUDA capable GPU with at least 1 device is "
        "available. If None, the use_gpu property has not yet been called.")
    avoid_process_spawn_in_data_loaders: bool = \
        param.Boolean(is_windows(), doc="If True, use a data loader logic that avoid spawning new processes at the "
                                        "start of each epoch. This speeds up training on both Windows and Linux, but"
                                        "on Linux, inference is currently disabled as the data loaders hang. "
                                        "If False, use the default data loader logic that starts new processes for "
                                        "each epoch.")
    # The default multiprocessing start_method in both PyTorch and the Python standard library is "fork" for Linux and
    # "spawn" (the only available method) for Windows. There is some evidence that using "forkserver" on Linux
    # can reduce the chance of stuck jobs.
    multiprocessing_start_method: MultiprocessingStartMethod = \
        param.ClassSelector(class_=MultiprocessingStartMethod,
                            default=(MultiprocessingStartMethod.spawn if is_windows()
                                     else MultiprocessingStartMethod.fork),
                            doc="Method to be used to start child processes in pytorch. Should be one of forkserver, "
                                "fork or spawn. If not specified, fork is used on Linux and spawn on Windows. "
                                "Set to forkserver as a possible remedy for stuck jobs.")
    output_to: Optional[str] = \
        param.String(default=None,
                     doc="If provided, the run outputs will be written to the given folder. If not provided, outputs "
                         "will go into a subfolder of the project root folder.")
    max_batch_grad_cam: int = param.Integer(
        default=0,
        doc="Max number of validation batches for which "
        "to save gradCam images. By default "
        "visualizations are saved for all images "
        "in the validation set")
    label_smoothing_eps: float = param.Number(
        0.0,
        bounds=(0.0, 1.0),
        doc="Target smoothing value for label smoothing")
    log_to_parent_run: bool = param.Boolean(
        default=False,
        doc="If true, hyperdrive child runs will log their metrics"
        "to their parent run.")

    use_imbalanced_sampler_for_training: bool = param.Boolean(
        default=False,
        doc="If True, use an imbalanced sampler during training.")
    drop_last_batch_in_training: bool = param.Boolean(
        default=False,
        doc="If True, drop the last incomplete batch during"
        "training. If all batches are complete, no batch gets "
        "dropped. If False, keep all batches.")
    log_summaries_to_files: bool = param.Boolean(
        default=True,
        doc=
        "If True, model summaries are logged to files in logs/model_summaries; "
        "if False, to stdout or driver log")
    mean_teacher_alpha: float = param.Number(
        bounds=(0, 1),
        allow_None=True,
        default=None,
        doc="If this value is set, the mean teacher model will be computed. "
        "Currently only supported for scalar models. In this case, we only "
        "report metrics and cross-validation results for "
        "the mean teacher model. Likewise the model used for inference "
        "is the mean teacher model. The student model is only used for "
        "training. Alpha is the momentum term for weight updates of the mean "
        "teacher model. After each training step the mean teacher model "
        "weights are updated using mean_teacher_"
        "weight = alpha * (mean_teacher_weight) "
        " + (1-alpha) * (current_student_weights). ")

    def __init__(self, **params: Any) -> None:
        self._model_name = type(self).__name__
        # This should be annotated as torch.utils.data.Dataset, but we don't want to import torch here.
        self._datasets_for_training: Optional[Dict[ModelExecutionMode,
                                                   Any]] = None
        self._datasets_for_inference: Optional[Dict[ModelExecutionMode,
                                                    Any]] = None
        super().__init__(throw_if_unknown_param=True, **params)
        logging.info("Creating the default output folder structure.")
        self.create_filesystem(fixed_paths.repository_root_directory())

    def validate(self) -> None:
        """
        Validates the parameters stored in the present object.
        """
        if len(self.adam_betas) < 2:
            raise ValueError(
                "The adam_betas parameter should be the coefficients used for computing running averages of "
                "gradient and its square")

        if self.azure_dataset_id is None and self.local_dataset is None:
            raise ValueError(
                "Either of local_dataset or azure_dataset_id must be set.")

        if self.number_of_cross_validation_splits == 1:
            raise ValueError(
                f"At least two splits required to perform cross validation found "
                f"number_of_cross_validation_splits={self.number_of_cross_validation_splits}"
            )
        if 0 < self.number_of_cross_validation_splits <= self.cross_validation_split_index:
            raise ValueError(
                f"Cross validation split index is out of bounds: {self.cross_validation_split_index}, "
                f"which is invalid for CV with {self.number_of_cross_validation_splits} splits."
            )
        elif self.number_of_cross_validation_splits == 0 and self.cross_validation_split_index != -1:
            raise ValueError(
                f"Cross validation split index must be -1 for a non cross validation run, "
                f"found number_of_cross_validation_splits = {self.number_of_cross_validation_splits} "
                f"and cross_validation_split_index={self.cross_validation_split_index}"
            )

        if self.l_rate_scheduler == LRSchedulerType.MultiStep:
            if not self.l_rate_multi_step_milestones:
                raise ValueError(
                    "Must specify l_rate_multi_step_milestones to use LR scheduler MultiStep"
                )
            if sorted(set(self.l_rate_multi_step_milestones)
                      ) != self.l_rate_multi_step_milestones:
                raise ValueError(
                    "l_rate_multi_step_milestones must be a strictly increasing list"
                )
            if self.l_rate_multi_step_milestones[0] <= 0:
                raise ValueError(
                    "l_rate_multi_step_milestones cannot be negative or 0.")

    @property
    def model_name(self) -> str:
        """
        Gets the human readable name of the model (e.g., Liver). This is usually set from the class name.
        :return: A model name as a string.
        """
        return self._model_name

    @property
    def model_category(self) -> ModelCategory:
        """
        Gets the high-level model category that this configuration objects represents (segmentation or scalar output).
        """
        return self._model_category

    @property
    def is_segmentation_model(self) -> bool:
        """
        Returns True if the present model configuration belongs to the high-level category ModelCategory.Segmentation.
        """
        return self.model_category == ModelCategory.Segmentation

    @property
    def is_scalar_model(self) -> bool:
        """
        Returns True if the present model configuration belongs to the high-level category ModelCategory.Scalar
        i.e. for Classification or Regression models.
        """
        return self.model_category.is_scalar

    @property
    def compute_grad_cam(self) -> bool:
        return self.max_batch_grad_cam > 0

    @property
    def min_l_rate(self) -> float:
        return self._min_l_rate

    @min_l_rate.setter
    def min_l_rate(self, value: float) -> None:
        if value > self.l_rate:
            raise ValueError(
                "l_rate must be >= min_l_rate, found: {}, {}".format(
                    self.l_rate, value))
        self._min_l_rate = value

    @property
    def outputs_folder(self) -> Path:
        """Gets the full path in which the model outputs should be stored."""
        return self.file_system_config.outputs_folder

    @property
    def logs_folder(self) -> Path:
        """Gets the full path in which the model logs should be stored."""
        return self.file_system_config.logs_folder

    @property
    def checkpoint_folder(self) -> str:
        """Gets the full path in which the model checkpoints should be stored during training."""
        return str(self.outputs_folder / CHECKPOINT_FOLDER)

    @property
    def visualization_folder(self) -> Path:
        """Gets the full path in which the visualizations notebooks should be saved during training."""
        return self.outputs_folder / VISUALIZATION_FOLDER

    @property
    def perform_cross_validation(self) -> bool:
        """
        True if cross validation will be be performed as part of the training procedure.
        :return:
        """
        return self.number_of_cross_validation_splits > 1

    @property
    def overrides(self) -> Optional[Dict[str, Any]]:
        return self._overrides

    @property
    def dataset_data_frame(self) -> Optional[DataFrame]:
        """
        Gets the pandas data frame that the model uses.
        :return:
        """
        return self._dataset_data_frame

    @dataset_data_frame.setter
    def dataset_data_frame(self, data_frame: Optional[DataFrame]) -> None:
        """
        Sets the pandas data frame that the model uses.
        :param data_frame: The data frame to set.
        """
        self._dataset_data_frame = data_frame

    @property
    def metrics_data_frame_loggers(self) -> MetricsDataframeLoggers:
        """
        Gets the metrics data frame loggers for this config.
        :return:
        """
        return self._metrics_data_frame_loggers

    def set_output_to(self, output_to: PathOrString) -> None:
        """
        Adjusts the file system settings in the present object such that all outputs are written to the given folder.
        :param output_to: The absolute path to a folder that should contain the outputs.
        """
        if isinstance(output_to, Path):
            output_to = str(output_to)
        self.output_to = output_to
        self.create_filesystem()

    def create_filesystem(
        self, project_root: Path = fixed_paths.repository_root_directory()
    ) -> None:
        """
        Creates new file system settings (outputs folder, logs folder) based on the information stored in the
        present object. If any of the folders do not yet exist, they are created.
        :param project_root: The root folder for the codebase that triggers the training run.
        """
        self.file_system_config = DeepLearningFileSystemConfig.create(
            project_root=project_root,
            model_name=self.model_name,
            is_offline_run=self.is_offline_run,
            output_to=self.output_to)

    def create_dataframe_loggers(self) -> None:
        """
        Initializes the metrics loggers that are stored in self._metrics_data_frame_loggers
        :return:
        """
        self._metrics_data_frame_loggers = MetricsDataframeLoggers(
            outputs_folder=self.outputs_folder)

    def should_load_checkpoint_for_training(self) -> bool:
        """Returns true if start epoch > 0, that is, if an existing checkpoint is used to continue training."""
        return self.start_epoch > 0

    def should_save_epoch(self, epoch: int) -> bool:
        """Returns True if the present epoch should be saved, as per the save_start_epoch and save_step_epochs
        settings. Epoch writing starts with the first epoch that is >= save_start_epoch, and that
        is evenly divisible by save_step_epochs. A checkpoint is always written for the last epoch (num_epochs),
        such that it is easy to overwrite num_epochs on the commandline without having to change the test parameters
        at the same time.
        :param epoch: The current epoch. The first epoch is assumed to be 1."""
        should_save_epoch = epoch >= self.save_start_epoch \
                            and epoch % self.save_step_epochs == 0
        is_last_epoch = epoch == self.num_epochs
        return should_save_epoch or is_last_epoch

    def get_train_epochs(self) -> List[int]:
        """
        Returns the epochs for which training will be performed.
        :return:
        """
        return list(range(self.start_epoch + 1, self.num_epochs + 1))

    def get_total_number_of_training_epochs(self) -> int:
        """
        Returns the number of epochs for which a model will be trained.
        :return:
        """
        return len(self.get_train_epochs())

    def get_total_number_of_save_epochs(self) -> int:
        """
        Returns the number of epochs for which a model checkpoint will be saved.
        :return:
        """
        return len(
            list(filter(self.should_save_epoch, self.get_train_epochs())))

    def get_total_number_of_validation_epochs(self) -> int:
        """
        Returns the number of epochs for which a model will be validated.
        :return:
        """
        return self.get_total_number_of_training_epochs()

    def get_test_epochs(self) -> List[int]:
        """
        Returns the list of epochs for which the model should be evaluated on full images in the test set.
        These are all epochs starting at self.test_start_epoch, in intervals of self.n_steps_epoch.
        The last training epoch is always included. If either of the self.test_* fields is missing (set to None),
        only the last training epoch is returned.
        :return:
        """
        test_epochs = {self.num_epochs}
        if self.test_diff_epochs is not None and self.test_start_epoch is not None and \
                self.test_step_epochs is not None:
            for j in range(self.test_diff_epochs):
                epoch = self.test_start_epoch + self.test_step_epochs * j
                if epoch > self.num_epochs:
                    break
                test_epochs.add(epoch)
        return sorted(test_epochs)

    def get_path_to_checkpoint(self, epoch: int) -> Path:
        """
        Returns full path to a checkpoint given an epoch
        :param epoch: the epoch number
        :param for_mean_teacher_model: if True looking returns path to the mean teacher checkpoint. Else returns the
        path to the (main / student) model checkpoint.
        :return: path to a checkpoint given an epoch
        """
        return create_checkpoint_path(
            path=fixed_paths.repository_root_directory() /
            self.checkpoint_folder,
            epoch=epoch)

    def get_effective_random_seed(self) -> int:
        """
        Returns the random seed set as part of this configuration. If the configuration corresponds
        to a cross validation split, then the cross validation fold index will be added to the
        set random seed in order to return the effective random seed.
        :return:
        """
        seed = self.random_seed
        if self.perform_cross_validation:
            # offset the random seed based on the cross validation split index so each
            # fold has a different initial random state.
            seed += self.cross_validation_split_index
        return seed

    @property  # type: ignore
    def use_gpu(self) -> bool:  # type: ignore
        """
        Returns True if a CUDA capable GPU is present and should be used, False otherwise.
        """
        if self._use_gpu is None:
            # Use a local import here because we don't want the whole file to depend on pytorch.
            from InnerEye.ML.utils.ml_util import is_gpu_available
            self._use_gpu = is_gpu_available()
        return self._use_gpu

    @use_gpu.setter
    def use_gpu(self, value: bool) -> None:
        """
        Sets the flag that controls the use of the GPU. Raises a ValueError if the value is True, but no GPU is
        present.
        """
        if value:
            # Use a local import here because we don't want the whole file to depend on pytorch.
            from InnerEye.ML.utils.ml_util import is_gpu_available
            if not is_gpu_available():
                raise ValueError(
                    "Can't set use_gpu to True if there is not CUDA capable GPU present."
                )
        self._use_gpu = value

    @property
    def use_data_parallel(self) -> bool:
        """
        Data parallel is used if GPUs are usable and the number of CUDA devices are greater than 1.
        :return:
        """
        _devices = self.get_cuda_devices()
        return _devices is not None and len(_devices) > 1

    def write_args_file(self, root: Optional[Path] = None) -> None:
        """
        Writes the current config to disk. The file is written either to the given folder, or if omitted,
        to the default outputs folder.
        """
        dst = (root or self.outputs_folder) / ARGS_TXT
        dst.write_text(data=str(self))

    def should_wait_for_other_cross_val_child_runs(self) -> bool:
        """
        Returns True if the current run is an online run and is the 0th cross validation split.
        In this case, this will be the run that will wait for all other child runs to finish in order
        to aggregate their results.
        :return:
        """
        return (
            not self.is_offline_run) and self.cross_validation_split_index == 0

    @property
    def is_offline_run(self) -> bool:
        """
        Returns True if the run is executing outside AzureML, or False if inside AzureML.
        """
        return is_offline_run_context(RUN_CONTEXT)

    @property
    def compute_mean_teacher_model(self) -> bool:
        """
        Returns True if the mean teacher model should be computed.
        """
        return self.mean_teacher_alpha is not None

    def __str__(self) -> str:
        """Returns a string describing the present object, as a list of key == value pairs."""
        arguments_str = "\nArguments:\n"
        property_dict = vars(self)
        keys = sorted(property_dict)
        for key in keys:
            arguments_str += "\t{:18}: {}\n".format(key, property_dict[key])
        return arguments_str
예제 #26
0
class run_batch(ParameterizedFunction):
    """
    Run a Topographica simulation in batch mode.

    Features:

      - Generates a unique, well-defined name for each 'experiment'
        (i.e. simulation run) based on the date, script file, and
        parameter settings. Note that very long names may be truncated
        (see the max_name_length parameter).

      - Allows parameters to be varied on the command-line,
        to allow comparing various settings

      - Saves a script capturing the simulation state periodically,
        to preserve parameter values from old experiments and to allow
        them to be reproduced exactly later

      - Can perform user-specified analysis routines periodically,
        to monitor the simulation as it progresses.

      - Stores commandline output (stdout) in the output directory

    A typical use of this function is for remote execution of a large
    number of simulations with different parameters, often on remote
    machines (such as clusters).

    The script_file parameter defines the .ty script we want to run in
    batch mode. The output_directory defines the root directory in
    which a unique individual directory will be created for this
    particular run.  The optional analysis_fn can be any python
    function to be called at each of the simulation iterations defined
    in the analysis times list.  The analysis_fn should perform
    whatever analysis of the simulation you want to perform, such as
    plotting or calculating some statistics.  The analysis_fn should
    avoid using any GUI functions (i.e., should not import anything
    from topo.tkgui), and it should save all of its results into
    files.

    As a special case, a number can be passed for the times list, in
    which case it is used to scale a default list of times up to
    10000; e.g. times=2 will select a default list of times up to
    20000.  Alternatively, an explicit list of times can be supplied.

    Any other optional parameters supplied will be set in the main
    namespace before any scripts are run.  They will also be used to
    construct a unique topo.sim.name for the file, and they will be
    encoded into the simulation directory name, to make it clear how
    each simulation differs from the others.

    If requested by setting snapshot=True, saves a snapshot at the
    end of the simulation.

    If available and requested by setting vc_info=True, prints
    the revision number and any outstanding diffs from the version
    control system.

    Note that this function alters param.normalize_path.prefix so that
    all output goes into the same location. The original value of
    param.normalize_path.prefix is deliberately not restored at the
    end of the function so that the output of any subsequent commands
    will go into the same place.
    """
    output_directory = param.String("Output")

    analysis_fn = param.Callable(default_analysis_function)

    times = param.Parameter(1.0)

    snapshot = param.Boolean(True)

    vc_info = param.Boolean(True)

    dirname_prefix = param.String(default="",
                                  doc="""
        Optional prefix for the directory name (allowing e.g. easy
        grouping).""")

    tag = param.String(default="",
                       doc="""
        Optional tag to embed in directory prefix to allow unique
        directory naming across multiple independent batches that
        share a common timestamp.""")

    # CB: do any platforms also have a maximum total path length?
    max_name_length = param.Number(default=200,
                                   doc="""
        The experiment's directory name will be truncated at this
        number of characters (since most filesystems have a
        limit).""")

    name_time_format = param.String(default="%Y%m%d%H%M",
                                    doc="""
        String format for the time included in the output directory
        and file names.  See the Python time module library
        documentation for codes.

        E.g. Adding '%S' to the default would include seconds.""")

    timestamp = param.NumericTuple(default=(0, 0),
                                   doc="""
        Optional override of timestamp in Python struct_time 8-tuple format.
        Useful when running many run_batch commands as part of a group with
        a shared timestamp. By default, the timestamp used is the time when
        run_batch is started.""")

    save_global_params = param.Boolean(default=True,
                                       doc="""
        Whether to save the script's global_parameters to a pickle in
        the output_directory after the script has been loaded (for
        e.g. future inspection of the experiment).""")

    dirname_params_filter = param.Callable(param_formatter.instance(),
                                           doc="""
        Function to control how the parameter names will appear in the
        output_directory's name.""")

    metadata_dir = param.String(doc="""Specifies the name of a
        subdirectory used to output metadata from run_batch if set.""")

    def _truncate(self, p, s):
        """
        If s is greater than the max_name_length parameter, truncate it
        (and indicate that it has been truncated).
        """
        # '___' at the end is supposed to represent '...'
        return s if len(s) <= p.max_name_length else s[0:p.max_name_length -
                                                       3] + '___'

    def __call__(self, script_file, **params_to_override):
        p = ParamOverrides(self, params_to_override, allow_extra_keywords=True)

        import os
        import shutil

        # Construct simulation name, etc.
        scriptbase = re.sub('.ty$', '', os.path.basename(script_file))
        prefix = ""
        if p.timestamp == (0, 0): prefix += time.strftime(p.name_time_format)
        else: prefix += time.strftime(p.name_time_format, p.timestamp)

        prefix += "_" + scriptbase + "_" + p.tag
        simname = prefix

        # Construct parameter-value portion of filename; should do more filtering
        # CBENHANCEMENT: should provide chance for user to specify a
        # function (i.e. make this a function, and have a parameter to
        # allow the function to be overridden).
        # And sort by name by default? Skip ones that aren't different
        # from default, or at least put them at the end?
        prefix += p.dirname_params_filter(p.extra_keywords())

        # Set provided parameter values in main namespace
        from topo.misc.commandline import global_params
        global_params.set_in_context(**p.extra_keywords())

        # Create output directories
        if not os.path.isdir(normalize_path(p.output_directory)):
            try:
                os.mkdir(normalize_path(p.output_directory))
            except OSError:
                pass  # Catches potential race condition (simultaneous run_batch runs)

        dirname = self._truncate(p, p.dirname_prefix + prefix)
        dirpath = normalize_path(os.path.join(p.output_directory, dirname))
        normalize_path.prefix = dirpath
        metadata_dir = os.path.join(normalize_path.prefix, p.metadata_dir)
        simpath = os.path.join(metadata_dir, simname)

        if os.path.isdir(normalize_path.prefix):
            print "Batch run: Warning -- directory already exists!"
            print "Run aborted; wait one minute before trying again, or else rename existing directory: \n" + \
                  normalize_path.prefix

            sys.exit(-1)
        else:
            os.makedirs(metadata_dir)
            print "Batch run output will be in " + normalize_path.prefix

        if p.vc_info:
            _print_vc_info(simpath + ".diffs")

        hostinfo = "Host: " + " ".join(platform.uname())
        topographicalocation = "Topographica: " + os.path.abspath(sys.argv[0])
        topolocation = "topo package: " + os.path.abspath(topo.__file__)
        scriptlocation = "script: " + os.path.abspath(script_file)

        starttime = time.time()
        startnote = "Batch run started at %s." % time.strftime(
            "%a %d %b %Y %H:%M:%S +0000", time.gmtime())

        # store a re-runnable copy of the command used to start this batch run
        try:
            # pipes.quote is undocumented, so I'm not sure which
            # versions of python include it (I checked python 2.6 and
            # 2.7 on linux; they both have it).
            import pipes
            quotefn = pipes.quote
        except (ImportError, AttributeError):
            # command will need a human to insert quotes before it can be re-used
            quotefn = lambda x: x

        command_used_to_start = string.join([quotefn(arg) for arg in sys.argv])

        # CBENHANCEMENT: would be nice to separately write out a
        # runnable script that does everything necessary to
        # re-generate results (applies diffs etc).

        # Shadow stdout to a .out file in the output directory, so that
        # print statements will go to both the file and to stdout.
        batch_output = open(normalize_path(simpath + ".out"), 'w')
        batch_output.write(command_used_to_start + "\n")
        sys.stdout = MultiFile(batch_output, sys.stdout)

        print
        print hostinfo
        print topographicalocation
        print topolocation
        print scriptlocation
        print
        print startnote

        from topo.misc.commandline import auto_import_commands
        auto_import_commands()

        # Ensure that saved state includes all parameter values
        from topo.command import save_script_repr
        param.parameterized.script_repr_suppress_defaults = False

        # Save a copy of the script file for reference
        shutil.copy2(script_file, normalize_path.prefix)
        shutil.move(normalize_path(scriptbase + ".ty"),
                    normalize_path(simpath + ".ty"))

        # Default case: times is just a number that scales a standard list of times
        times = p.times
        if not isinstance(times, list):
            times = [
                t * times for t in
                [0, 50, 100, 500, 1000, 2000, 3000, 4000, 5000, 10000]
            ]

        # Run script in main
        error_count = 0
        initial_warning_count = param.parameterized.warning_count
        try:
            execfile(script_file, __main__.__dict__)  #global_params.context
            global_params.check_for_unused_names()
            if p.save_global_params:
                _save_parameters(p.extra_keywords(),
                                 simpath + ".global_params.pickle")
            print_sizes()
            topo.sim.name = simname

            # Run each segment, doing the analysis and saving the script state each time
            for run_to in times:
                topo.sim.run(run_to - topo.sim.time())
                p.analysis_fn()
                normalize_path.prefix = metadata_dir
                save_script_repr()
                normalize_path.prefix = dirpath
                elapsedtime = time.time() - starttime
                param.Parameterized(name="run_batch").message(
                    "Elapsed real time %02d:%02d." %
                    (int(elapsedtime / 60), int(elapsedtime % 60)))

            if p.snapshot:
                save_snapshot()

        except:
            error_count += 1
            import traceback
            traceback.print_exc(file=sys.stdout)
            sys.stderr.write("Warning -- Error detected: execution halted.\n")

        print "\nBatch run completed at %s." % time.strftime(
            "%a %d %b %Y %H:%M:%S +0000", time.gmtime())
        print "There were %d error(s) and %d warning(s)%s." % \
              (error_count,(param.parameterized.warning_count-initial_warning_count),
               ((" (plus %d warning(s) prior to entering run_batch)"%initial_warning_count
                 if initial_warning_count>0 else "")))

        # restore stdout
        sys.stdout = sys.__stdout__
        batch_output.close()
예제 #27
0
class ResamplingOperation(Operation):
    """
    Abstract baseclass for resampling operations
    """

    dynamic = param.Boolean(default=True,
                            doc="""
       Enables dynamic processing by default.""")

    expand = param.Boolean(default=True,
                           doc="""
       Whether the x_range and y_range should be allowed to expand
       beyond the extent of the data.  Setting this value to True is
       useful for the case where you want to ensure a certain size of
       output grid, e.g. if you are doing masking or other arithmetic
       on the grids.  A value of False ensures that the grid is only
       just as large as it needs to be to contain the data, which will
       be faster and use less memory if the resulting aggregate is
       being overlaid on a much larger background.""")

    height = param.Integer(default=400,
                           doc="""
       The height of the output image in pixels.""")

    width = param.Integer(default=400,
                          doc="""
       The width of the output image in pixels.""")

    x_range = param.NumericTuple(default=None,
                                 length=2,
                                 doc="""
       The x_range as a tuple of min and max x-value. Auto-ranges
       if set to None.""")

    y_range = param.NumericTuple(default=None,
                                 length=2,
                                 doc="""
       The x_range as a tuple of min and max y-value. Auto-ranges
       if set to None.""")

    x_sampling = param.Number(default=None,
                              doc="""
        Specifies the smallest allowed sampling interval along the y-axis.""")

    y_sampling = param.Number(default=None,
                              doc="""
        Specifies the smallest allowed sampling interval along the y-axis.""")

    target = param.ClassSelector(class_=Image,
                                 doc="""
        A target Image which defines the desired x_range, y_range,
        width and height.
    """)

    streams = param.List(default=[PlotSize, RangeXY],
                         doc="""
        List of streams that are applied if dynamic=True, allowing
        for dynamic interaction with the plot.""")

    element_type = param.ClassSelector(class_=(Dataset, ),
                                       instantiate=False,
                                       is_instance=False,
                                       default=Image,
                                       doc="""
        The type of the returned Elements, must be a 2D Dataset type.""")

    link_inputs = param.Boolean(default=True,
                                doc="""
        By default, the link_inputs parameter is set to True so that
        when applying shade, backends that support linked streams
        update RangeXY streams on the inputs of the shade operation.
        Disable when you do not want the resulting plot to be interactive,
        e.g. when trying to display an interactive plot a second time.""")

    def _get_sampling(self, element, x, y):
        target = self.p.target
        if target:
            x_range, y_range = target.range(x), target.range(y)
            height, width = target.dimension_values(2, flat=False).shape
        else:
            if x is None or y is None:
                x_range = self.p.x_range or (-0.5, 0.5)
                y_range = self.p.y_range or (-0.5, 0.5)
            else:
                if self.p.expand or not self.p.x_range:
                    x_range = self.p.x_range or element.range(x)
                else:
                    x0, x1 = self.p.x_range
                    ex0, ex1 = element.range(x)
                    x_range = max([x0, ex0]), min([x1, ex1])
                if self.p.expand or not self.p.y_range:
                    y_range = self.p.y_range or element.range(y)
                else:
                    y0, y1 = self.p.y_range
                    ey0, ey1 = element.range(y)
                    y_range = max([y0, ey0]), min([y1, ey1])
            width, height = self.p.width, self.p.height
        (xstart, xend), (ystart, yend) = x_range, y_range

        # Compute highest allowed sampling density
        xspan = xend - xstart
        yspan = yend - ystart
        if self.p.x_sampling:
            width = int(min([(xspan / self.p.x_sampling), width]))
        if self.p.y_sampling:
            height = int(min([(yspan / self.p.y_sampling), height]))
        xunit, yunit = float(xspan) / width, float(yspan) / height
        xs, ys = (np.linspace(xstart + xunit / 2., xend - xunit / 2., width),
                  np.linspace(ystart + yunit / 2., yend - yunit / 2., height))
        return (x_range, y_range), (xs, ys), (width, height)
예제 #28
0
파일: options.py 프로젝트: gyenney/Tools
class Palette(Cycle):
    """
    Palettes allow easy specifying a discrete sampling
    of an existing colormap. Palettes may be supplied a key
    to look up a function function in the colormap class
    attribute. The function should accept a float scalar
    in the specified range and return a RGB(A) tuple.
    The number of samples may also be specified as a
    parameter.

    The range and samples may conveniently be overridden
    with the __getitem__ method.
    """

    key = param.String(default='grayscale',
                       doc="""
       Palettes look up the Palette values based on some key.""")

    range = param.NumericTuple(default=(0, 1),
                               doc="""
        The range from which the Palette values are sampled.""")

    samples = param.Integer(default=32,
                            doc="""
        The number of samples in the given range to supply to
        the sample_fn.""")

    sample_fn = param.Callable(default=np.linspace,
                               doc="""
        The function to generate the samples, by default linear.""")

    reverse = param.Boolean(default=False,
                            doc="""
        Whether to reverse the palette.""")

    # A list of available colormaps
    colormaps = {'grayscale': grayscale}

    def __init__(self, key, **params):
        super(Cycle, self).__init__(key=key, **params)
        self.values = self._get_values()

    def __getitem__(self, slc):
        """
        Provides a convenient interface to override the
        range and samples parameters of the Cycle.
        Supplying a slice step or index overrides the
        number of samples. Unsupplied slice values will be
        inherited.
        """
        (start, stop), step = self.range, self.samples
        if isinstance(slc, slice):
            if slc.start is not None:
                start = slc.start
            if slc.stop is not None:
                stop = slc.stop
            if slc.step is not None:
                step = slc.step
        else:
            step = slc
        return self(range=(start, stop), samples=step)

    def _get_values(self):
        cmap = self.colormaps[self.key]
        (start, stop), steps = self.range, self.samples
        samples = [cmap(n) for n in self.sample_fn(start, stop, steps)]
        return samples[::-1] if self.reverse else samples
예제 #29
0
class ResamplingOperation(LinkableOperation):
    """
    Abstract baseclass for resampling operations
    """

    dynamic = param.Boolean(default=True,
                            doc="""
       Enables dynamic processing by default.""")

    expand = param.Boolean(default=True,
                           doc="""
       Whether the x_range and y_range should be allowed to expand
       beyond the extent of the data.  Setting this value to True is
       useful for the case where you want to ensure a certain size of
       output grid, e.g. if you are doing masking or other arithmetic
       on the grids.  A value of False ensures that the grid is only
       just as large as it needs to be to contain the data, which will
       be faster and use less memory if the resulting aggregate is
       being overlaid on a much larger background.""")

    height = param.Integer(default=400,
                           doc="""
       The height of the output image in pixels.""")

    width = param.Integer(default=400,
                          doc="""
       The width of the output image in pixels.""")

    x_range = param.NumericTuple(default=None,
                                 length=2,
                                 doc="""
       The x_range as a tuple of min and max x-value. Auto-ranges
       if set to None.""")

    y_range = param.NumericTuple(default=None,
                                 length=2,
                                 doc="""
       The x_range as a tuple of min and max y-value. Auto-ranges
       if set to None.""")

    x_sampling = param.Number(default=None,
                              doc="""
        Specifies the smallest allowed sampling interval along the y-axis.""")

    y_sampling = param.Number(default=None,
                              doc="""
        Specifies the smallest allowed sampling interval along the y-axis.""")

    target = param.ClassSelector(class_=Image,
                                 doc="""
        A target Image which defines the desired x_range, y_range,
        width and height.
    """)

    streams = param.List(default=[PlotSize, RangeXY],
                         doc="""
        List of streams that are applied if dynamic=True, allowing
        for dynamic interaction with the plot.""")

    element_type = param.ClassSelector(class_=(Dataset, ),
                                       instantiate=False,
                                       is_instance=False,
                                       default=Image,
                                       doc="""
        The type of the returned Elements, must be a 2D Dataset type.""")

    precompute = param.Boolean(default=False,
                               doc="""
        Whether to apply precomputing operations. Precomputing can
        speed up resampling operations by avoiding unnecessary
        recomputation if the supplied element does not change between
        calls. The cost of enabling this option is that the memory
        used to represent this internal state is not freed between
        calls.""")

    @bothmethod
    def instance(self_or_cls, **params):
        inst = super(ResamplingOperation, self_or_cls).instance(**params)
        inst._precomputed = {}
        return inst

    def _get_sampling(self, element, x, y):
        target = self.p.target
        if target:
            x_range, y_range = target.range(x), target.range(y)
            height, width = target.dimension_values(2, flat=False).shape
        else:
            if x is None or y is None:
                x_range = self.p.x_range or (-0.5, 0.5)
                y_range = self.p.y_range or (-0.5, 0.5)
            else:
                if self.p.expand or not self.p.x_range:
                    x_range = self.p.x_range or element.range(x)
                else:
                    x0, x1 = self.p.x_range
                    ex0, ex1 = element.range(x)
                    x_range = np.max([x0, ex0]), np.min([x1, ex1])
                if x_range[0] == x_range[1]:
                    x_range = (x_range[0] - 0.5, x_range[0] + 0.5)

                if self.p.expand or not self.p.y_range:
                    y_range = self.p.y_range or element.range(y)
                else:
                    y0, y1 = self.p.y_range
                    ey0, ey1 = element.range(y)
                    y_range = np.max([y0, ey0]), np.min([y1, ey1])
            width, height = self.p.width, self.p.height
        (xstart, xend), (ystart, yend) = x_range, y_range

        xtype = 'numeric'
        if isinstance(xstart, datetime_types) or isinstance(
                xend, datetime_types):
            xstart, xend = dt_to_int(xstart, 'ns'), dt_to_int(xend, 'ns')
            xtype = 'datetime'
        elif not np.isfinite(xstart) and not np.isfinite(xend):
            if element.get_dimension_type(x) in datetime_types:
                xstart, xend = 0, 10000
                xtype = 'datetime'
            else:
                xstart, xend = 0, 1
        elif xstart == xend:
            xstart, xend = (xstart - 0.5, xend + 0.5)
        x_range = (xstart, xend)

        ytype = 'numeric'
        if isinstance(ystart, datetime_types) or isinstance(
                yend, datetime_types):
            ystart, yend = dt_to_int(ystart, 'ns'), dt_to_int(yend, 'ns')
            ytype = 'datetime'
        elif not np.isfinite(ystart) and not np.isfinite(yend):
            if element.get_dimension_type(y) in datetime_types:
                ystart, yend = 0, 10000
                ytype = 'datetime'
            else:
                ystart, yend = 0, 1
        elif ystart == yend:
            ystart, yend = (ystart - 0.5, yend + 0.5)
        y_range = (ystart, yend)

        # Compute highest allowed sampling density
        xspan = xend - xstart
        yspan = yend - ystart
        if self.p.x_sampling:
            width = int(min([(xspan / self.p.x_sampling), width]))
        if self.p.y_sampling:
            height = int(min([(yspan / self.p.y_sampling), height]))
        width, height = max([width, 1]), max([height, 1])
        xunit, yunit = float(xspan) / width, float(yspan) / height
        xs, ys = (np.linspace(xstart + xunit / 2., xend - xunit / 2., width),
                  np.linspace(ystart + yunit / 2., yend - yunit / 2., height))

        return (x_range, y_range), (xs, ys), (width, height), (xtype, ytype)
예제 #30
0
class histogram(ElementOperation):
    """
    Returns a Histogram of the input element data, binned into
    num_bins over the bin_range (if specified) along the specified
    dimension.

    If adjoin is True, the histogram will be returned adjoined to the
    Element as a side-plot.
    """

    adjoin = param.Boolean(default=True,
                           doc="""
      Whether to adjoin the histogram to the ViewableElement.""")

    bin_range = param.NumericTuple(default=(0, 0),
                                   doc="""
      Specifies the range within which to compute the bins.""")

    dimension = param.String(default=None,
                             doc="""
      Along which dimension of the ViewableElement to compute the histogram."""
                             )

    individually = param.Boolean(default=True,
                                 doc="""
      Specifies whether the histogram will be rescaled for each Raster in a UniformNdMapping."""
                                 )

    mean_weighted = param.Boolean(default=False,
                                  doc="""
      Whether the weighted frequencies are averaged.""")

    normed = param.Boolean(default=True,
                           doc="""
      Whether the histogram frequencies are normalized.""")

    nonzero = param.Boolean(default=False,
                            doc="""
      Whether to use only nonzero values when computing the histogram""")

    num_bins = param.Integer(default=20,
                             doc="""
      Number of bins in the histogram .""")

    weight_dimension = param.String(default=None,
                                    doc="""
       Name of the dimension the weighting should be drawn from""")

    style_prefix = param.String(default=None,
                                allow_None=None,
                                doc="""
      Used for setting a common style for histograms in a HoloMap or AdjointLayout."""
                                )

    def _process(self, view, key=None):
        if self.p.dimension:
            selected_dim = self.p.dimension
        else:
            selected_dim = [d.name for d in view.vdims + view.kdims][0]
        data = np.array(view.dimension_values(selected_dim))
        if self.p.nonzero:
            mask = data > 0
            data = data[mask]
        if self.p.weight_dimension:
            weights = np.array(view.dimension_values(self.p.weight_dimension))
            if self.p.nonzero:
                weights = weights[mask]
        else:
            weights = None
        hist_range = find_minmax((np.nanmin(data), np.nanmax(data)), (0, -float('inf')))\
            if self.p.bin_range is None else self.p.bin_range

        # Avoids range issues including zero bin range and empty bins
        if hist_range == (0, 0):
            hist_range = (0, 1)
        data = data[np.invert(np.isnan(data))]
        normed = False if self.p.mean_weighted and self.p.weight_dimension else self.p.normed
        try:
            hist, edges = np.histogram(data[np.isfinite(data)],
                                       normed=normed,
                                       range=hist_range,
                                       weights=weights,
                                       bins=self.p.num_bins)
            if not normed and self.p.weight_dimension and self.p.mean_weighted:
                hist_mean, _ = np.histogram(data[np.isfinite(data)],
                                            normed=normed,
                                            range=hist_range,
                                            bins=self.p.num_bins)
                hist /= hist_mean
        except:
            edges = np.linspace(hist_range[0], hist_range[1],
                                self.p.num_bins + 1)
            hist = np.zeros(self.p.num_bins)

        hist[np.isnan(hist)] = 0

        params = {}
        if self.p.weight_dimension:
            params['vdims'] = [view.get_dimension(self.p.weight_dimension)]
        if view.group != view.__class__.__name__:
            params['group'] = view.group

        hist_view = Histogram(hist,
                              edges,
                              kdims=[view.get_dimension(selected_dim)],
                              label=view.label,
                              **params)

        return (view << hist_view) if self.p.adjoin else hist_view