def main(): args = cfg.parse_args() torch.cuda.manual_seed(args.random_seed) # set tf env _init_inception(MODEL_DIR) inception_path = check_or_download_inception(None) create_inception_graph(inception_path) # weight init def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv2d') != -1: if args.init_type == 'normal': nn.init.normal_(m.weight.data, 0.0, 0.02) elif args.init_type == 'orth': nn.init.orthogonal_(m.weight.data) elif args.init_type == 'xavier_uniform': nn.init.xavier_uniform(m.weight.data, 1.) else: raise NotImplementedError('{} unknown inital type'.format( args.init_type)) elif classname.find('BatchNorm2d') != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0.0) gen_net, dis_net, gen_optimizer, dis_optimizer = create_shared_gan( args, weights_init) # set grow controller grow_ctrler = GrowCtrler(args.grow_step1, args.grow_step2) # initial start_search_iter = 0 # set writer if args.load_path: print(f'=> resuming from {args.load_path}') assert os.path.exists(args.load_path) checkpoint_file = os.path.join(args.load_path, 'Model', 'checkpoint.pth') assert os.path.exists(checkpoint_file) checkpoint = torch.load(checkpoint_file, map_location={'cuda:0': 'cpu'}) # set controller && its optimizer cur_stage = checkpoint['cur_stage'] controller, ctrl_optimizer = create_ctrler(args, cur_stage, weights_init) start_search_iter = checkpoint['search_iter'] gen_net.load_state_dict(checkpoint['gen_state_dict']) dis_net.load_state_dict(checkpoint['dis_state_dict']) controller.load_state_dict(checkpoint['ctrl_state_dict']) gen_optimizer.load_state_dict(checkpoint['gen_optimizer']) dis_optimizer.load_state_dict(checkpoint['dis_optimizer']) ctrl_optimizer.load_state_dict(checkpoint['ctrl_optimizer']) prev_archs = checkpoint['prev_archs'] prev_hiddens = checkpoint['prev_hiddens'] args.path_helper = checkpoint['path_helper'] logger = create_logger(args.path_helper['log_path']) logger.info( f'=> loaded checkpoint {checkpoint_file} (search iteration {start_search_iter})' ) else: # create new log dir assert args.exp_name args.path_helper = set_log_dir('logs', args.exp_name) logger = create_logger(args.path_helper['log_path']) prev_archs = None prev_hiddens = None # set controller && its optimizer cur_stage = 0 controller, ctrl_optimizer = create_ctrler(args, cur_stage, weights_init) # set up data_loader dataset = datasets.ImageDataset(args, 2**(cur_stage + 3), args.dis_batch_size, args.num_workers) train_loader = dataset.train logger.info(args) writer_dict = { 'writer': SummaryWriter(args.path_helper['log_path']), 'controller_steps': start_search_iter * args.ctrl_step } g_loss_history = RunningStats(args.dynamic_reset_window) d_loss_history = RunningStats(args.dynamic_reset_window) # train loop for search_iter in tqdm(range(int(start_search_iter), int(args.max_search_iter)), desc='search progress'): logger.info(f"<start search iteration {search_iter}>") if search_iter == args.grow_step1 or search_iter == args.grow_step2: # save cur_stage = grow_ctrler.cur_stage(search_iter) logger.info(f'=> grow to stage {cur_stage}') prev_archs, prev_hiddens = get_topk_arch_hidden( args, controller, gen_net, prev_archs, prev_hiddens) # grow section del controller del ctrl_optimizer controller, ctrl_optimizer = create_ctrler(args, cur_stage, weights_init) dataset = datasets.ImageDataset(args, 2**(cur_stage + 3), args.dis_batch_size, args.num_workers) train_loader = dataset.train dynamic_reset = train_shared(args, gen_net, dis_net, g_loss_history, d_loss_history, controller, gen_optimizer, dis_optimizer, train_loader, prev_hiddens=prev_hiddens, prev_archs=prev_archs) train_controller(args, controller, ctrl_optimizer, gen_net, prev_hiddens, prev_archs, writer_dict) if dynamic_reset: logger.info('re-initialize share GAN') del gen_net, dis_net, gen_optimizer, dis_optimizer gen_net, dis_net, gen_optimizer, dis_optimizer = create_shared_gan( args, weights_init) save_checkpoint( { 'cur_stage': cur_stage, 'search_iter': search_iter + 1, 'gen_model': args.gen_model, 'dis_model': args.dis_model, 'controller': args.controller, 'gen_state_dict': gen_net.state_dict(), 'dis_state_dict': dis_net.state_dict(), 'ctrl_state_dict': controller.state_dict(), 'gen_optimizer': gen_optimizer.state_dict(), 'dis_optimizer': dis_optimizer.state_dict(), 'ctrl_optimizer': ctrl_optimizer.state_dict(), 'prev_archs': prev_archs, 'prev_hiddens': prev_hiddens, 'path_helper': args.path_helper }, False, args.path_helper['ckpt_path']) final_archs, _ = get_topk_arch_hidden(args, controller, gen_net, prev_archs, prev_hiddens) logger.info(f"discovered archs: {final_archs}")
def main(): args = cfg.parse_args() torch.cuda.manual_seed(args.random_seed) print(args) # set tf env _init_inception() inception_path = check_or_download_inception(None) create_inception_graph(inception_path) # weight init def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv2d') != -1: if args.init_type == 'normal': nn.init.normal_(m.weight.data, 0.0, 0.02) elif args.init_type == 'orth': nn.init.orthogonal_(m.weight.data) elif args.init_type == 'xavier_uniform': nn.init.xavier_uniform(m.weight.data, 1.) else: raise NotImplementedError('{} unknown inital type'.format( args.init_type)) elif classname.find('BatchNorm2d') != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0.0) gen_net, dis_net, gen_optimizer, dis_optimizer = create_shared_gan( args, weights_init) # initial start_search_iter = 0 # set writer if args.load_path: print(f'=> resuming from {args.load_path}') assert os.path.exists(args.load_path) checkpoint_file = os.path.join(args.load_path, 'Model', 'checkpoint.pth') assert os.path.exists(checkpoint_file) checkpoint = torch.load(checkpoint_file) cur_stage = checkpoint['cur_stage'] start_search_iter = checkpoint['search_iter'] gen_net.load_state_dict(checkpoint['gen_state_dict']) dis_net.load_state_dict(checkpoint['dis_state_dict']) gen_optimizer.load_state_dict(checkpoint['gen_optimizer']) dis_optimizer.load_state_dict(checkpoint['dis_optimizer']) prev_archs = checkpoint['prev_archs'] prev_hiddens = checkpoint['prev_hiddens'] args.path_helper = checkpoint['path_helper'] logger = create_logger(args.path_helper['log_path']) logger.info( f'=> loaded checkpoint {checkpoint_file} (search iteration {start_search_iter})' ) else: # create new log dir assert args.exp_name args.path_helper = set_log_dir('logs', args.exp_name) logger = create_logger(args.path_helper['log_path']) prev_archs = None prev_hiddens = None # set controller && its optimizer cur_stage = 0 # set up data_loader dataset = datasets.ImageDataset(args, 2**(cur_stage + 3)) train_loader = dataset.train print(args.rl_num_eval_img, "##############################") logger.info(args) writer_dict = { 'writer': SummaryWriter(args.path_helper['log_path']), 'controller_steps': start_search_iter * args.ctrl_step } g_loss_history = RunningStats(args.dynamic_reset_window) d_loss_history = RunningStats(args.dynamic_reset_window) # train loop Agent = SAC(131) print(Agent.alpha) memory = ReplayMemory(2560000) updates = 0 outinfo = { 'rewards': [], 'a_loss': [], 'critic_error': [], } Best = False Z_NUMPY = None WARMUP = True update_time = 1 for search_iter in tqdm(range(int(start_search_iter), 100), desc='search progress'): logger.info(f"<start search iteration {search_iter}>") if search_iter >= 1: WARMUP = False ### Define number of layers, currently only support 1->3 total_layer_num = 3 ### Different image size for different layers ds = [ datasets.ImageDataset(args, 2**(k + 3)) for k in range(total_layer_num) ] train_loaders = [d.train for d in ds] last_R = 0. # Initial reward last_fid = 10000 # Inital reward last_arch = [] # Set exploration if search_iter > 69: update_time = 10 Best = True else: Best = False gen_net.set_stage(-1) last_R, last_fid, last_state = get_is(args, gen_net, args.rl_num_eval_img, get_is_score=True) for layer in range(total_layer_num): cur_stage = layer # This defines which layer to use as output, for example, if cur_stage==0, then the output will be the first layer output. Set it to 2 if you want the output of the last layer. action = Agent.select_action([layer, last_R, 0.01 * last_fid] + last_state, Best) arch = [ action[0][0], action[0][1], action[1][0], action[1][1], action[1][2], action[2][0], action[2][1], action[2][2], action[3][0], action[3][1], action[4][0], action[4][1], action[5][0], action[5][1] ] # print(arch) # argmax to get int description of arch cur_arch = [np.argmax(k) for k in action] # Pad the skip option 0=False (for only layer 1 and layer2, not layer0, see builing_blocks.py for why) if layer == 0: cur_arch = cur_arch[0:4] elif layer == 1: cur_arch = cur_arch[0:5] elif layer == 2: if cur_arch[4] + cur_arch[5] == 2: cur_arch = cur_arch[0:4] + [3] elif cur_arch[4] + cur_arch[5] == 0: cur_arch = cur_arch[0:4] + [0] elif cur_arch[4] == 1 and cur_arch[5] == 0: cur_arch = cur_arch[0:4] + [1] else: cur_arch = cur_arch[0:4] + [2] # Get the network arch with the new architecture attached. last_arch += cur_arch gen_net.set_arch(last_arch, layer) # Set the network, given cur_stage # Train network dynamic_reset = train_qin(args, gen_net, dis_net, g_loss_history, d_loss_history, gen_optimizer, dis_optimizer, train_loaders[layer], cur_stage, smooth=False, WARMUP=WARMUP) # Get reward, use the jth layer output for generation. (layer 0:j), and the proposed progressive state R, fid, state = get_is(args, gen_net, args.rl_num_eval_img, z_numpy=Z_NUMPY) # Print exploitation mark, for better readability of the log. if Best: print("arch:", cur_arch, "Exploitation:", Best) else: print("arch:", cur_arch, "Exploring...") # Proxy reward of the up-to-now (0:j) architecture. print("update times:", updates, "step:", layer + 1, "IS:", R, "FID:", fid) mask = 0 if layer == total_layer_num - 1 else 1 if search_iter >= 0: # warm up memory.push([layer, last_R, 0.01 * last_fid] + last_state, arch, R - last_R + 0.01 * (last_fid - fid), [layer + 1, R, 0.01 * fid] + state, mask) # Append transition to memory if len(memory) >= 64: # Number of updates per step in environment for i in range(update_time): # Update parameters of all the networks critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha = Agent.update_parameters( memory, min(len(memory), 256), updates) updates += 1 outinfo['critic_error'] = min(critic_1_loss, critic_2_loss) outinfo['entropy'] = ent_loss outinfo['a_loss'] = policy_loss print("full batch", outinfo, alpha) last_R = R # next step last_fid = fid last_state = state outinfo['rewards'] = R critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha = Agent.update_parameters( memory, len(memory), updates) updates += 1 outinfo['critic_error'] = min(critic_1_loss, critic_2_loss) outinfo['entropy'] = ent_loss outinfo['a_loss'] = policy_loss print("full batch", outinfo, alpha) # Clean up and start a new trajectory from scratch del gen_net, dis_net, gen_optimizer, dis_optimizer gen_net, dis_net, gen_optimizer, dis_optimizer = create_shared_gan( args, weights_init) print(outinfo, len(memory)) Agent.save_model("test") WARMUP = False
def main(): opt = SearchOptions().parse() torch.cuda.manual_seed(12345) _init_inception(MODEL_DIR) inception_path = check_or_download_inception(None) create_inception_graph(inception_path) start_search_iter = 0 cur_stage = 1 delta_grow_steps = [int(opt.grow_step ** i) for i in range(1, opt.max_skip_num)] + \ [int(opt.grow_step ** 3) for _ in range(1, opt.n_resnet - opt.max_skip_num + 1)] opt.max_search_iter = sum(delta_grow_steps) grow_steps = [ sum(delta_grow_steps[:i]) for i in range(len(delta_grow_steps)) ][1:] grow_ctrler = GrowCtrler(opt.grow_step, steps=grow_steps) if opt.load_path: print(f'=> resuming from {opt.load_path}') assert os.path.exists(opt.load_path) checkpoint_file = os.path.join(opt.load_path, 'Model', 'checkpoint.pth') assert os.path.exists(checkpoint_file) checkpoint = torch.load(checkpoint_file, map_location={'cuda:0': 'cpu'}) # set controller && its optimizer cur_stage = checkpoint['cur_stage'] start_search_iter = checkpoint["search_iter"] opt.path_helper = checkpoint['path_helper'] cycle_gan = CycleGANModel(opt) cycle_gan.setup(opt) cycle_controller = CycleControllerModel(opt, cur_stage=cur_stage) cycle_controller.setup(opt) cycle_controller.set(cycle_gan) cycle_gan.load_from_state(checkpoint["cycle_gan"]) cycle_controller.load_from_state(checkpoint["cycle_controller"]) else: opt.path_helper = set_log_dir(opt.checkpoints_dir, opt.name) cycle_gan = CycleGANModel(opt) cycle_gan.setup(opt) cycle_controller = CycleControllerModel(opt, cur_stage=cur_stage) cycle_controller.setup(opt) cycle_controller.set(cycle_gan) dataset = create_dataset( opt) # create a dataset given opt.dataset_mode and other options print('The number of training images = %d' % len(dataset)) writer_dict = { "writer": SummaryWriter(opt.path_helper['log_path']), 'controller_steps': start_search_iter * opt.ctrl_step, 'train_steps': start_search_iter * opt.shared_epoch } g_loss_history = RunningStats(opt.dynamic_reset_window) d_loss_history = RunningStats(opt.dynamic_reset_window) dynamic_reset = None for search_iter in tqdm( range(int(start_search_iter), int(opt.max_search_iter))): tqdm.write(f"<start search iteration {search_iter}>") cycle_controller.reset() if search_iter in grow_steps: cur_stage = grow_ctrler.cur_stage(search_iter) + 1 tqdm.write(f'=> grow to stage {cur_stage}') prev_archs_A, prev_hiddens_A = cycle_controller.get_topk_arch_hidden_A( ) prev_archs_B, prev_hiddens_B = cycle_controller.get_topk_arch_hidden_B( ) del cycle_controller cycle_controller = CycleControllerModel(opt, cur_stage) cycle_controller.setup(opt) cycle_controller.set(cycle_gan, prev_hiddens_A, prev_hiddens_B, prev_archs_A, prev_archs_B) dynamic_reset = cyclgan_train(opt, cycle_gan, cycle_controller, dataset, g_loss_history, d_loss_history, writer_dict) controller_train(opt, cycle_gan, cycle_controller, writer_dict) if dynamic_reset: tqdm.write('re-initialize share GAN') del cycle_gan cycle_gan = CycleGANModel(opt) cycle_gan.setup(opt) save_checkpoint( { 'cur_stage': cur_stage, 'search_iter': search_iter + 1, 'cycle_gan': cycle_gan.save_networks(epoch=search_iter), 'cycle_controller': cycle_controller.save_networks(epoch=search_iter), 'path_helper': opt.path_helper }, False, opt.path_helper['ckpt_path']) final_archs_A, _ = cycle_controller.get_topk_arch_hidden_A() final_archs_B, _ = cycle_controller.get_topk_arch_hidden_B() print(f"discovered archs: {final_archs_A}") print(f"discovered archs: {final_archs_B}")
def cyclgan_train(opt, cycle_gan: CycleGANModel, cycle_controller: CycleControllerModel, train_loader, g_loss_history: RunningStats, d_loss_history: RunningStats, writer_dict): cycle_gan.train() cycle_controller.eval() dynamic_reset = False writer = writer_dict['writer'] total_iters = 0 t_data = 0.0 for epoch in range(opt.shared_epoch): epoch_start_time = time.time() iter_data_time = time.time() epoch_iter = 0 train_steps = writer_dict['train_steps'] for i, data in enumerate(train_loader): iter_start_time = time.time() if total_iters % opt.print_freq == 0: t_data = iter_start_time - iter_data_time total_iters += opt.batch_size epoch_iter += opt.batch_size cycle_controller.forward() cycle_gan.set_input(data) cycle_gan.optimize_parameters() g_loss_history.push(cycle_gan.loss_G.item()) d_loss_history.push(cycle_gan.loss_D_A.item() + cycle_gan.loss_D_B.item()) if (i + 1) % opt.print_freq == 0: losses = cycle_gan.get_current_losses() t_comp = (time.time() - iter_start_time) message = "GAN: [Ep: %d/%d]" % (epoch, opt.shared_epoch) message += "[Batch: %d/%d][time: %.3f][data: %.3f]" % ( epoch_iter, len(train_loader), t_comp, t_data) for k, v in losses.items(): message += '[%s: %.3f]' % (k, v) tqdm.write(message) if (total_iters + 1) % opt.display_freq == 0: cycle_gan.compute_visuals() save_current_results(opt, cycle_gan.get_current_visuals(), train_steps) if g_loss_history.is_full(): if g_loss_history.get_var() < opt.dynamic_reset_threshold \ or d_loss_history.get_var() < opt.dynamic_reset_threshold: dynamic_reset = True tqdm.write("=> dynamic resetting triggered") g_loss_history.clear() d_loss_history.clear() return dynamic_reset if ( total_iters + 1 ) % opt.save_latest_freq == 0: # cache our latest model every <save_latest_freq> iterations tqdm.write( 'saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters)) save_suffix = 'latest' # cycle_gan.save_networks(train_steps) iter_data_time = time.time() if (epoch + 1) % opt.save_epoch_freq == 0: cycle_gan.save_networks('latest') # cycle_gan.save_networks(train_steps) tqdm.write('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time)) writer.add_scalars('Train/discriminator', { "A": float(cycle_gan.loss_D_A), "B": float(cycle_gan.loss_D_B), }, train_steps) writer.add_scalars('Train/generator', { "A": float(cycle_gan.loss_G_A), "B": float(cycle_gan.loss_G_B), }, train_steps) writer.add_scalars( 'Train/cycle', { "A": float(cycle_gan.loss_cycle_A), "B": float(cycle_gan.loss_cycle_B), }, train_steps) writer.add_scalars('Train/idt', { "A": float(cycle_gan.loss_idt_A), "B": float(cycle_gan.loss_idt_B), }, train_steps) writer_dict['train_steps'] += 1 return dynamic_reset
def main(): args = cfg.parse_args() torch.cuda.manual_seed(args.random_seed) # set tf env _init_inception() inception_path = check_or_download_inception(None) create_inception_graph(inception_path) # weight init def weights_init(m): classname = m.__class__.__name__ if classname.find("Conv2d") != -1: if args.init_type == "normal": nn.init.normal_(m.weight.data, 0.0, 0.02) elif args.init_type == "orth": nn.init.orthogonal_(m.weight.data) elif args.init_type == "xavier_uniform": nn.init.xavier_uniform(m.weight.data, 1.0) else: raise NotImplementedError("{} unknown inital type".format( args.init_type)) elif classname.find("BatchNorm2d") != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0.0) gen_net, dis_net, gen_optimizer, dis_optimizer = create_shared_gan( args, weights_init) # set grow controller grow_ctrler = GrowCtrler(args.grow_step1, args.grow_step2) # initial start_search_iter = 0 # set writer if args.load_path: print(f"=> resuming from {args.load_path}") assert os.path.exists(args.load_path) checkpoint_file = os.path.join(args.load_path, "Model", "checkpoint.pth") assert os.path.exists(checkpoint_file) checkpoint = torch.load(checkpoint_file) # set controller && its optimizer cur_stage = checkpoint["cur_stage"] controller, ctrl_optimizer = create_ctrler(args, cur_stage, weights_init) start_search_iter = checkpoint["search_iter"] gen_net.load_state_dict(checkpoint["gen_state_dict"]) dis_net.load_state_dict(checkpoint["dis_state_dict"]) controller.load_state_dict(checkpoint["ctrl_state_dict"]) gen_optimizer.load_state_dict(checkpoint["gen_optimizer"]) dis_optimizer.load_state_dict(checkpoint["dis_optimizer"]) ctrl_optimizer.load_state_dict(checkpoint["ctrl_optimizer"]) prev_archs = checkpoint["prev_archs"] prev_hiddens = checkpoint["prev_hiddens"] args.path_helper = checkpoint["path_helper"] logger = create_logger(args.path_helper["log_path"]) logger.info( f"=> loaded checkpoint {checkpoint_file} (search iteration {start_search_iter})" ) else: # create new log dir assert args.exp_name args.path_helper = set_log_dir("logs", args.exp_name) logger = create_logger(args.path_helper["log_path"]) prev_archs = None prev_hiddens = None # set controller && its optimizer cur_stage = 0 controller, ctrl_optimizer = create_ctrler(args, cur_stage, weights_init) # set up data_loader dataset = datasets.ImageDataset(args, 2**(cur_stage + 3)) train_loader = dataset.train logger.info(args) writer_dict = { "writer": SummaryWriter(args.path_helper["log_path"]), "controller_steps": start_search_iter * args.ctrl_step, } g_loss_history = RunningStats(args.dynamic_reset_window) d_loss_history = RunningStats(args.dynamic_reset_window) # train loop for search_iter in tqdm(range(int(start_search_iter), int(args.max_search_iter)), desc="search progress"): logger.info(f"<start search iteration {search_iter}>") if search_iter == args.grow_step1 or search_iter == args.grow_step2: # save cur_stage = grow_ctrler.cur_stage(search_iter) logger.info(f"=> grow to stage {cur_stage}") prev_archs, prev_hiddens = get_topk_arch_hidden( args, controller, gen_net, prev_archs, prev_hiddens) # grow section del controller del ctrl_optimizer controller, ctrl_optimizer = create_ctrler(args, cur_stage, weights_init) dataset = datasets.ImageDataset(args, 2**(cur_stage + 3)) train_loader = dataset.train dynamic_reset = train_shared( args, gen_net, dis_net, g_loss_history, d_loss_history, controller, gen_optimizer, dis_optimizer, train_loader, prev_hiddens=prev_hiddens, prev_archs=prev_archs, ) train_controller( args, controller, ctrl_optimizer, gen_net, prev_hiddens, prev_archs, writer_dict, ) if dynamic_reset: logger.info("re-initialize share GAN") del gen_net, dis_net, gen_optimizer, dis_optimizer gen_net, dis_net, gen_optimizer, dis_optimizer = create_shared_gan( args, weights_init) save_checkpoint( { "cur_stage": cur_stage, "search_iter": search_iter + 1, "gen_model": args.gen_model, "dis_model": args.dis_model, "controller": args.controller, "gen_state_dict": gen_net.state_dict(), "dis_state_dict": dis_net.state_dict(), "ctrl_state_dict": controller.state_dict(), "gen_optimizer": gen_optimizer.state_dict(), "dis_optimizer": dis_optimizer.state_dict(), "ctrl_optimizer": ctrl_optimizer.state_dict(), "prev_archs": prev_archs, "prev_hiddens": prev_hiddens, "path_helper": args.path_helper, }, False, args.path_helper["ckpt_path"], ) final_archs, _ = get_topk_arch_hidden(args, controller, gen_net, prev_archs, prev_hiddens) logger.info(f"discovered archs: {final_archs}")