def encode(self, inputs, training=None, mask=None): X, y, mask = prepare_ssl_inputs(inputs, mask=mask, n_unsupervised_inputs=1) X = X[0] # only accept single inputs now ## encode normally h_x = self.encoder(X, training=training, mask=mask) h_x = bk.flatten(h_x, n_outdim=2) ## prepare the auxiliary qa_x = self.qa_dist(self.encoder_a(X, training=training), training=training, mask=mask) ## prepare the label embedding qy_ax = self.classify(X, training=training, qa_x=qa_x) ## combine into q(z|axy) h_x = self.x_to_qz(h_x, training=training) h_a = self.a_to_qz(qa_x, training=training) h_y = self.y_to_qz(qy_ax, training=training) h_axy = h_x + h_y + h_a if self.batchnorm: h_axy = self.qz_xy_norm(h_axy, training=training) if 0.0 < self.dropout < 1.0: h_axy = self.qz_xy_drop(h_axy, training=training) # conditional embedding y h_axy = self.xy_to_qz_net(h_axy, training=training, mask=mask) qz_axy = self.latents(h_axy, training=training, mask=mask) return (qz_axy, qa_x, qy_ax)
def encode(self, inputs: Union[TensorTypes, List[TensorTypes]], training: Optional[bool] = None, mask: Optional[TensorTypes] = None, **kwargs) -> JointDistributionSequential: X, y, mask = prepare_ssl_inputs( inputs, mask=mask, n_unsupervised_inputs=self.n_observation) if len(y) == 0: py = self.classify(X, training=training) else: # only support single labels model py = coercible_tensor(VectorDeterministic(loc=y[0])) # encode normally h_e = [ bk.flatten(fe(X[0], training=training, mask=mask), n_outdim=2) for fe in self.encoder ] if len(h_e) > 1: h_e = tf.concat(h_e, axis=-1) else: h_e = h_e[0] # conditional embedding y y_embedded = self.labels_embedder[0](py) h_e = tf.concat([h_e, y_embedded], axis=-1) qz_x = [fz(h_e, training=training) for fz in self.latents] qz_x.append(py) return qz_x
def _apply(self, X): ndims = X.get_shape().ndims if ndims is not None: if ndims == self.outdim: return X elif ndims < self.outdim: raise RuntimeError("Input shape: %s, cannot be flatten to %d-D" % (str(X.get_shape()), self.outdim)) return K.flatten(X, outdim=self.outdim)
def test_flatten(self): x = K.placeholder(shape=(None, 8, 12, 25, 18)) for i in range(1, 5): y = K.flatten(x, outdim=i) f = K.function(x, y) shape1 = K.get_shape(y) shape2 = f(np.random.rand(16, 8, 12, 25, 18)).shape self.assertEqual(len(shape1), len(shape2)) self.assertTrue( all(i == j for i, j in zip(shape1, shape2) if i is not None))
def _apply(self, X, h0=None, c0=None, mask=None): batch_size = K.get_shape(X, native=True)[0] is_bidirectional = self.direction_mode == 'bidirectional' input_mode = ('skip' if self.input_mode == 'skip' or self.input_mode == 'norm' else 'linear') # ====== precompute input ====== # # linear or norm input mode if self.input_mode == 'norm': X = K.dot(X, self.W_in) # normalize all axes except the time dimension bn = BatchNorm(axes=(0, 1), activation=K.linear, gamma_init=self.gamma, beta_init=self.beta, mean_init=self.mean, inv_std_init=self.inv_std) X = bn(X) # cudnnRNN doesnt' support multiple inputs shapeX = K.get_shape(X, native=True) ndims = K.ndim(X) if 'rnn' in self.rnn_mode: N = 1 elif self.rnn_mode == 'gru': N = 3 else: N = 4 newshape = [shapeX[i] for i in range(ndims - 1)] + [self.num_units, N] X = K.mean(K.reshape(X, newshape), axis=-1) # ====== hidden state ====== # num_layers = self.num_layers * 2 if is_bidirectional else self.num_layers require_shape = (num_layers, batch_size, self.num_units) h0 = _check_cudnn_hidden_init(h0, require_shape, self, 'h0') c0 = _check_cudnn_hidden_init(c0, require_shape, self, 'c0') # ====== parameters ====== # if self.params_split: parameters = K.concatenate([ K.flatten(i, outdim=1) for i in self.parameters if not has_roles(i, INITIAL_STATE) ]) else: parameters = self.params # ====== return CuDNN RNN ====== # results = K.rnn_dnn(X, hidden_size=self.num_units, rnn_mode=self.rnn_mode, num_layers=self.num_layers, parameters=parameters, h0=h0, c0=c0, input_mode=input_mode, direction_mode=self.direction_mode, dropout=self.dropout, name=self.name) if not self.return_states: results = results[0] # only get the output return results
def test_norm(self): for p in [1, 2, 'fro', np.inf]: for axis in [None, 0, 1, (0, 1)]: a = bk.norm(bk.flatten(x, 2), p=p, axis=axis, keepdims=True) b = bk.norm(bk.flatten(y, 2), p=p, axis=axis, keepdims=True) c = bk.norm(bk.flatten(z, 2), p=p, axis=axis, keepdims=True) assert_equal(self, (p, axis), a, b, c) a = bk.norm(bk.flatten(x, 2), p=p, axis=axis, keepdims=False) b = bk.norm(bk.flatten(y, 2), p=p, axis=axis, keepdims=False) c = bk.norm(bk.flatten(z, 2), p=p, axis=axis, keepdims=False) assert_equal(self, (p, axis), a, b, c)
def classify(self, inputs, training=False, qa_x: Optional[Distribution] = None) -> Distribution: """Return the prediction of labels""" # prepare x if isinstance(inputs, (tuple, list)): inputs = inputs[0] # only support a single inputs Tensor h_x = self.x_to_qy(bk.flatten(inputs, n_outdim=2), training=training) # prepare a if qa_x is None: qa_x = self.qa_dist(self.encoder_a(inputs, training=training), training=training) h_a = self.a_to_qy(qa_x, training=training) # final combination h_ax = h_a + h_x if self.batchnorm: h_ax = self.qy_ax_norm(h_ax, training=training) if 0.0 < self.dropout < 1.0: h_ax = self.qy_ax_drop(h_ax, training=training) h = self.classifier(h_ax, training=training) return self.labels(h, training=training)
name='Decoder') # =========================================================================== # Create statistical model # =========================================================================== # ====== encoder ====== # E = f_encoder(X) # ====== latent ====== # q_Z_given_X = B.parse_distribution(args.zdist, E, int(args.zdim), name='Z') # [n_sample, n_batch, zdim] q_Z_given_X_samples = q_Z_given_X.sample(nsample) Z = [ q_Z_given_X.mean(), tf.concat([q_Z_given_X.mean(), tf.sqrt(q_Z_given_X.variance())], axis=-1), K.flatten(tf.transpose(q_Z_given_X_samples, perm=(1, 0, 2)), outdim=2) ] Z_names = ["posterior mean", "statistic pooling", "all samples flatten"] # ====== Z prior ====== # p_Z = B.parse_distribution(dist_name=args.zprior) # ====== decoder ====== # D = f_decoder(q_Z_given_X_samples) # ====== reconstruction ====== # p_X_given_Z = B.parse_distribution(args.xdist, D, int(np.prod(input_shape[1:])), n_eventdim=1, name='W') # [n_sample, n_batch, feat_dim] p_X_given_Z_mean = p_X_given_Z.mean() # [n_batch, feat_dim]
def flatten_and_test(n): a = bk.flatten(x, n) b = bk.flatten(y, n) c = bk.flatten(z, n) assert_equal(self, n, a, b, c)
def _apply(self, x): input_shape = K.get_shape(x) _validate_input_shape(input_shape) return K.flatten(x, outdim=self.outdim)
def _apply(self, X): return K.flatten(X, outdim=self.outdim)