Ejemplo n.º 1
0
def main():
    args = parse_args()
    gen_net = Generator(args).cuda()
    _init_inception()
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)
    assert os.path.exists(
        args.load_path), "checkpoint file {} is not found".format(
            args.load_path)
    checkpoint = torch.load(args.load_path)
    torch.manual_seed(12345)
    torch.cuda.manual_seed(12345)
    np.random.seed(12345)
    #print("remaining percent: {}".format(0.8 ** checkpoint['round']))
    pruning_generate(
        gen_net, checkpoint['avg_gen_state_dict'])  # Create a buffer for mask]
    gen_net.load_state_dict(checkpoint['avg_gen_state_dict'])

    see_remain_rate_mask(gen_net)
    see_remain_rate(gen_net)
    print("Best FID:{}".format(checkpoint['best_fid']))
    fixed_z = torch.cuda.FloatTensor(
        np.random.normal(0, 1, (25, args.latent_dim)))
    fid_stat = 'fid_stat/fid_stats_cifar10_train.npz'
    inception_score, fid_score = evaluate(args, fixed_z, fid_stat, gen_net)

    print('Inception score: %.4f, FID score: %.4f' %
          (inception_score, fid_score))
Ejemplo n.º 2
0
def main():
    args = cfg.parse_args()
    torch.manual_seed(args.random_seed)
    random.seed(args.random_seed)
    torch.cuda.manual_seed(args.random_seed)
    assert args.exp_name
    assert args.load_path.endswith('.pth')
    assert os.path.exists(args.load_path)
    args.path_helper = set_log_dir('logs_eval', args.exp_name)
    logger = create_logger(args.path_helper['log_path'], phase='test')

    # set tf env
    _init_inception()
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    # import network
    gen_net = eval('models.' + args.model + '.Generator')(args=args).cuda()

    # fid stat
    if args.dataset.lower() == 'cifar10':
        fid_stat = 'fid_stat/fid_stats_cifar10_train.npz'
    else:
        raise NotImplementedError(f'no fid stat for {args.dataset.lower()}')
    assert os.path.exists(fid_stat)

    # initial
    np.random.seed(args.random_seed)
    fixed_z = torch.cuda.FloatTensor(
        np.random.normal(0, 1, (25, args.latent_dim)))
    if args.percent < 0.9:
        pruning_generate(gen_net, (1 - args.percent))
        see_remain_rate(gen_net)

    # set writer
    logger.info(f'=> resuming from {args.load_path}')
    checkpoint_file = args.load_path
    assert os.path.exists(checkpoint_file)
    checkpoint = torch.load(checkpoint_file)

    if 'avg_gen_state_dict' in checkpoint:
        gen_net.load_state_dict(checkpoint['avg_gen_state_dict'])
        epoch = checkpoint['epoch']
        logger.info(f'=> loaded checkpoint {checkpoint_file} (epoch {epoch})')
    else:
        gen_net.load_state_dict(checkpoint)
        logger.info(f'=> loaded checkpoint {checkpoint_file}')

    logger.info(args)
    writer_dict = {
        'writer': SummaryWriter(args.path_helper['log_path']),
        'valid_global_steps': 0,
    }
    inception_score, fid_score = validate(args, fixed_z, fid_stat, gen_net,
                                          writer_dict, epoch)
    logger.info(f'Inception score: {inception_score}, FID score: {fid_score}.')

    writer_dict['writer'].close()
Ejemplo n.º 3
0
def main():
    args = cfg.parse_args()
    torch.cuda.manual_seed(args.random_seed)
    assert args.exp_name
    assert args.load_path.endswith(".pth")
    assert os.path.exists(args.load_path)
    args.path_helper = set_log_dir("logs_eval", args.exp_name)
    logger = create_logger(args.path_helper["log_path"], phase="test")

    # set tf env
    _init_inception()
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    # import network
    gen_net = eval("models." + args.gen_model + ".Generator")(args=args).cuda()

    # fid stat
    if args.dataset.lower() == "cifar10":
        fid_stat = "fid_stat/fid_stats_cifar10_train.npz"
    elif args.dataset.lower() == "stl10":
        fid_stat = "fid_stat/stl10_train_unlabeled_fid_stats_48.npz"
    else:
        raise NotImplementedError(f"no fid stat for {args.dataset.lower()}")
    assert os.path.exists(fid_stat)

    # initial
    fixed_z = torch.cuda.FloatTensor(
        np.random.normal(0, 1, (25, args.latent_dim)))

    # set writer
    logger.info(f"=> resuming from {args.load_path}")
    checkpoint_file = args.load_path
    assert os.path.exists(checkpoint_file)
    checkpoint = torch.load(checkpoint_file)

    if "avg_gen_state_dict" in checkpoint:
        gen_net.load_state_dict(checkpoint["avg_gen_state_dict"])
        epoch = checkpoint["epoch"]
        logger.info(f"=> loaded checkpoint {checkpoint_file} (epoch {epoch})")
    else:
        gen_net.load_state_dict(checkpoint)
        logger.info(f"=> loaded checkpoint {checkpoint_file}")

    logger.info(args)
    writer_dict = {
        "writer": SummaryWriter(args.path_helper["log_path"]),
        "valid_global_steps": 0,
    }
    inception_score, fid_score = validate(args,
                                          fixed_z,
                                          fid_stat,
                                          gen_net,
                                          writer_dict,
                                          clean_dir=False)
    logger.info(f"Inception score: {inception_score}, FID score: {fid_score}.")
Ejemplo n.º 4
0
def calculate_IS_FID(G):
    torch.cuda.manual_seed(args.random_seed)

    # set tf env
    _init_inception()
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    # fid stat
    fid_stat = 'fid_stat/fid_stats_cifar10_train.npz'
    assert os.path.exists(fid_stat)

    # IS and FID:
    inception_score, fid_score = \
            calculate_metrics(args.fid_buffer_dir, args.num_eval_imgs, args.eval_batch_size, args.latent_dim,
                            fid_stat, G, do_IS=args.do_IS, do_FID=args.do_FID)
    print('Inception score: %s, FID score: %s.' % (inception_score, fid_score))
Ejemplo n.º 5
0
def main():
    args = cfg.parse_args()
    torch.cuda.manual_seed(args.random_seed)
    assert args.exp_name
    assert args.load_path.endswith('.pth')
    assert os.path.exists(args.load_path)
    args.path_helper = set_log_dir('logs_eval', args.exp_name)
    logger = create_logger(args.path_helper['log_path'], phase='test')

    # set tf env
    _init_inception()
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    # import network
    gen_net = eval('models.' + args.model + '.Generator')(args=args).cuda()

    # initial
    fixed_z = torch.cuda.FloatTensor(
        np.random.normal(0, 1, (25, args.latent_dim)))

    # set writer
    logger.info(f'=> resuming from {args.load_path}')
    checkpoint_file = args.load_path
    assert os.path.exists(checkpoint_file)
    checkpoint = torch.load(checkpoint_file)

    if 'avg_gen_state_dict' in checkpoint:
        gen_net.load_state_dict(checkpoint['avg_gen_state_dict'])
        epoch = checkpoint['epoch']
        logger.info(f'=> loaded checkpoint {checkpoint_file} (epoch {epoch})')
    else:
        gen_net.load_state_dict(checkpoint)
        logger.info(f'=> loaded checkpoint {checkpoint_file}')

    logger.info(args)
    writer_dict = {
        'writer': SummaryWriter(args.path_helper['log_path']),
        'valid_global_steps': 0,
    }
    inception_score, fid_score = validate(args, fixed_z, gen_net, writer_dict)
    logger.info(f'Inception score: {inception_score}, FID score: {fid_score}.')
Ejemplo n.º 6
0
def main():
    args = cfg.parse_args()
    torch.cuda.manual_seed(args.random_seed)

    # set visible GPU ids
    if len(args.gpu_ids) > 0:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_ids

    # set TensorFlow environment for evaluation (calculate IS and FID)
    _init_inception()
    inception_path = check_or_download_inception('./tmp/imagenet/')
    create_inception_graph(inception_path)

    # the first GPU in visible GPUs is dedicated for evaluation (running Inception model)
    str_ids = args.gpu_ids.split(',')
    args.gpu_ids = []
    for id in range(len(str_ids)):
        if id >= 0:
            args.gpu_ids.append(id)
    if len(args.gpu_ids) > 1:
        args.gpu_ids = args.gpu_ids[1:]
    else:
        args.gpu_ids = args.gpu_ids

    # genotype G
    genotypes_root = os.path.join('exps', args.genotypes_exp, 'Genotypes')
    genotype_G = np.load(os.path.join(genotypes_root, 'latest_G.npy'))

    # import network from genotype
    basemodel_gen = eval('archs.' + args.arch + '.Generator')(args, genotype_G)
    gen_net = torch.nn.DataParallel(
        basemodel_gen, device_ids=args.gpu_ids).cuda(args.gpu_ids[0])
    basemodel_dis = eval('archs.' + args.arch + '.Discriminator')(args)
    dis_net = torch.nn.DataParallel(
        basemodel_dis, device_ids=args.gpu_ids).cuda(args.gpu_ids[0])

    # fid stat
    if args.dataset.lower() == 'cifar10':
        fid_stat = 'fid_stat/fid_stats_cifar10_train.npz'
    elif args.dataset.lower() == 'stl10':
        fid_stat = 'fid_stat/stl10_train_unlabeled_fid_stats_48.npz'
    else:
        raise NotImplementedError(f'no fid stat for {args.dataset.lower()}')
    assert os.path.exists(fid_stat)

    # set writer
    print(f'=> resuming from {args.checkpoint}')
    assert os.path.exists(os.path.join('exps', args.checkpoint))
    checkpoint_file = os.path.join('exps', args.checkpoint, 'Model',
                                   'checkpoint_best.pth')
    assert os.path.exists(checkpoint_file)
    checkpoint = torch.load(checkpoint_file)
    epoch = checkpoint['epoch'] - 1
    gen_net.load_state_dict(checkpoint['gen_state_dict'])
    dis_net.load_state_dict(checkpoint['dis_state_dict'])
    avg_gen_net = deepcopy(gen_net)
    avg_gen_net.load_state_dict(checkpoint['avg_gen_state_dict'])
    gen_avg_param = copy_params(avg_gen_net)
    del avg_gen_net
    assert args.exp_name
    args.path_helper = set_log_dir('exps', args.exp_name)
    logger = create_logger(args.path_helper['log_path'])
    logger.info(f'=> loaded checkpoint {checkpoint_file} (epoch {epoch})')

    logger.info(args)
    writer_dict = {
        'writer': SummaryWriter(args.path_helper['log_path']),
        'valid_global_steps': epoch // args.val_freq,
    }

    # model size
    logger.info('Param size of G = %fMB', count_parameters_in_MB(gen_net))
    logger.info('Param size of D = %fMB', count_parameters_in_MB(dis_net))
    print_FLOPs(basemodel_gen, (1, args.latent_dim), logger)
    print_FLOPs(basemodel_dis, (1, 3, args.img_size, args.img_size), logger)

    # for visualization
    if args.draw_arch:
        from utils.genotype import draw_graph_G
        draw_graph_G(genotype_G,
                     save=True,
                     file_path=os.path.join(args.path_helper['graph_vis_path'],
                                            'latest_G'))
    fixed_z = torch.cuda.FloatTensor(
        np.random.normal(0, 1, (100, args.latent_dim)))

    # test
    load_params(gen_net, gen_avg_param)
    inception_score, std, fid_score = validate(args, fixed_z, fid_stat,
                                               gen_net, writer_dict)
    logger.info(
        f'Inception score mean: {inception_score}, Inception score std: {std}, '
        f'FID score: {fid_score} || @ epoch {epoch}.')
Ejemplo n.º 7
0
def main():
    args = cfg.parse_args()
    torch.cuda.manual_seed(args.random_seed)

    # set tf env
    _init_inception(MODEL_DIR)
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    # weight init
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv2d') != -1:
            if args.init_type == 'normal':
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif args.init_type == 'orth':
                nn.init.orthogonal_(m.weight.data)
            elif args.init_type == 'xavier_uniform':
                nn.init.xavier_uniform(m.weight.data, 1.)
            else:
                raise NotImplementedError('{} unknown inital type'.format(
                    args.init_type))
        elif classname.find('BatchNorm2d') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.0)

    gen_net, dis_net, gen_optimizer, dis_optimizer = create_shared_gan(
        args, weights_init)

    # set grow controller
    grow_ctrler = GrowCtrler(args.grow_step1, args.grow_step2)

    # initial
    start_search_iter = 0

    # set writer
    if args.load_path:
        print(f'=> resuming from {args.load_path}')
        assert os.path.exists(args.load_path)
        checkpoint_file = os.path.join(args.load_path, 'Model',
                                       'checkpoint.pth')
        assert os.path.exists(checkpoint_file)
        checkpoint = torch.load(checkpoint_file,
                                map_location={'cuda:0': 'cpu'})
        # set controller && its optimizer
        cur_stage = checkpoint['cur_stage']
        controller, ctrl_optimizer = create_ctrler(args, cur_stage,
                                                   weights_init)

        start_search_iter = checkpoint['search_iter']
        gen_net.load_state_dict(checkpoint['gen_state_dict'])
        dis_net.load_state_dict(checkpoint['dis_state_dict'])
        controller.load_state_dict(checkpoint['ctrl_state_dict'])
        gen_optimizer.load_state_dict(checkpoint['gen_optimizer'])
        dis_optimizer.load_state_dict(checkpoint['dis_optimizer'])
        ctrl_optimizer.load_state_dict(checkpoint['ctrl_optimizer'])
        prev_archs = checkpoint['prev_archs']
        prev_hiddens = checkpoint['prev_hiddens']

        args.path_helper = checkpoint['path_helper']
        logger = create_logger(args.path_helper['log_path'])
        logger.info(
            f'=> loaded checkpoint {checkpoint_file} (search iteration {start_search_iter})'
        )
    else:
        # create new log dir
        assert args.exp_name
        args.path_helper = set_log_dir('logs', args.exp_name)
        logger = create_logger(args.path_helper['log_path'])
        prev_archs = None
        prev_hiddens = None

        # set controller && its optimizer
        cur_stage = 0
        controller, ctrl_optimizer = create_ctrler(args, cur_stage,
                                                   weights_init)

    # set up data_loader
    dataset = datasets.ImageDataset(args, 2**(cur_stage + 3),
                                    args.dis_batch_size, args.num_workers)
    train_loader = dataset.train

    logger.info(args)
    writer_dict = {
        'writer': SummaryWriter(args.path_helper['log_path']),
        'controller_steps': start_search_iter * args.ctrl_step
    }

    g_loss_history = RunningStats(args.dynamic_reset_window)
    d_loss_history = RunningStats(args.dynamic_reset_window)

    # train loop

    for search_iter in tqdm(range(int(start_search_iter),
                                  int(args.max_search_iter)),
                            desc='search progress'):
        logger.info(f"<start search iteration {search_iter}>")
        if search_iter == args.grow_step1 or search_iter == args.grow_step2:

            # save
            cur_stage = grow_ctrler.cur_stage(search_iter)
            logger.info(f'=> grow to stage {cur_stage}')
            prev_archs, prev_hiddens = get_topk_arch_hidden(
                args, controller, gen_net, prev_archs, prev_hiddens)

            # grow section
            del controller
            del ctrl_optimizer
            controller, ctrl_optimizer = create_ctrler(args, cur_stage,
                                                       weights_init)

            dataset = datasets.ImageDataset(args, 2**(cur_stage + 3),
                                            args.dis_batch_size,
                                            args.num_workers)
            train_loader = dataset.train

        dynamic_reset = train_shared(args,
                                     gen_net,
                                     dis_net,
                                     g_loss_history,
                                     d_loss_history,
                                     controller,
                                     gen_optimizer,
                                     dis_optimizer,
                                     train_loader,
                                     prev_hiddens=prev_hiddens,
                                     prev_archs=prev_archs)
        train_controller(args, controller, ctrl_optimizer, gen_net,
                         prev_hiddens, prev_archs, writer_dict)

        if dynamic_reset:
            logger.info('re-initialize share GAN')
            del gen_net, dis_net, gen_optimizer, dis_optimizer
            gen_net, dis_net, gen_optimizer, dis_optimizer = create_shared_gan(
                args, weights_init)

        save_checkpoint(
            {
                'cur_stage': cur_stage,
                'search_iter': search_iter + 1,
                'gen_model': args.gen_model,
                'dis_model': args.dis_model,
                'controller': args.controller,
                'gen_state_dict': gen_net.state_dict(),
                'dis_state_dict': dis_net.state_dict(),
                'ctrl_state_dict': controller.state_dict(),
                'gen_optimizer': gen_optimizer.state_dict(),
                'dis_optimizer': dis_optimizer.state_dict(),
                'ctrl_optimizer': ctrl_optimizer.state_dict(),
                'prev_archs': prev_archs,
                'prev_hiddens': prev_hiddens,
                'path_helper': args.path_helper
            }, False, args.path_helper['ckpt_path'])

    final_archs, _ = get_topk_arch_hidden(args, controller, gen_net,
                                          prev_archs, prev_hiddens)
    logger.info(f"discovered archs: {final_archs}")
Ejemplo n.º 8
0
def main():
    args = cfg.parse_args()
    torch.cuda.manual_seed(args.random_seed)

    # set tf env
    _init_inception()
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    # import network
    gen_net = eval('models.' + args.model + '.Generator')(args=args).cuda()
    dis_net = eval('models.' + args.model + '.Discriminator')(args=args).cuda()

    # weight init
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv2d') != -1:
            if args.init_type == 'normal':
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif args.init_type == 'orth':
                nn.init.orthogonal_(m.weight.data)
            elif args.init_type == 'xavier_uniform':
                nn.init.xavier_uniform(m.weight.data, 1.)
            else:
                raise NotImplementedError('{} unknown inital type'.format(
                    args.init_type))
        elif classname.find('BatchNorm2d') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.0)

    gen_net.apply(weights_init)
    dis_net.apply(weights_init)

    # set optimizer
    gen_optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, gen_net.parameters()), args.g_lr,
        (args.beta1, args.beta2))
    dis_optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, dis_net.parameters()), args.d_lr,
        (args.beta1, args.beta2))
    gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0,
                                  args.max_iter * args.n_critic)
    dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0,
                                  args.max_iter * args.n_critic)

    # set up data_loader
    dataset = datasets.ImageDataset(args)
    train_loader = dataset.train

    # fid stat
    if args.dataset.lower() == 'cifar10':
        fid_stat = 'fid_stat/fid_stats_cifar10_train.npz'
    elif args.dataset.lower() == 'stl10':
        fid_stat = 'fid_stat/stl10_train_unlabeled_fid_stats_48.npz'
    else:
        raise NotImplementedError(f'no fid stat for {args.dataset.lower()}')
    assert os.path.exists(fid_stat)

    # epoch number for dis_net
    args.max_epoch = args.max_epoch * args.n_critic
    if args.max_iter:
        args.max_epoch = np.ceil(args.max_iter * args.n_critic /
                                 len(train_loader))

    # initial
    fixed_z = torch.cuda.FloatTensor(
        np.random.normal(0, 1, (25, args.latent_dim)))
    gen_avg_param = copy_params(gen_net)
    start_epoch = 0
    best_fid = 1e4

    # set writer
    if args.load_path:
        print(f'=> resuming from {args.load_path}')
        assert os.path.exists(args.load_path)
        checkpoint_file = os.path.join(args.load_path, 'Model',
                                       'checkpoint.pth')
        assert os.path.exists(checkpoint_file)
        checkpoint = torch.load(checkpoint_file)
        start_epoch = checkpoint['epoch']
        best_fid = checkpoint['best_fid']
        gen_net.load_state_dict(checkpoint['gen_state_dict'])
        dis_net.load_state_dict(checkpoint['dis_state_dict'])
        gen_optimizer.load_state_dict(checkpoint['gen_optimizer'])
        dis_optimizer.load_state_dict(checkpoint['dis_optimizer'])
        avg_gen_net = deepcopy(gen_net)
        avg_gen_net.load_state_dict(checkpoint['avg_gen_state_dict'])
        gen_avg_param = copy_params(avg_gen_net)
        del avg_gen_net

        args.path_helper = checkpoint['path_helper']
        logger = create_logger(args.path_helper['log_path'])
        logger.info(
            f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})')
    else:
        # create new log dir
        assert args.exp_name
        args.path_helper = set_log_dir('logs', args.exp_name)
        logger = create_logger(args.path_helper['log_path'])

    logger.info(args)
    writer_dict = {
        'writer': SummaryWriter(args.path_helper['log_path']),
        'train_global_steps': start_epoch * len(train_loader),
        'valid_global_steps': start_epoch // args.val_freq,
    }

    # train loop
    for epoch in tqdm(range(int(start_epoch), int(args.max_epoch)),
                      desc='total progress'):
        lr_schedulers = (gen_scheduler,
                         dis_scheduler) if args.lr_decay else None
        train(args, gen_net, dis_net, gen_optimizer, dis_optimizer,
              gen_avg_param, train_loader, epoch, writer_dict, lr_schedulers)

        if epoch and epoch % args.val_freq == 0 or epoch == int(
                args.max_epoch) - 1:
            backup_param = copy_params(gen_net)
            load_params(gen_net, gen_avg_param)
            inception_score, fid_score = validate(args, fixed_z, fid_stat,
                                                  gen_net, writer_dict)
            logger.info(
                f'Inception score: {inception_score}, FID score: {fid_score} || @ epoch {epoch}.'
            )
            load_params(gen_net, backup_param)
            if fid_score < best_fid:
                best_fid = fid_score
                is_best = True
            else:
                is_best = False
        else:
            is_best = False

        avg_gen_net = deepcopy(gen_net)
        load_params(avg_gen_net, gen_avg_param)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': args.model,
                'gen_state_dict': gen_net.state_dict(),
                'dis_state_dict': dis_net.state_dict(),
                'avg_gen_state_dict': avg_gen_net.state_dict(),
                'gen_optimizer': gen_optimizer.state_dict(),
                'dis_optimizer': dis_optimizer.state_dict(),
                'best_fid': best_fid,
                'path_helper': args.path_helper
            }, is_best, args.path_helper['ckpt_path'])
        del avg_gen_net
Ejemplo n.º 9
0
def main():
    args = cfg.parse_args()
    torch.cuda.manual_seed(args.random_seed)

    # set tf env
    _init_inception()
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    # weight init
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find("Conv2d") != -1:
            if args.init_type == "normal":
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif args.init_type == "orth":
                nn.init.orthogonal_(m.weight.data)
            elif args.init_type == "xavier_uniform":
                nn.init.xavier_uniform(m.weight.data, 1.0)
            else:
                raise NotImplementedError("{} unknown inital type".format(
                    args.init_type))
        elif classname.find("BatchNorm2d") != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.0)

    gen_net, dis_net, gen_optimizer, dis_optimizer = create_shared_gan(
        args, weights_init)

    # set grow controller
    grow_ctrler = GrowCtrler(args.grow_step1, args.grow_step2)

    # initial
    start_search_iter = 0

    # set writer
    if args.load_path:
        print(f"=> resuming from {args.load_path}")
        assert os.path.exists(args.load_path)
        checkpoint_file = os.path.join(args.load_path, "Model",
                                       "checkpoint.pth")
        assert os.path.exists(checkpoint_file)
        checkpoint = torch.load(checkpoint_file)
        # set controller && its optimizer
        cur_stage = checkpoint["cur_stage"]
        controller, ctrl_optimizer = create_ctrler(args, cur_stage,
                                                   weights_init)

        start_search_iter = checkpoint["search_iter"]
        gen_net.load_state_dict(checkpoint["gen_state_dict"])
        dis_net.load_state_dict(checkpoint["dis_state_dict"])
        controller.load_state_dict(checkpoint["ctrl_state_dict"])
        gen_optimizer.load_state_dict(checkpoint["gen_optimizer"])
        dis_optimizer.load_state_dict(checkpoint["dis_optimizer"])
        ctrl_optimizer.load_state_dict(checkpoint["ctrl_optimizer"])
        prev_archs = checkpoint["prev_archs"]
        prev_hiddens = checkpoint["prev_hiddens"]

        args.path_helper = checkpoint["path_helper"]
        logger = create_logger(args.path_helper["log_path"])
        logger.info(
            f"=> loaded checkpoint {checkpoint_file} (search iteration {start_search_iter})"
        )
    else:
        # create new log dir
        assert args.exp_name
        args.path_helper = set_log_dir("logs", args.exp_name)
        logger = create_logger(args.path_helper["log_path"])
        prev_archs = None
        prev_hiddens = None

        # set controller && its optimizer
        cur_stage = 0
        controller, ctrl_optimizer = create_ctrler(args, cur_stage,
                                                   weights_init)

    # set up data_loader
    dataset = datasets.ImageDataset(args, 2**(cur_stage + 3))
    train_loader = dataset.train

    logger.info(args)
    writer_dict = {
        "writer": SummaryWriter(args.path_helper["log_path"]),
        "controller_steps": start_search_iter * args.ctrl_step,
    }

    g_loss_history = RunningStats(args.dynamic_reset_window)
    d_loss_history = RunningStats(args.dynamic_reset_window)

    # train loop
    for search_iter in tqdm(range(int(start_search_iter),
                                  int(args.max_search_iter)),
                            desc="search progress"):
        logger.info(f"<start search iteration {search_iter}>")
        if search_iter == args.grow_step1 or search_iter == args.grow_step2:

            # save
            cur_stage = grow_ctrler.cur_stage(search_iter)
            logger.info(f"=> grow to stage {cur_stage}")
            prev_archs, prev_hiddens = get_topk_arch_hidden(
                args, controller, gen_net, prev_archs, prev_hiddens)

            # grow section
            del controller
            del ctrl_optimizer
            controller, ctrl_optimizer = create_ctrler(args, cur_stage,
                                                       weights_init)

            dataset = datasets.ImageDataset(args, 2**(cur_stage + 3))
            train_loader = dataset.train

        dynamic_reset = train_shared(
            args,
            gen_net,
            dis_net,
            g_loss_history,
            d_loss_history,
            controller,
            gen_optimizer,
            dis_optimizer,
            train_loader,
            prev_hiddens=prev_hiddens,
            prev_archs=prev_archs,
        )
        train_controller(
            args,
            controller,
            ctrl_optimizer,
            gen_net,
            prev_hiddens,
            prev_archs,
            writer_dict,
        )

        if dynamic_reset:
            logger.info("re-initialize share GAN")
            del gen_net, dis_net, gen_optimizer, dis_optimizer
            gen_net, dis_net, gen_optimizer, dis_optimizer = create_shared_gan(
                args, weights_init)

        save_checkpoint(
            {
                "cur_stage": cur_stage,
                "search_iter": search_iter + 1,
                "gen_model": args.gen_model,
                "dis_model": args.dis_model,
                "controller": args.controller,
                "gen_state_dict": gen_net.state_dict(),
                "dis_state_dict": dis_net.state_dict(),
                "ctrl_state_dict": controller.state_dict(),
                "gen_optimizer": gen_optimizer.state_dict(),
                "dis_optimizer": dis_optimizer.state_dict(),
                "ctrl_optimizer": ctrl_optimizer.state_dict(),
                "prev_archs": prev_archs,
                "prev_hiddens": prev_hiddens,
                "path_helper": args.path_helper,
            },
            False,
            args.path_helper["ckpt_path"],
        )

    final_archs, _ = get_topk_arch_hidden(args, controller, gen_net,
                                          prev_archs, prev_hiddens)
    logger.info(f"discovered archs: {final_archs}")
Ejemplo n.º 10
0
def main():
    args = cfg.parse_args()
    random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed(args.random_seed)
    np.random.seed(args.random_seed)
    # set tf env
    _init_inception()
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    # import netwo

    # weight init
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv2d') != -1:
            if args.init_type == 'normal':
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif args.init_type == 'orth':
                nn.init.orthogonal_(m.weight.data)
            elif args.init_type == 'xavier_uniform':
                nn.init.xavier_uniform(m.weight.data, 1.)
            else:
                raise NotImplementedError('{} unknown inital type'.format(
                    args.init_type))
        elif classname.find('BatchNorm2d') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.0)

    gen_net = eval('models.' + args.model + '.Generator')(args=args).cuda()
    dis_net = eval('models.' + args.model + '.Discriminator')(args=args).cuda()
    gen_net.apply(weights_init)
    dis_net.apply(weights_init)
    avg_gen_net = deepcopy(gen_net)
    initial_gen_net_weight = torch.load(os.path.join(args.init_path,
                                                     'initial_gen_net.pth'),
                                        map_location="cpu")
    initial_dis_net_weight = torch.load(os.path.join(args.init_path,
                                                     'initial_dis_net.pth'),
                                        map_location="cpu")
    assert id(initial_dis_net_weight) != id(dis_net.state_dict())
    assert id(initial_gen_net_weight) != id(gen_net.state_dict())

    # set optimizer
    gen_optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, gen_net.parameters()), args.g_lr,
        (args.beta1, args.beta2))
    dis_optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, dis_net.parameters()), args.d_lr,
        (args.beta1, args.beta2))
    gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0,
                                  args.max_iter * args.n_critic)
    dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0,
                                  args.max_iter * args.n_critic)

    # set up data_loader
    dataset = datasets.ImageDataset(args)
    train_loader = dataset.train

    # fid stat
    if args.dataset.lower() == 'cifar10':
        fid_stat = 'fid_stat/fid_stats_cifar10_train.npz'
    elif args.dataset.lower() == 'stl10':
        fid_stat = 'fid_stat/fid_stats_stl10_train.npz'
    else:
        raise NotImplementedError('no fid stat for %s' % args.dataset.lower())
    assert os.path.exists(fid_stat)

    # epoch number for dis_net
    args.max_epoch = args.max_epoch * args.n_critic
    if args.max_iter:
        args.max_epoch = np.ceil(args.max_iter * args.n_critic /
                                 len(train_loader))

    # initial
    fixed_z = torch.cuda.FloatTensor(
        np.random.normal(0, 1, (25, args.latent_dim)))

    start_epoch = 0
    best_fid = 1e4

    print('=> resuming from %s' % args.load_path)
    assert os.path.exists(args.load_path)
    checkpoint_file = args.load_path
    assert os.path.exists(checkpoint_file)
    checkpoint = torch.load(checkpoint_file)
    pruning_generate(gen_net, checkpoint['gen_state_dict'])
    dis_net.load_state_dict(checkpoint['dis_state_dict'])
    total = 0
    total_nonzero = 0
    for m in dis_net.modules():
        if isinstance(m, nn.Conv2d):
            total += m.weight_orig.data.numel()
            mask = m.weight_orig.data.abs().clone().gt(0).float().cuda()
            total_nonzero += torch.sum(mask)
    conv_weights = torch.zeros(total)
    index = 0
    for m in dis_net.modules():
        if isinstance(m, nn.Conv2d):
            size = m.weight_orig.data.numel()
            conv_weights[index:(
                index + size)] = m.weight_orig.data.view(-1).abs().clone()
            index += size

    y, i = torch.sort(conv_weights)
    # thre_index = int(total * args.percent)
    # only care about the non zero weights
    # e.g: total = 100, total_nonzero = 80, percent = 0.2, thre_index = 36, that means keep 64
    thre_index = total - total_nonzero
    thre = y[int(thre_index)]
    pruned = 0
    print('Pruning threshold: {}'.format(thre))
    zero_flag = False
    masks = OrderedDict()
    for k, m in enumerate(dis_net.modules()):
        if isinstance(m, nn.Conv2d):
            weight_copy = m.weight_orig.data.abs().clone()
            mask = weight_copy.gt(thre).float()
            masks[k] = mask
            pruned = pruned + mask.numel() - torch.sum(mask)
            m.weight_orig.data.mul_(mask)
            if int(torch.sum(mask)) == 0:
                zero_flag = True
            print(
                'layer index: {:d} \t total params: {:d} \t remaining params: {:d}'
                .format(k, mask.numel(), int(torch.sum(mask))))
    print('Total conv params: {}, Pruned conv params: {}, Pruned ratio: {}'.
          format(total, pruned, pruned / total))

    pruning_generate(avg_gen_net, checkpoint['gen_state_dict'])
    see_remain_rate(gen_net)

    if not args.finetune_G:
        gen_weight = gen_net.state_dict()
        gen_orig_weight = rewind_weight(initial_gen_net_weight,
                                        gen_weight.keys())
        gen_weight.update(gen_orig_weight)
        gen_net.load_state_dict(gen_weight)
    gen_avg_param = copy_params(gen_net)

    if args.finetune_D:
        dis_net.load_state_dict(checkpoint['dis_state_dict'])
    else:
        dis_net.load_state_dict(initial_dis_net_weight)

    for k, m in enumerate(dis_net.modules()):
        if isinstance(m, nn.Conv2d):
            m.weight_orig.data.mul_(masks[k])

    orig_dis_net = eval('models.' + args.model +
                        '.Discriminator')(args=args).cuda()
    orig_dis_net.load_state_dict(checkpoint['dis_state_dict'])
    orig_dis_net.eval()

    args.path_helper = set_log_dir('logs',
                                   args.exp_name + "_{}".format(args.percent))
    logger = create_logger(args.path_helper['log_path'])
    #logger.info('=> loaded checkpoint %s (epoch %d)' % (checkpoint_file, start_epoch))

    logger.info(args)
    writer_dict = {
        'writer': SummaryWriter(args.path_helper['log_path']),
        'train_global_steps': start_epoch * len(train_loader),
        'valid_global_steps': start_epoch // args.val_freq,
    }

    # train loop
    for epoch in tqdm(range(int(start_epoch), int(args.max_epoch)),
                      desc='total progress'):
        lr_schedulers = (gen_scheduler,
                         dis_scheduler) if args.lr_decay else None
        see_remain_rate(gen_net)
        see_remain_rate_orig(dis_net)
        if not args.use_kd_D:
            train_with_mask(args, gen_net, dis_net, gen_optimizer,
                            dis_optimizer, gen_avg_param, train_loader, epoch,
                            writer_dict, masks, lr_schedulers)
        else:
            train_with_mask_kd(args, gen_net, dis_net, orig_dis_net,
                               gen_optimizer, dis_optimizer, gen_avg_param,
                               train_loader, epoch, writer_dict, masks,
                               lr_schedulers)

        if epoch and epoch % args.val_freq == 0 or epoch == int(
                args.max_epoch) - 1:
            backup_param = copy_params(gen_net)
            load_params(gen_net, gen_avg_param)
            inception_score, fid_score = validate(args, fixed_z, fid_stat,
                                                  gen_net, writer_dict, epoch)
            logger.info(
                'Inception score: %.4f, FID score: %.4f || @ epoch %d.' %
                (inception_score, fid_score, epoch))
            load_params(gen_net, backup_param)
            if fid_score < best_fid:
                best_fid = fid_score
                is_best = True
            else:
                is_best = False
        else:
            is_best = False

        avg_gen_net.load_state_dict(gen_net.state_dict())
        load_params(avg_gen_net, gen_avg_param)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': args.model,
                'gen_state_dict': gen_net.state_dict(),
                'dis_state_dict': dis_net.state_dict(),
                'avg_gen_state_dict': avg_gen_net.state_dict(),
                'gen_optimizer': gen_optimizer.state_dict(),
                'dis_optimizer': dis_optimizer.state_dict(),
                'best_fid': best_fid,
                'path_helper': args.path_helper
            }, is_best, args.path_helper['ckpt_path'])
Ejemplo n.º 11
0
def main():
    args = cfg.parse_args()
    torch.cuda.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)
    torch.backends.cudnn.deterministic = True

    # set tf env
    _init_inception()
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    # epoch number for dis_net
    dataset = datasets.ImageDataset(args, cur_img_size=8)
    train_loader = dataset.train
    if args.max_iter:
        args.max_epoch = np.ceil(args.max_iter / len(train_loader))
    else:
        args.max_iter = args.max_epoch * len(train_loader)
    args.max_epoch = args.max_epoch * args.n_critic

    # import network
    gen_net = eval('models.' + args.gen_model + '.Generator')(args=args).cuda()
    dis_net = eval('models.' + args.dis_model +
                   '.Discriminator')(args=args).cuda()
    gen_net.set_arch(args.arch, cur_stage=2)

    # weight init
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv2d') != -1:
            if args.init_type == 'normal':
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif args.init_type == 'orth':
                nn.init.orthogonal_(m.weight.data)
            elif args.init_type == 'xavier_uniform':
                nn.init.xavier_uniform_(m.weight.data, 1.)
            else:
                raise NotImplementedError('{} unknown inital type'.format(
                    args.init_type))
        elif classname.find('BatchNorm2d') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.0)

    gen_net.apply(weights_init)
    dis_net.apply(weights_init)

    gpu_ids = [i for i in range(int(torch.cuda.device_count()))]
    gen_net = torch.nn.DataParallel(gen_net.to("cuda:0"), device_ids=gpu_ids)
    dis_net = torch.nn.DataParallel(dis_net.to("cuda:0"), device_ids=gpu_ids)

    gen_net.module.cur_stage = 0
    dis_net.module.cur_stage = 0
    gen_net.module.alpha = 1.
    dis_net.module.alpha = 1.

    # set optimizer
    if args.optimizer == "adam":
        gen_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, gen_net.parameters()), args.g_lr,
            (args.beta1, args.beta2))
        dis_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, dis_net.parameters()), args.d_lr,
            (args.beta1, args.beta2))
    elif args.optimizer == "adamw":
        gen_optimizer = AdamW(filter(lambda p: p.requires_grad,
                                     gen_net.parameters()),
                              args.g_lr,
                              weight_decay=args.wd)
        dis_optimizer = AdamW(filter(lambda p: p.requires_grad,
                                     dis_net.parameters()),
                              args.g_lr,
                              weight_decay=args.wd)
    gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0,
                                  args.max_iter * args.n_critic)
    dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0,
                                  args.max_iter * args.n_critic)

    # fid stat
    if args.dataset.lower() == 'cifar10':
        fid_stat = 'fid_stat/fid_stats_cifar10_train.npz'
    elif args.dataset.lower() == 'stl10':
        fid_stat = 'fid_stat/stl10_train_unlabeled_fid_stats_48.npz'
    elif args.fid_stat is not None:
        fid_stat = args.fid_stat
    else:
        raise NotImplementedError(f'no fid stat for {args.dataset.lower()}')
    assert os.path.exists(fid_stat)

    # initial
    fixed_z = torch.cuda.FloatTensor(
        np.random.normal(0, 1, (64, args.latent_dim)))
    gen_avg_param = copy_params(gen_net)
    start_epoch = 0
    best_fid = 1e4

    # set writer
    if args.load_path:
        print(f'=> resuming from {args.load_path}')
        assert os.path.exists(args.load_path)
        checkpoint_file = os.path.join(args.load_path)
        assert os.path.exists(checkpoint_file)
        checkpoint = torch.load(checkpoint_file)
        start_epoch = checkpoint['epoch']
        best_fid = checkpoint['best_fid']
        gen_net.load_state_dict(checkpoint['gen_state_dict'])
        dis_net.load_state_dict(checkpoint['dis_state_dict'])
        gen_optimizer.load_state_dict(checkpoint['gen_optimizer'])
        dis_optimizer.load_state_dict(checkpoint['dis_optimizer'])
        #         avg_gen_net = deepcopy(gen_net)
        #         avg_gen_net.load_state_dict(checkpoint['avg_gen_state_dict'])
        gen_avg_param = checkpoint['gen_avg_param']
        #         del avg_gen_net
        cur_stage = cur_stages(start_epoch, args)
        gen_net.module.cur_stage = cur_stage
        dis_net.module.cur_stage = cur_stage
        gen_net.module.alpha = 1.
        dis_net.module.alpha = 1.

        args.path_helper = checkpoint['path_helper']

    else:
        # create new log dir
        assert args.exp_name
        args.path_helper = set_log_dir('logs', args.exp_name)

    logger = create_logger(args.path_helper['log_path'])
    logger.info(args)
    writer_dict = {
        'writer': SummaryWriter(args.path_helper['log_path']),
        'train_global_steps': start_epoch * len(train_loader),
        'valid_global_steps': start_epoch // args.val_freq,
    }

    def return_states():
        states = {}
        states['epoch'] = epoch
        states['best_fid'] = best_fid_score
        states['gen_state_dict'] = gen_net.state_dict()
        states['dis_state_dict'] = dis_net.state_dict()
        states['gen_optimizer'] = gen_optimizer.state_dict()
        states['dis_optimizer'] = dis_optimizer.state_dict()
        states['gen_avg_param'] = gen_avg_param
        states['path_helper'] = args.path_helper
        return states

    # train loop

    for epoch in range(start_epoch + 1, args.max_epoch):
        train(
            args,
            gen_net,
            dis_net,
            gen_optimizer,
            dis_optimizer,
            gen_avg_param,
            train_loader,
            epoch,
            writer_dict,
            fixed_z,
        )
        backup_param = copy_params(gen_net)
        load_params(gen_net, gen_avg_param)
        fid_score = validate(
            args,
            fixed_z,
            fid_stat,
            epoch,
            gen_net,
            writer_dict,
        )
        logger.info(f'FID score: {fid_score} || @ epoch {epoch}.')
        load_params(gen_net, backup_param)
        is_best = False
        if epoch == 1 or fid_score < best_fid_score:
            best_fid_score = fid_score
            is_best = True
        if is_best or epoch % 1 == 0:
            states = return_states()
            save_checkpoint(states,
                            is_best,
                            args.path_helper['ckpt_path'],
                            filename=f'checkpoint_epoch_{epoch}.pth')
Ejemplo n.º 12
0
def main():
    args = cfg.parse_args()
    random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed(args.random_seed)

    # set tf env
    _init_inception()
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    # weight init
    gen_net = eval('models.' + args.model + '.Generator')(args=args)
    dis_net = eval('models.' + args.model + '.Discriminator')(args=args)

    # weight init
    def weights_init(m):
        if isinstance(m, nn.Conv2d):
            if args.init_type == 'normal':
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif args.init_type == 'orth':
                nn.init.orthogonal_(m.weight.data)
            elif args.init_type == 'xavier_uniform':
                nn.init.xavier_uniform(m.weight.data, 1.)
            else:
                raise NotImplementedError('{} unknown inital type'.format(
                    args.init_type))
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.0)

    gen_net.apply(weights_init)
    dis_net.apply(weights_init)

    gen_net = gen_net.cuda()
    dis_net = dis_net.cuda()

    avg_gen_net = deepcopy(gen_net)
    initial_gen_net_weight = deepcopy(gen_net.state_dict())
    initial_dis_net_weight = deepcopy(dis_net.state_dict())
    assert id(initial_dis_net_weight) != id(dis_net.state_dict())
    assert id(initial_gen_net_weight) != id(gen_net.state_dict())
    # set optimizer
    gen_optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, gen_net.parameters()), args.g_lr,
        (args.beta1, args.beta2))
    dis_optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, dis_net.parameters()), args.d_lr,
        (args.beta1, args.beta2))
    gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0,
                                  args.max_iter * args.n_critic)
    dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0,
                                  args.max_iter * args.n_critic)

    # set up data_loader
    dataset = datasets.ImageDataset(args)
    train_loader = dataset.train

    # fid stat
    if args.dataset.lower() == 'cifar10':
        fid_stat = 'fid_stat/fid_stats_cifar10_train.npz'
    elif args.dataset.lower() == 'stl10':
        fid_stat = 'fid_stat/fid_stats_stl10_train.npz'
    else:
        raise NotImplementedError('no fid stat for %s' % args.dataset.lower())
    assert os.path.exists(fid_stat)

    # epoch number for dis_net
    args.max_epoch = args.max_epoch * args.n_critic
    if args.max_iter:
        args.max_epoch = np.ceil(args.max_iter * args.n_critic /
                                 len(train_loader))

    # initial
    np.random.seed(args.random_seed)
    fixed_z = torch.cuda.FloatTensor(
        np.random.normal(0, 1, (25, args.latent_dim)))

    start_epoch = 0
    best_fid = 1e4

    args.path_helper = set_log_dir('logs',
                                   args.exp_name + "_{}".format(args.percent))
    logger = create_logger(args.path_helper['log_path'])
    # logger.info('=> loaded checkpoint %s (epoch %d)' % (checkpoint_file, start_epoch))

    logger.info(args)
    writer_dict = {
        'writer': SummaryWriter(args.path_helper['log_path']),
        'train_global_steps': start_epoch * len(train_loader),
        'valid_global_steps': start_epoch // args.val_freq,
    }

    print('=> resuming from %s' % args.load_path)
    assert os.path.exists(args.load_path)
    checkpoint_file = os.path.join(args.load_path, 'Model', 'checkpoint.pth')
    assert os.path.exists(checkpoint_file)
    checkpoint = torch.load(checkpoint_file)
    gen_net.load_state_dict(checkpoint['gen_state_dict'])

    torch.manual_seed(args.random_seed)
    pruning_generate(gen_net, (1 - args.percent), args.pruning_method)
    torch.manual_seed(args.random_seed)
    pruning_generate(avg_gen_net, (1 - args.percent), args.pruning_method)
    see_remain_rate(gen_net)

    if args.second_seed:
        dis_net.apply(weights_init)
    if args.finetune_D:
        dis_net.load_state_dict(checkpoint['dis_state_dict'])
    else:
        dis_net.load_state_dict(initial_dis_net_weight)

    gen_weight = gen_net.state_dict()
    gen_orig_weight = rewind_weight(initial_gen_net_weight, gen_weight.keys())
    assert id(gen_weight) != id(gen_orig_weight)
    gen_weight.update(gen_orig_weight)
    gen_net.load_state_dict(gen_weight)
    gen_avg_param = copy_params(gen_net)

    if args.use_kd_D:
        orig_dis_net = eval('models.' + args.model +
                            '.Discriminator')(args=args).cuda()
        orig_dis_net.load(checkpoint['dis_state_dict'])
        orig_dis_net.eval()
    # train loop
    for epoch in tqdm(range(int(start_epoch), int(args.max_epoch)),
                      desc='total progress'):
        lr_schedulers = (gen_scheduler,
                         dis_scheduler) if args.lr_decay else None
        see_remain_rate(gen_net)
        if not args.use_kd_D:
            train(args, gen_net, dis_net, gen_optimizer, dis_optimizer,
                  gen_avg_param, train_loader, epoch, writer_dict,
                  lr_schedulers)
        else:
            train_kd(args, gen_net, dis_net, orig_dis_net, gen_optimizer,
                     dis_optimizer, gen_avg_param, train_loader, epoch,
                     writer_dict, lr_schedulers)

        if epoch and epoch % args.val_freq == 0 or epoch == int(
                args.max_epoch) - 1:
            backup_param = copy_params(gen_net)
            load_params(gen_net, gen_avg_param)
            inception_score, fid_score = validate(args, fixed_z, fid_stat,
                                                  gen_net, writer_dict, epoch)
            logger.info(
                'Inception score: %.4f, FID score: %.4f || @ epoch %d.' %
                (inception_score, fid_score, epoch))
            load_params(gen_net, backup_param)
            if fid_score < best_fid:
                best_fid = fid_score
                is_best = True
            else:
                is_best = False
        else:
            is_best = False

        avg_gen_net.load_state_dict(gen_net.state_dict())
        load_params(avg_gen_net, gen_avg_param)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': args.model,
                'gen_state_dict': gen_net.state_dict(),
                'dis_state_dict': dis_net.state_dict(),
                'avg_gen_state_dict': avg_gen_net.state_dict(),
                'gen_optimizer': gen_optimizer.state_dict(),
                'dis_optimizer': dis_optimizer.state_dict(),
                'best_fid': best_fid,
                'path_helper': args.path_helper
            }, is_best, args.path_helper['ckpt_path'])
Ejemplo n.º 13
0
def main():
    args = cfg.parse_args()
    random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed(args.random_seed)

    # set tf env
    _init_inception()
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    # weight init
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv2d') != -1:
            if args.init_type == 'normal':
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif args.init_type == 'orth':
                nn.init.orthogonal_(m.weight.data)
            elif args.init_type == 'xavier_uniform':
                nn.init.xavier_uniform(m.weight.data, 1.)
            else:
                raise NotImplementedError('{} unknown inital type'.format(
                    args.init_type))
        elif classname.find('BatchNorm2d') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.0)

    gen_net = Generator(bottom_width=args.bottom_width,
                        gf_dim=args.gf_dim,
                        latent_dim=args.latent_dim).cuda()
    dis_net = eval('models.' + args.model + '.Discriminator')(args=args).cuda()
    gen_net.apply(weights_init)
    dis_net.apply(weights_init)

    initial_gen_net_weight = torch.load(os.path.join(args.init_path,
                                                     'initial_gen_net.pth'),
                                        map_location="cpu")
    initial_dis_net_weight = torch.load(os.path.join(args.init_path,
                                                     'initial_dis_net.pth'),
                                        map_location="cpu")

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    exp_str = args.dir
    args.load_path = os.path.join('output', exp_str, 'pth',
                                  'epoch{}.pth'.format(args.load_epoch))

    # state dict:
    assert os.path.exists(args.load_path)
    checkpoint = torch.load(args.load_path)
    print('=> loaded checkpoint %s' % args.load_path)
    state_dict = checkpoint['generator']
    gen_net = load_subnet(args, state_dict, initial_gen_net_weight).cuda()
    avg_gen_net = deepcopy(gen_net)

    # set optimizer
    gen_optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, gen_net.parameters()), args.g_lr,
        (args.beta1, args.beta2))
    dis_optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, dis_net.parameters()), args.d_lr,
        (args.beta1, args.beta2))
    gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0,
                                  args.max_iter * args.n_critic)
    dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0,
                                  args.max_iter * args.n_critic)

    # set up data_loader
    dataset = datasets.ImageDataset(args)
    train_loader = dataset.train

    # fid stat
    if args.dataset.lower() == 'cifar10':
        fid_stat = 'fid_stat/fid_stats_cifar10_train.npz'
    else:
        raise NotImplementedError('no fid stat for %s' % args.dataset.lower())
    assert os.path.exists(fid_stat)

    # epoch number for dis_net
    args.max_epoch = args.max_epoch * args.n_critic
    if args.max_iter:
        args.max_epoch = np.ceil(args.max_iter * args.n_critic /
                                 len(train_loader))

    # initial
    np.random.seed(args.random_seed)
    fixed_z = torch.cuda.FloatTensor(
        np.random.normal(0, 1, (25, args.latent_dim)))

    start_epoch = 0
    best_fid = 1e4

    args.path_helper = set_log_dir('logs', args.exp_name)
    logger = create_logger(args.path_helper['log_path'])
    #logger.info('=> loaded checkpoint %s (epoch %d)' % (checkpoint_file, start_epoch))

    logger.info(args)
    writer_dict = {
        'writer': SummaryWriter(args.path_helper['log_path']),
        'train_global_steps': start_epoch * len(train_loader),
        'valid_global_steps': start_epoch // args.val_freq,
    }
    gen_avg_param = copy_params(gen_net)
    # train loop
    for epoch in tqdm(range(int(start_epoch), int(args.max_epoch)),
                      desc='total progress'):
        lr_schedulers = (gen_scheduler,
                         dis_scheduler) if args.lr_decay else None
        train(args, gen_net, dis_net, gen_optimizer, dis_optimizer,
              gen_avg_param, train_loader, epoch, writer_dict, lr_schedulers)

        if epoch and epoch % args.val_freq == 0 or epoch == int(
                args.max_epoch) - 1:
            backup_param = copy_params(gen_net)
            load_params(gen_net, gen_avg_param)
            inception_score, fid_score = validate(args, fixed_z, fid_stat,
                                                  gen_net, writer_dict)
            logger.info(
                'Inception score: %.4f, FID score: %.4f || @ epoch %d.' %
                (inception_score, fid_score, epoch))
            load_params(gen_net, backup_param)
            if fid_score < best_fid:
                best_fid = fid_score
                is_best = True
            else:
                is_best = False
        else:
            is_best = False

        avg_gen_net.load_state_dict(gen_net.state_dict())
        load_params(avg_gen_net, gen_avg_param)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': args.model,
                'gen_state_dict': gen_net.state_dict(),
                'dis_state_dict': dis_net.state_dict(),
                'avg_gen_state_dict': avg_gen_net.state_dict(),
                'gen_optimizer': gen_optimizer.state_dict(),
                'dis_optimizer': dis_optimizer.state_dict(),
                'best_fid': best_fid,
                'path_helper': args.path_helper
            }, is_best, args.path_helper['ckpt_path'])
Ejemplo n.º 14
0
def main():
    args = cfg.parse_args()
    torch.cuda.manual_seed(args.random_seed)

    # set tf env
    _init_inception()
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    # import network
    gen_net = eval("models_search." + args.gen_model +
                   ".Generator")(args=args).cuda()
    dis_net = eval("models_search." + args.dis_model +
                   ".Discriminator")(args=args).cuda()

    gen_net.set_arch(args.arch, cur_stage=2)
    dis_net.cur_stage = 2

    # weight init
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find("Conv2d") != -1:
            if args.init_type == "normal":
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif args.init_type == "orth":
                nn.init.orthogonal_(m.weight.data)
            elif args.init_type == "xavier_uniform":
                nn.init.xavier_uniform(m.weight.data, 1.0)
            else:
                raise NotImplementedError("{} unknown inital type".format(
                    args.init_type))
        elif classname.find("BatchNorm2d") != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.0)

    gen_net.apply(weights_init)
    dis_net.apply(weights_init)

    # set optimizer
    gen_optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, gen_net.parameters()),
        args.g_lr,
        (args.beta1, args.beta2),
    )
    dis_optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, dis_net.parameters()),
        args.d_lr,
        (args.beta1, args.beta2),
    )
    gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0,
                                  args.max_iter * args.n_critic)
    dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0,
                                  args.max_iter * args.n_critic)

    # set up data_loader
    dataset = datasets.ImageDataset(args)
    train_loader = dataset.train

    # fid stat
    if args.dataset.lower() == "cifar10":
        fid_stat = "fid_stat/fid_stats_cifar10_train.npz"
    elif args.dataset.lower() == "stl10":
        fid_stat = "fid_stat/stl10_train_unlabeled_fid_stats_48.npz"
    else:
        raise NotImplementedError(f"no fid stat for {args.dataset.lower()}")
    assert os.path.exists(fid_stat)

    # epoch number for dis_net
    args.max_epoch = args.max_epoch * args.n_critic
    if args.max_iter:
        args.max_epoch = np.ceil(args.max_iter * args.n_critic /
                                 len(train_loader))

    # initial
    fixed_z = torch.cuda.FloatTensor(
        np.random.normal(0, 1, (25, args.latent_dim)))
    gen_avg_param = copy_params(gen_net)
    start_epoch = 0
    best_fid = 1e4

    # set writer
    if args.load_path:
        print(f"=> resuming from {args.load_path}")
        assert os.path.exists(args.load_path)
        checkpoint_file = os.path.join(args.load_path, "Model",
                                       "checkpoint.pth")
        assert os.path.exists(checkpoint_file)
        checkpoint = torch.load(checkpoint_file)
        start_epoch = checkpoint["epoch"]
        best_fid = checkpoint["best_fid"]
        gen_net.load_state_dict(checkpoint["gen_state_dict"])
        dis_net.load_state_dict(checkpoint["dis_state_dict"])
        gen_optimizer.load_state_dict(checkpoint["gen_optimizer"])
        dis_optimizer.load_state_dict(checkpoint["dis_optimizer"])
        avg_gen_net = deepcopy(gen_net)
        avg_gen_net.load_state_dict(checkpoint["avg_gen_state_dict"])
        gen_avg_param = copy_params(avg_gen_net)
        del avg_gen_net

        args.path_helper = checkpoint["path_helper"]
        logger = create_logger(args.path_helper["log_path"])
        logger.info(
            f"=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})")
    else:
        # create new log dir
        assert args.exp_name
        args.path_helper = set_log_dir("logs", args.exp_name)
        logger = create_logger(args.path_helper["log_path"])

    logger.info(args)
    writer_dict = {
        "writer": SummaryWriter(args.path_helper["log_path"]),
        "train_global_steps": start_epoch * len(train_loader),
        "valid_global_steps": start_epoch // args.val_freq,
    }

    # train loop
    for epoch in tqdm(range(int(start_epoch), int(args.max_epoch)),
                      desc="total progress"):
        lr_schedulers = (gen_scheduler,
                         dis_scheduler) if args.lr_decay else None
        train(
            args,
            gen_net,
            dis_net,
            gen_optimizer,
            dis_optimizer,
            gen_avg_param,
            train_loader,
            epoch,
            writer_dict,
            lr_schedulers,
        )

        if epoch and epoch % args.val_freq == 0 or epoch == int(
                args.max_epoch) - 1:
            backup_param = copy_params(gen_net)
            load_params(gen_net, gen_avg_param)
            inception_score, fid_score = validate(args, fixed_z, fid_stat,
                                                  gen_net, writer_dict)
            logger.info(
                f"Inception score: {inception_score}, FID score: {fid_score} || @ epoch {epoch}."
            )
            load_params(gen_net, backup_param)
            if fid_score < best_fid:
                best_fid = fid_score
                is_best = True
            else:
                is_best = False
        else:
            is_best = False

        avg_gen_net = deepcopy(gen_net)
        load_params(avg_gen_net, gen_avg_param)
        save_checkpoint(
            {
                "epoch": epoch + 1,
                "gen_model": args.gen_model,
                "dis_model": args.dis_model,
                "gen_state_dict": gen_net.state_dict(),
                "dis_state_dict": dis_net.state_dict(),
                "avg_gen_state_dict": avg_gen_net.state_dict(),
                "gen_optimizer": gen_optimizer.state_dict(),
                "dis_optimizer": dis_optimizer.state_dict(),
                "best_fid": best_fid,
                "path_helper": args.path_helper,
            },
            is_best,
            args.path_helper["ckpt_path"],
        )
        del avg_gen_net
Ejemplo n.º 15
0
def main():
    args = cfg_train.parse_args()
    torch.cuda.manual_seed(args.random_seed)

    # set tf env
    _init_inception()
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    # import network
    # gen_net = eval('models.' + args.gen_model + '.' + args.gen)(args=args).cuda()
    genotype_gen = eval('genotypes.%s' % args.arch_gen)
    gen_net = eval('models.' + args.gen_model + '.' + args.gen)(
        args, genotype_gen).cuda()
    # gen_net = eval('models.' + args.gen_model + '.' + args.gen)(args = args).cuda()
    if 'Discriminator' not in args.dis:
        genotype_dis = eval('genotypes.%s' % args.arch_dis)
        dis_net = eval('models.' + args.dis_model + '.' + args.dis)(
            args, genotype_dis).cuda()
    else:
        dis_net = eval('models.' + args.dis_model + '.' +
                       args.dis)(args=args).cuda()

    # weight init
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv2d') != -1:
            if args.init_type == 'normal':
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif args.init_type == 'orth':
                nn.init.orthogonal_(m.weight.data)
            elif args.init_type == 'xavier_uniform':
                nn.init.xavier_uniform(m.weight.data, 1.)
            else:
                raise NotImplementedError('{} unknown inital type'.format(
                    args.init_type))
        elif classname.find('BatchNorm2d') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.0)

    gen_net.apply(weights_init)
    dis_net.apply(weights_init)

    # set up data_loader
    dataset = datasets.ImageDataset(args)
    train_loader = dataset.train
    val_loader = dataset.valid

    # set optimizer
    gen_optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, gen_net.parameters()), args.g_lr,
        (args.beta1, args.beta2))
    dis_optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, dis_net.parameters()), args.d_lr,
        (args.beta1, args.beta2))
    gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, args.g_lr * 0.01,
                                  260 * len(train_loader),
                                  args.max_iter * args.n_critic)
    dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, args.d_lr * 0.01,
                                  260 * len(train_loader),
                                  args.max_iter * args.n_critic)

    # fid stat
    if args.dataset.lower() == 'cifar10':
        fid_stat = 'fid_stat/fid_stats_cifar10_train.npz'
    elif args.dataset.lower() == 'stl10':
        fid_stat = 'fid_stat/stl10_train_unlabeled_fid_stats_48.npz'
    elif args.dataset.lower() == 'mnist':
        fid_stat = 'fid_stat/stl10_train_unlabeled_fid_stats_48.npz'
    else:
        raise NotImplementedError(f'no fid stat for {args.dataset.lower()}')
    assert os.path.exists(fid_stat)

    # epoch number for dis_net
    args.max_epoch = args.max_epoch * args.n_critic
    if args.max_iter:
        args.max_epoch = np.ceil(args.max_iter * args.n_critic /
                                 len(train_loader))

    # initial
    fixed_z = torch.cuda.FloatTensor(
        np.random.normal(0, 1, (25, args.latent_dim)))
    fixed_z_sample = torch.cuda.FloatTensor(
        np.random.normal(0, 1, (args.eval_batch_size, args.latent_dim)))
    gen_avg_param = copy_params(gen_net)
    start_epoch = 0
    best_fid = 1e4
    best_fid_epoch = 0
    is_with_fid = 0
    std_with_fid = 0.
    best_is = 0
    best_is_epoch = 0
    fid_with_is = 0
    best_dts = 0

    # set writer
    if args.load_path:
        print(f'=> resuming from {args.load_path}')
        assert os.path.exists(args.load_path)
        checkpoint_file = os.path.join(args.load_path, 'Model',
                                       'checkpoint.pth')
        assert os.path.exists(checkpoint_file)
        checkpoint = torch.load(checkpoint_file)
        start_epoch = checkpoint['epoch']
        best_fid = checkpoint['best_fid']
        gen_net.load_state_dict(checkpoint['gen_state_dict'])
        dis_net.load_state_dict(checkpoint['dis_state_dict'])
        gen_optimizer.load_state_dict(checkpoint['gen_optimizer'])
        dis_optimizer.load_state_dict(checkpoint['dis_optimizer'])
        avg_gen_net = deepcopy(gen_net)
        avg_gen_net.load_state_dict(checkpoint['avg_gen_state_dict'])
        gen_avg_param = copy_params(avg_gen_net)
        del avg_gen_net

        args.path_helper = checkpoint['path_helper']
        logger = create_logger(args.path_helper['log_path'])
        logger.info(
            f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})')
    else:
        # create new log dir
        assert args.exp_name
        args.path_helper = set_log_dir('logs', args.exp_name)
        logger = create_logger(args.path_helper['log_path'])

    logger.info(args)
    writer_dict = {
        'writer': SummaryWriter(args.path_helper['log_path']),
        'train_global_steps': start_epoch * len(train_loader),
        'valid_global_steps': start_epoch // args.val_freq,
    }

    # calculate the FLOPs and param count of G
    input = torch.randn(args.gen_batch_size, args.latent_dim).cuda()
    flops, params = profile(gen_net, inputs=(input, ))
    flops, params = clever_format([flops, params], "%.3f")
    logger.info('FLOPs is {}, param count is {}'.format(flops, params))

    # train loop
    dg_list = []
    worst_lr = 1e-5
    for epoch in tqdm(range(int(start_epoch), int(args.max_epoch)),
                      desc='total progress'):
        lr_schedulers = (gen_scheduler,
                         dis_scheduler) if args.lr_decay else None

        train(args, gen_net, dis_net, gen_optimizer, dis_optimizer,
              gen_avg_param, train_loader, epoch, writer_dict, args.consistent,
              lr_schedulers)

        if epoch and epoch % args.val_freq == 0 or epoch == int(
                args.max_epoch) - 1:
            backup_param = copy_params(gen_net)
            load_params(gen_net, gen_avg_param)
            inception_score, std, fid_score = validate(args,
                                                       fixed_z,
                                                       fid_stat,
                                                       gen_net,
                                                       writer_dict,
                                                       args.path_helper,
                                                       search=False)
            logger.info(
                f'Inception score: {inception_score}, FID score: {fid_score}+-{std} || @ epoch {epoch}.'
            )
            load_params(gen_net, backup_param)
            if fid_score < best_fid:
                best_fid = fid_score
                best_fid_epoch = epoch
                is_with_fid = inception_score
                std_with_fid = std
                is_best = True
            else:
                is_best = False
            if inception_score > best_is:
                best_is = inception_score
                best_std = std
                fid_with_is = fid_score
                best_is_epoch = epoch
        else:
            is_best = False

        # save generated images
        if epoch % args.image_every == 0:
            gen_noise = torch.cuda.FloatTensor(
                np.random.normal(0, 1,
                                 (args.eval_batch_size, args.latent_dim)))
            # gen_images = gen_net(fixed_z_sample)
            # gen_images = gen_images.reshape(args.eval_batch_size, 32, 32, 3)
            # gen_images = gen_images.cpu().detach()
            gen_images = gen_net(fixed_z_sample).mul_(127.5).add_(
                127.5).clamp_(0.0, 255.0).permute(0, 2, 3,
                                                  1).to('cpu',
                                                        torch.uint8).numpy()
            fig = plt.figure()
            grid = ImageGrid(fig, 111, nrows_ncols=(10, 10), axes_pad=0)
            for x in range(args.eval_batch_size):
                grid[x].imshow(gen_images[x])  # cmap="gray")
                grid[x].set_xticks([])
                grid[x].set_yticks([])
            plt.savefig(
                os.path.join(args.path_helper['sample_path'],
                             "epoch_{}.png".format(epoch)))
            plt.close()

        avg_gen_net = deepcopy(gen_net)
        # avg_gen_net = eval('models.'+args.gen_model+'.' + args.gen)(args, genotype_gen).cuda()
        # avg_gen_net = eval('models.' + args.gen_model + '.' + args.gen)(args=args).cuda()
        load_params(avg_gen_net, gen_avg_param)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'gen_model': args.gen_model,
                'dis_model': args.dis_model,
                'gen_state_dict': gen_net.state_dict(),
                'dis_state_dict': dis_net.state_dict(),
                'avg_gen_state_dict': avg_gen_net.state_dict(),
                'gen_optimizer': gen_optimizer.state_dict(),
                'dis_optimizer': dis_optimizer.state_dict(),
                'best_fid': best_fid,
                'path_helper': args.path_helper
            }, is_best, args.path_helper['ckpt_path'])
        del avg_gen_net
    logger.info(
        'best_is is {}+-{}@{} epoch, fid is {}, best_fid is {}@{}, is is {}+-{}'
        .format(best_is, best_std, best_is_epoch, fid_with_is, best_fid,
                best_fid_epoch, is_with_fid, std_with_fid))
Ejemplo n.º 16
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    # Create tensorboard logger
    writer_dict = {
        'writer': SummaryWriter(path_helper['log']),
        'inner_steps': 0,
        'val_steps': 0,
        'valid_global_steps': 0
    }

    # set tf env
    if args.eval:
        _init_inception()
        inception_path = check_or_download_inception(None)
        create_inception_graph(inception_path)

    # fid_stat
    if args.dataset.lower() == 'cifar10':
        fid_stat = 'fid_stat/fid_stats_cifar10_train.npz'
    elif args.dataset.lower() == 'stl10':
        fid_stat = 'fid_stat/stl10_train_unlabeled_fid_stats_48.npz'
    elif args.dataset.lower() == 'mnist':
        fid_stat = 'fid_stat/stl10_train_unlabeled_fid_stats_48.npz'
    else:
        raise NotImplementedError(f'no fid stat for {args.dataset.lower()}')
    assert os.path.exists(fid_stat)

    # initial
    fixed_z = torch.cuda.FloatTensor(
        np.random.normal(0, 1, (25, args.latent_dim)))
    FID_best = 1e+4
    IS_best = 0.
    FID_best_epoch = 0
    IS_best_epoch = 0

    # build gen and dis
    gen = eval('model_search_gan.' + args.gen)(args)
    gen = gen.cuda()
    dis = eval('model_search_gan.' + args.dis)(args)
    dis = dis.cuda()
    logging.info("generator param size = %fMB",
                 utils.count_parameters_in_MB(gen))
    logging.info("discriminator param size = %fMB",
                 utils.count_parameters_in_MB(dis))
    if args.parallel:
        gen = nn.DataParallel(gen)
        dis = nn.DataParallel(dis)

    # resume training
    if args.load_path != '':
        gen.load_state_dict(
            torch.load(
                os.path.join(args.load_path, 'model',
                             'weights_gen_' + 'last' + '.pt')))
        dis.load_state_dict(
            torch.load(
                os.path.join(args.load_path, 'model',
                             'weights_dis_' + 'last' + '.pt')))

    # set optimizer for parameters W of gen and dis
    gen_optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, gen.parameters()), args.g_lr,
        (args.beta1, args.beta2))
    dis_optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, dis.parameters()), args.d_lr,
        (args.beta1, args.beta2))

    # set moving average parameters for generator
    gen_avg_param = copy_params(gen)

    img_size = 8 if args.grow else args.img_size
    train_transform, valid_transform = eval('utils.' + '_data_transforms_' +
                                            args.dataset + '_resize')(args,
                                                                      img_size)
    if args.dataset == 'cifar10':
        train_data = eval('dset.' + dataset[args.dataset])(
            root=args.data,
            train=True,
            download=True,
            transform=train_transform)
    elif args.dataset == 'stl10':
        train_data = eval('dset.' + dataset[args.dataset])(
            root=args.data, download=True, transform=train_transform)

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))

    train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.gen_batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=2)

    valid_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.gen_batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(
            indices[split:num_train]),
        pin_memory=True,
        num_workers=2)
    logging.info('length of train_queue is {}'.format(len(train_queue)))
    logging.info('length of valid_queue is {}'.format(len(valid_queue)))

    max_iter = len(train_queue) * args.epochs

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        gen_optimizer, float(args.epochs), eta_min=args.learning_rate_min)
    gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0,
                                  max_iter * args.n_critic)
    dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0,
                                  max_iter * args.n_critic)

    architect = Architect_gen(gen, dis, args, 'duality_gap_with_mm', logging)

    gen.set_gumbel(args.use_gumbel)
    dis.set_gumbel(args.use_gumbel)
    for epoch in range(args.start_epoch + 1, args.epochs):
        scheduler.step()
        lr = scheduler.get_lr()[0]
        logging.info('epoch %d lr %e', epoch, lr)
        logging.info('epoch %d gen_lr %e', epoch, args.g_lr)
        logging.info('epoch %d dis_lr %e', epoch, args.d_lr)

        genotype_gen = gen.genotype()
        logging.info('gen_genotype = %s', genotype_gen)

        if 'Discriminator' not in args.dis:
            genotype_dis = dis.genotype()
            logging.info('dis_genotype = %s', genotype_dis)

        print('up_1: ', F.softmax(gen.alphas_up_1, dim=-1))
        print('up_2: ', F.softmax(gen.alphas_up_2, dim=-1))
        print('up_3: ', F.softmax(gen.alphas_up_3, dim=-1))

        # determine whether use gumbel or not
        if epoch == args.fix_alphas_epochs + 1:
            gen.set_gumbel(args.use_gumbel)
            dis.set_gumbel(args.use_gumbel)

        # grow discriminator and generator
        if args.grow:
            dis.cur_stage = grow_ctrl(epoch, args.grow_epoch)
            gen.cur_stage = grow_ctrl(epoch, args.grow_epoch)
            if args.restrict_dis_grow and dis.cur_stage > 1:
                dis.cur_stage = 1
                print('debug: dis.cur_stage is {}'.format(dis.cur_stage))
            if epoch in args.grow_epoch:
                train_transform, valid_transform = utils._data_transforms_cifar10_resize(
                    args, 2**(gen.cur_stage + 3))
                train_data = dset.CIFAR10(root=args.data,
                                          train=True,
                                          download=True,
                                          transform=train_transform)
                num_train = len(train_data)
                indices = list(range(num_train))
                split = int(np.floor(args.train_portion * num_train))
                train_queue = torch.utils.data.DataLoader(
                    train_data,
                    batch_size=args.gen_batch_size,
                    sampler=torch.utils.data.sampler.SubsetRandomSampler(
                        indices[:split]),
                    pin_memory=True,
                    num_workers=2)
                valid_queue = torch.utils.data.DataLoader(
                    train_data,
                    batch_size=args.gen_batch_size,
                    sampler=torch.utils.data.sampler.SubsetRandomSampler(
                        indices[split:num_train]),
                    pin_memory=True,
                    num_workers=2)
        else:
            gen.cur_stage = 2
            dis.cur_stage = 2

        # training parameters
        train_gan_parameter(args, train_queue, gen, dis, gen_optimizer,
                            dis_optimizer, gen_avg_param, logging, writer_dict)

        # training alphas
        if epoch > args.fix_alphas_epochs:
            train_gan_alpha(args, train_queue, valid_queue, gen, dis,
                            architect, gen_optimizer, gen_avg_param, epoch, lr,
                            writer_dict, logging)

        # evaluate the IS and FID
        if args.eval and epoch % args.eval_every == 0:
            inception_score, std, fid_score = validate(args, fixed_z, fid_stat,
                                                       gen, writer_dict,
                                                       path_helper)
            logging.info('epoch {}: IS is {}+-{}, FID is {}'.format(
                epoch, inception_score, std, fid_score))
            if inception_score > IS_best:
                IS_best = inception_score
                IS_epoch_best = epoch
            if fid_score < FID_best:
                FID_best = fid_score
                FID_epoch_best = epoch
            logging.info('best epoch {}: IS is {}'.format(
                IS_best_epoch, IS_best))
            logging.info('best epoch {}: FID is {}'.format(
                FID_best_epoch, FID_best))

        utils.save(
            gen,
            os.path.join(path_helper['model'],
                         'weights_gen_{}.pt'.format('last')))
        utils.save(
            dis,
            os.path.join(path_helper['model'],
                         'weights_dis_{}.pt'.format('last')))

    genotype_gen = gen.genotype()
    if 'Discriminator' not in args.dis:
        genotype_dis = dis.genotype()
    logging.info('best epoch {}: IS is {}'.format(IS_best_epoch, IS_best))
    logging.info('best epoch {}: FID is {}'.format(FID_best_epoch, FID_best))
    logging.info('final discovered gen_arch is {}'.format(genotype_gen))
    if 'Discriminator' not in args.dis:
        logging.info('final discovered dis_arch is {}'.format(genotype_dis))
Ejemplo n.º 17
0
def main():
    opt = SearchOptions().parse()
    torch.cuda.manual_seed(12345)

    _init_inception(MODEL_DIR)
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    start_search_iter = 0
    cur_stage = 1

    delta_grow_steps = [int(opt.grow_step ** i) for i in range(1, opt.max_skip_num)] + \
                       [int(opt.grow_step ** 3) for _ in range(1, opt.n_resnet - opt.max_skip_num + 1)]

    opt.max_search_iter = sum(delta_grow_steps)
    grow_steps = [
        sum(delta_grow_steps[:i]) for i in range(len(delta_grow_steps))
    ][1:]

    grow_ctrler = GrowCtrler(opt.grow_step, steps=grow_steps)

    if opt.load_path:
        print(f'=> resuming from {opt.load_path}')
        assert os.path.exists(opt.load_path)
        checkpoint_file = os.path.join(opt.load_path, 'Model',
                                       'checkpoint.pth')
        assert os.path.exists(checkpoint_file)
        checkpoint = torch.load(checkpoint_file,
                                map_location={'cuda:0': 'cpu'})
        # set controller && its optimizer
        cur_stage = checkpoint['cur_stage']
        start_search_iter = checkpoint["search_iter"]
        opt.path_helper = checkpoint['path_helper']

        cycle_gan = CycleGANModel(opt)
        cycle_gan.setup(opt)

        cycle_controller = CycleControllerModel(opt, cur_stage=cur_stage)
        cycle_controller.setup(opt)
        cycle_controller.set(cycle_gan)

        cycle_gan.load_from_state(checkpoint["cycle_gan"])
        cycle_controller.load_from_state(checkpoint["cycle_controller"])

    else:
        opt.path_helper = set_log_dir(opt.checkpoints_dir, opt.name)

        cycle_gan = CycleGANModel(opt)
        cycle_gan.setup(opt)

        cycle_controller = CycleControllerModel(opt, cur_stage=cur_stage)
        cycle_controller.setup(opt)
        cycle_controller.set(cycle_gan)

    dataset = create_dataset(
        opt)  # create a dataset given opt.dataset_mode and other options
    print('The number of training images = %d' % len(dataset))

    writer_dict = {
        "writer": SummaryWriter(opt.path_helper['log_path']),
        'controller_steps': start_search_iter * opt.ctrl_step,
        'train_steps': start_search_iter * opt.shared_epoch
    }

    g_loss_history = RunningStats(opt.dynamic_reset_window)
    d_loss_history = RunningStats(opt.dynamic_reset_window)

    dynamic_reset = None
    for search_iter in tqdm(
            range(int(start_search_iter), int(opt.max_search_iter))):
        tqdm.write(f"<start search iteration {search_iter}>")
        cycle_controller.reset()

        if search_iter in grow_steps:
            cur_stage = grow_ctrler.cur_stage(search_iter) + 1
            tqdm.write(f'=> grow to stage {cur_stage}')
            prev_archs_A, prev_hiddens_A = cycle_controller.get_topk_arch_hidden_A(
            )
            prev_archs_B, prev_hiddens_B = cycle_controller.get_topk_arch_hidden_B(
            )

            del cycle_controller

            cycle_controller = CycleControllerModel(opt, cur_stage)
            cycle_controller.setup(opt)
            cycle_controller.set(cycle_gan, prev_hiddens_A, prev_hiddens_B,
                                 prev_archs_A, prev_archs_B)

        dynamic_reset = cyclgan_train(opt, cycle_gan, cycle_controller,
                                      dataset, g_loss_history, d_loss_history,
                                      writer_dict)

        controller_train(opt, cycle_gan, cycle_controller, writer_dict)

        if dynamic_reset:
            tqdm.write('re-initialize share GAN')
            del cycle_gan
            cycle_gan = CycleGANModel(opt)
            cycle_gan.setup(opt)

        save_checkpoint(
            {
                'cur_stage':
                cur_stage,
                'search_iter':
                search_iter + 1,
                'cycle_gan':
                cycle_gan.save_networks(epoch=search_iter),
                'cycle_controller':
                cycle_controller.save_networks(epoch=search_iter),
                'path_helper':
                opt.path_helper
            }, False, opt.path_helper['ckpt_path'])

    final_archs_A, _ = cycle_controller.get_topk_arch_hidden_A()
    final_archs_B, _ = cycle_controller.get_topk_arch_hidden_B()
    print(f"discovered archs: {final_archs_A}")
    print(f"discovered archs: {final_archs_B}")
Ejemplo n.º 18
0
def main():
    args = cfg.parse_args()
    random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed(args.random_seed)
    np.random.seed(args.random_seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True
    os.environ['PYTHONHASHSEED'] = str(args.random_seed)
    
    # set tf env
    _init_inception()
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    # import network
    gen_net = eval('models.'+args.model+'.Generator')(args=args)
    dis_net = eval('models.'+args.model+'.Discriminator')(args=args)

    initial_gen_net_weight = torch.load(os.path.join(args.init_path, 'initial_gen_net.pth'), map_location="cpu")
    initial_dis_net_weight = torch.load(os.path.join(args.init_path, 'initial_dis_net.pth'), map_location="cpu")
    
    gen_net = gen_net.cuda()
    dis_net = dis_net.cuda()
    
    gen_net.load_state_dict(initial_gen_net_weight)
    dis_net.load_state_dict(initial_dis_net_weight)
    
    # set optimizer
    gen_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, gen_net.parameters()),
                                     args.g_lr, (args.beta1, args.beta2))
    dis_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, dis_net.parameters()),
                                     args.d_lr, (args.beta1, args.beta2))
    gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0, args.max_iter * args.n_critic)
    dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0, args.max_iter * args.n_critic)

    # set up data_loader
    dataset = datasets.ImageDataset(args)
    train_loader = dataset.train

    # fid stat
    if args.dataset.lower() == 'cifar10':
        fid_stat = 'fid_stat/fid_stats_cifar10_train.npz'
    elif args.dataset.lower() == 'stl10':
        fid_stat = 'fid_stat/fid_stats_stl10_train.npz'
    else:
        raise NotImplementedError('no fid stat for %s' % args.dataset.lower())
    assert os.path.exists(fid_stat)

    # epoch number for dis_net
    args.max_epoch = args.max_epoch * args.n_critic
    if args.max_iter:
        args.max_epoch = np.ceil(args.max_iter * args.n_critic / len(train_loader))

    # initial
    fixed_z = torch.cuda.FloatTensor(np.random.normal(0, 1, (25, args.latent_dim)))
    gen_avg_param = copy_params(gen_net)
    start_epoch = 0
    best_fid = 1e4

    # set writer
    if args.load_path:
        print('=> resuming from %s' % args.load_path)
        assert os.path.exists(args.load_path)
        checkpoint_file = os.path.join(args.load_path, 'Model', 'checkpoint.pth')
        assert os.path.exists(checkpoint_file)
        checkpoint = torch.load(checkpoint_file)
        start_epoch = checkpoint['epoch']
        best_fid = checkpoint['best_fid']
        gen_net.load_state_dict(checkpoint['gen_state_dict'])
        dis_net.load_state_dict(checkpoint['dis_state_dict'])
        gen_optimizer.load_state_dict(checkpoint['gen_optimizer'])
        dis_optimizer.load_state_dict(checkpoint['dis_optimizer'])
        avg_gen_net = deepcopy(gen_net)
        avg_gen_net.load_state_dict(checkpoint['avg_gen_state_dict'])
        gen_avg_param = copy_params(avg_gen_net)
        del avg_gen_net

        args.path_helper = checkpoint['path_helper']
        logger = create_logger(args.path_helper['log_path'])
        logger.info('=> loaded checkpoint %s (epoch %d)' % (checkpoint_file, start_epoch))
    else:
        # create new log dir
        assert args.exp_name
        args.path_helper = set_log_dir('logs', args.exp_name)
        logger = create_logger(args.path_helper['log_path'])

    logger.info(args)
    writer_dict = {
        'writer': SummaryWriter(args.path_helper['log_path']),
        'train_global_steps': start_epoch * len(train_loader),
        'valid_global_steps': start_epoch // args.val_freq,
    }

    # train loop
    switch = False
    for epoch in range(int(start_epoch), int(args.max_epoch)):
            
        lr_schedulers = (gen_scheduler, dis_scheduler) if args.lr_decay else None
        train(args, gen_net, dis_net, gen_optimizer, dis_optimizer, gen_avg_param, train_loader, epoch, writer_dict,
              lr_schedulers)

        if epoch and epoch % args.val_freq == 0 or epoch == int(args.max_epoch)-1:
            backup_param = copy_params(gen_net)
            load_params(gen_net, gen_avg_param)
            inception_score, fid_score = validate(args, fixed_z, fid_stat, gen_net, writer_dict, epoch)
            logger.info('Inception score: %.4f, FID score: %.4f || @ epoch %d.' % (inception_score, fid_score, epoch))
            load_params(gen_net, backup_param)
            if fid_score < best_fid:
                best_fid = fid_score
                is_best = True
            else:
                is_best = False
        else:
            is_best = False

        avg_gen_net = deepcopy(gen_net)
        load_params(avg_gen_net, gen_avg_param)
        save_checkpoint({
            'epoch': epoch + 1,
            'model': args.model,
            'gen_state_dict': gen_net.state_dict(),
            'dis_state_dict': dis_net.state_dict(),
            'avg_gen_state_dict': avg_gen_net.state_dict(),
            'gen_optimizer': gen_optimizer.state_dict(),
            'dis_optimizer': dis_optimizer.state_dict(),
            'best_fid': best_fid,
            'path_helper': args.path_helper,
            'seed': args.random_seed
        }, is_best, args.path_helper['ckpt_path'])
        del avg_gen_net
Ejemplo n.º 19
0
args.fid_buffer_dir = os.path.join(args.output_dir, 'fid_buffer')
args.do_IS = False
args.do_FID = False
create_dir(args.pth_dir), create_dir(args.img_dir), create_dir(
    args.gamma_dir), create_dir(args.fid_buffer_dir)

# gpu:
torch.manual_seed(args.random_seed)
torch.cuda.manual_seed(args.random_seed)
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

# set tf env
if args.do_IS:
    _init_inception()
if args.do_FID:
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

# G0:
with torch.no_grad():
    # define model
    G0 = Generator(bottom_width=args.bottom_width,
                   gf_dim=args.gf_dim,
                   latent_dim=args.latent_dim).cuda()
    # load ckpt
    pth_path = os.path.join(args.load_path, 'Model', 'checkpoint_best.pth')
    checkpoint = torch.load(pth_path)
    G0.load_state_dict(checkpoint['avg_gen_state_dict'])
    best_dense_epoch = checkpoint['epoch']
Ejemplo n.º 20
0
def main():
    args = cfg.parse_args()
    torch.cuda.manual_seed(args.random_seed)

    # set visible GPU ids
    if len(args.gpu_ids) > 0:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_ids

    # set TensorFlow environment for evaluation (calculate IS and FID)
    _init_inception()
    inception_path = check_or_download_inception('./tmp/imagenet/')
    create_inception_graph(inception_path)

    # the first GPU in visible GPUs is dedicated for evaluation (running Inception model)
    str_ids = args.gpu_ids.split(',')
    args.gpu_ids = []
    for id in range(len(str_ids)):
        if id >= 0:
            args.gpu_ids.append(id)
    if len(args.gpu_ids) > 1:
        args.gpu_ids = args.gpu_ids[1:]
    else:
        args.gpu_ids = args.gpu_ids

    # genotype G
    genotypes_root = os.path.join('exps', args.genotypes_exp, 'Genotypes')
    genotype_G = np.load(os.path.join(genotypes_root, 'latest_G.npy'))

    # import network from genotype
    basemodel_gen = eval('archs.' + args.arch + '.Generator')(args, genotype_G)
    gen_net = torch.nn.DataParallel(
        basemodel_gen, device_ids=args.gpu_ids).cuda(args.gpu_ids[0])
    basemodel_dis = eval('archs.' + args.arch + '.Discriminator')(args)
    dis_net = torch.nn.DataParallel(
        basemodel_dis, device_ids=args.gpu_ids).cuda(args.gpu_ids[0])

    # basemodel_gen = eval('archs.' + args.arch + '.Generator')(args=args)
    # gen_net = torch.nn.DataParallel(basemodel_gen, device_ids=args.gpu_ids).cuda(args.gpu_ids[0])
    # basemodel_dis = eval('archs.' + args.arch + '.Discriminator')(args=args)
    # dis_net = torch.nn.DataParallel(basemodel_dis, device_ids=args.gpu_ids).cuda(args.gpu_ids[0])

    # weight init
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv2d') != -1:
            if args.init_type == 'normal':
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif args.init_type == 'orth':
                nn.init.orthogonal_(m.weight.data)
            elif args.init_type == 'xavier_uniform':
                nn.init.xavier_uniform(m.weight.data, 1.)
            else:
                raise NotImplementedError('{} unknown inital type'.format(
                    args.init_type))
        elif classname.find('BatchNorm2d') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.0)

    gen_net.apply(weights_init)
    dis_net.apply(weights_init)

    # set up data_loader
    dataset = datasets.ImageDataset(args)
    train_loader = dataset.train

    # epoch number for dis_net
    args.max_epoch_D = args.max_epoch_G * args.n_critic
    if args.max_iter_G:
        args.max_epoch_D = np.ceil(args.max_iter_G * args.n_critic /
                                   len(train_loader))
    max_iter_D = args.max_epoch_D * len(train_loader)

    # set optimizer
    gen_optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, gen_net.parameters()), args.g_lr,
        (args.beta1, args.beta2))
    dis_optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, dis_net.parameters()), args.d_lr,
        (args.beta1, args.beta2))
    gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0, max_iter_D)
    dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0, max_iter_D)

    # fid stat
    if args.dataset.lower() == 'cifar10':
        fid_stat = 'fid_stat/fid_stats_cifar10_train.npz'
    elif args.dataset.lower() == 'stl10':
        fid_stat = 'fid_stat/stl10_train_unlabeled_fid_stats_48.npz'
    else:
        raise NotImplementedError(f'no fid stat for {args.dataset.lower()}')
    assert os.path.exists(fid_stat)

    # initial
    gen_avg_param = copy_params(gen_net)
    start_epoch = 0
    best_fid = 1e4

    # set writer
    if args.checkpoint:
        # resuming
        print(f'=> resuming from {args.checkpoint}')
        assert os.path.exists(os.path.join('exps', args.checkpoint))
        checkpoint_file = os.path.join('exps', args.checkpoint, 'Model',
                                       'checkpoint_best.pth')
        assert os.path.exists(checkpoint_file)
        checkpoint = torch.load(checkpoint_file)
        start_epoch = checkpoint['epoch']
        best_fid = checkpoint['best_fid']
        gen_net.load_state_dict(checkpoint['gen_state_dict'])
        dis_net.load_state_dict(checkpoint['dis_state_dict'])
        gen_optimizer.load_state_dict(checkpoint['gen_optimizer'])
        dis_optimizer.load_state_dict(checkpoint['dis_optimizer'])
        avg_gen_net = deepcopy(gen_net)
        avg_gen_net.load_state_dict(checkpoint['avg_gen_state_dict'])
        gen_avg_param = copy_params(avg_gen_net)
        del avg_gen_net

        args.path_helper = checkpoint['path_helper']
        logger = create_logger(args.path_helper['log_path'])
        logger.info(
            f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})')
    else:
        # create new log dir
        assert args.exp_name
        args.path_helper = set_log_dir('exps', args.exp_name)
        logger = create_logger(args.path_helper['log_path'])

    logger.info(args)
    writer_dict = {
        'writer': SummaryWriter(args.path_helper['log_path']),
        'train_global_steps': start_epoch * len(train_loader),
        'valid_global_steps': start_epoch // args.val_freq,
    }

    # model size
    logger.info('Param size of G = %fMB', count_parameters_in_MB(gen_net))
    logger.info('Param size of D = %fMB', count_parameters_in_MB(dis_net))
    print_FLOPs(basemodel_gen, (1, args.latent_dim), logger)
    print_FLOPs(basemodel_dis, (1, 3, args.img_size, args.img_size), logger)

    # for visualization
    if args.draw_arch:
        from utils.genotype import draw_graph_G
        draw_graph_G(genotype_G,
                     save=True,
                     file_path=os.path.join(args.path_helper['graph_vis_path'],
                                            'latest_G'))
    fixed_z = torch.cuda.FloatTensor(
        np.random.normal(0, 1, (100, args.latent_dim)))

    # train loop
    for epoch in tqdm(range(int(start_epoch), int(args.max_epoch_D)),
                      desc='total progress'):
        lr_schedulers = (gen_scheduler,
                         dis_scheduler) if args.lr_decay else None
        train(args, gen_net, dis_net, gen_optimizer, dis_optimizer,
              gen_avg_param, train_loader, epoch, writer_dict, lr_schedulers)

        if epoch % args.val_freq == 0 or epoch == int(args.max_epoch_D) - 1:
            backup_param = copy_params(gen_net)
            load_params(gen_net, gen_avg_param)
            inception_score, std, fid_score = validate(args, fixed_z, fid_stat,
                                                       gen_net, writer_dict)
            logger.info(
                f'Inception score mean: {inception_score}, Inception score std: {std}, '
                f'FID score: {fid_score} || @ epoch {epoch}.')
            load_params(gen_net, backup_param)
            if fid_score < best_fid:
                best_fid = fid_score
                is_best = True
            else:
                is_best = False
        else:
            is_best = False

        # save model
        avg_gen_net = deepcopy(gen_net)
        load_params(avg_gen_net, gen_avg_param)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': args.arch,
                'gen_state_dict': gen_net.state_dict(),
                'dis_state_dict': dis_net.state_dict(),
                'avg_gen_state_dict': avg_gen_net.state_dict(),
                'gen_optimizer': gen_optimizer.state_dict(),
                'dis_optimizer': dis_optimizer.state_dict(),
                'best_fid': best_fid,
                'path_helper': args.path_helper
            }, is_best, args.path_helper['ckpt_path'])
        del avg_gen_net
Ejemplo n.º 21
0
def main():
    args = cfg.parse_args()

    torch.cuda.manual_seed(args.random_seed)
    print(args)

    # set tf env
    _init_inception()
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    # weight init
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv2d') != -1:
            if args.init_type == 'normal':
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif args.init_type == 'orth':
                nn.init.orthogonal_(m.weight.data)
            elif args.init_type == 'xavier_uniform':
                nn.init.xavier_uniform(m.weight.data, 1.)
            else:
                raise NotImplementedError('{} unknown inital type'.format(
                    args.init_type))
        elif classname.find('BatchNorm2d') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.0)

    gen_net, dis_net, gen_optimizer, dis_optimizer = create_shared_gan(
        args, weights_init)

    # initial
    start_search_iter = 0

    # set writer
    if args.load_path:
        print(f'=> resuming from {args.load_path}')
        assert os.path.exists(args.load_path)
        checkpoint_file = os.path.join(args.load_path, 'Model',
                                       'checkpoint.pth')
        assert os.path.exists(checkpoint_file)
        checkpoint = torch.load(checkpoint_file)
        cur_stage = checkpoint['cur_stage']

        start_search_iter = checkpoint['search_iter']
        gen_net.load_state_dict(checkpoint['gen_state_dict'])
        dis_net.load_state_dict(checkpoint['dis_state_dict'])
        gen_optimizer.load_state_dict(checkpoint['gen_optimizer'])
        dis_optimizer.load_state_dict(checkpoint['dis_optimizer'])
        prev_archs = checkpoint['prev_archs']
        prev_hiddens = checkpoint['prev_hiddens']

        args.path_helper = checkpoint['path_helper']
        logger = create_logger(args.path_helper['log_path'])
        logger.info(
            f'=> loaded checkpoint {checkpoint_file} (search iteration {start_search_iter})'
        )
    else:
        # create new log dir
        assert args.exp_name
        args.path_helper = set_log_dir('logs', args.exp_name)
        logger = create_logger(args.path_helper['log_path'])
        prev_archs = None
        prev_hiddens = None

        # set controller && its optimizer
        cur_stage = 0

    # set up data_loader
    dataset = datasets.ImageDataset(args, 2**(cur_stage + 3))
    train_loader = dataset.train
    print(args.rl_num_eval_img, "##############################")
    logger.info(args)
    writer_dict = {
        'writer': SummaryWriter(args.path_helper['log_path']),
        'controller_steps': start_search_iter * args.ctrl_step
    }

    g_loss_history = RunningStats(args.dynamic_reset_window)
    d_loss_history = RunningStats(args.dynamic_reset_window)

    # train loop
    Agent = SAC(131)
    print(Agent.alpha)

    memory = ReplayMemory(2560000)
    updates = 0
    outinfo = {
        'rewards': [],
        'a_loss': [],
        'critic_error': [],
    }
    Best = False
    Z_NUMPY = None
    WARMUP = True
    update_time = 1
    for search_iter in tqdm(range(int(start_search_iter), 100),
                            desc='search progress'):
        logger.info(f"<start search iteration {search_iter}>")
        if search_iter >= 1:
            WARMUP = False

        ### Define number of layers, currently only support 1->3
        total_layer_num = 3
        ### Different image size for different layers
        ds = [
            datasets.ImageDataset(args, 2**(k + 3))
            for k in range(total_layer_num)
        ]
        train_loaders = [d.train for d in ds]
        last_R = 0.  # Initial reward
        last_fid = 10000  # Inital reward
        last_arch = []

        # Set exploration
        if search_iter > 69:
            update_time = 10
            Best = True
        else:
            Best = False

        gen_net.set_stage(-1)
        last_R, last_fid, last_state = get_is(args,
                                              gen_net,
                                              args.rl_num_eval_img,
                                              get_is_score=True)
        for layer in range(total_layer_num):

            cur_stage = layer  # This defines which layer to use as output, for example, if cur_stage==0, then the output will be the first layer output. Set it to 2 if you want the output of the last layer.
            action = Agent.select_action([layer, last_R, 0.01 * last_fid] +
                                         last_state, Best)
            arch = [
                action[0][0], action[0][1], action[1][0], action[1][1],
                action[1][2], action[2][0], action[2][1], action[2][2],
                action[3][0], action[3][1], action[4][0], action[4][1],
                action[5][0], action[5][1]
            ]
            # print(arch)
            # argmax to get int description of arch
            cur_arch = [np.argmax(k) for k in action]
            # Pad the skip option 0=False (for only layer 1 and layer2, not layer0, see builing_blocks.py for why)
            if layer == 0:
                cur_arch = cur_arch[0:4]
            elif layer == 1:
                cur_arch = cur_arch[0:5]
            elif layer == 2:
                if cur_arch[4] + cur_arch[5] == 2:
                    cur_arch = cur_arch[0:4] + [3]
                elif cur_arch[4] + cur_arch[5] == 0:
                    cur_arch = cur_arch[0:4] + [0]
                elif cur_arch[4] == 1 and cur_arch[5] == 0:
                    cur_arch = cur_arch[0:4] + [1]
                else:
                    cur_arch = cur_arch[0:4] + [2]

            # Get the network arch with the new architecture attached.
            last_arch += cur_arch
            gen_net.set_arch(last_arch,
                             layer)  # Set the network, given cur_stage
            # Train network
            dynamic_reset = train_qin(args,
                                      gen_net,
                                      dis_net,
                                      g_loss_history,
                                      d_loss_history,
                                      gen_optimizer,
                                      dis_optimizer,
                                      train_loaders[layer],
                                      cur_stage,
                                      smooth=False,
                                      WARMUP=WARMUP)

            # Get reward, use the jth layer output for generation. (layer 0:j), and the proposed progressive state
            R, fid, state = get_is(args,
                                   gen_net,
                                   args.rl_num_eval_img,
                                   z_numpy=Z_NUMPY)
            # Print exploitation mark, for better readability of the log.
            if Best:
                print("arch:", cur_arch, "Exploitation:", Best)
            else:
                print("arch:", cur_arch, "Exploring...")
            # Proxy reward of the up-to-now (0:j) architecture.
            print("update times:", updates, "step:", layer + 1, "IS:", R,
                  "FID:", fid)
            mask = 0 if layer == total_layer_num - 1 else 1
            if search_iter >= 0:  # warm up
                memory.push([layer, last_R, 0.01 * last_fid] + last_state,
                            arch, R - last_R + 0.01 * (last_fid - fid),
                            [layer + 1, R, 0.01 * fid] + state,
                            mask)  # Append transition to memory

            if len(memory) >= 64:
                # Number of updates per step in environment
                for i in range(update_time):
                    # Update parameters of all the networks
                    critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha = Agent.update_parameters(
                        memory, min(len(memory), 256), updates)

                    updates += 1
                    outinfo['critic_error'] = min(critic_1_loss, critic_2_loss)
                    outinfo['entropy'] = ent_loss
                    outinfo['a_loss'] = policy_loss
                print("full batch", outinfo, alpha)
            last_R = R  # next step
            last_fid = fid
            last_state = state
        outinfo['rewards'] = R
        critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha = Agent.update_parameters(
            memory, len(memory), updates)
        updates += 1
        outinfo['critic_error'] = min(critic_1_loss, critic_2_loss)
        outinfo['entropy'] = ent_loss
        outinfo['a_loss'] = policy_loss
        print("full batch", outinfo, alpha)
        # Clean up and start a new trajectory from scratch
        del gen_net, dis_net, gen_optimizer, dis_optimizer
        gen_net, dis_net, gen_optimizer, dis_optimizer = create_shared_gan(
            args, weights_init)
        print(outinfo, len(memory))
        Agent.save_model("test")
        WARMUP = False