def load_model(opts, n_classes=4): if opts.model == 'unet': model = UNet(n_classes=n_classes) elif opts.model == 'fcn': if opts.backbone == 'resnet50': model = load_fcn_resnet50(n_classes) elif opts.backbone == 'resnet101': model = load_fcn_resnet101(n_classes) else: raise NotImplementedError("Invalid backbone specified") elif opts.model == 'deeplab': if opts.backbone == 'resnet50': model = load_deeplab_resnet50(n_classes) elif opts.backbone == 'resnet101': model = load_deeplab_resnet101(n_classes) else: raise NotImplementedError("Invalid backbone specified") elif opts.model == 'deeplabv3+': if opts.backbone == 'resnet101': model = DeepLabv3_plus_resnet(n_classes) elif opts.backbone == 'xception': model = DeepLabv3_plus_xception(n_classes) else: raise NotImplementedError("Invalid backbone specified") else: raise NotImplementedError("Invalid model type specified") model.n_classes = n_classes return model
def main(_): config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: model = UNet(args.experiment_dir, batch_size=args.batch_size, experiment_id=args.experiment_id, input_width=args.image_size, output_width=args.image_size, embedding_num=args.embedding_num, embedding_dim=args.embedding_dim, L1_penalty=args.L1_penalty) model.register_session(sess) if args.flip_labels: model.build_model(is_training=True, inst_norm=args.inst_norm, no_target_source=True) else: model.build_model(is_training=True, inst_norm=args.inst_norm) fine_tune_list = None if args.fine_tune: ids = args.fine_tune.split(",") fine_tune_list = set([int(i) for i in ids]) model.train(lr=args.lr, epoch=args.epoch, resume=args.resume, schedule=args.schedule, freeze_encoder=args.freeze_encoder, fine_tune=fine_tune_list, sample_steps=args.sample_steps, checkpoint_steps=args.checkpoint_steps, flip_labels=args.flip_labels)
def __init__(self, seq_length, color_channels, unet_path="pretrained/unet.mdl", discrim_path="pretrained/dicrim.mdl", facenet_path="pretrained/facenet.mdl", vgg_path="", embedding_size=1000, unet_depth=3, unet_filts=32, facenet_filts=32, resnet=18): self.color_channels = color_channels self.margin = 0.5 self.writer = SummaryWriter(log_dir="logs") self.unet_path = unet_path self.discrim_path = discrim_path self.facenet_path = facenet_path self.unet = UNet(in_channels=color_channels, out_channels=color_channels, depth=unet_depth, start_filts=unet_filts, up_mode="upsample", merge_mode='concat').to(device) self.discrim = FaceNetModel(embedding_size=embedding_size, start_filts=facenet_filts, in_channels=color_channels, resnet=resnet, pretrained=False).to(device) self.facenet = FaceNetModel(embedding_size=embedding_size, start_filts=facenet_filts, in_channels=color_channels, resnet=resnet, pretrained=False).to(device) if os.path.isfile(unet_path): self.unet.load_state_dict(torch.load(unet_path)) print("unet loaded") if os.path.isfile(discrim_path): self.discrim.load_state_dict(torch.load(discrim_path)) print("discrim loaded") if os.path.isfile(facenet_path): self.facenet.load_state_dict(torch.load(facenet_path)) print("facenet loaded") if os.path.isfile(vgg_path): self.vgg_loss_network = LossNetwork(vgg_face_dag(vgg_path)).to(device) self.vgg_loss_network.eval() print("vgg loaded") self.mse_loss_function = nn.MSELoss().to(device) self.discrim_loss_function = nn.BCELoss().to(device) self.triplet_loss_function = TripletLoss(margin=self.margin) self.unet_optimizer = torch.optim.Adam(self.unet.parameters(), betas=(0.9, 0.999)) self.discrim_optimizer = torch.optim.Adam(self.discrim.parameters(), betas=(0.9, 0.999)) self.facenet_optimizer = torch.optim.Adam(self.facenet.parameters(), betas=(0.9, 0.999))
def train(): # Init data train_dataset, val_dataset = prepare_datasets() train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=10, shuffle=True) loaders = dict(train=train_loader, val=val_loader) # Init Model model = UNet().cuda() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, amsgrad=True) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.984) loss_fn = nn.BCELoss() epochs = 500 for epoch in range(epochs): for phase in 'train val'.split(): if phase == 'train': model = model.train() torch.set_grad_enabled(True) else: model = model.eval() torch.set_grad_enabled(False) loader = loaders[phase] epoch_losses = dict(train=[], val=[]) running_loss = [] for batch in loader: imgs, masks = batch imgs = imgs.cuda() masks = masks.cuda() outputs = model(imgs) loss = loss_fn(outputs, masks) running_loss.append(loss.item()) if phase == 'train': optimizer.zero_grad() loss.backward() optimizer.step() # End of Epoch print(f'{epoch}) {phase} loss: {np.mean(running_loss)}') visualize_results(loader, model, epoch, phase) epoch_losses[phase].append(np.mean(running_loss)) tensorboard(epoch_losses[phase], phase) if phase == 'train': scheduler.step()
def create_model(args, input_shape, enable_decoder=True): # If using CPU or single GPU if args.gpus <= 1: if args.net == 'unet': from models.unet import UNet model = UNet(input_shape) return [model] elif args.net == 'tiramisu': from models.densenets import DenseNetFCN model = DenseNetFCN(input_shape) return [model] elif args.net == 'segcapsr1': from segcapsnet.capsnet import CapsNetR1 model_list = CapsNetR1(input_shape) return model_list elif args.net == 'segcapsr3': from segcapsnet.capsnet import CapsNetR3 model_list = CapsNetR3(input_shape, args.num_class, enable_decoder) return model_list elif args.net == 'capsbasic': from segcapsnet.capsnet import CapsNetBasic model_list = CapsNetBasic(input_shape) return model_list else: raise Exception('Unknown network type specified: {}'.format( args.net)) # If using multiple GPUs else: with tf.device("/cpu:0"): if args.net == 'unet': from models.unet import UNet model = UNet(input_shape) return [model] elif args.net == 'tiramisu': from models.densenets import DenseNetFCN model = DenseNetFCN(input_shape) return [model] elif args.net == 'segcapsr1': from segcapsnet.capsnet import CapsNetR1 model_list = CapsNetR1(input_shape) return model_list elif args.net == 'segcapsr3': from segcapsnet.capsnet import CapsNetR3 model_list = CapsNetR3(input_shape, args.num_class, enable_decoder) return model_list elif args.net == 'capsbasic': from segcapsnet.capsnet import CapsNetBasic model_list = CapsNetBasic(input_shape) return model_list else: raise Exception('Unknown network type specified: {}'.format( args.net))
def visualize_voc_unet(): from data.voc2012_loader_segmentation import PascalVOCSegmentation from torch.utils.data.dataloader import DataLoader from visualize.visualize import visualize from models.unet import UNet dataloader = DataLoader(PascalVOCSegmentation('val'), batch_size=16, shuffle=False, num_workers=0) model = UNet(outputs=21, name='voc_unet') model.load() visualize(model, dataloader, model.name + '_visualization/')
def run(): """Builds model, loads data, trains and evaluates""" model = UNet(CFG) model.load_data() model.build() #model.train() model.evaluate()
def validate(state_dict_path, use_gpu, device): model = UNet(n_channels=1, n_classes=2) model.load_state_dict(torch.load(state_dict_path, map_location='cpu' if not use_gpu else device)) model.to(device) val_transforms = transforms.Compose([ ToTensor(), NormalizeBRATS()]) BraTS_val_ds = BRATS2018('./BRATS2018',\ data_set='val',\ seg_type='et',\ scan_type='t1ce',\ transform=val_transforms) data_loader = DataLoader(BraTS_val_ds, batch_size=2, shuffle=False, num_workers=0) running_dice_score = 0. for batch_ind, batch in enumerate(data_loader): imgs, targets = batch imgs = imgs.to(device) targets = targets.to(device) model.eval() with torch.no_grad(): outputs = model(imgs) preds = torch.argmax(F.softmax(outputs, dim=1), dim=1) running_dice_score += dice_score(preds, targets) * targets.size(0) print('running dice score: {:.6f}'.format(running_dice_score)) dice = running_dice_score / len(BraTS_val_ds) print('mean dice score of the validating set: {:.6f}'.format(dice))
class EventGANBase(object): def __init__(self, options): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.generator = UNet(num_input_channels=2*options.n_image_channels, num_output_channels=options.n_time_bins * 2, skip_type='concat', activation='relu', num_encoders=4, base_num_channels=32, num_residual_blocks=2, norm='BN', use_upsample_conv=True, with_activation=True, sn=options.sn, multi=False) latest_checkpoint = get_latest_checkpoint(options.checkpoint_dir) checkpoint = torch.load(latest_checkpoint) self.generator.load_state_dict(checkpoint["gen"]) self.generator.to(self.device) def forward(self, images, is_train=False): if len(images.shape) == 3: images = images[None, ...] assert len(images.shape) == 4 and images.shape[1] == 2, \ "Input images must be either 2xHxW or Bx2xHxW." if not is_train: with torch.no_grad(): self.generator.eval() event_volume = self.generator(images) self.generator.train() else: event_volume = self.generator(images) return event_volume
def get_model(model_name): model = None if model_name == 'vgg16': from models.vgg16 import Vgg16GAP model = Vgg16GAP(name="vgg16") return model if model_name == 'unet': from models.unet import UNet model = UNet() return model if model_name == 'deeplab': from models.deeplab import DeepLab model = DeepLab(name="deeplab") return model if model_name == 'affinitynet': from models.aff_net import AffNet model = AffNet(name="affinitynet") return model if model_name == 'wasscam': from models.wass import WASS model = WASS() return model raise Error('Model name has no implementation')
def get_net(input_depth, NET_TYPE, pad, upsample_mode, n_channels=3, act_fun='LeakyReLU', skip_n33d=128, skip_n33u=128, skip_n11=4, num_scales=5, downsample_mode='stride'): if NET_TYPE == 'ResNet': # TODO net = ResNet(input_depth, 3, 10, 16, 1, nn.BatchNorm2d, False) elif NET_TYPE == 'skip': net = skip(input_depth, n_channels, num_channels_down = [skip_n33d]*num_scales if isinstance(skip_n33d, int) else skip_n33d, num_channels_up = [skip_n33u]*num_scales if isinstance(skip_n33u, int) else skip_n33u, num_channels_skip = [skip_n11]*num_scales if isinstance(skip_n11, int) else skip_n11, upsample_mode=upsample_mode, downsample_mode=downsample_mode, need_sigmoid=True, need_bias=True, pad=pad, act_fun=act_fun) elif NET_TYPE == 'texture_nets': net = get_texture_nets(inp=input_depth, ratios = [32, 16, 8, 4, 2, 1], fill_noise=False,pad=pad) elif NET_TYPE =='UNet': net = UNet(num_input_channels=input_depth, num_output_channels=3, feature_scale=4, more_layers=0, concat_x=False, upsample_mode=upsample_mode, pad=pad, norm_layer=nn.BatchNorm2d, need_sigmoid=True, need_bias=True) elif NET_TYPE == 'identity': assert input_depth == 3 net = nn.Sequential() else: assert False return net
def get_model(self, dataset): if self.args.model=="GCN": model = GCN(dataset.num_classes, dataset.img_size, k=self.args.K).cuda() elif self.args.model=="UNet": model = UNet(dataset.num_classes).cuda() elif self.args.model=="GCN_DENSENET": model = GCN_DENSENET(dataset.num_classes, dataset.img_size, k=self.args.K).cuda() elif self.args.model=="GCN_DECONV": model = GCN_DECONV(dataset.num_classes, dataset.img_size, k=self.args.K).cuda() elif self.args.model=="GCN_PSP": model = GCN_PSP(dataset.num_classes, dataset.img_size, k=self.args.K).cuda() elif self.args.model=="GCN_COMB": model = GCN_COMBINED(dataset.num_classes, dataset.img_size, k=self.args.K).cuda() elif self.args.model=="GCN_RESNEXT": model = GCN_RESNEXT(dataset.num_classes, k=self.args.K).cuda() else: raise ValueError("Invalid model arg.") start_epoch = 0 if self.args.resume: setup.load_save(model, self.args) start_epoch = self.args.resume_epoch self.args.start_epoch = start_epoch model.train() return model
def __init__(self, in_channels=12, use_model='unet', use_d8=False, learning_rate=0.02, adam_epsilon=1e-8, **kwargs): super(DrainageNetworkExtractor, self).__init__() self.save_hyperparameters() if use_model.lower() == 'unet': self.model = UNet(n_channels=in_channels, n_classes=12, bilinear=self.hparams.bilinear) elif use_model.lower() == 'lhn_unet': self.model = LHNUNet(n_channels=in_channels, n_classes=12, n_classes_l1=self.hparams.n_classes_l1, n_classes_l2=self.hparams.n_classes_l2, n_classes_l3=self.hparams.n_classes_l3, n_classes_l4=self.hparams.n_classes_l4) elif use_model.lower() == 'deep_lab': self.model = DeepLab(backbone=self.hparams.backbone, in_channels=in_channels, num_classes=12, sync_bn=self.hparams.sync_bn, freeze_bn=self.hparams.freeze_bn, output_stride=self.hparams.output_stride) elif use_model.lower() == 'modsegnet': self.model = ModSegNet(num_classes=12, n_init_features=in_channels, drop_rate=self.hparams.drop_rate) elif use_model.lower() == 'segnet': self.model = SegNet(num_classes=12, n_init_features=in_channels, drop_rate=self.hparams.drop_rate, use_kriging_loss=self.hparams.use_kriging_loss) elif use_model.lower() == 'aspp_segnet': self.model = ASPPSegNet(num_classes=12, n_init_features=in_channels, use_kriging_loss=self.hparams.use_kriging_loss) elif use_model.lower() == 'sp_segnet': self.model = SPSegNet(num_classes=12, n_init_features=in_channels) elif use_model.lower() == 'dl_segnet': self.model = DLSegNet(num_classes=12, n_init_features=in_channels, drop_rate=self.hparams.drop_rate) else: raise Exception(f"{use_model} is not implemented") if use_d8: self.d8_emb = nn.Embedding(9, 3, max_norm=1)
def add_model_specific_args(parent_parser): parser = ArgumentParser(parents=[parent_parser], add_help=False) parser.add_argument('--use-coord', action='store_true') parser.add_argument('--use-d8', action='store_true') parser.add_argument('--use-slope', action='store_true') parser.add_argument('--use-curvature', action='store_true') parser.add_argument('--in-channels', type=int, default=12) parser.add_argument('--use-model', type=str, default='unet') parser.add_argument('--learning-rate', type=float, default=0.02) parser.add_argument('--adam-epsilon', type=float, default=1e-8) parser.add_argument('--use-kriging-loss', action='store_true') parser.fromfile_prefix_chars = "@" temp_args, _ = parser.parse_known_args() if temp_args.use_model.lower() == 'unet': parser = UNet.add_model_specific_args(parser) elif temp_args.use_model.lower() == 'lhn_unet': parser = LHNUNet.add_model_specific_args(parser) elif temp_args.use_model.lower() == 'deep_lab': parser = DeepLab.add_model_specific_args(parser) elif temp_args.use_model.lower() == 'modsegnet': parser = ModSegNet.add_model_specific_args(parser) elif temp_args.use_model.lower() == 'segnet': parser = SegNet.add_model_specific_args(parser) elif temp_args.use_model.lower() == 'aspp_segnet': parser = ASPPSegNet.add_model_specific_args(parser) elif temp_args.use_model.lower() == 'sp_segnet': parser = SPSegNet.add_model_specific_args(parser) elif temp_args.use_model.lower() == 'dl_segnet': parser = DLSegNet.add_model_specific_args(parser) return parser
def init_fn(self): if self.options.model == 'flow': num_input_channels = self.options.n_time_bins * 2 num_output_channels = 2 elif self.options.model == 'recons': # For the reconstruction model, we sum the event volume across the time dimension, so # that the network only sees a single channel event input, plus the prev image. num_input_channels = 1 + self.options.n_image_channels num_output_channels = self.options.n_image_channels else: raise ValueError( "Class was initialized with an invalid model {}" ", only {EventGAN, flow, recons} are supported.".format( self.options.model)) self.cycle_unet = UNet(num_input_channels=num_input_channels, num_output_channels=num_output_channels, skip_type='concat', activation='tanh', num_encoders=4, base_num_channels=32, num_residual_blocks=2, norm='BN', use_upsample_conv=True, multi=True) self.models_dict = {"model": self.cycle_unet} model_params = self.cycle_unet.parameters() optimizer = radam.RAdam(list(model_params), lr=self.options.lrc, weight_decay=self.options.wd, betas=(self.options.lr_decay, 0.999)) self.ssim = pytorch_ssim.SSIM() self.l1 = nn.L1Loss(reduction="mean") self.image_loss = lambda x, y: self.l1(x, y) - self.ssim(x, y) self.optimizers_dict = {"optimizer": optimizer} self.train_ds, self.train_sampler = event_loader.get_and_concat_datasets( self.options.train_file, self.options, train=True) self.validation_ds, self.validation_sampler = event_loader.get_and_concat_datasets( self.options.validation_file, self.options, train=False) self.cdl_kwargs["collate_fn"] = event_utils.none_safe_collate self.cdl_kwargs["sampler"] = self.train_sampler
def _decomposer(self): """ Build an image decomposer into a spatial binary mask of the myocardium and a non-spatial vector z of the remaining image information. :return a Keras model of the decomposer """ input = Input(self.conf.input_shape) unet = UNet(self.conf.input_shape, residual=False) l = unet.unet_downsample(input) unet.unet_bottleneck(l) l = unet.bottleneck # build Z regressor modality = Conv2D(256, 3, strides=1, padding='same')(l) modality = BatchNormalization()(modality) modality = LeakyReLU()(modality) modality = Conv2D(64, 3, strides=1, padding='same')(modality) modality = BatchNormalization()(modality) modality = LeakyReLU()(modality) modality = Flatten()(modality) modality = Dense(32)(modality) modality = LeakyReLU()(modality) modality = Dense(16, activation='sigmoid')(modality) l = unet.unet_upsample(unet.bottleneck) anatomy = unet.out(l) m = Model(inputs=input, outputs=[anatomy, modality], name='Decomposer') log.info('Decomposer') m.summary(print_fn=log.info) return m
def build(conf, name='Enc_Anatomy'): """ Build a UNet based encoder to extract anatomical information from the image. """ spatial_encoder = UNet(conf) spatial_encoder.input = Input(shape=conf.input_shape) l1_down = spatial_encoder.unet_downsample( spatial_encoder.input, spatial_encoder.normalise) # downsample spatial_encoder.unet_bottleneck(l1_down, spatial_encoder.normalise) # bottleneck l2_up = spatial_encoder.unet_upsample( spatial_encoder.bottleneck, spatial_encoder.normalise) # upsample anatomy = Conv2D(conf.out_channels, 1, padding='same', activation='softmax', name='conv_anatomy')(l2_up) if conf.rounding: anatomy = Rounding()(anatomy) model = Model(inputs=spatial_encoder.input, outputs=anatomy, name=name) log.info('Enc_Anatomy') model.summary(print_fn=log.info) return model
def load_finetuned_model(args, baseline_model): """ :param args: :param baseline_model: :return: """ # augment_net = Net(0, 0.0, 32, 3, 0.0, num_classes=32**2 * 3, do_res=True) augment_net = UNet(in_channels=3, n_classes=3, depth=1, wf=2, padding=True, batch_norm=False, do_noise_channel=True, up_mode='upsample', use_identity_residual=True) # TODO(PV): Initialize UNet properly # TODO (JON): DEPTH 1 WORKED WELL. Changed upconv to upsample. Use a wf of 2. # This ResNet outputs scalar weights to be applied element-wise to the per-example losses from models.simple_models import CNN, Net imsize, in_channel, num_classes = 32, 3, 10 reweighting_net = Net(0, 0.0, imsize, in_channel, 0.0, num_classes=1) #resnet_cifar.resnet20(num_classes=1) if args.load_finetune_checkpoint: checkpoint = torch.load(args.load_finetune_checkpoint) baseline_model.load_state_dict(checkpoint['elementary_model_state_dict']) augment_net.load_state_dict(checkpoint['augment_model_state_dict']) try: reweighting_net.load_state_dict(checkpoint['reweighting_model_state_dict']) except KeyError: pass augment_net, reweighting_net, baseline_model = augment_net.cuda(), reweighting_net.cuda(), baseline_model.cuda() augment_net.train(), reweighting_net.train(), baseline_model.train() return augment_net, reweighting_net, baseline_model
def __init__(self, img_size, hidden_size): super(Combiner, self).__init__() self.unet = UNet(in_channels=21, out_channels=1, depth=3, start_filts=64, up_mode="upsample", merge_mode='concat')
def __init__(self, flownet_backbone): super(convAE, self).__init__() self.generator = UNet(input_channels=12, output_channel=3) self.discriminator = PixelDiscriminator(input_nc=3) self.flownet_backbone = flownet_backbone if flownet_backbone == '2sd': self.flow_net = FlowNet2SD() else: self.flow_net = lite_flow.Network()
def __init__(self, options): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.generator = UNet(num_input_channels=2*options.n_image_channels, num_output_channels=options.n_time_bins * 2, skip_type='concat', activation='relu', num_encoders=4, base_num_channels=32, num_residual_blocks=2, norm='BN', use_upsample_conv=True, with_activation=True, sn=options.sn, multi=False) latest_checkpoint = get_latest_checkpoint(options.checkpoint_dir) checkpoint = torch.load(latest_checkpoint) self.generator.load_state_dict(checkpoint["gen"]) self.generator.to(self.device)
def __init__(self, path_to_shape_net_weights='', n_classes=15): super(SH_UNet, self).__init__() self.unet = UNet((3, 512, 512)) self.shapeUNet = ShapeUNet((15, 512, 512)) self.softmax = nn.Softmax(dim=1) if path_to_shape_net_weights: self.shapeUNet.load_state_dict( torch.load(path_to_shape_net_weights))
def _compile(self): """ Compiles model (architecture, loss function, optimizers, etc.). 初始化 网络、损失函数、优化器等 """ print('Noise2Noise: Learning Image Restoration without Clean Data (Lethinen et al., 2018)') # Model (3x3=9 channels for Monte Carlo since it uses 3 HDR buffers) 已删除蒙特卡洛相关代码 if self.p.noise_type == 'mc': self.is_mc = True self.model = UNet(in_channels=9) else: self.is_mc = False self.model = UNet(in_channels=3) # Set optimizer and loss, if in training mode # 如果 为训练,则初始化优化器和损失 if self.trainable: self.optim = Adam(self.model.parameters(), lr=self.p.learning_rate, betas=self.p.adam[:2], eps=self.p.adam[2]) # Learning rate adjustment self.scheduler = lr_scheduler.ReduceLROnPlateau(self.optim, patience=self.p.nb_epochs/4, factor=0.5, verbose=True) # Loss function if self.p.loss == 'hdr': assert self.is_mc, 'Using HDR loss on non Monte Carlo images' self.loss = HDRLoss() elif self.p.loss == 'l2': self.loss = nn.MSELoss() else: self.loss = nn.L1Loss() # CUDA support self.use_cuda = torch.cuda.is_available() and self.p.cuda if self.use_cuda: self.model = self.model.cuda() if self.trainable: self.loss = self.loss.cuda()
def build(self, conf): # build encoder1 encoder1 = UNet(conf) encoder1.input = Input(shape=conf.input_shape) l1 = encoder1.unet_downsample(encoder1.input, encoder1.normalise) # build encoder2 encoder2 = UNet(conf) encoder2.input = Input(shape=conf.input_shape) l2 = encoder2.unet_downsample(encoder2.input, encoder2.normalise) self.build_decoder(conf) d1_l3 = encoder1.d_l3 if conf.downsample > 3 else None d2_l3 = encoder2.d_l3 if conf.downsample > 3 else None anatomy_output1 = self.evaluate_decoder(conf, l1, d1_l3, encoder1.d_l2, encoder1.d_l1, encoder1.d_l0) anatomy_output2 = self.evaluate_decoder(conf, l2, d2_l3, encoder2.d_l2, encoder2.d_l1, encoder2.d_l0) # build shared layer shr_lay4 = Conv2D(conf.out_channels, 1, padding='same', activation='softmax', name='conv_anatomy') # connect models encoder1_output = shr_lay4(anatomy_output1) encoder2_output = shr_lay4(anatomy_output2) if conf.rounding: encoder1_output = Rounding()(encoder1_output) encoder2_output = Rounding()(encoder2_output) encoder1 = Model(inputs=encoder1.input, outputs=encoder1_output, name='Enc_Anatomy_%s' % self.modalities[0]) encoder2 = Model(inputs=encoder2.input, outputs=encoder2_output, name='Enc_Anatomy_%s' % self.modalities[1]) return [encoder1, encoder2]
def build_unet(self, in_channels, n_class, kernels, strides): return UNet( in_channels=in_channels, n_class=n_class, kernels=kernels, strides=strides, normalization_layer=self.args.norm, negative_slope=self.args.negative_slope, deep_supervision=self.args.deep_supervision, dimension=self.args.dim, )
def train(): model = UNet(cfg.input_shape) #编译和打印模型 model.compile(optimizer=cfg.optimizer, loss=cfg.loss, metrics=cfg.metrics) print_summary(model=model) #训练数据生成器G1 G1 = imageSegmentationGenerator(cfg.train_images, cfg.train_annotations, cfg.train_batch_size, cfg.n_classes, cfg.input_shape[0], cfg.input_shape[1], cfg.output_shape[0], cfg.output_shape[1]) #测试数据生成器G2 if cfg.validate: G2 = imageSegmentationGenerator(cfg.val_images, cfg.val_annotations, cfg.val_batch_size, cfg.n_classes, cfg.input_shape[0], cfg.input_shape[1], cfg.output_shape[0], cfg.output_shape[1]) #循环训练 save_index = 1 for ep in range(cfg.epochs): #1、训练两种方式 if not cfg.validate: #只有G1 hisroy = model.fit_generator( G1, steps_per_epoch=cfg.train_steps_per_epoch, workers=cfg.workers, epochs=1, verbose=1, use_multiprocessing=cfg.use_multiprocessing) else: #有G1和G2 hisroy = model.fit_generator( G1, steps_per_epoch=cfg.train_steps_per_epoch, workers=cfg.workers, epochs=1, verbose=1, use_multiprocessing=cfg.use_multiprocessing, validation_data=G2, validation_steps=cfg.validate_steps_per_epoch) # 2、保存模型 if save_index == cfg.epochs_save: save_index = 1 save_weights_name = 'model.{}'.format(ep) save_weights_path = os.path.join(cfg.save_weights_path, save_weights_name) model.save_weights(save_weights_path) save_index += 1
def __init__(self, input_topic, output_topic, resize_width, resize_height, model_path, force_cpu): self.bridge = CvBridge() self.graph = UNet([3, resize_width, resize_height], 3) self.graph.load_state_dict(torch.load(model_path)) self.force_cpu = force_cpu and torch.cuda.is_available() self.resize_width, self.resize_height = resize_width, resize_height if not self.force_cpu: self.graph.cuda() self.graph.eval() self.to_tensor = transforms.Compose([transforms.ToTensor()]) self.publisher = rospy.Publisher(output_topic, ImMsg, queue_size=1) self.raw_subscriber = rospy.Subscriber(input_topic, CompressedImage, self.image_cb, queue_size=1, buff_size=10**8)
def build_nnunet(self): in_channels, n_class, kernels, strides, self.patch_size = get_unet_params( self.args) self.n_class = n_class - 1 self.model = UNet( in_channels=in_channels, n_class=n_class, kernels=kernels, strides=strides, dimension=self.args.dim, residual=self.args.residual, attention=self.args.attention, drop_block=self.args.drop_block, normalization_layer=self.args.norm, negative_slope=self.args.negative_slope, deep_supervision=self.args.deep_supervision, ) if is_main_process(): print( f"Filters: {self.model.filters},\nKernels: {kernels}\nStrides: {strides}" )
def detect_noise_regions(image, args): # load noise segmentation network (U-Net) unet_model_path = os.path.join(args.checkpoints, 'unet', 'UNet.pth') net = UNet(n_channels=3, n_classes=1).to(device) net.load_state_dict(torch.load(unet_model_path)) net.eval() # predict noise regions predict = predict_img(net, device, image) # search inpaint patches patches, labels, _, absolute_position, relative_position = search_inpaint_area(np.array(image), np.array(predict.convert('RGB'))) # save inpaint patches patches_dir = os.path.join(args.output, 'patches') labels_dir = os.path.join(args.output, 'labels') os.makedirs(patches_dir, exist_ok=True) os.makedirs(labels_dir, exist_ok=True) filename = os.path.basename(args.input).split('.')[0] counter = 0 for patch, label in zip(patches, labels): Image.fromarray(patch).save(os.path.join(patches_dir, '{}-{:0>3d}.png'.format(filename, counter))) Image.fromarray(label).save(os.path.join(labels_dir, '{}-{:0>3d}.png'.format(filename, counter))) counter += 1 return patches_dir, labels_dir, absolute_position, relative_position
def train(): args = setup_run_arguments() # args = parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"[INFO] Initializing UNet-model using: {device}") net = UNet(n_channels=args.n_channels, n_classes=args.n_classes, bilinear=True) if args.from_pretrained: net.load_state_dict(torch.load(args.from_pretrained, map_location=device)) net.to(device=device) training_loop.run(network=net, epochs=args.epochs, batch_size=args.batch_size, lr=args.learning_rate, device=device, n_classes=args.n_classes, val_percent=args.val_percent, image_dir=args.image_dir, mask_dir=args.mask_dir, checkpoint_path=args.checkpoint_path, loss=args.loss, num_workers=args.num_workers )