Exemple #1
0
    def __init__(self, opt):
        super(GibbsNet, self).__init__(opt)

        self.ali_model = ALI(opt)
        self.opt = opt
        self.sampling_count = opt.sampling_count
        self.z = None
        self.fake_x = None

        self.input = self.Tensor(opt.batch_size, opt.input_channel, opt.height,
                                 opt.width)
        self.x = None
Exemple #2
0
    def train_on_batch(self, x_data, y_batch=None, compute_grad_norms=False):
        # self.swap_weights()

        batchsize = len(x_data)

        # perform label smoothing if applicable
        y_pos, y_neg = ALI.get_labels(batchsize, self.label_smoothing)
        y = np.stack((y_neg, y_pos), axis=1)

        # get real latent variables distribution
        z_latent_dis = np.random.normal(size=(batchsize, self.z_dims))

        # train both networks
        d_loss = self.dis_trainer.train_on_batch([x_data, z_latent_dis], y)
        g_loss = self.gen_trainer.train_on_batch([x_data, z_latent_dis], y)

        # repeat generator training if loss is too high in comparison with D
        max_loss, max_g_2_d_loss_ratio = 5., 5.
        retrained_times, max_retrains = 0, 3
        while retrained_times < max_retrains and (g_loss > max_loss or g_loss > self.last_d_loss * max_g_2_d_loss_ratio):
            g_loss = self.gen_trainer.train_on_batch([x_data, z_latent_dis], y)
            retrained_times += 1
        if retrained_times > 0:
            print('Retrained Generator {} time(s)'.format(retrained_times))

        self.last_d_loss = d_loss
        losses = {
            'g_loss': g_loss,
            'd_loss': d_loss
        }

        return losses
Exemple #3
0
def create_model(opt):
    if opt.model == 'ALI':
        return ALI(opt)
    elif opt.model == 'GibbsNet':
        return GibbsNet(opt)
    else:
        raise Exception("This implementation only supports ALI, GibbsNet.")
Exemple #4
0
def create_model(opt):
    print('=========================models.py-->create_model(opt)')
    if opt.model == 'ALI':
        return ALI(opt)
    elif opt.model == 'GibbsNet':
        return GibbsNet(opt)
    elif opt.model == 'RGibbsNet':
        return RGibbsNet(opt)
    elif opt.model == 'HALI':
        return Hali(opt)
    else:
        raise Exception("This implementation only supports ALI, GibbsNet.")
Exemple #5
0
            nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0),
            nn.LeakyReLU(0.1, True),
            nn.Dropout2d(p=0.2),

            nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0),
            nn.LeakyReLU(0.1, True),
            nn.Dropout2d(p=0.2),
    
            nn.Conv2d(1024, 1, kernel_size=1, stride=1, padding=0),
            nn.Dropout2d(p=0.2),
        )

    def forward(self, z, x):
        x = (x * 0.5) + 0.5
        concat = torch.cat((self.part_x(x), self.part_z(z)), dim=1)        
        return self.joint(concat)

params = {
    'device': 'cuda',
    'lr_g': 1e-4,
    'lr_d': 1e-5,
    'betas': (.5, 1e-3),
    'size_z': 64,
    'init_normal': True,
    'gx': GeneratorX(),
    'gz': GeneratorZ(),
    'd': Discriminator(),
    }

model = ALI(**params)
Exemple #6
0
class GibbsNet(BaseModel):
    def __init__(self, opt):
        super(GibbsNet, self).__init__(opt)

        self.ali_model = ALI(opt)
        self.opt = opt
        self.sampling_count = opt.sampling_count
        self.z = None
        self.fake_x = None

        self.input = self.Tensor(
            opt.batch_size,
            opt.input_channel,
            opt.height,
            opt.width
        )
        self.x = None

    def forward(self, volatile=False):
        self.sampling()

        # clamped chain : ALI model
        self.ali_model.set_z(var=self.z)
        self.ali_model.set_input(self.x.data, is_z_given=True)

        self.ali_model.forward()

    def test(self):
        self.forward(volatile=True)

    def set_input(self, data):
        temp = self.input.clone()
        temp.resize_(self.input.size())
        temp.copy_(self.input)
        self.input = temp
        self.input.resize_(data.size()).copy_(data)

    def sampling(self, volatile=True):
        batch_size = self.opt.batch_size
        self.x = Variable(self.input)
        # self.x = Variable(self.input, volatile=volatile)

        # unclamped chain
        if self.gpu_ids:
            self.z = Variable(torch.randn((batch_size, self.ali_model.encoder.k)).cuda())
        else:
            self.z = Variable(torch.randn((batch_size, self.ali_model.encoder.k)))
            # self.z = Variable(torch.randn((batch_size, self.ali_model.encoder.k)), volatile=volatile)

        for i in range(self.sampling_count):
            self.fake_x = self.ali_model.decoder(self.z)
            # print(self.fake_x)
            # input()
            self.z = self.ali_model.encoder(self.fake_x)

    def save(self, epoch):
        self.ali_model.save(epoch)

    def load(self, epoch):
        self.ali_model.load(epoch)

    def optimize_parameters(self, inferring_count=0):
        self.sampling()
        # clamped chain : ALI model
        self.ali_model.set_z(var=self.z)
        self.ali_model.set_input(self.x.data, is_z_given=True)

        self.ali_model.optimize_parameters()

    def get_losses(self):
        return self.ali_model.get_losses()

    def get_visuals(self, sample_single_image=True):
        return self.ali_model.get_visuals(sample_single_image=sample_single_image)

    def get_infervisuals(self, infernum, sample_single_image=True):
        return self.ali_model.get_infervisuals(infernum, sample_single_image=sample_single_image)

    def get_lv(self):
        return self.ali_model.get_lv()

    def remove(self, epoch):
        self.ali_model.remove(epoch)

    def reconstruction(self):
        return self.ali_model.reconstruction()
Exemple #7
0
            nn.Conv2d(nz * 4, nz * 4, kernel_size=1, stride=1, padding=0),
            nn.LeakyReLU(0.1, True),
            nn.Dropout2d(p=0.2),
            nn.Conv2d(nz * 4, 1, kernel_size=1, stride=1, padding=0),
            nn.Dropout2d(p=0.2),
        )

    def forward(self, z, x):
        concat = torch.cat((self.part_x(x), self.part_z(z)), dim=1)
        return self.joint(concat).unsqueeze(2).unsqueeze(3)


params = {
    'device': 'cuda',
    'lr_g': 1e-4,
    'lr_d': 1e-4,
    'betas': (.5, 0.999),
    'size_z': 32,
    'init_normal': False,
    'gx': GeneratorX(),
    'gz': GeneratorZ(),
    'd': Discriminator(),
}

model = ALI(**params)
model.scheduler_g = torch.optim.lr_scheduler.StepLR(model.optim_g,
                                                    step_size=10,
                                                    gamma=0.5)
model.scheduler_d = torch.optim.lr_scheduler.StepLR(model.optim_d,
                                                    step_size=10,
                                                    gamma=0.5)