Пример #1
0
    def __init__(self,
                 mode: str,
                 restoration: Module,
                 fidelity_input: str,
                 fidelity_output: str,
                 feature_extractor: Module,
                 feature_size: int,
                 classifier: Module,
                 downsample: str = 'bilinear',
                 fidelity=None,
                 num_channel=16,
                 increase=0.5,
                 MEAN: list = [0.485, 0.456, 0.406],
                 STD: list = [0.229, 0.224, 0.225]) -> None:

        super(Model, self).__init__()

        self.mode = mode.lower()
        self.downsample_mode = downsample

        # prepare restoration network, feature extractor, classifier
        self.restoration = restoration
        self.feature_extractor = feature_extractor
        self.classifier = classifier
        self.img_normal = Normalize(MEAN, STD)
        self.feature_size = feature_size

        # prepare fidelity map estimator
        self.fidelity_output = fidelity_output
        if fidelity is not None:
            self.fidelity = fidelity
            self.fidelity_input = fidelity_input
        if 'cos' in fidelity_output.lower():
            self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)
            in_channels = 1
        else:
            in_channels = 3

        # basic trainable module
        basic_cnn = nn.Sequential(
            nn.Conv2d(in_channels=1,
                      out_channels=num_channel,
                      kernel_size=3,
                      padding=1,
                      padding_mode='reflect'), nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=num_channel,
                      out_channels=1,
                      kernel_size=3,
                      padding=1,
                      padding_mode='reflect'))
        # spatial multiplication and spatial addition
        cnn_layers = nn.ModuleList([copy.deepcopy(basic_cnn)])
        for child in feature_extractor.children():
            if isinstance(child, nn.Sequential):
                for subchild in child.children():
                    if isinstance(subchild,
                                  (nn.MaxPool2d, nn.AdaptiveAvgPool2d)):
                        cnn_layers.append(copy.deepcopy(basic_cnn))
            elif isinstance(child, (nn.MaxPool2d, nn.AdaptiveAvgPool2d)):
                cnn_layers.append(copy.deepcopy(basic_cnn))
        self.cnn_layers_weight = cnn_layers
        self.cnn_layers_bias = copy.deepcopy(cnn_layers)

        if 'cos' not in self.fidelity_output.lower():
            sigma = math.sqrt(
                (0.1**2 + 0.2**2 + 0.3**2 + 0.4**2 + 0.5**2) / (6.0**2 * 2.0))
            if 'l1' in self.fidelity_output.lower():
                mean = sigma * math.sqrt(2.0 / math.pi)
                std = sigma * math.sqrt(1.0 - 2.0 / math.pi)
            elif 'l2' in self.fidelity_output.lower():
                mean = sigma**2.0
                std = math.sqrt(2.0) * sigma**2.0
            self.fidelity_normal = Normalize([mean] * in_channels,
                                             [std] * in_channels)

        # channel multiplication
        self.mul_fc = nn.Sequential(
            nn.Linear(feature_size, int(feature_size * increase)),
            nn.ReLU(inplace=True),
            nn.Linear(int(feature_size * increase), feature_size),
        )

        # channel concatenation
        self.cat_fc = nn.Sequential(
            nn.Linear(feature_size * 2, int(feature_size * increase)),
            nn.ReLU(inplace=True),
            nn.Linear(int(feature_size * increase), feature_size),
        )
        # ensemble module
        self.is_ensemble = False
        self.ensemble = nn.Sequential(
            nn.Linear(1, int(feature_size * increase)), nn.ReLU(inplace=True),
            nn.Linear(int(feature_size * increase), feature_size),
            nn.Sigmoid())
Пример #2
0
def train(model: Module,
          task: str,
          mode: str,
          device: torch.device,
          dataloaders: dict,
          lr: float,
          num_epochs: int,
          warmup: int = 5,
          logs: TextIO = None,
          model_name: str = None,
          smoothing: float = 0.1,
          restoration=None,
          MEAN: list = [0.485, 0.456, 0.406],
          STD: list = [0.229, 0.224, 0.225],
          fidelity_input: str = 'degraded',
          fidelity_output: str = 'l1') -> None:
    """ Main training function for classifiers
        Args:
            model (Module): Classifier model to train
            device (torch.device): Device where classifier model located.
            datalodaers (dict): Dataloader directory with two keys 'train' and 'valid'.
            lr (float): Initial learning rate.
            num_epochs (int): Total number of epochs for training.
            logs (Text file): Text file to record all training information: loss, accuracy v.s. epochs.
            model_name (str): Name of classifier in ['vgg', 'alexnet', 'resnet', 'googlenet']
            warmup (int): The number of epochs for learing rate warmup in first several batches.
    """
    # Gather the parameters to be optimized/updated in this run.
    logs.write("Params to learn:\n")
    print("Params to learn:")
    param_num = 0
    params_to_update = []
    for name, param in model.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            param_num += param.nelement()
            logs.write("%s\n" % name)


#             print(name)
    logs.write("Number of Params to learn: {:E}\n".format(param_num))
    print("Number of Params to learn: {:E}".format(param_num))

    # Preparetion
    since = time.time()
    valid_metrics, train_metrics, valid_loss, train_loss = [], [], [], []
    best_model_wts = copy.deepcopy(model.state_dict())
    best_metrics = 0.0 if task.lower() != 'fidelity' else -1.0 * float('inf')
    if task.lower() in ('classification', 'model', 'deepcorrect', 'wavecnet'):
        metrics = 'Accuracy'
    elif 'restoration' in task.lower():
        metrics = 'PSNR improvement'
    else:
        metrics = 'Loss'

    # Observe that all parameters are being optimized
    if task.lower() in ('classification', 'model', 'deepcorrect', 'wavecnet'):
        optimizer = optim.SGD(params_to_update,
                              lr=lr,
                              momentum=0.9,
                              nesterov=True)
    else:
        optimizer = optim.Adam(params_to_update, lr=lr)

    # Setup the loss function
    if smoothing > 0 and task.lower() in ('classification', 'model',
                                          'deepcorrect', 'wavecnet'):
        label_smoothing = LabelSmoothing(smoothing)
        criterion = nn.KLDivLoss(reduction='batchmean')
    elif task.lower() in ('classification', 'model', 'deepcorrect',
                          'wavecnet'):
        criterion = nn.CrossEntropyLoss()
    else:
        criterion = nn.L1Loss()

    # Train for each epoch
    for epoch in range(num_epochs):

        # each epoch has a training and validation phase
        for phase in ['train', 'valid']:
            # set model to train mode
            if phase == 'train':
                # fine-tune classification network
                if 'model' not in task.lower():
                    model.train()
                # train our proposed method
                else:
                    if torch.cuda.device_count() > 1:
                        if model.module.is_ensemble:
                            model.module.ensemble.train()
                        else:
                            model.module.cnn_layers_weight.train()
                            model.module.cnn_layers_bias.train()
                            model.module.cat_fc.train()
                            model.module.mul_fc.train()
                            if 'endtoend' in args.mode.lower():
                                model.module.fidelity.train()
                    else:
                        if model.is_ensemble:
                            model.ensemble.train()
                        else:
                            model.cnn_layers_weight.train()
                            model.cnn_layers_bias.train()
                            model.cat_fc.train()
                            model.mul_fc.train()
                            if 'endtoend' in args.mode.lower():
                                model.fidelity.train()
            # set model to evaluate mode
            else:
                model.eval()

            running_loss = 0.0
            running_metrics = 0

            # cosine learning rate decay
            if epoch >= warmup and phase == 'train':
                cosine_learning_rate_decay(optimizer, epoch - warmup, lr,
                                           num_epochs - warmup)

            # iterate over data
            batch = 0
            for inputs, origins, labels in tqdm(dataloaders[phase],
                                                ncols=70,
                                                leave=False,
                                                unit='b',
                                                desc='Epoch {}/{} {}'.format(
                                                    epoch + 1, num_epochs,
                                                    phase)):
                # learning rate warmup
                batch += 1
                if epoch < warmup and phase == 'train':
                    learning_rate_warmup(optimizer, warmup, epoch, lr, batch,
                                         len(dataloaders[phase]))

                # For restoration network and foidelity map estimator, change size of inputs (N, P, C, H, W) to (NxP, C, H, W)
                if task.lower() in ('restoration', 'fidelity'):
                    batch_size, patch_size, num_channels, height, width = inputs.shape
                    inputs = inputs.view(-1, num_channels, height, width)
                    origins = origins.view(-1, num_channels, height, width)

                inputs = inputs.to(device)
                if task.lower() == 'fidelity':
                    degraded = copy.deepcopy(inputs)

                # train network on restored images
                inputs = restoration(inputs).clamp_(
                    0.0, 1.0) if restoration is not None else inputs
                if task.lower() == 'fidelity':
                    restored = copy.deepcopy(inputs)
                    if fidelity_input.lower == 'degraded':
                        inputs = degraded
                # normalize image to train classification network,
                #   for proposed model, this step is integraed in model object.
                inputs = Normalize(MEAN, STD)(inputs) if task.lower() in (
                    'classification', 'deepcorrect', 'wavecnet') else inputs

                origins = origins.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                model.zero_grad()
                optimizer.zero_grad()

                # forward, track history if only in train
                with torch.set_grad_enabled(phase == 'train'):

                    # get model outputs and calculate loss
                    if model_name == 'googlenet' and phase == 'train' and 'classification' == task.lower(
                    ):
                        outputs, aux1, aux2 = model(inputs)
                        # set up label smoothing
                        if smoothing > 0:
                            labels_before_smoothing = copy.deepcopy(labels)
                            labels = label_smoothing(outputs, labels)
                            aux1 = F.log_softmax(aux1, dim=-1)
                            aux2 = F.log_softmax(aux2, dim=-1)
                            outputs = F.log_softmax(outputs, dim=-1)
                        loss1 = criterion(aux1, labels)
                        loss2 = criterion(aux2, labels)
                        loss3 = criterion(outputs, labels)
                        loss = loss3 + 0.3 * (loss1 + loss2)
                    elif task.lower() in ('classification', 'model',
                                          'deepcorrect', 'wavecnet'):
                        outputs = model(inputs, origins) if task.lower(
                        ) == 'model' and 'oracle' in mode.lower() else model(
                            inputs)
                        # set up label smoothing
                        if smoothing > 0:
                            labels_before_smoothing = copy.deepcopy(labels)
                            labels = label_smoothing(outputs, labels)
                            outputs = F.log_softmax(outputs, dim=-1)
                        loss = criterion(outputs, labels)
                    elif task.lower() in ('fidelity'):
                        outputs = inputs - model(inputs)
                        if 'l1' in fidelity_output.lower():
                            targets = (restored - origins).abs()
                        elif 'l2' in fidelity_output.lower():
                            targets = (restored - origins).square()
                        loss = criterion(outputs, targets)
                    else:
                        outputs = model(inputs)
                        loss = criterion(outputs, origins)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                if task.lower() in ('classification', 'model', 'deepcorrect',
                                    'wavecnet'):
                    labels = labels_before_smoothing if smoothing > 0 else labels
                    _, preds = torch.max(outputs, 1)
                    running_metrics += torch.sum(preds == labels.data)
                elif 'restoration' in task.lower():
                    running_metrics += getPSNR(outputs.clamp(0,1), origins).item() * inputs.size(0)\
                                        - getPSNR(inputs, origins).item() * inputs.size(0)

            if task.lower() in ('classification', 'model', 'deepcorrect',
                                'wavecnet'):
                epoch_loss = running_loss / len(dataloaders[phase].dataset)
                epoch_metrics = running_metrics.double() / len(
                    dataloaders[phase].dataset)
            else:
                epoch_loss = running_loss / (len(dataloaders[phase].dataset) *
                                             patch_size)
                epoch_metrics = running_metrics / (len(dataloaders[phase].dataset) * patch_size) \
                                if 'restoration' in task.lower() else -1.0 * epoch_loss

            print('Epoch {}: {} Loss: {:.6f} {}: {:.6f}'.format(
                epoch + 1, phase, epoch_loss, metrics, epoch_metrics))
            logs.write('Epoch {}: {} Loss: {:.6f} {}: {:.6f}\n'.format(
                epoch + 1, phase, epoch_loss, metrics, epoch_metrics))

            # record loss and metrics
            if phase == 'valid':
                valid_metrics.append(epoch_metrics)
                valid_loss.append(epoch_loss)
                # deep copy the model
                if epoch_metrics > best_metrics:
                    best_metrics = epoch_metrics
                    best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'train':
                train_metrics.append(epoch_metrics)
                train_loss.append(epoch_loss)

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    logs.write('Training complete in {:.0f}m {:.0f}s\n'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val {}: {:6f}'.format(metrics, best_metrics))
    logs.write('Best val {}: {:6f}\n'.format(metrics, best_metrics))

    # load best model weights
    model.load_state_dict(best_model_wts)

    # save model and training data
    if torch.cuda.device_count() > 1:
        torch.save(model.module.state_dict(), 'model.pth')
    else:
        torch.save(model.state_dict(), 'model.pth')
    for data, name in [(valid_metrics, 'valid_metrics'),
                       (valid_loss, 'valid_loss'),
                       (train_metrics, 'train_metrics'),
                       (train_loss, 'train_loss')]:
        with open('{}.pickle'.format(name), 'wb') as file:
            pickle.dump(data, file)

    with open('last_batch.pickle', 'wb') as file:
        pickle.dump(inputs, file)

    return model