def cli(train_data_nc, train_step): with nc.Dataset(train_data_nc, "r") as src: inputs = src["inputs"][:] model = torsk.load_model(train_data_nc.parent, prefix="idx0-") model.params.debug = True model.forward(inputs)
def test_numpy_save_load(tmpdir): params_string = """{ "input_shape": [10, 10], "input_map_specs": [ {"type":"pixels", "size":[10, 10], "input_scale":3}], "reservoir_representation": "dense", "spectral_radius" : 2.0, "density": 1e-1, "train_length": 800, "pred_length": 300, "transient_length": 100, "train_method": "pinv_svd", "dtype": "float64", "backend": "numpy", "debug": false, "imed_loss": true, "timing_depth": 1 } """ params_json = tmpdir.join("params.json") with open(params_json, "w") as dst: dst.write(params_string) params = torsk.Params(params_json) model = NumpyESN(params) inputs = bh.random.uniform(size=[2, 10, 10]).astype(bh.float64) state = bh.random.uniform(size=[100]).astype(bh.float64) _, out1 = model.forward(inputs, state) torsk.save_model(tmpdir, model) model = torsk.load_model(str(tmpdir)) _, out2 = model.forward(inputs, state) assert bh.all(out1 == out2)
def generate_offline_esn_pred(inputs_batch, labels_batch, pred_labels_batch, outdir, hidden_size=512): model = torsk.load_model(outdir) params = esn_params(hidden_size) params.train_length = 200 esn_error = [] print("Generating ESN predictions") for inputs, labels, pred_labels in zip(inputs_batch, labels_batch, pred_labels_batch): zero_state = np.zeros(model.esn_cell.hidden_size) _, states = model.forward(inputs, zero_state, states_only=True) pred, _ = model.predict(labels[-1], states[-1], nr_predictions=params.pred_length) err = np.abs(pred - pred_labels) esn_error.append(err.squeeze()) esn_error = np.mean(esn_error, axis=0) np.save(outdir / "esn_error.npy", esn_error) np.save(outdir / "esn_pred.npy", pred) np.save(outdir / "esn_lbls.npy", pred_labels) inputs = inputs_batch[0] labels = labels_batch[0] pred_labels = pred_labels_batch[0] zero_state = np.zeros(model.esn_cell.hidden_size) _, states = model.forward(inputs, zero_state, states_only=True) pred, _ = model.predict(labels[-1], states[-1], nr_predictions=params.pred_length) return esn_error, pred, pred_labels
if params.backend == "numpy": logger.info("Running with NUMPY backend") from torsk.data.numpy_dataset import NumpyImageDataset as ImageDataset from torsk.models.numpy_esn import NumpyESN as ESN else: logger.info("Running with TORCH backend") from torsk.data.torch_dataset import TorchImageDataset as ImageDataset from torsk.models.torch_esn import TorchESN as ESN npzpath = pathlib.Path( "/mnt/data/torsk_experiments/aguhlas_SSH_3daymean_x1050:1550_y700:1000.npz" ) images = np.load(npzpath)["SSH"][:] images = resample2d_sequence(images, params.input_shape) dataset = ImageDataset(images, params, scale_images=True) prefix = "idx0" if outdir.joinpath(f"{prefix}-model.pkl").exists(): logger.info("Restoring model ...") model = torsk.load_model(outdir, prefix=prefix) else: logger.info("Building model ...") model = ESN(params) logger.info("Training + predicting ...") model, outputs, pred_labels = torsk.train_predict_esn(model, dataset, outdir, steps=1000, step_length=1)