Exemplo n.º 1
0
 def __init__(self, architecture_config, training_config, callbacks_config):
     super().__init__(architecture_config, training_config,
                      callbacks_config)
     self.model = UNet(**architecture_config['model_params'])
     self.weight_regularization = weight_regularization_unet
     self.optimizer = optim.Adam(
         self.weight_regularization(
             self.model, **architecture_config['regularizer_params']),
         **architecture_config['optimizer_params'])
     self.loss_function = [('multichannel_map',
                            multiclass_segmentation_loss, 1.0)]
     self.callbacks = callbacks_unet(self.callbacks_config)
Exemplo n.º 2
0
class PyTorchUNetStream(Model):
    def __init__(self, architecture_config, training_config, callbacks_config):
        super().__init__(architecture_config, training_config,
                         callbacks_config)
        self.model = UNet(**architecture_config['model_params'])
        self.weight_regularization = weight_regularization_unet
        self.optimizer = optim.Adam(
            self.weight_regularization(
                self.model, **architecture_config['regularizer_params']),
            **architecture_config['optimizer_params'])
        self.loss_function = [('multichannel_map',
                               multiclass_segmentation_loss, 1.0)]
        self.callbacks = callbacks_unet(self.callbacks_config)

    def transform(self, datagen, validation_datagen=None):
        if len(self.output_names) == 1:
            output_generator = self._transform(datagen, validation_datagen)
            output = {
                '{}_prediction'.format(self.output_names[0]): output_generator
            }
            return output
        else:
            raise NotImplementedError

    def _transform(self, datagen, validation_datagen=None):
        self.model.eval()
        batch_gen, steps = datagen
        for batch_id, data in enumerate(batch_gen):
            if isinstance(data, list):
                X = data[0]
            else:
                X = data

            if torch.cuda.is_available():
                X = Variable(X, volatile=True).cuda()
            else:
                X = Variable(X, volatile=True)

            outputs_batch = self.model(X)
            outputs_batch = outputs_batch.data.cpu().numpy()

            for output in outputs_batch:
                output = softmax(output, axis=0)
                yield output

            if batch_id == steps:
                break
        self.model.train()
 def set_model(self):
     encoder = self.architecture_config['model_params']['encoder']
     if encoder == 'from_scratch':
         self.model = UNet(**self.architecture_config['model_params'])
     else:
         config = PRETRAINED_NETWORKS[encoder]
         self.model = config['model'](**config['model_config'])
         self._initialize_model_weights = lambda: None
Exemplo n.º 4
0
 def __init__(self, architecture_config, training_config, callbacks_config):
     super().__init__(architecture_config, training_config,
                      callbacks_config)
     self.model = UNet(**architecture_config['model_params'])
     self.weight_regularization = weight_regularization_unet
     self.optimizer = optim.Adam(
         self.weight_regularization(
             self.model, **architecture_config['regularizer_params']),
         **architecture_config['optimizer_params'])
     self.loss_function = segmentation_loss
     self.callbacks = build_callbacks(self.callbacks_config)