def train_helper(model: torchvision.models.resnet.ResNet, dataloaders: Dict[str, torch.utils.data.DataLoader], dataset_sizes: Dict[str, int], criterion: torch.nn.modules.loss, optimizer: torch.optim, scheduler: torch.optim.lr_scheduler, num_epochs: int, writer: IO, train_order_writer: IO, device: torch.device, start_epoch: int, batch_size: int, save_interval: int, checkpoints_folder: Path, num_layers: int, classes: List[str], num_classes: int) -> None: """ Function for training ResNet. Args: model: ResNet model for training. dataloaders: Dataloaders for IO pipeline. dataset_sizes: Sizes of the training and validation dataset. criterion: Metric used for calculating loss. optimizer: Optimizer to use for gradient descent. scheduler: Scheduler to use for learning rate decay. start_epoch: Starting epoch for training. writer: Writer to write logging information. train_order_writer: Writer to write the order of training examples. device: Device to use for running model. num_epochs: Total number of epochs to train for. batch_size: Mini-batch size to use for training. save_interval: Number of epochs between saving checkpoints. checkpoints_folder: Directory to save model checkpoints to. num_layers: Number of layers to use in the ResNet model from [18, 34, 50, 101, 152]. classes: Names of the classes in the dataset. num_classes: Number of classes in the dataset. """ since = time.time() # Initialize all the tensors to be used in training and validation. # Do this outside the loop since it will be written over entirely at each # epoch and doesn't need to be reallocated each time. train_all_labels = torch.empty(size=(dataset_sizes["train"], ), dtype=torch.long).cpu() train_all_predicts = torch.empty(size=(dataset_sizes["train"], ), dtype=torch.long).cpu() val_all_labels = torch.empty(size=(dataset_sizes["val"], ), dtype=torch.long).cpu() val_all_predicts = torch.empty(size=(dataset_sizes["val"], ), dtype=torch.long).cpu() global_minibatch_counter = 0 # Train for specified number of epochs. for epoch in range(start_epoch, num_epochs): # Training phase. model.train(mode=True) train_running_loss = 0.0 train_running_corrects = 0 epoch_minibatch_counter = 0 # Train over all training data. for idx, (inputs, labels, paths) in enumerate(dataloaders["train"]): train_inputs = inputs.to(device=device) train_labels = labels.to(device=device) optimizer.zero_grad() # Forward and backpropagation. with torch.set_grad_enabled(mode=True): train_outputs = model(train_inputs) __, train_preds = torch.max(train_outputs, dim=1) train_loss = criterion(input=train_outputs, target=train_labels) train_loss.backward() optimizer.step() # Update training diagnostics. train_running_loss += train_loss.item() * train_inputs.size(0) train_running_corrects += torch.sum( train_preds == train_labels.data, dtype=torch.double) start = idx * batch_size end = start + batch_size train_all_labels[start:end] = train_labels.detach().cpu() train_all_predicts[start:end] = train_preds.detach().cpu() global_minibatch_counter += 1 epoch_minibatch_counter += 1 # for path in paths: #write the order that the model was trained in # train_order_writer.write("/".join(path.split("/")[-2:]) + "\n") if global_minibatch_counter % 10 == 0 or global_minibatch_counter == 5: calculate_confusion_matrix( all_labels=train_all_labels.numpy(), all_predicts=train_all_predicts.numpy(), classes=classes, num_classes=num_classes) # Store training diagnostics. train_loss = train_running_loss / (epoch_minibatch_counter * batch_size) train_acc = train_running_corrects / (epoch_minibatch_counter * batch_size) # Validation phase. model.train(mode=False) val_running_loss = 0.0 val_running_corrects = 0 # Feed forward over all the validation data. for idx, (val_inputs, val_labels, paths) in enumerate(dataloaders["val"]): val_inputs = val_inputs.to(device=device) val_labels = val_labels.to(device=device) # Feed forward. with torch.set_grad_enabled(mode=False): val_outputs = model(val_inputs) _, val_preds = torch.max(val_outputs, dim=1) val_loss = criterion(input=val_outputs, target=val_labels) # Update validation diagnostics. val_running_loss += val_loss.item() * val_inputs.size(0) val_running_corrects += torch.sum( val_preds == val_labels.data, dtype=torch.double) start = idx * batch_size end = start + batch_size val_all_labels[start:end] = val_labels.detach().cpu() val_all_predicts[start:end] = val_preds.detach().cpu() calculate_confusion_matrix( all_labels=val_all_labels.numpy(), all_predicts=val_all_predicts.numpy(), classes=classes, num_classes=num_classes) # Store validation diagnostics. val_loss = val_running_loss / dataset_sizes["val"] val_acc = val_running_corrects / dataset_sizes["val"] if torch.cuda.is_available(): torch.cuda.empty_cache() # Remaining things related to training. if global_minibatch_counter % 10 == 0 or global_minibatch_counter == 5: epoch_output_path = checkpoints_folder.joinpath( f"resnet{num_layers}_e{epoch}_mb{global_minibatch_counter}_va{val_acc:.5f}.pt" ) # Confirm the output directory exists. epoch_output_path.parent.mkdir(parents=True, exist_ok=True) # Save the model as a state dictionary. torch.save(obj={ "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "epoch": epoch + 1 }, f=str(epoch_output_path)) writer.write( f"{epoch},{global_minibatch_counter},{train_loss:.4f}," f"{train_acc:.4f},{val_loss:.4f},{val_acc:.4f}\n") current_lr = None for group in optimizer.param_groups: current_lr = group["lr"] # Print the diagnostics for each epoch. print(f"Epoch {epoch} with " f"mb {global_minibatch_counter} " f"lr {current_lr:.15f}: " f"t_loss: {train_loss:.4f} " f"t_acc: {train_acc:.4f} " f"v_loss: {val_loss:.4f} " f"v_acc: {val_acc:.4f}\n") scheduler.step() current_lr = None for group in optimizer.param_groups: current_lr = group["lr"] # Print training information at the end. print(f"\ntraining complete in " f"{(time.time() - since) // 60:.2f} minutes")
def train_helper_with_gradients_no_update( model: torchvision.models.resnet.ResNet, dataloaders: Dict[str, torch.utils.data.DataLoader], dataset_sizes: Dict[str, int], criterion: torch.nn.modules.loss, optimizer: torch.optim, scheduler: torch.optim.lr_scheduler, num_epochs: int, writer: IO, train_order_writer: IO, device: torch.device, start_epoch: int, batch_size: int, save_interval: int, checkpoints_folder: Path, num_layers: int, classes: List[str], num_classes: int, grad_csv: Path) -> None: since = time.time() # Initialize all the tensors to be used in training and validation. # Do this outside the loop since it will be written over entirely at each # epoch and doesn't need to be reallocated each time. train_all_labels = torch.empty(size=(dataset_sizes["train"], ), dtype=torch.long).cpu() train_all_predicts = torch.empty(size=(dataset_sizes["train"], ), dtype=torch.long).cpu() val_all_labels = torch.empty(size=(dataset_sizes["val"], ), dtype=torch.long).cpu() val_all_predicts = torch.empty(size=(dataset_sizes["val"], ), dtype=torch.long).cpu() global_minibatch_counter = 0 mag_writer = open(str(grad_csv), "w") mag_writer.write( "image_name,train_loss,layers_-1,layer_0,layer_60,layer_1,layer_20,layer_40,layer_59,conf,correct\n" ) # Train for specified number of epochs. for epoch in range(0, num_epochs): # Training phase. model.train(mode=True) train_running_loss = 0.0 train_running_corrects = 0 epoch_minibatch_counter = 0 # Train over all training data. for idx, (inputs, labels, paths) in enumerate(dataloaders["train"]): train_inputs = inputs.to(device=device) train_labels = labels.to(device=device) optimizer.zero_grad() # Forward and backpropagation. with torch.set_grad_enabled(mode=True): train_outputs = model(train_inputs) confs, train_preds = torch.max(train_outputs, dim=1) train_loss = criterion(input=train_outputs, target=train_labels) train_loss.backward(retain_graph=True) # optimizer.step() # batch_grads = torch.autograd.grad(train_loss, model.parameters(), retain_graph=True) # print(len(batch_grads)) # for batch_grad in batch_grads: # print(batch_grad.size()) train_loss_npy = float(train_loss.detach().cpu().numpy()) layer_num_to_mag = get_grad_magnitude(model) image_name = get_image_name(paths[0]) conf = float(confs.detach().cpu().numpy()) train_pred = int(train_preds.detach().cpu().numpy()[0]) gt_label = int(train_labels.detach().cpu().numpy()[0]) correct = 0 if train_pred == gt_label: correct = 1 output_line = f"{image_name},{train_loss_npy:.4f},{layer_num_to_mag[-1]:.4f},{layer_num_to_mag[0]:.4f},{layer_num_to_mag[60]:.4f},{layer_num_to_mag[1]:.4f},{layer_num_to_mag[20]:.4f},{layer_num_to_mag[40]:.4f},{layer_num_to_mag[59]:.4f},{conf:.4f},{correct}\n" mag_writer.write(output_line) print(idx, output_line) # print(idx, image_name, train_loss_npy, conf, train_pred, gt_label) # Update training diagnostics. train_running_loss += train_loss.item() * train_inputs.size(0) train_running_corrects += torch.sum( train_preds == train_labels.data, dtype=torch.double) start = idx * batch_size end = start + batch_size train_all_labels[start:end] = train_labels.detach().cpu() train_all_predicts[start:end] = train_preds.detach().cpu() global_minibatch_counter += 1 epoch_minibatch_counter += 1 # if global_minibatch_counter % 1000 == 0: # calculate_confusion_matrix(all_labels=train_all_labels.numpy(), # all_predicts=train_all_predicts.numpy(), # classes=classes, # num_classes=num_classes) # # Store training diagnostics. # train_loss = train_running_loss / (epoch_minibatch_counter * batch_size) # train_acc = train_running_corrects / (epoch_minibatch_counter * batch_size) # # Validation phase. # model.train(mode=False) # val_running_loss = 0.0 # val_running_corrects = 0 # # Feed forward over all the validation data. # for idx, (val_inputs, val_labels, paths) in enumerate(dataloaders["val"]): # val_inputs = val_inputs.to(device=device) # val_labels = val_labels.to(device=device) # # Feed forward. # with torch.set_grad_enabled(mode=False): # val_outputs = model(val_inputs) # _, val_preds = torch.max(val_outputs, dim=1) # val_loss = criterion(input=val_outputs, target=val_labels) # # Update validation diagnostics. # val_running_loss += val_loss.item() * val_inputs.size(0) # val_running_corrects += torch.sum(val_preds == val_labels.data, # dtype=torch.double) # start = idx * batch_size # end = start + batch_size # val_all_labels[start:end] = val_labels.detach().cpu() # val_all_predicts[start:end] = val_preds.detach().cpu() # calculate_confusion_matrix(all_labels=val_all_labels.numpy(), # all_predicts=val_all_predicts.numpy(), # classes=classes, # num_classes=num_classes) # # Store validation diagnostics. # val_loss = val_running_loss / dataset_sizes["val"] # val_acc = val_running_corrects / dataset_sizes["val"] # if torch.cuda.is_available(): # torch.cuda.empty_cache() # Remaining things related to training. # if global_minibatch_counter % 200000 == 0 or global_minibatch_counter == 5: # epoch_output_path = checkpoints_folder.joinpath( # f"resnet{num_layers}_e{epoch}_mb{global_minibatch_counter}_va{val_acc:.5f}.pt") # # Confirm the output directory exists. # epoch_output_path.parent.mkdir(parents=True, exist_ok=True) # # Save the model as a state dictionary. # torch.save(obj={ # "model_state_dict": model.state_dict(), # "optimizer_state_dict": optimizer.state_dict(), # "scheduler_state_dict": scheduler.state_dict(), # "epoch": epoch + 1 # }, f=str(epoch_output_path)) # writer.write(f"{epoch},{global_minibatch_counter},{train_loss:.4f}," # f"{train_acc:.4f},{val_loss:.4f},{val_acc:.4f}\n") # current_lr = None # for group in optimizer.param_groups: # current_lr = group["lr"] # # Print the diagnostics for each epoch. # print(f"Epoch {epoch} with " # f"mb {global_minibatch_counter} " # f"lr {current_lr:.15f}: " # f"t_loss: {train_loss:.4f} " # f"t_acc: {train_acc:.4f} " # f"v_loss: {val_loss:.4f} " # f"v_acc: {val_acc:.4f}\n") scheduler.step() current_lr = None for group in optimizer.param_groups: current_lr = group["lr"] # Print training information at the end. print(f"\ntraining complete in " f"{(time.time() - since) // 60:.2f} minutes")
def _train_helper(self, model: torchvision.models.resnet.ResNet, dataloaders: Dict[str, torch.utils.data.DataLoader], dataset_sizes: Dict[str, int], loss_fn, optimizer: torch.optim, scheduler: torch.optim.lr_scheduler, start_epoch: int, writer: IO) -> None: """ Function for learning ResNet. Args: model: ResNet model for learning. dataloaders: Dataloaders for IO pipeline. dataset_sizes: Sizes of the learning and validation dataset. loss_fn: Metric used for calculating loss. optimizer: Optimizer to use for gradient descent. scheduler: Scheduler to use for learning rate decay. start_epoch: Starting epoch for learning. writer: Writer to write logging information. """ learning_init_time = time.time() # Initialize all the tensors to be used in learning and validation. # Do this outside the loop since it will be written over entirely at each # epoch and doesn't need to be reallocated each time. train_all_labels = torch.empty(size=(dataset_sizes["train"], ), dtype=torch.long).cpu() train_all_predicts = torch.empty(size=(dataset_sizes["train"], ), dtype=torch.long).cpu() val_all_labels = torch.empty(size=(dataset_sizes["val"], ), dtype=torch.long).cpu() val_all_predicts = torch.empty(size=(dataset_sizes["val"], ), dtype=torch.long).cpu() early_stopper = EarlyStopper(patience=self._early_stopping_patience, mode=EarlyStopper.Mode.MAX) if self._resume_checkpoint and self._last_val_acc: best_val_acc = self._last_val_acc else: best_val_acc = 0. # Train for specified number of epochs. for epoch in range(start_epoch, self._num_epochs): epoch_init_time = time.time() # Training phase. model.train(mode=True) train_running_loss = 0.0 train_running_corrects = 0 # Train over all learning data. for idx, (train_inputs, true_labels) in enumerate(dataloaders["train"]): train_patches = train_inputs["patch"].to(device=self._device) train_x_coord = train_inputs["x_coord"].to(device=self._device) train_y_coord = train_inputs["y_coord"].to(device=self._device) true_labels = true_labels.to(device=self._device) optimizer.zero_grad() # Forward and backpropagation. with torch.set_grad_enabled(mode=True): train_logits = model(train_patches, train_x_coord, train_y_coord).squeeze(dim=1) train_loss = loss_fn(logits=train_logits, target=true_labels) train_loss.backward() optimizer.step() # Update learning diagnostics. train_running_loss += train_loss.item() * train_patches.size(0) pred_labels = self._extract_pred_labels(train_logits) train_running_corrects += torch.sum( pred_labels == true_labels.data, dtype=torch.double) start = idx * self._batch_size end = start + self._batch_size train_all_labels[start:end] = true_labels.detach().cpu() train_all_predicts[start:end] = pred_labels.detach().cpu() self._calculate_confusion_matrix( all_labels=train_all_labels.numpy(), all_predicts=train_all_predicts.numpy(), classes=self._classes, num_classes=self._num_classes) # Store learning diagnostics. train_loss = train_running_loss / dataset_sizes["train"] train_acc = train_running_corrects / dataset_sizes["train"] if torch.cuda.is_available(): torch.cuda.empty_cache() # Validation phase. model.train(mode=False) val_running_loss = 0.0 val_running_corrects = 0 # Feed forward over all the validation data. for idx, (val_inputs, val_labels) in enumerate(dataloaders["val"]): val_patches = val_inputs["patch"].to(device=self._device) val_x_coord = val_inputs["x_coord"].to(device=self._device) val_y_coord = val_inputs["y_coord"].to(device=self._device) val_labels = val_labels.to(device=self._device) # Feed forward. with torch.set_grad_enabled(mode=False): val_logits = model(val_patches, val_x_coord, val_y_coord).squeeze(dim=1) val_loss = loss_fn(logits=val_logits, target=val_labels) # Update validation diagnostics. val_running_loss += val_loss.item() * val_patches.size(0) pred_labels = self._extract_pred_labels(val_logits) val_running_corrects += torch.sum( pred_labels == val_labels.data, dtype=torch.double) start = idx * self._batch_size end = start + self._batch_size val_all_labels[start:end] = val_labels.detach().cpu() val_all_predicts[start:end] = pred_labels.detach().cpu() self._calculate_confusion_matrix( all_labels=val_all_labels.numpy(), all_predicts=val_all_predicts.numpy(), classes=self._classes, num_classes=self._num_classes) # Store validation diagnostics. val_loss = val_running_loss / dataset_sizes["val"] val_acc = val_running_corrects / dataset_sizes["val"] if torch.cuda.is_available(): torch.cuda.empty_cache() scheduler.step() current_lr = None for group in optimizer.param_groups: current_lr = group["lr"] # Remaining things related to learning. if val_acc > best_val_acc: best_val_acc = val_acc best_model_ckpt_path = self._checkpoints_folder.joinpath( f"resnet{self._num_layers}_e{epoch}_va{val_acc:.5f}.pt") # Confirm the output directory exists. best_model_ckpt_path.parent.mkdir(parents=True, exist_ok=True) # Save the model as a state dictionary. torch.save(obj={ "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "epoch": epoch + 1 }, f=str(best_model_ckpt_path)) self._clean_ckpt_folder(best_model_ckpt_path) writer.write(f"{epoch},{train_loss:.4f}," f"{train_acc:.4f},{val_loss:.4f},{val_acc:.4f}\n") # Print the diagnostics for each epoch. logging.info( f"Epoch {epoch} " f"with lr {current_lr:.15f}: " f"{self._format_time_period(epoch_init_time, time.time())} " f"t_loss: {train_loss:.4f} " f"t_acc: {train_acc:.4f} " f"v_loss: {val_loss:.4f} " f"v_acc: {val_acc:.4f}\n") early_stopper.update(val_acc) if early_stopper.is_stopping(): logging.info("Early stopping") break # Print learning information at the end. logging.info( f"\nlearning complete in " f"{self._format_time_period(learning_init_time, time.time())}")
def compute_resnet_grad_no_update_helper( model: torchvision.models.resnet.ResNet, dataloaders: Dict[str, torch.utils.data.DataLoader], dataset_sizes: Dict[str, int], criterion: torch.nn.modules.loss, optimizer: torch.optim, scheduler: torch.optim.lr_scheduler, num_epochs: int, log_writer: IO, train_order_writer: IO, device: torch.device, batch_size: int, checkpoints_folder: Path, num_layers: int, classes: List[str], num_classes: int): global_minibatch_counter = 0 # Initialize all the tensors to be used in training and validation. # Do this outside the loop since it will be written over entirely at each # epoch and doesn't need to be reallocated each time. train_all_labels = torch.empty(size=(dataset_sizes["train"], ), dtype=torch.long).cpu() train_all_predicts = torch.empty(size=(dataset_sizes["train"], ), dtype=torch.long).cpu() val_all_labels = torch.empty(size=(dataset_sizes["val"], ), dtype=torch.long).cpu() val_all_predicts = torch.empty(size=(dataset_sizes["val"], ), dtype=torch.long).cpu() # grad_writer = grad_csv.open(mode="w") # grad_writer.write("image_name,train_loss,layers_-1,layer_0,layer_60,layer_1,layer_20,layer_40,layer_59,conf,correct\n") for epoch in range(1, num_epochs + 1): model.train(mode=True) # Training phase. train_running_loss, train_running_corrects, epoch_minibatch_counter = 0.0, 0, 0 tup_list = [] for idx, (inputs, labels, paths) in enumerate(dataloaders["train"]): train_inputs = inputs.to(device=device) train_labels = labels.to(device=device) optimizer.zero_grad() # Forward and backpropagation. with torch.set_grad_enabled(mode=True): train_outputs = model(train_inputs) confs, train_preds = torch.max(train_outputs, dim=1) train_loss = criterion(input=train_outputs, target=train_labels) train_loss.backward(retain_graph=True) # optimizer.step() train_loss_npy = float(train_loss.detach().cpu().numpy()) layer_num_to_mag = get_grad_magnitude(model) image_name = get_image_name(paths[0]) conf = float(confs.detach().cpu().numpy()) train_pred = int(train_preds.detach().cpu().numpy()[0]) gt_label = int(train_labels.detach().cpu().numpy()[0]) correct = 0 if train_pred == gt_label: correct = 1 output_line = f"{image_name},{train_loss_npy:.4f},{layer_num_to_mag[-1]:.4f},{layer_num_to_mag[0]:.4f},{layer_num_to_mag[60]:.4f},{layer_num_to_mag[1]:.4f},{layer_num_to_mag[20]:.4f},{layer_num_to_mag[40]:.4f},{layer_num_to_mag[59]:.4f},{conf:.4f},{correct}\n" # grad_writer.write(output_line) tup = (idx, image_name, layer_num_to_mag[-1]) tup_list.append(tup) if idx % 1000 == 0: print(tup) return tup_list
def train_helper(model: torchvision.models.resnet.ResNet, dataloaders: Dict[str, torch.utils.data.DataLoader], dataset_sizes: Dict[str, int], criterion: torch.nn.modules.loss, optimizer: torch.optim, scheduler: torch.optim.lr_scheduler, num_epochs: int, log_writer: IO, train_order_writer: IO, device: torch.device, batch_size: int, checkpoints_folder: Path, num_layers: int, classes: List[str], minibatch_counter, num_classes: int) -> None: since = time.time() global_minibatch_counter = minibatch_counter # Initialize all the tensors to be used in training and validation. # Do this outside the loop since it will be written over entirely at each # epoch and doesn't need to be reallocated each time. train_all_labels = torch.empty(size=(dataset_sizes["train"], ), dtype=torch.long).cpu() train_all_predicts = torch.empty(size=(dataset_sizes["train"], ), dtype=torch.long).cpu() val_all_labels = torch.empty(size=(dataset_sizes["val"], ), dtype=torch.long).cpu() val_all_predicts = torch.empty(size=(dataset_sizes["val"], ), dtype=torch.long).cpu() for epoch in range(1, num_epochs + 1): model.train(mode=True) # Training phase. train_running_loss, train_running_corrects, epoch_minibatch_counter = 0.0, 0, 0 for idx, (inputs, labels, paths) in enumerate(dataloaders["train"]): train_inputs = inputs.to(device=device) train_labels = labels.to(device=device) optimizer.zero_grad() # Forward and backpropagation. with torch.set_grad_enabled(mode=True): train_outputs = model(train_inputs) __, train_preds = torch.max(train_outputs, dim=1) train_loss = criterion(input=train_outputs, target=train_labels) train_loss.backward() optimizer.step() # Update training diagnostics. train_running_loss += train_loss.item() * train_inputs.size(0) train_running_corrects += torch.sum( train_preds == train_labels.data, dtype=torch.double) this_batch_size = train_labels.detach().cpu().shape[0] start = idx * batch_size end = start + this_batch_size train_all_labels[start:end] = train_labels.detach().cpu() train_all_predicts[start:end] = train_preds.detach().cpu() global_minibatch_counter += 1 epoch_minibatch_counter += 1 # Calculate training diagnostics calculate_confusion_matrix(all_labels=train_all_labels.numpy(), all_predicts=train_all_predicts.numpy(), classes=classes, num_classes=num_classes) train_loss = train_running_loss / (epoch_minibatch_counter * batch_size) train_acc = train_running_corrects / (epoch_minibatch_counter * batch_size) # Validation phase. model.train(mode=False) val_running_loss = 0.0 val_running_corrects = 0 # Feed forward over all the validation data. for idx, (val_inputs, val_labels, paths) in enumerate(dataloaders["val"]): val_inputs = val_inputs.to(device=device) val_labels = val_labels.to(device=device) # Feed forward. with torch.set_grad_enabled(mode=False): val_outputs = model(val_inputs) _, val_preds = torch.max(val_outputs, dim=1) val_loss = criterion(input=val_outputs, target=val_labels) # Update validation diagnostics. val_running_loss += val_loss.item() * val_inputs.size(0) val_running_corrects += torch.sum(val_preds == val_labels.data, dtype=torch.double) this_batch_size = val_labels.detach().cpu().shape[0] start = idx * batch_size end = start + this_batch_size val_all_labels[start:end] = val_labels.detach().cpu() val_all_predicts[start:end] = val_preds.detach().cpu() # Calculate validation diagnostics calculate_confusion_matrix(all_labels=val_all_labels.numpy(), all_predicts=val_all_predicts.numpy(), classes=classes, num_classes=num_classes) val_loss = val_running_loss / dataset_sizes["val"] val_acc = val_running_corrects / dataset_sizes["val"] if torch.cuda.is_available(): torch.cuda.empty_cache() # Remaining things related to training. epoch_output_path = checkpoints_folder.joinpath( f"resnet{num_layers}_e{epoch}_mb{global_minibatch_counter}_va{val_acc:.5f}.pt" ) epoch_output_path.parent.mkdir(parents=True, exist_ok=True) # Save the model as a state dictionary. torch.save(obj={ "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "epoch": epoch + 1 }, f=str(epoch_output_path)) log_writer.write( f"{epoch},{global_minibatch_counter},{train_loss:.4f},{train_acc:.4f},{val_loss:.4f},{val_acc:.4f}\n" ) current_lr = None for group in optimizer.param_groups: current_lr = group["lr"] # Print the diagnostics for each epoch. print(f"Epoch {epoch} with " f"mb {global_minibatch_counter} " f"lr {current_lr:.15f}: " f"t_loss: {train_loss:.4f} " f"t_acc: {train_acc:.4f} " f"v_loss: {val_loss:.4f} " f"v_acc: {val_acc:.4f}\n") scheduler.step() current_lr = None for group in optimizer.param_groups: current_lr = group["lr"] # Print training information at the end. print(f"\ntraining complete in " f"{(time.time() - since) // 60:.2f} minutes") return epoch_output_path, global_minibatch_counter
def train_smartgrad_helper(model: torchvision.models.resnet.ResNet, dataloaders: Dict[str, torch.utils.data.DataLoader], dataset_sizes: Dict[str, int], criterion: torch.nn.modules.loss, optimizer: torch.optim, scheduler: torch.optim.lr_scheduler, num_epochs: int, log_writer: IO, train_order_writer: IO, device: torch.device, train_batch_size: int, val_batch_size: int, fake_minibatch_size: int, annealling_factor: float, save_mb_interval: int, val_mb_interval: int, checkpoints_folder: Path, num_layers: int, classes: List[str], num_classes: int) -> None: grad_layers = list(range(1, 21)) since = time.time() global_minibatch_counter = 0 # Initialize all the tensors to be used in training and validation. # Do this outside the loop since it will be written over entirely at each # epoch and doesn't need to be reallocated each time. train_all_labels = torch.empty(size=(dataset_sizes["train"], ), dtype=torch.long).cpu() train_all_predicts = torch.empty(size=(dataset_sizes["train"], ), dtype=torch.long).cpu() val_all_labels = torch.empty(size=(dataset_sizes["val"], ), dtype=torch.long).cpu() val_all_predicts = torch.empty(size=(dataset_sizes["val"], ), dtype=torch.long).cpu() for epoch in range(1, num_epochs+1): model.train(mode=False) # Training phase. train_running_loss, train_running_corrects, epoch_minibatch_counter = 0.0, 0, 0 idx_to_gt = {} for idx, (inputs, labels, paths) in enumerate(dataloaders["train"]): train_inputs = inputs.to(device=device) train_labels = labels.to(device=device) optimizer.zero_grad() # Forward and backpropagation. with torch.set_grad_enabled(mode=True): train_outputs = model(train_inputs) __, train_preds = torch.max(train_outputs, dim=1) train_loss = criterion(input=train_outputs, target=train_labels) train_loss.backward(retain_graph=True) gt_label = int(train_labels.detach().cpu().numpy()[0]) idx_to_gt[idx] = gt_label ######################## #### important code #### ######################## #clear the memory fake_minibatch_idx = idx % fake_minibatch_size fake_minibatch_num = int(idx / fake_minibatch_size) if fake_minibatch_idx == 0: minibatch_grad_dict = {}; gc.collect() #get the per-example gradient magnitude and add to minibatch_grad_dict grad_as_dict, grad_flattened = model_to_grad_as_dict_and_flatten(model, grad_layers) minibatch_grad_dict[idx] = (grad_as_dict, grad_flattened) #every batch, calculate the best ones if fake_minibatch_idx == fake_minibatch_size - 1: idx_to_weight_batch = get_idx_to_weight(minibatch_grad_dict, annealling_factor, idx_to_gt) print(idx_to_weight_batch) ########################## # print("\n...............................updating......................................" + str(idx)) for layer_num, param in enumerate(model.parameters()): # if layer_num in [0]:#grad_layers: new_grad = get_new_layer_grad(layer_num, idx_to_weight_batch, minibatch_grad_dict) assert param.grad.detach().cpu().numpy().shape == new_grad.detach().cpu().numpy().shape param.grad = new_grad # check_model_weights(idx, model) optimizer.step() # check_model_weights(idx, model) # print("................................done........................................." + str(idx) + '\n\n\n\n') ########################## # Update training diagnostics. train_running_loss += train_loss.item() * train_inputs.size(0) train_running_corrects += torch.sum(train_preds == train_labels.data, dtype=torch.double) start = idx * train_batch_size end = start + train_batch_size train_all_labels[start:end] = train_labels.detach().cpu() train_all_predicts[start:end] = train_preds.detach().cpu() global_minibatch_counter += 1 epoch_minibatch_counter += 1 # Write the path of training order if it exists if train_order_writer: for path in paths: #write the order that the model was trained in train_order_writer.write("/".join(path.split("/")[-2:]) + "\n") # Validate the model if global_minibatch_counter % val_mb_interval == 0 or global_minibatch_counter == 1: # Calculate training diagnostics calculate_confusion_matrix( all_labels=train_all_labels.numpy(), all_predicts=train_all_predicts.numpy(), classes=classes, num_classes=num_classes) train_loss = train_running_loss / (epoch_minibatch_counter * train_batch_size) train_acc = train_running_corrects / (epoch_minibatch_counter * train_batch_size) # Validation phase. model.train(mode=False) val_running_loss = 0.0 val_running_corrects = 0 # Feed forward over all the validation data. for idx, (val_inputs, val_labels, paths) in enumerate(dataloaders["val"]): val_inputs = val_inputs.to(device=device) val_labels = val_labels.to(device=device) # Feed forward. with torch.set_grad_enabled(mode=False): val_outputs = model(val_inputs) _, val_preds = torch.max(val_outputs, dim=1) val_loss = criterion(input=val_outputs, target=val_labels) # Update validation diagnostics. val_running_loss += val_loss.item() * val_inputs.size(0) val_running_corrects += torch.sum(val_preds == val_labels.data, dtype=torch.double) start = idx * val_batch_size end = start + val_batch_size val_all_labels[start:end] = val_labels.detach().cpu() val_all_predicts[start:end] = val_preds.detach().cpu() # Calculate validation diagnostics calculate_confusion_matrix( all_labels=val_all_labels.numpy(), all_predicts=val_all_predicts.numpy(), classes=classes, num_classes=num_classes) val_loss = val_running_loss / dataset_sizes["val"] val_acc = val_running_corrects / dataset_sizes["val"] if torch.cuda.is_available(): torch.cuda.empty_cache() # Remaining things related to training. if global_minibatch_counter % save_mb_interval == 0 or global_minibatch_counter == 1: epoch_output_path = checkpoints_folder.joinpath(f"resnet{num_layers}_e{epoch}_mb{global_minibatch_counter}_va{val_acc:.5f}.pt") epoch_output_path.parent.mkdir(parents=True, exist_ok=True) # Save the model as a state dictionary. torch.save(obj={ "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "epoch": epoch + 1 }, f=str(epoch_output_path)) log_writer.write(f"{epoch},{global_minibatch_counter},{train_loss:.4f},{train_acc:.4f},{val_loss:.4f},{val_acc:.4f}\n") current_lr = None for group in optimizer.param_groups: current_lr = group["lr"] # Print the diagnostics for each epoch. print(f"Epoch {epoch} with " f"mb {global_minibatch_counter} " f"lr {current_lr:.15f}: " f"t_loss: {train_loss:.4f} " f"t_acc: {train_acc:.4f} " f"v_loss: {val_loss:.4f} " f"v_acc: {val_acc:.4f}\n") scheduler.step() current_lr = None for group in optimizer.param_groups: current_lr = group["lr"]