コード例 #1
0
class ZillowHousingResponse(object):
    zpid = attr.ib()
    address = attr.ib(converter=optional(ZillowAddress.from_dict))
    estimated_rent_price = attr.ib(converter=optional(decimal.Decimal))
    estimated_rent_price_currency = attr.ib()
    last_estimated = attr.ib(converter=optional(parse_zillow_datetime))
    url = attr.ib()

    @classmethod
    def from_zillow_response(cls, dic):
        try:
            rent_estimate = dic["rentzestimate"]["amount"]
        except KeyError:
            rent_estimate = {}

        return cls(
            zpid=dic.get("zpid"),
            address=dic.get("address"),
            estimated_rent_price=rent_estimate.get("#text"),
            estimated_rent_price_currency=rent_estimate.get("@currency"),
            last_estimated=dic.get("rentzestimate", {}).get("last-updated"),
            url=dic.get("links", {}).get("mapthishome"),
        )

    @property
    def estimated_rent_price_for_display(self):
        return f"{self.estimated_rent_price_currency} {self.estimated_rent_price}"
コード例 #2
0
class EvalConfig:
    """class that represents [EVAL] section of config.toml file

    Attributes
    ----------
    csv_path : str
        path to where dataset was saved as a csv.
    checkpoint_path : str
        path to directory with checkpoint files saved by Torch, to reload model
    output_dir : str
        Path to location where .csv files with evaluation metrics should be saved.
    labelmap_path : str
        path to 'labelmap.json' file.
    models : list
        of model names. e.g., 'models = TweetyNet, GRUNet, ConvNet'
    batch_size : int
        number of samples per batch presented to models during training.
    num_workers : int
        Number of processes to use for parallel loading of data.
        Argument to torch.DataLoader. Default is 2.
    device : str
        Device on which to work with model + data.
        Defaults to 'cuda' if torch.cuda.is_available is True.
    spect_scaler_path : str
        path to a saved SpectScaler object used to normalize spectrograms.
        If spectrograms were normalized and this is not provided, will give
        incorrect results.
    """
    # required, external files
    checkpoint_path = attr.ib(converter=expanded_user_path,
                              validator=is_a_file)
    labelmap_path = attr.ib(converter=expanded_user_path, validator=is_a_file)
    output_dir = attr.ib(converter=expanded_user_path,
                         validator=is_a_directory)

    # required, model / dataloader
    models = attr.ib(converter=comma_separated_list,
                     validator=[instance_of(list), is_valid_model_name])
    batch_size = attr.ib(converter=int, validator=instance_of(int))

    # csv_path is actually 'required' but we can't enforce that here because cli.prep looks at
    # what sections are defined to figure out where to add csv_path after it creates the csv
    csv_path = attr.ib(converter=converters.optional(expanded_user_path),
                       validator=validators.optional(is_a_file),
                       default=None)

    # optional, transform
    spect_scaler_path = attr.ib(
        converter=converters.optional(expanded_user_path),
        validator=validators.optional(is_a_file),
        default=None)

    # optional, data loader
    num_workers = attr.ib(validator=instance_of(int), default=2)
    device = attr.ib(validator=instance_of(str), default=device.get_default())
コード例 #3
0
ファイル: test_converters.py プロジェクト: bocoup/wpt-docs
 def test_fail(self):
     """
     Propagates the underlying conversion error when conversion fails.
     """
     c = optional(int)
     with pytest.raises(ValueError):
         c("not_an_int")
コード例 #4
0
    def test_success_with_none(self):
        """
        Nothing happens if None.
        """
        c = optional(int)

        assert c(None) is None
コード例 #5
0
    def test_success_with_type(self):
        """
        Wrapped converter is used as usual if value is not None.
        """
        c = optional(int)

        assert c("42") == 42
コード例 #6
0
class Account(BaseAAEntity):
    TYPENAME = "Account"

    provider = attr.attrib(
        converter=enums.ProviderType)  # type: enums.ProviderType
    username = attr.attrib()  # type: str
    access_token = attr.attrib()  # type: Optional[str]
    access_token_expires_at = attr.attrib(converter=converters.optional(
        ciso8601.parse_datetime), )  # type: Optional[datetime.datetime]
コード例 #7
0
class SFTPClientOptions(object):
    """
    Client options for sending SFTP files.

    :param host: the host of the SFTP server
    :param port: the port ofo the SFTP server
    :param fingerprint: the expected fingerprint of the host
    :param user: the user to login as
    :param identity: the identity file, optional and like the "-i" command line option
    :param password: an optional password
    """
    host = attr.ib(converter=str)  # type: str
    port = attr.ib(converter=int)  # type: int
    fingerprint = attr.ib(converter=str)  # type: str
    user = attr.ib(converter=str)  # type: str
    identity = attr.ib(converter=optional(str),
                       default=None)  # type: Optional[str]
    password = attr.ib(converter=optional(str),
                       default=None)  # type: Optional[str]
コード例 #8
0
class LearncurveConfig(TrainConfig):
    """class that represents [LEARNCURVE] section of config.toml file

    Attributes
    ----------
    models : list
        of model names. e.g., 'models = TweetyNet, GRUNet, ConvNet'
    csv_path : str
        path to where dataset was saved as a csv.
    num_epochs : int
        number of training epochs. One epoch = one iteration through the entire
        training set.
    normalize_spectrograms : bool
        if True, use spect.utils.data.SpectScaler to normalize the spectrograms.
        Normalization is done by subtracting off the mean for each frequency bin
        of the training set and then dividing by the std for that frequency bin.
        This same normalization is then applied to validation + test data.
    ckpt_step : int
        step/epoch at which to save to checkpoint file.
        Default is None, in which case checkpoint is only saved at the last epoch.
    patience : int
        number of epochs to wait without the error dropping before stopping the
        training. Default is None, in which case training continues for num_epochs
    train_set_durs : list
        of int, durations in seconds of subsets taken from training data
        to create a learning curve, e.g. [5, 10, 15, 20]. Default is None
        (when training a single model on all available training data).
    num_replicates : int
        number of times to replicate training for each training set duration
        to better estimate mean accuracy for a training set of that size.
        Each replicate uses a different randomly drawn subset of the training
        data (but of the same duration).
    save_only_single_checkpoint_file : bool
        if True, save only one checkpoint file instead of separate files every time
        we save. Default is True.
    use_train_subsets_from_previous_run : bool
        if True, use training subsets saved in a previous run. Default is False.
        Requires setting previous_run_path option in config.toml file.
    previous_run_path : str
        path to results directory from a previous run.
        Used for training if use_train_subsets_from_previous_run is True.
    """

    train_set_durs = attr.ib(validator=instance_of(list), kw_only=True)
    num_replicates = attr.ib(validator=instance_of(int), kw_only=True)
    previous_run_path = attr.ib(
        converter=converters.optional(expanded_user_path),
        validator=validators.optional(is_a_directory),
        default=None,
    )
コード例 #9
0
class SpectParamsConfig:
    """represents parameters for making spectrograms from audio and saving in files

    Attributes
    ----------
    fft_size : int
        size of window for Fast Fourier transform, number of time bins. Default is 512.
    step_size : int
        step size for Fast Fourier transform. Default is 64.
    freq_cutoffs : tuple
        of two elements, lower and higher frequencies. Used to bandpass filter audio
        (using a Butter filter) before generating spectrogram.
        Default is None, in which case no bandpass filtering is applied.
    transform_type : str
        one of {'log_spect', 'log_spect_plus_one'}.
        'log_spect' transforms the spectrogram to log(spectrogram), and
        'log_spect_plus_one' does the same thing but adds one to each element.
        Default is None. If None, no transform is applied.
    thresh: int
        threshold minimum power for log spectrogram.
    spect_key : str
        key for accessing spectrogram in files. Default is 's'.
    freqbins_key : str
        key for accessing vector of frequency bins in files. Default is 'f'.
    timebins_key : str
        key for accessing vector of time bins in files. Default is 't'.
    audio_path_key : str
        key for accessing path to source audio file for spectogram in files.
        Default is 'audio_path'.
    """

    fft_size = attr.ib(converter=int, validator=instance_of(int), default=512)
    step_size = attr.ib(converter=int, validator=instance_of(int), default=64)
    freq_cutoffs = attr.ib(
        validator=validators.optional(freq_cutoffs_validator), default=None
    )
    thresh = attr.ib(
        converter=converters.optional(float),
        validator=validators.optional(instance_of(float)),
        default=None,
    )
    transform_type = attr.ib(
        validator=validators.optional([instance_of(str), is_valid_transform_type]),
        default=None,
    )
    spect_key = attr.ib(validator=instance_of(str), default="s")
    freqbins_key = attr.ib(validator=instance_of(str), default="f")
    timebins_key = attr.ib(validator=instance_of(str), default="t")
    audio_path_key = attr.ib(validator=instance_of(str), default="audio_path")
コード例 #10
0
ファイル: test_converters.py プロジェクト: Coder206/servo
 def test_success_with_none(self):
     """
     Nothing happens if None.
     """
     c = optional(int)
     assert c(None) is None
コード例 #11
0
ファイル: test_converters.py プロジェクト: Coder206/servo
 def test_success_with_type(self):
     """
     Wrapped converter is used as usual if value is not None.
     """
     c = optional(int)
     assert c("42") == 42
コード例 #12
0
class FeedlyEntry:
    url: str = attr.ib(validator=instance_of(str))
    source: str = attr.ib(repr=False)
    published: datetime = attr.ib(converter=utils.datetime_converters)
    updated: datetime = attr.ib(default=None,
                                converter=optional(utils.datetime_converters),
                                repr=False)

    keywords: Keywords = attr.ib(
        converter=utils.ensure_collection(lowercase_set),
        factory=lowercase_set,
        repr=False)
    author: Optional[str] = attr.ib(default='', repr=False)
    title: Optional[str] = attr.ib(default='', repr=False)

    markup: Dict[str, str] = attr.ib(factory=dict, repr=False)
    hyperlinks: HyperlinkStore = attr.ib(factory=HyperlinkStore, repr=False)

    @classmethod
    def from_upstream(cls, item: JSONDict) -> FeedlyEntry:
        data = {}
        for name in attr.fields_dict(cls):
            value = item.get(name)
            if value:
                data[name] = value
        data['url'] = cls._get_page_url(item)
        data['source'] = cls._get_source_url(item)
        entry = cls(**data)
        cls._set_markup(entry, item)
        return entry

    @staticmethod
    def _get_page_url(item):
        url = urlsplit(item.get('originId', ''))
        if url.netloc:
            url = url.geturl()
        else:
            url = ''
            alt = item.get('alternate')
            if alt and alt != 'none':
                url = alt[0]['href']
        return url

    @staticmethod
    def _get_source_url(item):
        source = item.get('origin')
        if source:
            return get_feed_uri(source.get('streamId', '/'))
        return ''

    @staticmethod
    def _set_markup(entry, item):
        content = item.get('content', item.get('summary'))
        if content:
            content = content.get('content')
        if content:
            entry.add_markup('summary', content)

    @staticmethod
    def _filter_attrib(attrib: attr.Attribute, value: Any) -> bool:
        return attrib.name[0] != '_'

    def add_markup(self, name, markup):
        self.markup[name] = markup
        self.hyperlinks.parse_html(self.url, markup)

    def for_json(self) -> JSONDict:
        dict_ = attr.asdict(self, filter=self._filter_attrib)
        return dict_
コード例 #13
0
ファイル: config.py プロジェクト: heindsight/dhcp-notify
class SMTPConfig(ConfigBase):
    host = attr.ib(validator=instance_of(str))
    port = attr.ib(default="465", validator=instance_of(str))
    tls = attr.ib(converter=SMTPTLSConfig, default="tls")
    credentials = attr.ib(converter=optional(Credentials.from_dict),
                          default=None)
コード例 #14
0
class PrepConfig:
    """class to represent [PREP] section of config.toml file

    Attributes
    ----------
    data_dir : str
        path to directory with files from which to make dataset
    output_dir : str
        Path to location where data sets should be saved. Default is None,
        in which case data sets are saved in the current working directory.
    audio_format : str
        format of audio files. One of {'wav', 'cbin'}.
    spect_format : str
        format of files containg spectrograms as 2-d matrices.
        One of {'mat', 'npy'}.
    annot_format : str
        format of annotations. Any format that can be used with the
        crowsetta library is valid.
    annot_file : str
        Path to a single annotation file. Default is None.
        Used when a single file contains annotations for multiple audio files.
    labelset : set
        of str or int, the set of labels that correspond to annotated segments
        that a network should learn to segment and classify. Note that if there
        are segments that are not annotated, e.g. silent gaps between songbird
        syllables, then `vak` will assign a dummy label to those segments
        -- you don't have to give them a label here.
    train_dur : float
        total duration of training set, in seconds. When creating a learning curve,
        training subsets of shorter duration (specified by the 'train_set_durs' option
        in the LEARNCURVE section of a config.toml file) will be drawn from this set.
    val_dur : float
        total duration of validation set, in seconds.
    test_dur : float
        total duration of test set, in seconds.
    """
    data_dir = attr.ib(converter=expanded_user_path, validator=is_a_directory)
    output_dir = attr.ib(converter=expanded_user_path,
                         validator=is_a_directory)

    audio_format = attr.ib(validator=validators.optional(is_audio_format),
                           default=None)
    spect_format = attr.ib(validator=validators.optional(is_spect_format),
                           default=None)
    annot_file = attr.ib(converter=converters.optional(expanded_user_path),
                         validator=validators.optional(is_a_file),
                         default=None)
    annot_format = attr.ib(validator=validators.optional(is_annot_format),
                           default=None)

    labelset = attr.ib(converter=converters.optional(labelset_from_toml_value),
                       validator=validators.optional(instance_of(set)),
                       default=None)

    train_dur = attr.ib(
        converter=converters.optional(duration_from_toml_value),
        validator=validators.optional(is_valid_duration),
        default=None)
    val_dur = attr.ib(converter=converters.optional(duration_from_toml_value),
                      validator=validators.optional(is_valid_duration),
                      default=None)
    test_dur = attr.ib(converter=converters.optional(duration_from_toml_value),
                       validator=validators.optional(is_valid_duration),
                       default=None)
コード例 #15
0
ファイル: outage.py プロジェクト: kouk/pingdom-uptime-report
class Outage(object):
    """An outage.

    Attributes:
        start (:class:`~arrow.arrow.Arrow`): the start time of the
            outage period.
        finish (:class:`~arrow.arrow.Arrow`): the ending time of the
            outage period.
        before (:class:`~arrow.arrow.Arrow`, optional): last time check was ok,
            if available.
        after (:class:`~arrow.arrow.Arrow`, optional): next time check was ok,
            if available.
        meta: (dict, optinal): arbitrary metadata about this outage.

    """

    start = attr.ib(convert=optional(arrow.get))
    finish = attr.ib(convert=optional(arrow.get))
    before = attr.ib(convert=optional(arrow.get), default=None)
    after = attr.ib(convert=optional(arrow.get), default=None)
    meta = attr.ib(default=attr.Factory(dict))

    def for_json(self):
        """Return a representation of this object as a dict.

        Example:

            >>> t = arrow.utcnow()
            >>> Outage(t, t).for_json() == {
            ...    "start": t,
            ...    "finish": t,
            ...    "before": None,
            ...    "after": None,
            ...    "meta": {}
            ... }
            True

        Returns:
            list: a list of names as (:class:`str`) instances.
        """
        return attr.asdict(self)

    @property
    def humanized_duration(self):
        return self.start.humanize(other=self.finish, only_distance=True)

    def humanize(self):
        return {
            'Begin': self.start.format(),
            'End': self.finish.format(),
            'Duration': self.humanized_duration,
        }

    @classmethod
    def fields(cls):
        """Return the field names for this class.

        Example:

            >>> Outage.fields()
            ['start', 'finish', 'before', 'after', 'meta']

        Returns:
            list: a list of names as (:class:`str`) instances.
        """
        return list(map(attrgetter('name'), attr.fields(cls)))
コード例 #16
0
ファイル: parameters.py プロジェクト: steven-murray/yabf
class Param:
    """Specification of a parameter that is to be constrained."""

    name = attr.ib()
    _min: numeric = attr.ib(-np.inf)
    _max: numeric = attr.ib(np.inf)

    prior = attr.ib(
        kw_only=True,
        validator=vld.optional(vld.instance_of(stats.distributions.rv_frozen)),
    )

    fiducial = attr.ib(
        None,
        type=float,
        converter=cnv.optional(float),
        validator=vld.optional(vld.instance_of(float)),
        kw_only=True,
    )
    latex = attr.ib(kw_only=True)
    ref = attr.ib(kw_only=True)
    determines = attr.ib(
        converter=tuplify,
        kw_only=True,
        validator=vld.deep_iterable(vld.instance_of(str)),
    )
    transforms = attr.ib(converter=tuplify, kw_only=True)

    @latex.default
    def _ltx_default(self):
        return texify(self.name)

    @ref.default
    def _ref_default(self):
        return self.prior

    @prior.default
    def _prior_default(self) -> stats.distributions.rv_frozen | None:
        if np.isinf(self._min) or np.isinf(self._max):
            return None

        return stats.uniform(self._min, self._max - self._min)

    @determines.default
    def _determines_default(self):
        return (self.name,)

    @transforms.default
    def _transforms_default(self):
        return (None,) * len(self.determines)

    @transforms.validator
    def _transforms_validator(self, attribute, value):
        for val in value:
            if val is not None and not callable(val):
                raise TypeError("transforms must be a list of callables")

    @property
    def min(self) -> float:
        """The minimum boundary of the prior, helpful for constraints."""
        if self.prior is None:
            return self._min
        elif isinstance(self.prior, type(stats.uniform(0, 1))):
            return self.prior.support()[0]
        else:
            return -np.inf

    @property
    def max(self) -> float:
        """The maximum boundary of the prior, helpful for constraints."""
        if self.prior is None:
            return self._max
        elif isinstance(self.prior, type(stats.uniform(0, 1))):
            return self.prior.support()[1]
        else:
            return np.inf

    @cached_property
    def is_alias(self):
        return all(pm is None for pm in self.transforms)

    @cached_property
    def is_pure_alias(self):
        return self.is_alias and len(self.determines) == 1

    def transform(self, val):
        for pm in self.transforms:
            if pm is None:
                yield val
            else:
                yield pm(val)

    def generate_ref(self, n=1):
        if self.ref is None:
            raise ValueError("Must specify a valid function for ref to generate refs.")

        try:
            ref = self.ref.rvs(size=n)
        except AttributeError:
            try:
                ref = self.ref(size=n)
            except TypeError:
                raise TypeError(
                    f"parameter '{self.name}' does not have a valid value for ref"
                )

        if np.any(self.prior.pdf(ref) == 0):
            raise ValueError(
                f"param {self.name} produced a reference value outside its domain."
            )

        return ref

    def logprior(self, val):
        if self.prior is None:
            if self._min > val or self._max < val:
                return -np.inf
            else:
                return 0

        return self.prior.logpdf(val)

    def clone(self, **kwargs):
        return attr.evolve(self, **kwargs)

    def new(self, p: Parameter) -> Param:
        """Create a new :class:`Param`.

        Any missing info from this instance filled in by the given instance.
        """
        assert isinstance(p, Parameter)
        assert self.determines == (p.name,)

        if len(self.determines) > 1:
            raise ValueError("Cannot create new Param if it is not just an alias")

        default_range = (list(self.transform(p.min))[0], list(self.transform(p.max))[0])

        return Param(
            name=self.name,
            min=max(self._min, min(default_range)),
            max=min(self._max, max(default_range)),
            fiducial=self.fiducial if self.fiducial is not None else p.fiducial,
            latex=self.latex
            if (self.latex != self.name or self.name != p.name)
            else p.latex,
            ref=self.ref or attr.NOTHING,
            prior=self.prior or attr.NOTHING,
            determines=self.determines,
            transforms=self.transforms,
        )

    def __getstate__(self):
        """Obtain a simple input state of the class that can initialize it."""
        out = attr.asdict(self)

        if self.transforms == (None,):
            del out["transforms"]
        if self.ref is None:
            del out["ref"]
        if self.determines == (self.name,):
            del out["determines"]
        if self.latex == self.name:
            del out["latex"]

        return out

    def as_dict(self):
        """Simple representation of the class as a dict.

        No "name" is included in the dict.
        """
        out = self.__getstate__()
        del out["name"]
        return out
コード例 #17
0
class PredictConfig:
    """class that represents [PREDICT] section of config.toml file

     Attributes
     ----------
     csv_path : str
         path to where dataset was saved as a csv.
     checkpoint_path : str
         path to directory with checkpoint files saved by Torch, to reload model
     labelmap_path : str
         path to 'labelmap.json' file.
     models : list
         of model names. e.g., 'models = TweetyNet, GRUNet, ConvNet'
     batch_size : int
         number of samples per batch presented to models during training.
     num_workers : int
         Number of processes to use for parallel loading of data.
         Argument to torch.DataLoader. Default is 2.
     device : str
         Device on which to work with model + data.
         Defaults to 'cuda' if torch.cuda.is_available is True.
     spect_scaler_path : str
         path to a saved SpectScaler object used to normalize spectrograms.
         If spectrograms were normalized and this is not provided, will give
         incorrect results.
     annot_csv_filename : str
         name of .csv file containing predicted annotations.
         Default is None, in which case the name of the dataset .csv
         is used, with '.annot.csv' appended to it.
     output_dir : str
         path to location where .csv containing predicted annotation
         should be saved. Defaults to current working directory.
     min_segment_dur : float
         minimum duration of segment, in seconds. If specified, then
         any segment with a duration less than min_segment_dur is
         removed from lbl_tb. Default is None, in which case no
         segments are removed.
     majority_vote : bool
         if True, transform segments containing multiple labels
         into segments with a single label by taking a "majority vote",
         i.e. assign all time bins in the segment the most frequently
         occurring label in the segment. This transform can only be
         applied if the labelmap contains an 'unlabeled' label,
         because unlabeled segments makes it possible to identify
         the labeled segments. Default is False.
    save_net_outputs : bool
         if True, save 'raw' outputs of neural networks
         before they are converted to annotations. Default is False.
         Typically the output will be "logits"
         to which a softmax transform might be applied.
         For each item in the dataset--each row in  the `csv_path` .csv--
         the output will be saved in a separate file in `output_dir`,
         with the extension `{MODEL_NAME}.output.npz`. E.g., if the input is a
         spectrogram with `spect_path` filename `gy6or6_032312_081416.npz`,
         and the network is `TweetyNet`, then the net output file
         will be `gy6or6_032312_081416.tweetynet.output.npz`.
    """

    # required, external files
    checkpoint_path = attr.ib(converter=expanded_user_path,
                              validator=is_a_file)
    labelmap_path = attr.ib(converter=expanded_user_path, validator=is_a_file)

    # required, model / dataloader
    models = attr.ib(
        converter=comma_separated_list,
        validator=[instance_of(list), is_valid_model_name],
    )
    batch_size = attr.ib(converter=int, validator=instance_of(int))

    # csv_path is actually 'required' but we can't enforce that here because cli.prep looks at
    # what sections are defined to figure out where to add csv_path after it creates the csv
    csv_path = attr.ib(
        converter=converters.optional(expanded_user_path),
        validator=validators.optional(is_a_file),
        default=None,
    )

    # optional, transform
    spect_scaler_path = attr.ib(
        converter=converters.optional(expanded_user_path),
        validator=validators.optional(is_a_file),
        default=None,
    )

    # optional, data loader
    num_workers = attr.ib(validator=instance_of(int), default=2)
    device = attr.ib(validator=instance_of(str), default=device.get_default())

    annot_csv_filename = attr.ib(validator=validators.optional(
        instance_of(str)),
                                 default=None)
    output_dir = attr.ib(
        converter=expanded_user_path,
        validator=is_a_directory,
        default=Path(os.getcwd()),
    )
    min_segment_dur = attr.ib(validator=validators.optional(
        instance_of(float)),
                              default=None)
    majority_vote = attr.ib(validator=instance_of(bool), default=True)
    save_net_outputs = attr.ib(validator=instance_of(bool), default=False)
コード例 #18
0
ファイル: train.py プロジェクト: Luke-Poeppel/vak
class TrainConfig:
    """class that represents [TRAIN] section of config.toml file

    Attributes
    ----------
    models : list
        comma-separated list of model names.
        e.g., 'models = TweetyNet, GRUNet, ConvNet'
    csv_path : str
        path to where dataset was saved as a csv.
    num_epochs : int
        number of training epochs. One epoch = one iteration through the entire
        training set.
    batch_size : int
        number of samples per batch presented to models during training.
    root_results_dir : str
        directory in which results will be created.
        The vak.cli.train function will create
        a subdirectory in this directory each time it runs.
    num_workers : int
        Number of processes to use for parallel loading of data.
        Argument to torch.DataLoader.
    device : str
        Device on which to work with model + data.
        Defaults to 'cuda' if torch.cuda.is_available is True.
    shuffle: bool
        if True, shuffle training data before each epoch. Default is True.
    normalize_spectrograms : bool
        if True, use spect.utils.data.SpectScaler to normalize the spectrograms.
        Normalization is done by subtracting off the mean for each frequency bin
        of the training set and then dividing by the std for that frequency bin.
        This same normalization is then applied to validation + test data.
    val_step : int
        Step on which to estimate accuracy using validation set.
        If val_step is n, then validation is carried out every time
        the global step / n is a whole number, i.e., when val_step modulo the global step is 0.
        Default is None, in which case no validation is done.
    ckpt_step : int
        Step on which to save to checkpoint file.
        If ckpt_step is n, then a checkpoint is saved every time
        the global step / n is a whole number, i.e., when ckpt_step modulo the global step is 0.
        Default is None, in which case checkpoint is only saved at the last epoch.
    patience : int
        number of validation steps to wait without performance on the
        validation set improving before stopping the training.
        Default is None, in which case training only stops after the specified number of epochs.
    """

    # required
    models = attr.ib(
        converter=comma_separated_list,
        validator=[instance_of(list), is_valid_model_name],
    )
    num_epochs = attr.ib(converter=int, validator=instance_of(int))
    batch_size = attr.ib(converter=int, validator=instance_of(int))
    root_results_dir = attr.ib(converter=expanded_user_path, validator=is_a_directory)

    # optional
    # csv_path is actually 'required' but we can't enforce that here because cli.prep looks at
    # what sections are defined to figure out where to add csv_path after it creates the csv
    csv_path = attr.ib(
        converter=converters.optional(expanded_user_path),
        validator=validators.optional(is_a_file),
        default=None,
    )

    results_dirname = attr.ib(
        converter=converters.optional(expanded_user_path),
        validator=validators.optional(is_a_directory),
        default=None,
    )

    normalize_spectrograms = attr.ib(
        converter=bool_from_str,
        validator=validators.optional(instance_of(bool)),
        default=False,
    )

    num_workers = attr.ib(validator=instance_of(int), default=2)
    device = attr.ib(validator=instance_of(str), default=device.get_default())
    shuffle = attr.ib(
        converter=bool_from_str, validator=instance_of(bool), default=True
    )

    val_step = attr.ib(
        converter=converters.optional(int),
        validator=validators.optional(instance_of(int)),
        default=None,
    )
    ckpt_step = attr.ib(
        converter=converters.optional(int),
        validator=validators.optional(instance_of(int)),
        default=None,
    )
    patience = attr.ib(
        converter=converters.optional(int),
        validator=validators.optional(instance_of(int)),
        default=None,
    )
コード例 #19
0
class ProxyPlugin(StateMachine):
    'Proxy USB communications using a ubq_core enabled hardware device.'

    #: Address to listen to for USB device.
    _device_addr = attr.ib(converter=optional(str), default=None)

    #: Port to listen to for USB device.
    _device_port = attr.ib(converter=optional(int), default=None)

    #: Address to send to for USB host.
    _host_addr = attr.ib(converter=optional(str), default=None)

    #: Port to send to for USB host.
    _host_port = attr.ib(converter=optional(int), default=None)

    #: Timeout for select statement that waits for incoming USBQ packets
    timeout = attr.ib(converter=int, default=1)

    # States
    idle = State('idle', initial=True)
    running = State('running')

    # Valid state transitions
    start = idle.to(running)
    reset = running.to(idle) | idle.to(idle)
    reload = idle.to(running)

    EMPTY = []
    MANAGEMENT_MSG = {
        ManagementMessage.ManagementType.NEW_DEVICE:
        'New device connected to USBQ proxy',
        ManagementMessage.ManagementType.RESET:
        'Device reset sent from USBQ proxy',
    }

    def __attrs_post_init__(self):
        # Workaround to mesh attr and StateMachine
        super().__init__()
        self._socks = []
        self._proxy_host = True
        self._proxy_device = True
        self._device_dst = None
        self._detected_host = False
        self._detected_device = False

        if self._device_addr is None or self._device_port is None:
            self._proxy_device = False

        if self._host_addr is None or self._host_port is None:
            self._proxy_host = False

        if self._proxy_device:
            log.info(
                f'Device listen to {self._device_addr}:{self._device_port}')
            self._device_sock = socket.socket(socket.AF_INET,
                                              socket.SOCK_DGRAM)
            self._device_sock.setsockopt(socket.SOL_SOCKET,
                                         socket.SO_REUSEADDR, 1)
            self._device_sock.setblocking(False)
            self._device_sock.bind((self._device_addr, self._device_port))
            self._socks.append(self._device_sock)

        if self._proxy_host:
            log.info(f'Host send to {self._host_addr}:{self._host_port}')
            self._host_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
            self._host_sock.setblocking(False)
            self._host_dst = (self._host_addr, self._host_port)
            self._socks.append(self._host_sock)

    def _has_data(self, socks, timeout=0):
        (read, write, error) = select.select(socks, self.EMPTY, socks, timeout)
        if len(read) != 0:
            return True
        return False

    @hookimpl
    def usbq_host_has_packet(self):
        if self._proxy_host:
            if self._has_data([self._host_sock]):
                return True

    @hookimpl
    def usbq_device_has_packet(self):
        if self._proxy_device:
            if self._has_data([self._device_sock]):
                return True

    @hookimpl
    def usbq_wait_for_packet(self):
        # Poll for data from non-proxy source
        queued_data = []
        if not self._proxy_host:
            queued_data += pm.hook.usbq_host_has_packet()
        if not self._proxy_device:
            queued_data += pm.hook.usbq_device_has_packet()

        if any(queued_data):
            return True
        else:
            # Wait
            if self._has_data(self._socks, timeout=self.timeout):
                return True

    @hookimpl
    def usbq_get_host_packet(self):
        data, self._host_dst = self._host_sock.recvfrom(4096)

        if not self._detected_host:
            log.info('First USBQ host packet detected from proxy')
            self._detected_host = True

        return data

    @hookimpl
    def usbq_get_device_packet(self):
        data, self._device_dst = self._device_sock.recvfrom(4096)

        if not self._detected_device:
            log.info('First USBQ device packet detected from proxy')
            self._detected_device = True

        return data

    @hookimpl
    def usbq_send_host_packet(self, data):
        return self._host_sock.sendto(data, self._host_dst) > 0

    @hookimpl
    def usbq_send_device_packet(self, data):
        if self._device_dst is not None:
            return self._device_sock.sendto(data, self._device_dst) > 0

    @hookimpl
    def usbq_log_pkt(self, pkt):
        if ManagementMessage in pkt:
            msg = self.MANAGEMENT_MSG.get(pkt.content.management_type, None)
            if msg is not None:
                log.info(msg)

            if pkt.content.management_type == ManagementMessage.ManagementType.RESET:
                self._detected_device = False

    def on_start(self):
        log.info('Starting proxy.')

    def _send_host_mgmt(self, pkt):
        data = pm.hook.usbq_host_encode(pkt=USBMessageDevice(
            type=USBMessageHost.MitmType.MANAGEMENT, content=pkt))
        self.usbq_send_host_packet(data)

    def _send_device_mgmt(self, pkt):
        data = pm.hook.usbq_device_encode(pkt=USBMessageHost(
            type=USBMessageDevice.MitmType.MANAGEMENT, content=pkt))
        self.usbq_send_device_packet(data)

    def on_reset(self):
        log.info('Reset device.')

        self._send_device_mgmt(
            ManagementMessage(
                management_type=ManagementMessage.ManagementType.RESET,
                management_content=ManagementReset(),
            ))

    def on_reload(self):
        log.info('Reload device.')

        self._send_device_mgmt(
            ManagementMessage(
                management_type=ManagementMessage.ManagementType.RELOAD,
                management_content=ManagementReload(),
            ))
コード例 #20
0
class PredictConfig:
    """class that represents [PREDICT] section of config.toml file

    Attributes
    ----------
    csv_path : str
        path to where dataset was saved as a csv.
    checkpoint_path : str
        path to directory with checkpoint files saved by Torch, to reload model
    labelmap_path : str
        path to 'labelmap.json' file.
    annot_format : str
        format of annotations. Any format that can be used with the
        crowsetta library is valid.
    models : list
        of model names. e.g., 'models = TweetyNet, GRUNet, ConvNet'
    batch_size : int
        number of samples per batch presented to models during training.
    num_workers : int
        Number of processes to use for parallel loading of data.
        Argument to torch.DataLoader. Default is 2.
    device : str
        Device on which to work with model + data.
        Defaults to 'cuda' if torch.cuda.is_available is True.
    spect_scaler_path : str
        path to a saved SpectScaler object used to normalize spectrograms.
        If spectrograms were normalized and this is not provided, will give
        incorrect results.
    to_format_kwargs : dict
        keyword arguments for crowsetta `to_format` function.
        Defined in .toml config file as a table.
        An example for the notmat annotation format (as a dictionary) is:
        {'min_syl_dur': 10., 'min_silent_dur', 6., 'threshold': 1500}.
    """
    # required, external files
    checkpoint_path = attr.ib(converter=expanded_user_path,
                              validator=is_a_file)
    labelmap_path = attr.ib(converter=expanded_user_path, validator=is_a_file)

    # required, for annotation
    annot_format = attr.ib(validator=is_annot_format)

    # required, model / dataloader
    models = attr.ib(converter=comma_separated_list,
                     validator=[instance_of(list), is_valid_model_name])
    batch_size = attr.ib(converter=int, validator=instance_of(int))

    # csv_path is actually 'required' but we can't enforce that here because cli.prep looks at
    # what sections are defined to figure out where to add csv_path after it creates the csv
    csv_path = attr.ib(converter=converters.optional(expanded_user_path),
                       validator=validators.optional(is_a_file),
                       default=None)

    # optional
    to_format_kwargs = attr.ib(validator=validators.optional(
        instance_of(dict)),
                               default=None)

    # optional, transform
    spect_scaler_path = attr.ib(
        converter=converters.optional(expanded_user_path),
        validator=validators.optional(is_a_file),
        default=None)

    # optional, data loader
    num_workers = attr.ib(validator=instance_of(int), default=2)
    device = attr.ib(validator=instance_of(str), default=device.get_default())