Example #1
0
def load_model(opts, n_classes=4):
    if opts.model == 'unet':
        model = UNet(n_classes=n_classes)
    elif opts.model == 'fcn':
        if opts.backbone == 'resnet50':
            model = load_fcn_resnet50(n_classes)
        elif opts.backbone == 'resnet101':
            model = load_fcn_resnet101(n_classes)
        else:
            raise NotImplementedError("Invalid backbone specified")
    elif opts.model == 'deeplab':
        if opts.backbone == 'resnet50':
            model = load_deeplab_resnet50(n_classes)
        elif opts.backbone == 'resnet101':
            model = load_deeplab_resnet101(n_classes)
        else:
            raise NotImplementedError("Invalid backbone specified")
    elif opts.model == 'deeplabv3+':
        if opts.backbone == 'resnet101':
            model = DeepLabv3_plus_resnet(n_classes)
        elif opts.backbone == 'xception':
            model = DeepLabv3_plus_xception(n_classes)
        else:
            raise NotImplementedError("Invalid backbone specified")
    else:
        raise NotImplementedError("Invalid model type specified")

    model.n_classes = n_classes
    return model
Example #2
0
def main(_):
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        model = UNet(args.experiment_dir,
                     batch_size=args.batch_size,
                     experiment_id=args.experiment_id,
                     input_width=args.image_size,
                     output_width=args.image_size,
                     embedding_num=args.embedding_num,
                     embedding_dim=args.embedding_dim,
                     L1_penalty=args.L1_penalty)
        model.register_session(sess)
        if args.flip_labels:
            model.build_model(is_training=True,
                              inst_norm=args.inst_norm,
                              no_target_source=True)
        else:
            model.build_model(is_training=True, inst_norm=args.inst_norm)
        fine_tune_list = None
        if args.fine_tune:
            ids = args.fine_tune.split(",")
            fine_tune_list = set([int(i) for i in ids])
        model.train(lr=args.lr,
                    epoch=args.epoch,
                    resume=args.resume,
                    schedule=args.schedule,
                    freeze_encoder=args.freeze_encoder,
                    fine_tune=fine_tune_list,
                    sample_steps=args.sample_steps,
                    checkpoint_steps=args.checkpoint_steps,
                    flip_labels=args.flip_labels)
Example #3
0
    def __init__(self, seq_length, color_channels, unet_path="pretrained/unet.mdl",
                 discrim_path="pretrained/dicrim.mdl",
                 facenet_path="pretrained/facenet.mdl",
                 vgg_path="",
                 embedding_size=1000,
                 unet_depth=3,
                 unet_filts=32,
                 facenet_filts=32,
                 resnet=18):

        self.color_channels = color_channels
        self.margin = 0.5
        self.writer = SummaryWriter(log_dir="logs")

        self.unet_path = unet_path
        self.discrim_path = discrim_path
        self.facenet_path = facenet_path

        self.unet = UNet(in_channels=color_channels, out_channels=color_channels,
                         depth=unet_depth,
                         start_filts=unet_filts,
                         up_mode="upsample",
                         merge_mode='concat').to(device)

        self.discrim = FaceNetModel(embedding_size=embedding_size, start_filts=facenet_filts,
                                    in_channels=color_channels, resnet=resnet,
                                    pretrained=False).to(device)

        self.facenet = FaceNetModel(embedding_size=embedding_size, start_filts=facenet_filts,
                                    in_channels=color_channels, resnet=resnet,
                                    pretrained=False).to(device)

        if os.path.isfile(unet_path):
            self.unet.load_state_dict(torch.load(unet_path))
            print("unet loaded")

        if os.path.isfile(discrim_path):
            self.discrim.load_state_dict(torch.load(discrim_path))
            print("discrim loaded")

        if os.path.isfile(facenet_path):
            self.facenet.load_state_dict(torch.load(facenet_path))
            print("facenet loaded")
        if os.path.isfile(vgg_path):
            self.vgg_loss_network = LossNetwork(vgg_face_dag(vgg_path)).to(device)
            self.vgg_loss_network.eval()

            print("vgg loaded")

        self.mse_loss_function = nn.MSELoss().to(device)
        self.discrim_loss_function = nn.BCELoss().to(device)
        self.triplet_loss_function = TripletLoss(margin=self.margin)

        self.unet_optimizer = torch.optim.Adam(self.unet.parameters(), betas=(0.9, 0.999))
        self.discrim_optimizer = torch.optim.Adam(self.discrim.parameters(), betas=(0.9, 0.999))
        self.facenet_optimizer = torch.optim.Adam(self.facenet.parameters(), betas=(0.9, 0.999))
Example #4
0
def train():
    # Init data
    train_dataset, val_dataset = prepare_datasets()
    train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=10, shuffle=True)
    loaders = dict(train=train_loader, val=val_loader)

    # Init Model
    model = UNet().cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, amsgrad=True)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer,
                                                       gamma=0.984)
    loss_fn = nn.BCELoss()

    epochs = 500
    for epoch in range(epochs):
        for phase in 'train val'.split():
            if phase == 'train':
                model = model.train()
                torch.set_grad_enabled(True)

            else:
                model = model.eval()
                torch.set_grad_enabled(False)

            loader = loaders[phase]
            epoch_losses = dict(train=[], val=[])
            running_loss = []

            for batch in loader:
                imgs, masks = batch
                imgs = imgs.cuda()
                masks = masks.cuda()

                outputs = model(imgs)
                loss = loss_fn(outputs, masks)

                running_loss.append(loss.item())

                if phase == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

            # End of Epoch
            print(f'{epoch}) {phase} loss: {np.mean(running_loss)}')
            visualize_results(loader, model, epoch, phase)

            epoch_losses[phase].append(np.mean(running_loss))
            tensorboard(epoch_losses[phase], phase)

            if phase == 'train':
                scheduler.step()
Example #5
0
def create_model(args, input_shape, enable_decoder=True):
    # If using CPU or single GPU
    if args.gpus <= 1:
        if args.net == 'unet':
            from models.unet import UNet
            model = UNet(input_shape)
            return [model]
        elif args.net == 'tiramisu':
            from models.densenets import DenseNetFCN
            model = DenseNetFCN(input_shape)
            return [model]
        elif args.net == 'segcapsr1':
            from segcapsnet.capsnet import CapsNetR1
            model_list = CapsNetR1(input_shape)
            return model_list
        elif args.net == 'segcapsr3':
            from segcapsnet.capsnet import CapsNetR3
            model_list = CapsNetR3(input_shape, args.num_class, enable_decoder)
            return model_list
        elif args.net == 'capsbasic':
            from segcapsnet.capsnet import CapsNetBasic
            model_list = CapsNetBasic(input_shape)
            return model_list
        else:
            raise Exception('Unknown network type specified: {}'.format(
                args.net))
    # If using multiple GPUs
    else:
        with tf.device("/cpu:0"):
            if args.net == 'unet':
                from models.unet import UNet
                model = UNet(input_shape)
                return [model]
            elif args.net == 'tiramisu':
                from models.densenets import DenseNetFCN
                model = DenseNetFCN(input_shape)
                return [model]
            elif args.net == 'segcapsr1':
                from segcapsnet.capsnet import CapsNetR1
                model_list = CapsNetR1(input_shape)
                return model_list
            elif args.net == 'segcapsr3':
                from segcapsnet.capsnet import CapsNetR3
                model_list = CapsNetR3(input_shape, args.num_class,
                                       enable_decoder)
                return model_list
            elif args.net == 'capsbasic':
                from segcapsnet.capsnet import CapsNetBasic
                model_list = CapsNetBasic(input_shape)
                return model_list
            else:
                raise Exception('Unknown network type specified: {}'.format(
                    args.net))
def visualize_voc_unet():
    from data.voc2012_loader_segmentation import PascalVOCSegmentation
    from torch.utils.data.dataloader import DataLoader
    from visualize.visualize import visualize
    from models.unet import UNet

    dataloader = DataLoader(PascalVOCSegmentation('val'),
                            batch_size=16,
                            shuffle=False,
                            num_workers=0)
    model = UNet(outputs=21, name='voc_unet')
    model.load()
    visualize(model, dataloader, model.name + '_visualization/')
Example #7
0
def run():
    """Builds model, loads data, trains and evaluates"""
    model = UNet(CFG)
    model.load_data()
    model.build()
    #model.train()
    model.evaluate()
def validate(state_dict_path, use_gpu, device):
    model = UNet(n_channels=1, n_classes=2)
    model.load_state_dict(torch.load(state_dict_path, map_location='cpu' if not use_gpu else device))
    model.to(device)
    val_transforms = transforms.Compose([
        ToTensor(), 
        NormalizeBRATS()])

    BraTS_val_ds = BRATS2018('./BRATS2018',\
        data_set='val',\
        seg_type='et',\
        scan_type='t1ce',\
        transform=val_transforms)

    data_loader = DataLoader(BraTS_val_ds, batch_size=2, shuffle=False, num_workers=0)

    running_dice_score = 0.

    for batch_ind, batch in enumerate(data_loader):
        imgs, targets = batch
        imgs = imgs.to(device)
        targets = targets.to(device)
        
        model.eval()

        with torch.no_grad():
            outputs = model(imgs)
            preds = torch.argmax(F.softmax(outputs, dim=1), dim=1)

            running_dice_score += dice_score(preds, targets) * targets.size(0)
            print('running dice score: {:.6f}'.format(running_dice_score))
    
    dice = running_dice_score / len(BraTS_val_ds)
    print('mean dice score of the validating set: {:.6f}'.format(dice))
Example #9
0
class EventGANBase(object):
    def __init__(self, options):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.generator = UNet(num_input_channels=2*options.n_image_channels,
                              num_output_channels=options.n_time_bins * 2,
                              skip_type='concat',
                              activation='relu',
                              num_encoders=4,
                              base_num_channels=32,
                              num_residual_blocks=2,
                              norm='BN',
                              use_upsample_conv=True,
                              with_activation=True,
                              sn=options.sn,
                              multi=False)
        latest_checkpoint = get_latest_checkpoint(options.checkpoint_dir)
        checkpoint = torch.load(latest_checkpoint)
        self.generator.load_state_dict(checkpoint["gen"])
        self.generator.to(self.device)
        
    def forward(self, images, is_train=False):
        if len(images.shape) == 3:
            images = images[None, ...]
        assert len(images.shape) == 4 and images.shape[1] == 2, \
            "Input images must be either 2xHxW or Bx2xHxW."
        if not is_train:
            with torch.no_grad():
                self.generator.eval()
                event_volume = self.generator(images)
            self.generator.train()
        else:
            event_volume = self.generator(images)

        return event_volume
Example #10
0
def get_model(model_name):
    model = None
    if model_name == 'vgg16':
        from models.vgg16 import Vgg16GAP
        model = Vgg16GAP(name="vgg16")
        return model

    if model_name == 'unet':
        from models.unet import UNet
        model = UNet()
        return model

    if model_name == 'deeplab':
        from models.deeplab import DeepLab
        model = DeepLab(name="deeplab")
        return model

    if model_name == 'affinitynet':
        from models.aff_net import AffNet
        model = AffNet(name="affinitynet")
        return model

    if model_name == 'wasscam':
        from models.wass import WASS
        model = WASS()
        return model

    raise Error('Model name has no implementation')
Example #11
0
def get_net(input_depth, NET_TYPE, pad, upsample_mode, n_channels=3, act_fun='LeakyReLU', skip_n33d=128, skip_n33u=128, skip_n11=4, num_scales=5, downsample_mode='stride'):
    if NET_TYPE == 'ResNet':
        # TODO
        net = ResNet(input_depth, 3, 10, 16, 1, nn.BatchNorm2d, False)
    elif NET_TYPE == 'skip':
        net = skip(input_depth, n_channels, num_channels_down = [skip_n33d]*num_scales if isinstance(skip_n33d, int) else skip_n33d,
                                            num_channels_up =   [skip_n33u]*num_scales if isinstance(skip_n33u, int) else skip_n33u,
                                            num_channels_skip = [skip_n11]*num_scales if isinstance(skip_n11, int) else skip_n11, 
                                            upsample_mode=upsample_mode, downsample_mode=downsample_mode,
                                            need_sigmoid=True, need_bias=True, pad=pad, act_fun=act_fun)

    elif NET_TYPE == 'texture_nets':
        net = get_texture_nets(inp=input_depth, ratios = [32, 16, 8, 4, 2, 1], fill_noise=False,pad=pad)

    elif NET_TYPE =='UNet':
        net = UNet(num_input_channels=input_depth, num_output_channels=3, 
                   feature_scale=4, more_layers=0, concat_x=False,
                   upsample_mode=upsample_mode, pad=pad, norm_layer=nn.BatchNorm2d, need_sigmoid=True, need_bias=True)
    elif NET_TYPE == 'identity':
        assert input_depth == 3
        net = nn.Sequential()
    else:
        assert False

    return net
Example #12
0
	def get_model(self, dataset):
		if self.args.model=="GCN":
			model = GCN(dataset.num_classes, dataset.img_size, k=self.args.K).cuda()
		elif self.args.model=="UNet":
			model = UNet(dataset.num_classes).cuda()
		elif self.args.model=="GCN_DENSENET":
			model = GCN_DENSENET(dataset.num_classes, dataset.img_size, k=self.args.K).cuda()
		elif self.args.model=="GCN_DECONV":
			model = GCN_DECONV(dataset.num_classes, dataset.img_size, k=self.args.K).cuda()
		elif self.args.model=="GCN_PSP":
			model = GCN_PSP(dataset.num_classes, dataset.img_size, k=self.args.K).cuda()
		elif self.args.model=="GCN_COMB":
			model = GCN_COMBINED(dataset.num_classes, dataset.img_size, k=self.args.K).cuda()
		elif self.args.model=="GCN_RESNEXT":
			model = GCN_RESNEXT(dataset.num_classes, k=self.args.K).cuda()
		else:
			raise ValueError("Invalid model arg.")

		start_epoch = 0
		if self.args.resume:
			setup.load_save(model, self.args)
			start_epoch = self.args.resume_epoch
		self.args.start_epoch = start_epoch

		model.train()
		return model
    def __init__(self, in_channels=12, use_model='unet', use_d8=False,
                 learning_rate=0.02, adam_epsilon=1e-8, **kwargs):
        super(DrainageNetworkExtractor, self).__init__()
        self.save_hyperparameters()

        if use_model.lower() == 'unet':
            self.model = UNet(n_channels=in_channels, n_classes=12, bilinear=self.hparams.bilinear)
        elif use_model.lower() == 'lhn_unet':
            self.model = LHNUNet(n_channels=in_channels, n_classes=12,
                                 n_classes_l1=self.hparams.n_classes_l1, n_classes_l2=self.hparams.n_classes_l2,
                                 n_classes_l3=self.hparams.n_classes_l3, n_classes_l4=self.hparams.n_classes_l4)
        elif use_model.lower() == 'deep_lab':
            self.model = DeepLab(backbone=self.hparams.backbone, in_channels=in_channels, num_classes=12,
                                 sync_bn=self.hparams.sync_bn, freeze_bn=self.hparams.freeze_bn,
                                 output_stride=self.hparams.output_stride)
        elif use_model.lower() == 'modsegnet':
            self.model = ModSegNet(num_classes=12, n_init_features=in_channels, drop_rate=self.hparams.drop_rate)
        elif use_model.lower() == 'segnet':
            self.model = SegNet(num_classes=12, n_init_features=in_channels, drop_rate=self.hparams.drop_rate,
                                use_kriging_loss=self.hparams.use_kriging_loss)
        elif use_model.lower() == 'aspp_segnet':
            self.model = ASPPSegNet(num_classes=12, n_init_features=in_channels,
                                    use_kriging_loss=self.hparams.use_kriging_loss)
        elif use_model.lower() == 'sp_segnet':
            self.model = SPSegNet(num_classes=12, n_init_features=in_channels)
        elif use_model.lower() == 'dl_segnet':
            self.model = DLSegNet(num_classes=12, n_init_features=in_channels,
                                  drop_rate=self.hparams.drop_rate)
        else:
            raise Exception(f"{use_model} is not implemented")

        if use_d8:
            self.d8_emb = nn.Embedding(9, 3, max_norm=1)
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--use-coord', action='store_true')
        parser.add_argument('--use-d8', action='store_true')
        parser.add_argument('--use-slope', action='store_true')
        parser.add_argument('--use-curvature', action='store_true')
        parser.add_argument('--in-channels', type=int, default=12)
        parser.add_argument('--use-model', type=str, default='unet')
        parser.add_argument('--learning-rate', type=float, default=0.02)
        parser.add_argument('--adam-epsilon', type=float, default=1e-8)
        parser.add_argument('--use-kriging-loss', action='store_true')
        parser.fromfile_prefix_chars = "@"

        temp_args, _ = parser.parse_known_args()
        if temp_args.use_model.lower() == 'unet':
            parser = UNet.add_model_specific_args(parser)
        elif temp_args.use_model.lower() == 'lhn_unet':
            parser = LHNUNet.add_model_specific_args(parser)
        elif temp_args.use_model.lower() == 'deep_lab':
            parser = DeepLab.add_model_specific_args(parser)
        elif temp_args.use_model.lower() == 'modsegnet':
            parser = ModSegNet.add_model_specific_args(parser)
        elif temp_args.use_model.lower() == 'segnet':
            parser = SegNet.add_model_specific_args(parser)
        elif temp_args.use_model.lower() == 'aspp_segnet':
            parser = ASPPSegNet.add_model_specific_args(parser)
        elif temp_args.use_model.lower() == 'sp_segnet':
            parser = SPSegNet.add_model_specific_args(parser)
        elif temp_args.use_model.lower() == 'dl_segnet':
            parser = DLSegNet.add_model_specific_args(parser)

        return parser
    def init_fn(self):
        if self.options.model == 'flow':
            num_input_channels = self.options.n_time_bins * 2
            num_output_channels = 2
        elif self.options.model == 'recons':
            # For the reconstruction model, we sum the event volume across the time dimension, so
            # that the network only sees a single channel event input, plus the prev image.
            num_input_channels = 1 + self.options.n_image_channels
            num_output_channels = self.options.n_image_channels
        else:
            raise ValueError(
                "Class was initialized with an invalid model {}"
                ", only {EventGAN, flow, recons} are supported.".format(
                    self.options.model))

        self.cycle_unet = UNet(num_input_channels=num_input_channels,
                               num_output_channels=num_output_channels,
                               skip_type='concat',
                               activation='tanh',
                               num_encoders=4,
                               base_num_channels=32,
                               num_residual_blocks=2,
                               norm='BN',
                               use_upsample_conv=True,
                               multi=True)

        self.models_dict = {"model": self.cycle_unet}
        model_params = self.cycle_unet.parameters()

        optimizer = radam.RAdam(list(model_params),
                                lr=self.options.lrc,
                                weight_decay=self.options.wd,
                                betas=(self.options.lr_decay, 0.999))

        self.ssim = pytorch_ssim.SSIM()
        self.l1 = nn.L1Loss(reduction="mean")
        self.image_loss = lambda x, y: self.l1(x, y) - self.ssim(x, y)

        self.optimizers_dict = {"optimizer": optimizer}

        self.train_ds, self.train_sampler = event_loader.get_and_concat_datasets(
            self.options.train_file, self.options, train=True)
        self.validation_ds, self.validation_sampler = event_loader.get_and_concat_datasets(
            self.options.validation_file, self.options, train=False)

        self.cdl_kwargs["collate_fn"] = event_utils.none_safe_collate
        self.cdl_kwargs["sampler"] = self.train_sampler
Example #16
0
    def _decomposer(self):
        """
        Build an image decomposer into a spatial binary mask of the myocardium and a non-spatial vector z of the
        remaining image information.
        :return a Keras model of the decomposer
        """
        input = Input(self.conf.input_shape)

        unet = UNet(self.conf.input_shape, residual=False)
        l = unet.unet_downsample(input)
        unet.unet_bottleneck(l)
        l = unet.bottleneck

        # build Z regressor
        modality = Conv2D(256, 3, strides=1, padding='same')(l)
        modality = BatchNormalization()(modality)
        modality = LeakyReLU()(modality)
        modality = Conv2D(64, 3, strides=1, padding='same')(modality)
        modality = BatchNormalization()(modality)
        modality = LeakyReLU()(modality)
        modality = Flatten()(modality)
        modality = Dense(32)(modality)
        modality = LeakyReLU()(modality)
        modality = Dense(16, activation='sigmoid')(modality)

        l = unet.unet_upsample(unet.bottleneck)
        anatomy = unet.out(l)

        m = Model(inputs=input, outputs=[anatomy, modality], name='Decomposer')
        log.info('Decomposer')
        m.summary(print_fn=log.info)
        return m
def build(conf, name='Enc_Anatomy'):
    """
    Build a UNet based encoder to extract anatomical information from the image.
    """
    spatial_encoder = UNet(conf)
    spatial_encoder.input = Input(shape=conf.input_shape)
    l1_down = spatial_encoder.unet_downsample(
        spatial_encoder.input, spatial_encoder.normalise)  # downsample
    spatial_encoder.unet_bottleneck(l1_down,
                                    spatial_encoder.normalise)  # bottleneck
    l2_up = spatial_encoder.unet_upsample(
        spatial_encoder.bottleneck, spatial_encoder.normalise)  # upsample

    anatomy = Conv2D(conf.out_channels,
                     1,
                     padding='same',
                     activation='softmax',
                     name='conv_anatomy')(l2_up)
    if conf.rounding:
        anatomy = Rounding()(anatomy)

    model = Model(inputs=spatial_encoder.input, outputs=anatomy, name=name)
    log.info('Enc_Anatomy')
    model.summary(print_fn=log.info)
    return model
Example #18
0
def load_finetuned_model(args, baseline_model):
    """

    :param args:
    :param baseline_model:
    :return:
    """
    # augment_net = Net(0, 0.0, 32, 3, 0.0, num_classes=32**2 * 3, do_res=True)
    augment_net = UNet(in_channels=3, n_classes=3, depth=1, wf=2, padding=True, batch_norm=False,
                       do_noise_channel=True,
                       up_mode='upsample', use_identity_residual=True)  # TODO(PV): Initialize UNet properly
    # TODO (JON): DEPTH 1 WORKED WELL.  Changed upconv to upsample.  Use a wf of 2.

    # This ResNet outputs scalar weights to be applied element-wise to the per-example losses
    from models.simple_models import CNN, Net
    imsize, in_channel, num_classes = 32, 3, 10
    reweighting_net = Net(0, 0.0, imsize, in_channel, 0.0, num_classes=1)
    #resnet_cifar.resnet20(num_classes=1)

    if args.load_finetune_checkpoint:
        checkpoint = torch.load(args.load_finetune_checkpoint)
        baseline_model.load_state_dict(checkpoint['elementary_model_state_dict'])
        augment_net.load_state_dict(checkpoint['augment_model_state_dict'])
        try:
            reweighting_net.load_state_dict(checkpoint['reweighting_model_state_dict'])
        except KeyError:
            pass

    augment_net, reweighting_net, baseline_model = augment_net.cuda(), reweighting_net.cuda(), baseline_model.cuda()
    augment_net.train(), reweighting_net.train(), baseline_model.train()
    return augment_net, reweighting_net, baseline_model
Example #19
0
    def __init__(self, img_size, hidden_size):
        super(Combiner, self).__init__()

        self.unet = UNet(in_channels=21,
                         out_channels=1,
                         depth=3,
                         start_filts=64,
                         up_mode="upsample",
                         merge_mode='concat')
Example #20
0
 def __init__(self, flownet_backbone):
     super(convAE, self).__init__()
     self.generator = UNet(input_channels=12, output_channel=3)
     self.discriminator = PixelDiscriminator(input_nc=3)
     self.flownet_backbone = flownet_backbone
     if flownet_backbone == '2sd':
         self.flow_net = FlowNet2SD()
     else:
         self.flow_net = lite_flow.Network()
Example #21
0
 def __init__(self, options):
     self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
     self.generator = UNet(num_input_channels=2*options.n_image_channels,
                           num_output_channels=options.n_time_bins * 2,
                           skip_type='concat',
                           activation='relu',
                           num_encoders=4,
                           base_num_channels=32,
                           num_residual_blocks=2,
                           norm='BN',
                           use_upsample_conv=True,
                           with_activation=True,
                           sn=options.sn,
                           multi=False)
     latest_checkpoint = get_latest_checkpoint(options.checkpoint_dir)
     checkpoint = torch.load(latest_checkpoint)
     self.generator.load_state_dict(checkpoint["gen"])
     self.generator.to(self.device)
    def __init__(self, path_to_shape_net_weights='', n_classes=15):
        super(SH_UNet, self).__init__()

        self.unet = UNet((3, 512, 512))
        self.shapeUNet = ShapeUNet((15, 512, 512))
        self.softmax = nn.Softmax(dim=1)
        if path_to_shape_net_weights:
            self.shapeUNet.load_state_dict(
                torch.load(path_to_shape_net_weights))
Example #23
0
    def _compile(self):
        """
        Compiles model (architecture, loss function, optimizers, etc.).
        初始化 网络、损失函数、优化器等
        """

        print('Noise2Noise: Learning Image Restoration without Clean Data (Lethinen et al., 2018)')

        # Model (3x3=9 channels for Monte Carlo since it uses 3 HDR buffers)  已删除蒙特卡洛相关代码
        if self.p.noise_type == 'mc':
            self.is_mc = True
            self.model = UNet(in_channels=9)
        else:
            self.is_mc = False
            self.model = UNet(in_channels=3)

        # Set optimizer and loss, if in training mode
        # 如果 为训练,则初始化优化器和损失
        if self.trainable:
            self.optim = Adam(self.model.parameters(),
                              lr=self.p.learning_rate,
                              betas=self.p.adam[:2],
                              eps=self.p.adam[2])

            # Learning rate adjustment
            self.scheduler = lr_scheduler.ReduceLROnPlateau(self.optim,
                patience=self.p.nb_epochs/4, factor=0.5, verbose=True)

            # Loss function
            if self.p.loss == 'hdr':
                assert self.is_mc, 'Using HDR loss on non Monte Carlo images'
                self.loss = HDRLoss()
            elif self.p.loss == 'l2':
                self.loss = nn.MSELoss()
            else:
                self.loss = nn.L1Loss()

        # CUDA support
        self.use_cuda = torch.cuda.is_available() and self.p.cuda
        if self.use_cuda:
            self.model = self.model.cuda()
            if self.trainable:
                self.loss = self.loss.cuda()
    def build(self, conf):
        # build encoder1
        encoder1 = UNet(conf)
        encoder1.input = Input(shape=conf.input_shape)
        l1 = encoder1.unet_downsample(encoder1.input, encoder1.normalise)

        # build encoder2
        encoder2 = UNet(conf)
        encoder2.input = Input(shape=conf.input_shape)
        l2 = encoder2.unet_downsample(encoder2.input, encoder2.normalise)

        self.build_decoder(conf)

        d1_l3 = encoder1.d_l3 if conf.downsample > 3 else None
        d2_l3 = encoder2.d_l3 if conf.downsample > 3 else None
        anatomy_output1 = self.evaluate_decoder(conf, l1, d1_l3, encoder1.d_l2,
                                                encoder1.d_l1, encoder1.d_l0)
        anatomy_output2 = self.evaluate_decoder(conf, l2, d2_l3, encoder2.d_l2,
                                                encoder2.d_l1, encoder2.d_l0)

        # build shared layer
        shr_lay4 = Conv2D(conf.out_channels,
                          1,
                          padding='same',
                          activation='softmax',
                          name='conv_anatomy')

        # connect models
        encoder1_output = shr_lay4(anatomy_output1)
        encoder2_output = shr_lay4(anatomy_output2)

        if conf.rounding:
            encoder1_output = Rounding()(encoder1_output)
            encoder2_output = Rounding()(encoder2_output)

        encoder1 = Model(inputs=encoder1.input,
                         outputs=encoder1_output,
                         name='Enc_Anatomy_%s' % self.modalities[0])
        encoder2 = Model(inputs=encoder2.input,
                         outputs=encoder2_output,
                         name='Enc_Anatomy_%s' % self.modalities[1])

        return [encoder1, encoder2]
 def build_unet(self, in_channels, n_class, kernels, strides):
     return UNet(
         in_channels=in_channels,
         n_class=n_class,
         kernels=kernels,
         strides=strides,
         normalization_layer=self.args.norm,
         negative_slope=self.args.negative_slope,
         deep_supervision=self.args.deep_supervision,
         dimension=self.args.dim,
     )
Example #26
0
def train():

    model = UNet(cfg.input_shape)

    #编译和打印模型
    model.compile(optimizer=cfg.optimizer, loss=cfg.loss, metrics=cfg.metrics)
    print_summary(model=model)

    #训练数据生成器G1
    G1 = imageSegmentationGenerator(cfg.train_images, cfg.train_annotations,
                                    cfg.train_batch_size, cfg.n_classes,
                                    cfg.input_shape[0], cfg.input_shape[1],
                                    cfg.output_shape[0], cfg.output_shape[1])
    #测试数据生成器G2
    if cfg.validate:
        G2 = imageSegmentationGenerator(cfg.val_images, cfg.val_annotations,
                                        cfg.val_batch_size, cfg.n_classes,
                                        cfg.input_shape[0], cfg.input_shape[1],
                                        cfg.output_shape[0],
                                        cfg.output_shape[1])
    #循环训练
    save_index = 1
    for ep in range(cfg.epochs):
        #1、训练两种方式
        if not cfg.validate:  #只有G1
            hisroy = model.fit_generator(
                G1,
                steps_per_epoch=cfg.train_steps_per_epoch,
                workers=cfg.workers,
                epochs=1,
                verbose=1,
                use_multiprocessing=cfg.use_multiprocessing)
        else:  #有G1和G2
            hisroy = model.fit_generator(
                G1,
                steps_per_epoch=cfg.train_steps_per_epoch,
                workers=cfg.workers,
                epochs=1,
                verbose=1,
                use_multiprocessing=cfg.use_multiprocessing,
                validation_data=G2,
                validation_steps=cfg.validate_steps_per_epoch)

        # 2、保存模型
        if save_index == cfg.epochs_save:
            save_index = 1
            save_weights_name = 'model.{}'.format(ep)
            save_weights_path = os.path.join(cfg.save_weights_path,
                                             save_weights_name)
            model.save_weights(save_weights_path)
        save_index += 1
Example #27
0
    def __init__(self, input_topic, output_topic, resize_width, resize_height,
                 model_path, force_cpu):
        self.bridge = CvBridge()

        self.graph = UNet([3, resize_width, resize_height], 3)
        self.graph.load_state_dict(torch.load(model_path))
        self.force_cpu = force_cpu and torch.cuda.is_available()

        self.resize_width, self.resize_height = resize_width, resize_height

        if not self.force_cpu:
            self.graph.cuda()
        self.graph.eval()
        self.to_tensor = transforms.Compose([transforms.ToTensor()])

        self.publisher = rospy.Publisher(output_topic, ImMsg, queue_size=1)
        self.raw_subscriber = rospy.Subscriber(input_topic,
                                               CompressedImage,
                                               self.image_cb,
                                               queue_size=1,
                                               buff_size=10**8)
Example #28
0
 def build_nnunet(self):
     in_channels, n_class, kernels, strides, self.patch_size = get_unet_params(
         self.args)
     self.n_class = n_class - 1
     self.model = UNet(
         in_channels=in_channels,
         n_class=n_class,
         kernels=kernels,
         strides=strides,
         dimension=self.args.dim,
         residual=self.args.residual,
         attention=self.args.attention,
         drop_block=self.args.drop_block,
         normalization_layer=self.args.norm,
         negative_slope=self.args.negative_slope,
         deep_supervision=self.args.deep_supervision,
     )
     if is_main_process():
         print(
             f"Filters: {self.model.filters},\nKernels: {kernels}\nStrides: {strides}"
         )
def detect_noise_regions(image, args):
    # load noise segmentation network (U-Net)
    unet_model_path = os.path.join(args.checkpoints, 'unet', 'UNet.pth')
    net = UNet(n_channels=3, n_classes=1).to(device)
    net.load_state_dict(torch.load(unet_model_path))
    net.eval()

    # predict noise regions
    predict = predict_img(net, device, image)

    # search inpaint patches
    patches, labels, _, absolute_position, relative_position = search_inpaint_area(np.array(image),
                                                                                   np.array(predict.convert('RGB')))

    # save inpaint patches
    patches_dir = os.path.join(args.output, 'patches')
    labels_dir = os.path.join(args.output, 'labels')
    os.makedirs(patches_dir, exist_ok=True)
    os.makedirs(labels_dir, exist_ok=True)
    filename = os.path.basename(args.input).split('.')[0]
    counter = 0
    for patch, label in zip(patches, labels):
        Image.fromarray(patch).save(os.path.join(patches_dir, '{}-{:0>3d}.png'.format(filename, counter)))
        Image.fromarray(label).save(os.path.join(labels_dir, '{}-{:0>3d}.png'.format(filename, counter)))
        counter += 1
    return patches_dir, labels_dir, absolute_position, relative_position
Example #30
0
def train():
    args = setup_run_arguments()

    # args = parse_args()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"[INFO] Initializing UNet-model using: {device}")

    net = UNet(n_channels=args.n_channels, n_classes=args.n_classes, bilinear=True)

    if args.from_pretrained:
        net.load_state_dict(torch.load(args.from_pretrained, map_location=device))

    net.to(device=device)

    training_loop.run(network=net,
                      epochs=args.epochs,
                      batch_size=args.batch_size,
                      lr=args.learning_rate,
                      device=device,
                      n_classes=args.n_classes,
                      val_percent=args.val_percent,
                      image_dir=args.image_dir,
                      mask_dir=args.mask_dir,
                      checkpoint_path=args.checkpoint_path,
                      loss=args.loss,
                      num_workers=args.num_workers
                      )