Esempio n. 1
0
def define_G(input_nc,
             output_nc,
             ngf,
             netG,
             norm='batch',
             dropout_rate=0,
             init_type='normal',
             init_gain=0.02,
             gpu_ids=[],
             opt=None):
    norm_layer = get_norm_layer(norm_type=norm)
    if netG == 'resnet_9blocks':
        from .modules.resnet_architecture.resnet_generator import ResnetGenerator
        net = ResnetGenerator(input_nc,
                              output_nc,
                              ngf,
                              norm_layer=norm_layer,
                              dropout_rate=dropout_rate,
                              n_blocks=9)
    elif netG == 'mobile_resnet_9blocks':
        from .modules.resnet_architecture.mobile_resnet_generator import MobileResnetGenerator
        net = MobileResnetGenerator(input_nc,
                                    output_nc,
                                    ngf=ngf,
                                    norm_layer=norm_layer,
                                    dropout_rate=dropout_rate,
                                    n_blocks=9)
    elif netG == 'super_mobile_resnet_9blocks':
        from .modules.resnet_architecture.super_mobile_resnet_generator import SuperMobileResnetGenerator
        net = SuperMobileResnetGenerator(input_nc,
                                         output_nc,
                                         ngf=ngf,
                                         norm_layer=norm_layer,
                                         dropout_rate=dropout_rate,
                                         n_blocks=9)
    elif netG == 'sub_mobile_resnet_9blocks':
        from .modules.resnet_architecture.sub_mobile_resnet_generator import SubMobileResnetGenerator
        assert opt.config_str is not None
        config = decode_config(opt.config_str)
        net = SubMobileResnetGenerator(input_nc,
                                       output_nc,
                                       config,
                                       norm_layer=norm_layer,
                                       dropout_rate=dropout_rate,
                                       n_blocks=9)
    elif netG == 'spade':
        from .modules.spade_architecture.spade_generator import SPADEGenerator
        net = SPADEGenerator(opt)
    elif netG == 'mobile_spade':
        from .modules.spade_architecture.mobile_spade_generator import MobileSPADEGenerator
        net = MobileSPADEGenerator(opt)
    elif netG == 'sub_mobile_spade':
        from .modules.spade_architecture.sub_mobile_spade_generator import SubMobileSPADEGenerator
        assert opt.config_str is not None
        config = decode_config((opt.config_str))
        net = SubMobileSPADEGenerator(opt, config)
    else:
        raise NotImplementedError(
            'Generator model name [%s] is not recognized' % netG)
    return init_net(net, init_type, init_gain, gpu_ids)
Esempio n. 2
0
def define_G(netG, **kwargs):
    Generator = get_netG_cls(netG)
    if netG in [
            'resnet_9blocks', 'mobile_resnet_9blocks',
            'super_mobile_resnet_9blocks'
    ]:
        assert 'input_nc' in kwargs and 'output_nc' in kwargs and 'ngf' in kwargs
        input_nc = kwargs.get('input_nc')
        output_nc = kwargs.get('output_nc')
        ngf = kwargs.get('ngf')
        dropout_rate = kwargs.get('dropout_rate', 0)
        norm = kwargs.get('norm', 'batch')
        norm_layer = get_norm_layer(norm_type=norm)
        net = Generator(input_nc,
                        output_nc,
                        ngf=ngf,
                        norm_layer=norm_layer,
                        dropout_rate=dropout_rate,
                        n_blocks=9)
    elif netG in [
            'sub_mobile_resnet_9blocks', 'legacy_sub_mobile_resnet_9blocks'
    ]:
        assert 'input_nc' in kwargs and 'output_nc' in kwargs and 'opt' in kwargs
        input_nc = kwargs.get('input_nc')
        output_nc = kwargs.get('output_nc')
        dropout_rate = kwargs.get('dropout_rate', 0)
        norm = kwargs.get('norm', 'batch')
        opt = kwargs.get('opt')
        norm_layer = get_norm_layer(norm_type=norm)
        assert opt.config_str is not None
        config = decode_config(opt.config_str)
        net = Generator(input_nc,
                        output_nc,
                        config,
                        norm_layer=norm_layer,
                        dropout_rate=dropout_rate,
                        n_blocks=9)
    elif netG in [
            'spade', 'mobile_spade', 'super_mobile_spade', 'munit',
            'super_munit', 'super_mobile_munit'
    ]:
        assert 'opt' in kwargs
        opt = kwargs.get('opt')
        net = Generator(opt)
    elif netG in ['sub_mobile_spade', 'sub_mobile_munit']:
        assert 'opt' in kwargs
        opt = kwargs.get('opt')
        assert opt.config_str is not None
        config = decode_config(opt.config_str)
        net = Generator(opt, config)
    else:
        raise NotImplementedError(
            'Generator model name [%s] is not recognized' % netG)
    init_type = kwargs.get('init_type', 'normal')
    init_gain = kwargs.get('init_gain', 0.02)
    gpu_ids = kwargs.get('gpu_ids', [])
    return init_net(net, init_type, init_gain, gpu_ids)
Esempio n. 3
0
def main(opt):
    config = decode_config(opt.config_str)
    if opt.model == 'mobile_resnet':
        from models.modules.resnet_architecture.mobile_resnet_generator import MobileResnetGenerator as SuperModel
        from models.modules.resnet_architecture.sub_mobile_resnet_generator import SubMobileResnetGenerator as SubModel
        input_nc, output_nc = opt.input_nc, opt.output_nc
        super_model = SuperModel(input_nc,
                                 output_nc,
                                 ngf=opt.ngf,
                                 norm_layer=nn.InstanceNorm2d,
                                 n_blocks=9)
        sub_model = SubModel(input_nc,
                             output_nc,
                             config=config,
                             norm_layer=nn.InstanceNorm2d,
                             n_blocks=9)
    elif opt.model == 'mobile_spade':
        from models.modules.spade_architecture.mobile_spade_generator import MobileSPADEGenerator as SuperModel
        from models.modules.spade_architecture.sub_mobile_spade_generator import SubMobileSPADEGenerator as SubModel
        opt.norm_G = 'spadesyncbatch3x3'
        opt.num_upsampling_layers = 'more'
        opt.semantic_nc = opt.input_nc + (1 if opt.contain_dontcare_label else
                                          0) + (0 if opt.no_instance else 1)
        super_model = SuperModel(opt)
        sub_model = SubModel(opt, config)
    else:
        raise NotImplementedError('Unknown architecture [%s]!' % opt.model)

    load_network(super_model, opt.input_path)
    transfer_weight(super_model, sub_model)

    output_dir = os.path.dirname(opt.output_path)
    os.makedirs(output_dir, exist_ok=True)
    torch.save(sub_model.state_dict(), opt.output_path)
    print('Successfully export the subnet at [%s].' % opt.output_path)
Esempio n. 4
0
def main(opt):
    if opt.model == 'mobile_resnet':
        from models.modules.resnet_architecture.mobile_resnet_generator import MobileResnetGenerator as SuperModel
        from models.modules.resnet_architecture.sub_mobile_resnet_generator import SubMobileResnetGenerator as SubModel
    elif opt.model == 'mobile_spade':
        # TODO
        raise NotImplementedError
    else:
        raise NotImplementedError('Unknown architecture [%s]!' % opt.model)

    config = decode_config(opt.config_str)

    input_nc, output_nc = opt.input_nc, opt.output_nc
    super_model = SuperModel(input_nc,
                             output_nc,
                             ngf=opt.ngf,
                             norm_layer=nn.InstanceNorm2d,
                             n_blocks=9)
    sub_model = SubModel(input_nc,
                         output_nc,
                         config=config,
                         norm_layer=nn.InstanceNorm2d,
                         n_blocks=9)

    load_network(super_model, opt.input_path)
    transfer_weight(super_model, sub_model)

    output_dir = os.path.dirname(opt.output_path)
    os.makedirs(output_dir, exist_ok=True)
    torch.save(sub_model.state_dict(), opt.output_path)
    print('Successfully export the subnet at [%s].' % opt.output_path)
Esempio n. 5
0
 def __init__(self, opt):
     assert 'super' in opt.student_netG
     super(MunitSupernet, self).__init__(opt)
     self.best_fid_largest = 1e9
     self.best_fid_smallest = 1e9
     self.fids_largest, self.fids_smallest = [], []
     if opt.config_set is not None:
         assert opt.config_str is None
         self.configs = get_configs(opt.config_set)
         self.opt.eval_mode = 'both'
     else:
         assert opt.config_str is not None
         self.configs = SingleConfigs(decode_config(opt.config_str))
         self.opt.eval_mode = 'largest'
Esempio n. 6
0
def main(cfgs):
    fluid.enable_imperative()
    config = decode_config(args.config_str)
    if args.model == 'mobile_resnet':
        from model.mobile_generator import MobileResnetGenerator as SuperModel
        from model.sub_mobile_generator import SubMobileResnetGenerator as SubModel
        input_nc, output_nc = args.input_nc, args.output_nc
        super_model = SuperModel(input_nc, output_nc, ngf=args.ngf, norm_layer=InstanceNorm, n_blocks=9)
        sub_model = SubModel(input_nc, output_nc, config=config, norm_layer=InstanceNorm, n_blocks=9)
    else:
        raise NotImplementedError

    load_network(super_model, args.input_path)
    transfer_weight(super_model, sub_model)

    if not os.path.exists(args.output_path):
        os.makedirs(args.output_path)
    save_path = os.path.join(args.output_path, 'final_net')
    fluid.save_dygraph(sub_model.state_dict(), save_path)
    print('Successfully export the subnet at [%s].' % save_path)
Esempio n. 7
0
 def __init__(self, opt):
     super(SPADESupernet, self).__init__(opt)
     self.best_fid_largest = 1e9
     self.best_fid_smallest = 1e9
     self.best_mIoU_largest = -1e9
     self.best_mIoU_smallest = -1e9
     self.fids_largest, self.fids_smallest = [], []
     self.mIoUs_largest, self.mIoUs_smallest = [], []
     if opt.config_set is not None:
         assert opt.config_str is None
         self.configs = get_configs(opt.config_set)
         self.opt.eval_mode = 'both'
         self.opt.no_calibration = False
     else:
         assert opt.config_str is not None
         self.configs = SingleConfigs(decode_config(opt.config_str))
         self.opt.eval_mode = 'largest'
         self.opt.no_calibration = True
     opt = copy.deepcopy(opt)
     opt.load_in_memory = False
     opt.max_dataset_size = 256
     self.train_dataloader = create_dataloader(opt, verbose=False)
Esempio n. 8
0
    assert opt.preprocess == 'resize_and_crop'
    assert opt.batch_size == 1

    if not opt.no_fid:
        assert opt.real_stat_path is not None
    if opt.phase == 'train':
        warnings.warn('You are using training set for inference.')


if __name__ == '__main__':
    opt = TestOptions().parse()
    print(' '.join(sys.argv))
    set_seed(opt.seed)
    if opt.config_str is not None:
        assert 'super' in opt.netG or 'sub' in opt.netG
        config = decode_config(opt.config_str)
    else:
        assert 'super' not in opt.model
        config = None

    dataloader = create_dataloader(opt)
    model = create_model(opt)
    model.setup(opt, verbose=False)

    web_dir = opt.results_dir  # define the website directory
    if opt.model == 'munit_test':
        webpage = html.HTML(
            web_dir, 'G_A_path: %s\tG_B_path: %s' %
            (opt.restore_G_A_path, opt.restore_G_B_path))
    else:
        webpage = html.HTML(web_dir, 'G_path: %s' % (opt.restore_G_path))