def _prepare_traces(self, run_modes: Set[str]) -> None: """Prepare information about the traces for training. Add default traces into the traces_in_use list, also prints a warning if no model saver trace is detected. Args: run_modes: The current execution modes. """ self.traces_in_use = [trace for trace in self.traces] if self.system.log_steps is not None: self.traces_in_use.append(Logger()) if "train" in run_modes: self.traces_in_use.insert( 0, TrainEssential(monitor_names=self.monitor_names)) no_save_warning = True for trace in get_current_items(self.traces_in_use, run_modes=run_modes): if isinstance(trace, (ModelSaver, BestModelSaver)): no_save_warning = False if no_save_warning: print( "FastEstimator-Warn: No ModelSaver Trace detected. Models will not be saved." ) if "eval" in run_modes and "eval" in self.pipeline.get_modes(): self.traces_in_use.insert( 1, EvalEssential(monitor_names=self.monitor_names)) # insert system instance to trace for trace in get_current_items(self.traces_in_use, run_modes=run_modes): trace.system = self.system
def _prepare_traces(self, run_modes: Set[str]) -> None: """Prepare information about the traces for training. Add default traces into the traces_in_use list, also prints a warning if no model saver trace is detected. Args: run_modes: The current execution modes. """ self.traces_in_use = [trace for trace in self.traces] if self.system.log_steps is not None: self.traces_in_use.append(Logger()) # Look for any monitor names which should be automagically added. trace_outputs = set() extra_monitor_keys = set() for trace in sort_traces(get_current_items(self.traces_in_use, run_modes=run_modes)): trace_outputs.update(trace.outputs) extra_monitor_keys.update(trace.fe_monitor_names - trace_outputs) # Add the essential traces if "train" in run_modes: self.traces_in_use.insert(0, TrainEssential(monitor_names=self.monitor_names.union(extra_monitor_keys))) no_save_warning = True for trace in get_current_items(self.traces_in_use, run_modes=run_modes): if isinstance(trace, (ModelSaver, BestModelSaver)): no_save_warning = False if no_save_warning: print("FastEstimator-Warn: No ModelSaver Trace detected. Models will not be saved.") if "eval" in run_modes and "eval" in self.pipeline.get_modes(): self.traces_in_use.insert(1, EvalEssential(monitor_names=self.monitor_names.union(extra_monitor_keys))) # insert system instance to trace for trace in get_current_items(self.traces_in_use, run_modes=run_modes): trace.system = self.system
def _verify_inputs(self) -> None: """Ensure that all ops are TensorOps. Raises: AssertionError: If any of the ops are not TensorOps. """ for op in get_current_items(self.ops): assert isinstance(op, TensorOp), "unsupported op format, Network ops must be TensorOps" for op in get_current_items(self.postprocessing): assert isinstance(op, NumpyOp), "unsupported op format, Network postprocessing must be NumpyOps"
def _warmup(self, eager: bool = True) -> None: """Perform a test run of each pipeline and network signature epoch to make sure that training won't fail later. Traces are not executed in the warmup since they are likely to contain state variables which could become corrupted by running extra steps. Args: eager: Whether to run the training in eager mode. This is only related to TensorFlow training because PyTorch by nature is always in eager mode. """ all_traces = get_current_items(self.traces_in_use, run_modes={"train", "eval"}) sort_traces(all_traces) # This ensures that the traces can sort properly for on_begin and on_end monitor_names = self.monitor_names for mode in self.pipeline.get_modes() - {"test"}: scheduled_items = self.pipeline.get_scheduled_items(mode) + self.network.get_scheduled_items( mode) + self.get_scheduled_items(mode) signature_epochs = get_signature_epochs(scheduled_items, self.system.total_epochs, mode=mode) epochs_with_data = self.pipeline.get_epochs_with_data(total_epochs=self.system.total_epochs, mode=mode) for epoch in signature_epochs: if epoch not in epochs_with_data: continue network_output_keys = self.network.get_all_output_keys(mode, epoch) network_input_keys = self.network.get_effective_input_keys(mode, epoch) trace_input_keys = set() trace_output_keys = {"*"} traces = get_current_items(self.traces_in_use, run_modes=mode, epoch=epoch) for idx, trace in enumerate(traces): if idx > 0: # ignore TrainEssential and EvalEssential's inputs for unmet requirement checking trace_input_keys.update(trace.inputs) trace_output_keys.update(trace.outputs) # key checking loader = self._configure_loader( self.pipeline.get_loader(mode, epoch, output_keys=trace_input_keys - network_output_keys | network_input_keys)) with Suppressor(): if isinstance(loader, tf.data.Dataset): batch = list(loader.take(1))[0] else: batch = next(iter(loader)) batch = self._configure_tensor(loader, batch) assert isinstance(batch, dict), "please make sure data output format is dictionary" pipeline_output_keys = to_set(batch.keys()) monitor_names = monitor_names - (pipeline_output_keys | network_output_keys) unmet_requirements = trace_input_keys - (pipeline_output_keys | network_output_keys | trace_output_keys) assert not unmet_requirements, \ "found missing key(s) during epoch {} mode {}: {}".format(epoch, mode, unmet_requirements) sort_traces(traces, available_outputs=pipeline_output_keys | network_output_keys) trace_input_keys.update(traces[0].inputs) self.network.load_epoch(mode, epoch, output_keys=trace_input_keys, warmup=True, eager=eager) self.network.run_step(batch) self.network.unload_epoch() assert not monitor_names, "found missing key(s): {}".format(monitor_names)
def load_epoch(self, mode: str, epoch: int, output_keys: Optional[Set[str]] = None, warmup: bool = False, eager: bool = False) -> None: """Prepare the network to run a given epoch and mode. This method is necessary since schedulers and op mode restrictions may result in different computation graphs every epoch. Args: mode: The mode to prepare to execute. One of 'train', 'eval', 'test', or 'infer'. epoch: The epoch to prepare to execute. output_keys: What keys can be moved from the GPU back to the CPU after executing a step. warmup: Whether to prepare to execute it warmup mode or not (end users can likely ignore this argument). eager: Whether to run the training in eager mode. This is only related to TensorFlow training because PyTorch by nature is always in eager mode. """ self.effective_inputs[mode] = self.get_effective_input_keys( mode, epoch) self.effective_outputs[mode] = self.get_all_output_keys(mode, epoch) if output_keys: self.effective_outputs[ mode] = self.effective_outputs[mode].intersection( output_keys ) | self._get_effective_postprocessing_input_keys(mode, epoch) self.epoch_ops = get_current_items(self.ops, mode, epoch) self.epoch_postprocessing = get_current_items(self.postprocessing, mode, epoch) self.epoch_models = set.union( *[op.get_fe_models() for op in self.epoch_ops]) gradient_ops = [ op for op in self.epoch_ops if op.fe_retain_graph() is not None ] for idx, gradient_op in enumerate(gradient_ops): gradient_op.fe_retain_graph(idx != len(gradient_ops) - 1) self.epoch_state = { "warmup": warmup, "mode": mode, "req_grad": len(gradient_ops) > 0, "epoch": epoch, "deferred": {}, "eager": eager } # warmup: bool, mode: str, req_grad: bool, epoch: int, deferred: Dict[str, List[Callable]]] for model in self.epoch_models: if hasattr(model, "optimizer") and model.optimizer is not None: if isinstance(model.optimizer, Scheduler): model.current_optimizer = model.optimizer.get_current_value( epoch) else: model.current_optimizer = model.optimizer
def _verify_dataset(self, dataset: DataSource, **kwargs) -> bool: """A helper function to ensure that all of a dataset's arguments are correct. Args: dataset: The dataset to validate against. **kwargs: A selection of variables and their values which must be validated. Returns: True iff the `dataset` is a PyTorch Dataset (as opposed to a DataLoader or tf.data.Dataset). Raises: AssertionError: If the `kwargs` are found to be invalid based on the given `dataset`. ValueError: If the `dataset` is of an unknown type. """ if isinstance(dataset, Dataset): # batch_size check for batch_size in get_current_items(to_list(self.batch_size)): assert isinstance( batch_size, (int, dict)), "unsupported batch_size format: {}".format( type(batch_size)) if isinstance(batch_size, dict): assert all([key in {"train", "eval", "test", "infer"} for key in batch_size.keys()]), \ "batch size dictionaries must be keyed on mode" assert all([isinstance(val, int) for val in batch_size.values()]), \ "batch size dictionary values must be integers" # ops check for op in get_current_items(self.ops): assert isinstance( op, NumpyOp ), "unsupported op format, must provide NumpyOp in Pipeline" # num_process check assert isinstance(self.num_process, int), "number of processes must be an integer" return True elif isinstance(dataset, (DataLoader, tf.data.Dataset)): if kwargs['batch_size'] is not None: print( "FastEstimator-Warn: batch_size will only be used for built-in dataset" ) if kwargs['ops'] is not None: print( "FastEstimator-Warn: ops will only be used for built-in dataset" ) if kwargs['num_process'] is not None: print( "FastEstimator-Warn: num_process will only be used for built-in dataset" ) return False else: raise ValueError("Unsupported dataset type: {}".format( type(dataset)))
def __init__( self, target_type: str, device: Optional[torch.device], ops: Iterable[Union[TensorOp, Scheduler[TensorOp]]], postprocessing: Union[None, NumpyOp, Scheduler[NumpyOp], Iterable[Union[NumpyOp, Scheduler[NumpyOp]]]] = None ) -> None: self.ops = to_list(ops) self.target_type = target_type self.device = device for op in get_current_items(self.ops): op.build(framework=self.target_type, device=self.device) self.models = to_list(_collect_models(ops)) self.postprocessing = to_list(postprocessing) self._verify_inputs() self.effective_inputs = dict() self.effective_outputs = dict() self.epoch_ops = [] self.epoch_postprocessing = [] self.epoch_models = set() self.epoch_state = dict() self.mixed_precision = any([model.mixed_precision for model in self.models]) if self.mixed_precision and not all([model.mixed_precision for model in self.models]): raise ValueError("Cannot mix full precision and mixed-precision models")
def load_epoch(self, mode: str, epoch: int, output_keys: Optional[Set[str]] = None, warmup: bool = False) -> None: """Prepare the network to run a given epoch and mode. This method is necessary since schedulers and op mode restrictions may result in different computation graphs every epoch. Args: mode: The mode to prepare to execute. One of 'train', 'eval', 'test', or 'infer'. epoch: The epoch to prepare to execute. output_keys: What keys must be moved from the GPU back to the CPU after executing a step. warmup: Whether to prepare to execute it warmup mode or not (end users can likely ignore this argument). """ self.effective_inputs[mode] = self.get_effective_input_keys(mode, epoch) self.effective_outputs[mode] = self.get_all_output_keys(mode, epoch) if output_keys: self.effective_outputs[mode] = self.effective_outputs[mode].intersection(output_keys) self.epoch_ops = get_current_items(self.ops, mode, epoch) self.epoch_models = set(op.model for op in self.epoch_ops if isinstance(op, (UpdateOp, ModelOp))) gradient_ops = [op for op in self.epoch_ops if hasattr(op, "retain_graph")] for idx, gradient_op in enumerate(gradient_ops): gradient_op.retain_graph = idx != len(gradient_ops) - 1 self.epoch_state = {"warmup": warmup, "mode": mode, "req_grad": len(gradient_ops) > 0, "epoch": epoch} for model in self.epoch_models: if hasattr(model, "optimizer") and model.optimizer is not None: if isinstance(model.optimizer, Scheduler): model.current_optimizer = model.optimizer.get_current_value(epoch) else: model.current_optimizer = model.optimizer
def _get_op_split( self, mode: str, epoch: int, ds_id: str) -> Tuple[List[NumpyOp], Batch, List[NumpyOp]]: """Figure out which ops are pre-batch vs post-batch. Args: mode: The current mode. epoch: The current epoch. ds_id: The current dataset. Returns: (instance ops, batch info, batch ops). """ batch_info = Batch() instance_ops = [] batch_ops = [] ops = get_current_items(self.ops, run_modes=mode, epoch=epoch, ds_id=ds_id) target = instance_ops for op in ops: if isinstance(op, Batch): batch_info = op target = batch_ops continue target.append(op) return instance_ops, batch_info, batch_ops
def _start(self, run_modes: Set[str], eager: bool) -> None: """The outer training loop. This method invokes the trace on_begin method, runs the necessary 'train' and 'eval' epochs, and then invokes the trace on_end method. Args: run_modes: The current execution modes. eager: Whether to run the training in eager mode. This is only related to TensorFlow training because PyTorch by nature is always in eager mode. """ all_traces = sort_traces(get_current_items(self.traces_in_use, run_modes=run_modes)) try: self._run_traces_on_begin(traces=all_traces) if "train" in run_modes or "eval" in run_modes: # If the training is re-starting from a restore wizard, it should re-run the last eval epoch if self.system.epoch_idx > 0 and "eval" in self.pipeline.get_modes(epoch=self.system.epoch_idx): self.system.mode = "eval" self._run_epoch(eager=eager) for self.system.epoch_idx in range(self.system.epoch_idx + 1, self.system.total_epochs + 1): if "train" in self.pipeline.get_modes(epoch=self.system.epoch_idx): self.system.mode = "train" self._run_epoch(eager=eager) if "eval" in self.pipeline.get_modes(epoch=self.system.epoch_idx): self.system.mode = "eval" self._run_epoch(eager=eager) else: self._run_epoch(eager=eager) except EarlyStop: pass # On early stopping we still want to run the final traces and return results self._run_traces_on_end(traces=all_traces)
def _start(self, run_modes: Set[str]) -> None: """The outer training loop. This method invokes the trace on_begin method, runs the necessary 'train' and 'eval' epochs, and then invokes the trace on_end method. Args: run_modes: The current execution modes. """ all_traces = sort_traces(get_current_items(self.traces_in_use, run_modes=run_modes)) try: self._run_traces_on_begin(traces=all_traces) if "train" in run_modes or "eval" in run_modes: for self.system.epoch_idx in range(self.system.epoch_idx + 1, self.system.total_epochs + 1): if "train" in self.pipeline.get_modes(epoch=self.system.epoch_idx): self.system.mode = "train" self._run_epoch() if "eval" in self.pipeline.get_modes(epoch=self.system.epoch_idx): self.system.mode = "eval" self._run_epoch() else: self._run_epoch() except EarlyStop: pass # On early stopping we still want to run the final traces and return results self._run_traces_on_end(traces=all_traces)
def _draw_diagram(self, mode: str, epoch: int) -> pydot.Dot: """Draw a summary diagram of the FastEstimator Ops / Traces. Args: mode: The execution mode to summarize ('train', 'eval', 'test', or 'infer'). epoch: The epoch to summarize. Returns: A pydot digraph representing the execution flow. """ ds = self.system.pipeline.data[mode] if isinstance(ds, Scheduler): ds = ds.get_current_value(epoch) pipe_ops = get_current_items( self.system.pipeline.ops, run_modes=mode, epoch=epoch) if isinstance(ds, Dataset) else [] net_ops = get_current_items(self.system.network.ops, run_modes=mode, epoch=epoch) traces = sort_traces( get_current_items(self.system.traces, run_modes=mode, epoch=epoch)) diagram = pydot.Dot( compound='true' ) # Compound lets you draw edges which terminate at sub-graphs diagram.set('rankdir', 'TB') diagram.set('dpi', 300) diagram.set_node_defaults(shape='box') # Make the dataset the first of the pipeline ops pipe_ops.insert(0, ds) label_last_seen = defaultdict( lambda: str(id(ds))) # Where was this key last generated batch_size = "" if isinstance(ds, Dataset) and not isinstance(ds, BatchDataset): batch_size = self.system.pipeline.batch_size if isinstance(batch_size, Scheduler): batch_size = batch_size.get_current_value(epoch) if batch_size is not None: batch_size = f" (Batch Size: {batch_size})" self._draw_subgraph(diagram, diagram, label_last_seen, f'Pipeline{batch_size}', pipe_ops) self._draw_subgraph(diagram, diagram, label_last_seen, 'Network', net_ops) self._draw_subgraph(diagram, diagram, label_last_seen, 'Traces', traces) return diagram
def _verify_inputs(self) -> None: """Ensure that all ops are TensorOps. Raises: AssertionError: If any of the ops are not TensorOps. """ for op in get_current_items(self.ops): assert isinstance(op, (TensorOp, LambdaOp)), "unsupported op format, must provide TensorOp in Network"
def test_estimator_prepare_traces_check_all_trace_have_system(self): est = fe.Estimator(pipeline=self.pipeline, network=self.network, epochs=1) est._prepare_traces({"train", "eval", "test"}) for trace in get_current_items(est.traces_in_use, run_modes={"train", "eval", "test"}): self.assertEqual(trace.system, est.system)
def __init__(self, ops: Iterable[Union[TensorOp, Scheduler[TensorOp]]]) -> None: super().__init__(ops) for op in get_current_items(self.ops): op.build(framework='torch') self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") if any([model.mixed_precision for model in self.models]): self.scaler = torch.cuda.amp.GradScaler()
def get_loss_keys(self) -> Set[str]: """Find all of the keys associated with model losses. Returns: All of the keys associated with model losses in this network. """ loss_keys = set() for op in get_current_items(self.ops): loss_keys |= op.get_fe_loss_keys() return loss_keys
def _verify_dataset(self, dataset: DataSource, **kwargs) -> bool: """A helper function to ensure that all of a dataset's arguments are correct. Args: dataset: The dataset to validate against. **kwargs: A selection of variables and their values which must be validated. Returns: True iff the `dataset` is a PyTorch Dataset (as opposed to a DataLoader or tf.data.Dataset). Raises: AssertionError: If the `kwargs` are found to be invalid based on the given `dataset`. ValueError: If the `dataset` is of an unknown type. """ if isinstance(dataset, Dataset): # batch_size check for batch_size in get_current_items(to_list(self.batch_size)): assert isinstance( batch_size, int), "unsupported batch_size format: {}".format( type(batch_size)) # ops check for op in get_current_items(self.ops): assert isinstance( op, (NumpyOp, LambdaOp) ), "unsupported op format, must provide NumpyOp in Pipeline" # num_process check assert isinstance(self.num_process, int), "number of processes must be an integer" return True elif isinstance(dataset, (DataLoader, tf.data.Dataset)): if kwargs['batch_size'] is not None: warnings.warn( "batch_size will only be used for built-in dataset") if kwargs['ops'] is not None: warnings.warn("ops will only be used for built-in dataset") if kwargs['num_process'] is not None: warnings.warn( "num_process will only be used for built-in dataset") return False else: raise ValueError("Unsupported dataset type: {}".format( type(dataset)))
def on_epoch_begin(self, data: Data) -> None: if not self.inputs: self.monitor_keys = self.system.network.get_loss_keys() elif "*" in self.inputs: self.monitor_keys = self.system.network.get_loss_keys() for trace in get_current_items(self.system.traces, run_modes=self.system.mode, epoch=self.system.epoch_idx): self.monitor_keys.update(trace.outputs) else: self.monitor_keys = self.inputs
def get_loss_keys(self) -> Set[str]: """Find all of the keys associated with model losses. Returns: All of the keys associated with model losses in this network. """ loss_keys = set() for op in get_current_items(self.ops): if isinstance(op, UpdateOp): loss_keys.update(op.inputs) return loss_keys
def get_loader(self, mode: str, epoch: int = 1, shuffle: Optional[bool] = None, output_keys: Optional[Set[str]] = None) -> Union[DataLoader, tf.data.Dataset]: """Get a data loader from the Pipeline for a given `mode` and `epoch`. Args: mode: The execution mode for the loader. This can be 'train', 'eval' or 'test'. epoch: The epoch index for the loader. Note that epoch indices are 1-indexed. shuffle: Whether to shuffle the data. If None, the value for shuffle is based on mode. NOTE: This argument is only used with FastEstimator Datasets. output_keys: What keys can be produced from pipeline. If None, all keys will be considered. Returns: A data loader for the given `mode` and `epoch`. """ data = self.data[mode] if isinstance(data, Scheduler): data = data.get_current_value(epoch) if isinstance(data, Dataset): # batch size batch_size = self.batch_size if isinstance(batch_size, Scheduler): batch_size = batch_size.get_current_value(epoch) if isinstance(batch_size, dict): batch_size = batch_size[mode] # batch dataset if isinstance(data, BatchDataset): data.pad_value = self.pad_value # shuffle if shuffle is None: shuffle = mode == "train" and batch_size is not None # collate_fn collate_fn = self.collate_fn if collate_fn is None and self.pad_value is not None: collate_fn = self._pad_batch_collate op_dataset = OpDataset(data, get_current_items(self.ops, mode, epoch), mode, output_keys, deep_remainder=False) # Results will be immediately converted to tensors, so don't need deep_remainder batch_size = None if isinstance(data, BatchDataset) else batch_size data = DataLoader(op_dataset, batch_size=batch_size, shuffle=False if isinstance(data, BatchDataset) else shuffle, sampler=RandomSampler(op_dataset) if isinstance(data, BatchDataset) and shuffle else None, num_workers=self.num_process, drop_last=False if batch_size is None else self.drop_last, worker_init_fn=lambda _: np.random.seed(random.randint(0, 2**32 - 1)), collate_fn=collate_fn) return data
def _collect_models(ops: Iterable[Union[TensorOp, Scheduler[TensorOp]]]) -> Set[Model]: """Collect all model instances from amongst a list of ops. Args: ops: The ops to be searched through. Returns: All of the model instances contained within the `ops`. """ models = set() for op in get_current_items(ops): models |= op.get_fe_models() return models
def _draw_diagram(self, mode: str, epoch: int) -> pydot.Dot: """Draw a summary diagram of the FastEstimator Ops / Traces. Args: mode: The execution mode to summarize ('train', 'eval', 'test', or 'infer'). epoch: The epoch to summarize. Returns: A pydot digraph representing the execution flow. """ ds = self.system.pipeline.data[mode] if isinstance(ds, Scheduler): ds = ds.get_current_value(epoch) pipe_ops = get_current_items( self.system.pipeline.ops, run_modes=mode, epoch=epoch) if isinstance(ds, FEDataset) else [] net_ops = get_current_items(self.system.network.ops, run_modes=mode, epoch=epoch) traces = get_current_items(self.system.traces, run_modes=mode, epoch=epoch) diagram = pydot.Dot() diagram.set('rankdir', 'TB') diagram.set('dpi', 300) diagram.set_node_defaults(shape='record') diagram.add_node( pydot.Node(str(id(ds)), label=f'{ds.__class__.__name__} ({FEID(id(ds))})', texlbl=HrefFEID(FEID(id(ds)), name=ds.__class__.__name__).dumps())) label_last_seen = defaultdict( lambda: str(id(ds))) # Where was this key last generated self._draw_subgraph(diagram, label_last_seen, 'Pipeline', pipe_ops) self._draw_subgraph(diagram, label_last_seen, 'Network', net_ops) self._draw_subgraph(diagram, label_last_seen, 'Traces', traces) return diagram
def get_all_output_keys(self, mode: str, epoch: int) -> Set[str]: """Get all of the keys that will be generated by the network during the given `epoch` and `mode`. Args: mode: The execution mode to consider. One of 'train', 'eval', 'test', or 'infer'. epoch: The epoch number to consider when searching for outputs. Returns: The keys that will be generated by the network's Ops during the `epoch` for the given `mode`. """ output_keys = set() for op in get_current_items(self.ops, mode, epoch): output_keys.update(op.outputs) return output_keys
def _run_epoch(self) -> None: """A method to perform an epoch of activity. This method requires that the current mode and epoch already be specified within the self.system object. """ traces = get_current_items(self.traces_in_use, run_modes=self.system.mode, epoch=self.system.epoch_idx) trace_input_keys = set() for trace in traces: trace_input_keys.update(trace.inputs) loader = self._configure_loader( self.pipeline.get_loader(self.system.mode, self.system.epoch_idx)) iterator = iter(loader) self.network.load_epoch(mode=self.system.mode, epoch=self.system.epoch_idx, output_keys=trace_input_keys) self.system.batch_idx = None with Suppressor(): batch = next(iterator) traces = self._sort_traces( traces, available_outputs=to_set(batch.keys()) | self.network.get_all_output_keys(self.system.mode, self.system.epoch_idx)) self._run_traces_on_epoch_begin(traces=traces) while True: try: if self.system.mode == "train": self.system.update_global_step() self.system.update_batch_idx() batch = self._configure_tensor(loader, batch) self._run_traces_on_batch_begin(batch, traces=traces) batch, prediction = self.network.run_step(batch) self._run_traces_on_batch_end(batch, prediction, traces=traces) if isinstance(loader, DataLoader) and ( (self.system.batch_idx == self.system.max_train_steps_per_epoch and self.system.mode == "train") or (self.system.batch_idx == self.system.max_eval_steps_per_epoch and self.system.mode == "eval")): raise StopIteration with Suppressor(): batch = next(iterator) except StopIteration: break self._run_traces_on_epoch_end(traces=traces) self.network.unload_epoch()
def get_effective_input_keys(self, mode: str, epoch: int) -> Set[str]: """Determine which keys need to be provided as input to the network during the given `epoch`. Args: mode: The execution mode to consider. One of 'train', 'eval', 'test', or 'infer'. epoch: The epoch number to consider for determining inputs. Returns: The necessary inputs for the network to execute the given `epoch` and `mode`. """ input_keys = set() produced_keys = set() for op in get_current_items(self.ops, mode, epoch): input_keys.update(set(key for key in op.inputs if key not in produced_keys)) produced_keys.update(op.outputs) return input_keys
def transform(self, data: Dict[str, Any], mode: str, epoch: int = 1) -> Dict[str, Any]: """Apply all pipeline operations on a given data instance for the specified `mode` and `epoch`. Args: data: Input data in dictionary format. mode: The execution mode in which to run. This can be "train", "eval", "test" or "infer". epoch: The epoch index to run. Note that epoch indices are 1-indexed. Returns: The transformed data. """ data = deepcopy(data) ops = get_current_items(self.ops, mode, epoch) forward_numpyop(ops, data, {'mode': mode}) for key, value in data.items(): data[key] = np.expand_dims(value, 0) return data
def _verify_inputs(self, **kwargs) -> None: """A helper method to ensure that the Pipeline inputs are valid. Args: **kwargs: A collection of variable / value pairs to validate. Raises: AssertionError: If `batch_size`, `ops`, or `num_process` were specified in the absence of a FastEstimator Dataset. """ fe_dataset = False for dataset in get_current_items(self.data.values()): fe_dataset = self._verify_dataset(dataset, **kwargs) or fe_dataset if not fe_dataset: assert kwargs['batch_size'] is None, "Pipeline only supports batch_size with built-in (FE) datasets" assert kwargs['ops'] is None, "Pipeline only supports ops with built-in (FE) datasets" assert kwargs['num_process'] is None, "Pipeline only support num_process with built-in (FE) datasets"
def __init__( self, target_type: str, device: Optional[torch.device], ops: Iterable[Union[TensorOp, Scheduler[TensorOp]]], postprocessing: Union[None, NumpyOp, Scheduler[NumpyOp], Iterable[Union[NumpyOp, Scheduler[NumpyOp]]]] = None ) -> None: self.ops = to_list(ops) self.target_type = target_type self.device = device for op in get_current_items(self.ops): op.build(framework=self.target_type, device=self.device) self.models = to_list(_collect_models(ops)) self.postprocessing = to_list(postprocessing) self._verify_inputs() self.effective_inputs = dict() self.effective_outputs = dict() self.epoch_ops = [] self.epoch_postprocessing = [] self.epoch_models = set() self.epoch_state = dict() self.scaler = None
def _run_epoch(self, eager: bool) -> None: """A method to perform an epoch of activity. This method requires that the current mode and epoch already be specified within the self.system object. Args: eager: Whether to run the training in eager mode. This is only related to TensorFlow training because PyTorch by nature is always in eager mode. """ ds_ids = self.pipeline.get_ds_ids(self.system.epoch_idx, self.system.mode) epoch_traces = sort_traces(get_current_items( self.traces_in_use, run_modes=self.system.mode, epoch=self.system.epoch_idx), ds_ids=ds_ids) self._run_traces_on_epoch_begin(traces=epoch_traces) self.system.batch_idx = None end_epoch_data = Data( ) # We will aggregate data over on_ds_end and put it into on_epoch_end for printing # run for each dataset for self.system.ds_id in ds_ids: ds_traces = get_current_items(self.traces_in_use, run_modes=self.system.mode, epoch=self.system.epoch_idx, ds_id=self.system.ds_id) trace_input_keys = set() for ds_trace in ds_traces: trace_input_keys.update(ds_trace.inputs) network_input_keys = self.network.get_effective_input_keys( mode=self.system.mode, epoch=self.system.epoch_idx, ds_id=self.system.ds_id) network_output_keys = self.network.get_all_output_keys( mode=self.system.mode, epoch=self.system.epoch_idx, ds_id=self.system.ds_id) self.network.load_epoch(mode=self.system.mode, epoch=self.system.epoch_idx, ds_id=self.system.ds_id, output_keys=trace_input_keys, eager=eager) with self.pipeline( mode=self.system.mode, epoch=self.system.epoch_idx, ds_id=self.system.ds_id, steps_per_epoch=self.system.steps_per_epoch, output_keys=trace_input_keys - network_output_keys | network_input_keys) as loader: loader = self._configure_loader(loader) iterator = iter(loader) with Suppressor(): batch = next(iterator) ds_traces = sort_traces(ds_traces, available_outputs=to_set(batch.keys()) | network_output_keys, ds_ids=ds_ids) per_ds_traces = [ trace for trace in ds_traces if isinstance(trace, PerDSTrace) ] self._run_traces_on_ds_begin(traces=per_ds_traces) while True: try: if self.system.mode == "train": self.system.update_global_step() self.system.update_batch_idx() batch = self._configure_tensor(loader, batch) self._run_traces_on_batch_begin(batch, traces=ds_traces) batch, prediction = self.network.run_step(batch) self._run_traces_on_batch_end(batch, prediction, traces=ds_traces) if isinstance(loader, DataLoader) and ( (self.system.batch_idx == self.system.train_steps_per_epoch and self.system.mode == "train") or (self.system.batch_idx == self.system.eval_steps_per_epoch and self.system.mode == "eval")): raise StopIteration with Suppressor(): batch = next(iterator) except StopIteration: break self._run_traces_on_ds_end(traces=per_ds_traces, data=end_epoch_data) self.network.unload_epoch() self._run_traces_on_epoch_end(traces=epoch_traces, data=end_epoch_data)
def __call__(self, mode: str, epoch: int = 1, ds_id: str = '', shuffle: Optional[bool] = None, steps_per_epoch: Optional[int] = None, output_keys: Optional[Set[str]] = None) -> 'Pipeline': """Prepare this Pipeline for a given `mode` and `epoch`. A given pipeline can only provide one loader at a time. This helps to prevent issues with multi-threading. ```python pipe = Pipeline(...) with pipe(mode='eval', epoch=2) as loader: for batch in loader: print(batch) ``` Args: mode: The execution mode for the loader. This can be 'train', 'eval' or 'test'. epoch: The epoch index for the loader. Note that epoch indices are 1-indexed. ds_id: The dataset id to consider for the loader. shuffle: Whether to shuffle the data. If None, the value for shuffle is based on mode. NOTE: This argument is only used with FastEstimator Datasets. steps_per_epoch: Training or Evaluation will be cut short or extended to complete N steps even if loader is not yet exhausted. If None, all data will be used. output_keys: What keys can be produced from pipeline. If None or empty, all keys will be considered. Returns: The pipeline, but with `mode` and `epoch` set for use in a loader. Raises: ValueError: If called while the pipeline already has an active loader. """ # Make sure that a loader isn't currently instantiated with other settings acquired = self.ctx_lock.acquire(blocking=False) if not acquired: raise ValueError( "You cannot invoke a Pipeline's __call__ method while it already has an active loader." ) self.ctx_mode = mode self.ctx_epoch = epoch self.ctx_ds_id = ds_id self.ctx_shuffle = mode == 'train' if shuffle is None else shuffle self.ctx_steps_per_epoch = steps_per_epoch self.ctx_output_keys = output_keys or set() self.ctx_ops, self.ctx_batch_info, self.ctx_batch_ops = self._get_op_split( mode=mode, epoch=epoch, ds_id=ds_id) # Figure out which input keys are required by the batch ops (so they don't get pruned too early) self.ctx_batch_input_keys = set() batch_produced_keys = set() for op in get_current_items(self.ctx_batch_ops, mode, epoch, ds_id=ds_id): self.ctx_batch_input_keys.update( set(key for key in op.inputs if key not in batch_produced_keys)) batch_produced_keys.update(op.outputs) # Decide on the batch size (this might still be ignored later if the user is using a BatchDataset) self.ctx_batch_size = self.ctx_batch_info.batch_size if self.ctx_batch_size is None: # batch size batch_size = self.batch_size if isinstance(batch_size, Scheduler): batch_size = batch_size.get_current_value(self.ctx_epoch) self.ctx_batch_size = batch_size self.ctx_lock.release() return self