def test_baseline(model, test_ds, test_ds_info, run_paths): # Load ckpts and ckpt manager # manager automatically handles model reloading if directory contains ckpts # First build model, otherwise not all variables will be loaded model.build(input_shape=tuple([None] + test_ds._flat_shapes[0][1:].as_list())) ckpt = tf.train.Checkpoint(model=model) ckpt_manager = tf.train.CheckpointManager(ckpt, directory=run_paths['path_ckpts_train'], max_to_keep=2) ckpt.restore(ckpt_manager.latest_checkpoint) if ckpt_manager.latest_checkpoint: logging.info(f"Restored from {ckpt_manager.latest_checkpoint}.") epoch_start = int(os.path.basename(ckpt_manager.latest_checkpoint).split('-')[1]) else: assert ('No checkpoint for testing...') # Prepare Metrics metrics_test = metrics.prep_metrics() # Testing for images, labels in test_ds: eval_step(model, images, labels, metrics_test) # fetch & reset metrics metrics_res_test = metrics.result(metrics_test, as_numpy=True) metrics.reset_states(metrics_test) logging.info(f'Result: metrics_test: {metrics_res_test}.') return metrics_res_test
def train_and_eval_baseline(model, train_ds, train_ds_info, eval_ds, test_ds, run_paths, n_epochs=200, lr_base=0.1, lr_momentum=0.9, lr_drop_boundaries=[1, 80, 120], lr_factors=[0.1, 1, 0.1, 0.01], save_period=1): # generate summary writer writer_train = tf.summary.create_file_writer( os.path.dirname(run_paths['path_logs_train'])) writer_eval = tf.summary.create_file_writer( os.path.dirname(run_paths['path_logs_eval'])) writer_test = tf.summary.create_file_writer( os.path.dirname(run_paths['path_logs_test'])) logging.info( f"saving train log to {os.path.dirname(run_paths['path_logs_train'])}") # loss loss_obj = ks.losses.CategoricalCrossentropy() # define optimizer with learning rate schedule steps_per_epoch = 50000 // train_ds._flat_shapes[0][0] boundaries = [k * steps_per_epoch for k in lr_drop_boundaries] lr_values = [k * lr_base for k in lr_factors] learning_rate_schedule = ks.optimizers.schedules.PiecewiseConstantDecay( boundaries=boundaries, values=lr_values) optimizer = ks.optimizers.SGD(learning_rate=learning_rate_schedule, momentum=lr_momentum) # define ckpts and ckpt manager # manager automatically handles model reloading if directory contains ckpts # First build model, otherwise not all variables will be loaded model.build(input_shape=tuple([None] + train_ds._flat_shapes[0][1:].as_list())) ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer) ckpt_manager = tf.train.CheckpointManager( ckpt, directory=run_paths['path_ckpts_train'], max_to_keep=2) ckpt.restore(ckpt_manager.latest_checkpoint) if ckpt_manager.latest_checkpoint: logging.info(f"Restored from {ckpt_manager.latest_checkpoint}.") epoch_start = int( os.path.basename(ckpt_manager.latest_checkpoint).split('-')[1]) else: logging.info("Initializing from scratch.") epoch_start = 0 # metrics metric_loss = tf.keras.metrics.Mean() metrics_train = metrics.prep_metrics() metrics_eval = metrics.prep_metrics() metrics_test = metrics.prep_metrics() logging.info(f"Training from epoch {epoch_start + 1} to {n_epochs}.") # use tf variable for epoch passing - so no new trace is triggered # if using normal range (instead of tf.range) assign a epoch_tf tensor, otherwise function gets recreated every turn epoch_tf = tf.Variable(1, dtype=tf.int32) # Global training time in s total_time = 0.0 # Note: using tf.range also seems to create a graph for epoch in range(epoch_start, int(n_epochs)): # Start epoch timer start = time() eta = (n_epochs - epoch) * (total_time / (epoch + 1e-12)) / 60 # assign tf variable, so graph building doesn't get triggered epoch_tf.assign(epoch) # perform training for one epoch logging.info( f"Epoch {epoch + 1}/{n_epochs}: starting training, LR: {optimizer.learning_rate(optimizer.iterations.numpy()).numpy():.5f}, ETA: {eta:.1f} minutes." ) for images, labels in train_ds: train_step(model, images, labels, optimizer, loss_obj, metric_loss, metrics_train, epoch_tf=epoch_tf, b_verbose=False) # print model summary once - done after training on first epoch, so model is already built. if epoch <= 0: model.summary() # save train metrics loss_avg = metric_loss.result() metrics_res_train = metrics.result(metrics_train, as_numpy=True) with writer_train.as_default(): tf.summary.scalar('loss_average', loss_avg, step=epoch) [ tf.summary.scalar(k, v, step=epoch) for (k, v) in metrics_res_train.items() ] # Reset metrics metric_loss.reset_states() metrics.reset_states(metrics_train) # Eval epoch for images, labels in eval_ds: eval_step(model, images, labels, metrics_eval) # fetch & reset metrics metrics_res_eval = metrics.result(metrics_eval, as_numpy=True) with writer_eval.as_default(): [ tf.summary.scalar(k, v, step=epoch) for (k, v) in metrics_res_eval.items() ] metrics.reset_states(metrics_eval) # Test epoch for images, labels in test_ds: eval_step(model, images, labels, metrics_test) # fetch & reset metrics metrics_res_test = metrics.result(metrics_test, as_numpy=True) with writer_test.as_default(): [ tf.summary.scalar(k, v, step=epoch) for (k, v) in metrics_res_test.items() ] metrics.reset_states(metrics_test) logging.info( f'Epoch {epoch + 1}/{n_epochs}: loss_average: {loss_avg}, metrics_train: {metrics_res_train}, metrics_eval: {metrics_res_eval}, metrics_test: {metrics_res_test}.' ) # saving checkpoints after first epoch, last epoch and save_period epochs if ((epoch + 1) % save_period == 0) | (epoch == n_epochs - 1): logging.info( f'Saving checkpoint to {run_paths["path_ckpts_train"]}.') ckpt_manager.save(checkpoint_number=epoch) # write config after everything has been established if epoch <= 0: gin_string = gin.operative_config_str() logging.info(f'Fetched config parameters: {gin_string}.') utils_params.save_gin(run_paths['path_gin'], gin_string) # Calc total run_time total_time += time() - start return metrics_res_eval
def train(self): """ Training method :return: """ # checkpoint and maybe restore model self.target_ckpt.restore(self.target_ckpt_manager.latest_checkpoint) if self.target_ckpt_manager.latest_checkpoint: logging.info( f"Restored from {self.target_ckpt_manager.latest_checkpoint}.") epoch_start = int( os.path.basename( self.target_ckpt_manager.latest_checkpoint).split('-')[1]) else: logging.info("Initializing from scratch.") epoch_start = 0 # use tf variable for epoch passing - so no new trace is triggered epoch_tf = tf.Variable(1, dtype=tf.int32) # global time counter for eta estimation total_time = 0.0 for epoch in range(epoch_start, int(self.meta_epochs)): # Start epoch timer start = time() eta = (self.meta_epochs - epoch) * (total_time / (epoch + 1e-12)) / 60 # assign tf variable, so graph building doesn't get triggered epoch_tf.assign(epoch) # Log start of epoch and ETA if self.use_lr_drop: logging.info( f"Epoch {epoch + 1}/{self.meta_epochs}: starting training, LR: {self.meta_optimizer.learning_rate(self.meta_optimizer.iterations.numpy()).numpy():.5f}, ETA: {eta:.1f} minutes." ) else: logging.info( f"Epoch {epoch + 1}/{self.meta_epochs}: starting training, ETA: {eta:.1f} minutes." ) # Start iteration over meta batches step_cnt = 1 for images_aug_1, images_aug_2, images, labels in self.ds_train: # Update one (meta) step start_step = time() self.meta_train_step(images_aug_1, images_aug_2, images, labels) if self.debug: logging.info( f"Step {step_cnt} finished in: {time() - start_step}s, ce_loss: {self.metric_ce_loss.result()}, byol_loss: {self.metric_byol_loss.result()}" ) step_cnt += 1 # Eval (target) model for images, labels in self.ds_val: self.eval_step(images, labels) # maybe saving checkpoint if (epoch % self.save_period == 0) | (epoch + 1 == self.meta_epochs): logging.info( f'Saving checkpoint to {self.run_paths["path_ckpts_train"]}.' ) self.target_ckpt_manager.save(checkpoint_number=epoch) # get metrics and losses ce_loss = self.metric_ce_loss.result() byol_loss = self.metric_byol_loss.result() loss = self.metric_loss.result() metrics_res_train = metrics.result_meta(self.metrics_train, as_numpy=True) metrics_res_val = metrics.result(self.metrics_eval, as_numpy=True) # logging of metrics logging.info( f'Epoch {epoch + 1}/{self.meta_epochs}: loss: {loss}, ce_loss: {ce_loss}, byol_loss: {byol_loss}, metrics_train: {metrics_res_train}, metrics_eval: {metrics_res_val}' ) # Saving results into tensorboard with self.writer_train.as_default(): tf.summary.scalar('loss', loss, step=epoch) tf.summary.scalar('ce_loss', ce_loss, step=epoch) tf.summary.scalar('byol_loss', byol_loss, step=epoch) [ tf.summary.scalar(k, v, step=epoch) for (k, v) in metrics_res_train.items() ] with self.writer_eval.as_default(): [ tf.summary.scalar(k, v, step=epoch) for (k, v) in metrics_res_val.items() ] # reset metrics self.metric_ce_loss.reset_states() self.metric_byol_loss.reset_states() self.metric_loss.reset_states() metrics.reset_states(self.metrics_train) metrics.reset_states(self.metrics_eval) # save gin config and summarize model if epoch <= 0: gin_string = gin.operative_config_str() logging.info(f'Fetched config parameters: {gin_string}.') utils_params.save_gin(self.run_paths['path_gin'], gin_string) self.target_model.summary() # estimate epoch time total_time += time() - start return metrics_res_val
def test_meta(target_model, online_model, test_ds, test_ds_info, run_paths, test_lr, num_test_steps): # Load ckpts and ckpt manager # manager automatically handles model reloading if directory contains ckpts # First call model, otherwise not all variables will be loaded target_model(tf.ones(shape=tuple([1] + test_ds._flat_shapes[0][1:].as_list())), use_predictor=True) online_model(tf.ones(shape=tuple([1] + test_ds._flat_shapes[0][1:].as_list())), use_predictor=True) ckpt = tf.train.Checkpoint(model=target_model) ckpt_manager = tf.train.CheckpointManager( ckpt, directory=run_paths['path_ckpts_train'], max_to_keep=2) ckpt.restore(ckpt_manager.latest_checkpoint).expect_partial() if ckpt_manager.latest_checkpoint: logging.info(f"Restored from {ckpt_manager.latest_checkpoint}.") epoch_start = int( os.path.basename(ckpt_manager.latest_checkpoint).split('-')[1]) else: assert ('No checkpoint for testing...') # Prepare Metrics metrics_test = [metrics.prep_metrics() for _ in range(num_test_steps + 1)] # Get optimizer (similiar to inner loop, so no momentum and so on) optimizer = tf.keras.optimizers.SGD(learning_rate=test_lr, momentum=0.0) # def byol loss def byol_loss_fn(x, y): x = tf.math.l2_normalize(x, axis=-1) y = tf.math.l2_normalize(y, axis=-1) return 2 - 2 * tf.math.reduce_sum(x * y, axis=-1) @tf.function def inner_loop(images_aug_1, images_aug_2, images, labels): # copy weights for each image # online_model.set_weights(target_model.get_weights()) # slow for k in range(0, len(online_model.weights)): if not online_model.weights[k].dtype == tf.bool: online_model.weights[k].assign(target_model.weights[k]) # acc without inner update _, _, predictions = online_model( images[:1, :, :, :], training=False) # only one image since repetition metrics.update_state(metrics_test[0], labels[:1, :], predictions) # inner update and acc for k in range(num_test_steps): # calc target # Get targets _, tar1, _ = target_model(images_aug_1, use_predictor=False, training=True) _, tar2, _ = target_model(images_aug_2, use_predictor=False, training=True) # Perform inner optimization with tf.GradientTape(persistent=False) as test_tape: _, prediction1, _ = online_model(images_aug_1, use_predictor=True, training=True) _, prediction2, _ = online_model(images_aug_2, use_predictor=True, training=True) # Calc byol loss loss1 = byol_loss_fn(prediction1, tf.stop_gradient(tar2)) loss2 = byol_loss_fn(prediction2, tf.stop_gradient(tar1)) loss = tf.reduce_mean(loss1 + loss2) gradients = test_tape.gradient(loss, online_model.trainable_variables) optimizer.apply_gradients( zip(gradients, online_model.trainable_variables)) # get predictions for test acc _, _, predictions = online_model( images[:1, :, :, :], training=False) # only one image since repetition metrics.update_state(metrics_test[k + 1], labels[:1, :], predictions) return 0 k = 1 for images_aug_1, images_aug_2, images, labels in test_ds: inner_loop(images_aug_1, images_aug_2, images, labels) k += 1 #if k==3: # break # fetch & reset metrics metrics_res_test = [ metrics.result(metrics_, as_numpy=True) for metrics_ in metrics_test ] [metrics.reset_states(metrics_) for metrics_ in metrics_test] logging.info(f'Result: metrics_test: {metrics_res_test}.') return metrics_res_test