Esempio n. 1
0
    def from_trial(
        trial_inst: det.Trial,
        context: det.TrialContext,
        env: det.EnvContext,
        workloads: workload.Stream,
        load_path: Optional[pathlib.Path],
        rendezvous_info: det.RendezvousInfo,
        hvd_config: horovod.HorovodContext,
    ) -> det.TrialController:
        check.is_instance(
            context, keras.TFKerasTrialContext,
            "TFKerasTrialController needs a TFKerasTrialContext")
        context = cast(keras.TFKerasTrialContext, context)

        check.is_instance(trial_inst, TFKerasTrial,
                          "TFKerasTrialController needs a TFKerasTrial")
        trial = cast(TFKerasTrial, trial_inst)

        session = TFKerasTrialController._configure_session(
            env, hvd_config, trial.session_config())

        training_data = keras._adapt_data_from_data_loader(
            input_data=trial.build_training_data_loader(),
            batch_size=context.get_per_slot_batch_size(),
        )

        validation_data = keras._adapt_data_from_data_loader(
            input_data=trial.build_validation_data_loader(),
            batch_size=context.get_per_slot_batch_size(),
        )

        trial.build_model()
        check.is_not_none(context.model, "Please call wrap_model(...).")

        check.is_not_none(context.compile_args,
                          "Please call model.compile(...).")
        compile_args = cast(inspect.BoundArguments, context.compile_args)

        TFKerasTrialController.compile_model(context=context,
                                             compile_args=compile_args,
                                             env=env,
                                             hvd_config=hvd_config)

        tf_keras_callbacks = trial.keras_callbacks()

        return TFKerasTrialController(
            context.model,
            session,
            keras.TFKerasTrainConfig(training_data, validation_data,
                                     tf_keras_callbacks),
            context,
            env,
            workloads,
            load_path,
            rendezvous_info,
            hvd_config,
        )
Esempio n. 2
0
    def from_trial(
        cls: Type["TFKerasTrialController"],
        trial_inst: det.Trial,
        context: det.TrialContext,
        env: det.EnvContext,
        workloads: Optional[workload.Stream] = None,
    ) -> det.TrialController:
        check.is_instance(
            context, keras.TFKerasTrialContext, "TFKerasTrialController needs a TFKerasTrialContext"
        )
        context = cast(keras.TFKerasTrialContext, context)

        check.is_instance(trial_inst, TFKerasTrial, "TFKerasTrialController needs a TFKerasTrial")
        trial = cast(TFKerasTrial, trial_inst)

        # Keras only supports horovod backend for distributed training
        session = cls._configure_session(
            env, trial.session_config(), use_horovod=context.distributed.size > 1
        )

        training_data = keras._adapt_data_from_data_loader(
            input_data=trial.build_training_data_loader(),
            batch_size=context.get_per_slot_batch_size(),
        )

        validation_data = keras._adapt_data_from_data_loader(
            input_data=trial.build_validation_data_loader(),
            batch_size=context.get_per_slot_batch_size(),
        )

        trial.build_model()
        check.is_not_none(context.model, "Please call wrap_model(...).")

        check.is_not_none(context.compile_args, "Please call model.compile(...).")
        compile_args = cast(inspect.BoundArguments, context.compile_args)

        cls.compile_model(context=context, compile_args=compile_args, env=env)

        tf_keras_callbacks = trial.keras_callbacks()

        return cls(
            context.model,
            session,
            keras.TFKerasTrainConfig(training_data, validation_data, tf_keras_callbacks),
            trial,
            context,
            env,
            workloads,
        )
Esempio n. 3
0
    def from_trial(
        trial_inst: det.Trial,
        context: det.TrialContext,
        env: det.EnvContext,
        workloads: workload.Stream,
        load_path: Optional[pathlib.Path],
        rendezvous_info: det.RendezvousInfo,
        hvd_config: horovod.HorovodContext,
    ) -> det.TrialController:
        check.is_instance(
            context,
            keras.TFKerasTrialContext,
            "TFKerasTrialController needs a TFKerasTrialContext",
        )
        context = cast(keras.TFKerasTrialContext, context)

        check.is_instance(trial_inst, TFKerasTrial,
                          "TFKerasTrialController needs a TFKerasTrial")
        trial = cast(TFKerasTrial, trial_inst)

        session = TFKerasTrialController._configure_session(
            env, hvd_config, trial.session_config())

        training_x, training_y, training_sample_weight = keras._get_x_y_and_sample_weight(
            input_data=trial.build_training_data_loader())
        training_data = keras._adapt_keras_data(
            x=training_x,
            y=training_y,
            sample_weight=training_sample_weight,
            batch_size=context.get_per_slot_batch_size(),
            drop_leftovers=True,
        )

        val_x, val_y, val_sample_weight = keras._get_x_y_and_sample_weight(
            input_data=trial.build_validation_data_loader())
        validation_data = keras._adapt_keras_data(
            x=val_x,
            y=val_y,
            sample_weight=val_sample_weight,
            batch_size=context.get_per_slot_batch_size(),
            drop_leftovers=False,
        )

        trial.build_model()
        check.is_not_none(context.model, "Please call wrap_model(...).")

        check.is_not_none(context.compile_args,
                          "Please call model.compile(...).")
        compile_args = cast(inspect.BoundArguments, context.compile_args)

        (
            context.model,
            compile_args.arguments["optimizer"],
        ) = keras._get_multi_gpu_model_and_optimizer(
            pre_compiled_model=context.model,
            optimizer=compile_args.arguments["optimizer"],
            env=env,
            hvd_config=hvd_config,
            profile_frequency=env.experiment_config.profile_frequency(),
            profile_filename=DeterminedProfiler.OUTPUT_FILENAME,
        )

        if hvd_config.use and version.parse(
                tf.__version__) >= version.parse("2.0.0"):
            logging.info(
                "Calling `model.compile(...)` with `experimental_run_tf_function=False` to ensure "
                "TensorFlow calls `optimizer.get_gradients()` to compute gradients."
            )
            context.model.compile(*compile_args.args,
                                  **compile_args.kwargs,
                                  experimental_run_tf_function=False)
        else:
            context.model.compile(*compile_args.args, **compile_args.kwargs)

        tf_keras_callbacks = trial.keras_callbacks()

        return TFKerasTrialController(
            context.model,
            session,
            keras.TFKerasTrainConfig(training_data, validation_data,
                                     tf_keras_callbacks),
            context,
            env,
            workloads,
            load_path,
            rendezvous_info,
            hvd_config,
        )
    def from_trial(
        trial_inst: det.Trial,
        context: det.TrialContext,
        env: det.EnvContext,
        workloads: workload.Stream,
        load_path: Optional[pathlib.Path],
        rendezvous_info: det.RendezvousInfo,
        hvd_config: horovod.HorovodContext,
    ) -> det.TrialController:
        check.is_instance(
            context, keras.TFKerasTrialContext,
            "TFKerasTrialController needs a TFKerasTrialContext")
        context = cast(keras.TFKerasTrialContext, context)

        check.is_instance(trial_inst, TFKerasTrial,
                          "TFKerasTrialController needs a TFKerasTrial")
        trial = cast(TFKerasTrial, trial_inst)

        session = TFKerasTrialController._configure_session(
            env, hvd_config, trial.session_config())

        training_data_loader = trial.build_training_data_loader()
        validation_data_loader = trial.build_validation_data_loader()

        trial.build_model()
        check.is_not_none(context.model, "Please call wrap_model(...).")

        training_x, training_y, training_sample_weight = keras._get_x_y_and_sample_weight(
            input_data=training_data_loader)
        training_data = keras._adapt_keras_data(
            x=training_x,
            y=training_y,
            sample_weight=training_sample_weight,
            batch_size=context.get_per_slot_batch_size(),
            use_multiprocessing=context._fit_use_multiprocessing,
            workers=context._fit_workers,
            max_queue_size=context._fit_max_queue_size,
            drop_leftovers=True,
        )

        val_x, val_y, val_sample_weight = keras._get_x_y_and_sample_weight(
            input_data=validation_data_loader)
        validation_data = keras._adapt_keras_data(
            x=val_x,
            y=val_y,
            sample_weight=val_sample_weight,
            batch_size=context.get_per_slot_batch_size(),
            use_multiprocessing=context._fit_use_multiprocessing,
            workers=context._fit_workers,
            max_queue_size=context._fit_max_queue_size,
            drop_leftovers=False,
        )

        check.is_not_none(context.compile_args,
                          "Please call model.compile(...).")
        compile_args = cast(inspect.BoundArguments, context.compile_args)

        TFKerasTrialController.compile_model(context=context,
                                             compile_args=compile_args,
                                             env=env,
                                             hvd_config=hvd_config)

        tf_keras_callbacks = trial.keras_callbacks()

        return TFKerasTrialController(
            context.model,
            session,
            keras.TFKerasTrainConfig(training_data, validation_data,
                                     tf_keras_callbacks),
            context,
            env,
            workloads,
            load_path,
            rendezvous_info,
            hvd_config,
        )