def train_helper(model: torchvision.models.resnet.ResNet,
                 dataloaders: Dict[str, torch.utils.data.DataLoader],
                 dataset_sizes: Dict[str,
                                     int], criterion: torch.nn.modules.loss,
                 optimizer: torch.optim, scheduler: torch.optim.lr_scheduler,
                 num_epochs: int, writer: IO, train_order_writer: IO,
                 device: torch.device, start_epoch: int, batch_size: int,
                 save_interval: int, checkpoints_folder: Path, num_layers: int,
                 classes: List[str], num_classes: int) -> None:
    """
    Function for training ResNet.
    Args:
        model: ResNet model for training.
        dataloaders: Dataloaders for IO pipeline.
        dataset_sizes: Sizes of the training and validation dataset.
        criterion: Metric used for calculating loss.
        optimizer: Optimizer to use for gradient descent.
        scheduler: Scheduler to use for learning rate decay.
        start_epoch: Starting epoch for training.
        writer: Writer to write logging information.
        train_order_writer: Writer to write the order of training examples.
        device: Device to use for running model.
        num_epochs: Total number of epochs to train for.
        batch_size: Mini-batch size to use for training.
        save_interval: Number of epochs between saving checkpoints.
        checkpoints_folder: Directory to save model checkpoints to.
        num_layers: Number of layers to use in the ResNet model from [18, 34, 50, 101, 152].
        classes: Names of the classes in the dataset.
        num_classes: Number of classes in the dataset.
    """
    since = time.time()

    # Initialize all the tensors to be used in training and validation.
    # Do this outside the loop since it will be written over entirely at each
    # epoch and doesn't need to be reallocated each time.
    train_all_labels = torch.empty(size=(dataset_sizes["train"], ),
                                   dtype=torch.long).cpu()
    train_all_predicts = torch.empty(size=(dataset_sizes["train"], ),
                                     dtype=torch.long).cpu()
    val_all_labels = torch.empty(size=(dataset_sizes["val"], ),
                                 dtype=torch.long).cpu()
    val_all_predicts = torch.empty(size=(dataset_sizes["val"], ),
                                   dtype=torch.long).cpu()

    global_minibatch_counter = 0

    # Train for specified number of epochs.
    for epoch in range(start_epoch, num_epochs):

        # Training phase.
        model.train(mode=True)

        train_running_loss = 0.0
        train_running_corrects = 0
        epoch_minibatch_counter = 0

        # Train over all training data.
        for idx, (inputs, labels, paths) in enumerate(dataloaders["train"]):

            train_inputs = inputs.to(device=device)
            train_labels = labels.to(device=device)
            optimizer.zero_grad()

            # Forward and backpropagation.
            with torch.set_grad_enabled(mode=True):
                train_outputs = model(train_inputs)
                __, train_preds = torch.max(train_outputs, dim=1)
                train_loss = criterion(input=train_outputs,
                                       target=train_labels)
                train_loss.backward()
                optimizer.step()

            # Update training diagnostics.
            train_running_loss += train_loss.item() * train_inputs.size(0)
            train_running_corrects += torch.sum(
                train_preds == train_labels.data, dtype=torch.double)

            start = idx * batch_size
            end = start + batch_size

            train_all_labels[start:end] = train_labels.detach().cpu()
            train_all_predicts[start:end] = train_preds.detach().cpu()

            global_minibatch_counter += 1
            epoch_minibatch_counter += 1

            # for path in paths: #write the order that the model was trained in
            #     train_order_writer.write("/".join(path.split("/")[-2:]) + "\n")

            if global_minibatch_counter % 10 == 0 or global_minibatch_counter == 5:

                calculate_confusion_matrix(
                    all_labels=train_all_labels.numpy(),
                    all_predicts=train_all_predicts.numpy(),
                    classes=classes,
                    num_classes=num_classes)

                # Store training diagnostics.
                train_loss = train_running_loss / (epoch_minibatch_counter *
                                                   batch_size)
                train_acc = train_running_corrects / (epoch_minibatch_counter *
                                                      batch_size)

                # Validation phase.
                model.train(mode=False)

                val_running_loss = 0.0
                val_running_corrects = 0

                # Feed forward over all the validation data.
                for idx, (val_inputs, val_labels,
                          paths) in enumerate(dataloaders["val"]):
                    val_inputs = val_inputs.to(device=device)
                    val_labels = val_labels.to(device=device)

                    # Feed forward.
                    with torch.set_grad_enabled(mode=False):
                        val_outputs = model(val_inputs)
                        _, val_preds = torch.max(val_outputs, dim=1)
                        val_loss = criterion(input=val_outputs,
                                             target=val_labels)

                    # Update validation diagnostics.
                    val_running_loss += val_loss.item() * val_inputs.size(0)
                    val_running_corrects += torch.sum(
                        val_preds == val_labels.data, dtype=torch.double)

                    start = idx * batch_size
                    end = start + batch_size

                    val_all_labels[start:end] = val_labels.detach().cpu()
                    val_all_predicts[start:end] = val_preds.detach().cpu()

                calculate_confusion_matrix(
                    all_labels=val_all_labels.numpy(),
                    all_predicts=val_all_predicts.numpy(),
                    classes=classes,
                    num_classes=num_classes)

                # Store validation diagnostics.
                val_loss = val_running_loss / dataset_sizes["val"]
                val_acc = val_running_corrects / dataset_sizes["val"]

                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

                # Remaining things related to training.
                if global_minibatch_counter % 10 == 0 or global_minibatch_counter == 5:
                    epoch_output_path = checkpoints_folder.joinpath(
                        f"resnet{num_layers}_e{epoch}_mb{global_minibatch_counter}_va{val_acc:.5f}.pt"
                    )

                    # Confirm the output directory exists.
                    epoch_output_path.parent.mkdir(parents=True, exist_ok=True)

                    # Save the model as a state dictionary.
                    torch.save(obj={
                        "model_state_dict": model.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                        "scheduler_state_dict": scheduler.state_dict(),
                        "epoch": epoch + 1
                    },
                               f=str(epoch_output_path))

                writer.write(
                    f"{epoch},{global_minibatch_counter},{train_loss:.4f},"
                    f"{train_acc:.4f},{val_loss:.4f},{val_acc:.4f}\n")

                current_lr = None
                for group in optimizer.param_groups:
                    current_lr = group["lr"]

                # Print the diagnostics for each epoch.
                print(f"Epoch {epoch} with "
                      f"mb {global_minibatch_counter} "
                      f"lr {current_lr:.15f}: "
                      f"t_loss: {train_loss:.4f} "
                      f"t_acc: {train_acc:.4f} "
                      f"v_loss: {val_loss:.4f} "
                      f"v_acc: {val_acc:.4f}\n")

        scheduler.step()

        current_lr = None
        for group in optimizer.param_groups:
            current_lr = group["lr"]

    # Print training information at the end.
    print(f"\ntraining complete in "
          f"{(time.time() - since) // 60:.2f} minutes")
예제 #2
0
def train_helper_with_gradients_no_update(
        model: torchvision.models.resnet.ResNet,
        dataloaders: Dict[str, torch.utils.data.DataLoader],
        dataset_sizes: Dict[str, int], criterion: torch.nn.modules.loss,
        optimizer: torch.optim, scheduler: torch.optim.lr_scheduler,
        num_epochs: int, writer: IO, train_order_writer: IO,
        device: torch.device, start_epoch: int, batch_size: int,
        save_interval: int, checkpoints_folder: Path, num_layers: int,
        classes: List[str], num_classes: int, grad_csv: Path) -> None:
    since = time.time()

    # Initialize all the tensors to be used in training and validation.
    # Do this outside the loop since it will be written over entirely at each
    # epoch and doesn't need to be reallocated each time.
    train_all_labels = torch.empty(size=(dataset_sizes["train"], ),
                                   dtype=torch.long).cpu()
    train_all_predicts = torch.empty(size=(dataset_sizes["train"], ),
                                     dtype=torch.long).cpu()
    val_all_labels = torch.empty(size=(dataset_sizes["val"], ),
                                 dtype=torch.long).cpu()
    val_all_predicts = torch.empty(size=(dataset_sizes["val"], ),
                                   dtype=torch.long).cpu()

    global_minibatch_counter = 0

    mag_writer = open(str(grad_csv), "w")
    mag_writer.write(
        "image_name,train_loss,layers_-1,layer_0,layer_60,layer_1,layer_20,layer_40,layer_59,conf,correct\n"
    )

    # Train for specified number of epochs.
    for epoch in range(0, num_epochs):

        # Training phase.
        model.train(mode=True)

        train_running_loss = 0.0
        train_running_corrects = 0
        epoch_minibatch_counter = 0

        # Train over all training data.
        for idx, (inputs, labels, paths) in enumerate(dataloaders["train"]):
            train_inputs = inputs.to(device=device)
            train_labels = labels.to(device=device)
            optimizer.zero_grad()

            # Forward and backpropagation.
            with torch.set_grad_enabled(mode=True):
                train_outputs = model(train_inputs)
                confs, train_preds = torch.max(train_outputs, dim=1)
                train_loss = criterion(input=train_outputs,
                                       target=train_labels)
                train_loss.backward(retain_graph=True)
                # optimizer.step()

                # batch_grads = torch.autograd.grad(train_loss, model.parameters(), retain_graph=True)
                # print(len(batch_grads))
                # for batch_grad in batch_grads:
                #     print(batch_grad.size())

                train_loss_npy = float(train_loss.detach().cpu().numpy())
                layer_num_to_mag = get_grad_magnitude(model)
                image_name = get_image_name(paths[0])
                conf = float(confs.detach().cpu().numpy())
                train_pred = int(train_preds.detach().cpu().numpy()[0])
                gt_label = int(train_labels.detach().cpu().numpy()[0])
                correct = 0
                if train_pred == gt_label:
                    correct = 1

                output_line = f"{image_name},{train_loss_npy:.4f},{layer_num_to_mag[-1]:.4f},{layer_num_to_mag[0]:.4f},{layer_num_to_mag[60]:.4f},{layer_num_to_mag[1]:.4f},{layer_num_to_mag[20]:.4f},{layer_num_to_mag[40]:.4f},{layer_num_to_mag[59]:.4f},{conf:.4f},{correct}\n"
                mag_writer.write(output_line)
                print(idx, output_line)
                # print(idx, image_name, train_loss_npy, conf, train_pred, gt_label)

            # Update training diagnostics.
            train_running_loss += train_loss.item() * train_inputs.size(0)
            train_running_corrects += torch.sum(
                train_preds == train_labels.data, dtype=torch.double)

            start = idx * batch_size
            end = start + batch_size

            train_all_labels[start:end] = train_labels.detach().cpu()
            train_all_predicts[start:end] = train_preds.detach().cpu()

            global_minibatch_counter += 1
            epoch_minibatch_counter += 1

            # if global_minibatch_counter % 1000 == 0:

            #     calculate_confusion_matrix(all_labels=train_all_labels.numpy(),
            #                             all_predicts=train_all_predicts.numpy(),
            #                             classes=classes,
            #                             num_classes=num_classes)

            #     # Store training diagnostics.
            #     train_loss = train_running_loss / (epoch_minibatch_counter * batch_size)
            #     train_acc = train_running_corrects / (epoch_minibatch_counter * batch_size)

            #     # Validation phase.
            #     model.train(mode=False)

            #     val_running_loss = 0.0
            #     val_running_corrects = 0

            #     # Feed forward over all the validation data.
            #     for idx, (val_inputs, val_labels, paths) in enumerate(dataloaders["val"]):
            #         val_inputs = val_inputs.to(device=device)
            #         val_labels = val_labels.to(device=device)

            #         # Feed forward.
            #         with torch.set_grad_enabled(mode=False):
            #             val_outputs = model(val_inputs)
            #             _, val_preds = torch.max(val_outputs, dim=1)
            #             val_loss = criterion(input=val_outputs, target=val_labels)

            #         # Update validation diagnostics.
            #         val_running_loss += val_loss.item() * val_inputs.size(0)
            #         val_running_corrects += torch.sum(val_preds == val_labels.data,
            #                                         dtype=torch.double)

            #         start = idx * batch_size
            #         end = start + batch_size

            #         val_all_labels[start:end] = val_labels.detach().cpu()
            #         val_all_predicts[start:end] = val_preds.detach().cpu()

            #     calculate_confusion_matrix(all_labels=val_all_labels.numpy(),
            #                             all_predicts=val_all_predicts.numpy(),
            #                             classes=classes,
            #                             num_classes=num_classes)

            #     # Store validation diagnostics.
            #     val_loss = val_running_loss / dataset_sizes["val"]
            #     val_acc = val_running_corrects / dataset_sizes["val"]

            #     if torch.cuda.is_available():
            #         torch.cuda.empty_cache()

            # Remaining things related to training.
            # if global_minibatch_counter % 200000 == 0 or global_minibatch_counter == 5:
            #     epoch_output_path = checkpoints_folder.joinpath(
            #         f"resnet{num_layers}_e{epoch}_mb{global_minibatch_counter}_va{val_acc:.5f}.pt")

            #     # Confirm the output directory exists.
            #     epoch_output_path.parent.mkdir(parents=True, exist_ok=True)

            #     # Save the model as a state dictionary.
            #     torch.save(obj={
            #         "model_state_dict": model.state_dict(),
            #         "optimizer_state_dict": optimizer.state_dict(),
            #         "scheduler_state_dict": scheduler.state_dict(),
            #         "epoch": epoch + 1
            #     }, f=str(epoch_output_path))

            # writer.write(f"{epoch},{global_minibatch_counter},{train_loss:.4f},"
            #             f"{train_acc:.4f},{val_loss:.4f},{val_acc:.4f}\n")

            # current_lr = None
            # for group in optimizer.param_groups:
            #     current_lr = group["lr"]

            # # Print the diagnostics for each epoch.
            # print(f"Epoch {epoch} with "
            #     f"mb {global_minibatch_counter} "
            #     f"lr {current_lr:.15f}: "
            #     f"t_loss: {train_loss:.4f} "
            #     f"t_acc: {train_acc:.4f} "
            #     f"v_loss: {val_loss:.4f} "
            #     f"v_acc: {val_acc:.4f}\n")

        scheduler.step()

        current_lr = None
        for group in optimizer.param_groups:
            current_lr = group["lr"]

    # Print training information at the end.
    print(f"\ntraining complete in "
          f"{(time.time() - since) // 60:.2f} minutes")
예제 #3
0
    def _train_helper(self, model: torchvision.models.resnet.ResNet,
                      dataloaders: Dict[str, torch.utils.data.DataLoader],
                      dataset_sizes: Dict[str, int], loss_fn,
                      optimizer: torch.optim,
                      scheduler: torch.optim.lr_scheduler, start_epoch: int,
                      writer: IO) -> None:
        """
        Function for learning ResNet.

        Args:
            model: ResNet model for learning.
            dataloaders: Dataloaders for IO pipeline.
            dataset_sizes: Sizes of the learning and validation dataset.
            loss_fn: Metric used for calculating loss.
            optimizer: Optimizer to use for gradient descent.
            scheduler: Scheduler to use for learning rate decay.
            start_epoch: Starting epoch for learning.
            writer: Writer to write logging information.
        """
        learning_init_time = time.time()

        # Initialize all the tensors to be used in learning and validation.
        # Do this outside the loop since it will be written over entirely at each
        # epoch and doesn't need to be reallocated each time.
        train_all_labels = torch.empty(size=(dataset_sizes["train"], ),
                                       dtype=torch.long).cpu()
        train_all_predicts = torch.empty(size=(dataset_sizes["train"], ),
                                         dtype=torch.long).cpu()
        val_all_labels = torch.empty(size=(dataset_sizes["val"], ),
                                     dtype=torch.long).cpu()
        val_all_predicts = torch.empty(size=(dataset_sizes["val"], ),
                                       dtype=torch.long).cpu()
        early_stopper = EarlyStopper(patience=self._early_stopping_patience,
                                     mode=EarlyStopper.Mode.MAX)

        if self._resume_checkpoint and self._last_val_acc:
            best_val_acc = self._last_val_acc
        else:
            best_val_acc = 0.

        # Train for specified number of epochs.
        for epoch in range(start_epoch, self._num_epochs):
            epoch_init_time = time.time()

            # Training phase.
            model.train(mode=True)

            train_running_loss = 0.0
            train_running_corrects = 0

            # Train over all learning data.
            for idx, (train_inputs,
                      true_labels) in enumerate(dataloaders["train"]):
                train_patches = train_inputs["patch"].to(device=self._device)
                train_x_coord = train_inputs["x_coord"].to(device=self._device)
                train_y_coord = train_inputs["y_coord"].to(device=self._device)
                true_labels = true_labels.to(device=self._device)
                optimizer.zero_grad()

                # Forward and backpropagation.
                with torch.set_grad_enabled(mode=True):
                    train_logits = model(train_patches, train_x_coord,
                                         train_y_coord).squeeze(dim=1)
                    train_loss = loss_fn(logits=train_logits,
                                         target=true_labels)
                    train_loss.backward()
                    optimizer.step()

                # Update learning diagnostics.
                train_running_loss += train_loss.item() * train_patches.size(0)
                pred_labels = self._extract_pred_labels(train_logits)
                train_running_corrects += torch.sum(
                    pred_labels == true_labels.data, dtype=torch.double)

                start = idx * self._batch_size
                end = start + self._batch_size

                train_all_labels[start:end] = true_labels.detach().cpu()
                train_all_predicts[start:end] = pred_labels.detach().cpu()

            self._calculate_confusion_matrix(
                all_labels=train_all_labels.numpy(),
                all_predicts=train_all_predicts.numpy(),
                classes=self._classes,
                num_classes=self._num_classes)

            # Store learning diagnostics.
            train_loss = train_running_loss / dataset_sizes["train"]
            train_acc = train_running_corrects / dataset_sizes["train"]

            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            # Validation phase.
            model.train(mode=False)

            val_running_loss = 0.0
            val_running_corrects = 0

            # Feed forward over all the validation data.
            for idx, (val_inputs, val_labels) in enumerate(dataloaders["val"]):
                val_patches = val_inputs["patch"].to(device=self._device)
                val_x_coord = val_inputs["x_coord"].to(device=self._device)
                val_y_coord = val_inputs["y_coord"].to(device=self._device)
                val_labels = val_labels.to(device=self._device)

                # Feed forward.
                with torch.set_grad_enabled(mode=False):
                    val_logits = model(val_patches, val_x_coord,
                                       val_y_coord).squeeze(dim=1)
                    val_loss = loss_fn(logits=val_logits, target=val_labels)

                # Update validation diagnostics.
                val_running_loss += val_loss.item() * val_patches.size(0)
                pred_labels = self._extract_pred_labels(val_logits)
                val_running_corrects += torch.sum(
                    pred_labels == val_labels.data, dtype=torch.double)

                start = idx * self._batch_size
                end = start + self._batch_size

                val_all_labels[start:end] = val_labels.detach().cpu()
                val_all_predicts[start:end] = pred_labels.detach().cpu()

            self._calculate_confusion_matrix(
                all_labels=val_all_labels.numpy(),
                all_predicts=val_all_predicts.numpy(),
                classes=self._classes,
                num_classes=self._num_classes)

            # Store validation diagnostics.
            val_loss = val_running_loss / dataset_sizes["val"]
            val_acc = val_running_corrects / dataset_sizes["val"]

            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            scheduler.step()

            current_lr = None
            for group in optimizer.param_groups:
                current_lr = group["lr"]

            # Remaining things related to learning.
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_model_ckpt_path = self._checkpoints_folder.joinpath(
                    f"resnet{self._num_layers}_e{epoch}_va{val_acc:.5f}.pt")

                # Confirm the output directory exists.
                best_model_ckpt_path.parent.mkdir(parents=True, exist_ok=True)

                # Save the model as a state dictionary.
                torch.save(obj={
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler_state_dict": scheduler.state_dict(),
                    "epoch": epoch + 1
                },
                           f=str(best_model_ckpt_path))

                self._clean_ckpt_folder(best_model_ckpt_path)

            writer.write(f"{epoch},{train_loss:.4f},"
                         f"{train_acc:.4f},{val_loss:.4f},{val_acc:.4f}\n")

            # Print the diagnostics for each epoch.
            logging.info(
                f"Epoch {epoch} "
                f"with lr {current_lr:.15f}: "
                f"{self._format_time_period(epoch_init_time, time.time())} "
                f"t_loss: {train_loss:.4f} "
                f"t_acc: {train_acc:.4f} "
                f"v_loss: {val_loss:.4f} "
                f"v_acc: {val_acc:.4f}\n")

            early_stopper.update(val_acc)
            if early_stopper.is_stopping():
                logging.info("Early stopping")
                break

        # Print learning information at the end.
        logging.info(
            f"\nlearning complete in "
            f"{self._format_time_period(learning_init_time, time.time())}")
예제 #4
0
def compute_resnet_grad_no_update_helper(
        model: torchvision.models.resnet.ResNet,
        dataloaders: Dict[str, torch.utils.data.DataLoader],
        dataset_sizes: Dict[str, int], criterion: torch.nn.modules.loss,
        optimizer: torch.optim, scheduler: torch.optim.lr_scheduler,
        num_epochs: int, log_writer: IO, train_order_writer: IO,
        device: torch.device, batch_size: int, checkpoints_folder: Path,
        num_layers: int, classes: List[str], num_classes: int):

    global_minibatch_counter = 0
    # Initialize all the tensors to be used in training and validation.
    # Do this outside the loop since it will be written over entirely at each
    # epoch and doesn't need to be reallocated each time.
    train_all_labels = torch.empty(size=(dataset_sizes["train"], ),
                                   dtype=torch.long).cpu()
    train_all_predicts = torch.empty(size=(dataset_sizes["train"], ),
                                     dtype=torch.long).cpu()
    val_all_labels = torch.empty(size=(dataset_sizes["val"], ),
                                 dtype=torch.long).cpu()
    val_all_predicts = torch.empty(size=(dataset_sizes["val"], ),
                                   dtype=torch.long).cpu()

    # grad_writer = grad_csv.open(mode="w")
    # grad_writer.write("image_name,train_loss,layers_-1,layer_0,layer_60,layer_1,layer_20,layer_40,layer_59,conf,correct\n")

    for epoch in range(1, num_epochs + 1):

        model.train(mode=True)  # Training phase.
        train_running_loss, train_running_corrects, epoch_minibatch_counter = 0.0, 0, 0

        tup_list = []

        for idx, (inputs, labels, paths) in enumerate(dataloaders["train"]):
            train_inputs = inputs.to(device=device)
            train_labels = labels.to(device=device)
            optimizer.zero_grad()

            # Forward and backpropagation.
            with torch.set_grad_enabled(mode=True):
                train_outputs = model(train_inputs)
                confs, train_preds = torch.max(train_outputs, dim=1)
                train_loss = criterion(input=train_outputs,
                                       target=train_labels)
                train_loss.backward(retain_graph=True)
                # optimizer.step()

                train_loss_npy = float(train_loss.detach().cpu().numpy())
                layer_num_to_mag = get_grad_magnitude(model)
                image_name = get_image_name(paths[0])
                conf = float(confs.detach().cpu().numpy())
                train_pred = int(train_preds.detach().cpu().numpy()[0])
                gt_label = int(train_labels.detach().cpu().numpy()[0])
                correct = 0
                if train_pred == gt_label:
                    correct = 1

                output_line = f"{image_name},{train_loss_npy:.4f},{layer_num_to_mag[-1]:.4f},{layer_num_to_mag[0]:.4f},{layer_num_to_mag[60]:.4f},{layer_num_to_mag[1]:.4f},{layer_num_to_mag[20]:.4f},{layer_num_to_mag[40]:.4f},{layer_num_to_mag[59]:.4f},{conf:.4f},{correct}\n"
                # grad_writer.write(output_line)
                tup = (idx, image_name, layer_num_to_mag[-1])
                tup_list.append(tup)

                if idx % 1000 == 0:
                    print(tup)

    return tup_list
예제 #5
0
def train_helper(model: torchvision.models.resnet.ResNet,
                 dataloaders: Dict[str, torch.utils.data.DataLoader],
                 dataset_sizes: Dict[str,
                                     int], criterion: torch.nn.modules.loss,
                 optimizer: torch.optim, scheduler: torch.optim.lr_scheduler,
                 num_epochs: int, log_writer: IO, train_order_writer: IO,
                 device: torch.device, batch_size: int,
                 checkpoints_folder: Path, num_layers: int, classes: List[str],
                 minibatch_counter, num_classes: int) -> None:

    since = time.time()
    global_minibatch_counter = minibatch_counter
    # Initialize all the tensors to be used in training and validation.
    # Do this outside the loop since it will be written over entirely at each
    # epoch and doesn't need to be reallocated each time.
    train_all_labels = torch.empty(size=(dataset_sizes["train"], ),
                                   dtype=torch.long).cpu()
    train_all_predicts = torch.empty(size=(dataset_sizes["train"], ),
                                     dtype=torch.long).cpu()
    val_all_labels = torch.empty(size=(dataset_sizes["val"], ),
                                 dtype=torch.long).cpu()
    val_all_predicts = torch.empty(size=(dataset_sizes["val"], ),
                                   dtype=torch.long).cpu()

    for epoch in range(1, num_epochs + 1):

        model.train(mode=True)  # Training phase.
        train_running_loss, train_running_corrects, epoch_minibatch_counter = 0.0, 0, 0

        for idx, (inputs, labels, paths) in enumerate(dataloaders["train"]):
            train_inputs = inputs.to(device=device)
            train_labels = labels.to(device=device)
            optimizer.zero_grad()

            # Forward and backpropagation.
            with torch.set_grad_enabled(mode=True):
                train_outputs = model(train_inputs)
                __, train_preds = torch.max(train_outputs, dim=1)
                train_loss = criterion(input=train_outputs,
                                       target=train_labels)
                train_loss.backward()
                optimizer.step()

            # Update training diagnostics.
            train_running_loss += train_loss.item() * train_inputs.size(0)
            train_running_corrects += torch.sum(
                train_preds == train_labels.data, dtype=torch.double)

            this_batch_size = train_labels.detach().cpu().shape[0]
            start = idx * batch_size
            end = start + this_batch_size
            train_all_labels[start:end] = train_labels.detach().cpu()
            train_all_predicts[start:end] = train_preds.detach().cpu()

            global_minibatch_counter += 1
            epoch_minibatch_counter += 1

        # Calculate training diagnostics
        calculate_confusion_matrix(all_labels=train_all_labels.numpy(),
                                   all_predicts=train_all_predicts.numpy(),
                                   classes=classes,
                                   num_classes=num_classes)
        train_loss = train_running_loss / (epoch_minibatch_counter *
                                           batch_size)
        train_acc = train_running_corrects / (epoch_minibatch_counter *
                                              batch_size)

        # Validation phase.
        model.train(mode=False)
        val_running_loss = 0.0
        val_running_corrects = 0

        # Feed forward over all the validation data.
        for idx, (val_inputs, val_labels,
                  paths) in enumerate(dataloaders["val"]):
            val_inputs = val_inputs.to(device=device)
            val_labels = val_labels.to(device=device)

            # Feed forward.
            with torch.set_grad_enabled(mode=False):
                val_outputs = model(val_inputs)
                _, val_preds = torch.max(val_outputs, dim=1)
                val_loss = criterion(input=val_outputs, target=val_labels)

            # Update validation diagnostics.
            val_running_loss += val_loss.item() * val_inputs.size(0)
            val_running_corrects += torch.sum(val_preds == val_labels.data,
                                              dtype=torch.double)

            this_batch_size = val_labels.detach().cpu().shape[0]
            start = idx * batch_size
            end = start + this_batch_size
            val_all_labels[start:end] = val_labels.detach().cpu()
            val_all_predicts[start:end] = val_preds.detach().cpu()

        # Calculate validation diagnostics
        calculate_confusion_matrix(all_labels=val_all_labels.numpy(),
                                   all_predicts=val_all_predicts.numpy(),
                                   classes=classes,
                                   num_classes=num_classes)
        val_loss = val_running_loss / dataset_sizes["val"]
        val_acc = val_running_corrects / dataset_sizes["val"]

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        # Remaining things related to training.

        epoch_output_path = checkpoints_folder.joinpath(
            f"resnet{num_layers}_e{epoch}_mb{global_minibatch_counter}_va{val_acc:.5f}.pt"
        )
        epoch_output_path.parent.mkdir(parents=True, exist_ok=True)

        # Save the model as a state dictionary.
        torch.save(obj={
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict(),
            "epoch": epoch + 1
        },
                   f=str(epoch_output_path))

        log_writer.write(
            f"{epoch},{global_minibatch_counter},{train_loss:.4f},{train_acc:.4f},{val_loss:.4f},{val_acc:.4f}\n"
        )

        current_lr = None
        for group in optimizer.param_groups:
            current_lr = group["lr"]

        # Print the diagnostics for each epoch.
        print(f"Epoch {epoch} with "
              f"mb {global_minibatch_counter} "
              f"lr {current_lr:.15f}: "
              f"t_loss: {train_loss:.4f} "
              f"t_acc: {train_acc:.4f} "
              f"v_loss: {val_loss:.4f} "
              f"v_acc: {val_acc:.4f}\n")

        scheduler.step()

        current_lr = None
        for group in optimizer.param_groups:
            current_lr = group["lr"]

    # Print training information at the end.
    print(f"\ntraining complete in "
          f"{(time.time() - since) // 60:.2f} minutes")

    return epoch_output_path, global_minibatch_counter
def train_smartgrad_helper(model: torchvision.models.resnet.ResNet,
                 dataloaders: Dict[str, torch.utils.data.DataLoader],
                 dataset_sizes: Dict[str, int],
                 criterion: torch.nn.modules.loss, 
                 optimizer: torch.optim,
                 scheduler: torch.optim.lr_scheduler, 
                 num_epochs: int,
                 log_writer: IO, 
                 train_order_writer: IO, 
                 device: torch.device, 
                 train_batch_size: int,
                 val_batch_size: int,
                 fake_minibatch_size: int, 
                 annealling_factor: float,
                 save_mb_interval: int, 
                 val_mb_interval: int,
                 checkpoints_folder: Path,
                 num_layers: int, 
                 classes: List[str],
                 num_classes: int) -> None:

    grad_layers = list(range(1, 21))

    since = time.time()
    global_minibatch_counter = 0
    # Initialize all the tensors to be used in training and validation.
    # Do this outside the loop since it will be written over entirely at each
    # epoch and doesn't need to be reallocated each time.
    train_all_labels = torch.empty(size=(dataset_sizes["train"], ),
                                   dtype=torch.long).cpu()
    train_all_predicts = torch.empty(size=(dataset_sizes["train"], ),
                                     dtype=torch.long).cpu()
    val_all_labels = torch.empty(size=(dataset_sizes["val"], ),
                                 dtype=torch.long).cpu()
    val_all_predicts = torch.empty(size=(dataset_sizes["val"], ),
                                   dtype=torch.long).cpu()

    for epoch in range(1, num_epochs+1):

        model.train(mode=False) # Training phase.
        train_running_loss, train_running_corrects, epoch_minibatch_counter = 0.0, 0, 0
        idx_to_gt = {}
        
        for idx, (inputs, labels, paths) in enumerate(dataloaders["train"]):
            train_inputs = inputs.to(device=device)
            train_labels = labels.to(device=device)
            optimizer.zero_grad()

            # Forward and backpropagation.
            with torch.set_grad_enabled(mode=True):
                train_outputs = model(train_inputs)
                __, train_preds = torch.max(train_outputs, dim=1)
                train_loss = criterion(input=train_outputs, target=train_labels)
                train_loss.backward(retain_graph=True)

                gt_label = int(train_labels.detach().cpu().numpy()[0])
                idx_to_gt[idx] = gt_label

                ########################
                #### important code ####
                ########################

                #clear the memory
                fake_minibatch_idx = idx % fake_minibatch_size
                fake_minibatch_num = int(idx / fake_minibatch_size)
                if fake_minibatch_idx == 0:
                    minibatch_grad_dict = {}; gc.collect()
                
                #get the per-example gradient magnitude and add to minibatch_grad_dict
                grad_as_dict, grad_flattened = model_to_grad_as_dict_and_flatten(model, grad_layers)
                minibatch_grad_dict[idx] = (grad_as_dict, grad_flattened)

                #every batch, calculate the best ones
                if fake_minibatch_idx == fake_minibatch_size - 1:
                    idx_to_weight_batch = get_idx_to_weight(minibatch_grad_dict, annealling_factor, idx_to_gt)
                    print(idx_to_weight_batch)

                    ##########################
                    # print("\n...............................updating......................................" + str(idx))
                    for layer_num, param in enumerate(model.parameters()):
                        # if layer_num in [0]:#grad_layers:
                        new_grad = get_new_layer_grad(layer_num, idx_to_weight_batch, minibatch_grad_dict)
                        assert param.grad.detach().cpu().numpy().shape == new_grad.detach().cpu().numpy().shape
                        param.grad = new_grad
                            # check_model_weights(idx, model)
                    optimizer.step()
                    # check_model_weights(idx, model)
                    # print("................................done........................................." + str(idx) + '\n\n\n\n')
                    ##########################

            # Update training diagnostics.
            train_running_loss += train_loss.item() * train_inputs.size(0)
            train_running_corrects += torch.sum(train_preds == train_labels.data, dtype=torch.double)

            start = idx * train_batch_size
            end = start + train_batch_size
            train_all_labels[start:end] = train_labels.detach().cpu()
            train_all_predicts[start:end] = train_preds.detach().cpu()

            global_minibatch_counter += 1
            epoch_minibatch_counter += 1

            # Write the path of training order if it exists
            if train_order_writer:
                for path in paths: #write the order that the model was trained in
                    train_order_writer.write("/".join(path.split("/")[-2:]) + "\n")

            # Validate the model
            if global_minibatch_counter % val_mb_interval == 0 or global_minibatch_counter == 1:

                # Calculate training diagnostics
                calculate_confusion_matrix( all_labels=train_all_labels.numpy(), all_predicts=train_all_predicts.numpy(),
                                            classes=classes, num_classes=num_classes)
                train_loss = train_running_loss / (epoch_minibatch_counter * train_batch_size)
                train_acc = train_running_corrects / (epoch_minibatch_counter * train_batch_size)

                # Validation phase.
                model.train(mode=False)
                val_running_loss = 0.0
                val_running_corrects = 0

                # Feed forward over all the validation data.
                for idx, (val_inputs, val_labels, paths) in enumerate(dataloaders["val"]):
                    val_inputs = val_inputs.to(device=device)
                    val_labels = val_labels.to(device=device)

                    # Feed forward.
                    with torch.set_grad_enabled(mode=False):
                        val_outputs = model(val_inputs)
                        _, val_preds = torch.max(val_outputs, dim=1)
                        val_loss = criterion(input=val_outputs, target=val_labels)

                    # Update validation diagnostics.
                    val_running_loss += val_loss.item() * val_inputs.size(0)
                    val_running_corrects += torch.sum(val_preds == val_labels.data,
                                                    dtype=torch.double)

                    start = idx * val_batch_size
                    end = start + val_batch_size
                    val_all_labels[start:end] = val_labels.detach().cpu()
                    val_all_predicts[start:end] = val_preds.detach().cpu()

                # Calculate validation diagnostics
                calculate_confusion_matrix( all_labels=val_all_labels.numpy(), all_predicts=val_all_predicts.numpy(),
                                            classes=classes, num_classes=num_classes)
                val_loss = val_running_loss / dataset_sizes["val"]
                val_acc = val_running_corrects / dataset_sizes["val"]

                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                    

                # Remaining things related to training.
                if global_minibatch_counter % save_mb_interval == 0 or global_minibatch_counter == 1:

                    epoch_output_path = checkpoints_folder.joinpath(f"resnet{num_layers}_e{epoch}_mb{global_minibatch_counter}_va{val_acc:.5f}.pt")
                    epoch_output_path.parent.mkdir(parents=True, exist_ok=True)

                    # Save the model as a state dictionary.
                    torch.save(obj={
                        "model_state_dict": model.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                        "scheduler_state_dict": scheduler.state_dict(),
                        "epoch": epoch + 1
                    }, f=str(epoch_output_path))

                log_writer.write(f"{epoch},{global_minibatch_counter},{train_loss:.4f},{train_acc:.4f},{val_loss:.4f},{val_acc:.4f}\n")

                current_lr = None
                for group in optimizer.param_groups:
                    current_lr = group["lr"]

                # Print the diagnostics for each epoch.
                print(f"Epoch {epoch} with "
                    f"mb {global_minibatch_counter} "
                    f"lr {current_lr:.15f}: "
                    f"t_loss: {train_loss:.4f} "
                    f"t_acc: {train_acc:.4f} "
                    f"v_loss: {val_loss:.4f} "
                    f"v_acc: {val_acc:.4f}\n")

        scheduler.step()

        current_lr = None
        for group in optimizer.param_groups:
            current_lr = group["lr"]