示例#1
0
    def test_raise_error_with_custom_init_fn_in_eval(self):
        def model_fn(features, labels, mode):
            _, _ = features, labels

            def init_fn(scaffold, session):
                _, _ = scaffold, session

            return estimator_lib.EstimatorSpec(
                mode,
                loss=constant_op.constant(3.),
                scaffold=training.Scaffold(init_fn=init_fn),
                train_op=constant_op.constant(5.),
                eval_metric_ops={
                    'mean_of_features':
                    metrics_lib.mean(constant_op.constant(2.))
                })

        estimator = estimator_lib.Estimator(model_fn=model_fn)

        def input_fn():
            return dataset_ops.Dataset.range(10)

        evaluator = hooks_lib.InMemoryEvaluatorHook(estimator, input_fn)
        with self.assertRaisesRegexp(ValueError,
                                     'does not support custom init_fn'):
            evaluator.begin()
示例#2
0
    def test_raise_error_with_saveables_other_than_global_variables(self):
        def model_fn(features, labels, mode):
            _, _ = features, labels
            w = variables.Variable(
                initial_value=[0.],
                trainable=False,
                collections=[ops.GraphKeys.SAVEABLE_OBJECTS])
            init_op = control_flow_ops.group(
                [w.initializer,
                 training.get_global_step().initializer])
            return estimator_lib.EstimatorSpec(
                mode,
                loss=constant_op.constant(3.),
                scaffold=training.Scaffold(init_op=init_op),
                train_op=constant_op.constant(5.),
                eval_metric_ops={
                    'mean_of_features':
                    metrics_lib.mean(constant_op.constant(2.))
                })

        estimator = estimator_lib.Estimator(model_fn=model_fn)

        def input_fn():
            return dataset_ops.Dataset.range(10)

        evaluator = hooks_lib.InMemoryEvaluatorHook(estimator, input_fn)
        with self.assertRaisesRegexp(ValueError, 'does not support saveables'):
            estimator.train(input_fn, hooks=[evaluator])
  def testTrainedMnistSavedModel(self):
    """Test mnist SavedModel, trained with dummy data and small steps."""
    # Build classifier
    classifier = estimator.Estimator(
        model_fn=model_fn,
        params={
            "data_format": "channels_last"  # tflite format
        })

    # Train and pred for serving
    classifier.train(input_fn=dummy_input_fn, steps=2)
    image = array_ops.placeholder(dtypes.float32, [None, 28, 28])
    pred_input_fn = estimator.export.build_raw_serving_input_receiver_fn({
        "image": image,
    })

    # Export SavedModel
    saved_model_dir = os.path.join(self.get_temp_dir(), "mnist_savedmodel")
    classifier.export_savedmodel(saved_model_dir, pred_input_fn)

    # Convert to tflite and test output
    saved_model_name = os.listdir(saved_model_dir)[0]
    saved_model_final_dir = os.path.join(saved_model_dir, saved_model_name)

    # TODO(zhixianyan): no need to limit output_arrays to `Softmax'
    # once b/74205001 fixed and argmax implemented in tflite.
    result = convert_saved_model.freeze_saved_model(
        saved_model_dir=saved_model_final_dir,
        input_arrays=None,
        input_shapes=None,
        output_arrays=["Softmax"],
        tag_set=None,
        signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)

    self.assertTrue(result)
示例#4
0
    def test_uses_latest_variable_value(self):
        def model_fn(features, labels, mode):
            _ = labels
            step = training.get_global_step()
            w = variable_scope.get_variable(
                'w',
                shape=[],
                initializer=init_ops.zeros_initializer(),
                dtype=dtypes.int64)
            if estimator_lib.ModeKeys.TRAIN == mode:
                # to consume features, we have control dependency
                with ops.control_dependencies([features]):
                    step_inc = state_ops.assign_add(training.get_global_step(),
                                                    1)
                with ops.control_dependencies([step_inc]):
                    assign_w_to_step_plus_2 = w.assign(step + 2)
                return estimator_lib.EstimatorSpec(
                    mode,
                    loss=constant_op.constant(3.),
                    train_op=assign_w_to_step_plus_2)
            if estimator_lib.ModeKeys.EVAL == mode:
                # to consume features, we have control dependency
                with ops.control_dependencies([features]):
                    loss = constant_op.constant(5.)
                return estimator_lib.EstimatorSpec(
                    mode,
                    loss=loss,
                    # w is constant in each step, so the mean.
                    # w = 0 if step==0 else step+2
                    eval_metric_ops={'mean_of_const': metrics_lib.mean(w)})

        estimator = estimator_lib.Estimator(model_fn=model_fn)

        def input_fn():
            return dataset_ops.Dataset.range(10)

        evaluator = hooks_lib.InMemoryEvaluatorHook(estimator,
                                                    input_fn,
                                                    every_n_iter=4)
        estimator.train(input_fn, hooks=[evaluator])

        self.assertTrue(os.path.isdir(estimator.eval_dir()))
        step_keyword_to_value = summary_step_keyword_to_value_mapping(
            estimator.eval_dir())
        # w = 0 if step==0 else step+2
        self.assertEqual(0, step_keyword_to_value[0]['mean_of_const'])
        self.assertEqual(6, step_keyword_to_value[4]['mean_of_const'])
        self.assertEqual(12, step_keyword_to_value[10]['mean_of_const'])
示例#5
0
    def test_should_not_conflict_with_existing_predictions(self):
        def input_fn():
            return {'x': [[3.], [5.]], 'id': [[101], [102]]}

        def model_fn(features, mode):
            del features
            global_step = training.get_global_step()
            return estimator_lib.EstimatorSpec(
                mode,
                loss=constant_op.constant([5.]),
                predictions={'x': constant_op.constant([5.])},
                train_op=global_step.assign_add(1))

        estimator = estimator_lib.Estimator(model_fn=model_fn)
        estimator.train(input_fn=input_fn, steps=1)

        estimator = extenders.forward_features(estimator)
        with self.assertRaisesRegexp(ValueError, 'Cannot forward feature key'):
            next(estimator.predict(input_fn=input_fn))
示例#6
0
    def test_runs_eval_metrics(self):
        def model_fn(features, labels, mode):
            _ = labels
            if estimator_lib.ModeKeys.TRAIN == mode:
                with ops.control_dependencies([features]):
                    train_op = state_ops.assign_add(training.get_global_step(),
                                                    1)
                return estimator_lib.EstimatorSpec(
                    mode, loss=constant_op.constant(3.), train_op=train_op)
            if estimator_lib.ModeKeys.EVAL == mode:
                return estimator_lib.EstimatorSpec(
                    mode,
                    loss=constant_op.constant(5.),
                    eval_metric_ops={
                        'mean_of_features': metrics_lib.mean(features)
                    })

        estimator = estimator_lib.Estimator(model_fn=model_fn)

        def input_fn():
            return dataset_ops.Dataset.range(10)

        evaluator = hooks_lib.InMemoryEvaluatorHook(estimator,
                                                    input_fn,
                                                    every_n_iter=4)
        estimator.train(input_fn, hooks=[evaluator])

        self.assertTrue(os.path.isdir(estimator.eval_dir()))
        step_keyword_to_value = summary_step_keyword_to_value_mapping(
            estimator.eval_dir())

        # 4.5 = sum(range(10))/10
        # before training
        self.assertEqual(4.5, step_keyword_to_value[0]['mean_of_features'])
        # intervals (every_n_iter=4)
        self.assertEqual(4.5, step_keyword_to_value[4]['mean_of_features'])
        self.assertEqual(4.5, step_keyword_to_value[8]['mean_of_features'])
        # end
        self.assertEqual(4.5, step_keyword_to_value[10]['mean_of_features'])
        self.assertEqual(set([0, 4, 8, 10]), set(step_keyword_to_value.keys()))
示例#7
0
params['epochs_to_reduce_at'] = [40, 120]
params['initial_learning_rate'] = 0.1
params['epoch_reduction_factor'] = 0.1
params['mixup_val'] = 0.7
pprint(params)

# get data loader
cifar_data = CIFAR10(batch_size=params['batch_size'],
                     mixup_val=params['mixup_val'])

run_config = estimator.RunConfig(
    save_checkpoints_steps=params['steps_per_epoch'],
    save_summary_steps=500,
    keep_checkpoint_max=5)

fixup_estimator = estimator.Estimator(model_dir=model_dir,
                                      model_fn=model.model_fn,
                                      params=params,
                                      config=run_config)

# training/evaluation specs for run
train_spec = estimator.TrainSpec(input_fn=cifar_data.build_training_data,
                                 max_steps=params['total_steps_train'])
eval_spec = estimator.EvalSpec(input_fn=cifar_data.build_validation_data,
                               steps=None,
                               throttle_secs=params['throttle_eval'],
                               start_delay_secs=0)

# run train and evaluate
estimator.train_and_evaluate(fixup_estimator, train_spec, eval_spec)
示例#8
0
    def compile(self, run_config, envir_config, model_fn):
        import logging
        import sys
        # Config logging
        tf.logging.set_verbosity(logging.INFO)
        handlers = [
            logging.FileHandler(os.path.join(run_config.store_dir,
                                             'main.log')),
            logging.StreamHandler(sys.stdout)
        ]
        logging.getLogger('tensorflow').handlers = handlers
        tf.logging.set_verbosity(tf.logging.INFO)
        self.test_dir = run_config.test_dir
        start = time.time()
        session_config = self._create_session_config(envir_config)
        exe_config = estimator.RunConfig(
            model_dir=run_config.model_dir,
            session_config=session_config,
            save_summary_steps=run_config.save_summary_steps,
            keep_checkpoint_max=run_config.keep_checkpoint_max,
            save_checkpoints_steps=run_config.save_checkpoints_steps,
            keep_checkpoint_every_n_hours=run_config.
            keep_checkpoint_every_n_hours)

        def _model_fn(features, labels, mode):
            if mode == estimator.ModeKeys.TRAIN:
                loss, accuracy, var_list, hooks = model_fn[mode](features,
                                                                 labels,
                                                                 run_config)
                # Learning rate
                # todo organize lr and optimizer configuration
                learning_rate = run_config.learning_rate
                if run_config.scheduler == 'exponential':
                    learning_rate = tf.train.exponential_decay(
                        learning_rate=learning_rate,
                        global_step=tf.train.get_or_create_global_step(),
                        decay_steps=run_config.decay_steps,
                        decay_rate=run_config.decay_rate,
                        staircase=run_config.staircase)
                elif run_config.scheduler == 'step':
                    learning_rate = step_lr(boundaries=run_config.boundaries,
                                            values=run_config.lr_values)
                else:
                    learning_rate = tf.constant(learning_rate,
                                                dtype=tf.float32)
                tf.summary.scalar('lr', learning_rate)
                # Optimizer
                optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
                # Hook
                hooks += [
                    LoggerHook(learning_rate=learning_rate,
                               log_frequency=run_config.log_frequency,
                               batch_size=run_config.batch_size,
                               loss=loss,
                               accuracy=accuracy,
                               metric_names=run_config.class_names)
                ]
                if hasattr(run_config, 'lr_multiplier'):
                    train_op = multi_lr(optimizer, loss, var_list,
                                        run_config.lr_multiplier)
                else:
                    train_op = optimizer.minimize(
                        loss,
                        global_step=tf.train.get_global_step(),
                        var_list=var_list)
                return estimator.EstimatorSpec(estimator.ModeKeys.TRAIN,
                                               loss=loss,
                                               training_hooks=hooks,
                                               train_op=train_op)
            elif mode == estimator.ModeKeys.EVAL:
                loss, metrics = model_fn[mode](features, labels, run_config)
                return estimator.EstimatorSpec(estimator.ModeKeys.EVAL,
                                               loss=loss,
                                               eval_metric_ops=metrics)
            elif mode == estimator.ModeKeys.PREDICT:
                predictions = model_fn[mode](features, run_config)
                return estimator.EstimatorSpec(estimator.ModeKeys.PREDICT,
                                               predictions)
            else:
                raise ValueError("Expect mode in [train, eval, infer],"
                                 "but received {}".format(mode))

        self.executor = estimator.Estimator(model_fn=_model_fn,
                                            model_dir=run_config.model_dir,
                                            config=exe_config)
        self.steps = run_config.steps
        print(">>>>>>>>>>>>Finish Compiling in {:.2}s>>>>>>>>>>>>".format(
            time.time() - start))
        print(envir_config)
        print(run_config)
        flag = input('Is all config correct? (yes/no)')
        if flag not in ['yes', 'y', '1']:
            exit(-1)