示例#1
0
    def train(self, epochs=1, print_interval=1):
        _train_loss_collector, _val_loss_collector = LossCollector(), LossCollector()
        for pred_lookahead in range(1, self.pred_seq_len + 1):

            print("Collecting dataset for lookahead", pred_lookahead)
            train_x, train_y = self.collect_dataset(
                self.train_data, pred_lookahead=pred_lookahead - 1
            )
            val_x, val_y = self.collect_dataset(
                self.val_data, pred_lookahead=pred_lookahead - 1
            )

            train_x, train_y = self.preprocess(train_x, train_y)
            val_x, val_y = self.preprocess(val_x, val_y)

            print("Training lookahead", pred_lookahead)
            self.fit(pred_lookahead-1, train_x, train_y)

            train_predictions = self.models[pred_lookahead-1].predict(train_x)
            val_predictions = self.models[pred_lookahead-1].predict(val_x)

            train_err = _train_loss_collector.collect(
                torch.from_numpy(train_predictions).float(), train_y
            )
            val_err = _val_loss_collector.collect(
                torch.from_numpy(val_predictions).float(), val_y
            )

            print(
                "lookahead", pred_lookahead, "train:", LossCollector.format(train_err)
            )
            print("lookahead", pred_lookahead, "val:", LossCollector.format(val_err))

            if self.plot:
                self.plot_pred(
                    val_predictions[:200], val_y[:200], lookahead=pred_lookahead
                )

        self.collect_train_metrics(_train_loss_collector.reduce())
        self.collect_val_metrics(_val_loss_collector.reduce())
        self.print_eval_summary()
        return _val_loss_collector
示例#2
0
    def train(self, epochs=1, print_interval=1, val=True, validation_interval=3):
        cur_time = prev_time = time.time()
        for epoch in range(1, epochs + 1):
            self.model.train()

            epoch_loss_collector = LossCollector()

            self.dataset.encoder.horizon_route_node_fts

            # Prepare batches
            if False:
                total_seq_len = self.seq_len + self.pred_seq_len
                restx = torch.zeros(0, total_seq_len, len(self.dataset.encoder.seq_route_node_fts))
                resty = torch.zeros(0, self.pred_seq_len, len(self.dataset.encoder.predict_fts))
                restea = torch.zeros(0, total_seq_len - 1, len(self.dataset.encoder.route_edge_fts))
                resttei = torch.zeros(0, total_seq_len, 2)

                for data in self.train_data:
                    if data.x is None:
                        continue
                    x = torch.cat([restx, data.x], dim=0)
                    y = torch.cat([resty, data.y], dim=0)
                    ea = torch.cat([restea, data.temporal_edge_attr], dim=0)
                    
                    for b_s in range(0, x.size(0), self.batch_size):
                        batchx = x[b_s : b_s + self.batch_size, :, :]
                        batchy = y[b_s : b_s + self.batch_size, :, :]
                        batchea = ea[b_s : b_s + self.batch_size, :, :]
                        batchx = batchx.to(self.device)
                        batchy = batchy.to(self.device)
                        batchea = batchea.to(self.device)

                        self.optimizer.zero_grad()
                        outputs, expected = self.feed(None, batchx, batchy, batchea)

                        
                        mask = data.current_transports[:outputs.size(0), -self.pred_seq_len :]
                        outputs[~mask] = 0
                        expected[~mask] = 0

                        l1_regularization = 0.
                        for param in self.model.parameters():
                            l1_regularization += param.abs().sum()
                        
                        loss = self.loss(outputs, expected)

                        loss = loss + self.l1_reg * l1_regularization

                        epoch_loss_collector.collect(outputs, expected)

                        # Compute gradients
                        loss.backward()

                        # Update parameters
                        self.optimizer.step()

                    restx = x[x.size(0) % self.batch_size :, :, :]
                    resty = y[y.size(0) % self.batch_size :, :, :]
                    restea = ea[ea.size(0) % self.batch_size :, :, :]
            else:
                for data in self.train_data:
                    if data.x is None:
                        continue

                    data = data.to(self.device)
                
                    self.optimizer.zero_grad()
                    outputs, expected = self.feed(data, data.x, data.y, data.temporal_edge_attr)

                    mask = data.current_transports[:outputs.size(0), -self.pred_seq_len :]
                    outputs[~mask] = 0
                    expected[~mask] = 0
                    
                    l1_regularization = 0.
                    for param in self.model.parameters():
                        l1_regularization += param.abs().sum()
                    
                    loss = self.loss(outputs, expected)
                    loss = loss + self.l1_reg * l1_regularization

                    epoch_loss_collector.collect(outputs, expected)

                    # Compute gradients
                    loss.backward()

                    # Update parameters
                    self.optimizer.step()

            if self.lr_scheduler:
                self.lr_scheduler.step()
            loss = epoch_loss_collector.reduce()
            self.collect_train_metrics(loss)

            cur_time = time.time()
            log = "Epoch: {:03d}, Train: {:.4f} ({:.4f}) Val: {:.4f} ({:.4f}) Acc: {:.4f} Time: {:.1f}s"
            if epoch % print_interval == 0:
                val_loss = dict(mse=-1, mae=-1, acc=-1)
                if val and epoch % validation_interval == 0:
                    # Validate
                    val_acc, val_loss = self.test()
                print(
                    log.format(
                        epoch,
                        loss["mse"], loss["mae"],
                        val_loss["mse"], val_loss["mae"],
                        val_loss["acc"],
                        cur_time - prev_time,
                    )
                )
            prev_time = cur_time

        self.print_eval_summary()
示例#3
0
    def bptt_train(self, epochs=1, print_interval=1, k1=1, k2=3, seq=5, skip=None, debug=False):
        cur_time = prev_time = time.time()
        torch.set_anomaly_enabled(debug)
        retain_graph = k1 < k2

        for epoch in range(1, epochs + 1):
            self.model.train()

            net_state_hidden, net_cell_states = self.init_rnn_state()
            epoch_loss_collector = LossCollector()
            states = [(None, self.init_rnn_state())]

            total_loss = 0
            for i, data in enumerate(self.train_data):
                if skip and i % skip != 0:
                    continue

                if data.x is None or torch.isnan(data.x).any():
                    # print("Skipping", i, data.x, torch.isnan(data.x).any(), data.x is None)
                    continue

                data = data.to(self.device)
                self.optimizer.zero_grad()

                # Compute gradients
                outputs, expected, net_state_hidden, net_cell_states = self.feed(data, net_state_hidden, net_cell_states)
                
                def repackage_hidden(h):
                    """Wraps hidden states in new Tensors, to detach them from their history."""

                    if isinstance(h, torch.Tensor):
                        return h.detach()
                    else:
                        return tuple(repackage_hidden(v) for v in h)

                mask = data.current_transports[:, -self.pred_seq_len:]
                outputs[~mask] = 0
                expected[~mask] = 0

                l1_regularization = 0.
                for param in self.model.parameters():
                    l1_regularization = l1_regularization + param.abs().sum()
                
                loss = self.loss(outputs, expected)

                loss = loss + self.l1_reg * l1_regularization

                total_loss = total_loss + loss

                total_norm = 0
                for p in self.model.parameters():
                    if p.grad is None:
                        continue
                    param_norm = p.grad.data.norm(2)
                    total_norm = total_norm + param_norm.item() ** 2
                total_norm = total_norm ** (1. / 2)
                # print("Paramm norm:", total_norm)

                if i % seq == 0:
                    total_loss.backward()
                    epoch_loss_collector.collect(outputs, expected)
                    # Update parameters
                    nn.utils.clip_grad_norm_(self.model.parameters(), 1)
                    self.optimizer.step()

                    # Cut the gradient graph
                    net_state_hidden = repackage_hidden(net_state_hidden)
                    net_cell_states = repackage_hidden(net_cell_states)
                    total_loss = 0

            loss = epoch_loss_collector.reduce()
            self.collect_train_metrics(loss)

            cur_time = time.time()
            log = "Epoch: {:03d}, Train: {:.4f} ({:.4f}) Val: {:.4f} ({:.4f}) Acc: {:.4f}, Time: {:.1f}s"
            if epoch % print_interval == 0:
                val_loss = dict(mse=-1, mae=-1, acc=-1)
                val_acc = [-1]
                if epoch % val_interval == 0:
                    # Validate, this uses another hidden state for the model
                    val_acc, val_loss = self.test()
                print(
                    log.format(
                        epoch,
                        loss["mse"], loss["mae"],
                        val_loss["mse"], val_loss["mae"],
                        val_loss["acc"],
                        cur_time - prev_time,
                    )
                )
            prev_time = cur_time

        self.print_eval_summary()
示例#4
0
    def train(self, epochs=1, print_interval=1, val=True, debug=False):
        cur_time = prev_time = time.time()
        torch.set_anomaly_enabled(debug)

        net_state_hidden, net_cell_states = self.init_rnn_state()

        for epoch in range(1, epochs + 1):
            self.model.train()

            epoch_loss_collector = LossCollector()

            for i, data in enumerate(self.train_data):
                # print("%d of %d" % (i, len(self.train_data)))

                if data.x is None or torch.isnan(data.x).any():
                    # print("Skipping", i, data.x, torch.isnan(data.x).any(), data.x is None)
                    continue

                data = data.to(self.device)
                self.optimizer.zero_grad()

                # Compute gradients
                outputs, expected, net_state_hidden, net_cell_states = self.feed(data, net_state_hidden, net_cell_states)
                mask = data.current_transports[:, -self.pred_seq_len :]
                # print(outputs.shape, expected.shape)
                outputs[~mask] = 0
                expected[~mask] = 0
                
                l1_regularization = 0.
                if True:
                    for param in self.model.parameters():
                        l1_regularization = l1_regularization + param.abs().sum()
                
                assert not torch.isnan(expected).any()
                loss = self.loss(outputs, expected)
                
                loss = loss + self.l1_reg * l1_regularization

                loss.backward()

                epoch_loss_collector.collect(outputs, expected)

                total_norm = 0
                for p in self.model.parameters():
                    param_norm = p.grad.data.norm(2)
                    total_norm = total_norm + param_norm.item() ** 2
                total_norm = total_norm ** (1. / 2)
                assert total_norm == total_norm
                # print("Param norm:", total_norm)

                # Update parameters
                # nn.utils.clip_grad_norm_(self.model.parameters(), 10e0)
                self.optimizer.step()

            if self.lr_scheduler:
                self.lr_scheduler.step()
            loss = epoch_loss_collector.reduce()
            self.collect_train_metrics(loss)

            cur_time = time.time()
            log = "Epoch: {:03d}, Train: {:.4f} ({:.4f}) Val: {:.4f} ({:.4f}) Acc: {:.4f}, Time: {:.1f}s"
            if epoch % print_interval == 0:
                val_loss = dict(mse=-1, mae=-1, acc=-1)
                if val and epoch % val_interval == 0:
                    # Validate
                    val_acc, val_loss = self.test()
                print(
                    log.format(
                        epoch,
                        loss["mse"], loss["mae"],
                        val_loss["mse"], val_loss["mae"],
                        val_loss["acc"],
                        cur_time - prev_time,
                    )
                )
            prev_time = cur_time

        self.print_eval_summary()