def _prepare_groundtruth_for_eval(detection_model, class_agnostic, max_number_of_boxes): """Extracts groundtruth data from detection_model and prepares it for eval. Args: detection_model: A `DetectionModel` object. class_agnostic: Whether the detections are class_agnostic. max_number_of_boxes: Max number of groundtruth boxes. Returns: A tuple of: groundtruth: Dictionary with the following fields: 'groundtruth_boxes': [batch_size, num_boxes, 4] float32 tensor of boxes, in normalized coordinates. 'groundtruth_classes': [batch_size, num_boxes] int64 tensor of 1-indexed classes. 'groundtruth_masks': 4D float32 tensor of instance masks (if provided in groundtruth) 'groundtruth_is_crowd': [batch_size, num_boxes] bool tensor indicating is_crowd annotations (if provided in groundtruth). 'groundtruth_area': [batch_size, num_boxes] float32 tensor indicating the area (in the original absolute coordinates) of annotations (if provided in groundtruth). 'num_groundtruth_boxes': [batch_size] tensor containing the maximum number of groundtruth boxes per image.. 'groundtruth_keypoints': [batch_size, num_boxes, num_keypoints, 2] float32 tensor of keypoints (if provided in groundtruth). 'groundtruth_dp_num_points_list': [batch_size, num_boxes] int32 tensor with the number of DensePose points for each instance (if provided in groundtruth). 'groundtruth_dp_part_ids_list': [batch_size, num_boxes, max_sampled_points] int32 tensor with the part ids for each DensePose sampled point (if provided in groundtruth). 'groundtruth_dp_surface_coords_list': [batch_size, num_boxes, max_sampled_points, 4] containing the DensePose surface coordinates for each sampled point (if provided in groundtruth). 'groundtruth_track_ids_list': [batch_size, num_boxes] int32 tensor with track ID for each instance (if provided in groundtruth). 'groundtruth_group_of': [batch_size, num_boxes] bool tensor indicating group_of annotations (if provided in groundtruth). 'groundtruth_labeled_classes': [batch_size, num_classes] int64 tensor of 1-indexed classes. 'groundtruth_verified_neg_classes': [batch_size, num_classes] float32 K-hot representation of 1-indexed classes which were verified as not present in the image. 'groundtruth_not_exhaustive_classes': [batch_size, num_classes] K-hot representation of 1-indexed classes which don't have all of their instances marked exhaustively. class_agnostic: Boolean indicating whether detections are class agnostic. """ input_data_fields = fields.InputDataFields() groundtruth_boxes = tf.stack( detection_model.groundtruth_lists(fields.BoxListFields.boxes)) groundtruth_boxes_shape = tf.shape(groundtruth_boxes) # For class-agnostic models, groundtruth one-hot encodings collapse to all # ones. if class_agnostic: groundtruth_classes_one_hot = tf.ones( [groundtruth_boxes_shape[0], groundtruth_boxes_shape[1], 1]) else: groundtruth_classes_one_hot = tf.stack( detection_model.groundtruth_lists(fields.BoxListFields.classes)) label_id_offset = 1 # Applying label id offset (b/63711816) groundtruth_classes = ( tf.argmax(groundtruth_classes_one_hot, axis=2) + label_id_offset) groundtruth = { input_data_fields.groundtruth_boxes: groundtruth_boxes, input_data_fields.groundtruth_classes: groundtruth_classes } if detection_model.groundtruth_has_field(fields.BoxListFields.masks): groundtruth[input_data_fields.groundtruth_instance_masks] = tf.stack( detection_model.groundtruth_lists(fields.BoxListFields.masks)) if detection_model.groundtruth_has_field(fields.BoxListFields.is_crowd): groundtruth[input_data_fields.groundtruth_is_crowd] = tf.stack( detection_model.groundtruth_lists(fields.BoxListFields.is_crowd)) if detection_model.groundtruth_has_field(input_data_fields.groundtruth_area): groundtruth[input_data_fields.groundtruth_area] = tf.stack( detection_model.groundtruth_lists(input_data_fields.groundtruth_area)) if detection_model.groundtruth_has_field(fields.BoxListFields.keypoints): groundtruth[input_data_fields.groundtruth_keypoints] = tf.stack( detection_model.groundtruth_lists(fields.BoxListFields.keypoints)) if detection_model.groundtruth_has_field( fields.BoxListFields.keypoint_depths): groundtruth[input_data_fields.groundtruth_keypoint_depths] = tf.stack( detection_model.groundtruth_lists(fields.BoxListFields.keypoint_depths)) groundtruth[ input_data_fields.groundtruth_keypoint_depth_weights] = tf.stack( detection_model.groundtruth_lists( fields.BoxListFields.keypoint_depth_weights)) if detection_model.groundtruth_has_field( fields.BoxListFields.keypoint_visibilities): groundtruth[input_data_fields.groundtruth_keypoint_visibilities] = tf.stack( detection_model.groundtruth_lists( fields.BoxListFields.keypoint_visibilities)) if detection_model.groundtruth_has_field(fields.BoxListFields.group_of): groundtruth[input_data_fields.groundtruth_group_of] = tf.stack( detection_model.groundtruth_lists(fields.BoxListFields.group_of)) label_id_offset_paddings = tf.constant([[0, 0], [1, 0]]) if detection_model.groundtruth_has_field( input_data_fields.groundtruth_verified_neg_classes): groundtruth[input_data_fields.groundtruth_verified_neg_classes] = tf.pad( tf.stack(detection_model.groundtruth_lists( input_data_fields.groundtruth_verified_neg_classes)), label_id_offset_paddings) if detection_model.groundtruth_has_field( input_data_fields.groundtruth_not_exhaustive_classes): groundtruth[ input_data_fields.groundtruth_not_exhaustive_classes] = tf.pad( tf.stack(detection_model.groundtruth_lists( input_data_fields.groundtruth_not_exhaustive_classes)), label_id_offset_paddings) if detection_model.groundtruth_has_field( fields.BoxListFields.densepose_num_points): groundtruth[input_data_fields.groundtruth_dp_num_points] = tf.stack( detection_model.groundtruth_lists( fields.BoxListFields.densepose_num_points)) if detection_model.groundtruth_has_field( fields.BoxListFields.densepose_part_ids): groundtruth[input_data_fields.groundtruth_dp_part_ids] = tf.stack( detection_model.groundtruth_lists( fields.BoxListFields.densepose_part_ids)) if detection_model.groundtruth_has_field( fields.BoxListFields.densepose_surface_coords): groundtruth[input_data_fields.groundtruth_dp_surface_coords] = tf.stack( detection_model.groundtruth_lists( fields.BoxListFields.densepose_surface_coords)) if detection_model.groundtruth_has_field(fields.BoxListFields.track_ids): groundtruth[input_data_fields.groundtruth_track_ids] = tf.stack( detection_model.groundtruth_lists(fields.BoxListFields.track_ids)) if detection_model.groundtruth_has_field( input_data_fields.groundtruth_labeled_classes): groundtruth[input_data_fields.groundtruth_labeled_classes] = tf.pad( tf.stack( detection_model.groundtruth_lists( input_data_fields.groundtruth_labeled_classes)), label_id_offset_paddings) groundtruth[input_data_fields.num_groundtruth_boxes] = ( tf.tile([max_number_of_boxes], multiples=[groundtruth_boxes_shape[0]])) return groundtruth
def parser(record): """function used to parse tfrecord.""" record_spec = { "input": tf.FixedLenFeature([seq_len], tf.int64), "target": tf.FixedLenFeature([seq_len], tf.int64), "seg_id": tf.FixedLenFeature([seq_len], tf.int64), "label": tf.FixedLenFeature([1], tf.int64), "is_masked": tf.FixedLenFeature([seq_len], tf.int64), } # retrieve serialized example example = tf.parse_single_example( serialized=record, features=record_spec) inputs = example.pop("input") target = example.pop("target") is_masked = tf.cast(example.pop("is_masked"), tf.bool) non_reuse_len = seq_len - reuse_len assert perm_size <= reuse_len and perm_size <= non_reuse_len perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0 = _local_perm( inputs[:reuse_len], target[:reuse_len], is_masked[:reuse_len], perm_size, reuse_len) perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1 = _local_perm( inputs[reuse_len:], target[reuse_len:], is_masked[reuse_len:], perm_size, non_reuse_len) perm_mask_0 = tf.concat([perm_mask_0, tf.ones([reuse_len, non_reuse_len])], axis=1) perm_mask_1 = tf.concat([tf.zeros([non_reuse_len, reuse_len]), perm_mask_1], axis=1) perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0) target = tf.concat([target_0, target_1], axis=0) target_mask = tf.concat([target_mask_0, target_mask_1], axis=0) input_k = tf.concat([input_k_0, input_k_1], axis=0) input_q = tf.concat([input_q_0, input_q_1], axis=0) if num_predict is not None: indices = tf.range(seq_len, dtype=tf.int64) bool_target_mask = tf.cast(target_mask, tf.bool) indices = tf.boolean_mask(indices, bool_target_mask) ##### extra padding due to CLS/SEP introduced after prepro actual_num_predict = tf.shape(indices)[0] pad_len = num_predict - actual_num_predict ##### target_mapping target_mapping = tf.one_hot(indices, seq_len, dtype=tf.float32) paddings = tf.zeros([pad_len, seq_len], dtype=target_mapping.dtype) target_mapping = tf.concat([target_mapping, paddings], axis=0) example["target_mapping"] = tf.reshape(target_mapping, [num_predict, seq_len]) ##### target target = tf.boolean_mask(target, bool_target_mask) paddings = tf.zeros([pad_len], dtype=target.dtype) target = tf.concat([target, paddings], axis=0) example["target"] = tf.reshape(target, [num_predict]) ##### target mask target_mask = tf.concat( [tf.ones([actual_num_predict], dtype=tf.float32), tf.zeros([pad_len], dtype=tf.float32)], axis=0) example["target_mask"] = tf.reshape(target_mask, [num_predict]) else: example["target"] = tf.reshape(target, [seq_len]) example["target_mask"] = tf.reshape(target_mask, [seq_len]) # reshape back to fixed shape example["perm_mask"] = tf.reshape(perm_mask, [seq_len, seq_len]) example["input_k"] = tf.reshape(input_k, [seq_len]) example["input_q"] = tf.reshape(input_q, [seq_len]) _convert_example(example, use_bfloat16) for k, v in example.items(): logging.info("%s: %s", k, v) return example
def testIntegratedGradientAttribution(self): # Due to complexity of the indicator we cannot easily extend this test to # > 1 lab test. obs_values = tf.constant([[[10000.0], [15000.0], [2.0]], [[0.0], [100.0], [2000.0]]]) # We compare these values to a linear interpolation between the second to # the last and the last value of the test. obs_values_base = tf.constant([[[10000.0], [15000.0], [15000.0]], [[0.0], [100.0], [100.0]]]) # For this test we need to select all attributions in order for consistency # to hold. indicator = tf.ones(shape=[2, 3, 1], dtype=tf.float32) delta_time = tf.constant([[[1000], [999], [2]], [[1001], [500], [20]]], dtype=tf.float32) # Selected so that the attribution is only over the third time step in both # batch entries. attribution_max_delta_time = 100 num_classes = 1 diff_delta_time = tf.constant( [[[1000], [1], [997]], [[1001], [501], [480]]], dtype=tf.float32) # This is also important to not loose any time steps in the attribution. sequence_length = tf.constant([3, 3]) # TODO(milah): Not clear why this test doesn't work for the RNN. def construct_logits_fn(unused_diff_delta_time, obs_values, unused_indicator, unused_sequence_length, unused_seq_mask, unused_hparams, reuse): result = tf.layers.dense(obs_values, num_classes, name='test1', reuse=reuse, activation=None) * (tf.expand_dims( obs_values[:, 0, :], axis=1) + 0.5) return result, None # First setup the weights of the RNN. logits, _ = construct_logits_fn(diff_delta_time, obs_values, indicator, sequence_length, None, None, False) # To verify the correctness of the attribution we compute the prediction at # the obs_values_base. base_logits, _ = construct_logits_fn(diff_delta_time, obs_values_base, indicator, sequence_length, None, None, True) # Set high for increased precision of the approximation. num_steps = 100 hparams = contrib_training.HParams( sequence_prediction=True, use_rnn_attention=False, path_integrated_gradients_num_steps=num_steps, attribution_max_delta_time=attribution_max_delta_time) gradients = osm.compute_path_integrated_gradient_attribution( obs_values, indicator, diff_delta_time, delta_time, sequence_length, None, hparams, construct_logits_fn) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) actual_logits = sess.run(logits) actual_base_logits = sess.run(base_logits) actual_gradients = sess.run(gradients) self.assertAllClose(actual_logits - actual_base_logits, actual_gradients, atol=0.001)
def __init__(self, bert_config, is_training, input_ids, input_mask=None, token_type_ids=None, use_one_hot_embeddings=True, scope=None, embedding_size=None, input_embeddings=None, input_reprs=None, update_embeddings=True, untied_embeddings=False, ltr=False, rtl=False): """Constructor for BertModel. Args: bert_config: `BertConfig` instance. is_training: bool. true for training model, false for eval model. Controls whether dropout will be applied. input_ids: int32 Tensor of shape [batch_size, seq_length]. input_mask: (optional) int32 Tensor of shape [batch_size, seq_length]. token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. use_one_hot_embeddings: (optional) bool. Whether to use one-hot word embeddings or tf.embedding_lookup() for the word embeddings. On the TPU, it is much faster if this is True, on the CPU or GPU, it is faster if this is False. scope: (optional) variable scope. Defaults to "electra". Raises: ValueError: The config is invalid or one of the input tensor shapes is invalid. """ bert_config = copy.deepcopy(bert_config) if not is_training: bert_config.hidden_dropout_prob = 0.0 bert_config.attention_probs_dropout_prob = 0.0 input_shape = get_shape_list(token_type_ids, expected_rank=2) batch_size = input_shape[0] seq_length = input_shape[1] if input_mask is None: input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32) assert token_type_ids is not None if input_reprs is None: if input_embeddings is None: with tf.variable_scope( (scope if untied_embeddings else "electra") + "/embeddings", reuse=tf.AUTO_REUSE): # Perform embedding lookup on the word ids if embedding_size is None: embedding_size = bert_config.hidden_size (self.token_embeddings, self.embedding_table) = embedding_lookup( input_ids=input_ids, vocab_size=bert_config.vocab_size, embedding_size=embedding_size, initializer_range=bert_config.initializer_range, word_embedding_name="word_embeddings", use_one_hot_embeddings=use_one_hot_embeddings) else: self.token_embeddings = input_embeddings with tf.variable_scope( (scope if untied_embeddings else "electra") + "/embeddings", reuse=tf.AUTO_REUSE): # Add positional embeddings and token type embeddings, then layer # normalize and perform dropout. self.embedding_output = embedding_postprocessor( input_tensor=self.token_embeddings, use_token_type=True, token_type_ids=token_type_ids, token_type_vocab_size=bert_config.type_vocab_size, token_type_embedding_name="token_type_embeddings", use_position_embeddings=True, position_embedding_name="position_embeddings", initializer_range=bert_config.initializer_range, max_position_embeddings=bert_config.max_position_embeddings, dropout_prob=bert_config.hidden_dropout_prob) else: self.embedding_output = input_reprs if not update_embeddings: self.embedding_output = tf.stop_gradient(self.embedding_output) with tf.variable_scope(scope, default_name="electra"): if self.embedding_output.shape[-1] != bert_config.hidden_size: self.embedding_output = tf.layers.dense( self.embedding_output, bert_config.hidden_size, name="embeddings_project") with tf.variable_scope("encoder"): # This converts a 2D mask of shape [batch_size, seq_length] to a 3D # mask of shape [batch_size, seq_length, seq_length] which is used # for the attention scores. attention_mask = create_attention_mask_from_input_mask( token_type_ids, input_mask) # Add causal masking to the attention for running the transformer # left-to-right or right-to-left if ltr or rtl: causal_mask = tf.ones((seq_length, seq_length)) if ltr: causal_mask = tf.matrix_band_part(causal_mask, -1, 0) else: causal_mask = tf.matrix_band_part(causal_mask, 0, -1) attention_mask *= tf.expand_dims(causal_mask, 0) # Run the stacked transformer. Output shapes # sequence_output: [batch_size, seq_length, hidden_size] # pooled_output: [batch_size, hidden_size] # all_encoder_layers: [n_layers, batch_size, seq_length, hidden_size]. # attn_maps: [n_layers, batch_size, n_heads, seq_length, seq_length] (self.all_layer_outputs, self.attn_maps) = transformer_model( input_tensor=self.embedding_output, attention_mask=attention_mask, hidden_size=bert_config.hidden_size, num_hidden_layers=bert_config.num_hidden_layers, num_attention_heads=bert_config.num_attention_heads, intermediate_size=bert_config.intermediate_size, intermediate_act_fn=get_activation(bert_config.hidden_act), hidden_dropout_prob=bert_config.hidden_dropout_prob, attention_probs_dropout_prob= bert_config.attention_probs_dropout_prob, initializer_range=bert_config.initializer_range, do_return_all_layers=True) self.sequence_output = self.all_layer_outputs[-1] self.pooled_output = self.sequence_output[:, 0]
def mlm_sample_text(params, x, random_documents=False): seed = params.get('seed', None) ctx_len = params["n_ctx"] assert 'mlm_mask_id' in params, 'the key `mlm_mask_id` must be set on your config to do masked language model training, specifying the id of the reserved mask token' mask_id = params['mlm_mask_id'] cls_token_id = params.get('mlm_cls_token_id', None) num_tokens = params.get('n_vocab', None) mask_ignore_ids = set(params.get('mlm_mask_ignore_ids', [])) mask_ignore_ids.add(cls_token_id) mask_prob = params.get('mlm_mask_prob', 0.15) same_token_prob = params.get('mlm_same_token_prob', 0.10) random_token_prob = params.get('mlm_random_token_prob', 0.) seq_len = ctx_len if cls_token_id is None else (ctx_len - 1) if random_documents: s = tf.size(x) r = tf.random.uniform([], maxval=(s - seq_len), dtype=tf.dtypes.int32, seed=seed) r1 = tf.range(r, r + seq_len) r1 = tf.reshape(r1, [seq_len]) features = tf.gather(x, r1) else: features = x[:seq_len] # add cls token id if specified by `mlm_cls_token_id` if cls_token_id is not None: features = tf.pad(features, [[1, 0]], constant_values=cls_token_id) features = tf.cast(features, dtype=tf.int32) shape = features.shape # determine which tokens are mask-able can_mask = tf.not_equal(features, 0) for ignore_id in mask_ignore_ids: can_mask &= tf.not_equal(features, ignore_id) # generate boolean mask for masking ids mask_mask = tf.less( tf.random.uniform(shape, minval=0., maxval=1., dtype=tf.float32, seed=seed), mask_prob) mask_mask &= can_mask # generate mask for actually replacing the tokens, for allowing a small number of tokens to stay the same replace_mask = tf.less( tf.random.uniform(shape, minval=0., maxval=1., dtype=tf.float32, seed=seed), 1 - same_token_prob) # randomly replace some tokens with random tokens before masking if random_token_prob > 0: random_token_mask = tf.less( tf.random.uniform(shape, minval=0., maxval=1., dtype=tf.float32, seed=seed), random_token_prob) random_tokens = tf.random.uniform(shape, minval=1, maxval=num_tokens, dtype=tf.dtypes.int32, seed=seed) # make sure random tokens do not include illegal token ids specified by `mlm_mask_ignore_ids` random_can_mask = tf.not_equal(random_tokens, 0) for ignore_id in mask_ignore_ids: random_can_mask &= tf.not_equal(random_tokens, ignore_id) features = tf.where(random_token_mask & random_can_mask, random_tokens, features) # mask the tokens mask_tokens = tf.ones(shape, dtype=tf.int32) * mask_id masked_features = tf.where(mask_mask & replace_mask, mask_tokens, features) # labels will be set to 0 for all non-masked tokens labels = tf.where(mask_mask, tf.zeros(shape, dtype=tf.int32), features) masked_features, labels = map(lambda t: tf.reshape(t, [ctx_len]), (masked_features, labels)) return masked_features, labels
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf.logging.info("*** Features ***") for name in sorted(features.keys()): tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] label_ids = features["label_ids"] is_real_example = None if "is_real_example" in features: is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) else: is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) is_training = (mode == tf.estimator.ModeKeys.TRAIN) (total_loss, per_example_loss, probabilities, logits, predictions) = \ create_model(albert_config, is_training, input_ids, input_mask, segment_ids, label_ids, num_labels, use_one_hot_embeddings, max_seq_length, dropout_prob, hub_module) tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None if init_checkpoint: (assignment_map, initialized_variable_names ) = modeling.get_assignment_map_from_checkpoint( tvars, init_checkpoint) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) tf.logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimization.create_optimizer(total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) output_spec = contrib_tpu.TPUEstimatorSpec(mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(per_example_loss, label_ids, logits, is_real_example): predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) accuracy = tf.metrics.accuracy(labels=label_ids, predictions=predictions, weights=is_real_example) loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example) return { "eval_accuracy": accuracy, "eval_loss": loss, } eval_metrics = (metric_fn, [ per_example_loss, label_ids, logits, is_real_example ]) output_spec = contrib_tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: output_spec = contrib_tpu.TPUEstimatorSpec(mode=mode, predictions={ "probabilities": probabilities, "predictions": predictions }, scaffold_fn=scaffold_fn) return output_spec
def rnn(cell, inputs, sequence_length=None, initial_state=None, ff_keep_prob=1., recur_keep_prob=1., enforce_dropout=False, dtype=tf.float32, scope=None): """ """ inputs = tf.transpose(inputs, [1, 0, 2]) # (B,T,D) => (T,B,D) parallel_iterations = 32 if sequence_length is not None: sequence_length = tf.to_int32(sequence_length) with tf.variable_scope(scope or 'RNN') as varscope: #if varscope.caching_device is None: # varscope.set_caching_device(lambda op: op.device) input_shape = tf.shape(inputs) time_steps, batch_size, _ = tf.unstack(input_shape, 3) const_time_steps, const_batch_size, const_depth = inputs.get_shape( ).as_list() if initial_state is not None: state = initial_state else: if not dtype: raise ValueError( 'If no initial_state is provided, dtype must be.') state = cell.zero_state(batch_size, dtype) zero_output = tf.zeros(tf.stack([batch_size, cell.output_size]), inputs.dtype) if sequence_length is not None: min_sequence_length = tf.reduce_min(sequence_length) max_sequence_length = tf.reduce_max(sequence_length) time = tf.constant(0, dtype=tf.int32, name='time') output_ta = tf.TensorArray(dtype=inputs.dtype, size=time_steps, tensor_array_name='dynamic_rnn_output') input_ta = tf.TensorArray(dtype=inputs.dtype, size=time_steps, tensor_array_name='dynamic_rnn_input') if ff_keep_prob < 1: noise_shape = tf.stack([1, batch_size, const_depth]) if enforce_dropout is not None: inputs = tf.layers.dropout(inputs, 1 - ff_keep_prob, noise_shape=noise_shape, training=enforce_dropout) else: inputs = tf.nn.dropout(inputs, ff_keep_prob, noise_shape=noise_shape) if recur_keep_prob < 1: ones = tf.ones(tf.stack([batch_size, cell.output_size])) if enforce_dropout is not None: state_dropout = tf.layers.dropout(ones, 1 - recur_keep_prob, training=enforce_dropout) else: state_dropout = tf.nn.dropout(ones, recur_keep_prob) state_dropout = tf.concat( [ones] * (cell.state_size // cell.output_size - 1) + [state_dropout], 1) else: state_dropout = 1 input_ta = input_ta.unstack(inputs) #----------------------------------------------------------- def _time_step(time, state, output_ta_t): """ """ input_t = input_ta.read(time) #- - - - - - - - - - - - - - - - - - - - - - - - - - - - - def _empty_update(): return zero_output, state #- - - - - - - - - - - - - - - - - - - - - - - - - - - - - def _call_cell(): return cell(input_t, state * state_dropout) #- - - - - - - - - - - - - - - - - - - - - - - - - - - - - def _maybe_copy_some_through(): new_output, new_state = _call_cell() return tf.cond( time < min_sequence_length, lambda: (new_output, new_state), lambda: (tf.where( time >= sequence_length, zero_output, new_output ), tf.where(time >= sequence_length, state, new_state))) #- - - - - - - - - - - - - - - - - - - - - - - - - - - - - if sequence_length is not None: output, new_state = tf.cond(time >= max_sequence_length, _empty_update, _maybe_copy_some_through) else: (output, new_state) = _call_cell() output_ta_t = output_ta_t.write(time, output) return (time + 1, new_state, output_ta_t) #----------------------------------------------------------- _, final_state, output_final_ta = tf.while_loop( cond=lambda time, _1, _2: time < time_steps, body=_time_step, loop_vars=(time, state, output_ta), parallel_iterations=parallel_iterations) final_outputs = output_final_ta.stack() outputs = tf.transpose(final_outputs, [1, 0, 2]) # (T,B,D) => (B,T,D) return outputs, final_state
def test_glorot_regularizer(self): reg = prior.GlorotNormalRegularizer(weight=1.0) result = float(reg(tf.ones((3, 5)))) self.assertAlmostEqual(result, 30., msg='GlorotNormalRegularizer regularization wrong.')
def main(_): with tf.Graph().as_default(): # Create inputs in [0, 1], as expected by vgg_16. inputs, _ = image_utils.imagenet_inputs( FLAGS.batch_size, FLAGS.image_size) evaluation_images = image_utils.load_evaluation_images(FLAGS.image_size) # Process style and weight flags if FLAGS.style_coefficients is None: style_coefficients = [1.0 for _ in range(FLAGS.num_styles)] else: style_coefficients = ast.literal_eval(FLAGS.style_coefficients) if len(style_coefficients) != FLAGS.num_styles: raise ValueError( 'number of style coefficients differs from number of styles') content_weights = ast.literal_eval(FLAGS.content_weights) style_weights = ast.literal_eval(FLAGS.style_weights) # Load style images. style_images, labels, style_gram_matrices = image_utils.style_image_inputs( os.path.expanduser(FLAGS.style_dataset_file), batch_size=FLAGS.num_styles, image_size=FLAGS.image_size, square_crop=True, shuffle=False) labels = tf.unstack(labels) def _create_normalizer_params(style_label): """Creates normalizer parameters from a style label.""" return {'labels': tf.expand_dims(style_label, 0), 'num_categories': FLAGS.num_styles, 'center': True, 'scale': True} # Dummy call to simplify the reuse logic model.transform( inputs, alpha=FLAGS.alpha, reuse=False, normalizer_params=_create_normalizer_params(labels[0])) def _style_sweep(inputs): """Transfers all styles onto the input one at a time.""" inputs = tf.expand_dims(inputs, 0) stylized_inputs = [] for _, style_label in enumerate(labels): stylized_input = model.transform( inputs, alpha=FLAGS.alpha, reuse=True, normalizer_params=_create_normalizer_params(style_label)) stylized_inputs.append(stylized_input) return tf.concat([inputs] + stylized_inputs, 0) if FLAGS.style_grid: style_row = tf.concat( [tf.ones([1, FLAGS.image_size, FLAGS.image_size, 3]), style_images], 0) stylized_training_example = _style_sweep(inputs[0]) stylized_evaluation_images = [ _style_sweep(image) for image in tf.unstack(evaluation_images)] stylized_noise = _style_sweep( tf.random_uniform([FLAGS.image_size, FLAGS.image_size, 3])) stylized_style_images = [ _style_sweep(image) for image in tf.unstack(style_images)] if FLAGS.style_crossover: grid = tf.concat( [style_row, stylized_training_example, stylized_noise] + stylized_evaluation_images + stylized_style_images, 0) else: grid = tf.concat( [style_row, stylized_training_example, stylized_noise] + stylized_evaluation_images, 0) if FLAGS.style_crossover: grid_shape = [ 3 + evaluation_images.get_shape().as_list()[0] + FLAGS.num_styles, 1 + FLAGS.num_styles] else: grid_shape = [ 3 + evaluation_images.get_shape().as_list()[0], 1 + FLAGS.num_styles] tf.summary.image( 'Style Grid', tf.cast( image_utils.form_image_grid( grid, grid_shape, [FLAGS.image_size, FLAGS.image_size], 3) * 255.0, tf.uint8)) if FLAGS.learning_curves: metrics = {} for i, label in enumerate(labels): gram_matrices = dict( (key, value[i: i + 1]) for key, value in style_gram_matrices.items()) stylized_inputs = model.transform( inputs, alpha=FLAGS.alpha, reuse=True, normalizer_params=_create_normalizer_params(label)) _, loss_dict = learning.total_loss( inputs, stylized_inputs, gram_matrices, content_weights, style_weights, reuse=i > 0) for key, value in loss_dict.items(): metrics['{}_style_{}'.format(key, i)] = slim.metrics.streaming_mean( value) names_values, names_updates = slim.metrics.aggregate_metric_map(metrics) for name, value in names_values.items(): summary_op = tf.summary.scalar(name, value, []) print_op = tf.Print(summary_op, [value], name) tf.add_to_collection(tf.GraphKeys.SUMMARIES, print_op) eval_op = list(names_updates.values()) num_evals = FLAGS.num_evals else: eval_op = None num_evals = 1 slim.evaluation.evaluation_loop( master=FLAGS.master, checkpoint_dir=os.path.expanduser(FLAGS.train_dir), logdir=os.path.expanduser(FLAGS.eval_dir), eval_op=eval_op, num_evals=num_evals, eval_interval_secs=FLAGS.eval_interval_secs)
def train(dset_name, s_dim, n_dim, factors, batch_size, dec_lr, enc_lr_mul, iterations, model_type="gen"): ut.log("In train") masks = datasets.make_masks(factors, s_dim) z_dim = s_dim + n_dim enc_lr = enc_lr_mul * dec_lr # Load data dset = datasets.get_dlib_data(dset_name) if dset is None: x_shape = [64, 64, 1] else: x_shape = dset.observation_shape targets_real = tf.ones((batch_size, 1)) targets_fake = tf.zeros((batch_size, 1)) targets = tf.concat((targets_real, targets_fake), axis=0) # Networks if model_type == "gen": assert factors.split("=")[0] in {"c", "s", "cs", "r"} y_dim = len(masks) dis = networks.Discriminator(x_shape, y_dim) gen = networks.Generator(x_shape, z_dim) enc = networks.Encoder(x_shape, s_dim) # Encoder ignores nuisance param ut.log(dis.read(dis.WITH_VARS)) ut.log(gen.read(gen.WITH_VARS)) ut.log(enc.read(enc.WITH_VARS)) elif model_type == "enc": assert factors.split("=")[0] in {"r"} enc = networks.Encoder(x_shape, s_dim) # Encoder ignores nuisance param ut.log(enc.read(enc.WITH_VARS)) elif model_type == "van": assert factors.split("=")[0] in {"l"} dis = networks.LabelDiscriminator(x_shape, s_dim) # Uses s_dim gen = networks.Generator(x_shape, z_dim) enc = networks.Encoder(x_shape, s_dim) # Encoder ignores nuisance param ut.log(dis.read(dis.WITH_VARS)) ut.log(gen.read(gen.WITH_VARS)) ut.log(enc.read(enc.WITH_VARS)) # Create optimizers if model_type in {"gen", "van"}: gen_opt = tfk.optimizers.Adam(learning_rate=dec_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-8) dis_opt = tfk.optimizers.Adam(learning_rate=enc_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-8) enc_opt = tfk.optimizers.Adam(learning_rate=enc_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-8) elif model_type == "enc": enc_opt = tfk.optimizers.Adam(learning_rate=enc_lr, beta_1=0.5, beta_2=0.999, epsilon=1e-8) @tf.function def train_gen_step(x1_real, x2_real, y_real): gen.train() dis.train() enc.train() # Alternate discriminator step and generator step with tf.GradientTape(persistent=True) as tape: # Generate z1, z2, y_fake = datasets.paired_randn(batch_size, z_dim, masks) x1_fake = tf.stop_gradient(gen(z1)) x2_fake = tf.stop_gradient(gen(z2)) # Discriminate x1 = tf.concat((x1_real, x1_fake), 0) x2 = tf.concat((x2_real, x2_fake), 0) y = tf.concat((y_real, y_fake), 0) logits = dis(x1, x2, y) # Encode p_z = enc(x1_fake) dis_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=targets)) # Encoder ignores nuisance parameters (if they exist) enc_loss = -tf.reduce_mean(p_z.log_prob(z1[:, :s_dim])) dis_grads = tape.gradient(dis_loss, dis.trainable_variables) enc_grads = tape.gradient(enc_loss, enc.trainable_variables) dis_opt.apply_gradients(zip(dis_grads, dis.trainable_variables)) enc_opt.apply_gradients(zip(enc_grads, enc.trainable_variables)) with tf.GradientTape(persistent=False) as tape: # Generate z1, z2, y_fake = datasets.paired_randn(batch_size, z_dim, masks) x1_fake = gen(z1) x2_fake = gen(z2) # Discriminate logits_fake = dis(x1_fake, x2_fake, y_fake) gen_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=targets_real)) gen_grads = tape.gradient(gen_loss, gen.trainable_variables) gen_opt.apply_gradients(zip(gen_grads, gen.trainable_variables)) return dict(gen_loss=gen_loss, dis_loss=dis_loss, enc_loss=enc_loss) @tf.function def train_van_step(x_real, y_real): gen.train() dis.train() enc.train() if n_dim > 0: padding = tf.zeros((y_real.shape[0], n_dim)) y_real_pad = tf.concat((y_real, padding), axis=-1) else: y_real_pad = y_real # Alternate discriminator step and generator step with tf.GradientTape(persistent=False) as tape: # Generate z_fake = datasets.paired_randn(batch_size, z_dim, masks) z_fake = z_fake + y_real_pad x_fake = gen(z_fake) # Discriminate logits_fake = dis(x_fake, y_real) gen_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=targets_real)) gen_grads = tape.gradient(gen_loss, gen.trainable_variables) gen_opt.apply_gradients(zip(gen_grads, gen.trainable_variables)) with tf.GradientTape(persistent=True) as tape: # Generate z_fake = datasets.paired_randn(batch_size, z_dim, masks) z_fake = z_fake + y_real_pad x_fake = tf.stop_gradient(gen(z_fake)) # Discriminate x = tf.concat((x_real, x_fake), 0) y = tf.concat((y_real, y_real), 0) logits = dis(x, y) # Encode p_z = enc(x_fake) dis_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=targets)) # Encoder ignores nuisance parameters (if they exist) enc_loss = -tf.reduce_mean(p_z.log_prob(z_fake[:, :s_dim])) dis_grads = tape.gradient(dis_loss, dis.trainable_variables) enc_grads = tape.gradient(enc_loss, enc.trainable_variables) dis_opt.apply_gradients(zip(dis_grads, dis.trainable_variables)) enc_opt.apply_gradients(zip(enc_grads, enc.trainable_variables)) return dict(gen_loss=gen_loss, dis_loss=dis_loss, enc_loss=enc_loss) @tf.function def train_enc_step(x1_real, x2_real, y_real): with tf.GradientTape() as tape: z1 = enc(x1_real).mean() z2 = enc(x2_real).mean() logits = tf.gather(z1 - z2, masks, axis=-1) loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=y_real)) enc_grads = tape.gradient(loss, enc.trainable_variables) enc_opt.apply_gradients(zip(enc_grads, enc.trainable_variables)) return dict(gen_loss=0, dis_loss=0, enc_loss=loss) @tf.function def gen_eval(z): gen.eval() return gen(z) @tf.function def enc_eval(x): enc.eval() return enc(x).mean() enc_np = lambda x: enc_eval(x).numpy() # Initial preparation if FLAGS.debug: iter_log = 100 iter_save = 2000 train_range = range(iterations) basedir = FLAGS.basedir vizdir = FLAGS.basedir ckptdir = FLAGS.basedir new_run = True else: iter_log = 5000 iter_save = 50000 iter_metric = iter_save * 5 # Make sure this is a factor of 500k basedir = os.path.join(FLAGS.basedir, "exp") ckptdir = os.path.join(basedir, "ckptdir") vizdir = os.path.join(basedir, "vizdir") gfile.MakeDirs(basedir) gfile.MakeDirs(ckptdir) gfile.MakeDirs(vizdir) # train_range will be specified below ckpt_prefix = os.path.join(ckptdir, "model") if model_type in {"gen", "van"}: ckpt_root = tf.train.Checkpoint(dis=dis, dis_opt=dis_opt, gen=gen, gen_opt=gen_opt, enc=enc, enc_opt=enc_opt) elif model_type == "enc": ckpt_root = tf.train.Checkpoint(enc=enc, enc_opt=enc_opt) # Check if we're resuming training if not in debugging mode if not FLAGS.debug: latest_ckpt = tf.train.latest_checkpoint(ckptdir) if latest_ckpt is None: new_run = True ut.log("Starting a completely new model") train_range = range(iterations) else: new_run = False ut.log("Restarting from {}".format(latest_ckpt)) ckpt_root.restore(latest_ckpt) resuming_iteration = iter_save * (int(ckpt_root.save_counter) - 1) train_range = range(resuming_iteration, iterations) # Training if dset is None: ut.log("Dataset {} is not available".format(dset_name)) ut.log("Ending program having checked that the networks can be built.") return batches = datasets.paired_data_generator( dset, masks).repeat().batch(batch_size).prefetch(1000) batches = iter(batches) start_time = time.time() train_time = 0 if FLAGS.debug: train_range = tqdm(train_range) for global_step in train_range: stopwatch = time.time() if model_type == "gen": x1, x2, y = next(batches) vals = train_gen_step(x1, x2, y) elif model_type == "enc": x1, x2, y = next(batches) vals = train_enc_step(x1, x2, y) elif model_type == "van": x, y = next(batches) vals = train_van_step(x, y) train_time += time.time() - stopwatch # Generic bookkeeping if (global_step + 1) % iter_log == 0 or global_step == 0: elapsed_time = time.time() - start_time string = ", ".join(( "Iter: {:07d}, Elapsed: {:.3e}, (Elapsed) Iter/s: {:.3e}, (Train Step) Iter/s: {:.3e}" .format(global_step, elapsed_time, global_step / elapsed_time, global_step / train_time), "Gen: {gen_loss:.4f}, Dis: {dis_loss:.4f}, Enc: {enc_loss:.4f}" .format(**vals))) + "." ut.log(string) # Log visualizations and evaluations if (global_step + 1) % iter_save == 0 or global_step == 0: if model_type == "gen": viz.ablation_visualization(x1, x2, gen_eval, z_dim, vizdir, global_step + 1) elif model_type == "van": viz.ablation_visualization(x, x, gen_eval, z_dim, vizdir, global_step + 1) if FLAGS.debug: evaluate.evaluate_enc(enc_np, dset, s_dim, FLAGS.gin_file, FLAGS.gin_bindings, pida_sample_size=1000, dlib_metrics=FLAGS.debug_dlib_metrics) else: dlib_metrics = (global_step + 1) % iter_metric == 0 evaluate.evaluate_enc(enc_np, dset, s_dim, FLAGS.gin_file, FLAGS.gin_bindings, pida_sample_size=10000, dlib_metrics=dlib_metrics) # Save model if (global_step + 1) % iter_save == 0 or (global_step == 0 and new_run): # Save model only after ensuring all measurements are taken. # This ensures that restarts always computes the evals ut.log("Saved to", ckpt_root.save(ckpt_prefix))
def test_he_regularizer(self): reg = prior.HeNormalRegularizer(weight=1.0) result = float(reg(tf.ones((3, 5)))) self.assertAlmostEqual(result, 11.25, msg='HeNormalRegularizer regularization wrong.')
def _create_cross_entropy_action_tensors(self, num_samples=200, top_k_portion=0.5): """Create tensorflow operations for cross_entropy max_actions.""" top_k_num = int(top_k_portion * num_samples) self._dynamic_batch_size = tf.placeholder(dtype=tf.int32, name="dynamic_batch_size") self._action_init_tensor = tf.placeholder(dtype=tf.float32, name="action_init_tensor", shape=(None, self.action_dim)) self._tolerance_tensor = tf.placeholder(dtype=tf.float32, name="tolerance_tensor", shape=()) sample_mean_init = self._action_init_tensor sample_covariance_diag_init = tf.ones_like(self._action_init_tensor) top_k_value_init = tf.constant( [np.inf]) * tf.ones(shape=(self._dynamic_batch_size, 1)) top_k_action_samples_init = tf.tile( tf.expand_dims(tf.zeros_like(self._action_init_tensor), axis=1), [1, top_k_num, 1]) random_sampler = tfp.distributions.MultivariateNormalDiag( loc=np.zeros(self.action_dim), scale_diag=np.ones(self.action_dim)) def cond_cross_entropy(itr, cond_terminate, sample_mean, sample_covariance_diag, top_k_value, top_k_action_samples): del sample_mean, sample_covariance_diag, top_k_value, top_k_action_samples cond_1 = tf.math.less(itr, self.action_maximization_iterations) return tf.math.logical_and(cond_1, tf.logical_not(cond_terminate)) def body_cross_entropy(itr, cond_terminate, sample_mean, sample_covariance_diag, top_k_value, top_k_action_samples): """Function for cross entropy search of actions.""" del top_k_action_samples top_k_value_prev = top_k_value batch_sample_mean = tf.reshape( tf.tile(sample_mean, [1, num_samples]), [self._dynamic_batch_size * num_samples, self.action_dim]) batch_sample_covariance_diag = tf.reshape( tf.tile(sample_covariance_diag, [1, num_samples]), [self._dynamic_batch_size * num_samples, self.action_dim]) action_samples = self._action_projection( batch_sample_mean + batch_sample_covariance_diag * tf.cast(random_sampler.sample( sample_shape=[self._dynamic_batch_size * num_samples]), dtype=tf.float32)) state_samples = tf.reshape( tf.tile(self._state_tensor, [1, num_samples]), [self._dynamic_batch_size * num_samples, self.state_dim]) action_samples = tf.reshape( action_samples, [self._dynamic_batch_size * num_samples, self.action_dim]) values = tf.reshape( self._build_q_function_net(state_samples, action_samples), [self._dynamic_batch_size, num_samples]) # everything is in batch mode top_k_index = tf.argsort(values, axis=1, direction="DESCENDING")[:, 0:top_k_num] top_k_index_1d = tf.reshape( top_k_index, [self._dynamic_batch_size * top_k_num, 1]) counter_tensor_1d = tf.reshape( tf.tile( tf.reshape(tf.range(self._dynamic_batch_size), [self._dynamic_batch_size, 1]), [1, top_k_num]), [self._dynamic_batch_size * top_k_num, 1]) top_k_index_2d = tf.concat([counter_tensor_1d, top_k_index_1d], axis=1) action_samples = tf.reshape( action_samples, [self._dynamic_batch_size, num_samples, self.action_dim]) top_k_action_samples = tf.gather_nd(action_samples, top_k_index_2d) top_k_action_samples = tf.reshape( top_k_action_samples, [self._dynamic_batch_size, top_k_num, self.action_dim]) top_k_values = tf.gather_nd(values, top_k_index_2d) top_k_values = tf.reshape(top_k_values, [self._dynamic_batch_size, top_k_num]) # it's a batch_size x 1 tensor top_k_value = tf.reshape(tf.reduce_mean(top_k_values, axis=1), [self._dynamic_batch_size, 1]) sample_mean = tf.reduce_mean(top_k_action_samples, axis=1) sample_covariance_diag = tf.math.reduce_variance( top_k_action_samples, axis=1) itr = itr + 1 cond_terminate = tf.less_equal( tf.reduce_mean(tf.math.abs(top_k_value - top_k_value_prev)), self._tolerance_tensor) return itr, cond_terminate, sample_mean, sample_covariance_diag, \ top_k_value, top_k_action_samples self.cost_optimizer = tf.while_loop( cond_cross_entropy, body_cross_entropy, [ tf.constant(0), tf.constant(False), sample_mean_init, sample_covariance_diag_init, top_k_value_init, top_k_action_samples_init ])
def test_irregular_shape(self): config = hparams_config.get_efficientdet_config('efficientdet-d0') config.image_size = '896x1600' model = efficientdet_keras.EfficientDetNet(config=config) model(tf.ones([1, 896, 1600, 3]), False) model(tf.ones([1, 499, 333, 3]), False)
def GetEmbeddingLookupList(signals_list, embedding_vars, sparse_ids, sparse_weights=None, combiners='sqrtn', partition_strategies='mod'): """Get a list of embedding lookup tensors. Args: signals_list: A list of strings, representing names of features. embedding_vars: Dict mapping feature names to full embedding variables. sparse_ids: Dict mapping feature names to SparseTensors of their ids. sparse_weights: Either None, or a dict mapping feature names to SparseTensors of their weights (which can also be None). combiners: Either a common combiner type for all features ('mean', sqrtn' or 'sum') or a dict mapping each feature name to a combiner type. partition_strategies: Either a common partition_strategy for all features ('mod' or 'div') or a dict mapping feature_names to partition_stratgies. Returns: embedding_lookup_list: A list of embedding lookup tensors used for bag of words attribution, aligned with signals_list. """ assert isinstance(embedding_vars, dict) and isinstance(sparse_ids, dict) assert sparse_weights is None or isinstance(sparse_weights, dict) assert combiners in ('mean', 'sqrtn', 'sum') or isinstance(combiners, dict) assert (partition_strategies in ('mod', 'div') or isinstance(partition_strategies, dict)) embedding_lookup_list = [] for signal in signals_list: combiner = combiners[signal] if isinstance(combiners, dict) else combiners partition_strategy = (partition_strategies[signal] if isinstance( partition_strategies, dict) else partition_strategies) # Batch dimension should be 1 for attribution. with tf.control_dependencies( [tf.assert_equal(tf.shape(sparse_ids[signal])[0], 1)]): embedding_lookup = tf.nn.embedding_lookup( params=embedding_vars[signal], ids=tf.sparse_tensor_to_dense(sparse_ids[signal]), partition_strategy=partition_strategy) if sparse_weights is None or sparse_weights[signal] is None: num_vals = tf.size(sparse_ids[signal].values) if combiner == 'mean': embedding_weights = tf.fill([1, num_vals], 1.0 / tf.to_float(num_vals)) elif combiner == 'sqrtn': embedding_weights = tf.fill([1, num_vals], 1.0 / tf.sqrt(tf.to_float(num_vals))) else: embedding_weights = tf.ones([1, num_vals], dtype=tf.float32) else: # Batch dimension should be 1 for attribution. with tf.control_dependencies( [tf.assert_equal(tf.shape(sparse_weights[signal])[0], 1)]): dense_weights = tf.sparse_tensor_to_dense( sparse_weights[signal]) if combiner == 'mean': embedding_weights = dense_weights / tf.reduce_sum( dense_weights) elif combiner == 'sqrtn': embedding_weights = ( dense_weights / tf.sqrt(tf.reduce_sum(tf.pow(dense_weights, 2)))) else: embedding_weights = dense_weights embedding_lookup *= tf.expand_dims(embedding_weights, -1) embedding_lookup_list.append(embedding_lookup) return embedding_lookup_list
def setUp(self): super(ConfigurableOpsTest, self).setUp() tf.reset_default_graph() self.inputs_shape = [2, 4, 4, 3] self.inputs = tf.ones(self.inputs_shape, dtype=tf.float32) self.fc_inputs = tf.ones([3, 12])
def _test_generator_graph_helper(self, shape): """Check that generator can take small and non-square inputs.""" output_imgs, _ = cyclegan.cyclegan_generator_resnet(tf.ones(shape)) self.assertAllEqual(shape, output_imgs.shape.as_list())
def build_genie_model(feat_dict, cfg, batch_size, seq_len, is_training=True, seq_varlens=None, dtype=tf.float32): """Builds a Piano Genie model. Args: feat_dict: Dictionary containing input tensors. cfg: Configuration object. batch_size: Number of items in batch. seq_len: Length of each batch item. is_training: Set to False for evaluation. seq_varlens: If not None, a tensor with the batch sequence lengths. dtype: Model weight type. Returns: A dict containing tensors for relevant model config. """ out_dict = {} # Parse features pitches = util.demidify(feat_dict["midi_pitches"]) velocities = feat_dict["velocities"] pitches_scalar = ((tf.cast(pitches, tf.float32) / 87.) * 2.) - 1. # Create sequence lens if is_training and cfg.train_randomize_seq_len: seq_lens = tf.random_uniform([batch_size], minval=cfg.train_seq_len_min, maxval=seq_len + 1, dtype=tf.int32) stp_varlen_mask = tf.sequence_mask(seq_lens, maxlen=seq_len, dtype=tf.float32) elif seq_varlens is not None: seq_lens = seq_varlens stp_varlen_mask = tf.sequence_mask(seq_varlens, maxlen=seq_len, dtype=tf.float32) else: seq_lens = tf.ones([batch_size], dtype=tf.int32) * seq_len stp_varlen_mask = None # Encode if (cfg.stp_emb_unconstrained or cfg.stp_emb_vq or cfg.stp_emb_iq or cfg.seq_emb_unconstrained or cfg.seq_emb_vae or cfg.lor_emb_unconstrained): # Build encoder features enc_feats = [] if cfg.enc_pitch_scalar: enc_feats.append(tf.expand_dims(pitches_scalar, axis=-1)) else: enc_feats.append(tf.one_hot(pitches, 88)) if "delta_times_int" in cfg.enc_aux_feats: enc_feats.append( tf.one_hot(feat_dict["delta_times_int"], cfg.data_max_discrete_times + 1)) if "velocities" in cfg.enc_aux_feats: enc_feats.append( tf.one_hot(velocities, cfg.data_max_discrete_velocities + 1)) enc_feats = tf.concat(enc_feats, axis=2) with tf.variable_scope("encoder"): enc_stp, enc_seq = simple_lstm_encoder( enc_feats, seq_lens, rnn_celltype=cfg.rnn_celltype, rnn_nlayers=cfg.rnn_nlayers, rnn_nunits=cfg.rnn_nunits, rnn_bidirectional=cfg.enc_rnn_bidirectional, dtype=dtype) latents = [] # Step embeddings (single vector per timestep) if cfg.stp_emb_unconstrained: with tf.variable_scope("stp_emb_unconstrained"): stp_emb_unconstrained = tf.layers.dense( enc_stp, cfg.stp_emb_unconstrained_embedding_dim) out_dict["stp_emb_unconstrained"] = stp_emb_unconstrained latents.append(stp_emb_unconstrained) # Quantized step embeddings with VQ-VAE if cfg.stp_emb_vq: import sonnet as snt # pylint:disable=g-import-not-at-top with tf.variable_scope("stp_emb_vq"): with tf.variable_scope("pre_vq"): # pre_vq_encoding is tf.float32 of [batch_size, seq_len, embedding_dim] pre_vq_encoding = tf.layers.dense(enc_stp, cfg.stp_emb_vq_embedding_dim) with tf.variable_scope("quantizer"): assert stp_varlen_mask is None vq_vae = snt.nets.VectorQuantizer( embedding_dim=cfg.stp_emb_vq_embedding_dim, num_embeddings=cfg.stp_emb_vq_codebook_size, commitment_cost=cfg.stp_emb_vq_commitment_cost) vq_vae_output = vq_vae(pre_vq_encoding, is_training=is_training) stp_emb_vq_quantized = vq_vae_output["quantize"] stp_emb_vq_discrete = tf.reshape( tf.argmax(vq_vae_output["encodings"], axis=1, output_type=tf.int32), [batch_size, seq_len]) stp_emb_vq_codebook = tf.transpose(vq_vae.embeddings) out_dict["stp_emb_vq_quantized"] = stp_emb_vq_quantized out_dict["stp_emb_vq_discrete"] = stp_emb_vq_discrete out_dict["stp_emb_vq_loss"] = vq_vae_output["loss"] out_dict["stp_emb_vq_codebook"] = stp_emb_vq_codebook out_dict["stp_emb_vq_codebook_ppl"] = vq_vae_output["perplexity"] latents.append(stp_emb_vq_quantized) # This tensor retrieves continuous embeddings from codebook. It should # *never* be used during training. out_dict["stp_emb_vq_quantized_lookup"] = tf.nn.embedding_lookup( stp_emb_vq_codebook, stp_emb_vq_discrete) # Integer-quantized step embeddings with straight-through if cfg.stp_emb_iq: with tf.variable_scope("stp_emb_iq"): with tf.variable_scope("pre_iq"): # pre_iq_encoding is tf.float32 of [batch_size, seq_len] pre_iq_encoding = tf.layers.dense(enc_stp, 1)[:, :, 0] def iqst(x, n): """Integer quantization with straight-through estimator.""" eps = 1e-7 s = float(n - 1) xp = tf.clip_by_value((x + 1) / 2.0, -eps, 1 + eps) xpp = tf.round(s * xp) xppp = 2 * (xpp / s) - 1 return xpp, x + tf.stop_gradient(xppp - x) with tf.variable_scope("quantizer"): # Pass rounded vals to decoder w/ straight-through estimator stp_emb_iq_discrete_f, stp_emb_iq_discrete_rescaled = iqst( pre_iq_encoding, cfg.stp_emb_iq_nbins) stp_emb_iq_discrete = tf.cast(stp_emb_iq_discrete_f + 1e-4, tf.int32) stp_emb_iq_discrete_f = tf.cast(stp_emb_iq_discrete, tf.float32) stp_emb_iq_quantized = tf.expand_dims( stp_emb_iq_discrete_rescaled, axis=2) # Determine which elements round to valid indices stp_emb_iq_inrange = tf.logical_and( tf.greater_equal(pre_iq_encoding, -1), tf.less_equal(pre_iq_encoding, 1)) stp_emb_iq_inrange_mask = tf.cast(stp_emb_iq_inrange, tf.float32) stp_emb_iq_valid_p = weighted_avg(stp_emb_iq_inrange_mask, stp_varlen_mask) # Regularize to encourage encoder to output in range stp_emb_iq_range_penalty = weighted_avg( tf.square(tf.maximum(tf.abs(pre_iq_encoding) - 1, 0)), stp_varlen_mask) # Regularize to correlate latent finite differences to input stp_emb_iq_dlatents = pre_iq_encoding[:, 1:] - pre_iq_encoding[:, : -1] if cfg.stp_emb_iq_contour_dy_scalar: stp_emb_iq_dnotes = pitches_scalar[:, 1:] - pitches_scalar[:, : -1] else: stp_emb_iq_dnotes = tf.cast( pitches[:, 1:] - pitches[:, :-1], tf.float32) if cfg.stp_emb_iq_contour_exp == 1: power_func = tf.identity elif cfg.stp_emb_iq_contour_exp == 2: power_func = tf.square else: raise NotImplementedError() if cfg.stp_emb_iq_contour_comp == "product": comp_func = tf.multiply elif cfg.stp_emb_iq_contour_comp == "quotient": def comp_func(x, y): return tf.divide(x, y + 1e-6) else: raise NotImplementedError() stp_emb_iq_contour_penalty = weighted_avg( power_func( tf.maximum( cfg.stp_emb_iq_contour_margin - comp_func(stp_emb_iq_dnotes, stp_emb_iq_dlatents), 0)), None if stp_varlen_mask is None else stp_varlen_mask[:, 1:]) # Regularize to maintain note consistency stp_emb_iq_note_held = tf.cast( tf.equal(pitches[:, 1:] - pitches[:, :-1], 0), tf.float32) if cfg.stp_emb_iq_deviate_exp == 1: power_func = tf.abs elif cfg.stp_emb_iq_deviate_exp == 2: power_func = tf.square if stp_varlen_mask is None: mask = stp_emb_iq_note_held else: mask = stp_varlen_mask[:, 1:] * stp_emb_iq_note_held stp_emb_iq_deviate_penalty = weighted_avg( power_func(stp_emb_iq_dlatents), mask) # Calculate perplexity of discrete encoder posterior if stp_varlen_mask is None: mask = stp_emb_iq_inrange_mask else: mask = stp_varlen_mask * stp_emb_iq_inrange_mask stp_emb_iq_discrete_oh = tf.one_hot(stp_emb_iq_discrete, cfg.stp_emb_iq_nbins) stp_emb_iq_avg_probs = weighted_avg(stp_emb_iq_discrete_oh, mask, axis=[0, 1], expand_mask=True) stp_emb_iq_discrete_ppl = tf.exp( -tf.reduce_sum(stp_emb_iq_avg_probs * tf.log(stp_emb_iq_avg_probs + 1e-10))) out_dict["stp_emb_iq_quantized"] = stp_emb_iq_quantized out_dict["stp_emb_iq_discrete"] = stp_emb_iq_discrete out_dict["stp_emb_iq_valid_p"] = stp_emb_iq_valid_p out_dict["stp_emb_iq_range_penalty"] = stp_emb_iq_range_penalty out_dict["stp_emb_iq_contour_penalty"] = stp_emb_iq_contour_penalty out_dict["stp_emb_iq_deviate_penalty"] = stp_emb_iq_deviate_penalty out_dict["stp_emb_iq_discrete_ppl"] = stp_emb_iq_discrete_ppl latents.append(stp_emb_iq_quantized) # This tensor converts discrete values to continuous. # It should *never* be used during training. out_dict["stp_emb_iq_quantized_lookup"] = tf.expand_dims( 2. * (stp_emb_iq_discrete_f / (cfg.stp_emb_iq_nbins - 1.)) - 1., axis=2) # Sequence embedding (single vector per sequence) if cfg.seq_emb_unconstrained: with tf.variable_scope("seq_emb_unconstrained"): seq_emb_unconstrained = tf.layers.dense( enc_seq, cfg.seq_emb_unconstrained_embedding_dim) out_dict["seq_emb_unconstrained"] = seq_emb_unconstrained seq_emb_unconstrained = tf.stack([seq_emb_unconstrained] * seq_len, axis=1) latents.append(seq_emb_unconstrained) # Sequence embeddings (variational w/ reparameterization trick) if cfg.seq_emb_vae: with tf.variable_scope("seq_emb_vae"): seq_emb_vae = tf.layers.dense(enc_seq, cfg.seq_emb_vae_embedding_dim * 2) mean = seq_emb_vae[:, :cfg.seq_emb_vae_embedding_dim] stddev = 1e-6 + tf.nn.softplus( seq_emb_vae[:, cfg.seq_emb_vae_embedding_dim:]) seq_emb_vae = mean + stddev * tf.random_normal( tf.shape(mean), 0, 1, dtype=dtype) kl = tf.reduce_mean( 0.5 * tf.reduce_sum(tf.square(mean) + tf.square(stddev) - tf.log(1e-8 + tf.square(stddev)) - 1, axis=1)) out_dict["seq_emb_vae"] = seq_emb_vae out_dict["seq_emb_vae_kl"] = kl seq_emb_vae = tf.stack([seq_emb_vae] * seq_len, axis=1) latents.append(seq_emb_vae) # Low-rate embeddings if cfg.lor_emb_unconstrained: assert seq_len % cfg.lor_emb_n == 0 with tf.variable_scope("lor_emb_unconstrained"): # Downsample step embeddings rnn_embedding_dim = int(enc_stp.get_shape()[-1]) enc_lor = tf.reshape(enc_stp, [ batch_size, seq_len // cfg.lor_emb_n, cfg.lor_emb_n * rnn_embedding_dim ]) lor_emb_unconstrained = tf.layers.dense( enc_lor, cfg.lor_emb_unconstrained_embedding_dim) out_dict["lor_emb_unconstrained"] = lor_emb_unconstrained # Upsample lo-rate embeddings for decoding lor_emb_unconstrained = tf.expand_dims(lor_emb_unconstrained, axis=2) lor_emb_unconstrained = tf.tile(lor_emb_unconstrained, [1, 1, cfg.lor_emb_n, 1]) lor_emb_unconstrained = tf.reshape( lor_emb_unconstrained, [batch_size, seq_len, cfg.lor_emb_unconstrained_embedding_dim]) latents.append(lor_emb_unconstrained) # Build decoder features dec_feats = latents if cfg.dec_autoregressive: # Retrieve pitch numbers curr_pitches = pitches last_pitches = curr_pitches[:, :-1] last_pitches = tf.pad(last_pitches, [[0, 0], [1, 0]], constant_values=-1) # Prepend <SOS> token out_dict["dec_last_pitches"] = last_pitches dec_feats.append(tf.one_hot(last_pitches + 1, 89)) if cfg.dec_pred_velocity: curr_velocities = velocities last_velocities = curr_velocities[:, :-1] last_velocities = tf.pad(last_velocities, [[0, 0], [1, 0]]) dec_feats.append( tf.one_hot(last_velocities, cfg.data_max_discrete_velocities + 1)) if "delta_times_int" in cfg.dec_aux_feats: dec_feats.append( tf.one_hot(feat_dict["delta_times_int"], cfg.data_max_discrete_times + 1)) if "velocities" in cfg.dec_aux_feats: assert not cfg.dec_pred_velocity dec_feats.append( tf.one_hot(feat_dict["velocities"], cfg.data_max_discrete_velocities + 1)) assert dec_feats dec_feats = tf.concat(dec_feats, axis=2) # Decode with tf.variable_scope("decoder"): dec_stp, dec_initial_state, dec_final_state = simple_lstm_decoder( dec_feats, seq_lens, batch_size, rnn_celltype=cfg.rnn_celltype, rnn_nlayers=cfg.rnn_nlayers, rnn_nunits=cfg.rnn_nunits) with tf.variable_scope("pitches"): dec_recons_logits = tf.layers.dense(dec_stp, 88) dec_recons_loss = weighted_avg( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=dec_recons_logits, labels=pitches), stp_varlen_mask) out_dict["dec_initial_state"] = dec_initial_state out_dict["dec_final_state"] = dec_final_state out_dict["dec_recons_logits"] = dec_recons_logits out_dict["dec_recons_scores"] = tf.nn.softmax(dec_recons_logits, axis=-1) out_dict["dec_recons_preds"] = tf.argmax(dec_recons_logits, output_type=tf.int32, axis=-1) out_dict["dec_recons_midi_preds"] = util.remidify( out_dict["dec_recons_preds"]) out_dict["dec_recons_loss"] = dec_recons_loss if cfg.dec_pred_velocity: with tf.variable_scope("velocities"): dec_recons_velocity_logits = tf.layers.dense( dec_stp, cfg.data_max_discrete_velocities + 1) dec_recons_velocity_loss = weighted_avg( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=dec_recons_velocity_logits, labels=velocities), stp_varlen_mask) out_dict["dec_recons_velocity_logits"] = dec_recons_velocity_logits out_dict["dec_recons_velocity_loss"] = dec_recons_velocity_loss # Stats if cfg.stp_emb_vq or cfg.stp_emb_iq: discrete = out_dict["stp_emb_vq_discrete" if cfg. stp_emb_vq else "stp_emb_iq_discrete"] dx = pitches[:, 1:] - pitches[:, :-1] dy = discrete[:, 1:] - discrete[:, :-1] contour_violation = tf.reduce_mean( tf.cast(tf.less(dx * dy, 0), tf.float32)) dx_hold = tf.equal(dx, 0) deviate_violation = weighted_avg( tf.cast(tf.not_equal(dy, 0), tf.float32), tf.cast(dx_hold, tf.float32)) out_dict["contour_violation"] = contour_violation out_dict["deviate_violation"] = deviate_violation return out_dict
def train(config): """Train.""" logging.info("Training.") tf.reset_default_graph() np.set_printoptions(precision=4) # Get data. raw_data = reader.get_raw_data(data_path=config.data_dir, dataset=config.dataset) train_data, valid_data, word_to_id = raw_data id_to_word = {v: k for k, v in word_to_id.items()} vocab_size = len(word_to_id) max_length = reader.MAX_TOKENS_SEQUENCE[config.dataset] logging.info("Vocabulary size: %d", vocab_size) iterator = reader.iterator(raw_data=train_data, batch_size=config.batch_size) iterator_valid = reader.iterator(raw_data=valid_data, batch_size=config.batch_size) real_sequence = tf.placeholder(dtype=tf.int32, shape=[config.batch_size, max_length], name="real_sequence") real_sequence_length = tf.placeholder(dtype=tf.int32, shape=[config.batch_size], name="real_sequence_length") first_batch_np = next(iterator) valid_batch_np = next(iterator_valid) test_real_batch = {k: tf.constant(v) for k, v in first_batch_np.items()} test_fake_batch = { "sequence": tf.constant( np.random.choice(vocab_size, size=[config.batch_size, max_length]).astype(np.int32)), "sequence_length": tf.constant( np.random.choice(max_length, size=[config.batch_size]).astype(np.int32)), } valid_batch = {k: tf.constant(v) for k, v in valid_batch_np.items()} # Create generator. if config.use_pretrained_embedding: embedding_source = utils.get_embedding_path(config.data_dir, config.dataset) vocab_file = "/tmp/vocab.txt" with gfile.GFile(vocab_file, "w") as f: for i in range(len(id_to_word)): f.write(id_to_word[i] + "\n") logging.info("Temporary vocab file: %s", vocab_file) else: embedding_source = None vocab_file = None gen = generators.LSTMGen( vocab_size=vocab_size, feature_sizes=[config.gen_feature_size] * config.num_layers_gen, max_sequence_length=reader.MAX_TOKENS_SEQUENCE[config.dataset], batch_size=config.batch_size, use_layer_norm=config.layer_norm_gen, trainable_embedding_size=config.trainable_embedding_size, input_dropout=config.gen_input_dropout, output_dropout=config.gen_output_dropout, pad_token=reader.PAD_INT, embedding_source=embedding_source, vocab_file=vocab_file, ) gen_outputs = gen() # Create discriminator. disc = discriminator_nets.LSTMEmbedDiscNet( vocab_size=vocab_size, feature_sizes=[config.disc_feature_size] * config.num_layers_disc, trainable_embedding_size=config.trainable_embedding_size, embedding_source=embedding_source, use_layer_norm=config.layer_norm_disc, pad_token=reader.PAD_INT, vocab_file=vocab_file, dropout=config.disc_dropout, ) disc_logits_real = disc(sequence=real_sequence, sequence_length=real_sequence_length) disc_logits_fake = disc(sequence=gen_outputs["sequence"], sequence_length=gen_outputs["sequence_length"]) # Loss of the discriminator. if config.disc_loss_type == "ce": targets_real = tf.ones( [config.batch_size, reader.MAX_TOKENS_SEQUENCE[config.dataset]]) targets_fake = tf.zeros( [config.batch_size, reader.MAX_TOKENS_SEQUENCE[config.dataset]]) loss_real = losses.sequential_cross_entropy_loss( disc_logits_real, targets_real) loss_fake = losses.sequential_cross_entropy_loss( disc_logits_fake, targets_fake) disc_loss = 0.5 * loss_real + 0.5 * loss_fake # Loss of the generator. gen_loss, cumulative_rewards, baseline = losses.reinforce_loss( disc_logits=disc_logits_fake, gen_logprobs=gen_outputs["logprobs"], gamma=config.gamma, decay=config.baseline_decay) # Optimizers disc_optimizer = tf.train.AdamOptimizer(learning_rate=config.disc_lr, beta1=config.disc_beta1) gen_optimizer = tf.train.AdamOptimizer(learning_rate=config.gen_lr, beta1=config.gen_beta1) # Get losses and variables. disc_vars = disc.get_all_variables() gen_vars = gen.get_all_variables() l2_disc = tf.reduce_sum(tf.add_n([tf.nn.l2_loss(v) for v in disc_vars])) l2_gen = tf.reduce_sum(tf.add_n([tf.nn.l2_loss(v) for v in gen_vars])) scalar_disc_loss = tf.reduce_mean(disc_loss) + config.l2_disc * l2_disc scalar_gen_loss = tf.reduce_mean(gen_loss) + config.l2_gen * l2_gen # Update ops. global_step = tf.train.get_or_create_global_step() disc_update = disc_optimizer.minimize(scalar_disc_loss, var_list=disc_vars, global_step=global_step) gen_update = gen_optimizer.minimize(scalar_gen_loss, var_list=gen_vars, global_step=global_step) # Saver. saver = tf.train.Saver() # Metrics test_disc_logits_real = disc(**test_real_batch) test_disc_logits_fake = disc(**test_fake_batch) valid_disc_logits = disc(**valid_batch) disc_predictions_real = tf.nn.sigmoid(disc_logits_real) disc_predictions_fake = tf.nn.sigmoid(disc_logits_fake) valid_disc_predictions = tf.reduce_mean(tf.nn.sigmoid(valid_disc_logits), axis=0) test_disc_predictions_real = tf.reduce_mean( tf.nn.sigmoid(test_disc_logits_real), axis=0) test_disc_predictions_fake = tf.reduce_mean( tf.nn.sigmoid(test_disc_logits_fake), axis=0) # Only log results for the first element of the batch. metrics = { "scalar_gen_loss": scalar_gen_loss, "scalar_disc_loss": scalar_disc_loss, "disc_predictions_real": tf.reduce_mean(disc_predictions_real), "disc_predictions_fake": tf.reduce_mean(disc_predictions_fake), "test_disc_predictions_real": tf.reduce_mean(test_disc_predictions_real), "test_disc_predictions_fake": tf.reduce_mean(test_disc_predictions_fake), "valid_disc_predictions": tf.reduce_mean(valid_disc_predictions), "cumulative_rewards": tf.reduce_mean(cumulative_rewards), "baseline": tf.reduce_mean(baseline), } # Training. logging.info("Starting training") with tf.Session() as sess: sess.run(tf.global_variables_initializer()) latest_ckpt = tf.train.latest_checkpoint(config.checkpoint_dir) if latest_ckpt: saver.restore(sess, latest_ckpt) for step in range(config.num_steps): real_data_np = next(iterator) train_feed = { real_sequence: real_data_np["sequence"], real_sequence_length: real_data_np["sequence_length"], } # Update generator and discriminator. for _ in range(config.num_disc_updates): sess.run(disc_update, feed_dict=train_feed) for _ in range(config.num_gen_updates): sess.run(gen_update, feed_dict=train_feed) # Reporting if step % config.export_every == 0: gen_sequence_np, metrics_np = sess.run( [gen_outputs["sequence"], metrics], feed_dict=train_feed) metrics_np["gen_sentence"] = utils.sequence_to_sentence( gen_sequence_np[0, :], id_to_word) saver.save(sess, save_path=config.checkpoint_dir + "scratchgan", global_step=global_step) metrics_np["model_path"] = tf.train.latest_checkpoint( config.checkpoint_dir) logging.info(metrics_np) # After training, export models. saver.save(sess, save_path=config.checkpoint_dir + "scratchgan", global_step=global_step) logging.info("Saved final model at %s.", tf.train.latest_checkpoint(config.checkpoint_dir))
def ones(shape, name=None): """All ones.""" initial = tf.ones(shape, dtype=tf.float32) return tf.Variable(initial, name=name)
#при примерно равных 2 и 4 стабильно отлает предпочтение 2ому #немного лечится увеличением рандома параметром е (который в начале вообще хотел выпилить) но это вносит бОльший шанс ошибки "не-2" (например выберет 3) num_bandits = len(bandits) def pullBandit(bandit): #Сгенерировать случайное число result = np.random.randn(1) if result > bandit: #Выигрыш return 1 else: #Проигрыш return -1 tf.reset_default_graph() weights = tf.Variable(tf.ones([num_bandits])) chosen_action = tf.argmax(weights, 0) reward_holder = tf.placeholder(shape=[1], dtype=tf.float32) action_holder = tf.placeholder(shape=[1], dtype=tf.int32) responsible_weight = tf.slice(weights, action_holder, [1]) loss = -(tf.log(responsible_weight) * reward_holder) optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001) update = optimizer.minimize(loss) total_episodes = 1000 total_reward = np.zeros(num_bandits) e = 0.3 init = tf.initializers.global_variables()
def rnn_differentiable_plasticity(seqs, batch_size=32, hidden_size=50, fast_steps=1, fast_lr_fixed=None, use_oja_rule=None, update_mem_with_prev_timestep=None, learn_fast_lr=None, learn_plasticity_coeffs=None): """Implements differentiable plasticity from Miconi et al. 2018. Differences with FW: (1) learnable plasticity_coeffs (2) learnable fast learning rate, eta (3a) Oja's rule to update memory matrix, instead of Hebb's rule OR (3b) use previous timestep to update memory matrix with Hebb's rule Args: seqs: <tf.float32>[batch_size, seq_len, pattern_size] Sequences of patterns, where `seq_len` is the length of a sequence (including the query pattern) and `pattern_size` is the dimensionality of each pattern. batch_size (int): Batch size. hidden_size (int): Number of hidden units fast_steps (int): Number of inner loop iterations we apply fast weights. fast_lr_fixed (float): Learning rate (eta) for fast weights update if fast lr is not learned. use_oja_rule (bool): True if we update the memory matrix with Oja's rule. update_mem_with_prev_timestep (bool): True if we update the memory matrix by a dot product with the previous hidden state. (only applies if we use Hebb's rule to update) learn_fast_lr (bool): True if the fast learning rate is learnable. learn_plasticity_coeffs (bool): True if the plasticity coefficients are learnable. Returns: preds: <tf.float32>[batch_size, pattern_size] The retrieved pattern for the degraded query at the end of the sequence. Raises: TypeError: If kwargs are not specified. """ # Validate boolean args that would otherwise fail silently if (use_oja_rule is None or update_mem_with_prev_timestep is None or learn_fast_lr is None or learn_plasticity_coeffs is None): raise TypeError( "Settings must be specified for differentiable plasticity.") _, seq_len, pattern_size = seqs.shape # Initialize parameters w1 = tf.get_variable("w1", shape=[pattern_size + hidden_size, hidden_size]) b1 = tf.get_variable("b1", shape=[hidden_size], initializer=tf.zeros_initializer()) # state: <float32>[batch_size, hidden_size] state = tf.zeros([batch_size, hidden_size], name="state") # a_memory: <float32>[batch_size, hidden_size, hidden_size] a_memory = tf.zeros([batch_size, hidden_size, hidden_size], dtype=tf.float32, name="A") if learn_plasticity_coeffs: # plasticity_coeffs: <float32>[hidden_size, hidden_size] plasticity_coeffs = tf.get_variable("plasticity_coeffs", shape=[hidden_size, hidden_size]) if learn_fast_lr: # fast_lr_learned: <float32>[] fast_lr_learned = tf.get_variable("fast_lr_learned", shape=[]) fast_lr = tf.nn.sigmoid(fast_lr_learned) else: fast_lr = fast_lr_fixed tf.summary.scalar("fast_lr", fast_lr) # Unroll graph manually for timestep in range(seq_len): # inp: <float32>[batch_size, pattern_size + hidden_size] inp = tf.concat([seqs[:, timestep, :], state], 1) last_state = state # state: <float32>[batch_size, hidden_size] state = tf.matmul(inp, w1) + b1 boundary_state = state # "sustained boundary condition," pre-nonlinearity state = tf.tanh(state) # Apply fast weights for _ in range(fast_steps): # fw_state: <float32>[batch_size, hidden_size] fw_state = tf.squeeze(tf.matmul(a_memory, tf.expand_dims(state, 2))) # Apply plasticity coefficient if learn_plasticity_coeffs: # fw_state: <float32>[batch_size, hidden_size, 1] fw_state = tf.expand_dims(fw_state, 2) # pc_tiled: <float32>[batch_size, hidden_size, hidden_size] pc_tiled = (tf.ones([batch_size, hidden_size, hidden_size]) * plasticity_coeffs) # fw_state: <float32>[batch_size, hidden_size] fw_state = tf.squeeze(tf.matmul(pc_tiled, fw_state)) state = boundary_state + fw_state state = layer_norm(state, hidden_size) state = tf.tanh(state) # Update fast weights matrix if use_oja_rule: a_memory = a_memory + fast_lr * (tf.multiply( tf.expand_dims(state, 1), tf.expand_dims(last_state, 2) - tf.multiply(tf.expand_dims(state, 1), a_memory))) elif update_mem_with_prev_timestep: a_memory = (fast_lr * tf.matmul(tf.expand_dims(last_state, 2), tf.expand_dims(state, 1)) + (1 - fast_lr) * a_memory) else: # Fast weights update, except only fast_lr is parameterized a_memory = (1 - fast_lr) * a_memory + fast_lr * tf.matmul( tf.expand_dims(state, 2), # <float32>[batch_size, hidden_size, 1] tf.expand_dims(state, 1)) # <float32>[batch_size, 1, hidden_size] # preds: <float32>[batch_size, pattern_size] preds = tf.layers.dense(state, pattern_size) tf.summary.histogram("preds", preds) return preds
def _create_variables(self): with tf.name_scope('linear_embedding'): self.embedding_P_1 = tf.Variable(tf.truncated_normal( shape=[self.num_users, self.embedding_size], mean=0.0, stddev=0.01), name='embedding_P_1', dtype=tf.float32) self.embedding_P_2 = tf.Variable(tf.truncated_normal( shape=[self.num_users, self.embedding_size], mean=0.0, stddev=0.01), name='embedding_P_2', dtype=tf.float32) if self.b_num == 3: # b_num == 2 or 3 self.embedding_P_3 = tf.Variable(tf.truncated_normal( shape=[self.num_users, self.embedding_size], mean=0.0, stddev=0.01), name='embedding_P_3', dtype=tf.float32) self.embedding_Q = tf.Variable(tf.truncated_normal( shape=[self.num_items, self.embedding_size], mean=0.0, stddev=0.01), name='embedding_Q', dtype=tf.float32) with tf.name_scope('attention_layer'): pass # with tf.name_scope('shared_bias'): # self.bias = tf.Variable(tf.zeros([self.num_items, 1]), name='bias', dtype=tf.float32) with tf.name_scope('NCF'): # the h-vector in original paper v_size = self.embedding_size + int(self.embedding_size / (2**(self.layer_num - 1))) self.v_1 = tf.Variable(tf.random_uniform( [v_size, 1], minval=-tf.sqrt(3 / v_size), maxval=tf.sqrt(3 / v_size)), name='v_1') self.v_2 = tf.Variable(tf.random_uniform( [v_size, 1], minval=-tf.sqrt(3 / v_size), maxval=tf.sqrt(3 / v_size)), name='v_2') if self.b_num == 3: self.v_3 = tf.Variable(tf.random_uniform( [v_size, 1], minval=-tf.sqrt(3 / v_size), maxval=tf.sqrt(3 / v_size)), name='v_3') if self.layer_num == 0: pass # no variable elif self.layer_num == 1: # view specific self.W1 = tf.Variable(tf.random_uniform( shape=[2 * self.embedding_size, self.embedding_size], minval=-tf.sqrt(1 / self.embedding_size), maxval=tf.sqrt(1 / self.embedding_size)), name='W1') self.b1 = tf.Variable(tf.zeros([1, self.embedding_size]), dtype=tf.float32, name='b1') # add cart specific self.W2 = tf.Variable(tf.random_uniform( shape=[2 * self.embedding_size, self.embedding_size], minval=-tf.sqrt(3 / (2 * self.embedding_size)), maxval=tf.sqrt(3 / (2 * self.embedding_size))), name='W2') self.b2 = tf.Variable(tf.zeros([1, self.embedding_size]), dtype=tf.float32, name='b2') # buy specific if self.b_num == 3: self.W3 = tf.Variable(tf.random_uniform( shape=[2 * self.embedding_size, self.embedding_size], minval=-tf.sqrt(3 / (2 * self.embedding_size)), maxval=tf.sqrt(3 / (2 * self.embedding_size))), name='W3') self.b3 = tf.Variable(tf.zeros([1, self.embedding_size]), dtype=tf.float32, name='b3') else: self.W1, self.b1 = [], [] self.W2, self.b2 = [], [] if self.b_num == 3: self.W3, self.b3 = [], [] for i in range(self.layer_num): input_size = int(2 * self.embedding_size / (2**i)) output_size = int(2 * self.embedding_size / (2**(i + 1))) self.W1.append( tf.Variable(tf.random_uniform( shape=[input_size, output_size], minval=-tf.sqrt(3 / input_size), maxval=tf.sqrt(3 / input_size)), name='W1_%d' % i)) self.b1.append( tf.Variable(tf.zeros([1, output_size]), dtype=tf.float32, name='b1_%d' % i)) self.W2.append( tf.Variable(tf.random_uniform( shape=[input_size, output_size], minval=-tf.sqrt(3 / input_size), maxval=tf.sqrt(3 / input_size)), name='W2_%d' % i)) self.b2.append( tf.Variable(tf.zeros([1, output_size]), dtype=tf.float32, name='b2_%d' % i)) if self.b_num == 3: self.W3.append( tf.Variable(tf.random_uniform( shape=[input_size, output_size], minval=-tf.sqrt(3 / input_size), maxval=tf.sqrt(3 / input_size)), name='W3_%d' % i)) self.b3.append( tf.Variable(tf.zeros([1, output_size]), dtype=tf.float32, name='b3_%d' % i)) with tf.name_scope('multi_task_learning'): self.H = tf.Variable(tf.ones(shape=[self.num_users, self.b_num]), name='co_relation_h', dtype=tf.float32)
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf.logging.info("*** Features ***") for name in sorted(features.keys()): tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] label_ids = features["label_ids"] is_real_example = None if "is_real_example" in features: is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) else: is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) is_training = (mode == tf.estimator.ModeKeys.TRAIN) (total_loss, per_example_loss, probabilities, logits, predictions) = \ create_model(albert_config, is_training, input_ids, input_mask, segment_ids, label_ids, num_labels, use_one_hot_embeddings, task_name, hub_module) tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None if init_checkpoint: (assignment_map, initialized_variable_names ) = modeling.get_assignment_map_from_checkpoint( tvars, init_checkpoint) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) tf.logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimization.create_optimizer(total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu, optimizer) output_spec = contrib_tpu.TPUEstimatorSpec(mode=mode, loss=total_loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.EVAL: if task_name not in ["sts-b", "cola"]: def metric_fn(per_example_loss, label_ids, logits, is_real_example): predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) accuracy = tf.metrics.accuracy(labels=label_ids, predictions=predictions, weights=is_real_example) loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example) return { "eval_accuracy": accuracy, "eval_loss": loss, } elif task_name == "sts-b": def metric_fn(per_example_loss, label_ids, logits, is_real_example): """Compute Pearson correlations for STS-B.""" # Display labels and predictions concat1 = contrib_metrics.streaming_concat(logits) concat2 = contrib_metrics.streaming_concat(label_ids) # Compute Pearson correlation pearson = contrib_metrics.streaming_pearson_correlation( logits, label_ids, weights=is_real_example) # Compute MSE # mse = tf.metrics.mean(per_example_loss) mse = tf.metrics.mean_squared_error( label_ids, logits, weights=is_real_example) loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example) return { "pred": concat1, "label_ids": concat2, "pearson": pearson, "MSE": mse, "eval_loss": loss, } elif task_name == "cola": def metric_fn(per_example_loss, label_ids, logits, is_real_example): """Compute Matthew's correlations for STS-B.""" predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) # https://en.wikipedia.org/wiki/Matthews_correlation_coefficient tp, tp_op = tf.metrics.true_positives( predictions, label_ids, weights=is_real_example) tn, tn_op = tf.metrics.true_negatives( predictions, label_ids, weights=is_real_example) fp, fp_op = tf.metrics.false_positives( predictions, label_ids, weights=is_real_example) fn, fn_op = tf.metrics.false_negatives( predictions, label_ids, weights=is_real_example) # Compute Matthew's correlation mcc = tf.div_no_nan( tp * tn - fp * fn, tf.pow((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn), 0.5)) # Compute accuracy accuracy = tf.metrics.accuracy(labels=label_ids, predictions=predictions, weights=is_real_example) loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example) return { "matthew_corr": (mcc, tf.group(tp_op, tn_op, fp_op, fn_op)), "eval_accuracy": accuracy, "eval_loss": loss, } eval_metrics = (metric_fn, [ per_example_loss, label_ids, logits, is_real_example ]) output_spec = contrib_tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) else: output_spec = contrib_tpu.TPUEstimatorSpec(mode=mode, predictions={ "probabilities": probabilities, "predictions": predictions }, scaffold_fn=scaffold_fn) return output_spec
def cosine_embedding_single_instance_loss(embeddings, target_instances_mask, other_instances_mask, normalized_embeddings=True, term_1_2_normalization='individual', term_0_squared=False, term_1_squared=False, return_h_norm=False, use_first_frame_for_mean=False, is_background=False, data_format='channels_first'): if data_format == 'channels_first': embedding_axis = 0 frame_axis = 1 if embeddings.shape.ndims == 3: image_axes = [1, 2] elif embeddings.shape.ndims == 4: image_axes = [1, 2, 3] else: embedding_axis = embeddings.shape.ndims - 1 frame_axis = 0 if embeddings.shape.ndims == 3: image_axes = [0, 1] elif embeddings.shape.ndims == 4: image_axes = [0, 1, 2] # expand axis, such that embeddings and instances are in different dimensions # create target and other instances pixel masks # calculate mean embedding for target pixels if use_first_frame_for_mean: # FIXME: does not support videos/volumes slices = [slice(None)] * 4 slices.insert(frame_axis, slice(0, 1)) h = reduce_sum_masked(embeddings[slices], target_instances_mask[slices], axis=image_axes, keepdims=True) else: if not is_background: h = reduce_mean_masked(embeddings, target_instances_mask, axis=image_axes, keepdims=True) else: if embeddings.shape.ndims == 3: if term_1_squared: h = tf.concat([ tf.ones((1, 1, 1), dtype=embeddings.dtype), tf.zeros((embeddings.shape[embedding_axis] - 1, 1, 1), dtype=embeddings.dtype) ], axis=0) else: h = tf.ones( (embeddings.shape[embedding_axis], 1, 1)) * (-1) elif embeddings.shape.ndims == 4: if term_1_squared: h = tf.concat([ tf.ones((1, 1, 1, 1), dtype=embeddings.dtype), tf.zeros( (embeddings.shape[embedding_axis] - 1, 1, 1, 1), dtype=embeddings.dtype) ], axis=0) else: h = tf.ones( (embeddings.shape[embedding_axis], 1, 1, 1)) * (-1) h_norm = tf.nn.l2_normalize(h, dim=embedding_axis) #, epsilon=1e-4) # l2_normalize embeddings -> needed for cos_simliarity if normalized_embeddings is None: embeddings = tf.nn.l2_normalize(embeddings, dim=embedding_axis) else: embeddings = normalized_embeddings # calculate cos_similarity with target mean embedding and all embeddings cos_similarity = tf.reduce_sum(h_norm * embeddings, axis=embedding_axis, keepdims=True) # term_0: target mean embedding and target pixel embeddings should be as similar as possible if term_0_squared: term_0 = 1 - (cos_similarity**2) else: term_0 = 1 - cos_similarity if term_1_squared: # term_1: target mean embedding and other pixel embeddings should be orthogonal (== 0) term_1 = cos_similarity**2 else: # term_1: target mean embedding and other pixel embeddings should be far apart (>= 0) term_1 = tf.nn.relu(cos_similarity) if term_1_2_normalization == 'individual': term_0 = tf.expand_dims(reduce_mean_masked(term_0, target_instances_mask), axis=0) term_1 = tf.expand_dims(reduce_mean_masked(term_1, other_instances_mask), axis=0) return term_0, term_1 elif term_1_2_normalization == 'none' or term_1_2_normalization == 'combined': term_0 = tf.boolean_mask(term_0, target_instances_mask) term_1 = tf.boolean_mask(term_1, other_instances_mask) if not return_h_norm: return term_0, term_1 else: return term_0, term_1, tf.squeeze(h_norm) else: assert 'invalid normalization mode'
def _local_perm(inputs, targets, is_masked, perm_size, seq_len): """Samples a permutation of the factorization order, and create a mask. Args: inputs: int64 Tensor in shape [seq_len], input ids. targets: int64 Tensor in shape [seq_len], target ids. is_masked: bool Tensor in shape [seq_len]. True means being selected for partial prediction. perm_size: the length of longest permutation. Could be set to be reuse_len. Should not be larger than reuse_len or there will be data leaks. seq_len: int, sequence length. Returns: The permutation mask, new targets, target mask, and new inputs. """ # Generate permutation indices index = tf.range(seq_len, dtype=tf.int64) index = tf.transpose(tf.reshape(index, [-1, perm_size])) index = tf.random_shuffle(index) index = tf.reshape(tf.transpose(index), [-1]) # `perm_mask` and `target_mask` # non-functional tokens non_func_tokens = tf.logical_not(tf.logical_or( tf.equal(inputs, SEP_ID), tf.equal(inputs, CLS_ID))) non_mask_tokens = tf.logical_and(tf.logical_not(is_masked), non_func_tokens) masked_or_func_tokens = tf.logical_not(non_mask_tokens) # Set the permutation indices of non-masked (& non-funcional) tokens to the # smallest index (-1): # (1) they can be seen by all other positions # (2) they cannot see masked positions, so there won"t be information leak smallest_index = -tf.ones([seq_len], dtype=tf.int64) rev_index = tf.where(non_mask_tokens, smallest_index, index) # Create `target_mask`: non-funcional and maksed tokens # 1: use mask as input and have loss # 0: use token (or [SEP], [CLS]) as input and do not have loss target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens) target_mask = tf.cast(target_tokens, tf.float32) # Create `perm_mask` # `target_tokens` cannot see themselves self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1) # 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens) # 0: can attend if i > j or j is non-masked perm_mask = tf.logical_and( self_rev_index[:, None] <= rev_index[None, :], masked_or_func_tokens) perm_mask = tf.cast(perm_mask, tf.float32) # new target: [next token] for LM and [curr token] (self) for PLM new_targets = tf.concat([inputs[0: 1], targets[: -1]], axis=0) # construct inputs_k inputs_k = inputs # construct inputs_q inputs_q = target_mask return perm_mask, new_targets, target_mask, inputs_k, inputs_q
def testMaskDtype(self, dtype): _ = self._setup_session() mask = tf.ones((3, 2)) mask1 = sparse_utils.get_mask_random(mask, 0.5, dtype) self.assertEqual(mask1.dtype, dtype)
def sample(self, sample_shape=(), seed=None): return tf.ones(sample_shape, dtype=tf.float32) * self.value
def make_prior(num_latents): # Zero mean, unit variance prior. prior_mean = tf.zeros(shape=(num_latents), dtype=tf.float32) prior_scale = tf.ones(shape=(num_latents), dtype=tf.float32) return tfd.Normal(loc=prior_mean, scale=prior_scale)
def __init__(self, config, is_training, input_ids, input_mask=None, token_type_ids=None, use_one_hot_embeddings=False, scope=None): """Constructor for BertModel. Args: config: `BertConfig` instance. is_training: bool. true for training model, false for eval model. Controls whether dropout will be applied. input_ids: int32 Tensor of shape [batch_size, seq_length]. input_mask: (optional) int32 Tensor of shape [batch_size, seq_length]. token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. use_one_hot_embeddings: (optional) bool. Whether to use one-hot word embeddings or tf.embedding_lookup() for the word embeddings. scope: (optional) variable scope. Defaults to "bert". Raises: ValueError: The config is invalid or one of the input tensor shapes is invalid. """ config = copy.deepcopy(config) if not is_training: config.hidden_dropout_prob = 0.0 config.attention_probs_dropout_prob = 0.0 input_shape = get_shape_list(input_ids, expected_rank=2) batch_size = input_shape[0] seq_length = input_shape[1] if input_mask is None: input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32) if token_type_ids is None: token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32) with tf.variable_scope(scope, default_name="bert"): with tf.variable_scope("embeddings"): # Perform embedding lookup on the word ids. (self.word_embedding_output, self.embedding_table) = embedding_lookup( input_ids=input_ids, vocab_size=config.vocab_size, embedding_size=config.hidden_size, initializer_range=config.initializer_range, word_embedding_name="word_embeddings", use_one_hot_embeddings=use_one_hot_embeddings) # Add positional embeddings and token type embeddings, then layer # normalize and perform dropout. self.embedding_output = embedding_postprocessor( input_tensor=self.word_embedding_output, use_token_type=True, token_type_ids=token_type_ids, token_type_vocab_size=config.type_vocab_size, token_type_embedding_name="token_type_embeddings", use_position_embeddings=True, position_embedding_name="position_embeddings", initializer_range=config.initializer_range, max_position_embeddings=config.max_position_embeddings, dropout_prob=config.hidden_dropout_prob) with tf.variable_scope("encoder"): # This converts a 2D mask of shape [batch_size, seq_length] to a 3D # mask of shape [batch_size, seq_length, seq_length] which is used # for the attention scores. attention_mask = create_attention_mask_from_input_mask( input_ids, input_mask) # Run the stacked transformer. # `sequence_output` shape = [batch_size, seq_length, hidden_size]. self.all_encoder_layers = transformer_model( input_tensor=self.embedding_output, attention_mask=attention_mask, hidden_size=config.hidden_size, num_hidden_layers=config.num_hidden_layers, num_attention_heads=config.num_attention_heads, intermediate_size=config.intermediate_size, intermediate_act_fn=get_activation(config.hidden_act), hidden_dropout_prob=config.hidden_dropout_prob, attention_probs_dropout_prob=config. attention_probs_dropout_prob, initializer_range=config.initializer_range, do_return_all_layers=True) self.sequence_output = self.all_encoder_layers[-1] # The "pooler" converts the encoded sequence tensor of shape # [batch_size, seq_length, hidden_size] to a tensor of shape # [batch_size, hidden_size]. This is necessary for segment-level # (or segment-pair-level) classification tasks where we need a fixed # dimensional representation of the segment. with tf.variable_scope("pooler"): # We "pool" the model by simply taking the hidden state corresponding # to the first token. We assume that this has been pre-trained first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1) self.pooled_output = tf.layers.dense( first_token_tensor, config.hidden_size, activation=tf.tanh, kernel_initializer=create_initializer( config.initializer_range))
def infer_step(logits_so_far, current_hidden): """Inference step of LSTM while loop.""" # unflatten hidden: current_hidden = tuple( tf.nn.rnn_cell.LSTMStateTuple(c=s[0], h=s[1]) for s in current_hidden) # put logits_so_far through top tm = self._problem_hparams.modality['targets'] # need to reuse top params reset_scope = tf.variable_scope(tf.VariableScope( tf.AUTO_REUSE, ''), reuse=tf.AUTO_REUSE, auxiliary_name_scope=False) top_scope = tf.variable_scope('svg_decoder/{}_modality'.format(tm), reuse=tf.AUTO_REUSE) with reset_scope, top_scope: samples_so_far = self.hparams.top['targets']( logits_so_far, None, self.hparams, self.problem_hparams.vocab_size) # append a zero pad to the samples. this effectively shifts the samples # right, but, unlike shift_right, by not removing the last element, we # allow an empty samples_so_far to not be empty after padding samples_so_far = tf.concat([zero_pad, samples_so_far], axis=1) shifted_targets = common_layers.flatten4d3d(samples_so_far) # now take the very last one here, will be the actual input to the rnn shifted_targets = shifted_targets[:, -1:, :] # tile and append the bottleneck to inputs sln_offset = 0 if hparams.condition_on_sln: sln_offset = 51 pre_tile_y = tf.reshape(bottleneck, [ common_layers.shape_list(bottleneck)[0], 1, hparams.bottleneck_bits + hparams.num_categories + sln_offset ]) overlay_x = tf.tile( pre_tile_y, [1, common_layers.shape_list(shifted_targets)[1], 1]) inputs = tf.concat([shifted_targets, overlay_x], -1) seq_len_batch = tf.ones([common_layers.shape_list(inputs)[0]]) # RUN PRE-LSTM LAYER with tf.variable_scope('pre_decoder', reuse=tf.AUTO_REUSE): inputs = tf.layers.dense(inputs, hparams.hidden_size, name='bottom') inputs = tf.nn.tanh(inputs) # RUN LSTM with tf.variable_scope('lstm_decoder', reuse=tf.AUTO_REUSE): next_step, next_state = tf.nn.dynamic_rnn( layers, inputs, seq_len_batch, initial_state=current_hidden, dtype=tf.float32, time_major=False) next_step = tf.expand_dims(next_step, [1]) logits_so_far = tf.concat([logits_so_far, next_step], 1) #print('concat success') # input() # flatten state next_state = tuple((s.c, s.h) for s in next_state) return logits_so_far, next_state