def train(self, epochs=1, print_interval=1): _train_loss_collector, _val_loss_collector = LossCollector(), LossCollector() for pred_lookahead in range(1, self.pred_seq_len + 1): print("Collecting dataset for lookahead", pred_lookahead) train_x, train_y = self.collect_dataset( self.train_data, pred_lookahead=pred_lookahead - 1 ) val_x, val_y = self.collect_dataset( self.val_data, pred_lookahead=pred_lookahead - 1 ) train_x, train_y = self.preprocess(train_x, train_y) val_x, val_y = self.preprocess(val_x, val_y) print("Training lookahead", pred_lookahead) self.fit(pred_lookahead-1, train_x, train_y) train_predictions = self.models[pred_lookahead-1].predict(train_x) val_predictions = self.models[pred_lookahead-1].predict(val_x) train_err = _train_loss_collector.collect( torch.from_numpy(train_predictions).float(), train_y ) val_err = _val_loss_collector.collect( torch.from_numpy(val_predictions).float(), val_y ) print( "lookahead", pred_lookahead, "train:", LossCollector.format(train_err) ) print("lookahead", pred_lookahead, "val:", LossCollector.format(val_err)) if self.plot: self.plot_pred( val_predictions[:200], val_y[:200], lookahead=pred_lookahead ) self.collect_train_metrics(_train_loss_collector.reduce()) self.collect_val_metrics(_val_loss_collector.reduce()) self.print_eval_summary() return _val_loss_collector
def train(self, epochs=1, print_interval=1, val=True, validation_interval=3): cur_time = prev_time = time.time() for epoch in range(1, epochs + 1): self.model.train() epoch_loss_collector = LossCollector() self.dataset.encoder.horizon_route_node_fts # Prepare batches if False: total_seq_len = self.seq_len + self.pred_seq_len restx = torch.zeros(0, total_seq_len, len(self.dataset.encoder.seq_route_node_fts)) resty = torch.zeros(0, self.pred_seq_len, len(self.dataset.encoder.predict_fts)) restea = torch.zeros(0, total_seq_len - 1, len(self.dataset.encoder.route_edge_fts)) resttei = torch.zeros(0, total_seq_len, 2) for data in self.train_data: if data.x is None: continue x = torch.cat([restx, data.x], dim=0) y = torch.cat([resty, data.y], dim=0) ea = torch.cat([restea, data.temporal_edge_attr], dim=0) for b_s in range(0, x.size(0), self.batch_size): batchx = x[b_s : b_s + self.batch_size, :, :] batchy = y[b_s : b_s + self.batch_size, :, :] batchea = ea[b_s : b_s + self.batch_size, :, :] batchx = batchx.to(self.device) batchy = batchy.to(self.device) batchea = batchea.to(self.device) self.optimizer.zero_grad() outputs, expected = self.feed(None, batchx, batchy, batchea) mask = data.current_transports[:outputs.size(0), -self.pred_seq_len :] outputs[~mask] = 0 expected[~mask] = 0 l1_regularization = 0. for param in self.model.parameters(): l1_regularization += param.abs().sum() loss = self.loss(outputs, expected) loss = loss + self.l1_reg * l1_regularization epoch_loss_collector.collect(outputs, expected) # Compute gradients loss.backward() # Update parameters self.optimizer.step() restx = x[x.size(0) % self.batch_size :, :, :] resty = y[y.size(0) % self.batch_size :, :, :] restea = ea[ea.size(0) % self.batch_size :, :, :] else: for data in self.train_data: if data.x is None: continue data = data.to(self.device) self.optimizer.zero_grad() outputs, expected = self.feed(data, data.x, data.y, data.temporal_edge_attr) mask = data.current_transports[:outputs.size(0), -self.pred_seq_len :] outputs[~mask] = 0 expected[~mask] = 0 l1_regularization = 0. for param in self.model.parameters(): l1_regularization += param.abs().sum() loss = self.loss(outputs, expected) loss = loss + self.l1_reg * l1_regularization epoch_loss_collector.collect(outputs, expected) # Compute gradients loss.backward() # Update parameters self.optimizer.step() if self.lr_scheduler: self.lr_scheduler.step() loss = epoch_loss_collector.reduce() self.collect_train_metrics(loss) cur_time = time.time() log = "Epoch: {:03d}, Train: {:.4f} ({:.4f}) Val: {:.4f} ({:.4f}) Acc: {:.4f} Time: {:.1f}s" if epoch % print_interval == 0: val_loss = dict(mse=-1, mae=-1, acc=-1) if val and epoch % validation_interval == 0: # Validate val_acc, val_loss = self.test() print( log.format( epoch, loss["mse"], loss["mae"], val_loss["mse"], val_loss["mae"], val_loss["acc"], cur_time - prev_time, ) ) prev_time = cur_time self.print_eval_summary()
def bptt_train(self, epochs=1, print_interval=1, k1=1, k2=3, seq=5, skip=None, debug=False): cur_time = prev_time = time.time() torch.set_anomaly_enabled(debug) retain_graph = k1 < k2 for epoch in range(1, epochs + 1): self.model.train() net_state_hidden, net_cell_states = self.init_rnn_state() epoch_loss_collector = LossCollector() states = [(None, self.init_rnn_state())] total_loss = 0 for i, data in enumerate(self.train_data): if skip and i % skip != 0: continue if data.x is None or torch.isnan(data.x).any(): # print("Skipping", i, data.x, torch.isnan(data.x).any(), data.x is None) continue data = data.to(self.device) self.optimizer.zero_grad() # Compute gradients outputs, expected, net_state_hidden, net_cell_states = self.feed(data, net_state_hidden, net_cell_states) def repackage_hidden(h): """Wraps hidden states in new Tensors, to detach them from their history.""" if isinstance(h, torch.Tensor): return h.detach() else: return tuple(repackage_hidden(v) for v in h) mask = data.current_transports[:, -self.pred_seq_len:] outputs[~mask] = 0 expected[~mask] = 0 l1_regularization = 0. for param in self.model.parameters(): l1_regularization = l1_regularization + param.abs().sum() loss = self.loss(outputs, expected) loss = loss + self.l1_reg * l1_regularization total_loss = total_loss + loss total_norm = 0 for p in self.model.parameters(): if p.grad is None: continue param_norm = p.grad.data.norm(2) total_norm = total_norm + param_norm.item() ** 2 total_norm = total_norm ** (1. / 2) # print("Paramm norm:", total_norm) if i % seq == 0: total_loss.backward() epoch_loss_collector.collect(outputs, expected) # Update parameters nn.utils.clip_grad_norm_(self.model.parameters(), 1) self.optimizer.step() # Cut the gradient graph net_state_hidden = repackage_hidden(net_state_hidden) net_cell_states = repackage_hidden(net_cell_states) total_loss = 0 loss = epoch_loss_collector.reduce() self.collect_train_metrics(loss) cur_time = time.time() log = "Epoch: {:03d}, Train: {:.4f} ({:.4f}) Val: {:.4f} ({:.4f}) Acc: {:.4f}, Time: {:.1f}s" if epoch % print_interval == 0: val_loss = dict(mse=-1, mae=-1, acc=-1) val_acc = [-1] if epoch % val_interval == 0: # Validate, this uses another hidden state for the model val_acc, val_loss = self.test() print( log.format( epoch, loss["mse"], loss["mae"], val_loss["mse"], val_loss["mae"], val_loss["acc"], cur_time - prev_time, ) ) prev_time = cur_time self.print_eval_summary()
def train(self, epochs=1, print_interval=1, val=True, debug=False): cur_time = prev_time = time.time() torch.set_anomaly_enabled(debug) net_state_hidden, net_cell_states = self.init_rnn_state() for epoch in range(1, epochs + 1): self.model.train() epoch_loss_collector = LossCollector() for i, data in enumerate(self.train_data): # print("%d of %d" % (i, len(self.train_data))) if data.x is None or torch.isnan(data.x).any(): # print("Skipping", i, data.x, torch.isnan(data.x).any(), data.x is None) continue data = data.to(self.device) self.optimizer.zero_grad() # Compute gradients outputs, expected, net_state_hidden, net_cell_states = self.feed(data, net_state_hidden, net_cell_states) mask = data.current_transports[:, -self.pred_seq_len :] # print(outputs.shape, expected.shape) outputs[~mask] = 0 expected[~mask] = 0 l1_regularization = 0. if True: for param in self.model.parameters(): l1_regularization = l1_regularization + param.abs().sum() assert not torch.isnan(expected).any() loss = self.loss(outputs, expected) loss = loss + self.l1_reg * l1_regularization loss.backward() epoch_loss_collector.collect(outputs, expected) total_norm = 0 for p in self.model.parameters(): param_norm = p.grad.data.norm(2) total_norm = total_norm + param_norm.item() ** 2 total_norm = total_norm ** (1. / 2) assert total_norm == total_norm # print("Param norm:", total_norm) # Update parameters # nn.utils.clip_grad_norm_(self.model.parameters(), 10e0) self.optimizer.step() if self.lr_scheduler: self.lr_scheduler.step() loss = epoch_loss_collector.reduce() self.collect_train_metrics(loss) cur_time = time.time() log = "Epoch: {:03d}, Train: {:.4f} ({:.4f}) Val: {:.4f} ({:.4f}) Acc: {:.4f}, Time: {:.1f}s" if epoch % print_interval == 0: val_loss = dict(mse=-1, mae=-1, acc=-1) if val and epoch % val_interval == 0: # Validate val_acc, val_loss = self.test() print( log.format( epoch, loss["mse"], loss["mae"], val_loss["mse"], val_loss["mae"], val_loss["acc"], cur_time - prev_time, ) ) prev_time = cur_time self.print_eval_summary()