def train(self, train_image_indices, batch_size, num_epochs=50, train_method='normal', lambda_1=0, lambda_2=0, start_from_pretrained_model=True, learning_rate=0.01, optimizer='SGD'): if os.path.exists(self.checkpoint_path): os.remove(self.checkpoint_path) model = self.initialize_model( start_from_pretrained_model=start_from_pretrained_model) model = model.to(self.device) criterion = nn.CrossEntropyLoss() if optimizer == 'SGD': optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4) elif optimizer == 'Adam': optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4) else: optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4) train_batch_loader = BatchLoader(self.train_folder_path, train_image_indices) n_images = len(train_image_indices) if n_images % batch_size == 0: num_batches = n_images // batch_size else: num_batches = (n_images // batch_size) + 1 penalty_inside_list = [] penalty_outside_list = [] train_acc_list = [] train_loss_list = [] val_loss_list = [] val_acc_list = [] best_acc = 0.0 for epoch in range(num_epochs): model.train() train_batch_loader.reset() print('Epoch: {}/{}'.format(epoch + 1, num_epochs)) print('-' * 50) train_correct = 0.0 train_loss = 0.0 penalty_inside = 0.0 penalty_outside = 0.0 for batch in range(num_batches): batch_indices = train_batch_loader.get_batch_indices( batch_size) inputs = self.x_train[batch_indices] labels = self.y_train[batch_indices] inputs, labels = inputs.to(self.device), labels.to(self.device) if train_method == 'bbox': inputs.requires_grad_() outputs = model(inputs) preds = torch.argmax(outputs, dim=1) # cross entropy loss loss = criterion(outputs, labels) input_gradient = torch.autograd.grad(loss, inputs, create_graph=True)[0] penalty_inside_box, penalty_outside_box = self.calculate_penalty_box( batch_indices, input_gradient) new_loss = loss + lambda_1 * penalty_inside_box + lambda_2 * penalty_outside_box optimizer.zero_grad() new_loss.backward() optimizer.step() else: outputs = model(inputs) preds = torch.argmax(outputs, dim=1) # cross entropy loss loss = criterion(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() penalty_inside_box = torch.tensor(0).to(self.device) penalty_outside_box = torch.tensor(0).to(self.device) train_loss += loss.item() train_correct += torch.sum(preds == labels).float().item() penalty_inside += penalty_inside_box.item() * lambda_1 penalty_outside += penalty_outside_box.item() * lambda_2 train_loss = train_loss / self.train_dataset_length train_loss_list.append(train_loss) train_acc = (train_correct / self.train_dataset_length) * 100.0 train_acc_list.append(train_acc) penalty_inside = penalty_inside / self.train_dataset_length penalty_outside = penalty_outside / self.train_dataset_length penalty_inside_list.append(penalty_inside) penalty_outside_list.append(penalty_outside) print('Train Loss: {:.4f} Acc: {:.4f} % '.format( train_loss, train_acc)) print(f'Penalty Inside Box: {round(penalty_inside, 4)}') print(f'Penalty Outside Box: {round(penalty_outside, 4)}') # validate after each epoch val_correct = 0.0 val_loss = 0.0 model.eval() with torch.no_grad(): for inputs_val, labels_val in self.val_loader: inputs_val, labels_val = inputs_val.to( self.device), labels_val.to(self.device) outputs_val = model(inputs_val) preds_val = torch.argmax(outputs_val, dim=1) loss_test = criterion(outputs_val, labels_val) val_loss += loss_test.item() val_correct += torch.sum( preds_val == labels_val).float().item() val_loss = val_loss / self.val_dataset_length val_loss_list.append(val_loss) val_acc = (val_correct / self.val_dataset_length) * 100.0 val_acc_list.append(val_acc) print('Val Loss: {:.4f} Acc: {:.4f} % \n'.format( val_loss, val_acc)) # save the best model if val_acc > best_acc: best_acc = val_acc model.state_dict() if os.path.exists(self.checkpoint_path): os.remove(self.checkpoint_path) torch.save(model.state_dict(), self.checkpoint_path) return_dict = { 'train_acc_list': train_acc_list, 'train_loss_list': train_loss_list, 'penalty_inside_list': penalty_inside_list, 'penalty_outside_list': penalty_outside_list, 'val_loss_list': val_loss_list, 'val_acc_list': val_acc_list, 'best_acc': best_acc } return return_dict