Exemplo n.º 1
0
    def forward(self, zb, a):
        h1 = F.elu(self.lin1(zb))
        h2 = F.elu(self.lin2(h1))

        # finish forward binary and categorical covariates
        bin_out_dict = dict()

        # for each categorical variable
        for i in range(len(self.headnames)):
            # calculate probability paramater
            p_a0 = self.binheads_a0[i](h2)
            p_a1 = self.binheads_a1[i](h2)
            dist_p_a0 = torch.sigmoid(p_a0)
            dist_p_a1 = torch.sigmoid(p_a1)
            # create distribution in dict
            if self.headnames[i] == 'BINARY':
                bin_out_dict[self.headnames[i]] = bernoulli.Bernoulli((1-a)*dist_p_a0 + a*dist_p_a1)
            else:
                bin_out_dict[self.headnames[i]] = OneHotCategorical((1-a)*dist_p_a0 + a*dist_p_a1)

        # finish forward continuous vars for the right TAR head
        mu_a0 = self.mu_a0(h2)
        mu_a1 = self.mu_a1(h2)
        sigma_a0 = self.softplus(self.sigma_a0(h2))
        sigma_a1 = self.softplus(self.sigma_a1(h2))
        # cap sigma to prevent collapse for continuous vars being 0
        sigma_a0 = torch.clamp(sigma_a0, min=0.1)
        sigma_a1 = torch.clamp(sigma_a1, min=0.1)
        con_out = normal.Normal((1-a) * mu_a0 + a * mu_a1, (1-a)* sigma_a0 + a * sigma_a1)

        return con_out, bin_out_dict
Exemplo n.º 2
0
def sampler(P_op, P_skip, num_samples):
    cat_op = categorical.Categorical(P_op)
    cat_sk = bernoulli.Bernoulli(P_skip)
    ops, sks = cat_op.sample([num_samples]), cat_sk.sample([num_samples])
    #print(ops.shape)
    #print(sks.shape)
    return CM.ChildModelBatch(ops, sks)
Exemplo n.º 3
0
    def forward(self, x):
        x = F.elu(self.input(x))
        for i in range(self.nh):
            x = F.elu(self.hidden[i](x))
        # for binary outputs:
        out_p = torch.sigmoid(self.output(x))

        out = bernoulli.Bernoulli(out_p)
        return out
Exemplo n.º 4
0
    def __init__(self, noise1, noise2, method, prob=0.5):
        self.noise1 = noise1
        self.noise2 = noise2
        self.bernoulli = bernoulli.Bernoulli(probs=prob)

        if method == 'rand':
            self.sample = self.sample_rand
        elif method == 'sum':
            self.sample = self.sample_sum
Exemplo n.º 5
0
 def forward(self, z, a):
     rep = F.elu(self.rep(z))
     # for each value of a different mapping from representation
     rep0 = F.elu(self.lin0(rep))
     rep1 = F.elu(self.lin1(rep))
     p_a0 = torch.sigmoid(self.output_a0(rep0))
     p_a1 = torch.sigmoid(self.output_a1(rep1))
     # combine TAR net into single output
     out = bernoulli.Bernoulli((1 - a) * p_a0 + a * p_a1)
     return out
Exemplo n.º 6
0
    def forward(self, zb, a):
        h1 = F.elu(self.lin1(zb))
        h2_a0 = F.elu(self.lin2_a0(h1))
        h2_a1 = F.elu(self.lin2_a1(h1))
        bern_p_a0 = torch.sigmoid(h2_a0)
        bern_p_a1 = torch.sigmoid(h2_a1)

        bern_out = bernoulli.Bernoulli((1-a)*bern_p_a0 + a*bern_p_a1)

        return bern_out
Exemplo n.º 7
0
    def __init__(self, data_loader, config):
        torch.manual_seed(config.seed)
        torch.cuda.manual_seed(config.seed)

        self.data_loader = data_loader
        self.model = config.model
        self.adv_loss = config.adv_loss

        self.imsize = config.imsize
        self.g_num = config.g_num
        self.z_dim = config.z_dim
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.parallel = config.parallel
        self.extra = config.extra

        self.lambda_gp = config.lambda_gp
        self.total_step = config.total_step
        self.d_iters = config.d_iters
        self.batch_size = config.batch_size
        self.num_workers = config.num_workers
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.lr_scheduler = config.lr_scheduler
        self.g_beta1 = config.g_beta1
        self.d_beta1 = config.d_beta1
        self.beta2 = config.beta2

        self.dataset = config.dataset
        self.log_path = config.log_path
        self.model_save_path = config.model_save_path
        self.sample_path = config.sample_path
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.version = config.version
        self.backup_freq = config.backup_freq
        self.bup_path = config.bup_path

        # Path
        self.optim = config.optim
        self.svrg = config.svrg
        self.avg_start = config.avg_start
        self.build_model()

        if self.svrg:
            self.mu_g = []
            self.mu_d = []
            self.g_snapshot = copy.deepcopy(self.G)
            self.d_snapshot = copy.deepcopy(self.D)
            self.svrg_freq_sampler = bernoulli.Bernoulli(torch.tensor([1 / len(self.data_loader)]))

        self.info_logger = setup_logger(self.log_path)
        self.info_logger.info(config)
        self.cont = config.cont
Exemplo n.º 8
0
 def forward(self, xz, a):
     xz_embed = F.elu(self.input(xz))
     h1 = F.elu(self.h1(xz_embed))
     # Separate TAR heads for a values
     h2_a0 = F.elu(self.h2_a0(h1))
     h2_a1 = F.elu(self.h2_a1(h1))
     h3_a0 = F.elu(self.h3_a0(h2_a0))
     h3_a1 = F.elu(self.h3_a1(h2_a1))
     p_y_a0 = torch.sigmoid(self.p_y_a0(h3_a0))
     p_y_a1 = torch.sigmoid(self.p_y_a1(h3_a1))
     y = bernoulli.Bernoulli((1 - a) * p_y_a0 + a * p_y_a1)
     return y
Exemplo n.º 9
0
    def forward(self, z_input):
        z = F.elu(self.input(z_input))
        for i in range(self.nh-1):
            z = F.elu(self.hidden[i](z))
        # for binary outputs:
        x_bin_p = torch.sigmoid(self.output_bin(z))
        x_bin = bernoulli.Bernoulli(x_bin_p)
        # for continuous outputs
        mu, sigma = self.output_con_mu(z), self.softplus(self.output_con_sigma(z))
        x_con = normal.Normal(mu, sigma)

        if (z != z).all():
            raise ValueError('p(x|z) forward contains NaN')

        return x_bin, x_con
Exemplo n.º 10
0
def sample_batch_from_out_dist(y_hat, bias):
    """
    Can be removed
    """
    batch_size = y_hat.shape[0]
    split_sizes = [1] + [20] * 6
    y = torch.split(y_hat, split_sizes, dim=1)

    eos_prob = torch.sigmoid(y[0])
    mixture_weights = stable_softmax(y[1] * (1 + bias), dim=1)
    mu_1 = y[2]
    mu_2 = y[3]
    std_1 = torch.exp(y[4] - bias)
    std_2 = torch.exp(y[5] - bias)
    correlations = torch.tanh(y[6])

    bernoulli_dist = bernoulli.Bernoulli(probs=eos_prob)
    eos_sample = bernoulli_dist.sample()

    K = torch.multinomial(mixture_weights, 1).squeeze()

    mu_k = y_hat.new_zeros((y_hat.shape[0], 2))

    mu_k[:, 0] = mu_1[torch.arange(batch_size), K]
    mu_k[:, 1] = mu_2[torch.arange(batch_size), K]
    cov = y_hat.new_zeros(y_hat.shape[0], 2, 2)
    cov[:, 0, 0] = std_1[torch.arange(batch_size), K].pow(2)
    cov[:, 1, 1] = std_2[torch.arange(batch_size), K].pow(2)
    cov[:, 0, 1], cov[:, 1, 0] = (
        correlations[torch.arange(batch_size), K]
        * std_1[torch.arange(batch_size), K]
        * std_2[torch.arange(batch_size), K],
        correlations[torch.arange(batch_size), K]
        * std_1[torch.arange(batch_size), K]
        * std_2[torch.arange(batch_size), K],
    )

    X = torch.normal(
        mean=torch.zeros(batch_size, 2, 1), std=torch.ones(batch_size, 2, 1)
    ).to(y_hat.device)
    Z = mu_k + torch.matmul(cov, X).squeeze()

    sample = y_hat.new_zeros(batch_size, 1, 3)
    sample[:, 0, 0:1] = eos_sample
    sample[:, 0, 1:] = Z.squeeze()
    return sample
Exemplo n.º 11
0
    def forward(self, za_input):
        z = F.elu(self.input(za_input))
        for i in range(self.nh - 1):
            z = F.elu(self.hidden[i](z))
        # for binary outputs:
        x_bin_p = torch.sigmoid(self.output_bin(z))
        x_bin = bernoulli.Bernoulli(x_bin_p)
        # for continuous outputs
        mu, sigma = self.output_con_mu(z), self.softplus(
            self.output_con_sigma(z))
        # sigma overruled by simplicity assumption Madras, legitimized by standardization
        sigma = torch.exp(2 * torch.ones(mu.shape).cuda())
        x_con = normal.Normal(mu, sigma)

        if (z != z).all():
            print('Forward contains NaN')

        return x_bin, x_con
Exemplo n.º 12
0
def sample_from_out_dist(y_hat, bias):
    split_sizes = [1] + [20] * 6
    y = torch.split(y_hat, split_sizes, dim=0)
    """
    Can softmax be replaced
    Equation 16 to  25
    """
    eos_prob = torch.sigmoid(y[0])
    mixture_weights = stable_softmax(y[1] * (1 + bias), dim=0)
    mu_1 = y[2]
    mu_2 = y[3]
    std_1 = torch.exp(y[4] - bias)
    std_2 = torch.exp(y[5] - bias)
    correlations = torch.tanh(y[6])

    bernoulli_dist = bernoulli.Bernoulli(probs=eos_prob)
    eos_sample = bernoulli_dist.sample()

    K = torch.multinomial(mixture_weights, 1)

    mu_k = y_hat.new_zeros(2)

    mu_k[0] = mu_1[K]
    mu_k[1] = mu_2[K]
    cov = y_hat.new_zeros(2, 2)
    cov[0, 0] = std_1[K].pow(2)
    cov[1, 1] = std_2[K].pow(2)
    cov[0, 1], cov[1, 0] = (
        correlations[K] * std_1[K] * std_2[K],
        correlations[K] * std_1[K] * std_2[K],
    )

    x = torch.normal(mean=torch.Tensor([0.0, 0.0]),
                     std=torch.Tensor([1.0, 1.0])).to(y_hat.device)
    Z = mu_k + torch.mv(cov, x)

    sample = y_hat.new_zeros(1, 1, 3)
    sample[0, 0, 0] = eos_sample.item()
    sample[0, 0, 1:] = Z
    return sample
Exemplo n.º 13
0
 def sampling(self, step):
     sampler = bernoulli.Bernoulli(self.eps).sample().item()
     if sampler == 0:
         return torch.argmax(self.est_values).item()
     else:
         return torch.randint(high=self.num_actions, size=(1, )).item()
Exemplo n.º 14
0
    def build_model(self):
        # Models                    ###################################################################
        self.G = Generator(self.batch_size, self.imsize, self.z_dim, self.g_conv_dim).cuda()
        self.D = Discriminator(self.batch_size, self.imsize, self.d_conv_dim).cuda()
        # Todo: do not allocate unnecessary GPU mem for G_extra and D_extra if self.extra == False
        self.G_extra = Generator(self.batch_size, self.imsize, self.z_dim, self.g_conv_dim).cuda()
        self.D_extra = Discriminator(self.batch_size, self.imsize, self.d_conv_dim).cuda()
        if self.avg_start >= 0:
            self.avg_g = copy.deepcopy(self.G)
            self.avg_d = copy.deepcopy(self.D)
            self._requires_grad(self.avg_g, False)
            self._requires_grad(self.avg_d, False)
            self.avg_g.eval()
            self.avg_d.eval()
            self.avg_step = 1
            self.avg_freq_restart_sampler = bernoulli.Bernoulli(.1)

        if self.parallel:
            self.G = nn.DataParallel(self.G)
            self.D = nn.DataParallel(self.D)
            self.G_extra = nn.DataParallel(self.G_extra)
            self.D_extra = nn.DataParallel(self.D_extra)
            if self.avg_start >= 0:
                self.avg_g = nn.DataParallel(self.avg_g)
                self.avg_d = nn.DataParallel(self.avg_d)
        self.G_extra.train()
        self.D_extra.train()

        self.G_avg = copy.deepcopy(self.G)
        self.G_ema = copy.deepcopy(self.G)
        self._requires_grad(self.G_avg, False)
        self._requires_grad(self.G_ema, False)

        # Logs, Loss & optimizers   ###################################################################
        grad_var_logger_g = setup_logger(self.log_path, 'log_grad_var_g.log')
        grad_var_logger_d = setup_logger(self.log_path, 'log_grad_var_d.log')
        grad_mean_logger_g = setup_logger(self.log_path, 'log_grad_mean_g.log')
        grad_mean_logger_d = setup_logger(self.log_path, 'log_grad_mean_d.log')

        if self.optim == 'sgd':
            self.g_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self.G.parameters()),
                                               self.g_lr,
                                               logger_mean=grad_mean_logger_g,
                                               logger_var=grad_var_logger_g)
            self.d_optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self.D.parameters()),
                                               self.d_lr,
                                               logger_mean=grad_mean_logger_d,
                                               logger_var=grad_var_logger_d)
            self.g_optimizer_extra = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                                            self.G_extra.parameters()),
                                                     self.g_lr)
            self.d_optimizer_extra = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                                            self.D_extra.parameters()),
                                                     self.d_lr)
        elif self.optim == 'adam':
            self.g_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.G.parameters()),
                                                self.g_lr, [self.g_beta1, self.beta2],
                                                logger_mean=grad_mean_logger_g,
                                                logger_var=grad_var_logger_g)
            self.d_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D.parameters()),
                                                self.d_lr, [self.d_beta1, self.beta2],
                                                logger_mean=grad_mean_logger_d,
                                                logger_var=grad_var_logger_d)
            self.g_optimizer_extra = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                                             self.G_extra.parameters()),
                                                      self.g_lr, [self.g_beta1, self.beta2])
            self.d_optimizer_extra = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                                             self.D_extra.parameters()),
                                                      self.d_lr, [self.d_beta1, self.beta2])
        elif self.optim == 'svrgadam':
            self.g_optimizer = torch.optim.SvrgAdam(filter(lambda p: p.requires_grad, self.G.parameters()),
                                                    self.g_lr, [self.g_beta1, self.beta2],
                                                    logger_mean=grad_mean_logger_g,
                                                    logger_var=grad_var_logger_g)
            self.d_optimizer = torch.optim.SvrgAdam(filter(lambda p: p.requires_grad, self.D.parameters()),
                                                    self.d_lr, [self.d_beta1, self.beta2],
                                                    logger_mean=grad_mean_logger_d,
                                                    logger_var=grad_var_logger_d)
            self.g_optimizer_extra = torch.optim.SvrgAdam(filter(lambda p: p.requires_grad,
                                                          self.G_extra.parameters()),
                                                          self.g_lr, [self.g_beta1, self.beta2])
            self.d_optimizer_extra = torch.optim.SvrgAdam(filter(lambda p: p.requires_grad,
                                                          self.D_extra.parameters()),
                                                          self.d_lr, [self.d_beta1, self.beta2])
        else:
            raise NotImplementedError('Supported optimizers: SGD, Adam, Adadelta')

        if self.lr_scheduler > 0:  # Exponentially decaying learning rate
            self.scheduler_g = torch.optim.lr_scheduler.ExponentialLR(self.g_optimizer,
                                                                      gamma=self.lr_scheduler)
            self.scheduler_d = torch.optim.lr_scheduler.ExponentialLR(self.d_optimizer,
                                                                      gamma=self.lr_scheduler)
            self.scheduler_g_extra = torch.optim.lr_scheduler.ExponentialLR(self.g_optimizer_extra,
                                                                            gamma=self.lr_scheduler)
            self.scheduler_d_extra = torch.optim.lr_scheduler.ExponentialLR(self.d_optimizer_extra,
                                                                            gamma=self.lr_scheduler)

        print(self.G)
        print(self.D)
Exemplo n.º 15
0
    # - As SVRG requires GD step, an additional data loader is instantiated which
    # uses larger batch size (opt.large_batch_size). Analogous hold for noise data.
    # - To ensure that in expectation the noise vanishes (what reduces SVRG to SGD
    # [*]), svrg_noise_sampler & noise_sampler use the same noise tensor. This
    # noise tensor is re-sampled from p_z by noise_sampler, after its full traverse.
    #
    # [*] Accelerating stochastic gradient descent using predictive variance reduction,
    # Johnson & Zhang, Advances in Neural Information Processing Systems, 2013.

    dataset = load_dataset(opt.dataset, opt.dataroot, opt.verbose)
    data_sampler = dataset_generator(dataset,
                                     opt.batch_size,
                                     num_workers=opt.n_workers,
                                     drop_last=True)
    _n_batches = len(dataset) // opt.batch_size
    svrg_freq_sampler = bernoulli.Bernoulli(torch.tensor([1 / _n_batches]))
    noise_dataset = torch.FloatTensor(2 * len(dataset),
                                      _NOISE_DIM).normal_(0, 1)
    noise_sampler = noise_generator(noise_dataset,
                                    opt.batch_size,
                                    drop_last=True,
                                    resample=True)
    logger.info(
        "{} loaded. Found {} samples, resulting in {} mini-batches.".format(
            opt.dataset, len(dataset), _n_batches))
    avg_rst_freq_sampler = bernoulli.Bernoulli(opt.rst_freq)
    avg_step = 1
    avg_g = copy.deepcopy(generator['net'])
    avg_d = copy.deepcopy(discriminator['net'])
    _requires_grad(avg_g, False)
    _requires_grad(avg_d, False)
Exemplo n.º 16
0
def recon_loss_bernoulli(x, logits, *args, **kwargs):
    # *args to take in extra variables to fit in the existing trainer setup
    # NOTE: Only pass 0 and 1s in x, otherwise we get absurd probabilities
    rv = bernoulli.Bernoulli(logits=logits.reshape(logits.shape[0], -1))
    return -rv.log_prob(x.reshape(logits.shape[0], -1)).sum(1).mean()
Exemplo n.º 17
0
    def __init__(self, input_size, hidden_size, output_size, model_prob, mlp_layers=[], config_size=None, mlp_config=[], activation=nn.Sequential()):
        '''
        create a variational LSTM model.
        '''

        super(LSTM, self).__init__()

        self.hidden_size = hidden_size
        self.output_size = output_size
        self.z_mask = []
        self.ft = nn.ModuleList()
        self.it = nn.ModuleList()
        self.Ctild_t = nn.ModuleList()
        self.ot = nn.ModuleList()

        for ii in range(len(hidden_size)):
            if ii == 0:
                self.ft.append(nn.Linear(input_size+hidden_size[ii], hidden_size[ii]))
                self.it.append(nn.Linear(input_size+hidden_size[ii], hidden_size[ii]))
                self.Ctild_t.append(nn.Linear(input_size+hidden_size[ii], hidden_size[ii]))
                self.ot.append(nn.Linear(input_size+hidden_size[ii], hidden_size[ii]))
            else:
                self.ft.append(nn.Linear(hidden_size[ii-1]+hidden_size[ii], hidden_size[ii]))
                self.it.append(nn.Linear(hidden_size[ii-1]+hidden_size[ii], hidden_size[ii]))
                self.Ctild_t.append(nn.Linear(hidden_size[ii-1]+hidden_size[ii], hidden_size[ii]))
                self.ot.append(nn.Linear(hidden_size[ii-1]+hidden_size[ii], hidden_size[ii]))

        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()

        mlp_layers.insert(0, hidden_size[-1])
        mlp_layers.append(output_size)
        self.mlp_layers = mlp_layers
        self.activation = activation
        self.model_bernoulli = Bernoulli.Bernoulli(probs=model_prob)
        #self.dropout = nn.Dropout(1-model_prob)
        #remove dropout in the mlp layer in order to make the training process easier
        self.dropout = nn.Dropout(0)
        mlp_list = []

        for ii in range(len(self.mlp_layers)-1):
            mlp_list.append(self.dropout)
            mlp_list.append(nn.Linear(self.mlp_layers[ii], self.mlp_layers[ii+1]))
            if ii+2 < len(self.mlp_layers):
                mlp_list.append(self.activation)
        self.mlp = nn.Sequential(*mlp_list)

        self.mlp_config = nn.Sequential()

        if config_size is not None:

            mlp_config.insert(0, config_size)
            count = 0
            for ii in range(len(hidden_size)):
                count += 1
                mlp_config_ii = []

                for jj in range(len(mlp_config)-1):
                    mlp_config_ii.append(nn.Linear(mlp_config[jj], mlp_config[jj+1]))
                    mlp_config_ii.append(self.activation)

                mlp_config_ii.append(nn.Linear(mlp_config[-1], hidden_size[ii]))

                self.mlp_config.add_module('mlp {}'.format(count), nn.Sequential(*mlp_config_ii))