def get_device_spec(self, device_id): device_spec = self._device_specs.get(device_id) assert_( device_spec is not None, f"device_id {device_id} not found in specs. Consider calling dry_run() first.", RuntimeError) return device_spec
def unbatcher(self, batch): utils.assert_( len(batch.shape) in [4, 5], f"`batch` must either be a NCHW or NCDHW tensor, " f"got one with dimension {len(batch.shape)} " f"instead.", TikIO.ShapeError) return list(batch)
def halo(self, value): if isinstance(value, int): self._halo = [value] * len(self.dynamic_shape) else: assert_( len(value) == len(self.dynamic_shape), f"Halo of a {len(self.dynamic_shape)}-D network cannot " f"be {len(value)}-D.", ValueError) self._halo = value
def parse_inputs(self, inputs): if isinstance(inputs, TikIn): inputs = [inputs] elif isinstance(inputs, (np.ndarray, torch.Tensor)): inputs = [TikIn([inputs])] elif isinstance(inputs, (list, tuple)): utils.assert_(all(isinstance(input, TikIn) for input in inputs), "Inputs must all be TikIn objects.") else: raise TypeError("Inputs must be list TikIn objects.") return inputs
def validate_shape(tensors): utils.assert_( [tensor.shape == tensors[0].shape for tensor in tensors], f"Input `tensors` to TikIn must all have the same shape. " f"Got tensors of shape: {[tensor.shape for tensor in tensors]}", TikIn.ShapeError) utils.assert_( len(tensors[0].shape) in [2, 3, 4], f"Tensors must be of dimensions " f"2 (HW), 3 (CHW/DHw), or 4 (CDHW). " f"Got {len(tensors[0].shape)} instead.", TikIO.ShapeError) return tensors
def batch_inputs(self, inputs): input_shapes = self.get('input_shape', assert_exist=True) assert isinstance(input_shapes, (list, tuple)) # input_shapes can either be a list of shapes or a shape. Make sure it's the latter if isinstance(input_shapes[0], int): input_shapes = [input_shapes] * len(inputs) elif isinstance(input_shapes[0], (list, tuple)): pass else: raise TypeError( f"`input_shapes` must be a list/tuple of ints or " f"lists/tuples or ints. Got list/tuple of {type(input_shapes[0])}." ) utils.assert_( len(input_shapes) == len(inputs), f"Expecting {len(inputs)} inputs, got {len(input_shapes)} input shapes.", ValueError) batches = [ input.batcher(input_shape) for input, input_shape in zip(inputs, input_shapes) ] return batches
def compute_halo(self, device_id=0, set_=True): device = self.devices[device_id] # Evaluate model on the smallest possible image to keep it quick input_tensor = torch.zeros(1, self.channels, *self.dynamic_shape.base_shape) output_tensor = torch.zeros( 1, self.channels, *self.dynamic_shape.base_shape ) #self.model.to(device)(input_tensor.to(device)) # Assuming NCHW or NCDHW, the first two axes are not relevant for computing halo input_spatial_shape = input_tensor.shape[2:] output_spatial_shape = output_tensor.shape[2:] shape_difference = [ _ishape - _oshape for _ishape, _oshape in zip( input_spatial_shape, output_spatial_shape) ] # Support for only symmetric halos for now assert_(all(_shape_diff % 2 == 0 for _shape_diff in shape_difference), "Only symmetric halos are supported.", RuntimeError) # Compute halo halo = [_shape_diff // 2 for _shape_diff in shape_difference] if set_: self.halo = halo return halo
def batcher(self, network_input_shape: list): """ Build batch for the network. Parameters ---------- network_input_shape: list Input shape to the network. Returns ------- torch.Tensor """ network_input_format = {3: 'CHW', 4: 'CDHW'}[len(network_input_shape)] if self.format == 'HW': utils.assert_( network_input_format == 'CHW', f"Input format is HW, which is not compatible with the network " f"input format {network_input_format} (must be CHW).", TikIn.ShapeError) utils.assert_( network_input_shape[0] == 1, f"Input format is HW, for which the number of input channels (C)" f"to the network must be 1. Got C = {network_input_shape[0]} instead.", TikIn.ShapeError) pre_cat = self.reshape(1, *self.shape) elif self.format == 'CDHW': utils.assert_( network_input_format == 'CDHW', f"Input format (CDHW) is not compatible with network input format " f"({network_input_format}).", TikIn.ShapeError) utils.assert_( self.shape[0] == network_input_shape[0], f"Number of input channels in input ({self.shape[0]}) is not " f"consistent with what the network expects ({network_input_shape[0]}).", TikIn.ShapeError) pre_cat = self.tensors elif self.format == 'CHW/DHW': if network_input_format == 'CHW': # input format is CHW utils.assert_( self.shape[0] == network_input_shape[0], f"Number of input channels in input ({self.shape[0]}) is " f"not compatible with the number of input channels to " f"the network ({network_input_shape[0]})", TikIn.ShapeError) pre_cat = self.tensors elif network_input_format == 'CDHW': utils.assert_( network_input_shape[0] == 1, f"Input format DHW requires that the number of input channels (C) " f"to the network is 1. Got C = {network_input_shape[0]} instead.", TikIn.ShapeError) pre_cat = self.reshape(1, *self.shape) else: raise TikIn.ShapeError( f"Input format {self.format} is not compatible with " f"the network input format {network_input_format}.") else: raise ValueError("Internal Error: Invalid Format.") # Concatenate to a batch batch = torch.stack(pre_cat, dim=0) return batch
def tensors(self): utils.assert_( self._tensors is not None, "Trying to acess `TikIn.tensors` with it yet to be defined.", ValueError) return self._tensors