def _run_train_phase(): """The real function that runs the training phase.""" # Setup input pipeline. ds_creator = unet.get_dataset_creator('train') mtf_shapes = unet.get_input_mtf_shapes('train') model_train_fn, train_hooks, _ = _get_model_fn('train', mesh_context) if FLAGS.use_tpu: assert mesh_context.device_assignment assert mesh_context.num_cores simd_input_reader = input_reader.SimdMeshImplInputReader( mesh_context.mesh_impl, ds_creator, mtf_shapes, external_worker=(not FLAGS.on_gcp), is_eval_mode=False) train_computation = tpu.replicate( computation=model_train_fn, inputs=[[]] * mesh_context.num_cores, infeed_queue=simd_input_reader.infeed_queue, device_assignment=mesh_context.device_assignment) else: placement_input_reader = input_reader.PlacementMeshImplInputReader( mesh_context.mesh_impl, ds_creator, mtf_shapes, is_eval_mode=False) train_computation = placement_input_reader.gpu_placement(model_train_fn) ########################################################### # Training. master_to_slice_hook, slice_to_master_hook = train_hooks.get() ckpt_loader_hook = _CkptLoaderHook() step_counter_hook = tf.train.StepCounterHook(every_n_steps=10) all_hooks = [ckpt_loader_hook, master_to_slice_hook, slice_to_master_hook, step_counter_hook] if FLAGS.write_summary: flush_summary = contrib_summary.flush() with tf.train.MonitoredTrainingSession( master=master, scaffold=_get_scaffold(additional_initializers=[]), hooks=all_hooks, config=config) as sess: if FLAGS.write_summary: contrib_summary.initialize(session=sess) if FLAGS.use_tpu: simd_input_reader.start_infeed_thread( sess, FLAGS.num_train_iterations_per_loop) else: placement_input_reader.initialize(sess) for step in range(FLAGS.num_train_iterations_per_loop): sess.run(train_computation) if FLAGS.write_summary: sess.run(flush_summary) tf.logging.info('train steps: {}'.format(step))
def _run_eval_phase(): """The real function that runs the evaluation phase.""" # Setup input pipeline. ds_creator = unet.get_dataset_creator('eval') mtf_shapes = unet.get_input_mtf_shapes('eval') model_eval_fn, eval_hooks, output_dtypes_shapes = _get_model_fn( 'eval', mesh_context) if FLAGS.use_tpu: assert mesh_context.device_assignment assert mesh_context.num_cores simd_input_reader = input_reader.SimdMeshImplInputReader( mesh_context.mesh_impl, ds_creator, mtf_shapes, is_eval_mode=True) eval_computation = tpu.replicate( computation=model_eval_fn, inputs=[[]] * mesh_context.num_cores, infeed_queue=simd_input_reader.infeed_queue, device_assignment=mesh_context.device_assignment) output_dtypes, output_shapes = output_dtypes_shapes.get() outfeed_dequeue_ops = [] # Create outfeed_dequeue_ops. for host_id in range(mesh_context.num_hosts): # pylint: disable=protected-access with ops.device(input_reader._host_id_to_tf_device( host_id, external_worker=True)): for device_ordinal in range(mesh_context.num_cores_per_host): outfeed_dequeue_op = tpu_ops.outfeed_dequeue_tuple( dtypes=output_dtypes, shapes=output_shapes, device_ordinal=device_ordinal) # We don't need output other than from core 0. if outfeed_dequeue_ops: outfeed_dequeue_ops.append( [tf.reduce_mean(x) for x in outfeed_dequeue_op]) else: outfeed_dequeue_ops.append(outfeed_dequeue_op) else: placement_input_reader = input_reader.PlacementMeshImplInputReader( mesh_context.mesh_impl, ds_creator, mtf_shapes, is_eval_mode=False) eval_computation = placement_input_reader.gpu_placement(model_eval_fn) ########################################################### # Evaluation. master_to_slice_hook, _ = eval_hooks.get() ckpt_loader_hook = _CkptLoaderHook() all_hooks = [ckpt_loader_hook, master_to_slice_hook] if FLAGS.write_summary: flush_summary = contrib_summary.flush() with tf.train.MonitoredSession( session_creator=tf.train.ChiefSessionCreator( master=FLAGS.master, config=tf.ConfigProto(allow_soft_placement=True)), hooks=all_hooks) as sess: if FLAGS.write_summary: contrib_summary.initialize(session=sess) if FLAGS.use_tpu: simd_input_reader.start_infeed_thread( sess, FLAGS.num_eval_iterations_per_loop) else: placement_input_reader.initialize(sess) pprocessor = unet.PostProcessor() for step in range(FLAGS.num_eval_iterations_per_loop): # Only get results from the 0-th core. if FLAGS.use_tpu: sess.run(eval_computation) results = sess.run(outfeed_dequeue_ops)[0] else: results = sess.run(eval_computation) pprocessor.record(results, FLAGS.pred_output_dir) if FLAGS.write_summary: sess.run(flush_summary) tf.logging.info('eval steps: {}'.format(step)) pprocessor.finish()
def __init__(self): threads = 8 graph = tf.Graph() self.session = tf.Session(graph=graph, config=tf.ConfigProto( inter_op_parallelism_threads=threads, intra_op_parallelism_threads=threads)) with graph.as_default(): self.images = tf.placeholder(tf.float32, shape=[None, 16, 16, 3], name="images") z_dim = 10 def encoder(img): out = tf.layers.flatten(img) out = tf.layers.dense(out, 500, activation=tf.nn.relu) out = tf.layers.dense(out, 500, activation=tf.nn.relu) out = tf.layers.dense(out, z_dim, activation=tf.nn.relu) return out def decoder(z): out = tf.layers.dense(z, 500, activation=tf.nn.relu) out = tf.layers.dense(out, 500, activation=tf.nn.relu) out = tf.layers.dense(out, 16 * 16 * 3, activation=None) return tf.reshape(out, [-1, 16, 16, 3]) self.z = encoder(self.images) self.generated_logits = decoder(self.z) self.generated_images = tf.nn.sigmoid(self.generated_logits, name="generated_images") self.loss = tf.losses.mean_squared_error(self.images, self.generated_images) global_step = tf.train.create_global_step() self.training = tf.train.AdamOptimizer().minimize( self.loss, global_step=global_step) logdir = "logs/autoencoder-{}-{}".format( z_dim, datetime.datetime.now().strftime("%Y-%m-%d_%H%M%S")) summary_writer = tfsum.create_file_writer(logdir, flush_millis=10 * 1000) with summary_writer.as_default( ), tfsum.record_summaries_every_n_global_steps(100): self.summaries = [ tfsum.scalar("loss", self.loss), tfsum.histogram("latent", self.z) ] self.generated_images_summary_data = tf.placeholder( tf.float32, [None, None, 3]) with summary_writer.as_default(), tfsum.always_record_summaries(): self.generated_images_summary = tfsum.image( "generated_image", tf.expand_dims(self.generated_images_summary_data, axis=0)) init = tf.global_variables_initializer() self.session.run(init) with summary_writer.as_default(): tfsum.initialize(session=self.session, graph=self.session.graph)
summary.histogram("obj masks", obj_p[0]) summary.histogram("flow_x_hist", flow[:, :, :, 0], family="flow") summary.histogram("flow_y_hist", flow[:, :, :, 1], family="flow") summary.image("frame0", cast_im(f0), max_images=3) summary.image("frame1", cast_im(f1), max_images=3) summary.image("frame1_t", cast_im(f1_t), max_images=3) summary.image("depth", cast_depth(depth), max_images=3) summary.image("optical_flow", cast_flow(flow), max_images=3) summary.image("object masks", cast_im(obj_p[0]), max_images=3) obj_summary(obj_p) cam_summary(cam_p) with tf.Session() as sess: tf.global_variables_initializer().run() summary.initialize(graph=tf.get_default_graph()) model.load_weights(os.path.join(models_path, "sfm.h5")) for s in range(S_max): l, *_ = sess.run( [loss, optimize, summary.all_summary_ops()]) # beholder.update(session=sess) if s % 50 == 0: print("Iteration: {} Loss: {}".format(s, l)) if s % 5000 == 0 and not s == 0: model.save_weights(os.path.join(models_path, "sfm.h5"))