class GSModelCheckpoint(tf.keras.callbacks.Callback):
    """
    Wrapper around ModelCheckpoint that adds support for GCS paths.

    For GCS paths, it has the nested ModelCheckpoint write to a local temp file.
    On epoch end, after ModelCheckpoint.on_epoch_end has been invoked, it checks to see
    whether the temp file changed from the last time and if so, uploads it to GCS.

    For local paths, it simply delegates to nested ModelCheckpoint callback.
    """

    def __init__(self, filepath: str, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False,
                 mode='auto', period=1):

        super().__init__()
        self.filepath = filepath
        self.nested_filepath = filepath
        self.is_cloud = False
        self.stat = None  # keep last known stat of the temp file

        if filepath.index("gs://") == 0:
            self.nested_filepath = "/tmp/123.hdf5"
            self.is_cloud = True
            self.has_file_changed()  # ignoring return value; just to update self.stat

        self.nested_callback = ModelCheckpoint(self.nested_filepath,
                                               monitor=monitor,
                                               verbose=verbose,
                                               save_best_only=save_best_only,
                                               save_weights_only=save_weights_only,
                                               mode=mode,
                                               period=period)

    def has_file_changed(self):
        try:
            newstat = os.stat(self.nested_filepath)
            if self.stat != newstat:
                self.stat = newstat
                return True
        except FileNotFoundError:
            pass
        return False

    def on_epoch_end(self, epoch, logs=None):
        self.nested_callback.on_epoch_end(epoch, logs)
        if not self.is_cloud:
            return

        if self.has_file_changed():
            # TODO: upload to gcs
            pass
class SmartCheckpoint(Callback):
    r"Checkpoint class that automatically handles non existing paths and s3 synchronization"

    def __init__(self, destination_path, file_format, **kwargs):
        self.local_dir = None
        self.destination_path = destination_path
        self.file_format = file_format
        self.__create_local_folder__()
        self.checkpoint_path = os.path.join(
            self.local_dir if self.local_dir is not None else self.destination_path,
            self.file_format)
        self.checkpoint_callback = ModelCheckpoint(self.checkpoint_path, **kwargs)

    def __create_local_folder__(self):
        #only use local temp directory if destionation is s3
        if 's3://' in self.destination_path:
            self.local_dir = 'temporary_checkpoints/'
            os.makedirs(self.local_dir)

    def on_epoch_end(self, epoch, logs={}):
        #can't move to init due to how self.model is assigned
        if epoch == 0:
            self.checkpoint_callback.model = self.model

        ckpt_path_formatted = self.checkpoint_path.format(epoch=epoch + 1, **logs)
        ckpt_path_directory = os.path.dirname(ckpt_path_formatted)
        os.makedirs(ckpt_path_directory, exist_ok=True)
        self.checkpoint_callback.on_epoch_end(epoch, logs)
        if self.local_dir is not None:
            checkpoint_created = len(os.listdir(ckpt_path_directory))
            if checkpoint_created:
                #move all of the contents of local directory to respect directory structure
                files_or_dirs = os.listdir(self.local_dir)
                for file_or_dir in files_or_dirs:
                    command = 'aws s3 mv --recursive --quiet --no-progress {} {}'.format(
                        os.path.join(self.local_dir, file_or_dir),
                        os.path.join(self.destination_path, file_or_dir))
                    # subprocess.Popen(command, shell=True, stdout=subprocess.DEVNULL)
                    with open(os.devnull, 'w') as devnull:
                        subprocess.Popen(command.split(' '),
                                         shell=True,
                                         stdout=devnull,
                                         stderr=devnull)
예제 #3
0
    def _gcp_on_epoch_end(self, epoch, logs=None):
        # Call original checkpoint to temporary file
        KerasModelCheckpoint.on_epoch_end(self, epoch, logs=logs)

        logs = logs or {}

        # Check if file exists and not empty
        if not os.path.exists(self.filepath):
            log.warning("Checkpoint file does not seem to exists. Ignoring")
            return

        if os.path.getsize(self.filepath) == 0:
            log.warning("File empty, no checkpoint has been saved")
            return

        final_path = self._original_filepath.format(epoch=epoch + 1, **logs)

        with file_io.FileIO(self.filepath, mode='rb') as input_f:
            with file_io.FileIO(final_path, mode='w+b') as output_f:
                output_f.write(input_f.read())

        # Remove local model
        os.remove(self.filepath)
예제 #4
0
    for x_batch_test, y_batch_test in get_batch(batch_size, x_test, y_test):

        test_loss, test_accuracy = model.test_on_batch(x_batch_test,
                                                       y_batch_test)
        testing_acc.append(test_accuracy)
        testing_loss.append(test_loss)
    train_logs_dict = get_logs(train_logs_dict, epoch, model, x_train, y_train)
    test_logs_dict = get_logs(test_logs_dict, epoch, model, x_test, y_test)
    logs = {
        'acc': np.mean(training_acc),
        'loss': np.mean(training_loss),
        'val_loss': np.mean(testing_loss),
        'val_acc': np.mean(testing_acc)
    }
    modelcheckpoint.on_epoch_end(epoch, logs)
    earlystop.on_epoch_end(epoch, logs)
    reduce_lr.on_epoch_end(epoch, logs)
    tensorboard.on_epoch_end(epoch, logs)
    print(
        "accuracy: {}, loss: {}, validation accuracy: {}, validation loss: {}".
        format(np.mean(training_acc), np.mean(training_loss),
               np.mean(testing_acc), np.mean(testing_loss)))
    if model.stop_training:
        break
earlystop.on_train_end()
modelcheckpoint.on_train_end()
reduce_lr.on_train_end()
tensorboard.on_train_end()

# confusion metric for training