class MySchema(ArgSchema): boolean = fields.Boolean(required=True) date = fields.Date(required=True) datetime = fields.DateTime(required=True) decimal = fields.Decimal(requied=True) dict = fields.Dict(required=True) email = fields.Email(required=True) float = fields.Float(required=True) inputdir = fields.InputDir(required=True) inputfile = fields.InputFile(required=True) integer = fields.Int(required=True) list = fields.List(fields.Int, required=True, cli_as_single_argument=True) localdatetime = fields.LocalDateTime(required=True) nested = fields.Nested(MyNestedSchema, required=True) number = fields.Number(required=True) numpyarray = fields.NumpyArray(dtype="uint8", required=True) outputdir = fields.OutputDir(required=True) outputfile = fields.OutputFile(required=True) raw = fields.Raw(required=True) slice = fields.Slice(required=True) string = fields.Str(required=True) time = fields.Time(required=True) timedelta = fields.TimeDelta(required=True) url = fields.URL(required=True) uuid = fields.UUID(required=True)
class ClassifierArtifactsInputSchema(ArgSchema): # Input data locations. video_path = fields.InputFile( required=True, description="Path to motion corrected(+denoised) video.", ) roi_path = fields.InputFile( required=True, description="Path to json file containing detected ROIs", ) graph_path = fields.InputFile( required=True, description="Path to pickle file containing full movie graph.", ) # Output Artifact location. out_dir = fields.OutputDir( required=True, description="Output directory to put artifacts.", ) # Artifact generation settings. low_quantile = fields.Float( required=False, default=0.2, description="Low quantile to saturate/clip to.", ) high_quantile = fields.Float( required=False, default=0.99, description="High quantile to saturate/clip to.", ) cutout_size = fields.Int( required=False, default=128, description="Size of square cutout in pixels.", ) selected_rois = fields.List( fields.Int, required=False, allow_none=True, default=None, description="Specific subset of ROIs by ROI id in the experiment FOV " "to produce artifacts for. Only ROIs specified in this " "will have artifacts output.", )
class DffJobSchema(ArgSchema): input_file = H5InputFile( required=True, description=("Input h5 file containing fluorescence traces and the " "associated ROI IDs (in datasets specified by the keys " "'input_dataset' and 'roi_field', respectively.") ) output_file = fields.OutputFile( required=True, description="h5 file to write the results of dff computation." ) movie_frame_rate_hz = fields.Float( required=True, description=("Acquisition frame rate for the trace data in " "`input_dataset`") ) log_level = fields.Int( required=False, default=20 # logging.INFO ) input_dataset = fields.Str( required=False, default="FC", description="Key of h5 dataset to use from `input_file`." ) roi_field = fields.Str( required=False, default="roi_names", description=("The h5 dataset key in both the `input_file` and " "`output_file` containing ROI IDs associated with " "traces.") ) output_dataset = fields.Str( required=False, default="data", description=("h5 dataset key used to store the computed dff traces " "in `output_file`.") ) sigma_dataset = fields.Str( required=False, default="sigma_dff", description=("h5 dataset key used to store the estimated noise " "standard deviation for the dff traces in `output_file`.") ) baseline_frames_dataset = fields.Str( required=False, default="num_small_baseline_frames", description=("h5 dataset key used to store the number of small " "baseline frames (where the computed baseline of the " "fluorescence trace was smaller than its estimated " "noise standard deviation) in `output_file`.") ) long_baseline_filter_s = fields.Int( required=False, default=600, description=("Number of seconds to use in the rolling median " "filter for for computing the baseline activity. " "The length of the filter is the frame rate of the " "signal in Hz * the long baseline filter seconds (" "+1 if the result is even, since the median filter " "length must be odd).") ) short_filter_s = fields.Float( required=False, default=3.333, description=("Number of seconds to use in the rolling median " "filter for the short timescale detrending. " "The length of the filter is the frame rate of the " "signal in Hz * the short baseline filter seconds (" "+1 if the result is even, since the median filter " "length must be odd).") ) n_parallel_workers = fields.Int( required=False, default=1, description="number of parallel workers") @post_load def filter_s_to_frames(self, item, **kwargs): """Convert number of seconds to number of frames for the filters `short_filter_s`, `long_baseline_filter_s`. If the number of frames is even, add 1.""" short_frames = int(np.round( item["movie_frame_rate_hz"] * item["short_filter_s"])) long_frames = int(np.round( item["movie_frame_rate_hz"] * item["long_baseline_filter_s"])) # Has to be odd item["short_filter_frames"] = ( short_frames if short_frames % 2 else short_frames + 1) item["long_filter_frames"] = ( long_frames if long_frames % 2 else long_frames + 1) return item
class InferenceInputSchema(ArgSchema): """ Argschema parser for module as a script """ neuropil_traces_path = H5InputFile( required=True, description=( "Path to neuropil traces from an experiment (h5 format). " "The order of the traces in the dataset should correspond to " "the order of masks in `roi_masks_path`.")) neuropil_traces_data_key = fields.Str( required=False, missing="data", description=("Key in `neuropil_traces_path` h5 file where data array " "is stored.")) neuropil_trace_names_key = fields.Str( required=False, missing="roi_names", description=("Key in `neuropil_traces_path` h5 file which describes" "the roi name (id) associated with each trace.")) traces_path = H5InputFile( required=True, description=( "Path to traces extracted from an experiment (h5 format). " "The order of the traces in the dataset should correspond to " "the order of masks in `roi_masks_path`.")) traces_data_key = fields.Str( required=False, missing="data", description=("Key in `traces_path` h5 file where data array is " "stored.")) trace_names_key = fields.Str( required=False, missing="roi_names", description=("Key in `traces_path` h5 file which describes" "the roi name (id) associated with each trace.")) roi_masks_path = fields.InputFile( required=True, description=("Path to json file of segmented ROI masks. The file " "records must conform to the schema " "`DenseROISchema`")) rig = fields.Str( required=True, description=("Name of the ophys rig used for the experiment.")) depth = fields.Int(required=True, description=("Imaging depth for the experiment.")) full_genotype = fields.Str( required=True, description=("Genotype of the experimental subject.")) targeted_structure = fields.Str( required=True, description=("Name of the brain structure targeted by imaging.")) classifier_model_path = fields.Str( required=True, description=("Path to model. Can either be an s3 location or a " "path on the local file system. The output of the model " "should be 0 if the ROI is classified as not a cell, " "and 1 if the ROI is classified as a cell. If this " "field is not provided, the classifier model registry " "DynamoDB will be queried.")) trace_sampling_rate = fields.Int( required=False, missing=31, description=("Sampling rate of trace (frames per second). By default " "trace sampling rates are assumed to be 31 Hz (inherited " "from the source motion_corrected.h5 movie).")) desired_trace_sampling_rate = fields.Int( required=False, missing=4, validate=lambda x: x > 0, description=("Target rate to downsample trace data (frames per " "second). Will use average bin values for downsampling.")) output_json = fields.OutputFile( required=True, description="Filepath to dump json output.") model_registry_table_name = fields.Str( required=False, missing="ROIClassifierRegistry", description=("The name of the DynamoDB table containing " "the ROI classifier model registry.")) model_registry_env = fields.Str( required=False, validate=OneOf({'dev', 'stage', 'prod'}, error=("'{input}' is not a valid value for the " "'model_registry_env' field. Possible " "valid options are: {choices}")), missing="prod", description=( "Which environment to query when searching for a " "classifier model path from the classifier model " "registry. Possible options are: ['dev', 'stage', 'prod]")) # The options below are set by the LIMS queue but are not necessary to run # the code. motion_corrected_movie_path = fields.InputFile( required=False, default=None, allow_none=True, description=("Path to motion corrected video.")) movie_frame_rate_hz = fields.Float( required=False, default=None, allow_none=True, description=("The frame rate (in Hz) of the optical physiology " "movie to be Suite2P segmented. Used in conjunction " "with 'bin_duration' to derive an 'nbinned' " "Suite2P value.")) @pre_load def determine_classifier_model_path(self, data: dict, **kwargs) -> dict: if "classifier_model_path" not in data: # Can't rely on field `missing` param as it doesn't get filled in # until deserialization/validation. The get defaults should match # the 'missing' param for the model_registry_table_name and # model_registry_env fields. table_name = data.get("model_registry_table_name", "ROIClassifierRegistry") model_env = data.get("model_registry_env", "prod") model_registry = utils.RegistryConnection(table_name=table_name) model_path = model_registry.get_active_model(env=model_env) data["classifier_model_path"] = model_path return data @validates("classifier_model_path") def validate_classifier_model_path(self, uri: str, **kwargs): """ Check to see if file exists (either s3 or local file) """ if uri.startswith("s3://"): s3 = boto3.client("s3") parsed = urlparse(uri, allow_fragments=False) try: s3.head_object(Bucket=parsed.netloc, Key=parsed.path.lstrip("/")) except ClientError as e: if e.response["ResponseMetadata"]["HTTPStatusCode"] == 404: raise ValidationError( f"Object at URI {uri} does not exist.") else: raise e from None else: if not os.path.exists(uri): raise ValidationError(f"File at '{uri}' does not exist.") @post_load def check_keys_exist(self, data: dict, **kwargs) -> dict: """ For h5 files, check that the passed key exists in the data. """ pairs = [("neuropil_traces_path", "neuropil_traces_data_key"), ("traces_path", "traces_data_key")] for h5file, key in pairs: with h5py.File(data[h5file], "r") as f: if not data[key] in f.keys(): raise ValidationError( f"Key '{data[key]}' ({key}) was missing in h5 file " f"{data[h5file]} ({h5file}.") return data