Example #1
0
def test_baseline(model, test_ds, test_ds_info, run_paths):
    # Load ckpts and ckpt manager
    # manager automatically handles model reloading if directory contains ckpts
    # First build model, otherwise not all variables will be loaded
    model.build(input_shape=tuple([None] + test_ds._flat_shapes[0][1:].as_list()))
    ckpt = tf.train.Checkpoint(model=model)
    ckpt_manager = tf.train.CheckpointManager(ckpt, directory=run_paths['path_ckpts_train'], max_to_keep=2)
    ckpt.restore(ckpt_manager.latest_checkpoint)

    if ckpt_manager.latest_checkpoint:
        logging.info(f"Restored from {ckpt_manager.latest_checkpoint}.")
        epoch_start = int(os.path.basename(ckpt_manager.latest_checkpoint).split('-')[1])
    else:
        assert ('No checkpoint for testing...')

    # Prepare Metrics
    metrics_test = metrics.prep_metrics()

    # Testing
    for images, labels in test_ds:
        eval_step(model, images, labels, metrics_test)

    # fetch & reset metrics
    metrics_res_test = metrics.result(metrics_test, as_numpy=True)

    metrics.reset_states(metrics_test)

    logging.info(f'Result: metrics_test: {metrics_res_test}.')

    return metrics_res_test
Example #2
0
def train_and_eval_baseline(model,
                            train_ds,
                            train_ds_info,
                            eval_ds,
                            test_ds,
                            run_paths,
                            n_epochs=200,
                            lr_base=0.1,
                            lr_momentum=0.9,
                            lr_drop_boundaries=[1, 80, 120],
                            lr_factors=[0.1, 1, 0.1, 0.01],
                            save_period=1):
    # generate summary writer
    writer_train = tf.summary.create_file_writer(
        os.path.dirname(run_paths['path_logs_train']))
    writer_eval = tf.summary.create_file_writer(
        os.path.dirname(run_paths['path_logs_eval']))
    writer_test = tf.summary.create_file_writer(
        os.path.dirname(run_paths['path_logs_test']))
    logging.info(
        f"saving train log to {os.path.dirname(run_paths['path_logs_train'])}")

    # loss
    loss_obj = ks.losses.CategoricalCrossentropy()

    # define optimizer with learning rate schedule
    steps_per_epoch = 50000 // train_ds._flat_shapes[0][0]
    boundaries = [k * steps_per_epoch for k in lr_drop_boundaries]
    lr_values = [k * lr_base for k in lr_factors]
    learning_rate_schedule = ks.optimizers.schedules.PiecewiseConstantDecay(
        boundaries=boundaries, values=lr_values)
    optimizer = ks.optimizers.SGD(learning_rate=learning_rate_schedule,
                                  momentum=lr_momentum)

    # define ckpts and ckpt manager
    # manager automatically handles model reloading if directory contains ckpts
    # First build model, otherwise not all variables will be loaded
    model.build(input_shape=tuple([None] +
                                  train_ds._flat_shapes[0][1:].as_list()))
    ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer)
    ckpt_manager = tf.train.CheckpointManager(
        ckpt, directory=run_paths['path_ckpts_train'], max_to_keep=2)
    ckpt.restore(ckpt_manager.latest_checkpoint)

    if ckpt_manager.latest_checkpoint:
        logging.info(f"Restored from {ckpt_manager.latest_checkpoint}.")
        epoch_start = int(
            os.path.basename(ckpt_manager.latest_checkpoint).split('-')[1])
    else:
        logging.info("Initializing from scratch.")
        epoch_start = 0

    # metrics
    metric_loss = tf.keras.metrics.Mean()
    metrics_train = metrics.prep_metrics()
    metrics_eval = metrics.prep_metrics()
    metrics_test = metrics.prep_metrics()
    logging.info(f"Training from epoch {epoch_start + 1} to {n_epochs}.")

    # use tf variable for epoch passing - so no new trace is triggered
    # if using normal range (instead of tf.range) assign a epoch_tf tensor, otherwise function gets recreated every turn
    epoch_tf = tf.Variable(1, dtype=tf.int32)

    # Global training time in s
    total_time = 0.0
    # Note: using tf.range also seems to create a graph
    for epoch in range(epoch_start, int(n_epochs)):
        # Start epoch timer
        start = time()
        eta = (n_epochs - epoch) * (total_time / (epoch + 1e-12)) / 60
        # assign tf variable, so graph building doesn't get triggered
        epoch_tf.assign(epoch)

        # perform training for one epoch
        logging.info(
            f"Epoch {epoch + 1}/{n_epochs}: starting training, LR:  {optimizer.learning_rate(optimizer.iterations.numpy()).numpy():.5f},  ETA: {eta:.1f} minutes."
        )

        for images, labels in train_ds:
            train_step(model,
                       images,
                       labels,
                       optimizer,
                       loss_obj,
                       metric_loss,
                       metrics_train,
                       epoch_tf=epoch_tf,
                       b_verbose=False)
        # print model summary once - done after training on first epoch, so model is already built.
        if epoch <= 0:
            model.summary()

        # save train metrics
        loss_avg = metric_loss.result()
        metrics_res_train = metrics.result(metrics_train, as_numpy=True)

        with writer_train.as_default():
            tf.summary.scalar('loss_average', loss_avg, step=epoch)
            [
                tf.summary.scalar(k, v, step=epoch)
                for (k, v) in metrics_res_train.items()
            ]

        # Reset metrics
        metric_loss.reset_states()
        metrics.reset_states(metrics_train)

        # Eval epoch
        for images, labels in eval_ds:
            eval_step(model, images, labels, metrics_eval)

        # fetch & reset metrics
        metrics_res_eval = metrics.result(metrics_eval, as_numpy=True)
        with writer_eval.as_default():
            [
                tf.summary.scalar(k, v, step=epoch)
                for (k, v) in metrics_res_eval.items()
            ]

        metrics.reset_states(metrics_eval)

        # Test epoch
        for images, labels in test_ds:
            eval_step(model, images, labels, metrics_test)

        # fetch & reset metrics
        metrics_res_test = metrics.result(metrics_test, as_numpy=True)
        with writer_test.as_default():
            [
                tf.summary.scalar(k, v, step=epoch)
                for (k, v) in metrics_res_test.items()
            ]

        metrics.reset_states(metrics_test)

        logging.info(
            f'Epoch {epoch + 1}/{n_epochs}: loss_average: {loss_avg}, metrics_train: {metrics_res_train}, metrics_eval: {metrics_res_eval}, metrics_test: {metrics_res_test}.'
        )

        # saving checkpoints after first epoch, last epoch and save_period epochs
        if ((epoch + 1) % save_period == 0) | (epoch == n_epochs - 1):
            logging.info(
                f'Saving checkpoint to {run_paths["path_ckpts_train"]}.')
            ckpt_manager.save(checkpoint_number=epoch)

        # write config after everything has been established
        if epoch <= 0:
            gin_string = gin.operative_config_str()
            logging.info(f'Fetched config parameters: {gin_string}.')
            utils_params.save_gin(run_paths['path_gin'], gin_string)
        # Calc total run_time
        total_time += time() - start
    return metrics_res_eval
Example #3
0
    def train(self):
        """
        Training method
        :return:
        """

        # checkpoint and maybe restore model
        self.target_ckpt.restore(self.target_ckpt_manager.latest_checkpoint)
        if self.target_ckpt_manager.latest_checkpoint:
            logging.info(
                f"Restored from {self.target_ckpt_manager.latest_checkpoint}.")
            epoch_start = int(
                os.path.basename(
                    self.target_ckpt_manager.latest_checkpoint).split('-')[1])
        else:
            logging.info("Initializing from scratch.")
            epoch_start = 0

        # use tf variable for epoch passing - so no new trace is triggered
        epoch_tf = tf.Variable(1, dtype=tf.int32)

        # global time counter for eta estimation
        total_time = 0.0

        for epoch in range(epoch_start, int(self.meta_epochs)):
            # Start epoch timer
            start = time()
            eta = (self.meta_epochs - epoch) * (total_time /
                                                (epoch + 1e-12)) / 60
            # assign tf variable, so graph building doesn't get triggered
            epoch_tf.assign(epoch)
            # Log start of epoch and ETA
            if self.use_lr_drop:
                logging.info(
                    f"Epoch {epoch + 1}/{self.meta_epochs}: starting training, LR:  {self.meta_optimizer.learning_rate(self.meta_optimizer.iterations.numpy()).numpy():.5f},  ETA: {eta:.1f} minutes."
                )
            else:
                logging.info(
                    f"Epoch {epoch + 1}/{self.meta_epochs}: starting training,  ETA: {eta:.1f} minutes."
                )

            # Start iteration over meta batches
            step_cnt = 1
            for images_aug_1, images_aug_2, images, labels in self.ds_train:
                # Update one (meta) step
                start_step = time()
                self.meta_train_step(images_aug_1, images_aug_2, images,
                                     labels)
                if self.debug:
                    logging.info(
                        f"Step {step_cnt} finished in: {time() - start_step}s, ce_loss: {self.metric_ce_loss.result()}, byol_loss: {self.metric_byol_loss.result()}"
                    )
                step_cnt += 1

            # Eval (target) model
            for images, labels in self.ds_val:
                self.eval_step(images, labels)

            # maybe saving checkpoint
            if (epoch % self.save_period == 0) | (epoch + 1
                                                  == self.meta_epochs):
                logging.info(
                    f'Saving checkpoint to {self.run_paths["path_ckpts_train"]}.'
                )
                self.target_ckpt_manager.save(checkpoint_number=epoch)

            # get metrics and losses
            ce_loss = self.metric_ce_loss.result()
            byol_loss = self.metric_byol_loss.result()
            loss = self.metric_loss.result()
            metrics_res_train = metrics.result_meta(self.metrics_train,
                                                    as_numpy=True)
            metrics_res_val = metrics.result(self.metrics_eval, as_numpy=True)

            # logging of metrics
            logging.info(
                f'Epoch {epoch + 1}/{self.meta_epochs}: loss: {loss}, ce_loss: {ce_loss}, byol_loss: {byol_loss}, metrics_train: {metrics_res_train}, metrics_eval: {metrics_res_val}'
            )

            # Saving results into tensorboard
            with self.writer_train.as_default():
                tf.summary.scalar('loss', loss, step=epoch)
                tf.summary.scalar('ce_loss', ce_loss, step=epoch)
                tf.summary.scalar('byol_loss', byol_loss, step=epoch)
                [
                    tf.summary.scalar(k, v, step=epoch)
                    for (k, v) in metrics_res_train.items()
                ]

            with self.writer_eval.as_default():
                [
                    tf.summary.scalar(k, v, step=epoch)
                    for (k, v) in metrics_res_val.items()
                ]

            # reset metrics
            self.metric_ce_loss.reset_states()
            self.metric_byol_loss.reset_states()
            self.metric_loss.reset_states()
            metrics.reset_states(self.metrics_train)
            metrics.reset_states(self.metrics_eval)

            # save gin config and summarize model
            if epoch <= 0:
                gin_string = gin.operative_config_str()
                logging.info(f'Fetched config parameters: {gin_string}.')
                utils_params.save_gin(self.run_paths['path_gin'], gin_string)
                self.target_model.summary()

            # estimate epoch time
            total_time += time() - start

        return metrics_res_val
Example #4
0
def test_meta(target_model, online_model, test_ds, test_ds_info, run_paths,
              test_lr, num_test_steps):
    # Load ckpts and ckpt manager
    # manager automatically handles model reloading if directory contains ckpts
    # First call model, otherwise not all variables will be loaded
    target_model(tf.ones(shape=tuple([1] +
                                     test_ds._flat_shapes[0][1:].as_list())),
                 use_predictor=True)
    online_model(tf.ones(shape=tuple([1] +
                                     test_ds._flat_shapes[0][1:].as_list())),
                 use_predictor=True)
    ckpt = tf.train.Checkpoint(model=target_model)
    ckpt_manager = tf.train.CheckpointManager(
        ckpt, directory=run_paths['path_ckpts_train'], max_to_keep=2)
    ckpt.restore(ckpt_manager.latest_checkpoint).expect_partial()

    if ckpt_manager.latest_checkpoint:
        logging.info(f"Restored from {ckpt_manager.latest_checkpoint}.")
        epoch_start = int(
            os.path.basename(ckpt_manager.latest_checkpoint).split('-')[1])
    else:
        assert ('No checkpoint for testing...')

    # Prepare Metrics
    metrics_test = [metrics.prep_metrics() for _ in range(num_test_steps + 1)]

    # Get optimizer (similiar to inner loop, so no momentum and so on)
    optimizer = tf.keras.optimizers.SGD(learning_rate=test_lr, momentum=0.0)

    # def byol loss
    def byol_loss_fn(x, y):
        x = tf.math.l2_normalize(x, axis=-1)
        y = tf.math.l2_normalize(y, axis=-1)
        return 2 - 2 * tf.math.reduce_sum(x * y, axis=-1)

    @tf.function
    def inner_loop(images_aug_1, images_aug_2, images, labels):
        # copy weights for each image
        # online_model.set_weights(target_model.get_weights()) # slow
        for k in range(0, len(online_model.weights)):
            if not online_model.weights[k].dtype == tf.bool:
                online_model.weights[k].assign(target_model.weights[k])
        # acc without inner update
        _, _, predictions = online_model(
            images[:1, :, :, :],
            training=False)  # only one image since repetition
        metrics.update_state(metrics_test[0], labels[:1, :], predictions)

        # inner update and acc
        for k in range(num_test_steps):
            # calc target
            # Get targets
            _, tar1, _ = target_model(images_aug_1,
                                      use_predictor=False,
                                      training=True)
            _, tar2, _ = target_model(images_aug_2,
                                      use_predictor=False,
                                      training=True)

            # Perform inner optimization
            with tf.GradientTape(persistent=False) as test_tape:
                _, prediction1, _ = online_model(images_aug_1,
                                                 use_predictor=True,
                                                 training=True)
                _, prediction2, _ = online_model(images_aug_2,
                                                 use_predictor=True,
                                                 training=True)
                # Calc byol loss
                loss1 = byol_loss_fn(prediction1, tf.stop_gradient(tar2))
                loss2 = byol_loss_fn(prediction2, tf.stop_gradient(tar1))
                loss = tf.reduce_mean(loss1 + loss2)
            gradients = test_tape.gradient(loss,
                                           online_model.trainable_variables)
            optimizer.apply_gradients(
                zip(gradients, online_model.trainable_variables))
            # get predictions for test acc
            _, _, predictions = online_model(
                images[:1, :, :, :],
                training=False)  # only one image since repetition
            metrics.update_state(metrics_test[k + 1], labels[:1, :],
                                 predictions)
        return 0

    k = 1
    for images_aug_1, images_aug_2, images, labels in test_ds:
        inner_loop(images_aug_1, images_aug_2, images, labels)
        k += 1
        #if k==3:
        #    break

    # fetch & reset metrics
    metrics_res_test = [
        metrics.result(metrics_, as_numpy=True) for metrics_ in metrics_test
    ]

    [metrics.reset_states(metrics_) for metrics_ in metrics_test]

    logging.info(f'Result: metrics_test: {metrics_res_test}.')

    return metrics_res_test