def __init__(self, opt):
        self.opt = opt
        if 'resnet' in opt.netG:
            from configs.resnet_configs import get_configs
        elif 'spade' in opt.netG:
            from configs.spade_configs import get_configs
        elif 'munit' in opt.netG:
            from configs.munit_configs import get_configs
        else:
            raise NotImplementedError
        self.configs = get_configs(config_name=opt.config_set)

        self.dataloader = create_dataloader(opt)
        model = create_model(opt)
        model.setup(opt)
        for data_i in self.dataloader:
            model.set_input(data_i)
            break
        self.model = model
        self.device = model.device
        self.inception_model, self.drn_model, self.deeplabv2_model = create_metric_models(
            opt, self.device)
        if self.inception_model is not None:
            self.npz = np.load(opt.real_stat_path)
        self.macs_cache = {}
        self.result_cache = {}

        self.log_file = open(os.path.join(opt.output_dir, 'log.txt'), 'a')
        now = time.strftime('%c')
        self.log_file.write('================ (%s) ================\n' % now)
        self.log_file.flush()
Пример #2
0
def main(cfgs):
    fluid.enable_imperative() 
    if 'resnet' in cfgs.netG:
        from configs.resnet_configs import get_configs
    else:
        raise NotImplementedError
    configs = get_configs(config_name=cfgs.config_set)
    configs = list(configs.all_configs())

    data_loader, id2name = create_eval_data(cfgs, direction=cfgs.direction)
    model = TestModel(cfgs)
    model.setup()  ### load_network

    ### this input used in compute model flops and params
    for data in data_loader:
        model.set_input(data)
        break

    npz = np.load(cfgs.real_stat_path)
    results = []
    for config in configs:
        fakes, names = [], []
        flops, _ = model.profile(config=config)
        s_time = time.time()
        for i, data in enumerate(data_loader()):
            model.set_input(data)
            model.test(config)
            generated = model.fake_B
            fakes.append(generated.detach().numpy())
            name = id2name[i]
            save_path = os.path.join(cfgs.save_dir, 'test' + str(config))
            if not os.path.exists(save_path):
                os.makedirs(save_path)
            save_path = os.path.join(save_path, name)
            names.append(name)
            if i < cfgs.num_test:
               image = util.tensor2img(generated)
               util.save_image(image, save_path)

        result = {'config_str': encode_config(config), 'flops': flops} ### compute FLOPs

        fluid.disable_imperative()
        if not cfgs.no_fid:
            block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
            inception_model = InceptionV3([block_idx])
            fid = get_fid(fakes, inception_model, npz, cfgs.inception_model_path, batch_size=cfgs.batch_size, use_gpu=cfgs.use_gpu)
            result['fid'] = fid
        fluid.enable_imperative() 

        e_time = (time.time() - s_time) / 60
        result['time'] = e_time
        print(result)
        results.append(result)

    if not os.path.exists(cfgs.save_dir):
        os.makedirs(os.path.dirname(cfgs.save_dir))
    save_file = os.path.join(cfgs.save_dir, 'search_result.pkl')
    with open(save_file, 'wb') as f:
        pickle.dump(results, f)
    print('Successfully finish searching!!!')
Пример #3
0
def get_config_split(opt):
    if 'resnet' in opt.netG:
        from configs.resnet_configs import get_configs
    elif 'spade' in opt.netG:
        from configs.spade_configs import get_configs
    else:
        raise NotImplementedError
    configs = list(get_configs(config_name=opt.config_set).all_configs())
    random.shuffle(configs)
    configs = np.array_split(np.array(configs), opt.num_splits)[opt.split]
    return configs
Пример #4
0
 def __init__(self, opt):
     assert 'super' in opt.student_netG
     super(ResnetSupernet, self).__init__(opt)
     self.best_fid_largest = 1e9
     self.best_fid_smallest = 1e9
     self.best_mIoU_largest = -1e9
     self.best_mIoU_smallest = -1e9
     self.fids_largest, self.fids_smallest = [], []
     self.mIoUs_largest, self.mIoUs_smallest = [], []
     if opt.config_set is not None:
         assert opt.config_str is None
         self.configs = get_configs(opt.config_set)
         self.opt.eval_mode = 'both'
     else:
         assert opt.config_str is not None
         self.configs = SingleConfigs(decode_config(opt.config_str))
         self.opt.eval_mode = 'largest'
Пример #5
0
    def __init__(self, opt):
        self.opt = opt
        if 'resnet' in opt.netG:
            from configs.resnet_configs import get_configs
        elif 'spade' in opt.netG:
            from configs.spade_configs import get_configs
        else:
            raise NotImplementedError
        self.configs = get_configs(config_name=opt.config_set)

        self.dataloader = create_dataloader(opt)
        model = create_model(opt)
        model.setup(opt)
        for data_i in self.dataloader:
            model.set_input(data_i)
            break
        self.model = model
        self.device = model.device
        self.inception_model, self.drn_model, self.deeplabv2_model = create_metric_models(opt, self.device)
        self.npz = np.load(opt.real_stat_path)
        self.macs_cache = {}
        self.result_cache = {}
Пример #6
0
if __name__ == '__main__':
    mp.set_start_method('spawn')
    opt = SearchOptions().parse()
    print(' '.join(sys.argv), flush=True)
    check(opt)
    set_seed(opt.seed)

    if 'resnet' in opt.netG:
        from configs.resnet_configs import get_configs
    elif 'spade' in opt.netG:
        # TODO
        pass
    else:
        raise NotImplementedError
    configs = get_configs(config_name=opt.config_set)
    configs = list(configs.all_configs())
    random.shuffle(configs)

    chunk_size = (len(configs) + len(opt.gpu_ids) - 1) // len(opt.gpu_ids)

    processes = []
    queue = mp.Queue()

    for i, gpu_id in enumerate(opt.gpu_ids):
        start = min(i * chunk_size, len(configs))
        end = min((i + 1) * chunk_size, len(configs))
        p = mp.Process(target=main,
                       args=(configs[start:end], copy.deepcopy(opt), gpu_id,
                             queue, i == 0))
        processes.append(p)