예제 #1
0
 def __init__(self, *args, **kws):
     KerasModelCheckpoint.__init__(self, *args, **kws)
     if self.filepath.startswith("gs://"):
         self.on_epoch_end = self._gcp_on_epoch_end
         self._original_filepath = self.filepath
         self._temp_file = tempfile.NamedTemporaryFile()
         self.filepath = self._temp_file.name
    def __init__(self,
                 directory,
                 filename,
                 monitor='val_loss',
                 verbose=0,
                 save_best_only=False,
                 save_weights_only=False,
                 mode='auto',
                 period=1):

        # make folder with the current time as name
        now = datetime.datetime.now()
        current_time = "{}_{}_{}_{}_{}_{}".format(now.day, now.month, now.year,
                                                  now.hour, now.minute,
                                                  now.second)
        constants.SAVE_DIR = os.path.join(directory, current_time)

        create_folder(constants.SAVE_DIR)

        ModelCheckpoint.__init__(self,
                                 os.path.join(constants.SAVE_DIR, filename),
                                 monitor=monitor,
                                 save_best_only=save_best_only,
                                 save_weights_only=save_weights_only,
                                 mode=mode,
                                 period=period)
예제 #3
0
 def __init__(self,
              filepath,
              monitor='val_loss',
              verbose=0,
              save_best_only=False,
              mode='auto',
              start_epoch=0):
     ModelCheckpoint.__init__(self,
                              filepath,
                              monitor=monitor,
                              verbose=verbose,
                              save_best_only=save_best_only,
                              mode=mode)
     self.start_epoch = start_epoch
예제 #4
0
    def __init__(self, filepath, save_best_only=True, training_set=(None, None), testing_set=(None, None), folder=None, cost_string="log_loss", save_training_dataset=False, verbose=1):
        ModelCheckpoint.__init__(self, filepath=filepath, save_best_only=save_best_only, verbose=1)

        self.training_x, self.training_y = training_set
        self.testing_x, self.testing_id, = testing_set
        self.folder = folder
        self.save_training_dataset = save_training_dataset

        if cost_string == "log_loss":
            self.cost_function = cost_string
        elif cost_string == "auc":
            self.cost_function = roc_auc_score
        else:
            log("Found undefined cost function - {}".format(cost_string), ERROR)
            raise NotImplementError
예제 #5
0
    def __init__(self,
                 name,
                 directory='',
                 associated_trial=None,
                 monitor='val_loss',
                 verbose=0,
                 save_best_only=True,
                 mode='auto'):
        self.name = name
        if (associated_trial != None):

            self.smartDir = associated_trial.get_path()
            self.checkpointFilename = self.smartDir + "weights.h5"
            self.historyFilename = self.smartDir + "history.json"
        else:
            self.smartDir = directory + 'SmartCheckpoint/'
            self.checkpointFilename = self.smartDir + name + "_weights.h5"
            self.historyFilename = self.smartDir + name + "_history.json"
        self.startTime = 0
        # self.max_epoch = max_epoch
        self.histobj = History()

        histDict = {}
        try:
            histDict = json.load(open(self.historyFilename, "rb"))
            print('Sucessfully loaded history at ' + self.historyFilename)
        except (IOError, EOFError):
            print('Failed to load history at ' + self.historyFilename)

        self.histobj.history = histDict

        self.elapse_time = histDict.get("elapse_time", 0)

        ModelCheckpoint.__init__(self, self.checkpointFilename, monitor,
                                 verbose, save_best_only, mode)

        metric_history = histDict.get(monitor, None)
        if (metric_history != None):
            best = metric_history[0]
            for metric in metric_history:
                if self.monitor_op(metric, self.best):
                    self.best = metric
    def __init__(self,
                 filepath,
                 base_model,
                 monitor='val_loss',
                 verbose=0,
                 save_best_only=False,
                 save_weights_only=False,
                 mode='auto',
                 period=1):

        ModelCheckpoint.__init__(self,
                                 filepath,
                                 monitor=monitor,
                                 verbose=verbose,
                                 save_best_only=save_best_only,
                                 save_weights_only=save_weights_only,
                                 mode=mode,
                                 period=period)

        self.base_model = base_model
예제 #7
0
 def __init__(self, *args, **kwargs):
     self.model_name_ = kwargs.pop('name', 'unknown')
     self.save_every_k_epochs_ = kwargs.pop('save_every_k_epochs', 100)
     ModelCheckpoint.__init__(self, *args, **kwargs)