Beispiel #1
0
    def _do_prediction(self,
                       inputs,
                       pred_targets,
                       pcounts,
                       train=False,
                       batch_idx=0,
                       loader=None):
        """
        Do prediction.
        """
        class_predictions = correct_arr = None
        if self.predictor:
            pred_targets = pred_targets.flatten()

            predictor_dist, predictor_logits = self.predictor(inputs.detach())

            # This loss is without inference-time model interpolation
            pred_loss = self.predictor_loss(predictor_logits,
                                            pred_targets)  # cross-entropy loss

            # This loss is for the interpolated model
            interp_loss = self._interpolated_loss(predictor_dist,
                                                  pred_targets,
                                                  loader=loader,
                                                  train=train)

            _, class_predictions = torch.max(predictor_dist, 1)
            pcounts["total_samples"] += pred_targets.size(0)
            correct_arr = class_predictions == pred_targets
            pcounts["correct_samples"] += correct_arr.sum().item()
            pred_loss_ = pred_loss.item()
            pcounts["total_pred_loss"] += pred_loss_
            pcounts["total_interp_loss"] += interp_loss
            if train:
                # Predictor backward + optimize
                pred_loss.backward()
                self.pred_optimizer.step()

        if self.batch_log_interval and batch_idx % self.batch_log_interval == 0:
            print("Finished batch %d" % batch_idx)
            if self.predictor:
                batch_acc = correct_arr.float().mean() * 100
                batch_ppl = lang_util.perpl(pred_loss_ / pred_targets.size(0))
                print("Partial pred acc - "
                      "batch acc: %.3f%%, pred ppl: %.1f" %
                      (batch_acc, batch_ppl))

        return (pcounts, class_predictions, correct_arr)
Beispiel #2
0
    def train_epoch(self, epoch):
        """
        Do one epoch of training and testing.

        Returns:
            A dict that describes progress of this epoch.
            The dict includes the key 'stop'. If set to one, this network
            should be stopped early. Training is not progressing well enough.
        """
        t1 = time.time()
        ret = {}

        self.model.train()  # Needed if using dropout
        if self.predictor:
            self.predictor.train()

        # Performance metrics
        total_loss = 0.0
        pcounts = {
            "total_samples": 0.0,
            "correct_samples": 0.0,
            "total_pred_loss": 0.0,
            "total_interp_loss": 0.0,
        }

        bsz = self.batch_size

        hidden = self.train_hidden_buffer[
            -1] if self.train_hidden_buffer else None
        if hidden is None:
            hidden = self._init_hidden(self.batch_size)

        for batch_idx, (inputs, targets, pred_targets,
                        _input_labels) in enumerate(self.train_loader):
            # Inputs are of shape (batch, input_size)
            if inputs.size(0) > bsz:
                # Crop to smaller first epoch batch size
                inputs = inputs[:bsz]
                targets = targets[:bsz]
                pred_targets = pred_targets[:bsz]

            hidden = self._repackage_hidden(hidden)

            self.optimizer.zero_grad()
            if self.pred_optimizer:
                self.pred_optimizer.zero_grad()

            # Forward
            inputs = inputs.to(self.device)
            targets = targets.to(self.device)
            pred_targets = pred_targets.to(self.device)

            output, hidden = self.model(inputs, hidden)

            x_b, pred_input = self._get_prediction_and_loss_inputs(hidden)

            self.train_hidden_buffer.append(hidden)

            # Loss
            loss_targets = (targets, x_b)
            loss = self._compute_loss(output, loss_targets)
            if loss is not None:
                total_loss += loss.item()
                if not self.model_learning_paused:
                    self._backward_and_optimize(loss)

            # Keep only latest batch states around
            self.train_hidden_buffer = self.train_hidden_buffer[-1:]

            pcounts, class_predictions, correct_arr = self._do_prediction(
                pred_input,
                pred_targets,
                pcounts,
                train=True,
                batch_idx=batch_idx,
                loader=self.train_loader,
            )

            if epoch == 0 and batch_idx >= self.batches_in_first_epoch - 1:
                print("Breaking after %d batches in epoch %d" %
                      (self.batches_in_first_epoch, epoch))
                break

        ret["stop"] = 0
        self.model._post_train_epoch(epoch)  # Update kwinners duty cycles, etc

        if self.eval_interval and (epoch - 1) % self.eval_interval == 0:

            # Evaluate each x epochs
            ret.update(self.eval_epoch(epoch))

        train_time = time.time() - t1
        self._post_epoch(epoch)

        num_batches = batch_idx + 1
        ret["train_loss"] = total_loss / num_batches
        if self.predictor:
            num_samples = num_batches * self.batch_size
            train_pred_loss = pcounts["total_pred_loss"] / num_samples
            train_interp_loss = pcounts["total_interp_loss"] / num_samples
            ret["train_interp_ppl"] = lang_util.perpl(train_interp_loss)
            ret["train_pred_ppl"] = lang_util.perpl(train_pred_loss)
            ret["train_pred_acc"] = (100 * pcounts["correct_samples"] /
                                     pcounts["total_samples"])

        ret["epoch_time_train"] = train_time
        ret["epoch_time"] = time.time() - t1
        ret["learning_rate"] = self.learning_rate
        print(epoch, print_epoch_values(ret))
        return ret
Beispiel #3
0
    def eval_epoch(self, epoch, loader=None):
        ret = {}
        print("Evaluating...")
        if not loader:
            loader = self.val_loader

        if self.instrumentation:
            # Capture entropy prior to evaluation
            _, train_entropy = binary_entropy(self.model.RSM_1.duty_cycle)
            ret["train_entropy"] = train_entropy.item()
            self.model.RSM_1.duty_cycle.fill_(0.0)  # Clear duty cycle

        self.model.eval()
        if self.predictor:
            self.predictor.eval()

        if self.weight_sparsity is not None:
            # Rezeroing happens before forward pass, so rezero after last
            # training forward.
            self.model._zero_sparse_weights()

        with torch.no_grad():
            total_loss = 0.0
            pcounts = {
                "total_samples": 0.0,
                "correct_samples": 0.0,
                "total_pred_loss": 0.0,
                "total_interp_loss": 0.0,
            }

            hidden = self._init_hidden(self.eval_batch_size)

            read_out_tgt = []
            read_out_pred = []
            metrics = {}

            for _b_idx, (inputs, targets, pred_targets,
                         input_labels) in enumerate(loader):

                # Forward
                inputs = inputs.to(self.device)
                targets = targets.to(self.device)
                pred_targets = pred_targets.to(self.device)
                input_labels = input_labels.to(self.device)

                self._cache_inputs(input_labels, clear=_b_idx == 0)

                output, hidden = self.model(inputs, hidden)

                x_b, pred_input = self._get_prediction_and_loss_inputs(hidden)

                # Loss
                loss = self._compute_loss(output, (targets, x_b))
                if loss is not None:
                    total_loss += loss.item()

                pcounts, class_predictions, correct_arr = self._do_prediction(
                    pred_input,
                    pred_targets,
                    pcounts,
                    batch_idx=_b_idx,
                    loader=loader)

                self._read_out_predictions(pred_targets, class_predictions,
                                           read_out_tgt, read_out_pred)

                hidden = self._repackage_hidden(hidden)

                if self.instrumentation:
                    metrics = self._agg_batch_metrics(
                        metrics,
                        pred_images=output[0].unsqueeze(0),
                        targets=targets.unsqueeze(0),
                        correct_arr=correct_arr.unsqueeze(0),
                        pred_targets=pred_targets,
                        class_predictions=class_predictions,
                    )

                    if self.dataset_kind == "mnist" and self.model_kind == "rsm":
                        # Summary of column activation by input & next input
                        self._store_activity_for_viz(x_b, input_labels,
                                                     pred_targets)

            if self.instrumentation:
                # Save some snapshots from last batch of epoch
                # if self.model_kind == "rsm":
                #     metrics['last_hidden_snp'] = x_b
                #     metrics['last_input_snp'] = inputs
                #     metrics['last_output_snp'] = last_output

                # After all eval batches, generate stats & figures
                ret.update(self._generate_instr_charts(metrics))
                ret.update(self._store_instr_hists())
                _, test_entropy = binary_entropy(self.model.RSM_1.duty_cycle)
                ret["test_entropy"] = test_entropy.item()
                self.model.RSM_1.duty_cycle.fill_(0.0)  # Clear duty cycle

            num_batches = _b_idx + 1
            num_samples = pcounts["total_samples"]
            ret["val_loss"] = val_loss = total_loss / num_batches
            if self.predictor:
                test_pred_loss = pcounts["total_pred_loss"] / num_samples
                test_interp_loss = pcounts["total_interp_loss"] / num_samples
                ret["val_interp_ppl"] = lang_util.perpl(test_interp_loss)
                ret["val_pred_ppl"] = lang_util.perpl(test_pred_loss)
                ret["val_pred_acc"] = 100 * pcounts[
                    "correct_samples"] / num_samples

            if not self.best_val_loss or val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
            else:
                # Val loss increased
                if self.learning_rate_gamma:
                    self.do_anneal_learning = True  # Reduce LR during post_epoch
                if self.pause_after_upticks and not self.model_learning_paused:
                    if not self.pause_min_epoch or (
                            self.pause_min_epoch
                            and epoch >= self.pause_min_epoch):
                        self.n_upticks += 1
                        if self.n_upticks >= self.pause_after_upticks:
                            print(
                                ">>> Pausing learning after %d upticks, validation "
                                "loss rose to %.3f, best: %.3f" %
                                (self.n_upticks, val_loss, self.best_val_loss))
                            self._pause_learning(epoch)

        return ret
Beispiel #4
0
    def train_epoch(self, epoch):
        """This should be called to do one epoch of training and testing.

        Returns:
            A dict that describes progress of this epoch.
            The dict includes the key 'stop'. If set to one, this network
            should be stopped early. Training is not progressing well enough.
        """
        t1 = time.time()

        ret = {}

        self.model.train()  # Needed if using dropout
        if self.predictor:
            self.predictor.train()

        # Performance metrics
        total_loss = total_samples = correct_samples = total_pred_loss = 0.0

        bsz = self.batch_size
        if epoch == 0 and self.batch_size_first < self.batch_size:
            bsz = self.batch_size_first

        hidden = self._init_hidden(bsz)
        last_output = None

        for batch_idx, (inputs, targets, pred_targets,
                        _) in enumerate(self.train_loader):
            # Inputs are of shape (batch, input_size)

            if inputs.size(0) > bsz:
                # Crop to smaller first epoch batch size
                inputs = inputs[:bsz]
                targets = targets[:bsz]
                pred_targets = pred_targets[:bsz]

            hidden = self._repackage_hidden(hidden)

            self.optimizer.zero_grad()
            if self.pred_optimizer:
                self.pred_optimizer.zero_grad()

            # Forward
            inputs = inputs.to(self.device)
            targets = targets.to(self.device)
            pred_targets = pred_targets.to(self.device)

            output, hidden = self.model(inputs, hidden)

            # Loss
            loss = self._compute_loss(
                output, targets, last_output=last_output,
                x_b=hidden[0])  # Kwargs used only for predict_memory

            if self.debug:
                self.model._register_hooks()

            if loss is not None:
                total_loss += loss.item()

                # RSM backward + optimize
                loss.backward()
                if self.model_kind == "lstm":
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                   0.25)
                    for p in self.model.parameters():
                        p.data.add_(-self.learning_rate, p.grad.data)
                else:
                    self.optimizer.step()

            if self.plot_gradients:
                self._plot_gradient_flow()

            x_b = hidden[0]
            total_samples, correct_samples, class_predictions, correct_arr, \
                batch_loss, total_pred_loss = self._do_prediction(
                    x_b,
                    pred_targets,
                    total_samples,
                    correct_samples,
                    total_pred_loss,
                    train=True,
                )

            last_output = output

            if self.batch_log_interval and batch_idx % self.batch_log_interval == 0:
                print("Finished batch %d" % batch_idx)
                if self.predictor:
                    acc = 100 * correct_samples / total_samples
                    batch_acc = correct_arr.float().mean() * 100
                    batch_ppl = lang_util.perpl(batch_loss)
                    print("Partial train pred acc - epoch: %.3f%%, "
                          "batch acc: %.3f%%, batch ppl: %.1f" %
                          (acc, batch_acc, batch_ppl))

        ret["stop"] = 0

        if self.eval_interval and (epoch == 0 or
                                   (epoch + 1) % self.eval_interval == 0):
            # Evaluate each x epochs
            ret.update(self._eval())
            if self.dataset_kind == "ptb" and epoch >= 12 and ret[
                    "val_pred_ppl"] > 280:
                ret["stop"] = 1

        train_time = time.time() - t1
        self._post_epoch(epoch)

        ret["train_loss"] = total_loss / (batch_idx + 1)
        if self.predictor:
            train_pred_loss = total_pred_loss / (batch_idx + 1)
            ret["train_pred_ppl"] = lang_util.perpl(train_pred_loss)
            ret["train_pred_acc"] = 100 * correct_samples / total_samples

        ret["epoch_time_train"] = train_time
        ret["epoch_time"] = time.time() - t1
        ret["learning_rate"] = self.learning_rate
        print(epoch, print_epoch_values(ret))
        return ret
Beispiel #5
0
    def _eval(self):
        ret = {}
        print("Evaluating...")
        # Disable dropout
        self.model.eval()
        if self.predictor:
            self.predictor.eval()

        if self.weight_sparsity is not None:
            # Rezeroing happens before forward pass, so rezero after last
            # training forward.
            self.model._zero_sparse_weights()

        with torch.no_grad():
            total_loss = 0.0
            total_samples = 0.0
            correct_samples = 0.0
            total_pred_loss = 0.0

            hidden = self._init_hidden(self.batch_size)

            all_x_a_next = all_targets = all_correct_arrs = all_pred_targets = None
            all_cls_preds = None
            last_output = None

            for batch_idx, (inputs, targets, pred_targets,
                            input_labels) in enumerate(self.val_loader):

                # Forward
                inputs = inputs.to(self.device)
                targets = targets.to(self.device)
                pred_targets = pred_targets.to(self.device)
                x_a_next, hidden = self.model(inputs, hidden)
                x_b = hidden[0]

                # Loss
                loss = self._compute_loss(
                    x_a_next, targets, last_output=last_output,
                    x_b=x_b)  # Kwargs used only for predict_memory
                if loss is not None:
                    total_loss += loss.item()

                total_samples, correct_samples, class_predictions, correct_arr, \
                    batch_loss, total_pred_loss = self._do_prediction(
                        x_b, pred_targets, total_samples, correct_samples,
                        total_pred_loss
                    )

                hidden = self._repackage_hidden(hidden)

                # Save results for image grid & confusion matrix
                x_a_next.unsqueeze_(0)
                targets.unsqueeze_(0)
                correct_arr.unsqueeze_(0)
                all_x_a_next = (x_a_next if all_x_a_next is None else
                                torch.cat((all_x_a_next, x_a_next)))
                all_targets = (targets if all_targets is None else torch.cat(
                    (all_targets, targets)))
                all_correct_arrs = (correct_arr
                                    if all_correct_arrs is None else torch.cat(
                                        (all_correct_arrs, correct_arr)))
                all_pred_targets = (pred_targets
                                    if all_pred_targets is None else torch.cat(
                                        (all_pred_targets, pred_targets)))
                all_cls_preds = (class_predictions
                                 if all_cls_preds is None else torch.cat(
                                     (all_cls_preds, class_predictions)))

                if self.dataset_kind == "mnist" and self.model_kind == "rsm":
                    # Summary of column activation by input & next input
                    self._store_activity_for_viz(x_b, input_labels,
                                                 pred_targets)

                ret.update(self._track_hists())

                last_output = x_a_next

                if batch_idx >= self.eval_batches_in_epoch:
                    break

            # After all eval batches, generate stats & figures
            if self.dataset_kind == "mnist" and self.model_kind == "rsm":
                if not self.predict_memory:
                    ret["img_preds"] = self._image_grid(
                        all_x_a_next,
                        compare_with=all_targets,
                        compare_correct=all_correct_arrs,
                    ).cpu()
                cm_fig = self._confusion_matrix(all_pred_targets,
                                                all_cls_preds)
                ret["img_confusion"] = fig2img(cm_fig)
                if self.flattened:
                    activity_grid = plot_activity_grid(
                        self.activity_by_inputs,
                        n_labels=self.predictor_output_size)
                else:
                    activity_grid = plot_activity(
                        self.activity_by_inputs,
                        n_labels=self.predictor_output_size,
                        level="cell",
                    )
                img_repr_sim = plot_representation_similarity(
                    self.activity_by_inputs,
                    n_labels=self.predictor_output_size,
                    title=self.boost_strat,
                )
                ret["img_repr_sim"] = fig2img(img_repr_sim)
                ret["img_col_activity"] = fig2img(activity_grid)
                self.activity_by_inputs = {}

            ret["val_loss"] = val_loss = total_loss / (batch_idx + 1)
            if self.predictor:
                test_pred_loss = total_pred_loss / (batch_idx + 1)
                ret["val_pred_ppl"] = lang_util.perpl(test_pred_loss)
                ret["val_pred_acc"] = 100 * correct_samples / total_samples

            if not self.best_val_loss or val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
            else:
                if self.learning_rate_gamma:
                    self.do_anneal_learning = True  # Reduce LR during post_epoch

        return ret