def infer(self, features, **kwargs): with tf.variable_scope("sparse_transformer", reuse=tf.AUTO_REUSE): features = self.bottom(features) decode_length = self.hparams.max_target_length cache = {} decoding_stats = {} targets_old = features.get("targets") start_step = 0 initial_output = tf.zeros((self.batch_size, decode_length, 1, 1), dtype=tf.int32) initial_logits = tf.zeros( (self.batch_size, decode_length, self.vocab_size)) # call body once to initialize cache with representations of input frames. features["targets"] = initial_output # Set shape of inputs if "inputs" in features: features["inputs"].set_shape([ self.batch_size, self.hparams.max_length, 1, self.hparams.hidden_size ]) with tf.variable_scope("sparse_transformer/body", reuse=tf.AUTO_REUSE): self.body(features, decode_step=None, cache=cache, decoding_stats=decoding_stats) def infer_step(i, recent_output, recent_logits, cache, decoding_stats): """Inference step.""" features_copy = features.copy() features_copy["targets"] = recent_output cur_sample, cur_logit = self.sample(features_copy, decode_step=i, cache=cache, decoding_stats=decoding_stats) pos = i samples = recent_output + tf.scatter_nd( indices=[[b, pos, 0, 0] for b in range(self.batch_size)], updates=cur_sample, shape=utils.shape_list(recent_output)) logits = recent_logits + tf.scatter_nd( indices=[[b, pos] for b in range(self.batch_size)], updates=cur_logit, shape=utils.shape_list(recent_logits)) return i + 1, samples, logits, cache, decoding_stats def while_exit_cond(i, result, logits, cache, decoding_stats): # pylint: disable=unused-argument """Exit the loop if it reaches decode_length.""" not_overflow = i < decode_length return not_overflow _, final_result, final_logits, _, decoding_stats = tf.while_loop( while_exit_cond, infer_step, [ start_step, initial_output, initial_logits, cache, decoding_stats ], back_prop=False, parallel_iterations=1) original_shape = [decode_length] blocks_per_dim = [ s // q for s, q in zip(original_shape, self.hparams.query_shape) ] final_result_shape = utils.shape_list(final_result) final_result = tf.reshape( final_result, [final_result_shape[0], -1, np.prod(self.hparams.query_shape), 1]) final_logits_shape = utils.shape_list(final_logits) final_logits = tf.reshape(final_logits, [ final_logits_shape[0], -1, np.prod(self.hparams.query_shape), final_logits_shape[-1] ]) final_result = utils.unflatten_blocks_nd(final_result, blocks_per_dim) final_result = utils.put_back_blocks_nd(final_result, self.hparams.query_shape) final_logits = utils.unflatten_blocks_nd(final_logits, blocks_per_dim) final_logits = utils.put_back_blocks_nd(final_logits, self.hparams.query_shape) for name, value in decoding_stats.items(): tf.summary.scalar("decodes/%s" % name, value / decode_length) # Reassign targets back to the previous value. if targets_old is not None: features["targets"] = targets_old return { "outputs": final_result, "scores": None, "logits": final_logits, "losses": None, }
def lstm_decoder_infer(self, inputs, sequence_length, hparams, clss, train, initial_state=None, bottleneck=None): # IN PREDICT MODE, RUN tf.while RNN max_decode_length = 51 batch_size = common_layers.shape_list(inputs)[0] zero_pad, logits_so_far = self.create_initial_input_for_decode( batch_size) layers = contrib_rnn.MultiRNNCell([ self.lstm_cell(hparams, train) for _ in range(hparams.num_hidden_layers) ]) if initial_state is None: raise Exception('initial state should be init from bottleneck!') # append one-hot class to bottleneck, which will be given per step clss = tf.reshape(clss, [-1]) if not hparams.use_cls: clss = tf.zeros_like(clss) if hparams.condition_on_sln: sln = tf.reshape(sequence_length, [-1]) bottleneck = tf.concat( (bottleneck, tf.one_hot(clss, hparams.num_categories), tf.one_hot(sln, max_decode_length)), -1) else: bottleneck = tf.concat( (bottleneck, tf.one_hot(clss, hparams.num_categories)), -1) def infer_step(logits_so_far, current_hidden): """Inference step of LSTM while loop.""" # unflatten hidden: current_hidden = tuple( tf.nn.rnn_cell.LSTMStateTuple(c=s[0], h=s[1]) for s in current_hidden) # put logits_so_far through top tm = self._problem_hparams.modality['targets'] # need to reuse top params reset_scope = tf.variable_scope(tf.VariableScope( tf.AUTO_REUSE, ''), reuse=tf.AUTO_REUSE, auxiliary_name_scope=False) top_scope = tf.variable_scope('svg_decoder/{}_modality'.format(tm), reuse=tf.AUTO_REUSE) with reset_scope, top_scope: samples_so_far = self.hparams.top['targets']( logits_so_far, None, self.hparams, self.problem_hparams.vocab_size) # append a zero pad to the samples. this effectively shifts the samples # right, but, unlike shift_right, by not removing the last element, we # allow an empty samples_so_far to not be empty after padding samples_so_far = tf.concat([zero_pad, samples_so_far], axis=1) shifted_targets = common_layers.flatten4d3d(samples_so_far) # now take the very last one here, will be the actual input to the rnn shifted_targets = shifted_targets[:, -1:, :] # tile and append the bottleneck to inputs sln_offset = 0 if hparams.condition_on_sln: sln_offset = 51 pre_tile_y = tf.reshape(bottleneck, [ common_layers.shape_list(bottleneck)[0], 1, hparams.bottleneck_bits + hparams.num_categories + sln_offset ]) overlay_x = tf.tile( pre_tile_y, [1, common_layers.shape_list(shifted_targets)[1], 1]) inputs = tf.concat([shifted_targets, overlay_x], -1) seq_len_batch = tf.ones([common_layers.shape_list(inputs)[0]]) # RUN PRE-LSTM LAYER with tf.variable_scope('pre_decoder', reuse=tf.AUTO_REUSE): inputs = tf.layers.dense(inputs, hparams.hidden_size, name='bottom') inputs = tf.nn.tanh(inputs) # RUN LSTM with tf.variable_scope('lstm_decoder', reuse=tf.AUTO_REUSE): next_step, next_state = tf.nn.dynamic_rnn( layers, inputs, seq_len_batch, initial_state=current_hidden, dtype=tf.float32, time_major=False) next_step = tf.expand_dims(next_step, [1]) logits_so_far = tf.concat([logits_so_far, next_step], 1) #print('concat success') # input() # flatten state next_state = tuple((s.c, s.h) for s in next_state) return logits_so_far, next_state def while_exit_cond(logits_so_far, unused_current_hidden): length = common_layers.shape_list(logits_so_far)[1] return length < max_decode_length # passing state must be flattened: initial_state = tuple([(s.c, s.h) for s in initial_state]) # actually run tf.while: logits, final_state = tf.while_loop( while_exit_cond, infer_step, [logits_so_far, initial_state], shape_invariants=[ tf.TensorShape([None, None, 1, hparams.hidden_size]), tuple([(s[0].get_shape(), s[1].get_shape()) for s in initial_state]), ], back_prop=False, parallel_iterations=1) # logits should be returned in 3d mode: logits = common_layers.flatten4d3d(logits) return logits, final_state
def sorted_non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold): """A wrapper that handles non-maximum suppression. Assumption: * The boxes are sorted by scores unless the box is a dot (all coordinates are zero). * Boxes with higher scores can be used to suppress boxes with lower scores. The overal design of the algorithm is to handle boxes tile-by-tile: boxes = boxes.pad_to_multiply_of(tile_size) num_tiles = len(boxes) // tile_size output_boxes = [] for i in range(num_tiles): box_tile = boxes[i*tile_size : (i+1)*tile_size] for j in range(i - 1): suppressing_tile = boxes[j*tile_size : (j+1)*tile_size] iou = bbox_overlap(box_tile, suppressing_tile) # if the box is suppressed in iou, clear it to a dot box_tile *= _update_boxes(iou) # Iteratively handle the diagnal tile. iou = _box_overlap(box_tile, box_tile) iou_changed = True while iou_changed: # boxes that are not suppressed by anything else suppressing_boxes = _get_suppressing_boxes(iou) # boxes that are suppressed by suppressing_boxes suppressed_boxes = _get_suppressed_boxes(iou, suppressing_boxes) # clear iou to 0 for boxes that are suppressed, as they cannot be used # to suppress other boxes any more new_iou = _clear_iou(iou, suppressed_boxes) iou_changed = (new_iou != iou) iou = new_iou # remaining boxes that can still suppress others, are selected boxes. output_boxes.append(_get_suppressing_boxes(iou)) if len(output_boxes) >= max_output_size: break Args: scores: a tensor with a shape of [batch_size, anchors]. boxes: a tensor with a shape of [batch_size, anchors, 4]. max_output_size: a scalar integer `Tensor` representing the maximum number of boxes to be selected by non max suppression. iou_threshold: a float representing the threshold for deciding whether boxes overlap too much with respect to IOU. Returns: nms_scores: a tensor with a shape of [batch_size, anchors]. It has same dtype as input scores. nms_proposals: a tensor with a shape of [batch_size, anchors, 4]. It has same dtype as input boxes. """ batch_size = tf.shape(boxes)[0] num_boxes = tf.shape(boxes)[1] pad = tf.cast(tf.ceil(tf.cast(num_boxes, tf.float32) / NMS_TILE_SIZE), tf.int32) * NMS_TILE_SIZE - num_boxes boxes = tf.pad(tf.cast(boxes, tf.float32), [[0, 0], [0, pad], [0, 0]]) scores = tf.pad(tf.cast(scores, tf.float32), [[0, 0], [0, pad]]) num_boxes += pad def _loop_cond(unused_boxes, unused_threshold, output_size, idx): return tf.logical_and( tf.reduce_min(output_size) < max_output_size, idx < num_boxes // NMS_TILE_SIZE) selected_boxes, _, output_size, _ = tf.while_loop( _loop_cond, _suppression_loop_body, [ boxes, iou_threshold, tf.zeros([batch_size], tf.int32), tf.constant(0) ]) idx = num_boxes - tf.cast( tf.nn.top_k( tf.cast(tf.reduce_any(selected_boxes > 0, [2]), tf.int32) * tf.expand_dims(tf.range(num_boxes, 0, -1), 0), max_output_size)[0], tf.int32) idx = tf.minimum(idx, num_boxes - 1) idx = tf.reshape( idx + tf.reshape(tf.range(batch_size) * num_boxes, [-1, 1]), [-1]) boxes = tf.reshape(tf.gather(tf.reshape(boxes, [-1, 4]), idx), [batch_size, max_output_size, 4]) boxes = boxes * tf.cast( tf.reshape(tf.range(max_output_size), [1, -1, 1]) < tf.reshape( output_size, [-1, 1, 1]), boxes.dtype) scores = tf.reshape(tf.gather(tf.reshape(scores, [-1, 1]), idx), [batch_size, max_output_size]) scores = scores * tf.cast( tf.reshape(tf.range(max_output_size), [1, -1]) < tf.reshape( output_size, [-1, 1]), scores.dtype) return scores, boxes
def compute_gradients(self, loss, var_list, gate_gradients=GATE_OP, aggregation_method=None, colocate_gradients_with_ops=False, grad_loss=None, gradient_tape=None): if callable(loss): # TF is running in Eager mode, check we received a vanilla tape. if not gradient_tape: raise ValueError( 'When in Eager mode, a tape needs to be passed.') vector_loss = loss() if self._num_microbatches is None: self._num_microbatches = tf.shape(input=vector_loss)[0] sample_state = self._dp_sum_query.initial_sample_state( var_list) microbatches_losses = tf.reshape(vector_loss, [self._num_microbatches, -1]) sample_params = (self._dp_sum_query.derive_sample_params( self._global_state)) def process_microbatch(i, sample_state): """Process one microbatch (record) with privacy helper.""" microbatch_loss = tf.reduce_mean( input_tensor=tf.gather(microbatches_losses, [i])) grads = gradient_tape.gradient(microbatch_loss, var_list) sample_state = self._dp_sum_query.accumulate_record( sample_params, sample_state, grads) return sample_state for idx in range(self._num_microbatches): sample_state = process_microbatch(idx, sample_state) grad_sums, self._global_state = ( self._dp_sum_query.get_noised_result( sample_state, self._global_state)) def normalize(v): return v / tf.cast(self._num_microbatches, tf.float32) final_grads = tf.nest.map_structure(normalize, grad_sums) grads_and_vars = list(zip(final_grads, var_list)) return grads_and_vars else: # TF is running in graph mode, check we did not receive a gradient tape. if gradient_tape: raise ValueError( 'When in graph mode, a tape should not be passed.') # Note: it would be closer to the correct i.i.d. sampling of records if # we sampled each microbatch from the appropriate binomial distribution, # although that still wouldn't be quite correct because it would be # sampling from the dataset without replacement. if self._num_microbatches is None: self._num_microbatches = tf.shape(input=loss)[0] microbatches_losses = tf.reshape(loss, [self._num_microbatches, -1]) sample_params = (self._dp_sum_query.derive_sample_params( self._global_state)) def process_microbatch(i, sample_state): """Process one microbatch (record) with privacy helper.""" grads, _ = zip(*super(cls, self).compute_gradients( tf.reduce_mean( input_tensor=tf.gather(microbatches_losses, [i])), var_list, gate_gradients, aggregation_method, colocate_gradients_with_ops, grad_loss)) grads_list = [ g if g is not None else tf.zeros_like(v) for (g, v) in zip(list(grads), var_list) ] sample_state = self._dp_sum_query.accumulate_record( sample_params, sample_state, grads_list) return sample_state if var_list is None: var_list = (tf.trainable_variables() + tf.get_collection( tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) sample_state = self._dp_sum_query.initial_sample_state( var_list) if self._unroll_microbatches: for idx in range(self._num_microbatches): sample_state = process_microbatch(idx, sample_state) else: # Use of while_loop here requires that sample_state be a nested # structure of tensors. In general, we would prefer to allow it to be # an arbitrary opaque type. cond_fn = lambda i, _: tf.less(i, self._num_microbatches) body_fn = lambda i, state: [tf.add(i, 1), process_microbatch(i, state)] # pylint: disable=line-too-long idx = tf.constant(0) _, sample_state = tf.while_loop( cond=cond_fn, body=body_fn, loop_vars=[idx, sample_state]) grad_sums, self._global_state = ( self._dp_sum_query.get_noised_result( sample_state, self._global_state)) def normalize(v): return tf.truediv( v, tf.cast(self._num_microbatches, tf.float32)) final_grads = tf.nest.map_structure(normalize, grad_sums) return list(zip(final_grads, var_list))
def task_metalearn(inp, reuse=True): """Run meta learning.""" TRAIN = 'train' in prefix # pylint: disable=invalid-name # Perform gradient descent for one task in the meta-batch. inputa, inputb, labela, labelb = inp task_outputbs, task_lossesb = [], [] task_msesb = [] # support_pred and loss, (n_data_per_task, out_dim) task_outputa = self.forward( inputa, weights, reuse=reuse) # only not reuse on the first iter # labela is (n_data_per_task, out_dim) task_lossa = self.loss_func(task_outputa, labela) # INNER LOOP (no change with ib) grads = tf.gradients(task_lossa, list(weights.values())) if FLAGS.stop_grad: grads = [tf.stop_gradient(grad) for grad in grads] gradients = dict(zip(weights.keys(), grads)) ## theta_pi = theta - alpha * grads fast_weights = dict( zip(weights.keys(), [ weights[key] - self.update_lr * gradients[key] for key in weights.keys() ])) # use theta_pi to forward meta-test output = self.forward(inputb, fast_weights, reuse=True) task_outputbs.append(output) # meta-test loss task_msesb.append(self.loss_func(output, labelb)) task_lossesb.append(self.loss_func(output, labelb)) def while_body(fast_weights_values): """Update params.""" loss = self.loss_func( self.forward( inputa, dict(zip(fast_weights.keys(), fast_weights_values)), reuse=True), labela) grads = tf.gradients(loss, fast_weights_values) fast_weights_values = [ v - self.update_lr * g for v, g in zip(fast_weights_values, grads) ] return fast_weights_values fast_weights_values = tf.while_loop( lambda _: True, while_body, loop_vars=[fast_weights.values()], maximum_iterations=num_updates - 1, back_prop=TRAIN) fast_weights = dict(zip(fast_weights.keys(), fast_weights_values)) output = self.forward(inputb, fast_weights, reuse=True) task_outputbs.append(output) task_msesb.append(self.loss_func(output, labelb)) task_lossesb.append(self.loss_func(output, labelb)) task_output = [ task_outputa, task_outputbs, task_lossa, task_lossesb, task_msesb ] return task_output
def sample_sequence( hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, top_p=0.0, ): if start_token is None: assert (context is not None), 'Specify exactly one of start_token and context!' else: assert (context is None), 'Specify exactly one of start_token and context!' context = tf.fill([batch_size, 1], start_token) def step(hparams, tokens, past=None): lm_output = gpt2_model.model(hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE) logits = lm_output['logits'][:, :, :hparams.n_vocab] presents = lm_output['present'] presents.set_shape( gpt2_model.past_shape(hparams=hparams, batch_size=batch_size)) return {'logits': logits, 'presents': presents} with tf.name_scope('sample_sequence'): context_output = step(hparams, context[:, :-1]) def body(past, prev, output): next_outputs = step(hparams, prev[:, tf.newaxis], past=past) logits = next_outputs['logits'][:, -1, :] / tf.cast( temperature, tf.float32) if top_p > 0.0: logits = top_p_logits(logits, p=top_p) else: logits = top_k_logits(logits, k=top_k) samples = tf.random.categorical(logits, num_samples=1, dtype=tf.int32) return [ tf.concat([past, next_outputs['presents']], axis=-2), tf.squeeze(samples, axis=[1]), tf.concat([output, samples], axis=1), ] def cond(*args): return True _, _, tokens = tf.while_loop( cond=cond, body=body, maximum_iterations=length, loop_vars=[context_output['presents'], context[:, -1], context], shape_invariants=[ tf.TensorShape( gpt2_model.past_shape(hparams=hparams, batch_size=batch_size)), tf.TensorShape([batch_size]), tf.TensorShape([batch_size, None]), ], back_prop=False, ) return tokens
def _build_adapted_parameters( self, inputs, labels, initial_parameters, num_steps, back_prop=False, parallel_iterations=1, shuffle=True, ): """Builds adapted model parameters dynamically using tf.while_loop. Parameters ---------- inputs : Tensor <float32> [None, ...] Inputs of the samples used for building adapted parameters. labels : Tensor <float32> [None, num_classes] Labels of the samples used for building adapted parameters. initial_parameters : dict of Tensors A dictionary with initial parameters of the model. num_steps : int or Tensor <int32> [] Number of gradient steps used for adaptation. back_prop : bool, optional (default: False) Indicates whether backprop is allowed through the adapted parameters. parallel_iterations : int, optional (default=1) Parallel iterations parameter for the tf.while_loop. shuffle : bool, optional (default=True) Whether to shuffle the samples before batching. Returns ------- adapted_parameters : dict of Tensors A dictionary with adapted parameters of the model. """ # If batch size not specified, use all inputs. batch_size = self.batch_size or tf.shape(inputs)[0] # Build batched indices. # <int32> [batch_size * num_steps]. indices = tf.math.mod(tf.range(batch_size * num_steps, dtype=tf.int32), tf.shape(inputs)[0]) if shuffle: indices = tf.random.shuffle(indices) # <int32> [num_steps, batch_size]. batched_indices = tf.reshape(indices, shape=(num_steps, batch_size)) def cond_fn(step, _unused_params): return tf.less(step, num_steps) def body_fn(step, parameters): x = tf.gather(inputs, batched_indices[step], axis=0) y = tf.gather(labels, batched_indices[step], axis=0) # Build a model with new parameters. with utils.custom_make_variable(parameters, self.model.name): self.inner_adapted_models.append(self.model_builder()) loss = self.inner_adapted_models[-1].loss(x, y) # Build new parameters. new_parameters = utils.build_new_parameters( loss, parameters, optimizer=self.inner_optimizer, first_order=self.first_order, ) return [tf.add(step, 1), new_parameters] _, adapted_parameters = tf.while_loop( cond=cond_fn, body=body_fn, loop_vars=[tf.constant(0), initial_parameters], parallel_iterations=parallel_iterations, back_prop=back_prop, name="adapt", ) return adapted_parameters
def multiply2n_ragged(tensor1, tensor2): #this function multiplies two ragged tesnsors of rank 2 . the most outer ranks of the two tensros must be equal . #setting variables and constats outerloop_counter = tf.constant(0, dtype=tf.int32) carry_on = tf.constant(0, dtype=tf.int32) taValues = tf.TensorArray(tf.float32, size=0, dynamic_size=True, clear_after_read=False, infer_shape=False) taL2Splits = tf.TensorArray(tf.int32, size=0, dynamic_size=True, clear_after_read=False, infer_shape=False) taL1Splits = tf.TensorArray(tf.int32, size=0, dynamic_size=True, clear_after_read=False, infer_shape=False) taL1Splits = taL1Splits.write( 0, [0]) ## required intialization for L1 split only innerloop_processing_graphed = tf.function(innerloop_processing) generateL1Tensor_writeback_graphed = tf.function( generateL1Tensor_writeback) def outerloop_cond(counter, input1, input2, taValues, taL2Splits, taL1Splits, carry_on): value = tf.shape(input1[2])[0] - 1 return counter < value ## this is the length of the outermost dimision , stop of this def outloop_body(counter, input1, input2, taValues, taL2Splits, taL1Splits, carry_on): l1_comp_begin = input1[2][ counter] ## this is begin position of the current row in the outer split ( ie. the ith value in the outer row split tensor ) l1_comp_end = input1[2][ counter + 1] ## this is end position of the current row in the outer split (ie. the ith + 1 value in the outer row split tensor) l1_comp2_begin = input2[2][ counter] ## we do the same for the second components l1_comp2_end = input2[2][ counter + 1] ## we do the same for the second components comp = innerloop_processing_graphed( l1_comp_begin, l1_comp_end, input1 ) ## now retrive the data to be procesed for the selected rows from vector1 comp2 = innerloop_processing_graphed( l1_comp2_begin, l1_comp2_end, input2) ## do the same for vector 2 #comp2 = tf.transpose(comp2) ### desired operation multiply = tf.matmul(comp, comp2) #### This is the desired operation myshape = tf.shape( multiply ) ## calculate the shape of the result in order to prepare to write the result in a ragged tensor format. offset = tf.cond( taValues.size() > 0, lambda: tf.shape(taValues.concat())[0], lambda: 0 ) ### this is a hack, TensorArray.concat returns an error if the array is empty. Thus we check before calling this. #print11=tf.print("=================Final Shape is : " ,myshape[0] , " X " ,myshape[1] ) l2v = generateL1Tensor_writeback_graphed( offset, myshape[1], myshape[0] ) # generate the inner row split of the result for the current element taL2Splits = taL2Splits.write( counter, l2v) # write back the inner rowlplit to a TensorArray taValues = taValues.write( counter, tf.reshape(multiply, [-1]) ) # wirte back the actual ragged tensor elemnts in a another TensorArray carry_on = carry_on + myshape[ 0] ## required to calculate the outer row splite taL1Splits = taL1Splits.write( counter + 1, [carry_on]) ## This is the outmost row split. with tf.control_dependencies( [comp, comp2, myshape, l2v, carry_on, multiply]): counter = counter + 1 return counter, input1, input2, taValues, taL2Splits, taL1Splits, carry_on with tf.name_scope("RaggedMultiply"): outerloop_finalcounter, _, _, ta1, ta2, ta3, _ = tf.while_loop( outerloop_cond, outloop_body, [ outerloop_counter, tensor1, tensor2, taValues, taL2Splits, taL1Splits, carry_on ], back_prop=True) uinquie_ta2, _ = tf.unique( ta2.concat() ) # this is required since some values might be duplicate in the row split itself t1 = ta1.concat() t3 = ta3.concat() #with tf.control_dependencies([t1 , uinquie_ta2 ,t3 ]): final_values = t1, uinquie_ta2, t3 return final_values
def _create_cross_entropy_action_tensors(self, num_samples=200, top_k_portion=0.5): """Create tensorflow operations for cross_entropy max_actions.""" top_k_num = int(top_k_portion * num_samples) self._dynamic_batch_size = tf.placeholder(dtype=tf.int32, name="dynamic_batch_size") self._action_init_tensor = tf.placeholder(dtype=tf.float32, name="action_init_tensor", shape=(None, self.action_dim)) self._tolerance_tensor = tf.placeholder(dtype=tf.float32, name="tolerance_tensor", shape=()) sample_mean_init = self._action_init_tensor sample_covariance_diag_init = tf.ones_like(self._action_init_tensor) top_k_value_init = tf.constant( [np.inf]) * tf.ones(shape=(self._dynamic_batch_size, 1)) top_k_action_samples_init = tf.tile( tf.expand_dims(tf.zeros_like(self._action_init_tensor), axis=1), [1, top_k_num, 1]) random_sampler = tfp.distributions.MultivariateNormalDiag( loc=np.zeros(self.action_dim), scale_diag=np.ones(self.action_dim)) def cond_cross_entropy(itr, cond_terminate, sample_mean, sample_covariance_diag, top_k_value, top_k_action_samples): del sample_mean, sample_covariance_diag, top_k_value, top_k_action_samples cond_1 = tf.math.less(itr, self.action_maximization_iterations) return tf.math.logical_and(cond_1, tf.logical_not(cond_terminate)) def body_cross_entropy(itr, cond_terminate, sample_mean, sample_covariance_diag, top_k_value, top_k_action_samples): """Function for cross entropy search of actions.""" del top_k_action_samples top_k_value_prev = top_k_value batch_sample_mean = tf.reshape( tf.tile(sample_mean, [1, num_samples]), [self._dynamic_batch_size * num_samples, self.action_dim]) batch_sample_covariance_diag = tf.reshape( tf.tile(sample_covariance_diag, [1, num_samples]), [self._dynamic_batch_size * num_samples, self.action_dim]) action_samples = self._action_projection( batch_sample_mean + batch_sample_covariance_diag * tf.cast(random_sampler.sample( sample_shape=[self._dynamic_batch_size * num_samples]), dtype=tf.float32)) state_samples = tf.reshape( tf.tile(self._state_tensor, [1, num_samples]), [self._dynamic_batch_size * num_samples, self.state_dim]) action_samples = tf.reshape( action_samples, [self._dynamic_batch_size * num_samples, self.action_dim]) values = tf.reshape( self._build_q_function_net(state_samples, action_samples), [self._dynamic_batch_size, num_samples]) # everything is in batch mode top_k_index = tf.argsort(values, axis=1, direction="DESCENDING")[:, 0:top_k_num] top_k_index_1d = tf.reshape( top_k_index, [self._dynamic_batch_size * top_k_num, 1]) counter_tensor_1d = tf.reshape( tf.tile( tf.reshape(tf.range(self._dynamic_batch_size), [self._dynamic_batch_size, 1]), [1, top_k_num]), [self._dynamic_batch_size * top_k_num, 1]) top_k_index_2d = tf.concat([counter_tensor_1d, top_k_index_1d], axis=1) action_samples = tf.reshape( action_samples, [self._dynamic_batch_size, num_samples, self.action_dim]) top_k_action_samples = tf.gather_nd(action_samples, top_k_index_2d) top_k_action_samples = tf.reshape( top_k_action_samples, [self._dynamic_batch_size, top_k_num, self.action_dim]) top_k_values = tf.gather_nd(values, top_k_index_2d) top_k_values = tf.reshape(top_k_values, [self._dynamic_batch_size, top_k_num]) # it's a batch_size x 1 tensor top_k_value = tf.reshape(tf.reduce_mean(top_k_values, axis=1), [self._dynamic_batch_size, 1]) sample_mean = tf.reduce_mean(top_k_action_samples, axis=1) sample_covariance_diag = tf.math.reduce_variance( top_k_action_samples, axis=1) itr = itr + 1 cond_terminate = tf.less_equal( tf.reduce_mean(tf.math.abs(top_k_value - top_k_value_prev)), self._tolerance_tensor) return itr, cond_terminate, sample_mean, sample_covariance_diag, \ top_k_value, top_k_action_samples self.cost_optimizer = tf.while_loop( cond_cross_entropy, body_cross_entropy, [ tf.constant(0), tf.constant(False), sample_mean_init, sample_covariance_diag_init, top_k_value_init, top_k_action_samples_init ])
def run(params, y_data_test, siz_x_data, y_normscale, load_dir): multi_modal = True # USEFUL SIZES xsh1 = siz_x_data if params['by_channel'] == True: ysh0 = np.shape(y_data_test)[0] ysh1 = np.shape(y_data_test)[1] else: ysh0 = np.shape(y_data_test)[1] ysh1 = np.shape(y_data_test)[2] z_dimension = params['z_dimension'] n_weights_r1 = params['n_weights_r1'] n_weights_r2 = params['n_weights_r2'] n_weights_q = params['n_weights_q'] n_modes = params['n_modes'] n_hlayers_r1 = len(params['n_weights_r1']) n_hlayers_r2 = len(params['n_weights_r2']) n_hlayers_q = len(params['n_weights_q']) n_conv_r1 = len(params['n_filters_r1']) n_conv_r2 = len(params['n_filters_r2']) n_conv_q = len(params['n_filters_q']) n_filters_r1 = params['n_filters_r1'] n_filters_r2 = params['n_filters_r2'] n_filters_q = params['n_filters_q'] filter_size_r1 = params['filter_size_r1'] filter_size_r2 = params['filter_size_r2'] filter_size_q = params['filter_size_q'] n_convsteps = params['n_convsteps'] batch_norm = params['batch_norm'] red = params['reduce'] if n_convsteps != None: ysh_conv_r1 = int(ysh1*n_filters_r1/2**n_convsteps) if red==True else int(ysh1/2**n_convsteps) ysh_conv_r2 = int(ysh1*n_filters_r2/2**n_convsteps) if red==True else int(ysh1/2**n_convsteps) ysh_conv_q = int(ysh1*n_filters_q/2**n_convsteps) if red==True else int(ysh1/2**n_convsteps) else: ysh_conv_r1 = int(ysh1) ysh_conv_r2 = int(ysh1) ysh_conv_q = int(ysh1) drate = params['drate'] maxpool_r1 = params['maxpool_r1'] maxpool_r2 = params['maxpool_r2'] maxpool_q = params['maxpool_q'] conv_strides_r1 = params['conv_strides_r1'] conv_strides_r2 = params['conv_strides_r2'] conv_strides_q = params['conv_strides_q'] pool_strides_r1 = params['pool_strides_r1'] pool_strides_r2 = params['pool_strides_r2'] pool_strides_q = params['pool_strides_q'] if params['reduce'] == True or n_filters_r1 != None: if params['by_channel'] == True: num_det = np.shape(y_data_test)[2] else: num_det = ysh0 else: num_det = None # identify the indices of different sets of physical parameters vonmise_mask, vonmise_idx_mask, vonmise_len = get_param_index(params['inf_pars'],params['vonmise_pars']) gauss_mask, gauss_idx_mask, gauss_len = get_param_index(params['inf_pars'],params['gauss_pars']) sky_mask, sky_idx_mask, sky_len = get_param_index(params['inf_pars'],params['sky_pars']) ra_mask, ra_idx_mask, ra_len = get_param_index(params['inf_pars'],['ra']) dec_mask, dec_idx_mask, dec_len = get_param_index(params['inf_pars'],['dec']) m1_mask, m1_idx_mask, m1_len = get_param_index(params['inf_pars'],['mass_1']) m2_mask, m2_idx_mask, m2_len = get_param_index(params['inf_pars'],['mass_2']) idx_mask = np.argsort(gauss_idx_mask + vonmise_idx_mask + m1_idx_mask + m2_idx_mask + sky_idx_mask) # + dist_idx_mask) masses_len = m1_len + m2_len graph = tf.Graph() session = tf.Session(graph=graph) with graph.as_default(): tf.set_random_seed(np.random.randint(0,10)) SMALL_CONSTANT = 1e-12 # PLACEHOLDERS bs_ph = tf.placeholder(dtype=tf.int64, name="bs_ph") # batch size placeholder y_ph = tf.placeholder(dtype=tf.float32, shape=[None, params['ndata'], num_det], name="y_ph") # LOAD VICI NEURAL NETWORKS r2_xzy = VICI_decoder.VariationalAutoencoder('VICI_decoder', vonmise_mask, gauss_mask, m1_mask, m2_mask, sky_mask, n_input1=z_dimension, n_input2=params['ndata'], n_output=xsh1, n_channels=num_det, n_weights=n_weights_r2, drate=drate, n_filters=n_filters_r2, filter_size=filter_size_r2, maxpool=maxpool_r2) r1_zy = VICI_encoder.VariationalAutoencoder('VICI_encoder', n_input=params['ndata'], n_output=z_dimension, n_channels=num_det, n_weights=n_weights_r1, # generates params for r1(z|y) n_modes=n_modes, drate=drate, n_filters=n_filters_r1, filter_size=filter_size_r1, maxpool=maxpool_r1) q_zxy = VICI_VAE_encoder.VariationalAutoencoder('VICI_VAE_encoder', n_input1=xsh1, n_input2=params['ndata'], n_output=z_dimension, n_channels=num_det, n_weights=n_weights_q, drate=drate, n_filters=n_filters_q, filter_size=filter_size_q, maxpool=maxpool_q) # reduce the y data size y_conv = y_ph # GET r1(z|y) r1_loc, r1_scale, r1_weight = r1_zy._calc_z_mean_and_sigma(y_conv) temp_var_r1 = SMALL_CONSTANT + tf.exp(r1_scale) # define the r1(z|y) mixture model bimix_gauss = tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical(logits=r1_weight), components_distribution=tfd.MultivariateNormalDiag( loc=r1_loc, scale_diag=tf.sqrt(temp_var_r1))) # DRAW FROM r1(z|y) r1_zy_samp = bimix_gauss.sample() # GET r2(x|z,y) from r1(z|y) samples reconstruction_xzy = r2_xzy.calc_reconstruction(r1_zy_samp,y_ph) # ugly but needed for now # extract the means and variances of the physical parameter distributions r2_xzy_mean_gauss = reconstruction_xzy[0] r2_xzy_log_sig_sq_gauss = reconstruction_xzy[1] r2_xzy_mean_vonmise = reconstruction_xzy[2] r2_xzy_log_sig_sq_vonmise = reconstruction_xzy[3] r2_xzy_mean_m1 = reconstruction_xzy[4] r2_xzy_log_sig_sq_m1 = reconstruction_xzy[5] r2_xzy_mean_m2 = reconstruction_xzy[6] r2_xzy_log_sig_sq_m2 = reconstruction_xzy[7] r2_xzy_mean_sky = reconstruction_xzy[8] r2_xzy_log_sig_sq_sky = reconstruction_xzy[9] # draw from r2(x|z,y) - the masses temp_var_r2_m1 = SMALL_CONSTANT + tf.exp(r2_xzy_log_sig_sq_m1) # the m1 variance temp_var_r2_m2 = SMALL_CONSTANT + tf.exp(r2_xzy_log_sig_sq_m2) # the m2 variance joint = tfd.JointDistributionSequential([ tfd.Independent(tfd.TruncatedNormal(r2_xzy_mean_m1,tf.sqrt(temp_var_r2_m1),0,1,validate_args=True,allow_nan_stats=True),reinterpreted_batch_ndims=0), # m1 lambda b0: tfd.Independent(tfd.TruncatedNormal(r2_xzy_mean_m2,tf.sqrt(temp_var_r2_m2),0,b0,validate_args=True,allow_nan_stats=True),reinterpreted_batch_ndims=0)], # m2 validate_args=True) r2_xzy_samp_masses = tf.transpose(tf.reshape(joint.sample(),[2,-1])) # sample from the m1.m2 space # draw from r2(x|z,y) - the truncated gaussian temp_var_r2_gauss = SMALL_CONSTANT + tf.exp(r2_xzy_log_sig_sq_gauss) @tf.function # make this s a tensorflow function def truncnorm(idx,output): # we set up a function that adds the log-likelihoods and also increments the counter loc = tf.slice(r2_xzy_mean_gauss,[0,idx],[-1,1]) # take each specific parameter mean using slice std = tf.sqrt(tf.slice(temp_var_r2_gauss,[0,idx],[-1,1])) # take each specific parameter std using slice tn = tfd.TruncatedNormal(loc,std,0.0,1.0) # define the truncated Gaussian distribution return [idx+1, tf.concat([output,tf.reshape(tn.sample(),[bs_ph,1])],axis=1)] # return the updated index and new samples concattenated to the input # we do the loop until we've hit all the truncated gaussian parameters - i starts at 0 and the samples starts with a set of zeros that we cut out later idx = tf.constant(0) # initialise counter nsamp = params['n_samples'] # define the number of samples (MUST be a normal int NOT tensor so can't use bs_ph) output = tf.zeros([nsamp,1],dtype=tf.float32) # initialise the output (we cut this first set of zeros out later condition = lambda i,output: i<gauss_len # define the while loop stopping condition _,r2_xzy_samp_gauss = tf.while_loop(condition, truncnorm, loop_vars=[idx,output],shape_invariants=[idx.get_shape(), tf.TensorShape([nsamp,None])]) r2_xzy_samp_gauss = tf.slice(tf.reshape(r2_xzy_samp_gauss,[-1,gauss_len+1]),[0,1],[-1,-1]) # cut out the actual samples - delete the initial vector of zeros # draw from r2(x|z,y) - the vonmises part temp_var_r2_vonmise = SMALL_CONSTANT + tf.exp(r2_xzy_log_sig_sq_vonmise) con = tf.reshape(tf.math.reciprocal(temp_var_r2_vonmise),[-1,vonmise_len]) # modelling wrapped scale output as log variance von_mises = tfp.distributions.VonMises(loc=2.0*np.pi*(r2_xzy_mean_vonmise-0.5), concentration=con) r2_xzy_samp_vonmise = tf.reshape(von_mises.sample()/(2.0*np.pi) + 0.5,[-1,vonmise_len]) # sample from the von mises distribution and shift and scale from -pi-pi to 0-1 # draw from r2(x|z,y) - the von mises Fisher temp_var_r2_sky = SMALL_CONSTANT + tf.exp(r2_xzy_log_sig_sq_sky) con = tf.reshape(tf.math.reciprocal(temp_var_r2_sky),[bs_ph]) # modelling wrapped scale output as log variance - only 1 concentration parameter for all sky von_mises_fisher = tfp.distributions.VonMisesFisher( mean_direction=tf.math.l2_normalize(tf.reshape(r2_xzy_mean_sky,[bs_ph,3]),axis=1), concentration=con) # define p_vm(2*pi*mu,con=1/sig^2) xyz = tf.reshape(von_mises_fisher.sample(),[bs_ph,3]) # sample the distribution samp_ra = tf.math.floormod(tf.atan2(tf.slice(xyz,[0,1],[-1,1]),tf.slice(xyz,[0,0],[-1,1])),2.0*np.pi)/(2.0*np.pi) # convert to the rescaled 0->1 RA from the unit vector samp_dec = (tf.asin(tf.slice(xyz,[0,2],[-1,1])) + 0.5*np.pi)/np.pi # convert to the rescaled 0->1 dec from the unit vector r2_xzy_samp_sky = tf.reshape(tf.concat([samp_ra,samp_dec],axis=1),[bs_ph,2]) # group the sky samples # combine the samples r2_xzy_samp = tf.concat([r2_xzy_samp_gauss,r2_xzy_samp_vonmise,r2_xzy_samp_masses,r2_xzy_samp_sky],axis=1) r2_xzy_samp = tf.gather(r2_xzy_samp,tf.constant(idx_mask),axis=1) # VARIABLES LISTS var_list_VICI = [var for var in tf.trainable_variables() if var.name.startswith("VICI")] # INITIALISE AND RUN SESSION init = tf.initialize_all_variables() session.run(init) saver_VICI = tf.train.Saver(var_list_VICI) saver_VICI.restore(session,load_dir) # ESTIMATE TEST SET RECONSTRUCTION PER-PIXEL APPROXIMATE MARGINAL LIKELIHOOD and draw from q(x|y) ns = params['n_samples'] # number of samples to save per reconstruction y_data_test_exp = np.tile(y_data_test,(ns,1))/y_normscale y_data_test_exp = y_data_test_exp.reshape(-1,params['ndata'],num_det) run_startt = time.time() xs, mode_weights = session.run([r2_xzy_samp,r1_weight],feed_dict={bs_ph:ns,y_ph:y_data_test_exp}) run_endt = time.time() # run_startt = time.time() # xs, mode_weights = session.run([r2_xzy_samp,r1_weight],feed_dict={bs_ph:ns,y_ph:y_data_test_exp}) # run_endt = time.time() return xs, (run_endt - run_startt), mode_weights
def _greedy_decode(input_embeddings, output_vocab_size, target_end_id, target_start_id, output_vocab_embeddings_table, source_len, model_config, mode, input_copy_mask=None, clean_output_mask=None): """Fast decoding.""" encoder_output = common_layers.linear_transform( input_embeddings, output_size=model_config.model_parameters.encoder_dims, scope="bert_to_transformer") decode_length = model_config.data_options.max_decode_length # Expand the inputs in to the beam width. def symbols_to_logits_fn(logit_indices, current_index): """Go from targets to logits.""" logit_indices = tf.expand_dims(logit_indices, 0) decode_steps = decode_utils.get_decode_steps(logit_indices, output_vocab_size, model_config) target_embeddings = _get_target_embeddings( input_embeddings, output_vocab_embeddings_table, decode_steps, model_config) decoder_output = _build_transformer_decoder( encoder_output, source_len, target_embeddings, mode, model_config, single_step_index=current_index) logits = _get_action_logits(encoder_output, decoder_output, output_vocab_embeddings_table, output_vocab_size, model_config, input_copy_mask=input_copy_mask, clean_output_mask=clean_output_mask) # Squeeze batch dimension and length dimension, as both should be 1. logits = tf.squeeze(logits, axis=[0, 1]) # Shape of logits should now be (output_vocab_size). return logits def loop_cond(i, decoded_ids, unused_logprobs): """Loop conditional that returns false to stop loop.""" return tf.logical_and( tf.reduce_all(tf.not_equal(decoded_ids, target_end_id)), tf.less(i, decode_length)) def inner_loop(i, decoded_ids, logprobs): """Decoder function invoked on each while loop iteration.""" logits = symbols_to_logits_fn(decoded_ids, i) next_id = tf.argmax(logits, axis=0) softmax = tf.nn.softmax(logits) extended_vocab_size = tf.shape(softmax)[-1] mask = tf.one_hot(next_id, extended_vocab_size) prob = tf.reduce_sum(softmax * mask) logprob = tf.log(prob) # Add one-hot values to output Tensors, since values at index > i+1 should # still be zero. logprobs += tf.one_hot(i + 1, decode_length + 1, on_value=logprob, dtype=tf.float32) decoded_ids += tf.one_hot(i + 1, decode_length + 1, on_value=next_id, dtype=tf.int64) return i + 1, decoded_ids, logprobs initial_ids = tf.zeros(dtype=tf.int64, shape=[decode_length + 1]) initial_ids += tf.one_hot(0, decode_length + 1, on_value=tf.cast(target_start_id, tf.int64)) initial_logprob = tf.zeros(dtype=tf.float32, shape=[decode_length + 1]) initial_i = tf.constant(0) initial_values = [initial_i, initial_ids, initial_logprob] _, decoded_ids, logprobs = tf.while_loop(loop_cond, inner_loop, initial_values) # Remove <START> symbol. decoded_ids = decoded_ids[1:] logprobs = logprobs[1:] # Sum logprobs to get scores for overall sequence. logprobs = tf.reduce_sum(logprobs, axis=0) # Expand decoded_ids and logprobs to reflect beam width dimension of 1. decoded_ids = tf.expand_dims(decoded_ids, 0) logprobs = tf.expand_dims(logprobs, 0) # This is the output dict that the function returns. output_decode_steps = decode_utils.get_decode_steps( decoded_ids, output_vocab_size, model_config) predictions = decode_utils.get_predictions(output_decode_steps) predictions[constants.SCORES_KEY] = logprobs return predictions
def train(params, x_data, y_data, x_data_test, y_data_test, y_data_test_noisefree, y_normscale, save_dir, truth_test, bounds, fixed_vals, posterior_truth_test,snrs_test=None): # if True, do multi-modal multi_modal = True # USEFUL SIZES xsh = np.shape(x_data) ysh = np.shape(y_data)[1] n_convsteps = params['n_convsteps'] z_dimension = params['z_dimension'] bs = params['batch_size'] n_weights_r1 = params['n_weights_r1'] n_weights_r2 = params['n_weights_r2'] n_weights_q = params['n_weights_q'] n_modes = params['n_modes'] n_hlayers_r1 = len(params['n_weights_r1']) n_hlayers_r2 = len(params['n_weights_r2']) n_hlayers_q = len(params['n_weights_q']) n_conv_r1 = len(params['n_filters_r1']) n_conv_r2 = len(params['n_filters_r2']) n_conv_q = len(params['n_filters_q']) n_filters_r1 = params['n_filters_r1'] n_filters_r2 = params['n_filters_r2'] n_filters_q = params['n_filters_q'] filter_size_r1 = params['filter_size_r1'] filter_size_r2 = params['filter_size_r2'] filter_size_q = params['filter_size_q'] maxpool_r1 = params['maxpool_r1'] maxpool_r2 = params['maxpool_r2'] maxpool_q = params['maxpool_q'] conv_strides_r1 = params['conv_strides_r1'] conv_strides_r2 = params['conv_strides_r2'] conv_strides_q = params['conv_strides_q'] pool_strides_r1 = params['pool_strides_r1'] pool_strides_r2 = params['pool_strides_r2'] pool_strides_q = params['pool_strides_q'] batch_norm = params['batch_norm'] red = params['reduce'] if n_convsteps != None: ysh_conv_r1 = int(ysh*n_filters_r1/2**n_convsteps) if red==True else int(ysh/2**n_convsteps) ysh_conv_r2 = int(ysh*n_filters_r2/2**n_convsteps) if red==True else int(ysh/2**n_convsteps) ysh_conv_q = int(ysh*n_filters_q/2**n_convsteps) if red==True else int(ysh/2**n_convsteps) else: ysh_conv_r1 = int(ysh_r1) ysh_conv_r2 = int(ysh_r2) ysh_conv_q = int(ysh_q) drate = params['drate'] ramp_start = params['ramp_start'] ramp_end = params['ramp_end'] num_det = len(fixed_vals['det']) # identify the indices of different sets of physical parameters vonmise_mask, vonmise_idx_mask, vonmise_len = get_param_index(params['inf_pars'],params['vonmise_pars']) gauss_mask, gauss_idx_mask, gauss_len = get_param_index(params['inf_pars'],params['gauss_pars']) sky_mask, sky_idx_mask, sky_len = get_param_index(params['inf_pars'],params['sky_pars']) ra_mask, ra_idx_mask, ra_len = get_param_index(params['inf_pars'],['ra']) dec_mask, dec_idx_mask, dec_len = get_param_index(params['inf_pars'],['dec']) m1_mask, m1_idx_mask, m1_len = get_param_index(params['inf_pars'],['mass_1']) m2_mask, m2_idx_mask, m2_len = get_param_index(params['inf_pars'],['mass_2']) idx_mask = np.argsort(gauss_idx_mask + vonmise_idx_mask + m1_idx_mask + m2_idx_mask + sky_idx_mask) # + dist_idx_mask) graph = tf.Graph() session = tf.Session(graph=graph) with graph.as_default(): # PLACE HOLDERS bs_ph = tf.placeholder(dtype=tf.int64, name="bs_ph") # batch size placeholder x_ph = tf.placeholder(dtype=tf.float32, shape=[None, xsh[1]], name="x_ph") # params placeholder y_ph = tf.placeholder(dtype=tf.float32, shape=[None, params['ndata'], num_det], name="y_ph") ramp = tf.placeholder(dtype=tf.float32) # the ramp to slowly increase the KL contribution # LOAD VICI NEURAL NETWORKS r1_zy = VICI_encoder.VariationalAutoencoder('VICI_encoder', n_input=params['ndata'], n_output=z_dimension, n_channels=num_det, n_weights=n_weights_r1, # generates params for r1(z|y) n_modes=n_modes, drate=drate, n_filters=n_filters_r1, filter_size=filter_size_r1, maxpool=maxpool_r1) r2_xzy = VICI_decoder.VariationalAutoencoder('VICI_decoder', vonmise_mask, gauss_mask, m1_mask, m2_mask, sky_mask, n_input1=z_dimension, n_input2=params['ndata'], n_output=xsh[1], n_channels=num_det, n_weights=n_weights_r2, drate=drate, n_filters=n_filters_r2, filter_size=filter_size_r2, maxpool=maxpool_r2) q_zxy = VICI_VAE_encoder.VariationalAutoencoder('VICI_VAE_encoder', n_input1=xsh[1], n_input2=params['ndata'], n_output=z_dimension, n_channels=num_det, n_weights=n_weights_q, drate=drate, n_filters=n_filters_q, filter_size=filter_size_q, maxpool=maxpool_q) tf.set_random_seed(np.random.randint(0,10)) # reduce the y data size y_conv = y_ph # GET r1(z|y) # run inverse autoencoder to generate mean and logvar of z given y data - these are the parameters for r1(z|y) r1_loc, r1_scale, r1_weight = r1_zy._calc_z_mean_and_sigma(y_conv) temp_var_r1 = SMALL_CONSTANT + tf.exp(r1_scale) # define the r1(z|y) mixture model bimix_gauss = tfd.MixtureSameFamily( mixture_distribution=tfd.Categorical(logits=r1_weight), components_distribution=tfd.MultivariateNormalDiag( loc=r1_loc, scale_diag=tf.sqrt(temp_var_r1))) # DRAW FROM r1(z|y) - given the Gaussian parameters generate z samples r1_zy_samp = bimix_gauss.sample() # GET q(z|x,y) q_zxy_mean, q_zxy_log_sig_sq = q_zxy._calc_z_mean_and_sigma(x_ph,y_conv) # DRAW FROM q(z|x,y) temp_var_q = SMALL_CONSTANT + tf.exp(q_zxy_log_sig_sq) mvn_q = tfp.distributions.MultivariateNormalDiag( loc=q_zxy_mean, scale_diag=tf.sqrt(temp_var_q)) q_zxy_samp = mvn_q.sample() # GET r2(x|z,y) eps = tf.random.normal([bs_ph, params['ndata'], num_det], 0, 1., dtype=tf.float32) y_ph_ramp = tf.add(tf.multiply(ramp,y_conv), tf.multiply((1.0-ramp), eps)) reconstruction_xzy = r2_xzy.calc_reconstruction(q_zxy_samp,y_ph_ramp) # ugly but required for now - unpack the r2 output params r2_xzy_mean_gauss = reconstruction_xzy[0] # truncated gaussian mean r2_xzy_log_sig_sq_gauss = reconstruction_xzy[1] # truncated gaussian log var r2_xzy_mean_vonmise = reconstruction_xzy[2] # vonmises means r2_xzy_log_sig_sq_vonmise = reconstruction_xzy[3] # vonmises log var r2_xzy_mean_m1 = reconstruction_xzy[4] # m1 mean r2_xzy_log_sig_sq_m1 = reconstruction_xzy[5] # m1 var r2_xzy_mean_m2 = reconstruction_xzy[6] # m2 mean (m2 will be conditional on m1) r2_xzy_log_sig_sq_m2 = reconstruction_xzy[7] # m2 log var (m2 will be conditional on m1) r2_xzy_mean_sky = reconstruction_xzy[8] # sky mean unit vector (3D) r2_xzy_log_sig_sq_sky = reconstruction_xzy[9] # sky log var (1D) # COST FROM RECONSTRUCTION - the masses # this sets up a joint distribution on m1 and m2 with m2 being conditional on m1 # the ramp eveolves the truncation boundaries from far away to 0->1 for m1 and 0->m1 for m2 if m1_len>0 and m2_len>0: temp_var_r2_m1 = SMALL_CONSTANT + tf.exp(r2_xzy_log_sig_sq_m1) # the safe r2 variance temp_var_r2_m2 = SMALL_CONSTANT + tf.exp(r2_xzy_log_sig_sq_m2) joint = tfd.JointDistributionSequential([ # shrink the truncation with the ramp tfd.Independent(tfd.TruncatedNormal(r2_xzy_mean_m1,tf.sqrt(temp_var_r2_m1),-GAUSS_RANGE*(1.0-ramp),GAUSS_RANGE*(1.0-ramp) + 1.0),reinterpreted_batch_ndims=0), # m1 lambda b0: tfd.Independent(tfd.TruncatedNormal(r2_xzy_mean_m2,tf.sqrt(temp_var_r2_m2),-GAUSS_RANGE*(1.0-ramp),GAUSS_RANGE*(1.0-ramp) + ramp*b0),reinterpreted_batch_ndims=0)], # m2 ) reconstr_loss_masses = joint.log_prob((tf.boolean_mask(x_ph,m1_mask,axis=1),tf.boolean_mask(x_ph,m2_mask,axis=1))) # COST FROM RECONSTRUCTION - Truncated Gaussian parts # this sets up a loop over uncorreltaed truncated Gaussians # the ramp evolves the boundaries from far away to 0->1 if gauss_len>0: temp_var_r2_gauss = SMALL_CONSTANT + tf.exp(r2_xzy_log_sig_sq_gauss) gauss_x = tf.boolean_mask(x_ph,gauss_mask,axis=1) @tf.function def truncnorm(i,lp): # we set up a function that adds the log-likelihoods and also increments the counter loc = tf.slice(r2_xzy_mean_gauss,[0,i],[-1,1]) std = tf.sqrt(tf.slice(temp_var_r2_gauss,[0,i],[-1,1])) pos = tf.slice(gauss_x,[0,i],[-1,1]) tn = tfd.TruncatedNormal(loc,std,-GAUSS_RANGE*(1.0-ramp),GAUSS_RANGE*(1.0-ramp) + 1.0) # shrink the truncation with the ramp return [i+1, lp + tn.log_prob(pos)] # we do the loop until we've hit all the truncated gaussian parameters - i starts at 0 and the logprob starts at 0 _,reconstr_loss_gauss = tf.while_loop(lambda i,reconstr_loss_gauss: i<gauss_len, truncnorm, [0,tf.zeros([bs_ph],dtype=tf.dtypes.float32)]) # COST FROM RECONSTRUCTION - Von Mises parts for single parameters that wrap over 2pi if vonmise_len>0: temp_var_r2_vonmise = SMALL_CONSTANT + tf.exp(r2_xzy_log_sig_sq_vonmise) con = tf.reshape(tf.math.reciprocal(temp_var_r2_vonmise),[-1,vonmise_len]) # modelling wrapped scale output as log variance - convert to concentration von_mises = tfp.distributions.VonMises( loc=2.0*np.pi*(tf.reshape(r2_xzy_mean_vonmise,[-1,vonmise_len])-0.5), # remap 0>1 mean onto -pi->pi range concentration=con) reconstr_loss_vonmise = von_mises.log_prob(2.0*np.pi*(tf.reshape(tf.boolean_mask(x_ph,vonmise_mask,axis=1),[-1,vonmise_len]) - 0.5)) # 2pi is the von mises input range reconstr_loss_vonmise = reconstr_loss_vonmise[:,0] + reconstr_loss_vonmise[:,1] # computing Gaussian likelihood for von mises parameters to be faded away with the ramp gauss_vonmises = tfp.distributions.MultivariateNormalDiag( loc=r2_xzy_mean_vonmise, scale_diag=tf.sqrt(temp_var_r2_vonmise)) reconstr_loss_gauss_vonmise = gauss_vonmises.log_prob(tf.boolean_mask(x_ph,vonmise_mask,axis=1)) reconstr_loss_vonmise = ramp*reconstr_loss_vonmise + (1.0-ramp)*reconstr_loss_gauss_vonmise # start with a Gaussian model and fade in the true vonmises else: reconstr_loss_vonmise = 0.0 # COST FROM RECONSTRUCTION - Von Mises Fisher (sky) parts if sky_len>0: temp_var_r2_sky = SMALL_CONSTANT + tf.exp(r2_xzy_log_sig_sq_sky) con = tf.reshape(tf.math.reciprocal(temp_var_r2_sky),[bs_ph]) # modelling wrapped scale output as log variance - only 1 concentration parameter for all sky loc_xyz = tf.math.l2_normalize(tf.reshape(r2_xzy_mean_sky,[-1,3]),axis=1) # take the 3 output mean params from r2 and normalse so they are a unit vector von_mises_fisher = tfp.distributions.VonMisesFisher( mean_direction=loc_xyz, concentration=con) ra_sky = 2.0*np.pi*tf.reshape(tf.boolean_mask(x_ph,ra_mask,axis=1),[-1,1]) # convert the scaled 0->1 true RA value back to radians dec_sky = np.pi*(tf.reshape(tf.boolean_mask(x_ph,dec_mask,axis=1),[-1,1]) - 0.5) # convert the scaled 0>1 true dec value back to radians xyz_unit = tf.reshape(tf.concat([tf.cos(ra_sky)*tf.cos(dec_sky),tf.sin(ra_sky)*tf.cos(dec_sky),tf.sin(dec_sky)],axis=1),[-1,3]) # construct the true parameter unit vector reconstr_loss_sky = von_mises_fisher.log_prob(tf.math.l2_normalize(xyz_unit,axis=1)) # normalise it for safety (should already be normalised) and compute the logprob # computing Gaussian likelihood for von mises Fisher (sky) parameters to be faded away with the ramp mean_ra = tf.math.floormod(tf.atan2(tf.slice(loc_xyz,[0,1],[-1,1]),tf.slice(loc_xyz,[0,0],[-1,1])),2.0*np.pi)/(2.0*np.pi) # convert the unit vector to scaled 0->1 RA mean_dec = (tf.asin(tf.slice(loc_xyz,[0,2],[-1,1])) + 0.5*np.pi)/np.pi # convert the unit vector to scaled 0->1 dec mean_sky = tf.reshape(tf.concat([mean_ra,mean_dec],axis=1),[bs_ph,2]) # package up the scaled RA and dec gauss_sky = tfp.distributions.MultivariateNormalDiag( loc=mean_sky, scale_diag=tf.concat([tf.sqrt(temp_var_r2_sky),tf.sqrt(temp_var_r2_sky)],axis=1)) # use the same 1D concentration parameter for both RA and dec dimensions reconstr_loss_gauss_sky = gauss_sky.log_prob(tf.boolean_mask(x_ph,sky_mask,axis=1)) # compute the logprob at the true sky location reconstr_loss_sky = ramp*reconstr_loss_sky + (1.0-ramp)*reconstr_loss_gauss_sky # start with a Gaussian model and fade in the true vonmises Fisher cost_R = -1.0*tf.reduce_mean(reconstr_loss_gauss + reconstr_loss_vonmise + reconstr_loss_masses + reconstr_loss_sky) r2_xzy_mean = tf.gather(tf.concat([r2_xzy_mean_gauss,r2_xzy_mean_vonmise,r2_xzy_mean_m1,r2_xzy_mean_m2,r2_xzy_mean_sky],axis=1),tf.constant(idx_mask),axis=1) # put the elements back in order r2_xzy_scale = tf.gather(tf.concat([r2_xzy_log_sig_sq_gauss,r2_xzy_log_sig_sq_vonmise,r2_xzy_log_sig_sq_m1,r2_xzy_log_sig_sq_m2,r2_xzy_log_sig_sq_sky],axis=1),tf.constant(idx_mask),axis=1) # put the elements back in order log_q_q = mvn_q.log_prob(q_zxy_samp) log_r1_q = bimix_gauss.log_prob(q_zxy_samp) # evaluate the log prob of r1 at the q samples KL = tf.reduce_mean(log_q_q - log_r1_q) # average over batch # THE VICI COST FUNCTION COST = cost_R + ramp*KL #+ L1_weight_reg) # VARIABLES LISTS var_list_VICI = [var for var in tf.trainable_variables() if var.name.startswith("VICI")] # DEFINE OPTIMISER (using ADAM here) optimizer = tf.train.AdamOptimizer(params['initial_training_rate']) # optimizer = tf.train.RMSPropOptimizer(params['initial_training_rate']) minimize = optimizer.minimize(COST,var_list = var_list_VICI) # INITIALISE AND RUN SESSION init = tf.global_variables_initializer() session.run(init) saver = tf.train.Saver() print('Training Inference Model...') # START OPTIMISATION OF OELBO indices_generator = batch_manager.SequentialIndexer(params['batch_size'], xsh[0]) plotdata = [] load_chunk_it = 1 for i in range(params['num_iterations']): next_indices = indices_generator.next_indices() # if load chunks true, load in data by chunks if params['load_by_chunks'] == True and i == int(params['load_iteration']*load_chunk_it): x_data, y_data = load_chunk(params['train_set_dir'],params['inf_pars'],params,bounds,fixed_vals) load_chunk_it += 1 # Make noise realizations and add to training data next_x_data = x_data[next_indices,:] if params['reduce'] == True or n_conv_r1 != None: next_y_data = y_data[next_indices,:] + np.random.normal(0,1,size=(params['batch_size'],int(params['ndata']),len(fixed_vals['det']))) else: next_y_data = y_data[next_indices,:] + np.random.normal(0,1,size=(params['batch_size'],int(params['ndata']*len(fixed_vals['det'])))) next_y_data /= y_normscale # required for fast convergence if params['by_channel'] == False: next_y_data_new = [] for sig in next_y_data: next_y_data_new.append(sig.T) next_y_data = np.array(next_y_data_new) del next_y_data_new # restore session if wanted if params['resume_training'] == True and i == 0: print(save_dir) saver.restore(session, save_dir) # compute the ramp value rmp = 0.0 if params['ramp'] == True: if i>ramp_start: rmp = (np.log10(float(i)) - np.log10(ramp_start))/(np.log10(ramp_end) - np.log10(ramp_start)) if i>ramp_end: rmp = 1.0 else: rmp = 1.0 # train the network session.run(minimize, feed_dict={bs_ph:bs, x_ph:next_x_data, y_ph:next_y_data, ramp:rmp}) # if we are in a report iteration extract cost function values if i % params['report_interval'] == 0 and i > 0: # get training loss cost, kl, AB_batch = session.run([cost_R, KL, r1_weight], feed_dict={bs_ph:bs, x_ph:next_x_data, y_ph:next_y_data, ramp:rmp}) # get validation loss on test set cost_val, kl_val = session.run([cost_R, KL], feed_dict={bs_ph:y_data_test.shape[0], x_ph:x_data_test, y_ph:y_data_test/y_normscale, ramp:rmp}) plotdata.append([cost,kl,cost+kl,cost_val,kl_val,cost_val+kl_val]) try: # Make loss plot plt.figure() xvec = params['report_interval']*np.arange(np.array(plotdata).shape[0]) plt.semilogx(xvec,np.array(plotdata)[:,0],label='recon',color='blue',alpha=0.5) plt.semilogx(xvec,np.array(plotdata)[:,1],label='KL',color='orange',alpha=0.5) plt.semilogx(xvec,np.array(plotdata)[:,2],label='total',color='green',alpha=0.5) plt.semilogx(xvec,np.array(plotdata)[:,3],label='recon_val',color='blue',linestyle='dotted') plt.semilogx(xvec,np.array(plotdata)[:,4],label='KL_val',color='orange',linestyle='dotted') plt.semilogx(xvec,np.array(plotdata)[:,5],label='total_val',color='green',linestyle='dotted') plt.ylim([-25,15]) plt.xlabel('iteration') plt.ylabel('cost') plt.legend() plt.savefig('%s/latest_%s/cost_%s.png' % (params['plot_dir'],params['run_label'],params['run_label'])) plt.ylim([np.min(np.array(plotdata)[-int(0.9*np.array(plotdata).shape[0]):,0]), np.max(np.array(plotdata)[-int(0.9*np.array(plotdata).shape[0]):,1])]) plt.savefig('%s/latest_%s/cost_zoom_%s.png' % (params['plot_dir'],params['run_label'],params['run_label'])) plt.close('all') except: pass if params['print_values']==True: print('--------------------------------------------------------------') print('Iteration:',i) print('Training -ELBO:',cost) print('Validation -ELBO:',cost_val) print('Training KL Divergence:',kl) print('Validation KL Divergence:',kl_val) print('Training Total cost:',kl + cost) print('Validation Total cost:',kl_val + cost_val) print() # terminate training if vanishing gradient if np.isnan(kl+cost) == True or np.isnan(kl_val+cost_val) == True or kl+cost > int(1e5): print('Network is returning NaN values') print('Terminating network training') if params['hyperparam_optim'] == True: save_path = saver.save(session,save_dir) return 5000.0, session, saver, save_dir else: exit() try: # Save loss plot data np.savetxt(save_dir.split('/')[0] + '/loss_data.txt', np.array(plotdata)) except FileNotFoundError as err: print(err) pass if i % params['save_interval'] == 0 and i > 0: if params['hyperparam_optim'] == False: # Save model save_path = saver.save(session,save_dir) else: pass # stop hyperparam optim training it and return KL divergence as figure of merit if params['hyperparam_optim'] == True and i == params['hyperparam_optim_stop']: save_path = saver.save(session,save_dir) return np.array(plotdata)[-1,2], session, saver, save_dir if i % params['plot_interval'] == 0 and i>0: n_mode_weight_copy = 100 # must be a multiple of 50 # just run the network on the test data for j in range(params['r']*params['r']): # The trained inverse model weights can then be used to infer a probability density of solutions given new measurements if params['reduce'] == True or params['n_filters_r1'] != None: XS, dt, _ = VICI_inverse_model.run(params, y_data_test[j].reshape([1,y_data_test.shape[1],y_data_test.shape[2]]), np.shape(x_data_test)[1], y_normscale, "inverse_model_dir_%s/inverse_model.ckpt" % params['run_label']) else: XS, dt, _ = VICI_inverse_model.run(params, y_data_test[j].reshape([1,-1]), np.shape(x_data_test)[1], y_normscale, "inverse_model_dir_%s/inverse_model.ckpt" % params['run_label']) print('Runtime to generate {} samples = {} sec'.format(params['n_samples'],dt)) # Make corner plots # Get corner parnames to use in plotting labels parnames = [] for k_idx,k in enumerate(params['rand_pars']): if np.isin(k, params['inf_pars']): parnames.append(params['cornercorner_parnames'][k_idx]) defaults_kwargs = dict( bins=50, smooth=0.9, label_kwargs=dict(fontsize=16), title_kwargs=dict(fontsize=16), truth_color='tab:orange', quantiles=[0.16, 0.84], levels=(0.68,0.90,0.95), density=True, plot_density=False, plot_datapoints=True, max_n_ticks=3) figure = corner.corner(posterior_truth_test[j], **defaults_kwargs,labels=parnames, color='tab:blue',truths=x_data_test[j,:], show_titles=True) # compute weights, otherwise the 1d histograms will be different scales, could remove this corner.corner(XS,**defaults_kwargs,labels=parnames, color='tab:red', fill_contours=True, show_titles=True, fig=figure) plt.savefig('%s/corner_plot_%s_%d-%d.png' % (params['plot_dir'],params['run_label'],i,j)) plt.savefig('%s/latest_%s/corner_plot_%s_%d.png' % (params['plot_dir'],params['run_label'],params['run_label'],j)) plt.close('all') print('Made corner plot %d' % j) return
def __call__(self, learner, meta_batches=None, inner_batches=None, init_state=None, unroll_n_steps=None): if unroll_n_steps is None: unroll_n_steps = self.unroll_n_steps else: print("Using passed in unroll steps") if inner_batches is None: inner_batches = self.inner_batches else: # convert the batches object to a tensorarray. def to_ta(t): return tf.TensorArray(dtype=t.dtype, size=self.unroll_n_steps).unstack(t) inner_batches = nest.map_structure(to_ta, inner_batches) if meta_batches is None: meta_batches = self.meta_batches else: # convert the batches object to a tensorarray. def ml_to_ta(t): return tf.TensorArray(dtype=t.dtype, size=self.meta_loss_evals * self.unroll_n_steps).unstack(t) meta_batches = nest.map_structure(ml_to_ta, meta_batches) if init_state is None: init_state = learner.current_state() init_state = tf_utils.force_copy(init_state) current_state = (tf.constant(0, dtype=tf.int32), tf.constant(0., dtype=tf.float32), init_state) def loss_and_next_state_fn((idx, l, state)): batch = self.get_batch(idx, batches=inner_batches) l, s = learner.loss_and_next_state(state, loss_state=batch) return (idx + 1, l, s) def accumulate_fn((idx, _, s), (a_meta, a_inner)): """Accumulate loss for fold learning process.""" cond = lambda i, a: tf.less(i, self.meta_loss_evals) def body_meta(i, a): # minus 1 as this takes the following step. batch = self.get_batch((idx - 1) * (self.meta_loss_evals) + i, batches=meta_batches) return (i + 1, a + learner.meta_loss(s, loss_state=batch)) _, extra_losses = tf.while_loop(cond, body_meta, loop_vars=[0, 0.]) def body_inner(i, a): # minus 1 as this takes the following step. batch = self.get_batch((idx - 1) * (self.meta_loss_evals) + i, batches=meta_batches) return (i + 1, a + learner.inner_loss(s, loss_state=batch)) _, inner_losses = tf.while_loop(cond, body_inner, loop_vars=[0, 0.]) return a_meta + extra_losses, a_inner + inner_losses
def _slow_greedy_infer_guess_and_check(self, features, decode_length): assert self._hparams.block_size > 0 assert self._hparams.force_full_predict assert self._hparams.sampling_method == "argmax" assert self._decode_hparams.batch_size == 1 assert self._decode_hparams.block_size > 0 assert self._decode_hparams.block_size <= self._hparams.block_size assert self._decode_hparams.guess_and_check_top_k > 0 inputs_old = features["inputs"] assert "targets" not in features assert len(features["inputs"].shape) in [3, 4] if len(features["inputs"].shape) < 4: features["inputs"] = tf.expand_dims(features["inputs"], 2) block_size = self._decode_hparams.block_size decode_length += tf.shape(features["inputs"])[1] def while_exit_cond(result, length): # pylint: disable=unused-argument return tf.logical_and( length < decode_length, tf.reduce_all( tf.not_equal(result[:, :length, :, :], text_encoder.EOS_ID))) def infer_step(result, length): """Inference step.""" def print_info(result, length, new_length): vocab = self.problem_hparams.vocabulary["targets"] tf.logging.info( "length=%s new_length=%s length_diff=%s new_suffix=%s", length, new_length, new_length - length, str([ vocab._subtoken_id_to_subtoken_string(index) # pylint: disable=protected-access for index in result[0, -block_size:, 0, 0][:new_length - length] ]).decode("unicode-escape"), ) features["targets"] = tf.pad(result, [[0, 0], [0, 1], [0, 0], [0, 0]]) samples, logits, losses = self.sample(features) # pylint: disable=unused-variable _, top_k_indices = tf.nn.top_k( logits[:, :-1, :1, :, :], k=self._decode_hparams.guess_and_check_top_k) in_top_k = tf.reduce_any(tf.equal(tf.to_int64(top_k_indices), tf.expand_dims(result, 4)), axis=4) eos_cumsum = tf.cumsum(tf.to_int32( tf.equal(result, text_encoder.EOS_ID)), axis=1) after_eos = tf.greater(common_layers.shift_right(eos_cumsum), 0) correct = tf.logical_and(in_top_k, tf.logical_not(after_eos)) correct_cumsum = tf.cumsum(tf.to_int32(correct), axis=1) perfect_cumsum = 1 + tf.range(tf.shape(correct)[1]) for axis in [0, 2, 3]: perfect_cumsum = tf.expand_dims(perfect_cumsum, axis=axis) new_length = tf.reduce_sum(tf.to_int32( tf.equal(correct_cumsum, perfect_cumsum)), axis=1) new_length = tf.squeeze(new_length, axis=[0, 1, 2]) new_length = tf.minimum(new_length, decode_length) new_result = tf.concat([ result[:, :new_length, :, :], tf.reshape(samples[:, new_length, :block_size, :], [1, block_size, 1, 1]) ], axis=1) with tf.control_dependencies( [tf.py_func(print_info, [result, length, new_length], [])]): new_result = tf.identity(new_result) return new_result, new_length result = tf.zeros((1, 0, 1, 1), dtype=tf.int64) length = tf.squeeze(tf.zeros(1, dtype=tf.int32)) result, length = tf.while_loop(while_exit_cond, infer_step, [result, length], shape_invariants=[ tf.TensorShape([1, None, 1, 1]), tf.TensorShape([]), ], back_prop=False, parallel_iterations=1) result = result[:, :length, :, :] features["inputs"] = inputs_old return { "outputs": result, "scores": None, }
def pgd_attack(loss_fn, input_image, epsilon, num_steps, optimizer=UnrolledGradientDescent(), project_perturbation=_project_perturbation, image_bounds=None, random_init=1.): """Projected gradient descent for generating adversarial images. Args: loss_fn: A callable which takes `input_image` and `label` as arguments, and returns the loss, a scalar Tensor, we will be minimized input_image: Tensor, a batch of images epsilon: float, the L-infinity norm of the maximum allowable perturbation num_steps: int, the number of steps of gradient descent optimizer: An `UnrolledOptimizer` object project_perturbation: A function, which will be used to enforce some constraint. It should have the same signature as `_project_perturbation`. Note that if you use a custom projection function, you should double-check your implementation, since an incorrect implementation will not error, and will appear to work fine. image_bounds: A pair of floats: minimum and maximum pixel value. If None (default), the bounds are assumed to be 0 and 1. random_init: Probability of starting from random location rather than nominal input image. Returns: adversarial version of `input_image`, with L-infinity difference less than epsilon, which tries to minimize loss_fn. """ image_bounds = image_bounds or (0., 1.) random_shape = [tf.shape(input_image)[0] ] + [1] * (len(input_image.shape) - 1) use_random_init = tf.cast( tf.random_uniform(random_shape) < float(random_init), tf.float32) init_perturbation = use_random_init * tf.random_uniform( tf.shape(input_image), minval=-epsilon, maxval=epsilon) init_perturbation = project_perturbation(init_perturbation, epsilon, input_image, image_bounds) init_optim_state = optimizer.init_state([init_perturbation]) def loop_body(i, perturbation, flat_optim_state): """Update perturbation to input image.""" optim_state = nest.pack_sequence_as(structure=init_optim_state, flat_sequence=flat_optim_state) loss = loss_fn(input_image + perturbation) new_perturbation_list, new_optim_state = optimizer.minimize( loss, [perturbation], optim_state) projected_perturbation = project_perturbation(new_perturbation_list[0], epsilon, input_image, image_bounds) return i + 1, projected_perturbation, nest.flatten(new_optim_state) def cond(i, *_): return tf.less(i, num_steps) flat_init_optim_state = nest.flatten(init_optim_state) _, final_perturbation, _ = tf.while_loop( cond, loop_body, loop_vars=[tf.constant(0.), init_perturbation, flat_init_optim_state], parallel_iterations=1, back_prop=False) adversarial_image = input_image + final_perturbation return tf.stop_gradient(adversarial_image)
def _create_gradient_ascent_action_tensors(self, eps=1e-6): """Create tensorflow operations for gradient ascent max_actions.""" self._action_init_tensor = tf.placeholder(dtype=tf.float32, name="action_init_tensor", shape=(None, self.action_dim)) self._tolerance_tensor = tf.placeholder(dtype=tf.float32, name="tolerance_tensor", shape=()) with tf.variable_scope("{}_{}".format(self.name, "action_variable")): self._action_variable_tensor = tf.Variable( initial_value=self._action_init_tensor, trainable=True, name="action_var") # gradient ascentd self.cost_now = -tf.reduce_mean( self._build_q_function_net(self._state_tensor, self._action_variable_tensor)) self.action_gradient = tf.gradients( self.cost_now, self._action_variable_tensor)[0] # normalize the gradient self.normalized_action_gradient = self.action_gradient / ( eps + tf.linalg.norm(self.action_gradient)) if self.sufficient_ascent_flag: def cond_sufficient_descent(learning_rate_action, cond_sufficient_descent, cost_perturbed): del cost_perturbed cond_1 = tf.math.greater(learning_rate_action, self.learning_rate_action) return tf.math.logical_and( cond_1, tf.logical_not(cond_sufficient_descent)) def body_sufficient_descent(learning_rate_action, cond_sufficient_descent, cost_perturbed, c_armijo=0.01, c_goldstein=0.25, lr_decay=0.1): """Function for sufficient descent.""" del cond_sufficient_descent, cost_perturbed action_variable_perturbed_tensor = self._action_variable_tensor - \ learning_rate_action * self.normalized_action_gradient cost_perturbed = -tf.reduce_mean( self._build_q_function_net( self._state_tensor, action_variable_perturbed_tensor)) # Here the negative gradient corresponds to maximization of Q fun. sufficient_descent = tf.reduce_sum( self.action_gradient * -self.normalized_action_gradient) goldstein_condition = tf.greater_equal( cost_perturbed, self.cost_now + c_goldstein * learning_rate_action * sufficient_descent) armijo_condition = tf.less_equal( cost_perturbed, self.cost_now + c_armijo * learning_rate_action * sufficient_descent) cond_sufficient_descent = tf.logical_and( goldstein_condition, armijo_condition) with tf.control_dependencies([cond_sufficient_descent]): learning_rate_action = learning_rate_action * lr_decay return learning_rate_action, cond_sufficient_descent, cost_perturbed # Construct the while loop. def cond_gradient_ascent(itr, cond_terminate): cond_1 = tf.math.less(itr, self.action_maximization_iterations) return tf.math.logical_and(cond_1, tf.logical_not(cond_terminate)) def body_gradient_ascent(itr, cond_terminate, lr_init=100.0): """Function for gradient descent.""" del cond_terminate if self.sufficient_ascent_flag: # first calculate sufficeint descent result_sufficient_descent = tf.while_loop( cond_sufficient_descent, body_sufficient_descent, [ tf.constant(lr_init), tf.constant(False), tf.constant(np.inf) ]) lr_action = result_sufficient_descent[0] cost_perturbed = result_sufficient_descent[2] cond_terminate = tf.less_equal( tf.math.abs(cost_perturbed - self.cost_now), self._tolerance_tensor) else: # no sufficient descent step lr_action = self.learning_rate_ga action_variable_perturbed_tensor = self._action_variable_tensor - \ lr_action * self.normalized_action_gradient cost_perturbed = -tf.reduce_mean( self._build_q_function_net( self._state_tensor, action_variable_perturbed_tensor)) cond_terminate = tf.less_equal( tf.math.abs(cost_perturbed - self.cost_now), self._tolerance_tensor) train_op = tf.train.GradientDescentOptimizer( learning_rate=lr_action).apply_gradients( grads_and_vars=[(self.normalized_action_gradient, self._action_variable_tensor)]) # Ensure that the update is applied before continuing. with tf.control_dependencies([train_op]): itr = itr + 1 return itr, cond_terminate self.cost_optimizer = tf.while_loop( cond_gradient_ascent, body_gradient_ascent, [tf.constant(0), tf.constant(False)]) self.action_init_op = tf.initializers.variables( tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="{}_{}".format(self.name, "action_variable")))
def _build(self, inputs, labels): batch_size, input_shape, duplicated_inputs = self.prepare_inputs( inputs) if (self._max_specifications > 0 and self._max_specifications < self._specification.num_specifications): num_specs = self._max_specifications model_logits = self._eval_fn(inputs) bounds = self._specification.evaluate(model_logits) _, idx = tf.math.top_k(bounds, k=num_specs, sorted=False) if self._random_specifications: idx = tf.random.uniform( shape=tf.shape(idx), maxval=self._specification.num_specifications, dtype=idx.dtype) idx = tf.tile(tf.expand_dims(idx, 0), [self._num_restarts, 1, 1]) def select_fn(x, i): return tf.squeeze(tf.gather(x, tf.expand_dims(idx[:, :, i], -1), batch_dims=len(idx.shape) - 1), axis=-1) else: num_specs = self._specification.num_specifications select_fn = lambda x, i: x[:, :, i] def objective_fn(x): model_logits = self._eval_fn(x) # [restarts * batch_size, output]. model_logits = tf.reshape(model_logits, [self._num_restarts, batch_size, -1]) # Output has dimension [num_restarts, batch_size, num_specifications]. return self._specification.evaluate(model_logits) def flat_objective_fn(x): return _maximize_margin(objective_fn(x)) def build_loss_fn(idx): def _reduced_loss_fn(x): # Pick worse attack, output has shape [num_restarts, batch_size]. return -tf.reduce_sum(select_fn(objective_fn(x), idx)) return _reduced_loss_fn if _is_spsa_optimizer(self._optimizer_builder): raise ValueError('"UnrolledSPSA*" unsupported in ' 'MultiTargetedPGDAttack') optimizer = self._optimizer_builder(lr=self._lr, lr_fn=self._lr_fn) # Run a separate PGD attack for each specification. def cond(spec_idx, unused_attack, success): # If we are already successful, we break. return tf.logical_and(spec_idx < num_specs, tf.logical_not(tf.reduce_all(success))) def body(spec_idx, attack, success): """Runs a separate PGD attack for each specification.""" adversarial_input = pgd_attack( build_loss_fn(spec_idx), duplicated_inputs, epsilon=self._epsilon, num_steps=self._num_steps, image_bounds=self._input_bounds, random_init=self._random_init, optimizer=optimizer, project_perturbation=self._project_perturbation) new_attack = self.find_worst_attack(flat_objective_fn, adversarial_input, batch_size, input_shape) new_logits = self._eval_fn(new_attack) # Count the number of sample that violate any specification. new_success = _any_greater( self._specification.evaluate(new_logits)) # The first iteration always sets the attack and logits. use_new_values = tf.logical_or(tf.equal(spec_idx, 0), new_success) print_op = tf.print('Processed specification #', spec_idx) with tf.control_dependencies([print_op]): new_spec_idx = spec_idx + 1 return (new_spec_idx, tf.where(use_new_values, new_attack, attack), tf.logical_or(success, new_success)) _, self._attack, self._success = tf.while_loop( cond, body, back_prop=False, parallel_iterations=1, loop_vars=[ tf.constant(0, dtype=tf.int32), inputs, tf.zeros([tf.shape(inputs)[0]], dtype=tf.bool), ]) self._logits = self._eval_fn(self._attack, mode='final') return self._attack
def ssd_decode_and_crop(image_buffer, boxes, classes, raw_shape): """Crop image randomly and decode the cropped region. This function will crop an image to meet the following requirements: 1. height to width ratio between 0.5 and 2; 2. IoUs of some boxes exceed specified threshold; 3. At least one box center is in the cropped region. We defer the jpeg decoding task until after the crop to avoid wasted work. Reference: https://github.com/chauhan-utk/ssd.DomainAdaptation Args: image_buffer: Tensor tf.string containing the contents of a JPEG file. boxes: Tensor tf.float32 of shape [num_boxes, 4], containing coordinates of object bounding boxes. classes: Tensor tf.int64 of shape [num_boxes, 1], containing class labels of objects. raw_shape: [height, width, 3]. Returns: resized_image: decoded, cropped, and resized image Tensor tf.float32 of shape [ssd_constants.IMAGE_SIZE, ssd_constants.IMAGE_SIZE, 3], value range 0--255. cropped_boxes: box coordinates for objects in the cropped region. cropped_classes: class labels for objects in the cropped region. """ num_boxes = tf.shape(boxes)[0] def no_crop_check(): return (tf.random_uniform( shape=(), minval=0, maxval=1, dtype=tf.float32) < ssd_constants.P_NO_CROP_PER_PASS) def no_crop_proposal(): return ( tf.ones((), tf.bool), tf.convert_to_tensor([0, 0, 1, 1], dtype=tf.float32), tf.ones((num_boxes, ), tf.bool), ) def crop_proposal(): rand_vec = lambda minval, maxval: tf.random_uniform(shape=( ssd_constants.NUM_CROP_PASSES, 1), minval=minval, maxval=maxval, dtype=tf.float32) width, height = rand_vec(0.3, 1), rand_vec(0.3, 1) left, top = rand_vec(0, 1 - width), rand_vec(0, 1 - height) right = left + width bottom = top + height ltrb = tf.concat([left, top, right, bottom], axis=1) min_iou = tf.random_shuffle(ssd_constants.CROP_MIN_IOU_CHOICES)[0] ious = calc_iou_tensor(ltrb, boxes) # discard any bboxes whose center not in the cropped image xc, yc = [ tf.tile(0.5 * (boxes[:, i + 0] + boxes[:, i + 2])[tf.newaxis, :], (ssd_constants.NUM_CROP_PASSES, 1)) for i in range(2) ] masks = tf.reduce_all(tf.stack([ tf.greater(xc, tf.tile(left, (1, num_boxes))), tf.less(xc, tf.tile(right, (1, num_boxes))), tf.greater(yc, tf.tile(top, (1, num_boxes))), tf.less(yc, tf.tile(bottom, (1, num_boxes))), ], axis=2), axis=2) # Checks of whether a crop is valid. valid_aspect = tf.logical_and(tf.less(height / width, 2), tf.less(width / height, 2)) valid_ious = tf.reduce_all(tf.greater(ious, min_iou), axis=1, keepdims=True) valid_masks = tf.reduce_any(masks, axis=1, keepdims=True) valid_all = tf.cast( tf.reduce_all(tf.concat([valid_aspect, valid_ious, valid_masks], axis=1), axis=1), tf.int32) # One indexed, as zero is needed for the case of no matches. index = tf.range(1, 1 + ssd_constants.NUM_CROP_PASSES, dtype=tf.int32) # Either one-hot, or zeros if there is no valid crop. selection = tf.equal(tf.reduce_max(index * valid_all), index) use_crop = tf.reduce_any(selection) output_ltrb = tf.reduce_sum(tf.multiply( ltrb, tf.tile(tf.cast(selection, tf.float32)[:, tf.newaxis], (1, 4))), axis=0) output_masks = tf.reduce_any(tf.logical_and( masks, tf.tile(selection[:, tf.newaxis], (1, num_boxes))), axis=0) return use_crop, output_ltrb, output_masks def proposal(*args): return tf.cond( pred=no_crop_check(), true_fn=no_crop_proposal, false_fn=crop_proposal, ) _, crop_bounds, box_masks = tf.while_loop( cond=lambda x, *_: tf.logical_not(x), body=proposal, loop_vars=[ tf.zeros((), tf.bool), tf.zeros((4, ), tf.float32), tf.zeros((num_boxes, ), tf.bool) ], ) filtered_boxes = tf.boolean_mask(boxes, box_masks, axis=0) mlperf.logger.log(key=mlperf.tags.NUM_CROPPING_ITERATIONS, value=ssd_constants.NUM_CROP_PASSES) # Clip boxes to the cropped region. filtered_boxes = tf.stack([ tf.maximum(filtered_boxes[:, 0], crop_bounds[0]), tf.maximum(filtered_boxes[:, 1], crop_bounds[1]), tf.minimum(filtered_boxes[:, 2], crop_bounds[2]), tf.minimum(filtered_boxes[:, 3], crop_bounds[3]), ], axis=1) left = crop_bounds[0] top = crop_bounds[1] width = crop_bounds[2] - left height = crop_bounds[3] - top cropped_boxes = tf.stack([ (filtered_boxes[:, 0] - left) / width, (filtered_boxes[:, 1] - top) / height, (filtered_boxes[:, 2] - left) / width, (filtered_boxes[:, 3] - top) / height, ], axis=1) # crop_window containing integer coordinates of cropped region. A normalized # coordinate value of y should be mapped to the image coordinate at # y * (height - 1). raw_shape = tf.cast(raw_shape, tf.float32) crop_window = tf.stack([ left * (raw_shape[0] - 1), top * (raw_shape[1] - 1), width * raw_shape[0], height * raw_shape[1] ]) crop_window = tf.cast(crop_window, tf.int32) # Fused op only decodes the cropped portion of an image cropped_image = tf.image.decode_and_crop_jpeg(image_buffer, crop_window, channels=3) # Resize converts image dtype from uint8 to float32, without rescaling values. resized_image = tf.image.resize_images( cropped_image, [ssd_constants.IMAGE_SIZE, ssd_constants.IMAGE_SIZE]) mlperf.logger.log(key=mlperf.tags.INPUT_SIZE, value=ssd_constants.IMAGE_SIZE) cropped_classes = tf.boolean_mask(classes, box_masks, axis=0) return resized_image, cropped_boxes, cropped_classes
def _rnn_fn(sample_arc, x, prev_s, w_prev, w_skip, input_mask, layer_mask, params): """Multi-layer LSTM. Args: sample_arc: [num_layers * 2], sequence of tokens representing architecture. x: [batch_size, num_steps, hidden_size]. prev_s: [batch_size, hidden_size]. w_prev: [2 * hidden_size, 2 * hidden_size]. w_skip: [None, [hidden_size, 2 * hidden_size] * (num_layers-1)]. input_mask: `[batch_size, hidden_size]`. layer_mask: `[batch_size, hidden_size]`. params: hyper-params object. Returns: next_s: [batch_size, hidden_size]. all_s: [[batch_size, num_steps, hidden_size] * num_layers]. """ batch_size = x.get_shape()[0].value num_steps = tf.shape(x)[1] num_layers = len(sample_arc) // 2 all_s = tf.TensorArray(dtype=tf.float32, size=num_steps, infer_shape=False) # extract the relevant variables, so that you only do L2-reg on them. u_skip = [] start_idx = 0 for layer_id in range(num_layers): prev_idx = sample_arc[start_idx] func_idx = sample_arc[start_idx + 1] u_skip.append(w_skip[layer_id][func_idx, prev_idx]) start_idx += 2 w_skip = u_skip var_s = [w_prev] + w_skip[1:] def _select_function(h, function_id): h = tf.stack([tf.tanh(h), tf.nn.relu(h), tf.sigmoid(h), h], axis=0) h = h[function_id] return h def _condition(step, *unused_args): return tf.less(step, num_steps) def _body(step, prev_s, all_s): """Body function.""" inp = x[:, step, :] # important change: first input uses a tanh() if layer_mask is not None: assert input_mask is not None ht = tf.matmul( tf.concat([inp * input_mask, prev_s * layer_mask], axis=1), w_prev) else: ht = tf.matmul(tf.concat([inp, prev_s], axis=1), w_prev) h, t = tf.split(ht, 2, axis=1) h = tf.tanh(h) t = tf.sigmoid(t) s = prev_s + t * (h - prev_s) layers = [s] start_idx = 0 used = [] for layer_id in range(num_layers): prev_idx = sample_arc[start_idx] func_idx = sample_arc[start_idx + 1] used.append(tf.one_hot(prev_idx, depth=num_layers, dtype=tf.int32)) prev_s = tf.stack(layers, axis=0)[prev_idx] if layer_mask is not None: ht = tf.matmul(prev_s * layer_mask, w_skip[layer_id]) else: ht = tf.matmul(prev_s, w_skip[layer_id]) h, t = tf.split(ht, 2, axis=1) h = _select_function(h, func_idx) t = tf.sigmoid(t) s = prev_s + t * (h - prev_s) s.set_shape([batch_size, params.hidden_size]) layers.append(s) start_idx += 2 next_s = tf.add_n(layers[1:]) / tf.cast(num_layers, dtype=tf.float32) all_s = all_s.write(step, next_s) return step + 1, next_s, all_s loop_inps = [tf.constant(0, dtype=tf.int32), prev_s, all_s] _, next_s, all_s = tf.while_loop(_condition, _body, loop_inps) all_s = tf.transpose(all_s.stack(), [1, 0, 2]) return next_s, all_s, var_s
def select_indices_stratified(size, scores, clusters, indices=None, soft=False, parallel_iterations=8) -> tf.Tensor: """Selects indices of the instances to label given the scores. Parameters ---------- size : int Number of samples to label. scores : Tensor <float32> [num_samples] A vector of scores that are used to select which sample to label. clusters : Tensor <int32> [num_samples] A vector of cluster indices used for sampling stratification. indices : Tensor <int32> [num_instances], optional A vector of absolute indices of the samples in a larger collection. If not None, the method returns `selected_indices` from `indices`. Otherwise, `selected_indices` are relative. soft : bool, optional (default=False) Whether to select top indices softly by sampling a categorical distribution with logits proportional to the scores. parallel_iterations : int (default: 8) Number of parallel iterations passed to tf.while_loop. Returns ------- selected_indices : Tensor <int32> [size] """ # size_per_cluster: <int32> [num_unique_clusters]. # unique_clusters: <int32> [num_unique_clusters]. size_per_cluster, unique_clusters = Sampler.stratify_by_cluster( size, clusters, parallel_iterations=parallel_iterations) def cond_fn(step, _unused_indices): return tf.less(step, tf.size(size_per_cluster)) def body_fn(step, selected_indices): cluster_mask = tf.equal(clusters, unique_clusters[step]) cluster_indices = tf.where(cluster_mask)[:, 0] cluster_scores = tf.gather(scores, cluster_indices, axis=0) selected_idx = tf.cond( pred=tf.greater(size_per_cluster[step], 0), true_fn=lambda: Sampler.select_indices( size=size_per_cluster[step], scores=cluster_scores, indices=cluster_indices, soft=soft, ), false_fn=lambda: tf.constant([], dtype=tf.int32), ) return [ tf.add(step, 1), selected_indices.write(step, selected_idx) ] # Select indices for each cluster cluster. _, selected_indices_ta = tf.while_loop( cond=cond_fn, body=body_fn, loop_vars=[ tf.constant(0), tf.TensorArray(dtype=tf.int32, infer_shape=False, size=tf.size(unique_clusters)), ], back_prop=False, parallel_iterations=parallel_iterations, name="stratified-index-selection", ) selected_indices = selected_indices_ta.concat() selected_indices = tf.reshape(selected_indices, shape=(size, )) if indices is not None: selected_indices = tf.gather(indices, selected_indices, axis=0) return selected_indices
def hmc(energy_fn, init_X, L=20, step_size=1.0, burn_in=100, num_samples=1000, thinning_steps=1, max_steps=None): samples = tf.TensorArray(init_X.dtype, size=num_samples * thinning_steps, dynamic_size=False, name='samples_ta') #init_X = init_X[tf.newaxis,:] X_shape = tf.shape(init_X) if max_steps == None: max_steps = 1000 * num_samples * thinning_steps def hmc_step(i, num_accepted, q, samples): # Sample momentum variables as standard Gaussians. p = tf.random.normal(X_shape, mean=0., stddev=1.) init_q = q # Compute initial kinetic and potential energies. init_K = tf.reduce_sum(tf.square(p)) / 2. init_U = energy_fn(q) # Do first half-step p = p - step_size * tf.gradients(init_U, q)[0] / 2. # Run for L steps. for step in range(L): q = q + step_size * p if step != L - 1: p = p - step_size * tf.gradients(energy_fn(q), q)[0] proposed_U = energy_fn(q) p = p - step_size * tf.gradients(proposed_U, q)[0] / 2. p = -p proposed_K = tf.reduce_sum(tf.square(p)) / 2. p = tf.debugging.check_numerics(p, "Nans in p.") q = tf.debugging.check_numerics(q, "Nans in q.") accept = tf.random.uniform( []) < tf.exp(init_U - proposed_U + init_K - proposed_K) accept_samples = tf.logical_and(accept, i > burn_in) samples = tf.cond(accept_samples, lambda: samples.write(num_accepted, q), lambda: samples) accept_samples = tf.squeeze(accept_samples) q = tf.cond(accept, lambda: q, lambda: init_q) return i + 1, num_accepted + tf.to_int32(accept_samples), q, samples def hmc_predicate(i, num_accepted, unused_q, unused_samples): return tf.logical_and( tf.less(i, burn_in + max_steps), tf.less(num_accepted, num_samples * thinning_steps)) results = tf.while_loop(hmc_predicate, hmc_step, (0, 0, init_X, samples), back_prop=False) #[num_samples, data_dim] samples = results[-1].stack() samples = tf.reshape(samples, [num_samples, thinning_steps, -1]) samples = samples[:, -1, :] num_steps = results[0] num_accepted = results[1] accept_ratio = num_accepted / (num_steps - burn_in) tf.summary.scalar("acceptance_ratio", accept_ratio) tf.summary.scalar("num_hmc_steps", num_steps - burn_in) return samples
def _slow_greedy_infer(self, features, decode_length): """A slow greedy inference method. Quadratic time in decode_length. Args: features: an map of string to `Tensor` decode_length: an integer. How many additional timesteps to decode. Returns: A dict of decoding results { "outputs": integer `Tensor` of decoded ids of shape [batch_size, <= decode_length] if beam_size == 1 or [batch_size, top_beams, <= decode_length] "scores": None "logits": `Tensor` of shape [batch_size, time, 1, 1, vocab_size]. "losses": a dictionary: {loss-name (string): floating point `Scalar`} } """ if not features: features = {} inputs_old = None # process all conditioning features if "inputs" in features: if len(features["inputs"].shape) < 4: inputs_old = features["inputs"] features["inputs"] = tf.expand_dims(features["inputs"], 2) else: # this would be for melody decoding if "melody" in features: if len(features["melody"].shape) < 4: inputs_old = features["melody"] features["melody"] = tf.expand_dims(features["melody"], 2) if "performance" in features: if len(features["performance"].shape) < 4: inputs_old = features["performance"] features["performance"] = tf.expand_dims( features["performance"], 2) if not self.has_input: # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs") if partial_targets is None: partial_targets = features["targets"] features["partial_targets"] = tf.to_int64(partial_targets) # Save the targets in a var and reassign it after the tf.while loop to avoid # having targets being in a 'while' frame. This ensures targets when used # in metric functions stays in the same frame as other vars. targets_old = features.get("targets", None) target_modality = self._problem_hparams.modality["targets"] def infer_step(recent_output, recent_logits, unused_loss): """Inference step.""" if not tf.executing_eagerly(): if self._target_modality_is_real: dim = self._problem_hparams.vocab_size["targets"] if dim is not None and hasattr(self._hparams, "vocab_divisor"): dim += (-dim) % self._hparams.vocab_divisor recent_output.set_shape([None, None, None, dim]) else: recent_output.set_shape([None, None, None, 1]) padded = tf.pad(recent_output, [[0, 0], [0, 1], [0, 0], [0, 0]]) features["targets"] = padded # This is inefficient in that it generates samples at all timesteps, # not just the last one, except if target_modality is pointwise. samples, logits, losses = self.sample(features) # Concatenate the already-generated recent_output with last timestep # of the newly-generated samples. top = self._hparams.top.get("targets", modalities.get_top(target_modality)) if getattr(top, "pointwise", False): cur_sample = samples[:, -1, :, :] else: cur_sample = samples[:, common_layers.shape_list(recent_output )[1], :, :] if self._target_modality_is_real: cur_sample = tf.expand_dims(cur_sample, axis=1) samples = tf.concat([recent_output, cur_sample], axis=1) else: cur_sample = tf.to_int64(tf.expand_dims(cur_sample, axis=1)) samples = tf.concat([recent_output, cur_sample], axis=1) if not tf.executing_eagerly(): samples.set_shape([None, None, None, 1]) # Assuming we have one shard for logits. logits = tf.concat([recent_logits, logits[:, -1:]], 1) loss = sum([l for l in losses.values() if l is not None]) return samples, logits, loss # Create an initial output tensor. This will be passed # to the infer_step, which adds one timestep at every iteration. if "partial_targets" in features: initial_output = tf.to_int64(features["partial_targets"]) while len(initial_output.get_shape().as_list()) < 4: initial_output = tf.expand_dims(initial_output, 2) batch_size = common_layers.shape_list(initial_output)[0] else: batch_size = common_layers.shape_list(features["performance"])[0] if self._target_modality_is_real: dim = self._problem_hparams.vocab_size["targets"] if dim is not None and hasattr(self._hparams, "vocab_divisor"): dim += (-dim) % self._hparams.vocab_divisor initial_output = tf.zeros((batch_size, 0, 1, dim), dtype=tf.float32) else: initial_output = tf.zeros((batch_size, 0, 1, 1), dtype=tf.int64) # Hack: foldl complains when the output shape is less specified than the # input shape, so we confuse it about the input shape. initial_output = tf.slice(initial_output, [0, 0, 0, 0], common_layers.shape_list(initial_output)) target_modality = self._problem_hparams.modality["targets"] if target_modality == modalities.ModalityType.CLASS_LABEL: decode_length = 1 else: if "partial_targets" in features: prefix_length = common_layers.shape_list( features["partial_targets"])[1] else: # this code will generate outputs that tend to be long, # but this is to avoid the case when the melody is extremely short. # this can be changed to features["melody"] for the actual behavior. prefix_length = common_layers.shape_list( features["performance"])[1] decode_length = prefix_length + decode_length # Initial values of result, logits and loss. result = initial_output vocab_size = self._problem_hparams.vocab_size["targets"] if vocab_size is not None and hasattr(self._hparams, "vocab_divisor"): vocab_size += (-vocab_size) % self._hparams.vocab_divisor if self._target_modality_is_real: logits = tf.zeros((batch_size, 0, 1, vocab_size)) logits_shape_inv = [None, None, None, None] else: # tensor of shape [batch_size, time, 1, 1, vocab_size] logits = tf.zeros((batch_size, 0, 1, 1, vocab_size)) logits_shape_inv = [None, None, None, None, None] if not tf.executing_eagerly(): logits.set_shape(logits_shape_inv) loss = 0.0 def while_exit_cond(result, logits, loss): # pylint: disable=unused-argument """Exit the loop either if reach decode_length or EOS.""" length = common_layers.shape_list(result)[1] not_overflow = length < decode_length if self._problem_hparams.stop_at_eos: def fn_not_eos(): return tf.not_equal( # Check if the last predicted element is a EOS tf.squeeze(result[:, -1, :, :]), text_encoder.EOS_ID) not_eos = tf.cond( # We only check for early stopping if there is at least 1 element ( # otherwise not_eos will crash). tf.not_equal(length, 0), fn_not_eos, lambda: True, ) return tf.cond( tf.equal(batch_size, 1), # If batch_size == 1, we check EOS for early stopping. lambda: tf.logical_and(not_overflow, not_eos), # Else, just wait for max length lambda: not_overflow) return not_overflow result, logits, loss = tf.while_loop( while_exit_cond, infer_step, [result, logits, loss], shape_invariants=[ tf.TensorShape([None, None, None, None]), tf.TensorShape(logits_shape_inv), tf.TensorShape([]), ], back_prop=False, parallel_iterations=1) if inputs_old is not None: # Restore to not confuse Estimator. features["inputs"] = inputs_old # Reassign targets back to the previous value. if targets_old is not None: features["targets"] = targets_old losses = {"training": loss} if "partial_targets" in features: partial_target_length = common_layers.shape_list( features["partial_targets"])[1] result = tf.slice(result, [0, partial_target_length, 0, 0], [-1, -1, -1, -1]) return { "outputs": result, "scores": None, "logits": logits, "losses": losses, }
def build_sample_graph(self, input_pianorolls=None, outer_masks=None, total_gibbs_steps=None): """Builds the tf.while_loop based sampling graph. Args: input_pianorolls: Optional input pianorolls override. If None, uses the pianorolls placeholder. outer_masks: Optional input outer_masks override. If None, uses the outer_masks placeholder. total_gibbs_steps: Optional input total_gibbs_steps override. If None, uses the total_gibbs_steps placeholder. Returns: The output op of the graph. """ if input_pianorolls is None: input_pianorolls = self.inputs["pianorolls"] if outer_masks is None: outer_masks = self.inputs["outer_masks"] tt = tf.shape(input_pianorolls)[1] sample_steps = tf.to_float(self.inputs["sample_steps"]) if total_gibbs_steps is None: total_gibbs_steps = self.inputs["total_gibbs_steps"] temperature = self.inputs["temperature"] input_pianorolls = tf.to_float(input_pianorolls) outer_masks = self.make_outer_masks(outer_masks, input_pianorolls) # Calculate total_gibbs_steps as steps * num_instruments if not given. total_gibbs_steps = tf.cond( tf.equal(total_gibbs_steps, 0), lambda: tf.to_float(tt * self.hparams.num_instruments), lambda: tf.to_float(total_gibbs_steps)) # sample_steps is set to total_gibbs_steps if not given. sample_steps = tf.cond(tf.equal(sample_steps, 0), lambda: total_gibbs_steps, lambda: tf.to_float(sample_steps)) def infer_step(pianorolls, step_count): """Called by tf.while_loop, takes a Gibbs step.""" mask_prob = compute_mask_prob_from_yao_schedule( step_count, total_gibbs_steps) # 1 indicates mask out, 0 is not mask. masks = make_bernoulli_masks(tf.shape(pianorolls), mask_prob, outer_masks) logits = self.predict(pianorolls, masks) samples = sample_with_temperature(logits, temperature=temperature) outputs = pianorolls * (1 - masks) + samples * masks check_completion_op = tf.assert_equal( tf.where(tf.equal(tf.reduce_max(masks, axis=2), 1.), tf.reduce_max(outputs, axis=2), tf.reduce_max(pianorolls, axis=2)), 1.) with tf.control_dependencies([check_completion_op]): outputs = tf.identity(outputs) step_count += 1 return outputs, step_count current_step = tf.to_float(self.inputs["current_step"]) # Initializes pianorolls by evaluating the model once to fill in all gaps. logits = self.predict(tf.to_float(input_pianorolls), outer_masks) samples = sample_with_temperature(logits, temperature=temperature) tf.get_variable_scope().reuse_variables() self.samples, current_step = tf.while_loop( lambda samples, current_step: current_step < sample_steps, infer_step, [samples, current_step], shape_invariants=[ tf.TensorShape([None, None, None, None]), tf.TensorShape(None), ], back_prop=False, parallel_iterations=1, name="coco_while") self.samples.set_shape(input_pianorolls.shape) return self.samples
def infer(self, features, **kwargs): decode_length = (self.frame_height * self.frame_width * self.num_channels) cache = {} decoding_stats = {} targets_old = features.get("targets", None) initial_output = tf.zeros((self.batch_size, decode_length), dtype=tf.int32) initial_logits = tf.zeros( (self.batch_size, decode_length, self.targets_vocab_size)) # call body once to initialize cache with representations of input frames. features["targets"] = initial_output with tf.variable_scope("sparse_imagetransformer/body", reuse=tf.AUTO_REUSE, use_resource=True): self.body(features, decode_step=None, cache=cache, decoding_stats=decoding_stats) def infer_step(i, recent_output, recent_logits, cache, decoding_stats): """Inference step.""" features_copy = features.copy() features_copy["targets"] = recent_output cur_sample, cur_logit = self.sample(features_copy, decode_step=i, cache=cache, decoding_stats=decoding_stats) pos = i samples = recent_output + tf.scatter_nd( indices=[[b, pos] for b in range(self.batch_size)], updates=cur_sample, shape=utils.shape_list(recent_output)) logits = recent_logits + tf.scatter_nd( indices=[[b, pos] for b in range(self.batch_size)], updates=cur_logit, shape=utils.shape_list(recent_logits)) return i + 1, samples, logits, cache, decoding_stats def while_exit_cond(i, result, logits, cache, decoding_stats): # pylint: disable=unused-argument """Exit the loop if it reaches decode_length.""" not_overflow = i < decode_length return not_overflow _, final_result, final_logits, _, decoding_stats = tf.while_loop( while_exit_cond, infer_step, [ tf.constant(0), initial_output, initial_logits, cache, decoding_stats ], back_prop=False, parallel_iterations=1) original_shape = self.get_shape_for_decoder() blocks_per_dim = [ s // q for s, q in zip(original_shape, self.hparams.query_shape) ] final_result_shape = utils.shape_list(final_result) final_result = tf.reshape( final_result, [final_result_shape[0], -1, np.prod(self.hparams.query_shape), 1]) final_logits_shape = utils.shape_list(final_logits) final_logits = tf.reshape(final_logits, [ final_logits_shape[0], -1, np.prod(self.hparams.query_shape), final_logits_shape[-1] ]) final_result = utils.unflatten_blocks_nd(final_result, blocks_per_dim) final_result = utils.put_back_blocks_nd(final_result, self.hparams.query_shape) final_logits = utils.unflatten_blocks_nd(final_logits, blocks_per_dim) final_logits = utils.put_back_blocks_nd(final_logits, self.hparams.query_shape) final_result = tf.reshape( final_result, [-1, self.frame_height, self.frame_width, self.num_channels]) final_logits = tf.reshape(final_logits, [ -1, self.frame_height, self.frame_width, self.num_channels, self.targets_vocab_size ]) if utils.is_xla_compiled(): _IMGS["decodes"] = final_result for name, value in decoding_stats.items(): tf.summary.scalar("decodes/%s" % name, value / decode_length) # Reassign targets back to the previous value. if targets_old is not None: features["targets"] = targets_old return { "outputs": final_result, "scores": None, "logits": final_logits, "losses": None, }
def dynamic_decode(decoder, impute_finished=False, maximum_iterations=None, parallel_iterations=32, swap_memory=False, scope=None): """Perform dynamic decoding with `decoder`. Calls initialize() once and step() repeatedly on the Decoder object. Args: decoder: A `Decoder` instance. impute_finished: Python boolean. If `True`, then states for batch entries which are marked as finished get copied through and the corresponding outputs get zeroed out. This causes some slowdown at each time step, but ensures that the final state and outputs have the correct values and that backprop ignores time steps that were marked as finished. maximum_iterations: `int32` scalar, maximum allowed number of decoding steps. Default is `None` (decode until the decoder is fully done). parallel_iterations: Argument passed to `tf.while_loop`. swap_memory: Argument passed to `tf.while_loop`. scope: Optional variable scope to use. Returns: `(final_outputs, final_state, final_sequence_lengths)`. Raises: TypeError: if `decoder` is not an instance of `Decoder`. ValueError: if `maximum_iterations` is provided but is not a scalar. """ if not isinstance(decoder, Decoder): raise TypeError("Expected decoder to be type Decoder, but saw: %s" % type(decoder)) with tf.variable_scope(scope, "decoder") as varscope: # Determine context types. ctxt = tf.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access is_xla = control_flow_util.GetContainingXLAContext(ctxt) is not None in_while_loop = (control_flow_util.GetContainingWhileContext(ctxt) is not None) # Properly cache variable values inside the while_loop. # Don't set a caching device when running in a loop, since it is possible # that train steps could be wrapped in a tf.while_loop. In that scenario # caching prevents forward computations in loop iterations from re-reading # the updated weights. if not tf.executing_eagerly() and not in_while_loop: if varscope.caching_device is None: varscope.set_caching_device(lambda op: op.device) if maximum_iterations is not None: maximum_iterations = tf.convert_to_tensor( maximum_iterations, dtype=tf.int32, name="maximum_iterations") if maximum_iterations.get_shape().ndims != 0: raise ValueError("maximum_iterations must be a scalar") initial_finished, initial_inputs, initial_state = decoder.initialize() zero_outputs = _create_zero_outputs(decoder.output_size, decoder.output_dtype, decoder.batch_size) if is_xla and maximum_iterations is None: raise ValueError( "maximum_iterations is required for XLA compilation.") if maximum_iterations is not None: initial_finished = tf.logical_or(initial_finished, 0 >= maximum_iterations) initial_sequence_lengths = tf.zeros_like(initial_finished, dtype=tf.int32) initial_time = tf.constant(0, dtype=tf.int32) def _create_ta(s, d): return tf.zeros([maximum_iterations, decoder.batch_size, s], dtype=d) initial_outputs_ta = contrib_framework.nest.map_structure( _create_ta, decoder.output_size, decoder.output_dtype) def condition(unused_time, unused_outputs_ta, unused_state, unused_inputs, finished, unused_sequence_lengths): return True def body(time, outputs_ta, state, inputs, finished, sequence_lengths): """Internal while_loop body. Args: time: scalar int32 tensor. outputs_ta: structure of TensorArray. state: (structure of) state tensors and TensorArrays. inputs: (structure of) input tensors. finished: bool tensor (keeping track of what's finished). sequence_lengths: int32 tensor (keeping track of time of finish). Returns: `(time + 1, outputs_ta, next_state, next_inputs, next_finished, next_sequence_lengths)`. ``` """ (next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step(time, inputs, state) if decoder.tracks_own_finished: next_finished = decoder_finished else: next_finished = tf.logical_or(decoder_finished, finished) next_sequence_lengths = tf.where( tf.logical_not(finished), tf.fill(tf.shape(sequence_lengths), time + 1), sequence_lengths) contrib_framework.nest.assert_same_structure(state, decoder_state) contrib_framework.nest.assert_same_structure( outputs_ta, next_outputs) contrib_framework.nest.assert_same_structure(inputs, next_inputs) # Zero out output values past finish if impute_finished: emit = contrib_framework.nest.map_structure( lambda out, zero: tf.where(finished, zero, out), next_outputs, zero_outputs) else: emit = next_outputs # Copy through states past finish def _maybe_copy_state(new, cur): # TensorArrays and scalar states get passed through. if isinstance(cur, tf.TensorArray): pass_through = True else: new.set_shape(cur.shape) pass_through = (new.shape.ndims == 0) return new if pass_through else tf.where(finished, cur, new) if impute_finished: next_state = contrib_framework.nest.map_structure( _maybe_copy_state, decoder_state, state) else: next_state = decoder_state outputs_ta = contrib_framework.nest.map_structure( lambda ta, out: inplace_ops.alias_inplace_update( ta, time, out), outputs_ta, emit) return (time + 1, outputs_ta, next_state, next_inputs, next_finished, next_sequence_lengths) res = tf.while_loop(condition, body, loop_vars=( initial_time, initial_outputs_ta, initial_state, initial_inputs, initial_finished, initial_sequence_lengths, ), parallel_iterations=parallel_iterations, maximum_iterations=maximum_iterations, swap_memory=swap_memory) final_outputs_ta = res[1] final_state = res[2] final_sequence_lengths = res[5] final_outputs = final_outputs_ta try: final_outputs, final_state = decoder.finalize( final_outputs, final_state, final_sequence_lengths) except NotImplementedError: pass pred_ids = tf.transpose(final_state.pred_ids, [2, 0, 1]) return pred_ids
def beam_search(symbols_to_logits_fn, initial_ids, beam_size, decode_length, vocab_size, alpha, states=None, eos_id=EOS_ID, stop_early=True, use_tpu=False, use_top_k_with_unique=True): """Beam search with length penalties. Requires a function that can take the currently decoded symbols and return the logits for the next symbol. The implementation is inspired by https://arxiv.org/abs/1609.08144. When running, the beam search steps can be visualized by using tfdbg to watch the operations generating the output ids for each beam step. These operations have the pattern: (alive|finished)_topk_(seq,scores) Operations marked `alive` represent the new beam sequences that will be processed in the next step. Operations marked `finished` represent the completed beam sequences, which may be padded with 0s if no beams finished. Operations marked `seq` store the full beam sequence for the time step. Operations marked `scores` store the sequence's final log scores. The beam search steps will be processed sequentially in order, so when capturing observed from these operations, tensors, clients can make assumptions about which step is being recorded. WARNING: Assumes 2nd dimension of tensors in `states` and not invariant, this means that the shape of the 2nd dimension of these tensors will not be available (i.e. set to None) inside symbols_to_logits_fn. Args: symbols_to_logits_fn: Interface to the model, to provide logits. Shoud take [batch_size, decoded_ids] and return [batch_size, vocab_size] initial_ids: Ids to start off the decoding, this will be the first thing handed to symbols_to_logits_fn (after expanding to beam size) [batch_size] beam_size: Size of the beam. decode_length: Number of steps to decode for. vocab_size: Size of the vocab, must equal the size of the logits returned by symbols_to_logits_fn alpha: alpha for length penalty. states: dict (possibly nested) of decoding states. eos_id: ID for end of sentence. stop_early: a boolean - stop once best sequence is provably determined. use_tpu: A bool, whether to do beam search on TPU. use_top_k_with_unique: bool, whether to use a fast (but decreased precision) top_k during TPU beam search. Returns: Tuple of (decoded beams [batch_size, beam_size, decode_length] decoding probabilities [batch_size, beam_size]) """ batch_size = common_layers.shape_list(initial_ids)[0] # Assume initial_ids are prob 1.0 initial_log_probs = tf.constant([[0.] + [-INF] * (beam_size - 1)]) # Expand to beam_size (batch_size, beam_size) alive_log_probs = tf.tile(initial_log_probs, [batch_size, 1]) # Expand each batch and state to beam_size alive_seq = _expand_to_beam_size(initial_ids, beam_size) alive_seq = tf.expand_dims(alive_seq, axis=2) # (batch_size, beam_size, 1) if use_tpu: alive_seq = tf.tile(alive_seq, [1, 1, decode_length + 1]) if states: states = nest.map_structure( lambda state: _expand_to_beam_size(state, beam_size), states) else: states = {} # Finished will keep track of all the sequences that have finished so far # Finished log probs will be negative infinity in the beginning # finished_flags will keep track of booleans finished_seq = tf.zeros(common_layers.shape_list(alive_seq), tf.int32) # Setting the scores of the initial to negative infinity. finished_scores = tf.ones([batch_size, beam_size]) * -INF finished_flags = tf.zeros([batch_size, beam_size], tf.bool) def grow_finished(finished_seq, finished_scores, finished_flags, curr_seq, curr_scores, curr_finished): """Given sequences and scores, will gather the top k=beam size sequences. Args: finished_seq: Current finished sequences. [batch_size, beam_size, current_decoded_length] finished_scores: scores for each of these sequences. [batch_size, beam_size] finished_flags: finished bools for each of these sequences. [batch_size, beam_size] curr_seq: current topk sequence that has been grown by one position. [batch_size, beam_size, current_decoded_length] curr_scores: scores for each of these sequences. [batch_size, beam_size] curr_finished: Finished flags for each of these sequences. [batch_size, beam_size] Returns: Tuple of (Topk sequences based on scores, log probs of these sequences, Finished flags of these sequences) """ if not use_tpu: # First append a column of 0'ids to finished to make the same length with # finished scores finished_seq = tf.concat( [finished_seq, tf.zeros([batch_size, beam_size, 1], tf.int32)], axis=2) # Set the scores of the unfinished seq in curr_seq to large negative # values curr_scores += (1. - tf.to_float(curr_finished)) * -INF # concatenating the sequences and scores along beam axis curr_finished_seq = tf.concat([finished_seq, curr_seq], axis=1) curr_finished_scores = tf.concat([finished_scores, curr_scores], axis=1) curr_finished_flags = tf.concat([finished_flags, curr_finished], axis=1) return compute_topk_scores_and_seq( curr_finished_seq, curr_finished_scores, curr_finished_scores, curr_finished_flags, beam_size, batch_size, "grow_finished", use_tpu=use_tpu, use_top_k_with_unique=use_top_k_with_unique) def grow_alive(curr_seq, curr_scores, curr_log_probs, curr_finished, states): """Given sequences and scores, will gather the top k=beam size sequences. Args: curr_seq: current topk sequence that has been grown by one position. [batch_size, beam_size, i+1] curr_scores: scores for each of these sequences. [batch_size, beam_size] curr_log_probs: log probs for each of these sequences. [batch_size, beam_size] curr_finished: Finished flags for each of these sequences. [batch_size, beam_size] states: dict (possibly nested) of decoding states. Returns: Tuple of (Topk sequences based on scores, log probs of these sequences, Finished flags of these sequences) """ # Set the scores of the finished seq in curr_seq to large negative # values curr_scores += tf.to_float(curr_finished) * -INF return compute_topk_scores_and_seq(curr_seq, curr_scores, curr_log_probs, curr_finished, beam_size, batch_size, "grow_alive", states, use_tpu=use_tpu) def grow_topk(i, alive_seq, alive_log_probs, states): r"""Inner beam search loop. This function takes the current alive sequences, and grows them to topk sequences where k = 2*beam. We use 2*beam because, we could have beam_size number of sequences that might hit <EOS> and there will be no alive sequences to continue. With 2*beam_size, this will not happen. This relies on the assumption the vocab size is > beam size. If this is true, we'll have at least beam_size non <EOS> extensions if we extract the next top 2*beam words. Length penalty is given by = (5+len(decode)/6) ^ -\alpha. Pls refer to https://arxiv.org/abs/1609.08144. Args: i: loop index alive_seq: Topk sequences decoded so far [batch_size, beam_size, i+1] alive_log_probs: probabilities of these sequences. [batch_size, beam_size] states: dict (possibly nested) of decoding states. Returns: Tuple of (Topk sequences extended by the next word, The log probs of these sequences, The scores with length penalty of these sequences, Flags indicating which of these sequences have finished decoding, dict of transformed decoding states) """ # Get the logits for all the possible next symbols if use_tpu and states: flat_ids = tf.reshape( tf.slice(alive_seq, [0, 0, i], [batch_size, beam_size, 1]), [batch_size * beam_size, -1]) else: flat_ids = tf.reshape(alive_seq, [batch_size * beam_size, -1]) # (batch_size * beam_size, decoded_length) if states: flat_states = nest.map_structure(_merge_beam_dim, states) flat_logits, flat_states = symbols_to_logits_fn(flat_ids, i, flat_states) states = nest.map_structure( lambda t: _unmerge_beam_dim(t, batch_size, beam_size), flat_states) elif use_tpu: flat_logits = symbols_to_logits_fn(flat_ids, i) else: flat_logits = symbols_to_logits_fn(flat_ids) logits = tf.reshape(flat_logits, [batch_size, beam_size, -1]) # Convert logits to normalized log probs candidate_log_probs = common_layers.log_prob_from_logits(logits) # Multiply the probabilities by the current probabilities of the beam. # (batch_size, beam_size, vocab_size) + (batch_size, beam_size, 1) log_probs = candidate_log_probs + tf.expand_dims(alive_log_probs, axis=2) length_penalty = tf.pow(((5. + tf.to_float(i + 1)) / 6.), alpha) curr_scores = log_probs / length_penalty # Flatten out (beam_size, vocab_size) probs in to a list of possibilities flat_curr_scores = tf.reshape(curr_scores, [-1, beam_size * vocab_size]) if use_tpu and use_top_k_with_unique: topk_scores, topk_ids = top_k_with_unique( flat_curr_scores, k=beam_size * 2) else: topk_scores, topk_ids = tf.nn.top_k(flat_curr_scores, k=beam_size * 2) # Recovering the log probs because we will need to send them back topk_log_probs = topk_scores * length_penalty # Work out what beam the top probs are in. topk_beam_index = topk_ids // vocab_size topk_ids %= vocab_size # Unflatten the ids if not use_tpu: # The next three steps are to create coordinates for tf.gather_nd to pull # out the correct sequences from id's that we need to grow. # We will also use the coordinates to gather the booleans of the beam # items that survived. batch_pos = compute_batch_indices(batch_size, beam_size * 2) # top beams will give us the actual coordinates to do the gather. # stacking will create a tensor of dimension batch * beam * 2, where the # last dimension contains the i,j gathering coordinates. topk_coordinates = tf.stack([batch_pos, topk_beam_index], axis=2) # Gather up the most probable 2*beams both for the ids and # finished_in_alive bools topk_seq = tf.gather_nd(alive_seq, topk_coordinates) if states: states = nest.map_structure( lambda state: tf.gather_nd(state, topk_coordinates), states) # Append the most probable alive topk_seq = tf.concat([topk_seq, tf.expand_dims(topk_ids, axis=2)], axis=2) else: # Gather up the most probable 2*beams both for the ids and # finished_in_alive bools topk_seq = fast_tpu_gather(alive_seq, topk_beam_index) if states: states = nest.map_structure( lambda state: fast_tpu_gather(state, topk_beam_index), states) # Update the most probable alive topk_seq = tf.transpose(topk_seq, perm=[2, 0, 1]) topk_seq = inplace_ops.alias_inplace_update(topk_seq, i + 1, topk_ids) topk_seq = tf.transpose(topk_seq, perm=[1, 2, 0]) topk_finished = tf.equal(topk_ids, eos_id) return topk_seq, topk_log_probs, topk_scores, topk_finished, states def inner_loop(i, alive_seq, alive_log_probs, finished_seq, finished_scores, finished_flags, states): """Inner beam search loop. There are three groups of tensors, alive, finished, and topk. The alive group contains information about the current alive sequences The topk group contains information about alive + topk current decoded words the finished group contains information about finished sentences, that is, the ones that have decoded to <EOS>. These are what we return. The general beam search algorithm is as follows: While we haven't terminated (pls look at termination condition) 1. Grow the current alive to get beam*2 topk sequences 2. Among the topk, keep the top beam_size ones that haven't reached EOS into alive 3. Among the topk, keep the top beam_size ones have reached EOS into finished Repeat To make things simple with using fixed size tensors, we will end up inserting unfinished sequences into finished in the beginning. To stop that we add -ve INF to the score of the unfinished sequence so that when a true finished sequence does appear, it will have a higher score than all the unfinished ones. Args: i: loop index alive_seq: Topk sequences decoded so far [batch_size, beam_size, i+1] alive_log_probs: probabilities of the beams. [batch_size, beam_size] finished_seq: Current finished sequences. [batch_size, beam_size, i+1] finished_scores: scores for each of these sequences. [batch_size, beam_size] finished_flags: finished bools for each of these sequences. [batch_size, beam_size] states: dict (possibly nested) of decoding states. Returns: Tuple of (Incremented loop index New alive sequences, Log probs of the alive sequences, New finished sequences, Scores of the new finished sequences, Flags indicating which sequence in finished as reached EOS, dict of final decoding states) """ # Each inner loop, we carry out three steps: # 1. Get the current topk items. # 2. Extract the ones that have finished and haven't finished # 3. Recompute the contents of finished based on scores. topk_seq, topk_log_probs, topk_scores, topk_finished, states = grow_topk( i, alive_seq, alive_log_probs, states) alive_seq, alive_log_probs, _, states = grow_alive( topk_seq, topk_scores, topk_log_probs, topk_finished, states) finished_seq, finished_scores, finished_flags, _ = grow_finished( finished_seq, finished_scores, finished_flags, topk_seq, topk_scores, topk_finished) return (i + 1, alive_seq, alive_log_probs, finished_seq, finished_scores, finished_flags, states) def _is_not_finished(i, unused_alive_seq, alive_log_probs, unused_finished_seq, finished_scores, unused_finished_in_finished, unused_states): """Checking termination condition. We terminate when we decoded up to decode_length or the lowest scoring item in finished has a greater score that the highest prob item in alive divided by the max length penalty Args: i: loop index alive_log_probs: probabilities of the beams. [batch_size, beam_size] finished_scores: scores for each of these sequences. [batch_size, beam_size] Returns: Bool. """ max_length_penalty = tf.pow(((5. + tf.to_float(decode_length)) / 6.), alpha) # The best possible score of the most likely alive sequence. lower_bound_alive_scores = alive_log_probs[:, 0] / max_length_penalty if not stop_early: # by considering the min score (in the top N beams) we ensure that # the decoder will keep decoding until there is at least one beam # (in the top N) that can be improved (w.r.t. the alive beams). # any unfinished beam will have score -INF - thus the min # will always be -INF if there is at least one unfinished beam - # which means the bound_is_met condition cannot be true in this case. lowest_score_of_finished_in_finished = tf.reduce_min(finished_scores) else: # by taking the max score we only care about the first beam; # as soon as this first beam cannot be beaten from the alive beams # the beam decoder can stop. # similarly to the above, if the top beam is not completed, its # finished_score is -INF, thus it will not activate the # bound_is_met condition. (i.e., decoder will keep going on). # note we need to find the max for every sequence eparately - so, we need # to keep the batch dimension (see axis=1) lowest_score_of_finished_in_finished = tf.reduce_max(finished_scores, axis=1) bound_is_met = tf.reduce_all( tf.greater(lowest_score_of_finished_in_finished, lower_bound_alive_scores)) return tf.logical_and( tf.less(i, decode_length), tf.logical_not(bound_is_met)) inner_shape = tf.TensorShape([None, None, None]) if use_tpu: inner_shape = tf.TensorShape([batch_size, beam_size, decode_length + 1]) if use_tpu: state_struc = nest.map_structure(lambda state: state.get_shape(), states) else: state_struc = nest.map_structure(get_state_shape_invariants, states) (_, alive_seq, alive_log_probs, finished_seq, finished_scores, finished_flags, states) = tf.while_loop( _is_not_finished, inner_loop, [ tf.constant(0), alive_seq, alive_log_probs, finished_seq, finished_scores, finished_flags, states ], shape_invariants=[ tf.TensorShape([]), inner_shape, alive_log_probs.get_shape(), inner_shape, finished_scores.get_shape(), finished_flags.get_shape(), state_struc ], parallel_iterations=1, back_prop=False) alive_seq.set_shape((None, beam_size, None)) finished_seq.set_shape((None, beam_size, None)) # Accounting for corner case: It's possible that no sequence in alive for a # particular batch item ever reached EOS. In that case, we should just copy # the contents of alive for that batch item. tf.reduce_any(finished_flags, 1) # if 0, means that no sequence for that batch index had reached EOS. We need # to do the same for the scores as well. finished_seq = tf.where( tf.reduce_any(finished_flags, 1), finished_seq, alive_seq) finished_scores = tf.where( tf.reduce_any(finished_flags, 1), finished_scores, alive_log_probs) return finished_seq, finished_scores, states
def setup(act_fun): channel_num = 3 if FLAGS.mnist_model: print("------------------Using MNIST model------------") model = MnistNet( num_channels=channel_num, num_filters=128, act_fun=act_fun) elif FLAGS.large_model: print("------------------Using ResNet32Large model------------") model = ResNet32Large( num_channels=channel_num, num_filters=128, train=True, act_fun=act_fun) elif FLAGS.larger_model: print("------------------Using ResNet32Larger model------------") model = ResNet32Larger( num_channels=channel_num, num_filters=128, act_fun=act_fun) elif FLAGS.wider_model: print("------------------Using ResNet32Wider model------------") model = ResNet32Wider( num_channels=channel_num, num_filters=192, act_fun=act_fun) else: print("------------------Using ResNet32 model------------") model = ResNet32( num_channels=channel_num, num_filters=128, act_fun=act_fun) batch_size = FLAGS.batch_size weights = [model.construct_weights('context_0')] Y = tf.placeholder(shape=(None), dtype=tf.int32) LABEL = None X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32) LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32) LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32) # Varibles to run in training X_SPLIT = tf.split(X, FLAGS.num_gpus) X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus) LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus) LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus) LABEL_SPLIT_INIT = list(LABEL_SPLIT) tower_grads = [] tower_gen_grads = [] x_mod_list = [] optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.999) optimizer = hvd.DistributedOptimizer(optimizer) for j in range(FLAGS.num_gpus): if FLAGS.model_cclass: ind_batch_size = FLAGS.batch_size // FLAGS.num_gpus label_tensor = tf.Variable( tf.convert_to_tensor( np.reshape( np.tile(np.eye(10), (FLAGS.batch_size, 1, 1)), (FLAGS.batch_size * 10, 10)), dtype=tf.float32), trainable=False, dtype=tf.float32) x_split = tf.tile( tf.reshape( X_SPLIT[j], (ind_batch_size, 1, 32, 32, 3)), (1, 10, 1, 1, 1)) x_split = tf.reshape(x_split, (ind_batch_size * 10, 32, 32, 3)) energy_pos = model.forward( x_split, weights[0], label=label_tensor, stop_at_grad=False) energy_pos_full = tf.reshape(energy_pos, (ind_batch_size, 10)) energy_partition_est = tf.reduce_logsumexp( energy_pos_full, axis=1, keepdims=True) uniform = tf.random_uniform(tf.shape(energy_pos_full)) label_tensor = tf.argmax(-energy_pos_full - tf.log(-tf.log(uniform)) - energy_partition_est, axis=1) label = tf.one_hot(label_tensor, 10, dtype=tf.float32) label = tf.Print(label, [label_tensor, energy_pos_full]) LABEL_SPLIT[j] = label energy_pos = tf.concat(energy_pos, axis=0) else: energy_pos = [ model.forward( X_SPLIT[j], weights[0], label=LABEL_POS_SPLIT[j], stop_at_grad=False)] energy_pos = tf.concat(energy_pos, axis=0) print("Building graph...") x_mod = x_orig = X_NOISE_SPLIT[j] x_grads = [] energy_negs = [] loss_energys = [] energy_negs.extend([model.forward(tf.stop_gradient( x_mod), weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True)]) eps_begin = tf.zeros(1) steps = tf.constant(0) c = lambda i, x: tf.less(i, FLAGS.num_steps) def langevin_step(counter, x_mod): x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005 * FLAGS.rescale * FLAGS.noise_scale) energy_noise = energy_start = tf.concat( [model.forward( x_mod, weights[0], label=LABEL_SPLIT[j], reuse=True, stop_at_grad=False, stop_batch=True)], axis=0) x_grad, label_grad = tf.gradients( FLAGS.temperature * energy_noise, [x_mod, LABEL_SPLIT[j]]) energy_noise_old = energy_noise lr = FLAGS.step_lr if FLAGS.proj_norm != 0.0: if FLAGS.proj_norm_type == 'l2': x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm) elif FLAGS.proj_norm_type == 'li': x_grad = tf.clip_by_value( x_grad, -FLAGS.proj_norm, FLAGS.proj_norm) else: print("Other types of projection are not supported!!!") assert False # Clip gradient norm for now if FLAGS.hmc: # Step size should be tuned to get around 65% acceptance def energy(x): return FLAGS.temperature * \ model.forward(x, weights[0], label=LABEL_SPLIT[j], reuse=True) x_last = hmc(x_mod, 15., 10, energy) else: x_last = x_mod - (lr) * x_grad x_mod = x_last x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale) counter = counter + 1 return counter, x_mod steps, x_mod = tf.while_loop(c, langevin_step, (steps, x_mod)) energy_eval = model.forward(x_mod, weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True) x_grad = tf.gradients(FLAGS.temperature * energy_eval, [x_mod])[0] x_grads.append(x_grad) energy_negs.append( model.forward( tf.stop_gradient(x_mod), weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True)) test_x_mod = x_mod temp = FLAGS.temperature energy_neg = energy_negs[-1] x_off = tf.reduce_mean( tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j])) loss_energy = model.forward( x_mod, weights[0], reuse=True, label=LABEL, stop_grad=True) print("Finished processing loop construction ...") target_vars = {} if FLAGS.cclass or FLAGS.model_cclass: label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0) label_prob = label_sum / tf.reduce_sum(label_sum) label_ent = -tf.reduce_sum(label_prob * tf.math.log(label_prob + 1e-7)) else: label_ent = tf.zeros(1) target_vars['label_ent'] = label_ent if FLAGS.train: if FLAGS.objective == 'logsumexp': pos_term = temp * energy_pos energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg)) coeff = tf.stop_gradient(tf.exp(-temp * energy_neg_reduced)) norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4 pos_loss = tf.reduce_mean(temp * energy_pos) neg_loss = coeff * (-1 * temp * energy_neg) / norm_constant loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss)) elif FLAGS.objective == 'cd': pos_loss = tf.reduce_mean(temp * energy_pos) neg_loss = -tf.reduce_mean(temp * energy_neg) loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss)) elif FLAGS.objective == 'softplus': loss_ml = FLAGS.ml_coeff * \ tf.nn.softplus(temp * (energy_pos - energy_neg)) loss_total = tf.reduce_mean(loss_ml) if not FLAGS.zero_kl: loss_total = loss_total + tf.reduce_mean(loss_energy) loss_total = loss_total + \ FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square((energy_neg)))) print("Started gradient computation...") gvs = optimizer.compute_gradients(loss_total) gvs = [(k, v) for (k, v) in gvs if k is not None] print("Applying gradients...") tower_grads.append(gvs) print("Finished applying gradients.") target_vars['loss_ml'] = loss_ml target_vars['total_loss'] = loss_total target_vars['loss_energy'] = loss_energy target_vars['weights'] = weights target_vars['gvs'] = gvs target_vars['X'] = X target_vars['Y'] = Y target_vars['LABEL'] = LABEL target_vars['LABEL_POS'] = LABEL_POS target_vars['X_NOISE'] = X_NOISE target_vars['energy_pos'] = energy_pos target_vars['energy_start'] = energy_negs[0] if len(x_grads) >= 1: target_vars['x_grad'] = x_grads[-1] target_vars['x_grad_first'] = x_grads[0] else: target_vars['x_grad'] = tf.zeros(1) target_vars['x_grad_first'] = tf.zeros(1) target_vars['x_mod'] = x_mod target_vars['x_off'] = x_off target_vars['temp'] = temp target_vars['energy_neg'] = energy_neg target_vars['test_x_mod'] = test_x_mod target_vars['eps_begin'] = eps_begin if FLAGS.train: grads = average_gradients(tower_grads) train_op = optimizer.apply_gradients(grads) target_vars['train_op'] = train_op config = tf.ConfigProto() if hvd.size() > 1: config.gpu_options.visible_device_list = str(hvd.local_rank()) sess = tf.Session(config=config) saver = loader = tf.train.Saver(max_to_keep=30, keep_checkpoint_every_n_hours=6) total_parameters = 0 for variable in tf.trainable_variables(): # shape is an array of tf.Dimension shape = variable.get_shape() variable_parameters = 1 for dim in shape: variable_parameters *= dim.value total_parameters += variable_parameters print("Model has a total of {} parameters".format(total_parameters)) sess.run(tf.global_variables_initializer()) resume_itr = 0 if (FLAGS.resume_iter != -1 or not FLAGS.train) and hvd.rank() == 0: model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter)) resume_itr = FLAGS.resume_iter # saver.restore(sess, model_file) optimistic_restore(sess, model_file) sess.run(hvd.broadcast_global_variables(0)) return target_vars, saver, sess, resume_itr
def _spsa_gradients(loss_fn, x, delta=0.01, num_samples=16, num_iterations=4): """Compute gradient estimates using SPSA. Args: loss_fn: Callable that takes a single argument of shape [batch_size, ...] and returns the loss contribution of each element of the batch as a tensor of shape [batch_size]. x: List of tensors with a single element. We only support computation of the gradient of the loss with respect to x[0]. We take a list as input to keep the same API call as tf.gradients. delta: The gradients are computed by computing the loss within x - delta and x + delta. num_samples: The total number of random samples used to compute the gradient is `num_samples` times `num_iterations`. `num_samples` contributes to the gradient by tiling `x` `num_samples` times. num_iterations: The total number of random samples used to compute the gradient is `num_samples` times `num_iterations`. `num_iterations` contributes to the gradient by iterating using a `tf.while_loop`. Returns: List of tensors with a single element corresponding to the gradient of loss_fn(x[0]) with respect to x[0]. """ if len(x) != 1: raise NotImplementedError('SPSA gradients with respect to multiple ' 'variables is not supported.') # loss_fn takes a single argument. tensor = x[0] def _get_delta(x): return delta * tf.sign( tf.random_uniform( tf.shape(x), minval=-1., maxval=1., dtype=x.dtype)) # Process batch_size samples at a time. def cond(i, *_): return tf.less(i, num_iterations) def loop_body(i, total_grad): """Compute gradient estimate.""" batch_size = tf.shape(tensor)[0] # The tiled tensor has shape [num_samples, batch_size, ...] tiled_tensor = tf.expand_dims(tensor, axis=0) tiled_tensor = tf.tile(tiled_tensor, [num_samples] + [1] * len(tensor.shape)) # The tiled tensor has now shape [2, num_samples, batch_size, ...]. delta = _get_delta(tiled_tensor) tiled_tensor = tf.stack([tiled_tensor + delta, tiled_tensor - delta], axis=0) # Compute loss with shape [2, num_samples, batch_size]. losses = loss_fn( tf.reshape(tiled_tensor, [2 * num_samples, batch_size] + tensor.shape.as_list()[1:])) losses = tf.reshape(losses, [2, num_samples, batch_size]) # Compute approximate gradient using broadcasting. shape = losses.shape.as_list() + [1] * (len(tensor.shape) - 1) shape = [(s or -1) for s in shape] # Remove None. losses = tf.reshape(losses, shape) g = tf.reduce_mean((losses[0] - losses[1]) / (2. * delta), axis=0) return [i + 1, g / num_iterations + total_grad] _, g = tf.while_loop(cond, loop_body, loop_vars=[tf.constant(0.), tf.zeros_like(tensor)], parallel_iterations=1, back_prop=False) return [g]
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf.logging.info("*** Features ***") for name in sorted(features.keys()): tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) unique_ids = features["unique_ids"] input_ids = features["input_ids"] segment_ids = features["segment_ids"] is_training = (mode == tf.estimator.ModeKeys.TRAIN) seq_length = modeling.get_shape_list(input_ids)[1] query_length = FLAGS.max_query_length batch_size = params["batch_size"] _, attention_mask = make_attention_mask(batch_size, query_length, seq_length) with tf.variable_scope("bert") as scope: word_logits = create_model( bert_config=bert_config, is_training=is_training, input_ids=input_ids, input_mask=attention_mask, segment_ids=segment_ids, use_one_hot_embeddings=use_one_hot_embeddings, scope=scope) if not is_training: with tf.variable_scope("bert", reuse=True) as scope: output_ids = input_ids word_id = tf.argmax(word_logits, axis=2, output_type=tf.int32) # This operation implements: output_ids[:, 2] = word_id[:, 0] word_id = tf.pad(word_id, [[0, 0], [2, seq_length - query_length]]) output_ids = input_ids + word_id * tf.one_hot( 2, seq_length, dtype=tf.int32) def body(i, ids): """A decoding step.""" word_logits = create_model( bert_config=bert_config, is_training=is_training, input_ids=ids, input_mask=attention_mask, segment_ids=segment_ids, use_one_hot_embeddings=use_one_hot_embeddings, scope=scope) word_id = tf.argmax(word_logits, axis=2, output_type=tf.int32) # This operation implements: output_ids[:, 1 + i] = word_id[:, i - 1] word_id = tf.pad(word_id, [[0, 0], [2, seq_length - query_length]]) return [ i + 1, ids + word_id * tf.one_hot(i + 1, seq_length, dtype=tf.int32) ] i0 = tf.constant(2) c = lambda i, _: i < query_length - 1 _, output_ids = tf.while_loop(c, body, loop_vars=[i0, output_ids]) tvars = tf.trainable_variables() initialized_variable_names = {} scaffold_fn = None if init_checkpoint: (assignment_map, initialized_variable_names ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) tf.logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: # Computes the loss for word prediction. loss = tf.losses.sparse_softmax_cross_entropy( input_ids[:, 2:query_length], word_logits, reduction=tf.losses.Reduction.MEAN) train_op = optimization.create_optimizer(loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) elif mode == tf.estimator.ModeKeys.PREDICT: predictions = { "unique_ids": tf.identity(unique_ids), "input_ids": output_ids, "segment_ids": tf.minimum(segment_ids, 1), "input_mask": tf.to_int32(tf.not_equal(output_ids, 0)), "start_positions": tf.identity(features["start_positions"]), "end_positions": tf.identity(features["end_positions"]), "answer_types": tf.identity(features["answer_types"]) } output_spec = tf.estimator.tpu.TPUEstimatorSpec( mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) else: raise ValueError("Only TRAIN and PREDICT modes are supported: %s" % (mode)) return output_spec
def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0): if start_token is None: assert context is not None, 'Specify exactly one of start_token and context!' else: assert context is None, 'Specify exactly one of start_token and context!' context = tf.fill([batch_size, 1], start_token) def step(hparams, tokens, past=None): lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE) logits = lm_output['logits'][:, :, :hparams.n_vocab] presents = lm_output['present'] presents.set_shape( model.past_shape(hparams=hparams, batch_size=batch_size)) return { 'logits': logits, 'presents': presents, } with tf.name_scope('sample_sequence'): # Don't feed the last context token -- leave that to the loop below # TODO: Would be slightly faster if we called step on the entire context, # rather than leaving the last token transformer calculation to the while loop. context_output = step(hparams, context[:, :-1]) def body(past, prev, output): next_outputs = step(hparams, prev[:, tf.newaxis], past=past) logits = next_outputs['logits'][:, -1, :] / \ tf.to_float(temperature) logits = top_k_logits(logits, k=top_k) samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32) return [ tf.concat([past, next_outputs['presents']], axis=-2), tf.squeeze(samples, axis=[1]), tf.concat([output, samples], axis=1), ] def cond(*args): return True _, _, tokens = tf.while_loop( cond=cond, body=body, maximum_iterations=length, loop_vars=[ context_output['presents'], context[:, -1], context, ], shape_invariants=[ tf.TensorShape( model.past_shape(hparams=hparams, batch_size=batch_size)), tf.TensorShape([batch_size]), tf.TensorShape([batch_size, None]), ], back_prop=False, ) return tokens