class GSModelCheckpoint(tf.keras.callbacks.Callback): """ Wrapper around ModelCheckpoint that adds support for GCS paths. For GCS paths, it has the nested ModelCheckpoint write to a local temp file. On epoch end, after ModelCheckpoint.on_epoch_end has been invoked, it checks to see whether the temp file changed from the last time and if so, uploads it to GCS. For local paths, it simply delegates to nested ModelCheckpoint callback. """ def __init__(self, filepath: str, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', period=1): super().__init__() self.filepath = filepath self.nested_filepath = filepath self.is_cloud = False self.stat = None # keep last known stat of the temp file if filepath.index("gs://") == 0: self.nested_filepath = "/tmp/123.hdf5" self.is_cloud = True self.has_file_changed() # ignoring return value; just to update self.stat self.nested_callback = ModelCheckpoint(self.nested_filepath, monitor=monitor, verbose=verbose, save_best_only=save_best_only, save_weights_only=save_weights_only, mode=mode, period=period) def has_file_changed(self): try: newstat = os.stat(self.nested_filepath) if self.stat != newstat: self.stat = newstat return True except FileNotFoundError: pass return False def on_epoch_end(self, epoch, logs=None): self.nested_callback.on_epoch_end(epoch, logs) if not self.is_cloud: return if self.has_file_changed(): # TODO: upload to gcs pass
class SmartCheckpoint(Callback): r"Checkpoint class that automatically handles non existing paths and s3 synchronization" def __init__(self, destination_path, file_format, **kwargs): self.local_dir = None self.destination_path = destination_path self.file_format = file_format self.__create_local_folder__() self.checkpoint_path = os.path.join( self.local_dir if self.local_dir is not None else self.destination_path, self.file_format) self.checkpoint_callback = ModelCheckpoint(self.checkpoint_path, **kwargs) def __create_local_folder__(self): #only use local temp directory if destionation is s3 if 's3://' in self.destination_path: self.local_dir = 'temporary_checkpoints/' os.makedirs(self.local_dir) def on_epoch_end(self, epoch, logs={}): #can't move to init due to how self.model is assigned if epoch == 0: self.checkpoint_callback.model = self.model ckpt_path_formatted = self.checkpoint_path.format(epoch=epoch + 1, **logs) ckpt_path_directory = os.path.dirname(ckpt_path_formatted) os.makedirs(ckpt_path_directory, exist_ok=True) self.checkpoint_callback.on_epoch_end(epoch, logs) if self.local_dir is not None: checkpoint_created = len(os.listdir(ckpt_path_directory)) if checkpoint_created: #move all of the contents of local directory to respect directory structure files_or_dirs = os.listdir(self.local_dir) for file_or_dir in files_or_dirs: command = 'aws s3 mv --recursive --quiet --no-progress {} {}'.format( os.path.join(self.local_dir, file_or_dir), os.path.join(self.destination_path, file_or_dir)) # subprocess.Popen(command, shell=True, stdout=subprocess.DEVNULL) with open(os.devnull, 'w') as devnull: subprocess.Popen(command.split(' '), shell=True, stdout=devnull, stderr=devnull)
def _gcp_on_epoch_end(self, epoch, logs=None): # Call original checkpoint to temporary file KerasModelCheckpoint.on_epoch_end(self, epoch, logs=logs) logs = logs or {} # Check if file exists and not empty if not os.path.exists(self.filepath): log.warning("Checkpoint file does not seem to exists. Ignoring") return if os.path.getsize(self.filepath) == 0: log.warning("File empty, no checkpoint has been saved") return final_path = self._original_filepath.format(epoch=epoch + 1, **logs) with file_io.FileIO(self.filepath, mode='rb') as input_f: with file_io.FileIO(final_path, mode='w+b') as output_f: output_f.write(input_f.read()) # Remove local model os.remove(self.filepath)
for x_batch_test, y_batch_test in get_batch(batch_size, x_test, y_test): test_loss, test_accuracy = model.test_on_batch(x_batch_test, y_batch_test) testing_acc.append(test_accuracy) testing_loss.append(test_loss) train_logs_dict = get_logs(train_logs_dict, epoch, model, x_train, y_train) test_logs_dict = get_logs(test_logs_dict, epoch, model, x_test, y_test) logs = { 'acc': np.mean(training_acc), 'loss': np.mean(training_loss), 'val_loss': np.mean(testing_loss), 'val_acc': np.mean(testing_acc) } modelcheckpoint.on_epoch_end(epoch, logs) earlystop.on_epoch_end(epoch, logs) reduce_lr.on_epoch_end(epoch, logs) tensorboard.on_epoch_end(epoch, logs) print( "accuracy: {}, loss: {}, validation accuracy: {}, validation loss: {}". format(np.mean(training_acc), np.mean(training_loss), np.mean(testing_acc), np.mean(testing_loss))) if model.stop_training: break earlystop.on_train_end() modelcheckpoint.on_train_end() reduce_lr.on_train_end() tensorboard.on_train_end() # confusion metric for training