def forward(self, image, dims, max_r): """Compute radial profile. Args: image (lbann.Layer): Image dims (tuple of int): Image dimensions (dim 0 corresponds to channel) max_r (int): Maximum radial distance. Pixels outside this distance are ignored. Returns: Layer: num_channels x max_r radial profile """ # Bin spatial positions r, r_counts = self._find_radial_bins(dims[1:], max_r) # Reciprocal of bin counts # Note: If a count is 0, its reciprocal is 0. r_counts_recip = [0 if c == 0 else 1 / c for c in r_counts] # Get scatter indices and scaling factors # Note: Independent binning for each channel (dim 0) tile_dims = [dims[0]] + [1] * r.ndim inds_vals = np.tile(r, tile_dims) inds_vals += np.arange(0, dims[0] * max_r, max_r).reshape(tile_dims) inds_vals[:, r >= max_r] = -1 inds_vals = inds_vals.flatten() scales_vals = r_counts_recip * dims[0] # Construct LBANN layer graph image = lbann.Reshape(image, dims=str_list([np.prod(dims)])) inds = lbann.WeightsLayer( weights=lbann.Weights( lbann.ValueInitializer(values=str_list(inds_vals)), optimizer=lbann.NoOptimizer(), ), dims=str_list([len(inds_vals)]), ) r_sums = lbann.Scatter(image, inds, dims=str_list([dims[0] * max_r])) scales = lbann.WeightsLayer( weights=lbann.Weights( lbann.ValueInitializer(values=str_list(scales_vals)), optimizer=lbann.NoOptimizer(), ), dims=str_list([len(scales_vals)]), ) r_means = lbann.Multiply(scales, r_sums) return lbann.Reshape(r_means, dims=str_list([dims[0], max_r]))
def forward(self, x, dims): """Apply fftshift. Args: x (lbann.Layer): Input tensor dims (tuple of int): Dimensions of x (dim 0 corresponds to channel) Returns: Layer: Output tensor """ # Get gather indices by applying fftshift to tensor filled with indices # Note: Independent fftshift for each channel (dim 0) spatial_size = np.prod(dims[1:]) spatial_inds = np.arange(spatial_size).reshape(dims[1:]) spatial_inds = np.fft.fftshift(spatial_inds) channel_offsets = np.arange(0, dims[0] * spatial_size, spatial_size) channel_offsets = channel_offsets.reshape([-1] + [1] * spatial_inds.ndim) inds = np.expand_dims(spatial_inds, 0) + channel_offsets # Construct LBANN layer graph size = np.prod(dims) x = lbann.Reshape(x, dims=str_list([size])) inds = lbann.WeightsLayer( weights=lbann.Weights( lbann.ValueInitializer(values=str_list(inds.flatten())), optimizer=lbann.NoOptimizer(), ), dims=str_list([size]), ) y = lbann.Gather(x, inds) return lbann.Reshape(y, dims=str_list(dims))
def __init__(self, size, statistics_group_size=1, name=None, data_layout='data_parallel'): super().__init__() FcBnRelu.global_count += 1 self.instance = 0 self.name = (name if name else 'fcbnrelu{0}'.format(FcBnRelu.global_count)) self.data_layout = data_layout self.fc = lbann.modules.FullyConnectedModule( size, bias=False, name=self.name + '_fc', data_layout=self.data_layout) # Weights for batchnorm scalebias_vals = [1.0] * size + [0.0] * size self.bn_weights = [ lbann.Weights(name='{0}_bn_running_mean'.format(self.name), initializer=lbann.ConstantInitializer(value=0.0)), lbann.Weights(name='{0}_bn_running_var'.format(self.name), initializer=lbann.ConstantInitializer(value=1.0)), lbann.Weights(name='{0}_bn_scalebias'.format(self.name), initializer=lbann.ValueInitializer( values=' '.join([str(x) for x in scalebias_vals]))) ]
def _positional_encoding(self, sequence_length): """Positional encodings corresponding to a sequence length. PE(pos,2*i) = sin( pos / 10000**(2*i/hidden_size) ) PE(pos,2*i+1) = cos( pos / 10000**(2*i/hidden_size) ) Encodings are memoized. """ # Construct positional encoding if not in cache if sequence_length not in self._positional_encoding_cache: vals = [] for pos in range(sequence_length): for i in range((self.hidden_size + 1) // 2): x = pos / 10000**(2 * i / self.hidden_size) vals.append(math.sin(x)) vals.append(math.cos(x)) if self.hidden_size % 2 != 0: vals.pop() weights = lbann.Weights( initializer=lbann.ValueInitializer(values=str_list(vals)), optimizer=None, name=f'{self.name}_positional{sequence_length}_weights', ) self._positional_encoding_cache[ sequence_length] = lbann.WeightsLayer( dims=str_list([sequence_length, self.hidden_size]), weights=weights, name=f'{self.name}_positional{sequence_length}', ) # Return cached positional encoding return self._positional_encoding_cache[sequence_length]
def _subsequent_mask(self, size): """Attention mask to prevent attending to subsequent positions. The (i,j) entry is -1e9 if i<j and is 0 otherwise. Masks are memoized. """ # Construct mask if not in cache if size not in self._subsequent_mask_cache: vals = np.triu(np.full((size, size), -1e9), k=1) weights = lbann.Weights( initializer=lbann.ValueInitializer( values=str_list(np.nditer(vals))), optimizer=None, name=f'{self.name}_mask{size}_weights', ) self._subsequent_mask_cache[size] = lbann.WeightsLayer( dims=str_list([size, size]), weights=weights, name=f'{self.name}_mask{size}', ) # Return cached mask return self._subsequent_mask_cache[size]
def Permute(x, dims, axes=None, name="", return_dims=False): global _permute_cache key = (dims, axes) size = np.prod(dims) if key not in _permute_cache: # Construct gather indices inds = np.arange(size).reshape(dims, order="C").transpose(axes) inds = lbann.Weights( initializer=lbann.ValueInitializer(values=str_list( np.nditer(inds, order="C")), ), optimizer=lbann.NoOptimizer(), ) inds = lbann.WeightsLayer(dims=str_list([size]), weights=inds) _permute_cache[key] = inds # Apply transpose with gather inds = _permute_cache[key] if axes == None: new_dims = dims[::-1] else: new_dims = np.array(dims)[list(axes)] x = lbann.Reshape(x, dims=str_list([size])) y = lbann.Gather(x, inds) y = lbann.Reshape(y, dims=str_list(list(new_dims)), name=name) if return_dims: return y, tuple(new_dims) return y
def Cumsum(x, dims, axis=0): global _cumsum_cache if len(dims) != 2: raise RuntimeError("dims > 2 not tested/supported for cumsum") if (axis < 0) or (axis > 1): raise RuntimeError("Unsupported cumsum axis: {}".format(axis)) shape = (dims[axis], dims[axis]) if shape not in _cumsum_cache: tril_ones = np.tril(np.full(shape, 1, dtype=int), k=0) tril_ones = lbann.Weights( initializer=lbann.ValueInitializer(values=str_list( np.nditer(tril_ones, order="C")), ), optimizer=lbann.NoOptimizer(), ) tril_ones = lbann.WeightsLayer(dims=str_list(shape), weights=tril_ones) _cumsum_cache[shape] = tril_ones # Apply cumsum tril_ones = _cumsum_cache[shape] if axis == 0: x = lbann.MatMul(tril_ones, x) return x if axis == 1: x = lbann.MatMul(x, tril_ones, transpose_b=True) return x
def random_projection(indices, num_projections, projection_dim): # Expand input indices to get an index for each vector entry # Note: proj_indices(i) = index*projection_dim + i proj_indices = lbann.WeightedSum( indices, scaling_factors=utils.str_list(projection_dim), ) iota = lbann.WeightsLayer( dims=utils.str_list(projection_dim), weights=lbann.Weights( initializer=lbann.ValueInitializer( values=utils.str_list(range(projection_dim))), optimizer=lbann.NoOptimizer(), ), ) proj_indices = lbann.Sum( lbann.Tessellate( lbann.Reshape(proj_indices, dims=utils.str_list([num_projections, 1])), dims=utils.str_list([num_projections, projection_dim]), ), lbann.Tessellate( lbann.Reshape(iota, dims=utils.str_list([1, projection_dim])), dims=utils.str_list([num_projections, projection_dim]), ), ) # Apply hash function and convert to Gaussian distribution proj = lbann.UniformHash(proj_indices) ones = lbann.Constant( value=1, num_neurons=utils.str_list([num_projections, projection_dim]), ) eps = 0.001 proj = lbann.ErfInv( lbann.WeightedSum( proj, ones, scaling_factors=utils.str_list([2 * (1 - eps), -(1 - eps)]), )) proj = lbann.InstanceNorm(proj) proj = lbann.WeightedSum( proj, scaling_factors=utils.str_list(1 / projection_dim), ) return proj
def create_position_ids_from_inputs_embeds(self, input_embeds): sequence_length = self.input_shape[1] position_ids = range(self.padding_idx + 1, sequence_length + self.padding_idx + 1) position_ids = lbann.WeightsLayer( weights=lbann.Weights( initializer=lbann.ValueInitializer( values=str_list(position_ids)), optimizer=lbann.NoOptimizer(), ), dims=str_list([sequence_length]), ) position_ids = lbann.Reshape(position_ids, dims=str_list([1, sequence_length])) position_ids = lbann.Tessellate(position_ids, dims=str_list(self.input_shape[:-1])) return position_ids
def mean_squared_error( data_dim, sequence_length, source_sequence, target_sequence, scale_decay=0.8, ): # Compute inner product between source and target vectors # Note: Inner products are computed for each (x,y) pair and a # weighted sum is computed. The scaling factors sum to 1 and decay # exponentially as x and y get further apart in the sequence. prods = lbann.MatMul( source_sequence, target_sequence, transpose_b=True, ) scale_dims = (sequence_length, sequence_length) scales = np.zeros(scale_dims) for i in range(sequence_length): for j in range(sequence_length): if i != j: scales[i, j] = ((1 - scale_decay) / (2 * scale_decay) * scale_decay**np.abs(j - i)) scales = lbann.Weights( initializer=lbann.ValueInitializer( values=utils.str_list(np.nditer(scales))), optimizer=lbann.NoOptimizer(), ) scales = lbann.WeightsLayer(dims=utils.str_list(scale_dims), weights=scales) prods = lbann.MatMul( lbann.Reshape(prods, dims='1 -1'), lbann.Reshape(scales, dims='1 -1'), transpose_b=True, ) prods = lbann.Reshape(prods, dims='1') # MSE(x,y) = ( norm(x)^2 + norm(y)^T - 2*prod(x,y) ) / dim(x) scale = 1 / (data_dim * sequence_length) return lbann.WeightedSum(lbann.L2Norm2(source_sequence), lbann.L2Norm2(target_sequence), prods, scaling_factors=utils.str_list( [scale, scale, -2 * scale]))
def positive_samples_loss( sequence_length, encoder_embeddings, decoder_embeddings, scale_decay=0.8, ): # Compute similarity scores between encoder and decoder embeddings scores = lbann.MatMul( encoder_embeddings, decoder_embeddings, transpose_b=True, ) scores = lbann.LogSigmoid(scores) # Scale similarity scores and add together # Note: The scaling factor decays exponentially as embeddings get # futher apart in the sequence. # Note: The sum of all the scaling factors is approximately -1. scale_dims = (sequence_length,sequence_length) scales = np.zeros(scale_dims) for i in range(sequence_length): for j in range(sequence_length): if i != j: scales[i,j] = ( -(1-scale_decay)/(2*scale_decay*sequence_length) * scale_decay**np.abs(j-i) ) scales = lbann.Weights( initializer=lbann.ValueInitializer(values=utils.str_list(np.nditer(scales))), optimizer=lbann.NoOptimizer(), ) scales = lbann.WeightsLayer(dims=utils.str_list(scale_dims), weights=scales) loss = lbann.MatMul( lbann.Reshape(scores, dims='1 -1'), lbann.Reshape(scales, dims='1 -1'), transpose_b=True, ) loss = lbann.Reshape(loss, dims='1') return loss
input_ = lbann.Input() # NumPy implementation dims = [2, 3, 4, 7] np_x = np.random.uniform(size=dims).astype(np.float32) np_y = np.zeros_like(np_x) for i in range(dims[0]): np_y[i] = np.fft.fftshift(np_x[i]) np_scales = np.random.uniform(size=np.prod(dims)).astype(np.float32) np_z = np.inner(np_y.flatten(), np_scales).item() tol = 8 * np_z * np.finfo(np.float32).eps # LBANN implementation lbann_x = lbann.WeightsLayer( weights=lbann.Weights( lbann.ValueInitializer(values=str_list(np_x.flatten())), ), dims=str_list(np_x.shape), ) lbann_y = FFTShift()(lbann_x, dims) lbann_scales = lbann.WeightsLayer( weights=lbann.Weights( lbann.ValueInitializer(values=str_list(np_scales)), optimizer=lbann.NoOptimizer(), ), dims=str_list(np_scales.shape), ) lbann_z = lbann.MatMul(lbann.Reshape(lbann_y, dims=str_list([1, -1])), lbann.Reshape(lbann_scales, dims=str_list([-1, 1]))) # Construct LBANN model with metric checking and gradient checking metric = lbann.Metric(lbann_z, name='metric')
_reader = reader.reader.add() _reader.name = 'synthetic' _reader.role = role _reader.num_samples = 1 _reader.num_labels = 1 _reader.synth_dimensions = '1' _reader.percent_of_data_to_use = 1.0 add_data_reader('train') add_data_reader('test') input_ = lbann.Input() # Radial profile x = lbann.WeightsLayer( weights=lbann.Weights( lbann.ValueInitializer(values=str_list(image.flatten())), ), dims=str_list(image.shape), ) max_r = image.shape[-1] // 2 rprof = RadialProfile()(x, image.shape, max_r) rprof_slice = lbann.Slice(rprof, slice_points=str_list([0, 1, 2, 3])) red = lbann.Identity(rprof_slice, name='red') green = lbann.Identity(rprof_slice, name='green') blue = lbann.Identity(rprof_slice, name='blue') # Construct model callbacks = [ lbann.CallbackDumpOutputs(layers=str_list(['red', 'green', 'blue'])), ] model = lbann.Model( epochs=0,
def compute_loss(self, x, y): # y[:, :-1] y = lbann.Slice( y, axis=0, slice_points=str_list([0, self.input_feature_dims-1]), ) y = lbann.Identity(y) # x[:, 1:] x = lbann.Slice( x, slice_points=str_list([1, self.input_feature_dims]), ) x = lbann.Identity(x) # Figure out entries in x to ignore ignore_mask = lbann.Equal( x, self.constant(self.label_to_ignore, hint_layer=x), ) keep_mask = lbann.LogicalNot(ignore_mask) length = lbann.Reduction(keep_mask, mode='sum') length = lbann.Max(length, self.constant(1, [1])) # Convert entries in x to indices in y # Note: Ignored entries correspond to an index of -1. offsets = [ row*self.dictionary_size for row in range(self.input_feature_dims-1) ] offsets = lbann.Weights( initializer=lbann.ValueInitializer(values=str_list(offsets)), optimizer=lbann.NoOptimizer(), ) offsets = lbann.WeightsLayer( dims=str_list([self.input_feature_dims-1]), weights=offsets, ) y_inds = lbann.Add(x, offsets) y_inds = lbann.Add( lbann.Multiply(keep_mask, y_inds), lbann.Multiply( ignore_mask, self.constant(-1, hint_layer=y_inds), ), ) # recon_loss = F.cross_entropy( # y[:, :-1].contiguous().view(-1, y.size(-1)), # x[:, 1:].contiguous().view(-1), # ignore_index=self.pad # ) # Shift y for numerical stability # Note: We'd prefer to shift by y.max(-1) shifts = lbann.MatMul( lbann.Max(y, self.constant(0, hint_layer=y)), self.constant( 1 / math.sqrt(self.dictionary_size), [self.dictionary_size, self.dictionary_size], ), ) y = lbann.Subtract(y, shifts) # Compute log of softmax denominator and sum z = lbann.MatMul( lbann.Exp(y), self.constant(1, [self.dictionary_size, 1]), ) z = lbann.Log(z) z = lbann.MatMul( lbann.Reshape(keep_mask, dims=str_list([1, -1])), z, ) z = lbann.Reshape(z, dims=str_list([1])) # Compute cross entropy recon_loss = lbann.Gather( lbann.Reshape(y, dims=str_list([-1])), y_inds, ) recon_loss = lbann.Reduction(recon_loss, mode='sum') recon_loss = lbann.Subtract(z, recon_loss) recon_loss = lbann.Divide(recon_loss, length) return recon_loss