def __init__(self, architecture_config, training_config, callbacks_config): super().__init__(architecture_config, training_config, callbacks_config) self.model = UNet(**architecture_config['model_params']) self.weight_regularization = weight_regularization_unet self.optimizer = optim.Adam( self.weight_regularization( self.model, **architecture_config['regularizer_params']), **architecture_config['optimizer_params']) self.loss_function = [('multichannel_map', multiclass_segmentation_loss, 1.0)] self.callbacks = callbacks_unet(self.callbacks_config)
class PyTorchUNetStream(Model): def __init__(self, architecture_config, training_config, callbacks_config): super().__init__(architecture_config, training_config, callbacks_config) self.model = UNet(**architecture_config['model_params']) self.weight_regularization = weight_regularization_unet self.optimizer = optim.Adam( self.weight_regularization( self.model, **architecture_config['regularizer_params']), **architecture_config['optimizer_params']) self.loss_function = [('multichannel_map', multiclass_segmentation_loss, 1.0)] self.callbacks = callbacks_unet(self.callbacks_config) def transform(self, datagen, validation_datagen=None): if len(self.output_names) == 1: output_generator = self._transform(datagen, validation_datagen) output = { '{}_prediction'.format(self.output_names[0]): output_generator } return output else: raise NotImplementedError def _transform(self, datagen, validation_datagen=None): self.model.eval() batch_gen, steps = datagen for batch_id, data in enumerate(batch_gen): if isinstance(data, list): X = data[0] else: X = data if torch.cuda.is_available(): X = Variable(X, volatile=True).cuda() else: X = Variable(X, volatile=True) outputs_batch = self.model(X) outputs_batch = outputs_batch.data.cpu().numpy() for output in outputs_batch: output = softmax(output, axis=0) yield output if batch_id == steps: break self.model.train()
def set_model(self): encoder = self.architecture_config['model_params']['encoder'] if encoder == 'from_scratch': self.model = UNet(**self.architecture_config['model_params']) else: config = PRETRAINED_NETWORKS[encoder] self.model = config['model'](**config['model_config']) self._initialize_model_weights = lambda: None
def __init__(self, architecture_config, training_config, callbacks_config): super().__init__(architecture_config, training_config, callbacks_config) self.model = UNet(**architecture_config['model_params']) self.weight_regularization = weight_regularization_unet self.optimizer = optim.Adam( self.weight_regularization( self.model, **architecture_config['regularizer_params']), **architecture_config['optimizer_params']) self.loss_function = segmentation_loss self.callbacks = build_callbacks(self.callbacks_config)