def forward(self, inputs): if len(inputs) != 2: raise ValueError('expected two inputs: predictions and labels') pred = inputs[0] label = inputs[1] ones = p.Constant(hint_layer=pred, value=1.0) term1 = lbann.Multiply( [label, lbann.Log(lbann.Subtract([ones, pred]))]) term2 = lbann.Log(pred) full = lbann.WeightedSum([term1, term2], scaling_factors='-1.0 -1.0') return lbann.Reduction(full)
def forward(self, inputs): if len(inputs) != 2: raise ValueError('expected two inputs: predictions and labels') pred = inputs[0] label = inputs[1] # Assumed to be Boolean masked_pred = lbann.Multiply([pred, label]) pred_sum = lbann.Reduction(masked_pred) return lbann.Negative(lbann.Log(pred_sum))
def forward(self, inputs): raise NotImplementedError # Requires log-gamma function if len(inputs) != 2: raise ValueError('expected two inputs: predictions and labels') pred = inputs[0] label = inputs[1] ones = lbann.Constant(hint_layer=pred, value=1.0) term1 = pred term2 = lbann.Multiply([label, lbann.Log(pred)]) term3 = lbann.LogGamma(lbann.Add([label, ones])) full = lbann.WeightedSum([term1, term2, term3], scaling_factors='1.0 -1.0 1.0') return lbann.Reduction(full)
def forward( self, motif_indices, motif_size, walk_indices, walk_length, ): # Apply generator fake_motif_indices, gen_prob, gen_log_prob = self.generator( walk_length, walk_indices, motif_size, ) # Get discriminator embeddings in log-space all_motif_indices = lbann.Concatenation(motif_indices, fake_motif_indices) all_motif_log_embeddings = self.discriminator.get_log_embeddings( all_motif_indices) all_motif_log_embeddings = lbann.Slice( all_motif_log_embeddings, slice_points=str_list([0, motif_size, 2 * motif_size]), ) real_motif_log_embeddings = lbann.Identity(all_motif_log_embeddings) fake_motif_log_embeddings = lbann.Identity(all_motif_log_embeddings) # Apply discriminator real_disc_prob, real_disc_log_not_prob \ = self.discriminator(motif_size, real_motif_log_embeddings) fake_disc_prob, fake_disc_log_not_prob \ = self.discriminator(motif_size, fake_motif_log_embeddings) # Loss function # L_disc = - log(D(real)) - log(1-D(fake)) # L_gen = - log(G) * stop_gradient(log(1-D(fake))) real_disc_log_prob \ = lbann.Log(lbann.Clamp(real_disc_prob, min=1e-37, max=1)) disc_loss = lbann.WeightedSum( real_disc_log_prob, fake_disc_log_not_prob, scaling_factors=str_list([-1, -1]), ) gen_loss = lbann.Multiply( gen_log_prob, lbann.StopGradient(fake_disc_log_not_prob), ) loss = lbann.Add(disc_loss, gen_loss) return loss, real_disc_prob, fake_disc_prob, gen_prob
def construct_model(num_epochs,mcr,spectral_loss,save_batch_interval): """Construct LBANN model. """ import lbann # Layer graph input = lbann.Input(target_mode='N/A',name='inp_img') ### Create expected labels for real and fake data (with label flipping = 0.01) prob_flip=0.01 label_flip_rand = lbann.Uniform(min=0,max=1, neuron_dims='1') label_flip_prob = lbann.Constant(value=prob_flip, num_neurons='1') ones = lbann.GreaterEqual(label_flip_rand,label_flip_prob, name='is_real') zeros = lbann.LogicalNot(ones,name='is_fake') gen_ones=lbann.Constant(value=1.0,num_neurons='1')## All ones: no flip. Input for training Generator. #============================================== ### Implement GAN ##Create the noise vector z = lbann.Reshape(lbann.Gaussian(mean=0.0,stdev=1.0, neuron_dims="64", name='noise_vec'),dims='1 64') ## Creating the GAN object and implementing forward pass for both networks ### d1_real, d1_fake, d_adv, gen_img, img = ExaGAN.CosmoGAN(mcr)(input,z,mcr) #============================================== ### Compute quantities for adding to Loss and Metrics d1_real_bce = lbann.SigmoidBinaryCrossEntropy([d1_real,ones],name='d1_real_bce') d1_fake_bce = lbann.SigmoidBinaryCrossEntropy([d1_fake,zeros],name='d1_fake_bce') d_adv_bce = lbann.SigmoidBinaryCrossEntropy([d_adv,gen_ones],name='d_adv_bce') #img_loss = lbann.MeanSquaredError([gen_img,img]) #l1_loss = lbann.L1Norm(lbann.WeightedSum([gen_img,img], scaling_factors="1 -1")) #============================================== ### Set up source and destination layers layers = list(lbann.traverse_layer_graph(input)) weights = set() src_layers,dst_layers = [],[] for l in layers: if(l.weights and "disc1" in l.name and "instance1" in l.name): src_layers.append(l.name) #freeze weights in disc2, analogous to discrim.trainable=False in Keras if(l.weights and "disc2" in l.name): dst_layers.append(l.name) for idx in range(len(l.weights)): l.weights[idx].optimizer = lbann.NoOptimizer() weights.update(l.weights) #l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4) #============================================== ### Define Loss and Metrics #Define loss (Objective function) loss_list=[d1_real_bce,d1_fake_bce,d_adv_bce] ## Usual GAN loss function # loss_list=[d1_real_bce,d1_fake_bce] ## skipping adversarial loss for G for testing spectral loss if spectral_loss: dft_gen_img = lbann.DFTAbs(gen_img) dft_img = lbann.StopGradient(lbann.DFTAbs(img)) spec_loss = lbann.Log(lbann.MeanSquaredError(dft_gen_img, dft_img)) loss_list.append(lbann.LayerTerm(spec_loss, scale=8.0)) loss = lbann.ObjectiveFunction(loss_list) #Define metrics metrics = [lbann.Metric(d1_real_bce,name='d_real'),lbann.Metric(d1_fake_bce, name='d_fake'), lbann.Metric(d_adv_bce,name='gen_adv')] if spectral_loss: metrics.append(lbann.Metric(spec_loss,name='spec_loss')) #============================================== ### Define callbacks list callbacks_list=[] dump_outputs=True save_model=False print_model=False callbacks_list.append(lbann.CallbackPrint()) callbacks_list.append(lbann.CallbackTimer()) callbacks_list.append(lbann.CallbackReplaceWeights(source_layers=list2str(src_layers), destination_layers=list2str(dst_layers),batch_interval=1)) if dump_outputs: #callbacks_list.append(lbann.CallbackDumpOutputs(layers='inp_img gen_img_instance1_activation', execution_modes='train validation', directory='dump_outs',batch_interval=save_batch_interval,format='npy')) callbacks_list.append(lbann.CallbackDumpOutputs(layers='gen_img_instance1_activation', execution_modes='train validation', directory='dump_outs',batch_interval=save_batch_interval,format='npy')) if save_model : callbacks_list.append(lbann.CallbackSaveModel(dir='models')) if print_model: callbacks_list.append(lbann.CallbackPrintModelDescription()) ### Construct model return lbann.Model(num_epochs, weights=weights, layers=layers, metrics=metrics, objective_function=loss, callbacks=callbacks_list)
def compute_loss(self, x, y): # y[:, :-1] y = lbann.Slice( y, axis=0, slice_points=str_list([0, self.input_feature_dims - 1]), ) y = lbann.Identity(y) # x[:, 1:] x = lbann.Slice( x, slice_points=str_list([1, self.input_feature_dims]), ) x = lbann.Identity(x) # Convert indices in x to one-hot representation # Note: Ignored indices result in zero vectors ignore_mask = lbann.Equal( x, self.constant(self.label_to_ignore, hint_layer=x), ) keep_mask = lbann.LogicalNot(ignore_mask) length = lbann.Reduction(keep_mask, mode='sum') length = lbann.Max(length, self.constant(1, [1])) x = lbann.Add( lbann.Multiply(keep_mask, x), lbann.Multiply(ignore_mask, self.constant(-1, hint_layer=x)), ) x = lbann.Slice(x, slice_points=str_list(range(self.input_feature_dims))) x = [lbann.Identity(x) for _ in range(self.input_feature_dims - 1)] x = [lbann.OneHot(xi, size=self.dictionary_size) for xi in x] x = [ lbann.Reshape(xi, dims=str_list([1, self.dictionary_size])) for xi in x ] x = lbann.Concatenation(x, axis=0) # recon_loss = F.cross_entropy( # y[:, :-1].contiguous().view(-1, y.size(-1)), # x[:, 1:].contiguous().view(-1), # ignore_index=self.pad # ) # Note: Ideally we'd shift y by y.max(-1) for numerical stability shifts = lbann.MatMul( lbann.Max(y, self.constant(0, hint_layer=y)), self.constant( 1 / math.sqrt(self.dictionary_size), [self.dictionary_size, self.dictionary_size], ), ) y = lbann.Subtract(y, shifts) z = lbann.MatMul( lbann.Exp(y), self.constant(1, [self.dictionary_size, 1]), ) z = lbann.Log(z) z = lbann.MatMul( lbann.Reshape(keep_mask, dims=str_list([1, -1])), z, ) recon_loss = lbann.MatMul( lbann.Reshape(y, dims=str_list([1, -1])), lbann.Reshape(x, dims=str_list([1, -1])), transpose_b=True, ) recon_loss = lbann.Subtract(z, recon_loss) recon_loss = lbann.Reshape(recon_loss, dims=str_list([1])) recon_loss = lbann.Divide(recon_loss, length) return recon_loss
def compute_loss(self, x, y): # y[:, :-1] y = lbann.Slice( y, axis=0, slice_points=str_list([0, self.input_feature_dims-1]), ) y = lbann.Identity(y) # x[:, 1:] x = lbann.Slice( x, slice_points=str_list([1, self.input_feature_dims]), ) x = lbann.Identity(x) # Figure out entries in x to ignore ignore_mask = lbann.Equal( x, self.constant(self.label_to_ignore, hint_layer=x), ) keep_mask = lbann.LogicalNot(ignore_mask) length = lbann.Reduction(keep_mask, mode='sum') length = lbann.Max(length, self.constant(1, [1])) # Convert entries in x to indices in y # Note: Ignored entries correspond to an index of -1. offsets = [ row*self.dictionary_size for row in range(self.input_feature_dims-1) ] offsets = lbann.Weights( initializer=lbann.ValueInitializer(values=str_list(offsets)), optimizer=lbann.NoOptimizer(), ) offsets = lbann.WeightsLayer( dims=str_list([self.input_feature_dims-1]), weights=offsets, ) y_inds = lbann.Add(x, offsets) y_inds = lbann.Add( lbann.Multiply(keep_mask, y_inds), lbann.Multiply( ignore_mask, self.constant(-1, hint_layer=y_inds), ), ) # recon_loss = F.cross_entropy( # y[:, :-1].contiguous().view(-1, y.size(-1)), # x[:, 1:].contiguous().view(-1), # ignore_index=self.pad # ) # Shift y for numerical stability # Note: We'd prefer to shift by y.max(-1) shifts = lbann.MatMul( lbann.Max(y, self.constant(0, hint_layer=y)), self.constant( 1 / math.sqrt(self.dictionary_size), [self.dictionary_size, self.dictionary_size], ), ) y = lbann.Subtract(y, shifts) # Compute log of softmax denominator and sum z = lbann.MatMul( lbann.Exp(y), self.constant(1, [self.dictionary_size, 1]), ) z = lbann.Log(z) z = lbann.MatMul( lbann.Reshape(keep_mask, dims=str_list([1, -1])), z, ) z = lbann.Reshape(z, dims=str_list([1])) # Compute cross entropy recon_loss = lbann.Gather( lbann.Reshape(y, dims=str_list([-1])), y_inds, ) recon_loss = lbann.Reduction(recon_loss, mode='sum') recon_loss = lbann.Subtract(z, recon_loss) recon_loss = lbann.Divide(recon_loss, length) return recon_loss