def reshape_like(a, b): """Reshapes a to match the shape of b in all but the last dimension.""" ret = tf.reshape(a, tf.concat([tf.shape(b)[:-1], tf.shape(a)[-1:]], 0)) if not contrib_eager.in_eager_mode(): ret.set_shape(b.get_shape().as_list()[:-1] + a.get_shape().as_list()[-1:]) return ret
def _get_weights(model_hparams, vocab_size, hidden_dim=None): """Copied from tensor2tensor/layers/modalities.py but uses total vocab.""" if hidden_dim is None: hidden_dim = model_hparams.hidden_size num_shards = model_hparams.symbol_modality_num_shards shards = [] for i in range(num_shards): shard_size = (sum(vocab_size) // num_shards) + ( 1 if i < sum(vocab_size) % num_shards else 0) var_name = 'weights_%d' % i shards.append( tf.get_variable( var_name, [shard_size, hidden_dim], initializer=tf.random_normal_initializer(0.0, hidden_dim**-0.5))) if num_shards == 1: ret = shards[0] else: ret = tf.concat(shards, 0) # Convert ret to tensor. if not contrib_eager.in_eager_mode(): ret = common_layers.convert_gradient_to_tensor(ret) return ret
def _get_weights(model_hparams, vocab_size, hidden_dim=None): """Create or get concatenated embedding or softmax variable. Args: model_hparams: tf.HParams, model hyperparmeters. vocab_size: int, vocabulary size. hidden_dim: dim of the variable. Defaults to model_hparams.hidden_size Returns: a list of num_shards Tensors. """ if hidden_dim is None: hidden_dim = model_hparams.hidden_size num_shards = model_hparams.symbol_modality_num_shards shards = [] sparsity_technique = model_hparams.get("sparsity_technique") aux_params_shards = [] for i in range(num_shards): shard_size = (vocab_size // num_shards) + ( 1 if i < vocab_size % num_shards else 0) var_name = "weights_%d" % i weight_init_stddev = hidden_dim**-0.5 if (model_hparams.get("load_masks_from") and model_hparams.get("initial_sparsity")): # If we are loading constant masks for scratch-e or scratch-b # experiments, we optionally rescale the variance of the weight # initialization. initial_sparsity = model_hparams.get("initial_sparsity") weight_init_stddev = (hidden_dim * (1 - initial_sparsity))**-0.5 tf.logging.info("Using sparse initialization with sparsity {} for symbol " .format(initial_sparsity)) shards.append( tf.get_variable( var_name, [shard_size, hidden_dim], initializer=tf.random_normal_initializer(0.0, weight_init_stddev))) if sparsity_technique == "variational_dropout": aux_params_shards.append( tf.get_variable( var_name + "_aux", [shard_size, hidden_dim], initializer=tf.constant_initializer(value=-10.0))) elif sparsity_technique == "l0_regularization": initializer = tf.random_normal_initializer(mean=2.197, stddev=0.01) aux_params_shards.append( tf.get_variable( var_name + "_aux", [shard_size, hidden_dim], initializer=initializer)) if num_shards == 1: ret = shards[0] else: ret = tf.concat(shards, 0) if not aux_params_shards: # Convert ret to tensor. if not contrib_eager.in_eager_mode(): ret = common_layers.convert_gradient_to_tensor(ret) return ret # Handle the auxiliary parameters if num_shards == 1: aux_ret = aux_params_shards[0] else: aux_ret = tf.concat(aux_params_shards, 0) global COLLECTED_VARIABLES if not COLLECTED_VARIABLES: if sparsity_technique == "variational_dropout": tf.add_to_collection( common_sparse.VARIATIONAL_DROPOUT_PARAMETERS, (ret, aux_ret)) elif sparsity_technique == "l0_regularization": tf.add_to_collection( common_sparse.L0_REGULARIZATION_PARAMETERS, (ret, aux_ret)) COLLECTED_VARIABLES = True # Convert aux ret to tensor. if not contrib_eager.in_eager_mode(): ret = common_layers.convert_gradient_to_tensor(ret) aux_ret = common_layers.convert_gradient_to_tensor(aux_ret) return (ret, aux_ret)