def forward(self, x, z): """Do the WAE forward step :param x: list of tensors of longs, embed representation of input :return: float, kl term component of loss :return: float, recon component of loss """ x = lbann.Slice(x, slice_points=str_list([0, self.input_feature_dims])) x = lbann.Identity(x) x_emb = lbann.Embedding(x, num_embeddings=self.dictionary_size, embedding_dim=self.embedding_size, name='emb', weights=self.emb_weights) # Encoder: x -> z, kl_loss z_sample = self.forward_encoder(x_emb) eps = lbann.Gaussian(mean=self.gmean, stdev=self.gstd, hint_layer=z_sample) z_sample = lbann.Add([z_sample, eps]) # Decoder: x, z -> recon_loss #pred = self.forward_decoder(x_emb, z_sample) pred, arg_max = self.forward_decoder(x_emb, z_sample) recon_loss = self.compute_loss(x, pred) # Hack to remove blocking GPU allreduce in evaluation layer #kl_loss = lbann.Identity(kl_loss, device='CPU') recon_loss = lbann.Identity(recon_loss, device='CPU') z_prior = lbann.Tessellate( lbann.Reshape(z, dims=str_list([1, self.zdim])), dims=str_list([self.input_feature_dims, self.zdim]), ) d_real = self.discriminator0( lbann.Concatenation([x_emb, z_prior], axis=1)) z_sample0 = lbann.Tessellate( lbann.Reshape(z_sample, dims=str_list([1, self.zdim])), dims=str_list([self.input_feature_dims, self.zdim]), ) y_z_sample = lbann.Concatenation([x_emb, z_sample0], axis=1) d_fake = self.discriminator0(lbann.StopGradient(y_z_sample)) d_adv = self.discriminator1(y_z_sample) #freeze return recon_loss, d_real, d_fake, d_adv, arg_max
def forward_decoder(self, x_emb, z): """Decoder step, emulating x ~ G(z) :param x_emb: (n_batch, len(x), d_z) of floats, embeddings for input sentence x :param z: (n_batch, d_z) of floats, latent vector z :return: float, recon component of loss :return: list of ints, reconstructed sentence """ # z_0 = z.unsqueeze(1).repeat(1, x_emb.size(1), 1) # x_input = torch.cat([x_emb, z_0], dim=-1) z_0 = lbann.Tessellate( lbann.Reshape(z, dims=str_list([1, 128])), dims=str_list([self.input_feature_dims, 128]), ) x_input = lbann.Concatenation(x_emb, z_0, axis=1) h_0 = self.decoder_lat(z) # h_0 = h_0.unsqueeze(0).repeat(self.decoder_rnn.num_layers, 1, 1) h_0 = lbann.Reshape(h_0, dims=str_list([1, 512])) h_0 = lbann.Tessellate(h_0, dims=str_list((3, 512))) # output, _ = self.decoder_rnn(x_input, h_0) output = self.decoder_rnn(x_input, h_0) # y = self.decoder_fc(output) y = lbann.ChannelwiseFullyConnected( output, output_channel_dims=self.dictionary_size, bias=True, name=f'{self.decoder_fc.name}', weights=self.decoder_fc.weights, ) # Set datatype of layers # Note: Depth-first search from y to x_emb and z stack = [y] in_stack = {l: True for l in stack} while stack: l = stack.pop() if type(l) not in (lbann.Slice, lbann.Reshape, lbann.Tessellate): l.datatype = self.datatype for parent in l.parents: if parent not in in_stack and parent not in (x_emb, z): stack.append(parent) in_stack[parent] = True return y
def random_projection(indices, num_projections, projection_dim): # Expand input indices to get an index for each vector entry # Note: proj_indices(i) = index*projection_dim + i proj_indices = lbann.WeightedSum( indices, scaling_factors=utils.str_list(projection_dim), ) iota = lbann.WeightsLayer( dims=utils.str_list(projection_dim), weights=lbann.Weights( initializer=lbann.ValueInitializer( values=utils.str_list(range(projection_dim))), optimizer=lbann.NoOptimizer(), ), ) proj_indices = lbann.Sum( lbann.Tessellate( lbann.Reshape(proj_indices, dims=utils.str_list([num_projections, 1])), dims=utils.str_list([num_projections, projection_dim]), ), lbann.Tessellate( lbann.Reshape(iota, dims=utils.str_list([1, projection_dim])), dims=utils.str_list([num_projections, projection_dim]), ), ) # Apply hash function and convert to Gaussian distribution proj = lbann.UniformHash(proj_indices) ones = lbann.Constant( value=1, num_neurons=utils.str_list([num_projections, projection_dim]), ) eps = 0.001 proj = lbann.ErfInv( lbann.WeightedSum( proj, ones, scaling_factors=utils.str_list([2 * (1 - eps), -(1 - eps)]), )) proj = lbann.InstanceNorm(proj) proj = lbann.WeightedSum( proj, scaling_factors=utils.str_list(1 / projection_dim), ) return proj
def create_position_ids_from_inputs_embeds(self, input_embeds): sequence_length = self.input_shape[1] position_ids = range(self.padding_idx + 1, sequence_length + self.padding_idx + 1) position_ids = lbann.WeightsLayer( weights=lbann.Weights( initializer=lbann.ValueInitializer( values=str_list(position_ids)), optimizer=lbann.NoOptimizer(), ), dims=str_list([sequence_length]), ) position_ids = lbann.Reshape(position_ids, dims=str_list([1, sequence_length])) position_ids = lbann.Tessellate(position_ids, dims=str_list(self.input_shape[:-1])) return position_ids
def make_model( num_epochs, embed_dim, num_heads, label_smoothing, ): # Embedding weights var = 2 / (embed_dim + vocab_size) # Glorot initialization embedding_weights = lbann.Weights( name='embeddings', initializer=lbann.NormalInitializer(standard_deviation=math.sqrt(var)), ) # Input is two sequences of token IDs input_ = lbann.Input(data_field='samples') # Get sequences of embedding vectors # Note: Scale embeddings by sqrt(embed_dim). # Note: Decoder input is shifted right, so embedding for last # token isn't needed. embeddings_tokens = lbann.Identity( lbann.Slice( input_, axis=0, slice_points=str_list([0, 2 * sequence_length - 1]), )) embeddings = lbann.Embedding( embeddings_tokens, weights=embedding_weights, num_embeddings=vocab_size, embedding_dim=embed_dim, padding_idx=pad_index, ) embeddings = lbann.WeightedSum( embeddings, scaling_factors=str(math.sqrt(embed_dim)), ) embeddings_slice = lbann.Slice( embeddings, axis=0, slice_points=str_list([0, sequence_length, 2 * sequence_length - 1]), ) encoder_input = lbann.Identity(embeddings_slice) decoder_input = lbann.Identity(embeddings_slice) # Apply transformer model transformer = lbann.models.Transformer( hidden_size=embed_dim, num_heads=num_heads, name='transformer', ) result = transformer( encoder_input, sequence_length, decoder_input, sequence_length - 1, ) # Reconstruct decoder input preds = lbann.ChannelwiseFullyConnected( result, weights=embedding_weights, output_channel_dims=[vocab_size], bias=False, transpose=True, ) preds = lbann.ChannelwiseSoftmax(preds) preds = lbann.Slice(preds, axis=0, slice_points=str_list(range(sequence_length))) preds = [lbann.Identity(preds) for _ in range(sequence_length - 1)] # Count number of non-pad tokens label_tokens = lbann.Identity( lbann.Slice( input_, slice_points=str_list([sequence_length + 1, 2 * sequence_length]), )) pads = lbann.Constant(value=pad_index, num_neurons=str(sequence_length - 1)) is_not_pad = lbann.NotEqual(label_tokens, pads) num_not_pad = lbann.Reduction(is_not_pad, mode='sum') # Cross entropy loss with label smoothing label_tokens = lbann.Slice( label_tokens, slice_points=str_list(range(sequence_length)), ) label_tokens = [ lbann.Identity(label_tokens) for _ in range(sequence_length - 1) ] if label_smoothing > 0: uniform_label = lbann.Constant(value=1 / vocab_size, num_neurons=str_list([1, vocab_size])) loss = [] for i in range(sequence_length - 1): label = lbann.OneHot(label_tokens[i], size=vocab_size) label = lbann.Reshape(label, dims=str_list([1, vocab_size])) if label_smoothing > 0: label = lbann.WeightedSum( label, uniform_label, scaling_factors=str_list( [1 - label_smoothing, label_smoothing]), ) loss.append(lbann.CrossEntropy(preds[i], label)) loss = lbann.Concatenation(loss) # Average cross entropy over non-pad tokens loss_scales = lbann.Divide( is_not_pad, lbann.Tessellate(num_not_pad, hint_layer=is_not_pad), ) loss = lbann.Multiply(loss, loss_scales) loss = lbann.Reduction(loss, mode='sum') # Construct model metrics = [] callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()] return lbann.Model( num_epochs, layers=lbann.traverse_layer_graph(input_), objective_function=loss, metrics=metrics, callbacks=callbacks, )