Exemple #1
0
def main():
    args = cfg.parse_args()
    torch.cuda.manual_seed(args.random_seed)

    # set tf env
    _init_inception()
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    # import network
    gen_net = eval('models.' + args.model + '.Generator')(args=args).cuda()
    dis_net = eval('models.' + args.model + '.Discriminator')(args=args).cuda()

    # weight init
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv2d') != -1:
            if args.init_type == 'normal':
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif args.init_type == 'orth':
                nn.init.orthogonal_(m.weight.data)
            elif args.init_type == 'xavier_uniform':
                nn.init.xavier_uniform(m.weight.data, 1.)
            else:
                raise NotImplementedError('{} unknown inital type'.format(
                    args.init_type))
        elif classname.find('BatchNorm2d') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.0)

    gen_net.apply(weights_init)
    dis_net.apply(weights_init)

    # set optimizer
    gen_optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, gen_net.parameters()), args.g_lr,
        (args.beta1, args.beta2))
    dis_optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, dis_net.parameters()), args.d_lr,
        (args.beta1, args.beta2))
    gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0,
                                  args.max_iter * args.n_critic)
    dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0,
                                  args.max_iter * args.n_critic)

    # set up data_loader
    dataset = datasets.ImageDataset(args)
    train_loader = dataset.train

    # fid stat
    if args.dataset.lower() == 'cifar10':
        fid_stat = 'fid_stat/fid_stats_cifar10_train.npz'
    elif args.dataset.lower() == 'stl10':
        fid_stat = 'fid_stat/stl10_train_unlabeled_fid_stats_48.npz'
    else:
        raise NotImplementedError(f'no fid stat for {args.dataset.lower()}')
    assert os.path.exists(fid_stat)

    # epoch number for dis_net
    args.max_epoch = args.max_epoch * args.n_critic
    if args.max_iter:
        args.max_epoch = np.ceil(args.max_iter * args.n_critic /
                                 len(train_loader))

    # initial
    fixed_z = torch.cuda.FloatTensor(
        np.random.normal(0, 1, (25, args.latent_dim)))
    gen_avg_param = copy_params(gen_net)
    start_epoch = 0
    best_fid = 1e4

    # set writer
    if args.load_path:
        print(f'=> resuming from {args.load_path}')
        assert os.path.exists(args.load_path)
        checkpoint_file = os.path.join(args.load_path, 'Model',
                                       'checkpoint.pth')
        assert os.path.exists(checkpoint_file)
        checkpoint = torch.load(checkpoint_file)
        start_epoch = checkpoint['epoch']
        best_fid = checkpoint['best_fid']
        gen_net.load_state_dict(checkpoint['gen_state_dict'])
        dis_net.load_state_dict(checkpoint['dis_state_dict'])
        gen_optimizer.load_state_dict(checkpoint['gen_optimizer'])
        dis_optimizer.load_state_dict(checkpoint['dis_optimizer'])
        avg_gen_net = deepcopy(gen_net)
        avg_gen_net.load_state_dict(checkpoint['avg_gen_state_dict'])
        gen_avg_param = copy_params(avg_gen_net)
        del avg_gen_net

        args.path_helper = checkpoint['path_helper']
        logger = create_logger(args.path_helper['log_path'])
        logger.info(
            f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})')
    else:
        # create new log dir
        assert args.exp_name
        args.path_helper = set_log_dir('logs', args.exp_name)
        logger = create_logger(args.path_helper['log_path'])

    logger.info(args)
    writer_dict = {
        'writer': SummaryWriter(args.path_helper['log_path']),
        'train_global_steps': start_epoch * len(train_loader),
        'valid_global_steps': start_epoch // args.val_freq,
    }

    # train loop
    for epoch in tqdm(range(int(start_epoch), int(args.max_epoch)),
                      desc='total progress'):
        lr_schedulers = (gen_scheduler,
                         dis_scheduler) if args.lr_decay else None
        train(args, gen_net, dis_net, gen_optimizer, dis_optimizer,
              gen_avg_param, train_loader, epoch, writer_dict, lr_schedulers)

        if epoch and epoch % args.val_freq == 0 or epoch == int(
                args.max_epoch) - 1:
            backup_param = copy_params(gen_net)
            load_params(gen_net, gen_avg_param)
            inception_score, fid_score = validate(args, fixed_z, fid_stat,
                                                  gen_net, writer_dict)
            logger.info(
                f'Inception score: {inception_score}, FID score: {fid_score} || @ epoch {epoch}.'
            )
            load_params(gen_net, backup_param)
            if fid_score < best_fid:
                best_fid = fid_score
                is_best = True
            else:
                is_best = False
        else:
            is_best = False

        avg_gen_net = deepcopy(gen_net)
        load_params(avg_gen_net, gen_avg_param)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': args.model,
                'gen_state_dict': gen_net.state_dict(),
                'dis_state_dict': dis_net.state_dict(),
                'avg_gen_state_dict': avg_gen_net.state_dict(),
                'gen_optimizer': gen_optimizer.state_dict(),
                'dis_optimizer': dis_optimizer.state_dict(),
                'best_fid': best_fid,
                'path_helper': args.path_helper
            }, is_best, args.path_helper['ckpt_path'])
        del avg_gen_net
Exemple #2
0
def main():
    args = cfg.parse_args()
    random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed(args.random_seed)

    # set tf env
    _init_inception()
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    # weight init
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv2d') != -1:
            if args.init_type == 'normal':
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif args.init_type == 'orth':
                nn.init.orthogonal_(m.weight.data)
            elif args.init_type == 'xavier_uniform':
                nn.init.xavier_uniform(m.weight.data, 1.)
            else:
                raise NotImplementedError('{} unknown inital type'.format(
                    args.init_type))
        elif classname.find('BatchNorm2d') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.0)

    gen_net = Generator(bottom_width=args.bottom_width,
                        gf_dim=args.gf_dim,
                        latent_dim=args.latent_dim).cuda()
    dis_net = eval('models.' + args.model + '.Discriminator')(args=args).cuda()
    gen_net.apply(weights_init)
    dis_net.apply(weights_init)

    initial_gen_net_weight = torch.load(os.path.join(args.init_path,
                                                     'initial_gen_net.pth'),
                                        map_location="cpu")
    initial_dis_net_weight = torch.load(os.path.join(args.init_path,
                                                     'initial_dis_net.pth'),
                                        map_location="cpu")

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    exp_str = args.dir
    args.load_path = os.path.join('output', exp_str, 'pth',
                                  'epoch{}.pth'.format(args.load_epoch))

    # state dict:
    assert os.path.exists(args.load_path)
    checkpoint = torch.load(args.load_path)
    print('=> loaded checkpoint %s' % args.load_path)
    state_dict = checkpoint['generator']
    gen_net = load_subnet(args, state_dict, initial_gen_net_weight).cuda()
    avg_gen_net = deepcopy(gen_net)

    # set optimizer
    gen_optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, gen_net.parameters()), args.g_lr,
        (args.beta1, args.beta2))
    dis_optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, dis_net.parameters()), args.d_lr,
        (args.beta1, args.beta2))
    gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0,
                                  args.max_iter * args.n_critic)
    dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0,
                                  args.max_iter * args.n_critic)

    # set up data_loader
    dataset = datasets.ImageDataset(args)
    train_loader = dataset.train

    # fid stat
    if args.dataset.lower() == 'cifar10':
        fid_stat = 'fid_stat/fid_stats_cifar10_train.npz'
    else:
        raise NotImplementedError('no fid stat for %s' % args.dataset.lower())
    assert os.path.exists(fid_stat)

    # epoch number for dis_net
    args.max_epoch = args.max_epoch * args.n_critic
    if args.max_iter:
        args.max_epoch = np.ceil(args.max_iter * args.n_critic /
                                 len(train_loader))

    # initial
    np.random.seed(args.random_seed)
    fixed_z = torch.cuda.FloatTensor(
        np.random.normal(0, 1, (25, args.latent_dim)))

    start_epoch = 0
    best_fid = 1e4

    args.path_helper = set_log_dir('logs', args.exp_name)
    logger = create_logger(args.path_helper['log_path'])
    #logger.info('=> loaded checkpoint %s (epoch %d)' % (checkpoint_file, start_epoch))

    logger.info(args)
    writer_dict = {
        'writer': SummaryWriter(args.path_helper['log_path']),
        'train_global_steps': start_epoch * len(train_loader),
        'valid_global_steps': start_epoch // args.val_freq,
    }
    gen_avg_param = copy_params(gen_net)
    # train loop
    for epoch in tqdm(range(int(start_epoch), int(args.max_epoch)),
                      desc='total progress'):
        lr_schedulers = (gen_scheduler,
                         dis_scheduler) if args.lr_decay else None
        train(args, gen_net, dis_net, gen_optimizer, dis_optimizer,
              gen_avg_param, train_loader, epoch, writer_dict, lr_schedulers)

        if epoch and epoch % args.val_freq == 0 or epoch == int(
                args.max_epoch) - 1:
            backup_param = copy_params(gen_net)
            load_params(gen_net, gen_avg_param)
            inception_score, fid_score = validate(args, fixed_z, fid_stat,
                                                  gen_net, writer_dict)
            logger.info(
                'Inception score: %.4f, FID score: %.4f || @ epoch %d.' %
                (inception_score, fid_score, epoch))
            load_params(gen_net, backup_param)
            if fid_score < best_fid:
                best_fid = fid_score
                is_best = True
            else:
                is_best = False
        else:
            is_best = False

        avg_gen_net.load_state_dict(gen_net.state_dict())
        load_params(avg_gen_net, gen_avg_param)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': args.model,
                'gen_state_dict': gen_net.state_dict(),
                'dis_state_dict': dis_net.state_dict(),
                'avg_gen_state_dict': avg_gen_net.state_dict(),
                'gen_optimizer': gen_optimizer.state_dict(),
                'dis_optimizer': dis_optimizer.state_dict(),
                'best_fid': best_fid,
                'path_helper': args.path_helper
            }, is_best, args.path_helper['ckpt_path'])
Exemple #3
0
def main(index, args):
    device = xm.xla_device()

    gen_net = Generator(args).to(device)
    dis_net = Discriminator(args).to(device)
    enc_net = Encoder(args).to(device)

    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv2d') != -1:
            if args.init_type == 'normal':
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif args.init_type == 'orth':
                nn.init.orthogonal_(m.weight.data)
            elif args.init_type == 'xavier_uniform':
                nn.init.xavier_uniform(m.weight.data, 1.)
            else:
                raise NotImplementedError('{} unknown inital type'.format(
                    args.init_type))
        elif classname.find('BatchNorm2d') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.0)

    gen_net.apply(weights_init)
    dis_net.apply(weights_init)
    enc_net.apply(weights_init)

    ae_recon_optimizer = torch.optim.Adam(
        itertools.chain(enc_net.parameters(), gen_net.parameters()),
        args.ae_recon_lr, (args.beta1, args.beta2))
    ae_reg_optimizer = torch.optim.Adam(
        itertools.chain(enc_net.parameters(), gen_net.parameters()),
        args.ae_reg_lr, (args.beta1, args.beta2))
    dis_optimizer = torch.optim.Adam(dis_net.parameters(), args.d_lr,
                                     (args.beta1, args.beta2))
    gen_optimizer = torch.optim.Adam(gen_net.parameters(), args.g_lr,
                                     (args.beta1, args.beta2))

    dataset = datasets.ImageDataset(args)
    train_loader = dataset.train
    valid_loader = dataset.valid
    para_loader = pl.ParallelLoader(train_loader, [device])

    fid_stat = str(pathlib.Path(
        __file__).parent.absolute()) + '/fid_stat/fid_stat_cifar10_test.npz'
    if not os.path.exists(fid_stat):
        download_stat_cifar10_test()

    is_best = True
    args.num_epochs = np.ceil(args.num_iter / len(train_loader))

    gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0,
                                  args.num_iter / 2, args.num_iter)
    dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0,
                                  args.num_iter / 2, args.num_iter)
    ae_recon_scheduler = LinearLrDecay(ae_recon_optimizer, args.ae_recon_lr, 0,
                                       args.num_iter / 2, args.num_iter)
    ae_reg_scheduler = LinearLrDecay(ae_reg_optimizer, args.ae_reg_lr, 0,
                                     args.num_iter / 2, args.num_iter)

    # initial
    start_epoch = 0
    best_fid = 1e4

    # set writer
    if args.load_path:
        print(f'=> resuming from {args.load_path}')
        assert os.path.exists(args.load_path)
        checkpoint_file = os.path.join(args.load_path, 'Model',
                                       'checkpoint.pth')
        assert os.path.exists(checkpoint_file)
        checkpoint = torch.load(checkpoint_file)
        start_epoch = checkpoint['epoch']
        best_fid = checkpoint['best_fid']
        gen_net.load_state_dict(checkpoint['gen_state_dict'])
        enc_net.load_state_dict(checkpoint['enc_state_dict'])
        dis_net.load_state_dict(checkpoint['dis_state_dict'])
        gen_optimizer.load_state_dict(checkpoint['gen_optimizer'])
        dis_optimizer.load_state_dict(checkpoint['dis_optimizer'])
        ae_recon_optimizer.load_state_dict(checkpoint['ae_recon_optimizer'])
        ae_reg_optimizer.load_state_dict(checkpoint['ae_reg_optimizer'])
        args.path_helper = checkpoint['path_helper']
        logger = create_logger(args.path_helper['log_path'])
        logger.info(
            f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})')
    else:
        # create new log dir
        assert args.exp_name
        logs_dir = str(pathlib.Path(__file__).parent.parent) + '/logs'
        args.path_helper = set_log_dir(logs_dir, args.exp_name)
        logger = create_logger(args.path_helper['log_path'])

    logger.info(args)
    writer_dict = {
        'writer': SummaryWriter(args.path_helper['log_path']),
        'train_global_steps': start_epoch * len(train_loader),
        'valid_global_steps': start_epoch // args.val_freq,
    }

    # train loop
    for epoch in tqdm(range(int(start_epoch), int(args.num_epochs)),
                      desc='total progress'):
        lr_schedulers = (gen_scheduler, dis_scheduler, ae_recon_scheduler,
                         ae_reg_scheduler)
        train(device, args, gen_net, dis_net, enc_net, gen_optimizer,
              dis_optimizer, ae_recon_optimizer, ae_reg_optimizer, para_loader,
              epoch, writer_dict, lr_schedulers)
        if epoch and epoch % args.val_freq == 0 or epoch == args.num_epochs - 1:
            fid_score = validate(args, fid_stat, gen_net, writer_dict,
                                 valid_loader)
            logger.info(f'FID score: {fid_score} || @ epoch {epoch}.')
            if fid_score < best_fid:
                best_fid = fid_score
                is_best = True
            else:
                is_best = False
        else:
            is_best = False

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'gen_state_dict': gen_net.state_dict(),
                'dis_state_dict': dis_net.state_dict(),
                'enc_state_dict': enc_net.state_dict(),
                'gen_optimizer': gen_optimizer.state_dict(),
                'dis_optimizer': dis_optimizer.state_dict(),
                'ae_recon_optimizer': ae_recon_optimizer.state_dict(),
                'ae_reg_optimizer': ae_reg_optimizer.state_dict(),
                'best_fid': best_fid,
                'path_helper': args.path_helper
            }, is_best, args.path_helper['ckpt_path'])
Exemple #4
0
def main():
    args = cfg.parse_args()
    random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed(args.random_seed)
    np.random.seed(args.random_seed)
    # set tf env
    _init_inception()
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    # import netwo

    # weight init
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv2d') != -1:
            if args.init_type == 'normal':
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif args.init_type == 'orth':
                nn.init.orthogonal_(m.weight.data)
            elif args.init_type == 'xavier_uniform':
                nn.init.xavier_uniform(m.weight.data, 1.)
            else:
                raise NotImplementedError('{} unknown inital type'.format(
                    args.init_type))
        elif classname.find('BatchNorm2d') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.0)

    gen_net = eval('models.' + args.model + '.Generator')(args=args).cuda()
    dis_net = eval('models.' + args.model + '.Discriminator')(args=args).cuda()
    gen_net.apply(weights_init)
    dis_net.apply(weights_init)
    avg_gen_net = deepcopy(gen_net)
    initial_gen_net_weight = torch.load(os.path.join(args.init_path,
                                                     'initial_gen_net.pth'),
                                        map_location="cpu")
    initial_dis_net_weight = torch.load(os.path.join(args.init_path,
                                                     'initial_dis_net.pth'),
                                        map_location="cpu")
    assert id(initial_dis_net_weight) != id(dis_net.state_dict())
    assert id(initial_gen_net_weight) != id(gen_net.state_dict())

    # set optimizer
    gen_optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, gen_net.parameters()), args.g_lr,
        (args.beta1, args.beta2))
    dis_optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, dis_net.parameters()), args.d_lr,
        (args.beta1, args.beta2))
    gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0,
                                  args.max_iter * args.n_critic)
    dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0,
                                  args.max_iter * args.n_critic)

    # set up data_loader
    dataset = datasets.ImageDataset(args)
    train_loader = dataset.train

    # fid stat
    if args.dataset.lower() == 'cifar10':
        fid_stat = 'fid_stat/fid_stats_cifar10_train.npz'
    elif args.dataset.lower() == 'stl10':
        fid_stat = 'fid_stat/fid_stats_stl10_train.npz'
    else:
        raise NotImplementedError('no fid stat for %s' % args.dataset.lower())
    assert os.path.exists(fid_stat)

    # epoch number for dis_net
    args.max_epoch = args.max_epoch * args.n_critic
    if args.max_iter:
        args.max_epoch = np.ceil(args.max_iter * args.n_critic /
                                 len(train_loader))

    # initial
    fixed_z = torch.cuda.FloatTensor(
        np.random.normal(0, 1, (25, args.latent_dim)))

    start_epoch = 0
    best_fid = 1e4

    print('=> resuming from %s' % args.load_path)
    assert os.path.exists(args.load_path)
    checkpoint_file = args.load_path
    assert os.path.exists(checkpoint_file)
    checkpoint = torch.load(checkpoint_file)
    pruning_generate(gen_net, checkpoint['gen_state_dict'])
    dis_net.load_state_dict(checkpoint['dis_state_dict'])
    total = 0
    total_nonzero = 0
    for m in dis_net.modules():
        if isinstance(m, nn.Conv2d):
            total += m.weight_orig.data.numel()
            mask = m.weight_orig.data.abs().clone().gt(0).float().cuda()
            total_nonzero += torch.sum(mask)
    conv_weights = torch.zeros(total)
    index = 0
    for m in dis_net.modules():
        if isinstance(m, nn.Conv2d):
            size = m.weight_orig.data.numel()
            conv_weights[index:(
                index + size)] = m.weight_orig.data.view(-1).abs().clone()
            index += size

    y, i = torch.sort(conv_weights)
    # thre_index = int(total * args.percent)
    # only care about the non zero weights
    # e.g: total = 100, total_nonzero = 80, percent = 0.2, thre_index = 36, that means keep 64
    thre_index = total - total_nonzero
    thre = y[int(thre_index)]
    pruned = 0
    print('Pruning threshold: {}'.format(thre))
    zero_flag = False
    masks = OrderedDict()
    for k, m in enumerate(dis_net.modules()):
        if isinstance(m, nn.Conv2d):
            weight_copy = m.weight_orig.data.abs().clone()
            mask = weight_copy.gt(thre).float()
            masks[k] = mask
            pruned = pruned + mask.numel() - torch.sum(mask)
            m.weight_orig.data.mul_(mask)
            if int(torch.sum(mask)) == 0:
                zero_flag = True
            print(
                'layer index: {:d} \t total params: {:d} \t remaining params: {:d}'
                .format(k, mask.numel(), int(torch.sum(mask))))
    print('Total conv params: {}, Pruned conv params: {}, Pruned ratio: {}'.
          format(total, pruned, pruned / total))

    pruning_generate(avg_gen_net, checkpoint['gen_state_dict'])
    see_remain_rate(gen_net)

    if not args.finetune_G:
        gen_weight = gen_net.state_dict()
        gen_orig_weight = rewind_weight(initial_gen_net_weight,
                                        gen_weight.keys())
        gen_weight.update(gen_orig_weight)
        gen_net.load_state_dict(gen_weight)
    gen_avg_param = copy_params(gen_net)

    if args.finetune_D:
        dis_net.load_state_dict(checkpoint['dis_state_dict'])
    else:
        dis_net.load_state_dict(initial_dis_net_weight)

    for k, m in enumerate(dis_net.modules()):
        if isinstance(m, nn.Conv2d):
            m.weight_orig.data.mul_(masks[k])

    orig_dis_net = eval('models.' + args.model +
                        '.Discriminator')(args=args).cuda()
    orig_dis_net.load_state_dict(checkpoint['dis_state_dict'])
    orig_dis_net.eval()

    args.path_helper = set_log_dir('logs',
                                   args.exp_name + "_{}".format(args.percent))
    logger = create_logger(args.path_helper['log_path'])
    #logger.info('=> loaded checkpoint %s (epoch %d)' % (checkpoint_file, start_epoch))

    logger.info(args)
    writer_dict = {
        'writer': SummaryWriter(args.path_helper['log_path']),
        'train_global_steps': start_epoch * len(train_loader),
        'valid_global_steps': start_epoch // args.val_freq,
    }

    # train loop
    for epoch in tqdm(range(int(start_epoch), int(args.max_epoch)),
                      desc='total progress'):
        lr_schedulers = (gen_scheduler,
                         dis_scheduler) if args.lr_decay else None
        see_remain_rate(gen_net)
        see_remain_rate_orig(dis_net)
        if not args.use_kd_D:
            train_with_mask(args, gen_net, dis_net, gen_optimizer,
                            dis_optimizer, gen_avg_param, train_loader, epoch,
                            writer_dict, masks, lr_schedulers)
        else:
            train_with_mask_kd(args, gen_net, dis_net, orig_dis_net,
                               gen_optimizer, dis_optimizer, gen_avg_param,
                               train_loader, epoch, writer_dict, masks,
                               lr_schedulers)

        if epoch and epoch % args.val_freq == 0 or epoch == int(
                args.max_epoch) - 1:
            backup_param = copy_params(gen_net)
            load_params(gen_net, gen_avg_param)
            inception_score, fid_score = validate(args, fixed_z, fid_stat,
                                                  gen_net, writer_dict, epoch)
            logger.info(
                'Inception score: %.4f, FID score: %.4f || @ epoch %d.' %
                (inception_score, fid_score, epoch))
            load_params(gen_net, backup_param)
            if fid_score < best_fid:
                best_fid = fid_score
                is_best = True
            else:
                is_best = False
        else:
            is_best = False

        avg_gen_net.load_state_dict(gen_net.state_dict())
        load_params(avg_gen_net, gen_avg_param)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': args.model,
                'gen_state_dict': gen_net.state_dict(),
                'dis_state_dict': dis_net.state_dict(),
                'avg_gen_state_dict': avg_gen_net.state_dict(),
                'gen_optimizer': gen_optimizer.state_dict(),
                'dis_optimizer': dis_optimizer.state_dict(),
                'best_fid': best_fid,
                'path_helper': args.path_helper
            }, is_best, args.path_helper['ckpt_path'])
Exemple #5
0
def main():
    args = cfg.parse_args()
    random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed(args.random_seed)

    # set tf env
    _init_inception()
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    # weight init
    gen_net = eval('models.' + args.model + '.Generator')(args=args)
    dis_net = eval('models.' + args.model + '.Discriminator')(args=args)

    # weight init
    def weights_init(m):
        if isinstance(m, nn.Conv2d):
            if args.init_type == 'normal':
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif args.init_type == 'orth':
                nn.init.orthogonal_(m.weight.data)
            elif args.init_type == 'xavier_uniform':
                nn.init.xavier_uniform(m.weight.data, 1.)
            else:
                raise NotImplementedError('{} unknown inital type'.format(
                    args.init_type))
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.0)

    gen_net.apply(weights_init)
    dis_net.apply(weights_init)

    gen_net = gen_net.cuda()
    dis_net = dis_net.cuda()

    avg_gen_net = deepcopy(gen_net)
    initial_gen_net_weight = deepcopy(gen_net.state_dict())
    initial_dis_net_weight = deepcopy(dis_net.state_dict())
    assert id(initial_dis_net_weight) != id(dis_net.state_dict())
    assert id(initial_gen_net_weight) != id(gen_net.state_dict())
    # set optimizer
    gen_optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, gen_net.parameters()), args.g_lr,
        (args.beta1, args.beta2))
    dis_optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, dis_net.parameters()), args.d_lr,
        (args.beta1, args.beta2))
    gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0,
                                  args.max_iter * args.n_critic)
    dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0,
                                  args.max_iter * args.n_critic)

    # set up data_loader
    dataset = datasets.ImageDataset(args)
    train_loader = dataset.train

    # fid stat
    if args.dataset.lower() == 'cifar10':
        fid_stat = 'fid_stat/fid_stats_cifar10_train.npz'
    elif args.dataset.lower() == 'stl10':
        fid_stat = 'fid_stat/fid_stats_stl10_train.npz'
    else:
        raise NotImplementedError('no fid stat for %s' % args.dataset.lower())
    assert os.path.exists(fid_stat)

    # epoch number for dis_net
    args.max_epoch = args.max_epoch * args.n_critic
    if args.max_iter:
        args.max_epoch = np.ceil(args.max_iter * args.n_critic /
                                 len(train_loader))

    # initial
    np.random.seed(args.random_seed)
    fixed_z = torch.cuda.FloatTensor(
        np.random.normal(0, 1, (25, args.latent_dim)))

    start_epoch = 0
    best_fid = 1e4

    args.path_helper = set_log_dir('logs',
                                   args.exp_name + "_{}".format(args.percent))
    logger = create_logger(args.path_helper['log_path'])
    # logger.info('=> loaded checkpoint %s (epoch %d)' % (checkpoint_file, start_epoch))

    logger.info(args)
    writer_dict = {
        'writer': SummaryWriter(args.path_helper['log_path']),
        'train_global_steps': start_epoch * len(train_loader),
        'valid_global_steps': start_epoch // args.val_freq,
    }

    print('=> resuming from %s' % args.load_path)
    assert os.path.exists(args.load_path)
    checkpoint_file = os.path.join(args.load_path, 'Model', 'checkpoint.pth')
    assert os.path.exists(checkpoint_file)
    checkpoint = torch.load(checkpoint_file)
    gen_net.load_state_dict(checkpoint['gen_state_dict'])

    torch.manual_seed(args.random_seed)
    pruning_generate(gen_net, (1 - args.percent), args.pruning_method)
    torch.manual_seed(args.random_seed)
    pruning_generate(avg_gen_net, (1 - args.percent), args.pruning_method)
    see_remain_rate(gen_net)

    if args.second_seed:
        dis_net.apply(weights_init)
    if args.finetune_D:
        dis_net.load_state_dict(checkpoint['dis_state_dict'])
    else:
        dis_net.load_state_dict(initial_dis_net_weight)

    gen_weight = gen_net.state_dict()
    gen_orig_weight = rewind_weight(initial_gen_net_weight, gen_weight.keys())
    assert id(gen_weight) != id(gen_orig_weight)
    gen_weight.update(gen_orig_weight)
    gen_net.load_state_dict(gen_weight)
    gen_avg_param = copy_params(gen_net)

    if args.use_kd_D:
        orig_dis_net = eval('models.' + args.model +
                            '.Discriminator')(args=args).cuda()
        orig_dis_net.load(checkpoint['dis_state_dict'])
        orig_dis_net.eval()
    # train loop
    for epoch in tqdm(range(int(start_epoch), int(args.max_epoch)),
                      desc='total progress'):
        lr_schedulers = (gen_scheduler,
                         dis_scheduler) if args.lr_decay else None
        see_remain_rate(gen_net)
        if not args.use_kd_D:
            train(args, gen_net, dis_net, gen_optimizer, dis_optimizer,
                  gen_avg_param, train_loader, epoch, writer_dict,
                  lr_schedulers)
        else:
            train_kd(args, gen_net, dis_net, orig_dis_net, gen_optimizer,
                     dis_optimizer, gen_avg_param, train_loader, epoch,
                     writer_dict, lr_schedulers)

        if epoch and epoch % args.val_freq == 0 or epoch == int(
                args.max_epoch) - 1:
            backup_param = copy_params(gen_net)
            load_params(gen_net, gen_avg_param)
            inception_score, fid_score = validate(args, fixed_z, fid_stat,
                                                  gen_net, writer_dict, epoch)
            logger.info(
                'Inception score: %.4f, FID score: %.4f || @ epoch %d.' %
                (inception_score, fid_score, epoch))
            load_params(gen_net, backup_param)
            if fid_score < best_fid:
                best_fid = fid_score
                is_best = True
            else:
                is_best = False
        else:
            is_best = False

        avg_gen_net.load_state_dict(gen_net.state_dict())
        load_params(avg_gen_net, gen_avg_param)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': args.model,
                'gen_state_dict': gen_net.state_dict(),
                'dis_state_dict': dis_net.state_dict(),
                'avg_gen_state_dict': avg_gen_net.state_dict(),
                'gen_optimizer': gen_optimizer.state_dict(),
                'dis_optimizer': dis_optimizer.state_dict(),
                'best_fid': best_fid,
                'path_helper': args.path_helper
            }, is_best, args.path_helper['ckpt_path'])
Exemple #6
0
def main():
    args = cfg.parse_args()
    torch.cuda.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)
    torch.backends.cudnn.deterministic = True

    # set tf env
    _init_inception()
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    # epoch number for dis_net
    dataset = datasets.ImageDataset(args, cur_img_size=8)
    train_loader = dataset.train
    if args.max_iter:
        args.max_epoch = np.ceil(args.max_iter / len(train_loader))
    else:
        args.max_iter = args.max_epoch * len(train_loader)
    args.max_epoch = args.max_epoch * args.n_critic

    # import network
    gen_net = eval('models.' + args.gen_model + '.Generator')(args=args).cuda()
    dis_net = eval('models.' + args.dis_model +
                   '.Discriminator')(args=args).cuda()
    gen_net.set_arch(args.arch, cur_stage=2)

    # weight init
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv2d') != -1:
            if args.init_type == 'normal':
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif args.init_type == 'orth':
                nn.init.orthogonal_(m.weight.data)
            elif args.init_type == 'xavier_uniform':
                nn.init.xavier_uniform_(m.weight.data, 1.)
            else:
                raise NotImplementedError('{} unknown inital type'.format(
                    args.init_type))
        elif classname.find('BatchNorm2d') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.0)

    gen_net.apply(weights_init)
    dis_net.apply(weights_init)

    gpu_ids = [i for i in range(int(torch.cuda.device_count()))]
    gen_net = torch.nn.DataParallel(gen_net.to("cuda:0"), device_ids=gpu_ids)
    dis_net = torch.nn.DataParallel(dis_net.to("cuda:0"), device_ids=gpu_ids)

    gen_net.module.cur_stage = 0
    dis_net.module.cur_stage = 0
    gen_net.module.alpha = 1.
    dis_net.module.alpha = 1.

    # set optimizer
    if args.optimizer == "adam":
        gen_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, gen_net.parameters()), args.g_lr,
            (args.beta1, args.beta2))
        dis_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, dis_net.parameters()), args.d_lr,
            (args.beta1, args.beta2))
    elif args.optimizer == "adamw":
        gen_optimizer = AdamW(filter(lambda p: p.requires_grad,
                                     gen_net.parameters()),
                              args.g_lr,
                              weight_decay=args.wd)
        dis_optimizer = AdamW(filter(lambda p: p.requires_grad,
                                     dis_net.parameters()),
                              args.g_lr,
                              weight_decay=args.wd)
    gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0,
                                  args.max_iter * args.n_critic)
    dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0,
                                  args.max_iter * args.n_critic)

    # fid stat
    if args.dataset.lower() == 'cifar10':
        fid_stat = 'fid_stat/fid_stats_cifar10_train.npz'
    elif args.dataset.lower() == 'stl10':
        fid_stat = 'fid_stat/stl10_train_unlabeled_fid_stats_48.npz'
    elif args.fid_stat is not None:
        fid_stat = args.fid_stat
    else:
        raise NotImplementedError(f'no fid stat for {args.dataset.lower()}')
    assert os.path.exists(fid_stat)

    # initial
    fixed_z = torch.cuda.FloatTensor(
        np.random.normal(0, 1, (64, args.latent_dim)))
    gen_avg_param = copy_params(gen_net)
    start_epoch = 0
    best_fid = 1e4

    # set writer
    if args.load_path:
        print(f'=> resuming from {args.load_path}')
        assert os.path.exists(args.load_path)
        checkpoint_file = os.path.join(args.load_path)
        assert os.path.exists(checkpoint_file)
        checkpoint = torch.load(checkpoint_file)
        start_epoch = checkpoint['epoch']
        best_fid = checkpoint['best_fid']
        gen_net.load_state_dict(checkpoint['gen_state_dict'])
        dis_net.load_state_dict(checkpoint['dis_state_dict'])
        gen_optimizer.load_state_dict(checkpoint['gen_optimizer'])
        dis_optimizer.load_state_dict(checkpoint['dis_optimizer'])
        #         avg_gen_net = deepcopy(gen_net)
        #         avg_gen_net.load_state_dict(checkpoint['avg_gen_state_dict'])
        gen_avg_param = checkpoint['gen_avg_param']
        #         del avg_gen_net
        cur_stage = cur_stages(start_epoch, args)
        gen_net.module.cur_stage = cur_stage
        dis_net.module.cur_stage = cur_stage
        gen_net.module.alpha = 1.
        dis_net.module.alpha = 1.

        args.path_helper = checkpoint['path_helper']

    else:
        # create new log dir
        assert args.exp_name
        args.path_helper = set_log_dir('logs', args.exp_name)

    logger = create_logger(args.path_helper['log_path'])
    logger.info(args)
    writer_dict = {
        'writer': SummaryWriter(args.path_helper['log_path']),
        'train_global_steps': start_epoch * len(train_loader),
        'valid_global_steps': start_epoch // args.val_freq,
    }

    def return_states():
        states = {}
        states['epoch'] = epoch
        states['best_fid'] = best_fid_score
        states['gen_state_dict'] = gen_net.state_dict()
        states['dis_state_dict'] = dis_net.state_dict()
        states['gen_optimizer'] = gen_optimizer.state_dict()
        states['dis_optimizer'] = dis_optimizer.state_dict()
        states['gen_avg_param'] = gen_avg_param
        states['path_helper'] = args.path_helper
        return states

    # train loop

    for epoch in range(start_epoch + 1, args.max_epoch):
        train(
            args,
            gen_net,
            dis_net,
            gen_optimizer,
            dis_optimizer,
            gen_avg_param,
            train_loader,
            epoch,
            writer_dict,
            fixed_z,
        )
        backup_param = copy_params(gen_net)
        load_params(gen_net, gen_avg_param)
        fid_score = validate(
            args,
            fixed_z,
            fid_stat,
            epoch,
            gen_net,
            writer_dict,
        )
        logger.info(f'FID score: {fid_score} || @ epoch {epoch}.')
        load_params(gen_net, backup_param)
        is_best = False
        if epoch == 1 or fid_score < best_fid_score:
            best_fid_score = fid_score
            is_best = True
        if is_best or epoch % 1 == 0:
            states = return_states()
            save_checkpoint(states,
                            is_best,
                            args.path_helper['ckpt_path'],
                            filename=f'checkpoint_epoch_{epoch}.pth')
def main():
    args = cfg_train.parse_args()
    torch.cuda.manual_seed(args.random_seed)

    # set tf env
    _init_inception()
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    # import network
    # gen_net = eval('models.' + args.gen_model + '.' + args.gen)(args=args).cuda()
    genotype_gen = eval('genotypes.%s' % args.arch_gen)
    gen_net = eval('models.' + args.gen_model + '.' + args.gen)(
        args, genotype_gen).cuda()
    # gen_net = eval('models.' + args.gen_model + '.' + args.gen)(args = args).cuda()
    if 'Discriminator' not in args.dis:
        genotype_dis = eval('genotypes.%s' % args.arch_dis)
        dis_net = eval('models.' + args.dis_model + '.' + args.dis)(
            args, genotype_dis).cuda()
    else:
        dis_net = eval('models.' + args.dis_model + '.' +
                       args.dis)(args=args).cuda()

    # weight init
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv2d') != -1:
            if args.init_type == 'normal':
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif args.init_type == 'orth':
                nn.init.orthogonal_(m.weight.data)
            elif args.init_type == 'xavier_uniform':
                nn.init.xavier_uniform(m.weight.data, 1.)
            else:
                raise NotImplementedError('{} unknown inital type'.format(
                    args.init_type))
        elif classname.find('BatchNorm2d') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.0)

    gen_net.apply(weights_init)
    dis_net.apply(weights_init)

    # set up data_loader
    dataset = datasets.ImageDataset(args)
    train_loader = dataset.train
    val_loader = dataset.valid

    # set optimizer
    gen_optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, gen_net.parameters()), args.g_lr,
        (args.beta1, args.beta2))
    dis_optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, dis_net.parameters()), args.d_lr,
        (args.beta1, args.beta2))
    gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, args.g_lr * 0.01,
                                  260 * len(train_loader),
                                  args.max_iter * args.n_critic)
    dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, args.d_lr * 0.01,
                                  260 * len(train_loader),
                                  args.max_iter * args.n_critic)

    # fid stat
    if args.dataset.lower() == 'cifar10':
        fid_stat = 'fid_stat/fid_stats_cifar10_train.npz'
    elif args.dataset.lower() == 'stl10':
        fid_stat = 'fid_stat/stl10_train_unlabeled_fid_stats_48.npz'
    elif args.dataset.lower() == 'mnist':
        fid_stat = 'fid_stat/stl10_train_unlabeled_fid_stats_48.npz'
    else:
        raise NotImplementedError(f'no fid stat for {args.dataset.lower()}')
    assert os.path.exists(fid_stat)

    # epoch number for dis_net
    args.max_epoch = args.max_epoch * args.n_critic
    if args.max_iter:
        args.max_epoch = np.ceil(args.max_iter * args.n_critic /
                                 len(train_loader))

    # initial
    fixed_z = torch.cuda.FloatTensor(
        np.random.normal(0, 1, (25, args.latent_dim)))
    fixed_z_sample = torch.cuda.FloatTensor(
        np.random.normal(0, 1, (args.eval_batch_size, args.latent_dim)))
    gen_avg_param = copy_params(gen_net)
    start_epoch = 0
    best_fid = 1e4
    best_fid_epoch = 0
    is_with_fid = 0
    std_with_fid = 0.
    best_is = 0
    best_is_epoch = 0
    fid_with_is = 0
    best_dts = 0

    # set writer
    if args.load_path:
        print(f'=> resuming from {args.load_path}')
        assert os.path.exists(args.load_path)
        checkpoint_file = os.path.join(args.load_path, 'Model',
                                       'checkpoint.pth')
        assert os.path.exists(checkpoint_file)
        checkpoint = torch.load(checkpoint_file)
        start_epoch = checkpoint['epoch']
        best_fid = checkpoint['best_fid']
        gen_net.load_state_dict(checkpoint['gen_state_dict'])
        dis_net.load_state_dict(checkpoint['dis_state_dict'])
        gen_optimizer.load_state_dict(checkpoint['gen_optimizer'])
        dis_optimizer.load_state_dict(checkpoint['dis_optimizer'])
        avg_gen_net = deepcopy(gen_net)
        avg_gen_net.load_state_dict(checkpoint['avg_gen_state_dict'])
        gen_avg_param = copy_params(avg_gen_net)
        del avg_gen_net

        args.path_helper = checkpoint['path_helper']
        logger = create_logger(args.path_helper['log_path'])
        logger.info(
            f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})')
    else:
        # create new log dir
        assert args.exp_name
        args.path_helper = set_log_dir('logs', args.exp_name)
        logger = create_logger(args.path_helper['log_path'])

    logger.info(args)
    writer_dict = {
        'writer': SummaryWriter(args.path_helper['log_path']),
        'train_global_steps': start_epoch * len(train_loader),
        'valid_global_steps': start_epoch // args.val_freq,
    }

    # calculate the FLOPs and param count of G
    input = torch.randn(args.gen_batch_size, args.latent_dim).cuda()
    flops, params = profile(gen_net, inputs=(input, ))
    flops, params = clever_format([flops, params], "%.3f")
    logger.info('FLOPs is {}, param count is {}'.format(flops, params))

    # train loop
    dg_list = []
    worst_lr = 1e-5
    for epoch in tqdm(range(int(start_epoch), int(args.max_epoch)),
                      desc='total progress'):
        lr_schedulers = (gen_scheduler,
                         dis_scheduler) if args.lr_decay else None

        train(args, gen_net, dis_net, gen_optimizer, dis_optimizer,
              gen_avg_param, train_loader, epoch, writer_dict, args.consistent,
              lr_schedulers)

        if epoch and epoch % args.val_freq == 0 or epoch == int(
                args.max_epoch) - 1:
            backup_param = copy_params(gen_net)
            load_params(gen_net, gen_avg_param)
            inception_score, std, fid_score = validate(args,
                                                       fixed_z,
                                                       fid_stat,
                                                       gen_net,
                                                       writer_dict,
                                                       args.path_helper,
                                                       search=False)
            logger.info(
                f'Inception score: {inception_score}, FID score: {fid_score}+-{std} || @ epoch {epoch}.'
            )
            load_params(gen_net, backup_param)
            if fid_score < best_fid:
                best_fid = fid_score
                best_fid_epoch = epoch
                is_with_fid = inception_score
                std_with_fid = std
                is_best = True
            else:
                is_best = False
            if inception_score > best_is:
                best_is = inception_score
                best_std = std
                fid_with_is = fid_score
                best_is_epoch = epoch
        else:
            is_best = False

        # save generated images
        if epoch % args.image_every == 0:
            gen_noise = torch.cuda.FloatTensor(
                np.random.normal(0, 1,
                                 (args.eval_batch_size, args.latent_dim)))
            # gen_images = gen_net(fixed_z_sample)
            # gen_images = gen_images.reshape(args.eval_batch_size, 32, 32, 3)
            # gen_images = gen_images.cpu().detach()
            gen_images = gen_net(fixed_z_sample).mul_(127.5).add_(
                127.5).clamp_(0.0, 255.0).permute(0, 2, 3,
                                                  1).to('cpu',
                                                        torch.uint8).numpy()
            fig = plt.figure()
            grid = ImageGrid(fig, 111, nrows_ncols=(10, 10), axes_pad=0)
            for x in range(args.eval_batch_size):
                grid[x].imshow(gen_images[x])  # cmap="gray")
                grid[x].set_xticks([])
                grid[x].set_yticks([])
            plt.savefig(
                os.path.join(args.path_helper['sample_path'],
                             "epoch_{}.png".format(epoch)))
            plt.close()

        avg_gen_net = deepcopy(gen_net)
        # avg_gen_net = eval('models.'+args.gen_model+'.' + args.gen)(args, genotype_gen).cuda()
        # avg_gen_net = eval('models.' + args.gen_model + '.' + args.gen)(args=args).cuda()
        load_params(avg_gen_net, gen_avg_param)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'gen_model': args.gen_model,
                'dis_model': args.dis_model,
                'gen_state_dict': gen_net.state_dict(),
                'dis_state_dict': dis_net.state_dict(),
                'avg_gen_state_dict': avg_gen_net.state_dict(),
                'gen_optimizer': gen_optimizer.state_dict(),
                'dis_optimizer': dis_optimizer.state_dict(),
                'best_fid': best_fid,
                'path_helper': args.path_helper
            }, is_best, args.path_helper['ckpt_path'])
        del avg_gen_net
    logger.info(
        'best_is is {}+-{}@{} epoch, fid is {}, best_fid is {}@{}, is is {}+-{}'
        .format(best_is, best_std, best_is_epoch, fid_with_is, best_fid,
                best_fid_epoch, is_with_fid, std_with_fid))
Exemple #8
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    # Create tensorboard logger
    writer_dict = {
        'writer': SummaryWriter(path_helper['log']),
        'inner_steps': 0,
        'val_steps': 0,
        'valid_global_steps': 0
    }

    # set tf env
    if args.eval:
        _init_inception()
        inception_path = check_or_download_inception(None)
        create_inception_graph(inception_path)

    # fid_stat
    if args.dataset.lower() == 'cifar10':
        fid_stat = 'fid_stat/fid_stats_cifar10_train.npz'
    elif args.dataset.lower() == 'stl10':
        fid_stat = 'fid_stat/stl10_train_unlabeled_fid_stats_48.npz'
    elif args.dataset.lower() == 'mnist':
        fid_stat = 'fid_stat/stl10_train_unlabeled_fid_stats_48.npz'
    else:
        raise NotImplementedError(f'no fid stat for {args.dataset.lower()}')
    assert os.path.exists(fid_stat)

    # initial
    fixed_z = torch.cuda.FloatTensor(
        np.random.normal(0, 1, (25, args.latent_dim)))
    FID_best = 1e+4
    IS_best = 0.
    FID_best_epoch = 0
    IS_best_epoch = 0

    # build gen and dis
    gen = eval('model_search_gan.' + args.gen)(args)
    gen = gen.cuda()
    dis = eval('model_search_gan.' + args.dis)(args)
    dis = dis.cuda()
    logging.info("generator param size = %fMB",
                 utils.count_parameters_in_MB(gen))
    logging.info("discriminator param size = %fMB",
                 utils.count_parameters_in_MB(dis))
    if args.parallel:
        gen = nn.DataParallel(gen)
        dis = nn.DataParallel(dis)

    # resume training
    if args.load_path != '':
        gen.load_state_dict(
            torch.load(
                os.path.join(args.load_path, 'model',
                             'weights_gen_' + 'last' + '.pt')))
        dis.load_state_dict(
            torch.load(
                os.path.join(args.load_path, 'model',
                             'weights_dis_' + 'last' + '.pt')))

    # set optimizer for parameters W of gen and dis
    gen_optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, gen.parameters()), args.g_lr,
        (args.beta1, args.beta2))
    dis_optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, dis.parameters()), args.d_lr,
        (args.beta1, args.beta2))

    # set moving average parameters for generator
    gen_avg_param = copy_params(gen)

    img_size = 8 if args.grow else args.img_size
    train_transform, valid_transform = eval('utils.' + '_data_transforms_' +
                                            args.dataset + '_resize')(args,
                                                                      img_size)
    if args.dataset == 'cifar10':
        train_data = eval('dset.' + dataset[args.dataset])(
            root=args.data,
            train=True,
            download=True,
            transform=train_transform)
    elif args.dataset == 'stl10':
        train_data = eval('dset.' + dataset[args.dataset])(
            root=args.data, download=True, transform=train_transform)

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))

    train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.gen_batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=2)

    valid_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.gen_batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(
            indices[split:num_train]),
        pin_memory=True,
        num_workers=2)
    logging.info('length of train_queue is {}'.format(len(train_queue)))
    logging.info('length of valid_queue is {}'.format(len(valid_queue)))

    max_iter = len(train_queue) * args.epochs

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        gen_optimizer, float(args.epochs), eta_min=args.learning_rate_min)
    gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0,
                                  max_iter * args.n_critic)
    dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0,
                                  max_iter * args.n_critic)

    architect = Architect_gen(gen, dis, args, 'duality_gap_with_mm', logging)

    gen.set_gumbel(args.use_gumbel)
    dis.set_gumbel(args.use_gumbel)
    for epoch in range(args.start_epoch + 1, args.epochs):
        scheduler.step()
        lr = scheduler.get_lr()[0]
        logging.info('epoch %d lr %e', epoch, lr)
        logging.info('epoch %d gen_lr %e', epoch, args.g_lr)
        logging.info('epoch %d dis_lr %e', epoch, args.d_lr)

        genotype_gen = gen.genotype()
        logging.info('gen_genotype = %s', genotype_gen)

        if 'Discriminator' not in args.dis:
            genotype_dis = dis.genotype()
            logging.info('dis_genotype = %s', genotype_dis)

        print('up_1: ', F.softmax(gen.alphas_up_1, dim=-1))
        print('up_2: ', F.softmax(gen.alphas_up_2, dim=-1))
        print('up_3: ', F.softmax(gen.alphas_up_3, dim=-1))

        # determine whether use gumbel or not
        if epoch == args.fix_alphas_epochs + 1:
            gen.set_gumbel(args.use_gumbel)
            dis.set_gumbel(args.use_gumbel)

        # grow discriminator and generator
        if args.grow:
            dis.cur_stage = grow_ctrl(epoch, args.grow_epoch)
            gen.cur_stage = grow_ctrl(epoch, args.grow_epoch)
            if args.restrict_dis_grow and dis.cur_stage > 1:
                dis.cur_stage = 1
                print('debug: dis.cur_stage is {}'.format(dis.cur_stage))
            if epoch in args.grow_epoch:
                train_transform, valid_transform = utils._data_transforms_cifar10_resize(
                    args, 2**(gen.cur_stage + 3))
                train_data = dset.CIFAR10(root=args.data,
                                          train=True,
                                          download=True,
                                          transform=train_transform)
                num_train = len(train_data)
                indices = list(range(num_train))
                split = int(np.floor(args.train_portion * num_train))
                train_queue = torch.utils.data.DataLoader(
                    train_data,
                    batch_size=args.gen_batch_size,
                    sampler=torch.utils.data.sampler.SubsetRandomSampler(
                        indices[:split]),
                    pin_memory=True,
                    num_workers=2)
                valid_queue = torch.utils.data.DataLoader(
                    train_data,
                    batch_size=args.gen_batch_size,
                    sampler=torch.utils.data.sampler.SubsetRandomSampler(
                        indices[split:num_train]),
                    pin_memory=True,
                    num_workers=2)
        else:
            gen.cur_stage = 2
            dis.cur_stage = 2

        # training parameters
        train_gan_parameter(args, train_queue, gen, dis, gen_optimizer,
                            dis_optimizer, gen_avg_param, logging, writer_dict)

        # training alphas
        if epoch > args.fix_alphas_epochs:
            train_gan_alpha(args, train_queue, valid_queue, gen, dis,
                            architect, gen_optimizer, gen_avg_param, epoch, lr,
                            writer_dict, logging)

        # evaluate the IS and FID
        if args.eval and epoch % args.eval_every == 0:
            inception_score, std, fid_score = validate(args, fixed_z, fid_stat,
                                                       gen, writer_dict,
                                                       path_helper)
            logging.info('epoch {}: IS is {}+-{}, FID is {}'.format(
                epoch, inception_score, std, fid_score))
            if inception_score > IS_best:
                IS_best = inception_score
                IS_epoch_best = epoch
            if fid_score < FID_best:
                FID_best = fid_score
                FID_epoch_best = epoch
            logging.info('best epoch {}: IS is {}'.format(
                IS_best_epoch, IS_best))
            logging.info('best epoch {}: FID is {}'.format(
                FID_best_epoch, FID_best))

        utils.save(
            gen,
            os.path.join(path_helper['model'],
                         'weights_gen_{}.pt'.format('last')))
        utils.save(
            dis,
            os.path.join(path_helper['model'],
                         'weights_dis_{}.pt'.format('last')))

    genotype_gen = gen.genotype()
    if 'Discriminator' not in args.dis:
        genotype_dis = dis.genotype()
    logging.info('best epoch {}: IS is {}'.format(IS_best_epoch, IS_best))
    logging.info('best epoch {}: FID is {}'.format(FID_best_epoch, FID_best))
    logging.info('final discovered gen_arch is {}'.format(genotype_gen))
    if 'Discriminator' not in args.dis:
        logging.info('final discovered dis_arch is {}'.format(genotype_dis))
Exemple #9
0
def main():
    args = cfg.parse_args()
    random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed(args.random_seed)
    np.random.seed(args.random_seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True
    os.environ['PYTHONHASHSEED'] = str(args.random_seed)
    
    # set tf env
    _init_inception()
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    # import network
    gen_net = eval('models.'+args.model+'.Generator')(args=args)
    dis_net = eval('models.'+args.model+'.Discriminator')(args=args)

    initial_gen_net_weight = torch.load(os.path.join(args.init_path, 'initial_gen_net.pth'), map_location="cpu")
    initial_dis_net_weight = torch.load(os.path.join(args.init_path, 'initial_dis_net.pth'), map_location="cpu")
    
    gen_net = gen_net.cuda()
    dis_net = dis_net.cuda()
    
    gen_net.load_state_dict(initial_gen_net_weight)
    dis_net.load_state_dict(initial_dis_net_weight)
    
    # set optimizer
    gen_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, gen_net.parameters()),
                                     args.g_lr, (args.beta1, args.beta2))
    dis_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, dis_net.parameters()),
                                     args.d_lr, (args.beta1, args.beta2))
    gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0, args.max_iter * args.n_critic)
    dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0, args.max_iter * args.n_critic)

    # set up data_loader
    dataset = datasets.ImageDataset(args)
    train_loader = dataset.train

    # fid stat
    if args.dataset.lower() == 'cifar10':
        fid_stat = 'fid_stat/fid_stats_cifar10_train.npz'
    elif args.dataset.lower() == 'stl10':
        fid_stat = 'fid_stat/fid_stats_stl10_train.npz'
    else:
        raise NotImplementedError('no fid stat for %s' % args.dataset.lower())
    assert os.path.exists(fid_stat)

    # epoch number for dis_net
    args.max_epoch = args.max_epoch * args.n_critic
    if args.max_iter:
        args.max_epoch = np.ceil(args.max_iter * args.n_critic / len(train_loader))

    # initial
    fixed_z = torch.cuda.FloatTensor(np.random.normal(0, 1, (25, args.latent_dim)))
    gen_avg_param = copy_params(gen_net)
    start_epoch = 0
    best_fid = 1e4

    # set writer
    if args.load_path:
        print('=> resuming from %s' % args.load_path)
        assert os.path.exists(args.load_path)
        checkpoint_file = os.path.join(args.load_path, 'Model', 'checkpoint.pth')
        assert os.path.exists(checkpoint_file)
        checkpoint = torch.load(checkpoint_file)
        start_epoch = checkpoint['epoch']
        best_fid = checkpoint['best_fid']
        gen_net.load_state_dict(checkpoint['gen_state_dict'])
        dis_net.load_state_dict(checkpoint['dis_state_dict'])
        gen_optimizer.load_state_dict(checkpoint['gen_optimizer'])
        dis_optimizer.load_state_dict(checkpoint['dis_optimizer'])
        avg_gen_net = deepcopy(gen_net)
        avg_gen_net.load_state_dict(checkpoint['avg_gen_state_dict'])
        gen_avg_param = copy_params(avg_gen_net)
        del avg_gen_net

        args.path_helper = checkpoint['path_helper']
        logger = create_logger(args.path_helper['log_path'])
        logger.info('=> loaded checkpoint %s (epoch %d)' % (checkpoint_file, start_epoch))
    else:
        # create new log dir
        assert args.exp_name
        args.path_helper = set_log_dir('logs', args.exp_name)
        logger = create_logger(args.path_helper['log_path'])

    logger.info(args)
    writer_dict = {
        'writer': SummaryWriter(args.path_helper['log_path']),
        'train_global_steps': start_epoch * len(train_loader),
        'valid_global_steps': start_epoch // args.val_freq,
    }

    # train loop
    switch = False
    for epoch in range(int(start_epoch), int(args.max_epoch)):
            
        lr_schedulers = (gen_scheduler, dis_scheduler) if args.lr_decay else None
        train(args, gen_net, dis_net, gen_optimizer, dis_optimizer, gen_avg_param, train_loader, epoch, writer_dict,
              lr_schedulers)

        if epoch and epoch % args.val_freq == 0 or epoch == int(args.max_epoch)-1:
            backup_param = copy_params(gen_net)
            load_params(gen_net, gen_avg_param)
            inception_score, fid_score = validate(args, fixed_z, fid_stat, gen_net, writer_dict, epoch)
            logger.info('Inception score: %.4f, FID score: %.4f || @ epoch %d.' % (inception_score, fid_score, epoch))
            load_params(gen_net, backup_param)
            if fid_score < best_fid:
                best_fid = fid_score
                is_best = True
            else:
                is_best = False
        else:
            is_best = False

        avg_gen_net = deepcopy(gen_net)
        load_params(avg_gen_net, gen_avg_param)
        save_checkpoint({
            'epoch': epoch + 1,
            'model': args.model,
            'gen_state_dict': gen_net.state_dict(),
            'dis_state_dict': dis_net.state_dict(),
            'avg_gen_state_dict': avg_gen_net.state_dict(),
            'gen_optimizer': gen_optimizer.state_dict(),
            'dis_optimizer': dis_optimizer.state_dict(),
            'best_fid': best_fid,
            'path_helper': args.path_helper,
            'seed': args.random_seed
        }, is_best, args.path_helper['ckpt_path'])
        del avg_gen_net
Exemple #10
0
def main():
    args = cfg.parse_args()
    torch.cuda.manual_seed(args.random_seed)

    # set tf env
    _init_inception()
    inception_path = check_or_download_inception(None)
    create_inception_graph(inception_path)

    # import network
    gen_net = eval("models_search." + args.gen_model +
                   ".Generator")(args=args).cuda()
    dis_net = eval("models_search." + args.dis_model +
                   ".Discriminator")(args=args).cuda()

    gen_net.set_arch(args.arch, cur_stage=2)
    dis_net.cur_stage = 2

    # weight init
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find("Conv2d") != -1:
            if args.init_type == "normal":
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif args.init_type == "orth":
                nn.init.orthogonal_(m.weight.data)
            elif args.init_type == "xavier_uniform":
                nn.init.xavier_uniform(m.weight.data, 1.0)
            else:
                raise NotImplementedError("{} unknown inital type".format(
                    args.init_type))
        elif classname.find("BatchNorm2d") != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.0)

    gen_net.apply(weights_init)
    dis_net.apply(weights_init)

    # set optimizer
    gen_optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, gen_net.parameters()),
        args.g_lr,
        (args.beta1, args.beta2),
    )
    dis_optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, dis_net.parameters()),
        args.d_lr,
        (args.beta1, args.beta2),
    )
    gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0,
                                  args.max_iter * args.n_critic)
    dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0,
                                  args.max_iter * args.n_critic)

    # set up data_loader
    dataset = datasets.ImageDataset(args)
    train_loader = dataset.train

    # fid stat
    if args.dataset.lower() == "cifar10":
        fid_stat = "fid_stat/fid_stats_cifar10_train.npz"
    elif args.dataset.lower() == "stl10":
        fid_stat = "fid_stat/stl10_train_unlabeled_fid_stats_48.npz"
    else:
        raise NotImplementedError(f"no fid stat for {args.dataset.lower()}")
    assert os.path.exists(fid_stat)

    # epoch number for dis_net
    args.max_epoch = args.max_epoch * args.n_critic
    if args.max_iter:
        args.max_epoch = np.ceil(args.max_iter * args.n_critic /
                                 len(train_loader))

    # initial
    fixed_z = torch.cuda.FloatTensor(
        np.random.normal(0, 1, (25, args.latent_dim)))
    gen_avg_param = copy_params(gen_net)
    start_epoch = 0
    best_fid = 1e4

    # set writer
    if args.load_path:
        print(f"=> resuming from {args.load_path}")
        assert os.path.exists(args.load_path)
        checkpoint_file = os.path.join(args.load_path, "Model",
                                       "checkpoint.pth")
        assert os.path.exists(checkpoint_file)
        checkpoint = torch.load(checkpoint_file)
        start_epoch = checkpoint["epoch"]
        best_fid = checkpoint["best_fid"]
        gen_net.load_state_dict(checkpoint["gen_state_dict"])
        dis_net.load_state_dict(checkpoint["dis_state_dict"])
        gen_optimizer.load_state_dict(checkpoint["gen_optimizer"])
        dis_optimizer.load_state_dict(checkpoint["dis_optimizer"])
        avg_gen_net = deepcopy(gen_net)
        avg_gen_net.load_state_dict(checkpoint["avg_gen_state_dict"])
        gen_avg_param = copy_params(avg_gen_net)
        del avg_gen_net

        args.path_helper = checkpoint["path_helper"]
        logger = create_logger(args.path_helper["log_path"])
        logger.info(
            f"=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})")
    else:
        # create new log dir
        assert args.exp_name
        args.path_helper = set_log_dir("logs", args.exp_name)
        logger = create_logger(args.path_helper["log_path"])

    logger.info(args)
    writer_dict = {
        "writer": SummaryWriter(args.path_helper["log_path"]),
        "train_global_steps": start_epoch * len(train_loader),
        "valid_global_steps": start_epoch // args.val_freq,
    }

    # train loop
    for epoch in tqdm(range(int(start_epoch), int(args.max_epoch)),
                      desc="total progress"):
        lr_schedulers = (gen_scheduler,
                         dis_scheduler) if args.lr_decay else None
        train(
            args,
            gen_net,
            dis_net,
            gen_optimizer,
            dis_optimizer,
            gen_avg_param,
            train_loader,
            epoch,
            writer_dict,
            lr_schedulers,
        )

        if epoch and epoch % args.val_freq == 0 or epoch == int(
                args.max_epoch) - 1:
            backup_param = copy_params(gen_net)
            load_params(gen_net, gen_avg_param)
            inception_score, fid_score = validate(args, fixed_z, fid_stat,
                                                  gen_net, writer_dict)
            logger.info(
                f"Inception score: {inception_score}, FID score: {fid_score} || @ epoch {epoch}."
            )
            load_params(gen_net, backup_param)
            if fid_score < best_fid:
                best_fid = fid_score
                is_best = True
            else:
                is_best = False
        else:
            is_best = False

        avg_gen_net = deepcopy(gen_net)
        load_params(avg_gen_net, gen_avg_param)
        save_checkpoint(
            {
                "epoch": epoch + 1,
                "gen_model": args.gen_model,
                "dis_model": args.dis_model,
                "gen_state_dict": gen_net.state_dict(),
                "dis_state_dict": dis_net.state_dict(),
                "avg_gen_state_dict": avg_gen_net.state_dict(),
                "gen_optimizer": gen_optimizer.state_dict(),
                "dis_optimizer": dis_optimizer.state_dict(),
                "best_fid": best_fid,
                "path_helper": args.path_helper,
            },
            is_best,
            args.path_helper["ckpt_path"],
        )
        del avg_gen_net
Exemple #11
0
def main():
    args = cfg.parse_args()
    torch.cuda.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)
    torch.backends.cudnn.deterministic = True


    # import network
    # args.gen_model is TransGAN_8_8_1 for example
    gen_net = eval('models.'+args.gen_model+'.Generator')(args=args).cuda()
    dis_net = eval('models.'+args.dis_model+'.Discriminator')(args=args).cuda()
    gen_net.set_arch(args.arch, cur_stage=2)

    print("The shit!")

    # weight init: Xavier Uniform
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv2d') != -1:
            if args.init_type == 'normal':
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif args.init_type == 'orth':
                nn.init.orthogonal_(m.weight.data)
            elif args.init_type == 'xavier_uniform':
                nn.init.xavier_uniform(m.weight.data, 1.)
            else:
                raise NotImplementedError('{} unknown inital type'.format(args.init_type))
        elif classname.find('BatchNorm2d') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.0)
    
    gen_net.apply(weights_init)
    dis_net.apply(weights_init)

    gpu_ids = [i for i in range(int(torch.cuda.device_count()))]
    gen_net = torch.nn.DataParallel(gen_net.to("cuda:0"), device_ids=gpu_ids)
    dis_net = torch.nn.DataParallel(dis_net.to("cuda:0"), device_ids=gpu_ids)
    
    

    # print(gen_net.module.cur_stage)

    if args.optimizer == "adam":
        gen_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, gen_net.parameters()),
                                        args.g_lr, (args.beta1, args.beta2))
        dis_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, dis_net.parameters()),
                                        args.d_lr, (args.beta1, args.beta2))
    elif args.optimizer == "adamw":
        gen_optimizer = AdamW(filter(lambda p: p.requires_grad, gen_net.parameters()),
                                        args.g_lr, weight_decay=args.wd)
        dis_optimizer = AdamW(filter(lambda p: p.requires_grad, dis_net.parameters()),
                                         args.g_lr, weight_decay=args.wd)
    gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0, args.max_iter * args.n_critic)
    dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0, args.max_iter * args.n_critic)

    # fid stat
    if args.dataset.lower() == 'cifar10':
        fid_stat = 'fid_stat/fid_stats_cifar10_train.npz'
    elif args.dataset.lower() == 'stl10':
        fid_stat = 'fid_stat/stl10_train_unlabeled_fid_stats_48.npz'
    elif args.fid_stat is not None:
        fid_stat = args.fid_stat
    else:
        raise NotImplementedError  # (f"no fid stat for %s"%args.dataset.lower()")
    assert os.path.exists(fid_stat)

    dataset = datasets.ImageDataset(args, cur_img_size=8)
    train_loader = dataset.train

    writer=SummaryWriter()
    writer_dict = {'writer':writer}
    writer_dict["train_global_steps"]=0
    writer_dict["valid_global_steps"]=0

    best = 1e4
    for epoch in range(args.max_epoch):

        train(args, gen_net = gen_net, dis_net = dis_net, gen_optimizer = gen_optimizer, dis_optimizer = dis_optimizer, gen_avg_param = None, train_loader = train_loader,
            epoch = epoch, writer_dict = writer_dict, fixed_z = None, schedulers=[gen_scheduler, dis_scheduler])

        checkpoint = {'epoch':epoch, 'best_fid':best}
        checkpoint['gen_state_dict'] = gen_net.state_dict()
        checkpoint['dis_state_dict'] = dis_net.state_dict()
        score = validate(args, None, fid_stat, epoch, gen_net, writer_dict, clean_dir=True)
        # print these scores, is it really the latest
        print(f'FID score: {score} - best ID score: {best} || @ epoch {epoch}.')
        if epoch == 0 or epoch > 30:
            if score < best:
                save_checkpoint(checkpoint, is_best=(score<best), output_dir=args.output_dir)
                print("Saved Latest Model!")
                best = score

    checkpoint = {'epoch':epoch, 'best_fid':best}
    checkpoint['gen_state_dict'] = gen_net.state_dict()
    checkpoint['dis_state_dict'] = dis_net.state_dict()
    score = validate(args, None, fid_stat, epoch, gen_net, writer_dict, clean_dir=True)
    save_checkpoint(checkpoint, is_best=(score<best), output_dir=args.output_dir)