Esempio n. 1
0
    def get_unet(self, load_best=True):
        """
        Get the segmentation network branch

        :param load_best: flag that specifies whether the best model should be loaded (True by default)
        :return: a U-Net module
        """

        # load parameters of best model
        if load_best:
            self.load_state_dict(
                torch.load(self.trainer.checkpoint_callback.best_model_path)
                ['state_dict'])

        # initialize new UNet2D model
        net = UNet2D(in_channels=self.encoder.in_channels,
                     coi=self.coi,
                     feature_maps=self.encoder.feature_maps,
                     levels=self.encoder.levels,
                     skip_connections=self.decoder.skip_connections,
                     norm=self.encoder.norm,
                     activation=self.encoder.activation,
                     dropout_enc=self.encoder.dropout)

        # load optimal parameters in the model
        net.encoder.load_state_dict(self.encoder.state_dict())
        net.decoder.load_state_dict(self.decoder.state_dict())

        return net
Esempio n. 2
0
    def get_unet(self):
        """
        Get the segmentation network branch
        :return: a U-Net module
        """
        net = UNet2D(in_channels=self.encoder.in_channels, coi=self.coi, feature_maps=self.encoder.feature_maps,
                     levels=self.encoder.levels, skip_connections=self.decoder.skip_connections,
                     norm=self.encoder.norm, activation=self.encoder.activation, dropout_enc=self.encoder.dropout)

        net.encoder.load_state_dict(self.encoder.state_dict())
        net.decoder.load_state_dict(self.decoder.state_dict())

        return net
Esempio n. 3
0
    def get_unet(self, tar=True):
        """
        Get the segmentation network branch
        :param tar: return the target or source branch
        :return: a U-Net module
        """
        net = UNet2D(in_channels=self.encoder.in_channels, coi=self.coi, feature_maps=self.encoder.feature_maps,
                     levels=self.encoder.levels, skip_connections=self.decoder.skip_connections, norm=self.encoder.norm)

        if tar:
            net.encoder.load_state_dict(self.tar_encoder.state_dict())
        else:
            net.encoder.load_state_dict(self.src_encoder.state_dict())
        net.decoder.load_state_dict(self.decoder.state_dict())

        return net
Esempio n. 4
0
    df['labels'],
    split_orientation=df['split-orientation'],
    split_location=df['split-location'],
    input_shape=input_shape,
    len_epoch=args.len_epoch,
    type=df['type'],
    train=False)
train_loader_src = DataLoader(train_src, batch_size=args.train_batch_size)
test_loader_src = DataLoader(test_src, batch_size=args.test_batch_size)
"""
    Build the network
"""
print('[%s] Building the network' % (datetime.datetime.now()))
net = UNet2D(feature_maps=args.fm,
             levels=args.levels,
             norm=args.norm,
             activation=args.activation,
             coi=args.classes_of_interest)
"""
    Setup optimization for training
"""
print('[%s] Setting up optimization for training' % (datetime.datetime.now()))
optimizer = optim.Adam(net.parameters(), lr=args.lr)
scheduler = optim.lr_scheduler.StepLR(optimizer,
                                      step_size=args.step_size,
                                      gamma=args.gamma)
"""
    Train the network
"""
print('[%s] Starting training' % (datetime.datetime.now()))
net.train_net(train_loader_src,
Esempio n. 5
0
                                    input_shape=input_shape,
                                    len_epoch=args.len_epoch,
                                    type='pngseq',
                                    in_channels=args.in_channels,
                                    batch_size=args.test_batch_size,
                                    orientations=args.orientations)
train_loader = DataLoader(train, batch_size=args.train_batch_size)
test_loader = DataLoader(test, batch_size=args.test_batch_size)
"""
    Build the network
"""
print_frm('Building the network')
net = UNet2D(in_channels=args.in_channels,
             feature_maps=args.fm,
             levels=args.levels,
             dropout_enc=args.dropout,
             dropout_dec=args.dropout,
             norm=args.norm,
             activation=args.activation,
             coi=args.classes_of_interest)
"""
    Setup optimization for training
"""
print_frm('Setting up optimization for training')
optimizer = optim.Adam(net.parameters(), lr=args.lr)
scheduler = optim.lr_scheduler.StepLR(optimizer,
                                      step_size=args.step_size,
                                      gamma=args.gamma)
"""
    Train the network
"""
print_frm('Starting training')