示例#1
0
    def __init__(self, param):
        super(SemanticTrainer, self).__init__()
        lr_d = param['lr_d']
        # Initiate the networks
        self.generator = get_generator(param)

        # Setup the optimizers
        beta1 = param['beta1']
        beta2 = param['beta2']
        self.gen_opt = torch.optim.SGD(params=[{
            'params':
            self.get_params(self.generator, key='1x'),
            'lr':
            param.lr_g
        }, {
            'params':
            self.get_params(self.generator, key='10x'),
            'lr':
            10 * param.lr_g
        }],
                                       momentum=param.momentum)
        self.gen_scheduler = get_scheduler(self.gen_opt, param)
        # self.dis_scheduler = None
        # self.gen_scheduler = None

        # Network weight initialization
        # self.apply(weights_init(param['init']))
        self.best_result = 0

        self.semantic_criterion = nn.CrossEntropyLoss(ignore_index=255)
示例#2
0
    def __init__(self, opts):
        super(CNNModel, self).__init__()

        self.loss_names = []
        self.networks = []
        self.optimizers = []

        # set default loss flags
        loss_flags = ("w_img_BCE")
        for flag in loss_flags:
            if not hasattr(opts, flag): setattr(opts, flag, 0)

        self.is_train = True if hasattr(opts, 'lr') else False

        self.net_G = get_generator(opts.net_G, opts)
        self.networks.append(self.net_G)

        if self.is_train:
            self.loss_names += ['loss_G_BCE']
            self.optimizer_G = torch.optim.Adam(self.net_G.parameters(),
                                                lr=opts.lr,
                                                betas=(opts.beta1, opts.beta2),
                                                weight_decay=opts.weight_decay)
            self.optimizers.append(self.optimizer_G)

        self.criterion = nn.BCEWithLogitsLoss()

        self.opts = opts
示例#3
0
    def __init__(self, param):
        super(Trainer, self).__init__()
        lr_d = param['lr_d']
        # Initiate the networks
        self.generator = get_generator(param)
        self.discriminator_bg = get_discriminator(param)
        self.discriminator_rf = get_discriminator(param)

        # ############################################################################
        # from thop import profile
        # from thop import clever_format
        # input_i = torch.randn(1, 3, 224, 224)
        # macs, params = profile(self.discriminator_bg, inputs=(input_i, ))
        # print('========================')
        # print('MACs: ',   macs)
        # print('PARAMs: ', params)
        # print('------------------------')
        # macs, params = clever_format([macs, params], "%.3f")
        # print('Clever MACs: ',   macs)
        # print('Clever PARAMs: ', params)
        # print('========================')
        # ############################################################################

        # Setup the optimizers
        beta1 = param['beta1']
        beta2 = param['beta2']
        dis_params = list(self.discriminator_bg.parameters()) + list(
            self.discriminator_rf.parameters())
        self.dis_opt = torch.optim.Adam(dis_params,
                                        lr=lr_d,
                                        betas=(beta1, beta2),
                                        weight_decay=param['weight_decay'])
        self.gen_opt = torch.optim.SGD(params=[{
            'params':
            self.get_params(self.generator, key='1x'),
            'lr':
            param.lr_g
        }, {
            'params':
            self.get_params(self.generator, key='10x'),
            'lr':
            10 * param.lr_g
        }],
                                       momentum=param.momentum)
        self.dis_scheduler = get_scheduler(self.dis_opt, param)
        self.gen_scheduler = get_scheduler(self.gen_opt, param)
        # self.dis_scheduler = None
        # self.gen_scheduler = None

        # Network weight initialization
        # self.apply(weights_init(param['init']))
        self.discriminator_bg.apply(weights_init('gaussian'))
        self.discriminator_rf.apply(weights_init('gaussian'))
        self.best_result = float('inf')

        self.perceptual_criterion = PerceptualLoss()
        self.retina_criterion = RetinaLoss()
        self.semantic_criterion = nn.CrossEntropyLoss(ignore_index=255)

        self.best_result = 0
示例#4
0
    def build_model(self):
        self.G = networks.get_generator(encoder=self.model_config.arch.encoder,
                                        decoder=self.model_config.arch.decoder)
        if not self.test_config.cpu:
            self.G.cuda()

        if self.test_config.fp16:
            self.G = self.G.half()
示例#5
0
    def __init__(self, opts):
        super(RecurrentModel, self).__init__()

        self.loss_names = []
        self.networks = []
        self.optimizers = []

        self.n_recurrent = opts.n_recurrent

        # set default loss flags
        loss_flags = ("w_img_L1")
        for flag in loss_flags:
            if not hasattr(opts, flag): setattr(opts, flag, 0)

        self.is_train = True if hasattr(opts, 'lr') else False

        self.net_G_I = get_generator(opts.net_G, opts)
        self.net_G_K = get_generator(opts.net_G, opts)
        self.networks.append(self.net_G_I)
        self.networks.append(self.net_G_K)

        if self.is_train:
            self.loss_names += ['loss_G_L1']
            param = list(self.net_G_I.parameters()) + list(self.net_G_K.parameters())
            self.optimizer_G = torch.optim.Adam(param,
                                                lr=opts.lr,
                                                betas=(opts.beta1, opts.beta2),
                                                weight_decay=opts.weight_decay)
            self.optimizers.append(self.optimizer_G)

        self.criterion = nn.L1Loss()

        self.opts = opts

        # data consistency layers in image space & k-space
        dcs_I = []
        for i in range(self.n_recurrent):
            dcs_I.append(DataConsistencyInKspace_I(noise_lvl=None))
        self.dcs_I = dcs_I

        dcs_K = []
        for i in range(self.n_recurrent):
            dcs_K.append(DataConsistencyInKspace_K(noise_lvl=None))
        self.dcs_K = dcs_K
示例#6
0
    def _build_model(self):
        """Building the model."""

        if not self.alpha_dir:
            self.gca = networks.get_generator(encoder='resnet_gca_encoder_29',
                                              decoder='res_gca_decoder_22')
            self.gca = self.gca.eval()
        self.rim = networks.SpectralRIM(sigma=0.1)
        self.rim = self.rim.eval()

        if not self.cpu:
            if not self.alpha_dir:
                self.gca.cuda()
            self.rim.cuda()
def create_generator(opt):
    # Initialize the network
    fpn_load_name, generator = networks.get_generator(opt)
    if opt.pre_train:
        # Init the network
        generator = networks.init_generator(generator, opt)
        print('Initialize network with %s type' % opt.init_type)
        pretrained_net = torch.load(fpn_load_name)
        load_dict(generator.backbone, pretrained_net)
        print('Generator is created!')
    else:
        # Load a pre-trained network
        pretrained_net = torch.load(opt.load_name)
        load_dict(generator, pretrained_net)
        print('Generator is loaded!')
    return generator
示例#8
0
    def __init__(self, model_dir, model_config):
        print("Loading GCA-Matting...")
        with open(model_config) as f:
            utils.load_config(toml.load(f))

        self.net = networks.get_generator(encoder=CONFIG.model.arch.encoder,
                                          decoder=CONFIG.model.arch.decoder)

        if torch.cuda.is_available():
            checkpoint = torch.load(model_dir)
            self.net.load_state_dict(utils.remove_prefix_state_dict(
                checkpoint['state_dict']),
                                     strict=True)
            self.net.cuda()
        else:
            checkpoint = torch.load(model_dir,
                                    map_location=torch.device('cpu'))
            self.net.load_state_dict(utils.remove_prefix_state_dict(
                checkpoint['state_dict']),
                                     strict=True)
        self.net.eval()
    def __init__(self, opts):
        super(GANModel, self).__init__()

        self.netG = get_generator(opts.net_G, opts)
        self.netD = get_discriminator(opts.net_D, opts)
        self.Sobelx, self.Sobely = get_gradoperator(opts)

        self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                            lr=opts.lr,
                                            betas=(opts.beta1, opts.beta2),
                                            weight_decay=opts.weight_decay)
        self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                            lr=opts.lr,
                                            betas=(opts.beta1, opts.beta2),
                                            weight_decay=opts.weight_decay)

        self.criterion_GAN = LSGANLoss().cuda(opts.gpu_ids[0])
        self.criterion_recon = nn.L1Loss().cuda(opts.gpu_ids[0])
        self.wr_recon = opts.wr_recon

        self.loss_names = ['loss_D', 'loss_G_GAN', 'loss_G_recon']
        self.opts = opts
示例#10
0
    def build_model(self):

        self.G = networks.get_generator(encoder=self.model_config.arch.encoder,
                                        decoder=self.model_config.arch.decoder)
        self.G.cuda()

        if CONFIG.dist:
            self.logger.info("Using pytorch synced BN")
            self.G = SyncBatchNorm.convert_sync_batchnorm(self.G)

        self.G_optimizer = torch.optim.Adam(
            self.G.parameters(),
            lr=self.train_config.G_lr,
            betas=[self.train_config.beta1, self.train_config.beta2])

        if CONFIG.dist:
            # SyncBatchNorm only supports DistributedDataParallel with single GPU per process
            self.G = DistributedDataParallel(self.G,
                                             device_ids=[CONFIG.local_rank],
                                             output_device=CONFIG.local_rank)
        else:
            self.G = nn.DataParallel(self.G)

        self.build_lr_scheduler()
示例#11
0
    def __init__(self):
        with open(
                os.path.join(os.path.dirname(__file__),
                             'config/gca-dist-all-data.toml')) as f:
            utils.load_config(toml.load(f))

        # Check if toml config file is loaded
        if CONFIG.is_default:
            raise ValueError("No .toml config loaded.")

        # build model
        self.model = networks.get_generator(encoder=CONFIG.model.arch.encoder,
                                            decoder=CONFIG.model.arch.decoder)
        # self.model.cuda()

        # load checkpoint
        checkpoint = torch.load(
            os.path.join(os.path.dirname(__file__), 'gca-dist-all-data.pth'))
        self.model.load_state_dict(utils.remove_prefix_state_dict(
            checkpoint['state_dict']),
                                   strict=True)

        # inference
        self.model.eval()
示例#12
0
    parser.add_argument('--TTA', type=bool, default=True, help="testing time augmentation")

    # Parse configuration
    args = parser.parse_args()
    with open(args.config) as f:
        utils.load_config(toml.load(f))

    # Check if toml config file is loaded
    if CONFIG.is_default:
        raise ValueError("No .toml config loaded.")

    args.output = os.path.join(args.output, CONFIG.version+'_'+args.checkpoint.split('/')[-1])
    utils.make_dir(args.output)

    # build model
    model = networks.get_generator(encoder=CONFIG.model.arch.encoder, decoder=CONFIG.model.arch.decoder)
    model.cuda()

    # load checkpoint
    checkpoint = torch.load(args.checkpoint)
    model.load_state_dict(utils.remove_prefix_state_dict(checkpoint['state_dict']), strict=True)
    
    # inference
    model = model.eval()

    for image_name in os.listdir(args.image_dir):
        # assume image and trimap have the same file name
        image_path = os.path.join(args.image_dir, image_name)
        trimap_path = os.path.join(args.trimap_dir, os.path.splitext(image_name)[0]+".png")
        # trimap_path = os.path.join(args.trimap_dir, os.path.splitext(image_name)[0][:-5]+"trimap.png")
        print('Image: ', image_path, ' Tirmap: ', trimap_path)
示例#13
0
    for key, tensor in tensor_dict.items():
        if tensor is not None:
            tensor_dict[key] = reduce_tensor(tensor, mode)
    return tensor_dict


def reduce_tensor(tensor, mode='mean'):
    """
    average tensor over different GPUs
    """
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
    if mode == 'mean':
        rt /= CONFIG.world_size
    elif mode == 'sum':
        pass
    else:
        raise NotImplementedError("reduce mode can only be 'mean' or 'sum'")
    return rt


if __name__ == "__main__":
    import networks
    logging.basicConfig(level=logging.DEBUG, format='[%(asctime)s] %(levelname)s: %(message)s',
                        datefmt='%m-%d %H:%M:%S')
    G = networks.get_generator().cuda()
    load_imagenet_pretrain(G, CONFIG.model.imagenet_pretrain_path)
    x = torch.randn(4,3,512,512).cuda()
    y = torch.randn(4,3,512,512).cuda()
    z = G(x, y)