示例#1
0
 def forward(self, z, inference=False):
     n_batch_axes = 1 if inference else 2
     h = F.tanh(self.linear(z, n_batch_axes=n_batch_axes))
     h = self.output(h, n_batch_axes=n_batch_axes)
     return D.Independent(D.Bernoulli(logit=h,
                                      binary_check=self.binary_check),
                          reinterpreted_batch_ndims=1)
示例#2
0
    def decode(self, x, **kwargs):
        y = self.predictor.decode(x, **kwargs)
        if kwargs.get('inference'):
            return F.sigmoid(y)

        else:
            p_x = D.Bernoulli(logit=y)
            return p_x
示例#3
0
 def make_bernoulli_dist(self, is_gpu=False):
     p = numpy.random.uniform(0, 1, self.shape).astype(numpy.float32)
     params = self.encode_params({"p": p}, is_gpu)
     return distributions.Bernoulli(**params)
list_prob_pred = []
list_accr = []

start_idx = int(C.SR * 90) // C.H
plot_length = C.SR * 20 // C.H

feat,labs,aligns = dset[0]
#_,feat_un,_,_ = dset_semi[0]

plt.figure(model_name)

plt.subplot(5,1,1)
specshow(feat[start_idx:start_idx+plot_length,:24].T)
plt.yticks(np.arange(24)+0.5,["C","","","","","F","","G","","","","","C","","","","","F","","G","","","",""],fontname="STIXGeneral")
#plt.text(-25,6,"(a)",fontname="STIXGeneral",fontsize=15)
dist_orig = dist.Bernoulli(feat[start_idx:start_idx+plot_length])
#plt.subplot(8,1,2)
#labs_onehot = [U.encode_onehot(labs[aligns[:512],i],cat) for i,cat in zip(list(range(6)),[C.N_VOCABULARY_TRIADS,13,4,4,3,3])]
labs_onehot = U.encode_onehot(labs[aligns[start_idx:start_idx+plot_length]],C.N_VOCABULARY_TRIADS)
generated = model.generator.reconstr_dist([feat[start_idx:start_idx+plot_length]],[labs_onehot])[0]
#generated = (generated.a/(generated.a+generated.b)).data
#specshow(generated[:,:24].T)
#plt.yticks(np.arange(24)+0.5,["C","","","","","F","","G","","","","","C","","","","","F","","G","","","",""],fontname="STIXGeneral")
#plt.text(-25,6,"(c)",fontname="STIXGeneral",fontsize=15)

#plt.subplot(8,1,3)
#generated,lab_estimated = model.reconst(feat[start_idx:start_idx+plot_length])
#print("P_proposed= %.5f" % dist_orig.log_prob(generated).data.sum(-1).mean())
lab_estimated = model.estimate(feat)[start_idx:start_idx+plot_length]
#generated = (generated.a/(generated.a+generated.b)).data
#specshow(generated[:,:24].T)
示例#5
0
 def __call__(self, z, inference=False):
     n_batch_axes = 1 if inference else 2
     h = F.tanh(self.fc1(z, n_batch_axes=n_batch_axes))
     h = self.fc2(h, n_batch_axes=n_batch_axes)
     return D.Bernoulli(logit=h)
示例#6
0
 def forward(self, z, inference=False):
     n_batch_axes = 1 if inference else 2
     h = F.tanh(self.linear(z, n_batch_axes=n_batch_axes))
     h = self.output(h, n_batch_axes=n_batch_axes)
     return D.Bernoulli(logit=h, binary_check=self.binary_check)
示例#7
0
 def to_dist_fn(h):
     return distributions.Independent(D.Bernoulli(logit=h),
                                      reinterpreted_batch_ndims=ndim)