def _do_prediction(self, inputs, pred_targets, pcounts, train=False, batch_idx=0, loader=None): """ Do prediction. """ class_predictions = correct_arr = None if self.predictor: pred_targets = pred_targets.flatten() predictor_dist, predictor_logits = self.predictor(inputs.detach()) # This loss is without inference-time model interpolation pred_loss = self.predictor_loss(predictor_logits, pred_targets) # cross-entropy loss # This loss is for the interpolated model interp_loss = self._interpolated_loss(predictor_dist, pred_targets, loader=loader, train=train) _, class_predictions = torch.max(predictor_dist, 1) pcounts["total_samples"] += pred_targets.size(0) correct_arr = class_predictions == pred_targets pcounts["correct_samples"] += correct_arr.sum().item() pred_loss_ = pred_loss.item() pcounts["total_pred_loss"] += pred_loss_ pcounts["total_interp_loss"] += interp_loss if train: # Predictor backward + optimize pred_loss.backward() self.pred_optimizer.step() if self.batch_log_interval and batch_idx % self.batch_log_interval == 0: print("Finished batch %d" % batch_idx) if self.predictor: batch_acc = correct_arr.float().mean() * 100 batch_ppl = lang_util.perpl(pred_loss_ / pred_targets.size(0)) print("Partial pred acc - " "batch acc: %.3f%%, pred ppl: %.1f" % (batch_acc, batch_ppl)) return (pcounts, class_predictions, correct_arr)
def train_epoch(self, epoch): """ Do one epoch of training and testing. Returns: A dict that describes progress of this epoch. The dict includes the key 'stop'. If set to one, this network should be stopped early. Training is not progressing well enough. """ t1 = time.time() ret = {} self.model.train() # Needed if using dropout if self.predictor: self.predictor.train() # Performance metrics total_loss = 0.0 pcounts = { "total_samples": 0.0, "correct_samples": 0.0, "total_pred_loss": 0.0, "total_interp_loss": 0.0, } bsz = self.batch_size hidden = self.train_hidden_buffer[ -1] if self.train_hidden_buffer else None if hidden is None: hidden = self._init_hidden(self.batch_size) for batch_idx, (inputs, targets, pred_targets, _input_labels) in enumerate(self.train_loader): # Inputs are of shape (batch, input_size) if inputs.size(0) > bsz: # Crop to smaller first epoch batch size inputs = inputs[:bsz] targets = targets[:bsz] pred_targets = pred_targets[:bsz] hidden = self._repackage_hidden(hidden) self.optimizer.zero_grad() if self.pred_optimizer: self.pred_optimizer.zero_grad() # Forward inputs = inputs.to(self.device) targets = targets.to(self.device) pred_targets = pred_targets.to(self.device) output, hidden = self.model(inputs, hidden) x_b, pred_input = self._get_prediction_and_loss_inputs(hidden) self.train_hidden_buffer.append(hidden) # Loss loss_targets = (targets, x_b) loss = self._compute_loss(output, loss_targets) if loss is not None: total_loss += loss.item() if not self.model_learning_paused: self._backward_and_optimize(loss) # Keep only latest batch states around self.train_hidden_buffer = self.train_hidden_buffer[-1:] pcounts, class_predictions, correct_arr = self._do_prediction( pred_input, pred_targets, pcounts, train=True, batch_idx=batch_idx, loader=self.train_loader, ) if epoch == 0 and batch_idx >= self.batches_in_first_epoch - 1: print("Breaking after %d batches in epoch %d" % (self.batches_in_first_epoch, epoch)) break ret["stop"] = 0 self.model._post_train_epoch(epoch) # Update kwinners duty cycles, etc if self.eval_interval and (epoch - 1) % self.eval_interval == 0: # Evaluate each x epochs ret.update(self.eval_epoch(epoch)) train_time = time.time() - t1 self._post_epoch(epoch) num_batches = batch_idx + 1 ret["train_loss"] = total_loss / num_batches if self.predictor: num_samples = num_batches * self.batch_size train_pred_loss = pcounts["total_pred_loss"] / num_samples train_interp_loss = pcounts["total_interp_loss"] / num_samples ret["train_interp_ppl"] = lang_util.perpl(train_interp_loss) ret["train_pred_ppl"] = lang_util.perpl(train_pred_loss) ret["train_pred_acc"] = (100 * pcounts["correct_samples"] / pcounts["total_samples"]) ret["epoch_time_train"] = train_time ret["epoch_time"] = time.time() - t1 ret["learning_rate"] = self.learning_rate print(epoch, print_epoch_values(ret)) return ret
def eval_epoch(self, epoch, loader=None): ret = {} print("Evaluating...") if not loader: loader = self.val_loader if self.instrumentation: # Capture entropy prior to evaluation _, train_entropy = binary_entropy(self.model.RSM_1.duty_cycle) ret["train_entropy"] = train_entropy.item() self.model.RSM_1.duty_cycle.fill_(0.0) # Clear duty cycle self.model.eval() if self.predictor: self.predictor.eval() if self.weight_sparsity is not None: # Rezeroing happens before forward pass, so rezero after last # training forward. self.model._zero_sparse_weights() with torch.no_grad(): total_loss = 0.0 pcounts = { "total_samples": 0.0, "correct_samples": 0.0, "total_pred_loss": 0.0, "total_interp_loss": 0.0, } hidden = self._init_hidden(self.eval_batch_size) read_out_tgt = [] read_out_pred = [] metrics = {} for _b_idx, (inputs, targets, pred_targets, input_labels) in enumerate(loader): # Forward inputs = inputs.to(self.device) targets = targets.to(self.device) pred_targets = pred_targets.to(self.device) input_labels = input_labels.to(self.device) self._cache_inputs(input_labels, clear=_b_idx == 0) output, hidden = self.model(inputs, hidden) x_b, pred_input = self._get_prediction_and_loss_inputs(hidden) # Loss loss = self._compute_loss(output, (targets, x_b)) if loss is not None: total_loss += loss.item() pcounts, class_predictions, correct_arr = self._do_prediction( pred_input, pred_targets, pcounts, batch_idx=_b_idx, loader=loader) self._read_out_predictions(pred_targets, class_predictions, read_out_tgt, read_out_pred) hidden = self._repackage_hidden(hidden) if self.instrumentation: metrics = self._agg_batch_metrics( metrics, pred_images=output[0].unsqueeze(0), targets=targets.unsqueeze(0), correct_arr=correct_arr.unsqueeze(0), pred_targets=pred_targets, class_predictions=class_predictions, ) if self.dataset_kind == "mnist" and self.model_kind == "rsm": # Summary of column activation by input & next input self._store_activity_for_viz(x_b, input_labels, pred_targets) if self.instrumentation: # Save some snapshots from last batch of epoch # if self.model_kind == "rsm": # metrics['last_hidden_snp'] = x_b # metrics['last_input_snp'] = inputs # metrics['last_output_snp'] = last_output # After all eval batches, generate stats & figures ret.update(self._generate_instr_charts(metrics)) ret.update(self._store_instr_hists()) _, test_entropy = binary_entropy(self.model.RSM_1.duty_cycle) ret["test_entropy"] = test_entropy.item() self.model.RSM_1.duty_cycle.fill_(0.0) # Clear duty cycle num_batches = _b_idx + 1 num_samples = pcounts["total_samples"] ret["val_loss"] = val_loss = total_loss / num_batches if self.predictor: test_pred_loss = pcounts["total_pred_loss"] / num_samples test_interp_loss = pcounts["total_interp_loss"] / num_samples ret["val_interp_ppl"] = lang_util.perpl(test_interp_loss) ret["val_pred_ppl"] = lang_util.perpl(test_pred_loss) ret["val_pred_acc"] = 100 * pcounts[ "correct_samples"] / num_samples if not self.best_val_loss or val_loss < self.best_val_loss: self.best_val_loss = val_loss else: # Val loss increased if self.learning_rate_gamma: self.do_anneal_learning = True # Reduce LR during post_epoch if self.pause_after_upticks and not self.model_learning_paused: if not self.pause_min_epoch or ( self.pause_min_epoch and epoch >= self.pause_min_epoch): self.n_upticks += 1 if self.n_upticks >= self.pause_after_upticks: print( ">>> Pausing learning after %d upticks, validation " "loss rose to %.3f, best: %.3f" % (self.n_upticks, val_loss, self.best_val_loss)) self._pause_learning(epoch) return ret
def train_epoch(self, epoch): """This should be called to do one epoch of training and testing. Returns: A dict that describes progress of this epoch. The dict includes the key 'stop'. If set to one, this network should be stopped early. Training is not progressing well enough. """ t1 = time.time() ret = {} self.model.train() # Needed if using dropout if self.predictor: self.predictor.train() # Performance metrics total_loss = total_samples = correct_samples = total_pred_loss = 0.0 bsz = self.batch_size if epoch == 0 and self.batch_size_first < self.batch_size: bsz = self.batch_size_first hidden = self._init_hidden(bsz) last_output = None for batch_idx, (inputs, targets, pred_targets, _) in enumerate(self.train_loader): # Inputs are of shape (batch, input_size) if inputs.size(0) > bsz: # Crop to smaller first epoch batch size inputs = inputs[:bsz] targets = targets[:bsz] pred_targets = pred_targets[:bsz] hidden = self._repackage_hidden(hidden) self.optimizer.zero_grad() if self.pred_optimizer: self.pred_optimizer.zero_grad() # Forward inputs = inputs.to(self.device) targets = targets.to(self.device) pred_targets = pred_targets.to(self.device) output, hidden = self.model(inputs, hidden) # Loss loss = self._compute_loss( output, targets, last_output=last_output, x_b=hidden[0]) # Kwargs used only for predict_memory if self.debug: self.model._register_hooks() if loss is not None: total_loss += loss.item() # RSM backward + optimize loss.backward() if self.model_kind == "lstm": torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.25) for p in self.model.parameters(): p.data.add_(-self.learning_rate, p.grad.data) else: self.optimizer.step() if self.plot_gradients: self._plot_gradient_flow() x_b = hidden[0] total_samples, correct_samples, class_predictions, correct_arr, \ batch_loss, total_pred_loss = self._do_prediction( x_b, pred_targets, total_samples, correct_samples, total_pred_loss, train=True, ) last_output = output if self.batch_log_interval and batch_idx % self.batch_log_interval == 0: print("Finished batch %d" % batch_idx) if self.predictor: acc = 100 * correct_samples / total_samples batch_acc = correct_arr.float().mean() * 100 batch_ppl = lang_util.perpl(batch_loss) print("Partial train pred acc - epoch: %.3f%%, " "batch acc: %.3f%%, batch ppl: %.1f" % (acc, batch_acc, batch_ppl)) ret["stop"] = 0 if self.eval_interval and (epoch == 0 or (epoch + 1) % self.eval_interval == 0): # Evaluate each x epochs ret.update(self._eval()) if self.dataset_kind == "ptb" and epoch >= 12 and ret[ "val_pred_ppl"] > 280: ret["stop"] = 1 train_time = time.time() - t1 self._post_epoch(epoch) ret["train_loss"] = total_loss / (batch_idx + 1) if self.predictor: train_pred_loss = total_pred_loss / (batch_idx + 1) ret["train_pred_ppl"] = lang_util.perpl(train_pred_loss) ret["train_pred_acc"] = 100 * correct_samples / total_samples ret["epoch_time_train"] = train_time ret["epoch_time"] = time.time() - t1 ret["learning_rate"] = self.learning_rate print(epoch, print_epoch_values(ret)) return ret
def _eval(self): ret = {} print("Evaluating...") # Disable dropout self.model.eval() if self.predictor: self.predictor.eval() if self.weight_sparsity is not None: # Rezeroing happens before forward pass, so rezero after last # training forward. self.model._zero_sparse_weights() with torch.no_grad(): total_loss = 0.0 total_samples = 0.0 correct_samples = 0.0 total_pred_loss = 0.0 hidden = self._init_hidden(self.batch_size) all_x_a_next = all_targets = all_correct_arrs = all_pred_targets = None all_cls_preds = None last_output = None for batch_idx, (inputs, targets, pred_targets, input_labels) in enumerate(self.val_loader): # Forward inputs = inputs.to(self.device) targets = targets.to(self.device) pred_targets = pred_targets.to(self.device) x_a_next, hidden = self.model(inputs, hidden) x_b = hidden[0] # Loss loss = self._compute_loss( x_a_next, targets, last_output=last_output, x_b=x_b) # Kwargs used only for predict_memory if loss is not None: total_loss += loss.item() total_samples, correct_samples, class_predictions, correct_arr, \ batch_loss, total_pred_loss = self._do_prediction( x_b, pred_targets, total_samples, correct_samples, total_pred_loss ) hidden = self._repackage_hidden(hidden) # Save results for image grid & confusion matrix x_a_next.unsqueeze_(0) targets.unsqueeze_(0) correct_arr.unsqueeze_(0) all_x_a_next = (x_a_next if all_x_a_next is None else torch.cat((all_x_a_next, x_a_next))) all_targets = (targets if all_targets is None else torch.cat( (all_targets, targets))) all_correct_arrs = (correct_arr if all_correct_arrs is None else torch.cat( (all_correct_arrs, correct_arr))) all_pred_targets = (pred_targets if all_pred_targets is None else torch.cat( (all_pred_targets, pred_targets))) all_cls_preds = (class_predictions if all_cls_preds is None else torch.cat( (all_cls_preds, class_predictions))) if self.dataset_kind == "mnist" and self.model_kind == "rsm": # Summary of column activation by input & next input self._store_activity_for_viz(x_b, input_labels, pred_targets) ret.update(self._track_hists()) last_output = x_a_next if batch_idx >= self.eval_batches_in_epoch: break # After all eval batches, generate stats & figures if self.dataset_kind == "mnist" and self.model_kind == "rsm": if not self.predict_memory: ret["img_preds"] = self._image_grid( all_x_a_next, compare_with=all_targets, compare_correct=all_correct_arrs, ).cpu() cm_fig = self._confusion_matrix(all_pred_targets, all_cls_preds) ret["img_confusion"] = fig2img(cm_fig) if self.flattened: activity_grid = plot_activity_grid( self.activity_by_inputs, n_labels=self.predictor_output_size) else: activity_grid = plot_activity( self.activity_by_inputs, n_labels=self.predictor_output_size, level="cell", ) img_repr_sim = plot_representation_similarity( self.activity_by_inputs, n_labels=self.predictor_output_size, title=self.boost_strat, ) ret["img_repr_sim"] = fig2img(img_repr_sim) ret["img_col_activity"] = fig2img(activity_grid) self.activity_by_inputs = {} ret["val_loss"] = val_loss = total_loss / (batch_idx + 1) if self.predictor: test_pred_loss = total_pred_loss / (batch_idx + 1) ret["val_pred_ppl"] = lang_util.perpl(test_pred_loss) ret["val_pred_acc"] = 100 * correct_samples / total_samples if not self.best_val_loss or val_loss < self.best_val_loss: self.best_val_loss = val_loss else: if self.learning_rate_gamma: self.do_anneal_learning = True # Reduce LR during post_epoch return ret