예제 #1
0
def logprob_undercomponent(x, component):
    B = x.shape[0]
    mean = (component.float()*10.).view(B,1)
    std = (torch.ones([B]) *5.).view(B,1)
    m = Normal(mean.cuda(), std.cuda())
    logpx_given_z = m.log_prob(x)
    return logpx_given_z
예제 #2
0
	def forward(self, inputs, c, z=None):
		inputs = inputs.view(-1, 1, 28, 28) #huh?
		mu = self.localization_mu(inputs)
		sigma = self.localization_sigma(inputs)
		dist = Normal(mu, sigma)
		if z is None: 
			z = dist.rsample()
		score = dist.log_prob(z).sum(dim=1).sum(dim=1).sum(dim=1)
		return z, score
예제 #3
0
def sample_true2():
    cat = Categorical(probs= torch.tensor(true_mixture_weights))
    cluster = cat.sample()
    # print (cluster)
    # fsd
    norm = Normal(torch.tensor([cluster*10.]).float(), torch.tensor([5.0]).float())
    samp = norm.sample()
    # print (samp)
    return samp,cluster
예제 #4
0
def sample_gmm(batch_size, mixture_weights):
    cat = Categorical(probs=mixture_weights)
    cluster = cat.sample([batch_size]) # [B]
    mean = (cluster*10.).float().cuda()
    std = torch.ones([batch_size]).cuda() *5.
    norm = Normal(mean, std)
    samp = norm.sample()
    samp = samp.view(batch_size, 1)
    return samp
예제 #5
0
 def forward(self, inputs, c=None):    
     inputs_permuted = inputs.transpose(0,1) # |D| * batch * ... 
     embeddings = [self.enc(x) for x in inputs_permuted]
     mean_embedding = sum(embeddings)/len(embeddings)
     mu_c = self.mu_c(mean_embedding)
     sigma_c = self.sigma_c(mean_embedding)
     dist = Normal(mu_c, sigma_c)
     if c is None: c = dist.rsample()
     return c, dist.log_prob(c).sum(dim=1) # Return value, score
예제 #6
0
	def forward(self, inputs, c=None):
		# transform the input
		xs = [self.stn(inputs[:,i,:,:,:]) for i in range(inputs.size(1))]

		embs = [self.conv_post_stn(x) for x in xs]
		emb = sum(embs)/len(embs)
		mu = self.conv_mu(emb)
		sigma = self.conv_sigma(emb)
		dist = Normal(mu, sigma)
		if c is None: c = dist.rsample()
		return c, dist.log_prob(c).sum(dim=1).sum(dim=1).sum(dim=1)
예제 #7
0
def logprob_givenmixtureeweights(x, needsoftmax_mixtureweight):

    mixture_weights = torch.softmax(needsoftmax_mixtureweight, dim=0)
    probs_sum = 0# = []
    for c in range(n_components):
        m = Normal(torch.tensor([c*10.]).float(), torch.tensor([5.0]).float())
        # for x in xs:
        component_i = torch.exp(m.log_prob(x))* mixture_weights[c] #.numpy()
        # probs.append(probs)
        probs_sum+=component_i
    logprob = torch.log(probs_sum)
    return logprob
예제 #8
0
def plot_dist2(n_components, mixture_weights, true_mixture_weights, exp_dir, name=''):


    # mixture_weights = torch.softmax(needsoftmax_mixtureweight, dim=0)

    rows = 1
    cols = 1
    fig = plt.figure(figsize=(10+cols,4+rows), facecolor='white') #, dpi=150)

    col =0
    row = 0
    ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1)


    # xs = np.linspace(-9,205, 300)
    xs = np.linspace(-10,n_components*10 +5, 300)
    sum_ = np.zeros(len(xs))

    # C = 20
    for c in range(n_components):
        m = Normal(torch.tensor([c*10.]).float(), torch.tensor([5.0]).float())
        ys = []
        for x in xs:
            component_i = (torch.exp(m.log_prob(x) )* mixture_weights[c]).detach().cpu().numpy()
            ys.append(component_i)
        ys = np.reshape(np.array(ys), [-1])
        sum_ += ys
        ax.plot(xs, ys, label='', c='orange')
    ax.plot(xs, sum_, label='current', c='r')


    sum_ = np.zeros(len(xs))
    for c in range(n_components):
        m = Normal(torch.tensor([c*10.]).float(), torch.tensor([5.0]).float())
        ys = []
        for x in xs:
            component_i = (torch.exp(m.log_prob(x) )* true_mixture_weights[c]).detach().cpu().numpy()
            ys.append(component_i)
        ys = np.reshape(np.array(ys), [-1])
        sum_ += ys
        ax.plot(xs, ys, label='', c='c')
    ax.plot(xs, sum_, label='true', c='b')

    ax.legend()

    ax.set_title(str(mixture_weights) +'\n'+str(true_mixture_weights), size=8, family='serif')


    # save_dir = home+'/Documents/Grad_Estimators/GMM/'
    plt_path = exp_dir+'gmm_plot_dist'+name+'.png'
    plt.savefig(plt_path)
    print ('saved training plot', plt_path)
    plt.close()
예제 #9
0
파일: actor.py 프로젝트: sra4077/Horizon
    def get_log_prob(self, state, squashed_action):
        """
        Action is expected to be squashed with tanh
        """
        with torch.no_grad():
            loc, scale_log = self._get_loc_and_scale_log(state)
            # This is not getting exported; we can use it
            n = Normal(loc, scale_log.exp())
            raw_action = self._atanh(squashed_action)
            log_prob = torch.sum(
                n.log_prob(raw_action) - self._squash_correction(squashed_action), dim=1
            ).reshape(-1, 1)

        return log_prob
예제 #10
0
def sample_true(batch_size):
    # print (true_mixture_weights.shape)
    cat = Categorical(probs=torch.tensor(true_mixture_weights))
    cluster = cat.sample([batch_size]) # [B]
    mean = (cluster*10.).float()
    std = torch.ones([batch_size]) *5.
    # print (cluster.shape)
    # fsd
    # norm = Normal(torch.tensor([cluster*10.]).float(), torch.tensor([5.0]).float())
    norm = Normal(mean, std)
    samp = norm.sample()
    # print (samp.shape)
    # fadsf
    samp = samp.view(batch_size, 1)
    return samp
예제 #11
0
def plot_dist(x=None):

    if x is None:
        x1 = sample_true(1).cuda() 
    else:
        x1 = x[0].cpu().numpy()#.view(1,1)
        # print (x)

    mixture_weights = torch.softmax(needsoftmax_mixtureweight, dim=0)

    rows = 1
    cols = 1
    fig = plt.figure(figsize=(10+cols,4+rows), facecolor='white') #, dpi=150)

    col =0
    row = 0
    ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1)


    xs = np.linspace(-9,205, 300)
    sum_ = np.zeros(len(xs))

    C = 20
    for c in range(C):
        m = Normal(torch.tensor([c*10.]).float(), torch.tensor([5.0]).float())
        ys = []
        for x in xs:
            # component_i = (torch.exp(m.log_prob(x) )* ((c+5.) / 290.)).numpy()
            component_i = (torch.exp(m.log_prob(x) )* mixture_weights[c]).detach().cpu().numpy()


            ys.append(component_i)

        ys = np.reshape(np.array(ys), [-1])
        sum_ += ys
        ax.plot(xs, ys, label='')

    ax.plot(xs, sum_, label='')

    # print (x)
    ax.plot([x1,x1+.001],[0.,.002])
    # fasda

    # save_dir = home+'/Documents/Grad_Estimators/GMM/'
    plt_path = exp_dir+'gmm_plot_dist.png'
    plt.savefig(plt_path)
    print ('saved training plot', plt_path)
    plt.close()
예제 #12
0
def true_posterior(x, needsoftmax_mixtureweight):

    mixture_weights = torch.softmax(needsoftmax_mixtureweight, dim=0)
    probs_ = []
    for c in range(n_components):
        m = Normal(torch.tensor([c*10.]).float().cuda(), torch.tensor([5.0]).float().cuda())
        component_i = torch.exp(m.log_prob(x))* mixture_weights[c] #.numpy()
        # print(component_i.shape)
        # fsdf
        probs_.append(component_i[0])
    probs_ = torch.stack(probs_)
    probs_ = probs_ / torch.sum(probs_)
    # print (probs_.shape)
    # fdssdfd
    # logprob = torch.log(probs_sum)
    return probs_
예제 #13
0
def synthesize(model):
    global global_step
    model.eval()
    for batch_idx, (x, c) in enumerate(synth_loader):
        if batch_idx == 0:
            x, c = x.to(device), c.to(device)

            q_0 = Normal(x.new_zeros(x.size()), x.new_ones(x.size()))
            z = q_0.sample()

            start_time = time.time()
            with torch.no_grad():
                y_gen = model.module.reverse(z, c).squeeze()
            wav = y_gen.to(torch.device("cpu")).data.numpy()
            wav_name = '{}/{}/generate_{}_{}.wav'.format(args.sample_path, args.model_name, global_step, batch_idx)
            print('{} seconds'.format(time.time() - start_time))
            librosa.output.write_wav(wav_name, wav, sr=22050)
            print('{} Saved!'.format(wav_name))
            del x, c, z, q_0, y_gen, wav
예제 #14
0
def logprob_undercomponent(x, component, needsoftmax_mixtureweight, cuda=False):
    # c= component
    # C = c.
    B = x.shape[0]
    # print()
    # print (needsoftmax_mixtureweight.shape)
    mixture_weights = torch.softmax(needsoftmax_mixtureweight, dim=0)
    # print (mixture_weights.shape)
    # fdsfa
    # probs_sum = 0# = []
    # for c in range(n_components):
    # m = Normal(torch.tensor([c*10.]).float().cuda(), torch.tensor([5.0]).float() )#.cuda())
    mean = (component.float()*10.).view(B,1)
    std = (torch.ones([B]) *5.).view(B,1)
    # print (mean.shape) #[B]
    if not cuda:
        m = Normal(mean, std)#.cuda())
    else:
        m = Normal(mean.cuda(), std.cuda())
    # for x in xs:
    # component_i = torch.exp(m.log_prob(x))* mixture_weights[c] #.numpy()
    # print (m.log_prob(x))
    # print (torch.log(mixture_weights[c]))
    # print(x.shape)
    logpx_given_z = m.log_prob(x)
    logpz = torch.log(mixture_weights[component]).view(B,1)
    # print (px_given_z.shape)
    # print (component)
    # print (mixture_weights)
    # print (mixture_weights[component])
    # print (torch.log(mixture_weights[component]).shape)
    # fdsasa
    # print (logpx_given_z.shape)
    # print (logpz.shape)
    # fsdfas
    logprob = logpx_given_z + logpz
    # print (logprob.shape)
    # fsfd
    # probs.append(probs)
    # probs_sum+=component_i
    # logprob = torch.log(component_i)
    return logprob
예제 #15
0
def sample_posterior (mus, sigmas):

    """"
	For every training batch, we need to sample the weights from the Variational Posterior.
    This function will be called for any element of the Model that used Bayesian weights.
        
    The first time we want to sample from the posterior during training, the variable
    will not exist and it will be sampled from the Prior. The next times it will just be obtained.
        
    In this case the variables are the parameters of the posterior :), the mus and stds.
            
    """
    # Reparametrization !!
    # The eps for the reparametrizaiton trick
    eps = Normal(0.0, 1.0).sample(mus.size()).to( dtype = dtype,device = device)
#        sigmas = softplus(rhos)
#    print (sigmas.device)
#    print (mus.device)
#    print (eps.device)
    posterior_samples = eps.mul(sigmas).add(mus)
    return posterior_samples
예제 #16
0
def logprob_undercomponent(x, component, needsoftmax_mixtureweight, cuda=False):
    c= component
    # print (needsoftmax_mixtureweight.shape)
    mixture_weights = torch.softmax(needsoftmax_mixtureweight, dim=0)
    # probs_sum = 0# = []
    # for c in range(n_components):
    # m = Normal(torch.tensor([c*10.]).float().cuda(), torch.tensor([5.0]).float() )#.cuda())
    if not cuda:
        m = Normal(torch.tensor([c*10.]).float(), torch.tensor([5.0]).float() )#.cuda())
    else:
        m = Normal(torch.tensor([c*10.]).float().cuda(), torch.tensor([5.0]).float().cuda())
    # for x in xs:
    # component_i = torch.exp(m.log_prob(x))* mixture_weights[c] #.numpy()
    # print (m.log_prob(x))
    # print (torch.log(mixture_weights[c]))

    logprob = m.log_prob(x) + torch.log(mixture_weights[c])
    # probs.append(probs)
    # probs_sum+=component_i
    # logprob = torch.log(component_i)
    return logprob
예제 #17
0
def gmm_loss(batch, mus, sigmas, logpi, reduce=True): # pylint: disable=too-many-arguments
    """ Computes the gmm loss.

    Compute minus the log probability of batch under the GMM model described
    by mus, sigmas, pi. Precisely, with bs1, bs2, ... the sizes of the batch
    dimensions (several batch dimension are useful when you have both a batch
    axis and a time step axis), gs the number of mixtures and fs the number of
    features.

    :args batch: (bs1, bs2, *, fs) torch tensor
    :args mus: (bs1, bs2, *, gs, fs) torch tensor
    :args sigmas: (bs1, bs2, *, gs, fs) torch tensor
    :args logpi: (bs1, bs2, *, gs) torch tensor
    :args reduce: if not reduce, the mean in the following formula is ommited

    :returns:
    loss(batch) = - mean_{i1=0..bs1, i2=0..bs2, ...} log(
        sum_{k=1..gs} pi[i1, i2, ..., k] * N(
            batch[i1, i2, ..., :] | mus[i1, i2, ..., k, :], sigmas[i1, i2, ..., k, :]))

    NOTE: The loss is not reduced along the feature dimension (i.e. it should scale ~linearily
    with fs).
    """
    batch = batch.unsqueeze(-2)
    normal_dist = Normal(mus, sigmas)
    g_log_probs = normal_dist.log_prob(batch)
    g_log_probs = logpi + torch.sum(g_log_probs, dim=-1)
    max_log_probs = torch.max(g_log_probs, dim=-1, keepdim=True)[0]
    g_log_probs = g_log_probs - max_log_probs

    g_probs = torch.exp(g_log_probs)
    probs = torch.sum(g_probs, dim=-1)

    log_prob = max_log_probs.squeeze() + torch.log(probs)
    if reduce:
        return - torch.mean(log_prob)
    return - log_prob
예제 #18
0
 def forward(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
     return Normal(loc=0, scale=1).cdf(samples.squeeze(-1))
예제 #19
0
def KL_standard_normal(mu, sigma):
    p = Normal(torch.zeros_like(mu), torch.ones_like(mu))
    q = Normal(mu, sigma)
    return torch.sum(torch.distributions.kl_divergence(q, p))
    def forward(self, X, u):
        [_, self.bs, c, self.d, self.d] = X.shape
        T = len(self.t_eval)
        # encode
        self.r0_m, self.r0_v, self.phi0_m, self.phi0_v, self.phi0_m_n = self.encode(
            X[0])
        self.r1_m, self.r1_v, self.phi1_m, self.phi1_v, self.phi1_m_n = self.encode(
            X[1])

        # reparametrize
        self.Q_r0 = Normal(self.r0_m, self.r0_v)
        self.P_normal = Normal(torch.zeros_like(self.r0_m),
                               torch.ones_like(self.r0_v))
        self.r0 = self.Q_r0.rsample()

        self.Q_phi0 = VonMisesFisher(self.phi0_m_n, self.phi0_v)
        self.P_hyper_uni = HypersphericalUniform(1, device=self.device)
        self.phi0 = self.Q_phi0.rsample()
        while torch.isnan(self.phi0).any():
            self.phi0 = self.Q_phi0.rsample()

        # estimate velocity
        self.r_dot0 = (self.r1_m - self.r0_m) / (self.t_eval[1] -
                                                 self.t_eval[0])
        self.phi_dot0 = self.angle_vel_est(self.phi0_m_n, self.phi1_m_n,
                                           self.t_eval[1] - self.t_eval[0])

        # predict
        z0_u = torch.cat([self.r0, self.phi0, self.r_dot0, self.phi_dot0, u],
                         dim=1)
        zT_u = odeint(self.ode, z0_u, self.t_eval,
                      method=self.hparams.solver)  # T, bs, 4
        self.qT, self.q_dotT, _ = zT_u.split([3, 2, 2], dim=-1)
        self.qT = self.qT.view(T * self.bs, 3)

        # decode
        ones = torch.ones_like(self.qT[:, 0:1])
        self.cart = self.obs_net_1(ones)
        self.pole = self.obs_net_2(ones)

        theta1 = self.get_theta_inv(1, 0, self.qT[:, 0], 0, bs=T * self.bs)
        theta2 = self.get_theta_inv(self.qT[:, 1],
                                    self.qT[:, 2],
                                    self.qT[:, 0],
                                    0,
                                    bs=T * self.bs)

        grid1 = F.affine_grid(theta1,
                              torch.Size((T * self.bs, 1, self.d, self.d)))
        grid2 = F.affine_grid(theta2,
                              torch.Size((T * self.bs, 1, self.d, self.d)))

        transf_cart = F.grid_sample(
            self.cart.view(T * self.bs, 1, self.d, self.d), grid1)
        transf_pole = F.grid_sample(
            self.pole.view(T * self.bs, 1, self.d, self.d), grid2)
        self.Xrec = torch.cat(
            [transf_cart, transf_pole,
             torch.zeros_like(transf_cart)], dim=1)
        self.Xrec = self.Xrec.view(T, self.bs, 3, self.d, self.d)
        return None
예제 #21
0
class AIR(nn.Module):
    def __init__(self, arch=None):
        """
        :param arch: dictionary, for overriding default architecture
        """
        nn.Module.__init__(self)
        self.arch = deepcopy(default_arch)
        if arch is not None:
            self.arch.update(arch)

        self.T = self.arch.max_steps
        self.reinforce_weight = 0.0

        # 4: where + pres
        lstm_input_size = self.arch.input_size + self.arch.z_what_size + 4
        self.lstm_cell = LSTMCell(lstm_input_size, self.arch.lstm_hidden_size)

        # predict z_where, z_pres from h
        self.predict = Predict(self.arch)
        # encode object into what
        self.encoder = Encoder(self.arch)
        # decode what into object
        self.decoder = Decoder(self.arch)

        # spatial transformers
        self.image_to_object = SpatialTransformer(self.arch.input_shape,
                                                  self.arch.object_shape)
        self.object_to_image = SpatialTransformer(self.arch.object_shape,
                                                  self.arch.input_shape)

        # baseline RNN
        self.bl_rnn = LSTMCell(lstm_input_size, self.arch.baseline_hidden_size)
        # predict baseline value
        self.bl_predict = nn.Linear(self.arch.baseline_hidden_size, 1)

        # priors
        self.pres_prior = Bernoulli(probs=self.arch.z_pres_prob_prior)
        self.where_prior = Normal(loc=self.arch.z_where_loc_prior,
                                  scale=self.arch.z_where_scale_prior)
        self.what_prior = Normal(loc=self.arch.z_what_loc_prior,
                                 scale=self.arch.z_what_scale_prior)

        # modules excluding baseline rnn
        self.air_modules = nn.ModuleList(
            [self.predict, self.lstm_cell, self.encoder, self.decoder])

        self.baseline_modules = nn.ModuleList([self.bl_rnn, self.bl_predict])

    def forward(self, x):
        B = x.size(0)
        state = AIRState.get_intial_state(B, self.arch)

        # accumulated KL divergence
        kl = []
        # baseline value for each step
        baseline_value = []
        # z_pres likelihood for each step
        z_pres_likelihood = []
        # learning signal for each step
        learning_signal = torch.zeros(B, self.arch.max_steps, device=x.device)
        # signal_mask (prev.z_pres)
        signal_mask = torch.ones(B, self.arch.max_steps, device=x.device)
        # mask (z_pres)
        mask = torch.ones(B, self.arch.max_steps, device=x.device)
        # canvas
        h, w = self.arch.input_shape
        canvas = torch.zeros(B, 1, h, w, device=x.device)

        if DEBUG:
            vis_logger['image'] = x[0]
            vis_logger['z_pres_p_list'] = []
            vis_logger['z_pres_list'] = []
            vis_logger['canvas_list'] = []
            vis_logger['z_where_list'] = []
            vis_logger['object_enc_list'] = []
            vis_logger['object_dec_list'] = []
            vis_logger['kl_pres_list'] = []
            vis_logger['kl_what_list'] = []
            vis_logger['kl_where_list'] = []

        for t in range(self.T):
            # This is prev.z_pres. The only purpose is for masking learning signal.
            signal_mask[:, t] = state.z_pres.squeeze()

            # all terms are already masked
            state, this_kl, this_baseline_value, this_z_pres_likelihood = self.infer_step(
                state, x)
            baseline_value.append(this_baseline_value.squeeze())
            kl.append(this_kl)
            z_pres_likelihood.append(this_z_pres_likelihood.squeeze())

            # add learning signal to depending terms (1:i-1)
            # NOTE: kl of z_pres of current step does not depends on sample from
            # z_pres, but kl of z_where and z_what DOES. They cannot be excluded
            # from learning signal. So here we use t + 1 instead of t. Although
            # this also includes kl of z_pres of current step, this will not
            # matter too much
            for j in range(t + 1):
                learning_signal[:, j] += this_kl.squeeze()

            # reconstruct
            object = self.decoder(state.z_what)
            # (B, 1, H, W)
            img = self.object_to_image(object, state.z_where, inverse=False)
            # Masking is crucial here.
            canvas = canvas + img * state.z_pres[:, :, None, None]

            mask[:, t] = state.z_pres.squeeze()

            vis_logger['canvas_list'].append(canvas[0])
            vis_logger['object_dec_list'].append(object[0])

        baseline_value = torch.stack(baseline_value, dim=1)
        kl = torch.stack(kl, dim=1)
        z_pres_likelihood = torch.stack(z_pres_likelihood, dim=1)

        # construct output distribution
        output_dist = Normal(canvas, self.arch.x_scale.expand(canvas.shape))
        likelihood = output_dist.log_prob(x)
        # sum over data dimension
        likelihood = likelihood.view(B, -1).sum(1)

        # Construct surrogate loss
        # Note the MNIUS sign here !
        learning_signal = learning_signal - likelihood[:, None]
        learning_signal = learning_signal * signal_mask
        reinforce_term = (learning_signal.detach() -
                          baseline_value.detach()) * z_pres_likelihood
        reinforce_term = reinforce_term.sum(1)
        # reinforce_term = torch.zeros_like(reinforce_term)

        # kl term, sum over batch dimension
        kl = kl.sum(1)

        loss = self.reinforce_weight * reinforce_term + kl - likelihood
        # mean over batch dimension
        loss = loss.mean()

        vis_logger['reinforce_loss'] = (reinforce_term.mean())
        vis_logger['kl_loss'] = (kl.mean())
        vis_logger['neg_likelihood'] = (-likelihood.mean())

        # compute baseline loss
        baseline_loss = F.mse_loss(baseline_value, learning_signal.detach())

        vis_logger['baseline_loss'] = baseline_loss

        # losslist = (reinforce_term.mean(), kl.mean(), likelihood.mean(), baseline_loss)

        return loss + baseline_loss, mask.sum(1)

    def infer_step(self, prev, x):
        """
        Given previous state, predict next state. We assume that z_pres is 1
        :param prev: AIRState
        :return: new_state, KL, baseline value, z_pres_likelihood
        """

        B = x.size(0)

        # Flatten x
        x_flat = x.view(B, -1)

        # First, compute h_t that encodes (x, z[1:i-1])
        lstm_input = torch.cat(
            (x_flat, prev.z_where, prev.z_what, prev.z_pres), dim=1)
        h, c = self.lstm_cell(lstm_input, (prev.h, prev.c))

        # Predict presence and location
        z_pres_p, z_where_loc, z_where_scale = self.predict(h)

        # In theory, if z_pres is 0, we don't need to continue computation. But
        # for batch processing, we will do this anyway.

        # sample z_pres
        z_pres_p = z_pres_p * prev.z_pres

        # NOTE: for numerical stability, if z_pres_p is 0 or 1, we will need to
        # clamp it to within (0, 1), or otherwise the gradient will explode
        eps = 1e-6
        z_pres_p = z_pres_p + eps * (z_pres_p == 0).float() - eps * (
            z_pres_p == 1).float()

        z_pres_post = Bernoulli(z_pres_p)
        z_pres = z_pres_post.sample()
        z_pres = z_pres * prev.z_pres

        # Likelihood. Note we must use prev.z_pres instead of z_pres because
        # p(z_pres[i]=0|z_prse[i]=1) is non-zero.
        z_pres_likelihood = z_pres_post.log_prob(z_pres) * prev.z_pres
        # (B,)
        z_pres_likelihood = z_pres_likelihood.squeeze()

        # sample z_where
        z_where_post = Normal(z_where_loc, z_where_scale)
        z_where = z_where_post.rsample()

        # extract object
        # (B, 1, Hobj, Wobj)
        object = self.image_to_object(x, z_where, inverse=True)

        # predict z_what
        z_what_loc, z_what_scale = self.encoder(object)
        z_what_post = Normal(z_what_loc, z_what_scale)
        z_what = z_what_post.rsample()

        # compute baseline for this z_pres
        bl_h, bl_c = self.bl_rnn(lstm_input.detach(), (prev.bl_h, prev.bl_c))
        # (B,)
        baseline_value = self.bl_predict(bl_h).squeeze()
        # If z_pres[i-1] is 0, the reinforce term will not be dependent on phi.
        # In this case, we don't need the term. So we set it to zero.
        # At the same time, we must set learning signal to zero as this will
        # matter in baseline loss computation.
        baseline_value = baseline_value * prev.z_pres.squeeze()

        # Compute KL as we go, sum over data dimension
        kl_pres = kl_divergence(
            z_pres_post,
            self.pres_prior.expand(z_pres_post.batch_shape)).sum(1)
        kl_where = kl_divergence(
            z_where_post,
            self.where_prior.expand(z_where_post.batch_shape)).sum(1)
        kl_what = kl_divergence(
            z_what_post,
            self.what_prior.expand(z_what_post.batch_shape)).sum(1)

        # For where and what, when z_pres[i] is 0, they are determnisitic
        kl_where = kl_where * z_pres.squeeze()
        kl_what = kl_what * z_pres.squeeze()
        # For pres, this is not the case. So we use prev.z_pres.
        kl_pres = kl_pres * prev.z_pres.squeeze()

        kl = (kl_pres + kl_where + kl_what)

        # new state
        new_state = AIRState(z_pres=z_pres,
                             z_where=z_where,
                             z_what=z_what,
                             h=h,
                             c=c,
                             bl_c=bl_c,
                             bl_h=bl_h,
                             z_pres_p=z_pres_p)

        # Logging
        if DEBUG:
            vis_logger['z_pres_p_list'].append(z_pres_p[0])
            vis_logger['z_pres_list'].append(z_pres[0])
            vis_logger['z_where_list'].append(z_where[0])
            vis_logger['object_enc_list'].append(object[0])
            vis_logger['kl_pres_list'].append(kl_pres.mean())
            vis_logger['kl_what_list'].append(kl_what.mean())
            vis_logger['kl_where_list'].append(kl_where.mean())

        return new_state, kl, baseline_value, z_pres_likelihood
def plot_both_dists():

    # needsoftmax_mixtureweight = needsoftmax_mixtureweight.cpu()

    #MAKE PLOT OF DISTRIBUTION
    rows = 1
    cols = 1
    fig = plt.figure(figsize=(10+cols,4+rows), facecolor='white') #, dpi=150)

    col =0
    row = 0
    ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1)



    xs = np.linspace(-9,205, 300)
    sum_ = np.zeros(len(xs))
    # C = 20
    for c in range(n_components):
        m = Normal(torch.tensor([c*10.]).float(), torch.tensor([5.0]).float())
        # xs = torch.tensor(xs)
        # print (m.log_prob(lin))
        ys = []
        for x in xs:
            # print (m.log_prob(x))
            # component_i = (torch.exp(m.log_prob(x) )* ((c+5.) / denom)).numpy()
            component_i = (torch.exp(m.log_prob(x) )* true_mixture_weights[c]).numpy()
            ys.append(component_i)
        ys = np.reshape(np.array(ys), [-1])
        sum_ += ys
        ax.plot(xs, ys, label='', c='c')
    ax.plot(xs, sum_, label='')



    # mixture_weights = torch.softmax(needsoftmax_mixtureweight, dim=0)
    # xs = np.linspace(-9,205, 300)
    # sum_ = np.zeros(len(xs))
    # C = 20
    # for c in range(C):
    #     m = Normal(torch.tensor([c*10.]).float(), torch.tensor([5.0]).float())
    #     # xs = torch.tensor(xs)
    #     # print (m.log_prob(lin))
    #     ys = []
    #     for x in xs:
    #         # print (m.log_prob(x))
    #         component_i = (torch.exp(m.log_prob(x) )* mixture_weights[c]).detach().numpy()
    #         ys.append(component_i)
    #     ys = np.reshape(np.array(ys), [-1])
    #     sum_ += ys
    #     ax.plot(xs, ys, label='', c='r')
    # ax.plot(xs, sum_, label='')


    # #HISTOGRAM
    # xs = []
    # for i in range(10000):
    #     x = sample_true().item()
    #     xs.append(x)
    # ax.hist(xs, bins=200, density=True)



    # # save_dir = home+'/Documents/Grad_Estimators/GMM/'
    # if simplax:
    #     plt_path = exp_dir+'gmm_pdf_plot_simplax.png'
    # elif reinforce:
    #     plt_path = exp_dir+'gmm_pdf_plot_reinforce.png'
    # elif marginal:
    #     plt_path = exp_dir+'gmm_pdf_plot_marginal.png'

    # plt.savefig(plt_path)
    # print ('saved training plot', plt_path)
    # plt.close()




    # save_dir = home+'/Documents/Grad_Estimators/GMM/'
    plt_path = exp_dir+'gmm_distplot.png'
    plt.savefig(plt_path)
    print ('saved training plot', plt_path)
    plt.close()
예제 #23
0
파일: models.py 프로젝트: mukami12/REM
class REM(nn.Module):
    def __init__(self, x_dim, z_dim, h_dim, version):
        super(REM, self).__init__()
        self.x_dim = x_dim
        self.z_dim = z_dim
        self.version = version

        self.encoder = Encoder(x_dim, z_dim, h_dim)
        self.decoder = Decoder(x_dim, z_dim, h_dim)
        self.prior = Normal(
            torch.zeros([z_dim]).to(device),
            torch.ones([z_dim]).to(device))

    def forward(self, x, S):
        x = x.view(-1, self.x_dim)
        bsz = x.size(0)

        ### get w and \alpha and L(\theta)
        mu, logvar = self.encoder(x)
        q_phi = Normal(loc=mu, scale=torch.exp(0.5 * logvar))
        z_q = q_phi.rsample((S, ))
        recon_batch = self.decoder(z_q)
        x_dist = Bernoulli(logits=recon_batch)
        log_lik = x_dist.log_prob(x).sum(-1)
        log_prior = self.prior.log_prob(z_q).sum(-1)
        log_q = q_phi.log_prob(z_q).sum(-1)
        log_w = log_lik + log_prior - log_q
        tmp_alpha = torch.logsumexp(log_w, dim=0).unsqueeze(0)
        alpha = torch.exp(log_w - tmp_alpha).detach()
        if self.version == 'v1':
            p_loss = -alpha * (log_lik + log_prior)

        ### get moment-matched proposal
        mu_r = alpha.unsqueeze(2) * z_q
        mu_r = mu_r.sum(0).detach()
        z_minus_mu_r = z_q - mu_r.unsqueeze(0)
        reshaped_diff = z_minus_mu_r.view(S * bsz, -1, 1)
        reshaped_diff_t = reshaped_diff.permute(0, 2, 1)
        outer = torch.bmm(reshaped_diff, reshaped_diff_t)
        outer = outer.view(S, bsz, self.z_dim, self.z_dim)
        Sigma_r = outer.mean(0) * S / (S - 1)
        Sigma_r = Sigma_r + torch.eye(self.z_dim).to(device) * 1e-6  ## ridging

        ### get v, \beta, and L(\phi)
        L = torch.cholesky(Sigma_r)
        r_phi = MultivariateNormal(loc=mu_r, scale_tril=L)

        z = r_phi.rsample((S, ))
        z_r = z.detach()
        recon_batch_r = self.decoder(z_r)
        x_dist_r = Bernoulli(logits=recon_batch_r)
        log_lik_r = x_dist_r.log_prob(x).sum(-1)
        log_prior_r = self.prior.log_prob(z_r).sum(-1)
        log_r = r_phi.log_prob(z_r)
        log_v = log_lik_r + log_prior_r - log_r
        tmp_beta = torch.logsumexp(log_v, dim=0).unsqueeze(0)
        beta = torch.exp(log_v - tmp_beta).detach()
        log_q = q_phi.log_prob(z_r).sum(-1)
        q_loss = -beta * log_q

        if self.version == 'v2':
            p_loss = -beta * (log_lik_r + log_prior_r)

        rem_loss = torch.sum(q_loss + p_loss, 0).sum()
        return rem_loss

    def log_lik(self, loader, n_samples):
        """Get log marginal estimate via importance sampling
        """
        nll = 0
        for i, (data, _) in enumerate(loader):
            data = data.view(-1, self.x_dim).to(device)
            bsz = data.size(0)
            mu, logvar = self.encoder(data)

            ### get moment-matched proposal
            q_phi = Normal(loc=mu, scale=torch.exp(0.5 * logvar))
            z_q = q_phi.rsample((n_samples, ))
            recon_batch = self.decoder(z_q)
            x_dist = Bernoulli(logits=recon_batch)
            log_lik = x_dist.log_prob(data).sum(-1)
            log_prior = self.prior.log_prob(z_q).sum(-1)
            log_q = q_phi.log_prob(z_q).sum(-1)
            log_w = log_lik + log_prior - log_q
            tmp_alpha = torch.logsumexp(log_w, dim=0).unsqueeze(0)
            alpha = torch.exp(log_w - tmp_alpha).detach()

            mu_r = alpha.unsqueeze(2) * z_q
            mu_r = mu_r.sum(0).detach()

            nll_proposal = Normal(loc=mu_r, scale=torch.exp(0.5 * logvar))

            bsz = data.size(0)

            z = nll_proposal.rsample((n_samples, ))
            recon_batch = self.decoder(z)
            x_dist = Bernoulli(logits=recon_batch)
            log_lik = x_dist.log_prob(data).sum(-1)
            log_prior = self.prior.log_prob(z).sum(-1)
            log_r = nll_proposal.log_prob(z).sum(-1)

            loss = log_lik + log_prior - log_r
            ll = torch.logsumexp(loss, dim=0) - math.log(n_samples)
            ll = ll.sum()
            nll += -ll.item()

            if i > 0 and i % 20000000 == 0:
                print('i: {}/{}'.format(i, len(loader)))

        nll /= len(loader.dataset)
        return nll
예제 #24
0
    def compute_all_losses(self, batch_dict, n_traj_samples=1, kl_coef=1.):
        # Condition on subsampled points
        # Make predictions for all the points
        pred_y, info = self.get_reconstruction(
            batch_dict["tp_to_predict"],
            batch_dict["observed_data"],
            batch_dict["observed_tp"],
            mask=batch_dict["observed_mask"],
            n_traj_samples=n_traj_samples,
            mode=batch_dict["mode"])

        #print("get_reconstruction done -- computing likelihood")
        fp_mu, fp_std, fp_enc = info["first_point"]
        fp_std = fp_std.abs()
        fp_distr = Normal(fp_mu, fp_std)

        assert (torch.sum(fp_std < 0) == 0.)

        kldiv_z0 = kl_divergence(fp_distr, self.z0_prior)

        if torch.isnan(kldiv_z0).any():
            print(fp_mu)
            print(fp_std)
            raise Exception("kldiv_z0 is Nan!")

        # Mean over number of latent dimensions
        # kldiv_z0 shape: [n_traj_samples, n_traj, n_latent_dims] if prior is a mixture of gaussians (KL is estimated)
        # kldiv_z0 shape: [1, n_traj, n_latent_dims] if prior is a standard gaussian (KL is computed exactly)
        # shape after: [n_traj_samples]
        kldiv_z0 = torch.mean(kldiv_z0, (1, 2))

        # Compute likelihood of all the points
        rec_likelihood = self.get_gaussian_likelihood(
            batch_dict["data_to_predict"],
            pred_y,
            mask=batch_dict["mask_predicted_data"])

        mse = self.get_mse(batch_dict["data_to_predict"],
                           pred_y,
                           mask=batch_dict["mask_predicted_data"])

        pois_log_likelihood = torch.Tensor([0.]).to(
            utils.get_device(batch_dict["data_to_predict"]))
        if self.use_poisson_proc:
            pois_log_likelihood = compute_poisson_proc_likelihood(
                batch_dict["data_to_predict"],
                pred_y,
                info,
                mask=batch_dict["mask_predicted_data"])
            # Take mean over n_traj
            pois_log_likelihood = torch.mean(pois_log_likelihood, 1)

        ################################
        # Compute CE loss for binary classification on Physionet
        device = utils.get_device(batch_dict["data_to_predict"])
        ce_loss = torch.Tensor([0.]).to(device)
        if (batch_dict["labels"] is not None) and self.use_binary_classif:

            if (batch_dict["labels"].size(-1) == 1) or (len(
                    batch_dict["labels"].size()) == 1):
                ce_loss = compute_binary_CE_loss(info["label_predictions"],
                                                 batch_dict["labels"])
            else:
                ce_loss = compute_multiclass_CE_loss(
                    info["label_predictions"],
                    batch_dict["labels"],
                    mask=batch_dict["mask_predicted_data"])

        # IWAE loss
        loss = -torch.logsumexp(rec_likelihood - kl_coef * kldiv_z0, 0)
        if torch.isnan(loss):
            loss = -torch.mean(rec_likelihood - kl_coef * kldiv_z0, 0)

        if self.use_poisson_proc:
            loss = loss - 0.1 * pois_log_likelihood

        if self.use_binary_classif:
            if self.train_classif_w_reconstr:
                loss = loss + ce_loss * 100
            else:
                loss = ce_loss

        results = {}
        results["loss"] = torch.mean(loss)
        results["likelihood"] = torch.mean(rec_likelihood).detach()
        results["mse"] = torch.mean(mse).detach()
        results["pois_likelihood"] = torch.mean(pois_log_likelihood).detach()
        results["ce_loss"] = torch.mean(ce_loss).detach()
        results["kl_first_p"] = torch.mean(kldiv_z0).detach()
        results["std_first_p"] = torch.mean(fp_std).detach()

        if batch_dict["labels"] is not None and self.use_binary_classif:
            results["label_predictions"] = info["label_predictions"].detach()

        return results
예제 #25
0
 def compute_loglik(self, data, recon):
     dist = Normal(torch.tensor([0.0]), torch.tensor([0.3]))
     diff = data - recon
     a = dist.log_prob(diff.cpu())
     return torch.sum(a) / self.batch_size
예제 #26
0
class HMC():
    def __init__(self, S, B, N, K, D, hmc_num_steps, leapfrog_step_size,
                 leapfrog_num_steps, CUDA, device):
        self.S, self.B, self.N, self.K, self.D = S, B, N, K, D
        self.Sigma = torch.ones(self.D)
        self.mu = torch.zeros(self.D)
        self.accept_count = 0.0
        if CUDA:
            with torch.cuda.device(device):
                self.Sigma = self.Sigma.cuda()
                self.mu = self.mu.cuda()
                self.uniformer = Uniform(
                    torch.Tensor([0.0]).cuda(),
                    torch.Tensor([1.0]).cuda())
        else:
            self.uniformer = Uniform(torch.Tensor([0.0]), torch.Tensor([1.0]))

        self.gauss_dist = Normal(self.mu, self.Sigma)
        self.hmc_num_steps = hmc_num_steps
        self.leapfrog_step_size = leapfrog_step_size
        self.leapfrog_num_steps = leapfrog_num_steps

    def init_sample(self):
        """
        initialize auxiliary variables from univariate Gaussian
        return r_tau, r_mu
        """
        return self.gauss_dist.sample((
            self.S,
            self.B,
            self.K,
        )), self.gauss_dist.sample((
            self.S,
            self.B,
            self.K,
        ))

    def hmc_sampling(self, generative, x, log_tau, mu, trace):
        for m in range(self.hmc_num_steps):
            log_tau, mu = self.metrioplis(generative,
                                          x,
                                          log_tau=log_tau.detach(),
                                          mu=mu.detach(),
                                          step_size=self.leapfrog_step_size,
                                          num_steps=self.leapfrog_num_steps)
            posterior_logits = posterior_z(x,
                                           tau=log_tau.exp(),
                                           mu=mu,
                                           prior_pi=generative.prior_pi)
            E_z = posterior_logits.exp().mean(0)
            z = cat(logits=posterior_logits).sample()
            log_joint = self.log_joint(generative,
                                       x,
                                       z=z,
                                       tau=log_tau.exp(),
                                       mu=mu)
            trace['density'].append(log_joint.unsqueeze(0))
        return log_tau, mu, trace

    def log_joint(self, generative, x, z, tau, mu):
        ll = generative.log_prob(x, z=z, tau=tau, mu=mu, aggregate=True)
        log_prior_tau = Gamma(
            generative.prior_alpha,
            generative.prior_beta).log_prob(tau).sum(-1).sum(-1)
        log_prior_mu = Normal(
            generative.prior_mu, 1. /
            (generative.prior_nu * tau).sqrt()).log_prob(mu).sum(-1).sum(-1)
        log_prior_z = cat(probs=generative.prior_pi).log_prob(z).sum(-1)
        return (ll + log_prior_tau + log_prior_mu + log_prior_z)

    def metrioplis(self, generative, x, log_tau, mu, step_size, num_steps):
        r_tau, r_mu = self.init_sample()
        ## compute hamiltonian given original position and momentum
        H_orig = self.hamiltonian(generative,
                                  x,
                                  log_tau=log_tau,
                                  mu=mu,
                                  r_tau=r_tau,
                                  r_mu=r_mu)
        new_log_tau, new_mu, new_r_tau, new_r_mu = self.leapfrog(
            generative, x, log_tau, mu, r_tau, r_mu, step_size, num_steps)
        ## compute hamiltonian given new proposals
        H_new = self.hamiltonian(generative,
                                 x,
                                 log_tau=new_log_tau,
                                 mu=new_mu,
                                 r_tau=new_r_tau,
                                 r_mu=new_r_mu)
        accept_ratio = (H_new - H_orig).exp()
        u_samples = self.uniformer.sample((
            self.S,
            self.B,
        )).squeeze(-1)
        accept_index = (u_samples < accept_ratio)
        # assert accept_index.shape == (self.S, self.B), "ERROR! index has unexpected shape."
        accept_index_expand = accept_index.unsqueeze(-1).unsqueeze(-1).repeat(
            1, 1, self.K, self.D)
        assert accept_index_expand.shape == (
            self.S, self.B, self.K,
            self.D), "ERROR! index has unexpected shape."
        filtered_log_tau = new_log_tau * accept_index_expand.float(
        ) + log_tau * (~accept_index_expand).float()
        filtered_mu = new_mu * accept_index_expand.float() + mu * (
            ~accept_index_expand).float()
        self.accept_count = self.accept_count + accept_index_expand.float()
        return filtered_log_tau.detach(), filtered_mu.detach()

    def leapfrog(self, generative, x, log_tau, mu, r_tau, r_mu, step_size,
                 num_steps):
        for step in range(num_steps):
            log_tau.requires_grad, mu.requires_grad = True, True
            log_p = self.log_marginal(generative, x, log_tau, mu)
            log_p.sum().backward(retain_graph=False)
            r_tau = (r_tau + 0.5 * step_size * log_tau.grad).detach()
            r_mu = (r_mu + 0.5 * step_size * mu.grad).detach()
            log_tau = (log_tau + step_size * r_tau).detach()
            mu = (mu + step_size * r_mu).detach()
            log_tau.requires_grad, mu.requires_grad = True, True
            log_p = self.log_marginal(generative, x, log_tau, mu)
            log_p.sum().backward(retain_graph=False)
            r_tau = (r_tau + 0.5 * step_size * log_tau.grad).detach()
            r_mu = (r_mu + 0.5 * step_size * mu.grad).detach()
            log_tau, mu = log_tau.detach(), mu.detach()
        return log_tau, mu, r_tau, r_mu

    def hamiltonian(self, generative, x, log_tau, mu, r_tau, r_mu):
        """
        compute the Hamiltonian given the position and momntum
        """
        Kp = self.kinetic_energy(r_tau=r_tau, r_mu=r_mu)
        Uq = self.log_marginal(generative, x, log_tau=log_tau, mu=mu)
        assert Kp.shape == (self.S, self.B), "ERROR! Kp has unexpected shape."
        assert Uq.shape == (self.S, self.B), 'ERROR! Uq has unexpected shape.'
        return Kp + Uq

    def kinetic_energy(self, r_tau, r_mu):
        """
        r_tau, r_mu : S * B * K * D
        return - 1/2 * ||(r_tau, r_mu)||^2
        """
        return -((r_tau**2).sum(-1).sum(-1) + (r_mu**2).sum(-1).sum(-1)) * 0.5

    def log_marginal(self, generative, x, log_tau, mu):
        """
        compute log density log p(x_1:N, mu_1:N, tau_1:N)
        by marginalizing discrete varaibles :                                   
        = \sum_{n=1}^N [log(\sum_{k=1}^K N(x_n; \mu_k, \Sigma_k)) - log(K)]
          + \sum_{k=1}^K [log p(\mu_k) + log p(\Sigma_k)]
        """
        tau = log_tau.exp()
        sigma = 1. / tau.sqrt()
        logprior_tau = (Gamma(generative.prior_alpha,
                              generative.prior_beta).log_prob(tau) +
                        log_tau).sum(-1).sum(-1)  # S * B
        logprior_mu = Normal(
            generative.prior_mu, 1. /
            (generative.prior_nu * tau).sqrt()).log_prob(mu).sum(-1).sum(-1)
        mu_expand = mu.unsqueeze(2).repeat(1, 1, self.N, 1,
                                           1).permute(3, 0, 1, 2, 4)
        sigma_expand = sigma.unsqueeze(2).repeat(1, 1, self.N, 1, 1).permute(
            3, 0, 1, 2, 4)  #  K * S * B * N * D
        ll = Normal(mu_expand, sigma_expand).log_prob(x).sum(-1).permute(
            1, 2, 3, 0)  # S * B * N * K
        log_density = torch.logsumexp(generative.prior_pi.log() + ll,
                                      dim=-1).sum(-1)
        return log_density + logprior_mu + logprior_tau
예제 #27
0
        # the shape of satistical events (8), and that we
        # want "nsmpl" independent events.
        pyro.sample('z', dist.Normal(mu, sd).to_event(1))


if __name__ == "__main__":

    if sys.version_info < (3, 0):
        sys.stderr.write("Requires Python 3\n")

    # Do it with CUDA if possible.
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Declare an Adam-based Stochastic Variational Inference engine.
    adam_params = {"lr": 0.01, "betas": (0.90, 0.999)}
    opt = pyro.optim.Adam(adam_params)
    svi = pyro.infer.SVI(model, guide, opt, loss=pyro.infer.Trace_ELBO())

    obs = Normal(lin(torch.tensor([[2., -2.]]).to(device)), .1).sample()

    loss = 0.
    for step in range(1000):
        loss += svi.step(obs)
        if (step + 1) % 100 == 0:
            print(loss)
            loss = 0
    import pdb
    pdb.set_trace()
    mu = pyro.param('mu')
    sd = pyro.param('sd')
예제 #28
0
 def forward(self, x):
     mu = self.mu(x)
     std = self.logstd.exp().expand_as(mu)  # 扩充为跟mu的大小一样
     m = Normal(mu, std)
     return m
class Model(pl.LightningModule):
    def __init__(self, hparams, data_path=None):
        super(Model, self).__init__()
        self.hparams = hparams
        self.data_path = data_path
        self.T_pred = self.hparams.T_pred
        self.loss_fn = torch.nn.MSELoss(reduction='none')

        self.recog_net_1 = MLP_Encoder(64 * 64, 300, 2, nonlinearity='elu')
        self.recog_net_2 = MLP_Encoder(64 * 64, 300, 3, nonlinearity='elu')
        self.obs_net_1 = MLP_Decoder(1, 100, 64 * 64, nonlinearity='elu')
        self.obs_net_2 = MLP_Decoder(1, 100, 64 * 64, nonlinearity='elu')

        V_net = MLP(3, 100, 1)
        M_net = PSD(3, 300, 2)
        g_net = MatrixNet(3, 100, 4, shape=(2, 2))

        self.ode = Lag_Net_R1_T1(g_net=g_net, M_net=M_net, V_net=V_net)

        self.train_dataset = None
        self.non_ctrl_ind = 1

    def train_dataloader(self):
        if self.hparams.homo_u:
            # must set trainer flag reload_dataloaders_every_epoch=True
            if self.train_dataset is None:
                self.train_dataset = HomoImageDataset(self.data_path,
                                                      self.hparams.T_pred)
            if self.current_epoch < 1000:
                # feed zero ctrl dataset and ctrl dataset in turns
                if self.current_epoch % 2 == 0:
                    u_idx = 0
                else:
                    u_idx = self.non_ctrl_ind
                    self.non_ctrl_ind += 1
                    if self.non_ctrl_ind == 9:
                        self.non_ctrl_ind = 1
            else:
                u_idx = self.current_epoch % 9
            self.train_dataset.u_idx = u_idx
            self.t_eval = torch.from_numpy(self.train_dataset.t_eval)
            return DataLoader(self.train_dataset,
                              batch_size=self.hparams.batch_size,
                              shuffle=True,
                              collate_fn=my_collate)
        else:
            train_dataset = ImageDataset(self.data_path,
                                         self.hparams.T_pred,
                                         ctrl=True)
            self.t_eval = torch.from_numpy(train_dataset.t_eval)
            return DataLoader(train_dataset,
                              batch_size=self.hparams.batch_size,
                              shuffle=True,
                              collate_fn=my_collate)

    def angle_vel_est(self, q0_m_n, q1_m_n, delta_t):
        delta_cos = q1_m_n[:, 0:1] - q0_m_n[:, 0:1]
        delta_sin = q1_m_n[:, 1:2] - q0_m_n[:, 1:2]
        q_dot0 = -delta_cos * q0_m_n[:, 1:
                                     2] / delta_t + delta_sin * q0_m_n[:, 0:
                                                                       1] / delta_t
        return q_dot0

    def encode(self, batch_image):
        r_m_logv = self.recog_net_1(batch_image[:, 0].reshape(
            self.bs, self.d * self.d))
        r_m, r_logv = r_m_logv.split([1, 1], dim=1)
        r_m = torch.tanh(r_m)
        r_v = torch.exp(r_logv) + 0.0001

        phi_m_logv = self.recog_net_2(batch_image[:, 1].reshape(
            self.bs, self.d * self.d))
        phi_m, phi_logv = phi_m_logv.split([2, 1], dim=1)
        phi_m_n = phi_m / phi_m.norm(dim=-1, keepdim=True)
        phi_v = F.softplus(phi_logv) + 1
        return r_m, r_v, phi_m, phi_v, phi_m_n

    def get_theta(self, cos, sin, x, y, bs=None):
        # x, y should have shape (bs, )
        bs = self.bs if bs is None else bs
        theta = torch.zeros([bs, 2, 3], dtype=self.dtype, device=self.device)
        theta[:, 0, 0] += cos
        theta[:, 0, 1] += sin
        theta[:, 0, 2] += x
        theta[:, 1, 0] += -sin
        theta[:, 1, 1] += cos
        theta[:, 1, 2] += y
        return theta

    def get_theta_inv(self, cos, sin, x, y, bs=None):
        bs = self.bs if bs is None else bs
        theta = torch.zeros([bs, 2, 3], dtype=self.dtype, device=self.device)
        theta[:, 0, 0] += cos
        theta[:, 0, 1] += -sin
        theta[:, 0, 2] += -x * cos + y * sin
        theta[:, 1, 0] += sin
        theta[:, 1, 1] += cos
        theta[:, 1, 2] += -x * sin - y * cos
        return theta

    def forward(self, X, u):
        [_, self.bs, c, self.d, self.d] = X.shape
        T = len(self.t_eval)
        # encode
        self.r0_m, self.r0_v, self.phi0_m, self.phi0_v, self.phi0_m_n = self.encode(
            X[0])
        self.r1_m, self.r1_v, self.phi1_m, self.phi1_v, self.phi1_m_n = self.encode(
            X[1])

        # reparametrize
        self.Q_r0 = Normal(self.r0_m, self.r0_v)
        self.P_normal = Normal(torch.zeros_like(self.r0_m),
                               torch.ones_like(self.r0_v))
        self.r0 = self.Q_r0.rsample()

        self.Q_phi0 = VonMisesFisher(self.phi0_m_n, self.phi0_v)
        self.P_hyper_uni = HypersphericalUniform(1, device=self.device)
        self.phi0 = self.Q_phi0.rsample()
        while torch.isnan(self.phi0).any():
            self.phi0 = self.Q_phi0.rsample()

        # estimate velocity
        self.r_dot0 = (self.r1_m - self.r0_m) / (self.t_eval[1] -
                                                 self.t_eval[0])
        self.phi_dot0 = self.angle_vel_est(self.phi0_m_n, self.phi1_m_n,
                                           self.t_eval[1] - self.t_eval[0])

        # predict
        z0_u = torch.cat([self.r0, self.phi0, self.r_dot0, self.phi_dot0, u],
                         dim=1)
        zT_u = odeint(self.ode, z0_u, self.t_eval,
                      method=self.hparams.solver)  # T, bs, 4
        self.qT, self.q_dotT, _ = zT_u.split([3, 2, 2], dim=-1)
        self.qT = self.qT.view(T * self.bs, 3)

        # decode
        ones = torch.ones_like(self.qT[:, 0:1])
        self.cart = self.obs_net_1(ones)
        self.pole = self.obs_net_2(ones)

        theta1 = self.get_theta_inv(1, 0, self.qT[:, 0], 0, bs=T * self.bs)
        theta2 = self.get_theta_inv(self.qT[:, 1],
                                    self.qT[:, 2],
                                    self.qT[:, 0],
                                    0,
                                    bs=T * self.bs)

        grid1 = F.affine_grid(theta1,
                              torch.Size((T * self.bs, 1, self.d, self.d)))
        grid2 = F.affine_grid(theta2,
                              torch.Size((T * self.bs, 1, self.d, self.d)))

        transf_cart = F.grid_sample(
            self.cart.view(T * self.bs, 1, self.d, self.d), grid1)
        transf_pole = F.grid_sample(
            self.pole.view(T * self.bs, 1, self.d, self.d), grid2)
        self.Xrec = torch.cat(
            [transf_cart, transf_pole,
             torch.zeros_like(transf_cart)], dim=1)
        self.Xrec = self.Xrec.view(T, self.bs, 3, self.d, self.d)
        return None

    def training_step(self, train_batch, batch_idx):
        X, u = train_batch
        self.forward(X, u)

        lhood = -self.loss_fn(self.Xrec, X)
        lhood = lhood.sum([0, 2, 3, 4]).mean()
        kl_q = torch.distributions.kl.kl_divergence(self.Q_r0, self.P_normal).mean() + \
                torch.distributions.kl.kl_divergence(self.Q_phi0, self.P_hyper_uni).mean()
        norm_penalty = (self.phi0_m.norm(dim=-1).mean() - 1)**2

        loss = -lhood + kl_q + 1 / 100 * norm_penalty

        logs = {'recon_loss': -lhood, 'kl_q_loss': kl_q, 'train_loss': loss}
        return {'loss': loss, 'log': logs, 'progress_bar': logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), self.hparams.learning_rate)

    @staticmethod
    def add_model_specific_args(parent_parser):
        """
        Specify the hyperparams for this LightningModule
        """
        # MODEL specific
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--learning_rate', default=1e-4, type=float)
        parser.add_argument('--batch_size', default=1024, type=int)

        return parser
예제 #30
0
def add_noise(images, mean=0, std=0.1):
    normal_dst = Normal(mean, std)
    noise = normal_dst.sample(images.shape)
    noisy_image = noise + images
    return noisy_image
예제 #31
0
        global_step += 1
        obs[step] = next_obs.copy()

        # ALGO LOGIC: put action logic here
        logits, std = pg.forward(obs[step:step + 1])
        values[step] = vf.forward(obs[step:step + 1])

        # ALGO LOGIC: `env.action_space` specific logic
        if isinstance(env.action_space, Discrete):
            probs = Categorical(logits=logits)
            action = probs.sample()
            actions[step], neglogprobs[step], entropys[step] = action.tolist(
            )[0], -probs.log_prob(action), probs.entropy()

        elif isinstance(env.action_space, Box):
            probs = Normal(logits, std)
            action = probs.sample()
            clipped_action = torch.clamp(
                action, torch.min(torch.Tensor(env.action_space.low)),
                torch.min(torch.Tensor(env.action_space.high)))
            actions[step], neglogprobs[step], entropys[
                step] = clipped_action.tolist(
                )[0], -probs.log_prob(action).sum(), probs.entropy().sum()

        elif isinstance(env.action_space, MultiDiscrete):
            logits_categories = torch.split(logits,
                                            env.action_space.nvec.tolist(),
                                            dim=1)
            action = []
            probs_categories = []
            probs_entropies = torch.zeros((logits.shape[0]))
예제 #32
0
def log_prob_Normal(value, loc, scale):
    return Normal(loc=loc, scale=scale).log_prob(value)
예제 #33
0
def fakeDataDIstr(args):
    ''' Plot empirical and real distribution of an artificial dataset ground-truth parameters.

    For each of the four parameter vector (true scores, biases, inconsistencies and difficulties),
    it plots the real density and the emirical distribution of the same figure.

    Args:
        args (namespace): the namespace collected in the begining of the script containing all arguments about model training and evaluation.
    '''

    dataConf = configparser.ConfigParser()
    dataConf.read("../data/{}.ini".format(args.dataset))
    dataConf = dataConf['default']

    paramKeys = ["trueScores", "diffs", "incons", "bias"]

    cleanNames = ["True Scores", "Difficulties", "Inconsistencies", "Biases"]
    dx = 0.05
    xLim = 0.4

    dist_dic = {"trueScores":lambda x:torch.exp(Uniform(1,5).log_prob(x)),\
                "diffs":lambda x:torch.exp(Beta(float(dataConf['diff_alpha']), float(dataConf["diff_beta"])).log_prob(x)), \
                "incons":lambda x:torch.exp(Beta(float(dataConf["incons_alpha"]), float(dataConf["incons_beta"])).log_prob(x)),\
                "bias":lambda x:torch.exp(Normal(torch.zeros(1), float(dataConf["bias_std"])*torch.eye(1)).log_prob(x))}

    range_dic = {"trueScores":torch.arange(1,5,dx),\
                "diffs":torch.arange(0,1,dx), \
                "incons":torch.arange(0,1,dx),\
                "bias":torch.arange(-3*float(dataConf["bias_std"]),3*float(dataConf["bias_std"]),dx)}

    for i, paramName in enumerate(paramKeys):

        paramValues = np.genfromtxt("../data/{}_{}.csv".format(
            dataConf["dataset_id"], paramName))
        trueCDF = dist_dic[paramName](range_dic[paramName]).numpy().reshape(-1)

        fig, empiCountAx = plt.subplots()

        plt.title(cleanNames[i])
        plt.ylabel("Density")

        plt.xlabel("Value")

        handles = []
        distAx = empiCountAx.twinx()

        if paramName == "incons" or paramName == "diffs":
            empiCountAx.hist(paramValues,
                             10,
                             label="Empirical distribution",
                             color="orange",
                             range=[0, xLim])
            #plt.xlim(0,xLim)
        else:
            empiCountAx.hist(paramValues,
                             10,
                             label="Empirical distribution",
                             color="orange")

        handles += distAx.plot(range_dic[paramName].numpy(),
                               trueCDF,
                               label="True distribution",
                               color="blue")
        print(trueCDF.max())
        #distAx.set_ylim(0,trueCDF.max())

        leg = plt.legend(handles=handles, title="Test")

        plt.gca().add_artist(leg)
        plt.savefig("../vis/{}_{}_dis.png".format(args.dataset, paramName))
예제 #34
0
 def forward(self, inputs, c, z=None):    
     mu_z = self.mu_z(inputs[:, 0])
     sigma_z = self.sigma_z(inputs[:, 0])
     dist = Normal(mu_z, sigma_z)
     if z is None: z = dist.rsample()
     return z, dist.log_prob(z).sum(dim=1) # Return value, score
예제 #35
0
 def _get_normal(self, logits):
     loc, scale = logits.chunk(2, dim=-1)
     loc = safe_squeeze(loc, -1)
     scale = torch.exp(safe_squeeze(scale, -1))
     return Normal(loc, scale)
예제 #36
0
    def infer_step(self, prev, x):
        """
        Given previous state, predict next state. We assume that z_pres is 1
        :param prev: AIRState
        :return: new_state, KL, baseline value, z_pres_likelihood
        """

        B = x.size(0)

        # Flatten x
        x_flat = x.view(B, -1)

        # First, compute h_t that encodes (x, z[1:i-1])
        lstm_input = torch.cat(
            (x_flat, prev.z_where, prev.z_what, prev.z_pres), dim=1)
        h, c = self.lstm_cell(lstm_input, (prev.h, prev.c))

        # Predict presence and location
        z_pres_p, z_where_loc, z_where_scale = self.predict(h)

        # In theory, if z_pres is 0, we don't need to continue computation. But
        # for batch processing, we will do this anyway.

        # sample z_pres
        z_pres_p = z_pres_p * prev.z_pres

        # NOTE: for numerical stability, if z_pres_p is 0 or 1, we will need to
        # clamp it to within (0, 1), or otherwise the gradient will explode
        eps = 1e-6
        z_pres_p = z_pres_p + eps * (z_pres_p == 0).float() - eps * (
            z_pres_p == 1).float()

        z_pres_post = Bernoulli(z_pres_p)
        z_pres = z_pres_post.sample()
        z_pres = z_pres * prev.z_pres

        # Likelihood. Note we must use prev.z_pres instead of z_pres because
        # p(z_pres[i]=0|z_prse[i]=1) is non-zero.
        z_pres_likelihood = z_pres_post.log_prob(z_pres) * prev.z_pres
        # (B,)
        z_pres_likelihood = z_pres_likelihood.squeeze()

        # sample z_where
        z_where_post = Normal(z_where_loc, z_where_scale)
        z_where = z_where_post.rsample()

        # extract object
        # (B, 1, Hobj, Wobj)
        object = self.image_to_object(x, z_where, inverse=True)

        # predict z_what
        z_what_loc, z_what_scale = self.encoder(object)
        z_what_post = Normal(z_what_loc, z_what_scale)
        z_what = z_what_post.rsample()

        # compute baseline for this z_pres
        bl_h, bl_c = self.bl_rnn(lstm_input.detach(), (prev.bl_h, prev.bl_c))
        # (B,)
        baseline_value = self.bl_predict(bl_h).squeeze()
        # If z_pres[i-1] is 0, the reinforce term will not be dependent on phi.
        # In this case, we don't need the term. So we set it to zero.
        # At the same time, we must set learning signal to zero as this will
        # matter in baseline loss computation.
        baseline_value = baseline_value * prev.z_pres.squeeze()

        # Compute KL as we go, sum over data dimension
        kl_pres = kl_divergence(
            z_pres_post,
            self.pres_prior.expand(z_pres_post.batch_shape)).sum(1)
        kl_where = kl_divergence(
            z_where_post,
            self.where_prior.expand(z_where_post.batch_shape)).sum(1)
        kl_what = kl_divergence(
            z_what_post,
            self.what_prior.expand(z_what_post.batch_shape)).sum(1)

        # For where and what, when z_pres[i] is 0, they are determnisitic
        kl_where = kl_where * z_pres.squeeze()
        kl_what = kl_what * z_pres.squeeze()
        # For pres, this is not the case. So we use prev.z_pres.
        kl_pres = kl_pres * prev.z_pres.squeeze()

        kl = (kl_pres + kl_where + kl_what)

        # new state
        new_state = AIRState(z_pres=z_pres,
                             z_where=z_where,
                             z_what=z_what,
                             h=h,
                             c=c,
                             bl_c=bl_c,
                             bl_h=bl_h,
                             z_pres_p=z_pres_p)

        # Logging
        if DEBUG:
            vis_logger['z_pres_p_list'].append(z_pres_p[0])
            vis_logger['z_pres_list'].append(z_pres[0])
            vis_logger['z_where_list'].append(z_where[0])
            vis_logger['object_enc_list'].append(object[0])
            vis_logger['kl_pres_list'].append(kl_pres.mean())
            vis_logger['kl_what_list'].append(kl_what.mean())
            vis_logger['kl_where_list'].append(kl_where.mean())

        return new_state, kl, baseline_value, z_pres_likelihood
예제 #37
0
def log_prob_TruncatedNormal(value, loc, scale, lower, upper):
    F_a = Phi(lower)
    F_b = Phi(upper)
    return Normal(loc=loc, scale=scale).log_prob(value) - torch.log(F_b - F_a + 1e-8)
예제 #38
0
    def forward(self, x):
        B = x.size(0)
        state = AIRState.get_intial_state(B, self.arch)

        # accumulated KL divergence
        kl = []
        # baseline value for each step
        baseline_value = []
        # z_pres likelihood for each step
        z_pres_likelihood = []
        # learning signal for each step
        learning_signal = torch.zeros(B, self.arch.max_steps, device=x.device)
        # signal_mask (prev.z_pres)
        signal_mask = torch.ones(B, self.arch.max_steps, device=x.device)
        # mask (z_pres)
        mask = torch.ones(B, self.arch.max_steps, device=x.device)
        # canvas
        h, w = self.arch.input_shape
        canvas = torch.zeros(B, 1, h, w, device=x.device)

        if DEBUG:
            vis_logger['image'] = x[0]
            vis_logger['z_pres_p_list'] = []
            vis_logger['z_pres_list'] = []
            vis_logger['canvas_list'] = []
            vis_logger['z_where_list'] = []
            vis_logger['object_enc_list'] = []
            vis_logger['object_dec_list'] = []
            vis_logger['kl_pres_list'] = []
            vis_logger['kl_what_list'] = []
            vis_logger['kl_where_list'] = []

        for t in range(self.T):
            # This is prev.z_pres. The only purpose is for masking learning signal.
            signal_mask[:, t] = state.z_pres.squeeze()

            # all terms are already masked
            state, this_kl, this_baseline_value, this_z_pres_likelihood = self.infer_step(
                state, x)
            baseline_value.append(this_baseline_value.squeeze())
            kl.append(this_kl)
            z_pres_likelihood.append(this_z_pres_likelihood.squeeze())

            # add learning signal to depending terms (1:i-1)
            # NOTE: kl of z_pres of current step does not depends on sample from
            # z_pres, but kl of z_where and z_what DOES. They cannot be excluded
            # from learning signal. So here we use t + 1 instead of t. Although
            # this also includes kl of z_pres of current step, this will not
            # matter too much
            for j in range(t + 1):
                learning_signal[:, j] += this_kl.squeeze()

            # reconstruct
            object = self.decoder(state.z_what)
            # (B, 1, H, W)
            img = self.object_to_image(object, state.z_where, inverse=False)
            # Masking is crucial here.
            canvas = canvas + img * state.z_pres[:, :, None, None]

            mask[:, t] = state.z_pres.squeeze()

            vis_logger['canvas_list'].append(canvas[0])
            vis_logger['object_dec_list'].append(object[0])

        baseline_value = torch.stack(baseline_value, dim=1)
        kl = torch.stack(kl, dim=1)
        z_pres_likelihood = torch.stack(z_pres_likelihood, dim=1)

        # construct output distribution
        output_dist = Normal(canvas, self.arch.x_scale.expand(canvas.shape))
        likelihood = output_dist.log_prob(x)
        # sum over data dimension
        likelihood = likelihood.view(B, -1).sum(1)

        # Construct surrogate loss
        # Note the MNIUS sign here !
        learning_signal = learning_signal - likelihood[:, None]
        learning_signal = learning_signal * signal_mask
        reinforce_term = (learning_signal.detach() -
                          baseline_value.detach()) * z_pres_likelihood
        reinforce_term = reinforce_term.sum(1)
        # reinforce_term = torch.zeros_like(reinforce_term)

        # kl term, sum over batch dimension
        kl = kl.sum(1)

        loss = self.reinforce_weight * reinforce_term + kl - likelihood
        # mean over batch dimension
        loss = loss.mean()

        vis_logger['reinforce_loss'] = (reinforce_term.mean())
        vis_logger['kl_loss'] = (kl.mean())
        vis_logger['neg_likelihood'] = (-likelihood.mean())

        # compute baseline loss
        baseline_loss = F.mse_loss(baseline_value, learning_signal.detach())

        vis_logger['baseline_loss'] = baseline_loss

        # losslist = (reinforce_term.mean(), kl.mean(), likelihood.mean(), baseline_loss)

        return loss + baseline_loss, mask.sum(1)
예제 #39
0
 def forward(self, x):
     value = self.critic(x)
     mu = nn.tanh(self.actor(x))
     std = self.log_std.exp().expand_as(mu)
     dist = Normal(mu, std)
     return dist, value
예제 #40
0
class Gaussian(Parameter):
    """Gaussian

    Gaussian Parameter parametrized by mu and rho (std is softplus of rho,
    this enables stability and forces the value to be positive). Actual value
    of the parameter is computed using the reparametrization trick. Weight
    is sampled by sampling a normal distribution centered on 0 with std of 1.
    The sammple value epsilon is then scaled using the mu and rho values.

    eps ~ N(0, 1)
    std = log(1 + exp(rho))
    W = mu + eps * std

    Attributes:
        size (Size): size of the parameter, weight
        initialization (Optional[Initialization]): initilization callback,
            responsible for the weight initialization. Initialises mu and rho.
            {default: DEFAULT_UNIFORM}
        mu (nn.Parameter): mu parameter of the gaussian
        rho (nn.Parameter): rho parameter of the gaussian
        normal (Normal): normal distribution to sample epsilon
    """

    def __init__(
        self, size: Size,
        initialization: Optional[Initialization] = DEFAULT_UNIFORM,
        dtype: Optional[torch.dtype] = torch.float32
    ) -> None:
        """Initilization

        Arguments:
            size (Size): size of the parameter, weight
        
        Keyword Arguments:
            initialization (Optional[Initialization]): initilization callback,
                responsible for the weight initialization.
                Initialises mu and rho. {default: DEFAULT_UNIFORM}
            dtype (Optional[torch.dtype]): data type of the parameter
                {default: torch.float32}
        """
        super(Gaussian, self).__init__()
        self.size, self.dtype = size, dtype
        self.initialization = initialization
        self.mu = parameter(self.size, dtype=self.dtype)
        self.rho = parameter(self.size, dtype=self.dtype)
        
        self.register_parameter("zero", nn.Parameter(torch.tensor(0.).float(), requires_grad = False))
        self.register_parameter("one", nn.Parameter(torch.tensor(1.).float(),  requires_grad = False))
        
        self.normal = Normal(self.zero, self.one)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        """Reset Parameter

        Reset mu and rho values using the initialization callback.
        """
        self.mu, self.rho = self.initialization(self.mu, self.rho)

    @property
    def sigma(self) -> Tensor:
        """Sigma

        Returns:
            Tensor: sigma return as the sofplus of rho
        """
        return F.softplus(self.rho)

    def sample(self) -> Tensor:
        """Sample

        Reparamtrization trick allowing for differentiable sampling.
        eps ~ N(0, 1)
        W = mu + eps * std

        Returns:
            Tensor: sampled gaussian weight using the reparametrization trick.
        """
        eps = self.normal.sample(self.size).to(self.mu.device)
        return self.mu + eps * self.sigma

    def log_prob(self, input: Tensor) -> Tensor:
        """Gaussian Log Probability

        Arguments:
            input (Tensor): sampled value of the gaussian weight

        Returns:
            Tensor: log probability
        """
        return (
            - np.log(np.sqrt(2 * np.pi))
            - torch.log(self.sigma)
            - ((input - self.mu) ** 2) / (2 * self.sigma ** 2)
        ).sum()
예제 #41
0
 def forward(self, state):
     mu_sigma = self.forward_pass(state)
     m = Normal(mu_sigma[:, :4], 0.2 * (0.5 + 0.5 * mu_sigma[:, 4:]))
     actions = m.sample()
     return actions
예제 #42
0
    rows = 1
    cols = 1
    fig = plt.figure(figsize=(10+cols,4+rows), facecolor='white') #, dpi=150)

    col =0
    row = 0
    ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1)




    xs = np.linspace(-9,205, 300)
    sum_ = np.zeros(len(xs))
    C = 20
    for c in range(C):
        m = Normal(torch.tensor([c*10.]).float(), torch.tensor([5.0]).float())
        # xs = torch.tensor(xs)
        # print (m.log_prob(lin))
        ys = []
        for x in xs:
            # print (m.log_prob(x))
            component_i = (torch.exp(m.log_prob(x) )* ((c+5.) / 290.)).numpy()
            ys.append(component_i)
        ys = np.reshape(np.array(ys), [-1])
        sum_ += ys
        ax.plot(xs, ys, label='', c='c')
    ax.plot(xs, sum_, label='')



    mixture_weights = torch.softmax(needsoftmax_mixtureweight, dim=0)
예제 #43
0
파일: models.py 프로젝트: npex2020rl/rl
 def log_prob(self, obs, act):
     mu, sigma = self.forward(obs)
     act_distribution = Independent(Normal(mu, sigma), 1)
     log_prob = act_distribution.log_prob(act)
     return log_prob
예제 #44
0
 def forward(self, x):
     loss = -Normal(self.mu, self.sigma).log_prob(x)
     return torch.sum(loss) / (loss.size(0) * loss.size(1))
예제 #45
0
 def _initialize_distributions(self):
     self._tails = [
         Normal(loc=l, scale=s, validate_args=True)
         for l, s in zip(self._loc, self.sigma)
     ]
예제 #46
0
파일: core.py 프로젝트: kekmodel/rl_pytorch
 def _distribution(self, obs):
     mu = self.mu_net(obs)
     std = torch.exp(self.log_std)
     return Normal(mu, std)
                checkpoint_['optimizers_state'][key])
    #Loading scheduler state dict
    for key in Schedulers.keys():
        if not key == 'h2l':
            Schedulers[key].load_state_dict(
                checkpoint_['scheduler_state'][key])
    epoch = checkpoint_['epoch']

testIter = cycle(testloader)

#The sampler
m_train = Normal(
    torch.zeros([batch_size['train'], 16, 1, 1],
                device=device,
                dtype=torch.float,
                requires_grad=False),
    torch.ones([batch_size['train'], 16, 1, 1],
               device=device,
               dtype=torch.float,
               requires_grad=False))
m_test = Normal(
    torch.zeros([batch_size['test'], 16, 1, 1],
                device=device,
                dtype=torch.float,
                requires_grad=False),
    torch.ones([batch_size['test'], 16, 1, 1],
               device=device,
               dtype=torch.float,
               requires_grad=False))

while True:
예제 #48
0
 def diag_logll(self, param_list, mean_list, var_list):
     logprob = 0.0
     for param, mean, scale in zip(param_list, mean_list, var_list):
         logprob += Normal(mean, scale).log_prob(param).sum()
     return logprob
예제 #49
0
 def forward(self, c, z, x=None):
     cz = torch.cat([c,z], dim=1)
     dist = Normal(self.mu(cz), self.sigma(cz))
     if x is None: x = dist.rsample()
     return x, dist.log_prob(x).sum(dim=1) # Return value, score
예제 #50
0
 def sample(self, logits):
     logits = safe_squeeze(logits, -1)
     return Normal(logits, 1.0).sample()
예제 #51
0
    def test_gmm_loss(self):
        # seq_len x batch_size x gaussian_size x feature_size
        # 1 x 1 x 2 x 2
        mus = torch.Tensor([[[[0.0, 0.0], [6.0, 6.0]]]])
        sigmas = torch.Tensor([[[[2.0, 2.0], [2.0, 2.0]]]])
        # seq_len x batch_size x gaussian_size
        pi = torch.Tensor([[[0.5, 0.5]]])
        logpi = torch.log(pi)

        # seq_len x batch_size x feature_size
        batch = torch.Tensor([[[3.0, 3.0]]])
        gl = gmm_loss(batch, mus, sigmas, logpi)

        # first component, first dimension
        n11 = Normal(mus[0, 0, 0, 0], sigmas[0, 0, 0, 0])
        # first component, second dimension
        n12 = Normal(mus[0, 0, 0, 1], sigmas[0, 0, 0, 1])
        p1 = (
            pi[0, 0, 0]
            * torch.exp(n11.log_prob(batch[0, 0, 0]))
            * torch.exp(n12.log_prob(batch[0, 0, 1]))
        )
        # second component, first dimension
        n21 = Normal(mus[0, 0, 1, 0], sigmas[0, 0, 1, 0])
        # second component, second dimension
        n22 = Normal(mus[0, 0, 1, 1], sigmas[0, 0, 1, 1])
        p2 = (
            pi[0, 0, 1]
            * torch.exp(n21.log_prob(batch[0, 0, 0]))
            * torch.exp(n22.log_prob(batch[0, 0, 1]))
        )

        logger.info(
            "gmm loss={}, p1={}, p2={}, p1+p2={}, -log(p1+p2)={}".format(
                gl, p1, p2, p1 + p2, -torch.log(p1 + p2)
            )
        )
        assert -torch.log(p1 + p2) == gl