Exemplo n.º 1
0
def load_model(model: Union[tf.keras.Model, torch.nn.Module],
               weights_path: str,
               load_optimizer: bool = False):
    """Load saved weights for a given model.

    This method can be used with TensorFlow models:
    ```python
    m = fe.build(fe.architecture.tensorflow.LeNet, optimizer_fn="adam")
    fe.backend.save_model(m, save_dir="tmp", model_name="test")
    fe.backend.load_model(m, weights_path="tmp/test.h5")
    ```

    This method can be used with PyTorch models:
    ```python
    m = fe.build(fe.architecture.pytorch.LeNet, optimizer_fn="adam")
    fe.backend.save_model(m, save_dir="tmp", model_name="test")
    fe.backend.load_model(m, weights_path="tmp/test.pt")
    ```

    Args:
        model: A neural network instance to load.
        weights_path: Path to the `model` weights.
        load_optimizer: Whether to load optimizer. If True, then it will load <weights_opt> file in the path.

    Raises:
        ValueError: If `model` is an unacceptable data type.
    """
    assert hasattr(
        model,
        "fe_compiled") and model.fe_compiled, "model must be built by fe.build"
    if isinstance(model, tf.keras.Model):
        model.load_weights(weights_path)
        if load_optimizer:
            assert model.current_optimizer, "optimizer does not exist"
            optimizer_path = "{}_opt.pkl".format(
                os.path.splitext(weights_path)[0])
            assert os.path.exists(
                optimizer_path), "cannot find optimizer path: {}".format(
                    optimizer_path)
            with open(optimizer_path, 'rb') as f:
                state_dict = pickle.load(f)
            model.current_optimizer.set_weights(state_dict['weights'])
            set_lr(model, state_dict['lr'])
    elif isinstance(model, torch.nn.Module):
        model.load_state_dict(torch.load(weights_path))
        if load_optimizer:
            assert model.current_optimizer, "optimizer does not exist"
            optimizer_path = "{}_opt.pt".format(
                os.path.splitext(weights_path)[0])
            assert os.path.exists(
                optimizer_path), "cannot find optimizer path: {}".format(
                    optimizer_path)
            model.current_optimizer.load_state_dict(torch.load(optimizer_path))
    else:
        raise ValueError("Unrecognized model instance {}".format(type(model)))
Exemplo n.º 2
0
 def on_epoch_end(self, data: Data) -> None:
     if self.monitor_op(data[self.inputs[0]], self.best):
         self.best = data[self.inputs[0]]
         self.wait = 0
     else:
         self.wait += 1
         if self.wait >= self.patience:
             new_lr = max(self.min_lr, np.float32(self.factor * get_lr(self.model)))
             set_lr(self.model, new_lr)
             self.wait = 0
             data.write_with_log(self.outputs[0], new_lr)
             print("FastEstimator-ReduceLROnPlateau: learning rate reduced to {}".format(new_lr))
Exemplo n.º 3
0
 def on_epoch_begin(self, data: Data) -> None:
     if self.system.mode == "train" and self.schedule_mode == "epoch":
         if isinstance(self.lr_fn, ARC):
             if self.system.epoch_idx > 1 and (self.system.epoch_idx % self.lr_fn.frequency == 1
                                               or self.lr_fn.frequency == 1):
                 multiplier = self.lr_fn.predict_next_multiplier()
                 new_lr = np.float32(get_lr(model=self.model) * multiplier)
                 set_lr(self.model, new_lr)
                 print("FastEstimator-ARC: Multiplying LR by {}".format(multiplier))
         else:
             new_lr = np.float32(self.lr_fn(self.system.epoch_idx))
             set_lr(self.model, new_lr)
Exemplo n.º 4
0
 def on_batch_begin(self, data: Data) -> None:
     if self.system.mode == "train" and self.schedule_mode == "step":
         new_lr = np.float32(self.lr_fn(self.system.global_step))
         set_lr(self.model, new_lr)
Exemplo n.º 5
0
 def on_epoch_begin(self, data: Data) -> None:
     if self.schedule_mode == "epoch":
         new_lr = np.float32(self.lr_fn(self.system.epoch_idx))
         set_lr(self.model, new_lr)