def __init__(self, mode: str, restoration: Module, fidelity_input: str, fidelity_output: str, feature_extractor: Module, feature_size: int, classifier: Module, downsample: str = 'bilinear', fidelity=None, num_channel=16, increase=0.5, MEAN: list = [0.485, 0.456, 0.406], STD: list = [0.229, 0.224, 0.225]) -> None: super(Model, self).__init__() self.mode = mode.lower() self.downsample_mode = downsample # prepare restoration network, feature extractor, classifier self.restoration = restoration self.feature_extractor = feature_extractor self.classifier = classifier self.img_normal = Normalize(MEAN, STD) self.feature_size = feature_size # prepare fidelity map estimator self.fidelity_output = fidelity_output if fidelity is not None: self.fidelity = fidelity self.fidelity_input = fidelity_input if 'cos' in fidelity_output.lower(): self.cos = nn.CosineSimilarity(dim=1, eps=1e-6) in_channels = 1 else: in_channels = 3 # basic trainable module basic_cnn = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=num_channel, kernel_size=3, padding=1, padding_mode='reflect'), nn.ReLU(inplace=True), nn.Conv2d(in_channels=num_channel, out_channels=1, kernel_size=3, padding=1, padding_mode='reflect')) # spatial multiplication and spatial addition cnn_layers = nn.ModuleList([copy.deepcopy(basic_cnn)]) for child in feature_extractor.children(): if isinstance(child, nn.Sequential): for subchild in child.children(): if isinstance(subchild, (nn.MaxPool2d, nn.AdaptiveAvgPool2d)): cnn_layers.append(copy.deepcopy(basic_cnn)) elif isinstance(child, (nn.MaxPool2d, nn.AdaptiveAvgPool2d)): cnn_layers.append(copy.deepcopy(basic_cnn)) self.cnn_layers_weight = cnn_layers self.cnn_layers_bias = copy.deepcopy(cnn_layers) if 'cos' not in self.fidelity_output.lower(): sigma = math.sqrt( (0.1**2 + 0.2**2 + 0.3**2 + 0.4**2 + 0.5**2) / (6.0**2 * 2.0)) if 'l1' in self.fidelity_output.lower(): mean = sigma * math.sqrt(2.0 / math.pi) std = sigma * math.sqrt(1.0 - 2.0 / math.pi) elif 'l2' in self.fidelity_output.lower(): mean = sigma**2.0 std = math.sqrt(2.0) * sigma**2.0 self.fidelity_normal = Normalize([mean] * in_channels, [std] * in_channels) # channel multiplication self.mul_fc = nn.Sequential( nn.Linear(feature_size, int(feature_size * increase)), nn.ReLU(inplace=True), nn.Linear(int(feature_size * increase), feature_size), ) # channel concatenation self.cat_fc = nn.Sequential( nn.Linear(feature_size * 2, int(feature_size * increase)), nn.ReLU(inplace=True), nn.Linear(int(feature_size * increase), feature_size), ) # ensemble module self.is_ensemble = False self.ensemble = nn.Sequential( nn.Linear(1, int(feature_size * increase)), nn.ReLU(inplace=True), nn.Linear(int(feature_size * increase), feature_size), nn.Sigmoid())
def train(model: Module, task: str, mode: str, device: torch.device, dataloaders: dict, lr: float, num_epochs: int, warmup: int = 5, logs: TextIO = None, model_name: str = None, smoothing: float = 0.1, restoration=None, MEAN: list = [0.485, 0.456, 0.406], STD: list = [0.229, 0.224, 0.225], fidelity_input: str = 'degraded', fidelity_output: str = 'l1') -> None: """ Main training function for classifiers Args: model (Module): Classifier model to train device (torch.device): Device where classifier model located. datalodaers (dict): Dataloader directory with two keys 'train' and 'valid'. lr (float): Initial learning rate. num_epochs (int): Total number of epochs for training. logs (Text file): Text file to record all training information: loss, accuracy v.s. epochs. model_name (str): Name of classifier in ['vgg', 'alexnet', 'resnet', 'googlenet'] warmup (int): The number of epochs for learing rate warmup in first several batches. """ # Gather the parameters to be optimized/updated in this run. logs.write("Params to learn:\n") print("Params to learn:") param_num = 0 params_to_update = [] for name, param in model.named_parameters(): if param.requires_grad == True: params_to_update.append(param) param_num += param.nelement() logs.write("%s\n" % name) # print(name) logs.write("Number of Params to learn: {:E}\n".format(param_num)) print("Number of Params to learn: {:E}".format(param_num)) # Preparetion since = time.time() valid_metrics, train_metrics, valid_loss, train_loss = [], [], [], [] best_model_wts = copy.deepcopy(model.state_dict()) best_metrics = 0.0 if task.lower() != 'fidelity' else -1.0 * float('inf') if task.lower() in ('classification', 'model', 'deepcorrect', 'wavecnet'): metrics = 'Accuracy' elif 'restoration' in task.lower(): metrics = 'PSNR improvement' else: metrics = 'Loss' # Observe that all parameters are being optimized if task.lower() in ('classification', 'model', 'deepcorrect', 'wavecnet'): optimizer = optim.SGD(params_to_update, lr=lr, momentum=0.9, nesterov=True) else: optimizer = optim.Adam(params_to_update, lr=lr) # Setup the loss function if smoothing > 0 and task.lower() in ('classification', 'model', 'deepcorrect', 'wavecnet'): label_smoothing = LabelSmoothing(smoothing) criterion = nn.KLDivLoss(reduction='batchmean') elif task.lower() in ('classification', 'model', 'deepcorrect', 'wavecnet'): criterion = nn.CrossEntropyLoss() else: criterion = nn.L1Loss() # Train for each epoch for epoch in range(num_epochs): # each epoch has a training and validation phase for phase in ['train', 'valid']: # set model to train mode if phase == 'train': # fine-tune classification network if 'model' not in task.lower(): model.train() # train our proposed method else: if torch.cuda.device_count() > 1: if model.module.is_ensemble: model.module.ensemble.train() else: model.module.cnn_layers_weight.train() model.module.cnn_layers_bias.train() model.module.cat_fc.train() model.module.mul_fc.train() if 'endtoend' in args.mode.lower(): model.module.fidelity.train() else: if model.is_ensemble: model.ensemble.train() else: model.cnn_layers_weight.train() model.cnn_layers_bias.train() model.cat_fc.train() model.mul_fc.train() if 'endtoend' in args.mode.lower(): model.fidelity.train() # set model to evaluate mode else: model.eval() running_loss = 0.0 running_metrics = 0 # cosine learning rate decay if epoch >= warmup and phase == 'train': cosine_learning_rate_decay(optimizer, epoch - warmup, lr, num_epochs - warmup) # iterate over data batch = 0 for inputs, origins, labels in tqdm(dataloaders[phase], ncols=70, leave=False, unit='b', desc='Epoch {}/{} {}'.format( epoch + 1, num_epochs, phase)): # learning rate warmup batch += 1 if epoch < warmup and phase == 'train': learning_rate_warmup(optimizer, warmup, epoch, lr, batch, len(dataloaders[phase])) # For restoration network and foidelity map estimator, change size of inputs (N, P, C, H, W) to (NxP, C, H, W) if task.lower() in ('restoration', 'fidelity'): batch_size, patch_size, num_channels, height, width = inputs.shape inputs = inputs.view(-1, num_channels, height, width) origins = origins.view(-1, num_channels, height, width) inputs = inputs.to(device) if task.lower() == 'fidelity': degraded = copy.deepcopy(inputs) # train network on restored images inputs = restoration(inputs).clamp_( 0.0, 1.0) if restoration is not None else inputs if task.lower() == 'fidelity': restored = copy.deepcopy(inputs) if fidelity_input.lower == 'degraded': inputs = degraded # normalize image to train classification network, # for proposed model, this step is integraed in model object. inputs = Normalize(MEAN, STD)(inputs) if task.lower() in ( 'classification', 'deepcorrect', 'wavecnet') else inputs origins = origins.to(device) labels = labels.to(device) # zero the parameter gradients model.zero_grad() optimizer.zero_grad() # forward, track history if only in train with torch.set_grad_enabled(phase == 'train'): # get model outputs and calculate loss if model_name == 'googlenet' and phase == 'train' and 'classification' == task.lower( ): outputs, aux1, aux2 = model(inputs) # set up label smoothing if smoothing > 0: labels_before_smoothing = copy.deepcopy(labels) labels = label_smoothing(outputs, labels) aux1 = F.log_softmax(aux1, dim=-1) aux2 = F.log_softmax(aux2, dim=-1) outputs = F.log_softmax(outputs, dim=-1) loss1 = criterion(aux1, labels) loss2 = criterion(aux2, labels) loss3 = criterion(outputs, labels) loss = loss3 + 0.3 * (loss1 + loss2) elif task.lower() in ('classification', 'model', 'deepcorrect', 'wavecnet'): outputs = model(inputs, origins) if task.lower( ) == 'model' and 'oracle' in mode.lower() else model( inputs) # set up label smoothing if smoothing > 0: labels_before_smoothing = copy.deepcopy(labels) labels = label_smoothing(outputs, labels) outputs = F.log_softmax(outputs, dim=-1) loss = criterion(outputs, labels) elif task.lower() in ('fidelity'): outputs = inputs - model(inputs) if 'l1' in fidelity_output.lower(): targets = (restored - origins).abs() elif 'l2' in fidelity_output.lower(): targets = (restored - origins).square() loss = criterion(outputs, targets) else: outputs = model(inputs) loss = criterion(outputs, origins) # backward + optimize only if in training phase if phase == 'train': loss.backward() optimizer.step() # statistics running_loss += loss.item() * inputs.size(0) if task.lower() in ('classification', 'model', 'deepcorrect', 'wavecnet'): labels = labels_before_smoothing if smoothing > 0 else labels _, preds = torch.max(outputs, 1) running_metrics += torch.sum(preds == labels.data) elif 'restoration' in task.lower(): running_metrics += getPSNR(outputs.clamp(0,1), origins).item() * inputs.size(0)\ - getPSNR(inputs, origins).item() * inputs.size(0) if task.lower() in ('classification', 'model', 'deepcorrect', 'wavecnet'): epoch_loss = running_loss / len(dataloaders[phase].dataset) epoch_metrics = running_metrics.double() / len( dataloaders[phase].dataset) else: epoch_loss = running_loss / (len(dataloaders[phase].dataset) * patch_size) epoch_metrics = running_metrics / (len(dataloaders[phase].dataset) * patch_size) \ if 'restoration' in task.lower() else -1.0 * epoch_loss print('Epoch {}: {} Loss: {:.6f} {}: {:.6f}'.format( epoch + 1, phase, epoch_loss, metrics, epoch_metrics)) logs.write('Epoch {}: {} Loss: {:.6f} {}: {:.6f}\n'.format( epoch + 1, phase, epoch_loss, metrics, epoch_metrics)) # record loss and metrics if phase == 'valid': valid_metrics.append(epoch_metrics) valid_loss.append(epoch_loss) # deep copy the model if epoch_metrics > best_metrics: best_metrics = epoch_metrics best_model_wts = copy.deepcopy(model.state_dict()) if phase == 'train': train_metrics.append(epoch_metrics) train_loss.append(epoch_loss) time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60)) logs.write('Training complete in {:.0f}m {:.0f}s\n'.format( time_elapsed // 60, time_elapsed % 60)) print('Best val {}: {:6f}'.format(metrics, best_metrics)) logs.write('Best val {}: {:6f}\n'.format(metrics, best_metrics)) # load best model weights model.load_state_dict(best_model_wts) # save model and training data if torch.cuda.device_count() > 1: torch.save(model.module.state_dict(), 'model.pth') else: torch.save(model.state_dict(), 'model.pth') for data, name in [(valid_metrics, 'valid_metrics'), (valid_loss, 'valid_loss'), (train_metrics, 'train_metrics'), (train_loss, 'train_loss')]: with open('{}.pickle'.format(name), 'wb') as file: pickle.dump(data, file) with open('last_batch.pickle', 'wb') as file: pickle.dump(inputs, file) return model