Пример #1
0
    def test_train(self):
        epoch = (1, 1)
        batch_nos = []
        batch_print_callback = ks.callbacks.LambdaCallback(
            on_batch_begin=lambda batch, logs: batch_nos.append(batch))
        self.orga.cfg.callback_train = batch_print_callback

        history = train_model(self.orga, self.model, epoch, batch_logger=False)
        target = {  # TODO why does this sometimes change?
            'loss': 18.236408802816285,
            'mc_A_loss': 9.647336,
            'mc_B_loss': 8.597588874108167,
        }
        print(history, target)
        assert_dict_arrays_equal(history, target, rtol=0.15)
        self.assertSequenceEqual(batch_nos, list(range(int(self.file_sizes[0]/self.orga.cfg.batchsize))))
Пример #2
0
    def test_train(self):
        epoch = (1, 1)
        batch_nos = []
        batch_print_callback = ks.callbacks.LambdaCallback(
            on_batch_begin=lambda batch, logs: batch_nos.append(batch))
        self.orga.cfg.callback_train = batch_print_callback

        history = train_model(self.orga, self.model, epoch, batch_logger=False)
        target = {  # TODO these sometimes change, so I set rtol to be high. But WHY???
            'loss': 18.252519607543945,
            'mc_A_loss': 9.647336959838867,
            'mc_B_loss': 8.605181694030762,
        }
        print(history, target)
        assert_dict_arrays_equal(history, target, rtol=0.8)
        self.assertSequenceEqual(
            batch_nos,
            list(range(int(self.file_sizes[0] / self.orga.cfg.batchsize))))
Пример #3
0
    def train(self, model=None):
        """
        Trains a model on the next file.

        The progress of the training is also logged and plotted.

        Parameters
        ----------
        model : ks.models.Model or str, optional
            Compiled keras model to use for training. Required for the first
            epoch (the start of training).
            Can also be the path to a saved keras model, which will be laoded.
            If model is None, the most recent saved model will be
            loaded automatically to continue the training.

        Returns
        -------
        history : dict
            The history of the training on this file. A record of training
            loss values and metrics values.

        """
        # Create folder structure
        self.io.get_subfolder(create=True)
        latest_epoch = self.io.get_latest_epoch()

        model = self._get_model(model, logging=True)

        self._set_up(model, logging=True)

        # epoch about to be trained
        next_epoch = self.io.get_next_epoch(latest_epoch)
        next_epoch_float = self.io.get_epoch_float(*next_epoch)

        if latest_epoch is None:
            self.io.check_connections(model)
            logging.log_start_training(self)

        model_path = self.io.get_model_path(*next_epoch)
        model_path_local = self.io.get_model_path(*next_epoch, local=True)
        if os.path.isfile(model_path):
            raise FileExistsError(
                "Can not train model in epoch {} file {}, this model has "
                "already been saved!".format(*next_epoch))

        smry_logger = logging.SummaryLogger(self, model)

        if self.cfg.learning_rate is not None:
            tf.keras.backend.set_value(
                model.optimizer.lr, self.io.get_learning_rate(next_epoch)
            )

        files_dict = self.io.get_file("train", next_epoch[1])

        line = "Training in epoch {} on file {}/{}".format(
            next_epoch[0], next_epoch[1], self.io.get_no_of_files("train"))
        self.io.print_log(line)
        self.io.print_log("-" * len(line))
        self.io.print_log("Learning rate is at {}".format(
            tf.keras.backend.get_value(model.optimizer.lr)))
        self.io.print_log('Inputs and files:')
        for input_name, input_file in files_dict.items():
            self.io.print_log("   {}: \t{}".format(input_name,
                                                   os.path.basename(
                                                       input_file)))

        start_time = time.time()
        history = backend.train_model(self, model, next_epoch, batch_logger=True)
        elapsed_s = int(time.time() - start_time)

        model.save(model_path)
        smry_logger.write_line(
            next_epoch_float,
            tf.keras.backend.get_value(model.optimizer.lr),
            history_train=history,
        )

        self.io.print_log('Training results:')
        for metric_name, loss in history.items():
            self.io.print_log(f"   {metric_name}: \t{loss}")
        self.io.print_log(f"Elapsed time: {timedelta(seconds=elapsed_s)}")
        self.io.print_log(f"Saved model to: {model_path_local}\n")

        update_summary_plot(self)
        if self.cfg.cleanup_models:
            self.cleanup_models()

        return history