class ZillowHousingResponse(object): zpid = attr.ib() address = attr.ib(converter=optional(ZillowAddress.from_dict)) estimated_rent_price = attr.ib(converter=optional(decimal.Decimal)) estimated_rent_price_currency = attr.ib() last_estimated = attr.ib(converter=optional(parse_zillow_datetime)) url = attr.ib() @classmethod def from_zillow_response(cls, dic): try: rent_estimate = dic["rentzestimate"]["amount"] except KeyError: rent_estimate = {} return cls( zpid=dic.get("zpid"), address=dic.get("address"), estimated_rent_price=rent_estimate.get("#text"), estimated_rent_price_currency=rent_estimate.get("@currency"), last_estimated=dic.get("rentzestimate", {}).get("last-updated"), url=dic.get("links", {}).get("mapthishome"), ) @property def estimated_rent_price_for_display(self): return f"{self.estimated_rent_price_currency} {self.estimated_rent_price}"
class EvalConfig: """class that represents [EVAL] section of config.toml file Attributes ---------- csv_path : str path to where dataset was saved as a csv. checkpoint_path : str path to directory with checkpoint files saved by Torch, to reload model output_dir : str Path to location where .csv files with evaluation metrics should be saved. labelmap_path : str path to 'labelmap.json' file. models : list of model names. e.g., 'models = TweetyNet, GRUNet, ConvNet' batch_size : int number of samples per batch presented to models during training. num_workers : int Number of processes to use for parallel loading of data. Argument to torch.DataLoader. Default is 2. device : str Device on which to work with model + data. Defaults to 'cuda' if torch.cuda.is_available is True. spect_scaler_path : str path to a saved SpectScaler object used to normalize spectrograms. If spectrograms were normalized and this is not provided, will give incorrect results. """ # required, external files checkpoint_path = attr.ib(converter=expanded_user_path, validator=is_a_file) labelmap_path = attr.ib(converter=expanded_user_path, validator=is_a_file) output_dir = attr.ib(converter=expanded_user_path, validator=is_a_directory) # required, model / dataloader models = attr.ib(converter=comma_separated_list, validator=[instance_of(list), is_valid_model_name]) batch_size = attr.ib(converter=int, validator=instance_of(int)) # csv_path is actually 'required' but we can't enforce that here because cli.prep looks at # what sections are defined to figure out where to add csv_path after it creates the csv csv_path = attr.ib(converter=converters.optional(expanded_user_path), validator=validators.optional(is_a_file), default=None) # optional, transform spect_scaler_path = attr.ib( converter=converters.optional(expanded_user_path), validator=validators.optional(is_a_file), default=None) # optional, data loader num_workers = attr.ib(validator=instance_of(int), default=2) device = attr.ib(validator=instance_of(str), default=device.get_default())
def test_fail(self): """ Propagates the underlying conversion error when conversion fails. """ c = optional(int) with pytest.raises(ValueError): c("not_an_int")
def test_success_with_none(self): """ Nothing happens if None. """ c = optional(int) assert c(None) is None
def test_success_with_type(self): """ Wrapped converter is used as usual if value is not None. """ c = optional(int) assert c("42") == 42
class Account(BaseAAEntity): TYPENAME = "Account" provider = attr.attrib( converter=enums.ProviderType) # type: enums.ProviderType username = attr.attrib() # type: str access_token = attr.attrib() # type: Optional[str] access_token_expires_at = attr.attrib(converter=converters.optional( ciso8601.parse_datetime), ) # type: Optional[datetime.datetime]
class SFTPClientOptions(object): """ Client options for sending SFTP files. :param host: the host of the SFTP server :param port: the port ofo the SFTP server :param fingerprint: the expected fingerprint of the host :param user: the user to login as :param identity: the identity file, optional and like the "-i" command line option :param password: an optional password """ host = attr.ib(converter=str) # type: str port = attr.ib(converter=int) # type: int fingerprint = attr.ib(converter=str) # type: str user = attr.ib(converter=str) # type: str identity = attr.ib(converter=optional(str), default=None) # type: Optional[str] password = attr.ib(converter=optional(str), default=None) # type: Optional[str]
class LearncurveConfig(TrainConfig): """class that represents [LEARNCURVE] section of config.toml file Attributes ---------- models : list of model names. e.g., 'models = TweetyNet, GRUNet, ConvNet' csv_path : str path to where dataset was saved as a csv. num_epochs : int number of training epochs. One epoch = one iteration through the entire training set. normalize_spectrograms : bool if True, use spect.utils.data.SpectScaler to normalize the spectrograms. Normalization is done by subtracting off the mean for each frequency bin of the training set and then dividing by the std for that frequency bin. This same normalization is then applied to validation + test data. ckpt_step : int step/epoch at which to save to checkpoint file. Default is None, in which case checkpoint is only saved at the last epoch. patience : int number of epochs to wait without the error dropping before stopping the training. Default is None, in which case training continues for num_epochs train_set_durs : list of int, durations in seconds of subsets taken from training data to create a learning curve, e.g. [5, 10, 15, 20]. Default is None (when training a single model on all available training data). num_replicates : int number of times to replicate training for each training set duration to better estimate mean accuracy for a training set of that size. Each replicate uses a different randomly drawn subset of the training data (but of the same duration). save_only_single_checkpoint_file : bool if True, save only one checkpoint file instead of separate files every time we save. Default is True. use_train_subsets_from_previous_run : bool if True, use training subsets saved in a previous run. Default is False. Requires setting previous_run_path option in config.toml file. previous_run_path : str path to results directory from a previous run. Used for training if use_train_subsets_from_previous_run is True. """ train_set_durs = attr.ib(validator=instance_of(list), kw_only=True) num_replicates = attr.ib(validator=instance_of(int), kw_only=True) previous_run_path = attr.ib( converter=converters.optional(expanded_user_path), validator=validators.optional(is_a_directory), default=None, )
class SpectParamsConfig: """represents parameters for making spectrograms from audio and saving in files Attributes ---------- fft_size : int size of window for Fast Fourier transform, number of time bins. Default is 512. step_size : int step size for Fast Fourier transform. Default is 64. freq_cutoffs : tuple of two elements, lower and higher frequencies. Used to bandpass filter audio (using a Butter filter) before generating spectrogram. Default is None, in which case no bandpass filtering is applied. transform_type : str one of {'log_spect', 'log_spect_plus_one'}. 'log_spect' transforms the spectrogram to log(spectrogram), and 'log_spect_plus_one' does the same thing but adds one to each element. Default is None. If None, no transform is applied. thresh: int threshold minimum power for log spectrogram. spect_key : str key for accessing spectrogram in files. Default is 's'. freqbins_key : str key for accessing vector of frequency bins in files. Default is 'f'. timebins_key : str key for accessing vector of time bins in files. Default is 't'. audio_path_key : str key for accessing path to source audio file for spectogram in files. Default is 'audio_path'. """ fft_size = attr.ib(converter=int, validator=instance_of(int), default=512) step_size = attr.ib(converter=int, validator=instance_of(int), default=64) freq_cutoffs = attr.ib( validator=validators.optional(freq_cutoffs_validator), default=None ) thresh = attr.ib( converter=converters.optional(float), validator=validators.optional(instance_of(float)), default=None, ) transform_type = attr.ib( validator=validators.optional([instance_of(str), is_valid_transform_type]), default=None, ) spect_key = attr.ib(validator=instance_of(str), default="s") freqbins_key = attr.ib(validator=instance_of(str), default="f") timebins_key = attr.ib(validator=instance_of(str), default="t") audio_path_key = attr.ib(validator=instance_of(str), default="audio_path")
class FeedlyEntry: url: str = attr.ib(validator=instance_of(str)) source: str = attr.ib(repr=False) published: datetime = attr.ib(converter=utils.datetime_converters) updated: datetime = attr.ib(default=None, converter=optional(utils.datetime_converters), repr=False) keywords: Keywords = attr.ib( converter=utils.ensure_collection(lowercase_set), factory=lowercase_set, repr=False) author: Optional[str] = attr.ib(default='', repr=False) title: Optional[str] = attr.ib(default='', repr=False) markup: Dict[str, str] = attr.ib(factory=dict, repr=False) hyperlinks: HyperlinkStore = attr.ib(factory=HyperlinkStore, repr=False) @classmethod def from_upstream(cls, item: JSONDict) -> FeedlyEntry: data = {} for name in attr.fields_dict(cls): value = item.get(name) if value: data[name] = value data['url'] = cls._get_page_url(item) data['source'] = cls._get_source_url(item) entry = cls(**data) cls._set_markup(entry, item) return entry @staticmethod def _get_page_url(item): url = urlsplit(item.get('originId', '')) if url.netloc: url = url.geturl() else: url = '' alt = item.get('alternate') if alt and alt != 'none': url = alt[0]['href'] return url @staticmethod def _get_source_url(item): source = item.get('origin') if source: return get_feed_uri(source.get('streamId', '/')) return '' @staticmethod def _set_markup(entry, item): content = item.get('content', item.get('summary')) if content: content = content.get('content') if content: entry.add_markup('summary', content) @staticmethod def _filter_attrib(attrib: attr.Attribute, value: Any) -> bool: return attrib.name[0] != '_' def add_markup(self, name, markup): self.markup[name] = markup self.hyperlinks.parse_html(self.url, markup) def for_json(self) -> JSONDict: dict_ = attr.asdict(self, filter=self._filter_attrib) return dict_
class SMTPConfig(ConfigBase): host = attr.ib(validator=instance_of(str)) port = attr.ib(default="465", validator=instance_of(str)) tls = attr.ib(converter=SMTPTLSConfig, default="tls") credentials = attr.ib(converter=optional(Credentials.from_dict), default=None)
class PrepConfig: """class to represent [PREP] section of config.toml file Attributes ---------- data_dir : str path to directory with files from which to make dataset output_dir : str Path to location where data sets should be saved. Default is None, in which case data sets are saved in the current working directory. audio_format : str format of audio files. One of {'wav', 'cbin'}. spect_format : str format of files containg spectrograms as 2-d matrices. One of {'mat', 'npy'}. annot_format : str format of annotations. Any format that can be used with the crowsetta library is valid. annot_file : str Path to a single annotation file. Default is None. Used when a single file contains annotations for multiple audio files. labelset : set of str or int, the set of labels that correspond to annotated segments that a network should learn to segment and classify. Note that if there are segments that are not annotated, e.g. silent gaps between songbird syllables, then `vak` will assign a dummy label to those segments -- you don't have to give them a label here. train_dur : float total duration of training set, in seconds. When creating a learning curve, training subsets of shorter duration (specified by the 'train_set_durs' option in the LEARNCURVE section of a config.toml file) will be drawn from this set. val_dur : float total duration of validation set, in seconds. test_dur : float total duration of test set, in seconds. """ data_dir = attr.ib(converter=expanded_user_path, validator=is_a_directory) output_dir = attr.ib(converter=expanded_user_path, validator=is_a_directory) audio_format = attr.ib(validator=validators.optional(is_audio_format), default=None) spect_format = attr.ib(validator=validators.optional(is_spect_format), default=None) annot_file = attr.ib(converter=converters.optional(expanded_user_path), validator=validators.optional(is_a_file), default=None) annot_format = attr.ib(validator=validators.optional(is_annot_format), default=None) labelset = attr.ib(converter=converters.optional(labelset_from_toml_value), validator=validators.optional(instance_of(set)), default=None) train_dur = attr.ib( converter=converters.optional(duration_from_toml_value), validator=validators.optional(is_valid_duration), default=None) val_dur = attr.ib(converter=converters.optional(duration_from_toml_value), validator=validators.optional(is_valid_duration), default=None) test_dur = attr.ib(converter=converters.optional(duration_from_toml_value), validator=validators.optional(is_valid_duration), default=None)
class Outage(object): """An outage. Attributes: start (:class:`~arrow.arrow.Arrow`): the start time of the outage period. finish (:class:`~arrow.arrow.Arrow`): the ending time of the outage period. before (:class:`~arrow.arrow.Arrow`, optional): last time check was ok, if available. after (:class:`~arrow.arrow.Arrow`, optional): next time check was ok, if available. meta: (dict, optinal): arbitrary metadata about this outage. """ start = attr.ib(convert=optional(arrow.get)) finish = attr.ib(convert=optional(arrow.get)) before = attr.ib(convert=optional(arrow.get), default=None) after = attr.ib(convert=optional(arrow.get), default=None) meta = attr.ib(default=attr.Factory(dict)) def for_json(self): """Return a representation of this object as a dict. Example: >>> t = arrow.utcnow() >>> Outage(t, t).for_json() == { ... "start": t, ... "finish": t, ... "before": None, ... "after": None, ... "meta": {} ... } True Returns: list: a list of names as (:class:`str`) instances. """ return attr.asdict(self) @property def humanized_duration(self): return self.start.humanize(other=self.finish, only_distance=True) def humanize(self): return { 'Begin': self.start.format(), 'End': self.finish.format(), 'Duration': self.humanized_duration, } @classmethod def fields(cls): """Return the field names for this class. Example: >>> Outage.fields() ['start', 'finish', 'before', 'after', 'meta'] Returns: list: a list of names as (:class:`str`) instances. """ return list(map(attrgetter('name'), attr.fields(cls)))
class Param: """Specification of a parameter that is to be constrained.""" name = attr.ib() _min: numeric = attr.ib(-np.inf) _max: numeric = attr.ib(np.inf) prior = attr.ib( kw_only=True, validator=vld.optional(vld.instance_of(stats.distributions.rv_frozen)), ) fiducial = attr.ib( None, type=float, converter=cnv.optional(float), validator=vld.optional(vld.instance_of(float)), kw_only=True, ) latex = attr.ib(kw_only=True) ref = attr.ib(kw_only=True) determines = attr.ib( converter=tuplify, kw_only=True, validator=vld.deep_iterable(vld.instance_of(str)), ) transforms = attr.ib(converter=tuplify, kw_only=True) @latex.default def _ltx_default(self): return texify(self.name) @ref.default def _ref_default(self): return self.prior @prior.default def _prior_default(self) -> stats.distributions.rv_frozen | None: if np.isinf(self._min) or np.isinf(self._max): return None return stats.uniform(self._min, self._max - self._min) @determines.default def _determines_default(self): return (self.name,) @transforms.default def _transforms_default(self): return (None,) * len(self.determines) @transforms.validator def _transforms_validator(self, attribute, value): for val in value: if val is not None and not callable(val): raise TypeError("transforms must be a list of callables") @property def min(self) -> float: """The minimum boundary of the prior, helpful for constraints.""" if self.prior is None: return self._min elif isinstance(self.prior, type(stats.uniform(0, 1))): return self.prior.support()[0] else: return -np.inf @property def max(self) -> float: """The maximum boundary of the prior, helpful for constraints.""" if self.prior is None: return self._max elif isinstance(self.prior, type(stats.uniform(0, 1))): return self.prior.support()[1] else: return np.inf @cached_property def is_alias(self): return all(pm is None for pm in self.transforms) @cached_property def is_pure_alias(self): return self.is_alias and len(self.determines) == 1 def transform(self, val): for pm in self.transforms: if pm is None: yield val else: yield pm(val) def generate_ref(self, n=1): if self.ref is None: raise ValueError("Must specify a valid function for ref to generate refs.") try: ref = self.ref.rvs(size=n) except AttributeError: try: ref = self.ref(size=n) except TypeError: raise TypeError( f"parameter '{self.name}' does not have a valid value for ref" ) if np.any(self.prior.pdf(ref) == 0): raise ValueError( f"param {self.name} produced a reference value outside its domain." ) return ref def logprior(self, val): if self.prior is None: if self._min > val or self._max < val: return -np.inf else: return 0 return self.prior.logpdf(val) def clone(self, **kwargs): return attr.evolve(self, **kwargs) def new(self, p: Parameter) -> Param: """Create a new :class:`Param`. Any missing info from this instance filled in by the given instance. """ assert isinstance(p, Parameter) assert self.determines == (p.name,) if len(self.determines) > 1: raise ValueError("Cannot create new Param if it is not just an alias") default_range = (list(self.transform(p.min))[0], list(self.transform(p.max))[0]) return Param( name=self.name, min=max(self._min, min(default_range)), max=min(self._max, max(default_range)), fiducial=self.fiducial if self.fiducial is not None else p.fiducial, latex=self.latex if (self.latex != self.name or self.name != p.name) else p.latex, ref=self.ref or attr.NOTHING, prior=self.prior or attr.NOTHING, determines=self.determines, transforms=self.transforms, ) def __getstate__(self): """Obtain a simple input state of the class that can initialize it.""" out = attr.asdict(self) if self.transforms == (None,): del out["transforms"] if self.ref is None: del out["ref"] if self.determines == (self.name,): del out["determines"] if self.latex == self.name: del out["latex"] return out def as_dict(self): """Simple representation of the class as a dict. No "name" is included in the dict. """ out = self.__getstate__() del out["name"] return out
class PredictConfig: """class that represents [PREDICT] section of config.toml file Attributes ---------- csv_path : str path to where dataset was saved as a csv. checkpoint_path : str path to directory with checkpoint files saved by Torch, to reload model labelmap_path : str path to 'labelmap.json' file. models : list of model names. e.g., 'models = TweetyNet, GRUNet, ConvNet' batch_size : int number of samples per batch presented to models during training. num_workers : int Number of processes to use for parallel loading of data. Argument to torch.DataLoader. Default is 2. device : str Device on which to work with model + data. Defaults to 'cuda' if torch.cuda.is_available is True. spect_scaler_path : str path to a saved SpectScaler object used to normalize spectrograms. If spectrograms were normalized and this is not provided, will give incorrect results. annot_csv_filename : str name of .csv file containing predicted annotations. Default is None, in which case the name of the dataset .csv is used, with '.annot.csv' appended to it. output_dir : str path to location where .csv containing predicted annotation should be saved. Defaults to current working directory. min_segment_dur : float minimum duration of segment, in seconds. If specified, then any segment with a duration less than min_segment_dur is removed from lbl_tb. Default is None, in which case no segments are removed. majority_vote : bool if True, transform segments containing multiple labels into segments with a single label by taking a "majority vote", i.e. assign all time bins in the segment the most frequently occurring label in the segment. This transform can only be applied if the labelmap contains an 'unlabeled' label, because unlabeled segments makes it possible to identify the labeled segments. Default is False. save_net_outputs : bool if True, save 'raw' outputs of neural networks before they are converted to annotations. Default is False. Typically the output will be "logits" to which a softmax transform might be applied. For each item in the dataset--each row in the `csv_path` .csv-- the output will be saved in a separate file in `output_dir`, with the extension `{MODEL_NAME}.output.npz`. E.g., if the input is a spectrogram with `spect_path` filename `gy6or6_032312_081416.npz`, and the network is `TweetyNet`, then the net output file will be `gy6or6_032312_081416.tweetynet.output.npz`. """ # required, external files checkpoint_path = attr.ib(converter=expanded_user_path, validator=is_a_file) labelmap_path = attr.ib(converter=expanded_user_path, validator=is_a_file) # required, model / dataloader models = attr.ib( converter=comma_separated_list, validator=[instance_of(list), is_valid_model_name], ) batch_size = attr.ib(converter=int, validator=instance_of(int)) # csv_path is actually 'required' but we can't enforce that here because cli.prep looks at # what sections are defined to figure out where to add csv_path after it creates the csv csv_path = attr.ib( converter=converters.optional(expanded_user_path), validator=validators.optional(is_a_file), default=None, ) # optional, transform spect_scaler_path = attr.ib( converter=converters.optional(expanded_user_path), validator=validators.optional(is_a_file), default=None, ) # optional, data loader num_workers = attr.ib(validator=instance_of(int), default=2) device = attr.ib(validator=instance_of(str), default=device.get_default()) annot_csv_filename = attr.ib(validator=validators.optional( instance_of(str)), default=None) output_dir = attr.ib( converter=expanded_user_path, validator=is_a_directory, default=Path(os.getcwd()), ) min_segment_dur = attr.ib(validator=validators.optional( instance_of(float)), default=None) majority_vote = attr.ib(validator=instance_of(bool), default=True) save_net_outputs = attr.ib(validator=instance_of(bool), default=False)
class TrainConfig: """class that represents [TRAIN] section of config.toml file Attributes ---------- models : list comma-separated list of model names. e.g., 'models = TweetyNet, GRUNet, ConvNet' csv_path : str path to where dataset was saved as a csv. num_epochs : int number of training epochs. One epoch = one iteration through the entire training set. batch_size : int number of samples per batch presented to models during training. root_results_dir : str directory in which results will be created. The vak.cli.train function will create a subdirectory in this directory each time it runs. num_workers : int Number of processes to use for parallel loading of data. Argument to torch.DataLoader. device : str Device on which to work with model + data. Defaults to 'cuda' if torch.cuda.is_available is True. shuffle: bool if True, shuffle training data before each epoch. Default is True. normalize_spectrograms : bool if True, use spect.utils.data.SpectScaler to normalize the spectrograms. Normalization is done by subtracting off the mean for each frequency bin of the training set and then dividing by the std for that frequency bin. This same normalization is then applied to validation + test data. val_step : int Step on which to estimate accuracy using validation set. If val_step is n, then validation is carried out every time the global step / n is a whole number, i.e., when val_step modulo the global step is 0. Default is None, in which case no validation is done. ckpt_step : int Step on which to save to checkpoint file. If ckpt_step is n, then a checkpoint is saved every time the global step / n is a whole number, i.e., when ckpt_step modulo the global step is 0. Default is None, in which case checkpoint is only saved at the last epoch. patience : int number of validation steps to wait without performance on the validation set improving before stopping the training. Default is None, in which case training only stops after the specified number of epochs. """ # required models = attr.ib( converter=comma_separated_list, validator=[instance_of(list), is_valid_model_name], ) num_epochs = attr.ib(converter=int, validator=instance_of(int)) batch_size = attr.ib(converter=int, validator=instance_of(int)) root_results_dir = attr.ib(converter=expanded_user_path, validator=is_a_directory) # optional # csv_path is actually 'required' but we can't enforce that here because cli.prep looks at # what sections are defined to figure out where to add csv_path after it creates the csv csv_path = attr.ib( converter=converters.optional(expanded_user_path), validator=validators.optional(is_a_file), default=None, ) results_dirname = attr.ib( converter=converters.optional(expanded_user_path), validator=validators.optional(is_a_directory), default=None, ) normalize_spectrograms = attr.ib( converter=bool_from_str, validator=validators.optional(instance_of(bool)), default=False, ) num_workers = attr.ib(validator=instance_of(int), default=2) device = attr.ib(validator=instance_of(str), default=device.get_default()) shuffle = attr.ib( converter=bool_from_str, validator=instance_of(bool), default=True ) val_step = attr.ib( converter=converters.optional(int), validator=validators.optional(instance_of(int)), default=None, ) ckpt_step = attr.ib( converter=converters.optional(int), validator=validators.optional(instance_of(int)), default=None, ) patience = attr.ib( converter=converters.optional(int), validator=validators.optional(instance_of(int)), default=None, )
class ProxyPlugin(StateMachine): 'Proxy USB communications using a ubq_core enabled hardware device.' #: Address to listen to for USB device. _device_addr = attr.ib(converter=optional(str), default=None) #: Port to listen to for USB device. _device_port = attr.ib(converter=optional(int), default=None) #: Address to send to for USB host. _host_addr = attr.ib(converter=optional(str), default=None) #: Port to send to for USB host. _host_port = attr.ib(converter=optional(int), default=None) #: Timeout for select statement that waits for incoming USBQ packets timeout = attr.ib(converter=int, default=1) # States idle = State('idle', initial=True) running = State('running') # Valid state transitions start = idle.to(running) reset = running.to(idle) | idle.to(idle) reload = idle.to(running) EMPTY = [] MANAGEMENT_MSG = { ManagementMessage.ManagementType.NEW_DEVICE: 'New device connected to USBQ proxy', ManagementMessage.ManagementType.RESET: 'Device reset sent from USBQ proxy', } def __attrs_post_init__(self): # Workaround to mesh attr and StateMachine super().__init__() self._socks = [] self._proxy_host = True self._proxy_device = True self._device_dst = None self._detected_host = False self._detected_device = False if self._device_addr is None or self._device_port is None: self._proxy_device = False if self._host_addr is None or self._host_port is None: self._proxy_host = False if self._proxy_device: log.info( f'Device listen to {self._device_addr}:{self._device_port}') self._device_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self._device_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self._device_sock.setblocking(False) self._device_sock.bind((self._device_addr, self._device_port)) self._socks.append(self._device_sock) if self._proxy_host: log.info(f'Host send to {self._host_addr}:{self._host_port}') self._host_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self._host_sock.setblocking(False) self._host_dst = (self._host_addr, self._host_port) self._socks.append(self._host_sock) def _has_data(self, socks, timeout=0): (read, write, error) = select.select(socks, self.EMPTY, socks, timeout) if len(read) != 0: return True return False @hookimpl def usbq_host_has_packet(self): if self._proxy_host: if self._has_data([self._host_sock]): return True @hookimpl def usbq_device_has_packet(self): if self._proxy_device: if self._has_data([self._device_sock]): return True @hookimpl def usbq_wait_for_packet(self): # Poll for data from non-proxy source queued_data = [] if not self._proxy_host: queued_data += pm.hook.usbq_host_has_packet() if not self._proxy_device: queued_data += pm.hook.usbq_device_has_packet() if any(queued_data): return True else: # Wait if self._has_data(self._socks, timeout=self.timeout): return True @hookimpl def usbq_get_host_packet(self): data, self._host_dst = self._host_sock.recvfrom(4096) if not self._detected_host: log.info('First USBQ host packet detected from proxy') self._detected_host = True return data @hookimpl def usbq_get_device_packet(self): data, self._device_dst = self._device_sock.recvfrom(4096) if not self._detected_device: log.info('First USBQ device packet detected from proxy') self._detected_device = True return data @hookimpl def usbq_send_host_packet(self, data): return self._host_sock.sendto(data, self._host_dst) > 0 @hookimpl def usbq_send_device_packet(self, data): if self._device_dst is not None: return self._device_sock.sendto(data, self._device_dst) > 0 @hookimpl def usbq_log_pkt(self, pkt): if ManagementMessage in pkt: msg = self.MANAGEMENT_MSG.get(pkt.content.management_type, None) if msg is not None: log.info(msg) if pkt.content.management_type == ManagementMessage.ManagementType.RESET: self._detected_device = False def on_start(self): log.info('Starting proxy.') def _send_host_mgmt(self, pkt): data = pm.hook.usbq_host_encode(pkt=USBMessageDevice( type=USBMessageHost.MitmType.MANAGEMENT, content=pkt)) self.usbq_send_host_packet(data) def _send_device_mgmt(self, pkt): data = pm.hook.usbq_device_encode(pkt=USBMessageHost( type=USBMessageDevice.MitmType.MANAGEMENT, content=pkt)) self.usbq_send_device_packet(data) def on_reset(self): log.info('Reset device.') self._send_device_mgmt( ManagementMessage( management_type=ManagementMessage.ManagementType.RESET, management_content=ManagementReset(), )) def on_reload(self): log.info('Reload device.') self._send_device_mgmt( ManagementMessage( management_type=ManagementMessage.ManagementType.RELOAD, management_content=ManagementReload(), ))
class PredictConfig: """class that represents [PREDICT] section of config.toml file Attributes ---------- csv_path : str path to where dataset was saved as a csv. checkpoint_path : str path to directory with checkpoint files saved by Torch, to reload model labelmap_path : str path to 'labelmap.json' file. annot_format : str format of annotations. Any format that can be used with the crowsetta library is valid. models : list of model names. e.g., 'models = TweetyNet, GRUNet, ConvNet' batch_size : int number of samples per batch presented to models during training. num_workers : int Number of processes to use for parallel loading of data. Argument to torch.DataLoader. Default is 2. device : str Device on which to work with model + data. Defaults to 'cuda' if torch.cuda.is_available is True. spect_scaler_path : str path to a saved SpectScaler object used to normalize spectrograms. If spectrograms were normalized and this is not provided, will give incorrect results. to_format_kwargs : dict keyword arguments for crowsetta `to_format` function. Defined in .toml config file as a table. An example for the notmat annotation format (as a dictionary) is: {'min_syl_dur': 10., 'min_silent_dur', 6., 'threshold': 1500}. """ # required, external files checkpoint_path = attr.ib(converter=expanded_user_path, validator=is_a_file) labelmap_path = attr.ib(converter=expanded_user_path, validator=is_a_file) # required, for annotation annot_format = attr.ib(validator=is_annot_format) # required, model / dataloader models = attr.ib(converter=comma_separated_list, validator=[instance_of(list), is_valid_model_name]) batch_size = attr.ib(converter=int, validator=instance_of(int)) # csv_path is actually 'required' but we can't enforce that here because cli.prep looks at # what sections are defined to figure out where to add csv_path after it creates the csv csv_path = attr.ib(converter=converters.optional(expanded_user_path), validator=validators.optional(is_a_file), default=None) # optional to_format_kwargs = attr.ib(validator=validators.optional( instance_of(dict)), default=None) # optional, transform spect_scaler_path = attr.ib( converter=converters.optional(expanded_user_path), validator=validators.optional(is_a_file), default=None) # optional, data loader num_workers = attr.ib(validator=instance_of(int), default=2) device = attr.ib(validator=instance_of(str), default=device.get_default())