def mean_squared_error( data_dim, sequence_length, source_sequence, target_sequence, scale_decay=0.8, ): # Compute inner product between source and target vectors # Note: Inner products are computed for each (x,y) pair and a # weighted sum is computed. The scaling factors sum to 1 and decay # exponentially as x and y get further apart in the sequence. prods = lbann.MatMul( source_sequence, target_sequence, transpose_b=True, ) scale_dims = (sequence_length, sequence_length) scales = np.zeros(scale_dims) for i in range(sequence_length): for j in range(sequence_length): if i != j: scales[i, j] = ((1 - scale_decay) / (2 * scale_decay) * scale_decay**np.abs(j - i)) scales = lbann.Weights( initializer=lbann.ValueInitializer( values=utils.str_list(np.nditer(scales))), optimizer=lbann.NoOptimizer(), ) scales = lbann.WeightsLayer(dims=utils.str_list(scale_dims), weights=scales) prods = lbann.MatMul( lbann.Reshape(prods, dims='1 -1'), lbann.Reshape(scales, dims='1 -1'), transpose_b=True, ) prods = lbann.Reshape(prods, dims='1') # MSE(x,y) = ( norm(x)^2 + norm(y)^T - 2*prod(x,y) ) / dim(x) scale = 1 / (data_dim * sequence_length) return lbann.WeightedSum(lbann.L2Norm2(source_sequence), lbann.L2Norm2(target_sequence), prods, scaling_factors=utils.str_list( [scale, scale, -2 * scale]))
def forward(self, _): w = lbann.WeightsLayer(weights=self.weights, dims='%d %d'.format(self.width, self.height)) slice = lbann.Slice(w, axis=0, slice_points=' '.join(range(self.width + 1))) cols = [] for _ in range(self.width): cols.append(lbann.Sqrt(lbann.L2Norm2(slice))) return lbann.Sum(cols)
def construct_jag_wae_model(ydim, zdim, mcf, useCNN, dump_models, ltfb_batch_interval, num_epochs): """Construct LBANN model. JAG Wasserstein autoencoder model """ # Layer graph input = lbann.Input(data_field='samples', name='inp_data') # data is 64*64*4 images + 15 scalar + 5 param #inp_slice = lbann.Slice(input, axis=0, slice_points="0 16399 16404",name='inp_slice') inp_slice = lbann.Slice(input, axis=0, slice_points=str_list([0, ydim, ydim + 5]), name='inp_slice') gt_y = lbann.Identity(inp_slice, name='gt_y') gt_x = lbann.Identity(inp_slice, name='gt_x') #param not used zero = lbann.Constant(value=0.0, num_neurons='1', name='zero') one = lbann.Constant(value=1.0, num_neurons='1', name='one') z_dim = 20 #Latent space dim z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims="20") model = macc_network_architectures.MACCWAE(zdim, ydim, cf=mcf, use_CNN=useCNN) d1_real, d1_fake, d_adv, pred_y = model(z, gt_y) d1_real_bce = lbann.SigmoidBinaryCrossEntropy([d1_real, one], name='d1_real_bce') d1_fake_bce = lbann.SigmoidBinaryCrossEntropy([d1_fake, zero], name='d1_fake_bce') d_adv_bce = lbann.SigmoidBinaryCrossEntropy([d_adv, one], name='d_adv_bce') img_loss = lbann.MeanSquaredError([pred_y, gt_y]) rec_error = lbann.L2Norm2( lbann.WeightedSum([pred_y, gt_y], scaling_factors="1 -1")) layers = list(lbann.traverse_layer_graph(input)) # Setup objective function weights = set() src_layers = [] dst_layers = [] for l in layers: if (l.weights and "disc0" in l.name and "instance1" in l.name): src_layers.append(l.name) #freeze weights in disc2 if (l.weights and "disc1" in l.name): dst_layers.append(l.name) for idx in range(len(l.weights)): l.weights[idx].optimizer = lbann.NoOptimizer() weights.update(l.weights) l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4) d_adv_bce = lbann.LayerTerm(d_adv_bce, scale=0.01) obj = lbann.ObjectiveFunction( [d1_real_bce, d1_fake_bce, d_adv_bce, img_loss, rec_error, l2_reg]) # Initialize check metric callback metrics = [lbann.Metric(img_loss, name='recon_error')] #pred_y = macc_models.MACCWAE.pred_y_name callbacks = [ lbann.CallbackPrint(), lbann.CallbackTimer(), lbann.CallbackPrintModelDescription(), lbann.CallbackSaveModel(dir=dump_models), lbann.CallbackReplaceWeights(source_layers=list2str(src_layers), destination_layers=list2str(dst_layers), batch_interval=2) ] if (ltfb_batch_interval > 0): callbacks.append( lbann.CallbackLTFB(batch_interval=ltfb_batch_interval, metric='recon_error', low_score_wins=True, exchange_hyperparameters=True)) # Construct model return lbann.Model(num_epochs, weights=weights, layers=layers, metrics=metrics, objective_function=obj, callbacks=callbacks)
def construct_model(): """Construct LBANN model. JAG Wasserstein autoencoder model """ import lbann # Layer graph input = lbann.Input(target_mode='N/A',name='inp_data') # data is 64*64*4 images + 15 scalar + 5 param inp_slice = lbann.Slice(input, axis=0, slice_points="0 16399 16404",name='inp_slice') gt_y = lbann.Identity(inp_slice,name='gt_y') gt_x = lbann.Identity(inp_slice, name='gt_x') #param not used zero = lbann.Constant(value=0.0,num_neurons='1',name='zero') one = lbann.Constant(value=1.0,num_neurons='1',name='one') y_dim = 16399 #image+scalar shape z_dim = 20 #Latent space dim z = lbann.Gaussian(mean=0.0,stdev=1.0, neuron_dims="20") d1_real, d1_fake, d_adv, pred_y = jag_models.WAE(z_dim,y_dim)(z,gt_y) d1_real_bce = lbann.SigmoidBinaryCrossEntropy([d1_real,one],name='d1_real_bce') d1_fake_bce = lbann.SigmoidBinaryCrossEntropy([d1_fake,zero],name='d1_fake_bce') d_adv_bce = lbann.SigmoidBinaryCrossEntropy([d_adv,one],name='d_adv_bce') img_loss = lbann.MeanSquaredError([pred_y,gt_y]) rec_error = lbann.L2Norm2(lbann.WeightedSum([pred_y,gt_y], scaling_factors="1 -1")) layers = list(lbann.traverse_layer_graph(input)) # Setup objective function weights = set() src_layers = [] dst_layers = [] for l in layers: if(l.weights and "disc0" in l.name and "instance1" in l.name): src_layers.append(l.name) #freeze weights in disc2 if(l.weights and "disc1" in l.name): dst_layers.append(l.name) for idx in range(len(l.weights)): l.weights[idx].optimizer = lbann.NoOptimizer() weights.update(l.weights) l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4) d_adv_bce = lbann.LayerTerm(d_adv_bce,scale=0.01) obj = lbann.ObjectiveFunction([d1_real_bce,d1_fake_bce,d_adv_bce,img_loss,rec_error,l2_reg]) # Initialize check metric callback metrics = [lbann.Metric(img_loss, name='recon_error')] callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer(), lbann.CallbackReplaceWeights(source_layers=list2str(src_layers), destination_layers=list2str(dst_layers), batch_interval=2)] # Construct model num_epochs = 100 return lbann.Model(num_epochs, weights=weights, layers=layers, metrics=metrics, objective_function=obj, callbacks=callbacks)