def _run_train_phase(): """The real function that runs the training phase.""" # Setup input pipeline. ds_creator = unet.get_dataset_creator('train') mtf_shapes = unet.get_input_mtf_shapes('train') model_train_fn, train_hooks, _ = _get_model_fn('train', mesh_context) if FLAGS.use_tpu: assert mesh_context.device_assignment assert mesh_context.num_cores simd_input_reader = input_reader.SimdMeshImplInputReader( mesh_context.mesh_impl, ds_creator, mtf_shapes, external_worker=(not FLAGS.on_gcp), is_eval_mode=False) train_computation = tpu.replicate( computation=model_train_fn, inputs=[[]] * mesh_context.num_cores, infeed_queue=simd_input_reader.infeed_queue, device_assignment=mesh_context.device_assignment) else: placement_input_reader = input_reader.PlacementMeshImplInputReader( mesh_context.mesh_impl, ds_creator, mtf_shapes, is_eval_mode=False) train_computation = placement_input_reader.gpu_placement(model_train_fn) ########################################################### # Training. master_to_slice_hook, slice_to_master_hook = train_hooks.get() ckpt_loader_hook = _CkptLoaderHook() step_counter_hook = tf.train.StepCounterHook(every_n_steps=10) all_hooks = [ckpt_loader_hook, master_to_slice_hook, slice_to_master_hook, step_counter_hook] if FLAGS.write_summary: flush_summary = contrib_summary.flush() with tf.train.MonitoredTrainingSession( master=master, scaffold=_get_scaffold(additional_initializers=[]), hooks=all_hooks, config=config) as sess: if FLAGS.write_summary: contrib_summary.initialize(session=sess) if FLAGS.use_tpu: simd_input_reader.start_infeed_thread( sess, FLAGS.num_train_iterations_per_loop) else: placement_input_reader.initialize(sess) for step in range(FLAGS.num_train_iterations_per_loop): sess.run(train_computation) if FLAGS.write_summary: sess.run(flush_summary) tf.logging.info('train steps: {}'.format(step))
def _run_train_phase(): """The real function that runs the training phase.""" # Setup input pipeline. ds_creator = unet.get_dataset_creator('train') mtf_shapes = unet.get_input_mtf_shapes('train') simd_input_reader = input_reader.SimdMeshImplInputReader( mesh_impl, ds_creator, mtf_shapes, is_eval_mode=False) model_train_fn, train_hooks, _ = _get_model_fn('train', cpu_devices, d_assignment, num_hosts) tpu_train_computation = tpu.replicate( computation=model_train_fn, inputs=[[]] * num_cores, infeed_queue=simd_input_reader.infeed_queue, device_assignment=d_assignment) ########################################################### # Training. master_to_slice_hook, slice_to_master_hook = train_hooks.get() ckpt_loader_hook = _CkptLoaderHook() if FLAGS.write_summary: flush_summary = tf.contrib.summary.flush() with tf.train.MonitoredTrainingSession( master=FLAGS.master, scaffold=_get_scaffold(additional_initializers=[]), hooks=[ ckpt_loader_hook, master_to_slice_hook, slice_to_master_hook ], config=tf.ConfigProto(allow_soft_placement=True)) as sess: if FLAGS.write_summary: tf.contrib.summary.initialize(session=sess) simd_input_reader.start_infeed_thread( sess, FLAGS.num_train_iterations_per_loop) for step in range(FLAGS.num_train_iterations_per_loop): sess.run(tpu_train_computation) if FLAGS.write_summary: sess.run(flush_summary) tf.logging.info('train steps: {}'.format(step))
def _run_eval_phase(): """The real function that runs the evaluation phase.""" # Setup input pipeline. ds_creator = unet.get_dataset_creator('eval') mtf_shapes = unet.get_input_mtf_shapes('eval') model_eval_fn, eval_hooks, output_dtypes_shapes = _get_model_fn( 'eval', mesh_context) if FLAGS.use_tpu: assert mesh_context.device_assignment assert mesh_context.num_cores simd_input_reader = input_reader.SimdMeshImplInputReader( mesh_context.mesh_impl, ds_creator, mtf_shapes, is_eval_mode=True) eval_computation = tpu.replicate( computation=model_eval_fn, inputs=[[]] * mesh_context.num_cores, infeed_queue=simd_input_reader.infeed_queue, device_assignment=mesh_context.device_assignment) output_dtypes, output_shapes = output_dtypes_shapes.get() outfeed_dequeue_ops = [] # Create outfeed_dequeue_ops. for host_id in range(mesh_context.num_hosts): # pylint: disable=protected-access with ops.device(input_reader._host_id_to_tf_device( host_id, external_worker=True)): for device_ordinal in range(mesh_context.num_cores_per_host): outfeed_dequeue_op = tpu_ops.outfeed_dequeue_tuple( dtypes=output_dtypes, shapes=output_shapes, device_ordinal=device_ordinal) # We don't need output other than from core 0. if outfeed_dequeue_ops: outfeed_dequeue_ops.append( [tf.reduce_mean(x) for x in outfeed_dequeue_op]) else: outfeed_dequeue_ops.append(outfeed_dequeue_op) else: placement_input_reader = input_reader.PlacementMeshImplInputReader( mesh_context.mesh_impl, ds_creator, mtf_shapes, is_eval_mode=False) eval_computation = placement_input_reader.gpu_placement(model_eval_fn) ########################################################### # Evaluation. master_to_slice_hook, _ = eval_hooks.get() ckpt_loader_hook = _CkptLoaderHook() all_hooks = [ckpt_loader_hook, master_to_slice_hook] if FLAGS.write_summary: flush_summary = contrib_summary.flush() with tf.train.MonitoredSession( session_creator=tf.train.ChiefSessionCreator( master=FLAGS.master, config=tf.ConfigProto(allow_soft_placement=True)), hooks=all_hooks) as sess: if FLAGS.write_summary: contrib_summary.initialize(session=sess) if FLAGS.use_tpu: simd_input_reader.start_infeed_thread( sess, FLAGS.num_eval_iterations_per_loop) else: placement_input_reader.initialize(sess) pprocessor = unet.PostProcessor() for step in range(FLAGS.num_eval_iterations_per_loop): # Only get results from the 0-th core. if FLAGS.use_tpu: sess.run(eval_computation) results = sess.run(outfeed_dequeue_ops)[0] else: results = sess.run(eval_computation) pprocessor.record(results, FLAGS.pred_output_dir) if FLAGS.write_summary: sess.run(flush_summary) tf.logging.info('eval steps: {}'.format(step)) pprocessor.finish()
def test_get_laidout_tensors(self, is_eval_mode): mesh_shape = "mesh_x:2, mesh_y:1" layout = "batch:mesh_x, io:mesh_y" batch_io_dim = 4 with tf.Session() as sess: topology, num_cores = self.initialize_system(sess) # Get a device_assignment object for mtf. d_assignment = device_assignment.device_assignment( topology, computation_shape=[ 1, ] * mtf.utils.topology_rank(topology), num_replicas=num_cores) # Hacked dataset creator: creates different datasets for the first and # second call, in order to test SimdMeshImplInputReader. self.sub_batch_created_times = 0 def stateful_ds_creator(): whole_batch = tf.eye(batch_io_dim, dtype=tf.float32) sub_batch = tf.slice(whole_batch, [self.sub_batch_created_times * 2, 0], [2, 4]) self.sub_batch_created_times += 1 return tf.data.Dataset.from_tensors( sub_batch).repeat().unbatch() batch_dim = mtf.Dimension("batch", batch_io_dim) io_dim = mtf.Dimension("io", batch_io_dim) mtf_input_shapes = [mtf.Shape([batch_dim, io_dim])] # Get mesh_impl. mesh_shape = mtf.convert_to_shape(mesh_shape) layout_rules = mtf.convert_to_layout_rules(layout) mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, None, d_assignment) simd_input_reader = input_reader.SimdMeshImplInputReader( mesh_impl, stateful_ds_creator, mtf_input_shapes, external_worker=False, is_eval_mode=is_eval_mode) def model_fn(features): return features replicated_computation = tpu.replicate( computation=model_fn, inputs=[[]] * num_cores, infeed_queue=simd_input_reader.infeed_queue, device_assignment=d_assignment) simd_input_reader.start_infeed_thread(sess, 1) results = sess.run(replicated_computation) print("results: {}".format(results)) core_0_data = results[0][0] core_1_data = results[1][0] print("core_0_data: {}".format(core_0_data)) print("core_1_data: {}".format(core_1_data)) if is_eval_mode: # If there is only one dataset object, then the stateful_ds_creator() # should be called only once. self.assertAllClose( np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32), core_0_data) self.assertAllClose( np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32), core_1_data) else: # If there are two dataset objects, then the stateful_ds_creator() # should be called twice. self.assertAllClose( np.array([[1, 0, 0, 0], [0, 1, 0, 0]], dtype=np.float32), core_0_data) self.assertAllClose( np.array([[0, 0, 1, 0], [0, 0, 0, 1]], dtype=np.float32), core_1_data) sess.run(tf.tpu.shutdown_system())