示例#1
0
    def get_MCD_model_list():
        if net_name == "fcn":
            from models.fcn import ResBase, ResClassifier
            model_g = ResBase(n_class, layer=res, input_ch=input_ch)
            model_f1 = ResClassifier(n_class)
            model_f2 = ResClassifier(n_class)
        elif net_name == "fcnvgg":
            from models.vgg_fcn import FCN8sBase, FCN8sClassifier
            model_g = FCN8sBase(n_class)
            model_f1 = FCN8sClassifier(n_class)
            model_f2 = FCN8sClassifier(n_class)
        elif "drn" in net_name:
            from models.dilated_fcn import DRNSegBase, DRNSegPixelClassifier_ADR
            if uses_one_classifier:
                model_g = DRNSegBase(model_name=net_name,
                                     n_class=n_class,
                                     input_ch=input_ch)
                model_f1 = DRNSegPixelClassifier_ADR(n_class=n_class)
                model_f2 = DRNSegPixelClassifier_ADR(n_class=n_class)
            else:
                from models.dilated_fcn import DRNSegBase, DRNSegPixelClassifier
                model_g = DRNSegBase(model_name=net_name,
                                     n_class=n_class,
                                     input_ch=input_ch)
                model_f1 = DRNSegPixelClassifier(n_class=n_class)
                model_f2 = DRNSegPixelClassifier(n_class=n_class)

        else:
            raise NotImplementedError(
                "Only FCN (Including Dilated FCN), SegNet, PSPNetare supported!"
            )

        return model_g, model_f1, model_f2
def get_multichannel_model(net_name,
                           input_ch_list,
                           n_class,
                           method="MCD",
                           res="50",
                           is_data_parallel=True):
    from models.dilated_fcn import FusionDRNSegPixelClassifier

    frond_model_list = []

    if "drn" in net_name:
        from models.dilated_fcn import DRNSegBase, DRNSegPixelClassifier
        fusion_type = method.split("-")[-1]
        ver = "ver2" if "ver2" in net_name else "ver1"
        drn_name = net_name.replace("_ver2", "")

        for input_ch in input_ch_list:
            frond_model_list.append(
                DRNSegBase(model_name=drn_name,
                           n_class=n_class,
                           input_ch=input_ch,
                           ver=ver))

        end_model = FusionDRNSegPixelClassifier(fusion_type=fusion_type,
                                                n_class=n_class,
                                                ver=ver)
    else:
        raise NotImplementedError("Only DRN are supported!")

    if is_data_parallel:
        return [torch.nn.DataParallel(x)
                for x in frond_model_list], torch.nn.DataParallel(end_model)

    return frond_model_list, end_model
示例#3
0
    def build_model(self):
        """Build generator and discriminator."""

        from models.dilated_fcn import DRNSegBase, DRNSegPixelClassifier, DRNSegDomainClassifier

        self.model_g = DRNSegBase(model_name=self.config.net, n_class=self.config.n_class)
        self.model_f = DRNSegPixelClassifier(n_class=self.config.n_class)
        self.model_d = DRNSegDomainClassifier(n_class=self.config.n_class)

        self.optimizer_g = torch.optim.SGD(self.model_g.parameters(), lr=self.config.lr, momentum=self.config.momentum,
                                           weight_decay=self.config.weight_decay)
        self.optimizer_d = torch.optim.SGD(self.model_d.parameters(), lr=self.config.lr, momentum=self.config.momentum,
                                           weight_decay=self.config.weight_decay)
        self.optimizer_f = torch.optim.SGD(list(self.model_f1.parameters()), lr=self.config.lr,
                                           momentum=self.config.momentum,
                                           weight_decay=self.config.weight_decay)

        if torch.cuda.is_available():
            self.generator.cuda()
            self.discriminator.cuda()
    def get_mfnet_model_list():
        from models.unet import MultiUNetClassifier

        assert input_ch in [4, 6]
        if "drn" in net_name:
            ver = "ver2" if "ver2" in net_name else "ver1"

            use_score_fusion = True if "score" in method.lower() else False

            drn_name = net_name.replace("_ver2", "")

            from models.dilated_fcn import DRNSegBase, DRNSegPixelClassifier, ScoreFusionDRNSegPixelClassifier
            fusion_type = method.split("-")[-1]

            print("fusion type: %s" % fusion_type)

            model_g_3ch = DRNSegBase(model_name=drn_name,
                                     n_class=n_class,
                                     input_ch=3,
                                     ver=ver)
            model_g_1ch = DRNSegBase(model_name=drn_name,
                                     n_class=n_class,
                                     input_ch=input_ch - 3,
                                     ver=ver)

            if use_score_fusion:
                print("Score Fusion!!!")
                model_f1 = ScoreFusionDRNSegPixelClassifier(
                    fusion_type=fusion_type, n_class=n_class)
                model_f2 = ScoreFusionDRNSegPixelClassifier(
                    fusion_type=fusion_type, n_class=n_class)
            else:
                from models.dilated_fcn import FusionDRNSegPixelClassifier
                model_f1 = FusionDRNSegPixelClassifier(fusion_type=fusion_type,
                                                       n_class=n_class,
                                                       ver=ver)
                model_f2 = FusionDRNSegPixelClassifier(fusion_type=fusion_type,
                                                       n_class=n_class,
                                                       ver=ver)

        elif net_name == "unet":
            # TODO add "input_ch" argument
            from models.unet import UNetBase, UNetClassifier
            model_g_3ch = UNetBase(input_ch=3)
            model_g_1ch = UNetBase(input_ch=input_ch - 3)
            model_f1 = MultiUNetClassifier(n_classes=n_class)
            model_f2 = MultiUNetClassifier(n_classes=n_class)

        elif net_name == "fcn":
            # TODO add "input_ch" argument
            from models.fcn import ResBase, MFResClassifier2
            model_g_3ch = ResBase(num_classes=n_class, layer=res, input_ch=3)
            model_g_1ch = ResBase(num_classes=n_class,
                                  layer=res,
                                  input_ch=input_ch - 3)
            model_f1 = MFResClassifier2(n_class=n_class)
            model_f2 = MFResClassifier2(n_class=n_class)

        else:
            raise NotImplementedError("Only Dilated FCN is supported!")

        return model_g_3ch, model_g_1ch, model_f1, model_f2
    def get_MCD_model_list():
        if net_name == "fcn":
            from models.fcn import ResBase, ResClassifier
            model_g = ResBase(n_class, layer=res, input_ch=input_ch)
            model_f1 = ResClassifier(n_class)
            model_f2 = ResClassifier(n_class)
        elif net_name == "fcnvgg":
            from models.vgg_fcn import FCN8sBase, FCN8sClassifier
            # TODO implement input_ch
            model_g = FCN8sBase(n_class)
            model_f1 = FCN8sClassifier(n_class)
            model_f2 = FCN8sClassifier(n_class)
        elif net_name == "psp":
            # TODO add "input_ch" argument
            from models.pspnet import PSPBase, PSPClassifier
            model_g = PSPBase(layer=res, input_ch=input_ch)
            model_f1 = PSPClassifier(num_classes=n_class)
            model_f2 = PSPClassifier(num_classes=n_class)
        elif net_name == "segnet":
            # TODO add "input_ch" argument
            from models.segnet import SegNetBase, SegNetClassifier
            model_g = SegNetBase()
            model_f1 = SegNetClassifier(n_class)
            model_f2 = SegNetClassifier(n_class)

        elif "drn" in net_name:
            if "fusenet" in net_name:
                drn_name = net_name.replace("_fusenet", "")

                from models.dilated_fcn import FuseDRNSegBase, DRNSegPixelClassifier
                model_g = FuseDRNSegBase(model_name=drn_name,
                                         n_class=n_class,
                                         input_ch=input_ch)
                model_f1 = DRNSegPixelClassifier(n_class=n_class)
                model_f2 = DRNSegPixelClassifier(n_class=n_class)

            else:
                ver = "ver2" if "ver2" in net_name else "ver1"
                drn_name = net_name.replace("_ver2", "")

                from models.dilated_fcn import DRNSegBase, DRNSegPixelClassifier
                model_g = DRNSegBase(model_name=drn_name,
                                     n_class=n_class,
                                     input_ch=input_ch,
                                     ver=ver)
                model_f1 = DRNSegPixelClassifier(n_class=n_class, ver=ver)
                model_f2 = DRNSegPixelClassifier(n_class=n_class, ver=ver)
        elif net_name == "unet":
            # TODO add "input_ch" argument
            from models.unet import UNetBase, UNetClassifier
            model_g = UNetBase(input_ch=input_ch)
            model_f1 = UNetClassifier(n_classes=n_class)
            model_f2 = UNetClassifier(n_classes=n_class)

        elif net_name == "fusenet":
            from models.FuseNet import FuseBase, FuseClassifier
            model_g = FuseBase(input_ch=input_ch)
            model_f1 = FuseClassifier(n_class=n_class)
            model_f2 = FuseClassifier(n_class=n_class)

        else:
            raise NotImplementedError(
                "Only FCN (Including Dilated FCN), SegNet, PSPNet UNet are supported!"
            )

        return model_g, model_f1, model_f2
    # model_g = torch.nn.DataParallel(PSPBase(layer=args.res, input_ch=args.input_ch))
    model_g = PSPBase(layer=args.res, input_ch=args.input_ch)
    model_f = torch.nn.DataParallel(PSPClassifier(num_classes=args.n_class))
    model_d = torch.nn.DataParallel(Discriminator())
elif args.net == "segnet":
    # TODO add "input_ch" argument
    from models.segnet import SegNetBase, SegNetClassifier

    model_g = torch.nn.DataParallel(SegNetBase())
    model_f = torch.nn.DataParallel(SegNetClassifier(args.n_class))
    model_d = torch.nn.DataParallel(Discriminator())
elif "drn" in args.net:
    from models.dilated_fcn import DRNSegBase, DRNSegPixelClassifier, DRNSegDomainClassifier

    model_g = DRNSegBase(model_name=args.net,
                         n_class=args.n_class,
                         input_ch=args.input_ch)
    model_f = DRNSegPixelClassifier(n_class=args.n_class)
    model_d = DRNSegDomainClassifier(n_class=args.n_class)

else:
    raise NotImplementedError("Only FCN, SegNet, PSPNet are supported!")

# if args.opt == 'sgd':
#     optimizer_g = torch.optim.SGD(model_g.parameters(), lr=args.lr, momentum=args.momentum,
#                                   weight_decay=args.weight_decay)
#     optimizer_d = torch.optim.SGD(model_d.parameters(), lr=args.lr, momentum=args.momentum,
#                                   weight_decay=args.weight_decay)
#     optimizer_f = torch.optim.SGD(list(model_f.parameters()), lr=args.lr,
#                                   momentum=args.momentum,
#                                   weight_decay=args.weight_decay)
示例#7
0
class Solver(object):
    def __init__(self, args, data_loader):
        self.generator = None
        self.discriminator = None
        self.g_optimizer = None
        self.d_optimizer = None
        self.build_model()
        self.config = args

    def build_model(self):
        """Build generator and discriminator."""

        from models.dilated_fcn import DRNSegBase, DRNSegPixelClassifier, DRNSegDomainClassifier

        self.model_g = DRNSegBase(model_name=self.config.net, n_class=self.config.n_class)
        self.model_f = DRNSegPixelClassifier(n_class=self.config.n_class)
        self.model_d = DRNSegDomainClassifier(n_class=self.config.n_class)

        self.optimizer_g = torch.optim.SGD(self.model_g.parameters(), lr=self.config.lr, momentum=self.config.momentum,
                                           weight_decay=self.config.weight_decay)
        self.optimizer_d = torch.optim.SGD(self.model_d.parameters(), lr=self.config.lr, momentum=self.config.momentum,
                                           weight_decay=self.config.weight_decay)
        self.optimizer_f = torch.optim.SGD(list(self.model_f1.parameters()), lr=self.config.lr,
                                           momentum=self.config.momentum,
                                           weight_decay=self.config.weight_decay)

        if torch.cuda.is_available():
            self.generator.cuda()
            self.discriminator.cuda()

    def to_variable(self, x):
        """Convert tensor to variable."""
        if torch.cuda.is_available():
            x = x.cuda()
        return Variable(x)

    def to_data(self, x):
        """Convert variable to tensor."""
        if torch.cuda.is_available():
            x = x.cpu()
        return x.data

    def reset_grad(self):
        """Zero the gradient buffers."""
        self.discriminator.zero_grad()
        self.generator.zero_grad()

    def denorm(self, x):
        """Convert range (-1, 1) to (0, 1)"""
        out = (x + 1) / 2
        return out.clamp(0, 1)

    def train(self):
        """Train generator and discriminator."""

        src_domain_lbl = Variable(torch.ones(args.batch_size).long())
        tgt_domain_lbl = Variable(torch.zeros(args.batch_size).long())

        for epoch in range(args.start_epoch, args.epochs):
            d_loss_per_epoch = 0
            c_loss_per_epoch = 0

            for ind, (source, target) in tqdm.tqdm(enumerate(train_loader)):
                src_imgs, src_lbls = Variable(source[0]), Variable(source[1])
                tgt_imgs = Variable(target[0])

                if torch.cuda.is_available():
                    src_imgs, src_lbls, tgt_imgs = src_imgs.cuda(), src_lbls.cuda(), tgt_imgs.cuda()
                    src_domain_lbl, tgt_domain_lbl = src_domain_lbl.cuda(), tgt_domain_lbl.cuda()

                # update generator and classifiers by source samples
                self.optimizer_g.zero_grad()
                self.optimizer_f.zero_grad()

                src_fet = self.model_g(src_imgs)
                tgt_fet = self.model_g(tgt_imgs)

                # for k, v in outputs.items():
                #     try:
                #         print ("%s: %s" % (k, v.size()))
                #     except AttributeError:
                #         print ("%s: %s" % (k, v))
                if "drn" in args.net:
                    src_domain_pred = self.model_d(src_fet)
                    tgt_domain_pred = self.model_d(tgt_fet)
                else:
                    src_domain_pred = self.model_d(src_fet["fm4"])
                    tgt_domain_pred = self.model_d(tgt_fet["fm4"])

                loss_d = - criterion_d(src_domain_pred, src_domain_lbl)
                loss_d -= criterion_d(tgt_domain_pred, tgt_domain_lbl)

                src_out = self.model_f(src_fet)
                loss = criterion(src_out, src_lbls)

                c_loss = loss.data[0]
                loss += loss_d
                loss.backward()
                c_loss_per_epoch += c_loss

                self.optimizer_g.step()
                self.optimizer_f.step()

                # update for classifiers
                self.optimizer_g.zero_grad()
                self.optimizer_f.zero_grad()

                src_fet = self.model_g(src_imgs)
                tgt_fet = self.model_g(tgt_imgs)
                if "drn" in args.net:
                    src_domain_pred = self.model_d(src_fet)
                    tgt_domain_pred = self.model_d(tgt_fet)
                else:
                    src_domain_pred = self.model_d(src_fet["fm4"])
                    tgt_domain_pred = self.model_d(tgt_fet["fm4"])
                loss_d = criterion_d(src_domain_pred, src_domain_lbl)
                loss_d += criterion_d(tgt_domain_pred, tgt_domain_lbl)
                loss_d.backward()
                self.optimizer_d.step()
                self.optimizer_d.zero_grad()

                d_loss = 0
                d_loss += loss_d.data[0] / args.num_k
                d_loss_per_epoch += d_loss
                if ind % 100 == 0:
                    print("iter [%d] DLoss: %.6f CLoss: %.4f" % (ind, d_loss, c_loss))

                if ind > args.max_iter:
                    break

            print("Epoch [%d] DLoss: %.4f CLoss: %.4f" % (epoch, d_loss_per_epoch, c_loss_per_epoch))
            # ploter.plot("c_loss", "train", epoch + 1, c_loss_per_epoch)
            # ploter.plot("d_loss", "train", epoch + 1, d_loss_per_epoch)
            log_value('c_loss', c_loss_per_epoch, epoch)
            log_value('d_loss', d_loss_per_epoch, epoch)
            log_value('lr', args.lr, epoch)

            if args.adjust_lr:
                args.lr = adjust_learning_rate(self.optimizer_g, args.lr, args.weight_decay, epoch, args.epochs)
                args.lr = adjust_learning_rate(self.optimizer_f, args.lr, args.weight_decay, epoch, args.epochs)

            checkpoint_fn = os.path.join(pth_dir,
                                         "%s-%s-res%s-%s.pth.tar" % (args.savename, args.net, args.res, epoch + 1))
            save_dic = {
                'epoch': epoch + 1,
                'res': args.res,
                'net': args.net,
                'args': args,
                'g_state_dict': self.model_g.state_dict(),
                'f1_state_dict': self.model_f.state_dict(),
                'd_state_dict': self.model_d.state_dict(),
                'self.optimizer_g': self.optimizer_g.state_dict(),
                'self.optimizer_f': self.optimizer_f.state_dict(),
                'self.optimizer_d': self.optimizer_d.state_dict(),
            }

            save_checkpoint(save_dic, is_best=False, filename=checkpoint_fn)

    def sample(self):

        # Load trained parameters
        g_path = os.path.join(self.model_path, 'generator-%d.pkl' % (self.num_epochs))
        d_path = os.path.join(self.model_path, 'discriminator-%d.pkl' % (self.num_epochs))
        self.generator.load_state_dict(torch.load(g_path))
        self.discriminator.load_state_dict(torch.load(d_path))
        self.generator.eval()
        self.discriminator.eval()

        # Sample the images
        noise = self.to_variable(torch.randn(self.sample_size, self.z_dim))
        fake_images = self.generator(noise)
        sample_path = os.path.join(self.sample_path, 'fake_samples-final.png')
        torchvision.utils.save_image(self.denorm(fake_images.data), sample_path, nrow=12)

        print("Saved sampled images to '%s'" % sample_path)