def single_step_fn(): """Function for a single TPU step.""" all_input_data = multi_device_iterator.get_next() for core in range(FLAGS.num_cores): features_shape, features, labels = all_input_data[core] flattened_inputs = (inputs_flattener.flatten_features_and_labels( features, labels)) per_host_sharded_inputs.append(flattened_inputs) if params['transpose_input']: is_height_short_side = tf.less(features_shape[0], features_shape[1]) else: is_height_short_side = tf.less(features_shape[1], features_shape[2]) def height_short_side_model_fn(*args): """Mode function for input images with height on the short side.""" features, labels = inputs_flattener.unflatten_features_and_labels( args) features, labels = _set_feature_and_label_shapes( features, labels, params) spec = mask_rcnn_model.mask_rcnn_model_fn( features, labels, tf.estimator.ModeKeys.TRAIN, params) captured_scaffold_fn.capture(spec.scaffold_fn) return spec.train_op def height_long_side_model_fn(*args): """Mode function for input images with height on the long side.""" features, labels = inputs_flattener.unflatten_features_and_labels( args) # Create a new params which has the reversed dynamic image shape. new_params = copy.deepcopy(params) new_params['dynamic_image_size'] = new_params[ 'dynamic_image_size'][::-1] features, labels = _set_feature_and_label_shapes( features, labels, new_params) spec = mask_rcnn_model.mask_rcnn_model_fn( features, labels, tf.estimator.ModeKeys.TRAIN, new_params) captured_scaffold_fn.capture(spec.scaffold_fn) return spec.train_op rewrite_computation = tf.cond( is_height_short_side, lambda: tpu.replicate(height_short_side_model_fn, per_host_sharded_inputs), # pylint: disable=line-too-long lambda: tpu.replicate(height_long_side_model_fn, per_host_sharded_inputs) # pylint: disable=line-too-long ) return rewrite_computation
def _run_train_phase(): """The real function that runs the training phase.""" # Setup input pipeline. ds_creator = unet.get_dataset_creator('train') mtf_shapes = unet.get_input_mtf_shapes('train') model_train_fn, train_hooks, _ = _get_model_fn('train', mesh_context) if FLAGS.use_tpu: assert mesh_context.device_assignment assert mesh_context.num_cores simd_input_reader = input_reader.SimdMeshImplInputReader( mesh_context.mesh_impl, ds_creator, mtf_shapes, external_worker=(not FLAGS.on_gcp), is_eval_mode=False) train_computation = tpu.replicate( computation=model_train_fn, inputs=[[]] * mesh_context.num_cores, infeed_queue=simd_input_reader.infeed_queue, device_assignment=mesh_context.device_assignment) else: placement_input_reader = input_reader.PlacementMeshImplInputReader( mesh_context.mesh_impl, ds_creator, mtf_shapes, is_eval_mode=False) train_computation = placement_input_reader.gpu_placement(model_train_fn) ########################################################### # Training. master_to_slice_hook, slice_to_master_hook = train_hooks.get() ckpt_loader_hook = _CkptLoaderHook() step_counter_hook = tf.train.StepCounterHook(every_n_steps=10) all_hooks = [ckpt_loader_hook, master_to_slice_hook, slice_to_master_hook, step_counter_hook] if FLAGS.write_summary: flush_summary = contrib_summary.flush() with tf.train.MonitoredTrainingSession( master=master, scaffold=_get_scaffold(additional_initializers=[]), hooks=all_hooks, config=config) as sess: if FLAGS.write_summary: contrib_summary.initialize(session=sess) if FLAGS.use_tpu: simd_input_reader.start_infeed_thread( sess, FLAGS.num_train_iterations_per_loop) else: placement_input_reader.initialize(sess) for step in range(FLAGS.num_train_iterations_per_loop): sess.run(train_computation) if FLAGS.write_summary: sess.run(flush_summary) tf.logging.info('train steps: {}'.format(step))
def freeze_graph_tpu(model_path): """Custom freeze_graph implementation for Cloud TPU.""" assert model_path assert FLAGS.tpu_name if FLAGS.tpu_name.startswith('grpc://'): tpu_grpc_url = FLAGS.tpu_name else: tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver( FLAGS.tpu_name, zone=None, project=None) tpu_grpc_url = tpu_cluster_resolver.get_master() sess = tf.Session(tpu_grpc_url) output_names = [] with sess.graph.as_default(): # Replicate the inference function for each TPU core. replicated_features = [] feature_type = tf.bool if FLAGS.bool_features else tf.float32 for i in range(FLAGS.num_tpu_cores): name = 'pos_tensor_%d' % i features = tf.placeholder(feature_type, [None], name=name) replicated_features.append((features, )) outputs = contrib_tpu.replicate(tpu_model_inference_fn, replicated_features) # The replicate op assigns names like output_0_shard_0 to the output # names. Give them human readable names. for i, (policy_output, value_output, _) in enumerate(outputs): policy_name = 'policy_output_%d' % i value_name = 'value_output_%d' % i output_names.extend([policy_name, value_name]) tf.identity(policy_output, policy_name) tf.identity(value_output, value_name) tf.train.Saver().restore(sess, model_path) out_graph = tf.graph_util.convert_variables_to_constants( sess, sess.graph.as_graph_def(), output_names) metadata = make_model_metadata({ 'engine': 'tpu', 'num_replicas': FLAGS.num_tpu_cores, }) minigo_model.write_graph_def(out_graph, metadata, model_path + '.minigo')
def _run_train_phase(): """The real function that runs the training phase.""" # Setup input pipeline. ds_creator = unet.get_dataset_creator('train') mtf_shapes = unet.get_input_mtf_shapes('train') simd_input_reader = input_reader.SimdMeshImplInputReader( mesh_impl, ds_creator, mtf_shapes, is_eval_mode=False) model_train_fn, train_hooks, _ = _get_model_fn('train', cpu_devices, d_assignment, num_hosts) tpu_train_computation = tpu.replicate( computation=model_train_fn, inputs=[[]] * num_cores, infeed_queue=simd_input_reader.infeed_queue, device_assignment=d_assignment) ########################################################### # Training. master_to_slice_hook, slice_to_master_hook = train_hooks.get() ckpt_loader_hook = _CkptLoaderHook() if FLAGS.write_summary: flush_summary = tf.contrib.summary.flush() with tf.train.MonitoredTrainingSession( master=FLAGS.master, scaffold=_get_scaffold(additional_initializers=[]), hooks=[ ckpt_loader_hook, master_to_slice_hook, slice_to_master_hook ], config=tf.ConfigProto(allow_soft_placement=True)) as sess: if FLAGS.write_summary: tf.contrib.summary.initialize(session=sess) simd_input_reader.start_infeed_thread( sess, FLAGS.num_train_iterations_per_loop) for step in range(FLAGS.num_train_iterations_per_loop): sess.run(tpu_train_computation) if FLAGS.write_summary: sess.run(flush_summary) tf.logging.info('train steps: {}'.format(step))
def _run_eval_phase(): """The real function that runs the evaluation phase.""" # Setup input pipeline. ds_creator = unet.get_dataset_creator('eval') mtf_shapes = unet.get_input_mtf_shapes('eval') model_eval_fn, eval_hooks, output_dtypes_shapes = _get_model_fn( 'eval', mesh_context) if FLAGS.use_tpu: assert mesh_context.device_assignment assert mesh_context.num_cores simd_input_reader = input_reader.SimdMeshImplInputReader( mesh_context.mesh_impl, ds_creator, mtf_shapes, is_eval_mode=True) eval_computation = tpu.replicate( computation=model_eval_fn, inputs=[[]] * mesh_context.num_cores, infeed_queue=simd_input_reader.infeed_queue, device_assignment=mesh_context.device_assignment) output_dtypes, output_shapes = output_dtypes_shapes.get() outfeed_dequeue_ops = [] # Create outfeed_dequeue_ops. for host_id in range(mesh_context.num_hosts): # pylint: disable=protected-access with ops.device(input_reader._host_id_to_tf_device( host_id, external_worker=True)): for device_ordinal in range(mesh_context.num_cores_per_host): outfeed_dequeue_op = tpu_ops.outfeed_dequeue_tuple( dtypes=output_dtypes, shapes=output_shapes, device_ordinal=device_ordinal) # We don't need output other than from core 0. if outfeed_dequeue_ops: outfeed_dequeue_ops.append( [tf.reduce_mean(x) for x in outfeed_dequeue_op]) else: outfeed_dequeue_ops.append(outfeed_dequeue_op) else: placement_input_reader = input_reader.PlacementMeshImplInputReader( mesh_context.mesh_impl, ds_creator, mtf_shapes, is_eval_mode=False) eval_computation = placement_input_reader.gpu_placement(model_eval_fn) ########################################################### # Evaluation. master_to_slice_hook, _ = eval_hooks.get() ckpt_loader_hook = _CkptLoaderHook() all_hooks = [ckpt_loader_hook, master_to_slice_hook] if FLAGS.write_summary: flush_summary = contrib_summary.flush() with tf.train.MonitoredSession( session_creator=tf.train.ChiefSessionCreator( master=FLAGS.master, config=tf.ConfigProto(allow_soft_placement=True)), hooks=all_hooks) as sess: if FLAGS.write_summary: contrib_summary.initialize(session=sess) if FLAGS.use_tpu: simd_input_reader.start_infeed_thread( sess, FLAGS.num_eval_iterations_per_loop) else: placement_input_reader.initialize(sess) pprocessor = unet.PostProcessor() for step in range(FLAGS.num_eval_iterations_per_loop): # Only get results from the 0-th core. if FLAGS.use_tpu: sess.run(eval_computation) results = sess.run(outfeed_dequeue_ops)[0] else: results = sess.run(eval_computation) pprocessor.record(results, FLAGS.pred_output_dir) if FLAGS.write_summary: sess.run(flush_summary) tf.logging.info('eval steps: {}'.format(step)) pprocessor.finish()