def test_baseline(model, test_ds, test_ds_info, run_paths): # Load ckpts and ckpt manager # manager automatically handles model reloading if directory contains ckpts # First build model, otherwise not all variables will be loaded model.build(input_shape=tuple([None] + test_ds._flat_shapes[0][1:].as_list())) ckpt = tf.train.Checkpoint(model=model) ckpt_manager = tf.train.CheckpointManager(ckpt, directory=run_paths['path_ckpts_train'], max_to_keep=2) ckpt.restore(ckpt_manager.latest_checkpoint) if ckpt_manager.latest_checkpoint: logging.info(f"Restored from {ckpt_manager.latest_checkpoint}.") epoch_start = int(os.path.basename(ckpt_manager.latest_checkpoint).split('-')[1]) else: assert ('No checkpoint for testing...') # Prepare Metrics metrics_test = metrics.prep_metrics() # Testing for images, labels in test_ds: eval_step(model, images, labels, metrics_test) # fetch & reset metrics metrics_res_test = metrics.result(metrics_test, as_numpy=True) metrics.reset_states(metrics_test) logging.info(f'Result: metrics_test: {metrics_res_test}.') return metrics_res_test
def train_and_eval_baseline(model, train_ds, train_ds_info, eval_ds, test_ds, run_paths, n_epochs=200, lr_base=0.1, lr_momentum=0.9, lr_drop_boundaries=[1, 80, 120], lr_factors=[0.1, 1, 0.1, 0.01], save_period=1): # generate summary writer writer_train = tf.summary.create_file_writer( os.path.dirname(run_paths['path_logs_train'])) writer_eval = tf.summary.create_file_writer( os.path.dirname(run_paths['path_logs_eval'])) writer_test = tf.summary.create_file_writer( os.path.dirname(run_paths['path_logs_test'])) logging.info( f"saving train log to {os.path.dirname(run_paths['path_logs_train'])}") # loss loss_obj = ks.losses.CategoricalCrossentropy() # define optimizer with learning rate schedule steps_per_epoch = 50000 // train_ds._flat_shapes[0][0] boundaries = [k * steps_per_epoch for k in lr_drop_boundaries] lr_values = [k * lr_base for k in lr_factors] learning_rate_schedule = ks.optimizers.schedules.PiecewiseConstantDecay( boundaries=boundaries, values=lr_values) optimizer = ks.optimizers.SGD(learning_rate=learning_rate_schedule, momentum=lr_momentum) # define ckpts and ckpt manager # manager automatically handles model reloading if directory contains ckpts # First build model, otherwise not all variables will be loaded model.build(input_shape=tuple([None] + train_ds._flat_shapes[0][1:].as_list())) ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer) ckpt_manager = tf.train.CheckpointManager( ckpt, directory=run_paths['path_ckpts_train'], max_to_keep=2) ckpt.restore(ckpt_manager.latest_checkpoint) if ckpt_manager.latest_checkpoint: logging.info(f"Restored from {ckpt_manager.latest_checkpoint}.") epoch_start = int( os.path.basename(ckpt_manager.latest_checkpoint).split('-')[1]) else: logging.info("Initializing from scratch.") epoch_start = 0 # metrics metric_loss = tf.keras.metrics.Mean() metrics_train = metrics.prep_metrics() metrics_eval = metrics.prep_metrics() metrics_test = metrics.prep_metrics() logging.info(f"Training from epoch {epoch_start + 1} to {n_epochs}.") # use tf variable for epoch passing - so no new trace is triggered # if using normal range (instead of tf.range) assign a epoch_tf tensor, otherwise function gets recreated every turn epoch_tf = tf.Variable(1, dtype=tf.int32) # Global training time in s total_time = 0.0 # Note: using tf.range also seems to create a graph for epoch in range(epoch_start, int(n_epochs)): # Start epoch timer start = time() eta = (n_epochs - epoch) * (total_time / (epoch + 1e-12)) / 60 # assign tf variable, so graph building doesn't get triggered epoch_tf.assign(epoch) # perform training for one epoch logging.info( f"Epoch {epoch + 1}/{n_epochs}: starting training, LR: {optimizer.learning_rate(optimizer.iterations.numpy()).numpy():.5f}, ETA: {eta:.1f} minutes." ) for images, labels in train_ds: train_step(model, images, labels, optimizer, loss_obj, metric_loss, metrics_train, epoch_tf=epoch_tf, b_verbose=False) # print model summary once - done after training on first epoch, so model is already built. if epoch <= 0: model.summary() # save train metrics loss_avg = metric_loss.result() metrics_res_train = metrics.result(metrics_train, as_numpy=True) with writer_train.as_default(): tf.summary.scalar('loss_average', loss_avg, step=epoch) [ tf.summary.scalar(k, v, step=epoch) for (k, v) in metrics_res_train.items() ] # Reset metrics metric_loss.reset_states() metrics.reset_states(metrics_train) # Eval epoch for images, labels in eval_ds: eval_step(model, images, labels, metrics_eval) # fetch & reset metrics metrics_res_eval = metrics.result(metrics_eval, as_numpy=True) with writer_eval.as_default(): [ tf.summary.scalar(k, v, step=epoch) for (k, v) in metrics_res_eval.items() ] metrics.reset_states(metrics_eval) # Test epoch for images, labels in test_ds: eval_step(model, images, labels, metrics_test) # fetch & reset metrics metrics_res_test = metrics.result(metrics_test, as_numpy=True) with writer_test.as_default(): [ tf.summary.scalar(k, v, step=epoch) for (k, v) in metrics_res_test.items() ] metrics.reset_states(metrics_test) logging.info( f'Epoch {epoch + 1}/{n_epochs}: loss_average: {loss_avg}, metrics_train: {metrics_res_train}, metrics_eval: {metrics_res_eval}, metrics_test: {metrics_res_test}.' ) # saving checkpoints after first epoch, last epoch and save_period epochs if ((epoch + 1) % save_period == 0) | (epoch == n_epochs - 1): logging.info( f'Saving checkpoint to {run_paths["path_ckpts_train"]}.') ckpt_manager.save(checkpoint_number=epoch) # write config after everything has been established if epoch <= 0: gin_string = gin.operative_config_str() logging.info(f'Fetched config parameters: {gin_string}.') utils_params.save_gin(run_paths['path_gin'], gin_string) # Calc total run_time total_time += time() - start return metrics_res_eval
def __init__(self, target_model, ds_train, ds_train_info, ds_val, run_paths, inner_repetition, meta_epochs, meta_lr, beta_byol, num_inner_steps, inner_lr, use_lr_drop, lr_drop_boundaries, lr_factors, use_inner_clipping, use_outer_clipping, clipping_norm, debug=True, keep_ckp=2, save_period=5): """ Init meta traininer """ # All parameter self.run_paths = run_paths self.meta_epochs = meta_epochs self.meta_lr = meta_lr self.num_inner_steps = num_inner_steps self.inner_lr = inner_lr self.save_period = save_period self.inner_repetition = inner_repetition self.lr_drop_boundaries = lr_drop_boundaries self.lr_factors = lr_factors self.use_lr_drop = use_lr_drop self.use_inner_clipping = use_inner_clipping self.use_outer_clipping = use_outer_clipping self.clipping_norm = clipping_norm self.beta_byol = beta_byol self.debug = debug self.keep_ckp = keep_ckp # datasets self.ds_train = ds_train self.ds_train_info = ds_train_info self.ds_val = ds_val # get shapes, batch sizes, steps per epoch self.meta_batch_size = ds_train._flat_shapes[0][0] self.inner_batch_size = ds_train._flat_shapes[0][1] self.input_shape = ds_train._flat_shapes[0][2:].as_list() self.num_classes = ds_train._flat_shapes[3][-1] if self.inner_repetition and ds_train_info.name == 'cifar10': self.steps_per_epoch = round(50000 / self.meta_batch_size) elif not self.inner_repetition and ds_train_info.name == 'cifar10': self.steps_per_epoch = round( 50000 / (self.meta_batch_size * self.inner_batch_size)) # init target model and call one time for correct init logging.info("Building models...") self.target_model = target_model(n_classes=self.num_classes) # self.target_model.build(input_shape=tuple([None] + self.input_shape)) self.target_model(tf.zeros(shape=tuple([1] + self.input_shape))) self.target_model(tf.zeros(shape=tuple([1] + self.input_shape)), use_predictor=True) # init one instance for each inner step (and step 0) self.updated_models = list() for _ in range(self.num_inner_steps + 1): updated_model = target_model(n_classes=self.num_classes) # updated_model.build(input_shape=tuple([None] + self.input_shape)) updated_model(tf.zeros(shape=tuple([1] + self.input_shape))) updated_model(tf.zeros(shape=tuple([1] + self.input_shape)), use_predictor=True) self.updated_models.append(updated_model) # define optimizer logging.info("Setup optimizer...") if self.use_lr_drop: boundaries = [ k * self.steps_per_epoch for k in self.lr_drop_boundaries ] lr_values = [k * self.meta_lr for k in self.lr_factors] learning_rate_schedule = ks.optimizers.schedules.PiecewiseConstantDecay( boundaries=boundaries, values=lr_values) self.meta_optimizer = ks.optimizers.SGD( learning_rate=learning_rate_schedule, momentum=0.9) else: self.meta_optimizer = ks.optimizers.SGD(learning_rate=self.meta_lr, momentum=0.9) # Checkpoint self.target_ckpt = tf.train.Checkpoint(model=self.target_model, optimizer=self.meta_optimizer) self.target_ckpt_manager = tf.train.CheckpointManager( self.target_ckpt, directory=run_paths['path_ckpts_train'], max_to_keep=self.keep_ckp) # Logging tb # generate summary writer self.writer_train = tf.summary.create_file_writer( os.path.dirname(run_paths['path_logs_train'])) self.writer_eval = tf.summary.create_file_writer( os.path.dirname(run_paths['path_logs_eval'])) logging.info( f"saving train log to {os.path.dirname(run_paths['path_logs_train'])}" ) # metrics and losses self.ce_loss_obj = ks.losses.CategoricalCrossentropy() self.metric_ce_loss = tf.keras.metrics.Mean() self.metric_byol_loss = tf.keras.metrics.Mean() self.metric_loss = tf.keras.metrics.Mean() self.metrics_train = metrics.prep_metrics_meta() self.metrics_eval = metrics.prep_metrics()
def test_meta(target_model, online_model, test_ds, test_ds_info, run_paths, test_lr, num_test_steps): # Load ckpts and ckpt manager # manager automatically handles model reloading if directory contains ckpts # First call model, otherwise not all variables will be loaded target_model(tf.ones(shape=tuple([1] + test_ds._flat_shapes[0][1:].as_list())), use_predictor=True) online_model(tf.ones(shape=tuple([1] + test_ds._flat_shapes[0][1:].as_list())), use_predictor=True) ckpt = tf.train.Checkpoint(model=target_model) ckpt_manager = tf.train.CheckpointManager( ckpt, directory=run_paths['path_ckpts_train'], max_to_keep=2) ckpt.restore(ckpt_manager.latest_checkpoint).expect_partial() if ckpt_manager.latest_checkpoint: logging.info(f"Restored from {ckpt_manager.latest_checkpoint}.") epoch_start = int( os.path.basename(ckpt_manager.latest_checkpoint).split('-')[1]) else: assert ('No checkpoint for testing...') # Prepare Metrics metrics_test = [metrics.prep_metrics() for _ in range(num_test_steps + 1)] # Get optimizer (similiar to inner loop, so no momentum and so on) optimizer = tf.keras.optimizers.SGD(learning_rate=test_lr, momentum=0.0) # def byol loss def byol_loss_fn(x, y): x = tf.math.l2_normalize(x, axis=-1) y = tf.math.l2_normalize(y, axis=-1) return 2 - 2 * tf.math.reduce_sum(x * y, axis=-1) @tf.function def inner_loop(images_aug_1, images_aug_2, images, labels): # copy weights for each image # online_model.set_weights(target_model.get_weights()) # slow for k in range(0, len(online_model.weights)): if not online_model.weights[k].dtype == tf.bool: online_model.weights[k].assign(target_model.weights[k]) # acc without inner update _, _, predictions = online_model( images[:1, :, :, :], training=False) # only one image since repetition metrics.update_state(metrics_test[0], labels[:1, :], predictions) # inner update and acc for k in range(num_test_steps): # calc target # Get targets _, tar1, _ = target_model(images_aug_1, use_predictor=False, training=True) _, tar2, _ = target_model(images_aug_2, use_predictor=False, training=True) # Perform inner optimization with tf.GradientTape(persistent=False) as test_tape: _, prediction1, _ = online_model(images_aug_1, use_predictor=True, training=True) _, prediction2, _ = online_model(images_aug_2, use_predictor=True, training=True) # Calc byol loss loss1 = byol_loss_fn(prediction1, tf.stop_gradient(tar2)) loss2 = byol_loss_fn(prediction2, tf.stop_gradient(tar1)) loss = tf.reduce_mean(loss1 + loss2) gradients = test_tape.gradient(loss, online_model.trainable_variables) optimizer.apply_gradients( zip(gradients, online_model.trainable_variables)) # get predictions for test acc _, _, predictions = online_model( images[:1, :, :, :], training=False) # only one image since repetition metrics.update_state(metrics_test[k + 1], labels[:1, :], predictions) return 0 k = 1 for images_aug_1, images_aug_2, images, labels in test_ds: inner_loop(images_aug_1, images_aug_2, images, labels) k += 1 #if k==3: # break # fetch & reset metrics metrics_res_test = [ metrics.result(metrics_, as_numpy=True) for metrics_ in metrics_test ] [metrics.reset_states(metrics_) for metrics_ in metrics_test] logging.info(f'Result: metrics_test: {metrics_res_test}.') return metrics_res_test