def __init__(self, opt): self.opt = opt if 'resnet' in opt.netG: from configs.resnet_configs import get_configs elif 'spade' in opt.netG: from configs.spade_configs import get_configs elif 'munit' in opt.netG: from configs.munit_configs import get_configs else: raise NotImplementedError self.configs = get_configs(config_name=opt.config_set) self.dataloader = create_dataloader(opt) model = create_model(opt) model.setup(opt) for data_i in self.dataloader: model.set_input(data_i) break self.model = model self.device = model.device self.inception_model, self.drn_model, self.deeplabv2_model = create_metric_models( opt, self.device) if self.inception_model is not None: self.npz = np.load(opt.real_stat_path) self.macs_cache = {} self.result_cache = {} self.log_file = open(os.path.join(opt.output_dir, 'log.txt'), 'a') now = time.strftime('%c') self.log_file.write('================ (%s) ================\n' % now) self.log_file.flush()
def main(cfgs): fluid.enable_imperative() if 'resnet' in cfgs.netG: from configs.resnet_configs import get_configs else: raise NotImplementedError configs = get_configs(config_name=cfgs.config_set) configs = list(configs.all_configs()) data_loader, id2name = create_eval_data(cfgs, direction=cfgs.direction) model = TestModel(cfgs) model.setup() ### load_network ### this input used in compute model flops and params for data in data_loader: model.set_input(data) break npz = np.load(cfgs.real_stat_path) results = [] for config in configs: fakes, names = [], [] flops, _ = model.profile(config=config) s_time = time.time() for i, data in enumerate(data_loader()): model.set_input(data) model.test(config) generated = model.fake_B fakes.append(generated.detach().numpy()) name = id2name[i] save_path = os.path.join(cfgs.save_dir, 'test' + str(config)) if not os.path.exists(save_path): os.makedirs(save_path) save_path = os.path.join(save_path, name) names.append(name) if i < cfgs.num_test: image = util.tensor2img(generated) util.save_image(image, save_path) result = {'config_str': encode_config(config), 'flops': flops} ### compute FLOPs fluid.disable_imperative() if not cfgs.no_fid: block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048] inception_model = InceptionV3([block_idx]) fid = get_fid(fakes, inception_model, npz, cfgs.inception_model_path, batch_size=cfgs.batch_size, use_gpu=cfgs.use_gpu) result['fid'] = fid fluid.enable_imperative() e_time = (time.time() - s_time) / 60 result['time'] = e_time print(result) results.append(result) if not os.path.exists(cfgs.save_dir): os.makedirs(os.path.dirname(cfgs.save_dir)) save_file = os.path.join(cfgs.save_dir, 'search_result.pkl') with open(save_file, 'wb') as f: pickle.dump(results, f) print('Successfully finish searching!!!')
def get_config_split(opt): if 'resnet' in opt.netG: from configs.resnet_configs import get_configs elif 'spade' in opt.netG: from configs.spade_configs import get_configs else: raise NotImplementedError configs = list(get_configs(config_name=opt.config_set).all_configs()) random.shuffle(configs) configs = np.array_split(np.array(configs), opt.num_splits)[opt.split] return configs
def __init__(self, opt): assert 'super' in opt.student_netG super(ResnetSupernet, self).__init__(opt) self.best_fid_largest = 1e9 self.best_fid_smallest = 1e9 self.best_mIoU_largest = -1e9 self.best_mIoU_smallest = -1e9 self.fids_largest, self.fids_smallest = [], [] self.mIoUs_largest, self.mIoUs_smallest = [], [] if opt.config_set is not None: assert opt.config_str is None self.configs = get_configs(opt.config_set) self.opt.eval_mode = 'both' else: assert opt.config_str is not None self.configs = SingleConfigs(decode_config(opt.config_str)) self.opt.eval_mode = 'largest'
def __init__(self, opt): self.opt = opt if 'resnet' in opt.netG: from configs.resnet_configs import get_configs elif 'spade' in opt.netG: from configs.spade_configs import get_configs else: raise NotImplementedError self.configs = get_configs(config_name=opt.config_set) self.dataloader = create_dataloader(opt) model = create_model(opt) model.setup(opt) for data_i in self.dataloader: model.set_input(data_i) break self.model = model self.device = model.device self.inception_model, self.drn_model, self.deeplabv2_model = create_metric_models(opt, self.device) self.npz = np.load(opt.real_stat_path) self.macs_cache = {} self.result_cache = {}
if __name__ == '__main__': mp.set_start_method('spawn') opt = SearchOptions().parse() print(' '.join(sys.argv), flush=True) check(opt) set_seed(opt.seed) if 'resnet' in opt.netG: from configs.resnet_configs import get_configs elif 'spade' in opt.netG: # TODO pass else: raise NotImplementedError configs = get_configs(config_name=opt.config_set) configs = list(configs.all_configs()) random.shuffle(configs) chunk_size = (len(configs) + len(opt.gpu_ids) - 1) // len(opt.gpu_ids) processes = [] queue = mp.Queue() for i, gpu_id in enumerate(opt.gpu_ids): start = min(i * chunk_size, len(configs)) end = min((i + 1) * chunk_size, len(configs)) p = mp.Process(target=main, args=(configs[start:end], copy.deepcopy(opt), gpu_id, queue, i == 0)) processes.append(p)