N, D_in, H, D_out = 64, 1000, h, 10 # Create random Tensors to hold inputs and outputs x = torch.randn(N, D_in) y = torch.randn(N, D_out) model = TwoLayerNet(D_in, H, D_out) x, y, model = x.to(device), y.to(device), model.to(device) layers = [model.linear1, model.linear2] stats = CheckLayerSat('regression/h{}'.format(h), layers) loss_fn = torch.nn.MSELoss(size_average=False) optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9) steps_iter = trange(2000, desc='steps', leave=True, position=0) steps_iter.write("{:^80}".format( "Regression - TwoLayerNet - Hidden layer size {}".format(h))) for _ in steps_iter: y_pred = model(x) loss = loss_fn(y_pred, y) steps_iter.set_description('loss=%g' % loss.data) optimizer.zero_grad() loss.backward() optimizer.step() stats.saturation() steps_iter.write('\n') stats.close() steps_iter.close()
def train(network, dataset, test_set, logging_dir, batch_size): network.to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(network.parameters()) #stats = CheckLayerSat(logging_dir, network, log_interval=len(dataset)//batch_size) stats = CheckLayerSat(logging_dir, network, log_interval=60, sat_method='cumvar99', conv_method='mean') epoch_acc = 0 thresh = 0.95 epoch = 0 total = 0 correct = 0 value_dict = None while epoch <= 20: print('Start Training Epoch', epoch, '\n') start = t.time() epoch_acc = 0 train_loss = 0 total = 0 correct = 0 network.train() for i, data in enumerate(dataset): step = epoch * len(dataset) + i inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() outputs = network(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() train_loss += loss.item() _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() #if i % 2000 == 1999: # print every 2000 mini-batches print(i, 'of', len(dataset), 'acc:', correct / total) # display layer saturation levels end = t.time() stats.saturation() test_loss, test_acc = test(network, test_set, criterion, stats, epoch) epoch_acc = correct / total print('Epoch', epoch, 'finished', 'Acc:', epoch_acc, 'Loss:', train_loss / total, '\n') stats.add_scalar('train_loss', train_loss / total, epoch) # optional stats.add_scalar('train_acc', epoch_acc, epoch) # optional value_dict = record_metrics(value_dict, stats.logs, epoch_acc, train_loss / total, test_acc, test_loss, epoch, (end - start) / total) log_to_csv(value_dict, logging_dir) epoch += 1 stats.close() # test_stats.close() return criterion
class Trainer: def __init__( self, model: nn.Module, logger: Logger, prefix: str = "", checkpoint_dir: Union[str, None] = None, summary_dir: Union[str, None] = None, n_summaries: int = 4, # input_shape: tuple = None, start_scratch: bool = False, #model_name: str="model", ): """ Class which implements network training, validation and testing as well as writing checkpoints, logs, summaries, and saving the final model. :param Union[str, None] checkpoint_dir: the type is either str or None (default: None) :param int n_summaries: number of images as samples at different phases to visualize on tensorboard """ #self.model_name=model_name self.model = model self.logger = logger self.prefix = prefix self.logger.info("Init summary writer") if summary_dir is not None: run_name = prefix + "_" if prefix != "" else "" run_name += "{time}-{host}".format( time=time.strftime("%y-%m-%d-%H-%M", time.localtime()), host=os.uname()[1], ) self.summary_dir = os.path.join(summary_dir, run_name) self.n_summaries = n_summaries self.writer = SummaryWriter(summary_dir) if input_shape is not None: dummy_input = torch.rand(input_shape) self.logger.info("Writing graph to summary") self.writer.add_graph(self.model, dummy_input) if checkpoint_dir is not None: self.cp = CheckpointHandler(checkpoint_dir, prefix=prefix, logger=self.logger) else: self.cp = None self.start_scratch = start_scratch def fit( self, train_dataloader, val_dataloader, train_ds, val_ds, loss_fn, optimizer, n_epochs, val_interval, patience_early_stopping, device, metrics: Union[list, dict] = [], val_metric: Union[int, str] = "loss", val_metric_mode: str = "min", start_epoch=0, ): """ train and validate the networks :param int n_epochs: max_train_epochs (default=500) :param int val_interval: run validation every val_interval number of epoch (ARGS.patience_early_stopping) :param int patience_early_stopping: after (patience_early_stopping/val_interval) number of epochs without improvement, terminate training """ self.logger.info("Init model on device '{}'".format(device)) self.model = self.model.to(device) # initalize delve self.tracker = CheckLayerSat(self.summary_dir, save_to="plotcsv", modules=self.model, device=device) best_model = copy.deepcopy(self.model.state_dict()) best_metric = 0.0 if val_metric_mode == "max" else float("inf") # as we don't validate after each epoch but at val_interval, # we update the patience_stopping accordingly to how many times of validation patience_stopping = math.ceil(patience_early_stopping / val_interval) patience_stopping = int(max(1, patience_stopping)) early_stopping = EarlyStoppingCriterion(mode=val_metric_mode, patience=patience_stopping) if not self.start_scratch and self.cp is not None: checkpoint = self.cp.read_latest() if checkpoint is not None: try: try: self.model.load_state_dict(checkpoint["modelState"]) except RuntimeError as e: self.logger.error( "Failed to restore checkpoint: " "Checkpoint has different parameters") self.logger.error(e) raise SystemExit optimizer.load_state_dict( checkpoint["trainState"]["optState"]) start_epoch = checkpoint["trainState"]["epoch"] + 1 best_metric = checkpoint["trainState"]["best_metric"] best_model = checkpoint["trainState"]["best_model"] early_stopping.load_state_dict( checkpoint["trainState"]["earlyStopping"]) #scheduler.load_state_dict(checkpoint["trainState"]["scheduler"]) self.logger.info( "Resuming with epoch {}".format(start_epoch)) except KeyError: self.logger.error("Failed to restore checkpoint") raise since = time.time() self.logger.info("Start training model " + self.prefix) try: if val_metric_mode == "min": val_comp = operator.lt # to run standard operator as function else: val_comp = operator.gt for epoch in range(start_epoch, n_epochs): self.train(epoch, train_dataloader, train_ds, loss_fn, optimizer, device) if epoch % val_interval == 0 or epoch == n_epochs - 1: # first, get val_loss for further comparison val_loss = self.validate(epoch, val_dataloader, val_ds, loss_fn, device, phase="val") if val_metric == "loss": val_result = val_loss # add metrics for delve to keep track of self.tracker.add_scalar("loss", val_loss) # add saturation to the mix self.tracker.add_saturations() else: val_result = metrics[val_metric].get() # compare to see if improvement occurs if val_comp(val_result, best_metric): best_metric = val_result # update best_metric with the loss (smaller than previous) best_model = copy.deepcopy(self.model.state_dict()) """previously, deadlock occurred, which seemed to be related to cp. comment self.cp.write() to see if freezing goes away.""" # write checkpoint self.cp.write({ "modelState": self.model.state_dict(), "trainState": { "epoch": epoch, "best_metric": best_metric, "best_model": best_model, "optState": optimizer.state_dict(), "earlyStopping": early_stopping.state_dict(), }, }) # test if the number of accumulated no-improvement epochs is bigger than patience if early_stopping.step(val_result): self.logger.info( "No improvement over the last {} epochs. Training is stopped." .format(patience_early_stopping)) break except Exception: import traceback self.logger.warning(traceback.format_exc()) self.logger.warning("Aborting...") self.logger.close() raise SystemExit # option here: load the best model to run test on test_dataset and log the final metric (along side best metric) # for ae, only split: train and validate dataset, without test_dataset time_elapsed = time.time() - since self.logger.info("Training complete in {:.0f}m {:.0f}s".format( time_elapsed // 60, time_elapsed % 60)) self.logger.info("Best val metric: {:4f}".format(best_metric)) # close delve tracker self.tracker.close() return self.model def train(self, epoch, train_dataloader, train_ds, loss_fn, optimizer, device): """ Training of one epoch on training data, loss function, optimizer, and respective metrics """ self.logger.debug("train|{}|start".format(epoch)) self.model.train() epoch_start = time.time() start_data_loading = epoch_start data_loading_time = m.Sum(torch.device("cpu")) train_running_loss = 0.0 for i, (train_specs, label) in enumerate(train_dataloader): train_specs = train_specs.to(device) call_label = None if "call" in label: call_label = label["call"].to( device, non_blocking=True, dtype=torch.int64 ) # e.g. tensor([True, True, True, True, True, True]) if "ground_truth" in label: ground_truth = label["ground_truth"].to(device, non_blocking=True) data_loading_time.update( torch.Tensor([(time.time() - start_data_loading)])) optimizer.zero_grad() # compute reconstructions outputs = self.model(train_specs) # compute training reconstruction loss, when augmentation is used # loss = loss_fn(outputs, ground_truth) # compute training reconstruction loss, when no augmentation is used loss = loss_fn(outputs, train_specs) # compute accumulated gradients loss.backward() # perform parameter update based on current gradients optimizer.step() # add the mini-batch training loss to epoch loss # the value of total cost averaged across all training examples of the current batch # loss.item()*data.size(0): total loss of the current batch (not averaged). train_running_loss += loss.item() * train_specs.size(0) prediction = None #print("label is ", label, "call_label is ", call_label) if i % 2 == 0: self.write_summaries( features=train_specs, #labels=call_label, #prediction=prediction, reconstructed=outputs, file_names=label["file_name"], epoch=epoch, phase="train", ) start_data_loading = time.time() # compute the epoch training loss train_epoch_loss = train_running_loss / len(train_ds) self.write_scalar_summaries_logs( loss=train_epoch_loss, #metrics=metrics, lr=optimizer.param_groups[0]["lr"], epoch_time=time.time() - epoch_start, data_loading_time=data_loading_time.get(), epoch=epoch, phase="train", ) self.writer.flush() return train_epoch_loss def validate(self, epoch, val_dataloader, val_ds, loss_fn, device, phase="val"): self.logger.debug("{}|{}|start".format(phase, epoch)) self.model.eval() val_running_loss = 0.0 with torch.no_grad(): epoch_start = time.time() start_data_loading = epoch_start data_loading_time = m.Sum(torch.device("cpu")) for i, (val_specs, label) in enumerate(val_dataloader): val_specs = val_specs.to(device) if "call" in label: call_label = label["call"].to(device, non_blocking=True, dtype=torch.int64) # bool data_loading_time.update( torch.Tensor([(time.time() - start_data_loading)])) # instead of converting spec. to color img, we save the 1-chn outputs directly produced by the network if i % 2 == 0: #grid = make_grid(val_specs) self.writer.add_images("Original", val_specs, epoch) #val_specs outputs = self.model(val_specs) if i % 2 == 0: # tb = SummaryWriter() #grid = make_grid(outputs) self.writer.add_images("Reconstructed", outputs, epoch) #outputs loss = loss_fn(outputs, val_specs) val_running_loss += loss.item() * val_specs.size(0) prediction = None if i % 2 == 0: self.write_summaries( features=val_specs, # original #labels=call_label, #prediction=prediction, reconstructed=outputs, file_names=label["file_name"], epoch=epoch, phase=phase, ) start_data_loading = time.time() val_epoch_loss = val_running_loss / len(val_ds) self.write_scalar_summaries_logs( loss=val_epoch_loss, #metrics=metrics, epoch_time=time.time() - epoch_start, data_loading_time=data_loading_time.get(), epoch=epoch, phase=phase, ) self.writer.flush() return val_epoch_loss def write_summaries( self, features, #labels=None, # tensor([True, True, True, True, True, True]) #prediction=None, reconstructed=None, file_names=None, epoch=None, phase="train", ): #"""Writes image summary per partition (spectrograms and the corresponding predictions)""" """Writes image summary per partition (spectrograms and reconstructed)""" with torch.no_grad(): self.write_img_summaries( features, #labels=labels, #prediction=prediction, reconstructed=reconstructed, file_names=file_names, epoch=epoch + 1, phase=phase, ) def write_img_summaries( self, features, #labels=None, #prediction=None, reconstructed=None, file_names=None, epoch=None, phase="train", ): """ Writes image summary per partition with respect to the prediction output (true predictions - true positive/negative, false predictions - false positive/negative) """ with torch.no_grad(): if file_names is not None: if isinstance(file_names, torch.Tensor): file_names = file_names.cpu().numpy() elif isinstance(file_names, list): file_names = np.asarray(file_names) #if labels is not None and prediction is not None: if reconstructed is not None: features = features.cpu() #labels = labels.cpu() #prediction = prediction.cpu() reconstructed = reconstructed.cpu() self.writer.add_images( tag=phase + "/input", img_tensor=features[:self.n_summaries], #img_tensor=prepare_img( # features, num_images=self.n_summaries, file_names=file_names #), global_step=epoch, ) self.writer.add_images( tag=phase + "/reconstructed", img_tensor=reconstructed[:self.n_summaries], # img_tensor=prepare_img( # features, num_images=self.n_summaries, file_names=file_names # ), global_step=epoch, ) """ below are needed to visualize true positive/negative examples""" """for label in torch.unique(labels): # tensor(1, device='cuda:0') label = label.item() # Returns the value of this tensor as a standard Python number: 1 l_i = torch.eq(labels, label) t_i = torch.eq(prediction, label) * l_i name_t = "true_{}".format("positive" if label else "negative") try: self.writer.add_image( tag=phase + "/" + name_t, img_tensor=prepare_img( features[t_i], num_images=self.n_summaries, file_names=file_names[t_i.numpy() == 1], ), global_step=epoch, ) except ValueError: pass f_i = torch.ne(prediction, label) * l_i name_f = "false_{}".format("negative" if label else "positive") try: self.writer.add_image( tag=phase + "/" + name_f, img_tensor=prepare_img( features[f_i], num_images=self.n_summaries, file_names=file_names[f_i.numpy() == 1], ), global_step=epoch, ) except ValueError: pass else: self.writer.add_image( tag=phase + "/input", img_tensor=prepare_img( features, num_images=self.n_summaries, file_names=file_names ), global_step=epoch, )""" """ Writes scalar summary per partition including loss, confusion matrix, accuracy, recall, f1-score, true positive rate, false positive rate, precision, data_loading_time, epoch time """ def write_scalar_summaries_logs( self, loss: float, metrics: Union[list, dict] = [], lr: float = None, epoch_time: float = None, data_loading_time: float = None, epoch=None, phase="train", ): with torch.no_grad(): log_str = phase if epoch is not None: log_str += "|{}".format(epoch) self.writer.add_scalar(phase + "/epoch_loss", loss, epoch) log_str += "|loss:{:0.3f}".format(loss) if isinstance(metrics, dict): for name, metric in metrics.items(): self.writer.add_scalar(phase + "/" + name, metric.get(), epoch) log_str += "|{}:{:0.3f}".format(name, metric.get()) else: for i, metric in enumerate(metrics): self.writer.add_scalar(phase + "/metric_" + str(i), metric.get(), epoch) log_str += "|m_{}:{:0.3f}".format(i, metric.get()) if lr is not None: self.writer.add_scalar("lr", lr, epoch) log_str += "|lr:{:0.2e}".format(lr) if epoch_time is not None: self.writer.add_scalar(phase + "/time", epoch_time, epoch) log_str += "|t:{:0.1f}".format(epoch_time) if data_loading_time is not None: self.writer.add_scalar(phase + "/data_loading_time", data_loading_time, epoch) self.logger.info(log_str)
class Trainer: """The trainer object handles the actual training and testing of the model. Args: model: The PyTorch-Model data_bundle: The training and test data as a DataBundle optimizer_bundle: Contains the optimizer and the learning rate scheduler run_id: A string identifying the specific run. batch_size: The batch_size to train the model on. epochs: The (total) number of epochs to train the model. criterion: The optimization criterionn (loss), default is cross-entropy metrics: A list of metric-object device: The compute device or list of compute devices to place the model(s) on logs_dir: The directory to store the results conv_method: The strategy for handling convolutional layers for saturation computation device_sat: The device to compute the saturation on. If None, the same device is used as for the model. delta: The delta threshold for computing saturation data_parallel: Enable or Disable multi-GPU downsampling: If None, downsampling is disabled, else the feature maps will be downsampled to (downsampling x downsampling) resolution """ # private internal variables _tracker: CheckLayerSat = attrib(init=False) _save_path: str = attrib(init=False) _initial_epoch: int = attrib(init=False) _trained_epochs: int = attrib(init=False) _experiment_done: bool = attrib(init=False) # General Training setup model: Module data_bundle: DataBundle optimizer_bundle: OptimizerSchedulerBundle run_id: str batch_size: int = 32 epochs: int = 30 criterion: nn.modules.loss._Loss = nn.modules.CrossEntropyLoss() metrics: List[Metric] = attrib(factory=list) # Technical Setup device: str = 'cpu' logs_dir: str = './logs' # delve setup conv_method = 'channelwise' device_sat: Optional[str] = None delta: float = 0.99 data_parallel: bool = False downsampling: Optional[int] = None def _initialize_tracker(self): writer = CSVandPlottingWriter(self._save_path.replace('.csv', ''), primary_metric='test_accuracy') self._tracker = CheckLayerSat( self._save_path.replace('.csv', ''), [writer], self.model, ignore_layer_names='convolution', stats=['lsat', 'idim'], sat_threshold=self.delta, verbose=False, conv_method=self.conv_method, log_interval=1, device=self.device_sat, reset_covariance=True, max_samples=None, initial_epoch=self._initial_epoch, interpolation_strategy='nearest' if self.downsampling is not None else None, interpolation_downsampling=self.downsampling) def _initialize_saving_structure(self): save_dir: str = build_saving_structure( logs_dir=self.logs_dir, model_name=self.model.name, dataset_name=self.data_bundle.dataset_name, output_resolution=self.data_bundle.output_resolution, run_id=self.run_id) self._save_path = os.path.join( save_dir, f"{self.model.name}-{self.data_bundle.dataset_name}-r{self.data_bundle.output_resolution}-bs{self.batch_size}-e{self.epochs}.csv" ) def _load_model(self): self.model.load_state_dict( torch.load(self._save_path.replace('.csv', '.pt'), map_location="cpu")['model_state_dict']) self.model = self.model.to(self.device) def _load_optimizer_and_scheduler(self): self.optimizer_bundle.optimizer.load_state_dict( torch.load(self._save_path.replace('.csv', '.pt'))['optimizer']) if self.optimizer_bundle.scheduler is not None: self.optimizer_bundle.scheduler.load_state_dict( torch.load(self._save_path.replace('.csv', '.pt'))['scheduler']) def _load_initial_and_trained_epoch(self): self._trained_epochs = torch.load( self._save_path.replace('.csv', '.pt'))['epoch'] self._initial_epoch = self._trained_epochs + 1 def _check_training_done(self): if self._initial_epoch >= self.epochs: self._experiment_done = True print( f'Experiment Logs for the exact same experiment with identical run_id was detected, ' f'training will be skipped, consider using another run_id') def _checkpointing(self): self._initial_epoch = 0 self._trained_epochs = 0 self._experiment_done = False if self.data_parallel: print("Enabling multi gpu") self.model = nn.DataParallel(self.model, device_ids=["cuda:0", "cuda:1"], output_device=self.device) self.model = self.model.to(self.device) if os.path.exists(self._save_path): self._load_initial_and_trained_epoch() self._check_training_done() self._load_model() self._load_optimizer_and_scheduler() print('Resuming existing run, starting at epoch', self._initial_epoch + 1, 'from', self._save_path.replace('.csv', '.pt')) def _enable_benchmark_mode_if_cuda(self): if "cuda" in self.device: from torch.backends import cudnn cudnn.benchmark = True def __attrs_post_init__(self): self.device_sat = self.device if self.device_sat is None else self.device_sat self._enable_benchmark_mode_if_cuda() self._initialize_saving_structure() self._checkpointing() self._initialize_tracker() def _reset_metrics(self): for metric in self.metrics: metric.reset() def _eval_metrics(self, y_true: torch.Tensor, y_pred: torch.Tensor): for metric in self.metrics: metric.update(y_true, y_pred) def _update_pbar_postfix(self, pbar: tqdm): metrics = { metric.name: round(metric.value, 3) for metric in self.metrics } pbar.set_postfix(metrics) def _print_status(self, batch: int, old_time: int, dataset: DataLoader): metrics = [ f"{metric.name}: {round(metric.value, 3)}" for metric in self.metrics ] print(batch, 'of', len(dataset), 'processing time', round(time() - old_time, 3), *metrics) def _print_epoch_status(self, epoch: int, old_time: int, metric_dict: Dict[str, float]): metrics = [f"{k}: {round(v, 3)}" for (k, v) in metric_dict.items()] print(epoch + 1, 'of', self.epochs, 'processing time', round(time() - old_time, 3), *metrics) def _track_results(self, prefix: str, metric_name: str, metric_value: float) -> Tuple[str, float]: self._tracker.add_scalar(f"{prefix}_{metric_name}", metric_value) return f"{prefix}_{metric_name}", metric_value def _track_metrics(self, prefix: str, loss: float, total: int) -> Dict[str, float]: result: Dict[str, float] = dict() for metric in self.metrics: name, val = self._track_results(prefix, metric.name, metric.value) result[name] = val name, val = self._track_results(prefix, "loss", loss / total) result[name] = val return result def _save_checkpoint(self, train_metric: Dict[str, float], test_metric: Dict[str, float], epoch: int): state_dict = { 'model_state_dict': self.model.state_dict(), 'optimizer': self.optimizer_bundle.optimizer.state_dict(), 'scheduler': None if self.optimizer_bundle.scheduler is None else self.optimizer_bundle.scheduler.state_dict(), 'epoch': epoch } state_dict.update(train_metric) state_dict.update(test_metric) torch.save(state_dict, self._save_path.replace('.csv', '.pt')) def train(self): """Train the model. The model is trained for a total number of epochs given the number of epochs provided in the constructor. This includes epochs this model was trained previously. Returns: The path to the saturation ans metric logs. """ if self._experiment_done: return old_time = time() for epoch in range(self._initial_epoch, self.epochs): print('Start training epoch', epoch + 1) train_metric = self.train_epoch() test_metric = self.test() train_metric.update(test_metric) self._print_epoch_status(epoch=epoch, old_time=old_time, metric_dict=train_metric) old_time = time() if self.optimizer_bundle.scheduler is not None: self.optimizer_bundle.scheduler.step() self._tracker.add_saturations() self._save_checkpoint(train_metric=train_metric, test_metric=test_metric, epoch=epoch) self._tracker.close() return self._save_path + '.csv' def train_epoch(self) -> Dict[str, float]: """Train a single epoch. Returns: A dictionary containing all metrics computed incrementally during training. """ self.model.train() self._reset_metrics() running_loss = 0 total = 0 old_time = time() pbar = tqdm(self.data_bundle.train_dataset) for batch, data in enumerate(pbar): if batch % 10 == 0 and batch != 0: self._update_pbar_postfix(pbar) inputs, labels = data inputs, labels = inputs.to(self.device), labels.to(self.device) self.optimizer_bundle.optimizer.zero_grad(set_to_none=True) with torch.cuda.amp.autocast(): outputs = self.model(inputs) _, predicted = torch.max(outputs.data, 1) self._eval_metrics(labels, outputs) loss = self.criterion(outputs, labels) loss.backward() self.optimizer_bundle.optimizer.step() running_loss += loss.item() total += self.batch_size return self._track_metrics('training', running_loss, total) def test(self): """Evaluate the model on the test set. Returns: The metric computed on the test set. """ self._reset_metrics() self.model.eval() total = 0 test_loss = 0 with torch.no_grad(): old_time = time() pbar = tqdm(self.data_bundle.test_dataset) for batch, data in enumerate(pbar): inputs, labels = data inputs, labels = inputs.to(self.device), labels.to(self.device) outputs = self.model(inputs) loss = self.criterion(outputs, labels) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) test_loss += loss.item() self._eval_metrics(labels, outputs) if batch % 10 == 0 or batch == ( len(self.data_bundle.test_dataset) - 1): #self._print_status(batch, old_time, self.data_bundle.test_dataset) self._update_pbar_postfix(pbar) old_time = time() test_metrics = self._track_metrics('test', test_loss, total) return test_metrics