コード例 #1
0
def cli(train_data_nc, train_step):

    with nc.Dataset(train_data_nc, "r") as src:
        inputs = src["inputs"][:]

    model = torsk.load_model(train_data_nc.parent, prefix="idx0-")
    model.params.debug = True
    model.forward(inputs)
コード例 #2
0
ファイル: test_save_load.py プロジェクト: nmheim/torsk
def test_numpy_save_load(tmpdir):

    params_string = """{
      "input_shape": [10, 10],
      "input_map_specs": [
        {"type":"pixels", "size":[10, 10], "input_scale":3}],
      "reservoir_representation": "dense",
      "spectral_radius" : 2.0,
      "density": 1e-1,

      "train_length": 800,
      "pred_length": 300,
      "transient_length": 100,
      "train_method": "pinv_svd",

      "dtype": "float64",
      "backend": "numpy",
      "debug": false,
      "imed_loss": true,
      "timing_depth": 1
    }
    """
    params_json = tmpdir.join("params.json")
    with open(params_json, "w") as dst:
        dst.write(params_string)

    params = torsk.Params(params_json)
    model = NumpyESN(params)
    inputs = bh.random.uniform(size=[2, 10, 10]).astype(bh.float64)
    state = bh.random.uniform(size=[100]).astype(bh.float64)

    _, out1 = model.forward(inputs, state)

    torsk.save_model(tmpdir, model)

    model = torsk.load_model(str(tmpdir))
    _, out2 = model.forward(inputs, state)

    assert bh.all(out1 == out2)
コード例 #3
0
def generate_offline_esn_pred(inputs_batch,
                              labels_batch,
                              pred_labels_batch,
                              outdir,
                              hidden_size=512):
    model = torsk.load_model(outdir)
    params = esn_params(hidden_size)
    params.train_length = 200
    esn_error = []
    print("Generating ESN predictions")
    for inputs, labels, pred_labels in zip(inputs_batch, labels_batch,
                                           pred_labels_batch):

        zero_state = np.zeros(model.esn_cell.hidden_size)
        _, states = model.forward(inputs, zero_state, states_only=True)
        pred, _ = model.predict(labels[-1],
                                states[-1],
                                nr_predictions=params.pred_length)
        err = np.abs(pred - pred_labels)
        esn_error.append(err.squeeze())

    esn_error = np.mean(esn_error, axis=0)
    np.save(outdir / "esn_error.npy", esn_error)
    np.save(outdir / "esn_pred.npy", pred)
    np.save(outdir / "esn_lbls.npy", pred_labels)

    inputs = inputs_batch[0]
    labels = labels_batch[0]
    pred_labels = pred_labels_batch[0]
    zero_state = np.zeros(model.esn_cell.hidden_size)
    _, states = model.forward(inputs, zero_state, states_only=True)
    pred, _ = model.predict(labels[-1],
                            states[-1],
                            nr_predictions=params.pred_length)

    return esn_error, pred, pred_labels
コード例 #4
0
ファイル: conv_run_3daymean.py プロジェクト: nmheim/torsk
if params.backend == "numpy":
    logger.info("Running with NUMPY backend")
    from torsk.data.numpy_dataset import NumpyImageDataset as ImageDataset
    from torsk.models.numpy_esn import NumpyESN as ESN
else:
    logger.info("Running with TORCH backend")
    from torsk.data.torch_dataset import TorchImageDataset as ImageDataset
    from torsk.models.torch_esn import TorchESN as ESN

npzpath = pathlib.Path(
    "/mnt/data/torsk_experiments/aguhlas_SSH_3daymean_x1050:1550_y700:1000.npz"
)
images = np.load(npzpath)["SSH"][:]
images = resample2d_sequence(images, params.input_shape)
dataset = ImageDataset(images, params, scale_images=True)

prefix = "idx0"
if outdir.joinpath(f"{prefix}-model.pkl").exists():
    logger.info("Restoring model ...")
    model = torsk.load_model(outdir, prefix=prefix)
else:
    logger.info("Building model ...")
    model = ESN(params)

logger.info("Training + predicting ...")
model, outputs, pred_labels = torsk.train_predict_esn(model,
                                                      dataset,
                                                      outdir,
                                                      steps=1000,
                                                      step_length=1)