def main(_): batch_size = tf.placeholder(shape=(), dtype=tf.int32, name='batch_size') cached_reader, training_batch = load_mnist(batch_size) # Create autoencoder model. training_model = AutoEncoder(_ENCODER_LAYER_SIZE, weight_decay=FLAGS.weight_decay) layer_collection = kfac.LayerCollection() batch_loss, batch_error = training_model.build(layer_collection)( (training_batch, )) # Minimize loss. train_op = minimize_loss(batch_size, batch_loss, layer_collection, loss_fn=lambda prev_batch: training_model.build() (prev_batch)[0], cached_reader=cached_reader) # Fit model. global_step = tf.train.get_or_create_global_step() with tf.train.MonitoredTrainingSession(save_checkpoint_secs=30) as sess: while not sess.should_stop(): i = sess.run(global_step) # Update the covariance matrices. Also updates the damping parameter # every damping adaptation interval. Note that the # `cached_reader.cached_batch` is paased to `opt.KfacOptimizer`. _, batch_loss_, batch_error_ = sess.run( [train_op, batch_loss, batch_error], feed_dict={batch_size: _BATCH_SIZE}) # Print training stats. tf.logging.info('%d steps/batch_loss = %f, batch_error = %f', i, batch_loss_, batch_error_)
def train_mnist_single_machine(data_dir, num_epochs, use_fake_data=False): """Train a ConvNet on MNIST. Args: data_dir: string. Directory to read MNIST examples from. num_epochs: int. Number of passes to make over the training set. use_fake_data: bool. If True, generate a synthetic dataset. Returns: accuracy of model on the final minibatch of training data. """ # Load a dataset. tf.logging.info("Loading MNIST into memory.") examples, labels = mnist.load_mnist( data_dir, num_epochs=num_epochs, batch_size=128, use_fake_data=use_fake_data, flatten_images=False) # Build a ConvNet. layer_collection = kfac.LayerCollection() loss, accuracy = build_model( examples, labels, num_labels=10, layer_collection=layer_collection) # Fit model. return minimize_loss_single_machine(loss, accuracy, layer_collection)
def __init__(self): import unittest.mock as mock self._layer_collection = kfac.LayerCollection() def custom_apply(layer, inputs, *args, **kwargs): outs = layer.__call__(inputs, *args, **kwargs) if isinstance(layer, tf.layers.Conv2D): variables = tuple(layer.trainable_variables) if len(layer.trainable_variables) > 1 else \ layer.trainable_variables[0] self._layer_collection.register_conv2d( variables, [1] + list(layer.strides) + [1], layer.padding.upper(), inputs, outs) elif isinstance(layer, tf.layers.Dense): variables = tuple(layer.trainable_variables) if len(layer.trainable_variables) > 1 else \ layer.trainable_variables[0] self._layer_collection.register_fully_connected( variables, inputs, outs) self.logit = outs elif isinstance(layer, tf.layers.BatchNormalization): self._layer_collection.register_generic( tuple(layer.trainable_variables), tf.shape(outs)[0], "diagonal") else: print("ignored layers for kfac", layer) assert len(layer.trainable_variables) == 0 return outs self._patch = mock.patch.object(tf.layers.Layer, "apply", new=custom_apply)
def testBuildModel(self): with tf.Graph().as_default(): x = tf.placeholder(tf.float32, [None, 6, 6, 3]) y = tf.placeholder(tf.int64, [None]) layer_collection = kfac.LayerCollection() loss, accuracy = convnet.build_model( x, y, num_labels=5, layer_collection=layer_collection, register_layers_manually=convnet._USE_MANUAL_REG) if not convnet._USE_MANUAL_REG: layer_collection.auto_register_layers() # Ensure layers and logits were registered. self.assertEqual(len(layer_collection.fisher_blocks), 3) self.assertEqual(len(layer_collection.losses), 1) # Ensure inference doesn't crash. with self.test_session() as sess: sess.run(tf.global_variables_initializer()) feed_dict = { x: np.random.randn(10, 6, 6, 3).astype(np.float32), y: np.random.randint(5, size=10).astype(np.int64), } sess.run([loss, accuracy], feed_dict=feed_dict)
def construct_train_quants(): """Returns tensors and optimizer required to run the autoencoder.""" # Load dataset. cached_reader, num_examples = load_mnist() batch_size_schedule = _get_batch_size_schedule(num_examples) batch_size = tf.placeholder(shape=(), dtype=tf.int32, name='batch_size') minibatch = cached_reader(batch_size) training_model = AutoEncoder(784) layer_collection = kfac.LayerCollection() def loss_fn(minibatch, layer_collection=None, return_acc=False): input_ = minibatch[0] logits = training_model(input_) return compute_loss(logits=logits, labels=input_, layer_collection=layer_collection, return_acc=return_acc) (batch_loss, batch_error) = loss_fn(minibatch, layer_collection=layer_collection, return_acc=True) # Make training op with tf.device(FLAGS.device): train_op, opt = make_train_op(batch_size, batch_loss, layer_collection, loss_fn=loss_fn, cached_reader=cached_reader) return train_op, opt, batch_loss, batch_error, batch_size_schedule, batch_size
def _build_toy_problem(self): """Construct a toy linear regression problem. Initial loss should be, 2.5 = 0.5 * (1^2 + 2^2) Returns: loss: 0-D Tensor representing loss to be minimized. accuracy: 0-D Tensors representing model accuracy. layer_collection: LayerCollection instance describing model architecture. """ x = np.asarray([[1.], [2.]]).astype(np.float32) y = np.asarray([1., 2.]).astype(np.float32) x, y = (tf.data.Dataset.from_tensor_slices( (x, y)).repeat(100).batch(2).make_one_shot_iterator().get_next()) w = tf.get_variable("w", shape=[1, 1], initializer=tf.zeros_initializer()) y_hat = tf.matmul(x, w) loss = tf.reduce_mean(0.5 * tf.square(y_hat - y)) accuracy = loss layer_collection = kfac.LayerCollection() layer_collection.register_fully_connected(params=w, inputs=x, outputs=y_hat) layer_collection.register_normal_predictive_distribution(y_hat) return loss, accuracy, layer_collection
def train_mnist_distributed_sync_replicas(task_id, is_chief, num_worker_tasks, num_ps_tasks, master, num_epochs, op_strategy, use_fake_data=False): """Train a ConvNet on MNIST using Sync replicas optimizer. Args: task_id: int. Integer in [0, num_worker_tasks). ID for this worker. is_chief: `boolean`, `True` if the worker is chief worker. num_worker_tasks: int. Number of workers in this distributed training setup. num_ps_tasks: int. Number of parameter servers holding variables. master: string. IP and port of TensorFlow runtime process. num_epochs: int. Number of passes to make over the training set. op_strategy: `string`, Strategy to run the covariance and inverse ops. If op_strategy == `chief_worker` then covariance and inverse update ops are run on chief worker otherwise they are run on dedicated workers. use_fake_data: bool. If True, generate a synthetic dataset. Returns: accuracy of model on the final minibatch of training data. Raises: ValueError: If `op_strategy` not in ["chief_worker", "dedicated_workers"]. """ # Load a dataset. tf.logging.info("Loading MNIST into memory.") (examples, labels) = mnist.load_mnist_as_iterator(num_epochs, 128, use_fake_data=use_fake_data, flatten_images=False) # Build a ConvNet. layer_collection = kfac.LayerCollection() with tf.device(tf.train.replica_device_setter(num_ps_tasks)): loss, accuracy = build_model( examples, labels, num_labels=10, layer_collection=layer_collection, register_layers_manually=_USE_MANUAL_REG) if not _USE_MANUAL_REG: layer_collection.auto_register_layers() # Fit model. checkpoint_dir = None if op_strategy == "chief_worker": return distributed_grads_only_and_ops_chief_worker( task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir, loss, accuracy, layer_collection) elif op_strategy == "dedicated_workers": return distributed_grads_and_ops_dedicated_workers( task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir, loss, accuracy, layer_collection) else: raise ValueError("Only supported op strategies are : {}, {}".format( "chief_worker", "dedicated_workers"))
def train_mnist_single_machine(num_epochs, use_fake_data=False, device=None, manual_op_exec=False): """Train a ConvNet on MNIST. Args: num_epochs: int. Number of passes to make over the training set. use_fake_data: bool. If True, generate a synthetic dataset. device: string or None. The covariance and inverse update ops are run on this device. If empty or None, the default device will be used. (Default: None) manual_op_exec: bool, If `True` then `minimize_loss_single_machine_manual` is called for training which handles inverse and covariance computation. This is shown only for illustrative purpose. Otherwise `minimize_loss_single_machine` is called which relies on `PeriodicInvCovUpdateOpt` for op placement and execution. Returns: accuracy of model on the final minibatch of training data. """ from tensorflow.data import Iterator # Load a dataset. print ("Loading MNIST into memory.") tf.logging.info("Loading MNIST into memory.") iter_train_handle, output_types, output_shapes = mnist.load_mnist_as_iterator(num_epochs, args.batch_size, train=True, use_fake_data=use_fake_data, flatten_images=False) iter_val_handle, _, _ = mnist.load_mnist_as_iterator(10000*num_epochs, # This just ensures this doesn't cause early termination 10000, train=False, use_fake_data=use_fake_data, flatten_images=False) handle = tf.placeholder(tf.string, shape=[]) iterator = Iterator.from_string_handle( handle, output_types, output_shapes) next_batch = iterator.get_next() (examples, labels) = next_batch # Build a ConvNet. layer_collection = kfac.LayerCollection() loss, accuracy = build_model( examples, labels, num_labels=10, layer_collection=layer_collection, register_layers_manually=_USE_MANUAL_REG) if not _USE_MANUAL_REG: layer_collection.auto_register_layers() # Without setting allow_soft_placement=True there will be problems when # the optimizer tries to place certain ops like "mod" on the GPU (which isn't # supported). config = tf.ConfigProto(allow_soft_placement=True) # Fit model. return minimize_loss_single_machine(handle, iter_train_handle, iter_val_handle, loss, accuracy, layer_collection, device=device, session_config=config)
def construct_train_quants(): with tf.device(FLAGS.device): # Load dataset. cached_reader, num_examples = load_mnist() batch_size_schedule = _get_batch_size_schedule(num_examples) batch_size = tf.placeholder(shape=(), dtype=tf.int32, name='batch_size') minibatch = cached_reader(batch_size) training_model = Model() layer_collection = kfac.LayerCollection() if FLAGS.use_sua_approx: layer_collection.set_default_conv2d_approximation('kron_sua') ema = tf.train.ExponentialMovingAverage(FLAGS.polyak_decay, zero_debias=True) def loss_fn(minibatch, layer_collection=None, return_error=False): features, labels = minibatch logits = training_model(features) return compute_loss( logits=logits, labels=labels, layer_collection=layer_collection, return_error=return_error) (batch_loss, batch_error) = loss_fn( minibatch, layer_collection=layer_collection, return_error=True) layer_collection.auto_register_layers() train_vars = training_model.variables # Make training op: train_op, opt = make_train_op( minibatch, batch_size, batch_loss, layer_collection, loss_fn=loss_fn, prev_train_batch=cached_reader.cached_batch) with tf.control_dependencies([train_op]): train_op = ema.apply(train_vars) # We clear out the regularizers collection when creating our evaluation # graph (which uses different variables). It is important that we do this # only after the train op is constructed, since the minimize() will call # into the loss function (which includes the regularizer): tf.get_default_graph().clear_collection(tf.GraphKeys.REGULARIZATION_LOSSES) # These aren't run in the same sess.run call as train_op: (eval_loss, eval_error, eval_loss_avg, eval_error_avg) = make_eval_ops(train_vars, ema) return (train_op, opt, batch_loss, batch_error, batch_size_schedule, batch_size, eval_loss, eval_error, eval_loss_avg, eval_error_avg)
def testApplyNormWithLayerCollection(self): x = np.random.rand(5, 2, 1, 11) layer_collection = kfac.LayerCollection() common_layers.apply_norm(x, "layer", depth=11, epsilon=1e-6, layer_collection=layer_collection) self.assertLen(layer_collection.get_blocks(), 1)
def train_mnist_single_machine(num_epochs, use_fake_data=False, device=None, manual_op_exec=False): """Train a ConvNet on MNIST. Args: num_epochs: int. Number of passes to make over the training set. use_fake_data: bool. If True, generate a synthetic dataset. device: string or None. The covariance and inverse update ops are run on this device. If empty or None, the default device will be used. (Default: None) manual_op_exec: bool, If `True` then `minimize_loss_single_machine_manual` is called for training which handles inverse and covariance computation. This is shown only for illustrative purpose. Otherwise `minimize_loss_single_machine` is called which relies on `PeriodicInvCovUpdateOpt` for op placement and execution. Returns: accuracy of model on the final minibatch of training data. """ # Load a dataset. tf.logging.info("Loading MNIST into memory.") (examples, labels) = mnist.load_mnist_as_iterator(num_epochs, 128, use_fake_data=use_fake_data, flatten_images=False) # Build a ConvNet. layer_collection = kfac.LayerCollection() loss, accuracy = build_model(examples, labels, num_labels=10, layer_collection=layer_collection, register_layers_manually=_USE_MANUAL_REG) if not _USE_MANUAL_REG: layer_collection.auto_register_layers() # Without setting allow_soft_placement=True there will be problems when # the optimizer tries to place certain ops like "mod" on the GPU (which isn't # supported). config = tf.ConfigProto(allow_soft_placement=True) # Fit model. if manual_op_exec: return minimize_loss_single_machine_manual(loss, accuracy, layer_collection, device=device, session_config=config) else: return minimize_loss_single_machine(loss, accuracy, layer_collection, device=device, session_config=config)
def train_mnist_multitower(data_dir, num_epochs, num_towers, use_fake_data=True): """Train a ConvNet on MNIST. Args: data_dir: string. Directory to read MNIST examples from. num_epochs: int. Number of passes to make over the training set. num_towers: int. Number of CPUs to split inference across. use_fake_data: bool. If True, generate a synthetic dataset. Returns: accuracy of model on the final minibatch of training data. """ # Load a dataset. tf.logging.info("Loading MNIST into memory.") tower_batch_size = 128 batch_size = tower_batch_size * num_towers tf.logging.info( ("Loading MNIST into memory. Using batch_size = %d = %d towers * %d " "tower batch size.") % (batch_size, num_towers, tower_batch_size)) examples, labels = mnist.load_mnist( data_dir, num_epochs=num_epochs, batch_size=batch_size, use_fake_data=use_fake_data, flatten_images=False) # Split minibatch across towers. examples = tf.split(examples, num_towers) labels = tf.split(labels, num_towers) # Build an MLP. Each tower's layers will be added to the LayerCollection. layer_collection = kfac.LayerCollection() tower_results = [] for tower_id in range(num_towers): with tf.device("/cpu:%d" % tower_id): with tf.name_scope("tower%d" % tower_id): with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)): tf.logging.info("Building tower %d." % tower_id) tower_results.append( build_model(examples[tower_id], labels[tower_id], 10, layer_collection)) losses, accuracies = zip(*tower_results) # Average across towers. loss = tf.reduce_mean(losses) accuracy = tf.reduce_mean(accuracies) # Fit model. session_config = tf.ConfigProto( allow_soft_placement=False, device_count={ "CPU": num_towers }) return minimize_loss_single_machine( loss, accuracy, layer_collection, session_config=session_config)
def testLayerNorm(self): x = np.random.rand(5, 7, 11) y = common_layers.layer_norm(tf.constant(x, dtype=tf.float32), 11) self.evaluate(tf.global_variables_initializer()) res = self.evaluate(y) self.assertEqual(res.shape, (5, 7, 11)) # Testing layer collection. layer_collection = kfac.LayerCollection() common_layers.layer_norm(x, layer_collection=layer_collection) self.assertLen(layer_collection.get_blocks(), 1)
def testDenseWithLayerCollection(self): with tf.variable_scope("test_layer_collection"): x1 = tf.zeros([3, 4], tf.float32) layer_collection = kfac.LayerCollection() common_layers.dense( x1, units=10, layer_collection=layer_collection, name="y1") self.assertLen(layer_collection.get_blocks(), 1) # 3D inputs. x2 = tf.zeros([3, 4, 5], tf.float32) common_layers.dense( x2, units=10, layer_collection=layer_collection, name="y2") self.assertLen(layer_collection.get_blocks(), 2)
def testMultiheadAttentionWithLayerCollection(self): """Testing multihead attention with layer collection for kfac.""" x = tf.zeros([3, 4, 5], tf.float32) layer_collection = kfac.LayerCollection() common_attention.multihead_attention(x, None, None, 10, 10, 10, 2, 0.2, layer_collection=layer_collection) self.assertLen(layer_collection.get_blocks(), 4)
def train_mnist_single_machine(data_dir, num_epochs, use_fake_data=False, device=None, manual_op_exec=False): """Train a ConvNet on MNIST. Args: data_dir: string. Directory to read MNIST examples from. num_epochs: int. Number of passes to make over the training set. use_fake_data: bool. If True, generate a synthetic dataset. device: string or None. The covariance and inverse update ops are run on this device. If empty or None, the default device will be used. (Default: None) manual_op_exec: bool, If `True` then `minimize_loss_single_machine_manual` is called for training which handles inverse and covariance computation. This is shown only for illustrative purpose. Otherwise `minimize_loss_single_machine` is called which relies on `PeriodicInvCovUpdateOpt` for op placement and execution. Returns: accuracy of model on the final minibatch of training data. """ # Load a dataset. tf.logging.info("Loading MNIST into memory.") examples, labels = mnist.load_mnist(data_dir, num_epochs=num_epochs, batch_size=128, use_fake_data=use_fake_data, flatten_images=False) # Build a ConvNet. layer_collection = kfac.LayerCollection() loss, accuracy = build_model(examples, labels, num_labels=10, layer_collection=layer_collection) # Fit model. if manual_op_exec: return minimize_loss_single_machine_manual(loss, accuracy, layer_collection, device=device) else: return minimize_loss_single_machine(loss, accuracy, layer_collection, device=device)
def create_optimizer(acktr, model, learning_rate): """Creates an optimizer based on whether `ACKTR` or `A2C` is used. `A2C` uses the RMSProp optimizer, `ACKTR` uses the K-FAC optimizer. This function is not restricted to Atari models and can be used generally. Args: acktr (:obj:`bool`): Whether to use the optimizer of `ACKTR` or `A2C`. model (:obj:`~actorcritic.model.ActorCriticModel`): A model that is needed for K-FAC to register the neural network layers and the predictive distributions. learning_rate (:obj:`float` or :obj:`tf.Tensor`): A learning rate for the optimizer. """ if acktr: # required for the K-FAC optimizer layer_collection = kfac.LayerCollection() model.register_layers(layer_collection) model.register_predictive_distributions(layer_collection) # use SGD optimizer for the first few iterations, to prevent NaN values # TODO cold_optimizer = tf.train.MomentumOptimizer(learning_rate=0.0003, momentum=0.9) cold_optimizer = ClipGlobalNormOptimizer(cold_optimizer, clip_norm=0.5) optimizer = ColdStartPeriodicInvUpdateKfacOpt( num_cold_updates=30, cold_optimizer=cold_optimizer, invert_every=10, learning_rate=learning_rate, cov_ema_decay=0.99, damping=0.01, layer_collection=layer_collection, momentum=0.9, norm_constraint=0.0001, # trust region radius cov_devices=['/gpu:0'], inv_devices=['/gpu:0']) else: optimizer = tf.train.RMSPropOptimizer(learning_rate=learning_rate) optimizer = ClipGlobalNormOptimizer( optimizer, clip_norm=0.5) # clip the gradients return optimizer
def train_mnist_distributed(task_id, num_worker_tasks, num_ps_tasks, master, data_dir, num_epochs, use_fake_data=False): """Train a ConvNet on MNIST. Args: task_id: int. Integer in [0, num_worker_tasks). ID for this worker. num_worker_tasks: int. Number of workers in this distributed training setup. num_ps_tasks: int. Number of parameter servers holding variables. master: string. IP and port of TensorFlow runtime process. data_dir: string. Directory to read MNIST examples from. num_epochs: int. Number of passes to make over the training set. use_fake_data: bool. If True, generate a synthetic dataset. Returns: accuracy of model on the final minibatch of training data. """ # Load a dataset. tf.logging.info("Loading MNIST into memory.") examples, labels = mnist.load_mnist( data_dir, num_epochs=num_epochs, batch_size=128, use_fake_data=use_fake_data, flatten_images=False) # Build a ConvNet. layer_collection = kfac.LayerCollection() with tf.device(tf.train.replica_device_setter(num_ps_tasks)): loss, accuracy = build_model( examples, labels, num_labels=10, layer_collection=layer_collection) # Fit model. checkpoint_dir = None if data_dir is None else os.path.join(data_dir, "kfac") return minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master, checkpoint_dir, loss, accuracy, layer_collection)
def instantiate_optimizer(model, optimizer_tuple): optimizer_name = optimizer_tuple[0] optimizer_kwargs = optimizer_tuple[1] lr = process_learning_rate(optimizer_kwargs["learning_rate"], model.global_step, model.n_batches_per_epoch) # I want to copy because I want to modify it and I don't want to accidentally modify all the references around # in python references to a particular entry of a dictionary can be passed around and I might overwrite different task_opts optimizer_kwargs = optimizer_kwargs.copy() optimizer_kwargs.update({"learning_rate": lr}) try: # try to get the module from tf.train training_optimizer = eval_method_from_tuple( tf.train, (optimizer_name, optimizer_kwargs)) except AttributeError as e: optimizer_kwargs["model"] = model try: # first try to load from argo.core.optimizers optimizer_module = importlib.import_module( "." + optimizer_name, '.'.join(__name__.split('.')[:-1])) training_optimizer = eval_method_from_tuple( optimizer_module, (optimizer_name, optimizer_kwargs)) except ImportError: try: # second try to load from core.optimizers optimizer_module = importlib.import_module( "core.optimizers." + optimizer_name, '.'.join(__name__.split('.')[:-1])) training_optimizer = eval_method_from_tuple( optimizer_module, (optimizer_name, optimizer_kwargs)) except ImportError: try: # third try to load from core optimizer_module = importlib.import_module( "core." + optimizer_name, '.'.join(__name__.split('.')[:-1])) training_optimizer = eval_method_from_tuple( optimizer_module, (optimizer_name, optimizer_kwargs)) except ImportError: try: pdb.set_trace() # next try to load kfac import kfac layer_collection = kfac.LayerCollection() layer_collection.register_categorical_predictive_distribution( model.logits, name="logits") # Register parameters. K-FAC needs to know about the inputs, outputs, and # parameters of each conv/fully connected layer and the logits powering the # posterior probability over classes. tf.logging.info("Building LayerCollection.") layer_collection.auto_register_layers() # training_module = importlib.import_module("." + training_algorithm_name, '.'.join(__name__.split('.')[:-1])) training_module = kfac kfac_kwargs = { **optimizer_kwargs, "layer_collection": layer_collection, "placement_strategy": "round_robin", "cov_devices": ["/gpu:0"], "inv_devices": ["/gpu:0"], } training_optimizer = eval_method_from_tuple( training_module, (optimizer_name, kfac_kwargs)) except Exception as e: raise Exception( "problem loading training algorithm: %s, kwargs: %s, exception: %s" % (training_module, optimizer_kwargs, e)) from e return training_optimizer, lr
def _qmc_step_fn(self, optimizer_fn, using_kfac, global_step): """Training step for network given the MCMC state. Args: optimizer_fn: A function which takes as argument a LayerCollection object (None) if using_kfac is True (False) and returns the optimizer. using_kfac: True if optimizer_fn creates a instance of kfac.KfacOptimizer and False otherwise. global_step: tensorflow op for global step index. Returns: loss: per-GPU loss tensor with control dependencies for updating network. local_energy: local energy for each walker features: network output for each walker. Raises: RuntimeError: If using_kfac is True and optimizer_fn does not create a kfac.KfacOptimizer instance or the converse. """ # Note layer_collection cannot be modified after the KFac optimizer has been # constructed. if using_kfac: layer_collection = kfac.LayerCollection() else: layer_collection = None walkers = self.data_gen.walkers_per_gpu features, features_sign = self.network(walkers, layer_collection) optimizer = optimizer_fn(layer_collection) if bool(using_kfac) != isinstance(optimizer, kfac.KfacOptimizer): raise RuntimeError('Not using KFac but using_kfac is True.') if layer_collection: layer_collection.register_squared_error_loss(features, reuse=False) with tf.name_scope('local_energy'): kinetic_fn, potential_fn = self.hamiltonian kinetic = kinetic_fn(features, walkers) potential = potential_fn(walkers) local_energy = kinetic + potential loss = tf.reduce_mean(local_energy) replica_context = tf.distribute.get_replica_context() mean_op = tf.distribute.ReduceOp.MEAN mean_loss = replica_context.merge_call( lambda strategy, val: strategy.reduce(mean_op, val), args=(loss, )) grad_loss = local_energy - mean_loss if self._clip_el is not None: # clip_el should be much larger than 1, to avoid bias median = tfp.stats.percentile(grad_loss, 50.0) diff = tf.reduce_mean(tf.abs(grad_loss - median)) grad_loss_clipped = tf.clip_by_value( grad_loss, median - self._clip_el * diff, median + self._clip_el * diff) else: grad_loss_clipped = grad_loss with tf.name_scope('step'): # Create functions which take no arguments and return the ops for applying # an optimisation step. if not optimizer: optimize_step = tf.no_op else: optimize_step = functools.partial( optimizer.minimize, features, global_step=global_step, var_list=self.network.trainable_variables, grad_loss=grad_loss_clipped) if self._check_loss: # Apply optimisation step only if all local energies are well-defined. step = tf.cond(tf.reduce_any(tf.math.is_nan(mean_loss)), tf.no_op, optimize_step) else: # Faster, but less safe: always apply optimisation step. If the # gradients are not well-defined (e.g. loss contains a NaN), then the # network will also be set to NaN. step = optimize_step() # A strategy step function must return tensors, not ops so apply a # control dependency to a dummy op to ensure they're executed. with tf.control_dependencies([step]): loss = tf.identity(loss) return { 'loss': loss, 'local_energies': local_energy, 'features': features, 'features_sign': features_sign }
def _model_fn(features, labels, mode, params): """Estimator model_fn for an autoencoder with adaptive damping.""" del params training_model = classifier_mnist.Model() layer_collection = kfac.LayerCollection() def loss_fn(minibatch, logits=None, return_error=False): features, labels = minibatch if logits is None: # Note we do not need to do anything like # `with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):` # here because Sonnet takes care of variable reuse for us as long as we # call the same `training_model` module. Otherwise we would need to # use variable reusing here. logits = training_model(features) return classifier_mnist.compute_loss(logits=logits, labels=labels, return_error=return_error) logits = training_model(features) pre_update_batch_loss, pre_update_batch_error = loss_fn((features, labels), logits=logits, return_error=True) global_step = tf.train.get_or_create_global_step() if mode == tf.estimator.ModeKeys.TRAIN: layer_collection.register_softmax_cross_entropy_loss(logits, seed=FLAGS.seed + 1) layer_collection.auto_register_layers() train_op, kfac_optimizer = make_train_op( (features, labels), pre_update_batch_loss, layer_collection, loss_fn) tensors_to_print = { 'learning_rate': tf.expand_dims(kfac_optimizer.learning_rate, 0), 'momentum': tf.expand_dims(kfac_optimizer.momentum, 0), 'damping': tf.expand_dims(kfac_optimizer.damping, 0), 'global_step': tf.expand_dims(global_step, 0), 'loss': tf.expand_dims(pre_update_batch_loss, 0), 'error': tf.expand_dims(pre_update_batch_error, 0), } if FLAGS.adapt_damping: tensors_to_print['qmodel_change'] = tf.expand_dims( kfac_optimizer.qmodel_change, 0) tensors_to_print['rho'] = tf.expand_dims(kfac_optimizer.rho, 0) return contrib_tpu.TPUEstimatorSpec(mode=mode, loss=pre_update_batch_loss, train_op=train_op, host_call=(print_tensors, tensors_to_print), eval_metrics=None) else: # mode == tf.estimator.ModeKeys.{EVAL, PREDICT}: return contrib_tpu.TPUEstimatorSpec(mode=mode, loss=pre_update_batch_loss, eval_metrics=None)
def main(_): # Load dataset. cached_reader, num_examples = load_mnist() num_classes = 10 minibatch_maxsize_targetiter = 500 minibatch_maxsize = num_examples minibatch_startsize = 1000 div = (float(minibatch_maxsize_targetiter - 1) / math.log(float(minibatch_maxsize) / minibatch_startsize, 2)) batch_size_schedule = [ min(int(2.**(float(k) / div) * minibatch_startsize), minibatch_maxsize) for k in range(500) ] batch_size = tf.placeholder(shape=(), dtype=tf.int32, name='batch_size') layer_collection = kfac.LayerCollection() def loss_fn(minibatch, layer_collection=None): return compute_loss(minibatch[0], minibatch[1], num_classes, layer_collection=layer_collection) minibatch = cached_reader(batch_size) batch_loss = loss_fn(minibatch, layer_collection=layer_collection) # Make training op with tf.device(FLAGS.device): train_op, opt = make_train_op(batch_size, batch_loss, layer_collection, loss_fn=loss_fn, cached_reader=cached_reader) learning_rate = opt.learning_rate momentum = opt.momentum damping = opt.damping rho = opt.rho qmodel_change = opt.qmodel_change global_step = tf.train.get_or_create_global_step() # Without setting allow_soft_placement=True there will be problems when # the optimizer tries to place certain ops like "mod" on the GPU (which isn't # supported). config = tf.ConfigProto(allow_soft_placement=True) # Train model. with tf.train.MonitoredTrainingSession(save_checkpoint_secs=30, config=config) as sess: while not sess.should_stop(): i = sess.run(global_step) if FLAGS.use_batch_size_schedule: batch_size_ = batch_size_schedule[min( i, len(batch_size_schedule) - 1)] else: batch_size_ = FLAGS.batch_size _, batch_loss_ = sess.run([train_op, batch_loss], feed_dict={batch_size: batch_size_}) # We get these things in a separate sess.run() call because they are # stored as variables in the optimizer. (So there is no computational cost # to getting them, and if we don't get them after the previous call is # over they might not be updated.) (learning_rate_, momentum_, damping_, rho_, qmodel_change_) = sess.run( [learning_rate, momentum, damping, rho, qmodel_change]) # Print training stats. tf.logging.info('iteration: %d', i) tf.logging.info('mini-batch size: %d | mini-batch loss = %f', batch_size_, batch_loss_) tf.logging.info('learning_rate = %f | momentum = %f', learning_rate_, momentum_) tf.logging.info('damping = %f | rho = %f | qmodel_change = %f', damping_, rho_, qmodel_change_) tf.logging.info('----')
def model_fn(features, labels, mode, params): """Model function for MLP trained with K-FAC. Args: features: Tensor of shape [batch_size, input_size]. Input features. labels: Tensor of shape [batch_size]. Target labels for training. mode: tf.estimator.ModeKey. Must be TRAIN. params: ignored. Returns: EstimatorSpec for training. Raises: ValueError: If 'mode' is anything other than TRAIN. """ del params if mode != tf.estimator.ModeKeys.TRAIN: raise ValueError("Only training is supported with this API.") # Build a ConvNet. layer_collection = kfac.LayerCollection() loss, accuracy = build_model(features, labels, num_labels=10, layer_collection=layer_collection, register_layers_manually=_USE_MANUAL_REG) if not _USE_MANUAL_REG: layer_collection.auto_register_layers() # Train with K-FAC. global_step = tf.train.get_or_create_global_step() optimizer = kfac.KfacOptimizer( learning_rate=tf.train.exponential_decay(0.00002, global_step, 10000, 0.5, staircase=True), cov_ema_decay=0.95, damping=0.001, layer_collection=layer_collection, momentum=0.9) (cov_update_thunks, inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() def make_update_op(update_thunks): update_ops = [thunk() for thunk in update_thunks] return tf.group(*update_ops) def make_batch_executed_op(update_thunks, batch_size=1): return tf.group(*kfac.utils.batch_execute( global_step, update_thunks, batch_size=batch_size)) # Run cov_update_op every step. Run 1 inv_update_ops per step. cov_update_op = make_update_op(cov_update_thunks) with tf.control_dependencies([cov_update_op]): # But make sure to execute all the inverse ops on the first step inverse_op = tf.cond( tf.equal(global_step, 0), lambda: make_update_op(inv_update_thunks), lambda: make_batch_executed_op(inv_update_thunks)) with tf.control_dependencies([inverse_op]): train_op = optimizer.minimize(loss, global_step=global_step) # Print metrics every 5 sec. hooks = [ tf.train.LoggingTensorHook({ "loss": loss, "accuracy": accuracy }, every_n_secs=5), ] return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op, training_hooks=hooks)
def train_mnist_multitower(num_epochs, num_towers, devices, use_fake_data=False, session_config=None): """Train a ConvNet on MNIST. Training data is split equally among the towers. Each tower computes loss on its own batch of data and the loss is aggregated on the CPU. The model variables are placed on first tower. The covariance and inverse update ops and variables are placed on specified devices in a round robin manner. Args: num_epochs: int. Number of passes to make over the training set. num_towers: int. Number of towers. devices: list of strings. List of devices to place the towers. use_fake_data: bool. If True, generate a synthetic dataset. session_config: None or tf.ConfigProto. Configuration for tf.Session(). Returns: accuracy of model on the final minibatch of training data. """ num_towers = 1 if not devices else len(devices) # Load a dataset. tf.logging.info("Loading MNIST into memory.") tower_batch_size = 128 batch_size = tower_batch_size * num_towers tf.logging.info( ("Loading MNIST into memory. Using batch_size = %d = %d towers * %d " "tower batch size.") % (batch_size, num_towers, tower_batch_size)) (examples, labels) = mnist.load_mnist_as_iterator(num_epochs, batch_size, use_fake_data=use_fake_data, flatten_images=False) # Split minibatch across towers. examples = tf.split(examples, num_towers) labels = tf.split(labels, num_towers) # Build an MLP. Each tower's layers will be added to the LayerCollection. layer_collection = kfac.LayerCollection() tower_results = [] for tower_id in range(num_towers): with tf.device(devices[tower_id]): with tf.name_scope("tower%d" % tower_id): with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)): tf.logging.info("Building tower %d." % tower_id) tower_results.append( build_model(examples[tower_id], labels[tower_id], 10, layer_collection, register_layers_manually=_USE_MANUAL_REG)) losses, accuracies = zip(*tower_results) # When using multiple towers we only want to perform automatic # registation once, after the final tower is made if not _USE_MANUAL_REG: layer_collection.auto_register_layers() # Average across towers. loss = tf.reduce_mean(losses) accuracy = tf.reduce_mean(accuracies) # Fit model. g_step = tf.train.get_or_create_global_step() optimizer = kfac.PeriodicInvCovUpdateKfacOpt( invert_every=_INVERT_EVERY, cov_update_every=_COV_UPDATE_EVERY, learning_rate=0.0001, cov_ema_decay=0.95, damping=0.001, layer_collection=layer_collection, placement_strategy="round_robin", cov_devices=devices, inv_devices=devices, trans_devices=devices, momentum=0.9) with tf.device(devices[0]): train_op = optimizer.minimize(loss, global_step=g_step) # Without setting allow_soft_placement=True there will be problems when # the optimizer tries to place certain ops like "mod" on the GPU (which isn't # supported). if not session_config: session_config = tf.ConfigProto(allow_soft_placement=True) tf.logging.info("Starting training.") with tf.train.MonitoredTrainingSession(config=session_config) as sess: while not sess.should_stop(): global_step_, loss_, accuracy_, _ = sess.run( [g_step, loss, accuracy, train_op]) if global_step_ % _REPORT_EVERY == 0: tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_, loss_, accuracy_)
def _model_fn(features, labels, mode, params): """Estimator model_fn for an autoencoder with adaptive damping.""" del params layer_collection = kfac.LayerCollection() training_model_fn = autoencoder_mnist.AutoEncoder(784) def loss_fn(minibatch, logits=None): """Compute the model loss given a batch of inputs. Args: minibatch: `Tuple[Tensor, Tensor]` for the current batch of input images and labels. logits: `Tensor` for the current batch of logits. If None then reuses the AutoEncoder to compute them. Returns: `Tensor` for the batch loss. """ features, labels = minibatch del labels if logits is None: # Note we do not need to do anything like # `with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):` # here because Sonnet takes care of variable reuse for us as long as we # call the same `training_model_fn` module. Otherwise we would need to # use variable reusing here. logits = training_model_fn(features) batch_loss = compute_loss(logits=logits, labels=features) return batch_loss logits = training_model_fn(features) pre_update_batch_loss = loss_fn((features, labels), logits=logits) pre_update_batch_error = compute_squared_error(logits, features) if mode == tf.estimator.ModeKeys.TRAIN: # Make sure never to confuse this with register_softmax_cross_entropy_loss! layer_collection.register_sigmoid_cross_entropy_loss(logits, seed=FLAGS.seed + 1) layer_collection.auto_register_layers() global_step = tf.train.get_or_create_global_step() train_op, kfac_optimizer = make_train_op( (features, labels), pre_update_batch_loss, layer_collection, loss_fn) tensors_to_print = { 'learning_rate': tf.expand_dims(kfac_optimizer.learning_rate, 0), 'momentum': tf.expand_dims(kfac_optimizer.momentum, 0), 'damping': tf.expand_dims(kfac_optimizer.damping, 0), 'global_step': tf.expand_dims(global_step, 0), 'loss': tf.expand_dims(pre_update_batch_loss, 0), 'error': tf.expand_dims(pre_update_batch_error, 0), } if FLAGS.adapt_damping: tensors_to_print['qmodel_change'] = tf.expand_dims( kfac_optimizer.qmodel_change, 0) tensors_to_print['rho'] = tf.expand_dims(kfac_optimizer.rho, 0) return contrib_tpu.TPUEstimatorSpec(mode=mode, loss=pre_update_batch_loss, train_op=train_op, host_call=(print_tensors, tensors_to_print), eval_metrics=None) else: # mode == tf.estimator.ModeKeys.{EVAL, PREDICT}: return contrib_tpu.TPUEstimatorSpec(mode=mode, loss=pre_update_batch_loss, eval_metrics=None)
def train_a2c_acktr(acktr, env_id, num_envs, num_steps, save_path, model_name): """Trains an Atari model using A2C or ACKTR. Automatically saves and loads the trained model. Args: acktr (:obj:`bool`): Whether the ACKTR or the A2C algorithm should be used. ACKTR uses the K-FAC optimizer and uses 32 filters in the third convolutional layer of the neural network instead of 64. env_id (:obj:`string`): An id passed to :meth:`gym.make` to create the environments. num_envs (:obj:`int`): The number of environments that will be used (so `num_envs` subprocesses will be created). A2C normally uses 16. ACKTR normally uses 32. num_steps (:obj:`int`): The number of steps to take in each iteration. A2C normally uses 5. ACKTR normally uses 20. save_path (:obj:`string`): A directory to load and save the model. model_name (:obj:`string`): A name of the model. The files in the `save_path` directory will have this name. """ # creates functions to create environments (binds values to make_atari_env) # render first environment to visualize the learning progress env_fns = [functools.partial(make_atari_env, env_id, render=i == 0) for i in range(num_envs)] envs = create_subprocess_envs(env_fns) # stacking frames inside the subprocesses would cause the frames to be passed between processes multiple times envs = [wrappers.FrameStackWrapper(env, 4) for env in envs] multi_env = MultiEnv(envs) # acktr uses only 32 filters in the last layer model = AtariModel(multi_env.observation_space, multi_env.action_space, 32 if acktr else 64) objective = A2CObjective(model, discount_factor=0.99, entropy_regularization_strength=0.01) if acktr: # required for the K-FAC optimizer layer_collection = kfac.LayerCollection() model.register_layers(layer_collection) model.register_predictive_distributions(layer_collection) # use SGD optimizer for the first few iterations, to prevent NaN values # TODO cold_optimizer = tf.train.MomentumOptimizer(learning_rate=0.001, momentum=0.9) cold_optimizer = ClipGlobalNormOptimizer(cold_optimizer, clip_norm=0.25) optimizer = ColdStartPeriodicInvUpdateKfacOpt( num_cold_updates=30, cold_optimizer=cold_optimizer, invert_every=10, learning_rate=0.25, cov_ema_decay=0.99, damping=0.01, layer_collection=layer_collection, momentum=0.9, norm_constraint=0.0001, # trust region radius cov_devices=['/gpu:0'], inv_devices=['/gpu:0']) else: optimizer = tf.train.RMSPropOptimizer(learning_rate=0.0007) optimizer = ClipGlobalNormOptimizer(optimizer, clip_norm=0.5) # clip the gradients global_step = tf.train.get_or_create_global_step() # create optimizer operation for shared parameters optimize_op = objective.minimize_shared(optimizer, baseline_loss_weight=0.5, global_step=global_step) agent = MultiEnvAgent(multi_env, model, num_steps) with tf.Session() as session: session.run(tf.global_variables_initializer()) saver = tf.train.Saver() try: latest_checkpoint_path = tf.train.latest_checkpoint(save_path) if latest_checkpoint_path is None: raise FileNotFoundError() saver.restore(session, latest_checkpoint_path) print('Loaded model') except (tf.errors.NotFoundError, FileNotFoundError): print('No model loaded') step = None try: while True: # sample trajectory batch observations, actions, rewards, terminals, next_observations, infos = agent.interact(session) # update policy and baseline step, _ = session.run([global_step, optimize_op], feed_dict={ model.observations_placeholder: observations, model.bootstrap_observations_placeholder: next_observations, model.actions_placeholder: actions, model.rewards_placeholder: rewards, model.terminals_placeholder: terminals }) if step % 100 == 0 and step > 0: # save every 100th step saver.save(session, save_path + '/' + model_name, step) print('Saved model (step {})'.format(step)) except KeyboardInterrupt: multi_env.close() # save when interrupted if step is not None: saver.save(session, save_path + '/' + model_name, step) print('Saved model (step {})'.format(step))
def construct_train_quants(): """Returns tensors and optimizer required to run the autoencoder.""" with tf.device(FLAGS.device): # Load dataset. cached_reader, num_examples = load_mnist() batch_size_schedule = _get_batch_size_schedule(num_examples) batch_size = tf.placeholder(shape=(), dtype=tf.int32, name='batch_size') train_minibatch = cached_reader(batch_size) if FLAGS.auto_register_layers: if FLAGS.use_keras_model: features, _ = train_minibatch training_model = get_keras_autoencoder(tensor=features) else: training_model = AutoEncoder(784) else: training_model = AutoEncoderManualReg(784) layer_collection = kfac.LayerCollection() def loss_fn(minibatch, layer_collection=None, return_error=False): features, labels = minibatch del labels if FLAGS.auto_register_layers: logits = training_model(features) else: logits = training_model(features, layer_collection=layer_collection) return compute_loss(logits=logits, labels=features, layer_collection=layer_collection, return_error=return_error, model=training_model) if FLAGS.use_keras_model: (batch_loss, batch_error) = compute_loss(logits=training_model.output, labels=features, layer_collection=layer_collection, return_error=True, model=training_model) else: (batch_loss, batch_error) = loss_fn(train_minibatch, layer_collection=layer_collection, return_error=True) if FLAGS.auto_register_layers: layer_collection.auto_register_layers() # Make training op train_op, opt = make_train_op( train_minibatch, batch_size, batch_loss, layer_collection, loss_fn=loss_fn, prev_train_batch=cached_reader.cached_batch) return train_op, opt, batch_loss, batch_error, batch_size_schedule, batch_size
def construct_train_quants(): with tf.device(FLAGS.device): # Load dataset. cached_reader, num_examples = load_mnist() batch_size_schedule = _get_batch_size_schedule(num_examples) batch_size = tf.placeholder(shape=(), dtype=tf.int32, name='batch_size') minibatch = cached_reader(batch_size) training_model = Model() layer_collection = kfac.LayerCollection() if FLAGS.use_sua_approx: layer_collection.set_default_conv2d_approximation('kron_sua') ema = tf.train.ExponentialMovingAverage(FLAGS.polyak_decay, zero_debias=True) def loss_fn(minibatch, layer_collection=None, return_error=False): features, labels = minibatch logits = training_model(features) return compute_loss(logits=logits, labels=labels, layer_collection=layer_collection, return_error=return_error) (batch_loss, batch_error) = loss_fn(minibatch, layer_collection=layer_collection, return_error=True) train_vars = training_model.variables # Make training op: train_op, opt = make_train_op( minibatch, batch_size, batch_loss, layer_collection, loss_fn=loss_fn, prev_train_batch=cached_reader.cached_batch) with tf.control_dependencies([train_op]): train_op = ema.apply(train_vars) # Make eval ops: images, labels, num_examples = mnist.load_mnist_as_tensors( flatten_images=True) eval_model = Model() eval_model( images) # We need this dummy call because for some reason the # variables won't exist otherwise... eval_vars = eval_model.variables update_eval_model = group_assign(eval_vars, train_vars) with tf.control_dependencies([update_eval_model]): logits = eval_model(images) eval_loss, eval_error = compute_loss(logits=logits, labels=labels, return_error=True) with tf.control_dependencies([eval_loss, eval_error]): update_eval_model_avg = group_assign(eval_vars, (ema.average(t) for t in train_vars)) with tf.control_dependencies([update_eval_model_avg]): logits = eval_model(images) eval_loss_avg, eval_error_avg = compute_loss( logits=logits, labels=labels, return_error=True) return (train_op, opt, batch_loss, batch_error, batch_size_schedule, batch_size, eval_loss, eval_error, eval_loss_avg, eval_error_avg)
def cifar10_model_fn(features, labels, mode, params): """Model function for CIFAR-10.""" tf.summary.image('images', features, max_outputs=6) inputs = features _network = get_problem(params) def network(*inputs): with tf.variable_scope('nn', reuse=tf.AUTO_REUSE): return _network(*inputs, mode == tf.estimator.ModeKeys.TRAIN) logits = network(inputs) if params['optimizer'] == 'kfac': lc = kfac.LayerCollection() lc.register_categorical_predictive_distribution(logits) lc.auto_register_layers() predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) # Calculate loss, which includes softmax cross entropy and L2 regularization. cross_entropy = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=labels) # Create a tensor named cross_entropy for logging purposes. tf.identity(cross_entropy, name='cross_entropy') tf.summary.scalar('cross_entropy', cross_entropy) # Add weight decay to the loss. loss = cross_entropy + _WEIGHT_DECAY * tf.add_n( [tf.nn.l2_loss(v) for v in tf.trainable_variables()]) if mode == tf.estimator.ModeKeys.TRAIN: # Scale the learning rate linearly with the batch size. When the batch size # is 128, the learning rate should be 0.1. initial_learning_rate = params[ 'lr'] # 0.1 * params['batch_size'] / 128 # batches_per_epoch = _NUM_IMAGES['train'] / params['batch_size'] global_step = tf.train.get_or_create_global_step() # Multiply the learning rate by 0.1 at 100, 150, and 200 epochs. # boundaries = [int(batches_per_epoch * epoch) for epoch in [100, 150, 200]] # values = [initial_learning_rate * decay for decay in [1, 0.1, 0.01, 0.001]] # learning_rate = tf.train.piecewise_constant( # tf.cast(global_step, tf.int32), boundaries, values) learning_rate = initial_learning_rate # Create a tensor named learning_rate for logging purposes tf.identity(learning_rate, name='learning_rate') tf.summary.scalar('learning_rate', learning_rate) if params['optimizer'] == 'meta': optimizer = co.MetaHessionFreeOptimizer( learning_rate=learning_rate, iter=params['CG_iter'], x_use=params['x_use'], y_use=params['y_use'], d_use=params['d_use'], damping_type=params['damping_type'], damping=params['damping'], decay=params['decay']) elif params['optimizer'] == 'adam': optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=params['beta1'], beta2=params['beta2']) elif params['optimizer'] == 'RMSprop': optimizer = tf.train.RMSPropOptimizer(learning_rate=learning_rate, decay=params['decay']) elif params['optimizer'] == 'momentum': optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=params['momentum']) elif params['optimizer'] == 'SGD': optimizer = tf.train.GradientDescentOptimizer( learning_rate=learning_rate) elif params['optimizer'] == 'kfac': optimizer = kfac.PeriodicInvCovUpdateKfacOpt( learning_rate=learning_rate, cov_ema_decay=params['decay'], damping=params['damping'], layer_collection=lc) if params['damping_type'] == 'LM_heuristics': last_inputs = tf.get_variable('last_input', initializer=tf.zeros_initializer, shape=inputs.shape, dtype=inputs.dtype, trainable=False) last_labels = tf.get_variable('last_label', initializer=tf.zeros_initializer, shape=labels.shape, dtype=labels.dtype, trainable=False) catched_collecctions = [ tf.assign(last_inputs, inputs), tf.assign(last_labels, labels) ] optimizer.set_damping_adaptation_params( prev_train_batch=(last_inputs, last_labels), is_chief=True, loss_fn=lambda x: tf.losses.softmax_cross_entropy( logits=network(x[0]), onehot_labels=x[1]), damping_adaptation_decay=params['momentum'], ) else: raise ValueError # Batch norm requires update ops to be added as a dependency to the train_op update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): if params['optimizer'] == 'meta': train_op = optimizer.minimize(loss_type='cross_entropy', out=logits, label=labels, input_list=[inputs], global_step=global_step, network_fn=network) train_hooks = [ co.MetaParametersLoadingHook(params['meta_ckpt']) ] else: train_op = optimizer.minimize(loss, global_step=global_step) ''' train_hooks = [rl.RecordStateHook(state_scope='nn', total_step=total_step, account=100, loss=cross_entropy, experience=experience)] ''' if params['optimizer'] == 'kfac' and params[ 'damping_type'] == 'LM_heuristics': with tf.control_dependencies([train_op]): with tf.control_dependencies(catched_collecctions): train_op = tf.no_op() train_hooks = [] else: train_op = None train_hooks = [] accuracy = tf.metrics.accuracy(tf.argmax(labels, axis=1), predictions['classes']) metrics = {'accuracy': accuracy} # Create a tensor named train_accuracy for logging purposes tf.identity(accuracy[1], name='train_accuracy') tf.summary.scalar('train_accuracy', accuracy[1]) return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metric_ops=metrics, training_hooks=train_hooks)