def solve(self, *, arguments, x_init, fn_x=None, **kwargs): """ Iteratively solves an equation/optimization for $x$ involving an expression $f(x)$. Args: arguments: ??? x_init: Initial solution guess $x_0$. fn_x: A callable returning an expression $f(x)$ given $x$. **values: Additional solver-specific arguments. Returns: A solution $x$ to the problem as given by the solver. """ self.fn_x = fn_x signature = self.input_signature(function='step') # Initialization step values = self.start(arguments=arguments, x_init=x_init, **kwargs) # Iteration loop with termination condition max_iterations = self.max_iterations.value() values = signature.kwargs_to_args(kwargs=values, is_outer_args=True) values = tf.while_loop( cond=self.next_step, body=self.step, loop_vars=tuple(values), maximum_iterations=tf_util.int32(x=max_iterations) ) values = signature.args_to_kwargs(args=values) solution = self.end(**values.to_kwargs()) return solution
def apply(self, *, x): x = tf_util.int32(x=x) x = tf.nn.embedding_lookup(params=self.weights, ids=x, max_norm=self.max_norm) return super().apply(x=x)
def step(self, *, arguments, variables, fn_loss, **kwargs): learning_rate = self.learning_rate.value() unperturbed_loss = fn_loss(**arguments.to_kwargs()) deltas = [tf.zeros_like(input=variable) for variable in variables] previous_perturbations = [ tf.zeros_like(input=variable) for variable in variables ] def body(deltas, previous_perturbations): with tf.control_dependencies(control_inputs=deltas): perturbations = [ learning_rate * tf.random.normal(shape=tf_util.shape(x=variable), dtype=tf_util.get_dtype(type='float')) for variable in variables ] perturbation_deltas = [ pert - prev_pert for pert, prev_pert in zip( perturbations, previous_perturbations) ] assignments = list() for variable, delta in zip(variables, perturbation_deltas): assignments.append( variable.assign_add(delta=delta, read_value=False)) with tf.control_dependencies(control_inputs=assignments): perturbed_loss = fn_loss(**arguments.to_kwargs()) direction = tf.math.sign(x=(unperturbed_loss - perturbed_loss)) deltas = [ delta + direction * perturbation for delta, perturbation in zip(deltas, perturbations) ] return deltas, perturbations num_samples = self.num_samples.value() deltas, perturbations = tf.while_loop( cond=tf_util.always_true, body=body, loop_vars=(deltas, previous_perturbations), maximum_iterations=tf_util.int32(x=num_samples)) with tf.control_dependencies(control_inputs=deltas): num_samples = tf_util.cast(x=num_samples, dtype='float') deltas = [delta / num_samples for delta in deltas] perturbation_deltas = [ delta - pert for delta, pert in zip(deltas, perturbations) ] assignments = list() for variable, delta in zip(variables, perturbation_deltas): assignments.append( variable.assign_add(delta=delta, read_value=False)) with tf.control_dependencies(control_inputs=assignments): # Trivial operation to enforce control dependency return [tf_util.identity(input=delta) for delta in deltas]
def apply(self, *, x): output_shape = tf.concat(values=[ tf_util.cast(x=tf.shape(input=x)[:1], dtype='int'), tf_util.constant(value=self.output_shape, dtype='int') ], axis=0) x = tf.nn.conv2d_transpose( input=x, filters=self.weights, output_shape=tf_util.int32(x=output_shape), strides=self.stride, padding=self.padding.upper(), dilations=self.dilation ) return super().apply(x=x)
def step(self, *, arguments, variables, **kwargs): deltas = [tf.zeros_like(input=variable) for variable in variables] def body(*deltas): with tf.control_dependencies(control_inputs=deltas): step_deltas = self.optimizer.step(arguments=arguments, variables=variables, **kwargs) deltas = [ delta1 + delta2 for delta1, delta2 in zip(deltas, step_deltas) ] return deltas num_steps = self.num_steps.value() deltas = tf.while_loop(cond=tf_util.always_true, body=body, loop_vars=deltas, maximum_iterations=tf_util.int32(x=num_steps)) return deltas
def apply(self, *, x, horizons, internals): zero = tf_util.constant(value=0, dtype='int') one = tf_util.constant(value=1, dtype='int') batch_size = tf_util.cast(x=tf.shape(input=horizons)[0], dtype='int') zeros = tf_util.zeros(shape=(batch_size, ), dtype='int') ones = tf_util.ones(shape=(batch_size, ), dtype='int') # including 0th step horizon = self.horizon.value() + one # in case of longer horizon than necessary (e.g. main vs baseline policy) starts = horizons[:, 0] + tf.maximum(x=(horizons[:, 1] - horizon), y=zeros) lengths = horizons[:, 1] - tf.maximum(x=(horizons[:, 1] - horizon), y=zeros) horizon = tf.minimum(x=horizon, y=tf.math.reduce_max(input_tensor=lengths, axis=0)) output_spec = self.output_spec() if self.temporal_processing == 'cumulative': if self.horizon.is_constant(value=0): x = self.iterative_apply(xs=x, lengths=ones) else: def body(x, indices, remaining, xs): current_x = tf.gather(params=x, indices=indices) current_x = tf.expand_dims(input=current_x, axis=1) xs = tf.concat(values=(xs, current_x), axis=1) remaining -= tf.where(condition=tf.math.equal(x=remaining, y=zeros), x=zeros, y=ones) indices += tf.where(condition=tf.math.equal(x=remaining, y=zeros), x=zeros, y=ones) return x, indices, remaining, xs initial_xs = tf_util.zeros(shape=((batch_size, 0) + output_spec.shape), dtype=output_spec.type) _, final_indices, final_remaining, xs = tf.while_loop( cond=tf_util.always_true, body=body, loop_vars=(x, starts, lengths, initial_xs), maximum_iterations=tf_util.int64(x=horizon)) x = self.cumulative_apply(xs=xs, lengths=lengths) elif self.temporal_processing == 'iterative': if self.horizon.is_constant(value=0): x, final_internals = self.iterative_apply(x=x, internals=internals) else: initial_x = tf_util.zeros(shape=((batch_size, ) + output_spec.shape), dtype=output_spec.type) signature = self.input_signature(function='iterative_body') internals = signature['current_internals'].kwargs_to_args( kwargs=internals) _, final_indices, final_remaining, x, final_internals = tf.while_loop( cond=tf_util.always_true, body=self.iterative_body, loop_vars=(x, starts, lengths, initial_x, internals), maximum_iterations=tf_util.int32(x=horizon)) internals = signature['current_internals'].args_to_kwargs( args=final_internals) assertions = list() if self.config.create_tf_assertions: assertions.append( tf.debugging.assert_equal(x=final_indices, y=(tf.math.cumsum(x=lengths) - ones))) assertions.append( tf.debugging.assert_equal( x=tf.math.reduce_sum(input_tensor=final_remaining), y=zero)) with tf.control_dependencies(control_inputs=assertions): if self.temporal_processing == 'cumulative': return tf_util.identity(input=super().apply(x=x)) elif self.temporal_processing == 'iterative': return tf_util.identity(input=super().apply(x=x)), internals
def successors(self, *, indices, horizon, sequence_values, final_values): assert isinstance(sequence_values, tuple) assert isinstance(final_values, tuple) zero = tf_util.constant(value=0, dtype='int') one = tf_util.constant(value=1, dtype='int') capacity = tf_util.constant(value=self.capacity, dtype='int') def body(lengths, successor_indices, mask): current_index = successor_indices[:, -1:] current_terminal = tf.gather(params=self.buffers['terminal'], indices=current_index) is_not_terminal = tf.math.logical_and(x=tf.math.logical_not( x=tf.math.greater(x=current_terminal, y=zero)), y=mask[:, -1:]) next_index = tf.math.mod(x=(current_index + one), y=capacity) successor_indices = tf.concat(values=(successor_indices, next_index), axis=1) mask = tf.concat(values=(mask, is_not_terminal), axis=1) is_not_terminal = tf.squeeze(input=is_not_terminal, axis=1) zeros = tf.zeros_like(input=is_not_terminal, dtype=tf_util.get_dtype(type='int')) ones = tf.ones_like(input=is_not_terminal, dtype=tf_util.get_dtype(type='int')) lengths += tf.where(condition=is_not_terminal, x=ones, y=zeros) return lengths, successor_indices, mask lengths = tf.ones_like(input=indices, dtype=tf_util.get_dtype(type='int')) successor_indices = tf.expand_dims(input=indices, axis=1) mask = tf.ones_like(input=successor_indices, dtype=tf_util.get_dtype(type='bool')) shape = tf.TensorShape(dims=((None, None))) lengths, successor_indices, mask = tf.while_loop( cond=tf_util.always_true, body=body, loop_vars=(lengths, successor_indices, mask), shape_invariants=(lengths.get_shape(), shape, shape), maximum_iterations=tf_util.int32(x=horizon)) successor_indices = tf.reshape(tensor=successor_indices, shape=(-1, )) mask = tf.reshape(tensor=mask, shape=(-1, )) successor_indices = tf.boolean_mask(tensor=successor_indices, mask=mask, axis=0) assertions = list() if self.config.create_tf_assertions: assertions.append( tf.debugging.assert_greater_equal(x=tf.math.mod( x=(self.buffer_index - one - successor_indices), y=capacity), y=zero, message="Successor check.")) with tf.control_dependencies(control_inputs=assertions): function = (lambda buffer: tf.gather(params=buffer, indices=successor_indices)) values = self.buffers[sequence_values].fmap(function=function, cls=TensorDict) sequence_values = tuple(values[name] for name in sequence_values) starts = tf.math.cumsum(x=lengths, exclusive=True) ends = tf.math.cumsum(x=lengths) - one final_indices = tf.gather(params=successor_indices, indices=ends) function = ( lambda buffer: tf.gather(params=buffer, indices=final_indices)) values = self.buffers[final_values].fmap(function=function, cls=TensorDict) final_values = tuple(values[name] for name in final_values) if len(sequence_values) == 0: if len(final_values) == 0: return lengths else: return lengths, final_values elif len(final_values) == 0: return tf.stack(values=(starts, lengths), axis=1), sequence_values else: return tf.stack(values=(starts, lengths), axis=1), sequence_values, final_values
def step(self, *, arguments, variables, fn_loss, **kwargs): learning_rate = self.learning_rate.value() unperturbed_loss = fn_loss(**arguments.to_kwargs()) if self.num_samples.is_constant(value=1): deltas = list() for variable in variables: delta = tf.random.normal(shape=variable.shape, dtype=variable.dtype) if variable.dtype == tf_util.get_dtype(type='float'): deltas.append(learning_rate * delta) else: deltas.append( tf.cast(x=learning_rate, dtype=variable.dtype) * delta) assignments = list() for variable, delta in zip(variables, deltas): assignments.append( variable.assign_add(delta=delta, read_value=False)) with tf.control_dependencies(control_inputs=assignments): perturbed_loss = fn_loss(**arguments.to_kwargs()) def negate_deltas(): neg_two_float = tf_util.constant(value=-2.0, dtype='float') assignments = list() for variable, delta in zip(variables, deltas): if variable.dtype == tf_util.get_dtype(type='float'): assignments.append( variable.assign_add(delta=(neg_two_float * delta), read_value=False)) else: _ng_two_float = tf.constant(value=-2.0, dtype=variable.dtype) assignments.append( variable.assign_add(delta=(_ng_two_float * delta), read_value=False)) with tf.control_dependencies(control_inputs=assignments): return [tf.math.negative(x=delta) for delta in deltas] return tf.cond(pred=(perturbed_loss < unperturbed_loss), true_fn=(lambda: deltas), false_fn=negate_deltas) else: deltas = [tf.zeros_like(input=variable) for variable in variables] previous_perturbations = [ tf.zeros_like(input=variable) for variable in variables ] def body(deltas, previous_perturbations): with tf.control_dependencies(control_inputs=deltas): perturbations = list() for variable in variables: perturbation = tf.random.normal(shape=variable.shape, dtype=variable.dtype) if variable.dtype == tf_util.get_dtype(type='float'): perturbations.append(learning_rate * perturbation) else: perturbations.append( tf.cast(x=learning_rate, dtype=variable.dtype) * perturbation) perturbation_deltas = [ pert - prev_pert for pert, prev_pert in zip( perturbations, previous_perturbations) ] assignments = list() for variable, delta in zip(variables, perturbation_deltas): assignments.append( variable.assign_add(delta=delta, read_value=False)) with tf.control_dependencies(control_inputs=assignments): perturbed_loss = fn_loss(**arguments.to_kwargs()) one_float = tf_util.constant(value=1.0, dtype='float') neg_one_float = tf_util.constant(value=-1.0, dtype='float') direction = tf.where( condition=(perturbed_loss < unperturbed_loss), x=one_float, y=neg_one_float) next_deltas = list() for variable, delta, perturbation in zip( variables, deltas, perturbations): if variable.dtype == tf_util.get_dtype(type='float'): next_deltas.append(delta + direction * perturbation) else: next_deltas.append( delta + tf.cast(x=direction, dtype=variable.dtype) * perturbation) return next_deltas, perturbations num_samples = self.num_samples.value() deltas, perturbations = tf.while_loop( cond=tf_util.always_true, body=body, loop_vars=(deltas, previous_perturbations), maximum_iterations=tf_util.int32(x=num_samples)) with tf.control_dependencies(control_inputs=deltas): num_samples = tf_util.cast(x=num_samples, dtype='float') deltas = [delta / num_samples for delta in deltas] perturbation_deltas = [ delta - pert for delta, pert in zip(deltas, perturbations) ] assignments = list() for variable, delta in zip(variables, perturbation_deltas): assignments.append( variable.assign_add(delta=delta, read_value=False)) with tf.control_dependencies(control_inputs=assignments): # Trivial operation to enforce control dependency return [tf_util.identity(input=delta) for delta in deltas]