def test_masked_dense(input_dim): hidden_dim = input_dim * 3 output_dim_multiplier = input_dim - 4 mask, _ = create_mask(input_dim, [hidden_dim], np.random.permutation(input_dim), output_dim_multiplier) init_random_params, masked_dense = serial(MaskedDense(mask[0])) rng_key = random.PRNGKey(0) batch_size = 4 input_shape = (batch_size, input_dim) _, init_params = init_random_params(rng_key, input_shape) output = masked_dense(init_params, np.random.rand(*input_shape)) assert output.shape == (batch_size, hidden_dim)
def test_masks(input_dim, n_layers, output_dim_multiplier): hidden_dim = input_dim * 3 hidden_dims = [hidden_dim] * n_layers permutation = np.random.permutation(input_dim) masks, mask_skip = create_mask(input_dim, hidden_dims, permutation, output_dim_multiplier) masks = [np.transpose(m) for m in masks] mask_skip = np.transpose(mask_skip) # First test that hidden layer masks are adequately connected # Tracing backwards, works out what inputs each output is connected to # It's a dictionary of sets indexed by a tuple (input_dim, param_dim) _permutation = list(permutation) # Loop over variables for idx in range(input_dim): # Calculate correct answer correct = np.array( sorted(_permutation[0:np.where(permutation == idx)[0][0]])) # Loop over parameters for each variable for jdx in range(output_dim_multiplier): prev_connections = set() # Do output-to-penultimate hidden layer mask for kdx in range(masks[-1].shape[1]): if masks[-1][idx + jdx * input_dim, kdx]: prev_connections.add(kdx) # Do hidden-to-hidden, and hidden-to-input layer masks for m in reversed(masks[:-1]): this_connections = set() for kdx in prev_connections: for ldx in range(m.shape[1]): if m[kdx, ldx]: this_connections.add(ldx) prev_connections = this_connections assert_array_equal(list(sorted(prev_connections)), correct) # Test the skip-connections mask skip_connections = set() for kdx in range(mask_skip.shape[1]): if mask_skip[idx + jdx * input_dim, kdx]: skip_connections.add(kdx) assert_array_equal(list(sorted(skip_connections)), correct)