示例#1
0
    def conv_op_factory(self, in_channels, out_channels, part, index):

        if part == 'down':
            return torch.nn.Sequential(
                ConvELU2D(in_channels=in_channels,  out_channels=out_channels, kernel_size=3),
                ConvELU2D(in_channels=out_channels,  out_channels=out_channels, kernel_size=3)
            ), False
        elif part == 'bottom':
            return torch.nn.Sequential(
                ConvReLU2D(in_channels=in_channels,  out_channels=out_channels, kernel_size=3),
                ConvReLU2D(in_channels=out_channels,  out_channels=out_channels, kernel_size=3),
            ), False
        elif part == 'up':
            # are we in the very last block?
            if index  == 0:
                return torch.nn.Sequential(
                    ConvELU2D(in_channels=in_channels,  out_channels=out_channels, kernel_size=3),
                    Conv2D(in_channels=out_channels,  out_channels=out_channels, kernel_size=3)
                ), False
            else:
                return torch.nn.Sequential(
                    ConvELU2D(in_channels=in_channels,   out_channels=out_channels, kernel_size=3),
                    ConvReLU2D(in_channels=out_channels,  out_channels=out_channels, kernel_size=3)
                ), False
        else:
            raise RuntimeError("something is wrong")
示例#2
0
 def __init__(self, in_channels, out_channels, activated):
     super(CheapConv, self).__init__()
     self.in_channels = in_channels
     self.out_channels = out_channels
     if activated:
         self.convs = torch.nn.Sequential(
             ConvActivation(in_channels=in_channels,
                            out_channels=in_channels,
                            depthwise=True,
                            kernel_size=(3, 3),
                            activation='ReLU',
                            dim=2),
             ConvReLU2D(in_channels=in_channels,
                        out_channels=out_channels,
                        kernel_size=(1, 1)))
     else:
         self.convs = torch.nn.Sequential(
             ConvActivation(in_channels=in_channels,
                            out_channels=in_channels,
                            depthwise=True,
                            kernel_size=(3, 3),
                            activation='ReLU',
                            dim=2),
             Conv2D(in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=(1, 1)))
def build_big_model(image_channels, pred_channels=1, no_sigm=False):
    if no_sigm:
        return torch.nn.Sequential(
            ConvReLU2D(in_channels=image_channels,
                       out_channels=8,
                       kernel_size=3),
            inf_model.ResBlockUNet(dim=2,
                                   in_channels=8,
                                   out_channels=pred_channels,
                                   activated=False),
        )
    else:
        return torch.nn.Sequential(
            ConvReLU2D(in_channels=image_channels,
                       out_channels=8,
                       kernel_size=3),
            inf_model.ResBlockUNet(dim=2,
                                   in_channels=8,
                                   out_channels=pred_channels,
                                   activated=False), torch.nn.Sigmoid())
示例#4
0
 def __init__(self, in_channels, out_channels, activated):
     super(CheapConvBlock, self).__init__()
     self.activated = activated
     self.in_channels = in_channels
     self.out_channels = out_channels
     if (in_channels != out_channels):
         self.start = ConvReLU2D(in_channels=in_channels,
                                 out_channels=out_channels,
                                 kernel_size=(1, 1))
     else:
         self.start = None
     self.conv_a = CheapConv(in_channels=out_channels,
                             out_channels=out_channels,
                             activated=True)
     self.conv_b = CheapConv(in_channels=out_channels,
                             out_channels=out_channels,
                             activated=False)
     self.activation = torch.nn.ReLU()
示例#5
0
    ResBlockUNet(dim=2,
                 in_channels=image_channels,
                 out_channels=pred_channels,
                 activated=False), RemoveSingletonDimension(dim=1),
    torch.nn.Sigmoid())

##############################################################################
# while the model above will work in principal, it has some drawbacks.
# Within the UNet, the number of features is increased by a multiplicative
# factor while going down, the so-called gain. The default value for the gain is 2.
# Since we start with only a single channel we could either increase the gain,
# or use a some convolutions to increase the number of channels
# before the the UNet.
from inferno.extensions.layers import ConvReLU2D
model_a = torch.nn.Sequential(
    ConvReLU2D(in_channels=image_channels, out_channels=5, kernel_size=3),
    ResBlockUNet(dim=2,
                 in_channels=5,
                 out_channels=pred_channels,
                 activated=False,
                 res_block_kwargs=dict(batchnorm=True, size=2)),
    RemoveSingletonDimension(dim=1)
    # torch.nn.Sigmoid()
)

##############################################################################
# Training
# ----------------------------
# To train the unet, we use the infernos Trainer class of inferno.
# Since we train many models later on in this example we encapsulate
# the training in a function (see :ref:`sphx_glr_auto_examples_trainer.py` for
labelset_train = HDF5VolumeLoader(path='./stardistance.h5', path_in_h5_dataset='data',
                                  transforms=trans2, **yaml2dict('config_train.yml')['slicing_config_truth'])
trainset = Zip(imageset_train, labelset_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCHSIZE,
                                          shuffle=True, num_workers=2)
imageset_val = HDF5VolumeLoader(path='./val-volume.h5', path_in_h5_dataset='data',
                                transforms=trans, **yaml2dict('config_val.yml')['slicing_config'])
labelset_val = HDF5VolumeLoader(path='./stardistance_val.h5', path_in_h5_dataset='data',
                                transforms=trans2, **yaml2dict('config_val.yml')['slicing_config_truth'])
trainset = Zip(imageset_val, labelset_val)
valloader = torch.utils.data.DataLoader(trainset, batch_size=BATCHSIZE,
                                        shuffle=True, num_workers=2)


net = torch.nn.Sequential(
    ConvReLU2D(in_channels=1, out_channels=3, kernel_size=3),
    UNet(in_channels=3, out_channels=N_DIRECTIONS, dim=2, final_activation='ReLU')
    )

trainer = Trainer(net)

trainer.bind_loader('train', trainloader)
trainer.bind_loader('validate', valloader)

trainer.save_to_directory('./checkpoints')
trainer.save_every((200, 'iterations'))
trainer.build_logger(TensorboardLogger(log_scalars_every=(1, 'iteration'),
                                log_images_every='never'), log_directory=LOG_DIRECTORY)