Exemple #1
0
def load_ipython_extension(shell):
    import traitlets

    shell.add_traits(last_parent=traitlets.Any(),
                     cell_ids=traitlets.Set(),
                     id=traitlets.Any())
    shell.events.register(pre_run_cell.__name__, pre_run_cell)
    Kernel.patch()
Exemple #2
0
class LiteraryTagAllowListPreprocessor(preprocessors.Preprocessor):
    allow_cell_tags = traitlets.Set(traitlets.Unicode(),
                                    default_value={"export", "docstring"})

    def check_cell_conditions(self, cell, resources: dict, index: int) -> bool:
        tags = cell.metadata.get("tags", [])
        return bool(self.allow_cell_tags.intersection(tags))

    def preprocess(self, nb, resources: dict):
        nb.cells = [
            self.preprocess_cell(cell, resources, i)[0]
            for i, cell in enumerate(nb.cells)
            if self.check_cell_conditions(cell, resources, i)
        ]
        return nb, resources

    def preprocess_cell(self, cell, resources: dict, index: int):
        return cell, resources
Exemple #3
0
class AiidaLabApp(traitlets.HasTraits):
    """Manage installation status of an AiiDA lab app.

    Arguments:

        name (str):
            Name of the Aiida lab app.
        app_data (dict):
            Dictionary containing the app metadata.
        aiidalab_apps_path (str):
            Path to directory at which the app is expected to be installed.
    """

    path = traitlets.Unicode(allow_none=True, readonly=True)
    install_info = traitlets.Unicode()

    available_release_lines = traitlets.Set(traitlets.Unicode)
    installed_release_line = traitlets.Unicode(allow_none=True)
    installed_version = traitlets.Unicode(allow_none=True)
    updates_available = traitlets.Bool(readonly=True, allow_none=True)

    busy = traitlets.Bool(readonly=True)
    modified = traitlets.Bool(readonly=True, allow_none=True)

    class AppPathFileSystemEventHandler(FileSystemEventHandler):
        """Internal event handeler for app path file system events."""
        def __init__(self, app):
            self.app = app

        def on_any_event(self, event):
            """Refresh app for any event."""
            self.app.refresh_async()

    def __init__(self, name, app_data, aiidalab_apps_path):
        super().__init__()

        if app_data is not None:
            self._git_url = app_data['git_url']
            self._meta_url = app_data['meta_url']
            self._git_remote_refs = app_data['gitinfo']
            self.categories = app_data['categories']
            self._meta_info = app_data['metainfo']
        else:
            self._git_url = None
            self._meta_url = None
            self._git_remote_refs = {}
            self.categories = None
            self._meta_info = None

        self._observer = None
        self._check_install_status_changed_thread = None

        self.name = name
        self.path = os.path.join(aiidalab_apps_path, self.name)
        self.refresh_async()
        self._watch_repository()

    @traitlets.default('modified')
    def _default_modified(self):
        if self.is_installed():
            return self._repo.dirty()
        return None

    @traitlets.default('busy')
    def _default_busy(self):  # pylint: disable=no-self-use
        return False

    @contextmanager
    def _show_busy(self):
        """Apply this decorator to indicate that the app is busy during execution."""
        self.set_trait('busy', True)
        try:
            yield
        finally:
            self.set_trait('busy', False)

    def _watch_repository(self):
        """Watch the app repository for file system events.

        The app state is refreshed automatically for all events.
        """
        if self._observer is None and os.path.isdir(self.path):
            event_handler = self.AppPathFileSystemEventHandler(self)

            self._observer = Observer()
            self._observer.schedule(event_handler, self.path, recursive=True)
            self._observer.start()

        if self._check_install_status_changed_thread is None:

            def check_install_status_changed():
                installed = self.is_installed()
                while not self._check_install_status_changed_thread.stop_flag:
                    if installed != self.is_installed():
                        installed = self.is_installed()
                        self.refresh()
                    sleep(1)

            self._check_install_status_changed_thread = Thread(
                target=check_install_status_changed)
            self._check_install_status_changed_thread.stop_flag = False
            self._check_install_status_changed_thread.start()

    def _stop_watch_repository(self, timeout=None):
        """Stop watching the app repository for file system events."""
        if self._observer is not None:
            self._observer.stop()
            self._observer.join(timeout=timeout)
            if not self._observer.isAlive():
                self._observer = None

        if self._check_install_status_changed_thread is not None:
            self._check_install_status_changed_thread.stop_flag = True
            self._check_install_status_changed_thread.join(timeout=timeout)
            if not self._check_install_status_changed_thread.is_alive():
                self._check_install_status_changed_thread = None

    def __del__(self):  # pylint: disable=missing-docstring
        self._stop_watch_repository(1)

    def in_category(self, category):
        # One should test what happens if the category won't be defined.
        return category in self.categories

    def is_installed(self):
        """The app is installed if the corresponding folder is present."""
        return os.path.isdir(self.path)

    def _has_git_repo(self):
        """Check if the app has a .git folder in it."""
        try:
            Repo(self.path)
            return True
        except NotGitRepository:
            return False

    def install_app(self, version=None):
        """Installing the app."""
        assert self._git_url is not None
        if version is None:
            version = 'git:refs/heads/' + AIIDALAB_DEFAULT_GIT_BRANCH

        with self._show_busy():
            assert version.startswith('git:refs/heads/')
            branch = re.sub(r'git:refs\/heads\/', '', version)

            if not os.path.isdir(self.path):  # clone first
                check_output([
                    'git', 'clone', '--branch', branch, self._git_url,
                    self.path
                ],
                             cwd=os.path.dirname(self.path),
                             stderr=STDOUT)

            check_output(['git', 'checkout', '-f', branch],
                         cwd=self.path,
                         stderr=STDOUT)
            self.refresh()
            self._watch_repository()
            return branch

    def update_app(self, _=None):
        """Perform app update."""
        assert self._git_url is not None
        with self._show_busy():
            fetch(repo=self._repo, remote_location=self._git_url)
            tracked_branch = self._repo.get_tracked_branch()
            check_output(['git', 'reset', '--hard', tracked_branch],
                         cwd=self.path,
                         stderr=STDOUT)
            self.refresh_async()

    def uninstall_app(self, _=None):
        """Perfrom app uninstall."""
        # Perform uninstall process.
        with self._show_busy():
            self._stop_watch_repository()
            try:
                shutil.rmtree(self.path)
            except FileNotFoundError:
                raise RuntimeError("App was already uninstalled!")
            self.refresh()
            self._watch_repository()

    def check_for_updates(self):
        """Check whether there is an update available for the installed release line."""
        try:
            assert self._git_url is not None
            branch_ref = 'refs/heads/' + self._repo.branch().decode()
            assert self._repo.get_tracked_branch() is not None
            remote_update_available = self._git_remote_refs.get(
                branch_ref) != self._repo.head().decode()
            self.set_trait(
                'updates_available', remote_update_available
                or self._repo.update_available())
        except (AssertionError, RuntimeError):
            self.set_trait('updates_available', None)

    def _available_release_lines(self):
        """"Return all available release lines (local and remote)."""
        for branch in self._repo.list_branches():
            yield 'git:refs/heads/' + branch.decode()
        for ref in self._git_remote_refs:
            if ref.startswith('refs/heads/'):
                yield 'git:' + ref

    @throttled(calls_per_second=1)
    def refresh(self):
        """Refresh app state."""
        with self._show_busy():
            with self.hold_trait_notifications():
                if self.is_installed() and self._has_git_repo():
                    self.available_release_lines = set(
                        self._available_release_lines())
                    try:
                        self.installed_release_line = 'git:refs/heads/' + self._repo.branch(
                        ).decode()
                    except RuntimeError:
                        self.installed_release_line = None
                    self.installed_version = self._repo.head()
                    self.check_for_updates()
                    self.set_trait('modified', self._repo.dirty())
                else:
                    self.available_release_lines = set()
                    self.installed_release_line = None
                    self.installed_version = None
                    self.set_trait('updates_available', None)
                    self.set_trait('modified', None)

    def refresh_async(self):
        """Asynchronized (non-blocking) refresh of the app state."""
        refresh_thread = Thread(target=self.refresh)
        refresh_thread.start()

    @property
    def metadata(self):
        """Return metadata dictionary. Give the priority to the local copy (better for the developers)."""
        if self.is_installed():
            try:
                with open(os.path.join(self.path,
                                       'metadata.json')) as json_file:
                    return json.load(json_file)
            except IOError:
                return dict()
        elif self._meta_info is not None:
            return dict(self._meta_info)
        elif self._meta_url is None:
            raise RuntimeError(
                f"Requested app '{self.name}' is not installed and is also not registered on the app registry."
            )
        else:
            return requests.get(self._meta_url).json()

    def _get_from_metadata(self, what):
        """Get information from metadata."""

        try:
            return "{}".format(self.metadata[what])
        except KeyError:
            if not os.path.isfile(os.path.join(self.path, 'metadata.json')):
                return '({}) metadata.json file is not present'.format(what)
            return 'the field "{}" is not present in metadata.json file'.format(
                what)

    @property
    def authors(self):
        return self._get_from_metadata('authors')

    @property
    def description(self):
        return self._get_from_metadata('description')

    @property
    def title(self):
        return self._get_from_metadata('title')

    @property
    def url(self):
        """Provide explicit link to Git repository."""
        return self._git_url

    @property
    def more(self):
        return """<a href=./single_app.ipynb?app={}>Manage App</a>""".format(
            self.name)

    @property
    def logo(self):
        """Return logo object. Give the priority to the local version"""

        # For some reason standard ipw.Image() app does not work properly.
        res = ipw.HTML('<img src="./aiidalab_logo_v4.svg">',
                       layout={
                           'width': '100px',
                           'height': '100px'
                       })

        # Checking whether the 'logo' key is present in metadata dictionary.
        if 'logo' not in self.metadata:
            res.value = '<img src="./aiidalab_logo_v4.svg">'

        # If 'logo' key is present and the app is installed.
        elif self.is_installed():
            res.value = '<img src="{}">'.format(
                os.path.join('..', self.name, self.metadata['logo']))

        # If not installed, getting file from the remote git repository.
        else:
            # Remove .git if present.
            html_link = os.path.splitext(self._git_url)[0]

            # We expect it to always be a git repository
            html_link += '/master/' + self.metadata['logo']
            if 'github.com' in html_link:
                html_link = html_link.replace('github.com',
                                              'raw.githubusercontent.com')
                if html_link.endswith('.svg'):
                    html_link += '?sanitize=true'
            res.value = '<img src="{}">'.format(html_link)

        return res

    @property
    def _repo(self):
        """Returns Git repository."""
        if not self.is_installed():
            raise AppNotInstalledException("The app is not installed")
        return Repo(self.path)

    def render_app_manager_widget(self):
        """Display widget to manage the app."""
        try:
            return AppManagerWidget(self, with_version_selector=True)
        except Exception as error:  # pylint: disable=broad-except
            return ipw.HTML(
                '<div style="font-size: 30px; text-align:center;">'
                f'Unable to show app widget due to error: {error}'
                '</div>',
                layout={'width': '600px'})
Exemple #4
0
def Set(klass: T, **kw) -> typing.Set[T]:
    return traitlets.Set(klass, **kw)
Exemple #5
0
class CenterLimitedPSSMModel(BaseEstimator, RegressorMixin,
                             traitlets.HasTraits):

    alpha_center = traitlets.Float(default_value=0, allow_none=False, min=0)
    alpha_flank = traitlets.Float(default_value=0, allow_none=False, min=0)

    error_upper_lim = traitlets.Float(default_value=10, allow_none=False)
    error_lower_lim = traitlets.Float(default_value=-10, allow_none=False)

    max_data_upweight = traitlets.Float(default_value=1.0,
                                        allow_none=False,
                                        min=1.0)

    init_EC50_max = traitlets.Float(default_value=150.0, allow_none=False)
    init_k_max = traitlets.Float(default_value=80.0, allow_none=False)
    init_c0 = traitlets.Float(default_value=4.0, allow_none=False)

    flanking_window = traitlets.Integer(default_value=4,
                                        allow_none=False,
                                        min=0)

    init_aas = traitlets.Set(
        trait=traitlets.Enum(IUPAC.IUPACProtein.letters),
        default_value=None,
        allow_none=True,
    )

    def _get_param_names(self):
        return self.trait_names()

    @classmethod
    def from_state(cls, params):
        instance = cls.__new__(cls)
        instance.__setstate__(params)
        return instance

    def get_state(self):
        return self.__getstate__()

    def __setstate__(self, state):
        HasTraits.__setstate__(self, copy.deepcopy(state))

        if state.get("fit_coeffs_", None) is not None:
            self.setup()
            self.fit_coeffs_ = state["fit_coeffs_"]

    def __getstate__(self):
        state = HasTrait.__getstate__(self)
        statekeys = (
            "_trait_values",
            "_fit_coeffs",
            "_trait_validators",
            "_trait_notifiers",
            "_cross_validation_lock",
        )

        for k in set(state) - set(statekeys):
            del state[k]

        return state

    def setup(self):
        """Validate parameters and setup estimator."""

        self.dat = {
            "seq": T.lmatrix("seq"),
            "targ": T.dvector("targ"),
            "data_weights": T.dvector("data_weights"),
        }

        self.dat["seq"].tag.test_value = numpy.random.randint(21, size=(5, 30))
        self.dat["targ"].tag.test_value = numpy.random.random(5)
        self.dat["data_weights"].tag.test_value = numpy.random.random(5)

        self.vs = {
            "outer_PSSM": T.dmatrix("outer_PSSM"),
            "P1_PSSM": T.dvector("P1_PSSM"),
            "c0": T.scalar("c0"),
            "EC50_max": T.scalar("EC50_max"),
            "k_max": T.scalar("k_max"),
        }

        self.v_dtypes = {
            "outer_PSSM": (float, (2 * self.flanking_window + 1, 21)),
            "P1_PSSM": (float, 21),
            "c0": (float, ()),
            "EC50_max": (float, ()),
            "k_max": (float, ()),
        }

        for v in self.vs:
            self.vs[v].tag.test_value = numpy.random.random(
                self.v_dtypes[v][1])

        self.coeff_dtype = numpy.dtype([
            (n, ) + t for n, t in list(self.v_dtypes.items())
        ])

        self.targets = {}
        self._build_model()

        self.functions = {
            k:
            theano.function(list(self.dat.values()) + list(self.vs.values()),
                            outputs=f,
                            name=k,
                            on_unused_input="ignore")
            for k, f in list(self.targets.items())
        }
        self.predict_fn = theano.function([self.dat["seq"]] +
                                          list(self.vs.values()),
                                          outputs=self.targets["score"],
                                          name="predict")

    def encode_seqs(self, seqs):
        if isinstance(seqs, list) or seqs.dtype == 'O':
            assert len(set(map(len,
                               seqs))) == 1, "Sequences of unequal length."
            seqs = numpy.array([numpy.fromstring(s, dtype="S1") for s in seqs])

        if seqs.dtype == numpy.dtype("S1"):
            IUPAC_protein_letters_plus_Z = numpy.fromstring(
                IUPAC.IUPACProtein.letters + 'Z', dtype="S1")
            for aa in numpy.unique(seqs):
                assert aa in IUPAC_protein_letters_plus_Z, \
                    "Non-IUPAC/non-Z letter code in input sequences: %r" % aa
            seqs = numpy.searchsorted(IUPAC_protein_letters_plus_Z, seqs)

        assert seqs.dtype == numpy.int, "Invalid sequence dtype: %r" % seqs.dtype
        assert seqs.min() >= 0 and seqs.max(
        ) < 21, "Invalid encoded sequence range: (%s, %s)" % (seqs.min(),
                                                              seqs.max())

        return seqs

    def _build_model(self):

        outer_PSSM = self.vs["outer_PSSM"]
        P1_PSSM = self.vs["P1_PSSM"]

        window_size = self.flanking_window * 2 + 1
        window_pos = list(range(self.flanking_window)) + list(
            range(self.flanking_window + 1, window_size))

        seq = self.dat["seq"]

        point_scores = [
            outer_PSSM[i, seq[:, i:-(window_size - 1 - i) if i < window_size -
                              1 else None]] for i in window_pos
        ]

        point_scores.append(
            P1_PSSM[seq[:, self.flanking_window:-(self.flanking_window)]])

        k_max = self.vs["k_max"]

        pssm_score = sum(point_scores)

        c0 = self.vs["c0"]

        ind_score = k_max / (1.0 + T.exp(c0 - pssm_score))  # Eq. 19

        EC50_max = self.vs["EC50_max"]

        score = T.log(EC50_max / (ind_score.sum(axis=-1) + 1)) / T.log(
            3)  # Eq. 18, converted to a log EC50 instead of a raw EC50

        targ = self.dat["targ"]
        data_weights = self.dat["data_weights"]

        error_upper_lim = self.error_upper_lim
        error_lower_lim = self.error_lower_lim

        error = T.clip(targ - score, error_lower_lim, error_upper_lim)

        mse = T.mean(T.square(error) + (0.25 * T.abs_(targ - score)))  # Eq. 20
        weighted_mse = T.sum(
            (T.square(error) + (0.25 * T.abs_(targ - score))) * data_weights
        ) / T.sum(
            data_weights)  #Eq. 20, including weights on the data (not used)
        regularization = (self.alpha_flank *
                          T.square(outer_PSSM[:, 0:-1])).sum() + (
                              self.alpha_center * T.abs_(P1_PSSM[0:-1])).sum()

        loss = weighted_mse + regularization
        loss_jacobian = dict(
            list(zip(self.vs, T.jacobian(loss, list(self.vs.values())))))

        self.targets = {
            "score": score,
            "pssm_score": pssm_score,
            "ind_score": ind_score,
            "mse": mse,
            "weighted_mse": weighted_mse,
            "regularization": regularization,
            "loss": loss,
            "loss_jacobian": loss_jacobian,
        }

    def opt_cycle(self, target_vars, **opt_options):
        packed_dtype = numpy.dtype([(n, ) + self.v_dtypes[n]
                                    for n in target_vars])

        def eval_mse(packed_vars):
            vpack = packed_vars.copy().view(packed_dtype).reshape(())
            vs = self.fit_coeffs_.copy()
            for n in vpack.dtype.names:
                vs[n] = vpack[n]

            return self.eval_f("mse", coeffs=vs)

        def eval_loss(packed_vars):
            vpack = packed_vars.copy().view(packed_dtype).reshape(())
            vs = self.fit_coeffs_.copy()
            for n in vpack.dtype.names:
                vs[n] = vpack[n]

            return self.eval_f("loss", coeffs=vs)

        def eval_loss_jac(packed_vars):
            vpack = packed_vars.copy().view(packed_dtype).reshape(())
            vs = self.fit_coeffs_.copy()
            for n in vpack.dtype.names:
                vs[n] = vpack[n]

            jacobian = self.eval_f("loss_jacobian", coeffs=vs)
            jpack = numpy.zeros((), dtype=packed_dtype)
            for n in jpack.dtype.names:
                jpack[n] = jacobian[n]
            return jpack.reshape(1).view(float)

        def iter_callback(packed_vars):
            if ic[0] % 10 == 0:
                logging.debug("iter: %03i mse: %.3f loss: %.3f", ic[0],
                              eval_mse(packed_vars), eval_loss(packed_vars))
            ic[0] += 1

        ic = [0]
        start_coeffs = numpy.empty((), packed_dtype)
        for n in start_coeffs.dtype.names:
            start_coeffs[n] = self.fit_coeffs_[n]

        opt_result = scipy.optimize.minimize(
            fun=eval_loss,
            jac=eval_loss_jac,
            x0=start_coeffs.reshape(1).view(float),
            callback=iter_callback,
            **opt_options)

        opt_result.packed_x = opt_result.x.copy().view(packed_dtype).reshape(
            ())
        for n in opt_result.packed_x.dtype.names:
            self.fit_coeffs_[n] = opt_result.packed_x[n]

        logging.info("last_opt iter: %03i fun: %.3f mse: %.3f", ic[0],
                     opt_result.fun, eval_mse(opt_result.packed_x))

        return opt_result

    def fit(self, X, y):
        from scipy.stats import gaussian_kde
        self.setup()

        self.fit_X_ = self.encode_seqs(X)
        self.fit_y_ = y

        data_density = gaussian_kde(y)
        data_y_range = numpy.linspace(min(y), max(y), 1000)
        max_data_density = max(data_density(data_y_range))

        self.data_weights = numpy.clip(max_data_density / data_density(y), 1.0,
                                       self.max_data_upweight)

        self.fit_coeffs_ = numpy.zeros((), self.coeff_dtype)

        self.fit_coeffs_["outer_PSSM"][self.flanking_window] = 1
        if self.init_aas:
            self.fit_coeffs_["P1_PSSM"] = [
                1 if aa in self.init_aas else 0
                for aa in (IUPAC.IUPACProtein.letters + 'Z')
            ]
        else:
            self.fit_coeffs_["P1_PSSM"] = 0

        self.fit_coeffs_["c0"] = self.init_c0
        self.fit_coeffs_["k_max"] = self.init_k_max

        self.fit_coeffs_["EC50_max"] = self.init_EC50_max

        opt_cycles = [
            ("P1_PSSM", ),
            ("c0", "k_max", "EC50_max"),
            ("P1_PSSM", ),
            ("outer_PSSM", "P1_PSSM"),
            ("c0", "k_max", "EC50_max"),
            ("outer_PSSM", "P1_PSSM"),
            ("c0", "EC50_max", "k_max"),
            ("outer_PSSM", "P1_PSSM", "k_max"),
            ("c0", "EC50_max", "outer_PSSM", "P1_PSSM", "k_max"),
        ]

        for i, vset in enumerate(opt_cycles):
            logging.info("opt_cycle: %i vars: %s", i, vset)

            self._last_opt_result = self.opt_cycle(
                vset,
                tol=1e-3 if i < len(opt_cycles) - 1 else 2e-4  #2e-5
            )

    def predict(self, X):
        return self.predict_fn(
            self.encode_seqs(X),
            **{n: self.fit_coeffs_[n]
               for n in self.fit_coeffs_.dtype.names})

    def eval_f(self, fname, X=None, y=None, data_weights=None, coeffs=None):
        if coeffs is None:
            coeffs = self.fit_coeffs_
        if X is None:
            X = self.fit_X_
        else:
            X = self.encode_seqs(X)

        if y is None:
            y = self.fit_y_
        if data_weights is None:
            data_weights = self.data_weights

        return self.functions[fname](
            seq=X,
            targ=y,
            data_weights=data_weights,
            **{n: coeffs[n]
               for n in coeffs.dtype.names})
Exemple #6
0
class Coordinates1d(BaseCoordinates):
    """
    Base class for 1-dimensional coordinates.

    Coordinates1d objects contain values and metadata for a single dimension of coordinates. :class:`podpac.Coordinates` and
    :class:`StackedCoordinates` use Coordinate1d objects.

    Parameters
    ----------
    name : str
        Dimension name, one of 'lat', 'lon', 'time', or 'alt'.
    coordinates : array, read-only
        Full array of coordinate values.

    See Also
    --------
    :class:`ArrayCoordinates1d`, :class:`UniformCoordinates1d`
    """

    name = Dimension(allow_none=True)
    _properties = tl.Set()

    @tl.observe("name")
    def _set_property(self, d):
        if d["name"] is not None:
            self._properties.add(d["name"])

    def _set_name(self, value):
        # set name if it is not set already, otherwise check that it matches
        if "name" not in self._properties:
            self.name = value
        elif self.name != value:
            raise ValueError("Dimension mismatch, %s != %s" %
                             (value, self.name))

    # ------------------------------------------------------------------------------------------------------------------
    # standard methods
    # ------------------------------------------------------------------------------------------------------------------

    def __repr__(self):
        if self.name is None:
            name = "%s" % (self.__class__.__name__, )
        else:
            name = "%s(%s)" % (self.__class__.__name__, self.name)

        if self.ndim == 1:
            desc = "Bounds[%s, %s], N[%d]" % (self.bounds[0], self.bounds[1],
                                              self.size)
        else:
            desc = "Bounds[%s, %s], N[%s], Shape%s" % (
                self.bounds[0], self.bounds[1], self.size, self.shape)

        return "%s: %s" % (name, desc)

    def _eq_base(self, other):
        """ used by child __eq__ methods for common checks """
        if not isinstance(other, Coordinates1d):
            return False

        # defined coordinate properties should match
        for name in self._properties.union(other._properties):
            if getattr(self, name) != getattr(other, name):
                return False

        # shortcuts (not strictly necessary)
        for name in ["shape", "is_monotonic", "is_descending", "is_uniform"]:
            if getattr(self, name) != getattr(other, name):
                return False

        return True

    def __len__(self):
        return self.shape[0]

    def __contains__(self, item):
        try:
            item = make_coord_value(item)
        except:
            return False

        if type(item) != self.dtype:
            return False

        return item in self.coordinates

    # ------------------------------------------------------------------------------------------------------------------
    # Properties
    # ------------------------------------------------------------------------------------------------------------------

    @property
    def dims(self):
        if self.name is None:
            raise TypeError(
                "cannot access dims property of unnamed Coordinates1d")
        return (self.name, )

    @property
    def xcoords(self):
        """:dict: xarray coords"""

        if self.name is None:
            raise ValueError("Cannot get xcoords for unnamed Coordinates1d")

        return {self.name: (self.xdims, self.coordinates)}

    @property
    def dtype(self):
        """:type: Coordinates dtype.

        ``float`` for numerical coordinates and numpy ``datetime64`` for datetime coordinates.
        """

        raise NotImplementedError

    @property
    def deltatype(self):
        if self.dtype is np.datetime64:
            return np.timedelta64
        else:
            return self.dtype

    @property
    def is_monotonic(self):
        raise NotImplementedError

    @property
    def is_descending(self):
        raise NotImplementedError

    @property
    def is_uniform(self):
        raise NotImplementedError

    @property
    def start(self):
        raise NotImplementedError

    @property
    def stop(self):
        raise NotImplementedError

    @property
    def step(self):
        raise NotImplementedError

    @property
    def bounds(self):
        """ Low and high coordinate bounds. """

        raise NotImplementedError

    @property
    def properties(self):
        """:dict: Dictionary of the coordinate properties. """

        return {key: getattr(self, key) for key in self._properties}

    @property
    def definition(self):
        """:dict: Serializable 1d coordinates definition."""
        return self._get_definition(full=False)

    @property
    def full_definition(self):
        """:dict: Serializable 1d coordinates definition, containing all properties. For internal use."""
        return self._get_definition(full=True)

    def _get_definition(self, full=True):
        raise NotImplementedError

    @property
    def _full_properties(self):
        return {"name": self.name}

    # ------------------------------------------------------------------------------------------------------------------
    # Methods
    # ------------------------------------------------------------------------------------------------------------------

    def copy(self):
        """
        Make a deep copy of the 1d Coordinates.

        Returns
        -------
        :class:`Coordinates1d`
            Copy of the coordinates.
        """

        raise NotImplementedError

    def simplify(self):
        """Get the simplified/optimized representation of these coordinates.

        Returns
        -------
        simplified : Coordinates1d
            simplified version of the coordinates
        """

        raise NotImplementedError

    def get_area_bounds(self, boundary):
        """
        Get low and high coordinate area bounds.

        Arguments
        ---------
        boundary : float, timedelta, array, None
            Boundary offsets in this dimension.

            * For a centered uniform boundary (same for every coordinate), use a single positive float or timedelta
                offset. This represents the "total segment length" / 2.
            * For a uniform boundary (segment or polygon same for every coordinate), use an array of float or
                timedelta offsets
            * For a fully specified boundary, use an array of boundary arrays (2-D array, N_coords x boundary spec),
                 one per coordinate. The boundary_spec can be a single number, two numbers, or an array of numbers.
            * For point coordinates, use None.

        Returns
        -------
        low: float, np.datetime64
            low area bound
        high: float, np.datetime64
            high area bound
        """

        # point coordinates
        if boundary is None:
            return self.bounds

        # empty coordinates
        if self.size == 0:
            return self.bounds

        if np.array(boundary).ndim == 0:
            # shortcut for uniform centered boundary
            boundary = make_coord_delta(boundary)
            lo_offset = -boundary
            hi_offset = boundary
        elif np.array(boundary).ndim == 1:
            # uniform boundary polygon
            boundary = make_coord_delta_array(boundary)
            lo_offset = min(boundary)
            hi_offset = max(boundary)
        else:
            L, H = self.argbounds
            lo_offset = min(make_coord_delta_array(boundary[L]))
            hi_offset = max(make_coord_delta_array(boundary[H]))

        lo, hi = self.bounds
        lo = add_coord(lo, lo_offset)
        hi = add_coord(hi, hi_offset)

        return lo, hi

    def _select_empty(self, return_index):
        I = []
        if return_index:
            return self[I], I
        else:
            return self[I]

    def _select_full(self, return_index):
        I = slice(None)
        if return_index:
            return self[I], I
        else:
            return self[I]

    def select(self, bounds, return_index=False, outer=False):
        """
        Get the coordinate values that are within the given bounds.

        The default selection returns coordinates that are within the bounds::

            In [1]: c = ArrayCoordinates1d([0, 1, 2, 3], name='lat')

            In [2]: c.select([1.5, 2.5]).coordinates
            Out[2]: array([2.])

        The *outer* selection returns the minimal set of coordinates that contain the bounds::

            In [3]: c.select([1.5, 2.5], outer=True).coordinates
            Out[3]: array([1., 2., 3.])

        The *outer* selection also returns a boundary coordinate if a bound is outside this coordinates bounds but
        *inside* its area bounds::

            In [4]: c.select([3.25, 3.35], outer=True).coordinates
            Out[4]: array([3.0], dtype=float64)

            In [5]: c.select([10.0, 11.0], outer=True).coordinates
            Out[5]: array([], dtype=float64)

        Parameters
        ----------
        bounds : (low, high) or dict
            Selection bounds. If a dictionary of dim -> (low, high) bounds is supplied, the bounds matching these
            coordinates will be selected if available, otherwise the full coordinates will be returned.
        outer : bool, optional
            If True, do an *outer* selection. Default False.
        return_index : bool, optional
            If True, return index for the selection in addition to coordinates. Default False.

        Returns
        -------
        selection : :class:`Coordinates1d`
            Coordinates1d object with coordinates within the bounds.
        I : slice, boolean array
            index or slice for the selected coordinates (only if return_index=True)
        """

        # empty case
        if self.dtype is None:
            return self._select_empty(return_index)

        if isinstance(bounds, dict):
            bounds = bounds.get(self.name)
            if bounds is None:
                return self._select_full(return_index)

        bounds = make_coord_value(bounds[0]), make_coord_value(bounds[1])

        # check type
        if not isinstance(bounds[0], self.dtype):
            raise TypeError(
                "Input bounds do match the coordinates dtype (%s != %s)" %
                (type(self.bounds[0]), self.dtype))
        if not isinstance(bounds[1], self.dtype):
            raise TypeError(
                "Input bounds do match the coordinates dtype (%s != %s)" %
                (type(self.bounds[1]), self.dtype))

        my_bounds = self.bounds

        # If the bounds are of instance datetime64, then the comparison should happen at the lowest precision
        if self.dtype == np.datetime64:
            my_bounds, bounds = lower_precision_time_bounds(
                my_bounds, bounds, outer)

        # full
        if my_bounds[0] >= bounds[0] and my_bounds[1] <= bounds[1]:
            return self._select_full(return_index)

        # none
        if my_bounds[0] > bounds[1] or my_bounds[1] < bounds[0]:
            return self._select_empty(return_index)

        # partial, implemented in child classes
        return self._select(bounds, return_index, outer)

    def _select(self, bounds, return_index, outer):
        raise NotImplementedError

    def _transform(self, transformer):
        if self.name != "alt":
            # this assumes that the transformer does not have a spatial transform
            return self.copy()

        # transform "alt" coordinates
        from podpac.core.coordinates.array_coordinates1d import ArrayCoordinates1d

        _, _, tcoordinates = transformer.transform(np.zeros(self.shape),
                                                   np.zeros(self.shape),
                                                   self.coordinates)
        return ArrayCoordinates1d(tcoordinates, **self.properties)

    def issubset(self, other):
        """Report whether other coordinates contains these coordinates.

        Arguments
        ---------
        other : Coordinates, Coordinates1d
            Other coordinates to check

        Returns
        -------
        issubset : bool
            True if these coordinates are a subset of the other coordinates.
        """

        from podpac.core.coordinates import Coordinates

        if isinstance(other, Coordinates):
            if self.name not in other.dims:
                return False
            other = other[self.name]

        # short-cuts that don't require checking coordinates
        if self.size == 0:
            return True

        if other.size == 0:
            return False

        if self.dtype != other.dtype:
            return False

        if self.bounds[0] < other.bounds[0] or self.bounds[1] > other.bounds[1]:
            return False

        # check actual coordinates using built-in set method issubset
        # for datetimes, convert to the higher resolution
        my_coordinates = self.coordinates.ravel()
        other_coordinates = other.coordinates.ravel()

        if self.dtype == np.datetime64:
            if my_coordinates[0].dtype < other_coordinates[0].dtype:
                my_coordinates = my_coordinates.astype(other_coordinates.dtype)
            elif other_coordinates[0].dtype < my_coordinates[0].dtype:
                other_coordinates = other_coordinates.astype(
                    my_coordinates.dtype)

        return set(my_coordinates).issubset(other_coordinates)
Exemple #7
0
class IpyPubMain(Configurable):

    conversion = T.Unicode(
        "latex_ipypublish_main",
        help="key or path to conversion configuration").tag(config=True)

    plugin_folder_paths = T.Set(
        T.Unicode(),
        default_value=(),
        help="a list of folders containing conversion configurations",
    ).tag(config=True)

    @validate("plugin_folder_paths")
    def _validate_plugin_folder_paths(self, proposal):
        folder_paths = proposal["value"]
        for path in folder_paths:
            if not os.path.exists(path):
                raise TraitError(
                    "the configuration folder path does not exist: "
                    "{}".format(path))
        return proposal["value"]

    outpath = T.Union(
        [T.Unicode(), T.Instance(pathlib.Path)],
        allow_none=True,
        default_value=None,
        help="path to output converted files",
    ).tag(config=True)

    folder_suffix = T.Unicode(
        "_files",
        help=("suffix for the folder name where content will be dumped "
              "(e.g. internal images). "
              "It will be a sanitized version of the input filename, "
              "followed by the suffix"),
    ).tag(config=True)

    ignore_prefix = T.Unicode(
        "_", help=("prefixes to ignore, "
                   "when finding notebooks to merge")).tag(config=True)

    meta_path_placeholder = T.Unicode(
        "${meta_path}",
        help=("all string values in the export configuration containing "
              "this placeholder will be be replaced with the path to the "
              "notebook from which the metadata was obtained"),
    ).tag(config=True)

    files_folder_placeholder = T.Unicode(
        "${files_path}",
        help=(
            "all string values in the export configuration containing "
            "this placeholder will be be replaced with the path "
            "(relative to outpath) to the folder where files will be dumped"),
    ).tag(config=True)

    validate_nb_metadata = T.Bool(
        True,
        help=("before running the exporter, validate that "
              "the notebook level metadata is valid again the schema"),
    ).tag(config=True)

    pre_conversion_funcs = T.Dict(help=(
        "a mapping of file extensions to functions that can convert"
        "that file type Instance(nbformat.NotebookNode) = func(pathstr)")).tag(
            config=True)

    @default("pre_conversion_funcs")
    def _default_pre_conversion_funcs(self):
        try:
            import jupytext  # noqa: F401
        except ImportError:
            return {}

        try:
            from jupytext import read
        except ImportError:
            # this is deprecated in newer versions
            from jupytext import readf as read  # noqa: F401

        return {".Rmd": read, ".md": read}

    @validate("pre_conversion_funcs")
    def _validate_pre_conversion_funcs(self, proposal):
        for ext, func in proposal["value"].items():
            if not ext.startswith("."):
                raise TraitError("the extension key should start with a '.': "
                                 "{}".format(ext))
            try:
                func("string")
                # TODO should do this safely with inspect,
                # but no obvious solution
                # to check if it only requires one string argument
            except TypeError:
                raise TraitError("the function for {} can not be "
                                 "called with a single string arg: "
                                 "{}".format(ext, func))
            except Exception:
                pass
        return proposal["value"]

    log_to_stdout = T.Bool(
        True, help="whether to log to sys.stdout").tag(config=True)

    log_level_stdout = T.Enum(
        [
            "debug", "info", "warning", "error", "DEBUG", "INFO", "WARNING",
            "ERROR"
        ],
        default_value="INFO",
        help="the logging level to output to stdout",
    ).tag(config=True)

    log_stdout_formatstr = T.Unicode("%(levelname)s:%(name)s:%(message)s").tag(
        config=True)

    log_to_file = T.Bool(False, help="whether to log to file").tag(config=True)

    log_level_file = T.Enum(
        [
            "debug", "info", "warning", "error", "DEBUG", "INFO", "WARNING",
            "ERROR"
        ],
        default_value="INFO",
        help="the logging level to output to file",
    ).tag(config=True)

    log_file_path = T.Unicode(
        None,
        allow_none=True,
        help="if None, will output to {outdir}/{ipynb_name}.nbpub.log",
    ).tag(config=True)

    log_file_formatstr = T.Unicode("%(levelname)s:%(name)s:%(message)s").tag(
        config=True)

    default_ppconfig_kwargs = T.Dict(
        trait=T.Bool(),
        default_value=(
            ("pdf_in_temp", False),
            ("pdf_debug", False),
            ("launch_browser", False),
        ),
        help=("convenience arguments for constructing the post-processors "
              "default configuration"),
    ).tag(config=True)

    default_pporder_kwargs = T.Dict(
        trait=T.Bool(),
        default_value=(
            ("dry_run", False),
            ("clear_existing", False),
            ("dump_files", False),
            ("create_pdf", False),
            ("serve_html", False),
            ("slides", False),
        ),
        help=("convenience arguments for constructing the post-processors "
              "default list"),
    ).tag(config=True)

    # TODO validate that default_ppconfig/pporder_kwargs can be parsed to funcs

    default_exporter_config = T.Dict(
        help="default configuration for exporters").tag(config=True)

    @default("default_exporter_config")
    def _default_exporter_config(self):
        temp = "${files_path}/{unique_key}_{cell_index}_{index}{extension}"
        return {
            "ExtractOutputPreprocessor": {
                "output_filename_template": temp
            }
        }

    def _create_default_ppconfig(self,
                                 pdf_in_temp=False,
                                 pdf_debug=False,
                                 launch_browser=False):
        """create a default config for postprocessors"""
        return Config({
            "PDFExport": {
                "files_folder": "${files_path}",
                "convert_in_temp": pdf_in_temp,
                "debug_mode": pdf_debug,
                "open_in_browser": launch_browser,
                "skip_mime": False,
            },
            "RunSphinx": {
                "open_in_browser": launch_browser
            },
            "RemoveFolder": {
                "files_folder": "${files_path}"
            },
            "CopyResourcePaths": {
                "files_folder": "${files_path}"
            },
            "ConvertBibGloss": {
                "files_folder": "${files_path}"
            },
        })

    def _create_default_pporder(
        self,
        dry_run=False,
        clear_existing=False,
        dump_files=False,
        create_pdf=False,
        serve_html=False,
        slides=False,
    ):
        """create a default list of postprocessors to run"""
        default_pprocs = [
            "remove-blank-lines",
            "remove-trailing-space",
            "filter-output-files",
        ]
        if slides:
            default_pprocs.append("fix-slide-refs")
        if not dry_run:
            if clear_existing:
                default_pprocs.append("remove-folder")
            default_pprocs.append("write-text-file")
            if dump_files or create_pdf or serve_html:
                default_pprocs.extend([
                    "write-resource-files", "copy-resource-paths",
                    "convert-bibgloss"
                ])
            if create_pdf:
                default_pprocs.append("pdf-export")
            elif serve_html:
                default_pprocs.append("reveal-server")

        return default_pprocs

    @property
    def logger(self):
        return logging.getLogger("ipypublish")

    @contextmanager
    def _log_handlers(self, ipynb_name, outdir):

        root = logging.getLogger()
        root_level = root.level
        log_handlers = []

        try:
            root.setLevel(logging.DEBUG)

            if self.log_to_stdout:
                # setup logging to terminal
                slogger = logging.StreamHandler(sys.stdout)
                slogger.setLevel(
                    getattr(logging, self.log_level_stdout.upper()))
                formatter = logging.Formatter(self.log_stdout_formatstr)
                slogger.setFormatter(formatter)
                slogger.propogate = False
                root.addHandler(slogger)
                log_handlers.append(slogger)

            if self.log_to_file:
                # setup logging to file
                if self.log_file_path:
                    path = self.log_file_path
                else:
                    path = os.path.join(outdir, ipynb_name + ".nbpub.log")

                if not os.path.exists(os.path.dirname(path)):
                    os.makedirs(os.path.dirname(path))

                flogger = logging.FileHandler(path, "w")
                flogger.setLevel(getattr(logging, self.log_level_file.upper()))
                formatter = logging.Formatter(self.log_file_formatstr)
                flogger.setFormatter(formatter)
                flogger.propogate = False
                root.addHandler(flogger)
                log_handlers.append(flogger)

            yield

        finally:

            root.setLevel(root_level)
            for handler in log_handlers:
                handler.close()
                root.removeHandler(handler)

    def __init__(self, config=None):
        """
        Public constructor

        Parameters
        ----------
        config: traitlets.config.Config
            User configuration instance.

        """
        # with_default_config = self.default_config
        # if config:
        #     with_default_config.merge(config)
        if config is None:
            config = {}
        if not isinstance(config, Config):
            config = Config(config)
        with_default_config = config

        super(IpyPubMain, self).__init__(config=with_default_config)

    def __call__(self, ipynb_path, nb_node=None):
        """see IpyPubMain.publish"""
        return self.publish(ipynb_path, nb_node)

    def publish(self, ipynb_path, nb_node=None):
        """ convert one or more Jupyter notebooks to a published format

        paths can be string of an existing file or folder,
        or a pathlib.Path like object

        all files linked in the documents are placed into a single files_folder

        Parameters
        ----------
        ipynb_path: str or pathlib.Path
            notebook file or directory
        nb_node: None or nbformat.NotebookNode
            a pre-converted notebook

        Returns
        --------
        outdata: dict
            containing keys;
            "outpath", "exporter", "stream", "main_filepath", "resources"

        """
        # setup the input and output paths
        if isinstance(ipynb_path, string_types):
            ipynb_path = pathlib.Path(ipynb_path)
        ipynb_name, ipynb_ext = os.path.splitext(ipynb_path.name)
        outdir = (os.path.join(os.getcwd(), "converted")
                  if self.outpath is None else str(self.outpath))

        with self._log_handlers(ipynb_name, outdir):

            if not ipynb_path.exists() and not nb_node:
                handle_error(
                    "the notebook path does not exist: {}".format(ipynb_path),
                    IOError,
                    self.logger,
                )

            # log start of conversion
            self.logger.info("started ipypublish v{0} at {1}".format(
                ipypublish.__version__, time.strftime("%c")))
            self.logger.info("logging to: {}".format(
                os.path.join(outdir, ipynb_name + ".nbpub.log")))
            self.logger.info("running for ipynb(s) at: {0}".format(ipynb_path))
            self.logger.info("with conversion configuration: {0}".format(
                self.conversion))

            if nb_node is None and ipynb_ext in self.pre_conversion_funcs:
                func = self.pre_conversion_funcs[ipynb_ext]
                self.logger.info("running pre-conversion with: {}".format(
                    inspect.getmodule(func)))
                try:
                    nb_node = func(ipynb_path)
                except Exception as err:
                    handle_error(
                        "pre-conversion failed for {}: {}".format(
                            ipynb_path, err),
                        err,
                        self.logger,
                    )

            # doesn't work with folders
            # if (ipynb_ext != ".ipynb" and nb_node is None):
            #     handle_error(
            #         'the file extension is not associated with any '
            #         'pre-converter: {}'.format(ipynb_ext),
            # TypeError, self.logger)

            if nb_node is None:
                # merge all notebooks
                # TODO allow notebooks to remain separate
                # (would require creating a main.tex with the preamble in etc )
                # Could make everything a 'PyProcess',
                # with support for multiple streams
                final_nb, meta_path = merge_notebooks(
                    ipynb_path, ignore_prefix=self.ignore_prefix)
            else:
                final_nb, meta_path = (nb_node, ipynb_path)

            # validate the notebook metadata against the schema
            if self.validate_nb_metadata:
                nb_metadata_schema = read_file_from_directory(
                    get_module_path(schema),
                    "doc_metadata.schema.json",
                    "doc_metadata.schema",
                    self.logger,
                    interp_ext=True,
                )
                try:
                    jsonschema.validate(final_nb.metadata, nb_metadata_schema)
                except jsonschema.ValidationError as err:
                    handle_error(
                        "validation of notebook level metadata failed: {}\n"
                        "see the doc_metadata.schema.json for full spec".
                        format(err.message),
                        jsonschema.ValidationError,
                        logger=self.logger,
                    )

            # set text replacements for export configuration
            replacements = {
                self.meta_path_placeholder:
                str(meta_path),
                self.files_folder_placeholder:
                "{}{}".format(get_valid_filename(ipynb_name),
                              self.folder_suffix),
            }

            self.logger.debug("notebooks meta path: {}".format(meta_path))

            # load configuration file
            (
                exporter_cls,
                jinja_template,
                econfig,
                pprocs,
                pconfig,
            ) = self._load_config_file(replacements)

            # run nbconvert
            self.logger.info("running nbconvert")
            exporter, stream, resources = self.export_notebook(
                final_nb, exporter_cls, econfig, jinja_template)

            # postprocess results
            main_filepath = os.path.join(outdir,
                                         ipynb_name + exporter.file_extension)

            for post_proc_name in pprocs:
                proc_class = find_entry_point(
                    post_proc_name,
                    "ipypublish.postprocessors",
                    self.logger,
                    "ipypublish",
                )
                proc = proc_class(pconfig)
                stream, main_filepath, resources = proc.postprocess(
                    stream, exporter.output_mimetype, main_filepath, resources)

            self.logger.info("process finished successfully")

        return {
            "outpath": outdir,
            "exporter": exporter,
            "stream": stream,
            "main_filepath": main_filepath,
            "resources": resources,
        }

    def _load_config_file(self, replacements):
        # find conversion configuration
        self.logger.info("finding conversion configuration: {}".format(
            self.conversion))
        export_config_path = None
        if isinstance(self.conversion, string_types):
            outformat_path = pathlib.Path(self.conversion)
        else:
            outformat_path = self.conversion
        if outformat_path.exists():  # TODO use pathlib approach
            # if is outformat is a path that exists, use that
            export_config_path = outformat_path
        else:
            # else search internally
            export_config_path = get_export_config_path(
                self.conversion, self.plugin_folder_paths)

        if export_config_path is None:
            handle_error(
                "could not find conversion configuration: {}".format(
                    self.conversion),
                IOError,
                self.logger,
            )

        # read conversion configuration and create
        self.logger.info("loading conversion configuration")
        data = load_export_config(export_config_path)
        self.logger.info("creating exporter")
        exporter_cls = create_exporter_cls(data["exporter"]["class"])
        self.logger.info("creating template and loading filters")
        template_name = "template_file"
        jinja_template = load_template(template_name, data["template"])
        self.logger.info("creating process configuration")
        export_config = self._create_export_config(data["exporter"],
                                                   template_name, replacements)
        pprocs, pproc_config = self._create_pproc_config(
            data.get("postprocessors", {}), replacements)

        return (exporter_cls, jinja_template, export_config, pprocs,
                pproc_config)

    def _create_export_config(self, exporter_data, template_name,
                              replacements):
        # type: (dict, Dict[str, str]) -> Config
        config = {}
        exporter_name = exporter_data["class"].split(".")[-1]

        config[exporter_name + ".template_file"] = template_name
        config[exporter_name + ".filters"] = exporter_data.get("filters", [])

        preprocessors = []
        for preproc in exporter_data.get("preprocessors", []):
            preprocessors.append(preproc["class"])
            preproc_name = preproc["class"].split(".")[-1]
            for name, val in preproc.get("args", {}).items():
                config[preproc_name + "." + name] = val

        config[exporter_name + ".preprocessors"] = preprocessors

        for name, val in exporter_data.get("other_args", {}).items():
            config[name] = val

        final_config = self.default_exporter_config
        final_config.update(config)

        replace_placeholders(final_config, replacements)

        return dict_to_config(final_config, True)

    def _create_pproc_config(self, pproc_data, replacements):

        if "order" in pproc_data:
            pprocs_list = pproc_data["order"]
        else:
            pprocs_list = self._create_default_pporder(
                **self.default_pporder_kwargs)

        pproc_config = self._create_default_ppconfig(
            **self.default_ppconfig_kwargs)

        if "config" in pproc_data:
            override_config = pproc_data["config"]
            pproc_config.update(override_config)

        replace_placeholders(pproc_config, replacements)

        return pprocs_list, pproc_config

    def export_notebook(self, final_nb, exporter_cls, config, jinja_template):

        kwargs = {"config": config}
        if jinja_template is not None:
            kwargs["extra_loaders"] = [jinja_template]
        try:
            exporter = exporter_cls(**kwargs)
        except TypeError:
            self.logger.warning("the exporter class can not be parsed "
                                "the arguments: {}".format(list(
                                    kwargs.keys())))
            exporter = exporter_cls()

        body, resources = exporter.from_notebook_node(final_nb)
        return exporter, body, resources
Exemple #8
0
class MolViz2D(MessageWidget):
    """
    Draws 2D molecular representations with D3.js
    """
    _view_name = Unicode('MolWidget2DView').tag(sync=True)
    _model_name = Unicode('MolWidget2DModel').tag(sync=True)
    _view_module = Unicode('nbmolviz-js').tag(sync=True)
    _model_module = Unicode('nbmolviz-js').tag(sync=True)

    charge = traitlets.Float().tag(sync=True)
    uuid = traitlets.Unicode().tag(sync=True)
    graph = traitlets.Dict().tag(sync=True)
    clicked_bond_indices = traitlets.Tuple((-1, -1)).tag(sync=True)
    _atom_colors = traitlets.Dict({}).tag(sync=True)
    width = traitlets.Float().tag(sync=True)
    height = traitlets.Float().tag(sync=True)
    selected_atom_indices = traitlets.Set(set()).tag(sync=True)

    def __init__(self, atoms, charge=-150, width=400, height=350, **kwargs):

        kwargs.update(width=width, height=height)
        super(MolViz2D, self).__init__(**kwargs)

        try:
            self.atoms = atoms.atoms
        except AttributeError:
            self.atoms = atoms
        else:
            self.entity = atoms
        self.width = width
        self.height = height
        self.uuid = 'mol2d' + str(uuid.uuid4())
        self.charge = charge
        self._clicks_enabled = False
        self.graph = self.to_graph(self.atoms)

    def to_graph(self, atoms):
        """Turn a set of atoms into a graph
        Should return a dict of the form
        {nodes:[a1,a2,a3...],
        links:[b1,b2,b3...]}
        where ai = {atom:[atom name],color='black',size=1,index:i}
        and bi = {bond:[order],source:[i1],dest:[i2],
                color/category='black',distance=22.0,strength=1.0}
        You can assign an explicit color with "color" OR
        get automatically assigned unique colors using "category"
        """
        raise NotImplementedError(
            "This method must be implemented by the interface class")

    def set_atom_style(self, atoms=None, fill_color=None, outline_color=None):
        if atoms is None:
            indices = range(len(self.atoms))
        else:
            indices = map(self.get_atom_index, atoms)
        spec = {}
        if fill_color is not None:
            spec['fill'] = translate_color(fill_color, prefix='#')
        if outline_color is not None:
            spec['stroke'] = translate_color(outline_color, prefix='#')
        self.viewer('setAtomStyle', [indices, spec])

    def set_bond_style(self,
                       bonds,
                       color=None,
                       width=None,
                       dash_length=None,
                       opacity=None):
        """
        :param bonds: List of atoms
        :param color:
        :param width:
        :param dash_length:
        :return:
        """
        atom_pairs = [map(self.get_atom_index, pair) for pair in bonds]
        spec = {}
        if width is not None: spec['stroke-width'] = str(width) + 'px'
        if color is not None: spec['stroke'] = color
        if dash_length is not None:
            spec['stroke-dasharray'] = str(dash_length) + 'px'
        if opacity is not None: spec['opacity'] = opacity
        if not spec: raise ValueError('No bond style specified!')
        self.viewer('setBondStyle', [atom_pairs, spec])

    def set_atom_label(self,
                       atom,
                       text=None,
                       text_color=None,
                       size=None,
                       font=None):
        atomidx = self.get_atom_index(atom)
        self._change_label('setAtomLabel', atomidx, text, text_color, size,
                           font)

    def set_bond_label(self,
                       bond,
                       text=None,
                       text_color=None,
                       size=None,
                       font=None):
        bondids = map(self.get_atom_index, bond)
        self._change_label('setBondLabel', bondids, text, text_color, size,
                           font)

    def _change_label(self, driver_function, obj_index, text, text_color, size,
                      font):
        spec = {}
        if size is not None:
            if type(size) is not str:
                size = str(size) + 'pt'
                spec['font-size'] = size
        if text_color is not None:
            spec[
                'fill'] = text_color  # this strangely doesn't always work if you send it a name
        if font is not None:
            spec['font'] = font
        self.viewer(driver_function, [obj_index, text, spec])

    def highlight_atoms(self, atoms):
        indices = map(self.get_atom_index, atoms)
        self.viewer('updateHighlightAtoms', [indices])

    def get_atom_index(self, atom):
        raise NotImplemented(
            "This method must be implemented by the interface class")

    def set_click_callback(self, callback=None, enabled=True):
        """
        :param callback: Callback can have signature (), (trait_name), (trait_name,old), or (trait_name,old,new)
        :type callback: callable
        :param enabled:
        :return:
        """
        if not enabled: return  # TODO: FIX THIS
        assert callable(callback)
        self._clicks_enabled = True
        self.on_trait_change(callback, 'selected_atom_indices')
        self.click_callback = callback

    def set_color(self, color, atoms=None, render=None):
        self.set_atom_style(fill_color=color, atoms=atoms)

    def set_colors(self, colormap, render=True):
        """
        Args:
         colormap(Mapping[str,List[Atoms]]): mapping of colors to atoms
        """
        for color, atoms in colormap.iteritems():
            self.set_color(atoms=atoms, color=color)
Exemple #9
0
class Coordinates1d(BaseCoordinates):
    """
    Base class for 1-dimensional coordinates.

    Coordinates1d objects contain values and metadata for a single dimension of coordinates. :class:`Coordinates` and
    :class:`StackedCoordinates` use Coordinate1d objects.

    The following coordinates types (``ctype``) are supported:

     * 'point': each coordinate represents a single location
     * 'left': each coordinate is the left endpoint of its segment
     * 'right': each coordinate is the right endpoint of its endpoint
     * 'midpoint': segment endpoints are at the midpoints between coordinate values.

    The ``bounds`` are always the low and high coordinate value. For *point* coordinates, the ``area_bounds`` are the
    same as the ``bounds``. For *segment* coordinates (left, right, and midpoint), the ``area_bounds`` include the
    portion of the segments above and below the ``bounds`.
    
    Parameters
    ----------
    name : str
        Dimension name, one of 'lat', 'lon', 'time', or 'alt'.
    coordinates : array, read-only
        Full array of coordinate values.
    units : podpac.Units
        Coordinate units.
    coord_ref_sys : str
        Coordinate reference system.
    ctype : str
        Coordinates type: 'point', 'left', 'right', or 'midpoint'.
    segment_lengths : array, float, timedelta
        When ctype is a segment type, the segment lengths for the coordinates. This may be single coordinate delta for
        uniform segment lengths or an array of coordinate deltas corresponding to the coordinates for variable lengths.

    See Also
    --------
    :class:`ArrayCoordinates1d`, :class:`UniformCoordinates1d`
    """

    name = tl.Enum(['lat', 'lon', 'time', 'alt'], allow_none=True)
    units = tl.Instance(Units, allow_none=True, read_only=True)
    coord_ref_sys = tl.Enum(['WGS84', 'SPHER_MERC'],
                            allow_none=True,
                            read_only=True)
    ctype = tl.Enum(['point', 'left', 'right', 'midpoint'], read_only=True)
    segment_lengths = tl.Any(read_only=True)

    _properties = tl.Set()
    _segment_lengths = tl.Bool()

    def __init__(self,
                 name=None,
                 ctype=None,
                 units=None,
                 segment_lengths=None,
                 coord_ref_sys=None):
        """*Do not use.*"""

        if name is not None:
            self.name = name

        if ctype is not None:
            self.set_trait('ctype', ctype)

        if units is not None:
            self.set_trait('units', units)

        if coord_ref_sys is not None:
            self.set_trait('coord_ref_sys', coord_ref_sys)

        if segment_lengths is not None:
            if np.array(segment_lengths).ndim == 0:
                segment_lengths = make_coord_delta(segment_lengths)
            else:
                segment_lengths = make_coord_delta_array(segment_lengths)
                segment_lengths.setflags(write=False)

            self.set_trait('segment_lengths', segment_lengths)

        super(Coordinates1d, self).__init__()

    @tl.observe('name', 'units', 'coord_ref_sys', 'ctype')
    def _set_property(self, d):
        self._properties.add(d['name'])

    @tl.observe('segment_lengths')
    def _set_segment_lengths(self, d):
        self._segment_lengths = True

    @tl.validate('segment_lengths')
    def _validate_segment_lengths(self, d):
        val = d['value']

        if self.ctype == 'point':
            if val is not None:
                raise TypeError(
                    "segment_lengths must be None when ctype='point'")
            return None

        if isinstance(val, np.ndarray):
            if val.size != self.size:
                raise ValueError(
                    "coordinates and segment_lengths size mismatch, %d != %d" %
                    (self.size, val.size))
            if not np.issubdtype(val.dtype, self.deltatype):
                raise ValueError(
                    "coordinates and segment_lengths dtype mismatch, %s != %s"
                    % (self.dtype, self.deltatype))

        else:
            if self.size > 0 and not isinstance(val, self.deltatype):
                raise TypeError(
                    "coordinates and segment_lengths type mismatch, %s != %s" %
                    (self.deltatype, type(val)))

        if np.any(np.array(val).astype(float) <= 0.0):
            raise ValueError("segment_lengths must be positive")

        return val

    @tl.default('coord_ref_sys')
    def _default_coord_ref_sys(self):
        return DEFAULT_COORD_REF_SYS

    def __repr__(self):
        return "%s(%s): Bounds[%s, %s], N[%d], ctype['%s']" % (
            self.__class__.__name__, self.name
            or '?', self.bounds[0], self.bounds[1], self.size, self.ctype)

    def __eq__(self, other):
        if not isinstance(other, Coordinates1d):
            return False

        # defined coordinate properties should match
        for name in self._properties.union(other._properties):
            if getattr(self, name) != getattr(other, name):
                return False

        # shortcuts (not strictly necessary)
        for name in ['size', 'is_monotonic', 'is_descending', 'is_uniform']:
            if getattr(self, name) != getattr(other, name):
                return False

        # only check if one of the coordinates has custom segment lengths
        if self._segment_lengths or other._segment_lengths:
            if not np.all(self.segment_lengths == other.segment_lengths):
                return False

        return True

    def from_definition(self, d):
        raise NotImplementedError

    # ------------------------------------------------------------------------------------------------------------------
    # Properties
    # ------------------------------------------------------------------------------------------------------------------

    @property
    def dims(self):
        if self.name is None:
            raise TypeError(
                "cannot access dims property of unnamed Coordinates1d")
        return [self.name]

    @property
    def udims(self):
        return self.dims

    @property
    def coordinates(self):
        """:array, read-only: Full array of coordinates values."""

        raise NotImplementedError

    @property
    def dtype(self):
        """:type: Coordinates dtype.

        ``float`` for numerical coordinates and numpy ``datetime64`` for datetime coordinates.
        """

        raise NotImplementedError

    @property
    def deltatype(self):
        if self.dtype is np.datetime64:
            return np.timedelta64
        else:
            return self.dtype

    @property
    def size(self):
        """Number of coordinates. """

        raise NotImplementedError

    @property
    def is_monotonic(self):
        raise NotImplementedError

    @property
    def is_descending(self):
        raise NotImplementedError

    @property
    def is_uniform(self):
        raise NotImplementedError

    @property
    def bounds(self):
        """ Low and high coordinate bounds. """

        raise NotImplementedError

    @property
    def area_bounds(self):
        """
        Low and high coordinate area bounds.

        When ctype != 'point', this includes the portions of the segments beyond the coordinate bounds.
        """

        # point ctypes, just use bounds
        if self.ctype == 'point':
            return self.bounds

        # empty coordinates [np.nan, np.nan]
        if self.size == 0:
            return self.bounds

        # segment ctypes, calculated
        L, H = self.argbounds
        lo, hi = self.bounds

        if not isinstance(self.segment_lengths, np.ndarray):
            lo_length = hi_length = self.segment_lengths  # uniform segment_lengths
        else:
            lo_length, hi_length = self.segment_lengths[
                L], self.segment_lengths[H]

        if self.ctype == 'left':
            hi = add_coord(hi, hi_length)
        elif self.ctype == 'right':
            lo = add_coord(lo, -lo_length)
        elif self.ctype == 'midpoint':
            lo = add_coord(lo, -divide_delta(lo_length, 2.0))
            hi = add_coord(hi, divide_delta(hi_length, 2.0))

        # read-only array with the correct dtype
        area_bounds = np.array([lo, hi], dtype=self.dtype)
        area_bounds.setflags(write=False)
        return area_bounds

    @property
    def properties(self):
        """:dict: Dictionary of the coordinate properties. """

        return {key: getattr(self, key) for key in self._properties}

    @property
    def definition(self):
        """ Serializable 1d coordinates definition."""

        raise NotImplementedError

    # ------------------------------------------------------------------------------------------------------------------
    # Methods
    # ------------------------------------------------------------------------------------------------------------------

    def copy(self, **kwargs):
        """
        Make a deep copy of the 1d Coordinates.

        The coordinates properties will be copied. Any provided keyword arguments will override these properties.

        *Note: Defined in child classes.*

        Arguments
        ---------
        name : str, optional
            Dimension name. One of 'lat', 'lon', 'alt', and 'time'.
        coord_ref_sys : str, optional
            Coordinates reference system
        ctype : str, optional
            Coordinates type. One of 'point', 'midpoint', 'left', 'right'.
        units : podpac.Units, optional
            Coordinates units.

        Returns
        -------
        :class:`Coordinates1d`
            Copy of the coordinates, with provided properties.
        """

        raise NotImplementedError

    def _select_empty(self, return_indices):
        I = []
        if return_indices:
            return self[I], I
        else:
            return self[I]

    def _select_full(self, return_indices):
        I = slice(None)
        if return_indices:
            return self[I], I
        else:
            return self[I]

    def intersect(self, other, return_indices=False, outer=False):
        """
        Get the coordinate values that are within the bounds of a given coordinates object.

        If a Coordinates1d ``other`` is provided, then this dimension must match the other dimension
        (``self.name == other.name``). If a multidimensional :class:`Coordinates` ``other`` is provided, then the
        corresponding 1d coordinates are used for the intersection if available, and otherwise the entire coordinates
        are returned.

        The default intersection selects coordinates that are within the other coordinates bounds::

            In [1]: c = ArrayCoordinates1d([0, 1, 2, 3], name='lat')

            In [2]: other = ArrayCoordinates1d([1.5, 2.5], name='lat')

            In [3]: c.intersect(other).coordinates
            Out[3]: array([2.])

        The *outer* intersection selects the minimal set of coordinates that contain the other coordinates::
        
            In [4]: c.intersect(other, outer=True).coordinates
            Out[4]: array([1., 2., 3.])

        The *outer* intersection also selects a boundary coordinate if the other coordinates are outside this
        coordinates bounds but *inside* its area bounds::
        
            In [5]: c.area_bounds
            Out[5]: array([-0.5,  3.5])

            In [6]: other1 = podpac.coordinates.ArrayCoordinates1d([3.25], name='lat')
            
            In [7]: other2 = podpac.coordinates.ArrayCoordinates1d([3.75], name='lat')

            In [8]: c.intersect(o2, outer=True).coordinates
            Out[8]: array([3.0], dtype=float64)

            In [9]: c.intersect(o2, outer=True).coordinates
            Out[9]: array([], dtype=float64)
        
        Parameters
        ----------
        other : :class:`Coordinates1d`, :class:`StackedCoordinates`, :class:`Coordinates`
            Coordinates to intersect with.
        outer : bool, optional
            If True, do an *outer* intersection. Default False.
        return_indices : bool, optional
            If True, return slice or indices for the selection in addition to coordinates. Default False.
        
        Returns
        -------
        intersection : :class:`Coordinates1d`
            Coordinates1d object with coordinates within the other coordinates bounds.
        I : slice or list
            index or slice for the intersected coordinates (only if return_indices=True)
        
        Raises
        ------
        ValueError
            If the coordinates names do not match, when intersecting with a Coordinates1d other.

        See Also
        --------
        select : Get the coordinates within the given bounds.
        """

        from podpac.core.coordinates import Coordinates, StackedCoordinates

        if not isinstance(other, (BaseCoordinates, Coordinates)):
            raise TypeError("Cannot intersect with type '%s'" % type(other))

        if isinstance(other, (Coordinates, StackedCoordinates)):
            # short-circuit
            if self.name not in other.udims:
                return self._select_full(return_indices)

            other = other[self.name]

        if self.name != other.name:
            raise ValueError(
                "Cannot intersect mismatched dimensions ('%s' != '%s')" %
                (self.name, other.name))

        if self.dtype is not None and other.dtype is not None and self.dtype != other.dtype:
            raise ValueError(
                "Cannot intersect mismatched dtypes ('%s' != '%s')" %
                (self.dtype, other.dtype))

        if self.units != other.units:
            raise NotImplementedError(
                "Still need to implement handling different units")

        # no valid other bounds, empty
        if other.size == 0:
            return self._select_empty(return_indices)

        return self.select(other.bounds,
                           return_indices=return_indices,
                           outer=outer)

    def select(self, bounds, return_indices=False, outer=False):
        """
        Get the coordinate values that are within the given bounds.

        The default selection returns coordinates that are within the other coordinates bounds::

            In [1]: c = ArrayCoordinates1d([0, 1, 2, 3], name='lat')

            In [2]: c.select([1.5, 2.5]).coordinates
            Out[2]: array([2.])

        The *outer* selection returns the minimal set of coordinates that contain the other coordinates::
        
            In [3]: c.intersect([1.5, 2.5], outer=True).coordinates
            Out[3]: array([1., 2., 3.])

        The *outer* selection also returns a boundary coordinate if the other coordinates are outside this
        coordinates bounds but *inside* its area bounds::
        
            In [4]: c.intersect([3.25, 3.35], outer=True).coordinates
            Out[4]: array([3.0], dtype=float64)

            In [5]: c.intersect([10.0, 11.0], outer=True).coordinates
            Out[5]: array([], dtype=float64)

        *Note: Defined in child classes.*
        
        Parameters
        ----------
        bounds : low, high
            selection bounds
        outer : bool, optional
            If True, do an *outer* selection. Default False.
        return_indices : bool, optional
            If True, return slice or indices for the selection in addition to coordinates. Default False.

        Returns
        -------
        selection : :class:`Coordinates1d`
            Coordinates1d object with coordinates within the other coordinates bounds.
        I : slice or list
            index or slice for the intersected coordinates (only if return_indices=True)
        """

        raise NotImplementedError