def xfit(self, train_loader, test_loader, restore_file=None): if restore_file is not None: restore_path = os.path.join(self.params.model_dir, restore_file + '.pth.tar') logger.info('Restoring parameters from {}'.format(restore_path)) load_checkpoint(restore_path, self, self.optimizer) logger.info('begin training and evaluation') best_test_ND = float('inf') train_len = len(train_loader) ND_summary = np.zeros(self.params.num_epochs) loss_summary = np.zeros((train_len * self.params.num_epochs)) for epoch in range(self.params.num_epochs): logger.info('Epoch {}/{}'.format(epoch + 1, self.params.num_epochs)) # test_len = len(test_loader) # print(test_len) # loss_summary[epoch * train_len:(epoch + 1) * train_len] = train(model, optimizer, loss_fn, train_loader, test_loader, self.params, epoch) loss_summary[epoch * train_len:(epoch + 1) * train_len] = self.fit_epoch(train_loader, test_loader, epoch) # todo test_metrics = self.evaluate(test_loader, epoch, sample=self.params.sampling) ND_summary[epoch] = test_metrics['ND'] is_best = ND_summary[epoch] <= best_test_ND # Save weights save_checkpoint( { 'epoch': epoch + 1, 'state_dict': self.state_dict(), 'optim_dict': self.optimizer.state_dict() }, epoch=epoch, is_best=is_best, checkpoint=self.params.model_dir) if is_best: logger.info('- Found new best ND') best_test_ND = ND_summary[epoch] best_json_path = os.path.join( self.params.model_dir, 'metrics_test_best_weights.json') save_dict_to_json(test_metrics, best_json_path) logger.info('Current Best ND is: %.5f' % best_test_ND) plot_all_epoch(ND_summary[:epoch + 1], self.params.dataset + '_ND', self.params.plot_dir) plot_all_epoch(loss_summary[:(epoch + 1) * train_len], self.params.dataset + '_loss', self.params.plot_dir) last_json_path = os.path.join(self.params.model_dir, 'metrics_test_last_weights.json') save_dict_to_json(test_metrics, last_json_path)
def predict(self, x, using_best=True): ''' x: (numpy.narray) shape: [sample, full-len, dim] return: (numpy.narray) shape: [sample, prediction-len] ''' # test_batch: shape: [full-len, sample, dim] best_pth = os.path.join(self.params.model_dir, 'best.pth.tar') if os.path.exists(best_pth) and using_best: self.logger.info( 'Restoring best parameters from {}'.format(best_pth)) load_checkpoint(best_pth, self, self.optimizer) x = torch.tensor(x).to(torch.float32).to(self.params.device) output = self(x) pred = output.detach().cpu().numpy() return pred
def point_predict(self, x): ''' x: (torch.Tensor) shape: [sample, full-len, dim] return: (numpy.narray) shape: [sample, prediction-len] ''' # test_batch: shape: [full-len, sample, dim] best_pth = os.path.join(self.params.model_dir, 'best.pth.tar') if os.path.exists(best_pth): logger.info('Restoring best parameters from {}'.format(best_pth)) load_checkpoint(best_pth, self, self.optimizer) test_batch = x.permute(1, 0, 2).to(torch.float32).to(self.params.device) batch_size = test_batch.shape[1] input_mu = torch.zeros(batch_size, self.params.predict_start, device=self.params.device) # scaled input_sigma = torch.zeros(batch_size, self.params.predict_start, device=self.params.device) # scaled hidden = self.init_hidden(batch_size) cell = self.init_cell(batch_size) prediction = torch.zeros(batch_size, self.params.predict_steps, device=self.params.device) for t in range(self.params.predict_start): # if z_t is missing, replace it by output mu from the last time step zero_index = (test_batch[t, :, 0] == 0) if t > 0 and torch.sum(zero_index) > 0: test_batch[t, zero_index, 0] = mu[zero_index] mu, sigma, hidden, cell = self(test_batch[t].unsqueeze(0), hidden, cell) input_mu[:, t] = mu input_sigma[:, t] = sigma for t in range(self.params.predict_steps): mu_de, sigma_de, hidden, cell = self( test_batch[self.params.predict_start + t].unsqueeze(0), hidden, cell) prediction[:, t] = mu_de return prediction.cpu().detach().numpy()
def xfit(self, train_loader, val_loader, restore_file=None): # update self.params if restore_file is not None and os.path.exists( restore_file) and self.params.restore: self.logger.info( 'Restoring parameters from {}'.format(restore_file)) load_checkpoint(restore_file, self, self.optimizer) min_vmse = 9999 train_len = len(train_loader) loss_summary = np.zeros((train_len * self.params.num_epochs)) loss_avg = np.zeros((self.params.num_epochs)) vloss_avg = np.zeros_like(loss_avg) for epoch in trange(self.params.num_epochs): self.logger.info('Epoch {}/{}'.format(epoch + 1, self.params.num_epochs)) mse_train = 0 loss_epoch = np.zeros(train_len) for i, (batch_x, batch_y) in enumerate(train_loader): batch_x = batch_x.to(torch.float32).to(self.params.device) batch_y = batch_y.to(torch.float32).to(self.params.device) self.optimizer.zero_grad() y_pred = self(batch_x) y_pred = y_pred.squeeze(1) loss = self.loss_fn(y_pred, batch_y) loss.backward() mse_train += loss.item() loss_epoch[i] = loss.item() self.optimizer.step() mse_train = mse_train / train_len loss_summary[epoch * train_len:(epoch + 1) * train_len] = loss_epoch loss_avg[epoch] = mse_train self.epoch_scheduler.step() with torch.no_grad(): mse_val = 0 preds = [] true = [] for batch_x, batch_y in val_loader: batch_x = batch_x.to(torch.float32).to(self.params.device) batch_y = batch_y.to(torch.float32).to(self.params.device) output = self(batch_x) output = output.squeeze(1) preds.append(output.detach().cpu().numpy()) true.append(batch_y.detach().cpu().numpy()) mse_val += self.loss_fn(output, batch_y).item() mse_val = mse_val / len(val_loader) vloss_avg[epoch] = mse_val preds = np.concatenate(preds) true = np.concatenate(true) self.logger.info( 'Current training loss: {:.4f} \t validating loss: {:.4f}'. format(mse_train, mse_val)) vmse = mean_squared_error(true, preds) self.logger.info('Current vmse: {:.4f}'.format(vmse)) if vmse < min_vmse: min_vmse = vmse self.logger.info('Found new best state') savebest_checkpoint( { 'epoch': epoch, 'cv': self.params.cv, 'state_dict': self.state_dict(), 'optim_dict': self.optimizer.state_dict() }, checkpoint=self.params.model_dir) self.logger.info('Checkpoint saved to {}'.format( self.params.model_dir)) self.logger.info('Best vmse: {:.4f}'.format(min_vmse)) plot_all_epoch(loss_summary[:(epoch + 1) * train_len], self.params.dataset + '_loss', self.params.plot_dir) plot_xfit(loss_avg, vloss_avg, self.params.dataset + '_loss', self.params.plot_dir)