Example #1
0
    def build_model(self):
        """Build generator and discriminator."""
        if self.model_type == 'U_Net':
            self.unet = U_Net(img_ch=3, output_ch=1)
        elif self.model_type == 'R2U_Net':
            self.unet = R2U_Net(img_ch=3, output_ch=1, t=self.t)
        elif self.model_type == 'AttU_Net':
            self.unet = AttU_Net(img_ch=3, output_ch=1)
        elif self.model_type == 'R2AttU_Net':
            self.unet = R2AttU_Net(img_ch=3, output_ch=1, t=self.t)
        elif self.model_type == 'ASM_U_Net':
            self.unet = Structure_U_Net(img_ch=3, output_ch=1)
            self.analyzer = Analyzer_U_Net(img_ch=1, output_ch=1)

        self.optimizer_unet = optim.Adam(params=list(self.unet.parameters()),
                                         lr=self.lr,
                                         weight_decay=1e-5)
        self.optimizer_analyzer = optim.Adam(params=list(
            self.analyzer.parameters()),
                                             lr=self.lr * 0.2,
                                             weight_decay=1e-5)
        self.unet.to(self.device)
        self.unet = torch.nn.DataParallel(self.unet)
        self.analyzer.to(self.device)
        self.analyzer = torch.nn.DataParallel(self.analyzer)
Example #2
0
    def build_model(self):
        """Build generator and discriminator."""
        if self.model_type == 'U_Net':
            self.unet = U_Net(img_ch=self.img_ch, output_ch=1)
        elif self.model_type == 'R2U_Net':
            self.unet = R2U_Net(img_ch=self.img_ch, output_ch=1, t=self.t)
        elif self.model_type == 'AttU_Net':
            self.unet = AttU_Net(img_ch=self.img_ch, output_ch=1)
        elif self.model_type == 'R2AttU_Net':
            self.unet = R2AttU_Net(img_ch=self.img_ch, output_ch=1, t=self.t)
        elif self.model_type == 'MixU_Net':
            self.unet = MixU_Net(img_ch=self.img_ch, output_ch=1)
        elif self.model_type == 'MixAttU_Net':
            self.unet = MixAttU_Net(img_ch=self.img_ch, output_ch=1)
        elif self.model_type == 'MixR2U_Net':
            self.unet = MixR2U_Net(img_ch=self.img_ch, output_ch=1)
        elif self.model_type == 'MixR2AttU_Net':
            self.unet = MixR2AttU_Net(img_ch=self.img_ch, output_ch=1)
        elif self.model_type == 'GhostU_Net':
            self.unet = GhostU_Net(img_ch=self.img_ch, output_ch=1)
        elif self.model_type == 'GhostU_Net1':
            self.unet = GhostU_Net1(img_ch=self.img_ch, output_ch=1)
        elif self.model_type == 'GhostU_Net2':
            self.unet = GhostU_Net2(img_ch=self.img_ch, output_ch=1)

        #pytorch_total_params = sum(p.numel() for p in self.unet.parameters() if p.requires_grad)
        #print (pytorch_total_params)
        #raise
        self.optimizer = optim.Adam(list(self.unet.parameters()), self.lr,
                                    [self.beta1, self.beta2])
        self.unet.to(self.device)
Example #3
0
    def build_model(self):
        """Build generator and discriminator."""
        if self.model_type == 'UNet':
            self.unet = UNet(n_channels=1, n_classes=1)
        elif self.model_type == 'R2U_Net':
            self.unet = R2U_Net(img_ch=3, output_ch=1, t=self.t)  # TODO: changed for green image channel
        elif self.model_type == 'AttU_Net':
            self.unet = AttU_Net(img_ch=1, output_ch=1)
        elif self.model_type == 'R2AttU_Net':
            self.unet = R2AttU_Net(img_ch=3, output_ch=1, t=self.t)
        elif self.model_type == 'Iternet':
            self.unet = Iternet(n_channels=1, n_classes=1)
        elif self.model_type == 'AttUIternet':
            self.unet = AttUIternet(n_channels=1, n_classes=1)
        elif self.model_type == 'R2UIternet':
            self.unet = R2UIternet(n_channels=3, n_classes=1)
        elif self.model_type == 'NestedUNet':
            self.unet = NestedUNet(in_ch=1, out_ch=1)
        elif self.model_type == "AG_Net":
            self.unet = AG_Net(n_classes=1, bn=True, BatchNorm=False)

        self.optimizer = optim.Adam(list(self.unet.parameters()),
                                    self.lr,
                                    betas=tuple(self.beta_list))
        self.unet.to(self.device)
Example #4
0
    def build_model(self):
        """Build our deep learning model."""
        if self.model_type == 'U_Net':
            self.unet = U_Net(img_ch=1,
                              output_ch=1,
                              first_layer_numKernel=self.first_layer_numKernel)
        elif self.model_type == 'R2U_Net':
            self.unet = R2U_Net(
                img_ch=1,
                output_ch=1,
                t=self.t,
                first_layer_numKernel=self.first_layer_numKernel)
        elif self.model_type == 'AttU_Net':
            self.unet = AttU_Net(
                img_ch=1,
                output_ch=1,
                first_layer_numKernel=self.first_layer_numKernel)
        elif self.model_type == 'R2AttU_Net':
            self.unet = R2AttU_Net(
                img_ch=1,
                output_ch=1,
                t=self.t,
                first_layer_numKernel=self.first_layer_numKernel)
        elif self.model_type == 'ResAttU_Net':
            self.unet = ResAttU_Net(
                UnetLayer=self.UnetLayer,
                img_ch=1,
                output_ch=1,
                first_layer_numKernel=self.first_layer_numKernel)

        if self.initialization != 'NA':
            init_weights(self.unet, init_type=self.initialization)
        self.unet.to(self.device)
    def build_model(self):
        """Build generator and discriminator."""
        if self.model_type == 'U_Net':
            self.unet = U_Net(img_ch=1, output_ch=1)
        elif self.model_type == 'R2U_Net':
            self.unet = R2U_Net(img_ch=1, output_ch=1, t=self.t)
            #init_weights(self.unet, 'normal')

        self.optimizer = optim.Adam(list(self.unet.parameters()), self.lr,
                                    (self.beta1, self.beta2))
        self.unet.to(self.device)
Example #6
0
    def build_model(self):
        """Build generator and discriminator."""
        if self.model_type == 'U_Net':
            self.unet = U_Net(img_ch=3, output_ch=2)
        elif self.model_type == 'R2U_Net':
            self.unet = R2U_Net(img_ch=3, output_ch=1, t=self.t)
        elif self.model_type == 'AttU_Net':
            self.unet = AttU_Net(img_ch=3, output_ch=1)
        elif self.model_type == 'R2AttU_Net':
            self.unet = R2AttU_Net(img_ch=3, output_ch=1, t=self.t)

        self.optimizer = optim.Adam(list(self.unet.parameters()), self.lr,
                                    [self.beta1, self.beta2])
        self.unet.to(self.device)
Example #7
0
    def build_model(self):
        # Load required model
        if self.model_type == 'U_Net':
            self.unet = U_Net(img_ch=3, output_ch=1)
        elif self.model_type == 'R2U_Net':
            self.unet = R2U_Net(img_ch=3, output_ch=1, t=self.t)
        elif self.model_type == 'AttU_Net':
            self.unet = AttU_Net(img_ch=3, output_ch=1)
        elif self.model_type == 'R2AttU_Net':
            self.unet = R2AttU_Net(img_ch=3, output_ch=1, t=self.t)

        # Load optimizer
        self.optimizer = optim.Adam(list(self.unet.parameters()), self.lr,
                                    [self.beta1, self.beta2])
        # Move model to device
        self.unet.to(self.device)
Example #8
0
    def build_model(self, config):
        """Build generator and discriminator."""
        if self.model_type == 'U_Net':
            self.unet = U_Net(img_ch=3, output_ch=1)
        elif self.model_type == 'R2U_Net':
            self.unet = R2U_Net(img_ch=3, output_ch=1, t=self.t)
        elif self.model_type == 'AttU_Net':
            self.unet = AttU_Net(img_ch=3, output_ch=1)
        elif self.model_type == 'R2AttU_Net':
            self.unet = R2AttU_Net(img_ch=3, output_ch=1, t=self.t)

        # # Load the pretrained Encoder
        # if os.path.isfile(config.pretrained):
        # 	self.unet.load_state_dict(torch.load(config.pretrained))
        # 	print('%s is Successfully Loaded from %s'%(self.model_type,config.pretrained))

        self.optimizer = optim.Adam(list(self.unet.parameters()), self.lr,
                                    [self.beta1, self.beta2])
        self.unet.to(self.device)
Example #9
0
    def build_model(self):
        """Build generator and discriminator."""
        if self.model_type == 'U_Net':
            self.unet = U_Net(UnetLayer=self.UnetLayer,
                              img_ch=self.img_ch,
                              output_ch=self.output_ch,
                              first_layer_numKernel=self.first_layer_numKernel)
        elif self.model_type == 'R2U_Net':
            self.unet = R2U_Net(
                img_ch=self.img_ch,
                output_ch=self.output_ch,
                t=self.t,
                first_layer_numKernel=self.first_layer_numKernel)
        elif self.model_type == 'AttU_Net':
            self.unet = AttU_Net(
                img_ch=self.img_ch,
                output_ch=self.output_ch,
                first_layer_numKernel=self.first_layer_numKernel)
        elif self.model_type == 'R2AttU_Net':
            self.unet = R2AttU_Net(
                img_ch=self.img_ch,
                output_ch=self.output_ch,
                t=self.t,
                first_layer_numKernel=self.first_layer_numKernel)
        elif self.model_type == 'ResAttU_Net':
            self.unet = ResAttU_Net(
                UnetLayer=self.UnetLayer,
                img_ch=self.img_ch,
                output_ch=self.output_ch,
                first_layer_numKernel=self.first_layer_numKernel)

        if self.optimizer_choice == 'Adam':
            self.optimizer = optim.Adam(list(self.unet.parameters()),
                                        self.initial_lr,
                                        [self.beta1, self.beta2])
        elif self.optimizer_choice == 'SGD':
            self.optimizer = optim.SGD(list(self.unet.parameters()),
                                       self.initial_lr, self.momentum)
        else:
            pass

        self.unet.to(self.device)
Example #10
0
    def build_model(self):
        """Build generator and discriminator."""
        if self.model_type == 'U_Net':
            self.unet = U_Net(img_ch=3, output_ch=3)
        elif self.model_type == 'R2U_Net':
            print("------> using R2U <--------")
            self.unet = R2U_Net(img_ch=3, output_ch=3, t=self.t)
        elif self.model_type == 'AttU_Net':
            print("------> using AttU <--------")
            self.unet = AttU_Net(img_ch=3, output_ch=3)
        elif self.model_type == 'R2AttU_Net':
            print("------> using R2-AttU <--------")
            self.unet = R2AttU_Net(img_ch=3, output_ch=3, t=self.t)
        elif self.model_type == 'ABU_Net':
            print("------> using ABU_Net <--------")
            self.unet = U_Net_AB(img_ch=3, output_ch=1)
        elif self.model_type == 'Multi_Task':
            print("------> using Multi_Task Learning <--------")
            model = torch.hub.load('pytorch/vision',
                                   'mobilenet_v2',
                                   pretrained=True)
            model_infeatures_final_layer = model.classifier[1].in_features
            model.classifier = torch.nn.Sequential(
                *list(model.classifier.children())[:-1])
            for param in model.parameters():
                param.requires_grad = True
            for param in model.features[18].parameters():
                param.requires_grad = True
            for param in model.classifier.parameters():
                param.requires_grad = True
            model_trained_mobilenet = model
            print("All trainable parameters of model are")
            for name, param in model_trained_mobilenet.named_parameters():
                if param.requires_grad:
                    print(name, param.shape)
            self.unet = multi_task_model_classification(
                model_trained_mobilenet)

        self.optimizer = optim.AdamW(list(self.unet.parameters()), self.lr,
                                     [self.beta1, self.beta2])
        self.unet.to(self.device)
    def build_model(self):
        """Build generator and discriminator."""
        if self.model_type == 'U_Net':
            self.unet = U_Net(img_ch=1, output_ch=1)
        elif self.model_type == 'R2U_Net':
            self.unet = R2U_Net(img_ch=1, output_ch=1, t=self.t)
        elif self.model_type == 'AttU_Net':
            self.unet = AttU_Net(img_ch=1, output_ch=1)
        elif self.model_type == 'R2AttU_Net':
            self.unet = R2AttU_Net(img_ch=1, output_ch=1, t=self.t)

        if self.optimizer_choice == 'Adam':
            self.optimizer = optim.Adam(list(self.unet.parameters()),
                                        self.initial_lr,
                                        [self.beta1, self.beta2])
        elif self.optimizer_choice == 'SGD':
            self.optimizer = optim.SGD(list(self.unet.parameters()),
                                       self.initial_lr, self.momentum)
        else:
            pass

        self.unet.to(self.device)
Example #12
0
####______________Loading Dataset____________###
train_loader, train_dataset = get_loader('train')
valid_loader, valid_dataset = get_loader('val')

####______________Hyperparameters____________###
num_classes = 20
epochs = 250
lr = 1e-6
b1 = 0.5
b2 = 0.999
decay_ratio = random.random() * 0.8
decay_epoch = int(epochs * decay_ratio)

####______________Model Instance____________###
device = torch.device(f'cuda:{7}' if torch.cuda.is_available() else "cpu")
model = R2U_Net(img_ch=3, output_ch=num_classes)
model.to(device)
optimizer = torch.optim.Adam(list(model.parameters()), lr, [b1, b2])
criterion = nn.NLLLoss2d()

model = torch.nn.DataParallel(model, device_ids=[7, 1, 3])

####___________ERROR HANDLING______________###
#model = torch.nn.DataParallel(model)
##
print(device)

####______________Training____________###
train_loss_values = []
val_loss_values = []
PATH = "./model-R2U-Net.cpt"
Example #13
0
def train(models_path='./saved_models/', batch_size=2, \
    start_epoch=1, epochs=500, n_batches=1000, start_lr=0.0001, save_sample=100):
    Tensor = torch.cuda.FloatTensor

    border = var.BORDER
    window_size = var.WS

    net = R2U_Net(img_ch=3 + 1, t=2)

    if var.LOAD_MODEL_WEIGHTS:
        net.load_state_dict(torch.load(var.MODEL))

    net = net.cuda()

    os.makedirs(models_path, exist_ok=True)

    loss_net_buffer = LossBuffer()
    loss_net_buffer1 = LossBuffer()
    loss_net_buffer2 = LossBuffer()
    loss_net_buffer3 = LossBuffer()
    loss_net_buffer4 = LossBuffer()

    gen_obj = DataLoader(bs=batch_size, nb=n_batches, ws=window_size)

    optimizer_G = optim.Adam(net.parameters(), lr=start_lr)

    align_criterion = AlignLoss(window_size=window_size, border=border)
    align_criterion = align_criterion.cuda()
    bce_criterion = nn.BCELoss()
    bce_criterion = bce_criterion.cuda()

    for epoch in range(start_epoch, epochs):
        loader = gen_obj.generator()
        train_iterator = tqdm(loader, total=n_batches + 1)
        net.train()

        for i, (rgb, gti, miss, mod, inj) in enumerate(train_iterator):
            mod_inj = np.logical_or(mod, inj)
            gti_miss = np.logical_or(gti, miss)

            rgb = Variable(Tensor(rgb))
            gti = Variable(Tensor(gti))
            miss = Variable(Tensor(miss))
            mod = Variable(Tensor(mod))
            inj = Variable(Tensor(inj))
            mod_inj = Variable(Tensor(mod_inj))
            gti_miss = Variable(Tensor(gti_miss))

            rgb = rgb.permute(0, 3, 1, 2)
            gti = gti.permute(0, 3, 1, 2)
            miss = miss.permute(0, 3, 1, 2)
            mod = mod.permute(0, 3, 1, 2)
            inj = inj.permute(0, 3, 1, 2)
            mod_inj = mod_inj.permute(0, 3, 1, 2)
            gti_miss = gti_miss.permute(0, 3, 1, 2)

            # Train Generators
            optimizer_G.zero_grad()

            trs, rot, sca, seg, seg_miss, seg_inj = net(rgb, mod_inj)

            align_loss, proj = align_criterion(rgb, mod, gti, seg_inj, trs,
                                               rot, sca)
            seg_loss = bce_criterion(seg, gti_miss)
            miss_loss = bce_criterion(seg_miss, miss)
            inj_loss = bce_criterion(seg_inj, inj)

            net_loss = align_loss + seg_loss + miss_loss + inj_loss

            net_loss.backward()
            optimizer_G.step()

            status = "[Epoch: %d][loss_net: %2.4f][align: %2.4f, seg: %2.4f, miss: %2.4f, inj: %2.4f]" % (epoch, \
                    loss_net_buffer.push(net_loss.item()), \
                    loss_net_buffer1.push(align_loss.item()), \
                    loss_net_buffer2.push(seg_loss.item()), \
                    loss_net_buffer3.push(miss_loss.item()), \
                    loss_net_buffer4.push(inj_loss.item()), )
            train_iterator.set_description(status)

            if (i % save_sample == 0):
                mask = gti[:, 0, :, :].unsqueeze(1)
                mask = torch.cat((mask, mask), dim=1)
                sample_images(i, rgb, trs,
                              [gti, mod_inj, proj, seg_miss, seg_inj])
        torch.save(
            net.state_dict(),
            os.path.join(models_path, '_'.join(["alignNet",
                                                str(epoch)])))
        out_files = args.output

    return out_files


def mask_to_image(mask):
    return Image.fromarray((mask * 255).astype(np.uint8))


if __name__ == "__main__":
    args = get_args()
    in_files = args.input
    out_files = get_output_filenames(args)

    #net = UNet(n_channels=3, n_classes=1)
    net = R2U_Net()
    logging.info("Loading model {}".format(args.model))

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')
    net.to(device=device)
    net.load_state_dict(torch.load(args.model, map_location=device))

    logging.info("Model loaded !")

    for i, fn in enumerate(in_files):
        logging.info("\nPredicting image {} ...".format(fn))

        img = Image.open(fn)

        mask = predict_img(net=net,