def main(): args = parse_args() gen_net = Generator(args).cuda() _init_inception() inception_path = check_or_download_inception(None) create_inception_graph(inception_path) assert os.path.exists( args.load_path), "checkpoint file {} is not found".format( args.load_path) checkpoint = torch.load(args.load_path) torch.manual_seed(12345) torch.cuda.manual_seed(12345) np.random.seed(12345) #print("remaining percent: {}".format(0.8 ** checkpoint['round'])) pruning_generate( gen_net, checkpoint['avg_gen_state_dict']) # Create a buffer for mask] gen_net.load_state_dict(checkpoint['avg_gen_state_dict']) see_remain_rate_mask(gen_net) see_remain_rate(gen_net) print("Best FID:{}".format(checkpoint['best_fid'])) fixed_z = torch.cuda.FloatTensor( np.random.normal(0, 1, (25, args.latent_dim))) fid_stat = 'fid_stat/fid_stats_cifar10_train.npz' inception_score, fid_score = evaluate(args, fixed_z, fid_stat, gen_net) print('Inception score: %.4f, FID score: %.4f' % (inception_score, fid_score))
def main(): args = cfg.parse_args() torch.manual_seed(args.random_seed) random.seed(args.random_seed) torch.cuda.manual_seed(args.random_seed) assert args.exp_name assert args.load_path.endswith('.pth') assert os.path.exists(args.load_path) args.path_helper = set_log_dir('logs_eval', args.exp_name) logger = create_logger(args.path_helper['log_path'], phase='test') # set tf env _init_inception() inception_path = check_or_download_inception(None) create_inception_graph(inception_path) # import network gen_net = eval('models.' + args.model + '.Generator')(args=args).cuda() # fid stat if args.dataset.lower() == 'cifar10': fid_stat = 'fid_stat/fid_stats_cifar10_train.npz' else: raise NotImplementedError(f'no fid stat for {args.dataset.lower()}') assert os.path.exists(fid_stat) # initial np.random.seed(args.random_seed) fixed_z = torch.cuda.FloatTensor( np.random.normal(0, 1, (25, args.latent_dim))) if args.percent < 0.9: pruning_generate(gen_net, (1 - args.percent)) see_remain_rate(gen_net) # set writer logger.info(f'=> resuming from {args.load_path}') checkpoint_file = args.load_path assert os.path.exists(checkpoint_file) checkpoint = torch.load(checkpoint_file) if 'avg_gen_state_dict' in checkpoint: gen_net.load_state_dict(checkpoint['avg_gen_state_dict']) epoch = checkpoint['epoch'] logger.info(f'=> loaded checkpoint {checkpoint_file} (epoch {epoch})') else: gen_net.load_state_dict(checkpoint) logger.info(f'=> loaded checkpoint {checkpoint_file}') logger.info(args) writer_dict = { 'writer': SummaryWriter(args.path_helper['log_path']), 'valid_global_steps': 0, } inception_score, fid_score = validate(args, fixed_z, fid_stat, gen_net, writer_dict, epoch) logger.info(f'Inception score: {inception_score}, FID score: {fid_score}.') writer_dict['writer'].close()
def main(): args = cfg.parse_args() torch.cuda.manual_seed(args.random_seed) assert args.exp_name assert args.load_path.endswith(".pth") assert os.path.exists(args.load_path) args.path_helper = set_log_dir("logs_eval", args.exp_name) logger = create_logger(args.path_helper["log_path"], phase="test") # set tf env _init_inception() inception_path = check_or_download_inception(None) create_inception_graph(inception_path) # import network gen_net = eval("models." + args.gen_model + ".Generator")(args=args).cuda() # fid stat if args.dataset.lower() == "cifar10": fid_stat = "fid_stat/fid_stats_cifar10_train.npz" elif args.dataset.lower() == "stl10": fid_stat = "fid_stat/stl10_train_unlabeled_fid_stats_48.npz" else: raise NotImplementedError(f"no fid stat for {args.dataset.lower()}") assert os.path.exists(fid_stat) # initial fixed_z = torch.cuda.FloatTensor( np.random.normal(0, 1, (25, args.latent_dim))) # set writer logger.info(f"=> resuming from {args.load_path}") checkpoint_file = args.load_path assert os.path.exists(checkpoint_file) checkpoint = torch.load(checkpoint_file) if "avg_gen_state_dict" in checkpoint: gen_net.load_state_dict(checkpoint["avg_gen_state_dict"]) epoch = checkpoint["epoch"] logger.info(f"=> loaded checkpoint {checkpoint_file} (epoch {epoch})") else: gen_net.load_state_dict(checkpoint) logger.info(f"=> loaded checkpoint {checkpoint_file}") logger.info(args) writer_dict = { "writer": SummaryWriter(args.path_helper["log_path"]), "valid_global_steps": 0, } inception_score, fid_score = validate(args, fixed_z, fid_stat, gen_net, writer_dict, clean_dir=False) logger.info(f"Inception score: {inception_score}, FID score: {fid_score}.")
def calculate_IS_FID(G): torch.cuda.manual_seed(args.random_seed) # set tf env _init_inception() inception_path = check_or_download_inception(None) create_inception_graph(inception_path) # fid stat fid_stat = 'fid_stat/fid_stats_cifar10_train.npz' assert os.path.exists(fid_stat) # IS and FID: inception_score, fid_score = \ calculate_metrics(args.fid_buffer_dir, args.num_eval_imgs, args.eval_batch_size, args.latent_dim, fid_stat, G, do_IS=args.do_IS, do_FID=args.do_FID) print('Inception score: %s, FID score: %s.' % (inception_score, fid_score))
def main(): args = cfg.parse_args() torch.cuda.manual_seed(args.random_seed) assert args.exp_name assert args.load_path.endswith('.pth') assert os.path.exists(args.load_path) args.path_helper = set_log_dir('logs_eval', args.exp_name) logger = create_logger(args.path_helper['log_path'], phase='test') # set tf env _init_inception() inception_path = check_or_download_inception(None) create_inception_graph(inception_path) # import network gen_net = eval('models.' + args.model + '.Generator')(args=args).cuda() # initial fixed_z = torch.cuda.FloatTensor( np.random.normal(0, 1, (25, args.latent_dim))) # set writer logger.info(f'=> resuming from {args.load_path}') checkpoint_file = args.load_path assert os.path.exists(checkpoint_file) checkpoint = torch.load(checkpoint_file) if 'avg_gen_state_dict' in checkpoint: gen_net.load_state_dict(checkpoint['avg_gen_state_dict']) epoch = checkpoint['epoch'] logger.info(f'=> loaded checkpoint {checkpoint_file} (epoch {epoch})') else: gen_net.load_state_dict(checkpoint) logger.info(f'=> loaded checkpoint {checkpoint_file}') logger.info(args) writer_dict = { 'writer': SummaryWriter(args.path_helper['log_path']), 'valid_global_steps': 0, } inception_score, fid_score = validate(args, fixed_z, gen_net, writer_dict) logger.info(f'Inception score: {inception_score}, FID score: {fid_score}.')
def main(): args = cfg.parse_args() torch.cuda.manual_seed(args.random_seed) # set visible GPU ids if len(args.gpu_ids) > 0: os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_ids # set TensorFlow environment for evaluation (calculate IS and FID) _init_inception() inception_path = check_or_download_inception('./tmp/imagenet/') create_inception_graph(inception_path) # the first GPU in visible GPUs is dedicated for evaluation (running Inception model) str_ids = args.gpu_ids.split(',') args.gpu_ids = [] for id in range(len(str_ids)): if id >= 0: args.gpu_ids.append(id) if len(args.gpu_ids) > 1: args.gpu_ids = args.gpu_ids[1:] else: args.gpu_ids = args.gpu_ids # genotype G genotypes_root = os.path.join('exps', args.genotypes_exp, 'Genotypes') genotype_G = np.load(os.path.join(genotypes_root, 'latest_G.npy')) # import network from genotype basemodel_gen = eval('archs.' + args.arch + '.Generator')(args, genotype_G) gen_net = torch.nn.DataParallel( basemodel_gen, device_ids=args.gpu_ids).cuda(args.gpu_ids[0]) basemodel_dis = eval('archs.' + args.arch + '.Discriminator')(args) dis_net = torch.nn.DataParallel( basemodel_dis, device_ids=args.gpu_ids).cuda(args.gpu_ids[0]) # fid stat if args.dataset.lower() == 'cifar10': fid_stat = 'fid_stat/fid_stats_cifar10_train.npz' elif args.dataset.lower() == 'stl10': fid_stat = 'fid_stat/stl10_train_unlabeled_fid_stats_48.npz' else: raise NotImplementedError(f'no fid stat for {args.dataset.lower()}') assert os.path.exists(fid_stat) # set writer print(f'=> resuming from {args.checkpoint}') assert os.path.exists(os.path.join('exps', args.checkpoint)) checkpoint_file = os.path.join('exps', args.checkpoint, 'Model', 'checkpoint_best.pth') assert os.path.exists(checkpoint_file) checkpoint = torch.load(checkpoint_file) epoch = checkpoint['epoch'] - 1 gen_net.load_state_dict(checkpoint['gen_state_dict']) dis_net.load_state_dict(checkpoint['dis_state_dict']) avg_gen_net = deepcopy(gen_net) avg_gen_net.load_state_dict(checkpoint['avg_gen_state_dict']) gen_avg_param = copy_params(avg_gen_net) del avg_gen_net assert args.exp_name args.path_helper = set_log_dir('exps', args.exp_name) logger = create_logger(args.path_helper['log_path']) logger.info(f'=> loaded checkpoint {checkpoint_file} (epoch {epoch})') logger.info(args) writer_dict = { 'writer': SummaryWriter(args.path_helper['log_path']), 'valid_global_steps': epoch // args.val_freq, } # model size logger.info('Param size of G = %fMB', count_parameters_in_MB(gen_net)) logger.info('Param size of D = %fMB', count_parameters_in_MB(dis_net)) print_FLOPs(basemodel_gen, (1, args.latent_dim), logger) print_FLOPs(basemodel_dis, (1, 3, args.img_size, args.img_size), logger) # for visualization if args.draw_arch: from utils.genotype import draw_graph_G draw_graph_G(genotype_G, save=True, file_path=os.path.join(args.path_helper['graph_vis_path'], 'latest_G')) fixed_z = torch.cuda.FloatTensor( np.random.normal(0, 1, (100, args.latent_dim))) # test load_params(gen_net, gen_avg_param) inception_score, std, fid_score = validate(args, fixed_z, fid_stat, gen_net, writer_dict) logger.info( f'Inception score mean: {inception_score}, Inception score std: {std}, ' f'FID score: {fid_score} || @ epoch {epoch}.')
def main(): args = cfg.parse_args() torch.cuda.manual_seed(args.random_seed) # set tf env _init_inception(MODEL_DIR) inception_path = check_or_download_inception(None) create_inception_graph(inception_path) # weight init def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv2d') != -1: if args.init_type == 'normal': nn.init.normal_(m.weight.data, 0.0, 0.02) elif args.init_type == 'orth': nn.init.orthogonal_(m.weight.data) elif args.init_type == 'xavier_uniform': nn.init.xavier_uniform(m.weight.data, 1.) else: raise NotImplementedError('{} unknown inital type'.format( args.init_type)) elif classname.find('BatchNorm2d') != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0.0) gen_net, dis_net, gen_optimizer, dis_optimizer = create_shared_gan( args, weights_init) # set grow controller grow_ctrler = GrowCtrler(args.grow_step1, args.grow_step2) # initial start_search_iter = 0 # set writer if args.load_path: print(f'=> resuming from {args.load_path}') assert os.path.exists(args.load_path) checkpoint_file = os.path.join(args.load_path, 'Model', 'checkpoint.pth') assert os.path.exists(checkpoint_file) checkpoint = torch.load(checkpoint_file, map_location={'cuda:0': 'cpu'}) # set controller && its optimizer cur_stage = checkpoint['cur_stage'] controller, ctrl_optimizer = create_ctrler(args, cur_stage, weights_init) start_search_iter = checkpoint['search_iter'] gen_net.load_state_dict(checkpoint['gen_state_dict']) dis_net.load_state_dict(checkpoint['dis_state_dict']) controller.load_state_dict(checkpoint['ctrl_state_dict']) gen_optimizer.load_state_dict(checkpoint['gen_optimizer']) dis_optimizer.load_state_dict(checkpoint['dis_optimizer']) ctrl_optimizer.load_state_dict(checkpoint['ctrl_optimizer']) prev_archs = checkpoint['prev_archs'] prev_hiddens = checkpoint['prev_hiddens'] args.path_helper = checkpoint['path_helper'] logger = create_logger(args.path_helper['log_path']) logger.info( f'=> loaded checkpoint {checkpoint_file} (search iteration {start_search_iter})' ) else: # create new log dir assert args.exp_name args.path_helper = set_log_dir('logs', args.exp_name) logger = create_logger(args.path_helper['log_path']) prev_archs = None prev_hiddens = None # set controller && its optimizer cur_stage = 0 controller, ctrl_optimizer = create_ctrler(args, cur_stage, weights_init) # set up data_loader dataset = datasets.ImageDataset(args, 2**(cur_stage + 3), args.dis_batch_size, args.num_workers) train_loader = dataset.train logger.info(args) writer_dict = { 'writer': SummaryWriter(args.path_helper['log_path']), 'controller_steps': start_search_iter * args.ctrl_step } g_loss_history = RunningStats(args.dynamic_reset_window) d_loss_history = RunningStats(args.dynamic_reset_window) # train loop for search_iter in tqdm(range(int(start_search_iter), int(args.max_search_iter)), desc='search progress'): logger.info(f"<start search iteration {search_iter}>") if search_iter == args.grow_step1 or search_iter == args.grow_step2: # save cur_stage = grow_ctrler.cur_stage(search_iter) logger.info(f'=> grow to stage {cur_stage}') prev_archs, prev_hiddens = get_topk_arch_hidden( args, controller, gen_net, prev_archs, prev_hiddens) # grow section del controller del ctrl_optimizer controller, ctrl_optimizer = create_ctrler(args, cur_stage, weights_init) dataset = datasets.ImageDataset(args, 2**(cur_stage + 3), args.dis_batch_size, args.num_workers) train_loader = dataset.train dynamic_reset = train_shared(args, gen_net, dis_net, g_loss_history, d_loss_history, controller, gen_optimizer, dis_optimizer, train_loader, prev_hiddens=prev_hiddens, prev_archs=prev_archs) train_controller(args, controller, ctrl_optimizer, gen_net, prev_hiddens, prev_archs, writer_dict) if dynamic_reset: logger.info('re-initialize share GAN') del gen_net, dis_net, gen_optimizer, dis_optimizer gen_net, dis_net, gen_optimizer, dis_optimizer = create_shared_gan( args, weights_init) save_checkpoint( { 'cur_stage': cur_stage, 'search_iter': search_iter + 1, 'gen_model': args.gen_model, 'dis_model': args.dis_model, 'controller': args.controller, 'gen_state_dict': gen_net.state_dict(), 'dis_state_dict': dis_net.state_dict(), 'ctrl_state_dict': controller.state_dict(), 'gen_optimizer': gen_optimizer.state_dict(), 'dis_optimizer': dis_optimizer.state_dict(), 'ctrl_optimizer': ctrl_optimizer.state_dict(), 'prev_archs': prev_archs, 'prev_hiddens': prev_hiddens, 'path_helper': args.path_helper }, False, args.path_helper['ckpt_path']) final_archs, _ = get_topk_arch_hidden(args, controller, gen_net, prev_archs, prev_hiddens) logger.info(f"discovered archs: {final_archs}")
def main(): args = cfg.parse_args() torch.cuda.manual_seed(args.random_seed) # set tf env _init_inception() inception_path = check_or_download_inception(None) create_inception_graph(inception_path) # import network gen_net = eval('models.' + args.model + '.Generator')(args=args).cuda() dis_net = eval('models.' + args.model + '.Discriminator')(args=args).cuda() # weight init def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv2d') != -1: if args.init_type == 'normal': nn.init.normal_(m.weight.data, 0.0, 0.02) elif args.init_type == 'orth': nn.init.orthogonal_(m.weight.data) elif args.init_type == 'xavier_uniform': nn.init.xavier_uniform(m.weight.data, 1.) else: raise NotImplementedError('{} unknown inital type'.format( args.init_type)) elif classname.find('BatchNorm2d') != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0.0) gen_net.apply(weights_init) dis_net.apply(weights_init) # set optimizer gen_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, gen_net.parameters()), args.g_lr, (args.beta1, args.beta2)) dis_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, dis_net.parameters()), args.d_lr, (args.beta1, args.beta2)) gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0, args.max_iter * args.n_critic) dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0, args.max_iter * args.n_critic) # set up data_loader dataset = datasets.ImageDataset(args) train_loader = dataset.train # fid stat if args.dataset.lower() == 'cifar10': fid_stat = 'fid_stat/fid_stats_cifar10_train.npz' elif args.dataset.lower() == 'stl10': fid_stat = 'fid_stat/stl10_train_unlabeled_fid_stats_48.npz' else: raise NotImplementedError(f'no fid stat for {args.dataset.lower()}') assert os.path.exists(fid_stat) # epoch number for dis_net args.max_epoch = args.max_epoch * args.n_critic if args.max_iter: args.max_epoch = np.ceil(args.max_iter * args.n_critic / len(train_loader)) # initial fixed_z = torch.cuda.FloatTensor( np.random.normal(0, 1, (25, args.latent_dim))) gen_avg_param = copy_params(gen_net) start_epoch = 0 best_fid = 1e4 # set writer if args.load_path: print(f'=> resuming from {args.load_path}') assert os.path.exists(args.load_path) checkpoint_file = os.path.join(args.load_path, 'Model', 'checkpoint.pth') assert os.path.exists(checkpoint_file) checkpoint = torch.load(checkpoint_file) start_epoch = checkpoint['epoch'] best_fid = checkpoint['best_fid'] gen_net.load_state_dict(checkpoint['gen_state_dict']) dis_net.load_state_dict(checkpoint['dis_state_dict']) gen_optimizer.load_state_dict(checkpoint['gen_optimizer']) dis_optimizer.load_state_dict(checkpoint['dis_optimizer']) avg_gen_net = deepcopy(gen_net) avg_gen_net.load_state_dict(checkpoint['avg_gen_state_dict']) gen_avg_param = copy_params(avg_gen_net) del avg_gen_net args.path_helper = checkpoint['path_helper'] logger = create_logger(args.path_helper['log_path']) logger.info( f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})') else: # create new log dir assert args.exp_name args.path_helper = set_log_dir('logs', args.exp_name) logger = create_logger(args.path_helper['log_path']) logger.info(args) writer_dict = { 'writer': SummaryWriter(args.path_helper['log_path']), 'train_global_steps': start_epoch * len(train_loader), 'valid_global_steps': start_epoch // args.val_freq, } # train loop for epoch in tqdm(range(int(start_epoch), int(args.max_epoch)), desc='total progress'): lr_schedulers = (gen_scheduler, dis_scheduler) if args.lr_decay else None train(args, gen_net, dis_net, gen_optimizer, dis_optimizer, gen_avg_param, train_loader, epoch, writer_dict, lr_schedulers) if epoch and epoch % args.val_freq == 0 or epoch == int( args.max_epoch) - 1: backup_param = copy_params(gen_net) load_params(gen_net, gen_avg_param) inception_score, fid_score = validate(args, fixed_z, fid_stat, gen_net, writer_dict) logger.info( f'Inception score: {inception_score}, FID score: {fid_score} || @ epoch {epoch}.' ) load_params(gen_net, backup_param) if fid_score < best_fid: best_fid = fid_score is_best = True else: is_best = False else: is_best = False avg_gen_net = deepcopy(gen_net) load_params(avg_gen_net, gen_avg_param) save_checkpoint( { 'epoch': epoch + 1, 'model': args.model, 'gen_state_dict': gen_net.state_dict(), 'dis_state_dict': dis_net.state_dict(), 'avg_gen_state_dict': avg_gen_net.state_dict(), 'gen_optimizer': gen_optimizer.state_dict(), 'dis_optimizer': dis_optimizer.state_dict(), 'best_fid': best_fid, 'path_helper': args.path_helper }, is_best, args.path_helper['ckpt_path']) del avg_gen_net
def main(): args = cfg.parse_args() torch.cuda.manual_seed(args.random_seed) # set tf env _init_inception() inception_path = check_or_download_inception(None) create_inception_graph(inception_path) # weight init def weights_init(m): classname = m.__class__.__name__ if classname.find("Conv2d") != -1: if args.init_type == "normal": nn.init.normal_(m.weight.data, 0.0, 0.02) elif args.init_type == "orth": nn.init.orthogonal_(m.weight.data) elif args.init_type == "xavier_uniform": nn.init.xavier_uniform(m.weight.data, 1.0) else: raise NotImplementedError("{} unknown inital type".format( args.init_type)) elif classname.find("BatchNorm2d") != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0.0) gen_net, dis_net, gen_optimizer, dis_optimizer = create_shared_gan( args, weights_init) # set grow controller grow_ctrler = GrowCtrler(args.grow_step1, args.grow_step2) # initial start_search_iter = 0 # set writer if args.load_path: print(f"=> resuming from {args.load_path}") assert os.path.exists(args.load_path) checkpoint_file = os.path.join(args.load_path, "Model", "checkpoint.pth") assert os.path.exists(checkpoint_file) checkpoint = torch.load(checkpoint_file) # set controller && its optimizer cur_stage = checkpoint["cur_stage"] controller, ctrl_optimizer = create_ctrler(args, cur_stage, weights_init) start_search_iter = checkpoint["search_iter"] gen_net.load_state_dict(checkpoint["gen_state_dict"]) dis_net.load_state_dict(checkpoint["dis_state_dict"]) controller.load_state_dict(checkpoint["ctrl_state_dict"]) gen_optimizer.load_state_dict(checkpoint["gen_optimizer"]) dis_optimizer.load_state_dict(checkpoint["dis_optimizer"]) ctrl_optimizer.load_state_dict(checkpoint["ctrl_optimizer"]) prev_archs = checkpoint["prev_archs"] prev_hiddens = checkpoint["prev_hiddens"] args.path_helper = checkpoint["path_helper"] logger = create_logger(args.path_helper["log_path"]) logger.info( f"=> loaded checkpoint {checkpoint_file} (search iteration {start_search_iter})" ) else: # create new log dir assert args.exp_name args.path_helper = set_log_dir("logs", args.exp_name) logger = create_logger(args.path_helper["log_path"]) prev_archs = None prev_hiddens = None # set controller && its optimizer cur_stage = 0 controller, ctrl_optimizer = create_ctrler(args, cur_stage, weights_init) # set up data_loader dataset = datasets.ImageDataset(args, 2**(cur_stage + 3)) train_loader = dataset.train logger.info(args) writer_dict = { "writer": SummaryWriter(args.path_helper["log_path"]), "controller_steps": start_search_iter * args.ctrl_step, } g_loss_history = RunningStats(args.dynamic_reset_window) d_loss_history = RunningStats(args.dynamic_reset_window) # train loop for search_iter in tqdm(range(int(start_search_iter), int(args.max_search_iter)), desc="search progress"): logger.info(f"<start search iteration {search_iter}>") if search_iter == args.grow_step1 or search_iter == args.grow_step2: # save cur_stage = grow_ctrler.cur_stage(search_iter) logger.info(f"=> grow to stage {cur_stage}") prev_archs, prev_hiddens = get_topk_arch_hidden( args, controller, gen_net, prev_archs, prev_hiddens) # grow section del controller del ctrl_optimizer controller, ctrl_optimizer = create_ctrler(args, cur_stage, weights_init) dataset = datasets.ImageDataset(args, 2**(cur_stage + 3)) train_loader = dataset.train dynamic_reset = train_shared( args, gen_net, dis_net, g_loss_history, d_loss_history, controller, gen_optimizer, dis_optimizer, train_loader, prev_hiddens=prev_hiddens, prev_archs=prev_archs, ) train_controller( args, controller, ctrl_optimizer, gen_net, prev_hiddens, prev_archs, writer_dict, ) if dynamic_reset: logger.info("re-initialize share GAN") del gen_net, dis_net, gen_optimizer, dis_optimizer gen_net, dis_net, gen_optimizer, dis_optimizer = create_shared_gan( args, weights_init) save_checkpoint( { "cur_stage": cur_stage, "search_iter": search_iter + 1, "gen_model": args.gen_model, "dis_model": args.dis_model, "controller": args.controller, "gen_state_dict": gen_net.state_dict(), "dis_state_dict": dis_net.state_dict(), "ctrl_state_dict": controller.state_dict(), "gen_optimizer": gen_optimizer.state_dict(), "dis_optimizer": dis_optimizer.state_dict(), "ctrl_optimizer": ctrl_optimizer.state_dict(), "prev_archs": prev_archs, "prev_hiddens": prev_hiddens, "path_helper": args.path_helper, }, False, args.path_helper["ckpt_path"], ) final_archs, _ = get_topk_arch_hidden(args, controller, gen_net, prev_archs, prev_hiddens) logger.info(f"discovered archs: {final_archs}")
def main(): args = cfg.parse_args() random.seed(args.random_seed) torch.manual_seed(args.random_seed) torch.cuda.manual_seed(args.random_seed) np.random.seed(args.random_seed) # set tf env _init_inception() inception_path = check_or_download_inception(None) create_inception_graph(inception_path) # import netwo # weight init def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv2d') != -1: if args.init_type == 'normal': nn.init.normal_(m.weight.data, 0.0, 0.02) elif args.init_type == 'orth': nn.init.orthogonal_(m.weight.data) elif args.init_type == 'xavier_uniform': nn.init.xavier_uniform(m.weight.data, 1.) else: raise NotImplementedError('{} unknown inital type'.format( args.init_type)) elif classname.find('BatchNorm2d') != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0.0) gen_net = eval('models.' + args.model + '.Generator')(args=args).cuda() dis_net = eval('models.' + args.model + '.Discriminator')(args=args).cuda() gen_net.apply(weights_init) dis_net.apply(weights_init) avg_gen_net = deepcopy(gen_net) initial_gen_net_weight = torch.load(os.path.join(args.init_path, 'initial_gen_net.pth'), map_location="cpu") initial_dis_net_weight = torch.load(os.path.join(args.init_path, 'initial_dis_net.pth'), map_location="cpu") assert id(initial_dis_net_weight) != id(dis_net.state_dict()) assert id(initial_gen_net_weight) != id(gen_net.state_dict()) # set optimizer gen_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, gen_net.parameters()), args.g_lr, (args.beta1, args.beta2)) dis_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, dis_net.parameters()), args.d_lr, (args.beta1, args.beta2)) gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0, args.max_iter * args.n_critic) dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0, args.max_iter * args.n_critic) # set up data_loader dataset = datasets.ImageDataset(args) train_loader = dataset.train # fid stat if args.dataset.lower() == 'cifar10': fid_stat = 'fid_stat/fid_stats_cifar10_train.npz' elif args.dataset.lower() == 'stl10': fid_stat = 'fid_stat/fid_stats_stl10_train.npz' else: raise NotImplementedError('no fid stat for %s' % args.dataset.lower()) assert os.path.exists(fid_stat) # epoch number for dis_net args.max_epoch = args.max_epoch * args.n_critic if args.max_iter: args.max_epoch = np.ceil(args.max_iter * args.n_critic / len(train_loader)) # initial fixed_z = torch.cuda.FloatTensor( np.random.normal(0, 1, (25, args.latent_dim))) start_epoch = 0 best_fid = 1e4 print('=> resuming from %s' % args.load_path) assert os.path.exists(args.load_path) checkpoint_file = args.load_path assert os.path.exists(checkpoint_file) checkpoint = torch.load(checkpoint_file) pruning_generate(gen_net, checkpoint['gen_state_dict']) dis_net.load_state_dict(checkpoint['dis_state_dict']) total = 0 total_nonzero = 0 for m in dis_net.modules(): if isinstance(m, nn.Conv2d): total += m.weight_orig.data.numel() mask = m.weight_orig.data.abs().clone().gt(0).float().cuda() total_nonzero += torch.sum(mask) conv_weights = torch.zeros(total) index = 0 for m in dis_net.modules(): if isinstance(m, nn.Conv2d): size = m.weight_orig.data.numel() conv_weights[index:( index + size)] = m.weight_orig.data.view(-1).abs().clone() index += size y, i = torch.sort(conv_weights) # thre_index = int(total * args.percent) # only care about the non zero weights # e.g: total = 100, total_nonzero = 80, percent = 0.2, thre_index = 36, that means keep 64 thre_index = total - total_nonzero thre = y[int(thre_index)] pruned = 0 print('Pruning threshold: {}'.format(thre)) zero_flag = False masks = OrderedDict() for k, m in enumerate(dis_net.modules()): if isinstance(m, nn.Conv2d): weight_copy = m.weight_orig.data.abs().clone() mask = weight_copy.gt(thre).float() masks[k] = mask pruned = pruned + mask.numel() - torch.sum(mask) m.weight_orig.data.mul_(mask) if int(torch.sum(mask)) == 0: zero_flag = True print( 'layer index: {:d} \t total params: {:d} \t remaining params: {:d}' .format(k, mask.numel(), int(torch.sum(mask)))) print('Total conv params: {}, Pruned conv params: {}, Pruned ratio: {}'. format(total, pruned, pruned / total)) pruning_generate(avg_gen_net, checkpoint['gen_state_dict']) see_remain_rate(gen_net) if not args.finetune_G: gen_weight = gen_net.state_dict() gen_orig_weight = rewind_weight(initial_gen_net_weight, gen_weight.keys()) gen_weight.update(gen_orig_weight) gen_net.load_state_dict(gen_weight) gen_avg_param = copy_params(gen_net) if args.finetune_D: dis_net.load_state_dict(checkpoint['dis_state_dict']) else: dis_net.load_state_dict(initial_dis_net_weight) for k, m in enumerate(dis_net.modules()): if isinstance(m, nn.Conv2d): m.weight_orig.data.mul_(masks[k]) orig_dis_net = eval('models.' + args.model + '.Discriminator')(args=args).cuda() orig_dis_net.load_state_dict(checkpoint['dis_state_dict']) orig_dis_net.eval() args.path_helper = set_log_dir('logs', args.exp_name + "_{}".format(args.percent)) logger = create_logger(args.path_helper['log_path']) #logger.info('=> loaded checkpoint %s (epoch %d)' % (checkpoint_file, start_epoch)) logger.info(args) writer_dict = { 'writer': SummaryWriter(args.path_helper['log_path']), 'train_global_steps': start_epoch * len(train_loader), 'valid_global_steps': start_epoch // args.val_freq, } # train loop for epoch in tqdm(range(int(start_epoch), int(args.max_epoch)), desc='total progress'): lr_schedulers = (gen_scheduler, dis_scheduler) if args.lr_decay else None see_remain_rate(gen_net) see_remain_rate_orig(dis_net) if not args.use_kd_D: train_with_mask(args, gen_net, dis_net, gen_optimizer, dis_optimizer, gen_avg_param, train_loader, epoch, writer_dict, masks, lr_schedulers) else: train_with_mask_kd(args, gen_net, dis_net, orig_dis_net, gen_optimizer, dis_optimizer, gen_avg_param, train_loader, epoch, writer_dict, masks, lr_schedulers) if epoch and epoch % args.val_freq == 0 or epoch == int( args.max_epoch) - 1: backup_param = copy_params(gen_net) load_params(gen_net, gen_avg_param) inception_score, fid_score = validate(args, fixed_z, fid_stat, gen_net, writer_dict, epoch) logger.info( 'Inception score: %.4f, FID score: %.4f || @ epoch %d.' % (inception_score, fid_score, epoch)) load_params(gen_net, backup_param) if fid_score < best_fid: best_fid = fid_score is_best = True else: is_best = False else: is_best = False avg_gen_net.load_state_dict(gen_net.state_dict()) load_params(avg_gen_net, gen_avg_param) save_checkpoint( { 'epoch': epoch + 1, 'model': args.model, 'gen_state_dict': gen_net.state_dict(), 'dis_state_dict': dis_net.state_dict(), 'avg_gen_state_dict': avg_gen_net.state_dict(), 'gen_optimizer': gen_optimizer.state_dict(), 'dis_optimizer': dis_optimizer.state_dict(), 'best_fid': best_fid, 'path_helper': args.path_helper }, is_best, args.path_helper['ckpt_path'])
def main(): args = cfg.parse_args() torch.cuda.manual_seed(args.random_seed) torch.cuda.manual_seed_all(args.random_seed) np.random.seed(args.random_seed) random.seed(args.random_seed) torch.backends.cudnn.deterministic = True # set tf env _init_inception() inception_path = check_or_download_inception(None) create_inception_graph(inception_path) # epoch number for dis_net dataset = datasets.ImageDataset(args, cur_img_size=8) train_loader = dataset.train if args.max_iter: args.max_epoch = np.ceil(args.max_iter / len(train_loader)) else: args.max_iter = args.max_epoch * len(train_loader) args.max_epoch = args.max_epoch * args.n_critic # import network gen_net = eval('models.' + args.gen_model + '.Generator')(args=args).cuda() dis_net = eval('models.' + args.dis_model + '.Discriminator')(args=args).cuda() gen_net.set_arch(args.arch, cur_stage=2) # weight init def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv2d') != -1: if args.init_type == 'normal': nn.init.normal_(m.weight.data, 0.0, 0.02) elif args.init_type == 'orth': nn.init.orthogonal_(m.weight.data) elif args.init_type == 'xavier_uniform': nn.init.xavier_uniform_(m.weight.data, 1.) else: raise NotImplementedError('{} unknown inital type'.format( args.init_type)) elif classname.find('BatchNorm2d') != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0.0) gen_net.apply(weights_init) dis_net.apply(weights_init) gpu_ids = [i for i in range(int(torch.cuda.device_count()))] gen_net = torch.nn.DataParallel(gen_net.to("cuda:0"), device_ids=gpu_ids) dis_net = torch.nn.DataParallel(dis_net.to("cuda:0"), device_ids=gpu_ids) gen_net.module.cur_stage = 0 dis_net.module.cur_stage = 0 gen_net.module.alpha = 1. dis_net.module.alpha = 1. # set optimizer if args.optimizer == "adam": gen_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, gen_net.parameters()), args.g_lr, (args.beta1, args.beta2)) dis_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, dis_net.parameters()), args.d_lr, (args.beta1, args.beta2)) elif args.optimizer == "adamw": gen_optimizer = AdamW(filter(lambda p: p.requires_grad, gen_net.parameters()), args.g_lr, weight_decay=args.wd) dis_optimizer = AdamW(filter(lambda p: p.requires_grad, dis_net.parameters()), args.g_lr, weight_decay=args.wd) gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0, args.max_iter * args.n_critic) dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0, args.max_iter * args.n_critic) # fid stat if args.dataset.lower() == 'cifar10': fid_stat = 'fid_stat/fid_stats_cifar10_train.npz' elif args.dataset.lower() == 'stl10': fid_stat = 'fid_stat/stl10_train_unlabeled_fid_stats_48.npz' elif args.fid_stat is not None: fid_stat = args.fid_stat else: raise NotImplementedError(f'no fid stat for {args.dataset.lower()}') assert os.path.exists(fid_stat) # initial fixed_z = torch.cuda.FloatTensor( np.random.normal(0, 1, (64, args.latent_dim))) gen_avg_param = copy_params(gen_net) start_epoch = 0 best_fid = 1e4 # set writer if args.load_path: print(f'=> resuming from {args.load_path}') assert os.path.exists(args.load_path) checkpoint_file = os.path.join(args.load_path) assert os.path.exists(checkpoint_file) checkpoint = torch.load(checkpoint_file) start_epoch = checkpoint['epoch'] best_fid = checkpoint['best_fid'] gen_net.load_state_dict(checkpoint['gen_state_dict']) dis_net.load_state_dict(checkpoint['dis_state_dict']) gen_optimizer.load_state_dict(checkpoint['gen_optimizer']) dis_optimizer.load_state_dict(checkpoint['dis_optimizer']) # avg_gen_net = deepcopy(gen_net) # avg_gen_net.load_state_dict(checkpoint['avg_gen_state_dict']) gen_avg_param = checkpoint['gen_avg_param'] # del avg_gen_net cur_stage = cur_stages(start_epoch, args) gen_net.module.cur_stage = cur_stage dis_net.module.cur_stage = cur_stage gen_net.module.alpha = 1. dis_net.module.alpha = 1. args.path_helper = checkpoint['path_helper'] else: # create new log dir assert args.exp_name args.path_helper = set_log_dir('logs', args.exp_name) logger = create_logger(args.path_helper['log_path']) logger.info(args) writer_dict = { 'writer': SummaryWriter(args.path_helper['log_path']), 'train_global_steps': start_epoch * len(train_loader), 'valid_global_steps': start_epoch // args.val_freq, } def return_states(): states = {} states['epoch'] = epoch states['best_fid'] = best_fid_score states['gen_state_dict'] = gen_net.state_dict() states['dis_state_dict'] = dis_net.state_dict() states['gen_optimizer'] = gen_optimizer.state_dict() states['dis_optimizer'] = dis_optimizer.state_dict() states['gen_avg_param'] = gen_avg_param states['path_helper'] = args.path_helper return states # train loop for epoch in range(start_epoch + 1, args.max_epoch): train( args, gen_net, dis_net, gen_optimizer, dis_optimizer, gen_avg_param, train_loader, epoch, writer_dict, fixed_z, ) backup_param = copy_params(gen_net) load_params(gen_net, gen_avg_param) fid_score = validate( args, fixed_z, fid_stat, epoch, gen_net, writer_dict, ) logger.info(f'FID score: {fid_score} || @ epoch {epoch}.') load_params(gen_net, backup_param) is_best = False if epoch == 1 or fid_score < best_fid_score: best_fid_score = fid_score is_best = True if is_best or epoch % 1 == 0: states = return_states() save_checkpoint(states, is_best, args.path_helper['ckpt_path'], filename=f'checkpoint_epoch_{epoch}.pth')
def main(): args = cfg.parse_args() random.seed(args.random_seed) torch.manual_seed(args.random_seed) torch.cuda.manual_seed(args.random_seed) # set tf env _init_inception() inception_path = check_or_download_inception(None) create_inception_graph(inception_path) # weight init gen_net = eval('models.' + args.model + '.Generator')(args=args) dis_net = eval('models.' + args.model + '.Discriminator')(args=args) # weight init def weights_init(m): if isinstance(m, nn.Conv2d): if args.init_type == 'normal': nn.init.normal_(m.weight.data, 0.0, 0.02) elif args.init_type == 'orth': nn.init.orthogonal_(m.weight.data) elif args.init_type == 'xavier_uniform': nn.init.xavier_uniform(m.weight.data, 1.) else: raise NotImplementedError('{} unknown inital type'.format( args.init_type)) elif isinstance(m, nn.BatchNorm2d): nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0.0) gen_net.apply(weights_init) dis_net.apply(weights_init) gen_net = gen_net.cuda() dis_net = dis_net.cuda() avg_gen_net = deepcopy(gen_net) initial_gen_net_weight = deepcopy(gen_net.state_dict()) initial_dis_net_weight = deepcopy(dis_net.state_dict()) assert id(initial_dis_net_weight) != id(dis_net.state_dict()) assert id(initial_gen_net_weight) != id(gen_net.state_dict()) # set optimizer gen_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, gen_net.parameters()), args.g_lr, (args.beta1, args.beta2)) dis_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, dis_net.parameters()), args.d_lr, (args.beta1, args.beta2)) gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0, args.max_iter * args.n_critic) dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0, args.max_iter * args.n_critic) # set up data_loader dataset = datasets.ImageDataset(args) train_loader = dataset.train # fid stat if args.dataset.lower() == 'cifar10': fid_stat = 'fid_stat/fid_stats_cifar10_train.npz' elif args.dataset.lower() == 'stl10': fid_stat = 'fid_stat/fid_stats_stl10_train.npz' else: raise NotImplementedError('no fid stat for %s' % args.dataset.lower()) assert os.path.exists(fid_stat) # epoch number for dis_net args.max_epoch = args.max_epoch * args.n_critic if args.max_iter: args.max_epoch = np.ceil(args.max_iter * args.n_critic / len(train_loader)) # initial np.random.seed(args.random_seed) fixed_z = torch.cuda.FloatTensor( np.random.normal(0, 1, (25, args.latent_dim))) start_epoch = 0 best_fid = 1e4 args.path_helper = set_log_dir('logs', args.exp_name + "_{}".format(args.percent)) logger = create_logger(args.path_helper['log_path']) # logger.info('=> loaded checkpoint %s (epoch %d)' % (checkpoint_file, start_epoch)) logger.info(args) writer_dict = { 'writer': SummaryWriter(args.path_helper['log_path']), 'train_global_steps': start_epoch * len(train_loader), 'valid_global_steps': start_epoch // args.val_freq, } print('=> resuming from %s' % args.load_path) assert os.path.exists(args.load_path) checkpoint_file = os.path.join(args.load_path, 'Model', 'checkpoint.pth') assert os.path.exists(checkpoint_file) checkpoint = torch.load(checkpoint_file) gen_net.load_state_dict(checkpoint['gen_state_dict']) torch.manual_seed(args.random_seed) pruning_generate(gen_net, (1 - args.percent), args.pruning_method) torch.manual_seed(args.random_seed) pruning_generate(avg_gen_net, (1 - args.percent), args.pruning_method) see_remain_rate(gen_net) if args.second_seed: dis_net.apply(weights_init) if args.finetune_D: dis_net.load_state_dict(checkpoint['dis_state_dict']) else: dis_net.load_state_dict(initial_dis_net_weight) gen_weight = gen_net.state_dict() gen_orig_weight = rewind_weight(initial_gen_net_weight, gen_weight.keys()) assert id(gen_weight) != id(gen_orig_weight) gen_weight.update(gen_orig_weight) gen_net.load_state_dict(gen_weight) gen_avg_param = copy_params(gen_net) if args.use_kd_D: orig_dis_net = eval('models.' + args.model + '.Discriminator')(args=args).cuda() orig_dis_net.load(checkpoint['dis_state_dict']) orig_dis_net.eval() # train loop for epoch in tqdm(range(int(start_epoch), int(args.max_epoch)), desc='total progress'): lr_schedulers = (gen_scheduler, dis_scheduler) if args.lr_decay else None see_remain_rate(gen_net) if not args.use_kd_D: train(args, gen_net, dis_net, gen_optimizer, dis_optimizer, gen_avg_param, train_loader, epoch, writer_dict, lr_schedulers) else: train_kd(args, gen_net, dis_net, orig_dis_net, gen_optimizer, dis_optimizer, gen_avg_param, train_loader, epoch, writer_dict, lr_schedulers) if epoch and epoch % args.val_freq == 0 or epoch == int( args.max_epoch) - 1: backup_param = copy_params(gen_net) load_params(gen_net, gen_avg_param) inception_score, fid_score = validate(args, fixed_z, fid_stat, gen_net, writer_dict, epoch) logger.info( 'Inception score: %.4f, FID score: %.4f || @ epoch %d.' % (inception_score, fid_score, epoch)) load_params(gen_net, backup_param) if fid_score < best_fid: best_fid = fid_score is_best = True else: is_best = False else: is_best = False avg_gen_net.load_state_dict(gen_net.state_dict()) load_params(avg_gen_net, gen_avg_param) save_checkpoint( { 'epoch': epoch + 1, 'model': args.model, 'gen_state_dict': gen_net.state_dict(), 'dis_state_dict': dis_net.state_dict(), 'avg_gen_state_dict': avg_gen_net.state_dict(), 'gen_optimizer': gen_optimizer.state_dict(), 'dis_optimizer': dis_optimizer.state_dict(), 'best_fid': best_fid, 'path_helper': args.path_helper }, is_best, args.path_helper['ckpt_path'])
def main(): args = cfg.parse_args() random.seed(args.random_seed) torch.manual_seed(args.random_seed) torch.cuda.manual_seed(args.random_seed) # set tf env _init_inception() inception_path = check_or_download_inception(None) create_inception_graph(inception_path) # weight init def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv2d') != -1: if args.init_type == 'normal': nn.init.normal_(m.weight.data, 0.0, 0.02) elif args.init_type == 'orth': nn.init.orthogonal_(m.weight.data) elif args.init_type == 'xavier_uniform': nn.init.xavier_uniform(m.weight.data, 1.) else: raise NotImplementedError('{} unknown inital type'.format( args.init_type)) elif classname.find('BatchNorm2d') != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0.0) gen_net = Generator(bottom_width=args.bottom_width, gf_dim=args.gf_dim, latent_dim=args.latent_dim).cuda() dis_net = eval('models.' + args.model + '.Discriminator')(args=args).cuda() gen_net.apply(weights_init) dis_net.apply(weights_init) initial_gen_net_weight = torch.load(os.path.join(args.init_path, 'initial_gen_net.pth'), map_location="cpu") initial_dis_net_weight = torch.load(os.path.join(args.init_path, 'initial_dis_net.pth'), map_location="cpu") os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu exp_str = args.dir args.load_path = os.path.join('output', exp_str, 'pth', 'epoch{}.pth'.format(args.load_epoch)) # state dict: assert os.path.exists(args.load_path) checkpoint = torch.load(args.load_path) print('=> loaded checkpoint %s' % args.load_path) state_dict = checkpoint['generator'] gen_net = load_subnet(args, state_dict, initial_gen_net_weight).cuda() avg_gen_net = deepcopy(gen_net) # set optimizer gen_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, gen_net.parameters()), args.g_lr, (args.beta1, args.beta2)) dis_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, dis_net.parameters()), args.d_lr, (args.beta1, args.beta2)) gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0, args.max_iter * args.n_critic) dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0, args.max_iter * args.n_critic) # set up data_loader dataset = datasets.ImageDataset(args) train_loader = dataset.train # fid stat if args.dataset.lower() == 'cifar10': fid_stat = 'fid_stat/fid_stats_cifar10_train.npz' else: raise NotImplementedError('no fid stat for %s' % args.dataset.lower()) assert os.path.exists(fid_stat) # epoch number for dis_net args.max_epoch = args.max_epoch * args.n_critic if args.max_iter: args.max_epoch = np.ceil(args.max_iter * args.n_critic / len(train_loader)) # initial np.random.seed(args.random_seed) fixed_z = torch.cuda.FloatTensor( np.random.normal(0, 1, (25, args.latent_dim))) start_epoch = 0 best_fid = 1e4 args.path_helper = set_log_dir('logs', args.exp_name) logger = create_logger(args.path_helper['log_path']) #logger.info('=> loaded checkpoint %s (epoch %d)' % (checkpoint_file, start_epoch)) logger.info(args) writer_dict = { 'writer': SummaryWriter(args.path_helper['log_path']), 'train_global_steps': start_epoch * len(train_loader), 'valid_global_steps': start_epoch // args.val_freq, } gen_avg_param = copy_params(gen_net) # train loop for epoch in tqdm(range(int(start_epoch), int(args.max_epoch)), desc='total progress'): lr_schedulers = (gen_scheduler, dis_scheduler) if args.lr_decay else None train(args, gen_net, dis_net, gen_optimizer, dis_optimizer, gen_avg_param, train_loader, epoch, writer_dict, lr_schedulers) if epoch and epoch % args.val_freq == 0 or epoch == int( args.max_epoch) - 1: backup_param = copy_params(gen_net) load_params(gen_net, gen_avg_param) inception_score, fid_score = validate(args, fixed_z, fid_stat, gen_net, writer_dict) logger.info( 'Inception score: %.4f, FID score: %.4f || @ epoch %d.' % (inception_score, fid_score, epoch)) load_params(gen_net, backup_param) if fid_score < best_fid: best_fid = fid_score is_best = True else: is_best = False else: is_best = False avg_gen_net.load_state_dict(gen_net.state_dict()) load_params(avg_gen_net, gen_avg_param) save_checkpoint( { 'epoch': epoch + 1, 'model': args.model, 'gen_state_dict': gen_net.state_dict(), 'dis_state_dict': dis_net.state_dict(), 'avg_gen_state_dict': avg_gen_net.state_dict(), 'gen_optimizer': gen_optimizer.state_dict(), 'dis_optimizer': dis_optimizer.state_dict(), 'best_fid': best_fid, 'path_helper': args.path_helper }, is_best, args.path_helper['ckpt_path'])
def main(): args = cfg.parse_args() torch.cuda.manual_seed(args.random_seed) # set tf env _init_inception() inception_path = check_or_download_inception(None) create_inception_graph(inception_path) # import network gen_net = eval("models_search." + args.gen_model + ".Generator")(args=args).cuda() dis_net = eval("models_search." + args.dis_model + ".Discriminator")(args=args).cuda() gen_net.set_arch(args.arch, cur_stage=2) dis_net.cur_stage = 2 # weight init def weights_init(m): classname = m.__class__.__name__ if classname.find("Conv2d") != -1: if args.init_type == "normal": nn.init.normal_(m.weight.data, 0.0, 0.02) elif args.init_type == "orth": nn.init.orthogonal_(m.weight.data) elif args.init_type == "xavier_uniform": nn.init.xavier_uniform(m.weight.data, 1.0) else: raise NotImplementedError("{} unknown inital type".format( args.init_type)) elif classname.find("BatchNorm2d") != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0.0) gen_net.apply(weights_init) dis_net.apply(weights_init) # set optimizer gen_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, gen_net.parameters()), args.g_lr, (args.beta1, args.beta2), ) dis_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, dis_net.parameters()), args.d_lr, (args.beta1, args.beta2), ) gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0, args.max_iter * args.n_critic) dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0, args.max_iter * args.n_critic) # set up data_loader dataset = datasets.ImageDataset(args) train_loader = dataset.train # fid stat if args.dataset.lower() == "cifar10": fid_stat = "fid_stat/fid_stats_cifar10_train.npz" elif args.dataset.lower() == "stl10": fid_stat = "fid_stat/stl10_train_unlabeled_fid_stats_48.npz" else: raise NotImplementedError(f"no fid stat for {args.dataset.lower()}") assert os.path.exists(fid_stat) # epoch number for dis_net args.max_epoch = args.max_epoch * args.n_critic if args.max_iter: args.max_epoch = np.ceil(args.max_iter * args.n_critic / len(train_loader)) # initial fixed_z = torch.cuda.FloatTensor( np.random.normal(0, 1, (25, args.latent_dim))) gen_avg_param = copy_params(gen_net) start_epoch = 0 best_fid = 1e4 # set writer if args.load_path: print(f"=> resuming from {args.load_path}") assert os.path.exists(args.load_path) checkpoint_file = os.path.join(args.load_path, "Model", "checkpoint.pth") assert os.path.exists(checkpoint_file) checkpoint = torch.load(checkpoint_file) start_epoch = checkpoint["epoch"] best_fid = checkpoint["best_fid"] gen_net.load_state_dict(checkpoint["gen_state_dict"]) dis_net.load_state_dict(checkpoint["dis_state_dict"]) gen_optimizer.load_state_dict(checkpoint["gen_optimizer"]) dis_optimizer.load_state_dict(checkpoint["dis_optimizer"]) avg_gen_net = deepcopy(gen_net) avg_gen_net.load_state_dict(checkpoint["avg_gen_state_dict"]) gen_avg_param = copy_params(avg_gen_net) del avg_gen_net args.path_helper = checkpoint["path_helper"] logger = create_logger(args.path_helper["log_path"]) logger.info( f"=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})") else: # create new log dir assert args.exp_name args.path_helper = set_log_dir("logs", args.exp_name) logger = create_logger(args.path_helper["log_path"]) logger.info(args) writer_dict = { "writer": SummaryWriter(args.path_helper["log_path"]), "train_global_steps": start_epoch * len(train_loader), "valid_global_steps": start_epoch // args.val_freq, } # train loop for epoch in tqdm(range(int(start_epoch), int(args.max_epoch)), desc="total progress"): lr_schedulers = (gen_scheduler, dis_scheduler) if args.lr_decay else None train( args, gen_net, dis_net, gen_optimizer, dis_optimizer, gen_avg_param, train_loader, epoch, writer_dict, lr_schedulers, ) if epoch and epoch % args.val_freq == 0 or epoch == int( args.max_epoch) - 1: backup_param = copy_params(gen_net) load_params(gen_net, gen_avg_param) inception_score, fid_score = validate(args, fixed_z, fid_stat, gen_net, writer_dict) logger.info( f"Inception score: {inception_score}, FID score: {fid_score} || @ epoch {epoch}." ) load_params(gen_net, backup_param) if fid_score < best_fid: best_fid = fid_score is_best = True else: is_best = False else: is_best = False avg_gen_net = deepcopy(gen_net) load_params(avg_gen_net, gen_avg_param) save_checkpoint( { "epoch": epoch + 1, "gen_model": args.gen_model, "dis_model": args.dis_model, "gen_state_dict": gen_net.state_dict(), "dis_state_dict": dis_net.state_dict(), "avg_gen_state_dict": avg_gen_net.state_dict(), "gen_optimizer": gen_optimizer.state_dict(), "dis_optimizer": dis_optimizer.state_dict(), "best_fid": best_fid, "path_helper": args.path_helper, }, is_best, args.path_helper["ckpt_path"], ) del avg_gen_net
def main(): args = cfg_train.parse_args() torch.cuda.manual_seed(args.random_seed) # set tf env _init_inception() inception_path = check_or_download_inception(None) create_inception_graph(inception_path) # import network # gen_net = eval('models.' + args.gen_model + '.' + args.gen)(args=args).cuda() genotype_gen = eval('genotypes.%s' % args.arch_gen) gen_net = eval('models.' + args.gen_model + '.' + args.gen)( args, genotype_gen).cuda() # gen_net = eval('models.' + args.gen_model + '.' + args.gen)(args = args).cuda() if 'Discriminator' not in args.dis: genotype_dis = eval('genotypes.%s' % args.arch_dis) dis_net = eval('models.' + args.dis_model + '.' + args.dis)( args, genotype_dis).cuda() else: dis_net = eval('models.' + args.dis_model + '.' + args.dis)(args=args).cuda() # weight init def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv2d') != -1: if args.init_type == 'normal': nn.init.normal_(m.weight.data, 0.0, 0.02) elif args.init_type == 'orth': nn.init.orthogonal_(m.weight.data) elif args.init_type == 'xavier_uniform': nn.init.xavier_uniform(m.weight.data, 1.) else: raise NotImplementedError('{} unknown inital type'.format( args.init_type)) elif classname.find('BatchNorm2d') != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0.0) gen_net.apply(weights_init) dis_net.apply(weights_init) # set up data_loader dataset = datasets.ImageDataset(args) train_loader = dataset.train val_loader = dataset.valid # set optimizer gen_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, gen_net.parameters()), args.g_lr, (args.beta1, args.beta2)) dis_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, dis_net.parameters()), args.d_lr, (args.beta1, args.beta2)) gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, args.g_lr * 0.01, 260 * len(train_loader), args.max_iter * args.n_critic) dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, args.d_lr * 0.01, 260 * len(train_loader), args.max_iter * args.n_critic) # fid stat if args.dataset.lower() == 'cifar10': fid_stat = 'fid_stat/fid_stats_cifar10_train.npz' elif args.dataset.lower() == 'stl10': fid_stat = 'fid_stat/stl10_train_unlabeled_fid_stats_48.npz' elif args.dataset.lower() == 'mnist': fid_stat = 'fid_stat/stl10_train_unlabeled_fid_stats_48.npz' else: raise NotImplementedError(f'no fid stat for {args.dataset.lower()}') assert os.path.exists(fid_stat) # epoch number for dis_net args.max_epoch = args.max_epoch * args.n_critic if args.max_iter: args.max_epoch = np.ceil(args.max_iter * args.n_critic / len(train_loader)) # initial fixed_z = torch.cuda.FloatTensor( np.random.normal(0, 1, (25, args.latent_dim))) fixed_z_sample = torch.cuda.FloatTensor( np.random.normal(0, 1, (args.eval_batch_size, args.latent_dim))) gen_avg_param = copy_params(gen_net) start_epoch = 0 best_fid = 1e4 best_fid_epoch = 0 is_with_fid = 0 std_with_fid = 0. best_is = 0 best_is_epoch = 0 fid_with_is = 0 best_dts = 0 # set writer if args.load_path: print(f'=> resuming from {args.load_path}') assert os.path.exists(args.load_path) checkpoint_file = os.path.join(args.load_path, 'Model', 'checkpoint.pth') assert os.path.exists(checkpoint_file) checkpoint = torch.load(checkpoint_file) start_epoch = checkpoint['epoch'] best_fid = checkpoint['best_fid'] gen_net.load_state_dict(checkpoint['gen_state_dict']) dis_net.load_state_dict(checkpoint['dis_state_dict']) gen_optimizer.load_state_dict(checkpoint['gen_optimizer']) dis_optimizer.load_state_dict(checkpoint['dis_optimizer']) avg_gen_net = deepcopy(gen_net) avg_gen_net.load_state_dict(checkpoint['avg_gen_state_dict']) gen_avg_param = copy_params(avg_gen_net) del avg_gen_net args.path_helper = checkpoint['path_helper'] logger = create_logger(args.path_helper['log_path']) logger.info( f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})') else: # create new log dir assert args.exp_name args.path_helper = set_log_dir('logs', args.exp_name) logger = create_logger(args.path_helper['log_path']) logger.info(args) writer_dict = { 'writer': SummaryWriter(args.path_helper['log_path']), 'train_global_steps': start_epoch * len(train_loader), 'valid_global_steps': start_epoch // args.val_freq, } # calculate the FLOPs and param count of G input = torch.randn(args.gen_batch_size, args.latent_dim).cuda() flops, params = profile(gen_net, inputs=(input, )) flops, params = clever_format([flops, params], "%.3f") logger.info('FLOPs is {}, param count is {}'.format(flops, params)) # train loop dg_list = [] worst_lr = 1e-5 for epoch in tqdm(range(int(start_epoch), int(args.max_epoch)), desc='total progress'): lr_schedulers = (gen_scheduler, dis_scheduler) if args.lr_decay else None train(args, gen_net, dis_net, gen_optimizer, dis_optimizer, gen_avg_param, train_loader, epoch, writer_dict, args.consistent, lr_schedulers) if epoch and epoch % args.val_freq == 0 or epoch == int( args.max_epoch) - 1: backup_param = copy_params(gen_net) load_params(gen_net, gen_avg_param) inception_score, std, fid_score = validate(args, fixed_z, fid_stat, gen_net, writer_dict, args.path_helper, search=False) logger.info( f'Inception score: {inception_score}, FID score: {fid_score}+-{std} || @ epoch {epoch}.' ) load_params(gen_net, backup_param) if fid_score < best_fid: best_fid = fid_score best_fid_epoch = epoch is_with_fid = inception_score std_with_fid = std is_best = True else: is_best = False if inception_score > best_is: best_is = inception_score best_std = std fid_with_is = fid_score best_is_epoch = epoch else: is_best = False # save generated images if epoch % args.image_every == 0: gen_noise = torch.cuda.FloatTensor( np.random.normal(0, 1, (args.eval_batch_size, args.latent_dim))) # gen_images = gen_net(fixed_z_sample) # gen_images = gen_images.reshape(args.eval_batch_size, 32, 32, 3) # gen_images = gen_images.cpu().detach() gen_images = gen_net(fixed_z_sample).mul_(127.5).add_( 127.5).clamp_(0.0, 255.0).permute(0, 2, 3, 1).to('cpu', torch.uint8).numpy() fig = plt.figure() grid = ImageGrid(fig, 111, nrows_ncols=(10, 10), axes_pad=0) for x in range(args.eval_batch_size): grid[x].imshow(gen_images[x]) # cmap="gray") grid[x].set_xticks([]) grid[x].set_yticks([]) plt.savefig( os.path.join(args.path_helper['sample_path'], "epoch_{}.png".format(epoch))) plt.close() avg_gen_net = deepcopy(gen_net) # avg_gen_net = eval('models.'+args.gen_model+'.' + args.gen)(args, genotype_gen).cuda() # avg_gen_net = eval('models.' + args.gen_model + '.' + args.gen)(args=args).cuda() load_params(avg_gen_net, gen_avg_param) save_checkpoint( { 'epoch': epoch + 1, 'gen_model': args.gen_model, 'dis_model': args.dis_model, 'gen_state_dict': gen_net.state_dict(), 'dis_state_dict': dis_net.state_dict(), 'avg_gen_state_dict': avg_gen_net.state_dict(), 'gen_optimizer': gen_optimizer.state_dict(), 'dis_optimizer': dis_optimizer.state_dict(), 'best_fid': best_fid, 'path_helper': args.path_helper }, is_best, args.path_helper['ckpt_path']) del avg_gen_net logger.info( 'best_is is {}+-{}@{} epoch, fid is {}, best_fid is {}@{}, is is {}+-{}' .format(best_is, best_std, best_is_epoch, fid_with_is, best_fid, best_fid_epoch, is_with_fid, std_with_fid))
def main(): if not torch.cuda.is_available(): logging.info('no gpu device available') sys.exit(1) np.random.seed(args.seed) torch.cuda.set_device(args.gpu) cudnn.benchmark = True torch.manual_seed(args.seed) cudnn.enabled = True torch.cuda.manual_seed(args.seed) logging.info('gpu device = %d' % args.gpu) logging.info("args = %s", args) # Create tensorboard logger writer_dict = { 'writer': SummaryWriter(path_helper['log']), 'inner_steps': 0, 'val_steps': 0, 'valid_global_steps': 0 } # set tf env if args.eval: _init_inception() inception_path = check_or_download_inception(None) create_inception_graph(inception_path) # fid_stat if args.dataset.lower() == 'cifar10': fid_stat = 'fid_stat/fid_stats_cifar10_train.npz' elif args.dataset.lower() == 'stl10': fid_stat = 'fid_stat/stl10_train_unlabeled_fid_stats_48.npz' elif args.dataset.lower() == 'mnist': fid_stat = 'fid_stat/stl10_train_unlabeled_fid_stats_48.npz' else: raise NotImplementedError(f'no fid stat for {args.dataset.lower()}') assert os.path.exists(fid_stat) # initial fixed_z = torch.cuda.FloatTensor( np.random.normal(0, 1, (25, args.latent_dim))) FID_best = 1e+4 IS_best = 0. FID_best_epoch = 0 IS_best_epoch = 0 # build gen and dis gen = eval('model_search_gan.' + args.gen)(args) gen = gen.cuda() dis = eval('model_search_gan.' + args.dis)(args) dis = dis.cuda() logging.info("generator param size = %fMB", utils.count_parameters_in_MB(gen)) logging.info("discriminator param size = %fMB", utils.count_parameters_in_MB(dis)) if args.parallel: gen = nn.DataParallel(gen) dis = nn.DataParallel(dis) # resume training if args.load_path != '': gen.load_state_dict( torch.load( os.path.join(args.load_path, 'model', 'weights_gen_' + 'last' + '.pt'))) dis.load_state_dict( torch.load( os.path.join(args.load_path, 'model', 'weights_dis_' + 'last' + '.pt'))) # set optimizer for parameters W of gen and dis gen_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, gen.parameters()), args.g_lr, (args.beta1, args.beta2)) dis_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, dis.parameters()), args.d_lr, (args.beta1, args.beta2)) # set moving average parameters for generator gen_avg_param = copy_params(gen) img_size = 8 if args.grow else args.img_size train_transform, valid_transform = eval('utils.' + '_data_transforms_' + args.dataset + '_resize')(args, img_size) if args.dataset == 'cifar10': train_data = eval('dset.' + dataset[args.dataset])( root=args.data, train=True, download=True, transform=train_transform) elif args.dataset == 'stl10': train_data = eval('dset.' + dataset[args.dataset])( root=args.data, download=True, transform=train_transform) num_train = len(train_data) indices = list(range(num_train)) split = int(np.floor(args.train_portion * num_train)) train_queue = torch.utils.data.DataLoader( train_data, batch_size=args.gen_batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]), pin_memory=True, num_workers=2) valid_queue = torch.utils.data.DataLoader( train_data, batch_size=args.gen_batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler( indices[split:num_train]), pin_memory=True, num_workers=2) logging.info('length of train_queue is {}'.format(len(train_queue))) logging.info('length of valid_queue is {}'.format(len(valid_queue))) max_iter = len(train_queue) * args.epochs scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( gen_optimizer, float(args.epochs), eta_min=args.learning_rate_min) gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0, max_iter * args.n_critic) dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0, max_iter * args.n_critic) architect = Architect_gen(gen, dis, args, 'duality_gap_with_mm', logging) gen.set_gumbel(args.use_gumbel) dis.set_gumbel(args.use_gumbel) for epoch in range(args.start_epoch + 1, args.epochs): scheduler.step() lr = scheduler.get_lr()[0] logging.info('epoch %d lr %e', epoch, lr) logging.info('epoch %d gen_lr %e', epoch, args.g_lr) logging.info('epoch %d dis_lr %e', epoch, args.d_lr) genotype_gen = gen.genotype() logging.info('gen_genotype = %s', genotype_gen) if 'Discriminator' not in args.dis: genotype_dis = dis.genotype() logging.info('dis_genotype = %s', genotype_dis) print('up_1: ', F.softmax(gen.alphas_up_1, dim=-1)) print('up_2: ', F.softmax(gen.alphas_up_2, dim=-1)) print('up_3: ', F.softmax(gen.alphas_up_3, dim=-1)) # determine whether use gumbel or not if epoch == args.fix_alphas_epochs + 1: gen.set_gumbel(args.use_gumbel) dis.set_gumbel(args.use_gumbel) # grow discriminator and generator if args.grow: dis.cur_stage = grow_ctrl(epoch, args.grow_epoch) gen.cur_stage = grow_ctrl(epoch, args.grow_epoch) if args.restrict_dis_grow and dis.cur_stage > 1: dis.cur_stage = 1 print('debug: dis.cur_stage is {}'.format(dis.cur_stage)) if epoch in args.grow_epoch: train_transform, valid_transform = utils._data_transforms_cifar10_resize( args, 2**(gen.cur_stage + 3)) train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform) num_train = len(train_data) indices = list(range(num_train)) split = int(np.floor(args.train_portion * num_train)) train_queue = torch.utils.data.DataLoader( train_data, batch_size=args.gen_batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler( indices[:split]), pin_memory=True, num_workers=2) valid_queue = torch.utils.data.DataLoader( train_data, batch_size=args.gen_batch_size, sampler=torch.utils.data.sampler.SubsetRandomSampler( indices[split:num_train]), pin_memory=True, num_workers=2) else: gen.cur_stage = 2 dis.cur_stage = 2 # training parameters train_gan_parameter(args, train_queue, gen, dis, gen_optimizer, dis_optimizer, gen_avg_param, logging, writer_dict) # training alphas if epoch > args.fix_alphas_epochs: train_gan_alpha(args, train_queue, valid_queue, gen, dis, architect, gen_optimizer, gen_avg_param, epoch, lr, writer_dict, logging) # evaluate the IS and FID if args.eval and epoch % args.eval_every == 0: inception_score, std, fid_score = validate(args, fixed_z, fid_stat, gen, writer_dict, path_helper) logging.info('epoch {}: IS is {}+-{}, FID is {}'.format( epoch, inception_score, std, fid_score)) if inception_score > IS_best: IS_best = inception_score IS_epoch_best = epoch if fid_score < FID_best: FID_best = fid_score FID_epoch_best = epoch logging.info('best epoch {}: IS is {}'.format( IS_best_epoch, IS_best)) logging.info('best epoch {}: FID is {}'.format( FID_best_epoch, FID_best)) utils.save( gen, os.path.join(path_helper['model'], 'weights_gen_{}.pt'.format('last'))) utils.save( dis, os.path.join(path_helper['model'], 'weights_dis_{}.pt'.format('last'))) genotype_gen = gen.genotype() if 'Discriminator' not in args.dis: genotype_dis = dis.genotype() logging.info('best epoch {}: IS is {}'.format(IS_best_epoch, IS_best)) logging.info('best epoch {}: FID is {}'.format(FID_best_epoch, FID_best)) logging.info('final discovered gen_arch is {}'.format(genotype_gen)) if 'Discriminator' not in args.dis: logging.info('final discovered dis_arch is {}'.format(genotype_dis))
def main(): opt = SearchOptions().parse() torch.cuda.manual_seed(12345) _init_inception(MODEL_DIR) inception_path = check_or_download_inception(None) create_inception_graph(inception_path) start_search_iter = 0 cur_stage = 1 delta_grow_steps = [int(opt.grow_step ** i) for i in range(1, opt.max_skip_num)] + \ [int(opt.grow_step ** 3) for _ in range(1, opt.n_resnet - opt.max_skip_num + 1)] opt.max_search_iter = sum(delta_grow_steps) grow_steps = [ sum(delta_grow_steps[:i]) for i in range(len(delta_grow_steps)) ][1:] grow_ctrler = GrowCtrler(opt.grow_step, steps=grow_steps) if opt.load_path: print(f'=> resuming from {opt.load_path}') assert os.path.exists(opt.load_path) checkpoint_file = os.path.join(opt.load_path, 'Model', 'checkpoint.pth') assert os.path.exists(checkpoint_file) checkpoint = torch.load(checkpoint_file, map_location={'cuda:0': 'cpu'}) # set controller && its optimizer cur_stage = checkpoint['cur_stage'] start_search_iter = checkpoint["search_iter"] opt.path_helper = checkpoint['path_helper'] cycle_gan = CycleGANModel(opt) cycle_gan.setup(opt) cycle_controller = CycleControllerModel(opt, cur_stage=cur_stage) cycle_controller.setup(opt) cycle_controller.set(cycle_gan) cycle_gan.load_from_state(checkpoint["cycle_gan"]) cycle_controller.load_from_state(checkpoint["cycle_controller"]) else: opt.path_helper = set_log_dir(opt.checkpoints_dir, opt.name) cycle_gan = CycleGANModel(opt) cycle_gan.setup(opt) cycle_controller = CycleControllerModel(opt, cur_stage=cur_stage) cycle_controller.setup(opt) cycle_controller.set(cycle_gan) dataset = create_dataset( opt) # create a dataset given opt.dataset_mode and other options print('The number of training images = %d' % len(dataset)) writer_dict = { "writer": SummaryWriter(opt.path_helper['log_path']), 'controller_steps': start_search_iter * opt.ctrl_step, 'train_steps': start_search_iter * opt.shared_epoch } g_loss_history = RunningStats(opt.dynamic_reset_window) d_loss_history = RunningStats(opt.dynamic_reset_window) dynamic_reset = None for search_iter in tqdm( range(int(start_search_iter), int(opt.max_search_iter))): tqdm.write(f"<start search iteration {search_iter}>") cycle_controller.reset() if search_iter in grow_steps: cur_stage = grow_ctrler.cur_stage(search_iter) + 1 tqdm.write(f'=> grow to stage {cur_stage}') prev_archs_A, prev_hiddens_A = cycle_controller.get_topk_arch_hidden_A( ) prev_archs_B, prev_hiddens_B = cycle_controller.get_topk_arch_hidden_B( ) del cycle_controller cycle_controller = CycleControllerModel(opt, cur_stage) cycle_controller.setup(opt) cycle_controller.set(cycle_gan, prev_hiddens_A, prev_hiddens_B, prev_archs_A, prev_archs_B) dynamic_reset = cyclgan_train(opt, cycle_gan, cycle_controller, dataset, g_loss_history, d_loss_history, writer_dict) controller_train(opt, cycle_gan, cycle_controller, writer_dict) if dynamic_reset: tqdm.write('re-initialize share GAN') del cycle_gan cycle_gan = CycleGANModel(opt) cycle_gan.setup(opt) save_checkpoint( { 'cur_stage': cur_stage, 'search_iter': search_iter + 1, 'cycle_gan': cycle_gan.save_networks(epoch=search_iter), 'cycle_controller': cycle_controller.save_networks(epoch=search_iter), 'path_helper': opt.path_helper }, False, opt.path_helper['ckpt_path']) final_archs_A, _ = cycle_controller.get_topk_arch_hidden_A() final_archs_B, _ = cycle_controller.get_topk_arch_hidden_B() print(f"discovered archs: {final_archs_A}") print(f"discovered archs: {final_archs_B}")
def main(): args = cfg.parse_args() random.seed(args.random_seed) torch.manual_seed(args.random_seed) torch.cuda.manual_seed(args.random_seed) np.random.seed(args.random_seed) torch.backends.cudnn.deterministic = False torch.backends.cudnn.benchmark = True os.environ['PYTHONHASHSEED'] = str(args.random_seed) # set tf env _init_inception() inception_path = check_or_download_inception(None) create_inception_graph(inception_path) # import network gen_net = eval('models.'+args.model+'.Generator')(args=args) dis_net = eval('models.'+args.model+'.Discriminator')(args=args) initial_gen_net_weight = torch.load(os.path.join(args.init_path, 'initial_gen_net.pth'), map_location="cpu") initial_dis_net_weight = torch.load(os.path.join(args.init_path, 'initial_dis_net.pth'), map_location="cpu") gen_net = gen_net.cuda() dis_net = dis_net.cuda() gen_net.load_state_dict(initial_gen_net_weight) dis_net.load_state_dict(initial_dis_net_weight) # set optimizer gen_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, gen_net.parameters()), args.g_lr, (args.beta1, args.beta2)) dis_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, dis_net.parameters()), args.d_lr, (args.beta1, args.beta2)) gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0, args.max_iter * args.n_critic) dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0, args.max_iter * args.n_critic) # set up data_loader dataset = datasets.ImageDataset(args) train_loader = dataset.train # fid stat if args.dataset.lower() == 'cifar10': fid_stat = 'fid_stat/fid_stats_cifar10_train.npz' elif args.dataset.lower() == 'stl10': fid_stat = 'fid_stat/fid_stats_stl10_train.npz' else: raise NotImplementedError('no fid stat for %s' % args.dataset.lower()) assert os.path.exists(fid_stat) # epoch number for dis_net args.max_epoch = args.max_epoch * args.n_critic if args.max_iter: args.max_epoch = np.ceil(args.max_iter * args.n_critic / len(train_loader)) # initial fixed_z = torch.cuda.FloatTensor(np.random.normal(0, 1, (25, args.latent_dim))) gen_avg_param = copy_params(gen_net) start_epoch = 0 best_fid = 1e4 # set writer if args.load_path: print('=> resuming from %s' % args.load_path) assert os.path.exists(args.load_path) checkpoint_file = os.path.join(args.load_path, 'Model', 'checkpoint.pth') assert os.path.exists(checkpoint_file) checkpoint = torch.load(checkpoint_file) start_epoch = checkpoint['epoch'] best_fid = checkpoint['best_fid'] gen_net.load_state_dict(checkpoint['gen_state_dict']) dis_net.load_state_dict(checkpoint['dis_state_dict']) gen_optimizer.load_state_dict(checkpoint['gen_optimizer']) dis_optimizer.load_state_dict(checkpoint['dis_optimizer']) avg_gen_net = deepcopy(gen_net) avg_gen_net.load_state_dict(checkpoint['avg_gen_state_dict']) gen_avg_param = copy_params(avg_gen_net) del avg_gen_net args.path_helper = checkpoint['path_helper'] logger = create_logger(args.path_helper['log_path']) logger.info('=> loaded checkpoint %s (epoch %d)' % (checkpoint_file, start_epoch)) else: # create new log dir assert args.exp_name args.path_helper = set_log_dir('logs', args.exp_name) logger = create_logger(args.path_helper['log_path']) logger.info(args) writer_dict = { 'writer': SummaryWriter(args.path_helper['log_path']), 'train_global_steps': start_epoch * len(train_loader), 'valid_global_steps': start_epoch // args.val_freq, } # train loop switch = False for epoch in range(int(start_epoch), int(args.max_epoch)): lr_schedulers = (gen_scheduler, dis_scheduler) if args.lr_decay else None train(args, gen_net, dis_net, gen_optimizer, dis_optimizer, gen_avg_param, train_loader, epoch, writer_dict, lr_schedulers) if epoch and epoch % args.val_freq == 0 or epoch == int(args.max_epoch)-1: backup_param = copy_params(gen_net) load_params(gen_net, gen_avg_param) inception_score, fid_score = validate(args, fixed_z, fid_stat, gen_net, writer_dict, epoch) logger.info('Inception score: %.4f, FID score: %.4f || @ epoch %d.' % (inception_score, fid_score, epoch)) load_params(gen_net, backup_param) if fid_score < best_fid: best_fid = fid_score is_best = True else: is_best = False else: is_best = False avg_gen_net = deepcopy(gen_net) load_params(avg_gen_net, gen_avg_param) save_checkpoint({ 'epoch': epoch + 1, 'model': args.model, 'gen_state_dict': gen_net.state_dict(), 'dis_state_dict': dis_net.state_dict(), 'avg_gen_state_dict': avg_gen_net.state_dict(), 'gen_optimizer': gen_optimizer.state_dict(), 'dis_optimizer': dis_optimizer.state_dict(), 'best_fid': best_fid, 'path_helper': args.path_helper, 'seed': args.random_seed }, is_best, args.path_helper['ckpt_path']) del avg_gen_net
args.fid_buffer_dir = os.path.join(args.output_dir, 'fid_buffer') args.do_IS = False args.do_FID = False create_dir(args.pth_dir), create_dir(args.img_dir), create_dir( args.gamma_dir), create_dir(args.fid_buffer_dir) # gpu: torch.manual_seed(args.random_seed) torch.cuda.manual_seed(args.random_seed) os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu torch.backends.cudnn.deterministic = False torch.backends.cudnn.benchmark = True # set tf env if args.do_IS: _init_inception() if args.do_FID: inception_path = check_or_download_inception(None) create_inception_graph(inception_path) # G0: with torch.no_grad(): # define model G0 = Generator(bottom_width=args.bottom_width, gf_dim=args.gf_dim, latent_dim=args.latent_dim).cuda() # load ckpt pth_path = os.path.join(args.load_path, 'Model', 'checkpoint_best.pth') checkpoint = torch.load(pth_path) G0.load_state_dict(checkpoint['avg_gen_state_dict']) best_dense_epoch = checkpoint['epoch']
def main(): args = cfg.parse_args() torch.cuda.manual_seed(args.random_seed) # set visible GPU ids if len(args.gpu_ids) > 0: os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_ids # set TensorFlow environment for evaluation (calculate IS and FID) _init_inception() inception_path = check_or_download_inception('./tmp/imagenet/') create_inception_graph(inception_path) # the first GPU in visible GPUs is dedicated for evaluation (running Inception model) str_ids = args.gpu_ids.split(',') args.gpu_ids = [] for id in range(len(str_ids)): if id >= 0: args.gpu_ids.append(id) if len(args.gpu_ids) > 1: args.gpu_ids = args.gpu_ids[1:] else: args.gpu_ids = args.gpu_ids # genotype G genotypes_root = os.path.join('exps', args.genotypes_exp, 'Genotypes') genotype_G = np.load(os.path.join(genotypes_root, 'latest_G.npy')) # import network from genotype basemodel_gen = eval('archs.' + args.arch + '.Generator')(args, genotype_G) gen_net = torch.nn.DataParallel( basemodel_gen, device_ids=args.gpu_ids).cuda(args.gpu_ids[0]) basemodel_dis = eval('archs.' + args.arch + '.Discriminator')(args) dis_net = torch.nn.DataParallel( basemodel_dis, device_ids=args.gpu_ids).cuda(args.gpu_ids[0]) # basemodel_gen = eval('archs.' + args.arch + '.Generator')(args=args) # gen_net = torch.nn.DataParallel(basemodel_gen, device_ids=args.gpu_ids).cuda(args.gpu_ids[0]) # basemodel_dis = eval('archs.' + args.arch + '.Discriminator')(args=args) # dis_net = torch.nn.DataParallel(basemodel_dis, device_ids=args.gpu_ids).cuda(args.gpu_ids[0]) # weight init def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv2d') != -1: if args.init_type == 'normal': nn.init.normal_(m.weight.data, 0.0, 0.02) elif args.init_type == 'orth': nn.init.orthogonal_(m.weight.data) elif args.init_type == 'xavier_uniform': nn.init.xavier_uniform(m.weight.data, 1.) else: raise NotImplementedError('{} unknown inital type'.format( args.init_type)) elif classname.find('BatchNorm2d') != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0.0) gen_net.apply(weights_init) dis_net.apply(weights_init) # set up data_loader dataset = datasets.ImageDataset(args) train_loader = dataset.train # epoch number for dis_net args.max_epoch_D = args.max_epoch_G * args.n_critic if args.max_iter_G: args.max_epoch_D = np.ceil(args.max_iter_G * args.n_critic / len(train_loader)) max_iter_D = args.max_epoch_D * len(train_loader) # set optimizer gen_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, gen_net.parameters()), args.g_lr, (args.beta1, args.beta2)) dis_optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, dis_net.parameters()), args.d_lr, (args.beta1, args.beta2)) gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0, max_iter_D) dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0, max_iter_D) # fid stat if args.dataset.lower() == 'cifar10': fid_stat = 'fid_stat/fid_stats_cifar10_train.npz' elif args.dataset.lower() == 'stl10': fid_stat = 'fid_stat/stl10_train_unlabeled_fid_stats_48.npz' else: raise NotImplementedError(f'no fid stat for {args.dataset.lower()}') assert os.path.exists(fid_stat) # initial gen_avg_param = copy_params(gen_net) start_epoch = 0 best_fid = 1e4 # set writer if args.checkpoint: # resuming print(f'=> resuming from {args.checkpoint}') assert os.path.exists(os.path.join('exps', args.checkpoint)) checkpoint_file = os.path.join('exps', args.checkpoint, 'Model', 'checkpoint_best.pth') assert os.path.exists(checkpoint_file) checkpoint = torch.load(checkpoint_file) start_epoch = checkpoint['epoch'] best_fid = checkpoint['best_fid'] gen_net.load_state_dict(checkpoint['gen_state_dict']) dis_net.load_state_dict(checkpoint['dis_state_dict']) gen_optimizer.load_state_dict(checkpoint['gen_optimizer']) dis_optimizer.load_state_dict(checkpoint['dis_optimizer']) avg_gen_net = deepcopy(gen_net) avg_gen_net.load_state_dict(checkpoint['avg_gen_state_dict']) gen_avg_param = copy_params(avg_gen_net) del avg_gen_net args.path_helper = checkpoint['path_helper'] logger = create_logger(args.path_helper['log_path']) logger.info( f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})') else: # create new log dir assert args.exp_name args.path_helper = set_log_dir('exps', args.exp_name) logger = create_logger(args.path_helper['log_path']) logger.info(args) writer_dict = { 'writer': SummaryWriter(args.path_helper['log_path']), 'train_global_steps': start_epoch * len(train_loader), 'valid_global_steps': start_epoch // args.val_freq, } # model size logger.info('Param size of G = %fMB', count_parameters_in_MB(gen_net)) logger.info('Param size of D = %fMB', count_parameters_in_MB(dis_net)) print_FLOPs(basemodel_gen, (1, args.latent_dim), logger) print_FLOPs(basemodel_dis, (1, 3, args.img_size, args.img_size), logger) # for visualization if args.draw_arch: from utils.genotype import draw_graph_G draw_graph_G(genotype_G, save=True, file_path=os.path.join(args.path_helper['graph_vis_path'], 'latest_G')) fixed_z = torch.cuda.FloatTensor( np.random.normal(0, 1, (100, args.latent_dim))) # train loop for epoch in tqdm(range(int(start_epoch), int(args.max_epoch_D)), desc='total progress'): lr_schedulers = (gen_scheduler, dis_scheduler) if args.lr_decay else None train(args, gen_net, dis_net, gen_optimizer, dis_optimizer, gen_avg_param, train_loader, epoch, writer_dict, lr_schedulers) if epoch % args.val_freq == 0 or epoch == int(args.max_epoch_D) - 1: backup_param = copy_params(gen_net) load_params(gen_net, gen_avg_param) inception_score, std, fid_score = validate(args, fixed_z, fid_stat, gen_net, writer_dict) logger.info( f'Inception score mean: {inception_score}, Inception score std: {std}, ' f'FID score: {fid_score} || @ epoch {epoch}.') load_params(gen_net, backup_param) if fid_score < best_fid: best_fid = fid_score is_best = True else: is_best = False else: is_best = False # save model avg_gen_net = deepcopy(gen_net) load_params(avg_gen_net, gen_avg_param) save_checkpoint( { 'epoch': epoch + 1, 'model': args.arch, 'gen_state_dict': gen_net.state_dict(), 'dis_state_dict': dis_net.state_dict(), 'avg_gen_state_dict': avg_gen_net.state_dict(), 'gen_optimizer': gen_optimizer.state_dict(), 'dis_optimizer': dis_optimizer.state_dict(), 'best_fid': best_fid, 'path_helper': args.path_helper }, is_best, args.path_helper['ckpt_path']) del avg_gen_net
def main(): args = cfg.parse_args() torch.cuda.manual_seed(args.random_seed) print(args) # set tf env _init_inception() inception_path = check_or_download_inception(None) create_inception_graph(inception_path) # weight init def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv2d') != -1: if args.init_type == 'normal': nn.init.normal_(m.weight.data, 0.0, 0.02) elif args.init_type == 'orth': nn.init.orthogonal_(m.weight.data) elif args.init_type == 'xavier_uniform': nn.init.xavier_uniform(m.weight.data, 1.) else: raise NotImplementedError('{} unknown inital type'.format( args.init_type)) elif classname.find('BatchNorm2d') != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0.0) gen_net, dis_net, gen_optimizer, dis_optimizer = create_shared_gan( args, weights_init) # initial start_search_iter = 0 # set writer if args.load_path: print(f'=> resuming from {args.load_path}') assert os.path.exists(args.load_path) checkpoint_file = os.path.join(args.load_path, 'Model', 'checkpoint.pth') assert os.path.exists(checkpoint_file) checkpoint = torch.load(checkpoint_file) cur_stage = checkpoint['cur_stage'] start_search_iter = checkpoint['search_iter'] gen_net.load_state_dict(checkpoint['gen_state_dict']) dis_net.load_state_dict(checkpoint['dis_state_dict']) gen_optimizer.load_state_dict(checkpoint['gen_optimizer']) dis_optimizer.load_state_dict(checkpoint['dis_optimizer']) prev_archs = checkpoint['prev_archs'] prev_hiddens = checkpoint['prev_hiddens'] args.path_helper = checkpoint['path_helper'] logger = create_logger(args.path_helper['log_path']) logger.info( f'=> loaded checkpoint {checkpoint_file} (search iteration {start_search_iter})' ) else: # create new log dir assert args.exp_name args.path_helper = set_log_dir('logs', args.exp_name) logger = create_logger(args.path_helper['log_path']) prev_archs = None prev_hiddens = None # set controller && its optimizer cur_stage = 0 # set up data_loader dataset = datasets.ImageDataset(args, 2**(cur_stage + 3)) train_loader = dataset.train print(args.rl_num_eval_img, "##############################") logger.info(args) writer_dict = { 'writer': SummaryWriter(args.path_helper['log_path']), 'controller_steps': start_search_iter * args.ctrl_step } g_loss_history = RunningStats(args.dynamic_reset_window) d_loss_history = RunningStats(args.dynamic_reset_window) # train loop Agent = SAC(131) print(Agent.alpha) memory = ReplayMemory(2560000) updates = 0 outinfo = { 'rewards': [], 'a_loss': [], 'critic_error': [], } Best = False Z_NUMPY = None WARMUP = True update_time = 1 for search_iter in tqdm(range(int(start_search_iter), 100), desc='search progress'): logger.info(f"<start search iteration {search_iter}>") if search_iter >= 1: WARMUP = False ### Define number of layers, currently only support 1->3 total_layer_num = 3 ### Different image size for different layers ds = [ datasets.ImageDataset(args, 2**(k + 3)) for k in range(total_layer_num) ] train_loaders = [d.train for d in ds] last_R = 0. # Initial reward last_fid = 10000 # Inital reward last_arch = [] # Set exploration if search_iter > 69: update_time = 10 Best = True else: Best = False gen_net.set_stage(-1) last_R, last_fid, last_state = get_is(args, gen_net, args.rl_num_eval_img, get_is_score=True) for layer in range(total_layer_num): cur_stage = layer # This defines which layer to use as output, for example, if cur_stage==0, then the output will be the first layer output. Set it to 2 if you want the output of the last layer. action = Agent.select_action([layer, last_R, 0.01 * last_fid] + last_state, Best) arch = [ action[0][0], action[0][1], action[1][0], action[1][1], action[1][2], action[2][0], action[2][1], action[2][2], action[3][0], action[3][1], action[4][0], action[4][1], action[5][0], action[5][1] ] # print(arch) # argmax to get int description of arch cur_arch = [np.argmax(k) for k in action] # Pad the skip option 0=False (for only layer 1 and layer2, not layer0, see builing_blocks.py for why) if layer == 0: cur_arch = cur_arch[0:4] elif layer == 1: cur_arch = cur_arch[0:5] elif layer == 2: if cur_arch[4] + cur_arch[5] == 2: cur_arch = cur_arch[0:4] + [3] elif cur_arch[4] + cur_arch[5] == 0: cur_arch = cur_arch[0:4] + [0] elif cur_arch[4] == 1 and cur_arch[5] == 0: cur_arch = cur_arch[0:4] + [1] else: cur_arch = cur_arch[0:4] + [2] # Get the network arch with the new architecture attached. last_arch += cur_arch gen_net.set_arch(last_arch, layer) # Set the network, given cur_stage # Train network dynamic_reset = train_qin(args, gen_net, dis_net, g_loss_history, d_loss_history, gen_optimizer, dis_optimizer, train_loaders[layer], cur_stage, smooth=False, WARMUP=WARMUP) # Get reward, use the jth layer output for generation. (layer 0:j), and the proposed progressive state R, fid, state = get_is(args, gen_net, args.rl_num_eval_img, z_numpy=Z_NUMPY) # Print exploitation mark, for better readability of the log. if Best: print("arch:", cur_arch, "Exploitation:", Best) else: print("arch:", cur_arch, "Exploring...") # Proxy reward of the up-to-now (0:j) architecture. print("update times:", updates, "step:", layer + 1, "IS:", R, "FID:", fid) mask = 0 if layer == total_layer_num - 1 else 1 if search_iter >= 0: # warm up memory.push([layer, last_R, 0.01 * last_fid] + last_state, arch, R - last_R + 0.01 * (last_fid - fid), [layer + 1, R, 0.01 * fid] + state, mask) # Append transition to memory if len(memory) >= 64: # Number of updates per step in environment for i in range(update_time): # Update parameters of all the networks critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha = Agent.update_parameters( memory, min(len(memory), 256), updates) updates += 1 outinfo['critic_error'] = min(critic_1_loss, critic_2_loss) outinfo['entropy'] = ent_loss outinfo['a_loss'] = policy_loss print("full batch", outinfo, alpha) last_R = R # next step last_fid = fid last_state = state outinfo['rewards'] = R critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha = Agent.update_parameters( memory, len(memory), updates) updates += 1 outinfo['critic_error'] = min(critic_1_loss, critic_2_loss) outinfo['entropy'] = ent_loss outinfo['a_loss'] = policy_loss print("full batch", outinfo, alpha) # Clean up and start a new trajectory from scratch del gen_net, dis_net, gen_optimizer, dis_optimizer gen_net, dis_net, gen_optimizer, dis_optimizer = create_shared_gan( args, weights_init) print(outinfo, len(memory)) Agent.save_model("test") WARMUP = False