コード例 #1
0
def train_ffn(model_cls, **model_kwargs):
    with tf.Graph().as_default():
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               merge_devices=True)):
            # The constructor might define TF ops/placeholders, so it is important
            # that the FFN is instantiated within the current context.

            model = model_cls(with_membrane=FLAGS.with_membrane,
                              is_training=not (FLAGS.validation_mode),
                              adabn=FLAGS.adabn,
                              grad_clip_val=FLAGS.cap_gradient,
                              **model_kwargs)
            eval_shape_zyx = train_eval_size(model).tolist()[::-1]

            eval_tracker = EvalTracker(eval_shape_zyx)
            load_data_ops = define_data_input(model, queue_batch=1)
            prepare_ffn(model)

            merge_summaries_op = tf.summary.merge_all()

            if FLAGS.task == 0:
                save_flags()

            # Start supervisor.
            train_dir = FLAGS.train_dir
            summary_rate_secs = 999999
            save_model_secs = 30 if FLAGS.topup_mode else 999999
            sv = tf.train.Supervisor(logdir=train_dir,
                                     is_chief=(FLAGS.task == 0),
                                     saver=model.saver,
                                     summary_op=None,
                                     save_summaries_secs=summary_rate_secs,
                                     save_model_secs=save_model_secs,
                                     recovery_wait_secs=5)
            sess = sv.prepare_or_wait_for_session(
                FLAGS.master,
                config=tf.ConfigProto(log_device_placement=False,
                                      allow_soft_placement=True))
            eval_tracker.sess = sess

            # TODO (jk): load from ckpt
            step = int(sess.run(model.global_step))
            step_since_session_start = 0

            if FLAGS.task > 0:
                # To avoid early instabilities when using multiple replicas, we use
                # a launch schedule where new replicas are brought online gradually.
                logging.info('Delaying replica start.')
                while True:
                    if (int(sess.run(model.global_step)) >=
                            FLAGS.replica_step_delay * FLAGS.task):
                        break
                    time.sleep(5.0)
            else:
                summary_writer = tf.summary.FileWriterCache.get(
                    FLAGS.train_dir)
                summary_writer.add_session_log(
                    tf.summary.SessionLog(status=tf.summary.SessionLog.START),
                    step)

            fov_shifts = list(model.shifts)  # x, y, z
            if FLAGS.shuffle_moves:
                random.shuffle(fov_shifts)

            policy_map = {
                'fixed': partial(fixed_offsets, fov_shifts=fov_shifts),
                'max_pred_moves': max_pred_offsets
            }
            batch_it = get_batch(lambda: sess.run(load_data_ops), eval_tracker,
                                 model, FLAGS.batch_size,
                                 policy_map[FLAGS.fov_policy])

            t_last = time.time()

            # TODO (jk): text log of learning curve. refresh file.
            max_steps = FLAGS.eval_steps
            step_since_session_start = 0

            while step_since_session_start < max_steps:
                if (step % 20 == 0) & (step_since_session_start > 0):
                    # TODO (jk): text log of learning curve. refresh file.
                    logging.info(
                        'Step: ' + str(step) + ',   prec: ' + str(
                            np.round(1000 * eval_tracker.tp /
                                     (eval_tracker.tp + eval_tracker.fp +
                                      0.000001))) + ',   recll: ' +
                        str(
                            np.round(1000 * eval_tracker.tp /
                                     (eval_tracker.tp + eval_tracker.fn +
                                      0.000001))) + ',   acc: ' +
                        str(
                            np.round(1000 *
                                     (eval_tracker.tp + eval_tracker.tn) /
                                     (eval_tracker.tp + eval_tracker.tn +
                                      eval_tracker.fp + eval_tracker.fn +
                                      0.000001))) + ',   #patches: ' +
                        str(eval_tracker.num_patches))

                seed, patches, labels, weights = next(batch_it)
                updated_seed, step, summ, accuracy = run_training_step(
                    sess,
                    model,
                    None,
                    None,
                    feed_dict={
                        model.loss_weights: weights,
                        model.labels: labels,
                        model.offset_label: 'off',
                        model.input_patches: patches,
                        model.input_seed: seed,
                    })

                step += 1
                step_since_session_start += 1
                mask.update_at(seed, (0, 0, 0), updated_seed)

            # RECORD RESULT
            eval_curve_txt = open(os.path.join(FLAGS.train_dir, 'eval.txt'),
                                  "a")
            eval_curve_txt.write(
                '\nStep: ' + str(step) + ',   prec: ' +
                str((eval_tracker.tp /
                     (eval_tracker.tp + eval_tracker.fp + 0.000001))) +
                ',   recll: ' +
                str((eval_tracker.tp /
                     (eval_tracker.tp + eval_tracker.fn + 0.000001))) +
                ',   acc: ' +
                str(((eval_tracker.tp + eval_tracker.tn) /
                     (eval_tracker.tp + eval_tracker.tn + eval_tracker.fp +
                      eval_tracker.fn + 0.000001))) + ',   #patches: ' +
                str(eval_tracker.num_patches))
            eval_curve_txt.close()
            print(
                ' prec: ' + str(
                    np.round(1000 * eval_tracker.tp /
                             (eval_tracker.tp + eval_tracker.fp + 0.000001))) +
                ', recll: ' + str(
                    np.round(1000 * eval_tracker.tp /
                             (eval_tracker.tp + eval_tracker.fn + 0.000001))) +
                ', acc: ' + str(
                    np.round(1000 * (eval_tracker.tp + eval_tracker.tn) /
                             (eval_tracker.tp + eval_tracker.tn +
                              eval_tracker.fp + eval_tracker.fn + 0.000001))))
コード例 #2
0
def train_ffn(model_cls, **model_kwargs):

    with tf.Graph().as_default():
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               merge_devices=True)):
            # The constructor might define TF ops/placeholders, so it is important
            # that the FFN is instantiated within the current context.
            model = model_cls(**model_kwargs)
            eval_shape_zyx = train_eval_size(model).tolist()[::-1]

            eval_tracker = EvalTracker(eval_shape_zyx)
            load_data_ops = define_data_input(model, queue_batch=1)
            prepare_ffn(model)
            merge_summaries_op = tf.summary.merge_all()

            if FLAGS.task == 0:
                save_flags()

            summary_writer = None
            saver = tf.train.Saver(keep_checkpoint_every_n_hours=0.25)
            scaffold = tf.train.Scaffold(saver=saver)
            with tf.train.MonitoredTrainingSession(
                    master=FLAGS.master,
                    is_chief=(FLAGS.task == 0),
                    save_summaries_steps=None,
                    save_checkpoint_secs=300,
                    config=tf.ConfigProto(log_device_placement=False,
                                          allow_soft_placement=True),
                    checkpoint_dir=FLAGS.train_dir,
                    scaffold=scaffold) as sess:

                eval_tracker.sess = sess
                step = int(sess.run(model.global_step))

                if FLAGS.task > 0:
                    # To avoid early instabilities when using multiple replicas, we use
                    # a launch schedule where new replicas are brought online gradually.
                    logging.info('Delaying replica start.')
                    while step < FLAGS.replica_step_delay * FLAGS.task:
                        time.sleep(5.0)
                        step = int(sess.run(model.global_step))
                else:
                    summary_writer = tf.summary.FileWriterCache.get(
                        FLAGS.train_dir)
                    summary_writer.add_session_log(
                        tf.summary.SessionLog(
                            status=tf.summary.SessionLog.START), step)

                fov_shifts = list(model.shifts)  # x, y, z
                if FLAGS.shuffle_moves:
                    random.shuffle(fov_shifts)

                policy_map = {
                    'fixed': partial(fixed_offsets, fov_shifts=fov_shifts),
                    'max_pred_moves': max_pred_offsets
                }
                batch_it = get_batch(lambda: sess.run(load_data_ops),
                                     eval_tracker, model, FLAGS.batch_size,
                                     policy_map[FLAGS.fov_policy])

                t_last = time.time()

                while not sess.should_stop() and step < FLAGS.max_steps:
                    # Run summaries periodically.
                    t_curr = time.time()
                    if t_curr - t_last > FLAGS.summary_rate_secs and FLAGS.task == 0:
                        summ_op = merge_summaries_op
                        t_last = t_curr
                    else:
                        summ_op = None

                    seed, patches, labels, weights = next(batch_it)

                    updated_seed, step, summ = run_training_step(
                        sess,
                        model,
                        summ_op,
                        feed_dict={
                            model.loss_weights: weights,
                            model.labels: labels,
                            model.input_patches: patches,
                            model.input_seed: seed,
                        })

                    # Save prediction results in the original seed array so that
                    # they can be used in subsequent steps.
                    mask.update_at(seed, (0, 0, 0), updated_seed)

                    # Record summaries.
                    if summ is not None:
                        logging.info('Saving summaries.')
                        summ = tf.Summary.FromString(summ)

                        # Compute a loss over the whole training patch (i.e. more than a
                        # single-step field of view of the network). This quantifies the
                        # quality of the final object mask.
                        summ.value.extend(eval_tracker.get_summaries())
                        eval_tracker.reset()

                        assert summary_writer is not None
                        summary_writer.add_summary(summ, step)

            if summary_writer is not None:
                summary_writer.flush()
コード例 #3
0
def train_ffn(model_cls, **model_kwargs):
    with tf.Graph().as_default():
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               merge_devices=True)):
            # The constructor might define TF ops/placeholders, so it is important
            # that the FFN is instantiated within the current context.

            if not FLAGS.validation_mode:
                model = model_cls(with_membrane=FLAGS.with_membrane,
                                  is_training=True,
                                  grad_clip_val=5.0,
                                  **model_kwargs)
            else:
                model = model_cls(with_membrane=FLAGS.with_membrane,
                                  is_training=False,
                                  adabn=FLAGS.adabn,
                                  **model_kwargs)

            eval_shape_zyx = train_eval_size(model).tolist()[::-1]

            eval_tracker = EvalTracker(eval_shape_zyx)
            load_data_ops = define_data_input(model, queue_batch=1)
            prepare_ffn(model)

            merge_summaries_op = tf.summary.merge_all()

            if FLAGS.task == 0:
                save_flags()

            # Start supervisor.
            if not FLAGS.validation_mode:
                save_model_secs = 1800
                summary_rate_secs = FLAGS.summary_rate_secs
                train_dir = FLAGS.train_dir
                sv = tf.train.Supervisor(logdir=train_dir,
                                         is_chief=(FLAGS.task == 0),
                                         saver=model.saver,
                                         summary_op=None,
                                         save_summaries_secs=summary_rate_secs,
                                         save_model_secs=save_model_secs,
                                         recovery_wait_secs=5)
                sess = sv.prepare_or_wait_for_session(
                    FLAGS.master,
                    config=tf.ConfigProto(log_device_placement=False,
                                          allow_soft_placement=True))
            else:
                train_dir = FLAGS.train_dir + 'v'
                summary_rate_secs = 999999
                save_model_secs = 999999
                sv = tf.train.Supervisor(logdir=train_dir,
                                         is_chief=(FLAGS.task == 0),
                                         saver=None,
                                         summary_op=None,
                                         save_summaries_secs=summary_rate_secs,
                                         save_model_secs=save_model_secs,
                                         recovery_wait_secs=5)
                sess = sv.prepare_or_wait_for_session(
                    FLAGS.master,
                    config=tf.ConfigProto(log_device_placement=False,
                                          allow_soft_placement=True))
            eval_tracker.sess = sess

            # TODO (jk): load from ckpt
            if FLAGS.load_from_ckpt != 'None':
                logging.info('>>>>>>>>>>>>>>>>>>>>> Loading checkpoint.')
                model.saver.restore(eval_tracker.sess, FLAGS.load_from_ckpt)
                logging.info('>>>>>>>>>>>>>>>>>>>>> Checkpoint loaded.')
            step = int(sess.run(model.global_step))
            step_since_session_start = 0

            if FLAGS.task > 0:
                # To avoid early instabilities when using multiple replicas, we use
                # a launch schedule where new replicas are brought online gradually.
                logging.info('Delaying replica start.')
                while True:
                    if (int(sess.run(model.global_step)) >=
                            FLAGS.replica_step_delay * FLAGS.task):
                        break
                    time.sleep(5.0)
            else:
                summary_writer = tf.summary.FileWriterCache.get(
                    FLAGS.train_dir)
                summary_writer.add_session_log(
                    tf.summary.SessionLog(status=tf.summary.SessionLog.START),
                    step)

            fov_shifts = list(model.shifts)  # x, y, z
            if FLAGS.shuffle_moves:
                random.shuffle(fov_shifts)

            policy_map = {
                'fixed': partial(fixed_offsets, fov_shifts=fov_shifts),
                'max_pred_moves': max_pred_offsets
            }
            batch_it = get_batch(lambda: sess.run(load_data_ops), eval_tracker,
                                 model, FLAGS.batch_size,
                                 policy_map[FLAGS.fov_policy])

            t_last = time.time()

            if not FLAGS.validation_mode:
                # TODO (jk): text log of learning curve. refresh file.
                learning_curve_txt = open(
                    os.path.join(FLAGS.train_dir, 'lc.txt'), "w")
                learning_curve_txt.close()
                max_steps = FLAGS.max_steps
            else:
                max_steps = step + 50000 / FLAGS.batch_size

            # if FLAGS.adabn:
            #     ####################### BASIC ADABN
            #     sess.run(model.ada_initializer)

            mean1o = 0.
            mean2o = 0.
            var1o = 1.
            var2o = 1.
            while step < max_steps:
                if (step % 20 == 0) & (step_since_session_start > 0):
                    # TODO (jk): text log of learning curve. refresh file.
                    logging.info(
                        'Step: ' + str(step) + ',   prec: ' + str(
                            np.round(1000 * eval_tracker.tp /
                                     (eval_tracker.tp + eval_tracker.fp +
                                      0.000001))) + ',   recll: ' +
                        str(
                            np.round(1000 * eval_tracker.tp /
                                     (eval_tracker.tp + eval_tracker.fn +
                                      0.000001))) + ',   acc: ' +
                        str(
                            np.round(1000 *
                                     (eval_tracker.tp + eval_tracker.tn) /
                                     (eval_tracker.tp + eval_tracker.tn +
                                      eval_tracker.fp + eval_tracker.fn +
                                      0.000001))) + ',   #patches: ' +
                        str(eval_tracker.num_patches))

                # Run summaries periodically.
                t_curr = time.time()

                # TIME-BASED SUMMARY
                # if t_curr - t_last > FLAGS.summary_rate_secs and FLAGS.task == 0:
                #   summ_op = merge_summaries_op
                #   t_last = t_curr
                # else:
                #   summ_op = None

                # TODO (jk): ITERATION-BASED SUMMARY
                if (step % 100 == 0) & (step > 0):
                    summ_op = merge_summaries_op
                else:
                    summ_op = None

                if FLAGS.validation_mode & (
                    (step_since_session_start % 500)
                        == 0) & (step_since_session_start > 0):
                    print('REFRESHING EVAL TRACKER...')
                    eval_tracker.reset()

                # if ((step_since_session_start % 5) == 0) & (model.moment_list is not None):
                #     mean1 = sess.run(model.moment_list[0].name)
                #     mean2 = sess.run(model.moment_list[-2].name)
                #     var1 = sess.run(model.moment_list[1].name)
                #     var2 = sess.run(model.moment_list[-1].name)
                #     diff = np.sum(np.square(mean1-mean1o)) + np.sum(np.square(mean2-mean2o)) + np.sum(np.square(var1-var1o)) + np.sum(np.square(var2-var2o))
                #     mean1o = mean1
                #     mean2o = mean2
                #     var1o = var1
                #     var2o = var2
                #     print('moment displacement = ' + str(diff))
                #     # print(mean1)
                #     # print(var1)

                # if FLAGS.adabn:
                #   ####################### REMOVE THIS TO TURN OFF INSTANCE NORMALIZATION
                #   sess.run(model.ada_initializer)

                # if ((step_since_session_start % 50) == 0) & (step_since_session_start > 0):
                #   sess.run(model.ada_initializer)
                #   # sess.run(model.fgru_ada_initializer)
                #   # sess.run(model.ext_ada_initializer)
                #   eval_tracker.reset()
                #   print('REFRESHING MOMENTS, eval tracker...')

                seed, patches, labels, weights = next(batch_it)
                updated_seed, step, summ, accuracy = run_training_step(
                    sess,
                    model,
                    summ_op,
                    None,
                    feed_dict={
                        model.loss_weights: weights,
                        model.labels: labels,
                        model.offset_label: 'off',
                        model.input_patches: patches,
                        model.input_seed: seed,
                    })
                if FLAGS.validation_mode:
                    step += 1
                step_since_session_start += 1

                # Save prediction results in the original seed array so that
                # they can be used in subsequent steps.
                mask.update_at(seed, (0, 0, 0), updated_seed)

                # Record summaries.
                if summ is not None:

                    # TODO (jk): text log of learning curve
                    learning_curve_txt = open(
                        os.path.join(FLAGS.train_dir, 'lc.txt'), "a")
                    precision = eval_tracker.tp / (eval_tracker.tp +
                                                   eval_tracker.fp + 0.0001)
                    recall = eval_tracker.tp / (eval_tracker.tp +
                                                eval_tracker.fn + 0.0001)
                    accuracy = (eval_tracker.tp + eval_tracker.tn) / (
                        eval_tracker.tp + eval_tracker.tn + eval_tracker.fp +
                        eval_tracker.fn + 0.0001)
                    if not FLAGS.validation_mode:
                        learning_curve_txt.write('step_' + str(step) +
                                                 '_precision_' +
                                                 str(precision) + '_recall_' +
                                                 str(recall) + '_accuracy_' +
                                                 str(accuracy))
                        learning_curve_txt.write("\n")
                        learning_curve_txt.close()
                        logging.info('Saving summaries.')
                        summ = tf.Summary.FromString(summ)

                        # Compute a loss over the whole training patch (i.e. more than a
                        # single-step field of view of the network). This quantifies the
                        # quality of the final object mask.
                        summ.value.extend(eval_tracker.get_summaries())
                        eval_tracker.reset()

                        assert summary_writer is not None
                        summary_writer.add_summary(summ, step)

                    if np.min([precision, recall]) > 0.97:
                        logging.info(
                            '>>>>>>>>>>>>>>>>>>>>> Target performance (both prec and recall >0.9) reached.'
                        )
                        break

                if summary_writer is not None:
                    summary_writer.flush()
コード例 #4
0
def train_ffn(model_cls, **model_kwargs):
    hvd.init()
    logging.info('Rank: %d %d' % (hvd.rank(), rank))
    with tf.Graph().as_default():
        # The constructor might define TF ops/placeholders, so it is important
        # that the FFN is instantiated within the current context.
        model = model_cls(**model_kwargs)
        eval_shape_zyx = train_eval_size(model).tolist()[::-1]

        eval_tracker = EvalTracker(eval_shape_zyx)
        load_data_ops = h5_distributed_dataset(model, queue_batch=1)
        print(load_data_ops)

        prepare_ffn(model)
        merge_summaries_op = tf.compat.v1.summary.merge_all()

        if FLAGS.task == 0:
            save_flags()

        hooks = [
            # Horovod: BroadcastGlobalVariablesHook broadcasts initial variable states
            # from rank 0 to all other processes. This is necessary to ensure consistent
            # initialization of all workers when training is started with random weights
            # or restored from a checkpoint.
            hvd.BroadcastGlobalVariablesHook(0),

            # Horovod: adjust number of steps based on number of GPUs.
            tf.estimator.StopAtStepHook(last_step=FLAGS.max_steps //
                                        hvd.size()),
        ]

        config = tf.compat.v1.ConfigProto()
        config.gpu_options.allow_growth = True
        config.gpu_options.visible_device_list = str(hvd.local_rank())

        checkpoint_dir = FLAGS.train_dir if hvd.rank() == 0 else None
        summary_writer = None
        saver = tf.compat.v1.train.Saver(max_to_keep=None,
                                         keep_checkpoint_every_n_hours=24)
        scaffold = tf.compat.v1.train.Scaffold(saver=saver)

        # model.global_step = None
        # logging.warning('GLOBAL STEP %s %s', model.global_step, model.global_step.dtype.base_dtype.is_integer)
        # logging.warning('GLOBAL_STEP assert')
        # tf.compat.v1.train.training_utils.assert_global_step(model.global_step)
        # tf.train.training_utils.assert_global_step(model.global_step)

        # global_step = tf.Variable(0, name='global_step', trainable=False)
        # g = ops.get_default_graph()
        # logging.warning('GRAPH %s', g)
        # glst = g.get_collection(ops.GraphKeys.GLOBAL_STEP)
        # logging.warning('GRAPH GLOBAL%s', glst)

        with tf.compat.v1.train.MonitoredTrainingSession(
                master=FLAGS.master,
                is_chief=(FLAGS.task == 0),
                checkpoint_dir=checkpoint_dir,
                hooks=hooks,
                save_checkpoint_secs=300,
                save_summaries_steps=None,
                config=config,
                scaffold=scaffold) as sess:
            eval_tracker.sess = sess

            step = int(sess.run(model.global_step))

            if FLAGS.task > 0:
                # To avoid early instabilities when using multiple replicas, we use
                # a launch schedule where new replicas are brought online gradually.
                logging.info('Delaying replica start.')
                while step < FLAGS.replica_step_delay * FLAGS.task:
                    time.sleep(5.0)

            if rank == 0:
                summary_writer = tf.compat.v1.summary.FileWriterCache.get(
                    FLAGS.train_dir)
                summary_writer.add_session_log(
                    tf.compat.v1.summary.SessionLog(
                        status=tf.compat.v1.summary.SessionLog.START), step)

            fov_shifts = list(model.shifts)  # x, y, z
            if FLAGS.shuffle_moves:
                random.shuffle(fov_shifts)

            policy_map = {
                'fixed': partial(fixed_offsets, fov_shifts=fov_shifts),
                'max_pred_moves': max_pred_offsets
            }
            batch_it = get_batch(
                lambda: sess.run(load_data_ops),
                eval_tracker,
                model,
                FLAGS.batch_size,
                # eval_tracker, model, 1,
                policy_map[FLAGS.fov_policy])

            t_last = time.time()
            while not sess.should_stop() and step < FLAGS.max_steps:
                # Run summaries periodically.
                t_curr = time.time()
                if t_curr - t_last > FLAGS.summary_rate_secs and FLAGS.task == 0:
                    summ_op = merge_summaries_op
                    t_last = t_curr
                else:
                    summ_op = None

                seed, patches, labels, weights = next(batch_it)

                updated_seed, step, summ = run_training_step(
                    sess,
                    model,
                    summ_op,
                    feed_dict={
                        model.loss_weights: weights,
                        model.labels: labels,
                        model.input_patches: patches,
                        model.input_seed: seed,
                    })

                # Save prediction results in the original seed array so that
                # they can be used in subsequent steps.
                mask.update_at(seed, (0, 0, 0), updated_seed)

                # Record summaries.
                if hvd.rank() == 0 and summ is not None:
                    logging.info('Saving summaries.')
                    summ = tf.compat.v1.Summary.FromString(summ)

                    # Compute a loss over the whole training patch (i.e. more than a
                    # single-step field of view of the network). This quantifies the
                    # quality of the final object mask.
                    summ.value.extend(eval_tracker.get_summaries())
                    eval_tracker.reset()

                    assert summary_writer is not None
                    summary_writer.add_summary(summ, step)

        if summary_writer is not None:
            summary_writer.flush()
コード例 #5
0
def train_ffn(model_cls, **model_kwargs):
  with tf.Graph().as_default():
    with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks, merge_devices=True)):
      # The constructor might define TF ops/placeholders, so it is important
      # that the FFN is instantiated within the current context.
      model = model_cls(**model_kwargs)
      eval_shape_zyx = train_eval_size(model).tolist()[::-1]

      eval_tracker = EvalTracker(eval_shape_zyx)
      load_data_ops = define_data_input(model, queue_batch=1)
      prepare_ffn(model)
      merge_summaries_op = tf.summary.merge_all()

      if FLAGS.task == 0:
        save_flags()

      # Start supervisor.
      sv = tf.train.Supervisor(
          logdir=FLAGS.train_dir,
          is_chief=(FLAGS.task == 0),
          saver=model.saver,
          summary_op=None,
          save_summaries_secs=FLAGS.summary_rate_secs,
          save_model_secs=300,
          recovery_wait_secs=5)
      sess = sv.prepare_or_wait_for_session(
          FLAGS.master,
          config=tf.ConfigProto(
              log_device_placement=False, allow_soft_placement=True))
      eval_tracker.sess = sess

      options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
      run_metadata = tf.RunMetadata()
      
      if FLAGS.task > 0:
        # To avoid early instabilities when using multiple replicas, we use
        # a launch schedule where new replicas are brought online gradually.
        logging.info('Delaying replica start.')
        while True:
          if (int(sess.run(model.global_step,options=options, run_metadata=run_metadata)) >= FLAGS.replica_step_delay *
              FLAGS.task):
            break
          time.sleep(5.0)

      fov_shifts = list(model.shifts)  # x, y, z
      if FLAGS.shuffle_moves:
        random.shuffle(fov_shifts)

      policy_map = {
          'fixed': partial(fixed_offsets, fov_shifts=fov_shifts),
          'max_pred_moves': max_pred_offsets
      }
      batch_it = get_batch(lambda: sess.run(load_data_ops,options=options, run_metadata=run_metadata),
                           eval_tracker, model, FLAGS.batch_size,
                           policy_map[FLAGS.fov_policy])
      
      fetched_timeline = timeline.Timeline(run_metadata.step_stats)
      chrome_trace = fetched_timeline.generate_chrome_trace_format()
      with open('timeline_01.json', 'w') as f:
        f.write(chrome_trace)

      step = 0
      t_last = time.time()

      while step < FLAGS.max_steps:
        # Run summaries periodically.
        logging.info('Iteration' + str(step))
        t_curr = time.time()
        if t_curr - t_last > FLAGS.summary_rate_secs and FLAGS.task == 0:
          summ_op = merge_summaries_op
          t_last = t_curr
        else:
          summ_op = None

        seed, patches, labels, weights = next(batch_it)

        summaries = []
        updated_seed, step, summ = run_training_step(
            sess, model, summ_op,
            feed_dict={
                model.loss_weights: weights,
                model.labels: labels,
                model.offset_label: 'off',
                model.input_patches: patches,
                model.input_seed: seed,
            })

        # Save prediction results in the original seed array so that
        # they can be used in subsequent steps.
        mask.update_at(seed, (0, 0, 0), updated_seed)

        if summ is not None:
          summaries.append(tf.Summary.FromString(summ))

        # Record summaries.
        if FLAGS.task == 0 and summ_op is not None:
          # Compute a loss over the whole training patch (i.e. more than a
          # single-step field of view of the network). This quantifies the
          # quality of the final object mask.
          logging.info('Saving summaries.')
          summ = tf.Summary()
          summ.value.extend(eval_tracker.get_summaries())
          eval_tracker.reset()

          for s in summaries:
            summ.value.extend(s.value)
          sv.summary_computed(sess, summ, step)
コード例 #6
0
ファイル: train_hvd.py プロジェクト: keceli/ffn
def train_ffn(model_cls, **model_kwargs):
  with tf.Graph().as_default():
    # The constructor might define TF ops/placeholders, so it is important
    # that the FFN is instantiated within the current context.
    model = model_cls(**model_kwargs)
    eval_shape_zyx = train_eval_size(model).tolist()[::-1]

    eval_tracker = EvalTracker(eval_shape_zyx)
    load_data_ops = define_data_input(model, queue_batch=1)
    prepare_ffn(model)
    merge_summaries_op = tf.summary.merge_all()

    if hvd.rank() == 0:
      save_flags()

    summary_writer = None
    saver = tf.train.Saver(keep_checkpoint_every_n_hours=0.25,max_to_keep=20)
    scaffold = tf.train.Scaffold(saver=saver)
    if horovodworks:
      hooks = [hvd.BroadcastGlobalVariablesHook(0),
             tf.train.StopAtStepHook(last_step=FLAGS.max_steps),]
    else:
      hooks = [tf.train.StopAtStepHook(last_step=FLAGS.max_steps),]

    config=tf.ConfigProto(log_device_placement=False,
                          allow_soft_placement=True,
                          intra_op_parallelism_threads = FLAGS.num_intra_threads,
                          inter_op_parallelism_threads = FLAGS.num_inter_threads)
    config.gpu_options.allow_growth = True
    config.gpu_options.visible_device_list = str(hvd.local_rank())

    # Horovod: save checkpoints only on worker 0 to prevent other workers from
    # corrupting them.
    checkpoint_dir = FLAGS.train_dir if hvd.rank() == 0 else None
    with tf.train.MonitoredTrainingSession(
        master=FLAGS.master,
        save_summaries_steps=None,
        save_checkpoint_secs=FLAGS.summary_rate_secs,
        config=config,
        checkpoint_dir=checkpoint_dir,
        hooks=hooks,
        scaffold=scaffold) as sess:

      eval_tracker.sess = sess
      step = int(sess.run(model.global_step))
      if hvd.rank() == 0:
        summary_writer = tf.summary.FileWriterCache.get(FLAGS.train_dir)
        summary_writer.add_session_log(
            tf.summary.SessionLog(status=tf.summary.SessionLog.START), step)


      fov_shifts = list(model.shifts)  # x, y, z
      if FLAGS.shuffle_moves:
        random.shuffle(fov_shifts)

      policy_map = {
          'fixed': partial(fixed_offsets, fov_shifts=fov_shifts),
          'max_pred_moves': max_pred_offsets
      }
      batch_it = get_batch(lambda: sess.run(load_data_ops),
                           eval_tracker, model, FLAGS.batch_size,
                           policy_map[FLAGS.fov_policy])

      t_last = time.time()

      while not sess.should_stop() and step < FLAGS.max_steps:
        # Run summaries periodically.
        t_curr = time.time()
        if t_curr - t_last > FLAGS.summary_rate_secs and hvd.rank() == 0:
          summ_op = merge_summaries_op
          t_last = t_curr
        else:
          summ_op = None

        seed, patches, labels, weights = next(batch_it)
        target_lr = get_learning_rate(step,FLAGS.batch_size)

        updated_seed, step, summ = run_training_step(
            sess, model, summ_op,
            feed_dict={
                model.loss_weights: weights,
                model.labels: labels,
                model.input_patches: patches,
                model.input_seed: seed,
                model.learning_rate: target_lr # Wushi: update learning rate
            })

        # Save prediction results in the original seed array so that
        # they can be used in subsequent steps.
        mask.update_at(seed, (0, 0, 0), updated_seed)

        # Record summaries.
        if summ is not None and hvd.rank() == 0:
          logging.info('Saving summaries.')
          summ = tf.Summary.FromString(summ)

          # Compute a loss over the whole training patch (i.e. more than a
          # single-step field of view of the network). This quantifies the
          # quality of the final object mask.
          summ.value.extend(eval_tracker.get_summaries())
          eval_tracker.reset()

          assert summary_writer is not None
          summary_writer.add_summary(summ, step)

    if summary_writer is not None:
      summary_writer.flush()
コード例 #7
0
def train_ffn(model_cls, **model_kwargs):
    with tf.Graph().as_default():
        model = model_cls(**model_kwargs)  # initialize the model
        eval_shape_zyx = train_eval_size(model).tolist(
        )[::-1]  # size of the subvolume (within which the FOV moves)
        eval_tracker = EvalTracker(
            eval_shape_zyx)  # computes summary statistics inside EFOV
        load_data_ops = define_data_input(
            model,
            queue_batch=1)  # this creates a batch of training subvolumes
        prepare_ffn(model)  # here the tf graph is defined
        merge_summaries_op = tf.summary.merge_all(
        )  # merges all summaries defined in the graph.

        if hvd.rank() == 0:
            save_flags()

        var_to_reduce = tf.placeholder(tf.float32)
        bcast_op = hvd.broadcast_global_variables(0)
        avg_op = hvd.allreduce(var_to_reduce, average=True)

        # Start supervisor.
        sv = tf.train.Supervisor(
            logdir=(FLAGS.train_dir if hvd.rank() == 0 else None),
            is_chief=True,
            saver=(tf.train.Saver(max_to_keep=FLAGS.max_to_keep,
                                  keep_checkpoint_every_n_hours=1)
                   if hvd.rank() == 0 else None),
            save_model_secs=(FLAGS.save_model_secs if hvd.rank() == 0 else 0),
            summary_op=None,
            save_summaries_secs=0,  # will perform custom summaries instead
        )

        sess = sv.prepare_or_wait_for_session(
            FLAGS.master,
            config=tf.ConfigProto(
                log_device_placement=False,
                allow_soft_placement=True,
                intra_op_parallelism_threads=FLAGS.num_intra_threads,
                inter_op_parallelism_threads=FLAGS.num_inter_threads))

        # broadcast initial weights. This ensures that all horovod ranks
        # start at the same point in parameter space
        if hvd.rank() == 0:
            print("broadcasting initial weights")
        sess.run(bcast_op)

        eval_tracker.sess = sess  #--connect the eval tracker to the session
        eval_tracker.avg_op = avg_op
        fov_shifts = list(model.shifts)  # x, y, z
        if FLAGS.shuffle_moves:
            random.shuffle(
                fov_shifts
            )  #--this will shuffle the FOV positions that make up an extended FOV (EFOV)

        policy_map = {
            'fixed': partial(fixed_offsets, fov_shifts=fov_shifts),
            'max_pred_moves': max_pred_offsets
        }
        # batch iterator for getting the next batch
        batch_it = get_batch(lambda: sess.run(load_data_ops), eval_tracker,
                             model, FLAGS.batch_size,
                             policy_map[FLAGS.fov_policy])

        step = 0
        t_last = time.time()

        if hvd.rank() == 0:
            timing = []  # list of times for benchmarking

        steps_since_last_summary = 0
        if hvd.rank() == 0:
            print("starting training")
        while step < FLAGS.max_steps:
            time_step_start = time.time()

            if steps_since_last_summary == FLAGS.summary_every_steps:
                summ_op = merge_summaries_op
                steps_since_last_summary = 1
                if hvd.rank() == 0:
                    print("step ", step, "is a summary step")
            else:
                summ_op = None
                steps_since_last_summary += 1

            # get the next batch - this is reading the data from disk.
            seed, patches, labels, weights = next(batch_it)

            summaries = []

            scaled_lr = FLAGS.learning_rate
            if (FLAGS.scaling_rule >
                    0):  #--scale the learning rate linearly (1) or sqrt (2)
                if FLAGS.scaled_lr == 1:
                    scaled_lr *= hvd.size()
                elif FLAGS.scaled_lr == 2:
                    scaled_lr *= np.sqrt(hvd.size())
                if step < FLAGS.warmup_steps:
                    scaled_lr = FLAGS.learning_rate + (step / float(
                        FLAGS.warmup_steps)) * (scaled_lr -
                                                FLAGS.learning_rate)

            if (
                    FLAGS.decay_learning_rate_fraction > 0
            ):  # constantly decay the learning rate using exponential decay
                scaled_lr *= (FLAGS.decay_learning_rate_fraction)**(
                    step / FLAGS.decay_learning_rate_steps)

            updated_seed, step, summ, my_loss = run_training_step( # run training step on a SINGLE FOV
                sess, model, summ_op,
                feed_dict={
                    model.loss_weights: weights,
                    model.labels: labels,
                    model.offset_label: 'off',
                    model.input_patches: patches,
                    model.input_seed: seed,
                    model.learning_rate: scaled_lr
                })

            # compute average loss
            avg_loss = sess.run(avg_op, feed_dict={var_to_reduce: my_loss})

            # Save prediction results in the original seed array so that
            # they can be used in subsequent steps.
            mask.update_at(
                seed, (0, 0, 0),
                updated_seed)  # updates the mask inside the subvolume batches
            if hvd.rank() == 0:
                this_time = time.time(
                ) - time_step_start  # how long did this step take
                timing.append(this_time)
                print("step %i took %.2f seconds" % (step - 1, this_time))

            if summ is not None:
                summaries.append(tf.Summary.FromString(
                    summ))  # this adds the summaries from the single FOV

                # Compute a loss over the whole training patch (i.e. more than a
                # single-step field of view of the network). This quantifies the
                # quality of the final object mask.

                tp, fp, tn, fn, num_patches = eval_tracker.get_summaries_scalar(
                )

                tp_sum = hvd.size() * sess.run(
                    avg_op, feed_dict={var_to_reduce: float(tp)})
                fp_sum = hvd.size() * sess.run(
                    avg_op, feed_dict={var_to_reduce: float(fp)})
                tn_sum = hvd.size() * sess.run(
                    avg_op, feed_dict={var_to_reduce: float(tn)})
                fn_sum = hvd.size() * sess.run(
                    avg_op, feed_dict={var_to_reduce: float(fn)})
                avg_num_patches = sess.run(
                    avg_op, feed_dict={var_to_reduce: float(num_patches)})

                accuracy = (tp_sum + tn_sum) / (tp_sum + fp_sum + tn_sum +
                                                fn_sum)
                precision = (tp_sum) / (tp_sum + fp_sum)
                recall = (tp_sum) / (tp_sum + fn_sum)
                f1 = 2.0 * precision * recall / (precision + recall)

                eval_tracker_summaries = ([
                    tf.Summary.Value(tag='eval/patches',
                                     simple_value=avg_num_patches),
                    tf.Summary.Value(tag='eval/accuracy',
                                     simple_value=accuracy),
                    tf.Summary.Value(tag='eval/precision',
                                     simple_value=precision),
                    tf.Summary.Value(tag='eval/recall', simple_value=recall),
                    tf.Summary.Value(tag='eval/f1', simple_value=f1)
                ])

                if hvd.rank() == 0:
                    logging.info('Saving summaries.')
                    summ = tf.Summary()  #initialize tensorflow summary
                    summ.value.extend(
                        eval_tracker_summaries)  # add EFOV metrics
                    summ.value.extend(eval_tracker.get_summaries_images()
                                      )  # add image summaries
                    for s in summaries:
                        summ.value.extend(s.value)  # add FOV metrics
                    # other custom summary items:
                    summ.value.extend([
                        tf.Summary.Value(tag='avg_pixel_loss',
                                         simple_value=avg_loss)
                    ])  #avg pixel loss
                    summ.value.extend([
                        tf.Summary.Value(tag='learning_rate',
                                         simple_value=scaled_lr)
                    ])  #(scaled) learning rate
                    summ.value.extend([
                        tf.Summary.Value(tag='avg_time_per_step',
                                         simple_value=np.mean(timing))
                    ])  #avg time per step
                    print("avg time per step: ", np.mean(timing),
                          np.std(timing))
                    print("avg throughput: ",
                          FLAGS.batch_size / np.mean(timing))
                    timing = []  # reset the timing array
                    sv.summary_computed(sess, summ, step)

                # reset eval tracker before the next training step.
                eval_tracker.reset()

        if hvd.rank() == 0:
            print("all steps done!")
            if (FLAGS.do_benchmark_test == 1):
                print("benchmark result: ")
                print("steps, ranks, threads, mean, sigma:")
                string = str(FLAGS.max_steps) + "," + str(
                    FLAGS.batch_size) + "," + str(
                        hvd.size()) + "," + str(FLAGS.nthreads) + "," + str(
                            np.mean(timing)) + "," + str(np.std(timing)) + "\n"
                print(string)
                with open(FLAGS.timelog, "a") as myfile:
                    myfile.write(string)
コード例 #8
0
ファイル: train.py プロジェクト: malei-pku/ffn-tracer
def main(argv):
    experiment_uid = uid_from_flags(FLAGS)
    train_dir = os.path.join(FLAGS.train_base_dir, experiment_uid)
    with tf.Graph().as_default():
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               merge_devices=True)):
            # The constructor might define TF ops/placeholders, so it is important
            # that the FFN is instantiated within the current context.

            # Note: all inputs to ffn.training.model.FFNModel are in format (x, y, z),
            # except the fov_size, for historical reasons.

            # If fov_size is specified at command line, it will be stored as a list of
            # strings; these need to be coerced to integers.

            model = FFNTracerModel(batch_size=FLAGS.batch_size,
                                   adv_args=FLAGS.adv_args,
                                   **json.loads(FLAGS.model_args))
            eval_shape_zyx = train_eval_size(model).tolist()[::-1]

            eval_tracker = EvalTracker(eval_shape_zyx)
            load_data_ops = define_data_input(model, queue_batch=1)
            prepare_ffn(model)
            merge_summaries_op = tf.summary.merge_all()

            # if FLAGS.task == 0:
            #     save_flags()

            summary_writer = None
            saver = tf.train.Saver(keep_checkpoint_every_n_hours=0.25)
            scaffold = tf.train.Scaffold(saver=saver)
            with tf.train.MonitoredTrainingSession(
                    master=FLAGS.master,
                    is_chief=(FLAGS.task == 0),
                    save_summaries_steps=None,
                    save_checkpoint_secs=300,
                    config=tf.ConfigProto(
                        log_device_placement=False,
                        allow_soft_placement=True,
                        gpu_options=tf.GPUOptions(
                            allow_growth=True,
                            # visible_device_list=FLAGS.visible_gpus
                        )),
                    checkpoint_dir=train_dir,
                    scaffold=scaffold) as sess:

                eval_tracker.sess = sess
                step = int(sess.run(model.global_step))

                if FLAGS.task > 0:
                    # To avoid early instabilities when using multiple replicas, we use
                    # a launch schedule where new replicas are brought online gradually.
                    logging.info('Delaying replica start.')
                    while step < FLAGS.replica_step_delay * FLAGS.task:
                        time.sleep(5.0)
                        step = int(sess.run(model.global_step))
                else:
                    summary_writer = tf.summary.FileWriterCache.get(train_dir)
                    summary_writer.add_session_log(
                        tf.summary.SessionLog(
                            status=tf.summary.SessionLog.START), step)

                fov_shifts_xyz = list(model.shifts)
                if FLAGS.shuffle_moves:
                    random.shuffle(fov_shifts_xyz)

                # Policy_map is a dict mapping a fov_policy to a callable that
                # generates offsets, given a model and a seed as inputs. Note that
                # 'fixed' is the policy used for training, and max_pred_moves is the
                # policy used for inference.

                policy_map = {
                    'fixed': partial(fixed_offsets, fov_shifts=fov_shifts_xyz),
                    'max_pred_moves': max_pred_offsets
                }

                # JG: batch_it contains (seed, image, label, weights), where each is of
                # shape [b, z, y, x, 1]
                batch_it = get_batch(lambda: sess.run(load_data_ops),
                                     eval_tracker, model, FLAGS.batch_size,
                                     policy_map[FLAGS.fov_policy])

                t_last = time.time()

                while not sess.should_stop() and step < FLAGS.max_steps:
                    # Run summaries periodically.
                    t_curr = time.time()
                    if t_curr - t_last > FLAGS.summary_rate_secs and FLAGS.task == 0:
                        summ_op = merge_summaries_op
                        t_last = t_curr
                    else:
                        summ_op = None

                    seed, patches, labels, weights = next(batch_it)
                    # JG: weights, labels, patches, and seed all have
                    # shape [b, z, y, x, 1] at this stage.

                    feed_dict = {
                        model.loss_weights: weights,
                        model.labels: labels,
                        model.input_patches: patches,
                        model.input_seed: seed,
                    }

                    if is_adversary_update_step(step, model):
                        # Update the adversary
                        step, summ = run_adversary_training_step(
                            sess, model, summ_op, feed_dict=feed_dict)
                    else:
                        # Update the FFN model
                        updated_seed, step, summ = run_training_step(
                            sess, model, summ_op, feed_dict=feed_dict)

                        # Save prediction results in the original seed array so that
                        # they can be used in subsequent steps.
                        mask.update_at(seed, (0, 0, 0), updated_seed)

                    # Record summaries.
                    if summ is not None:
                        logging.info('Saving summaries.')
                        summ = tf.Summary.FromString(summ)

                        # Compute a loss over the whole training patch (i.e. more than a
                        # single-step field of view of the network). This quantifies the
                        # quality of the final object mask.
                        summ.value.extend(eval_tracker.get_summaries())
                        eval_tracker.reset()

                        assert summary_writer is not None
                        summary_writer.add_summary(summ, step)

            if summary_writer is not None:
                summary_writer.flush()