def train_and_eval(): """Trains and evaluates MeshTensorflow model without TPUEstimator. TODO(lehou): Pack everything nicely as a set of APIs. """ mesh_context = None tf.logging.info('FLAGS.master: {}'.format(FLAGS.master)) resolver = tf.distribute.cluster_resolver.TPUClusterResolver(FLAGS.master) tf.config.experimental_connect_to_cluster(resolver) config = tf.ConfigProto() config.allow_soft_placement = True cluster_spec = resolver.cluster_spec() if cluster_spec: config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) with tf.Session(target=resolver.master(), config=config) as sess: tf.tpu.experimental.initialize_tpu_system(resolver) mesh_context = MeshContext( sess, FLAGS.use_tpu, FLAGS.mesh_shape, unet.get_layout()) for _ in range(FLAGS.num_training_loops): _train_phase(mesh_context) _eval_phase(mesh_context) if FLAGS.use_tpu: _shutdown() tf.logging.info('finished.')
def train_and_eval(): """Trains and evaluates MeshTensorflow model without TPUEstimator. TODO(lehou): Pack everything nicely as a set of APIs. """ mesh_context = None tf.logging.info('FLAGS.master: {}'.format(FLAGS.master)) with tf.Session(target=FLAGS.master, config=tf.ConfigProto(allow_soft_placement=True)) as sess: mesh_context = MeshContext(sess, FLAGS.use_tpu, FLAGS.mesh_shape, unet.get_layout()) for _ in range(FLAGS.num_training_loops): _train_phase(mesh_context) _eval_phase(mesh_context) if FLAGS.use_tpu: _shutdown() tf.logging.info('finished.')
def train_and_eval(): """Trains and evaluates MeshTensorflow model without TPUEstimator. TODO(lehou): Pack everything nicely as a set of APIs. """ tf.logging.info('FLAGS.master: {}'.format(FLAGS.master)) # Open a session to get the list of CPU devices to hold master variables. with tf.Session(target=FLAGS.master, config=tf.ConfigProto(allow_soft_placement=True)) as sess: topology = sess.run(tpu.initialize_system()) cpu_devices = _list_cpu_devices(sess) topo_object = tf.contrib.tpu.Topology(serialized=topology) num_cores = int(np.prod(topo_object.mesh_shape)) num_hosts = int(topo_object.num_tasks) num_cores_per_host = int(num_cores // num_hosts) assert num_cores_per_host == int(topo_object.num_tpus_per_task) # Get a device_assignment object for mtf. d_assignment = device_assignment.device_assignment( topology, computation_shape=[1, 1, 1], num_replicas=num_cores) # Get mesh_impl. mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = unet.get_layout() mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules, None, d_assignment) for _ in range(FLAGS.num_training_loops): _train_phase(mesh_impl, cpu_devices, d_assignment, num_hosts, num_cores) _eval_phase(mesh_impl, cpu_devices, d_assignment, num_hosts, num_cores) _shutdown() tf.logging.info('finished.')
def _model_fn(input_fea, input_lab): """Creates a model, add summary, modes (train or eval), and hooks.""" def _add_summary(lowering, train_or_eval, tf_loss, scalars, global_step): """Add all summaries.""" for k in scalars.keys(): if not isinstance(scalars[k], tf.Tensor): scalars[k] = tf.cast( lowering.export_to_tf_tensor(scalars[k]), tf.float32) def _host_loss_summary(global_step, tf_loss, **scalars): """Add summary.scalar in host side.""" gs = tf.cast(global_step, tf.int64) sum_loss = tf.contrib.summary.scalar( '{}_loss'.format(train_or_eval), tf_loss, step=gs) sum_ops = [sum_loss.op] for description, tf_metric in scalars.iteritems(): sum_metric = tf.contrib.summary.scalar('{}_{}'.format( train_or_eval, description), tf_metric, step=gs) sum_ops.append(sum_metric) with tf.control_dependencies(sum_ops): return tf.identity(tf_loss) # Cast the global step to tf.int32, since # outside_compilation does not support tf.int64. tf_loss = tpu.outside_compilation(_host_loss_summary, tf.cast(global_step, tf.int32), tf_loss, **scalars) return tf_loss global_step = tf.train.get_or_create_global_step() graph = mtf.Graph() # Worker 0 caches all the TPU binaries. replica_cache_size = 300 * 1024 * 1024 # 300M per replica. worker0_mem = replica_cache_size * 8 * num_hosts devices_memory_usage = [worker0_mem] + [0] * (num_hosts - 1) tf.logging.info('cpu_devices: {}, devices_mem: {}'.format( cpu_devices, devices_memory_usage)) var_placer = mtf.utils.BalancedVariablePlacer(cpu_devices, devices_memory_usage) mesh = mtf.Mesh(graph, 'my_mesh', var_placer) mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = unet.get_layout() mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules, None, d_assignment) with mtf.utils.outside_all_rewrites(): # Do not tpu_rewrite this part. preds, loss, scalars, bn_update_ops = ( unet.unet_with_spatial_partition(mesh, train_or_eval, input_fea, input_lab)) if train_or_eval == 'train': var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) lr = FLAGS.lr * tf.pow( FLAGS.lr_drop_rate, tf.floor( tf.cast(global_step, tf.float32) / FLAGS.lr_drop_steps)) scalars['learning_rate'] = lr optimizer = mtf.optimize.AdafactorOptimizer(learning_rate=lr) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) # This is where the actual tf graph got built. lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) tf_update_ops.extend( [lowering.lowered_operation(op) for op in bn_update_ops]) tf_update_ops_group = tf.group(tf_update_ops) else: # train_or_eval == 'eval': preds = [mtf.anonymize(pred) for pred in preds] # This is where the actual tf graph got built. lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_preds = [ tf.cast(lowering.export_to_tf_tensor(pred), tf.float32) for pred in preds ] tf_loss = tf.cast(lowering.export_to_tf_tensor(loss), tf.float32) if FLAGS.write_summary: tf_loss = _add_summary(lowering, train_or_eval, tf_loss, scalars, global_step) master_to_slice_hook = mtf.MtfRestoreHook(lowering) if train_or_eval == 'train': with mtf.utils.outside_all_rewrites(): saver = tf.train.Saver(tf.global_variables(), save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) slice_to_master_hook = tf.train.CheckpointSaverHook( FLAGS.checkpoint_dir, save_steps=FLAGS.save_checkpoints_steps, saver=saver, listeners=[saver_listener]) captured_hooks.capture( [master_to_slice_hook, slice_to_master_hook]) return tf_update_ops_group else: # train_or_eval == 'eval': tf_preds.extend([tf_loss, global_step]) tf_preds_dtypes = [tf_pred.dtype for tf_pred in tf_preds] tf_preds_shapes = [tf_pred.shape for tf_pred in tf_preds] captured_hooks.capture([master_to_slice_hook, None]) captured_output_dtypes_shapes.capture( [tf_preds_dtypes, tf_preds_shapes]) return tpu_ops.outfeed_enqueue_tuple(tf_preds)