Example #1
0
 def set_callbacks(self, callbacks: list, model: k.Model = None):
   """Configures callbacks for use in various training loops.
       """
   callback_list = CallbackList(callbacks)
   callback_list.set_model(self.train_model if model == None else model)
   callback_list.model.stop_training = False
   self.callback_list = callback_list
Example #2
0
    def _init_callbacks(self, callbacks_list):
        if callbacks_list is None:
            self._callbacks_list = CallbackList([])
            self._callbacks_list.set_model(tf.keras.Model())
        else:
            self._callbacks_list = callbacks_list

        self._callbacks_list.model.stop_training = False
Example #3
0
    def fit_generator(self, generator, epochs=1,
                      validation_data=None,
                      callbacks=None,
                      verbose=True):
        method = self._model.optimizer.method
        x0 = self._collect_weights()
        history = History()
        _callbacks = [BaseLogger(stateful_metrics=self._model.metrics_names)]
        _callbacks += (callbacks or []) + [history]
        callback_list = CallbackList(_callbacks)
        callback_list.set_model(self._model)
        callback_list.set_params({
            'epochs': epochs,
            'verbose': False,
            'metrics': list(self._model.metrics_names),
        })
        state = {
            'epoch': 0,
            'verbose': verbose,
            'callbacks': callback_list,
            'in_epoch': False,
            'epoch_logs': {},
        }
        min_options = {
            'maxiter': epochs,
            'maxfun': epochs*10,
            'ftol': 1e-10,
            'gtol': 1e-10,
            'eps': 1e-8,
        }

        val_generator = None
        if validation_data is not None:
            if isinstance(validation_data, keras.utils.Sequence):
                val_generator = validation_data
            elif isinstance(validation_data, tuple) and len(validation_data) == 2:
                val_generator = GeneratorWrapper(*validation_data)

        def on_iteration_end(xk):
            cb = state['callbacks']
            if val_generator is not None:
                self._validate(xk, val_generator, state)
            cb.on_epoch_end(state['epoch'], state['epoch_logs'])
            # if state['verbose']:
            #     epoch_logs = state['epoch_logs']
            #     print('epoch: ', state['epoch'],
            #           ', '.join([' {0}: {1:.3e}'.format(k, v) for k, v in epoch_logs.items()]))
            state['epoch'] += 1
            state['in_epoch'] = False
            state['epoch_logs'] = {}

        callback_list.on_train_begin()
        result = minimize(
            self._fun_generator, x0, method=method, jac=True, options=min_options,
            callback=on_iteration_end, args=(generator, state))
        self._update_weights(result['x'])
        callback_list.on_train_end()
        return history
def prepare_callbacks(g,
                      d,
                      callbacks,
                      n_epochs=1,
                      n_batches=1,
                      include_d_metrics=False):
    # all the callback stuff is from https://github.com/keras-team/keras/blob/master/keras/engine/training_generator.py
    # NOTE: see if saving the weights of d_on_g is enough
    # Prepare display labels.
    out_labels = g.metrics_names
    out_labels = [_replace_label_first_underscore(l) for l in out_labels]
    # we only want to validate on the output of g
    val_out_labels = ['val_' + n for n in out_labels if g.name in n]
    callback_metrics = out_labels + val_out_labels
    if include_d_metrics:
        d_metrics_names = d.metrics_names
        d_metrics_fake = ['d_training/' + l + '_fake' for l in d_metrics_names]
        d_metrics_real = ['d_training/' + l + '_real' for l in d_metrics_names]
        d_metrics_names = d_metrics_fake + d_metrics_real
        callback_metrics += d_metrics_names
    # prepare callbacks
    g.history = cbks.History()
    _callbacks = [cbks.BaseLogger(stateful_metrics=g.metrics_names[1:])]
    _callbacks += (callbacks or []) + [g.history]
    callbacks = CallbackList(_callbacks)

    # it's possible to callback a different model than self:
    callback_model = g._get_callback_model()

    callbacks.set_model(callback_model)
    callbacks.set_params({
        'epochs': n_epochs,
        'steps': n_batches,
        'verbose': 0,
        # 'do_validation': do_validation, to set when using validation data
        'metrics': callback_metrics,
    })
    if not include_d_metrics:
        d_metrics_fake, d_metrics_real = None, None
    return callbacks, out_labels, val_out_labels, d_metrics_fake, d_metrics_real
Example #5
0
    def fit(self,
            x=None,
            y=None,
            batch_size=None,
            epochs=1,
            verbose=1,
            initial_epoch=0,
            validation_split=0.,
            validation_data=None,
            shuffle=True,
            callbacks=None):
        """

        :param x: Numpy array of training data (if the model has a single input), or list of Numpy arrays (if the model has multiple inputs).If input layers in the model are named, you can also pass a
            dictionary mapping input names to Numpy arrays.
        :param y: Numpy array of target (label) data (if the model has a single output), or list of Numpy arrays (if the model has multiple outputs).
        :param batch_size: Integer or `None`. Number of samples per gradient update. If unspecified, `batch_size` will default to 256.
        :param epochs: Integer. Number of epochs to train the model. An epoch is an iteration over the entire `x` and `y` data provided. Note that in conjunction with `initial_epoch`, `epochs` is to be understood as "final epoch". The model is not trained for a number of iterations given by `epochs`, but merely until the epoch of index `epochs` is reached.
        :param verbose: Integer. 0, 1, or 2. Verbosity mode. 0 = silent, 1 = progress bar, 2 = one line per epoch.
        :param initial_epoch: Integer. Epoch at which to start training (useful for resuming a previous training run).
        :param validation_split: Float between 0 and 1. Fraction of the training data to be used as validation data. The model will set apart this fraction of the training data, will not train on it, and will evaluate the loss and any model metrics on this data at the end of each epoch. The validation data is selected from the last samples in the `x` and `y` data provided, before shuffling.
        :param validation_data: tuple `(x_val, y_val)` or tuple `(x_val, y_val, val_sample_weights)` on which to evaluate the loss and any model metrics at the end of each epoch. The model will not be trained on this data. `validation_data` will override `validation_split`.
        :param shuffle: Boolean. Whether to shuffle the order of the batches at the beginning of each epoch.
        :param callbacks: List of `deepctr_torch.callbacks.Callback` instances. List of callbacks to apply during training and validation (if ). See [callbacks](https://tensorflow.google.cn/api_docs/python/tf/keras/callbacks). Now available: `EarlyStopping` , `ModelCheckpoint`

        :return: A `History` object. Its `History.history` attribute is a record of training loss values and metrics values at successive epochs, as well as validation loss values and validation metrics values (if applicable).
        """
        if isinstance(x, dict):
            x = [x[feature] for feature in self.feature_index]

        do_validation = False
        if validation_data:
            do_validation = True
            if len(validation_data) == 2:
                val_x, val_y = validation_data
                val_sample_weight = None
            elif len(validation_data) == 3:
                val_x, val_y, val_sample_weight = validation_data  # pylint: disable=unpacking-non-sequence
            else:
                raise ValueError(
                    'When passing a `validation_data` argument, '
                    'it must contain either 2 items (x_val, y_val), '
                    'or 3 items (x_val, y_val, val_sample_weights), '
                    'or alternatively it could be a dataset or a '
                    'dataset or a dataset iterator. '
                    'However we received `validation_data=%s`' %
                    validation_data)
            if isinstance(val_x, dict):
                val_x = [val_x[feature] for feature in self.feature_index]

        elif validation_split and 0. < validation_split < 1.:
            do_validation = True
            if hasattr(x[0], 'shape'):
                split_at = int(x[0].shape[0] * (1. - validation_split))
            else:
                split_at = int(len(x[0]) * (1. - validation_split))
            x, val_x = (slice_arrays(x, 0,
                                     split_at), slice_arrays(x, split_at))
            y, val_y = (slice_arrays(y, 0,
                                     split_at), slice_arrays(y, split_at))

        else:
            val_x = []
            val_y = []
        for i in range(len(x)):
            if len(x[i].shape) == 1:
                x[i] = np.expand_dims(x[i], axis=1)

        train_tensor_data = Data.TensorDataset(
            torch.from_numpy(np.concatenate(x, axis=-1)), torch.from_numpy(y))
        if batch_size is None:
            batch_size = 256

        model = self.train()
        loss_func = self.loss_func
        optim = self.optim

        if self.gpus:
            print('parallel running on these gpus:', self.gpus)
            model = torch.nn.DataParallel(model, device_ids=self.gpus)
            batch_size *= len(
                self.gpus)  # input `batch_size` is batch_size per gpu
        else:
            print(self.device)

        train_loader = DataLoader(dataset=train_tensor_data,
                                  shuffle=shuffle,
                                  batch_size=batch_size)

        sample_num = len(train_tensor_data)
        steps_per_epoch = (sample_num - 1) // batch_size + 1

        # configure callbacks
        callbacks = (callbacks or []) + [self.history]  # add history callback
        callbacks = CallbackList(callbacks)
        callbacks.on_train_begin()
        callbacks.set_model(self)
        if not hasattr(callbacks, 'model'):
            callbacks.__setattr__('model', self)
        callbacks.model.stop_training = False

        # Train
        print(
            "Train on {0} samples, validate on {1} samples, {2} steps per epoch"
            .format(len(train_tensor_data), len(val_y), steps_per_epoch))
        for epoch in range(initial_epoch, epochs):
            callbacks.on_epoch_begin(epoch)
            epoch_logs = {}
            start_time = time.time()
            loss_epoch = 0
            total_loss_epoch = 0
            train_result = {}
            try:
                with tqdm(enumerate(train_loader), disable=verbose != 1) as t:
                    for _, (x_train, y_train) in t:
                        x = x_train.to(self.device).float()
                        y = y_train.to(self.device).float()

                        y_pred = model(x).squeeze()

                        optim.zero_grad()
                        loss = loss_func(y_pred, y.squeeze(), reduction='sum')
                        reg_loss = self.get_regularization_loss()

                        total_loss = loss + reg_loss + self.aux_loss

                        loss_epoch += loss.item()
                        total_loss_epoch += total_loss.item()
                        total_loss.backward()
                        optim.step()

                        if verbose > 0:
                            for name, metric_fun in self.metrics.items():
                                if name not in train_result:
                                    train_result[name] = []
                                train_result[name].append(
                                    metric_fun(
                                        y.cpu().data.numpy(),
                                        y_pred.cpu().data.numpy().astype(
                                            "float64")))

            except KeyboardInterrupt:
                t.close()
                raise
            t.close()

            # Add epoch_logs
            epoch_logs["loss"] = total_loss_epoch / sample_num
            for name, result in train_result.items():
                epoch_logs[name] = np.sum(result) / steps_per_epoch

            if do_validation:
                eval_result = self.evaluate(val_x, val_y, batch_size)
                for name, result in eval_result.items():
                    epoch_logs["val_" + name] = result
            # verbose
            if verbose > 0:
                epoch_time = int(time.time() - start_time)
                print('Epoch {0}/{1}'.format(epoch + 1, epochs))

                eval_str = "{0}s - loss: {1: .4f}".format(
                    epoch_time, epoch_logs["loss"])

                for name in self.metrics:
                    eval_str += " - " + name + \
                                ": {0: .4f}".format(epoch_logs[name])

                if do_validation:
                    for name in self.metrics:
                        eval_str += " - " + "val_" + name + \
                                    ": {0: .4f}".format(epoch_logs["val_" + name])
                print(eval_str)
            callbacks.on_epoch_end(epoch, epoch_logs)
            if self.stop_training:
                break

        callbacks.on_train_end()

        return self.history
Example #6
0
    def _configure_callbacks(self, user_callbacks: Optional[List]) -> None:
        """
        If we pass a callbacks parameter to model.fit() or model.evaluate() which is a
        pre-constructed CallbackList, Keras will not alter it.  We can use this property to
        configure the exact callback order that we want in our system.

        The implementation is based closely on from the real
        tf.keras.callbacks.configure_callbacks(), with the following differences:

          - We always assume we have the original Callbacks list.
          - We prepend and append additional Determined and Horovod callbacks
          - We create a det.keras.CallbackList instead of the normal tf.keras one.
        """

        callbacks = user_callbacks or []
        check.is_instance(
            callbacks,
            list,
            "the callbacks parameter of model.fit() or model.eval() must be a list of Callbacks",
        )

        if self.env.experiment_config.get_records_per_epoch() is None:
            for cb in callbacks:
                if util.is_overridden(
                        cb.on_epoch_end,
                        tf.keras.callbacks.Callback) and not getattr(
                            cb, "_skip_epoch_end_check", False):
                    if isinstance(cb, keras.callbacks.Callback):
                        # New callbacks must obey the rules.
                        raise AssertionError(
                            "it is unsupported to use a Callback that defines on_epoch_end "
                            f"({type(cb).__name__}) without setting the records_per_epoch value "
                            "in the experiment config")
                    else:
                        # Pre-existing callbacks only get a warning.
                        logging.warning(
                            "It is unsupported to use a Callback that defines on_epoch_end "
                            f"({type(cb).__name__})without setting the records_per_epoch value in "
                            "the experiment config. Training will continue but on_epoch_end will "
                            "never be called.")

        # Standard post-callback from the real configure_callbacks().
        # Note that we are not including BaseLogger since it is only for averaging metrics over an
        # entire epoch, and we don't report any metrics in on_epoch_end at all.
        self.model.history = keras.callbacks._DeterminedHistory()
        callbacks = callbacks + [self.model.history]

        if self.context._fit_verbose:
            # Our implementation of verbose=True.
            callbacks = [keras.callbacks._DeterminedProgress()] + callbacks

        # Calculate batches per epoch.  We can only handle batches per epoch, not records per epoch,
        # because we would have to communicate after every batch to know how many records were in
        # each batch on each worker in order to trigger on_epoch_end callbacks correctly.
        batches_per_epoch = None
        records_per_epoch = self.env.experiment_config.get_records_per_epoch()
        if records_per_epoch is not None:
            batches_per_epoch = records_per_epoch // self.context.get_global_batch_size(
            )

        # We wrap all of the callbacks in a single Multiplexer.
        self.multiplexer = TrialControllerMultiplexer(
            self,
            callbacks,
            self.is_chief,
            self.batch_size,
            batches_per_epoch,
            self.multiplexer_load_state,
        )
        callbacks = [self.multiplexer]

        if self.hvd_config.use:
            # Horovod synchronization of initial variables should happen even before we enter our
            # control loop, in case we have an initial validation requested.
            callbacks = [hvd.callbacks.BroadcastGlobalVariablesCallback(0)
                         ] + callbacks

        # The remainder of Determined control logic is done with a custom CallbackList
        self.callback_list = CallbackList(callbacks)

        # Disable timing of callbacks in some versions of keras. This can fail in some corner-cases
        # because CallbackList is not designed to allow some callbacks to call other callbacks, and
        # they can interact very poorly.
        if hasattr(self.callback_list, "_timing"):
            self.callback_list._timing["on_train_batch_begin"] = True
            self.callback_list._timing["on_train_batch_end"] = True
            self.callback_list._timing["on_test_batch_begin"] = True
            self.callback_list._timing["on_test_batch_end"] = True
            self.callback_list._timing["on_predict_batch_begin"] = True
            self.callback_list._timing["on_predict_batch_end"] = True

        # callback_model is the model given to callbacks, where we should be checking for
        # stop_training.  In horovod dtrain or non-dtrain, it should always be self.model.
        callback_model = self.model._get_callback_model()
        self.callback_list.set_model(callback_model)

        # Fill in bogus values for most of these... some of them are very complex to calculate.
        set_callback_parameters(
            self.callback_list,
            self.model,
            do_validation=False,
            batch_size=self.batch_size,
            epochs=None,
            steps_per_epoch=None,
            samples=None,
            verbose=False,
            mode=ModeKeys.TRAIN,
        )

        self.callback_list.model.stop_training = False
Example #7
0
class TFKerasTrialController(det.LoopTrialController):
    @staticmethod
    def supports_averaging_training_metrics() -> bool:
        return True

    @staticmethod
    def pre_execute_hook(env: det.EnvContext,
                         hvd_config: horovod.HorovodContext) -> None:
        # Initialize the correct horovod.
        if hvd_config.use:
            hvd.require_horovod_type("tensorflow.keras",
                                     "TFKerasTrial is in use.")
            hvd.init()

        # Start with a clean graph.
        tf.compat.v1.reset_default_graph()

        TFKerasTrialController._set_random_seeds(env.trial_seed)

        # For the Native API we must configure the Session before running user code.
        if env.experiment_config.native_enabled():
            session_config = tf.compat.v1.ConfigProto(
                allow_soft_placement=True)
            TFKerasTrialController._configure_session(env, hvd_config,
                                                      session_config)

    @staticmethod
    def _set_random_seeds(seed: int) -> None:
        # Set identical random seeds on all training processes. When using horovod, each worker will
        # start at a unique offset in the dataset, ensuring it's processing a unique training batch.
        random.seed(seed)
        np.random.seed(seed)
        tf.compat.v1.set_random_seed(seed)

    @staticmethod
    def _configure_session(
        env: det.EnvContext,
        hvd_config: horovod.HorovodContext,
        session_config: tf.compat.v1.ConfigProto,
    ) -> Optional[tf.compat.v1.Session]:
        if not tf.executing_eagerly():
            session_config.gpu_options.allow_growth = True
            if hvd_config.use:
                # We launch a horovod process per GPU. Each process
                # needs to bind to a unique GPU.
                session_config.gpu_options.visible_device_list = str(
                    hvd.local_rank())

            session = tf.compat.v1.Session(
                graph=tf.compat.v1.get_default_graph(), config=session_config)

            tf.compat.v1.keras.backend.set_session(session)

            return session
        else:
            gpus = tf.config.experimental.list_physical_devices("GPU")

            if len(gpus) > 0:
                local_rank = hvd.local_rank() if hvd_config.use else 0
                gpu = gpus[local_rank]
                tf.config.experimental.set_visible_devices(gpu, "GPU")
                tf.config.experimental.set_memory_growth(gpu, True)

            return None

    @staticmethod
    def compile_model(
        context: keras.TFKerasContext,
        compile_args: inspect.BoundArguments,
        env: det.EnvContext,
        hvd_config: horovod.HorovodContext,
    ) -> None:
        context.model = keras._get_multi_gpu_model_if_using_native_parallel(
            pre_compiled_model=context.model,
            env=env,
            hvd_config=hvd_config,
        )

        if "optimizer" in compile_args.arguments:
            # For backwards compatibility we check if an optimizer is passed as part
            # of the compile call. If `wrap_optimizer()` is used, we will ignore this
            # this optimizer.
            compile_args.arguments[
                "optimizer"] = context._process_optimizer_from_compile(
                    compile_args.arguments["optimizer"])

        if hvd_config.use and version.parse("2.0.0") <= version.parse(
                tf.__version__) < version.parse("2.2.0"):
            logging.info(
                "Calling `model.compile(...)` with `experimental_run_tf_function=False` to ensure "
                "TensorFlow calls `optimizer.get_gradients()` to compute gradients."
            )

            context.model.compile(*compile_args.args,
                                  **compile_args.kwargs,
                                  experimental_run_tf_function=False)
        else:
            context.model.compile(*compile_args.args, **compile_args.kwargs)

    @staticmethod
    def from_trial(
        trial_inst: det.Trial,
        context: det.TrialContext,
        env: det.EnvContext,
        workloads: workload.Stream,
        load_path: Optional[pathlib.Path],
        rendezvous_info: det.RendezvousInfo,
        hvd_config: horovod.HorovodContext,
    ) -> det.TrialController:
        check.is_instance(
            context, keras.TFKerasTrialContext,
            "TFKerasTrialController needs a TFKerasTrialContext")
        context = cast(keras.TFKerasTrialContext, context)

        check.is_instance(trial_inst, TFKerasTrial,
                          "TFKerasTrialController needs a TFKerasTrial")
        trial = cast(TFKerasTrial, trial_inst)

        session = TFKerasTrialController._configure_session(
            env, hvd_config, trial.session_config())

        training_data = keras._adapt_data_from_data_loader(
            input_data=trial.build_training_data_loader(),
            batch_size=context.get_per_slot_batch_size(),
        )

        validation_data = keras._adapt_data_from_data_loader(
            input_data=trial.build_validation_data_loader(),
            batch_size=context.get_per_slot_batch_size(),
        )

        trial.build_model()
        check.is_not_none(context.model, "Please call wrap_model(...).")

        check.is_not_none(context.compile_args,
                          "Please call model.compile(...).")
        compile_args = cast(inspect.BoundArguments, context.compile_args)

        TFKerasTrialController.compile_model(context=context,
                                             compile_args=compile_args,
                                             env=env,
                                             hvd_config=hvd_config)

        tf_keras_callbacks = trial.keras_callbacks()

        return TFKerasTrialController(
            context.model,
            session,
            keras.TFKerasTrainConfig(training_data, validation_data,
                                     tf_keras_callbacks),
            context,
            env,
            workloads,
            load_path,
            rendezvous_info,
            hvd_config,
        )

    @staticmethod
    def from_native(
        context: det.NativeContext,
        env: det.EnvContext,
        workloads: workload.Stream,
        load_path: Optional[pathlib.Path],
        rendezvous_info: det.RendezvousInfo,
        hvd_config: horovod.HorovodContext,
    ) -> det.TrialController:
        check.is_instance(
            context,
            keras.TFKerasNativeContext,
            "TFKerasTrialController needs a TFKerasSprinkleContext",
        )
        context = cast(keras.TFKerasNativeContext, context)

        check.is_not_none(context.model, "Please call wrap_model(...).")

        check.is_not_none(context.compile_args,
                          "Please call model.compile(...).")
        check.is_not_none(
            context.train_config,
            "Please call model.fit(...) or model.fit_generator(...).")

        # For the Native API, we would break the user's model if we changed the session
        # right now, so we have to trust the user did not modify what we set previously.
        #
        # TODO(ryan): Fix this, probably with a function for configuring the backend session.
        session = tf.compat.v1.keras.backend.get_session()

        compile_args = cast(inspect.BoundArguments, context.compile_args)
        train_config = cast(keras.TFKerasTrainConfig, context.train_config)

        TFKerasTrialController.compile_model(context=context,
                                             compile_args=compile_args,
                                             env=env,
                                             hvd_config=hvd_config)

        return TFKerasTrialController(
            context.model,
            session,
            train_config,
            context,
            env,
            workloads,
            load_path,
            rendezvous_info,
            hvd_config,
        )

    def __init__(
        self,
        model: tf.keras.models.Model,
        session: tf.compat.v1.ConfigProto,
        train_config: keras.TFKerasTrainConfig,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        super().__init__(*args, **kwargs)

        self.model = model
        self.session = session

        # Configure optimizers, done for backwards compatibility.
        self.context._select_optimizers()

        keras._check_if_aggregation_frequency_will_work(
            model=self.model, hvd_config=self.hvd_config)

        self.training_data = train_config.training_data
        self.validation_data = train_config.validation_data

        # Support the deprecated SequenceAdapter API.
        if isinstance(self.training_data, keras.SequenceAdapter):
            self.context._configure_fit(
                workers=self.training_data.workers,
                use_multiprocessing=self.training_data.use_multiprocessing,
                max_queue_size=self.training_data.max_queue_size,
            )
            # Use the provided Sequence directly.
            self.training_data = self.training_data.sequence
        if isinstance(self.validation_data, keras.SequenceAdapter):
            # Ignore these settings and use the same settings as for the fit call.
            self.validation_data = self.validation_data.sequence

        self._check_training_data()
        self._check_validation_data()

        self.enqueuers = []  # type: List[keras._Enqueuer]

        # If a load path is provided, load weights and restore the data location.
        self._load()

        self._configure_callbacks(train_config.callbacks)

        self.train_response_func = None  # type: Optional[workload.ResponseFunc]
        self.train_workload_metrics = []  # type: List[Dict[str, Any]]
        self.train_workload_batches = 0
        self.train_workload_inputs = 0
        self.train_workload_len = 0
        self.test_inputs = 0

    def _check_training_data(self) -> None:
        cacheable_used = self.context.experimental.get_train_cacheable(
        ).is_decorator_used()
        wrap_used = self.context.dataset_initialized

        # Non-tf.data.Datasets should not have used the data layer.
        if not isinstance(self.training_data, tf.data.Dataset):
            if cacheable_used:
                raise det.errors.InvalidExperimentException(
                    "Pass in a tf.data.Dataset object for training data if using "
                    "context.experimental.cache_train_dataset().", )
            return

        # You can't use data layer and the wrap_dataset.
        if cacheable_used and wrap_used:
            raise det.errors.InvalidExperimentException(
                "Please do not use: context.wrap_dataset(dataset) if using "
                "context.experimental.cache_train_dataset() and "
                "context.experimental.cache_validation_dataset().", )

        # You must use either data layer or wrap_dataset.
        if not cacheable_used and not wrap_used:
            raise det.errors.InvalidExperimentException(
                "Please use either context.wrap_dataset(dataset) or "
                "context.experimental.cache_train_dataset() for tf.data.dataset inputs"
            )

    def _check_validation_data(self) -> None:
        cacheable_used = self.context.experimental.get_validation_cacheable(
        ).is_decorator_used()
        wrap_used = self.context.dataset_initialized

        # Non-tf.data.Datasets should not have used the data layer.
        if not isinstance(self.validation_data, tf.data.Dataset):
            if cacheable_used:
                raise det.errors.InvalidExperimentException(
                    "Pass in a tf.data.Dataset object for validation data if using "
                    "context.experimental.cache_validation_dataset().", )
            return

        # You can't use data layer and the wrap_dataset.
        if cacheable_used and wrap_used:
            raise det.errors.InvalidExperimentException(
                "Please do not use: context.wrap_dataset(dataset) if using "
                "context.experimental.cache_train_dataset() and "
                "context.experimental.cache_validation_dataset().", )

        # You must use either data layer or wrap_dataset.
        if not cacheable_used and not wrap_used:
            raise det.errors.InvalidExperimentException(
                "Please use either context.wrap_dataset(dataset) or "
                "context.experimental.cache_validation_dataset() for tf.data.dataset inputs"
            )

    def _configure_callbacks(self, user_callbacks: Optional[List]) -> None:
        """
        If we pass a callbacks parameter to model.fit() or model.evaluate() which is a
        pre-constructed CallbackList, Keras will not alter it.  We can use this property to
        configure the exact callback order that we want in our system.

        The implementation is based closely on from the real
        tf.keras.callbacks.configure_callbacks(), with the following differences:

          - We always assume we have the original Callbacks list.
          - We prepend and append additional Determined and Horovod callbacks
          - We create a det.keras.CallbackList instead of the normal tf.keras one.
        """

        callbacks = user_callbacks or []
        check.is_instance(
            callbacks,
            list,
            "the callbacks parameter of model.fit() or model.eval() must be a list of Callbacks",
        )

        if self.env.experiment_config.get_records_per_epoch() is None:
            for cb in callbacks:
                if util.is_overridden(
                        cb.on_epoch_end,
                        tf.keras.callbacks.Callback) and not getattr(
                            cb, "_skip_epoch_end_check", False):
                    if isinstance(cb, keras.callbacks.Callback):
                        # New callbacks must obey the rules.
                        raise AssertionError(
                            "it is unsupported to use a Callback that defines on_epoch_end "
                            f"({type(cb).__name__}) without setting the records_per_epoch value "
                            "in the experiment config")
                    else:
                        # Pre-existing callbacks only get a warning.
                        logging.warning(
                            "It is unsupported to use a Callback that defines on_epoch_end "
                            f"({type(cb).__name__})without setting the records_per_epoch value in "
                            "the experiment config. Training will continue but on_epoch_end will "
                            "never be called.")

        # Standard post-callback from the real configure_callbacks().
        # Note that we are not including BaseLogger since it is only for averaging metrics over an
        # entire epoch, and we don't report any metrics in on_epoch_end at all.
        self.model.history = keras.callbacks._DeterminedHistory()
        callbacks = callbacks + [self.model.history]

        if self.context._fit_verbose:
            # Our implementation of verbose=True.
            callbacks = [keras.callbacks._DeterminedProgress()] + callbacks

        # Calculate batches per epoch.  We can only handle batches per epoch, not records per epoch,
        # because we would have to communicate after every batch to know how many records were in
        # each batch on each worker in order to trigger on_epoch_end callbacks correctly.
        batches_per_epoch = None
        records_per_epoch = self.env.experiment_config.get_records_per_epoch()
        if records_per_epoch is not None:
            batches_per_epoch = records_per_epoch // self.context.get_global_batch_size(
            )

        # We wrap all of the callbacks in a single Multiplexer.
        self.multiplexer = TrialControllerMultiplexer(
            self,
            callbacks,
            self.is_chief,
            self.batch_size,
            batches_per_epoch,
            self.multiplexer_load_state,
        )
        callbacks = [self.multiplexer]

        if self.hvd_config.use:
            # Horovod synchronization of initial variables should happen even before we enter our
            # control loop, in case we have an initial validation requested.
            callbacks = [hvd.callbacks.BroadcastGlobalVariablesCallback(0)
                         ] + callbacks

        # The remainder of Determined control logic is done with a custom CallbackList
        self.callback_list = CallbackList(callbacks)

        # Disable timing of callbacks in some versions of keras. This can fail in some corner-cases
        # because CallbackList is not designed to allow some callbacks to call other callbacks, and
        # they can interact very poorly.
        if hasattr(self.callback_list, "_timing"):
            self.callback_list._timing["on_train_batch_begin"] = True
            self.callback_list._timing["on_train_batch_end"] = True
            self.callback_list._timing["on_test_batch_begin"] = True
            self.callback_list._timing["on_test_batch_end"] = True
            self.callback_list._timing["on_predict_batch_begin"] = True
            self.callback_list._timing["on_predict_batch_end"] = True

        # callback_model is the model given to callbacks, where we should be checking for
        # stop_training.  In horovod dtrain or non-dtrain, it should always be self.model.
        callback_model = self.model._get_callback_model()
        self.callback_list.set_model(callback_model)

        # Fill in bogus values for most of these... some of them are very complex to calculate.
        set_callback_parameters(
            self.callback_list,
            self.model,
            do_validation=False,
            batch_size=self.batch_size,
            epochs=None,
            steps_per_epoch=None,
            samples=None,
            verbose=False,
            mode=ModeKeys.TRAIN,
        )

        self.callback_list.model.stop_training = False

    def _save_checkpoint(self, path: pathlib.Path) -> workload.Response:
        if not self.is_chief:
            return workload.Skipped()

        path.mkdir(parents=True, exist_ok=True)

        # Save model weights. We use `tf` format because `h5` does not support
        # models that subclass `tf.keras.Model` and define custom `call()`
        # and/or `train_step()` functions.
        self.model.save_weights(str(
            path.joinpath("determined-keras-model-weights")),
                                save_format="tf")

        # Save optimizer(s) weights.
        with h5py.File(path.joinpath("determined-keras-optimizer-weights.h5"),
                       "w") as h5file:
            for idx, optimizer in enumerate(self.context._optimizers):
                opt_group = h5file.create_group(f"optimizer-{idx}")
                save_optimizer_weights_to_hdf5_group(opt_group, optimizer)

        # Save RNG state.
        rng_state = get_rng_state()

        with open(path.joinpath("rng_state.pkl"), "wb") as f:
            pickle.dump(rng_state, f)

        # Save user code.
        det.util.write_user_code(path, self.env.on_cluster)

        # Save callback(s) state.
        callbacks_state = self.multiplexer._get_state()
        with path.joinpath("determined-callbacks.v1.pkl").open("wb") as f:
            pickle.dump(callbacks_state, f)

        self.multiplexer._checkpoint_end(path)

        return {
            "framework": f"tensorflow-{tf.__version__}",
            "format": "saved_weights"
        }

    def _load_model_weights(
            self, model_weights_checkpoint_path: pathlib.Path) -> None:
        logging.info(
            f"Restoring model weights from {model_weights_checkpoint_path}.")
        self.model.load_weights(str(model_weights_checkpoint_path))

    def _load_optimizers_weights(
            self, optimizer_weights_checkpoint_path: pathlib.Path) -> None:
        logging.info(
            f"Restoring optimizer weights from {optimizer_weights_checkpoint_path}."
        )
        with h5py.File(optimizer_weights_checkpoint_path, "r") as h5file:
            if "optimizer_weights" in h5file:
                load_optimizer_weights(self.model, h5file["optimizer_weights"],
                                       self.model.optimizer)
                return

            for idx, optimizer in enumerate(self.context._optimizers):
                if f"optimizer-{idx}" in h5file:
                    load_optimizer_weights(self.model,
                                           h5file[f"optimizer-{idx}"],
                                           optimizer)

    def _load_model_and_optimizer_weights_v1(self) -> None:
        self.load_path = cast(pathlib.Path, self.load_path)
        self._load_model_weights(
            self.load_path.joinpath("determined-keras-model"))
        self._load_optimizers_weights(
            self.load_path.joinpath("determined-keras-model"))

    def _load_model_and_optimizer_weights_v2(self) -> None:
        self.load_path = cast(pathlib.Path, self.load_path)
        self._load_model_weights(
            self.load_path.joinpath("determined-keras-model.h5"))
        self._load_optimizers_weights(
            self.load_path.joinpath("determined-keras-model.h5"))

    def _load_model_and_optimizer_weights_v3(self) -> None:
        self.load_path = cast(pathlib.Path, self.load_path)
        self._load_model_weights(
            self.load_path.joinpath("determined-keras-model-weights"))
        self._load_optimizers_weights(
            self.load_path.joinpath("determined-keras-optimizer-weights.h5"))

    def _load(self) -> None:
        self.multiplexer_load_state = None  # type: Optional[Dict]
        if not self.load_path:
            return

        # Find model code path, we check multiple naming conventions for backwards compatibility.
        if self.load_path.joinpath("determined-keras-model.h5").exists():
            self._load_model_and_optimizer_weights_v2()
        elif self.load_path.joinpath(
                "determined-keras-optimizer-weights.h5").exists():
            self._load_model_and_optimizer_weights_v3()
        else:
            self._load_model_and_optimizer_weights_v1()

        # Load RNG state.
        try:
            with open(self.load_path.joinpath("rng_state.pkl"), "rb") as f:
                rng_state = pickle.load(f)

            set_rng_state(rng_state)
        except IOError:
            logging.warning("Checkpoint did not include RNG state.")

        # Load callbacks.
        cb_state_path = self.load_path.joinpath("determined-callbacks.v1.pkl")
        if cb_state_path.exists():
            with cb_state_path.open("rb") as f:
                self.multiplexer_load_state = pickle.load(f)

    def run(self) -> None:
        try:
            self._launch_fit()
        except det.errors.WorkerFinishedGracefully:
            pass
        finally:
            self._stop_enqueuers()

    def _launch_fit(self) -> None:
        training_data = self.training_data

        if isinstance(training_data, tf.keras.utils.Sequence):
            # Handle args from fit(): shuffle, workers, use_multiprocessing, and max_queue_size.
            enqueuer = keras._build_enqueuer(
                sequence=training_data,
                workers=self.context._fit_workers,
                use_multiprocessing=self.context._fit_use_multiprocessing,
                max_queue_size=self.context._fit_max_queue_size,
                shard_rank=self.context.distributed.get_rank(),
                num_shards=self.context.distributed.get_size(),
                repeat=True,
                shuffle=self.context._fit_shuffle,
                shuffle_seed=self.context.get_trial_seed(),
                prior_batches_trained=self.env.initial_workload.
                total_batches_processed,
            )
            enqueuer.start()
            self.enqueuers.append(enqueuer)
            training_data = enqueuer.data()

        if isinstance(training_data, tf.data.Dataset):
            training_data = training_data.repeat()
            if self.context._fit_shuffle:
                logging.warning(
                    "You set shuffle=True for a tf.data.Dataset, which will be ignored. "
                    "Please call .shuffle() on your dataset instead.")

        self.model.fit(
            training_data,
            class_weight=self.context._fit_class_weight,
            callbacks=self.callback_list,
            shuffle=False,
            steps_per_epoch=sys.maxsize,
            epochs=IMPOSSIBLY_LARGE_EPOCHS,
            validation_split=0,
            verbose=0,
            workers=0,
        )

    def _launch_evaluate(self) -> Any:
        validation_data = self.validation_data
        steps = None

        if isinstance(validation_data, tf.keras.utils.Sequence):
            # Calculate the length of our validation shard.
            steps = len(validation_data)
            if self.context.distributed.get_size() > 1:
                size = self.context.distributed.get_size()
                rank = self.context.distributed.get_rank()
                steps = steps // size + (1 if steps % size > rank else 0)

            # Handle args from fit(): shuffle, workers, use_multiprocessing, and max_queue_size.
            enqueuer = keras._build_enqueuer(
                sequence=validation_data,
                workers=self.context._fit_workers,
                use_multiprocessing=self.context._fit_use_multiprocessing,
                max_queue_size=self.context._fit_max_queue_size,
                shard_rank=self.context.distributed.get_rank(),
                num_shards=self.context.distributed.get_size(),
                repeat=False,
                shuffle=False,
                shuffle_seed=0,
                prior_batches_trained=0,
            )
            enqueuer.start()
            self.enqueuers.append(enqueuer)
            validation_data = enqueuer.data()

        if isinstance(validation_data, tf.data.Dataset):
            # Handle validation_steps, which in Keras only applies to tf.data.Datasets.
            steps = self.context._fit_validation_steps

        # Starting in TF 2.2 users may define custom test_step() that do
        # not use the model metrics.
        use_model_metrics = not (
            version.parse(tf.__version__) >= version.parse("2.2.0")
            and is_tf2_enabled() and tf.executing_eagerly())
        evaluate_kwargs = {} if use_model_metrics else {"return_dict": True}

        if self.env.test_mode:
            steps = 1

        metrics_values = self.model.evaluate(
            validation_data,
            callbacks=self.callback_list,
            steps=steps,
            verbose=0,
            workers=0,
            **evaluate_kwargs,
        )
        logging.debug(
            f"Worker finished model.evaluate() with metrics: {metrics_values}."
        )

        # Clean up the enqueuer if we started one.
        if isinstance(self.validation_data, tf.keras.utils.Sequence):
            enqueuer.stop()
            self.enqueuers.remove(enqueuer)

            # A special side-effect of converting the keras sequence to a generator and passing
            # steps explicitly is that keras will exit our generator after N steps and the
            # Sequence.on_epoch_end() that normally runs after the last yield won't run at all
            # because the fit loop will call next() exactly `steps` times.  So we try to match the
            # exact keras behavior by manually calling on_epoch_end() here.
            self.validation_data.on_epoch_end()

        # If the model was compiled with metrics=None, metrics_value will be a single value.
        if not isinstance(metrics_values, (tuple, list, dict)):
            metrics_values = (metrics_values, )

        if use_model_metrics:
            metrics = make_logs(self.model, {},
                                metrics_values,
                                ModeKeys.TEST,
                                prefix="val_")
        else:
            check.is_instance(metrics_values, dict)
            metrics = {f"val_{k}": v for k, v in metrics_values.items()}

        return metrics

    def _control_loop(self) -> None:
        for wkld, args, response_func in self.workloads:
            logging.debug(f"Received wkld {wkld.kind} with args {args}.")
            if wkld.kind == workload.Workload.Kind.RUN_STEP:
                # Configure the state for a training step.
                self.train_response_func = response_func
                self.train_workload_batches = 0
                self.train_workload_metrics = []
                self.train_workload_len = wkld.num_batches
                self.multiplexer.set_batches_requested(wkld.num_batches)
                break
            elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
                try:
                    response_func(
                        det.util.wrap_metrics(
                            self._compute_validation_metrics(),
                            self.context.get_stop_requested(),
                            invalid_hp=False,
                        ))
                except det.InvalidHP as e:
                    logging.info(
                        "Invalid hyperparameter exception in trial validation step: {}"
                        .format(e))
                    response_func(
                        util.wrap_metrics(
                            {},
                            self.context.get_stop_requested(),
                            invalid_hp=True,
                        ))
            elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
                check.len_eq(args, 1)
                check.is_instance(args[0], pathlib.Path)
                path = cast(pathlib.Path, args[0])
                response_func(self._save_checkpoint(path))
            elif wkld.kind == workload.Workload.Kind.TERMINATE:
                response_func({} if self.is_chief else workload.Skipped())
                self.multiplexer._corrected_train_end()
                raise det.errors.WorkerFinishedGracefully
            else:
                raise AssertionError(f"Unknown workload kind {wkld.kind}.")

    def _allreduce_logs(self, logs: Dict) -> Dict:
        if not self.hvd_config.use:
            return logs
        # Reduce logs in key-sorted to be deterministic across workers.
        keys = sorted(logs)
        logging.debug(
            f"all-reducing logs on worker {hvd.rank()} for {len(keys)} keys {keys}."
        )
        return {
            key:
            np.array(self._hvd_allreduce(logs[key], average=True, name=key))
            for key in keys
        }

    def _hvd_allreduce(self, value: Any, average: bool, name: str) -> Any:
        # The signature of our horovod allreduce changed after we rebased onto 0.21.
        hvd_sig = inspect.signature(hvd.allreduce)
        horovod_kwargs = {
            "value": value,
            "name": name,
        }  # type: Dict[str, Any]

        if "op" in hvd_sig.parameters:
            horovod_kwargs["op"] = hvd.Average if average else hvd.Sum

            # average has not yet been removed but it's deprecated. It defaults
            # to true and horovod does not support specifying an op while having
            # average be not None.
            if "average" in hvd_sig.parameters:
                horovod_kwargs["average"] = None
        else:
            horovod_kwargs["average"] = average

        return hvd.allreduce(**horovod_kwargs)

    def _convert_possible_tensor(self, possible_tensor: Any) -> Any:
        if isinstance(possible_tensor, EagerTensor):
            # Horovod and / or TensorFlow may promote scalars to tensors in eager mode.
            return possible_tensor.numpy()
        return possible_tensor

    def _post_train_batch_end(self, num_inputs: int, logs: Dict) -> None:
        # Remove default keras metrics we aren't interested in like "batch" and "size".
        self.train_workload_metrics.append({
            k: self._convert_possible_tensor(v)
            for k, v in logs.items() if k not in {"batch", "size"}
        })
        self.train_workload_inputs += num_inputs
        self.train_workload_batches += 1
        if self.train_workload_batches != self.train_workload_len:
            return

        if self.train_response_func is None:
            raise AssertionError(
                "Callback should avoid calling model.predict(), "
                "as this will affect Determined training behavior", )

        if self.hvd_config.use:
            num_inputs = self._hvd_allreduce(num_inputs,
                                             average=False,
                                             name="train_num_inputs")
            num_inputs = self._convert_possible_tensor(num_inputs)

        # Return only the latest metrics, which is the running average for all trained batches in
        # the step (Keras does not report individual logs, only running averages at any point).
        final_metrics = self.train_workload_metrics[-1]
        if self.env.experiment_config.averaging_training_metrics_enabled():
            final_metrics = self._allreduce_logs(final_metrics)

        self.multiplexer._train_workload_end(final_metrics)
        self._stop_training_check()

        if self.is_chief:
            # Don't use det.util.make_metrics, because our batch metrics are not raw metrics.

            response = {
                "metrics": {
                    "num_inputs": num_inputs,
                    "batch_metrics": self.train_workload_metrics,
                    "avg_metrics": final_metrics,
                },
                "stop_requested": self.context.get_stop_requested(),
                "invalid_hp": False,
            }
            self.train_response_func(response)
        else:
            self.train_response_func(workload.Skipped())

        self.train_response_func = None

        self._control_loop()

        # Always reset metrics before starting a new training step.
        self.model.reset_metrics()

    def _compute_validation_metrics(self) -> workload.Response:
        metrics = self._launch_evaluate()
        num_inputs = self.multiplexer.get_test_inputs()

        if self.hvd_config.use:
            # Use a global ZMQ barrier here because we have observed cases where hvd.allreduce
            # may hang when called minutes apart by different workers which may happen if
            # workers complete evaluation at different speeds.
            self._global_barrier()

            num_inputs = hvd.allreduce(num_inputs,
                                       average=False,
                                       name="validation_num_inputs")
            if isinstance(num_inputs, EagerTensor):
                # Horovod will promote an int to a tensor in eager mode.
                num_inputs = num_inputs.numpy()

        metrics = self._allreduce_logs(metrics)
        check.gt(len(metrics), 0)

        self.multiplexer._test_end(metrics)

        if not self.is_chief:
            return workload.Skipped()

        return {"num_inputs": num_inputs, "validation_metrics": metrics}

    def _stop_training_check(self) -> None:
        # Detect when users set stop_training and convert it to a set_stop_requested.
        if self.multiplexer.model.stop_training:
            if self.is_chief:
                self.multiplexer.model.stop_training = False
                self.context.set_stop_requested(True)
            else:
                logging.debug(
                    "cancelling model.stop_training on non-chief worker")
                self.multiplexer.model.stop_training = True

    def _stop_enqueuers(self) -> None:
        for enqueuer in self.enqueuers:
            enqueuer.stop()
Example #8
0
class _TaskDispatcher(object):
    """Creates and dispatches Tasks. Keep track of a Task's lifecycle."""
    def __init__(
        self,
        training_shards,
        evaluation_shards,
        prediction_shards,
        records_per_task,
        num_epochs,
        callbacks_list=None,
    ):
        """
        Arguments:
            training_shards: A dictionary from RecordIO file name to the
                number of training records.
            evaluation_shards: A dictionary from RecordIO file name to
                the number of evaluation records.
            prediction_shards: A dictionary from RecordIO file name to
                the number of prediction records.
            records_per_task: The number of records per task.
            num_epochs: The total number of epochs for the tasks where
                an epoch is a complete iteration over the shards.
            callbacks_list: The Keras CallbacksList object to contain all
                callback instances.
        """
        self._lock = threading.Lock()

        self._num_epochs = num_epochs
        self._epoch = 0
        self._training_shards = training_shards
        self._evaluation_shards = evaluation_shards
        self._prediction_shards = prediction_shards
        self._records_per_task = records_per_task
        self._init_callbacks(callbacks_list)

        self._todo = []
        # dictionary from task id to Task.
        self._doing = {}
        self._task_id = 0
        self._eval_todo = []
        self._evaluation_service = None

        # Callback list to invoke after all tasks complete.
        self._tasks_done_deferred_callbacks = []

        self._job_counters = {}
        self._task_retry_count = {}

        if self._training_shards:
            logger.info("Starting epoch %d", self._epoch)
            self.create_tasks(elasticdl_pb2.TRAINING)
        elif self._evaluation_shards:
            self.create_tasks(elasticdl_pb2.EVALUATION)
        elif self._prediction_shards:
            self.create_tasks(elasticdl_pb2.PREDICTION)

    def _init_callbacks(self, callbacks_list):
        if callbacks_list is None:
            self._callbacks_list = CallbackList([])
            self._callbacks_list.set_model(tf.keras.Model())
        else:
            self._callbacks_list = callbacks_list

        self._callbacks_list.model.stop_training = False

    def reset_job_counters(self, task_type):
        """Return record number in specific task_type"""
        self._job_counters[task_type] = JobCounter()

    def create_tasks(self, task_type, model_version=-1):
        logger.info(
            "Creating a new set of %s tasks for model version %d",
            elasticdl_pb2._TASKTYPE.values_by_number[task_type].name.lower(),
            model_version,
        )
        self.reset_job_counters(task_type)
        if task_type == elasticdl_pb2.TRAINING:
            shards = self._training_shards
        elif task_type == elasticdl_pb2.EVALUATION:
            shards = self._evaluation_shards
        else:
            shards = self._prediction_shards
        tasks = []
        num_records_before_create = self._job_counters[task_type].total_records
        # Note that a shard may contain records for multiple tasks.
        for (
                shard_name,
            (start_ind_this_shard, num_records_this_shard),
        ) in shards.items():
            max_ind_this_shard = start_ind_this_shard + num_records_this_shard
            self._job_counters[
                task_type].total_records += num_records_this_shard
            for start_ind_this_task in range(
                    start_ind_this_shard,
                    max_ind_this_shard,
                    self._records_per_task,
            ):
                end_ind_this_task = min(
                    start_ind_this_task + self._records_per_task,
                    max_ind_this_shard,
                )

                # Note that only records in [start, end) of this task
                # will be consumed later in the worker that handles
                # this task.
                tasks.append(
                    _Task(
                        shard_name=shard_name,
                        start=start_ind_this_task,
                        end=end_ind_this_task,
                        type=task_type,
                        model_version=model_version,
                    ))
        if task_type == elasticdl_pb2.TRAINING:
            random.shuffle(tasks)
            self._todo.extend(tasks)
        elif task_type == elasticdl_pb2.EVALUATION:
            self._eval_todo.extend(tasks)
        else:
            self._todo.extend(tasks)
        logger.info("%d tasks created with total of %d records." % (
            len(tasks),
            self._job_counters[task_type].total_records -
            num_records_before_create,
        ))

    def get_eval_task(self, worker_id):
        """Return next evaluation (task_id, Task) tuple"""
        with self._lock:
            if not self._eval_todo:
                return -1, None
            self._task_id += 1
            task = self._eval_todo.pop()
            self._doing[self._task_id] = (worker_id, task, time.time())
            return self._task_id, task

    def _create_train_end_callback_task(self):
        """
        Build one instance of training end task and add it to todo list.
        Because we need create a dataset to build the model for
        SavedModelExporter to execute on_train_end,we include
        a shard of data in this task.
        """
        if not self._training_shards:
            return

        self.reset_job_counters(elasticdl_pb2.TRAIN_END_CALLBACK)
        shards = self._training_shards
        assert shards is not None

        (shard_name, (start_ind_this_shard,
                      num_records_this_shard)) = next(iter(shards.items()))
        start_ind_this_task = start_ind_this_shard
        end_ind_this_task = start_ind_this_shard + min(self._records_per_task,
                                                       num_records_this_shard)

        # Use the first shard of data to do the SavedModel work
        train_end_callback_task = _Task(
            shard_name=shard_name,
            start=start_ind_this_task,
            end=end_ind_this_task,
            type=elasticdl_pb2.TRAIN_END_CALLBACK,
        )

        self._todo.append(train_end_callback_task)

    def add_deferred_callback_create_train_end_task(self):
        self._tasks_done_deferred_callbacks.append(
            lambda: self._create_train_end_callback_task())

    def invoke_deferred_callback(self):
        """
        Pop a callback from the list and invoke it.
        If the callback list is empty, return False directly.
        """
        if not self._tasks_done_deferred_callbacks:
            return False

        with self._lock:
            if not self._tasks_done_deferred_callbacks:
                return False

            callback = self._tasks_done_deferred_callbacks.pop()
            callback()
            return True

    def get(self, worker_id):
        """Return next (task_id, Task) tuple"""

        with self._lock:
            # TODO: check if task queue doesn't have training task,
            #       to avoid the queue is overwhelmed by evaluation tasks.
            if (not self._todo and not self._callbacks_list.model.stop_training
                    and self._epoch < self._num_epochs - 1):
                # Start a new epoch
                self._epoch += 1
                self.create_tasks(elasticdl_pb2.TRAINING)
                logger.info("Starting epoch %d", self._epoch)

            if not self._todo:
                # No more tasks
                return -1, None

            self._task_id += 1
            task = self._todo.pop()
            # TODO: Handle timeout of tasks.
            self._doing[self._task_id] = (worker_id, task, time.time())

            return self._task_id, task

    def report(self, request, success):
        """Report if the task is successful or not"""

        task_id = request.task_id
        evaluation_task_completed = False
        with self._lock:
            worker_id, task, start_time = self._doing.pop(
                task_id, (-1, None, -1))
            if task:
                self._job_counters[
                    task.type].failed_records += request.exec_counters.get(
                        TaskExecCounterKey.FAIL_COUNT, 0)
            if not task:
                logger.warning("Unknown task_id: %d" % task_id)
            elif not success:
                logger.warning("Task %d of %s failed " % (task_id, task.type))
                if not self.check_exceed_max_task_retries(task):
                    if task.type in [
                            elasticdl_pb2.TRAINING,
                            elasticdl_pb2.TRAIN_END_CALLBACK,
                    ]:
                        self._todo.append(task)
                    else:
                        self._eval_todo.append(task)
            elif (task.type == elasticdl_pb2.EVALUATION
                  and self._evaluation_service is not None):
                evaluation_task_completed = True
            else:
                self._call_on_task_end(task)
                logger.info(
                    "Task:%d completed, %d remaining tasks",
                    task_id,
                    len(self._todo) + len(self._doing),
                )
            if evaluation_task_completed:
                self._evaluation_service.complete_task()

            if success:
                if task in self._task_retry_count:
                    del self._task_retry_count[task]
                if self._callbacks_list.model.stop_training:
                    # Clear todo list to stop training
                    self._todo = []

        return (time.time() - start_time), task, worker_id

    def check_exceed_max_task_retries(self, task):
        self._task_retry_count.setdefault(task, 1)
        self._task_retry_count[task] += 1
        if self._task_retry_count[task] > _MAX_TASK_RETRIES:
            logger.error("A %s task failed with %d retries " %
                         (task.type, _MAX_TASK_RETRIES))
            return True
        return False

    def finished(self):
        """Return if all tasks are done"""
        return all([not self._todo, not self._eval_todo, not self._doing])

    def recover_tasks(self, worker_id):
        """Recover doing tasks for a dead worker"""

        with self._lock:
            ids = [
                id for id, (wid, _, _) in self._doing.items()
                if wid == worker_id
            ]
        request = elasticdl_pb2.ReportTaskResultRequest()
        for id in ids:
            request.task_id = id
            self.report(request, False)

    # TODO: need to re-check after refactoring servicer.py
    def set_evaluation_service(self, evaluation_service):
        with self._lock:
            self._evaluation_service = evaluation_service
            if self._evaluation_shards and not self._training_shards:
                evaluation_service.init_eval_only_job(len(self._eval_todo))

    def _call_on_task_end(self, task):
        # The on_task_end is not a method of tf.keras.callbacks.Callback
        # and tf.keras.callbacks.CallbackList. So, we need to check
        # before calling the method.
        for callback in self._callbacks_list.callbacks:
            if hasattr(callback, "on_task_end"):
                callback.on_task_end(task)
Example #9
0
def load_callbacks_from_module(callbacks_def, model_module):
    callbacks_def_name = callbacks_def.split(".")[-1]
    callbacks_fn = _get_spec_value(callbacks_def_name, None, model_module)
    callbacks = [] if callbacks_fn is None else callbacks_fn()
    return CallbackList(callbacks)
Example #10
0
                  optimizer="adam",
                  metrics=["accuracy"])

    model_path = "exp-{epoch:02d}-{val_accuracy:.3f}.h5"
    mc = NEpochModelCheckpoint_wOptimizer(model_path,
                                          nepoch=3,
                                          opt=True,
                                          monitor='val_accuracy',
                                          save_weights_only=True,
                                          mode='max',
                                          verbose=1)
    nbar = NBatchProgBarLogger(display_per_batches=10)
    callback_list = CallbackList([mc, nbar],
                                 add_history=True,
                                 add_progbar=False,
                                 model=model,
                                 verbose=1,
                                 epochs=epochs,
                                 steps=steps_per_epoch)
    model.fit(x_train,
              y_train,
              batch_size=batch_size,
              epochs=epochs,
              validation_split=0.1,
              steps_per_epoch=steps_per_epoch,
              callbacks=callback_list)

    # model2 = keras.Sequential(
    #     [
    #         keras.Input(shape=input_shape),
    #         layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
Example #11
0
    def fit_dataset(self,
                    dataset,
                    steps_per_epoch=None,
                    batch_size=32,
                    epochs=1,
                    verbose=1,
                    callbacks=None,
                    on_sample=None,
                    on_scores=None):
        """Train the model on the given dataset for a given number of epochs.

        Arguments
        ---------
            dataset: Instance of `BaseDataset` that provides the data
                     to train on.
            steps_per_epoch: int or None, number of gradient updates before
                             considering an epoch has passed. If None it is set
                             to be `len(dataset.train_data) / batch_size`.
            batch_size: int, number of samples per gradient update
            epochs: int, number of times to iterate `steps_per_epoch` times
            verbose: {0, >0}, whether to employ the progress bar Keras
                     callback or not
            callbacks: list of Keras callbacks to be called during training
            on_sample: callable that accepts the sampler, idxs, w, scores
            on_scores: callable that accepts the sampler and scores
        """
        try:
            if len(dataset.train_data) < batch_size:
                raise ValueError(("The model cannot be trained with "
                                  "batch_size > training set"))
        except RuntimeError as e:
            assert "no size" in str(e)

        # Set steps_per_epoch properly
        if steps_per_epoch is None:
            steps_per_epoch = len(dataset.train_data) // batch_size

        # Create the callbacks list
        self.history = History()
        callbacks = [BaseLogger()] + (callbacks or []) + [self.history]
        if verbose > 0:
            callbacks += [ProgbarLogger(count_mode="steps")]
        callbacks = CallbackList(callbacks)
        callbacks.set_model(self.original_model)
        callbacks.set_params({
            "epochs":
            epochs,
            "steps":
            steps_per_epoch,
            "verbose":
            verbose,
            "do_validation":
            len(dataset.test_data) > 0,
            "metrics":
            self.metrics_names + ["val_" + n for n in self.metrics_names]
        })

        # Create the sampler
        sampler = self.sampler(dataset, batch_size, steps_per_epoch, epochs)

        # Start the training loop
        epoch = 0
        self.original_model.stop_training = False
        callbacks.on_train_begin()
        while epoch < epochs:
            callbacks.on_epoch_begin(epoch)
            for step in range(steps_per_epoch):
                batch_logs = {"batch": step, "size": batch_size}
                callbacks.on_batch_begin(step, batch_logs)

                # Importance sampling is done here
                idxs, (x, y), w = sampler.sample(batch_size)
                # Train on the sampled data
                loss, metrics, scores = self.model.train_batch(x, y, w)
                # Update the sampler
                sampler.update(idxs, scores)

                values = map(lambda x: x.mean(), [loss] + metrics)
                for l, o in zip(self.metrics_names, values):
                    batch_logs[l] = o
                callbacks.on_batch_end(step, batch_logs)

                if on_scores is not None and hasattr(self, "_latest_scores"):
                    on_scores(sampler, self._latest_scores)

                if on_sample is not None:
                    on_sample(sampler, self._latest_sample_event["idxs"],
                              self._latest_sample_event["w"],
                              self._latest_sample_event["predicted_scores"])

                if self.original_model.stop_training:
                    break

            # Evaluate now that an epoch passed
            epoch_logs = {}
            if len(dataset.test_data) > 0:
                val = self.model.evaluate(*dataset.test_data[:],
                                          batch_size=batch_size)
                epoch_logs = {
                    "val_" + l: o
                    for l, o in zip(self.metrics_names, val)
                }
            callbacks.on_epoch_end(epoch, epoch_logs)
            if self.original_model.stop_training:
                break
            epoch += 1
        callbacks.on_train_end()

        return self.history
Example #12
0
    def fit_generator(self,
                      generator,
                      n_steps_per_epoch,
                      n_epochs=1,
                      validation_data=None,
                      n_validation_steps=None):
        """Train the network on batches of data generated from `generator`

        :param generator: a generator yielding batches indefinitely, where each
         batch is a tuple of (inputs, targets)
        :type generator: generator
        :param n_steps_per_epoch: number of batches to train on in one epoch
        :type n_steps_per_epoch: int
        :param n_epochs: number of epochs to train the model
        :type n_epochs: int
        :param validation_data: generator yielding batches to evaluate the loss
         on at the end of each epoch, where each batch is a tuple of (inputs,
         targets)
        :type validation_data: generator
        :param n_validation_steps: number of batches to evaluate on from
         `validation_data`
        :raises RuntimeError: if only one of `validation_data` and
         `n_validation_steps` are passed in
        """

        default_callbacks = self._default_callbacks()
        callbacks = CallbackList(default_callbacks)

        self._assert_compiled()

        invalid_inputs = (
            (validation_data is not None and n_validation_steps is None)
            or (n_validation_steps is not None and validation_data is None))
        if invalid_inputs:
            msg = ('`validation_data` and `n_validation_steps` must both be '
                   'passed, or neither.')
            raise RuntimeError(msg)

        if self.device:
            self.network.to(self.device)

        callbacks.set_params({
            'epochs': n_epochs,
            'metrics': ['loss', 'val_loss'],
            'steps': n_steps_per_epoch,
            'verbose': True
        })
        callbacks.set_model(self)

        callbacks.on_train_begin()
        for idx_epoch in range(n_epochs):
            if self.stop_training:
                break

            epoch_logs = {}
            callbacks.on_epoch_begin(idx_epoch)

            for idx_batch in range(n_steps_per_epoch):
                batch_logs = {'batch': idx_batch, 'size': 1}
                callbacks.on_batch_begin(idx_batch, batch_logs)

                inputs, targets = next(generator)
                loss = self.train_on_batch(inputs, targets)

                batch_logs['loss'] = loss
                callbacks.on_batch_end(idx_batch, batch_logs)

                if self.stop_training:
                    break

            if validation_data:
                val_loss = self.evaluate_generator(validation_data,
                                                   n_validation_steps)
                epoch_logs['val_loss'] = val_loss
            callbacks.on_epoch_end(idx_epoch, epoch_logs)
        callbacks.on_train_end()
Example #13
0
    def train(self, epochs):
        """
        functional API with model.fit doesn't support sparse tensors with the current implementation =>
        we write the training loop ourselves
        """
        callbacks = CallbackList([
            EvaluateCallback(self.valid_generator, prepend_str='val_'),
            TensorBoard(self.log_dir, profile_batch=0),
            ModelCheckpoint(self.model_save_path / 'best.h5py',
                            monitor='val_kendal',
                            save_best_only=True,
                            verbose=1,
                            mode='max'),
            EarlyStopping(monitor='val_kendal',
                          patience=5,
                          mode='max',
                          restore_best_weights=True),
            ReduceLROnPlateau(
                monitor='val_kendal', patience=2, factor=0.5, mode='max'),
        ],
                                 add_history=True,
                                 add_progbar=True,
                                 verbose=1,
                                 model=self.model,
                                 epochs=epochs,
                                 steps=len(self.train_generator))

        callbacks.on_train_begin()
        for epoch in range(epochs):
            if epoch % 5 == 0:
                self.train_generator.gen_new_graphs()
                self.valid_generator.gen_new_graphs()

            callbacks.on_epoch_begin(epoch)
            [c.on_train_begin() for c in callbacks]
            for batch, (x, y) in enumerate(self.train_generator):
                callbacks.on_train_batch_begin(batch)
                logs = self.model.train_on_batch(x, y, return_dict=True)
                callbacks.on_train_batch_end(batch, logs)

            epoch_logs = copy.copy(logs)
            callbacks.on_epoch_end(epoch, logs=epoch_logs)
            pd.DataFrame(self.model.history.history).to_csv(self.log_dir /
                                                            'history.csv',
                                                            index=False)
            if self.model.stop_training:
                break

        callbacks.on_train_end(copy.copy(epoch_logs))
        print(self.model.history.history)