def get_model(model_name, z_dim):
    if model_name == 'dc':
        D = GoodDiscriminator()
        G = GoodGenerator()
    elif model_name == 'dcBN':
        D = GoodDiscriminatorbn()
        G = GoodGenerator()
    elif model_name == 'dcD':
        D = GoodDiscriminatord()
        G = GoodGenerator()
    elif model_name == 'DCGAN':
        D = DC_discriminator()
        G = DC_generator(z_dim=z_dim)
    elif model_name == 'Resnet':
        D = ResNet32Discriminator(n_in=3, num_filters=128, batchnorm=True)
        G = ResNet32Generator(z_dim=z_dim, num_filters=128, batchnorm=True)
    elif model_name == 'ResnetWBN':
        D = ResNet32Discriminator(n_in=3, num_filters=128, batchnorm=False)
        G = ResNet32Generator(z_dim=z_dim, num_filters=128, batchnorm=True)
    elif model_name == 'DCGAN-WBN':
        D = DC_discriminatorW()
        G = DC_generator(z_dim=z_dim)
    elif model_name == 'dcSN':
        D = GoodSNDiscriminator()
        G = GoodGenerator()
    elif model_name == 'mnist':
        D = dc_D()
        G = dc_G(z_dim=z_dim)
    else:
        print('No matching result of :')
    print(model_name)
    return D, G
def generate_data(model_weight, path, z_dim=96, device='cpu'):
    chk = torch.load(model_weight)
    print('load from %s' % model_weight)
    dataset = get_data(dataname='MNIST', path='../datas/mnist')
    fixed_z = torch.randn((500, z_dim), device=device)
    fixed_D = dc_D().to(device)
    fixed_G = dc_G(z_dim=z_dim).to(device)
    fixed_D.load_state_dict(chk['D'])
    fixed_G.load_state_dict(chk['G'])
    real_loader = DataLoader(dataset=dataset, batch_size=500, shuffle=True,
                             num_workers=4)
    real_set = next(iter(real_loader))
    real_set = real_set[0].to(device)
    with torch.no_grad():
        fake_set = fixed_G(fixed_z)
        fixed_real_d = fixed_D(real_set)
        fixed_fake_d = fixed_D(fake_set)
        fixed_vec = torch.cat([fixed_real_d, fixed_fake_d])
    if not os.path.exists('figs/select'):
        os.makedirs('figs/select')
    torch.save({'real_set': real_set,
                'fake_set': fake_set,
                'real_d': fixed_real_d,
                'fake_d': fixed_fake_d,
                'pred_vec': fixed_vec}, path)
    for i in range(5):
        j = i * 100
        vutils.save_image(real_set[j: j + 100], 'figs/select/real_set_%d.png' % i, nrow=10, normalize=True)
        vutils.save_image(fake_set[j: j + 100], 'figs/select/fake_set_%d.png' % i, nrow=10, normalize=True)
Exemple #3
0
def get_model(model_name, z_dim, configs=None):
    if model_name == 'dc':
        D = GoodDiscriminator()
        G = GoodGenerator()
    elif model_name == 'dcBN':
        D = GoodDiscriminatorbn()
        G = GoodGenerator()
    elif model_name == 'dcD':
        D = GoodDiscriminatord()
        G = GoodGenerator()
    elif model_name == 'DCGAN':
        D = DC_discriminator()
        G = DC_generator(z_dim=z_dim)
    elif model_name == 'Resnet32':
        D = ResNet32Discriminator(n_in=3, num_filters=128, batchnorm=True)
        G = ResNet32Generator(z_dim=z_dim, num_filters=128, batchnorm=True)
    elif model_name == 'Resnet':
        D = ResNetDiscriminator(in_channel=configs['image_channel'],
                                insize=configs['image_size'],
                                num_filters=configs['feature_num'],
                                batchnorm=configs['batchnorm_d'])
        G = ResNetGenerator(z_dim=z_dim,
                            outsize=configs['image_size'],
                            num_filters=configs['feature_num'],
                            batchnorm=configs['batchnorm_g'])
    elif model_name == 'ResnetWBN':
        D = ResNet32Discriminator(n_in=3, num_filters=128, batchnorm=False)
        G = ResNet32Generator(z_dim=z_dim, num_filters=128, batchnorm=True)
    elif model_name == 'DCGAN-WBN':
        D = DC_discriminatorW()
        G = DC_generator(z_dim=z_dim)
    elif model_name == 'dcSN':
        D = GoodSNDiscriminator()
        G = GoodGenerator()
    elif model_name == 'mnist':
        D = dc_D()
        G = dc_G(z_dim=z_dim)
    elif model_name == 'dc32':
        D = dcD32()
        G = dcG32(z_dim=z_dim)
    elif model_name == 'DCGANs':
        D = DCGAN_D(insize=configs['image_size'],
                    channel_num=configs['image_channel'],
                    feature_num=configs['feature_num'],
                    n_extra_layers=configs['n_extra_layers'])
        G = DCGAN_G(outsize=configs['image_size'],
                    z_dim=z_dim,
                    nc=configs['image_channel'],
                    feature_num=configs['feature_num'],
                    n_extra_layers=configs['n_extra_layers'])
    else:
        print('No matching result of :')
    print(model_name)
    return D, G
Exemple #4
0
def train_mnist(epoch_num=10,
                show_iter=100,
                logdir='test',
                model_weight=None,
                load_d=False,
                load_g=False,
                compare_path=None,
                info_time=100,
                run_select=None,
                device='cpu'):
    lr_d = 0.01
    lr_g = 0.01
    batchsize = 128
    z_dim = 96
    print('MNIST, discriminator lr: %.3f, generator lr: %.3f' % (lr_d, lr_g))
    dataset = get_data(dataname='MNIST', path='../datas/mnist')
    dataloader = DataLoader(dataset=dataset,
                            batch_size=batchsize,
                            shuffle=True,
                            num_workers=4)
    D = dc_D().to(device)
    G = dc_G(z_dim=z_dim).to(device)
    D.apply(weights_init_d)
    G.apply(weights_init_g)
    if model_weight is not None:
        chk = torch.load(model_weight)
        if load_d:
            D.load_state_dict(chk['D'])
            print('Load D from %s' % model_weight)
        if load_g:
            G.load_state_dict(chk['G'])
            print('Load G from %s' % model_weight)
    if compare_path is not None:
        discriminator = dc_D().to(device)
        model_weight = torch.load(compare_path)
        discriminator.load_state_dict(model_weight['D'])
        model_vec = torch.cat(
            [p.contiguous().view(-1) for p in discriminator.parameters()])
        print('Load discriminator from %s' % compare_path)
    if run_select is not None:
        fixed_data = torch.load(run_select)
        real_set = fixed_data['real_set']
        fake_set = fixed_data['fake_set']
        real_d = fixed_data['real_d']
        fake_d = fixed_data['fake_d']
        fixed_vec = fixed_data['pred_vec']
        print('load fixed data set')

    from datetime import datetime
    current_time = datetime.now().strftime('%b%d_%H-%M-%S')
    writer = SummaryWriter(log_dir='logs/%s/%s_%.3f' %
                           (logdir, current_time, lr_d))
    d_optimizer = SGD(D.parameters(), lr=lr_d)
    g_optimizer = SGD(G.parameters(), lr=lr_g)
    timer = time.time()
    count = 0
    fixed_noise = torch.randn((64, z_dim), device=device)
    for e in range(epoch_num):
        for real_x in dataloader:
            real_x = real_x[0].to(device)
            d_real = D(real_x)
            z = torch.randn((d_real.shape[0], z_dim), device=device)
            fake_x = G(z)
            fake_x_c = fake_x.clone().detach()
            # update generator
            d_fake = D(fake_x)

            writer.add_scalars('Discriminator output', {
                'Generated image': d_fake.mean().item(),
                'Real image': d_real.mean().item()
            },
                               global_step=count)
            G_loss = get_loss(name='JSD', g_loss=True, d_fake=d_fake)
            g_optimizer.zero_grad()
            G_loss.backward()
            g_optimizer.step()
            gg = torch.norm(torch.cat(
                [p.grad.contiguous().view(-1) for p in G.parameters()]),
                            p=2)

            d_fake_c = D(fake_x_c)
            D_loss = get_loss(name='JSD',
                              g_loss=False,
                              d_real=d_real,
                              d_fake=d_fake_c)
            if compare_path is not None and count % info_time == 0:
                diff = get_diff(net=D, model_vec=model_vec)
                writer.add_scalar('Distance from checkpoint',
                                  diff.item(),
                                  global_step=count)
                if run_select is not None:
                    with torch.no_grad():
                        d_real_set = D(real_set)
                        d_fake_set = D(fake_set)
                        diff_real = torch.norm(d_real_set - real_d, p=2)
                        diff_fake = torch.norm(d_fake_set - fake_d, p=2)
                        d_vec = torch.cat([d_real_set, d_fake_set])
                        diff = torch.norm(d_vec.sub_(fixed_vec), p=2)
                        writer.add_scalars('L2 norm of pred difference', {
                            'Total': diff.item(),
                            'real set': diff_real.item(),
                            'fake set': diff_fake.item()
                        },
                                           global_step=count)
            d_optimizer.zero_grad()
            D_loss.backward()
            d_optimizer.step()

            gd = torch.norm(torch.cat(
                [p.grad.contiguous().view(-1) for p in D.parameters()]),
                            p=2)

            writer.add_scalars('Loss', {
                'D_loss': D_loss.item(),
                'G_loss': G_loss.item()
            },
                               global_step=count)
            writer.add_scalars('Grad', {
                'D grad': gd.item(),
                'G grad': gg.item()
            },
                               global_step=count)
            if count % show_iter == 0:
                time_cost = time.time() - timer
                print('Iter :%d , D_loss: %.5f, G_loss: %.5f, time: %.3fs' %
                      (count, D_loss.item(), G_loss.item(), time_cost))
                timer = time.time()
                with torch.no_grad():
                    fake_img = G(fixed_noise).detach()
                    path = 'figs/%s/' % logdir
                    if not os.path.exists(path):
                        os.makedirs(path)
                    vutils.save_image(fake_img,
                                      path + 'iter_%d.png' % count,
                                      normalize=True)
                save_checkpoint(path=logdir,
                                name='SGD-%.3f_%d.pth' % (lr_d, count),
                                D=D,
                                G=G)
            count += 1
    writer.close()
Exemple #5
0
def train_d(epoch_num=10,
            logdir='test',
            optim='SGD',
            loss_name='JSD',
            show_iter=500,
            model_weight=None,
            load_d=False,
            load_g=False,
            compare_path=None,
            info_time=100,
            run_select=None,
            device='cpu'):
    lr_d = 0.001
    lr_g = 0.01
    batchsize = 128
    z_dim = 96
    print('discriminator lr: %.3f' % lr_d)
    dataset = get_data(dataname='MNIST', path='../datas/mnist')
    dataloader = DataLoader(dataset=dataset,
                            batch_size=batchsize,
                            shuffle=True,
                            num_workers=4)
    D = dc_D().to(device)
    G = dc_G(z_dim=z_dim).to(device)
    D.apply(weights_init_d)
    G.apply(weights_init_g)
    if model_weight is not None:
        chk = torch.load(model_weight)
        if load_d:
            D.load_state_dict(chk['D'])
            print('Load D from %s' % model_weight)
        if load_g:
            G.load_state_dict(chk['G'])
            print('Load G from %s' % model_weight)
    if compare_path is not None:
        discriminator = dc_D().to(device)
        model_weight = torch.load(compare_path)
        discriminator.load_state_dict(model_weight['D'])
        model_vec = torch.cat(
            [p.contiguous().view(-1) for p in discriminator.parameters()])
        print('Load discriminator from %s' % compare_path)
    if run_select is not None:
        fixed_data = torch.load(run_select)
        real_set = fixed_data['real_set']
        fake_set = fixed_data['fake_set']
        real_d = fixed_data['real_d']
        fake_d = fixed_data['fake_d']
        fixed_vec = fixed_data['pred_vec']
        print('load fixed data set')
    from datetime import datetime
    current_time = datetime.now().strftime('%b%d_%H-%M-%S')
    # writer = SummaryWriter(log_dir='logs/%s/%s_%.3f' % (logdir, current_time, lr_d))
    if optim == 'SGD':
        d_optimizer = SGD(D.parameters(), lr=lr_d)
        print('Optimizer SGD')
    else:
        d_optimizer = BCGD2(max_params=G.parameters(),
                            min_params=D.parameters(),
                            lr_max=lr_g,
                            lr_min=lr_d,
                            update_max=False,
                            device=device,
                            collect_info=True)
        print('Optimizer BCGD2')
    timer = time.time()
    count = 0
    d_losses = []
    g_losses = []
    for e in range(epoch_num):
        tol_correct = 0
        tol_dloss = 0
        tol_gloss = 0
        for real_x in dataloader:
            real_x = real_x[0].to(device)
            d_real = D(real_x)
            z = torch.randn((real_x.shape[0], z_dim), device=device)
            fake_x = G(z)
            d_fake = D(fake_x)
            D_loss = get_loss(name=loss_name,
                              g_loss=False,
                              d_real=d_real,
                              d_fake=d_fake)
            tol_dloss += D_loss.item() * real_x.shape[0]
            G_loss = get_loss(name=loss_name,
                              g_loss=True,
                              d_real=d_real,
                              d_fake=d_fake)
            tol_gloss += G_loss.item() * fake_x.shape[0]
            if compare_path is not None and count % info_time == 0:
                diff = get_diff(net=D, model_vec=model_vec)
                # writer.add_scalar('Distance from checkpoint', diff.item(), global_step=count)
                if run_select is not None:
                    with torch.no_grad():
                        d_real_set = D(real_set)
                        d_fake_set = D(fake_set)
                        diff_real = torch.norm(d_real_set - real_d, p=2)
                        diff_fake = torch.norm(d_fake_set - fake_d, p=2)
                        d_vec = torch.cat([d_real_set, d_fake_set])
                        diff = torch.norm(d_vec.sub_(fixed_vec), p=2)
                        # writer.add_scalars('L2 norm of pred difference',
                        #                    {'Total': diff.item(),
                        #                     'real set': diff_real.item(),
                        #                     'fake set': diff_fake.item()},
                        #                    global_step=count)
            d_optimizer.zero_grad()
            if optim == 'SGD':
                D_loss.backward()
                d_optimizer.step()
                gd = torch.norm(torch.cat(
                    [p.grad.contiguous().view(-1) for p in D.parameters()]),
                                p=2)
                gg = torch.norm(torch.cat(
                    [p.grad.contiguous().view(-1) for p in G.parameters()]),
                                p=2)
            else:
                d_optimizer.step(D_loss)
                cgdInfo = d_optimizer.get_info()
                gd = cgdInfo['grad_y']
                gg = cgdInfo['grad_x']
                # writer.add_scalars('Grad', {'update': cgdInfo['update']}, global_step=count)
            tol_correct += (d_real > 0).sum().item() + (d_fake <
                                                        0).sum().item()
            # writer.add_scalars('Loss', {'D_loss': D_loss.item(),
            #                             'G_loss': G_loss.item()}, global_step=count)
            # writer.add_scalars('Grad', {'D grad': gd,
            #                             'G grad': gg}, global_step=count)
            # writer.add_scalars('Discriminator output', {'Generated image': d_fake.mean().item(),
            #                                             'Real image': d_real.mean().item()},
            #                    global_step=count)
            if count % show_iter == 0:
                time_cost = time.time() - timer
                print('Iter :%d , D_loss: %.5f, G_loss: %.5f, time: %.3fs' %
                      (count, D_loss.item(), G_loss.item(), time_cost))
                timer = time.time()
                save_checkpoint(path=logdir,
                                name='FixG-%.3f_%d.pth' % (lr_d, count),
                                D=D,
                                G=G)
            count += 1
Exemple #6
0
def train_g(epoch_num=10,
            logdir='test',
            loss_name='JSD',
            show_iter=500,
            model_weight=None,
            load_d=False,
            load_g=False,
            device='cpu'):
    lr_d = 0.01
    lr_g = 0.01
    batchsize = 128
    z_dim = 96
    print('MNIST, discriminator lr: %.3f' % lr_d)
    dataset = get_data(dataname='MNIST', path='../datas/mnist')
    dataloader = DataLoader(dataset=dataset,
                            batch_size=batchsize,
                            shuffle=True,
                            num_workers=4)
    D = dc_D().to(device)
    G = dc_G(z_dim=z_dim).to(device)
    D.apply(weights_init_d)
    G.apply(weights_init_g)
    if model_weight is not None:
        chk = torch.load(model_weight)
        if load_d:
            D.load_state_dict(chk['D'])
            print('Load D from %s' % model_weight)
        if load_g:
            G.load_state_dict(chk['G'])
            print('Load G from %s' % model_weight)
    from datetime import datetime
    current_time = datetime.now().strftime('%b%d_%H-%M-%S')
    # writer = SummaryWriter(log_dir='logs/%s/%s_%.3f' % (logdir, current_time, lr_g))
    d_optimizer = SGD(D.parameters(), lr=lr_d)
    g_optimizer = SGD(G.parameters(), lr=lr_g)
    timer = time.time()
    count = 0
    for e in range(epoch_num):
        for real_x in dataloader:
            real_x = real_x[0].to(device)
            d_real = D(real_x)
            z = torch.randn((real_x.shape[0], z_dim), device=device)
            fake_x = G(z)
            d_fake = D(fake_x)
            D_loss = get_loss(name=loss_name,
                              g_loss=False,
                              d_real=d_real,
                              d_fake=d_fake)
            G_loss = get_loss(name=loss_name,
                              g_loss=True,
                              d_real=d_real,
                              d_fake=d_fake)
            d_optimizer.zero_grad()
            g_optimizer.zero_grad()
            G_loss.backward()
            g_optimizer.step()
            print('D_loss: {}, G_loss: {}'.format(D_loss.item(),
                                                  G_loss.item()))
            # writer.add_scalars('Loss', {'D_loss': D_loss.item(),
            #                             'G_loss': G_loss.item()},
            #                    global_step=count)
            # writer.add_scalars('Discriminator output', {'Generated image': d_fake.mean().item(),
            #                                             'Real image': d_real.mean().item()},
            #                    global_step=count)
            if count % show_iter == 0:
                time_cost = time.time() - timer
                print('Iter :%d , D_loss: %.5f, G_loss: %.5f, time: %.3fs' %
                      (count, D_loss.item(), G_loss.item(), time_cost))
                timer = time.time()
                save_checkpoint(path=logdir,
                                name='FixD-%.3f_%d.pth' % (lr_d, count),
                                D=D,
                                G=G)
            count += 1