def test_static_accuracy(self):
        """
        Test the accuracy computation built into the optimizer, given some data.

        This function tests the accuracy of a given chunk of data, with no previous data totals,
        and thus only tests "static" accuracy, not "running" accuracy
        """
        cfg = DefaultOptimizerConfig()
        batch_size = cfg.training_cfg.batch_size
        num_outputs = 5

        # now, modify a subset of the network output and make that the "real" output
        step = 0.05
        batch_acc_vec = np.arange(0, 1 + step, step)
        for batch_acc in batch_acc_vec:
            random_mat = self.rso.rand(batch_size, num_outputs)
            row_sum = random_mat.sum(axis=1)

            # normalize the random_mat such that every row adds up to 1
            # broadcast so we can divide every element in matrix by the row's sum
            fake_network_output = random_mat / row_sum[:,
                                                       None]  # shape: [batch_size x n_output]
            network_output = np.argmax(fake_network_output,
                                       axis=1)  # the hard-decision prediction

            true_output = network_output.copy()
            num_indices_to_modify = int(batch_size * (1 - batch_acc))
            indices_to_modify = self.rso.choice(range(batch_size),
                                                num_indices_to_modify,
                                                replace=False)

            # create the "true" output such that the target accuracy matches the desired value
            for ii in indices_to_modify:
                true_output[ii] = (true_output[ii] + 1) % num_outputs

            expected_balanced_acc = balanced_accuracy_score(
                true_output, network_output) * 100

            # convert datatypes to what is expected during operation
            network_output_pt = torch.tensor(fake_network_output,
                                             dtype=torch.float)
            true_output_pt = torch.tensor(true_output, dtype=torch.long)

            # now compute the accuracy
            actual_acc, n_total, n_correct = \
                _running_eval_acc(network_output_pt, true_output_pt, n_total=None, n_correct=None)
            self.assertAlmostEqual(actual_acc, expected_balanced_acc)
    def train_epoch(self,   model: nn.Module, train_loader: DataLoader,
                    val_clean_loader: DataLoader, val_triggered_loader: DataLoader,
                    epoch_num: int, use_amp: bool = False):
        """
        Runs one epoch of training on the specified model

        :param model: the model to train for one epoch
        :param train_loader: a DataLoader object pointing to the training dataset
        :param val_clean_loader: a DataLoader object pointing to the validation dataset that is clean
        :param val_triggered_loader: a DataLoader object pointing to the validation dataset that is triggered
        :param epoch_num: the epoch number that is being trained
        :param use_amp: if True, uses automated mixed precision for FP16 training.
        :return: a list of statistics for batches where statistics were computed
        """

        # Probability of Adversarial attack to occur in each iteration
        attack_prob = self.optimizer_cfg.training_cfg.adv_training_ratio
        pid = os.getpid()
        train_dataset_len = len(train_loader.dataset)
        loop = tqdm(train_loader, disable=self.optimizer_cfg.reporting_cfg.disable_progress_bar)

        scaler = None
        if use_amp:
            scaler = torch.cuda.amp.GradScaler()

        train_n_correct, train_n_total = None, None

        # Define parameters of the adversarial attack
        attack_eps = float(self.optimizer_cfg.training_cfg.adv_training_eps)
        attack_iterations = int(self.optimizer_cfg.training_cfg.adv_training_iterations)
        eps_iter = (2.0 * attack_eps) / float(attack_iterations)
        attack = LinfPGDAttack(
            predict=model,
            loss_fn=nn.CrossEntropyLoss(reduction="sum"),
            eps=attack_eps,
            nb_iter=attack_iterations,
            eps_iter=eps_iter)

        sum_batchmean_train_loss = 0
        running_train_acc = 0
        num_batches = len(train_loader)
        model.train()
        for batch_idx, (x, y_truth) in enumerate(loop):
            x = x.to(self.device)
            y_truth = y_truth.to(self.device)

            # put network into training mode & zero out previous gradient computations
            self.optimizer.zero_grad()

            # get predictions based on input & weights learned so far
            if use_amp:
                with torch.cuda.amp.autocast():
                    # add adversarial noise via l_inf PGD attack
                    # only apply attack to attack_prob of the batches
                    if attack_prob and np.random.rand() <= attack_prob:
                        with ctx_noparamgrad_and_eval(model):
                            x = attack.perturb(x, y_truth)
                    y_hat = model(x)
                    # compute metrics
                    batch_train_loss = self._eval_loss_function(y_hat, y_truth)

            else:
                # add adversarial noise vis lin PGD attack
                if attack_prob and np.random.rand() <= attack_prob:
                    with ctx_noparamgrad_and_eval(model):
                        x = attack.perturb(x, y_truth)
                y_hat = model(x)
                batch_train_loss = self._eval_loss_function(y_hat, y_truth)

            sum_batchmean_train_loss += batch_train_loss.item()

            running_train_acc, train_n_total, train_n_correct = default_optimizer._running_eval_acc(y_hat, y_truth,
                                                                                  n_total=train_n_total,
                                                                                  n_correct=train_n_correct,
                                                                                  soft_to_hard_fn=self.soft_to_hard_fn,
                                                                                  soft_to_hard_fn_kwargs=self.soft_to_hard_fn_kwargs)

            # compute gradient
            if use_amp:
                # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
                # Backward passes under autocast are not recommended.
                # Backward ops run in the same dtype autocast chose for corresponding forward ops.
                scaler.scale(batch_train_loss).backward()
            else:
                if np.isnan(sum_batchmean_train_loss) or np.isnan(running_train_acc):
                    default_optimizer._save_nandata(x, y_hat, y_truth, batch_train_loss, sum_batchmean_train_loss, running_train_acc,
                                  train_n_total, train_n_correct, model)

                batch_train_loss.backward()

            # perform gradient clipping if configured
            if self.optimizer_cfg.training_cfg.clip_grad:
                if use_amp:
                    # Unscales the gradients of optimizer's assigned params in-place
                    scaler.unscale_(self.optimizer)

                if self.optimizer_cfg.training_cfg.clip_type == 'norm':
                    # clip_grad_norm_ modifies gradients in place
                    #  see: https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html
                    torch_clip_grad.clip_grad_norm_(model.parameters(), self.optimizer_cfg.training_cfg.clip_val,
                                                    **self.optimizer_cfg.training_cfg.clip_kwargs)
                elif self.optimizer_cfg.training_cfg.clip_type == 'val':
                    # clip_grad_val_ modifies gradients in place
                    #  see: https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html
                    torch_clip_grad.clip_grad_value_(
                        model.parameters(), self.optimizer_cfg.training_cfg.clip_val)
                else:
                    msg = "Unknown clipping type for gradient clipping!"
                    logger.error(msg)
                    raise ValueError(msg)

            if use_amp:
                # scaler.step() first unscales the gradients of the optimizer's assigned params.
                # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
                # otherwise, optimizer.step() is skipped.
                scaler.step(self.optimizer)
                # Updates the scale for next iteration.
                scaler.update()
            else:
                self.optimizer.step()

            # report batch statistics to tensorflow
            if self.tb_writer:
                try:
                    batch_num = int(epoch_num * num_batches + batch_idx)
                    self.tb_writer.add_scalar(self.optimizer_cfg.reporting_cfg.experiment_name + '-train_loss',
                                              batch_train_loss.item(), global_step=batch_num)
                    self.tb_writer.add_scalar(self.optimizer_cfg.reporting_cfg.experiment_name + '-running_train_acc',
                                              running_train_acc, global_step=batch_num)
                except:
                    # TODO: catch specific expcetions
                    pass

            loop.set_description('Epoch {}/{}'.format(epoch_num + 1, self.num_epochs))
            loop.set_postfix(avg_train_loss=batch_train_loss.item())

            if batch_idx % self.num_batches_per_logmsg == 0:
                logger.info('{}\tTrain Epoch: {} [{}/{} ({:.0f}%)]\tTrainLoss: {:.6f}\tTrainAcc: {:.6f}'.format(
                    pid, epoch_num, batch_idx * len(x), train_dataset_len,
                    100. * batch_idx / num_batches, batch_train_loss.item(), running_train_acc))

        train_stats = EpochTrainStatistics(running_train_acc, sum_batchmean_train_loss / float(num_batches))

        # if we have validation data, we compute on the validation dataset
        num_val_batches_clean = len(val_clean_loader)
        if num_val_batches_clean > 0:
            logger.info('Running Validation on Clean Data')
            running_val_clean_acc, _, _, val_clean_loss = \
                default_optimizer._eval_acc(val_clean_loader, model, self.device,
                          self.soft_to_hard_fn, self.soft_to_hard_fn_kwargs, self._eval_loss_function)
        else:
            logger.info("No dataset computed for validation on clean dataset!")
            running_val_clean_acc = None
            val_clean_loss = None

        num_val_batches_triggered = len(val_triggered_loader)
        if num_val_batches_triggered > 0:
            logger.info('Running Validation on Triggered Data')
            running_val_triggered_acc, _, _, val_triggered_loss = \
                default_optimizer._eval_acc(val_triggered_loader, model, self.device,
                          self.soft_to_hard_fn, self.soft_to_hard_fn_kwargs, self._eval_loss_function)
        else:
            logger.info(
                "No dataset computed for validation on triggered dataset!")
            running_val_triggered_acc = None
            val_triggered_loss = None

        validation_stats = EpochValidationStatistics(running_val_clean_acc, val_clean_loss,
                                                     running_val_triggered_acc, val_triggered_loss)
        if num_val_batches_clean > 0:
            logger.info('{}\tTrain Epoch: {} \tCleanValLoss: {:.6f}\tCleanValAcc: {:.6f}'.format(
                pid, epoch_num, val_clean_loss, running_val_clean_acc))
        if num_val_batches_triggered > 0:
            logger.info('{}\tTrain Epoch: {} \tTriggeredValLoss: {:.6f}\tTriggeredValAcc: {:.6f}'.format(
                pid, epoch_num, val_triggered_loss, running_val_triggered_acc))

        if self.tb_writer:
            try:
                batch_num = int((epoch_num + 1) * num_batches)
                if num_val_batches_clean > 0:
                    self.tb_writer.add_scalar(self.optimizer_cfg.reporting_cfg.experiment_name +
                                              '-clean-val-loss', val_clean_loss, global_step=batch_num)
                    self.tb_writer.add_scalar(self.optimizer_cfg.reporting_cfg.experiment_name +
                                              '-clean-val_acc', running_val_clean_acc, global_step=batch_num)
                if num_val_batches_triggered > 0:
                    self.tb_writer.add_scalar(self.optimizer_cfg.reporting_cfg.experiment_name +
                                              '-triggered-val-loss', val_triggered_loss, global_step=batch_num)
                    self.tb_writer.add_scalar(self.optimizer_cfg.reporting_cfg.experiment_name +
                                              '-triggered-val_acc', running_val_triggered_acc, global_step=batch_num)
            except:
                pass

        # update the lr-scheduler if necessary
        if self.lr_scheduler is not None:
            if self.optimizer_cfg.training_cfg.lr_scheduler_call_arg is None:
                self.lr_scheduler.step()
            elif self.optimizer_cfg.training_cfg.lr_scheduler_call_arg.lower() == 'val_acc':
                val_acc = validation_stats.get_val_acc()
                if val_acc is not None:
                    self.lr_scheduler.step(val_acc)
                else:
                    msg = "val_clean_acc not defined b/c validation dataset is not defined! Ignoring LR step!"
                    logger.warning(msg)
            elif self.optimizer_cfg.training_cfg.lr_scheduler_call_arg.lower() == 'val_loss':
                val_loss = validation_stats.get_val_loss()
                if val_loss is not None:
                    self.lr_scheduler.step(val_loss)
                else:
                    msg = "val_clean_loss not defined b/c validation dataset is not defined! Ignoring LR step!"
                    logger.warning(msg)
            else:
                msg = "Unknown mode for calling lr_scheduler!"
                logger.error(msg)
                raise ValueError(msg)

        return train_stats, validation_stats
    def test_running_accuracy(self):
        """
        Tests the running accuracy of the classifier, which takes the previous counts
        into account when computing the updated classification.
        """
        batch_size = 32
        num_outputs = 5

        # setup what our accumulated totals are so far
        # note that the _eval_acc function takes two dictionaries.  Each dictionary has the same format.
        #  the key is the classification label, and the value is the number of times that label has
        #  appeared.  In the n_total dictionary, the counts represent the total number of times that
        #  specific label was soon.  In the n_correct dictionary, the counts represent the number of
        #  times that label was classified correctly.
        n_total_prev = defaultdict(int, {0: 13, 1: 13, 2: 13, 3: 13, 4: 12})

        # test over a large variety of the number of correct predictions, per class
        val_ranges = [3, 7, 11]
        for zero_val in tqdm(val_ranges):
            for one_val in val_ranges:
                for two_val in val_ranges:
                    for three_val in val_ranges:
                        for four_val in val_ranges:
                            n_correct_prev = defaultdict(
                                int, {
                                    0: zero_val,
                                    1: one_val,
                                    2: two_val,
                                    3: three_val,
                                    4: four_val
                                })

                            # in order to compute the actual accuracy we should expect, we update the dictionaries with
                            # the network simulation, convert to a list, and then use sklearn to get the baseline
                            # accuracy that our default optimizer should be reporting.
                            # We then input the same information into the _eval_acc in the
                            # required format (i.e. cur_batch information, and all accumulated prev_information),
                            # and check whether they match

                            step = 0.1
                            batch_acc_vec = np.arange(0, 1 + step, step)
                            for batch_acc in batch_acc_vec:
                                random_mat = self.rso.rand(
                                    batch_size, num_outputs)
                                row_sum = random_mat.sum(axis=1)

                                # normalize the random_mat such that every row adds up to 1
                                # broadcast so we can divide every element in matrix by the row's sum
                                fake_network_output = random_mat / row_sum[:,
                                                                           None]
                                network_output = np.argmax(fake_network_output,
                                                           axis=1)

                                # now, modify a subset of the netowrk output and make that the "real" output
                                true_output = network_output.copy()
                                num_indices_to_modify = int(batch_size *
                                                            (1 - batch_acc))
                                indices_to_modify = self.rso.choice(
                                    range(batch_size),
                                    num_indices_to_modify,
                                    replace=False)

                                for ii in indices_to_modify:
                                    true_output[ii] = (true_output[ii] +
                                                       1) % num_outputs

                                # convert datatypes to what is expected during operation
                                network_output_pt = torch.tensor(
                                    fake_network_output, dtype=torch.float)
                                true_output_pt = torch.tensor(true_output,
                                                              dtype=torch.long)

                                # update the totals dictionaries to reflect the new batch of data
                                n_total_expected = n_total_prev.copy()
                                n_correct_expected = n_correct_prev.copy()
                                for to in true_output:
                                    n_total_expected[to] += 1
                                # compute how many the network got right, and update the necessary output
                                indices_not_modified = set(
                                    range(batch_size)).symmetric_difference(
                                        set(indices_to_modify))
                                for ii in indices_not_modified:
                                    n_correct_expected[network_output[ii]] += 1

                                # update the true & fake_network_output to aggregate both the previous call to _eval_acc
                                # and the current call to _eval_acc
                                true_output_prev_and_cur = list(true_output)
                                network_output_prev_and_cur = list(
                                    network_output)
                                for k, v in n_total_prev.items():
                                    true_output_prev_and_cur.extend([k] * v)
                                # simulate network outputs to keep the correct & total counts according to the
                                # previously defined dictionaries
                                for k, v in n_correct_prev.items():
                                    num_correct = v
                                    num_incorrect = n_total_prev[
                                        k] - num_correct
                                    network_output_prev_and_cur.extend(
                                        [k] * num_correct)
                                    network_output_prev_and_cur.extend(
                                        [((k + 1) % num_outputs)] *
                                        num_incorrect)
                                expected_balanced_acc = balanced_accuracy_score(
                                    true_output_prev_and_cur,
                                    network_output_prev_and_cur) * 100

                                actual_acc, n_total_actual, n_correct_actual = \
                                    _running_eval_acc(network_output_pt, true_output_pt, n_total=n_total_prev,
                                                      n_correct=n_correct_prev)
                                self.assertAlmostEqual(actual_acc,
                                                       expected_balanced_acc)
                                self.assertEqual(n_total_expected,
                                                 n_total_actual)
                                self.assertEqual(n_correct_expected,
                                                 n_correct_actual)
    def test_eval_binary_one_output_accuracy(self):
        batch_size = 32
        num_outputs = 1
        sigmoid_fn = lambda x: 1. / (1. + np.exp(-x))
        soft_to_hard_fn = lambda x: torch.round(torch.sigmoid(x)).int()

        step = 0.05
        batch_acc_vec = np.arange(0, 1 + step, step)
        for batch_acc in batch_acc_vec:
            true_output = (self.rso.rand(batch_size, num_outputs) * 4) - 2
            true_output_binary = np.expand_dims(np.asarray(
                [0 if x < 0 else 1 for x in true_output], dtype=np.int),
                                                axis=1)

            # now, modify a subset of the netowrk output and make that the "real" output
            network_output = true_output.copy()
            num_indices_to_modify = int(batch_size * (1 - batch_acc))
            num_indices_unmodified = batch_size - num_indices_to_modify
            indices_to_modify = self.rso.choice(range(batch_size),
                                                num_indices_to_modify,
                                                replace=False)

            for ii in indices_to_modify:
                # flip pos to neg, neg to pos
                if network_output[ii][0] >= 0:
                    network_output[ii][0] = network_output[ii][0] - 10
                else:
                    network_output[ii][0] = network_output[ii][0] + 10

            # convert datatypes to what is expected during operation
            network_output_pt = torch.tensor(network_output, dtype=torch.float)
            true_output_pt = torch.tensor(true_output_binary, dtype=torch.long)

            actual_acc, n_total, n_correct = \
                _running_eval_acc(network_output_pt, true_output_pt,
                                  n_total=None, n_correct=None,
                                  soft_to_hard_fn=soft_to_hard_fn)

            expected_n_total = defaultdict(int)
            for ii in range(len(true_output_binary)):
                to = true_output_binary[ii][0]
                expected_n_total[to] += 1
            expected_n_correct = defaultdict(int)
            for ii in range(batch_size):
                expected_n_correct[true_output_binary[ii][0]] += int(
                    np.round(sigmoid_fn(network_output[ii][0])) ==
                    true_output_binary[ii][0])
            expected_acc = balanced_accuracy_score(
                true_output_binary, np.round(sigmoid_fn(network_output))) * 100
            self.assertAlmostEqual(actual_acc, expected_acc)
            self.assertEqual(n_total, expected_n_total)
            self.assertEqual(n_correct, expected_n_correct)

        n_total_prev = defaultdict(int, {0: 40, 1: 24})
        val_ranges = [10, 15, 20]

        # test over a large variety of the number of correct predictions, per class
        for zero_val in tqdm(val_ranges):
            for one_val in val_ranges:
                n_correct_prev = defaultdict(int, {0: zero_val, 1: one_val})

                # in order to compute the actual accuracy we should expect, we update the dictionaries with
                # the network simulation, convert to a list, and then use sklearn to get the baseline
                # accuracy that our default optimizer should be reporting.
                # We then input the same information into the _eval_acc in the
                # required format (i.e. cur_batch information, and all accumulated prev_information),
                # and check whether they match

                step = 0.1
                batch_acc_vec = np.arange(0, 1 + step, step)
                for batch_acc in batch_acc_vec:
                    true_output = (self.rso.rand(batch_size, num_outputs) *
                                   4) - 2
                    true_output_binary = np.expand_dims(
                        np.asarray([0 if x < 0 else 1 for x in true_output],
                                   dtype=np.int),
                        axis=1)

                    # now, modify a subset of the netowrk output and make that the "real" output
                    network_output = true_output.copy()
                    num_indices_to_modify = int(batch_size * (1 - batch_acc))
                    num_indices_unmodified = batch_size - num_indices_to_modify
                    indices_to_modify = self.rso.choice(range(batch_size),
                                                        num_indices_to_modify,
                                                        replace=False)

                    for ii in indices_to_modify:
                        # flip pos to neg, neg to pos
                        if network_output[ii][0] >= 0:
                            network_output[ii][0] = network_output[ii][0] - 10
                        else:
                            network_output[ii][0] = network_output[ii][0] + 10

                    # convert datatypes to what is expected during operation
                    network_output_pt = torch.tensor(network_output,
                                                     dtype=torch.float)
                    true_output_pt = torch.tensor(true_output_binary,
                                                  dtype=torch.long)

                    n_total_expected = n_total_prev.copy()
                    n_correct_expected = n_correct_prev.copy()
                    for to in np.squeeze(true_output_binary):
                        n_total_expected[to] += 1
                    # compute how many the network got right, and update the necessary output
                    indices_not_modified = set(
                        range(batch_size)).symmetric_difference(
                            set(indices_to_modify))
                    for ii in indices_not_modified:
                        n_correct_expected[np.round(
                            sigmoid_fn(network_output[ii][0]))] += 1

                    true_output_prev_and_cur = []
                    for k, v, in n_total_expected.items():
                        true_output_prev_and_cur.extend([k] * v)
                    # simulate network outputs to keep the correct & total counts according to the
                    # previously defined dictionaries
                    network_output_prev_and_cur = []
                    for k, v in n_correct_expected.items():
                        num_correct = v
                        num_incorrect = n_total_expected[k] - num_correct
                        network_output_prev_and_cur.extend([k] * num_correct)
                        network_output_prev_and_cur.extend(
                            [((k + 1) % 2)] *
                            num_incorrect)  # hard-code mod to 2,
                        # b/c it is binary output and 1% 1 = 0, 0 % 1 = 0

                    expected_balanced_acc = balanced_accuracy_score(
                        true_output_prev_and_cur,
                        network_output_prev_and_cur) * 100

                    actual_acc, n_total, n_correct = \
                        _running_eval_acc(network_output_pt, true_output_pt,
                                          n_total=n_total_prev, n_correct=n_correct_prev,
                                          soft_to_hard_fn=soft_to_hard_fn)

                    self.assertAlmostEqual(actual_acc, expected_balanced_acc)
                    self.assertEqual(n_total_expected, n_total)
                    self.assertEqual(n_correct_expected, n_correct)