コード例 #1
0
ファイル: seq_gan.py プロジェクト: lijianxue77/NF-GAN
def run_seq_gan():
    config = Config()
    n_samples = config.get('n_samples')
    batch_size = config.get('batch_size')
    gen_embedding_dim = config.get('gen_embedding_dim')
    gen_hidden_dim = config.get('gen_hidden_dim')
    dis_embedding_dim = config.get('dis_embedding_dim')
    dis_hidden_dim = config.get('dis_hidden_dim')
    dataset_features = config.get('dataset_features')
    dataset_dtypes = config.get('dataset_dtypes')
    generated_features = config.get('generated_features')
    service_list = config.get('service_list')
    protocol_service_dict = config.get('protocol_service_dict')
    service_port_dict = config.get('service_port_dict')
    file_path = config.get('file_path')
    CUDA = torch.cuda.is_available()

    dataset = Traffic_Dataset(file_path,
                              dataset_features,
                              dataset_dtypes,
                              generated_features,
                              batch_size=batch_size,
                              transform=build_input_indices)
    vocab_dim = dataset.vocabulary_length
    max_seq_len = dataset.max_seq_length
    train_epochs = 100

    g = Generator(gen_embedding_dim, gen_hidden_dim, vocab_dim, max_seq_len,
                  CUDA)
    d = Discriminator(dis_embedding_dim, dis_hidden_dim, vocab_dim,
                      max_seq_len, CUDA)
    if CUDA:
        g.cuda()
        d.cuda()
    g_opt = optim.Adam(g.parameters())
    d_opt = optim.Adagrad(d.parameters())

    pre_training(g, d, g_opt, d_opt, dataset, n_samples, batch_size, CUDA)
    training(g, d, g_opt, d_opt, dataset, train_epochs, n_samples, batch_size,
             CUDA, service_list, protocol_service_dict, service_port_dict)
    visualize(dataset_features, generated_features)
コード例 #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--workers',
                        type=int,
                        help='number of data loading workers',
                        default=2)
    parser.add_argument('--batch_size',
                        type=int,
                        default=50,
                        help='input batch size')
    parser.add_argument('--nz',
                        type=int,
                        default=100,
                        help='size of the latent z vector')
    parser.add_argument('--nch_g', type=int, default=64)
    parser.add_argument('--nch_d', type=int, default=64)
    parser.add_argument('--n_epoch',
                        type=int,
                        default=200,
                        help='number of epochs to train for')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0002,
                        help='learning rate, default=0.0002')
    parser.add_argument('--beta1',
                        type=float,
                        default=0.5,
                        help='beta1 for adam. default=0.5')
    parser.add_argument('--outf',
                        default='./result_lsgan',
                        help='folder to output images and model checkpoints')

    opt = parser.parse_args()
    print(opt)

    try:
        os.makedirs(opt.outf)
    except OSError:
        pass

    # 乱数のシード(種)を固定
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)

    # STL-10のトレーニングデータセットとテストデータセットを読み込む
    trainset = dset.STL10(root='./dataset/stl10_root',
                          download=True,
                          split='train+unlabeled',
                          transform=transforms.Compose([
                              transforms.RandomResizedCrop(64,
                                                           scale=(88 / 96,
                                                                  1.0),
                                                           ratio=(1., 1.)),
                              transforms.RandomHorizontalFlip(),
                              transforms.ColorJitter(brightness=0.05,
                                                     contrast=0.05,
                                                     saturation=0.05,
                                                     hue=0.05),
                              transforms.ToTensor(),
                              transforms.Normalize((0.5, 0.5, 0.5),
                                                   (0.5, 0.5, 0.5)),
                          ]))  # ラベルを使用しないのでラベルなしを混在した'train+unlabeled'を読み込む
    testset = dset.STL10(root='./dataset/stl10_root',
                         download=True,
                         split='test',
                         transform=transforms.Compose([
                             transforms.RandomResizedCrop(64,
                                                          scale=(88 / 96, 1.0),
                                                          ratio=(1., 1.)),
                             transforms.RandomHorizontalFlip(),
                             transforms.ColorJitter(brightness=0.05,
                                                    contrast=0.05,
                                                    saturation=0.05,
                                                    hue=0.05),
                             transforms.ToTensor(),
                             transforms.Normalize((0.5, 0.5, 0.5),
                                                  (0.5, 0.5, 0.5)),
                         ]))
    dataset = trainset + testset  # STL-10のトレーニングデータセットとテストデータセットを合わせて訓練データとする

    # 訓練データをセットしたデータローダーを作成する
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=opt.batch_size,
                                             shuffle=True,
                                             num_workers=int(opt.workers))

    # 学習に使用するデバイスを得る。可能ならGPUを使用する
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print('device:', device)

    # 生成器G。ランダムベクトルから贋作画像を生成する
    netG = Generator(nz=opt.nz, nch_g=opt.nch_g).to(device)
    netG.apply(weights_init)  # weights_init関数で初期化
    print(netG)

    # 識別器D。画像が、元画像か贋作画像かを識別する
    netD = Discriminator(nch_d=opt.nch_d).to(device)
    netD.apply(weights_init)
    print(netD)

    criterion = nn.MSELoss()  # 損失関数は平均二乗誤差損失

    # オプティマイザ−のセットアップ
    optimizerD = optim.Adam(netD.parameters(),
                            lr=opt.lr,
                            betas=(opt.beta1, 0.999),
                            weight_decay=1e-5)  # 識別器D用
    optimizerG = optim.Adam(netG.parameters(),
                            lr=opt.lr,
                            betas=(opt.beta1, 0.999),
                            weight_decay=1e-5)  # 生成器G用

    fixed_noise = torch.randn(opt.batch_size, opt.nz, 1, 1,
                              device=device)  # 確認用の固定したノイズ

    # 学習のループ
    for epoch in range(opt.n_epoch):
        for itr, data in enumerate(dataloader):
            real_image = data[0].to(device)  # 元画像
            sample_size = real_image.size(0)  # 画像枚数
            noise = torch.randn(sample_size, opt.nz, 1, 1,
                                device=device)  # 正規分布からノイズを生成

            real_target = torch.full((sample_size, ), 1.,
                                     device=device)  # 元画像に対する識別信号の目標値「1」
            fake_target = torch.full((sample_size, ), 0.,
                                     device=device)  # 贋作画像に対する識別信号の目標値「0」

            ############################
            # 識別器Dの更新
            ###########################
            netD.zero_grad()  # 勾配の初期化

            output = netD(real_image)  # 識別器Dで元画像に対する識別信号を出力
            errD_real = criterion(output, real_target)  # 元画像に対する識別信号の損失値
            D_x = output.mean().item()

            fake_image = netG(noise)  # 生成器Gでノイズから贋作画像を生成

            output = netD(fake_image.detach())  # 識別器Dで元画像に対する識別信号を出力
            errD_fake = criterion(output, fake_target)  # 贋作画像に対する識別信号の損失値
            D_G_z1 = output.mean().item()

            errD = errD_real + errD_fake  # 識別器Dの全体の損失
            errD.backward()  # 誤差逆伝播
            optimizerD.step()  # Dのパラメーターを更新

            ############################
            # 生成器Gの更新
            ###########################
            netG.zero_grad()  # 勾配の初期化

            output = netD(fake_image)  # 更新した識別器Dで改めて贋作画像に対する識別信号を出力
            errG = criterion(
                output, real_target)  # 生成器Gの損失値。Dに贋作画像を元画像と誤認させたいため目標値は「1」
            errG.backward()  # 誤差逆伝播
            D_G_z2 = output.mean().item()

            optimizerG.step()  # Gのパラメータを更新

            print(
                '[{}/{}][{}/{}] Loss_D: {:.3f} Loss_G: {:.3f} D(x): {:.3f} D(G(z)): {:.3f}/{:.3f}'
                .format(epoch + 1, opt.n_epoch, itr + 1, len(dataloader),
                        errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

            if epoch == 0 and itr == 0:  # 初回に元画像を保存する
                vutils.save_image(real_image,
                                  '{}/real_samples.png'.format(opt.outf),
                                  normalize=True,
                                  nrow=10)

        ############################
        # 確認用画像の生成
        ############################
        fake_image = netG(fixed_noise)  # 1エポック終了ごとに確認用の贋作画像を生成する
        vutils.save_image(fake_image.detach(),
                          '{}/fake_samples_epoch_{:03d}.png'.format(
                              opt.outf, epoch + 1),
                          normalize=True,
                          nrow=10)

        ############################
        # モデルの保存
        ############################
        if (epoch + 1) % 50 == 0:  # 50エポックごとにモデルを保存する
            torch.save(netG.state_dict(),
                       '{}/netG_epoch_{}.pth'.format(opt.outf, epoch + 1))
            torch.save(netD.state_dict(),
                       '{}/netD_epoch_{}.pth'.format(opt.outf, epoch + 1))
コード例 #3
0
class Experiment():
    def __init__(self, args):
        self.args = args
        self.writer = SummaryWriter(args.output_dir)
        self.iter_i = 1

        # data
        transform_list = [
            transforms.Resize(args.imsize),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
        train_dataset = datasets.CIFAR10(
            './data',
            train=True,
            transform=transforms.Compose(transform_list),
            download=True)
        self.train_loader = DataLoader(train_dataset,
                                       batch_size=args.batch_size,
                                       num_workers=args.n_workers,
                                       shuffle=True)

        # network
        self.G = Generator(args.nz, args.ngf, args.nc).to(args.device)
        self.D = Discriminator(args.nc, args.ndf).to(args.device)
        self.criterion = nn.BCELoss()
        self.optimizer_G = optim.Adam(self.G.parameters(),
                                      lr=args.lr,
                                      betas=(args.beta1, args.beta2))
        self.optimizer_D = optim.Adam(self.D.parameters(),
                                      lr=args.lr,
                                      betas=(args.beta1, args.beta2))

        self.real_label = 1
        self.fake_label = 0
        self.fixed_z = torch.randn((args.batch_size, args.nz, 1, 1),
                                   device=args.device)

    def train(self, epoch):
        self.G.train()
        self.D.train()

        train_loss_G, train_loss_D = 0., 0.
        n_samples = 0
        for data, _ in self.train_loader:
            batch_size = len(data)
            n_samples += batch_size

            real = data.to(self.args.device)

            # train D
            self.optimizer_D.zero_grad()
            # with real
            label = torch.full((batch_size, ),
                               self.real_label,
                               device=self.args.device)
            output_real = self.D(real)
            loss_D_real = self.criterion(output_real, label)
            loss_D_real.backward()

            # with fake
            z = torch.randn((batch_size, self.args.nz, 1, 1),
                            device=self.args.device)
            fake = self.G(z)
            label = label.fill_(self.fake_label)
            output_fake = self.D(fake)
            loss_D_fake = self.criterion(output_fake, label)
            loss_D_fake.backward(retain_graph=True)

            loss_D = loss_D_real + loss_D_fake
            self.optimizer_D.step()

            # train G
            self.optimizer_G.zero_grad()
            label = label.fill_(self.real_label)
            output_fake = self.D(fake)
            loss_G = self.criterion(output_fake, label)
            loss_G.backward()
            self.optimizer_G.step()

            loss_D = loss_D.item()
            loss_G = loss_G.item()
            train_loss_D += loss_D
            train_loss_G += loss_G

            if self.iter_i % self.args.log_freq == 0:
                self.writer.add_scalar('Loss/D', loss_D, self.iter_i)
                self.writer.add_scalar('Loss/G', loss_G, self.iter_i)

                print('Epoch {} Train [{}/{}]:  Loss/D {:.4f} Loss/G {:.4f}'.
                      format(epoch, n_samples, len(self.train_loader.dataset),
                             loss_D / batch_size, loss_G / batch_size))

            self.iter_i += 1

        dataset_size = len(self.train_loader.dataset)
        print('Epoch {} Train: Loss/D {:.4f} Loss/G {:.4f}'.format(
            epoch, train_loss_D / dataset_size, train_loss_G / dataset_size))

    def test(self, epoch):
        self.G.eval()

        with torch.no_grad():
            fake = self.G(self.fixed_z)
            grid = make_grid(fake, normalize=True).cpu()
            self.writer.add_image('Fake', grid, self.iter_i)
            # show(grid)
            fname = osp.join(self.args.output_dir,
                             'fake_epoch_{}.png'.format(epoch))
            save_image(fake, fname, nrow=8)

    def save(self, epoch):
        # TODO
        torch.save(self.model.state_dict(),
                   './results/checkpoint_{}.pt'.format(epoch))

    def run(self):
        for epoch_i in range(1, 1 + self.args.epochs):
            self.train(epoch_i)
            self.test(epoch_i)
コード例 #4
0
ファイル: train.py プロジェクト: HuviX/gan_bot
        elif classname.find('BatchNorm') != -1:
            torch.nn.init.normal_(m.weight, 1.0, 0.02)
            torch.nn.init.zeros_(m.bias)

    netG = Generator(ngpu, nz, ngf).to(device)
    netG.apply(weights_init)

    netD = Discriminator(ngpu, ndf).to(device)
    netD.apply(weights_init)

    criterion = nn.BCELoss()

    fixed_noise = torch.randn(opt.batchSize, nz, 1, 1, device=device)
    real_label = 1
    fake_label = 0
    optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

    schedulerD = ReduceLROnPlateau(optimizerD,
                                   mode='min',
                                   factor=0.5,
                                   patience=2,
                                   verbose=True)
    schedulerG = ReduceLROnPlateau(optimizerG,
                                   mode='min',
                                   factor=0.5,
                                   patience=2,
                                   verbose=True)
    print(len(dataloader.dataset))
    writer = SummaryWriter('logs/')
    for epoch in range(opt.nepoch):
コード例 #5
0
ファイル: model.py プロジェクト: HuayueZhang/StarGAN_pytorch
class StarGAN:
    def __init__(self, opt):
        self.opt = opt
        self.global_step = opt.load_iter
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

        # define net class instance
        self.G_net = Generator(opt).to(self.device)
        self.D_net = Discriminator(opt).to(self.device)
        if opt.load_model and opt.load_iter > 0:
            self._load_pre_model(self.G_net, 'G')
            self._load_pre_model(self.D_net, 'D')

        # define objectives and optimizers
        self.adv_loss = torch.mean   # 这里的adv loss直接是真假结果的评分,真图越大越好,假图越小越好
        # self.cls_loss = torch.nn.BCELoss() # ??????????? 有啥区别
        self.cls_loss = F.binary_cross_entropy_with_logits
        self.rec_loss = torch.mean
        self.G_optimizer = torch.optim.Adam(self.G_net.parameters(), opt.G_lr, [opt.beta1, opt.beta2])
        self.D_optimizer = torch.optim.Adam(self.D_net.parameters(), opt.D_lr, [opt.beta1, opt.beta2])

        self.sample_gotten = False  # 把它放在init里面,是因为它只随着类的调用初始化一次,是固定的sample
        self.writer = TBVisualizer(opt)

    def _load_pre_model(self, net, module):
        filename = '%d-%s.ckpt' % (self.opt.load_iter, module)
        loadpath = os.path.join(self.opt.save_dir, self.opt.model_folder, filename)
        net.load_state_dict(torch.load(loadpath))
        print('load model: %s' % loadpath)

    def _set_eval_sample(self):
        # let the sample be the first batch of the whole training process
        self.sample_real = self.img_real
        self.sample_c_trg_list= self._create_fix_trg_label(self.c_org)
        self.sample_gotten = True

    def _create_fix_trg_label(self, c_org):
        # eval的时候,希望看到固定初始图片被转换成固定的其他样子=>设置固定的target domain labels
        # test的时候,也要这样为测试图片选择固定的目标域
        hair_color_ids = []
        for id, selected_attr in enumerate(self.opt.selected_attr):
            if selected_attr in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']:
                hair_color_ids.append(id)

        c_trg = c_org.clone()
        c_trg_list = []
        for i in range(self.opt.num_attr):
            # 把由5个特征(头发0,头发1,头发2,性别,年龄)表示的original domain转换到5个由5个特征表示的target domain,
            # target domain1:拥有头发0,性别年龄不变
            # target domain2:拥有头发1,性别年龄不变
            # target domain3:拥有头发2,性别年龄不变
            # target domain4:头发不变,性别改变,年龄不变
            # target domain5:头发不变,性别不变,年龄改变
            if i in hair_color_ids:
                for j in hair_color_ids:
                    c_trg[:, j] = int(i == j)
            else:
                c_trg[:, i] = (c_trg[:, i] == 0)

            c_trg_list.append(c_trg)
        return c_trg_list

    def set_inputs(self, data):
        img_real, labels_org = data
        self.img_real = img_real.to(self.device)
        self.c_org = labels_org.to(self.device)
        # generate target domain labels randomly
        # torch.randperm(n)把1到n这些数随机打乱得到的一个数字序列
        rand_idx = torch.randperm(labels_org.size(0))
        labels_trg = labels_org[rand_idx]
        self.c_trg = labels_trg.to(self.device)

    def _optimize_D(self, visualize=True):
        # forward
        # using real images
        src_real, cls_real = self.D_net(self.img_real)
        loss_D_real = - self.adv_loss(src_real)  # loss都是越小越好,如果希望越大越好,就取负
        loss_D_cls = self.cls_loss(cls_real, self.c_org)

        # using fake image
        img_fake = self.G_net(self.img_real, self.c_trg)
        src_fake, cls_fake = self.D_net(img_fake.detach())
        loss_D_fake = self.adv_loss(src_fake)

        loss_D = loss_D_real + loss_D_fake + self.opt.cls_lambda * loss_D_cls

        # backward
        self.D_optimizer.zero_grad()
        loss_D.backward()
        self.D_optimizer.step()

        # if visualize:
        self.writer.scalar('loss_D_real', loss_D_real, self.global_step)
        self.writer.scalar('loss_D_fake', loss_D_fake, self.global_step)
        self.writer.scalar('loss_D_cls', loss_D_fake, self.global_step)
        self.writer.scalar('loss_D', loss_D, self.global_step)

        return loss_D

    def _optimize_G(self, visualize=True):
        # forward
        img_fake = self.G_net(self.img_real, self.c_trg)

        # fuse discriminator
        src_fake, cls_fake = self.D_net(img_fake)
        loss_G_fake = - self.adv_loss(src_fake)
        loss_G_cls = self.cls_loss(cls_fake, self.c_trg)

        # reconstruct images
        img_rec = self.G_net(img_fake, self.c_org)
        loss_G_rec = self.rec_loss(torch.abs(img_rec-self.img_real))

        loss_G = loss_G_fake + self.opt.cls_lambda * loss_G_cls + self.opt.rec_lambda * loss_G_rec

        # backward
        self.G_optimizer.zero_grad()
        loss_G.backward()
        self.G_optimizer.step()

        # if visualize:
        self.writer.scalar('loss_G_fake', loss_G_fake, self.global_step)
        self.writer.scalar('loss_G_cls', loss_G_cls, self.global_step)
        self.writer.scalar('loss_G_rec', loss_G_rec, self.global_step)
        self.writer.scalar('loss_G', loss_G, self.global_step)
        return loss_G

    def eval_sample(self):
        if not self.sample_gotten:
            self._set_eval_sample()
        sample_fake_list = []
        for sample_c_trg in self.sample_c_trg_list:
            sample_fake = self.G_net(self.sample_real, sample_c_trg)  # (16, 5, 128, 128)
            sample_fake_list.append(sample_fake)

        return sample_fake_list

    def optimizer(self):
        self.global_step += 1
        loss_D = self._optimize_D(visualize=True)
        # loss_G = self._optimize_G(visualize=False)
        loss_G = self._optimize_G(visualize=True)

        if self.global_step % 100 == 0:
            message = 'iter[%6d/%d], loss_D = %.6f, loss_G = %.6f' % \
                      (self.global_step, self.opt.max_iter, loss_D, loss_G)
            self.writer.log_and_print(message)

        if self.global_step % 1000 == 0:
            sample_fake_list = self.eval_sample()
            manifold_img_array = get_manifold_img_array(self.sample_real, sample_fake_list, self.opt)  # (H, W, 3)
            manifold_img_array = (manifold_img_array + 1.) / 2.
            self.writer.image('eval_sample', manifold_img_array, self.global_step, self.opt)

        if self.global_step % 10000 == 0:
            self.save_model()

    # def test(self):

    def save_model(self):
        filename = '%d-G.ckpt' % self.global_step
        savepath = os.path.join(self.opt.save_dir, self.opt.model_folder, filename)
        torch.save(self.G_net.state_dict(), savepath)

        filename = '%d-D.ckpt' % self.global_step
        savepath = os.path.join(self.opt.save_dir, self.opt.model_folder, filename)
        torch.save(self.D_net.state_dict(), savepath)
コード例 #6
0
#================================optimizer======================================
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


dis = Discriminator().to(device)
dis.apply(weights_init)
gen = Generator(batch_size).to(device)
gen.apply(weights_init)

dis_opt = optim.Adam(dis.parameters(), lr=0.0002, betas=(0.5, 0.999))
gen_opt = optim.Adam(gen.parameters(), lr=0.0002, betas=(0.5, 0.999))
print("setted model/loss/optimizer")

#================================train==========================================
print("start training")
iteration_sum = 0
for epo in range(epoch):

    running_loss_dis = 0.0
    running_loss_gen = 0.0
    iterations = 0

    for i, data in enumerate(trainloader, 0):
        iterations = i
        inputs = data
コード例 #7
0
ファイル: train_AAE.py プロジェクト: swyoon/GPND
def train(folding_id, inliner_classes, ic):
    cfg = get_cfg_defaults()
    cfg.merge_from_file('configs/mnist.yaml')
    cfg.freeze()
    logger = logging.getLogger("logger")

    zsize = cfg.MODEL.LATENT_SIZE
    output_folder = os.path.join('results_' + str(folding_id) + "_" +
                                 "_".join([str(x) for x in inliner_classes]))
    os.makedirs(output_folder, exist_ok=True)
    os.makedirs('models', exist_ok=True)

    train_set, _, _ = make_datasets(cfg, folding_id, inliner_classes)

    logger.info("Train set size: %d" % len(train_set))

    G = Generator(cfg.MODEL.LATENT_SIZE,
                  channels=cfg.MODEL.INPUT_IMAGE_CHANNELS)
    G.weight_init(mean=0, std=0.02)

    D = Discriminator(channels=cfg.MODEL.INPUT_IMAGE_CHANNELS)
    D.weight_init(mean=0, std=0.02)

    E = Encoder(cfg.MODEL.LATENT_SIZE, channels=cfg.MODEL.INPUT_IMAGE_CHANNELS)
    E.weight_init(mean=0, std=0.02)

    if cfg.MODEL.Z_DISCRIMINATOR_CROSS_BATCH:
        ZD = ZDiscriminator_mergebatch(zsize, cfg.TRAIN.BATCH_SIZE)
    else:
        ZD = ZDiscriminator(zsize, cfg.TRAIN.BATCH_SIZE)
    ZD.weight_init(mean=0, std=0.02)

    lr = cfg.TRAIN.BASE_LEARNING_RATE

    G_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
    D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
    GE_optimizer = optim.Adam(list(E.parameters()) + list(G.parameters()),
                              lr=lr,
                              betas=(0.5, 0.999))
    ZD_optimizer = optim.Adam(ZD.parameters(), lr=lr, betas=(0.5, 0.999))

    BCE_loss = nn.BCELoss()
    sample = torch.randn(64, zsize).view(-1, zsize, 1, 1)

    tracker = LossTracker(output_folder=output_folder)

    for epoch in range(cfg.TRAIN.EPOCH_COUNT):
        G.train()
        D.train()
        E.train()
        ZD.train()

        epoch_start_time = time.time()

        data_loader = make_dataloader(train_set, cfg.TRAIN.BATCH_SIZE,
                                      torch.cuda.current_device())
        train_set.shuffle()

        if (epoch + 1) % 30 == 0:
            G_optimizer.param_groups[0]['lr'] /= 4
            D_optimizer.param_groups[0]['lr'] /= 4
            GE_optimizer.param_groups[0]['lr'] /= 4
            ZD_optimizer.param_groups[0]['lr'] /= 4
            print("learning rate change!")

        for y, x in data_loader:
            x = x.view(-1, cfg.MODEL.INPUT_IMAGE_CHANNELS,
                       cfg.MODEL.INPUT_IMAGE_SIZE, cfg.MODEL.INPUT_IMAGE_SIZE)

            y_real_ = torch.ones(x.shape[0])
            y_fake_ = torch.zeros(x.shape[0])

            y_real_z = torch.ones(
                1 if cfg.MODEL.Z_DISCRIMINATOR_CROSS_BATCH else x.shape[0])
            y_fake_z = torch.zeros(
                1 if cfg.MODEL.Z_DISCRIMINATOR_CROSS_BATCH else x.shape[0])

            #############################################

            D.zero_grad()

            D_result = D(x).squeeze()
            D_real_loss = BCE_loss(D_result, y_real_)

            z = torch.randn((x.shape[0], zsize)).view(-1, zsize, 1, 1)
            z = Variable(z)

            x_fake = G(z).detach()
            D_result = D(x_fake).squeeze()
            D_fake_loss = BCE_loss(D_result, y_fake_)

            D_train_loss = D_real_loss + D_fake_loss
            D_train_loss.backward()

            D_optimizer.step()

            tracker.update(dict(D=D_train_loss))

            #############################################

            G.zero_grad()

            z = torch.randn((x.shape[0], zsize)).view(-1, zsize, 1, 1)
            z = Variable(z)

            x_fake = G(z)
            D_result = D(x_fake).squeeze()

            G_train_loss = BCE_loss(D_result, y_real_)

            G_train_loss.backward()
            G_optimizer.step()

            tracker.update(dict(G=G_train_loss))

            #############################################

            ZD.zero_grad()

            z = torch.randn((x.shape[0], zsize)).view(-1, zsize)
            z = Variable(z)

            ZD_result = ZD(z).squeeze()
            ZD_real_loss = BCE_loss(ZD_result, y_real_z)

            z = E(x).squeeze().detach()

            ZD_result = ZD(z).squeeze()
            ZD_fake_loss = BCE_loss(ZD_result, y_fake_z)

            ZD_train_loss = ZD_real_loss + ZD_fake_loss
            ZD_train_loss.backward()

            ZD_optimizer.step()

            tracker.update(dict(ZD=ZD_train_loss))

            # #############################################

            E.zero_grad()
            G.zero_grad()

            z = E(x)
            x_d = G(z)

            ZD_result = ZD(z.squeeze()).squeeze()

            E_train_loss = BCE_loss(ZD_result, y_real_z) * 1.0

            Recon_loss = F.binary_cross_entropy(x_d, x.detach()) * 2.0

            (Recon_loss + E_train_loss).backward()

            GE_optimizer.step()

            tracker.update(dict(GE=Recon_loss, E=E_train_loss))

            # #############################################

        comparison = torch.cat([x, x_d])
        save_image(comparison.cpu(),
                   os.path.join(output_folder,
                                'reconstruction_' + str(epoch) + '.png'),
                   nrow=x.shape[0])

        epoch_end_time = time.time()
        per_epoch_ptime = epoch_end_time - epoch_start_time

        logger.info(
            '[%d/%d] - ptime: %.2f, %s' %
            ((epoch + 1), cfg.TRAIN.EPOCH_COUNT, per_epoch_ptime, tracker))

        tracker.register_means(epoch)
        tracker.plot()

        with torch.no_grad():
            resultsample = G(sample).cpu()
            save_image(
                resultsample.view(64, cfg.MODEL.INPUT_IMAGE_CHANNELS,
                                  cfg.MODEL.INPUT_IMAGE_SIZE,
                                  cfg.MODEL.INPUT_IMAGE_SIZE),
                os.path.join(output_folder, 'sample_' + str(epoch) + '.png'))

    logger.info("Training finish!... save training results")

    os.makedirs("models", exist_ok=True)

    print("Training finish!... save training results")
    torch.save(G.state_dict(), "models/Gmodel_%d_%d.pkl" % (folding_id, ic))
    torch.save(E.state_dict(), "models/Emodel_%d_%d.pkl" % (folding_id, ic))
コード例 #8
0
dataset = MyDataset(X_A_dataset, X_B_dataset)

# Create data loader
train_loader = torch.utils.data.DataLoader(dataset=dataset,
                                           batch_size=batch_size,
                                           shuffle=True)

D_A = Discriminator().to(device)
G_A = UNet(n_channels=3, n_classes=3).to(device)
D_B = Discriminator().to(device)
G_B = UNet(n_channels=3, n_classes=3).to(device)

criterion = nn.MSELoss()
rec_criterion = nn.L1Loss()
dA_optimizer = torch.optim.Adam(D_A.parameters(), lr=0.0002)
gA_optimizer = torch.optim.Adam(G_A.parameters(), lr=0.0002)
dB_optimizer = torch.optim.Adam(D_B.parameters(), lr=0.0002)
gB_optimizer = torch.optim.Adam(G_B.parameters(), lr=0.0002)


def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)


def reset_grad():
    dA_optimizer.zero_grad()
    gA_optimizer.zero_grad()
    dB_optimizer.zero_grad()
    gB_optimizer.zero_grad()
コード例 #9
0
class VAECycleGan(nn.Module):

    def __init__(self, args):
        super(VAECycleGan, self).__init__()
        self.x_dim = args["x_dim"]
        self.z_dim = args["z_dim"]
        self.lam0 = args["lam0"]
        self.lam1 = args["lam1"]
        self.lam2 = args["lam2"]
        self.lam3 = args["lam3"]
        self.lam4 = args["lam4"]
        
        self.vae1 = VAE(self.x_dim, h_dim1 = 2048, h_dim2 = 1024, z_dim = self.z_dim).to(device)
        self.vae2 = VAE(self.x_dim, h_dim1 = 2048, h_dim2 = 1024, z_dim = self.z_dim).to(device)
        #self.share_vae_features()
        
        self.D1 = Discriminator(self.x_dim).to(device)
        self.D2 = Discriminator(self.x_dim).to(device)
        self.G1 = self.vae1.Decoder
        self.G2 = self.vae2.Decoder
        
        
        self.young_data_fname = "kowalcyzk_logNorm_young_variableSubset.csv"
        self.old_data_fname = "kowalcyzk_logNorm_old_variableSubset.csv"
        self.young_data = np.genfromtxt(self.young_data_fname, delimiter=",").transpose()[1:,1:]
        self.old_data = np.genfromtxt(self.old_data_fname, delimiter=",").transpose()[1:,1:]
        
        
        """
        self.young_dataloader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           #transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=100, shuffle=True)
        self.old_dataloader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           #transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=100, shuffle=True)
        """
        
        
        self.young_data = torch.from_numpy(self.young_data).to(device).float()
        self.old_data = torch.from_numpy(self.old_data).to(device).float()
        
        
        self.young_train, self.young_test = self.split_data(self.young_data, 0.1)
        self.old_train, self.old_test = self.split_data(self.old_data, 0.1)
        
        
        self.young_ds = utils.TensorDataset(self.young_train)
        self.young_test_ds = utils.TensorDataset(self.young_test)
        self.young_dataloader = utils.DataLoader(self.young_ds, batch_size=10, shuffle=True)
        self.young_test_loader = utils.DataLoader(self.young_test_ds, batch_size=10, shuffle=True)
        self.old_ds = utils.TensorDataset(self.old_train)
        self.old_test_ds = utils.TensorDataset(self.old_test)
        self.old_dataloader = utils.DataLoader(self.old_ds, batch_size=10, shuffle=True)
        self.old_test_loader = utils.DataLoader(self.old_test_ds, batch_size=10, shuffle=True)
        
        
        self.G_optim = optim.Adam(list(self.G1.parameters()) + list(self.G2.parameters()), lr=0.005)
        self.D_optim = optim.Adam(list(self.D1.parameters()) + list(self.D2.parameters()), lr=0.005)
        self.VAE_optim = optim.Adam(list(self.vae1.parameters()) + list(self.vae2.parameters()), lr=0.005)
        #self.VAE_optim = optim.Adam(list(self.vae1.parameters()) + list(self.vae2.parameters()))
        #self.VAE_optim = optim.Adam(self.vae1.parameters(), lr=0.001)
        
        
    def split_data(self, data, p_test):
    
        N = len(data)
        inds = list(range(N))
        random.shuffle(inds)
        
        train_N = int((1-p_test) * N)
        test_N = int((p_test * N))
        
        train = data[:train_N]
        test = data[train_N:]
        
        return train, test
        
    def share_vae_features(self):
        self.vae1.fc31 = self.vae2.fc31
        self.vae1.fc32 = self.vae2.fc32
        self.vae1.fc4 = self.vae2.fc4
        
    def VAELoss(self, x_in, y_in):
        """
        BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
        KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        return BCE + KLD
        """
        G1_out, mu1, log_var1 = self.vae1(x_in)
        KLD = -self.lam1 * (torch.mean(1 + log_var1 - mu1.pow(2) - log_var1.exp()))
        BCE = self.lam2 * (F.mse_loss(G1_out, x_in.view(-1, self.x_dim), reduction='mean'))
        L1 = (BCE + KLD)
        #print(G1_out.shape, log_var1.shape, mu1.pow(2).shape)
        
        
        G2_out, mu2, log_var2 = self.vae2(y_in)
        KLD_2 = -self.lam1 * (torch.mean(1 + log_var2 - mu2.pow(2) - log_var2.exp()))
        BCE_2 = self.lam2 * (F.mse_loss(G2_out, y_in.view(-1,self.x_dim), reduction='mean'))
        
        L = L1 + (BCE_2 + KLD_2)
        L.backward()
        
        return L
        
    def Disc_loss(self, x_in, y_in):
        ### Compute first loss term
        targets_real = torch.ones((x_in.shape[0], 1)).float().to(device)
        targets_fake = torch.zeros((x_in.shape[0], 1)).float().to(device)
        
        L1_D = self.lam0 * torch.mean(F.mse_loss(self.D1(x_in), targets_real))
        mu2, logvar2 = self.vae2.encode(y_in)
        G1out = self.G1(self.vae2.sampling(mu2.detach(), logvar2.detach()))
        disc_pred_1 = self.D1(G1out)
        L1_D = L1_D + self.lam0 * torch.mean(F.mse_loss(disc_pred_1, targets_fake))
        
        ### Compute second loss term
        L2_D = self.lam0 * torch.mean(F.mse_loss(self.D2(y_in), targets_real))
        mu1, logvar1 = self.vae1.encode(x_in)
        G2out = self.G2(self.vae1.sampling(mu1.detach(), logvar1.detach()))
        disc_pred_2 = self.D2(G2out)
        L2_D = L2_D + self.lam0 * torch.mean(F.mse_loss(disc_pred_2, targets_fake))
        
        L_D = L1_D + L2_D
        L_D.backward()
        
        return L_D
        
    def Gen_loss(self, x_in, y_in):
    
        targets_real = torch.ones((x_in.shape[0], 1)).float().to(device)
        
        mu2, logvar2 = self.vae2.encode(y_in)
        G1out = self.G1(self.vae2.sampling(mu2.detach(), logvar2.detach()))
        disc_pred_1 = self.D1(G1out)
        L1_G = self.lam0 * torch.mean(F.mse_loss(disc_pred_1, targets_real))
        
        mu1, logvar1 = self.vae1.encode(x_in)
        G2out = self.G2(self.vae1.sampling(mu1.detach(), logvar1.detach()))
        disc_pred_2 = self.D2(G2out)
        L2_G = self.lam0 * torch.mean(F.mse_loss(disc_pred_2, targets_real))
        
        L_G = L1_G + L2_G
        L_G.backward()
        
        return L_G
        
    
    def cycleConsistencyLoss(self, x_in, y_in):
        
        #G1_out, mu1, log_var1 = self.vae1(x_in)
        mu1, log_var1 = self.vae1.encode(x_in)
        G2_reconstr = self.vae2.decode(self.vae1.sampling(mu1, log_var1))
        mu2, log_var2 = self.vae2.encode(G2_reconstr)
        G121_cycle = self.vae1.decode(self.vae2.sampling(mu2, log_var2))
        
        L1 = -self.lam3 * (torch.mean(1 + log_var1 - mu1.pow(2) - log_var1.exp()))
        L1 = L1 - self.lam3 * (torch.mean(1 + log_var2 - mu2.pow(2) - log_var2.exp()))
        L1 = L1 + self.lam4 * (F.mse_loss(G121_cycle, x_in))
        
        mu2, log_var2 = self.vae2.encode(y_in)
        G1_reconstr = self.vae1.decode(self.vae2.sampling(mu2, log_var2))
        mu1, log_var1 = self.vae1.encode(G1_reconstr)
        G212_cycle = self.vae2.decode(self.vae1.sampling(mu1, log_var1))
        
        L2 = -self.lam3 * (torch.mean(1 + log_var2 - mu2.pow(2) - log_var2.exp()))
        L2 = L2 - self.lam3 * (torch.mean(1 + log_var1 - mu1.pow(2) - log_var1.exp()))
        L2 = L2 + self.lam4 * (F.mse_loss(G212_cycle, y_in))
        
        L = L1 + L2
        L.backward()
        
        return L

    def train(self, num_epochs):
    
        self.vae1.train()
        self.vae2.train()
        self.G1.train()
        self.G2.train()
        self.D1.train()
        self.D2.train()
        
    
        losses = []
    
        for i in range(num_epochs):
            epoch_loss = 0.0
            total_vae = 0.0
            total_D = 0.0
            total_G = 0.0
            total_cc = 0.0
            
            if (i == 30):
                self.G_optim = optim.Adam(list(self.G1.parameters()) + list(self.G2.parameters()), lr=0.001)
            self.D_optim = optim.Adam(list(self.D1.parameters()) + list(self.D2.parameters()), lr=0.001)
            self.VAE_optim = optim.Adam(list(self.vae1.parameters()) + list(self.vae2.parameters()), lr=0.001)
        
            train_steps = min(len(self.old_dataloader), len(self.young_dataloader))
            
            self.young_dataloader = utils.DataLoader(self.young_ds, batch_size=10, shuffle=True)
            self.old_dataloader = utils.DataLoader(self.old_ds, batch_size=10, shuffle=True)
            old_data = iter(self.old_dataloader)
            young_data = iter(self.young_dataloader)
            
            """
            self.young_dataloader = torch.utils.data.DataLoader(
                datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           #transforms.Normalize((0.1307,), (0.3081,))
                       ])),
                batch_size=100, shuffle=True)
            self.old_dataloader = torch.utils.data.DataLoader(
                datasets.MNIST('../data', train=True, download=True,
                           transform=transforms.Compose([
                               transforms.ToTensor()
                           ])),
                batch_size=100, shuffle=True)
            
            young_data = iter(self.young_dataloader)
            old_data = iter(self.old_dataloader)
            """
            # Iterate through each batch
            for j in range(train_steps-1):
            #for j in range(10):   
                # Get batch of data
                [young_cells] = next(young_data)
                [old_cells] = next(old_data)
                
                young_cells = young_cells.to(device).flatten(1,-1)
                old_cells = old_cells.to(device).flatten(1,-1)
                #print(torch.max(young_cells), torch.min(young_cells))
                # Zero out all optimizers
                #self.G_optim.zero_grad()
                #self.D_optim.zero_grad()
                self.VAE_optim.zero_grad()
                
                """
                # Sum loss functions
                # Gradient backpropagation computed inside these loss functions
                
                D_loss = self.Disc_loss(young_cells, old_cells)
                self.D_optim.step()
                self.D_optim.zero_grad()
                self.G_optim.zero_grad()
                
                G_loss = self.Gen_loss(young_cells, old_cells)
                self.G_optim.step()
                self.D_optim.zero_grad()
                self.G_optim.zero_grad()
                """                
                
                G_loss, D_loss = torch.Tensor([0]), torch.Tensor([0])
                #vaeloss = torch.Tensor([0])
                # @TODO: Uncomment this block of code to enable VAE training
                  
                self.VAE_optim.zero_grad()
                vaeloss = self.VAELoss(young_cells, old_cells)
                self.VAE_optim.step()

                """                
                self.VAE_optim.zero_grad()
                ccloss = self.cycleConsistencyLoss(young_cells, old_cells)
                self.VAE_optim.step()
                """
                
                ccloss = torch.Tensor([0])
                # Exclude discriminator loss from total
                loss = vaeloss + G_loss + ccloss

                epoch_loss += loss.item()
                total_vae += vaeloss.item()
                total_D += D_loss.item()
                total_G += G_loss.item()
                total_cc += ccloss.item()

                #self.G_optim.step()
                #self.D_optim.step()

            [epoch_loss, total_vae, total_D, total_G, total_cc] = loss_arr = np.array([epoch_loss, total_vae, total_D, total_G, total_cc]) / train_steps
            
            
            print("Losses at epoch %d\t VAE: %f\tDISC: %f\tGEN: %f\tCC: %f\tTOTAL: %f" % (i+1, total_vae, total_D, total_G, total_cc, epoch_loss))
            
            losses.append(list(loss_arr))
            plt.plot(np.array(losses)[:,0])
            plt.show(block=False)
            plt.pause(0.001)
            
            self.test(save=False)
            
        plt.figure()
        losses = np.array(losses)
        #[v, _, _, c] = plt.plot(losses[:,1:])
        [vae_loss] = plt.plot(losses[:, [1]])
        plt.legend([vae_loss], ["VAE Reconstruction"], loc=1)
        
        plt.show()
        
    def test(self, save=True):
        
        self.vae1.eval()
        self.vae2.eval()
        self.G1.eval()
        self.G2.eval()
        self.D1.eval()
        self.D2.eval()
        
        """
        young_mu, young_logvar = self.vae1.encode(self.young_test)
        young_Z = self.vae1.sampling(young_mu, young_logvar)
        young_output = self.vae1.decode(young_Z)
        young_corr = self.pearson_correlation(self.young_test, young_output)
        
        old_mu, old_logvar = self.vae2.encode(self.old_test)
        old_Z = self.vae2.sampling(old_mu, old_logvar)
        old_output = self.vae2.decode(old_Z)
        old_corr = self.pearson_correlation(self.old_test, old_output)
        
        
        print("old corr: ", old_corr, " young corr: ", young_corr)
        """
        
        
        
        young_mu, young_logvar = self.vae1.encode(self.young_data)
        young_Z = self.vae1.sampling(young_mu, young_logvar)
        young_output = self.vae1.decode(young_Z)
        young_corr = self.pearson_correlation(self.young_data, young_output)
        
        old_mu, old_logvar = self.vae2.encode(self.old_data)
        old_Z = self.vae2.sampling(old_mu, old_logvar)
        old_output = self.vae2.decode(old_Z)
        old_corr = self.pearson_correlation(self.old_data, old_output)
        
        print("old corr: ", old_corr, " young corr: ", young_corr)
        
        if (save):
            
            np.savetxt("old_mu.csv", old_mu.cpu().data.numpy())
            np.savetxt("old_logvar.csv", old_logvar.cpu().data.numpy())
            np.savetxt("old_Z.csv", old_Z.cpu().data.numpy())
            np.savetxt("old_correlation.csv", np.array([old_corr.cpu().data.numpy()]))
            np.savetxt("old_recreated_from_vae.csv", old_output.cpu().data.numpy())
            
            np.savetxt("young_mu.csv", young_mu.cpu().data.numpy())
            np.savetxt("young_logvar.csv", young_logvar.cpu().data.numpy())
            np.savetxt("young_Z.csv", young_Z.cpu().data.numpy())
            np.savetxt("young_correlation.csv", np.array([young_corr.cpu().data.numpy()]))
            np.savetxt("young_recreated_from_vae.csv", young_output.cpu().data.numpy())
        
        
        """
        print(list(self.young_test[0].cpu().data.numpy()), list(young_output[0].cpu().data.numpy()))
        print(list(self.old_test[0].cpu().data.numpy()), list(old_output[0].cpu().data.numpy()))
        """
            
        
        """
        [x,_] = next(iter(self.young_dataloader))
        x = x.flatten(1,-1).to(device)
        trans, mu, log_var = self.vae1(x)
        trans = trans.reshape((-1, 1, 28, 28)).cpu()[0,0].data.numpy()
        stylized = self.vae2.decode(self.vae1.sampling(mu, log_var)).reshape((-1,1,28,28)).cpu()[0,0].data.numpy()
        
        #plt.imshow(x.reshape((-1,1,28,28)).cpu()[0,0].data.numpy())
        #plt.show()
        #plt.imshow(trans)
        #plt.show()
        cv2.imshow("original", x.reshape((-1,1,28,28)).cpu()[0,0].data.numpy())
        cv2.waitKey(0)
        cv2.imshow("reconstructed", trans)
        cv2.waitKey(0)
        cv2.imshow("stylized", stylized)
        cv2.waitKey(0)
        """
            
        
        """
        young_zero_p = (torch.sum((self.young_data <= 0.0).float()) / self.young_data.numel())
        inferred_young_zero_p = (torch.sum(self.vae2(self.old_data)[0] <= 0).float() / self.old_data.numel())
        
        
        old_zero_p = (torch.sum((self.old_data <= 0.0).float()) / self.old_data.numel()).cpu().data.numpy()
        inferred_old_zero_p = (torch.sum(self.vae1(self.young_data)[0] <= 0).float() / self.young_data.numel()).cpu().data.numpy()
        
        print("Ground truth proportion of 0's in young data: %f. Predicted: %f." % (young_zero_p, inferred_young_zero_p))
        print("Ground truth proportion of 0's in old data: %f. Predicted: %f." % (old_zero_p, inferred_old_zero_p))
        
        mu1, log_var1 = self.vae2.encode(self.old_data)
        z1 = self.vae2.sampling(mu1, log_var1)
        inferred_young = self.vae1.decode(z1)
        mu2, log_var2 = self.vae1.encode(self.young_data)
        z2 = self.vae1.sampling(mu2, log_var2)
        inferred_old = self.vae2.decode(z2)
        
        np.savetxt("YoungToOld.csv", torch.clamp(inferred_old, 0.0).cpu().data.numpy())
        np.savetxt("OldToYoung.csv", torch.clamp(inferred_young, 0.0).cpu().data.numpy())
        np.savetxt("Old_latent.txt", z1.cpu().data.numpy())
        np.savetxt("Young_latent.txt", z2.cpu().data.numpy())
        
        
        mu1, log_var1 = self.vae1.encode(self.young_data[0])
        G2_reconstr = self.vae2.decode(self.vae1.sampling(mu1, log_var1))
        mu2, log_var2 = self.vae2.encode(G2_reconstr)
        G121_cycle = self.vae1.decode(self.vae2.sampling(mu2, log_var2))
        print(self.young_data[0], G2_reconstr, G121_cycle)
        """
        
    def pearson_correlation(self, x, y):
        normx = x - torch.mean(x)
        normy = y - torch.mean(y)
        
        return torch.mean( torch.sum(normx * normy, dim=1) / (torch.sqrt(torch.sum(normx ** 2, dim=1)) * torch.sqrt(torch.sum(normy ** 2, dim=1))) )
コード例 #10
0
ファイル: train.py プロジェクト: cdefga/gan
def train(batch_size=1, latent_size=100, learning_rate=2e-3, num_epochs=100):
    cuda = torch.cuda.is_available()
    device = 'cuda:0' if cuda else 'cpu'
    dataloader = datamaker(batch_size=batch_size)
    fixed_img = np.random.uniform(-1, 1, size=(batch_size, latent_size))
    fixed_img = torch.from_numpy(fixed_img).float()
    gen_imgs = []

    G = Generator(input_size=latent_size)
    D = Discriminator()
    if cuda:
        print('Using CUDA')
        fixed_img = fixed_img.cuda()
        G.cuda()
        D.cuda()
        


    g_optimizer = optim.Adam(G.parameters(), lr=learning_rate)
    d_optimizer = optim.Adam(D.parameters(), lr=learning_rate)

    wandb.watch(G)
    wandb.watch(D)
    for epoch in range(num_epochs):
        D.train()
        G.train()
        for idx, ( real_images, _ ) in enumerate(tqdm(dataloader)):
            if cuda:
                real_images = real_images.cuda()

            batch_size = real_images.size(0)
            real_images = real_images * 2 - 1

            g_loss_value = 0.0
            d_loss_value = 0.0
            for phase in ['discriminator', 'generator']:
                # TRAIN DISCRIMINATOR
                if phase == 'discriminator':
                    # generate fake images from latent vector
                    latent_vector = np.random.uniform(-1, 1, size=(batch_size, latent_size))
                    latent_vector = torch.from_numpy(latent_vector).float()
                    if cuda:
                        latent_vector = latent_vector.cuda()
                    fake_images = G(latent_vector)

                    # compute discriminator loss on real images
                    d_optimizer.zero_grad()
                    d_real = D(real_images)
                    d_real_loss = real_loss(d_real, smooth=True)

                    # compute discriminator loss in fake images
                    d_fake = D(fake_images)
                    d_fake_loss = fake_loss(d_fake)

                    # total loss, backprop, optimize and update weights
                    d_loss = d_real_loss + d_fake_loss
                    d_loss_value = d_loss.item()

                    d_loss.backward()
                    d_optimizer.step()

                # TRAIN GENERATOR
                if phase == 'generator':
                    latent_vector = np.random.uniform(-1, 1, size=(batch_size, latent_size))
                    latent_vector = torch.from_numpy(latent_vector).float()
                    if cuda:
                      latent_vector = latent_vector.cuda()
                    fake_images = G(latent_vector)
                    
                    g_optimizer.zero_grad()
                    d_fake = D(fake_images)
                    g_loss = real_loss(d_fake)
                    g_loss_value = g_loss.item()

                    g_loss.backward()
                    g_optimizer.step()

            if idx % 100 == 0: 
                pass
                wandb.log({ 'G Loss': g_loss_value, 'D Loss': d_loss_value })
        wandb.log({ 'G Epoch Loss': g_loss_value, 'D Epoch Loss': d_loss_value }, step=epoch)
        
        # test performance
        G.eval()
        gen_img = G(fixed_img)
        gen_imgs.append(gen_img)
    
    # dump generated images
    with open('gen_imgs.pkl', 'wb') as f:
        pkl.dump(gen_imgs, f)
コード例 #11
0
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        m.weight.data.normal_(0.0, 0.02)


# create the objects for loss function, two networks and for the two optimizers
batchnorm = True
adversarial_loss = torch.nn.BCELoss()
generator = Generator(batchnorm=batchnorm,
                      latent=opt.latent,
                      img_shape=img_shape)
discriminator = Discriminator(batchnorm=batchnorm, img_shape=img_shape)
optimizer_G = torch.optim.Adam(generator.parameters(),
                               lr=opt.learning_rate,
                               betas=(opt.beta_1, opt.beta_2))
optimizer_D = torch.optim.Adam(discriminator.parameters(),
                               lr=opt.learning_rate,
                               betas=(opt.beta_1, opt.beta_2))

# put the nets on device - if a cuda gpu is installed it will use it
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
generator, discriminator = generator.to(device), discriminator.to(device)

# initialize weights from random distribution with mean 0 and std 0.02
generator.apply(weights_init)
discriminator.apply(weights_init)

if batchnorm:
    if not os.path.isdir(ROOT_DIR + "/images-batchnorm"):
        os.mkdir(ROOT_DIR + "/images-batchnorm")
else:
コード例 #12
0
    os.mkdir("./model")

batch_size = 64
z_dimension = 100
num_epoch = 1000

D = Discriminator()
G = Generator(z_dimension)

# if torch.cuda.is_available:
D = D.cuda()
G = G.cuda()

criterion = nn.BCELoss()

d_optimizer = torch.optim.Adam(D.parameters(), lr=0.001)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.001)

for epoch in range(num_epoch):
    for i, (img, _) in enumerate(dataloader):
        num_img = img.size(0)
        real_img = Variable(img).cuda()
        real_label = Variable(torch.ones(num_img)).cuda()
        fake_label = Variable(torch.zeros(num_img)).cuda()

        real_out = D(real_img)
        d_loss_real = criterion(real_out, real_label)
        real_scores = real_out

        z = Variable(torch.randn(num_img, z_dimension)).cuda()
        fake_img = G(z)
コード例 #13
0
print('device:', device)

# 生成器G。ランダムベクトルから贋作画像を生成する
netG = Generator(nz=nz, nch_g=nch_g).to(device)
netG.apply(weights_init)  # weights_init関数で初期化
print(netG)

# 識別器D。画像が、元画像か贋作画像かを識別する
netD = Discriminator(nch_d=nch_d).to(device)
netD.apply(weights_init)
print(netD)

criterion = nn.MSELoss()  # 損失関数は平均二乗誤差損失

# オプティマイザ−のセットアップ
optimizerD = optim.Adam(netD.parameters(),
                        lr=lr,
                        betas=(beta1, 0.999),
                        weight_decay=1e-5)  # 識別器D用
optimizerG = optim.Adam(netG.parameters(),
                        lr=lr,
                        betas=(beta1, 0.999),
                        weight_decay=1e-5)  # 生成器G用

fixed_noise = torch.randn(batch_size, nz, 1, 1, device=device)  # 確認用の固定したノイズ

# 学習のループ
for epoch in range(n_epoch):
    for itr, data in enumerate(dataloader):
        real_image = data[0].to(device)  # 元画像
        sample_size = real_image.size(0)  # 画像枚数
コード例 #14
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('--dataset',
                        required=True,
                        help='cifar10 | lsun | imagenet | folder | lfw | fake')
    parser.add_argument('--dataroot', required=True, help='path to dataset')
    parser.add_argument('--workers',
                        type=int,
                        help='number of data loading workers',
                        default=2)
    parser.add_argument('--batchSize',
                        type=int,
                        default=50,
                        help='input batch size')
    parser.add_argument(
        '--imageSize',
        type=int,
        default=64,
        help='the height / width of the input image to network')
    parser.add_argument('--nz',
                        type=int,
                        default=100,
                        help='size of the latent z vector')
    parser.add_argument('--nch_gen', type=int, default=512)
    parser.add_argument('--nch_dis', type=int, default=512)
    parser.add_argument('--nepoch',
                        type=int,
                        default=1000,
                        help='number of epochs to train for')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        help='learning rate, default=0.0002')
    parser.add_argument('--beta1',
                        type=float,
                        default=0.9,
                        help='beta1 for adam. default=0.5')
    parser.add_argument('--cuda', action='store_true', help='enables cuda')
    parser.add_argument('--ngpu',
                        type=int,
                        default=1,
                        help='number of GPUs to use')
    parser.add_argument('--gen',
                        default='',
                        help="path to gen (to continue training)")
    parser.add_argument('--dis',
                        default='',
                        help="path to dis (to continue training)")
    parser.add_argument('--outf',
                        default='./result',
                        help='folder to output images and model checkpoints')
    parser.add_argument('--manualSeed', type=int, help='manual seed')

    args = parser.parse_args()
    print(args)

    try:
        os.makedirs(args.outf)
    except OSError:
        pass

    if args.manualSeed is None:
        args.manualSeed = random.randint(1, 10000)
    print("Random Seed: ", args.manualSeed)
    random.seed(args.manualSeed)
    torch.manual_seed(args.manualSeed)

    cudnn.benchmark = True

    if torch.cuda.is_available() and not args.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )

    if args.dataset in ['imagenet', 'folder', 'lfw']:
        # folder dataset
        dataset = dset.ImageFolder(root=args.dataroot,
                                   transform=transforms.Compose([
                                       transforms.Resize(args.imageSize),
                                       transforms.CenterCrop(args.imageSize),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5),
                                                            (0.5, 0.5, 0.5)),
                                   ]))
    elif args.dataset == 'lsun':
        dataset = dset.LSUN(root=args.dataroot,
                            classes=['bedroom_train'],
                            transform=transforms.Compose([
                                transforms.Resize(args.imageSize),
                                transforms.CenterCrop(args.imageSize),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5),
                                                     (0.5, 0.5, 0.5)),
                            ]))
    elif args.dataset == 'cifar10':
        dataset = dset.CIFAR10(root=args.dataroot,
                               download=True,
                               transform=transforms.Compose([
                                   transforms.Resize(args.imageSize),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5),
                                                        (0.5, 0.5, 0.5)),
                               ]))  # [0, +1] -> [-1, +1]
    elif args.dataset == 'fake':
        dataset = dset.FakeData(image_size=(3, args.imageSize, args.imageSize),
                                transform=transforms.ToTensor())

    assert dataset
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=args.batchSize,
                                             shuffle=True,
                                             num_workers=int(args.workers))

    device = torch.device("cuda:0" if args.cuda else "cpu")
    nch_img = 3

    # custom weights initialization called on gen and dis
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            m.weight.data.normal_(0.0, 0.02)
            m.bias.data.normal_(0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            m.weight.data.normal_(1.0, 0.02)
            # m.bias.data.normal_(1.0, 0.02)
            # m.bias.data.fill_(0)

    gen = Generator(args.ngpu, args.nz, args.nch_gen, nch_img).to(device)
    gen.apply(weights_init)
    if args.gen != '':
        gen.load_state_dict(torch.load(args.gen))

    dis = Discriminator(args.ngpu, args.nch_dis, nch_img).to(device)
    dis.apply(weights_init)
    if args.dis != '':
        dis.load_state_dict(torch.load(args.dis))

    # criterion = nn.BCELoss()
    criterion = nn.MSELoss()

    # fixed_z = torch.randn(args.batchSize, args.nz, 1, 1, device=device)
    fixed_z = torch.randn(8 * 8, args.nz, 1, 1, device=device)
    a_label = 0
    b_label = 1
    c_label = 1

    # setup optimizer
    optim_dis = optim.Adam(dis.parameters(),
                           lr=args.lr,
                           betas=(args.beta1, 0.999))
    optim_gen = optim.Adam(gen.parameters(),
                           lr=args.lr,
                           betas=(args.beta1, 0.999))

    for epoch in range(args.nepoch):
        for itr, data in enumerate(dataloader, 0):
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            # train with real
            dis.zero_grad()
            real_img = data[0].to(device)
            batch_size = real_img.size(0)
            label = torch.full((batch_size, ), b_label, device=device)

            dis_real = dis(real_img)
            loss_dis_real = criterion(dis_real, label)
            loss_dis_real.backward()

            # train with fake
            z = torch.randn(batch_size, args.nz, 1, 1, device=device)
            fake_img = gen(z)
            label.fill_(a_label)

            dis_fake1 = dis(fake_img.detach())
            loss_dis_fake = criterion(dis_fake1, label)
            loss_dis_fake.backward()

            loss_dis = loss_dis_real + loss_dis_fake
            optim_dis.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            gen.zero_grad()
            label.fill_(c_label)  # fake labels are real for generator cost

            dis_fake2 = dis(fake_img)
            loss_gen = criterion(dis_fake2, label)
            loss_gen.backward()
            optim_gen.step()

            if (itr + 1) % 100 == 0:
                print(
                    '[{}/{}][{}/{}] LossD:{:.4f} LossG:{:.4f} D(x):{:.4f} D(G(z)):{:.4f}/{:.4f}'
                    .format(epoch + 1, args.nepoch, itr + 1, len(dataloader),
                            loss_dis.item(), loss_gen.item(),
                            dis_real.mean().item(),
                            dis_fake1.mean().item(),
                            dis_fake2.mean().item()))
            # loop end iteration

        if epoch == 0:
            vutils.save_image(real_img,
                              '{}/real_samples.png'.format(args.outf),
                              normalize=True)

        fake_img = gen(fixed_z)
        vutils.save_image(fake_img.detach(),
                          '{}/fake_samples_epoch_{:04}.png'.format(
                              args.outf, epoch),
                          normalize=True)

        # do checkpointing
        torch.save(gen.state_dict(),
                   '{}/gen_epoch_{}.pth'.format(args.outf, epoch))
        torch.save(dis.state_dict(),
                   '{}/dis_epoch_{}.pth'.format(args.outf, epoch))
コード例 #15
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--workers',
                        type=int,
                        help='number of data loading workers',
                        default=2)
    parser.add_argument('--batch',
                        type=int,
                        default=50,
                        help='input batch size')
    parser.add_argument('--nz',
                        type=int,
                        default=100,
                        help='size of the latent z vector')
    parser.add_argument('--ng_ch', type=int, default=64)
    parser.add_argument('--nd_ch', type=int, default=64)
    parser.add_argument('--epoch',
                        type=int,
                        default=50,
                        help='number of epochs to train for')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0002,
                        help='learning rate, default=0.0002')
    parser.add_argument('--beta1',
                        type=float,
                        default=0.5,
                        help='beta1 for adam. default=0.5')
    parser.add_argument('--outf',
                        default='./result',
                        help='folder to output images and model checkpoints')

    opt = parser.parse_args()
    print(opt)

    batch_size = opt.batch
    epoch_size = opt.epoch

    try:
        os.makedirs(opt.outf)
    except OSError:
        pass

    random.seed(0)
    torch.manual_seed(0)

    dataset = dset.SVHN(root='../svhn_root',
                        download=True,
                        transform=transforms.Compose([
                            transforms.Resize(64),
                            transforms.ColorJitter(brightness=0,
                                                   contrast=0,
                                                   saturation=0,
                                                   hue=0.5),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5),
                                                 (0.5, 0.5, 0.5)),
                        ]))

    dataloader = torch.utils.data.DataLoader(dataset[:50000],
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=int(opt.workers))

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print('device:', device)

    nz = int(opt.nz)

    netG = Generator().to(device)
    netG.apply(weights_init)
    print(netG)

    netD = Discriminator().to(device)
    netD.apply(weights_init)
    print(netD)

    criterion = nn.MSELoss()  # criterion = nn.BCELoss()

    fixed_noise = torch.randn(batch_size, nz, 1, 1, device=device)
    real_label = 1
    fake_label = 0

    # setup optimizer
    optimizerD = optim.Adam(netD.parameters(),
                            lr=opt.lr,
                            betas=(opt.beta1, 0.999),
                            weight_decay=1e-5)
    optimizerG = optim.Adam(netG.parameters(),
                            lr=opt.lr,
                            betas=(opt.beta1, 0.999),
                            weight_decay=1e-5)

    for epoch in range(epoch_size):
        for itr, data in enumerate(dataloader):
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            # train with real
            netD.zero_grad()
            real_image = data[0].to(device)

            sample_size = real_image.size(0)
            label = torch.full((sample_size, ), real_label, device=device)

            output = netD(real_image)
            errD_real = criterion(output, label)
            errD_real.backward()
            D_x = output.mean().item()

            # train with fake
            noise = torch.randn(sample_size, nz, 1, 1, device=device)
            fake_image = netG(noise)
            label.fill_(fake_label)
            output = netD(fake_image.detach())

            errD_fake = criterion(output, label)
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            errD = errD_real + errD_fake
            optimizerD.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            netG.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost
            output = netD(fake_image)
            errG = criterion(output, label)
            errG.backward()
            D_G_z2 = output.mean().item()
            optimizerG.step()

            print(
                '[{}/{}][{}/{}] Loss_D: {:.3f} Loss_G: {:.3f} D(x): {:.3f} D(G(z)): {:.3f}/{:.3f}'
                .format(epoch + 1, opt.n_epoch, itr + 1, len(dataloader),
                        errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

            if epoch == 0 and itr == 0:
                vutils.save_image(real_image,
                                  '{}/real_samples.png'.format(opt.outf),
                                  normalize=True,
                                  nrow=10)

        fake_image = netG(fixed_noise)
        vutils.save_image(fake_image.detach(),
                          '{}/fake_samples_epoch_{:03d}.png'.format(
                              opt.outf, epoch + 1),
                          normalize=True,
                          nrow=10)

        # do checkpointing
        if (epoch + 1) % 100 == 0:
            torch.save(netG.state_dict(),
                       '{}/netG_epoch_{}.pth'.format(opt.outf, epoch + 1))
            torch.save(netD.state_dict(),
                       '{}/netD_epoch_{}.pth'.format(opt.outf, epoch + 1))
コード例 #16
0
ファイル: train.py プロジェクト: loretoparisi/papers
      discriminator.zero_grad()

      # calculate loss for batch src projected in target
      discProjSrc = discriminator(projectedSrc).squeeze()
      loss = loss_fn(discProjSrc, zeroClass)
      discLoss = discLoss + loss.data[0]
      loss.backward()

      # loss for tgt classified with smoothed label
      discTgt = discriminator(batch_tgt).squeeze()
      loss = loss_fn(discTgt, smoothedOneClass)
      discLoss = discLoss + loss.data[0]
      loss.backward()

      for param in discriminator.parameters():
        param.data -= learningRate * param.grad.data

    bsrcIdx = torch.min((torch.rand(args.batchSize)*N).long(), torch.LongTensor([N-1]))
    batch_src = Variable(torch.index_select(semb, 0, bsrcIdx))
    if args.gpuid>=0:
      with torch.cuda.device(args.gpuid):
        batch_src = batch_src.cuda()

    # calculate loss for batch src projected in target
    projectedSrc = generator(batch_src)
    discProjSrc = discriminator(projectedSrc).squeeze()

    generator.zero_grad()
    loss = loss_fn(discProjSrc, oneClass)
    genLoss = genLoss + loss.data[0]
コード例 #17
0
class GAN(object):
    def __init__(self):
        warnings.filterwarnings('ignore')
        self.start_time = time()

        self.args = get_args()
        if self.args.checkpoint_dir_name:
            dir_name = self.args.checkpoint_dir_name
        else:
            dir_name = datetime.datetime.now().strftime('%y%m%d%H%M%S')
        self.path_to_dir = Path(__file__).resolve().parents[1]
        self.path_to_dir = os.path.join(self.path_to_dir, *['log', dir_name])
        os.makedirs(self.path_to_dir, exist_ok=True)

        # tensorboard
        path_to_tensorboard = os.path.join(self.path_to_dir, 'tensorboard')
        os.makedirs(path_to_tensorboard, exist_ok=True)
        self.writer = SummaryWriter(path_to_tensorboard)

        # model saving
        os.makedirs(os.path.join(self.path_to_dir, 'model'), exist_ok=True)
        path_to_model = os.path.join(self.path_to_dir, *['model', 'model.tar'])

        # csv
        os.makedirs(os.path.join(self.path_to_dir, 'csv'), exist_ok=True)
        self.path_to_results_csv = os.path.join(self.path_to_dir,
                                                *['csv', 'results.csv'])
        path_to_args_csv = os.path.join(self.path_to_dir, *['csv', 'args.csv'])
        if not self.args.checkpoint_dir_name:
            with open(path_to_args_csv, 'a') as f:
                args_dict = vars(self.args)
                param_writer = csv.DictWriter(f, list(args_dict.keys()))
                param_writer.writeheader()
                param_writer.writerow(args_dict)

        # logging by hyperdash
        if not self.args.no_hyperdash:
            from hyperdash import Experiment
            self.exp = Experiment('Generation task on ' + self.args.dataset +
                                  ' dataset with GAN')
            for key in vars(self.args).keys():
                exec("self.args.%s = self.exp.param('%s', self.args.%s)" %
                     (key, key, key))
        else:
            self.exp = None

        self.dataloader = get_dataloader(self.args.dataset,
                                         self.args.image_size,
                                         self.args.batch_size)
        sample_data = self.dataloader.__iter__().__next__()[0]
        image_channels = sample_data.shape[1]

        z = torch.randn(self.args.batch_size, self.args.z_dim)
        self.sample_z = z.view(z.size(0), z.size(1), 1, 1)

        self.Generator = Generator(self.args.z_dim, image_channels,
                                   self.args.image_size)
        self.Generator_optimizer = optim.Adam(self.Generator.parameters(),
                                              lr=self.args.lr_Generator,
                                              betas=(self.args.beta1,
                                                     self.args.beta2))
        self.writer.add_graph(self.Generator, self.sample_z)
        self.Generator.to(self.args.device)

        self.Discriminator = Discriminator(image_channels,
                                           self.args.image_size)
        self.Discriminator_optimizer = optim.Adam(
            self.Discriminator.parameters(),
            lr=self.args.lr_Discriminator,
            betas=(self.args.beta1, self.args.beta2))
        self.writer.add_graph(self.Discriminator, sample_data)
        self.Discriminator.to(self.args.device)

        self.BCELoss = nn.BCELoss()

        self.sample_z = self.sample_z.to(self.args.device)

    def train(self):
        self.train_hist = {}
        self.train_hist['Generator_loss'] = 0.0
        self.train_hist['Discriminator_loss'] = 0.0

        # real ---> y = 1
        # fake ---> y = 0
        self.y_real = torch.ones(self.args.batch_size, 1).to(self.args.device)
        self.y_fake = torch.zeros(self.args.batch_size, 1).to(self.args.device)

        self.Discriminator.train()

        global_step = 0
        #  -----training -----
        for epoch in range(1, self.args.n_epoch + 1):
            self.Generator.train()
            for idx, (x, _) in enumerate(self.dataloader):
                if idx == self.dataloader.dataset.__len__(
                ) // self.args.batch_size:
                    break

                z = torch.randn(self.args.batch_size, self.args.z_dim)
                z = z.view(z.size(0), z.size(1), 1, 1)
                z = z.to(self.args.device)
                x = x.to(self.args.device)

                # ----- update Discriminator -----
                # minimize: -{ log[D(x)] + log[1-D(G(z))] }
                self.Discriminator_optimizer.zero_grad()
                # real
                # ---> log[D(x)]
                Discriminator_real, _ = self.Discriminator(x)
                Discriminator_real_loss = self.BCELoss(Discriminator_real,
                                                       self.y_real)
                # fake
                # ---> log[1-D(G(z))]
                Discriminator_fake, _ = self.Discriminator(self.Generator(z))
                Discriminator_fake_loss = self.BCELoss(Discriminator_fake,
                                                       self.y_fake)

                Discriminator_loss = Discriminator_real_loss + Discriminator_fake_loss
                self.train_hist[
                    'Discriminator_loss'] = Discriminator_loss.item()

                Discriminator_loss.backward()
                self.Discriminator_optimizer.step()

                # ----- update Generator -----
                # As stated in the original paper,
                # we want to train the Generator
                # by minimizing log(1−D(G(z)))
                # in an effort to generate better fakes.
                # As mentioned, this was shown by Goodfellow
                # to not provide sufficient gradients,
                # especially early in the learning process.
                # As a fix, we instead wish to maximize log(D(G(z))).
                # ---> minimize: -log[D(G(z))]

                self.Generator_optimizer.zero_grad()
                Discriminator_fake, _ = self.Discriminator(self.Generator(z))
                Generator_loss = self.BCELoss(Discriminator_fake, self.y_real)
                self.train_hist['Generator_loss'] = Generator_loss.item()
                Generator_loss.backward()
                self.Generator_optimizer.step()

                # ----- logging by tensorboard, csv and hyperdash
                # tensorboard
                self.writer.add_scalar('loss/Generator_loss',
                                       Generator_loss.item(), global_step)
                self.writer.add_scalar('loss/Discriminator_loss',
                                       Discriminator_loss.item(), global_step)
                # csv
                with open(self.path_to_results_csv, 'a') as f:
                    result_writer = csv.DictWriter(
                        f, list(self.train_hist.keys()))
                    if epoch == 1 and idx == 0: result_writer.writeheader()
                    result_writer.writerow(self.train_hist)
                # hyperdash
                if self.exp:
                    self.exp.metric('Generator loss', Generator_loss.item())
                    self.exp.metric('Discriminator loss',
                                    Discriminator_loss.item())

                if (idx % 10) == 0:
                    self._plot_sample(global_step)
                global_step += 1

        elapsed_time = time() - self.start_time
        print('\nTraining Finish, elapsed time ---> %f' % (elapsed_time))

    def _plot_sample(self, global_step):
        with torch.no_grad():
            total_n_sample = min(self.args.n_sample, self.args.batch_size)
            image_frame_dim = int(np.floor(np.sqrt(total_n_sample)))
            samples = self.Generator(self.sample_z)
            samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1)
            samples = (samples + 1) / 2
            fig = plt.figure(figsize=(24, 15))
            for i in range(image_frame_dim * image_frame_dim):
                ax = fig.add_subplot(
                    image_frame_dim,
                    image_frame_dim * 2,
                    (int(i / image_frame_dim) + 1) * image_frame_dim + i + 1,
                    xticks=[],
                    yticks=[])
                if samples[i].shape[2] == 3:
                    ax.imshow(samples[i])
                else:
                    ax.imshow(samples[i][:, :, 0], cmap='gray')
            self.writer.add_figure('sample images generated by GAN', fig,
                                   global_step)
コード例 #18
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--workers',
                        type=int,
                        help='number of data loading workers',
                        default=2)
    parser.add_argument('--batch_size',
                        type=int,
                        default=50,
                        help='input batch size')
    parser.add_argument('--nz',
                        type=int,
                        default=100,
                        help='size of the latent z vector')
    parser.add_argument('--nch_g', type=int, default=64)
    parser.add_argument('--nch_d', type=int, default=64)
    parser.add_argument('--n_epoch',
                        type=int,
                        default=200,
                        help='number of epochs to train for')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0002,
                        help='learning rate, default=0.0002')
    parser.add_argument('--beta1',
                        type=float,
                        default=0.5,
                        help='beta1 for adam. default=0.5')
    parser.add_argument('--outf',
                        default='./result_cgan',
                        help='folder to output images and model checkpoints')

    opt = parser.parse_args()
    print(opt)

    try:
        os.makedirs(opt.outf)
    except OSError:
        pass

    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)

    trainset = dset.STL10(root='./dataset/stl10_root',
                          download=True,
                          split='train',
                          transform=transforms.Compose([
                              transforms.RandomResizedCrop(64,
                                                           scale=(88 / 96,
                                                                  1.0),
                                                           ratio=(1., 1.)),
                              transforms.RandomHorizontalFlip(),
                              transforms.ColorJitter(brightness=0.05,
                                                     contrast=0.05,
                                                     saturation=0.05,
                                                     hue=0.05),
                              transforms.ToTensor(),
                              transforms.Normalize((0.5, 0.5, 0.5),
                                                   (0.5, 0.5, 0.5)),
                          ]))  # ラベルを使用するのでunlabeledを含めない
    testset = dset.STL10(root='./dataset/stl10_root',
                         download=True,
                         split='test',
                         transform=transforms.Compose([
                             transforms.RandomResizedCrop(64,
                                                          scale=(88 / 96, 1.0),
                                                          ratio=(1., 1.)),
                             transforms.RandomHorizontalFlip(),
                             transforms.ColorJitter(brightness=0.05,
                                                    contrast=0.05,
                                                    saturation=0.05,
                                                    hue=0.05),
                             transforms.ToTensor(),
                             transforms.Normalize((0.5, 0.5, 0.5),
                                                  (0.5, 0.5, 0.5)),
                         ]))
    dataset = trainset + testset

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=opt.batch_size,
                                             shuffle=True,
                                             num_workers=int(opt.workers))

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print('device:', device)

    # 生成器G。ランダムベクトルとラベルを連結したベクトルから贋作画像を生成する
    netG = Generator(nz=opt.nz + 10, nch_g=opt.nch_g).to(
        device)  # 入力ベクトルの次元は、ランダムベクトルの次元nzにクラス数10を加算したもの
    netG.apply(weights_init)
    print(netG)

    # 識別器D。画像とラベルを連結したTensorが、元画像か贋作画像かを識別する
    netD = Discriminator(nch=3 + 10, nch_d=opt.nch_d).to(
        device)  # 入力Tensorのチャネル数は、画像のチャネル数3にクラス数10を加算したもの
    netD.apply(weights_init)
    print(netD)

    criterion = nn.MSELoss()

    optimizerD = optim.Adam(netD.parameters(),
                            lr=opt.lr,
                            betas=(opt.beta1, 0.999),
                            weight_decay=1e-5)
    optimizerG = optim.Adam(netG.parameters(),
                            lr=opt.lr,
                            betas=(opt.beta1, 0.999),
                            weight_decay=1e-5)

    fixed_noise = torch.randn(opt.batch_size, opt.nz, 1, 1, device=device)

    fixed_label = [i for i in range(10)] * (opt.batch_size // 10
                                            )  # 確認用のラベル。0〜9のラベルの繰り返し
    fixed_label = torch.tensor(fixed_label, dtype=torch.long, device=device)

    fixed_noise_label = concat_noise_label(fixed_noise, fixed_label,
                                           device)  # 確認用のノイズとラベルを連結

    # 学習のループ
    for epoch in range(opt.n_epoch):
        for itr, data in enumerate(dataloader):
            real_image = data[0].to(device)  # 元画像
            real_label = data[1].to(device)  # 元画像に対応するラベル
            real_image_label = concat_image_label(real_image, real_label,
                                                  device)  # 元画像とラベルを連結

            sample_size = real_image.size(0)
            noise = torch.randn(sample_size, opt.nz, 1, 1, device=device)
            fake_label = torch.randint(10, (sample_size, ),
                                       dtype=torch.long,
                                       device=device)  # 贋作画像生成用のラベル
            fake_noise_label = concat_noise_label(noise, fake_label,
                                                  device)  # ノイズとラベルを連結

            real_target = torch.full((sample_size, ), 1., device=device)
            fake_target = torch.full((sample_size, ), 0., device=device)

            ############################
            # 識別器Dの更新
            ###########################
            netD.zero_grad()

            output = netD(real_image_label)  # 識別器Dで元画像とラベルの組み合わせに対する識別信号を出力
            errD_real = criterion(output, real_target)
            D_x = output.mean().item()

            fake_image = netG(fake_noise_label)  # 生成器Gでラベルに対応した贋作画像を生成
            fake_image_label = concat_image_label(fake_image, fake_label,
                                                  device)  # 贋作画像とラベルを連結

            output = netD(
                fake_image_label.detach())  # 識別器Dで贋作画像とラベルの組み合わせに対する識別信号を出力
            errD_fake = criterion(output, fake_target)
            D_G_z1 = output.mean().item()

            errD = errD_real + errD_fake
            errD.backward()
            optimizerD.step()

            ############################
            # 生成器Gの更新
            ###########################
            netG.zero_grad()

            output = netD(
                fake_image_label)  # 更新した識別器Dで改めて贋作画像とラベルの組み合わせに対する識別信号を出力
            errG = criterion(output, real_target)
            errG.backward()
            D_G_z2 = output.mean().item()

            optimizerG.step()

            print(
                '[{}/{}][{}/{}] Loss_D: {:.3f} Loss_G: {:.3f} D(x): {:.3f} D(G(z)): {:.3f}/{:.3f}'
                .format(epoch + 1, opt.n_epoch, itr + 1, len(dataloader),
                        errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

            if epoch == 0 and itr == 0:
                vutils.save_image(real_image,
                                  '{}/real_samples.png'.format(opt.outf),
                                  normalize=True,
                                  nrow=10)

        ############################
        # 確認用画像の生成
        ############################
        fake_image = netG(
            fixed_noise_label)  # 1エポック終了ごとに、指定したラベルに対応する贋作画像を生成する
        vutils.save_image(fake_image.detach(),
                          '{}/fake_samples_epoch_{:03d}.png'.format(
                              opt.outf, epoch + 1),
                          normalize=True,
                          nrow=10)

        ############################
        # モデルの保存
        ############################
        if (epoch + 1) % 50 == 0:
            torch.save(netG.state_dict(),
                       '{}/netG_epoch_{}.pth'.format(opt.outf, epoch + 1))
            torch.save(netD.state_dict(),
                       '{}/netD_epoch_{}.pth'.format(opt.outf, epoch + 1))