Ejemplo n.º 1
0
def main():
    args = cfg.parse_args()
    torch.cuda.manual_seed(args.random_seed)

    # set tf env
    _init_inception(MODEL_DIR)
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    # weight init
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv2d') != -1:
            if args.init_type == 'normal':
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif args.init_type == 'orth':
                nn.init.orthogonal_(m.weight.data)
            elif args.init_type == 'xavier_uniform':
                nn.init.xavier_uniform(m.weight.data, 1.)
            else:
                raise NotImplementedError('{} unknown inital type'.format(
                    args.init_type))
        elif classname.find('BatchNorm2d') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.0)

    gen_net, dis_net, gen_optimizer, dis_optimizer = create_shared_gan(
        args, weights_init)

    # set grow controller
    grow_ctrler = GrowCtrler(args.grow_step1, args.grow_step2)

    # initial
    start_search_iter = 0

    # set writer
    if args.load_path:
        print(f'=> resuming from {args.load_path}')
        assert os.path.exists(args.load_path)
        checkpoint_file = os.path.join(args.load_path, 'Model',
                                       'checkpoint.pth')
        assert os.path.exists(checkpoint_file)
        checkpoint = torch.load(checkpoint_file,
                                map_location={'cuda:0': 'cpu'})
        # set controller && its optimizer
        cur_stage = checkpoint['cur_stage']
        controller, ctrl_optimizer = create_ctrler(args, cur_stage,
                                                   weights_init)

        start_search_iter = checkpoint['search_iter']
        gen_net.load_state_dict(checkpoint['gen_state_dict'])
        dis_net.load_state_dict(checkpoint['dis_state_dict'])
        controller.load_state_dict(checkpoint['ctrl_state_dict'])
        gen_optimizer.load_state_dict(checkpoint['gen_optimizer'])
        dis_optimizer.load_state_dict(checkpoint['dis_optimizer'])
        ctrl_optimizer.load_state_dict(checkpoint['ctrl_optimizer'])
        prev_archs = checkpoint['prev_archs']
        prev_hiddens = checkpoint['prev_hiddens']

        args.path_helper = checkpoint['path_helper']
        logger = create_logger(args.path_helper['log_path'])
        logger.info(
            f'=> loaded checkpoint {checkpoint_file} (search iteration {start_search_iter})'
        )
    else:
        # create new log dir
        assert args.exp_name
        args.path_helper = set_log_dir('logs', args.exp_name)
        logger = create_logger(args.path_helper['log_path'])
        prev_archs = None
        prev_hiddens = None

        # set controller && its optimizer
        cur_stage = 0
        controller, ctrl_optimizer = create_ctrler(args, cur_stage,
                                                   weights_init)

    # set up data_loader
    dataset = datasets.ImageDataset(args, 2**(cur_stage + 3),
                                    args.dis_batch_size, args.num_workers)
    train_loader = dataset.train

    logger.info(args)
    writer_dict = {
        'writer': SummaryWriter(args.path_helper['log_path']),
        'controller_steps': start_search_iter * args.ctrl_step
    }

    g_loss_history = RunningStats(args.dynamic_reset_window)
    d_loss_history = RunningStats(args.dynamic_reset_window)

    # train loop

    for search_iter in tqdm(range(int(start_search_iter),
                                  int(args.max_search_iter)),
                            desc='search progress'):
        logger.info(f"<start search iteration {search_iter}>")
        if search_iter == args.grow_step1 or search_iter == args.grow_step2:

            # save
            cur_stage = grow_ctrler.cur_stage(search_iter)
            logger.info(f'=> grow to stage {cur_stage}')
            prev_archs, prev_hiddens = get_topk_arch_hidden(
                args, controller, gen_net, prev_archs, prev_hiddens)

            # grow section
            del controller
            del ctrl_optimizer
            controller, ctrl_optimizer = create_ctrler(args, cur_stage,
                                                       weights_init)

            dataset = datasets.ImageDataset(args, 2**(cur_stage + 3),
                                            args.dis_batch_size,
                                            args.num_workers)
            train_loader = dataset.train

        dynamic_reset = train_shared(args,
                                     gen_net,
                                     dis_net,
                                     g_loss_history,
                                     d_loss_history,
                                     controller,
                                     gen_optimizer,
                                     dis_optimizer,
                                     train_loader,
                                     prev_hiddens=prev_hiddens,
                                     prev_archs=prev_archs)
        train_controller(args, controller, ctrl_optimizer, gen_net,
                         prev_hiddens, prev_archs, writer_dict)

        if dynamic_reset:
            logger.info('re-initialize share GAN')
            del gen_net, dis_net, gen_optimizer, dis_optimizer
            gen_net, dis_net, gen_optimizer, dis_optimizer = create_shared_gan(
                args, weights_init)

        save_checkpoint(
            {
                'cur_stage': cur_stage,
                'search_iter': search_iter + 1,
                'gen_model': args.gen_model,
                'dis_model': args.dis_model,
                'controller': args.controller,
                'gen_state_dict': gen_net.state_dict(),
                'dis_state_dict': dis_net.state_dict(),
                'ctrl_state_dict': controller.state_dict(),
                'gen_optimizer': gen_optimizer.state_dict(),
                'dis_optimizer': dis_optimizer.state_dict(),
                'ctrl_optimizer': ctrl_optimizer.state_dict(),
                'prev_archs': prev_archs,
                'prev_hiddens': prev_hiddens,
                'path_helper': args.path_helper
            }, False, args.path_helper['ckpt_path'])

    final_archs, _ = get_topk_arch_hidden(args, controller, gen_net,
                                          prev_archs, prev_hiddens)
    logger.info(f"discovered archs: {final_archs}")
Ejemplo n.º 2
0
def main():
    args = cfg.parse_args()

    torch.cuda.manual_seed(args.random_seed)
    print(args)

    # set tf env
    _init_inception()
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    # weight init
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv2d') != -1:
            if args.init_type == 'normal':
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif args.init_type == 'orth':
                nn.init.orthogonal_(m.weight.data)
            elif args.init_type == 'xavier_uniform':
                nn.init.xavier_uniform(m.weight.data, 1.)
            else:
                raise NotImplementedError('{} unknown inital type'.format(
                    args.init_type))
        elif classname.find('BatchNorm2d') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.0)

    gen_net, dis_net, gen_optimizer, dis_optimizer = create_shared_gan(
        args, weights_init)

    # initial
    start_search_iter = 0

    # set writer
    if args.load_path:
        print(f'=> resuming from {args.load_path}')
        assert os.path.exists(args.load_path)
        checkpoint_file = os.path.join(args.load_path, 'Model',
                                       'checkpoint.pth')
        assert os.path.exists(checkpoint_file)
        checkpoint = torch.load(checkpoint_file)
        cur_stage = checkpoint['cur_stage']

        start_search_iter = checkpoint['search_iter']
        gen_net.load_state_dict(checkpoint['gen_state_dict'])
        dis_net.load_state_dict(checkpoint['dis_state_dict'])
        gen_optimizer.load_state_dict(checkpoint['gen_optimizer'])
        dis_optimizer.load_state_dict(checkpoint['dis_optimizer'])
        prev_archs = checkpoint['prev_archs']
        prev_hiddens = checkpoint['prev_hiddens']

        args.path_helper = checkpoint['path_helper']
        logger = create_logger(args.path_helper['log_path'])
        logger.info(
            f'=> loaded checkpoint {checkpoint_file} (search iteration {start_search_iter})'
        )
    else:
        # create new log dir
        assert args.exp_name
        args.path_helper = set_log_dir('logs', args.exp_name)
        logger = create_logger(args.path_helper['log_path'])
        prev_archs = None
        prev_hiddens = None

        # set controller && its optimizer
        cur_stage = 0

    # set up data_loader
    dataset = datasets.ImageDataset(args, 2**(cur_stage + 3))
    train_loader = dataset.train
    print(args.rl_num_eval_img, "##############################")
    logger.info(args)
    writer_dict = {
        'writer': SummaryWriter(args.path_helper['log_path']),
        'controller_steps': start_search_iter * args.ctrl_step
    }

    g_loss_history = RunningStats(args.dynamic_reset_window)
    d_loss_history = RunningStats(args.dynamic_reset_window)

    # train loop
    Agent = SAC(131)
    print(Agent.alpha)

    memory = ReplayMemory(2560000)
    updates = 0
    outinfo = {
        'rewards': [],
        'a_loss': [],
        'critic_error': [],
    }
    Best = False
    Z_NUMPY = None
    WARMUP = True
    update_time = 1
    for search_iter in tqdm(range(int(start_search_iter), 100),
                            desc='search progress'):
        logger.info(f"<start search iteration {search_iter}>")
        if search_iter >= 1:
            WARMUP = False

        ### Define number of layers, currently only support 1->3
        total_layer_num = 3
        ### Different image size for different layers
        ds = [
            datasets.ImageDataset(args, 2**(k + 3))
            for k in range(total_layer_num)
        ]
        train_loaders = [d.train for d in ds]
        last_R = 0.  # Initial reward
        last_fid = 10000  # Inital reward
        last_arch = []

        # Set exploration
        if search_iter > 69:
            update_time = 10
            Best = True
        else:
            Best = False

        gen_net.set_stage(-1)
        last_R, last_fid, last_state = get_is(args,
                                              gen_net,
                                              args.rl_num_eval_img,
                                              get_is_score=True)
        for layer in range(total_layer_num):

            cur_stage = layer  # This defines which layer to use as output, for example, if cur_stage==0, then the output will be the first layer output. Set it to 2 if you want the output of the last layer.
            action = Agent.select_action([layer, last_R, 0.01 * last_fid] +
                                         last_state, Best)
            arch = [
                action[0][0], action[0][1], action[1][0], action[1][1],
                action[1][2], action[2][0], action[2][1], action[2][2],
                action[3][0], action[3][1], action[4][0], action[4][1],
                action[5][0], action[5][1]
            ]
            # print(arch)
            # argmax to get int description of arch
            cur_arch = [np.argmax(k) for k in action]
            # Pad the skip option 0=False (for only layer 1 and layer2, not layer0, see builing_blocks.py for why)
            if layer == 0:
                cur_arch = cur_arch[0:4]
            elif layer == 1:
                cur_arch = cur_arch[0:5]
            elif layer == 2:
                if cur_arch[4] + cur_arch[5] == 2:
                    cur_arch = cur_arch[0:4] + [3]
                elif cur_arch[4] + cur_arch[5] == 0:
                    cur_arch = cur_arch[0:4] + [0]
                elif cur_arch[4] == 1 and cur_arch[5] == 0:
                    cur_arch = cur_arch[0:4] + [1]
                else:
                    cur_arch = cur_arch[0:4] + [2]

            # Get the network arch with the new architecture attached.
            last_arch += cur_arch
            gen_net.set_arch(last_arch,
                             layer)  # Set the network, given cur_stage
            # Train network
            dynamic_reset = train_qin(args,
                                      gen_net,
                                      dis_net,
                                      g_loss_history,
                                      d_loss_history,
                                      gen_optimizer,
                                      dis_optimizer,
                                      train_loaders[layer],
                                      cur_stage,
                                      smooth=False,
                                      WARMUP=WARMUP)

            # Get reward, use the jth layer output for generation. (layer 0:j), and the proposed progressive state
            R, fid, state = get_is(args,
                                   gen_net,
                                   args.rl_num_eval_img,
                                   z_numpy=Z_NUMPY)
            # Print exploitation mark, for better readability of the log.
            if Best:
                print("arch:", cur_arch, "Exploitation:", Best)
            else:
                print("arch:", cur_arch, "Exploring...")
            # Proxy reward of the up-to-now (0:j) architecture.
            print("update times:", updates, "step:", layer + 1, "IS:", R,
                  "FID:", fid)
            mask = 0 if layer == total_layer_num - 1 else 1
            if search_iter >= 0:  # warm up
                memory.push([layer, last_R, 0.01 * last_fid] + last_state,
                            arch, R - last_R + 0.01 * (last_fid - fid),
                            [layer + 1, R, 0.01 * fid] + state,
                            mask)  # Append transition to memory

            if len(memory) >= 64:
                # Number of updates per step in environment
                for i in range(update_time):
                    # Update parameters of all the networks
                    critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha = Agent.update_parameters(
                        memory, min(len(memory), 256), updates)

                    updates += 1
                    outinfo['critic_error'] = min(critic_1_loss, critic_2_loss)
                    outinfo['entropy'] = ent_loss
                    outinfo['a_loss'] = policy_loss
                print("full batch", outinfo, alpha)
            last_R = R  # next step
            last_fid = fid
            last_state = state
        outinfo['rewards'] = R
        critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha = Agent.update_parameters(
            memory, len(memory), updates)
        updates += 1
        outinfo['critic_error'] = min(critic_1_loss, critic_2_loss)
        outinfo['entropy'] = ent_loss
        outinfo['a_loss'] = policy_loss
        print("full batch", outinfo, alpha)
        # Clean up and start a new trajectory from scratch
        del gen_net, dis_net, gen_optimizer, dis_optimizer
        gen_net, dis_net, gen_optimizer, dis_optimizer = create_shared_gan(
            args, weights_init)
        print(outinfo, len(memory))
        Agent.save_model("test")
        WARMUP = False
Ejemplo n.º 3
0
def main():
    opt = SearchOptions().parse()
    torch.cuda.manual_seed(12345)

    _init_inception(MODEL_DIR)
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    start_search_iter = 0
    cur_stage = 1

    delta_grow_steps = [int(opt.grow_step ** i) for i in range(1, opt.max_skip_num)] + \
                       [int(opt.grow_step ** 3) for _ in range(1, opt.n_resnet - opt.max_skip_num + 1)]

    opt.max_search_iter = sum(delta_grow_steps)
    grow_steps = [
        sum(delta_grow_steps[:i]) for i in range(len(delta_grow_steps))
    ][1:]

    grow_ctrler = GrowCtrler(opt.grow_step, steps=grow_steps)

    if opt.load_path:
        print(f'=> resuming from {opt.load_path}')
        assert os.path.exists(opt.load_path)
        checkpoint_file = os.path.join(opt.load_path, 'Model',
                                       'checkpoint.pth')
        assert os.path.exists(checkpoint_file)
        checkpoint = torch.load(checkpoint_file,
                                map_location={'cuda:0': 'cpu'})
        # set controller && its optimizer
        cur_stage = checkpoint['cur_stage']
        start_search_iter = checkpoint["search_iter"]
        opt.path_helper = checkpoint['path_helper']

        cycle_gan = CycleGANModel(opt)
        cycle_gan.setup(opt)

        cycle_controller = CycleControllerModel(opt, cur_stage=cur_stage)
        cycle_controller.setup(opt)
        cycle_controller.set(cycle_gan)

        cycle_gan.load_from_state(checkpoint["cycle_gan"])
        cycle_controller.load_from_state(checkpoint["cycle_controller"])

    else:
        opt.path_helper = set_log_dir(opt.checkpoints_dir, opt.name)

        cycle_gan = CycleGANModel(opt)
        cycle_gan.setup(opt)

        cycle_controller = CycleControllerModel(opt, cur_stage=cur_stage)
        cycle_controller.setup(opt)
        cycle_controller.set(cycle_gan)

    dataset = create_dataset(
        opt)  # create a dataset given opt.dataset_mode and other options
    print('The number of training images = %d' % len(dataset))

    writer_dict = {
        "writer": SummaryWriter(opt.path_helper['log_path']),
        'controller_steps': start_search_iter * opt.ctrl_step,
        'train_steps': start_search_iter * opt.shared_epoch
    }

    g_loss_history = RunningStats(opt.dynamic_reset_window)
    d_loss_history = RunningStats(opt.dynamic_reset_window)

    dynamic_reset = None
    for search_iter in tqdm(
            range(int(start_search_iter), int(opt.max_search_iter))):
        tqdm.write(f"<start search iteration {search_iter}>")
        cycle_controller.reset()

        if search_iter in grow_steps:
            cur_stage = grow_ctrler.cur_stage(search_iter) + 1
            tqdm.write(f'=> grow to stage {cur_stage}')
            prev_archs_A, prev_hiddens_A = cycle_controller.get_topk_arch_hidden_A(
            )
            prev_archs_B, prev_hiddens_B = cycle_controller.get_topk_arch_hidden_B(
            )

            del cycle_controller

            cycle_controller = CycleControllerModel(opt, cur_stage)
            cycle_controller.setup(opt)
            cycle_controller.set(cycle_gan, prev_hiddens_A, prev_hiddens_B,
                                 prev_archs_A, prev_archs_B)

        dynamic_reset = cyclgan_train(opt, cycle_gan, cycle_controller,
                                      dataset, g_loss_history, d_loss_history,
                                      writer_dict)

        controller_train(opt, cycle_gan, cycle_controller, writer_dict)

        if dynamic_reset:
            tqdm.write('re-initialize share GAN')
            del cycle_gan
            cycle_gan = CycleGANModel(opt)
            cycle_gan.setup(opt)

        save_checkpoint(
            {
                'cur_stage':
                cur_stage,
                'search_iter':
                search_iter + 1,
                'cycle_gan':
                cycle_gan.save_networks(epoch=search_iter),
                'cycle_controller':
                cycle_controller.save_networks(epoch=search_iter),
                'path_helper':
                opt.path_helper
            }, False, opt.path_helper['ckpt_path'])

    final_archs_A, _ = cycle_controller.get_topk_arch_hidden_A()
    final_archs_B, _ = cycle_controller.get_topk_arch_hidden_B()
    print(f"discovered archs: {final_archs_A}")
    print(f"discovered archs: {final_archs_B}")
Ejemplo n.º 4
0
def cyclgan_train(opt, cycle_gan: CycleGANModel,
                  cycle_controller: CycleControllerModel, train_loader,
                  g_loss_history: RunningStats, d_loss_history: RunningStats,
                  writer_dict):
    cycle_gan.train()
    cycle_controller.eval()

    dynamic_reset = False
    writer = writer_dict['writer']
    total_iters = 0
    t_data = 0.0

    for epoch in range(opt.shared_epoch):
        epoch_start_time = time.time()
        iter_data_time = time.time()
        epoch_iter = 0

        train_steps = writer_dict['train_steps']
        for i, data in enumerate(train_loader):
            iter_start_time = time.time()

            if total_iters % opt.print_freq == 0:
                t_data = iter_start_time - iter_data_time

            total_iters += opt.batch_size
            epoch_iter += opt.batch_size

            cycle_controller.forward()

            cycle_gan.set_input(data)
            cycle_gan.optimize_parameters()

            g_loss_history.push(cycle_gan.loss_G.item())
            d_loss_history.push(cycle_gan.loss_D_A.item() +
                                cycle_gan.loss_D_B.item())

            if (i + 1) % opt.print_freq == 0:
                losses = cycle_gan.get_current_losses()
                t_comp = (time.time() - iter_start_time)
                message = "GAN: [Ep: %d/%d]" % (epoch, opt.shared_epoch)
                message += "[Batch: %d/%d][time: %.3f][data: %.3f]" % (
                    epoch_iter, len(train_loader), t_comp, t_data)
                for k, v in losses.items():
                    message += '[%s: %.3f]' % (k, v)
                tqdm.write(message)

            if (total_iters + 1) % opt.display_freq == 0:
                cycle_gan.compute_visuals()
                save_current_results(opt, cycle_gan.get_current_visuals(),
                                     train_steps)

            if g_loss_history.is_full():
                if g_loss_history.get_var() < opt.dynamic_reset_threshold \
                        or d_loss_history.get_var() < opt.dynamic_reset_threshold:
                    dynamic_reset = True
                    tqdm.write("=> dynamic resetting triggered")
                    g_loss_history.clear()
                    d_loss_history.clear()
                    return dynamic_reset

            if (
                    total_iters + 1
            ) % opt.save_latest_freq == 0:  # cache our latest model every <save_latest_freq> iterations
                tqdm.write(
                    'saving the latest model (epoch %d, total_iters %d)' %
                    (epoch, total_iters))
                save_suffix = 'latest'
                # cycle_gan.save_networks(train_steps)

            iter_data_time = time.time()

        if (epoch + 1) % opt.save_epoch_freq == 0:
            cycle_gan.save_networks('latest')
            # cycle_gan.save_networks(train_steps)

        tqdm.write('End of epoch %d / %d \t Time Taken: %d sec' %
                   (epoch, opt.n_epochs + opt.n_epochs_decay,
                    time.time() - epoch_start_time))

        writer.add_scalars('Train/discriminator', {
            "A": float(cycle_gan.loss_D_A),
            "B": float(cycle_gan.loss_D_B),
        }, train_steps)
        writer.add_scalars('Train/generator', {
            "A": float(cycle_gan.loss_G_A),
            "B": float(cycle_gan.loss_G_B),
        }, train_steps)
        writer.add_scalars(
            'Train/cycle', {
                "A": float(cycle_gan.loss_cycle_A),
                "B": float(cycle_gan.loss_cycle_B),
            }, train_steps)
        writer.add_scalars('Train/idt', {
            "A": float(cycle_gan.loss_idt_A),
            "B": float(cycle_gan.loss_idt_B),
        }, train_steps)

        writer_dict['train_steps'] += 1

    return dynamic_reset
Ejemplo n.º 5
0
def main():
    args = cfg.parse_args()
    torch.cuda.manual_seed(args.random_seed)

    # set tf env
    _init_inception()
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    # weight init
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find("Conv2d") != -1:
            if args.init_type == "normal":
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif args.init_type == "orth":
                nn.init.orthogonal_(m.weight.data)
            elif args.init_type == "xavier_uniform":
                nn.init.xavier_uniform(m.weight.data, 1.0)
            else:
                raise NotImplementedError("{} unknown inital type".format(
                    args.init_type))
        elif classname.find("BatchNorm2d") != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.0)

    gen_net, dis_net, gen_optimizer, dis_optimizer = create_shared_gan(
        args, weights_init)

    # set grow controller
    grow_ctrler = GrowCtrler(args.grow_step1, args.grow_step2)

    # initial
    start_search_iter = 0

    # set writer
    if args.load_path:
        print(f"=> resuming from {args.load_path}")
        assert os.path.exists(args.load_path)
        checkpoint_file = os.path.join(args.load_path, "Model",
                                       "checkpoint.pth")
        assert os.path.exists(checkpoint_file)
        checkpoint = torch.load(checkpoint_file)
        # set controller && its optimizer
        cur_stage = checkpoint["cur_stage"]
        controller, ctrl_optimizer = create_ctrler(args, cur_stage,
                                                   weights_init)

        start_search_iter = checkpoint["search_iter"]
        gen_net.load_state_dict(checkpoint["gen_state_dict"])
        dis_net.load_state_dict(checkpoint["dis_state_dict"])
        controller.load_state_dict(checkpoint["ctrl_state_dict"])
        gen_optimizer.load_state_dict(checkpoint["gen_optimizer"])
        dis_optimizer.load_state_dict(checkpoint["dis_optimizer"])
        ctrl_optimizer.load_state_dict(checkpoint["ctrl_optimizer"])
        prev_archs = checkpoint["prev_archs"]
        prev_hiddens = checkpoint["prev_hiddens"]

        args.path_helper = checkpoint["path_helper"]
        logger = create_logger(args.path_helper["log_path"])
        logger.info(
            f"=> loaded checkpoint {checkpoint_file} (search iteration {start_search_iter})"
        )
    else:
        # create new log dir
        assert args.exp_name
        args.path_helper = set_log_dir("logs", args.exp_name)
        logger = create_logger(args.path_helper["log_path"])
        prev_archs = None
        prev_hiddens = None

        # set controller && its optimizer
        cur_stage = 0
        controller, ctrl_optimizer = create_ctrler(args, cur_stage,
                                                   weights_init)

    # set up data_loader
    dataset = datasets.ImageDataset(args, 2**(cur_stage + 3))
    train_loader = dataset.train

    logger.info(args)
    writer_dict = {
        "writer": SummaryWriter(args.path_helper["log_path"]),
        "controller_steps": start_search_iter * args.ctrl_step,
    }

    g_loss_history = RunningStats(args.dynamic_reset_window)
    d_loss_history = RunningStats(args.dynamic_reset_window)

    # train loop
    for search_iter in tqdm(range(int(start_search_iter),
                                  int(args.max_search_iter)),
                            desc="search progress"):
        logger.info(f"<start search iteration {search_iter}>")
        if search_iter == args.grow_step1 or search_iter == args.grow_step2:

            # save
            cur_stage = grow_ctrler.cur_stage(search_iter)
            logger.info(f"=> grow to stage {cur_stage}")
            prev_archs, prev_hiddens = get_topk_arch_hidden(
                args, controller, gen_net, prev_archs, prev_hiddens)

            # grow section
            del controller
            del ctrl_optimizer
            controller, ctrl_optimizer = create_ctrler(args, cur_stage,
                                                       weights_init)

            dataset = datasets.ImageDataset(args, 2**(cur_stage + 3))
            train_loader = dataset.train

        dynamic_reset = train_shared(
            args,
            gen_net,
            dis_net,
            g_loss_history,
            d_loss_history,
            controller,
            gen_optimizer,
            dis_optimizer,
            train_loader,
            prev_hiddens=prev_hiddens,
            prev_archs=prev_archs,
        )
        train_controller(
            args,
            controller,
            ctrl_optimizer,
            gen_net,
            prev_hiddens,
            prev_archs,
            writer_dict,
        )

        if dynamic_reset:
            logger.info("re-initialize share GAN")
            del gen_net, dis_net, gen_optimizer, dis_optimizer
            gen_net, dis_net, gen_optimizer, dis_optimizer = create_shared_gan(
                args, weights_init)

        save_checkpoint(
            {
                "cur_stage": cur_stage,
                "search_iter": search_iter + 1,
                "gen_model": args.gen_model,
                "dis_model": args.dis_model,
                "controller": args.controller,
                "gen_state_dict": gen_net.state_dict(),
                "dis_state_dict": dis_net.state_dict(),
                "ctrl_state_dict": controller.state_dict(),
                "gen_optimizer": gen_optimizer.state_dict(),
                "dis_optimizer": dis_optimizer.state_dict(),
                "ctrl_optimizer": ctrl_optimizer.state_dict(),
                "prev_archs": prev_archs,
                "prev_hiddens": prev_hiddens,
                "path_helper": args.path_helper,
            },
            False,
            args.path_helper["ckpt_path"],
        )

    final_archs, _ = get_topk_arch_hidden(args, controller, gen_net,
                                          prev_archs, prev_hiddens)
    logger.info(f"discovered archs: {final_archs}")