コード例 #1
0
def test_baseline(model, test_ds, test_ds_info, run_paths):
    # Load ckpts and ckpt manager
    # manager automatically handles model reloading if directory contains ckpts
    # First build model, otherwise not all variables will be loaded
    model.build(input_shape=tuple([None] + test_ds._flat_shapes[0][1:].as_list()))
    ckpt = tf.train.Checkpoint(model=model)
    ckpt_manager = tf.train.CheckpointManager(ckpt, directory=run_paths['path_ckpts_train'], max_to_keep=2)
    ckpt.restore(ckpt_manager.latest_checkpoint)

    if ckpt_manager.latest_checkpoint:
        logging.info(f"Restored from {ckpt_manager.latest_checkpoint}.")
        epoch_start = int(os.path.basename(ckpt_manager.latest_checkpoint).split('-')[1])
    else:
        assert ('No checkpoint for testing...')

    # Prepare Metrics
    metrics_test = metrics.prep_metrics()

    # Testing
    for images, labels in test_ds:
        eval_step(model, images, labels, metrics_test)

    # fetch & reset metrics
    metrics_res_test = metrics.result(metrics_test, as_numpy=True)

    metrics.reset_states(metrics_test)

    logging.info(f'Result: metrics_test: {metrics_res_test}.')

    return metrics_res_test
コード例 #2
0
def train_and_eval_baseline(model,
                            train_ds,
                            train_ds_info,
                            eval_ds,
                            test_ds,
                            run_paths,
                            n_epochs=200,
                            lr_base=0.1,
                            lr_momentum=0.9,
                            lr_drop_boundaries=[1, 80, 120],
                            lr_factors=[0.1, 1, 0.1, 0.01],
                            save_period=1):
    # generate summary writer
    writer_train = tf.summary.create_file_writer(
        os.path.dirname(run_paths['path_logs_train']))
    writer_eval = tf.summary.create_file_writer(
        os.path.dirname(run_paths['path_logs_eval']))
    writer_test = tf.summary.create_file_writer(
        os.path.dirname(run_paths['path_logs_test']))
    logging.info(
        f"saving train log to {os.path.dirname(run_paths['path_logs_train'])}")

    # loss
    loss_obj = ks.losses.CategoricalCrossentropy()

    # define optimizer with learning rate schedule
    steps_per_epoch = 50000 // train_ds._flat_shapes[0][0]
    boundaries = [k * steps_per_epoch for k in lr_drop_boundaries]
    lr_values = [k * lr_base for k in lr_factors]
    learning_rate_schedule = ks.optimizers.schedules.PiecewiseConstantDecay(
        boundaries=boundaries, values=lr_values)
    optimizer = ks.optimizers.SGD(learning_rate=learning_rate_schedule,
                                  momentum=lr_momentum)

    # define ckpts and ckpt manager
    # manager automatically handles model reloading if directory contains ckpts
    # First build model, otherwise not all variables will be loaded
    model.build(input_shape=tuple([None] +
                                  train_ds._flat_shapes[0][1:].as_list()))
    ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer)
    ckpt_manager = tf.train.CheckpointManager(
        ckpt, directory=run_paths['path_ckpts_train'], max_to_keep=2)
    ckpt.restore(ckpt_manager.latest_checkpoint)

    if ckpt_manager.latest_checkpoint:
        logging.info(f"Restored from {ckpt_manager.latest_checkpoint}.")
        epoch_start = int(
            os.path.basename(ckpt_manager.latest_checkpoint).split('-')[1])
    else:
        logging.info("Initializing from scratch.")
        epoch_start = 0

    # metrics
    metric_loss = tf.keras.metrics.Mean()
    metrics_train = metrics.prep_metrics()
    metrics_eval = metrics.prep_metrics()
    metrics_test = metrics.prep_metrics()
    logging.info(f"Training from epoch {epoch_start + 1} to {n_epochs}.")

    # use tf variable for epoch passing - so no new trace is triggered
    # if using normal range (instead of tf.range) assign a epoch_tf tensor, otherwise function gets recreated every turn
    epoch_tf = tf.Variable(1, dtype=tf.int32)

    # Global training time in s
    total_time = 0.0
    # Note: using tf.range also seems to create a graph
    for epoch in range(epoch_start, int(n_epochs)):
        # Start epoch timer
        start = time()
        eta = (n_epochs - epoch) * (total_time / (epoch + 1e-12)) / 60
        # assign tf variable, so graph building doesn't get triggered
        epoch_tf.assign(epoch)

        # perform training for one epoch
        logging.info(
            f"Epoch {epoch + 1}/{n_epochs}: starting training, LR:  {optimizer.learning_rate(optimizer.iterations.numpy()).numpy():.5f},  ETA: {eta:.1f} minutes."
        )

        for images, labels in train_ds:
            train_step(model,
                       images,
                       labels,
                       optimizer,
                       loss_obj,
                       metric_loss,
                       metrics_train,
                       epoch_tf=epoch_tf,
                       b_verbose=False)
        # print model summary once - done after training on first epoch, so model is already built.
        if epoch <= 0:
            model.summary()

        # save train metrics
        loss_avg = metric_loss.result()
        metrics_res_train = metrics.result(metrics_train, as_numpy=True)

        with writer_train.as_default():
            tf.summary.scalar('loss_average', loss_avg, step=epoch)
            [
                tf.summary.scalar(k, v, step=epoch)
                for (k, v) in metrics_res_train.items()
            ]

        # Reset metrics
        metric_loss.reset_states()
        metrics.reset_states(metrics_train)

        # Eval epoch
        for images, labels in eval_ds:
            eval_step(model, images, labels, metrics_eval)

        # fetch & reset metrics
        metrics_res_eval = metrics.result(metrics_eval, as_numpy=True)
        with writer_eval.as_default():
            [
                tf.summary.scalar(k, v, step=epoch)
                for (k, v) in metrics_res_eval.items()
            ]

        metrics.reset_states(metrics_eval)

        # Test epoch
        for images, labels in test_ds:
            eval_step(model, images, labels, metrics_test)

        # fetch & reset metrics
        metrics_res_test = metrics.result(metrics_test, as_numpy=True)
        with writer_test.as_default():
            [
                tf.summary.scalar(k, v, step=epoch)
                for (k, v) in metrics_res_test.items()
            ]

        metrics.reset_states(metrics_test)

        logging.info(
            f'Epoch {epoch + 1}/{n_epochs}: loss_average: {loss_avg}, metrics_train: {metrics_res_train}, metrics_eval: {metrics_res_eval}, metrics_test: {metrics_res_test}.'
        )

        # saving checkpoints after first epoch, last epoch and save_period epochs
        if ((epoch + 1) % save_period == 0) | (epoch == n_epochs - 1):
            logging.info(
                f'Saving checkpoint to {run_paths["path_ckpts_train"]}.')
            ckpt_manager.save(checkpoint_number=epoch)

        # write config after everything has been established
        if epoch <= 0:
            gin_string = gin.operative_config_str()
            logging.info(f'Fetched config parameters: {gin_string}.')
            utils_params.save_gin(run_paths['path_gin'], gin_string)
        # Calc total run_time
        total_time += time() - start
    return metrics_res_eval
コード例 #3
0
    def __init__(self,
                 target_model,
                 ds_train,
                 ds_train_info,
                 ds_val,
                 run_paths,
                 inner_repetition,
                 meta_epochs,
                 meta_lr,
                 beta_byol,
                 num_inner_steps,
                 inner_lr,
                 use_lr_drop,
                 lr_drop_boundaries,
                 lr_factors,
                 use_inner_clipping,
                 use_outer_clipping,
                 clipping_norm,
                 debug=True,
                 keep_ckp=2,
                 save_period=5):
        """
        Init meta traininer
        """

        # All parameter
        self.run_paths = run_paths
        self.meta_epochs = meta_epochs
        self.meta_lr = meta_lr
        self.num_inner_steps = num_inner_steps
        self.inner_lr = inner_lr
        self.save_period = save_period
        self.inner_repetition = inner_repetition
        self.lr_drop_boundaries = lr_drop_boundaries
        self.lr_factors = lr_factors
        self.use_lr_drop = use_lr_drop
        self.use_inner_clipping = use_inner_clipping
        self.use_outer_clipping = use_outer_clipping
        self.clipping_norm = clipping_norm
        self.beta_byol = beta_byol
        self.debug = debug
        self.keep_ckp = keep_ckp

        # datasets
        self.ds_train = ds_train
        self.ds_train_info = ds_train_info
        self.ds_val = ds_val

        # get shapes, batch sizes, steps per epoch
        self.meta_batch_size = ds_train._flat_shapes[0][0]
        self.inner_batch_size = ds_train._flat_shapes[0][1]
        self.input_shape = ds_train._flat_shapes[0][2:].as_list()
        self.num_classes = ds_train._flat_shapes[3][-1]
        if self.inner_repetition and ds_train_info.name == 'cifar10':
            self.steps_per_epoch = round(50000 / self.meta_batch_size)
        elif not self.inner_repetition and ds_train_info.name == 'cifar10':
            self.steps_per_epoch = round(
                50000 / (self.meta_batch_size * self.inner_batch_size))

        # init target model and call one time for correct init
        logging.info("Building models...")
        self.target_model = target_model(n_classes=self.num_classes)
        # self.target_model.build(input_shape=tuple([None] + self.input_shape))
        self.target_model(tf.zeros(shape=tuple([1] + self.input_shape)))
        self.target_model(tf.zeros(shape=tuple([1] + self.input_shape)),
                          use_predictor=True)

        # init one instance for each inner step (and step 0)
        self.updated_models = list()
        for _ in range(self.num_inner_steps + 1):
            updated_model = target_model(n_classes=self.num_classes)
            # updated_model.build(input_shape=tuple([None] + self.input_shape))
            updated_model(tf.zeros(shape=tuple([1] + self.input_shape)))
            updated_model(tf.zeros(shape=tuple([1] + self.input_shape)),
                          use_predictor=True)
            self.updated_models.append(updated_model)

        # define optimizer
        logging.info("Setup optimizer...")
        if self.use_lr_drop:
            boundaries = [
                k * self.steps_per_epoch for k in self.lr_drop_boundaries
            ]
            lr_values = [k * self.meta_lr for k in self.lr_factors]
            learning_rate_schedule = ks.optimizers.schedules.PiecewiseConstantDecay(
                boundaries=boundaries, values=lr_values)
            self.meta_optimizer = ks.optimizers.SGD(
                learning_rate=learning_rate_schedule, momentum=0.9)

        else:
            self.meta_optimizer = ks.optimizers.SGD(learning_rate=self.meta_lr,
                                                    momentum=0.9)

        # Checkpoint
        self.target_ckpt = tf.train.Checkpoint(model=self.target_model,
                                               optimizer=self.meta_optimizer)
        self.target_ckpt_manager = tf.train.CheckpointManager(
            self.target_ckpt,
            directory=run_paths['path_ckpts_train'],
            max_to_keep=self.keep_ckp)

        # Logging tb
        # generate summary writer
        self.writer_train = tf.summary.create_file_writer(
            os.path.dirname(run_paths['path_logs_train']))
        self.writer_eval = tf.summary.create_file_writer(
            os.path.dirname(run_paths['path_logs_eval']))
        logging.info(
            f"saving train log to {os.path.dirname(run_paths['path_logs_train'])}"
        )

        # metrics and losses
        self.ce_loss_obj = ks.losses.CategoricalCrossentropy()
        self.metric_ce_loss = tf.keras.metrics.Mean()
        self.metric_byol_loss = tf.keras.metrics.Mean()
        self.metric_loss = tf.keras.metrics.Mean()
        self.metrics_train = metrics.prep_metrics_meta()
        self.metrics_eval = metrics.prep_metrics()
コード例 #4
0
ファイル: test_meta.py プロジェクト: AlexanderBartler/MT3
def test_meta(target_model, online_model, test_ds, test_ds_info, run_paths,
              test_lr, num_test_steps):
    # Load ckpts and ckpt manager
    # manager automatically handles model reloading if directory contains ckpts
    # First call model, otherwise not all variables will be loaded
    target_model(tf.ones(shape=tuple([1] +
                                     test_ds._flat_shapes[0][1:].as_list())),
                 use_predictor=True)
    online_model(tf.ones(shape=tuple([1] +
                                     test_ds._flat_shapes[0][1:].as_list())),
                 use_predictor=True)
    ckpt = tf.train.Checkpoint(model=target_model)
    ckpt_manager = tf.train.CheckpointManager(
        ckpt, directory=run_paths['path_ckpts_train'], max_to_keep=2)
    ckpt.restore(ckpt_manager.latest_checkpoint).expect_partial()

    if ckpt_manager.latest_checkpoint:
        logging.info(f"Restored from {ckpt_manager.latest_checkpoint}.")
        epoch_start = int(
            os.path.basename(ckpt_manager.latest_checkpoint).split('-')[1])
    else:
        assert ('No checkpoint for testing...')

    # Prepare Metrics
    metrics_test = [metrics.prep_metrics() for _ in range(num_test_steps + 1)]

    # Get optimizer (similiar to inner loop, so no momentum and so on)
    optimizer = tf.keras.optimizers.SGD(learning_rate=test_lr, momentum=0.0)

    # def byol loss
    def byol_loss_fn(x, y):
        x = tf.math.l2_normalize(x, axis=-1)
        y = tf.math.l2_normalize(y, axis=-1)
        return 2 - 2 * tf.math.reduce_sum(x * y, axis=-1)

    @tf.function
    def inner_loop(images_aug_1, images_aug_2, images, labels):
        # copy weights for each image
        # online_model.set_weights(target_model.get_weights()) # slow
        for k in range(0, len(online_model.weights)):
            if not online_model.weights[k].dtype == tf.bool:
                online_model.weights[k].assign(target_model.weights[k])
        # acc without inner update
        _, _, predictions = online_model(
            images[:1, :, :, :],
            training=False)  # only one image since repetition
        metrics.update_state(metrics_test[0], labels[:1, :], predictions)

        # inner update and acc
        for k in range(num_test_steps):
            # calc target
            # Get targets
            _, tar1, _ = target_model(images_aug_1,
                                      use_predictor=False,
                                      training=True)
            _, tar2, _ = target_model(images_aug_2,
                                      use_predictor=False,
                                      training=True)

            # Perform inner optimization
            with tf.GradientTape(persistent=False) as test_tape:
                _, prediction1, _ = online_model(images_aug_1,
                                                 use_predictor=True,
                                                 training=True)
                _, prediction2, _ = online_model(images_aug_2,
                                                 use_predictor=True,
                                                 training=True)
                # Calc byol loss
                loss1 = byol_loss_fn(prediction1, tf.stop_gradient(tar2))
                loss2 = byol_loss_fn(prediction2, tf.stop_gradient(tar1))
                loss = tf.reduce_mean(loss1 + loss2)
            gradients = test_tape.gradient(loss,
                                           online_model.trainable_variables)
            optimizer.apply_gradients(
                zip(gradients, online_model.trainable_variables))
            # get predictions for test acc
            _, _, predictions = online_model(
                images[:1, :, :, :],
                training=False)  # only one image since repetition
            metrics.update_state(metrics_test[k + 1], labels[:1, :],
                                 predictions)
        return 0

    k = 1
    for images_aug_1, images_aug_2, images, labels in test_ds:
        inner_loop(images_aug_1, images_aug_2, images, labels)
        k += 1
        #if k==3:
        #    break

    # fetch & reset metrics
    metrics_res_test = [
        metrics.result(metrics_, as_numpy=True) for metrics_ in metrics_test
    ]

    [metrics.reset_states(metrics_) for metrics_ in metrics_test]

    logging.info(f'Result: metrics_test: {metrics_res_test}.')

    return metrics_res_test