コード例 #1
0
ファイル: train.py プロジェクト: romnn/rail-stgcnn
    def process_model(model_cls, params):
        model = model_cls(**params)
        if train:
            cache = "%s_norm_%d_%d" % (model.dataset.name, batch_hours, limit)
            if not (ts or avg):
                print("fitting normalization")
                z_score_norm = Scaler.fit(model.train_data,
                                          normalize=normalize_func,
                                          attrs=dict(
                                              temporal_edge_attr=1,
                                              x=1,
                                              y=1,
                                          ),
                                          cache=cache)
                model.dataset.transform = z_score_norm
                model.init_loaders()
                print("done fitting normalization")
            train_losses = model.train(epochs=epochs)
            model.save()
            if train_losses and False:
                # Plot loss curve
                plt.plot(train_losses)
                plt.savefig(
                    os.path.join(models_base_path, model.name + "_loss.pdf"),
                    format="pdf",
                    dpi=600,
                )
        elif search:
            model_cls.hyperparameter_search(**params)
        else:
            # Load the model
            try:
                model.load()
            except FileNotFoundError:
                print(
                    "No trained model to load. Train one first using --train")

        if evaluate:
            print("Testing the model")
            if not (ts or avg):
                print("fitting normalization")
                cache = "%s_norm_%d_%d" % (model.dataset.name, batch_hours,
                                           limit)
                z_score_norm = Scaler.fit(model.train_data,
                                          normalize=normalize_func,
                                          attrs=dict(
                                              temporal_edge_attr=1,
                                              x=1,
                                              y=1,
                                          ),
                                          cache=cache)
                model.dataset.transform = z_score_norm
                model.init_loaders()
                print("done fitting normalization")
            val_accs, val_losses = model.test()
            print(LossCollector.format(val_losses))
            plot_len = 200
            model.plot_primitive_prediction("val", val_losses["ys"][:plot_len],
                                            val_losses["xs"][:plot_len])
コード例 #2
0
    def train(self, epochs=1, print_interval=1):
        _train_loss_collector, _val_loss_collector = LossCollector(), LossCollector()
        for pred_lookahead in range(1, self.pred_seq_len + 1):

            print("Collecting dataset for lookahead", pred_lookahead)
            train_x, train_y = self.collect_dataset(
                self.train_data, pred_lookahead=pred_lookahead - 1
            )
            val_x, val_y = self.collect_dataset(
                self.val_data, pred_lookahead=pred_lookahead - 1
            )

            train_x, train_y = self.preprocess(train_x, train_y)
            val_x, val_y = self.preprocess(val_x, val_y)

            print("Training lookahead", pred_lookahead)
            self.fit(pred_lookahead-1, train_x, train_y)

            train_predictions = self.models[pred_lookahead-1].predict(train_x)
            val_predictions = self.models[pred_lookahead-1].predict(val_x)

            train_err = _train_loss_collector.collect(
                torch.from_numpy(train_predictions).float(), train_y
            )
            val_err = _val_loss_collector.collect(
                torch.from_numpy(val_predictions).float(), val_y
            )

            print(
                "lookahead", pred_lookahead, "train:", LossCollector.format(train_err)
            )
            print("lookahead", pred_lookahead, "val:", LossCollector.format(val_err))

            if self.plot:
                self.plot_pred(
                    val_predictions[:200], val_y[:200], lookahead=pred_lookahead
                )

        self.collect_train_metrics(_train_loss_collector.reduce())
        self.collect_val_metrics(_val_loss_collector.reduce())
        self.print_eval_summary()
        return _val_loss_collector