def process_model(model_cls, params): model = model_cls(**params) if train: cache = "%s_norm_%d_%d" % (model.dataset.name, batch_hours, limit) if not (ts or avg): print("fitting normalization") z_score_norm = Scaler.fit(model.train_data, normalize=normalize_func, attrs=dict( temporal_edge_attr=1, x=1, y=1, ), cache=cache) model.dataset.transform = z_score_norm model.init_loaders() print("done fitting normalization") train_losses = model.train(epochs=epochs) model.save() if train_losses and False: # Plot loss curve plt.plot(train_losses) plt.savefig( os.path.join(models_base_path, model.name + "_loss.pdf"), format="pdf", dpi=600, ) elif search: model_cls.hyperparameter_search(**params) else: # Load the model try: model.load() except FileNotFoundError: print( "No trained model to load. Train one first using --train") if evaluate: print("Testing the model") if not (ts or avg): print("fitting normalization") cache = "%s_norm_%d_%d" % (model.dataset.name, batch_hours, limit) z_score_norm = Scaler.fit(model.train_data, normalize=normalize_func, attrs=dict( temporal_edge_attr=1, x=1, y=1, ), cache=cache) model.dataset.transform = z_score_norm model.init_loaders() print("done fitting normalization") val_accs, val_losses = model.test() print(LossCollector.format(val_losses)) plot_len = 200 model.plot_primitive_prediction("val", val_losses["ys"][:plot_len], val_losses["xs"][:plot_len])
def train(self, epochs=1, print_interval=1): _train_loss_collector, _val_loss_collector = LossCollector(), LossCollector() for pred_lookahead in range(1, self.pred_seq_len + 1): print("Collecting dataset for lookahead", pred_lookahead) train_x, train_y = self.collect_dataset( self.train_data, pred_lookahead=pred_lookahead - 1 ) val_x, val_y = self.collect_dataset( self.val_data, pred_lookahead=pred_lookahead - 1 ) train_x, train_y = self.preprocess(train_x, train_y) val_x, val_y = self.preprocess(val_x, val_y) print("Training lookahead", pred_lookahead) self.fit(pred_lookahead-1, train_x, train_y) train_predictions = self.models[pred_lookahead-1].predict(train_x) val_predictions = self.models[pred_lookahead-1].predict(val_x) train_err = _train_loss_collector.collect( torch.from_numpy(train_predictions).float(), train_y ) val_err = _val_loss_collector.collect( torch.from_numpy(val_predictions).float(), val_y ) print( "lookahead", pred_lookahead, "train:", LossCollector.format(train_err) ) print("lookahead", pred_lookahead, "val:", LossCollector.format(val_err)) if self.plot: self.plot_pred( val_predictions[:200], val_y[:200], lookahead=pred_lookahead ) self.collect_train_metrics(_train_loss_collector.reduce()) self.collect_val_metrics(_val_loss_collector.reduce()) self.print_eval_summary() return _val_loss_collector