コード例 #1
0
def check_record(args):

    with tf.Session(config=config) as sess:
        dataset = tfmodels.TFRecordImageMask(training_record=args.record_path,
                                             sess=sess,
                                             crop_size=args.crop_size,
                                             ratio=args.ratio,
                                             batch_size=None,
                                             prefetch=None,
                                             shuffle_buffer=128,
                                             n_classes=args.classes,
                                             preprocess=[],
                                             repeat=False,
                                             n_threads=args.threads)

        dataset.print_info()

        idx = 0
        while True:
            try:
                x, y = sess.run([dataset.image_op, dataset.mask_op])
                idx += 1
            except tf.errors.OutOfRangeError:
                print('Reached end of {} examples'.format(idx))
                break
コード例 #2
0
def check_record(args):
  
  with tf.Session(config=config) as sess:
    dataset = tfmodels.TFRecordImageMask(
      training_record = args.record_path,
      sess = sess,
      crop_size = args.crop_size,
      ratio = args.ratio,
      batch_size = None,
      prefetch = None,
      shuffle_buffer = 64,
      n_classes = args.classes,
      preprocess = [],
      repeat = False, 
      n_threads = args.threads)

    dataset.print_info()

    idx = 0
    classtotals = {k: 0 for k in np.arange(args.classes)}
    while True:
      try:
        x, y = sess.run([dataset.image_op, dataset.mask_op])
        idx += 1
        batchtotals = check_y(y, args.classes)
        for k in np.arange(args.classes):
          classtotals[k] += batchtotals[k]

        if idx == 100:
          break

      except tf.errors.OutOfRangeError:
        print('Reached end of {} examples'.format(idx))
        break

      except:
        break

    total = np.sum([i for _,i in classtotals.items()])
    for k in np.arange(args.classes):
      c = classtotals[k]
      print('Class {}: {} {}'.format(k, c, c/total))
コード例 #3
0
def main(batch_size, image_ratio, crop_size, n_epochs, lr_0, basedir, restore_path):
    n_classes = N_CLASSES
    x_dims = [int(crop_size*image_ratio),
              int(crop_size*image_ratio),
              3]

    iterations = 5000  ## Define epoch length
    epochs = n_epochs ## if epochs=500, then we get 500 * 10 = 2500 times over the data
    snapshot_epochs = 5
    test_epochs = 25
    step_start = 0

    prefetch = 512
    threads = 8

    # basedir = '5x'
    log_dir, save_dir, debug_dir, infer_dir = tfmodels.make_experiment(
        basedir=basedir, remove_old=True)

    gamma = 1e-5
    # lr_0 = 1e-5
    def learning_rate(lr_0, gamma, step):
        return lr_0 * np.exp(-gamma*step)

    with tf.Session(config=config) as sess:
        dataset = tfmodels.TFRecordImageMask(
            training_record = train_record_path,
            # testing_record = test_record_path,
            sess = sess,
            crop_size = crop_size,
            ratio = image_ratio,
            batch_size = batch_size,
            prefetch = prefetch,
            shuffle_buffer = 256,
            n_classes = N_CLASSES,
            as_onehot = True,
            mask_dtype = tf.uint8,
            img_channels = 3,
            # preprocess = [],
            n_threads = threads)
        dataset.print_info()

        model = Training( sess = sess,
            dataset = dataset,
            global_step = step_start,
            learning_rate = lr_0,
            log_dir = log_dir,
            save_dir = save_dir,
            summary_iters = 200,
            summary_image_iters = iterations,
            summary_image_n = 4,
            max_to_keep = 25,
            n_classes = N_CLASSES,
            # summarize_grads = True,
            # summarize_vars = True,
            x_dims = x_dims)
        model.print_info()

        if restore_path is not None:
            model.restore(restore_path)

        ## --------------------- Optimizing Loop -------------------- ##
        print('Start')

        try:
            ## Re-initialize training step to have a clean learning rate curve
            training_step = 0
            print('Starting with model at step {}'.format(model.global_step))
            for epx in range(1, epochs):
                epoch_start = time.time()
                epoch_lr = learning_rate(lr_0, gamma, training_step)
                for itx in range(iterations):
                    training_step += 1
                    # model.train_step(lr=learning_rate(lr_0, gamma, training_step))
                    model.train_step(lr=epoch_lr)
                    # model.train_step(lr=1e-4)

                print('Epoch [{}] step [{}] time elapsed [{}]s'.format(
                    epx, model.global_step, time.time()-epoch_start))

                # if epx % test_epochs == 0:
                #     model.test(keep_prob=1.0)

                if epx % snapshot_epochs == 0:
                    model.snapshot()

        except Exception as e:
            print('Caught exception')
            print(e.__doc__)
            print(e.message)
        finally:
            model.snapshot()
            print('Stopping threads')
            print('Done')
コード例 #4
0
ファイル: train.py プロジェクト: slkarkar/gleason_grade
def main(args):
    x_dims = [
        int(args.crop_size * args.image_ratio),
        int(args.crop_size * args.image_ratio), 3
    ]

    snapshot_epochs = 5
    test_epochs = 25
    step_start = 0

    prefetch = 512
    threads = 8

    log_dir, save_dir, debug_dir, infer_dir = tfmodels.make_experiment(
        basedir=args.basedir, remove_old=True)
    write_arguments(args)

    gamma = 1e-5

    def learning_rate(lr_0, gamma, step):
        return lr_0 * np.exp(-gamma * step)

    with tf.Session(config=config) as sess:
        dataset = tfmodels.TFRecordImageMask(
            training_record=args.train_record,
            testing_record=args.val_record,
            sess=sess,
            crop_size=args.crop_size,
            ratio=args.image_ratio,
            batch_size=args.batch_size,
            prefetch=prefetch,
            shuffle_buffer=prefetch,
            n_classes=args.n_classes,
            as_onehot=True,
            mask_dtype=tf.uint8,
            img_channels=3,
            # preprocess = [],
            n_threads=threads)
        dataset.print_info()

        model_class = get_model(args.model_type,
                                sess,
                                None,
                                None,
                                training=True)
        model = model_class(
            sess=sess,
            dataset=dataset,
            global_step=step_start,
            learning_rate=args.lr,
            log_dir=log_dir,
            save_dir=save_dir,
            summary_iters=200,
            summary_image_iters=args.iterations,
            summary_image_n=4,
            max_to_keep=25,
            n_classes=args.n_classes,
            # summarize_grads = True,
            # summarize_vars = True,
            x_dims=x_dims)
        model.print_info()

        if args.restore_path is not None:
            model.restore(args.restore_path)

        ## --------------------- Optimizing Loop -------------------- ##
        print('Start')

        try:
            ## Re-initialize training step to have a clean learning rate curve
            training_step = 0
            print('Starting with model at step {}'.format(model.global_step))
            for epx in range(1, args.epochs):
                epoch_start = time.time()
                epoch_lr = learning_rate(args.lr, gamma, training_step)
                for itx in range(args.iterations):
                    training_step += 1
                    model.train_step(lr=epoch_lr)

                print('Epoch [{}] step [{}] time elapsed [{}]s'.format(
                    epx, model.global_step,
                    time.time() - epoch_start))

                # if epx % test_epochs == 0:
                #   model.test(keep_prob=1.0)

                if epx % snapshot_epochs == 0:
                    model.snapshot()

        except Exception as e:
            print('Caught exception')
            print(e.__doc__)
            print(e.message)
        finally:
            model.snapshot()
            print('Stopping threads')
            print('Done')