Beispiel #1
0
    def __init__(self, opt, layers):

        # ==== model ====
        self.gen = Gen(n_features=opt.n_features,
                       layers=layers,
                       temperature=opt.temperature)
        self.dis = Dis(n_features=opt.n_features, layers=layers)

        # ==== optimizers
        # Adam
        if opt.opt == 'adam':
            self.optimizer_gen = torch.optim.Adam(
                self.gen.parameters(),
                lr=opt.g_lr,
                weight_decay=opt.weight_decay)
            self.optimizer_dis = torch.optim.Adam(
                self.dis.parameters(),
                lr=opt.d_lr,
                weight_decay=opt.weight_decay)
        # SGD
        else:
            self.optimizer_gen = torch.optim.SGD(self.gen.parameters(),
                                                 lr=opt.g_lr,
                                                 weight_decay=opt.weight_decay,
                                                 momentum=opt.momentum)
            self.optimizer_dis = torch.optim.SGD(self.dis.parameters(),
                                                 lr=opt.d_lr,
                                                 weight_decay=opt.weight_decay,
                                                 momentum=opt.momentum)
        # ==== cuda ====

        self.cuda = True if torch.cuda.is_available() and opt.cuda else False
        self.FloatTensor = torch.cuda.FloatTensor if self.cuda else torch.FloatTensor
        self.LongTensor = torch.cuda.LongTensor if self.cuda else torch.LongTensor
        if self.cuda:
            self.gen.cuda()
            self.dis.cuda()
        print('use cuda : {}'.format(self.cuda))

        # ==== opt settings ====
        self.n_positions = opt.n_positions
        self.n = opt.n
        self.tools_dir = opt.tools_dir
        self.p = opt.p
        self.norm = opt.norm

        # bias_i, bias_j
        self.t_plus = [1 for _ in range(self.n_positions)]
        self.t_minus = [1 for _ in range(self.n_positions)]
        self.exam = [1 for _ in range(self.n_positions)]
Beispiel #2
0
def main():
    # Load dataset
    print('Loading dataset ...\n')
    dataset_train = Dataset(train=True)
    dataset_val = Dataset(train=False)
    loader_train = DataLoader(dataset=dataset_train,
                              num_workers=4,
                              batch_size=opt.batchSize,
                              shuffle=True)
    print("# of training samples: %d\n" % int(len(dataset_train)))

    # Build modelG
    resume_epoch = 0
    modelG = Gen(channels=opt.channels)
    modelG.apply(weights_init_kaiming)
    Gparam = sum(param.numel() for param in modelG.parameters())
    print('# modelG parameters:', Gparam)

    if opt.modelG != '':
        modelG.load_state_dict(
            torch.load(
                opt.modelG,
                map_location=lambda storage, location: storage)['state_dict'])
        resume_epoch = torch.load(opt.modelG)['epoch']

    # Build modelD
    modelD = Dis(channels=opt.channels)
    modelG.apply(weights_init_kaiming)
    Dparam = sum(param.numel() for param in modelD.parameters())
    print('# modelD parameters:', Dparam)

    if opt.modelD != '':
        modelD.load_state_dict(
            torch.load(
                opt.modelD,
                map_location=lambda storage, location: storage)['state_dict'])
        resume_epoch = torch.load(opt.modelD)['epoch']

    criterionBCE = nn.BCELoss()
    criterionMSE = nn.MSELoss()

    modelG.cuda()
    modelD.cuda()
    criterionBCE.cuda()
    criterionMSE.cuda()

    label = torch.FloatTensor(opt.batchSize)
    real_label = 1
    fake_label = 0
    label = Variable(label)

    # Optimizer
    optimizerG = optim.Adam(modelG.parameters(), lr=opt.lr)
    optimizerD = optim.Adam(modelD.parameters(), lr=opt.lr)

    # training
    step = 0
    for epoch in range(opt.epochs):
        if epoch < opt.milestone:
            current_lr = opt.lr
        else:
            current_lr = opt.lr / 10.
        # set learning rate
        for param_group in optimizerG.param_groups:
            param_group["lr"] = current_lr

        print('learning rate %f' % current_lr)
        # train
        for i, data in enumerate(loader_train, 0):
            # data
            img_train = data
            batch_size = img_train.size(0)

            noise = torch.zeros(img_train.size())
            stdN = np.random.uniform(opt.train_noise[0],
                                     opt.train_noise[1],
                                     size=noise.size()[0])
            for n in range(noise.size()[0]):
                sizeN = noise[0, :, :, :].size()
                noise[n, :, :, :] = torch.FloatTensor(sizeN).normal_(
                    mean=0, std=stdN[n] / 255.)
            imgn_train = img_train + noise
            img_train, imgn_train = Variable(img_train.cuda()), Variable(
                imgn_train.cuda())
            noise = Variable(noise.cuda())

            # train D
            fake = modelG(imgn_train)
            modelD.zero_grad()
            label.data.resize_(batch_size).fill_(real_label)
            for index1 in range(0, 255, 128):
                for index2 in range(0, 255, 128):
                    img_trainT = imgn_train[:, :, index1:index1 + 128,
                                            index1:index1 + 128]

                    output = modelD(img_trainT)
                    errD_real = criterionBCE(output, label)
                    errD_real.backward()

                    fakeT = fake[:, :, index1:index1 + 128,
                                 index1:index1 + 128]
                    label.data.fill_(fake_label)
                    output = modelD(fake.detach())
                    errD_fake = criterionBCE(output, label)
                    errD_fake.backward()

                    errD = errD_real + errD_fake
                    optimizerD.step()

            # train G
            modelG.train()
            modelG.zero_grad()
            optimizerG.zero_grad()
            label.data.fill_(real_label)
            errG_D = 0
            for index1 in range(0, 255, 128):
                for index2 in range(0, 255, 128):
                    fakeT = fake[:, :, index1:index1 + 128,
                                 index1:index1 + 128]
                    output = modelD(fakeT)
                    errG_D += criterionBCE(output, label) / 4.

            out_train = modelG(imgn_train)
            loss = criterionMSE(out_train, noise) + 0.01 * errG_D
            loss.backward()
            optimizerG.step()

            # results
            modelG.eval()
            denoise_image = torch.clamp(imgn_train - modelG(imgn_train), 0.,
                                        1.)
            psnr_train = batch_PSNR(denoise_image, img_train, 1.)

            print(
                "[epoch %d][%d/%d] Loss_G: %.4f PSNR_train: %.4f" %
                (epoch + 1, i + 1, len(loader_train), loss.item(), psnr_train))
            step += 1

        # log the images
        torch.save({
            'epoch': epoch + 1,
            'state_dict': modelG.state_dict()
        }, 'model/modelG.pth')
Beispiel #3
0
    dataset = Data(nb=1000)

assert dataset
dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=opt.batchSize,
                                         shuffle=True,
                                         num_workers=int(opt.workers))

device = torch.device("cuda:0" if opt.cuda else "cpu")
ngpu = int(opt.ngpu)
nz = int(opt.nz)
ngf = int(opt.ngf)
ndf = int(opt.ndf)

netG = Gen(nz=nz, nc=nc, w=opt.imageSize).to(device)
netG.apply(weights_init)
if opt.netG != '':
    netG.load_state_dict(torch.load(opt.netG))
print(netG)
netD = Discr(nc=nc, w=opt.imageSize).to(device)
netD.apply(weights_init)
if opt.netD != '':
    netD.load_state_dict(torch.load(opt.netD))
print(netD)

criterion = nn.BCELoss()

fixed_noise = torch.randn(opt.batchSize, nz, 1, 1, device=device)
real_label = 1
fake_label = 0
Beispiel #4
0
def main():
    # Build model

    print('Loading model ...\n')
    model = Gen(channels=1)
    model.cuda()
    model.load_state_dict(
        torch.load(
            opt.modelG,
            map_location=lambda storage, location: storage)['state_dict'])
    model.eval()
    # load data info
    print('Loading data info ...\n')
    types = ('*.bmp', '*.png', '*.jpg')
    files = []
    for im in types:
        files.extend(glob.glob(os.path.join(opt.dataroot, im)))
    files.sort()
    # process data
    psnr_test = 0
    ssim_test = 0
    results_psnr = []
    results_ssim = []
    it = 0
    for f in files:
        # image
        if opt.channels == 3:
            Img = cv2.imread(f)
            Img = (cv2.cvtColor(Img, cv2.COLOR_BGR2RGB)).transpose(2, 0, 1)
        else:
            Img = cv2.imread(f, cv2.IMREAD_GRAYSCALE)
            Img = np.expand_dims(Img, 0)

        Img = np.expand_dims(Img, 0)

        if Img.shape[2] % 32 != 0:
            Img = Img[:, :, :Img.shape[2] - Img.shape[2] % 32, :]

        if Img.shape[3] % 32 != 0:
            Img = Img[:, :, :, :Img.shape[3] - Img.shape[3] % 32]

        Img = normalize(Img)
        ISource = torch.Tensor(Img)
        N, C, H, W = ISource.size()

        dtype = torch.cuda.FloatTensor
        noise = torch.FloatTensor(ISource.size()).normal_(mean=0,
                                                          std=opt.noise_ratio /
                                                          255.)

        INoisy = ISource + noise
        ISource, INoisy = Variable(ISource.type(dtype)), Variable(
            INoisy.type(dtype))
        starttime = datetime.datetime.now()
        Out = torch.clamp(INoisy - model(INoisy), 0., 1.)
        endtime = datetime.datetime.now()

        psnr = batch_PSNR(Out, ISource, 1.)
        ssim = batch_SSIM(Out, ISource, 1.)
        psnrniose = batch_PSNR(INoisy, ISource, 1.)
        ssimniose = batch_SSIM(INoisy, ISource, 1.)
        results_psnr.append(psnr)
        results_ssim.append(ssim)
        psnr_test += psnr
        ssim_test += ssim
        print("%s PSNR %f SSIM %f" % (f, psnr, ssim))
        real = variable_to_cv2_image(ISource)
        denoise = variable_to_cv2_image(Out)
        noise = variable_to_cv2_image(INoisy)
        cv2.imwrite("./test-results/noise/%d.png" % it, noise)
        cv2.imwrite("./test-results/denoise/%d.png" % it, denoise)
        cv2.imwrite("./test-results/real/%d.png" % it, real)
        it += 1
    psnr_test /= len(files)
    ssim_test /= len(files)
    print("\nPSNR on test data %f SSIM on test data %f" %
          (psnr_test, ssim_test))
    std_psnr = np.std(results_psnr)
    std_ssim = np.std(results_ssim)
    print("\nPSNRstd on test data %f SSIMstd on test data %f" %
          (std_psnr, std_ssim))
    print(endtime - starttime)
Beispiel #5
0
#un-normalize output
mean, std = [0.0063, 0.0063, 0.9791], [0.0613, 0.0612, 0.1262]
unNormalize = transforms.Normalize(mean=[-m / d for m, d in zip(mean, std)],
                                   std=[1.0 / d for d in std])

#assign correct device
if torch.cuda.is_available():
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')

#Networks
saveFileGen = "gen.pth"
saveFileDiscrim = "discrim.pth"
gen = Gen().to(device)
try:
    gen.load_state_dict(torch.load(saveFileGen))
except:
    print("No weights found, generator initialized")
#discrim = Discrim().to(device)
#try:
#    discrim.load_state_dict(saveFileDiscrim)
#except:
#    print("No weights found, discriminator initialized")

dirin = "database/validSetLines/input/"
dirout = "database/validSetLines/result/"
imMax = 100
for i in range(imMax):
    img = f'{i:05}' + '.png'
Beispiel #6
0
class Trainer:
    def __init__(self, opt, layers):

        # ==== model ====
        self.gen = Gen(n_features=opt.n_features,
                       layers=layers,
                       temperature=opt.temperature)
        self.dis = Dis(n_features=opt.n_features, layers=layers)

        # ==== optimizers
        # Adam
        if opt.opt == 'adam':
            self.optimizer_gen = torch.optim.Adam(
                self.gen.parameters(),
                lr=opt.g_lr,
                weight_decay=opt.weight_decay)
            self.optimizer_dis = torch.optim.Adam(
                self.dis.parameters(),
                lr=opt.d_lr,
                weight_decay=opt.weight_decay)
        # SGD
        else:
            self.optimizer_gen = torch.optim.SGD(self.gen.parameters(),
                                                 lr=opt.g_lr,
                                                 weight_decay=opt.weight_decay,
                                                 momentum=opt.momentum)
            self.optimizer_dis = torch.optim.SGD(self.dis.parameters(),
                                                 lr=opt.d_lr,
                                                 weight_decay=opt.weight_decay,
                                                 momentum=opt.momentum)
        # ==== cuda ====

        self.cuda = True if torch.cuda.is_available() and opt.cuda else False
        self.FloatTensor = torch.cuda.FloatTensor if self.cuda else torch.FloatTensor
        self.LongTensor = torch.cuda.LongTensor if self.cuda else torch.LongTensor
        if self.cuda:
            self.gen.cuda()
            self.dis.cuda()
        print('use cuda : {}'.format(self.cuda))

        # ==== opt settings ====
        self.n_positions = opt.n_positions
        self.n = opt.n
        self.tools_dir = opt.tools_dir
        self.p = opt.p
        self.norm = opt.norm

        # bias_i, bias_j
        self.t_plus = [1 for _ in range(self.n_positions)]
        self.t_minus = [1 for _ in range(self.n_positions)]
        self.exam = [1 for _ in range(self.n_positions)]

    """
    choose real click docs as pos and GEN output as neg
    """

    def get_sample_pairs(self, features, positions, clicks):

        # == click index ==
        all_f = Variable(features.type(self.FloatTensor))
        clicks = clicks.detach().cpu().numpy()
        index = np.where(clicks == 1)
        unique_group_index, counts_clicks = np.unique(index[0],
                                                      return_counts=True)

        # == positive and negative features ==
        pos_f = features[index]
        neg_f = []
        pos_p = positions[index]
        neg_p = []

        # == GEN output ==
        count = 0
        g_output = self.gen(all_f).detach().cpu().numpy()
        for i in range(len(unique_group_index)):
            g_index = unique_group_index[i]
            exp_rating = np.exp(g_output[g_index] - np.max(g_output[g_index]))
            # remove clicked docs
            for index in range(counts_clicks[i]):
                exp_rating[pos_p[count + index] - 1] = 0
            count += index + 1
            prob = exp_rating / np.sum(exp_rating, axis=-1)
            try:
                neg_index = np.random.choice(self.n_positions,
                                             size=[counts_clicks[i]],
                                             p=prob)
            except:
                neg_index = np.random.choice(self.n_positions,
                                             size=[counts_clicks[i]])

            choose_index = positions[g_index][neg_index].tolist()

            # invalid samples
            if 0 in choose_index:
                choose_index = positions[g_index][neg_index].tolist()
                neg_index = np.random.choice(self.n_positions,
                                             size=[counts_clicks[i]],
                                             p=prob)

            neg_f.extend(features[g_index][neg_index].tolist())
            neg_p.extend(choose_index)

        # == output ==
        pos_f = Variable((tensor(pos_f)).type(self.FloatTensor))
        neg_f = Variable((tensor(neg_f)).type(self.FloatTensor))
        pos_p = Variable((tensor(pos_p)).type(self.LongTensor))
        neg_p = Variable((tensor(neg_p)).type(self.LongTensor))

        try:
            pred_valid = self.dis(pos_f).view(-1)
        except:
            print(pos_f.size())
        pred_fake = self.dis(neg_f).view(-1)
        true_diffs = Variable(self.FloatTensor(len(pos_p)).fill_(1),
                              requires_grad=False)

        return pred_valid, pred_fake, true_diffs, pos_p, neg_p

    """
    train the discriminator
    """

    def train_dis(self,
                  pred_valid,
                  pred_fake,
                  true_diffs,
                  position_i,
                  position_j,
                  forward=True):

        pred_diffs = pred_valid - pred_fake

        # calculate pairwise propensity given (i,j)
        propensity = []
        pos_propensity = []
        neg_propensity = []

        for index in range(len(position_i)):
            i = position_i[index]
            j = position_j[index]
            if i != 0 and j != 0:
                prop = self.t_plus[i - 1] * self.t_minus[j - 1]
                if prop != 0:
                    propensity.append(prop)
                else:
                    propensity.append(1)
            else:
                propensity.append(1)
            if i != 0:
                pos_propensity.append(1 / self.t_plus[i - 1])
            else:
                pos_propensity.append(1)

            if j != 0:
                neg_propensity.append(1 / self.t_minus[j - 1])
            else:
                neg_propensity.append(0)

        propensity = Variable((tensor(propensity)).type(self.FloatTensor))
        true_diffs = true_diffs / propensity

        loss = pairwise_loss(pred_diffs, true_diffs)

        # ==== optimize ====
        if forward:
            self.optimizer_dis.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.dis.parameters(), self.norm)
            self.optimizer_dis.step()
        return loss.item()

    """
    update t_plus and t_minus
    """

    def estimate_bias(self, pred_valid, pred_fake, true_diffs, position_i,
                      position_j):
        pred_diffs = pred_valid - pred_fake
        pred_diffs = pred_diffs.detach().cpu().numpy()
        true_diffs = true_diffs.detach().cpu().numpy()
        position_i = position_i.detach().cpu().numpy()
        position_j = position_j.detach().cpu().numpy()
        # prepare t_j
        t_j = []
        for index in range(len(position_j)):
            j = position_j[index]
            t_j.append(self.t_minus[j - 1])

        # prepare t_i
        t_i = []
        for index in range(len(position_i)):
            i = position_i[index]
            t_i.append(self.t_plus[i - 1])

        # prepare loss
        loss = []
        for i in range(len(self.t_plus)):
            i = i + 1
            index = np.where(position_i == i)
            pred_diffs_i_j = pred_diffs[index]
            true_diffs_i_j = true_diffs[index]
            t_j_index = np.array(t_j)[index]
            loss_i_j = estimate_loss(pred_diffs=pred_diffs_i_j,
                                     true_diffs=true_diffs_i_j) / t_j_index
            loss.append(np.sum(loss_i_j))

        # update t_plus
        if loss[0] == 0:
            loss[0] = np.mean(loss)
        for i in range(len(self.t_plus)):
            if loss[i] == 0:
                loss[i] = np.mean(loss)
            # Eq.(5)
            self.t_plus[i] = np.power(loss[i] / loss[0], 1 / (self.p + 1))

        # prepare loss
        loss = []
        for j in range(len(self.t_minus)):
            j = j + 1
            index = np.where(position_j == j)
            pred_diffs_i_j = pred_diffs[index]
            true_diffs_i_j = true_diffs[index]
            t_i_index = np.array(t_i)[index]
            loss_i_j = estimate_loss(pred_diffs=pred_diffs_i_j,
                                     true_diffs=true_diffs_i_j) / t_i_index
            loss.append(np.sum(loss_i_j))

        # update t_minus
        if loss[0] == 0:
            loss[0] = np.mean(loss)
        for i in range(len(self.t_plus)):
            if loss[i] == 0:
                loss[i] = np.mean(loss)
            # Eq.(6)
            self.t_minus[i] = np.power(loss[i] / loss[0], 1 / (self.p + 1))

    """
    main function for training the discriminator and position bias
    """

    def train_dis_prop(self,
                       features,
                       positions,
                       clicks,
                       forward=True,
                       prop=True):

        if torch.sum(clicks) == 0:
            return 0
        # prepare dataset S
        pred_valid, pred_fake, true_diffs, position_i, position_j = self.get_sample_pairs(
            features, positions, clicks)
        # train dis Eq.(3)
        loss = self.train_dis(pred_valid,
                              pred_fake,
                              true_diffs,
                              position_i,
                              position_j,
                              forward=forward)

        # update position ratios if prop with Eq.(5)(6)
        if prop:
            self.estimate_bias(pred_valid, pred_fake, true_diffs, position_i,
                               position_j)

        return loss

    """
    train the generator
    """

    def train_gen(self, features, forward=True):

        # ==== sample and get reward from DIS====
        features = Variable(features.type(self.FloatTensor))
        d_output = self.dis(features).view(len(features),
                                           -1).detach().cpu().numpy()

        # == select n documents for each q ==
        choose_features = []
        choose_reward = []
        for index in range(len(d_output)):
            exp_rating = np.exp(d_output[index] - np.max(d_output[index]))
            prob = exp_rating / np.sum(exp_rating)
            try:
                choose_index = np.random.choice(self.n_positions,
                                                size=[self.n],
                                                p=prob)
            except:
                choose_index = np.random.choice(self.n_positions,
                                                size=[self.n])
            reward = self.dis.reward(features[index][choose_index])
            choose_reward.append(reward.tolist())  # 5 x 10 x 1
            choose_features.append(features[index][choose_index].tolist())

        choose_features = tensor(choose_features).type(self.FloatTensor)
        choose_reward = tensor(choose_reward).type(self.FloatTensor).view(
            len(choose_features), -1)

        # ==== loss ====
        # update generator using Eq.(7)
        loss = self.gen.score(choose_features, choose_reward)

        # ==== optimize ====
        if forward:
            self.optimizer_gen.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.gen.parameters(), self.norm)
            self.optimizer_gen.step()
        return loss.item()

    """
    NDCG
    """

    def ndcg(self, l, dis, label, feature, label_i=5):
        label_index = l.index(label_i)
        if dis:
            res = ndcg_at_k(self.dis, label, feature, k=l, use_cuda=self.cuda)
            label = res[label_index]
            for i in range(len(l)):
                print('ndcg@{}:{:.4f} '.format(l[i], res[i]))
        else:
            res = ndcg_at_k(self.gen, label, feature, k=l, use_cuda=self.cuda)
            label = res[label_index]
            for i in range(len(l)):
                print('ndcg@{}:{:.4f} '.format(l[i], res[i]))
        return label

    """
    evaluate
    """

    def evaluate(self, model_path, since, opt, data_dir, data):
        print('')
        print('==== eval {} ===='.format(model_path))

        if model_path == "":
            print('You have to specify the eval model')
        else:
            self.dis.load_state_dict(torch.load(model_path))
            yahoo_output(model=self.dis,
                         data=data,
                         data_dir=data_dir,
                         tools_dir=opt.tools_dir,
                         opt=opt,
                         model_path=model_path,
                         since=since,
                         use_cuda=self.cuda)
Beispiel #7
0
def train(
    *,
    folder="out",
    dataset="mnist",
    image_size=None,
    resume_folder=None,
    wasserstein=False,
    log_interval=1,
    device="cpu",
    batch_size=64,
    mask_size=None,
    nz=100,
    parent_model=None,
    freeze_parent=True,
    num_workers=1,
    nb_filters=64,
    nb_epochs=200,
    nb_extra_layers=0,
    nb_draw_layers=1
):

    try:
        os.makedirs(folder)
    except Exception:
        pass
    lr = 0.0002
    if mask_size:
        mask_size = int(mask_size)
    dataset = load_dataset(dataset, split="train", image_size=image_size, mask_size=mask_size)
    x0, _ = dataset[0]
    nc = x0.size(0)
    w = x0.size(1)
    h = x0.size(2)
    _save_weights = partial(save_weights, folder=folder, prefix="gan")
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
    )
    act = "sigmoid" if nc == 1 else "tanh"
    if resume_folder:
        gen = torch.load("{}/gen.th".format(resume_folder))
        discr = torch.load("{}/discr.th".format(resume_folder))
    else:
        gen = Gen(
            latent_size=nz,
            nb_colors=nc,
            image_size=w,
            act=act,
            nb_gen_filters=nb_filters,
            nb_extra_layers=nb_extra_layers,
        )
        discr = Discr(
            nb_colors=nc,
            image_size=w,
            nb_discr_filters=nb_filters,
            nb_extra_layers=nb_extra_layers,
        )
    print(gen)
    print(discr)
    if wasserstein:
        gen_opt = optim.RMSprop(gen.parameters(), lr=lr)
        discr_opt = optim.RMSprop(discr.parameters(), lr=lr)
    else:
        gen_opt = optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))
        discr_opt = optim.Adam(discr.parameters(), lr=lr, betas=(0.5, 0.999))

    input = torch.FloatTensor(batch_size, nc, w, h)
    noise = torch.FloatTensor(batch_size, nz, 1, 1)
    label = torch.FloatTensor(batch_size)

    if wasserstein:
        real_label = 1
        fake_label = -1

        def criterion(output, label):
            return (output * label).mean()

    else:
        real_label = 1
        fake_label = 0
        criterion = nn.BCELoss()

    gen = gen.to(device)
    discr = discr.to(device)
    input, label = input.to(device), label.to(device)
    noise = noise.to(device)

    giter = 0
    diter = 0

    dreal_list = []
    dfake_list = []
    pred_error_list = []

    for epoch in range(nb_epochs):
        for i, (X, _) in enumerate(dataloader):
            if wasserstein:
                # clamp parameters to a cube
                for p in discr.parameters():
                    p.data.clamp_(-0.01, 0.01)
            # Update discriminator
            discr.zero_grad()
            batch_size = X.size(0)
            X = X.to(device)
            input.resize_as_(X).copy_(X)
            label.resize_(batch_size).fill_(real_label)
            inputv = Variable(input)
            labelv = Variable(label)
            output = discr(inputv)
            labelpred = output[:, 0:1] if wasserstein else nn.Sigmoid()(output[:, 0:1])
            errD_real = criterion(labelpred, labelv)
            errD_real.backward()
            D_x = labelpred.data.mean()
            dreal_list.append(D_x)
            noise.resize_(batch_size, nz, 1, 1).uniform_(-1, 1)
            noisev = Variable(noise)
            fake = gen(noisev)
            labelv = Variable(label.fill_(fake_label))
            output = discr(fake.detach())

            labelpred = output[:, 0:1] if wasserstein else nn.Sigmoid()(output[:, 0:1])
            errD_fake = criterion(labelpred, labelv)
            errD_fake.backward()
            D_G_z1 = labelpred.data.mean()
            dfake_list.append(D_G_z1)
            discr_opt.step()
            diter += 1

            # Update generator
            gen.zero_grad()
            fake = gen(noisev)
            labelv = Variable(label.fill_(real_label))
            output = discr(fake)
            labelpred = output[:, 0:1] if wasserstein else nn.Sigmoid()(output[:, 0:1])
            errG = criterion(labelpred, labelv)
            errG.backward()
            gen_opt.step()
            if diter % log_interval == 0:
                print(
                    "{}/{} dreal : {:.6f} dfake : {:.6f}".format(
                        epoch, nb_epochs, D_x, D_G_z1
                    )
                )
            if giter % 100 == 0:
                x = 0.5 * (X + 1) if act == "tanh" else X
                f = 0.5 * (fake.data + 1) if act == "tanh" else fake.data
                vutils.save_image(
                    x, "{}/real_samples_last.png".format(folder), normalize=True
                )
                vutils.save_image(
                    f,
                    "{}/fake_samples_iter_{:03d}.png".format(folder, giter),
                    normalize=True,
                )
                vutils.save_image(
                    f, "{}/fake_samples_last.png".format(folder), normalize=True
                )
                torch.save(gen, "{}/gen.th".format(folder))
                torch.save(discr, "{}/discr.th".format(folder))
                gen.apply(_save_weights)
            giter += 1
Beispiel #8
0
if args.model_name == 'H_CNN':
    from model.H_CNN import *
    model = H_CNN(args)

if args.model_name == 'H_LSTM_ATT':
    from model.H_LSTM_ATT import *
    model = H_LSTM_ATT(args)

if args.model_name == 'H_LSTM_ATT_ext':
    from model.H_LSTM_ATT_ext import *
    model = H_LSTM_ATT_ext(args)

if args.model_name == 'Gen':
    from model.Gen import *
    model = Gen(args)

if args.model_name == 'Gen_GE':
    from model.Gen_GE import *
    model = Gen_GE(args)

if args.model_name == 'CNN':
    from model.CNN import *
    model = CNN(args)

if args.model_name == 'LSTM':
    from model.LSTM import *
    model = LSTM(args)

if args.model_name == 'H_MLP_ATT':
    from model.H_MLP_ATT import *