def normalize_to_dataset(data: DataType):
    from arpes.io import load_dataset
    if isinstance(data, xr.Dataset):
        return data

    if isinstance(data, str):
        return load_dataset(dataset_uuid=data)
def attach_extra_dataset_columns(path, **kwargs):
    from arpes.io import load_dataset
    import arpes.xarray_extensions  # pylint: disable=unused-import, redefined-outer-name

    base_filename, extension = os.path.splitext(path)
    if extension not in _DATASET_EXTENSIONS:
        logging.warning('File is not an excel file')
        return None

    if 'cleaned' in base_filename:
        new_filename = base_filename + extension
    else:
        new_filename = base_filename + '.cleaned' + extension
    assert os.path.exists(new_filename)

    ds = pd.read_excel(new_filename, **kwargs)

    ColumnDef = namedtuple('ColumnDef', ['default', 'source'])
    add_columns = {
        'spectrum_type': ColumnDef('', 'attr'),
    }

    for column, definition in add_columns.items():
        ds[column] = definition.default

    # Add required columns
    if 'id' not in ds:
        ds['id'] = np.nan

    if 'path' not in ds:
        ds['path'] = ''

    # Cascade blank values
    for index, row in ds.sort_index().iterrows():
        row = row.copy()

        print(row.id)
        try:
            scan = load_dataset(dataset_uuid=row.id, df=ds)
        except ValueError as e:
            logging.warning(str(e))
            logging.warning('Skipping {}! Unable to load scan.'.format(row.id))
            continue
        for column, definition in add_columns.items():
            if definition.source == 'accessor':
                ds.loc[index, (column, )] = getattr(scan.S, column)
            elif definition.source == 'attr':
                ds.loc[index, (column, )] = scan.attrs[column]

    os.remove(new_filename)
    excel_writer = pd.ExcelWriter(new_filename)
    ds.to_excel(excel_writer, index=False)
    excel_writer.save()

    return ds.set_index('file')
    def make_tool(self,
                  arr: Union[xr.DataArray, str],
                  notebook_url=None,
                  notebook_handle=True,
                  **kwargs):
        from bokeh.application import Application
        from bokeh.application.handlers import FunctionHandler
        from bokeh.io import show

        def generate_url(port):
            if port is None:
                return 'localhost:8888'

            return 'localhost:{}'.format(port)

        if notebook_url is None:
            if 'PORT' in arpes.config.CONFIG:
                notebook_url = 'localhost:{}'.format(
                    arpes.config.CONFIG['PORT'])
            else:
                notebook_url = 'localhost:8888'

        if isinstance(arr, str):
            arr = load_dataset(arr)
            if 'cycle' in arr.dims and len(arr.dims) > 3:
                warnings.warn('Summing over cycle')
                arr = arr.sum('cycle', keep_attrs=True)

        if self.auto_zero_nans and {'kx', 'ky', 'kz', 'kp'}.intersection(
                set(arr.dims)):
            # We need to copy and make sure to clear any nan values, because bokeh
            # does not send these over the wire for some reason
            arr = arr.copy()
            np.nan_to_num(arr.values, copy=False)

        # rebin any axes that have more than 800 pixels
        if self.auto_rebin and np.any(np.asarray(arr.shape) > self.rebin_size):
            reduction = {
                d: (s // self.rebin_size) + 1
                for d, s in arr.S.dshape.items()
            }
            warnings.warn('Rebinning with {}'.format(reduction))

            arr = rebin(arr, reduction=reduction)

            # TODO pass in a reference to the original copy of the array and make sure that
            # preparation tasks move over transparently

        self.arr = arr
        handler = FunctionHandler(self.tool_handler)
        app = Application(handler)
        show(app, notebook_url=notebook_url, notebook_handle=notebook_handle)

        return self.app_context
def normalize_to_spectrum(data: DataType):
    from arpes.io import load_dataset
    if isinstance(data, xr.Dataset):
        if 'up' in data.data_vars:
            return data.up

        return data.S.spectrum

    if isinstance(data, str):
        return normalize_to_spectrum(load_dataset(dataset_uuid=data))

    # not guaranteed to be a spectrum, but close enough
    return data
def hv_reference_scan(data,
                      out=None,
                      e_cut=-0.05,
                      bkg_subtraction=0.8,
                      **kwargs):
    fs = data.S.fat_sel(eV=e_cut)
    fs = normalize_dim(fs, 'hv', keep_id=True)
    fs.data -= bkg_subtraction * np.mean(fs.data)
    fs.data[fs.data < 0] = 0

    _, ax = labeled_fermi_surface(fs, hold=True, **kwargs)

    all_scans = data.attrs['df']
    all_scans = all_scans[all_scans.id != data.attrs['id']]
    all_scans = all_scans[(all_scans.spectrum_type != 'xps_spectrum') |
                          (all_scans.spectrum_type == 'hv_map')]

    scans_by_hv = defaultdict(list)
    for _, row in all_scans.iterrows():
        scan = load_dataset(row.id)

        scans_by_hv[round(scan.S.hv)].append(scan.S.label.replace('_', ' '))

    dim_order = ax.dim_order
    handles = []
    handle_labels = []

    prop_cycle = plt.rcParams['axes.prop_cycle']
    colors_cycle = prop_cycle.by_key()['color']

    for line_color, (hv, labels) in zip(colors_cycle, scans_by_hv.items()):
        full_label = '\n'.join(labels)

        # determine direction
        if dim_order[0] == 'hv':
            # photon energy is along the x axis, we want an axvline
            handles.append(ax.axvline(hv, label=full_label, color=line_color))
        else:
            # photon energy is along the y axis, we want an axhline
            handles.append(ax.axhline(hv, label=full_label, color=line_color))

        handle_labels.append(full_label)

    plt.legend(handles, handle_labels)

    if out is not None:
        plt.savefig(path_for_plot(out), dpi=400)
        return path_for_plot(out)

    plt.show()
def reference_scan_fermi_surface(data, out=None, **kwargs):
    fs = data.S.fermi_surface
    _, ax = labeled_fermi_surface(fs, hold=True, **kwargs)

    referenced_scans = data.S.referenced_scans
    handles = []
    for index, row in referenced_scans.iterrows():
        scan = load_dataset(row.id)
        remapped_coords = remap_coords_to(scan, data)
        dim_order = ax.dim_order
        ls = ax.plot(remapped_coords[dim_order[0]],
                     remapped_coords[dim_order[1]],
                     label=index.replace('_', ' '))
        handles.append(ls[0])

    plt.legend(handles=handles)

    if out is not None:
        plt.savefig(path_for_plot(out), dpi=400)
        return path_for_plot(out)

    plt.show()