def forward_encoder(self, x_emb): """Encoder step, emulating z ~ E(x) = q_E(z|x) :param x_emb: (n_batch, len(x), d_z) of floats, embeddings for input sentence x :return: (n_batch, d_z) of floats, sample of latent vector z :return: float, kl term component of loss """ # _, h = self.encoder_rnn(x, None) h = self.encoder_rnn(x_emb, None) h = lbann.Slice( h, slice_points=str_list( [self.input_feature_dims - 1, self.input_feature_dims]), axis=0, ) h = lbann.Identity(h) mu, logvar = self.q_mu(h), self.q_logvar(h) # Set datatype of previous layers # Note: Depth-first search from mu and logvar to x_emb stack = [mu, logvar] in_stack = {l: True for l in stack} while stack: l = stack.pop() if type(l) not in (lbann.Slice, lbann.Reshape, lbann.Tessellate): l.datatype = self.datatype for parent in l.parents: if parent not in in_stack and parent is not x_emb: stack.append(parent) in_stack[parent] = True # eps = torch.randn_like(mu) eps = lbann.Gaussian(mean=0, stdev=1, hint_layer=mu) # z = mu + (logvar / 2).exp() * eps z = lbann.Add([ mu, (lbann.Multiply([ lbann.Exp(lbann.WeightedSum(logvar, scaling_factors='0.5')), eps ])) ]) # kl_loss = 0.5 * (logvar.exp() + mu ** 2 - 1 - logvar).sum(1).mean() kl_loss = lbann.Reduction( lbann.WeightedSum( lbann.Exp(logvar), lbann.Square(mu), self.constant(1, hint_layer=mu), logvar, scaling_factors='0.5 0.5 -0.5 -0.5', ), mode='sum', ) return z, kl_loss
def forward(self, motif_size, motif_log_embeddings): """Predict whether a motif is real. @todo Numerically accurate computation of both log(D) and log(1-D). """ # D = 1 - exp(-sum_j(prod_i(d_ij))) # log(1-D) = -sum_j(exp(sum_i(log(d_ij)))) x = lbann.MatMul( lbann.Constant(value=1, num_neurons=str_list([1, motif_size])), motif_log_embeddings, ) x = lbann.Exp(x) x = lbann.Reduction(x, mode='sum') x = lbann.Negative(x) log_not_prob = x # Convert log-probability to linear space # Note: D=-expm1(x) is accurate when D~0. When D~1, prefer # 1-D=exp(x). prob = lbann.Negative(lbann.Expm1(log_not_prob)) return prob, log_not_prob
def compute_loss(self, x, y): # y[:, :-1] y = lbann.Slice( y, axis=0, slice_points=str_list([0, self.input_feature_dims - 1]), ) y = lbann.Identity(y) # x[:, 1:] x = lbann.Slice( x, slice_points=str_list([1, self.input_feature_dims]), ) x = lbann.Identity(x) # Convert indices in x to one-hot representation # Note: Ignored indices result in zero vectors ignore_mask = lbann.Equal( x, self.constant(self.label_to_ignore, hint_layer=x), ) keep_mask = lbann.LogicalNot(ignore_mask) length = lbann.Reduction(keep_mask, mode='sum') length = lbann.Max(length, self.constant(1, [1])) x = lbann.Add( lbann.Multiply(keep_mask, x), lbann.Multiply(ignore_mask, self.constant(-1, hint_layer=x)), ) x = lbann.Slice(x, slice_points=str_list(range(self.input_feature_dims))) x = [lbann.Identity(x) for _ in range(self.input_feature_dims - 1)] x = [lbann.OneHot(xi, size=self.dictionary_size) for xi in x] x = [ lbann.Reshape(xi, dims=str_list([1, self.dictionary_size])) for xi in x ] x = lbann.Concatenation(x, axis=0) # recon_loss = F.cross_entropy( # y[:, :-1].contiguous().view(-1, y.size(-1)), # x[:, 1:].contiguous().view(-1), # ignore_index=self.pad # ) # Note: Ideally we'd shift y by y.max(-1) for numerical stability shifts = lbann.MatMul( lbann.Max(y, self.constant(0, hint_layer=y)), self.constant( 1 / math.sqrt(self.dictionary_size), [self.dictionary_size, self.dictionary_size], ), ) y = lbann.Subtract(y, shifts) z = lbann.MatMul( lbann.Exp(y), self.constant(1, [self.dictionary_size, 1]), ) z = lbann.Log(z) z = lbann.MatMul( lbann.Reshape(keep_mask, dims=str_list([1, -1])), z, ) recon_loss = lbann.MatMul( lbann.Reshape(y, dims=str_list([1, -1])), lbann.Reshape(x, dims=str_list([1, -1])), transpose_b=True, ) recon_loss = lbann.Subtract(z, recon_loss) recon_loss = lbann.Reshape(recon_loss, dims=str_list([1])) recon_loss = lbann.Divide(recon_loss, length) return recon_loss
encode3neuron = lbann.Relu(encode3, name="encode3neuron") # Latent space mu = lbann.FullyConnected(encode3neuron, name="mu", num_neurons=30, has_bias=True) logsd = lbann.FullyConnected(encode3, name="logsd", num_neurons=30, has_bias=True) # KL divergence sd = lbann.Exp(logsd, name="sd") var = lbann.Square(sd, name="var") meansq = lbann.Square(mu, name="meansq") kldiv_plus_half = lbann.WeightedSum([meansq, var, logsd], name="kldiv_plus_half", scaling_factors='0.5 0.5 -1') kldiv_full = lbann.Rsqrt(kldiv_plus_half, name="kldiv_full") kldiv = lbann.Reduction(kldiv_full, name="kldiv", mode="sum") # Generate sample noise = lbann.Gaussian(name="noise", mean=0, stdev=1, hint_layer=mu)
def compute_loss(self, x, y): # y[:, :-1] y = lbann.Slice( y, axis=0, slice_points=str_list([0, self.input_feature_dims-1]), ) y = lbann.Identity(y) # x[:, 1:] x = lbann.Slice( x, slice_points=str_list([1, self.input_feature_dims]), ) x = lbann.Identity(x) # Figure out entries in x to ignore ignore_mask = lbann.Equal( x, self.constant(self.label_to_ignore, hint_layer=x), ) keep_mask = lbann.LogicalNot(ignore_mask) length = lbann.Reduction(keep_mask, mode='sum') length = lbann.Max(length, self.constant(1, [1])) # Convert entries in x to indices in y # Note: Ignored entries correspond to an index of -1. offsets = [ row*self.dictionary_size for row in range(self.input_feature_dims-1) ] offsets = lbann.Weights( initializer=lbann.ValueInitializer(values=str_list(offsets)), optimizer=lbann.NoOptimizer(), ) offsets = lbann.WeightsLayer( dims=str_list([self.input_feature_dims-1]), weights=offsets, ) y_inds = lbann.Add(x, offsets) y_inds = lbann.Add( lbann.Multiply(keep_mask, y_inds), lbann.Multiply( ignore_mask, self.constant(-1, hint_layer=y_inds), ), ) # recon_loss = F.cross_entropy( # y[:, :-1].contiguous().view(-1, y.size(-1)), # x[:, 1:].contiguous().view(-1), # ignore_index=self.pad # ) # Shift y for numerical stability # Note: We'd prefer to shift by y.max(-1) shifts = lbann.MatMul( lbann.Max(y, self.constant(0, hint_layer=y)), self.constant( 1 / math.sqrt(self.dictionary_size), [self.dictionary_size, self.dictionary_size], ), ) y = lbann.Subtract(y, shifts) # Compute log of softmax denominator and sum z = lbann.MatMul( lbann.Exp(y), self.constant(1, [self.dictionary_size, 1]), ) z = lbann.Log(z) z = lbann.MatMul( lbann.Reshape(keep_mask, dims=str_list([1, -1])), z, ) z = lbann.Reshape(z, dims=str_list([1])) # Compute cross entropy recon_loss = lbann.Gather( lbann.Reshape(y, dims=str_list([-1])), y_inds, ) recon_loss = lbann.Reduction(recon_loss, mode='sum') recon_loss = lbann.Subtract(z, recon_loss) recon_loss = lbann.Divide(recon_loss, length) return recon_loss