def computation_fn(): graph = mtf.Graph() mesh = mtf.Mesh(graph, 'my_mesh') mesh_shape = mtf.convert_to_shape('all:2') layout = 'none:all' mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, mtf.convert_to_layout_rules(layout), mesh_devices, device_assignment) hidden_dim = mtf.Dimension('hidden', 3) w = mtf.get_variable(mesh, 'w', shape=[hidden_dim], initializer=tf.constant_initializer( [0.1, -0.2, -0.1])) x = mtf.constant(mesh, [0.4, 0.2, -0.5], [hidden_dim], dtype=tf.float32) loss = mtf.reduce_mean(mtf.square(x - w)) var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf_optimize.AdamWeightDecayOptimizer( learning_rate=0.2) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) self.lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_update_ops = [ self.lowering.lowered_operation(op) for op in update_ops ] return tf.group(tf_update_ops)
def testRecomputeGrad(self): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") # let's differentiate x^2 + x # dy/dx = 2x+1 def x_squared_plus_x(x): return x * x + x x = tf.constant([5, 10], dtype=tf.float32) dy = tf.constant([2, 3], dtype=tf.float32) two = mtf.Dimension("two", 2) expected_y = tf.constant([30, 110], dtype=tf.float32) expected_dx = tf.constant([22, 63], dtype=tf.float32) mtf_x = mtf.import_fully_replicated(mesh, x, shape=mtf.Shape([two])) mtf_dy = mtf.import_tf_tensor(mesh, dy, shape=mtf.Shape([two])) mtf_y = mtf.recompute_grad(x_squared_plus_x, [mtf_x]) [mtf_dx] = mtf.gradients([mtf_y], [mtf_x], [mtf_dy]) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( shape="processors:2", layout="two:processors", devices=["", ""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_y = lowering.export_to_tf_tensor(mtf_y) actual_dx = lowering.export_to_tf_tensor(mtf_dx) self.assertAllEqual(self.evaluate(actual_y), self.evaluate(expected_y)) self.assertAllEqual(self.evaluate(actual_dx), self.evaluate(expected_dx))
def model_fn(features, labels, mode, params): """The model_fn argument for creating an Estimator.""" global_step = tf.train.get_global_step() graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") logits, loss = model_backbone(features, labels, mesh) variables = graph._all_variables for v in variables: logger.debug("[parameter] (name,shape,dtype): ({},{},{})".format(v.name,v.shape,v.dtype.master_dtype)) mesh_shape = mtf.convert_to_shape(args_opt.mesh_shape) # layout_rules = mtf.auto_mtf.layout(graph, mesh_shape, [logits, loss]) mesh_shape = mtf.convert_to_shape(mesh_shape) estimator = memory_estimator.MemoryEstimator(graph, mesh_shape, [logits, loss]) optimizer = layout_optimizer.LayoutOptimizer(estimator,scheduler_alg="NAIVE") layout_rules = mtf.convert_to_layout_rules(optimizer.solve()) # layout_rules=[('batch', 'b1')] logger.info("[auto mtf search] strategy: {}".format(layout_rules)) mesh_devices = ["gpu:{}".format(i) for i in range(int(args_opt.num_gpus))] mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(mesh_shape, layout_rules, mesh_devices) if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf.optimize.SgdOptimizer(0.01) # optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) restore_hook = mtf.MtfRestoreHook(lowering) tf_logits = lowering.export_to_tf_tensor(logits) if mode != tf.estimator.ModeKeys.PREDICT: tf_loss = lowering.export_to_tf_tensor(loss) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) accuracy = tf.metrics.accuracy( labels=labels, predictions=tf.argmax(tf_logits, axis=1)) # Name tensors to be logged with LoggingTensorHook. tf.identity(tf_loss, "cross_entropy") tf.identity(accuracy[1], name="train_accuracy") logging_hook = tf.train.LoggingTensorHook(every_n_iter=100,tensors={'loss': 'cross_entropy','acc':'train_accuracy'}) # profiling_hook = tf.estimator.ProfilerHook(save_steps=20, output_dir='./profiling/') # restore_hook must come before saver_hook return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook,logging_hook])
def test_entmax(): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") length = mtf.Dimension("tensor_length", 8) tensor = mtf.range(mesh, length, tf.float32) output = entmax(tensor) grad = mtf.gradients([output], [tensor])[0] sample = sample_categorical(output, length) mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) sample = lowering.export_to_tf_tensor(sample) grad = lowering.export_to_tf_tensor(grad)
def testTopK(self): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") a_dim = mtf.Dimension("a", 6) b_dim = mtf.Dimension("b", 2) inputs = tf.constant([[1, 10], [2, 9], [3, 8], [4, 7], [5, 6], [6, 5]], dtype=tf.float32) k_dim = mtf.Dimension("k", 2) d_values = tf.constant([[11, 12], [13, 14]], dtype=tf.float32) reduced_dim = a_dim expected_values = tf.constant([[6, 5], [10, 9]], dtype=tf.float32) expected_indices = tf.constant([[5, 4], [0, 1]]) expected_d_inputs = tf.constant([[0, 13], [0, 14], [0, 0], [0, 0], [12, 0], [11, 0]], dtype=tf.float32) mtf_inputs = mtf.import_fully_replicated( mesh, inputs, shape=mtf.Shape([a_dim, b_dim])) mtf_d_values = mtf.import_tf_tensor( mesh, d_values, shape=mtf.Shape([b_dim, k_dim])) mtf_values, mtf_indices = mtf.top_k(mtf_inputs, reduced_dim=reduced_dim, k_dim=k_dim, name="test_nth_smallest") [mtf_d_inputs] = mtf.gradients([mtf_values], [mtf_inputs], [mtf_d_values]) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( shape="rows:2,cols:2", layout="a:rows,b:cols", devices=["", "", "", ""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_values = lowering.export_to_tf_tensor(mtf_values) actual_indices = lowering.export_to_tf_tensor(mtf_indices) actual_d_inputs = lowering.export_to_tf_tensor(mtf_d_inputs) actual_inputs = lowering.export_to_tf_tensor(mtf_inputs) self.assertAllEqual(self.evaluate(actual_inputs), self.evaluate(inputs)) self.assertAllEqual(self.evaluate(actual_values), self.evaluate(expected_values)) self.assertAllEqual(self.evaluate(actual_indices), self.evaluate(expected_indices)) self.assertAllEqual(self.evaluate(actual_d_inputs), self.evaluate(expected_d_inputs))
def my_model_fn(features, labels, mode, params=None, config=None): """Estimator model function. Args: features: input features dictionary labels: ignored mode: a tf.estimator.ModeKeys params: something config: something Returns: something """ del labels, config global_step = tf.train.get_global_step() if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [ host_placement_fn(host_id=t) for t in range(num_hosts) ] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer( device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size physical_shape = list( params["context"].device_assignment.topology.mesh_shape) logical_to_physical = _logical_to_physical(physical_shape, mesh_shape) mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment, logical_to_physical=logical_to_physical) else: var_placer = None mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) def _import_feature(key, allow_missing=False): """Import a feature from the features dictionary into a mtf.Tensor. Args: key: a string allow_missing: a boolean Returns: a mtf.Tensor with dtype int32 and shape [batch_dim, length_dim] """ outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size) batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size) length_dim = mtf.Dimension("length", sequence_length) mtf_shape = mtf.Shape([outer_batch_dim, batch_dim, length_dim]) if key not in features: if allow_missing: return None else: raise ValueError("feature not found %s - features %s = " % (key, features)) tf.logging.info("Import feature %s: %s" % (key, features[key])) x = tf.to_int32(features[key]) x = tf.reshape( x, [outer_batch_size, batch_size // outer_batch_size, -1]) if not use_tpu: x = tf.Print(x, [x], "import feature %s" % key, summarize=1000, first_n=1) return mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) if mode == tf.estimator.ModeKeys.PREDICT: inputs = _import_feature("inputs") inputs = mtf.reshape( inputs, mtf.Shape([ mtf.Dimension("batch", batch_size), mtf.Dimension("length", sequence_length) ])) if isinstance(transformer_model, transformer.Unitransformer): mtf_samples = transformer_model.sample_autoregressive( inputs, variable_dtype=get_variable_dtype()) elif isinstance(transformer_model, transformer.Bitransformer): mtf_samples = transformer_model.decode( inputs, variable_dtype=get_variable_dtype()) else: raise ValueError("unrecognized class") mtf_samples = mtf.anonymize(mtf_samples) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) outputs = lowering.export_to_tf_tensor(mtf_samples) predictions = {"outputs": outputs} return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) targets = _import_feature("targets") anon_targets = mtf.anonymize(targets) if model_type == "lm": _, length_dim = targets.shape inputs = mtf.shift(targets, offset=1, dim=length_dim, wrap=False) else: inputs = _import_feature("inputs") if mode == tf.estimator.ModeKeys.EVAL: if isinstance(transformer_model, transformer.Unitransformer): mtf_samples = transformer_model.sample_autoregressive( inputs, variable_dtype=get_variable_dtype()) elif isinstance(transformer_model, transformer.Bitransformer): mtf_samples = transformer_model.decode( inputs, variable_dtype=get_variable_dtype()) else: raise ValueError("unrecognized class") mtf_samples = mtf.anonymize(mtf_samples) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) outputs = lowering.export_to_tf_tensor(mtf_samples) labels = lowering.export_to_tf_tensor(anon_targets) restore_hook = mtf.MtfRestoreHook(lowering) # metric_names becomes locally scoped if we simply assign # ["padded_neg_log_perplexity"] to it conditioned on if it's None. local_metric_names = metric_names or ["token_accuracy"] def metric_fn(labels, outputs): return get_metric_fns(local_metric_names, labels, outputs) eval_metrics = (metric_fn, [labels, outputs]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, # Unfortunately TPUEstimatorSpec requires us to provide a value for # loss when in EVAL mode. Since we are sampling or decoding from the # model, we don't have a loss to report. loss=tf.constant(0.), evaluation_hooks=[restore_hook], eval_metrics=eval_metrics) if isinstance(transformer_model, transformer.Unitransformer): position_kwargs = dict( sequence_id=_import_feature("targets_segmentation", True), position=_import_feature("targets_position", True), ) elif isinstance(transformer_model, transformer.Bitransformer): position_kwargs = dict( encoder_sequence_id=_import_feature("inputs_segmentation", True), decoder_sequence_id=_import_feature("targets_segmentation", True), encoder_position=_import_feature("inputs_position", True), decoder_position=_import_feature("targets_position", True), ) else: raise ValueError("unrecognized class") logits, loss = transformer_model.call_simple( inputs=inputs, targets=targets, compute_loss=True, mode=mode, variable_dtype=get_variable_dtype(), **position_kwargs) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf.optimize.AdafactorOptimizer( learning_rate=learning_rate) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if not use_tpu: tf_loss = tf.Print(tf_loss, [tf_loss, tf.train.get_global_step()], "step, tf_loss") if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=checkpoints_to_keep, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( model_dir, save_steps=save_steps, saver=saver, listeners=[saver_listener]) gin_config_saver_hook = gin.tf.GinConfigSaverHook( model_dir, summarize_config=True) if mode == tf.estimator.ModeKeys.TRAIN: if use_tpu: return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ])
def model_fn(features, labels, mode, params): """The model_fn argument for creating an Estimator.""" #tf.logging.info("features = %s labels = %s mode = %s params=%s" % # (features, labels, mode, params)) global_step = tf.train.get_global_step() graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") data = features['data'] R0 = features['R0'] * 1. x0 = features['x0'] print("\nR0 in the model function : %0.1f\n" % R0) fields, metrics, kv = recon_model(mesh, data, R0, x0) fieldvar, final_field = fields chisq, prior, loss = metrics ## layout_rules = [("nx_lr", "row"), ("ny_lr", "col"), ("nx", "row"), ("ny", "col"), ("ty", "row"), ("tz", "col"), ("ty_lr", "row"), ("tz_lr", "col"), ("nx_block", "row"), ("ny_block", "col")] mesh_shape = [("row", FLAGS.nx), ("col", FLAGS.ny)] mesh_size = FLAGS.nx * FLAGS.ny # mesh_shape.size mesh_devices = [""] * mesh_size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) # mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) # mesh_size = mesh_shape.size # layout_rules = [("nx_lr", "row"), ("ny_lr", "col"), # ("nx", "row"), ("ny", "col"), # ("ty", "row"), ("tz", "col"), # ("ty_lr", "row"), ("tz_lr", "col"), # ("nx_block","row"), ("ny_block","col")] # devices = ["/job:localhost/replica:0/task:%d/device:GPU:%d"%(i,j) for i in range(0) for j in range(FLAGS.gpus_per_task)] # devices = [""] * mesh_size # mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(mesh_shape, layout_rules, devices) print(mesh_impl) ## if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) # nyq = np.pi*nc/bs # def _cwise_highpass(kfield, kx, ky, kz): # kx = tf.reshape(kx, [-1, 1, 1]) # ky = tf.reshape(ky, [1, -1, 1]) # kz = tf.reshape(kz, [1, 1, -1]) # kk = (kx / bs * nc)**2 + (ky/ bs * nc)**2 + (kz/ bs * nc)**2 # wts = tf.cast(tf.exp(- kk* (R0*bs/nc + 1/nyq)**2), kfield.dtype) # return kfield * (1-wts) # # k_dims_pr = [d.shape[0] for d in kv] # k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]] # cgrads = mesh_utils.r2c3d(var_grads[0], k_dims_pr, dtype=tf.complex64) # cgrads = mtf.cwise(_cwise_highpass, [cgrads] + kv, output_dtype=tf.complex64) # var_grads = [mesh_utils.c2r3d(cgrads, var_grads[0].shape[-3:], dtype=tf.float32)] # update_ops = [mtf.assign(fieldvar, fieldvar - var_grads[0]*0.2)] #optimizer = mtf.optimize.AdafactorOptimizer(1) #optimizer = mtf.optimize.SgdOptimizer(0.01) #optimizer = mtf.optimize.MomentumOptimizer(0.01, 0.001) optimizer = mtf.optimize.AdamWeightDecayOptimizer(features['lr']) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) # start = time.time() lowering = mtf.Lowering(graph, {mesh: mesh_impl}) print("\nTime taken for lowering is : %0.3f" % (time.time() - start)) restore_hook = mtf.MtfRestoreHook(lowering) # tf_init = lowering.export_to_tf_tensor(fieldvar) tf_data = lowering.export_to_tf_tensor(final_field) tf_loss = lowering.export_to_tf_tensor(loss) tf_chisq = lowering.export_to_tf_tensor(chisq) tf_prior = lowering.export_to_tf_tensor(prior) ##Predict if mode == tf.estimator.ModeKeys.PREDICT: tf_loss = lowering.export_to_tf_tensor(loss) tf.summary.scalar("loss", tf_loss) tf.summary.scalar("chisq", tf_chisq) tf.summary.scalar("prior", tf_prior) predictions = { "ic": tf_init, "data": tf_data, } return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[restore_hook], export_outputs={ "data": tf.estimator.export.PredictOutput( predictions) #TODO: is classify a keyword? }) ##Train if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=1, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook(fpath, save_steps=1000, saver=saver, listeners=[saver_listener]) logging_hook = tf.train.LoggingTensorHook( { "loss": tf_loss, "chisq": tf_chisq, "prior": tf_prior }, every_n_iter=10) # Name tensors to be logged with LoggingTensorHook. tf.identity(tf_loss, "loss") tf.identity(tf_prior, "prior") tf.identity(tf_chisq, "chisq") # Save accuracy scalar to Tensorboard output. tf.summary.scalar("loss", tf_loss) tf.summary.scalar("chisq", tf_chisq) tf.summary.scalar("prior", tf_prior) # restore_hook must come before saver_hook return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook, logging_hook]) ##Eval if mode == tf.estimator.ModeKeys.EVAL: return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.EVAL, loss=tf_loss, evaluation_hooks=[restore_hook], eval_metric_ops={ "loss": tf_loss, "chisq": tf_chisq, #tf.metrics.accuracy( # labels=labels, predictions=tf.argmax(tf_logits, axis=1)), })
def __call__(self, features, labels, mode, params): # this is the model_fn """A model is called by TpuEstimator.""" del labels global_step = tf.train.get_global_step() # Graph setup graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(self.mesh_shape) layout_rules = mtf.convert_to_layout_rules(self.layout) if params["use_tpu"]: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [ host_placement_fn(host_id=t) for t in range(num_hosts) ] # Worker 0 caches all the TPU binaries. replica_cache_size = 300 * 1024 * 1024 # 300M per replica. worker0_mem = replica_cache_size * 8 * num_hosts devices_memory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer( device_list, devices_memory_usage) mesh = mtf.Mesh(graph, "my_mesh", var_placer) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, devices_memory_usage) else: var_placer = None mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) mesh = mtf.Mesh(graph, "my_mesh", var_placer) # RUN Model with mtf.utils.outside_all_rewrites(): logits, loss = self.model(mesh, features, params) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) if self.optimizer == "Adafactor": optimizer = mtf.optimize.AdafactorOptimizer() else: assert self.optimizer == "SGD" optimizer = mtf.optimize.SgdOptimizer( learning_rate=self.learning_rate) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) else: # for now, we can only export fully-replicated tensors. fully_replicated_logits = mtf.anonymize(logits) # covert back to tensorflow format lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = tf.to_float(lowering.export_to_tf_tensor(loss)) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) else: tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits) # create estimator with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: saver = tf.train.Saver( tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True, ) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( self.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener], ) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook], ) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(tf_logits): mean_logits = tf.metrics.mean(tf_logits) return {"mean_logits": mean_logits} eval_metrics = (metric_fn, [tf_logits]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics, ) elif mode == tf.estimator.ModeKeys.PREDICT: return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.PREDICT, evaluation_hooks=[restore_hook], loss=None, eval_metrics=eval_metrics, ) @property def dense_initializer(self): if self.config.initializer_range: return tf.truncated_normal_initializer( stddev=self.config.initializer_range) else: return mtf.layers.VarianceScalingInitializer(scale=0.4) @property def embedding_initializer(self): initializer = self.dense_initializer if isinstance(initializer, mtf.layers.DenseInitializer): # embedding matrix is also used as classifier weight matrix. # scale it appropriately. return initializer(reduced_dims=[self.model_dim], new_dims=[self.vocab_dim]) else: return initializer @property def num_hidden_layers(self): return self.config.num_hidden_layers def normalize(self, x, reduce_dim): return nn.layer_norm( x, reduce_dim, subtract_mean=self.config.use_bias, use_bias=self.config.use_bias, ) def model(self, mesh, x, y, params): # x :: [batch, io, vocab] if params["precision"] == "bfloat16": dtype = tf.bfloat16 # master has type float32, slice and activation have type bfloat16 variable_dtype = mtf.VariableDType(tf.float32, tf.bfloat16, tf.bfloat16) else: dtype = tf.float32 # master, slice and activate have all float16 variable_dtype = mtf.VariableDType(tf.float32, tf.float32, tf.float32) # Build the actual model batch_dim = mtf.Dimension("batch", params["batch_size"]) vocab_dim = mtf.Dimension("vocab", params["vocab_size"]) io_dim = mtf.Dimension("sequence", params["io"]) io_chan_dim = mtf.Dimension("io", params["io_channels"]) # from input to mtf x = mtf.import_tf_tensor(mesh, x, mtf.Shape([batch_dim, io_dim, vocab_dim])) # Embeddings with tf.variable_scope(scope="toy", default_name="seq2seq"): with tf.variable_scope("embeddings"): # Perform embedding lookup on the word ids. embedding_table = mtf.get_variable( mesh, "word_embeddings", mtf.Shape([vocab_dim, io_chan_dim]), initializer=self.embedding_initializer, ) word_embedding_output = mtf.gather( embedding_table, x, dim=vocab_dim, output_shape=io_chan_dim) # Add positional embeddings and token type embeddings, then layer # normalize and perform dropout. embedding_output = word_embedding_output pos_embedding = mtf.get_variable( mesh, "pos_embeddings", mtf.Shape([io_dim, io_chan_dim]), initializer=self.embedding_initializer, ) embedding_output = self.normalize(embedding_output) embedding_output = mtf.dropout( embedding_output, keep_prob=1.0 - self.config.layer_output_dropout_prob, ) # shift token by pos embeddings x = word_embedding_output + pos_embedding x = mtf.cast(x, variable_dtype.activation_dtype) h = x for lnum in range(1, self.num_hidden_layers + 2): if lnum + 1 == self.num_hidden_layers + 2: # output layer dim = io_dim elif lnum % 2 == 0: dim = mtf.Dimension("hidden_even", io_chan_dim) else: dim = mtf.Dimension("hidden_odd", io_chan_dim) h = mtf.layers.dense( h, dim, use_bias=False, master_dtype=variable_dtype.master_dtype, slice_dtype=variable_dtype.slice_dtype, name="layer_%d" % lnum, ) prediction = h # project back to token dimensions # compute the mean quare loss between the input and the output loss = mtf.reduce_mean(mtf.square(y - prediction)) return prediction, loss
def _model_fn(input_fea, input_lab): """Creates a model, add summary, modes (train or eval), and hooks.""" # input_fea and input_lab should be a list (laid_out_tensors). if not isinstance(input_fea, list): input_fea = [input_fea] if not isinstance(input_lab, list): input_lab = [input_lab] def _add_summary(lowering, train_or_eval, tf_loss, scalars, global_step): """Add all summaries.""" for k in scalars.keys(): if not isinstance(scalars[k], tf.Tensor): scalars[k] = tf.cast( lowering.export_to_tf_tensor(scalars[k]), tf.float32) def _host_loss_summary(global_step, tf_loss, **scalars): """Add summary.scalar in host side.""" gs = tf.cast(global_step, tf.int64) sum_loss = contrib_summary.scalar( '{}_loss'.format(train_or_eval), tf_loss, step=gs) sum_ops = [sum_loss.op] for description, tf_metric in scalars.iteritems(): sum_metric = contrib_summary.scalar( '{}_{}'.format(train_or_eval, description), tf_metric, step=gs) sum_ops.append(sum_metric) with tf.control_dependencies(sum_ops): return tf.identity(tf_loss) if FLAGS.use_tpu: # Cast the global step to tf.int32, since # outside_compilation does not support tf.int64. tf_loss = tpu.outside_compilation( _host_loss_summary, tf.cast(global_step, tf.int32), tf_loss, **scalars) else: tf_loss = _host_loss_summary( tf.cast(global_step, tf.int32), tf_loss, **scalars) return tf_loss global_step = tf.train.get_or_create_global_step() graph, mesh, mesh_impl = mesh_context.create_graph_mesh_and_mesh_impl() with mtf.utils.outside_all_rewrites(): # Do not tpu_rewrite this part. Inside this unet, If you use Tensorflow, # instead of Mesh-Tensorflor, it will cause host to tpu send/rec. preds, loss, scalars, bn_update_ops = ( unet.unet_with_spatial_partition( mesh, mesh_impl, train_or_eval, input_fea, input_lab)) if train_or_eval == 'train': var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) lr = FLAGS.lr * tf.pow( FLAGS.lr_drop_rate, tf.floor(tf.cast(global_step, tf.float32) / FLAGS.lr_drop_steps)) scalars['learning_rate'] = lr optimizer = mtf.optimize.AdafactorOptimizer(learning_rate=lr) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) # This is where the actual tf graph got built. lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) tf_update_ops.extend( [lowering.lowered_operation(op) for op in bn_update_ops]) else: # train_or_eval == 'eval': preds = [mtf.anonymize(pred) for pred in preds] # This is where the actual tf graph got built. lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_preds = [tf.cast( lowering.export_to_tf_tensor(pred), tf.float32) for pred in preds] tf_loss = tf.cast(lowering.export_to_tf_tensor(loss), tf.float32) if FLAGS.write_summary: tf_loss = _add_summary( lowering, train_or_eval, tf_loss, scalars, global_step) master_to_slice_hook = mtf.MtfRestoreHook(lowering) if train_or_eval == 'train': with mtf.utils.outside_all_rewrites(): saver = tf.train.Saver(tf.global_variables(), save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) slice_to_master_hook = tf.train.CheckpointSaverHook( FLAGS.checkpoint_dir, save_steps=FLAGS.save_checkpoints_steps, saver=saver, listeners=[saver_listener]) captured_hooks.capture([master_to_slice_hook, slice_to_master_hook]) return tf.group([tf_loss] + tf_update_ops) else: # train_or_eval == 'eval': if FLAGS.use_tpu: tf_preds.extend([tf_loss, global_step]) tf_preds_dtypes = [tf_pred.dtype for tf_pred in tf_preds] tf_preds_shapes = [tf_pred.shape for tf_pred in tf_preds] captured_hooks.capture([master_to_slice_hook, None]) captured_output_dtypes_shapes.capture( [tf_preds_dtypes, tf_preds_shapes]) return tpu_ops.outfeed_enqueue_tuple(tf_preds) else: tf_preds.extend([tf_loss, global_step]) captured_hooks.capture([master_to_slice_hook, None]) return tf_preds
def estimator_model_fn(cls, hparams, features, labels, mode, config=None, params=None, decode_hparams=None, use_tpu=False): hparams = copy.deepcopy(hparams) hparams.use_tpu = use_tpu # merge decode_hparams into hparams if present if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None: for k, v in six.iteritems(decode_hparams.values()): if hasattr(hparams, k) and getattr(hparams, k) != v: tf.logging.warning("Overriding hparams.%s with %s from decode_hparams" % (k, v)) setattr(hparams, k, v) # Instantiate model data_parallelism = None if not use_tpu and config: data_parallelism = config.data_parallelism model = cls( hparams, mode, data_parallelism=data_parallelism, decode_hparams=decode_hparams) global_step = tf.train.get_global_step() mesh_shape = mtf.convert_to_shape(hparams.mesh_shape) layout_rules = mtf.convert_to_layout_rules(hparams.layout) if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) else: var_placer = None if data_parallelism is None or len(data_parallelism.ps_devices) == 1: mesh_devices = [""] * mesh_shape.size else: assert len(data_parallelism.ps_devices) == mesh_shape.size mesh_devices = data_parallelism.ps_devices mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) # PREDICT mode if mode == tf.estimator.ModeKeys.PREDICT: return model.estimator_spec_predict(features, mesh, mesh_impl, use_tpu) logits, loss = model.mtf_model_fn(features, mesh) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) lr = learning_rate.learning_rate_schedule(hparams) tf.summary.scalar("learning_rate", lr) mtf_lr = mtf.import_tf_tensor( mesh, tf.convert_to_tensor(lr, dtype=tf.float32), mtf.Shape([])) optimizer = mtf.optimize.make_optimizer(hparams, mtf_lr) update_ops = [] for grad, var in zip(var_grads, graph.trainable_variables): update_ops.extend(optimizer.apply_grad(grad, var)) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if logits and mode != tf.estimator.ModeKeys.TRAIN: tf_logits = lowering.export_to_tf_tensor(logits) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) # tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver( tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( hparams.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) # EVAL mode if mode == tf.estimator.ModeKeys.EVAL: tf_logits = lowering.export_to_tf_tensor(logits) return model.estimator_spec_eval(features, tf_logits, labels, tf_loss, restore_hook, use_tpu) if use_tpu: # TPU host call. Important: need to be called before remove_summaries() if hparams.tpu_enable_host_call: host_call = t2t_model.create_host_call(hparams.model_dir) else: host_call = None t2t_model.remove_summaries() return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, host_call=host_call, training_hooks=[restore_hook, saver_hook]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook])
def my_model_fn(features, labels, mode, params=None, config=None): """Estimator model function. Args: features: input features dictionary labels: ignored mode: a tf.estimator.ModeKeys params: something config: something Returns: something """ del labels, config use_tpu = FLAGS.tpu global_step = tf.train.get_global_step() mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) else: var_placer = None mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) model = transformer.Bitransformer( encoder_layer_stack=layer_stack(include_encdec_attention=False), decoder_layer_stack=layer_stack(include_encdec_attention=True), encoder_d_model=FLAGS.d_model, decoder_d_model=FLAGS.d_model, input_vocab_size=transformer_dataset.padded_vocab_size( transformer_dataset.inputs_vocab_size(FLAGS.dataset)), output_vocab_size=transformer_dataset.padded_vocab_size( transformer_dataset.targets_vocab_size(FLAGS.dataset)), max_length=FLAGS.max_length, shared_embedding=False, shared_embedding_and_softmax_weights=True, label_smoothing=FLAGS.label_smoothing, layout=FLAGS.layout, mesh_shape=FLAGS.mesh_shape) inputs = import_feature(features, mesh, "inputs") # Data-types used for variables and activations # See comments in the FLAGS master_dtype = tf.as_dtype(FLAGS.master_dtype) if FLAGS.slice_dtype: slice_dtype = tf.as_dtype(FLAGS.slice_dtype) elif not FLAGS.tpu or FLAGS.mode == "train": slice_dtype = tf.float32 else: slice_dtype = tf.bfloat16 if FLAGS.activation_dtype: activation_dtype = tf.as_dtype(FLAGS.activation_dtype) else: activation_dtype = tf.bfloat16 if FLAGS.tpu else tf.float32 variable_dtype = mtf.VariableDType(master_dtype=master_dtype, slice_dtype=slice_dtype, activation_dtype=activation_dtype) # PREDICT mode if mode == tf.estimator.ModeKeys.PREDICT: mtf_samples = model.decode( inputs, variable_dtype=variable_dtype, beam_size=FLAGS.beam_size, alpha=FLAGS.alpha, temperature=FLAGS.temperature) mtf_samples = mtf.anonymize(mtf_samples) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=FLAGS.autostack) outputs = lowering.export_to_tf_tensor(mtf_samples) predictions = { "outputs": outputs } return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) targets = import_feature(features, mesh, "targets") anon_targets = mtf.anonymize(targets) logits, loss = model.call_simple( inputs=inputs, targets=targets, compute_loss=True, mode=mode, variable_dtype=variable_dtype, encoder_sequence_id=import_feature(features, mesh, "inputs_segmentation"), decoder_sequence_id=import_feature( features, mesh, "targets_segmentation"), encoder_position=import_feature(features, mesh, "inputs_position"), decoder_position=import_feature(features, mesh, "targets_position") ) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf.optimize.AdafactorOptimizer() update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=FLAGS.autostack) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if not use_tpu: tf_loss = tf.Print( tf_loss, [tf_loss, tf.train.get_global_step()], "step, tf_loss") if logits and mode != tf.estimator.ModeKeys.TRAIN: tf_logits = lowering.export_to_tf_tensor(logits) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver( tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( FLAGS.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) if mode == tf.estimator.ModeKeys.TRAIN: if use_tpu: return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook]) elif mode == tf.estimator.ModeKeys.EVAL: def padded_neg_log_perplexity(logits, labels): weights = tf.to_float(tf.not_equal(labels, 0)) xent = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels, logits=logits) return {"neg_log_perplexity": tf.metrics.mean(-xent, weights)} labels = lowering.export_to_tf_tensor(anon_targets) eval_metrics = (padded_neg_log_perplexity, [tf_logits, labels]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)
def my_model_fn(features, labels, mode, params=None, config=None): """Estimator model function. Args: features: dictionary where keys are strings like "inputs" and "targets" and the values are the actual values of "inputs". See TPUEstimator's docs for more information labels: ignored argument mode: a tf.estimator.ModeKeys params: dictionary containing the key "context" config: ignored argument Returns: a TPUEstimatorSpec """ del labels, config global_step = tf.train.get_global_step() if use_tpu and "context" in params: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [ host_placement_fn(host_id=t) for t in range(num_hosts) ] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer( device_list, devices_memeory_usage) # deprecated mesh_devices = [""] * mesh_shape.size physical_shape = list( params["context"].device_assignment.topology.mesh_shape) logical_to_physical = mtf.simd_mesh_impl.auto_logical_to_physical_tpu( mesh_shape.to_integer_list, physical_shape) mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment, logical_to_physical=logical_to_physical) else: var_placer = None # deprecated mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) mtf_features = {} for key, x in features.items(): outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size) batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size) # Some auxiliary features may have been generated in packing. # The names of these new features are of the form # "<original_feature_name>_<suffix>", e.g. "inputs_segmentation". # We look up the lengths based on the original feature name, without # the "_<suffix>". feature_length = sequence_length[key.split("_")[0]] length_dim = mtf.Dimension("length", feature_length) ensemble_dims = ([mtf.Dimension("ensemble", ensemble_inputs)] if ensemble_inputs else []) feature_shape = mtf.Shape(ensemble_dims + [outer_batch_dim, batch_dim, length_dim]) x = tf.cast(features[key], tf.int32) x = tf.reshape(x, feature_shape.to_integer_list) if not use_tpu: tf.logging.info("feature %s : %s" % (key, x)) x = tf.Print(x, [x], "import feature %s" % key, summarize=1000, first_n=10) mtf_features[key] = mtf.import_fully_replicated(mesh, x, feature_shape, name=key) if key == "targets" or key == "codeprefixedtargets" or key == "controlcode": anon_targets = mtf.anonymize(mtf_features[key]) if mode == tf.estimator.ModeKeys.PREDICT: def _feature_shape(key): feature_length = sequence_length[key.split("_")[0]] return mtf.Shape([ mtf.Dimension("batch", batch_size), mtf.Dimension("length", feature_length) ]) mtf_features = { k: mtf.reshape(v, _feature_shape(k)) for k, v in six.iteritems(mtf_features) } inputs = mtf_features["inputs"] if attribute_embedding: attributes = mtf_features["attribute"] else: attributes = None if has_partial_sequences: controlcodes = mtf_features["controlcode"] else: controlcodes = None if predict_fn: mtf_samples = predict_fn(model=transformer_model, features=mtf_features, variable_dtype=get_variable_dtype()) elif isinstance(transformer_model, transformer.Unitransformer): # pad so that there is enough room for the targets inputs = mtf.pad(inputs, [0, sequence_length["targets"]], length_dim.name) mtf_samples = transformer_model.sample_autoregressive( inputs, variable_dtype=get_variable_dtype(), remove_partial_sequences=True) elif isinstance(transformer_model, Bitransformer_ll): mtf_samples = transformer_model.decode( inputs, attributes=attributes, controlcodes=controlcodes, has_partial_sequences=has_partial_sequences, remove_partial_sequences=remove_partial_sequences, variable_dtype=get_variable_dtype()) # elif isinstance( transformer_model, (transformer.Bitransformer, transformer.StudentTeacher)): mtf_samples = transformer_model.decode( inputs, variable_dtype=get_variable_dtype()) else: raise ValueError("unrecognized class") mtf_samples = mtf.anonymize(mtf_samples) inputs = mtf.anonymize(inputs) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) inputs = lowering.export_to_tf_tensor(inputs) outputs = lowering.export_to_tf_tensor(mtf_samples) predictions = {"inputs": inputs, "outputs": outputs} # When exporting a model, we need to communicate to TF-Serving that # master variables need to be copied to their slave slice variables. # Estimator uses a Scaffold's "local_init_op" for this purpose, so we # augment the default "local_init_op" here. # # The "ready_op" is also constructed here to ensure the variables # initialized by "local_init_op" are the same ones checked by "ready_op". # # WARNING: Any variables created outside of this model_fn() # (e.g. tpu_estimator/iterations_per_loop) will NOT be initialized nor # checked by these ops. def scaffold_fn(): return tf.train.Scaffold( local_init_op=tf.group( tf.train.Scaffold.default_local_init_op(), lowering.copy_masters_to_slices(), name="mtf_local_init_op"), ready_op=tf.concat([ tf.report_uninitialized_variables(), resources.report_uninitialized_resources() ], axis=0, name="mtf_ready_op")) return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, scaffold_fn=scaffold_fn, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) assert (mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL) def logits_and_loss(mtf_features): """Compute logits and loss. Args: mtf_features: a dictionary Returns: logits: a mtf.Tensor loss: a mtf.Tensor """ if model_type == "lm": # TOTRY Adapt that to our case if "inputs" in mtf_features: mtf_features = _dynamic_text2self(mtf_features) _, _, length_dim = mtf_features["targets"].shape inputs = mtf.shift(mtf_features["targets"], offset=1, dim=length_dim, wrap=False) else: inputs = mtf_features["inputs"] if attribute_embedding: attributes = mtf_features["attribute"] else: attributes = None if control_codes: codeprefixedtargets = mtf_features["codeprefixedtargets"] else: codeprefixedtargets = None if isinstance(transformer_model, transformer.Unitransformer): position_kwargs = dict( sequence_id=mtf_features.get("targets_segmentation", None), position=mtf_features.get("targets_position", None), ) elif isinstance(transformer_model, transformer.Bitransformer ) or model_type == "bi_student_teacher": if control_codes: position_kwargs = dict( encoder_sequence_id=mtf_features.get( "inputs_segmentation", None), decoder_sequence_id=mtf_features.get( "codeprefixedtargets_segmentation", None), decoder_subsequence_id=mtf_features.get( "codeprefixedtargets_subsegmentation", None), encoder_position=mtf_features.get( "inputs_position", None), decoder_position=mtf_features.get( "codeprefixedtargets_position", None), ) else: position_kwargs = dict( encoder_sequence_id=mtf_features.get( "inputs_segmentation", None), decoder_sequence_id=mtf_features.get( "targets_segmentation", None), decoder_subsequence_id=mtf_features.get( "targets_subsegmentation", None), encoder_position=mtf_features.get( "inputs_position", None), decoder_position=mtf_features.get( "targets_position", None), ) else: raise ValueError("unrecognized class") if isinstance(transformer_model, Bitransformer_ll): if cycle_consistency_loss: logits_ae, l_ae = transformer_model.call_simple( inputs=inputs, targets=mtf_features["targets"], compute_loss=True, attributes=attributes, codeprefixedtargets=codeprefixedtargets, mode=mode, variable_dtype=get_variable_dtype(), **position_kwargs) if has_partial_sequences: controlcodes = mtf_features["controlcode"] else: controlcodes = None with gin.config_scope('training'): mtf_samples = transformer_model.decode( inputs, attributes=attributes, controlcodes=controlcodes, has_partial_sequences=has_partial_sequences, remove_partial_sequences=remove_partial_sequences, variable_dtype=get_variable_dtype()) # mtf_samples = mtf.anonymize(mtf_samples) outputs = mtf_samples logits_cycle, l_cycle = transformer_model.call_simple( inputs=outputs, targets=mtf_features["targets"], compute_loss=True, attributes=attributes, codeprefixedtargets=codeprefixedtargets, mode=mode, variable_dtype=get_variable_dtype(), **position_kwargs) loss_ae_cycle = lambda_ae * l_ae + lambda_cycle * l_cycle return logits_cycle, loss_ae_cycle else: return transformer_model.call_simple( inputs=inputs, targets=mtf_features["targets"], compute_loss=True, attributes=attributes, codeprefixedtargets=codeprefixedtargets, mode=mode, variable_dtype=get_variable_dtype(), **position_kwargs) else: return transformer_model.call_simple( inputs=inputs, targets=mtf_features["targets"], compute_loss=True, mode=mode, variable_dtype=get_variable_dtype(), num_microbatches=num_microbatches, **position_kwargs) if mode == tf.estimator.ModeKeys.TRAIN: num_microbatches = serialize_num_microbatches( batch_dim, sequence_length, mesh_shape, layout_rules) if num_microbatches > 1: def serialized_fn(mtf_features): return { "loss": (logits_and_loss(mtf_features)[1] / num_microbatches) } var_grads, loss_dict = mtf.serialize_training_step( mtf_features, serialized_fn, batch_dim, num_microbatches) loss = loss_dict["loss"] else: loss = logits_and_loss(mtf_features)[1] var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) if tpu_summaries: mtf.scalar_summary("loss", loss) if callable(learning_rate_schedule): # the following happens on CPU since TPU can't handle summaries. with mtf.utils.outside_all_rewrites(): learning_rate = learning_rate_schedule( step=tf.train.get_global_step()) tf.summary.scalar("learning_rate", learning_rate) else: learning_rate = learning_rate_schedule if isinstance(variable_filter, str): pattern = re.compile(variable_filter) variable_filter_fn = lambda v: pattern.search(v.name) elif variable_filter is None: variable_filter_fn = lambda v: True elif callable(variable_filter): variable_filter_fn = variable_filter else: raise ValueError( "variable_filter must be None, a string, or a callable function" ) trainable_vars = [ v for v in graph.trainable_variables if variable_filter_fn(v) ] trainable_var_grads = [ g for g, v in zip(var_grads, graph.trainable_variables) if variable_filter_fn(v) ] if len(trainable_vars) != len(graph.trainable_variables): tf.logging.info("Variables being trained:") tf.logging.info([v.name for v in trainable_vars]) tf.logging.info("Variables not being trained:") tf.logging.info([ v.name for v in graph.trainable_variables if not variable_filter_fn(v) ]) update_ops = optimizer(learning_rate=learning_rate).apply_grads( trainable_var_grads, trainable_vars) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.cast(tf_loss, tf.float32) if not use_tpu: tf_loss = tf.Print( tf_loss, [tf_loss, tf.train.get_global_step()], "step, tf_loss") tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) if hasattr(transformer_model, "initialize"): with mtf.utils.outside_all_rewrites(): transformer_model.initialize() if tpu_summaries: # has to be outside of # with mtf.utils.outside_all_rewrites() host_call = mtf.utils.create_host_call(model_dir) mtf.utils.remove_summaries() else: host_call = None with mtf.utils.outside_all_rewrites(): if init_checkpoint: ckpt_vars = { v for v, _ in tf.train.list_variables(init_checkpoint) } global_vars = {v.op.name for v in tf.global_variables()} restore_vars = ckpt_vars.intersection(global_vars) tf.logging.info("Initializing variables from %s:", init_checkpoint) tf.logging.debug("\n".join(sorted(restore_vars))) tf.logging.info("Variables in %s but not in graph:", init_checkpoint) tf.logging.info("\n".join(sorted(ckpt_vars - global_vars))) tf.logging.info("Variables in graph but not in %s:", init_checkpoint) tf.logging.info("\n".join(sorted(global_vars - ckpt_vars))) tf.train.init_from_checkpoint(init_checkpoint, {v: v for v in restore_vars}) # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=keep_checkpoint_max, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( model_dir, save_steps=save_checkpoints_steps, saver=saver, listeners=[saver_listener]) gin_config_saver_hook = gin.tf.GinConfigSaverHook( model_dir, summarize_config=True, include_step_in_filename=False) if use_tpu: return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, host_call=host_call, training_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ]) elif mode == tf.estimator.ModeKeys.EVAL: logits, loss = logits_and_loss(mtf_features) anon_logits = mtf.anonymize(logits) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) tf_loss = tf.cast(lowering.export_to_tf_tensor(loss), tf.float32) tf_loss = tf.cast(tf_loss, tf.float32) tf_logits = tf.cast(lowering.export_to_tf_tensor(anon_logits), tf.float32) def simple_metrics(logits, labels): """Simple metrics for teacher-forced eval.""" weights = tf.cast(tf.not_equal(labels, 0), tf.float32) xent = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels, logits=logits) predictions = tf.cast(tf.argmax(logits, axis=-1), labels.dtype) token_correct = tf.cast(tf.equal(predictions, labels), tf.float32) * weights sequence_correct = tf.to_float( tf.equal(tf.reduce_sum(token_correct, -1), tf.reduce_sum(weights, -1))) sequence_weights = tf.to_float( tf.not_equal(tf.reduce_sum(weights, -1), 0)) return { "neg_log_perplexity": tf.metrics.mean(-xent, weights), "token_accuracy": tf.metrics.mean(token_correct, weights), "sequence_accuracy": tf.metrics.mean(sequence_correct, sequence_weights) } labels = lowering.export_to_tf_tensor(anon_targets) eval_metrics = (simple_metrics, [tf_logits, labels]) with mtf.utils.outside_all_rewrites(): restore_hook = mtf.MtfRestoreHook(lowering) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)
def my_model_fn(features, labels, mode, params=None, config=None): """Estimator model function. Args: features: input features dictionary labels: ignored mode: a tf.estimator.ModeKeys params: something config: something Returns: something """ del labels, config global_step = tf.train.get_global_step() if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [ host_placement_fn(host_id=t) for t in range(num_hosts) ] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer( device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size physical_shape = list( params["context"].device_assignment.topology.mesh_shape) logical_to_physical = _logical_to_physical(physical_shape, mesh_shape) mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment, logical_to_physical=logical_to_physical) else: var_placer = None mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size) batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size) length_dim = mtf.Dimension("length", sequence_length) feature_shape = mtf.Shape([outer_batch_dim, batch_dim, length_dim]) mtf_features = {} for key, x in features.items(): x = tf.to_int32(features[key]) x = tf.reshape(x, [ outer_batch_size, batch_size // outer_batch_size, sequence_length ]) if not use_tpu: x = tf.Print(x, [x], "import feature %s" % key, summarize=1000, first_n=1) mtf_features[key] = mtf.import_fully_replicated(mesh, x, feature_shape, name=key) if mode == tf.estimator.ModeKeys.PREDICT: inputs = mtf_features["inputs"] inputs = mtf.reshape( inputs, mtf.Shape([ mtf.Dimension("batch", batch_size), mtf.Dimension("length", sequence_length) ])) if isinstance(transformer_model, transformer.Unitransformer): mtf_samples = transformer_model.sample_autoregressive( inputs, variable_dtype=get_variable_dtype()) elif isinstance( transformer_model, (transformer.Bitransformer, transformer.StudentTeacher)): mtf_samples = transformer_model.decode( inputs, variable_dtype=get_variable_dtype()) else: raise ValueError("unrecognized class") mtf_samples = mtf.anonymize(mtf_samples) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) outputs = lowering.export_to_tf_tensor(mtf_samples) predictions = {"outputs": outputs} return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) elif mode == tf.estimator.ModeKeys.EVAL: raise NotImplementedError("We don't expect to use mode == eval.") else: assert mode == tf.estimator.ModeKeys.TRAIN num_microbatches = serialize_num_microbatches( batch_dim, length_dim, mesh_shape, layout_rules) def model_fn(mtf_features): """The kind of function we need for mtf.serialize_training_step. Args: mtf_features: a dictionary Returns: a dictionary """ targets = mtf_features["targets"] if model_type == "lm": _, _, length_dim = targets.shape inputs = mtf.shift(targets, offset=1, dim=length_dim, wrap=False) else: inputs = mtf_features["inputs"] if isinstance(transformer_model, transformer.Unitransformer): position_kwargs = dict( sequence_id=mtf_features.get("targets_segmentation", None), position=mtf_features.get("targets_position", None), ) elif isinstance(transformer_model, transformer.Bitransformer ) or model_type == "bi_student_teacher": position_kwargs = dict( encoder_sequence_id=mtf_features.get( "inputs_segmentation", None), decoder_sequence_id=mtf_features.get( "targets_segmentation", None), encoder_position=mtf_features.get( "inputs_position", None), decoder_position=mtf_features.get( "targets_position", None), ) else: raise ValueError("unrecognized class") logits, loss = transformer_model.call_simple( inputs=inputs, targets=targets, compute_loss=True, mode=mode, variable_dtype=get_variable_dtype(), **position_kwargs) if num_microbatches > 1: loss /= float(num_microbatches) del logits return {"loss": loss} if num_microbatches > 1: var_grads, loss_dict = mtf.serialize_training_step( mtf_features, model_fn, batch_dim, num_microbatches) else: loss_dict = model_fn(mtf_features) var_grads = mtf.gradients( [loss_dict["loss"]], [v.outputs[0] for v in graph.trainable_variables]) loss = loss_dict["loss"] if callable(learning_rate_schedule): # the following happens on CPU since TPU can't handle summaries. with mtf.utils.outside_all_rewrites(): learning_rate = learning_rate_schedule( step=tf.train.get_global_step()) tf.summary.scalar("learning_rate", learning_rate) else: learning_rate = learning_rate_schedule update_ops = optimizer(learning_rate=learning_rate).apply_grads( var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if not use_tpu: tf_loss = tf.Print( tf_loss, [tf_loss, tf.train.get_global_step()], "step, tf_loss") tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) if hasattr(transformer_model, "initialize"): with mtf.utils.outside_all_rewrites(): transformer_model.initialize() with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=keep_checkpoint_max, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( model_dir, save_steps=save_checkpoints_steps, saver=saver, listeners=[saver_listener]) gin_config_saver_hook = gin.tf.GinConfigSaverHook( model_dir, summarize_config=True) if use_tpu: if tpu_summaries: tf.summary.scalar("loss", tf_loss) host_call = mtf.utils.create_host_call(model_dir) mtf.utils.remove_summaries() else: host_call = None return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, host_call=host_call, training_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[ restore_hook, saver_hook, gin_config_saver_hook, ])
# result = mtf.sigmoid(result) # result = -(label*mtf.log(result)+(1-label)*mtf.log(1-result)) # result = mtf.reduce_sum(result) result = mtf.layers.sigmoid_cross_entropy_with_logits(result,label) wide_loss = mtf.reduce_sum(result) print("========",global_step) devices = ["gpu:0"] mesh_shape = [("all_processors", 1)] layout_rules = [("dim1", "all_processors")] mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, devices) var_grads = mtf.gradients( [wide_loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf.optimize.SgdOptimizer(0.01) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh:mesh_impl}) restore_hook = mtf.MtfRestoreHook(lowering) tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) estimator=tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=wide_loss, train_op=train_op, training_chief_hooks=[restore_hook]) WDlaunch = tf.estimator.Estimator( model_fn=estimator,
def Optimize(graph, loss, lr=0.01): grads = mtf.gradients([loss], [v.outputs[0] for v in graph.trainable_variables]) assert all(g for g in grads) opt = mtf.optimize.SgdOptimizer(lr) return opt.apply_grads(grads, graph.trainable_variables)
def recon_prototype(mesh, data, nc=FLAGS.nc, bs=FLAGS.box_size, batch_size=FLAGS.batch_size, a0=FLAGS.a0, a=FLAGS.af, nsteps=FLAGS.nsteps, dtype=tf.float32): """ Prototype of function computing LPT deplacement. Returns output tensorflow and mesh tensorflow tensors """ if dtype == tf.float32: npdtype = "float32" cdtype = tf.complex64 elif dtype == tf.float64: npdtype = "float64" cdtype = tf.complex128 print(dtype, npdtype) # Compute a few things first, using simple tensorflow stages = np.linspace(a0, a, nsteps, endpoint=True) #graph = mtf.Graph() #mesh = mtf.Mesh(graph, "my_mesh") # Define the named dimensions # Parameters of the small scales decomposition n_block_x = FLAGS.nx n_block_y = FLAGS.ny n_block_z = 1 halo_size = FLAGS.hsize if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z): new_size = int(0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z)) print('WARNING: REDUCING HALO SIZE from %d to %d' % (halo_size, new_size)) halo_size = new_size # Parameters of the large scales decomposition downsampling_factor = 2 lnc = nc // 2**downsampling_factor fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tfx_dim = mtf.Dimension("tx", nc) tfy_dim = mtf.Dimension("ty", nc) tfz_dim = mtf.Dimension("tz", nc) # Dimensions of the low resolution grid x_dim = mtf.Dimension("nx_lr", lnc) y_dim = mtf.Dimension("ny_lr", lnc) z_dim = mtf.Dimension("nz_lr", lnc) tx_dim = mtf.Dimension("tx_lr", lnc) ty_dim = mtf.Dimension("ty_lr", lnc) tz_dim = mtf.Dimension("tz_lr", lnc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) k_dims = [tx_dim, ty_dim, tz_dim] batch_dim = mtf.Dimension("batch", batch_size) klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim]) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False, dtype=npdtype) kx = mtf.import_tf_tensor(mesh, kvec[0].squeeze().astype(npdtype), shape=[tfx_dim]) ky = mtf.import_tf_tensor(mesh, kvec[1].squeeze().astype(npdtype), shape=[tfy_dim]) kz = mtf.import_tf_tensor(mesh, kvec[2].squeeze().astype(npdtype), shape=[tfz_dim]) kv = [ky, kz, kx] # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([lnc, lnc, lnc], symmetric=False, dtype=npdtype) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype(npdtype) / 2**downsampling_factor, shape=[tx_dim]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype(npdtype) / 2**downsampling_factor, shape=[ty_dim]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype(npdtype) / 2**downsampling_factor, shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] # kvec for high resolution blocks padded_sx_dim = mtf.Dimension('padded_sx_block', nc // n_block_x + 2 * halo_size) padded_sy_dim = mtf.Dimension('padded_sy_block', nc // n_block_y + 2 * halo_size) padded_sz_dim = mtf.Dimension('padded_sz_block', nc // n_block_z + 2 * halo_size) kvec_hr = flowpm.kernels.fftk([ nc // n_block_x + 2 * halo_size, nc // n_block_y + 2 * halo_size, nc // n_block_z + 2 * halo_size ], symmetric=False, dtype=npdtype) kx_hr = mtf.import_tf_tensor(mesh, kvec_hr[0].squeeze().astype(npdtype), shape=[padded_sx_dim]) ky_hr = mtf.import_tf_tensor(mesh, kvec_hr[1].squeeze().astype(npdtype), shape=[padded_sy_dim]) kz_hr = mtf.import_tf_tensor(mesh, kvec_hr[2].squeeze().astype(npdtype), shape=[padded_sz_dim]) kv_hr = [ky_hr, kz_hr, kx_hr] # kvec for prior blocks prior_sx_dim = mtf.Dimension('prior_sx_block', nc // n_block_x) prior_sy_dim = mtf.Dimension('prior_sy_block', nc // n_block_y) prior_sz_dim = mtf.Dimension('prior_sz_block', nc // n_block_z) kvec_pr = flowpm.kernels.fftk( [nc // n_block_x, nc // n_block_y, nc // n_block_z], symmetric=False, dtype=npdtype) kx_pr = mtf.import_tf_tensor(mesh, kvec_pr[0].squeeze().astype(npdtype), shape=[prior_sx_dim]) ky_pr = mtf.import_tf_tensor(mesh, kvec_pr[1].squeeze().astype(npdtype), shape=[prior_sy_dim]) kz_pr = mtf.import_tf_tensor(mesh, kvec_pr[2].squeeze().astype(npdtype), shape=[prior_sz_dim]) kv_pr = [ky_pr, kz_pr, kx_pr] shape = [batch_dim, fx_dim, fy_dim, fz_dim] lr_shape = [batch_dim, x_dim, y_dim, z_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] ## Compute initial initial conditions distributed fieldvar = mtf.get_variable(mesh, 'linear', hr_shape) input_field = tf.placeholder(data.dtype, [ batch_size, n_block_x, n_block_y, n_block_z, nc // n_block_x, nc // n_block_y, nc // n_block_z ]) mtfinp = mtf.import_tf_tensor(mesh, input_field, shape=hr_shape) linearop = mtf.assign(fieldvar, mtfinp) # field = fieldvar initc = mtf.slicewise(lambda x: x[:, 0, 0, 0], [field], output_dtype=tf.float32, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) # for block_size_dim in hr_shape[-3:]: field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name) for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]): field = mpm.halo_reduce(field, blocks_dim, block_size_dim, halo_size) field = mtf.reshape(field, field.shape + [mtf.Dimension('h_dim', 1)]) high = field low = mesh_utils.downsample(field, downsampling_factor, antialias=True) low = mtf.reshape(low, low.shape[:-1]) high = mtf.reshape(high, high.shape[:-1]) for block_size_dim in hr_shape[-3:]: low = mtf.slice(low, halo_size // 2**downsampling_factor, block_size_dim.size // 2**downsampling_factor, block_size_dim.name) # Hack usisng custom reshape because mesh is pretty dumb low = mtf.slicewise(lambda x: x[:, 0, 0, 0], [low], output_dtype=initc.dtype, output_shape=lr_shape, name='my_dumb_reshape', splittable_dims=lr_shape[:-1] + hr_shape[:4]) # Here we can run our nbody if FLAGS.nbody: state = mtfpm.lpt_init(low, high, 0.1, kv_lr, kv_hr, halo_size, hr_shape, lr_shape, part_shape[1:], downsampling_factor=downsampling_factor, antialias=True) final_state = mtfpm.nbody(state, stages, lr_shape, hr_shape, kv_lr, kv_hr, halo_size, downsampling_factor=downsampling_factor) else: final_state = mtfpm.lpt_init(low, high, stages[-1], kv_lr, kv_hr, halo_size, hr_shape, lr_shape, part_shape[1:], downsampling_factor=downsampling_factor, antialias=True) # paint the field final_field = mtf.zeros(mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: final_field = mtf.pad(final_field, [halo_size, halo_size], block_size_dim.name) final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size) # Halo exchange for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]): final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: final_field = mtf.slice(final_field, halo_size, block_size_dim.size, block_size_dim.name) final_field = mtf.slicewise( lambda x: x[:, 0, 0, 0], [final_field], output_dtype=dtype, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) mtfdata = mtf.import_tf_tensor(mesh, tf.convert_to_tensor(data), shape=shape) # Get prior k_dims_pr = [d.shape[0] for d in kv_pr] k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]] cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype) def _cwise_prior(kfield, pk, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2) kshape = kk.shape kk = tf.reshape(kk, [-1]) pkmesh = tfp.math.interp_regular_1d_grid( x=kk, x_ref_min=1e-05, x_ref_max=1000.0, y_ref=pk, grid_regularizing_transform=tf.log) priormesh = tf.reshape(pkmesh, kshape) return tf.abs(kfield) / priormesh**0.5 cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv_pr, output_dtype=tf.float32) prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3 # Total loss diff = (final_field - mtfdata) R0 = tf.placeholder(tf.float32, shape=()) def _cwise_smooth(kfield, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2 wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype) return kfield * wts cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype) cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv_pr, output_dtype=cdtype) diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype) chisq = mtf.reduce_sum(mtf.square(diff)) loss = chisq + prior #return initc, final_field, loss, linearop, input_field nyq = np.pi * nc / bs def _cwise_highpass(kfield, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2 wts = tf.cast(tf.exp(-kk * (R0 * bs / nc + 1 / nyq)**2), kfield.dtype) return kfield * (1 - wts) var_grads = mtf.gradients([loss], [fieldvar]) cgrads = mesh_utils.r2c3d(var_grads[0], k_dims_pr, dtype=cdtype) cgrads = mtf.cwise(_cwise_highpass, [cgrads] + kv_pr, output_dtype=cdtype) var_grads = [ mesh_utils.c2r3d(cgrads, var_grads[0].shape[-3:], dtype=dtype) ] lr = tf.placeholder(tf.float32, shape=()) update_op = mtf.assign(fieldvar, fieldvar - var_grads[0] * lr) return initc, final_field, loss, var_grads, update_op, linearop, input_field, lr, R0
def recon_prototype(mesh, data, nc=FLAGS.nc, bs=FLAGS.box_size, batch_size=FLAGS.batch_size, a0=FLAGS.a0, a=FLAGS.af, nsteps=FLAGS.nsteps, dtype=tf.float32): """ Prototype of function computing LPT deplacement. Returns output tensorflow and mesh tensorflow tensors """ if dtype == tf.float32: npdtype = "float32" cdtype = tf.complex64 elif dtype == tf.float64: npdtype = "float64" cdtype = tf.complex128 print("Dtype : ", dtype, npdtype) # Compute a few things first, using simple tensorflow kny = 1 * np.pi * nc / bs R1, R2 = 3., 3 * 1.2 stages = np.linspace(a0, a, nsteps, endpoint=True) #graph = mtf.Graph() #mesh = mtf.Mesh(graph, "my_mesh") # Define the named dimensions # Parameters of the small scales decomposition n_block_x = FLAGS.nx n_block_y = FLAGS.ny n_block_z = 1 halo_size = FLAGS.hsize if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z): new_size = int(0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z)) print('WARNING: REDUCING HALO SIZE from %d to %d' % (halo_size, new_size)) halo_size = new_size # Parameters of the large scales decomposition scalar = mtf.Dimension("scalar", 1) fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tfx_dim = mtf.Dimension("tx", nc) tfy_dim = mtf.Dimension("ty", nc) tfz_dim = mtf.Dimension("tz", nc) tx_dim = mtf.Dimension("tx_lr", nc) ty_dim = mtf.Dimension("ty_lr", nc) tz_dim = mtf.Dimension("tz_lr", nc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) #k_dims = [tx_dim, ty_dim, tz_dim] batch_dim = mtf.Dimension("batch", batch_size) klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim]) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False) kx = mtf.import_tf_tensor(mesh, kvec[0].squeeze().astype('float32'), shape=[tfx_dim]) ky = mtf.import_tf_tensor(mesh, kvec[1].squeeze().astype('float32'), shape=[tfy_dim]) kz = mtf.import_tf_tensor(mesh, kvec[2].squeeze().astype('float32'), shape=[tfz_dim]) kv = [ky, kz, kx] # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype('float32'), shape=[tx_dim]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype('float32'), shape=[ty_dim]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype('float32'), shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] shape = [batch_dim, fx_dim, fy_dim, fz_dim] lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] # # Begin simulation ## Compute initial initial conditions distributed #initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv) fieldvar = mtf.get_variable(mesh, 'linear', part_shape) input_field = tf.placeholder(data.dtype, [batch_size, nc, nc, nc]) mtfinp = mtf.import_tf_tensor(mesh, input_field, shape=part_shape) linearop = mtf.assign(fieldvar, mtfinp) #field = fieldvar initc = fieldvar print("initc : ", initc) # Here we can run our nbody if FLAGS.nbody: state = mtfpm.lpt_init_single( fieldvar, a0, kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) # Here we can run our nbody final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape, kv_lr, halo_size) else: final_state = mtfpm.lpt_init_single( initc, stages[-1], kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) # paint the field final_field = mtf.zeros(mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: final_field = mtf.pad(final_field, [halo_size, halo_size], block_size_dim.name) final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size) # Halo exchange for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]): final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: final_field = mtf.slice(final_field, halo_size, block_size_dim.size, block_size_dim.name) final_field = mtf.slicewise( lambda x: x[:, 0, 0, 0], [final_field], output_dtype=dtype, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) ## x = final_field ppars, mpars, kernel = setupfnn() pwts, pbias, pmx, psx = ppars mwts, mbias, mmx, msx, mmy, msy = mpars msy, mmy = msy[0], mmy[0] print("mmy : ", mmy) size = 3 k_dims = [d.shape[0] for d in kv] k_dims = [k_dims[2], k_dims[0], k_dims[1]] tfnc, tfbs = float_to_mtf(nc * 1., mesh, scalar), float_to_mtf(bs, mesh, scalar) x1f = mesh_utils.r2c3d(x, k_dims, dtype=cdtype) x1f = mtf.cwise(cwise_decic, [x1f] + kv + [tfnc, tfbs], output_dtype=cdtype) x1d = mesh_utils.c2r3d(x1f, x.shape[-3:], dtype=dtype) x1d = mtf.add(x1d, -1.) x1f0 = mesh_utils.r2c3d(x1d, k_dims, dtype=cdtype) x1f = mtf.cwise(cwise_fingauss, [x1f0, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype) x1 = mesh_utils.c2r3d(x1f, x1d.shape[-3:], dtype=dtype) x2f = mtf.cwise(cwise_fingauss, [x1f0, float_to_mtf(R2, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype) x2 = mesh_utils.c2r3d(x2f, x1d.shape[-3:], dtype=dtype) x12 = x1 - x2 width = tf.placeholder(tf.float32, shape=()) def apply_pwts(x, x1, x2): #y = tf.expand_dims(x, axis=-1) y = tf.nn.conv3d(tf.expand_dims(x, axis=-1), kernel, [1, 1, 1, 1, 1], 'SAME') y1 = tf.nn.conv3d(tf.expand_dims(x1, axis=-1), kernel, [1, 1, 1, 1, 1], 'SAME') y2 = tf.nn.conv3d(tf.expand_dims(x2, axis=-1), kernel, [1, 1, 1, 1, 1], 'SAME') #y = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x), -1), kernel, [1, 1, 1, 1, 1], 'VALID') #y1 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x1), -1), kernel, [1, 1, 1, 1, 1], 'VALID') #y2 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x12), -1), kernel, [1, 1, 1, 1, 1], 'VALID') yy = tf.concat([y, y1, y2], axis=-1) yy = yy - pmx yy = yy / psx yy1 = tf.nn.relu(tf.matmul(yy, pwts[0]) + pbias[0]) yy2 = tf.nn.relu(tf.matmul(yy1, pwts[1]) + pbias[1]) yy3 = tf.matmul(yy2, pwts[2]) + pbias[2] pmodel = tf.nn.sigmoid(width * yy3) return pmodel[..., 0] pmodel = mtf.slicewise( apply_pwts, [x, x1, x12], output_dtype=tf.float32, output_shape=part_shape, # + [mtf.Dimension('c_dim', 81)], name='apply_pwts', splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3]) def apply_mwts(x, x1, x2): #y = tf.expand_dims(x, axis=-1) zz = tf.concat([ tf.expand_dims(x, -1), tf.expand_dims(x1, -1), tf.expand_dims(x2, -1) ], axis=-1) zz = zz - mmx zz = zz / msx zz1 = tf.nn.elu(tf.matmul(zz, mwts[0]) + mbias[0]) zz2 = tf.nn.elu(tf.matmul(zz1, mwts[1]) + mbias[1]) zz3 = tf.matmul(zz2, mwts[2]) + mbias[2] mmodel = zz3 * msy + mmy return mmodel[..., 0] mmodel = mtf.slicewise( apply_mwts, [x, x1, x12], output_dtype=tf.float32, output_shape=part_shape, # + [mtf.Dimension('c_dim', 81)], name='apply_mwts', splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3]) model = pmodel * mmodel mtfdata = mtf.import_tf_tensor(mesh, tf.convert_to_tensor(data), shape=shape) # Get prior #k_dims = [d.shape[0] for d in kv] #k_dims = [k_dims[2], k_dims[0], k_dims[1]] k_dims_pr = [d.shape[0] for d in kv] k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]] cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype) def _cwise_prior(kfield, pk, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2) kshape = kk.shape kk = tf.reshape(kk, [-1]) pkmesh = tfp.math.interp_regular_1d_grid( x=kk, x_ref_min=1e-05, x_ref_max=1000.0, y_ref=pk, grid_regularizing_transform=tf.log) priormesh = tf.reshape(pkmesh, kshape) return tf.abs(kfield) / priormesh**0.5 cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv, output_dtype=tf.float32) prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3 * nc**3 # Total loss #diff = (model - mtfdata) modelf = mesh_utils.r2c3d(model, k_dims, dtype=cdtype) modelsmf = mtf.cwise(cwise_fingauss, [modelf, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype) modelsm = mesh_utils.c2r3d(modelsmf, x1d.shape[-3:], dtype=dtype) #dataf = mesh_utils.r2c3d(mtfdata, k_dims, dtype=cdtype) #datasmf = mtf.cwise(cwise_fingauss, [dataf, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype) #datasm = mesh_utils.c2r3d(datasmf, x1d.shape[-3:], dtype=dtype) ##Anneal R0 = tf.placeholder(tf.float32, shape=()) M0 = tf.placeholder(tf.float32, shape=()) off, istd = tf.placeholder(tf.float32, shape=data.shape), tf.placeholder( tf.float32, shape=data.shape) mtfoff = mtf.import_tf_tensor(mesh, off, shape=shape) mtfistd = mtf.import_tf_tensor(mesh, istd, shape=shape) diff = mtf.log(modelsm + M0) - mtf.log(mtfdata + M0) #diff = diff / 0.25 #diff = (diff + mtfoff)*mtfistd #For some reason, doing things wrong this one diff = (diff + mtfoff) / 0.25 def _cwise_smooth(kfield, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2 wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype) return kfield * wts cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype) cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv, output_dtype=cdtype) diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype) chisq = mtf.reduce_sum(mtf.square(diff)) loss = chisq + prior #return initc, final_field, loss, linearop, input_field nyq = np.pi * nc / bs def _cwise_highpass(kfield, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2 wts = tf.cast(tf.exp(-kk * (R0 * bs / nc + 1 / nyq)**2), kfield.dtype) return kfield * (1 - wts) var_grads = mtf.gradients([loss], [fieldvar]) cgrads = mesh_utils.r2c3d(var_grads[0], k_dims_pr, dtype=cdtype) cgrads = mtf.cwise(_cwise_highpass, [cgrads] + kv, output_dtype=cdtype) var_grads = [mesh_utils.c2r3d(cgrads, diff.shape[-3:], dtype=dtype)] lr = tf.placeholder(tf.float32, shape=()) update_op = mtf.assign(fieldvar, fieldvar - var_grads[0] * lr) return initc, model, loss, var_grads, update_op, linearop, input_field, lr, R0, M0, width, chisq, prior, off, istd
def model_fn(features, labels, mode, params): """The model_fn argument for creating an Estimator.""" #tf.logging.info("features = %s labels = %s mode = %s params=%s" % # (features, labels, mode, params)) global_step = tf.train.get_global_step() graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") fields, metrics, kv = recon_model(mesh, features['datasm'], features['bparams'], features['ipkerror'], features['R0'], features['x0']) fieldvar, final, model = fields chisq, prior, loss = metrics ## mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) mesh_size = mesh_shape.size layout_rules = [("nx_lr", "row"), ("ny_lr", "col"), ("nx", "row"), ("ny", "col"), ("ty", "row"), ("tz", "col"), ("ty_lr", "row"), ("tz_lr", "col"), ("nx_block", "row"), ("ny_block", "col")] devices = [""] * mesh_size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, devices) print(mesh_impl) ## if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf.optimize.AdamWeightDecayOptimizer(features['lr']) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) restore_hook = mtf.MtfRestoreHook(lowering) # tf_init = lowering.export_to_tf_tensor(fieldvar) tf_final = lowering.export_to_tf_tensor(final) tf_model = lowering.export_to_tf_tensor(model) tf_loss = lowering.export_to_tf_tensor(loss) tf_chisq = lowering.export_to_tf_tensor(chisq) tf_prior = lowering.export_to_tf_tensor(prior) ##Predict if mode == tf.estimator.ModeKeys.PREDICT: tf_loss = lowering.export_to_tf_tensor(loss) tf.summary.scalar("loss", tf_loss) tf.summary.scalar("chisq", tf_chisq) tf.summary.scalar("prior", tf_prior) predictions = { "ic": tf_init, "final": tf_final, "model": tf_model, } return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[restore_hook], export_outputs={ "model": tf.estimator.export.PredictOutput( predictions) #TODO: is classify a keyword? }) ##Train if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=1, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook(fpath, save_steps=1000, saver=saver, listeners=[saver_listener]) logging_hook = tf.train.LoggingTensorHook( { "loss": tf_loss, "chisq": tf_chisq, "prior": tf_prior }, every_n_iter=10) # Name tensors to be logged with LoggingTensorHook. tf.identity(tf_loss, "loss") tf.identity(tf_prior, "prior") tf.identity(tf_chisq, "chisq") # Save accuracy scalar to Tensorboard output. tf.summary.scalar("loss", tf_loss) tf.summary.scalar("chisq", tf_chisq) tf.summary.scalar("prior", tf_prior) # restore_hook must come before saver_hook return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook, logging_hook]) ##Eval if mode == tf.estimator.ModeKeys.EVAL: return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.EVAL, loss=tf_loss, evaluation_hooks=[restore_hook], eval_metric_ops={ "loss": tf_loss, "chisq": tf_chisq, #tf.metrics.accuracy( # labels=labels, predictions=tf.argmax(tf_logits, axis=1)), })
def testPool(self, pooling_method): batch = 2 depth = 3 height = 4 width = 6 channels = 3 tf.random.set_random_seed(1234) inputs = tf.random_normal([batch, depth, height, width, channels]) stride_d = 3 stride_h = 2 stride_w = 3 graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) depth_dim = mtf.Dimension("depth", depth) height_dim = mtf.Dimension("height", height) width_dim = mtf.Dimension("width", width) channels_dim = mtf.Dimension("channels", channels) mtf_inputs = mtf.import_tf_tensor(mesh, inputs, shape=mtf.Shape([ batch_dim, depth_dim, height_dim, width_dim, channels_dim ])) if pooling_method == "MAX_2D": mtf_outputs = mtf.layers.max_pool2d(mtf_inputs, ksize=(stride_h, stride_w)) inputs = tf.reshape(inputs, [batch * depth, height, width, channels]) expected_outputs = tf.keras.layers.MaxPooling2D( (stride_h, stride_w))(inputs) expected_outputs = tf.reshape(expected_outputs, [ batch, depth, int(height / stride_h), int(width / stride_w), channels ]) elif pooling_method == "AVG_2D": mtf_outputs = mtf.layers.avg_pool2d(mtf_inputs, ksize=(stride_h, stride_w)) inputs = tf.reshape(inputs, [batch * depth, height, width, channels]) expected_outputs = tf.keras.layers.AveragePooling2D( (stride_h, stride_w))(inputs) expected_outputs = tf.reshape(expected_outputs, [ batch, depth, int(height / stride_h), int(width / stride_w), channels ]) elif pooling_method == "MAX_3D": mtf_outputs = mtf.layers.max_pool3d( mtf_inputs, ksize=[stride_d, stride_h, stride_w]) expected_outputs = tf.keras.layers.MaxPooling3D( [stride_d, stride_h, stride_w])(inputs) elif pooling_method == "AVG_3D": mtf_outputs = mtf.layers.avg_pool3d( mtf_inputs, ksize=[stride_d, stride_h, stride_w]) expected_outputs = tf.keras.layers.AveragePooling3D( [stride_d, stride_h, stride_w])(inputs) mtf_gradient = mtf.gradients([mtf_outputs], [mtf_inputs])[0] mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) actual_gradient = lowering.export_to_tf_tensor(mtf_gradient) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() self.evaluate(init) self.evaluate(tf_group) actual, expected = self.evaluate([actual_outputs, expected_outputs]) self.assertAllClose(actual, expected) actual = self.evaluate(actual_gradient) if pooling_method == "MAX_2D": expected_non_zeros = batch * depth * height * width * channels / ( stride_h * stride_w) self.assertEqual(np.count_nonzero(actual), expected_non_zeros) elif pooling_method == "AVG_2D": expected = np.ones((batch, depth, height, width, channels), dtype=np.float32) / stride_h / stride_w self.assertAllClose(actual, expected) elif pooling_method == "MAX_3D": expected_non_zeros = batch * depth * height * width * channels / ( stride_d * stride_h * stride_w) self.assertEqual(np.count_nonzero(actual), expected_non_zeros) elif pooling_method == "AVG_3D": expected = np.ones( (batch, depth, height, width, channels), dtype=np.float32) / stride_d / stride_h / stride_w self.assertAllClose(actual, expected)
def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, max_optimized_variable_size=None, optimizer="adam", clip_gradients=True): """Creates an optimizer training op.""" global_step = tf.train.get_or_create_global_step() mesh = loss.mesh if init_lr: # Implements linear decay of the learning rate. learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) learning_rate = tf.train.polynomial_decay(learning_rate, global_step, num_train_steps, end_learning_rate=0.0, power=1.0, cycle=False) # Implements linear warmup. I.e., if global_step < num_warmup_steps, the # learning rate will be `global_step/num_warmup_steps * init_lr`. if num_warmup_steps: global_steps_int = tf.cast(global_step, tf.int32) warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) global_steps_float = tf.cast(global_steps_int, tf.float32) warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) warmup_percent_done = global_steps_float / warmup_steps_float warmup_learning_rate = init_lr * warmup_percent_done is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) learning_rate = ((1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) mtf_learning_rate = mtf.import_tf_tensor(mesh, learning_rate, []) else: if optimizer == "adam": raise ValueError("Adam does not have a default learning rate") learning_rate = None mtf_learning_rate = None # It is recommended that you use this optimizer for fine tuning, since this # is how the model was trained (note that the Adam m/v variables are NOT # loaded from init_checkpoint.) if optimizer == "adam": optimizer = mtf_optimize.AdamWeightDecayOptimizer( learning_rate=mtf_learning_rate, weight_decay_rate=0.01, beta_1=0.9, beta_2=0.999, epsilon=1e-6, exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) elif optimizer == "adafactor": optimizer = mtf_optimize.AdafactorOptimizer( learning_rate=learning_rate, min_dim_size_to_factor=32) else: raise ValueError("unknown optimizer") trainable_variables = mesh.graph.trainable_variables if max_optimized_variable_size: trainable_variables = [ t for t in trainable_variables if t.shape.size <= max_optimized_variable_size ] var_grads = mtf.gradients([loss], [v.outputs[0] for v in trainable_variables]) # This is how the model was pre-trained. if clip_gradients: (var_grads, _) = clip_by_global_norm(var_grads, clip_norm=mtf.constant(mesh, 1.0, dtype=tf.float32)) update_ops = optimizer.apply_grads(var_grads, trainable_variables) return learning_rate, update_ops
def model_fn(features, labels, mode, params): """A model is called by TpuEstimator.""" del labels global_step = tf.train.get_global_step() graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) if FLAGS.use_tpu: ctx = params['context'] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] tf.logging.info('device_list = %s' % device_list,) # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer(device_list, devices_memeory_usage) mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) else: var_placer = None mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) mesh = mtf.Mesh(graph, 'my_mesh', var_placer) with mtf.utils.outside_all_rewrites(): logits, loss = toy_model(features, mesh) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients([loss], [v.outputs[0] for v in graph.trainable_variables]) if FLAGS.optimizer == 'Adafactor': optimizer = mtf.optimize.AdafactorOptimizer() else: assert FLAGS.optimizer == 'SGD' optimizer = mtf.optimize.SgdOptimizer(lr=FLAGS.lr) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) else: # for now, we can only export fully-replicated tensors. fully_replicated_logits = mtf.anonymize(logits) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = tf.to_float(lowering.export_to_tf_tensor(loss)) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) tf.logging.info('tf_update_ops: {}'.format(tf_update_ops)) train_op = tf.group(tf_update_ops) else: tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: saver = tf.train.Saver( tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( FLAGS.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook]) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(tf_logits): mean_logits = tf.metrics.mean(tf_logits) return {'mean_logits': mean_logits} eval_metrics = (metric_fn, [tf_logits]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)
def my_model_fn(features, labels, mode, params=None, config=None): """Estimator model function. Args: features: input features dictionary labels: ignored mode: a tf.estimator.ModeKeys params: something config: something Returns: something """ del labels, config global_step = tf.train.get_global_step() if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [ host_placement_fn(host_id=t) for t in range(num_hosts) ] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer( device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) else: var_placer = None mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) def _import_feature(key): """Import a feature from the features dictionary into a mtf.Tensor. Args: key: a string Returns: a mtf.Tensor with dtype int32 and shape [batch_dim, length_dim] """ batch_dim = mtf.Dimension("batch", batch_size) length_dim = mtf.Dimension("length", length) mtf_shape = mtf.Shape([batch_dim, length_dim]) if key not in features: return None x = tf.to_int32(features[key]) if not use_tpu: x = tf.Print(x, [x], "import feature %s" % key, summarize=1000, first_n=1) return mtf.import_fully_replicated(mesh, x, mtf_shape, name=key) # PREDICT mode if mode == tf.estimator.ModeKeys.PREDICT: inputs = _import_feature("inputs") if text2self: mtf_samples = transformer_model.sample_autoregressive( inputs, variable_dtype=variable_dtype, temperature=temperature) else: mtf_samples = transformer_model.decode( inputs, variable_dtype=variable_dtype, beam_size=beam_size, alpha=alpha, temperature=temperature) mtf_samples = mtf.anonymize(mtf_samples) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) outputs = lowering.export_to_tf_tensor(mtf_samples) predictions = {"outputs": outputs} return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[mtf.MtfRestoreHook(lowering)]) targets = _import_feature("targets") anon_targets = mtf.anonymize(targets) if text2self: _, length_dim = targets.shape inputs = mtf.shift(targets, offset=1, dim=length_dim, wrap=False) else: inputs = _import_feature("inputs") if text2self: position_kwargs = dict( sequence_id=_import_feature("targets_segmentation"), position=_import_feature("targets_position"), ) else: position_kwargs = dict( encoder_sequence_id=_import_feature("inputs_segmentation"), decoder_sequence_id=_import_feature("targets_segmentation"), encoder_position=_import_feature("inputs_position"), decoder_position=_import_feature("targets_position"), ) logits, loss = transformer_model.call_simple( inputs=inputs, targets=targets, compute_loss=True, mode=mode, variable_dtype=variable_dtype, **position_kwargs) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf.optimize.AdafactorOptimizer() update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if not use_tpu: tf_loss = tf.Print(tf_loss, [tf_loss, tf.train.get_global_step()], "step, tf_loss") if logits and mode != tf.estimator.ModeKeys.TRAIN: tf_logits = lowering.export_to_tf_tensor(logits) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) if mode == tf.estimator.ModeKeys.TRAIN: if use_tpu: return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook]) elif mode == tf.estimator.ModeKeys.EVAL: def padded_neg_log_perplexity(logits, labels): weights = tf.to_float(tf.not_equal(labels, 0)) xent = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=labels, logits=logits) return { "neg_log_perplexity": tf.metrics.mean(-xent, weights) } labels = lowering.export_to_tf_tensor(anon_targets) eval_metrics = (padded_neg_log_perplexity, [tf_logits, labels]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)
def get_optimizer(mesh, loss, params, variable_dtype, inp_var_grads=None): """Creates and returns an optimizer training op.""" global_step = tf.train.get_or_create_global_step() # set defaults end_step = params.get("lr_decay_end", params["train_steps"]) lr_decay = params.get("lr_decay", "cosine") warmup_steps = params.get("warmup_steps", 3000) gradient_clipping = params.get("gradient_clipping", 1.0) optimizer_name = params.get("optimizer", "adam") learning_rate = tf.constant(value=params["lr"], shape=[], dtype=variable_dtype.slice_dtype) clip_value = mtf.constant(mesh, gradient_clipping, dtype=variable_dtype.slice_dtype) if inp_var_grads is None: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in mesh.graph.trainable_variables]) else: var_grads = inp_var_grads valid_grads_vars = list( filter(lambda grad_var: grad_var[0] is not None, zip(var_grads, mesh.graph.trainable_variables))) valid_vars = [var for grad, var in valid_grads_vars] valid_grad = [grad for grad, var in valid_grads_vars] tf.logging.info([ v for v in zip(var_grads, [v.outputs[0] for v in mesh.graph.trainable_variables]) ]) # Cast to full precision var_grads_fp = [ mtf.cast(v, variable_dtype.slice_dtype) for v in valid_grad ] if lr_decay == "linear": learning_rate = tf.train.polynomial_decay( learning_rate, global_step, end_step, end_learning_rate=params["lr"] * 0.1, # Decrease to 10% of initial LR according to GPT-3 paper power=1.0, cycle=False) elif lr_decay == "cosine": learning_rate = tf.train.cosine_decay( learning_rate, global_step, end_step, alpha=0.1 # Alpha is min lr value as a fraction of init lr. ) if warmup_steps > 0: global_steps_int = tf.cast(global_step, tf.int32) warmup_steps_int = tf.constant(warmup_steps, dtype=tf.int32) dtype = variable_dtype.slice_dtype global_steps_float = tf.cast(global_steps_int, dtype) warmup_steps_float = tf.cast(warmup_steps_int, dtype) warmup_percent_done = global_steps_float / warmup_steps_float warmup_learning_rate = learning_rate * warmup_percent_done is_warmup = tf.cast(global_steps_int < warmup_steps_int, dtype) learning_rate = ((1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) learning_rate = mtf.import_fully_replicated(mesh, learning_rate, mtf.Shape([]), name="learning_rate") scalar_summary("lr", learning_rate) if optimizer_name.lower() == "adam": optimizer = mtf.optimize.AdamWeightDecayOptimizer( learning_rate=learning_rate, weight_decay_rate=params.get("weight_decay", 0.0), beta_1=params.get("beta_1", 0.9), beta_2=params.get("beta_2", 0.999), epsilon=params.get("epsilon", 1e-6), exclude_from_weight_decay=["norm", "bias"]) elif optimizer_name.lower() == "adafactor": optimizer = mtf.optimize.AdafactorOptimizer( learning_rate=learning_rate, decay_rate=params.get("weight_decay", 0.0), beta1=params.get("beta_1", 0.9), epsilon1=params.get("epsilon_1", 1e-30), epsilon2=params.get("epsilon_2", 1e-3)) else: raise ValueError(f"{optimizer_name} not recognized") if gradient_clipping is not None: (var_grads_fp, _) = clip_by_global_norm(var_grads_fp, clip_norm=clip_value) update_ops = optimizer.apply_grads(var_grads_fp, valid_vars) return learning_rate, update_ops, var_grads_fp
def model_fn(features, labels, mode, params): """The model_fn argument for creating an Estimator.""" tf.logging.info("features = %s labels = %s mode = %s params=%s" % (features, labels, mode, params)) global_step = tf.train.get_global_step() graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") logits, loss = mnist_model(features, labels, mesh) variables = graph._all_variables tf.logging.info("[variable num]: {}".format(len(variables))) for v in variables: tf.logging.info("[variable] (name, shape): ({},{})".format(v.name,v.shape)) mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.auto_mtf.layout(graph, mesh_shape, [logits, loss]) # layout_rules = {('ResidualBlock-8-filters1', 'b2'), ('ResidualBlock-3-filters3', 'b1'), ('ResidualBlockWithDown-0-filters3', 'b2'), ('ResidualBlockWithDown-1-filters2', 'b2'), ('ResidualBlock-2-filters2', 'b1'), ('ResidualBlockWithDown-2-filters2', 'b1'), ('ResidualBlock-10-filters2', 'b1'), ('ResidualBlockWithDown-3-filters3', 'b1'), ('ResidualBlock-7-filters2', 'b2'), ('ResidualBlock-0-filters2', 'b2'), ('ResidualBlock-1-filters1', 'b2'), ('ResidualBlockWithDown-2-filters1', 'b2'), ('ResidualBlockWithDown-1-filters3', 'b1'), ('ResidualBlockWithDown-0-filters1', 'b2'), ('ResidualBlock-2-filters1', 'b2'), ('ResidualBlock-6-filters3', 'b2'), ('ResidualBlock-4-filters2', 'b1'), ('ResidualBlock-1-filters3', 'b2'), ('ResidualBlock-9-filters2', 'b1'), ('ResidualBlock-11-filters2', 'b2'), ('ResidualBlock-1-filters2', 'b1'), ('ResidualBlock-5-filters1', 'b1'), ('ResidualBlock-11-filters3', 'b1'), ('classesnum', 'b2'), ('ResidualBlock-9-filters3', 'b2'), ('ResidualBlock-2-filters3', 'b2'), ('ResidualBlockWithDown-3-filters2', 'b2'), ('ResidualBlock-6-filters2', 'b1'), ('ResidualBlock-3-filters2', 'b2'), ('backbone-filters', 'b1'), ('ResidualBlock-10-filters1', 'b2'), ('ResidualBlock-8-filters2', 'b1'), ('ResidualBlock-5-filters3', 'b1'), ('ResidualBlock-4-filters1', 'b2'), ('ResidualBlock-7-filters1', 'b1'), ('ResidualBlock-7-filters3', 'b1'), ('ResidualBlock-11-filters1', 'b1'), ('ResidualBlock-10-filters3', 'b2'), ('ResidualBlockWithDown-2-filters3', 'b2'), ('ResidualBlock-5-filters2', 'b2'), ('ResidualBlock-3-filters1', 'b1'), ('ResidualBlock-0-filters1', 'b1'), ('ResidualBlock-0-filters3', 'b1'), ('ResidualBlock-6-filters1', 'b2'), ('ResidualBlock-9-filters1', 'b2'), ('ResidualBlockWithDown-0-filters2', 'b1'), ('ResidualBlockWithDown-3-filters1', 'b1'), ('ResidualBlockWithDown-1-filters1', 'b1')} print("[auto mtf search] strategy: {}".format(layout_rules)) mesh_size = mesh_shape.size mesh_devices = ["gpu:0", "gpu:1","gpu:2", "gpu:3"] mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf.optimize.AdafactorOptimizer() # optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) restore_hook = mtf.MtfRestoreHook(lowering) tf_logits = lowering.export_to_tf_tensor(logits) if mode != tf.estimator.ModeKeys.PREDICT: tf_loss = lowering.export_to_tf_tensor(loss) tf.summary.scalar("loss", tf_loss) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) # saver = tf.train.Saver( # tf.global_variables(), # sharded=True, # max_to_keep=10, # keep_checkpoint_every_n_hours=2, # defer_build=False, save_relative_paths=True) # tf.add_to_collection(tf.GraphKeys.SAVERS, saver) # saver_listener = mtf.MtfCheckpointSaverListener(lowering) # saver_hook = tf.train.CheckpointSaverHook( # FLAGS.model_dir, # save_steps=1000, # saver=saver, # listeners=[saver_listener]) accuracy = tf.metrics.accuracy( labels=labels, predictions=tf.argmax(tf_logits, axis=1)) # Name tensors to be logged with LoggingTensorHook. tf.identity(tf_loss, "cross_entropy") tf.identity(accuracy[1], name="train_accuracy") logging_hook = tf.train.LoggingTensorHook(every_n_iter=100,tensors={'loss': 'cross_entropy','acc':'train_accuracy'}) # Save accuracy scalar to Tensorboard output. # tf.summary.scalar("train_accuracy", accuracy[1]) # restore_hook must come before saver_hook return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook,logging_hook]) if mode == tf.estimator.ModeKeys.PREDICT: predictions = { "classes": tf.argmax(tf_logits, axis=1), "probabilities": tf.nn.softmax(tf_logits), } return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[restore_hook], export_outputs={ "classify": tf.estimator.export.PredictOutput(predictions) }) if mode == tf.estimator.ModeKeys.EVAL: return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.EVAL, loss=tf_loss, evaluation_hooks=[restore_hook], eval_metric_ops={ "accuracy": tf.metrics.accuracy( labels=labels, predictions=tf.argmax(tf_logits, axis=1)), })
def model_fn(features, labels, mode, params): """The model_fn argument for creating an Estimator.""" tf.logging.info("features = %s labels = %s mode = %s params=%s" % (features, labels, mode, params)) global_step = tf.train.get_global_step() graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") logits, loss = mnist_model(features, labels, mesh) mesh_shape = [("gpu_rows", 1), ("gpu_cols", 6)] #dependency on ortools, which won't build on power #import mesh_tensorflow.auto_mtf #layout_rules = mtf.auto_mtf.layout(graph, mesh_shape,[logits,loss]) layout_rules = [("filters1", "gpu_rows"), ("filters2", "gpu_cols"), ("filters3", "gpu_rows"), ("filters4", "gpu_cols"), ("filters5", "gpu_rows"), ("filters6", "gpu_cols")] mesh_devices = ["gpu:%d" % itm for itm in range(6)] mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf.optimize.AdafactorOptimizer() lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_grads = [lowering.export_to_tf_tensor(grad) for grad in var_grads] tf_avg_grads = [hvd.allreduce(grad) for grad in tf_grads] var_avg_grads = [ mtf.import_tf_tensor(mesh, tf_avg_grad, shape=grad.shape) for tf_avg_grad, grad in zip(tf_avg_grads, var_grads) ] update_ops = optimizer.apply_grads(var_avg_grads, graph.trainable_variables) with tf.variable_scope('', reuse=tf.AUTO_REUSE) as _: lowering = mtf.Lowering(graph, {mesh: mesh_impl}) restore_hook = mtf.MtfRestoreHook(lowering) tf_logits = lowering.export_to_tf_tensor(logits) if mode != tf.estimator.ModeKeys.PREDICT: tf_loss = lowering.export_to_tf_tensor(loss) tf.summary.scalar("loss", tf_loss) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook("/tmp/%s" % FLAGS.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) accuracy = tf.metrics.accuracy(labels=labels, predictions=tf.argmax(tf_logits, axis=1)) # Name tensors to be logged with LoggingTensorHook. tf.identity(tf_loss, "cross_entropy") tf.identity(accuracy[1], name="train_accuracy") # Save accuracy scalar to Tensorboard output. tf.summary.scalar("train_accuracy", accuracy[1]) # restore_hook must come before saver_hook return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook]) if mode == tf.estimator.ModeKeys.PREDICT: predictions = { "classes": tf.argmax(tf_logits, axis=1), "probabilities": tf.nn.softmax(tf_logits), } return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[restore_hook], export_outputs={ "classify": tf.estimator.export.PredictOutput(predictions) }) if mode == tf.estimator.ModeKeys.EVAL: return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.EVAL, loss=tf_loss, evaluation_hooks=[restore_hook], eval_metric_ops={ "accuracy": tf.metrics.accuracy(labels=labels, predictions=tf.argmax(tf_logits, axis=1)), })
def get_optimizer(mesh, loss, params, variable_dtype, inp_var_grads=None): """Creates and returns an optimizer training op.""" global_step = tf.train.get_or_create_global_step() learning_rate = tf.constant(value=params["lr"], shape=[], dtype=variable_dtype.slice_dtype) clip_value = mtf.constant(mesh, params["gradient_clipping"], dtype=variable_dtype.slice_dtype) if inp_var_grads is None: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in mesh.graph.trainable_variables]) else: var_grads = inp_var_grads # Cast to full precision var_grads_fp = [mtf.cast(v, variable_dtype.slice_dtype) for v in var_grads] # decrease LR to final lr (lr*0.1) by this step - defaults to train_steps end_step = params.get("lr_decay_end", params["train_steps"]) if params["lr_decay"] == "linear": learning_rate = tf.train.polynomial_decay( learning_rate, global_step, end_step, end_learning_rate=params["lr"] * 0.1, # Decrease to 10% of initial LR according to GPT-3 paper power=1.0, cycle=False) elif params["lr_decay"] == "cosine": learning_rate = tf.train.cosine_decay( learning_rate, global_step, end_step, alpha=0.1 # Alpha is min lr value as a fraction of init lr. ) if params["warmup_steps"] > 0: global_steps_int = tf.cast(global_step, tf.int32) warmup_steps_int = tf.constant(params["warmup_steps"], dtype=tf.int32) dtype = variable_dtype.slice_dtype global_steps_float = tf.cast(global_steps_int, dtype) warmup_steps_float = tf.cast(warmup_steps_int, dtype) warmup_percent_done = global_steps_float / warmup_steps_float warmup_learning_rate = learning_rate * warmup_percent_done is_warmup = tf.cast(global_steps_int < warmup_steps_int, dtype) learning_rate = ((1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) learning_rate = mtf.import_fully_replicated(mesh, learning_rate, mtf.Shape([]), name="learning_rate") mtf.scalar_summary("lr", learning_rate) if params["opt_name"].lower() == "adam": optimizer = AdamWeightDecayOptimizer( learning_rate=learning_rate, weight_decay_rate=params["weight_decay"], beta_1=params["beta1"], beta_2=params["beta2"], epsilon=params["epsilon"], exclude_from_weight_decay=["norm", "bias"], variable_dtype=variable_dtype) else: optimizer = mtf.optimize.AdafactorOptimizer( learning_rate=params["lr"], decay_rate=params["weight_decay"], beta1=params["beta1"], epsilon1=params["ada_epsilon1"], epsilon2=params["ada_epsilon2"]) if params["use_tpu"]: optimizer = tf.tpu.CrossShardOptimizer(optimizer) if params["gradient_clipping"] is not None: (var_grads_fp, _) = clip_by_global_norm(var_grads_fp, clip_norm=clip_value) update_ops = optimizer.apply_grads(var_grads_fp, mesh.graph.trainable_variables) return learning_rate, update_ops, var_grads_fp
def model_fn(features, labels, mode, params): """The model_fn argument for creating an Estimator.""" tf.logging.info("features = %s labels = %s mode = %s params=%s" % (features, labels, mode, params)) global_step = tf.train.get_global_step() graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") logits, loss = cifar_model(features, labels, mesh) mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) mesh_size = mesh_shape.size # To enable manual device placement (e.g.: GPU 1, 2) comment the line below and uncomment the next one mesh_devices = [""] * mesh_size # mesh_devices = ['GPU:' + str(i) for i in range(mesh_size)] mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) labels = features['label'] if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) # Variables that affect learning rate. num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.batch_size decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY) # Decay the learning rate exponentially based on the number of steps. lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE, global_step, decay_steps, LEARNING_RATE_DECAY_FACTOR, staircase=True) tf.summary.scalar('learning_rate', lr) mtf_lr = mtf.import_tf_tensor( mesh, tf.convert_to_tensor(lr, dtype=tf.float32), mtf.Shape([])) optimizer = mtf.optimize.AdafactorOptimizer(learning_rate=mtf_lr) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) restore_hook = mtf.MtfRestoreHook(lowering) tf_logits = lowering.export_to_tf_tensor(logits) if mode != tf.estimator.ModeKeys.PREDICT: tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) tf.summary.scalar("loss", tf_loss) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook(FLAGS.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) accuracy = tf.metrics.accuracy(labels=labels, predictions=tf.argmax(tf_logits, axis=1)) # Name tensors to be logged with LoggingTensorHook. tf.identity(tf_loss, "cross_entropy") tf.identity(accuracy[1], name="train_accuracy") # Save accuracy scalar to Tensorboard output. tf.summary.scalar("train_accuracy", accuracy[1]) # restore_hook must come before saver_hook return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook]) if mode == tf.estimator.ModeKeys.PREDICT: predictions = { "classes": tf.argmax(tf_logits, axis=1), "probabilities": tf.nn.softmax(tf_logits), } return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[restore_hook], export_outputs={ "classify": tf.estimator.export.PredictOutput(predictions) }) if mode == tf.estimator.ModeKeys.EVAL: return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.EVAL, loss=tf_loss, evaluation_hooks=[restore_hook], eval_metric_ops={ "accuracy": tf.metrics.accuracy(labels=labels, predictions=tf.argmax(tf_logits, axis=1)), })
def model_fn(features, labels, mode, params): """The model_fn argument for creating an Estimator.""" tf.logging.info("features = %s labels = %s mode = %s params=%s" % (features, labels, mode, params)) global_step = tf.train.get_global_step() graph = mtf.Graph() # wrapped graph named "my_mesh" mesh = mtf.Mesh(graph, "my_mesh") logits, loss = mnist_model(features, labels, mesh) mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) mesh_size = mesh_shape.size print("mesh_shape.size = ", mesh_shape.size) mesh_devices = [""] * mesh_size mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf.optimize.AdafactorOptimizer() update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) restore_hook = mtf.MtfRestoreHook(lowering) tf_logits = lowering.export_to_tf_tensor(logits) if mode != tf.estimator.ModeKeys.PREDICT: tf_loss = lowering.export_to_tf_tensor(loss) tf.summary.scalar("loss", tf_loss) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) train_op = tf.group(tf_update_ops) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook(FLAGS.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) accuracy = tf.metrics.accuracy(labels=labels, predictions=tf.argmax(tf_logits, axis=1)) # Name tensors to be logged with LoggingTensorHook. tf.identity(tf_loss, "cross_entropy") tf.identity(accuracy[1], name="train_accuracy") # Save accuracy scalar to Tensorboard output. tf.summary.scalar("train_accuracy", accuracy[1]) # restore_hook must come before saver_hook return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook]) if mode == tf.estimator.ModeKeys.PREDICT: predictions = { "classes": tf.argmax(tf_logits, axis=1), "probabilities": tf.nn.softmax(tf_logits), } return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions, prediction_hooks=[restore_hook], export_outputs={ "classify": tf.estimator.export.PredictOutput(predictions) }) if mode == tf.estimator.ModeKeys.EVAL: return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.EVAL, loss=tf_loss, evaluation_hooks=[restore_hook], eval_metric_ops={ "accuracy": tf.metrics.accuracy(labels=labels, predictions=tf.argmax(tf_logits, axis=1)), })
def _model_fn(input_fea, input_lab): """Creates a model, add summary, modes (train or eval), and hooks.""" def _add_summary(lowering, train_or_eval, tf_loss, scalars, global_step): """Add all summaries.""" for k in scalars.keys(): if not isinstance(scalars[k], tf.Tensor): scalars[k] = tf.cast( lowering.export_to_tf_tensor(scalars[k]), tf.float32) def _host_loss_summary(global_step, tf_loss, **scalars): """Add summary.scalar in host side.""" gs = tf.cast(global_step, tf.int64) sum_loss = tf.contrib.summary.scalar( '{}_loss'.format(train_or_eval), tf_loss, step=gs) sum_ops = [sum_loss.op] for description, tf_metric in scalars.iteritems(): sum_metric = tf.contrib.summary.scalar('{}_{}'.format( train_or_eval, description), tf_metric, step=gs) sum_ops.append(sum_metric) with tf.control_dependencies(sum_ops): return tf.identity(tf_loss) # Cast the global step to tf.int32, since # outside_compilation does not support tf.int64. tf_loss = tpu.outside_compilation(_host_loss_summary, tf.cast(global_step, tf.int32), tf_loss, **scalars) return tf_loss global_step = tf.train.get_or_create_global_step() graph = mtf.Graph() # Worker 0 caches all the TPU binaries. replica_cache_size = 300 * 1024 * 1024 # 300M per replica. worker0_mem = replica_cache_size * 8 * num_hosts devices_memory_usage = [worker0_mem] + [0] * (num_hosts - 1) tf.logging.info('cpu_devices: {}, devices_mem: {}'.format( cpu_devices, devices_memory_usage)) var_placer = mtf.utils.BalancedVariablePlacer(cpu_devices, devices_memory_usage) mesh = mtf.Mesh(graph, 'my_mesh', var_placer) mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = unet.get_layout() mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules, None, d_assignment) with mtf.utils.outside_all_rewrites(): # Do not tpu_rewrite this part. preds, loss, scalars, bn_update_ops = ( unet.unet_with_spatial_partition(mesh, train_or_eval, input_fea, input_lab)) if train_or_eval == 'train': var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) lr = FLAGS.lr * tf.pow( FLAGS.lr_drop_rate, tf.floor( tf.cast(global_step, tf.float32) / FLAGS.lr_drop_steps)) scalars['learning_rate'] = lr optimizer = mtf.optimize.AdafactorOptimizer(learning_rate=lr) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) # This is where the actual tf graph got built. lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) tf_update_ops.extend( [lowering.lowered_operation(op) for op in bn_update_ops]) tf_update_ops_group = tf.group(tf_update_ops) else: # train_or_eval == 'eval': preds = [mtf.anonymize(pred) for pred in preds] # This is where the actual tf graph got built. lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_preds = [ tf.cast(lowering.export_to_tf_tensor(pred), tf.float32) for pred in preds ] tf_loss = tf.cast(lowering.export_to_tf_tensor(loss), tf.float32) if FLAGS.write_summary: tf_loss = _add_summary(lowering, train_or_eval, tf_loss, scalars, global_step) master_to_slice_hook = mtf.MtfRestoreHook(lowering) if train_or_eval == 'train': with mtf.utils.outside_all_rewrites(): saver = tf.train.Saver(tf.global_variables(), save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) slice_to_master_hook = tf.train.CheckpointSaverHook( FLAGS.checkpoint_dir, save_steps=FLAGS.save_checkpoints_steps, saver=saver, listeners=[saver_listener]) captured_hooks.capture( [master_to_slice_hook, slice_to_master_hook]) return tf_update_ops_group else: # train_or_eval == 'eval': tf_preds.extend([tf_loss, global_step]) tf_preds_dtypes = [tf_pred.dtype for tf_pred in tf_preds] tf_preds_shapes = [tf_pred.shape for tf_pred in tf_preds] captured_hooks.capture([master_to_slice_hook, None]) captured_output_dtypes_shapes.capture( [tf_preds_dtypes, tf_preds_shapes]) return tpu_ops.outfeed_enqueue_tuple(tf_preds)
def estimator_model_fn(cls, hparams, features, labels, mode, config=None, params=None, decode_hparams=None): hparams = copy.deepcopy(hparams) use_tpu = params and params.get("use_tpu", False) hparams.use_tpu = use_tpu # merge decode_hparams into hparams if present if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None: for k, v in six.iteritems(decode_hparams.values()): if hasattr(hparams, k) and getattr(hparams, k) != v: tf.logging.warning( "Overriding hparams.%s with %s from decode_hparams" % (k, v)) setattr(hparams, k, v) # Instantiate model data_parallelism = None if not use_tpu and config: data_parallelism = config.data_parallelism model = cls(hparams, mode, data_parallelism=data_parallelism, decode_hparams=decode_hparams) global_step = tf.train.get_global_step() mesh_shape = mtf.convert_to_shape(hparams.mesh_shape) layout_rules = mtf.convert_to_layout_rules(hparams.layout) if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [ host_placement_fn(host_id=t) for t in range(num_hosts) ] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer( device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) else: var_placer = None if len(data_parallelism.ps_devices) == 1: mesh_devices = [""] * mesh_shape.size else: assert len(data_parallelism.ps_devices) == mesh_shape.size mesh_devices = data_parallelism.ps_devices mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) # PREDICT mode if mode == tf.estimator.ModeKeys.PREDICT: return model.estimator_spec_predict(features, mesh, mesh_impl, use_tpu) logits, loss = model.mtf_model_fn(features, mesh) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) lr = learning_rate.learning_rate_schedule(hparams) tf.summary.scalar("learning_rate", lr) mtf_lr = mtf.import_tf_tensor( mesh, tf.convert_to_tensor(lr, dtype=tf.float32), mtf.Shape([])) optimizer = mtf.optimize.make_optimizer(hparams, mtf_lr) update_ops = [] for grad, var in zip(var_grads, graph.trainable_variables): update_ops.extend(optimizer.apply_grad(grad, var)) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if logits and mode != tf.estimator.ModeKeys.TRAIN: tf_logits = lowering.export_to_tf_tensor(logits) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) # tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( hparams.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) # EVAL mode if mode == tf.estimator.ModeKeys.EVAL: tf_logits = lowering.export_to_tf_tensor(logits) return model.estimator_spec_eval(features, tf_logits, labels, tf_loss, restore_hook, use_tpu) if use_tpu: # TPU host call. Important: need to be called before remove_summaries() if hparams.tpu_enable_host_call: host_call = t2t_model.create_host_call(hparams.model_dir) else: host_call = None t2t_model.remove_summaries() return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, host_call=host_call, training_hooks=[restore_hook, saver_hook]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook])
def model_fn(features, labels, mode, params): """A model is called by TpuEstimator.""" del labels global_step = tf.train.get_global_step() graph = mtf.Graph() mesh = mtf.Mesh(graph, 'my_mesh') mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, mtf.convert_to_layout_rules(FLAGS.layout), mesh_devices, params['context'].device_assignment) with mtf.utils.outside_all_rewrites(): logits, loss = toy_model(features, mesh) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) optimizer = mtf.optimize.AdafactorOptimizer() update_ops = [] for grad, var in zip(var_grads, graph.trainable_variables): update_ops.extend(optimizer.apply_grad(grad, var)) else: # for now, we can only export fully-replicated tensors. fully_replicated_logits = mtf.anonymize(logits) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = lowering.export_to_tf_tensor(loss) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] tf_update_ops.append(tf.assign_add(global_step, 1)) tf.logging.info('tf_update_ops: {}'.format(tf_update_ops)) train_op = tf.group(tf_update_ops) else: tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) if mode == tf.estimator.ModeKeys.TRAIN: saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( FLAGS.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook]) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(tf_logits): mean_logitss = tf.metrics.mean(tf_logits) return {'mean_logitss': mean_logitss} eval_metrics = (metric_fn, [tf_logits]) return tpu_estimator.TPUEstimatorSpec( tf.estimator.ModeKeys.EVAL, evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics)