Пример #1
0
def main(model,
         train_params,
         data_set_params,
         base_results_store_dir='./results'):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    # -----------------------------------------------------------------------------------
    # Sanity Checks
    # -----------------------------------------------------------------------------------
    # Validate Data set parameters
    # ----------------------------
    required_data_set_params = ['data_set_dir']
    for key in required_data_set_params:
        assert key in data_set_params, 'data_set_params does not have required key {}'.format(
            key)
    data_set_dir = data_set_params['data_set_dir']

    # Validate training parameters
    # ----------------------------
    required_training_params = [
        'train_batch_size', 'test_batch_size', 'learning_rate', 'num_epochs'
    ]
    for key in required_training_params:
        assert key in train_params, 'training_params does not have required key {}'.format(
            key)

    train_batch_size = train_params['train_batch_size']
    test_batch_size = train_params['test_batch_size']
    learning_rate = train_params['learning_rate']
    num_epochs = train_params['num_epochs']

    lambda1 = train_params['gaussian_reg_weight']
    gaussian_kernel_sigma = train_params['gaussian_reg_sigma']

    # -----------------------------------------------------------------------------------
    # Model
    # -----------------------------------------------------------------------------------
    print("====> Loading Model ")
    print("Name: {}".format(model.__class__.__name__))
    print(model)

    # Get name of contour integration layer
    temp = vars(model)  # Returns a dictionary.
    layers = temp['_modules']  # Returns all top level modules (layers)
    cont_int_layer_type = ''
    if 'contour_integration_layer' in layers:
        cont_int_layer_type = model.contour_integration_layer.__class__.__name__

    results_store_dir = os.path.join(
        base_results_store_dir, model.__class__.__name__ + '_' +
        cont_int_layer_type + datetime.now().strftime("_%Y%m%d_%H%M%S"))
    if not os.path.exists(results_store_dir):
        os.makedirs(results_store_dir)

    # -----------------------------------------------------------------------------------
    # Data Loader
    # -----------------------------------------------------------------------------------
    print("====> Setting up data loaders ")
    data_load_start_time = datetime.now()

    print("Data Source: {}".format(data_set_dir))

    # Imagenet Mean and STD
    ch_mean = [0.485, 0.456, 0.406]
    ch_std = [0.229, 0.224, 0.225]
    # print("Channel mean {}, std {}".format(meta_data['channel_mean'], meta_data['channel_std']))

    pre_process_transforms = transforms.Compose([
        transforms.Normalize(mean=ch_mean, std=ch_std),
    ])

    train_set = dataset_bsds.BSDS(data_dir=data_set_dir,
                                  image_set='train',
                                  transform=pre_process_transforms)
    train_batch_size = min(train_batch_size, len(train_set))

    train_data_loader = DataLoader(dataset=train_set,
                                   num_workers=4,
                                   batch_size=train_batch_size,
                                   shuffle=False,
                                   pin_memory=True)

    val_set = dataset_bsds.BSDS(
        data_dir=data_set_dir,
        image_set='val',
        transform=pre_process_transforms,
    )
    test_batch_size = min(test_batch_size, len(val_set))

    val_data_loader = DataLoader(dataset=val_set,
                                 num_workers=4,
                                 batch_size=test_batch_size,
                                 shuffle=False,
                                 pin_memory=True)

    print("Data loading Took {}. # Train {}, # Test {}".format(
        datetime.now() - data_load_start_time,
        len(train_data_loader) * train_batch_size,
        len(val_data_loader) * test_batch_size))

    # -----------------------------------------------------------------------------------
    # Loss / optimizer
    # -----------------------------------------------------------------------------------
    optimizer = optim.Adam(filter(lambda params: params.requires_grad,
                                  model.parameters()),
                           lr=learning_rate)

    lr_scheduler = optim.lr_scheduler.StepLR(optimizer,
                                             step_size=30,
                                             gamma=0.1)

    criterion = nn.BCEWithLogitsLoss().to(device)

    gaussian_mask_e = 1 - utils.get_2d_gaussian_kernel(
        model.contour_integration_layer.lateral_e.weight.shape[2:],
        sigma=gaussian_kernel_sigma)
    gaussian_mask_i = 1 - utils.get_2d_gaussian_kernel(
        model.contour_integration_layer.lateral_i.weight.shape[2:],
        sigma=gaussian_kernel_sigma)

    gaussian_mask_e = torch.from_numpy(gaussian_mask_e).float().to(device)
    gaussian_mask_i = torch.from_numpy(gaussian_mask_i).float().to(device)

    def inverse_gaussian_regularization(weight_e, weight_i):
        loss1 = (gaussian_mask_e * weight_e).abs().sum() + (
            gaussian_mask_i * weight_i).abs().sum()
        # print("Loss1: {:0.4f}".format(loss1))

        return loss1

    # -----------------------------------------------------------------------------------
    #  Training Validation Routines
    # -----------------------------------------------------------------------------------
    def train():
        """ Train for one Epoch over the train data set """
        model.train()
        e_loss = 0
        e_iou = 0

        for iteration, (img, label) in enumerate(train_data_loader, 1):
            optimizer.zero_grad()  # zero the parameter gradients

            img = img.to(device)
            label = label.to(device)

            label_out = model(img)

            batch_loss = criterion(label_out, label.float())

            kernel_loss = \
                inverse_gaussian_regularization(
                    model.contour_integration_layer.lateral_e.weight,
                    model.contour_integration_layer.lateral_i.weight
                )

            total_loss = batch_loss + lambda1 * kernel_loss

            # print("Total Loss: {:0.4f}, cross_entropy_loss {:0.4f}, kernel_loss {:0.4f}".format(
            #     total_loss, batch_loss,  lambda1 * kernel_loss))

            total_loss.backward()
            optimizer.step()

            e_loss += total_loss.item()

            preds = (torch.sigmoid(label_out) > detect_thres)
            e_iou += utils.intersection_over_union(
                preds.float(), label.float()).cpu().detach().numpy()

        e_loss = e_loss / len(train_data_loader)
        e_iou = e_iou / len(train_data_loader)

        # print("Train Epoch {} Loss = {:0.4f}, IoU={:0.4f}".format(epoch, e_loss, e_iou))

        return e_loss, e_iou

    def validate():
        """ Get loss over validation set """
        model.eval()
        e_loss = 0
        e_iou = 0

        with torch.no_grad():
            for iteration, (img, label) in enumerate(val_data_loader, 1):
                img = img.to(device)
                label = label.to(device)

                label_out = model(img)
                batch_loss = criterion(label_out, label.float())

                kernel_loss = \
                    inverse_gaussian_regularization(
                        model.contour_integration_layer.lateral_e.weight,
                        model.contour_integration_layer.lateral_i.weight
                    )

                total_loss = batch_loss + lambda1 * kernel_loss

                e_loss += total_loss.item()
                preds = (torch.sigmoid(label_out) > detect_thres)
                e_iou += utils.intersection_over_union(
                    preds.float(), label.float()).cpu().detach().numpy()

        e_loss = e_loss / len(val_data_loader)
        e_iou = e_iou / len(val_data_loader)

        # print("Val Loss = {:0.4f}, IoU={:0.4f}".format(e_loss, e_iou))

        return e_loss, e_iou

    # -----------------------------------------------------------------------------------
    # Main Loop
    # -----------------------------------------------------------------------------------
    print("====> Starting Training ")
    training_start_time = datetime.now()

    detect_thres = 0.5

    train_history = []
    val_history = []
    lr_history = []

    best_iou = 0

    # Summary file
    summary_file = os.path.join(results_store_dir, 'summary.txt')
    file_handle = open(summary_file, 'w+')

    file_handle.write("Data Set Parameters {}\n".format('-' * 60))
    file_handle.write("Source           : {}\n".format(data_set_dir))
    file_handle.write("Train Set Mean {}, std {}\n".format(
        train_set.data_set_mean, train_set.data_set_std))
    file_handle.write("Validation Set Mean {}, std {}\n".format(
        val_set.data_set_mean, train_set.data_set_std))

    file_handle.write("Training Parameters {}\n".format('-' * 60))
    file_handle.write("Train images     : {}\n".format(len(train_set.images)))
    file_handle.write("Val images       : {}\n".format(len(val_set.images)))
    file_handle.write("Train batch size : {}\n".format(train_batch_size))
    file_handle.write("Val batch size   : {}\n".format(test_batch_size))
    file_handle.write("Epochs           : {}\n".format(num_epochs))
    file_handle.write("Optimizer        : {}\n".format(
        optimizer.__class__.__name__))
    file_handle.write("learning rate    : {}\n".format(learning_rate))
    file_handle.write("Loss Fcn         : {}\n".format(
        criterion.__class__.__name__))
    file_handle.write("Gaussian Regularization sigma        : {}\n".format(
        gaussian_kernel_sigma))
    file_handle.write(
        "Gaussian Regularization weight        : {}\n".format(lambda1))
    file_handle.write("IoU Threshold    : {}\n".format(detect_thres))
    file_handle.write("Image pre-processing :\n")
    print(pre_process_transforms, file=file_handle)

    file_handle.write("Model Parameters {}\n".format('-' * 63))
    file_handle.write("Model Name       : {}\n".format(
        model.__class__.__name__))
    file_handle.write("\n")
    print(model, file=file_handle)

    temp = vars(model)  # Returns a dictionary.
    file_handle.write("Model Parameters:\n")
    p = [item for item in temp if not item.startswith('_')]
    for var in sorted(p):
        file_handle.write("{}: {}\n".format(var, getattr(model, var)))

    layers = temp['_modules']  # Returns all top level modules (layers)
    if 'contour_integration_layer' in layers:

        file_handle.write("Contour Integration Layer: {}\n".format(
            model.contour_integration_layer.__class__.__name__))

        # print fixed hyper parameters
        file_handle.write("Hyper parameters\n")

        cont_int_layer_vars = [
            item for item in vars(model.contour_integration_layer)
            if not item.startswith('_')
        ]
        for var in sorted(cont_int_layer_vars):
            file_handle.write("\t{}: {}\n".format(
                var, getattr(model.contour_integration_layer, var)))

        # print parameter names and whether they are trainable
        file_handle.write("Contour Integration Layer Parameters\n")
        layer_params = vars(model.contour_integration_layer)['_parameters']
        for k, v in sorted(layer_params.items()):
            file_handle.write("\t{}: requires_grad {}\n".format(
                k, v.requires_grad))

    file_handle.write("{}\n".format('-' * 80))
    file_handle.write("Training details\n")
    file_handle.write("Epoch, train_loss, train_iou, val_loss, val_iou, lr\n")

    print("train_batch_size={}, test_batch_size={}, lr={}, epochs={}".format(
        train_batch_size, test_batch_size, learning_rate, num_epochs))

    for epoch in range(0, num_epochs):

        epoch_start_time = datetime.now()

        train_history.append(train())
        val_history.append(validate())

        lr_history.append(get_lr(optimizer))
        lr_scheduler.step(epoch)

        print(
            "Epoch [{}/{}], Train: loss={:0.4f}, IoU={:0.4f}. Val: loss={:0.4f}, "
            "IoU={:0.4f}. Time {}".format(epoch + 1, num_epochs,
                                          train_history[epoch][0],
                                          train_history[epoch][1],
                                          val_history[epoch][0],
                                          val_history[epoch][1],
                                          datetime.now() - epoch_start_time))

        if val_history[epoch][1] > best_iou:
            best_iou = val_history[epoch][1]
            torch.save(model.state_dict(),
                       os.path.join(results_store_dir, 'best_accuracy.pth'))

        file_handle.write(
            "[{}, {:0.4f}, {:0.4f}, {:0.4f}, {:0.4f}, {}],\n".format(
                epoch + 1, train_history[epoch][0], train_history[epoch][1],
                val_history[epoch][0], val_history[epoch][1],
                lr_history[epoch]))

    training_time = datetime.now() - training_start_time
    print('Finished Training. Training took {}'.format(training_time))

    file_handle.write("{}\n".format('-' * 80))
    file_handle.write("Train Duration       : {}\n".format(training_time))
    file_handle.close()

    # -----------------------------------------------------------------------------------
    # Plots
    # -----------------------------------------------------------------------------------
    train_history = np.array(train_history)
    val_history = np.array(val_history)

    f = plt.figure()
    plt.title("Loss")
    plt.plot(train_history[:, 0], label='train')
    plt.plot(val_history[:, 0], label='validation')
    plt.xlabel('Epoch')
    plt.grid(True)
    plt.legend()
    f.savefig(os.path.join(results_store_dir, 'loss.jpg'), format='jpg')

    f = plt.figure()
    plt.title("IoU")
    plt.plot(train_history[:, 1], label='train')
    plt.plot(val_history[:, 1], label='validation')
    plt.xlabel('Epoch')
    plt.legend()
    plt.grid(True)
    f.savefig(os.path.join(results_store_dir, 'iou.jpg'), format='jpg')

    # -----------------------------------------------------------------------------------
    # Run Li 2006 experiments
    # -----------------------------------------------------------------------------------
    print("====> Running Experiments")
    experiment_gain_vs_len.main(model,
                                base_results_dir=results_store_dir,
                                iou_results=False)
    experiment_gain_vs_len.main(model,
                                base_results_dir=results_store_dir,
                                iou_results=False,
                                frag_size=np.array([11, 11]))
    experiment_gain_vs_spacing.main(model, base_results_dir=results_store_dir)
Пример #2
0
def main(model,
         train_params,
         data_set_params,
         base_results_store_dir='./results'):
    # -----------------------------------------------------------------------------------
    # Validate Parameters
    # -----------------------------------------------------------------------------------
    print("====> Validating Parameters ")
    # Pathfinder Dataset
    # ------------------
    pathfinder_required_data_set_params = ['pathfinder_data_set_dir']
    for key in pathfinder_required_data_set_params:
        assert key in data_set_params, 'data_set_params does not have required key {}'.format(
            key)
    pathfinder_data_set_dir = data_set_params['pathfinder_data_set_dir']

    # Optional
    pathfinder_train_subset_size = data_set_params.get(
        'pathfinder_train_subset_size', None)
    pathfinder_test_subset_size = data_set_params.get(
        'pathfinder_test_subset_size', None)

    # Contour Dataset
    # ---------------
    contour_required_data_set_params = ['contour_data_set_dir']
    for key in contour_required_data_set_params:
        assert key in data_set_params, 'data_set_params does not have required key {}'.format(
            key)
    contour_data_set_dir = data_set_params['contour_data_set_dir']

    # Optional
    contour_train_subset_size = data_set_params.get(
        'contour_train_subset_size', None)
    contour_test_subset_size = data_set_params.get('contour_test_subset_size',
                                                   None)
    c_len_arr = data_set_params.get('c_len_arr', None)
    beta_arr = data_set_params.get('beta_arr', None)
    alpha_arr = data_set_params.get('alpha_arr', None)
    gabor_set_arr = data_set_params.get('gabor_set_arr', None)

    # Training
    # --------
    required_training_params = \
        ['train_batch_size', 'test_batch_size', 'learning_rate', 'num_epochs']
    for key in required_training_params:
        assert key in train_params, 'training_params does not have required key {}'.format(
            key)
    train_batch_size = train_params['train_batch_size']
    test_batch_size = train_params['test_batch_size']
    learning_rate = train_params['learning_rate']
    num_epochs = train_params['num_epochs']

    # Optional
    lambda1 = train_params.get('gaussian_reg_weight', 0)
    gaussian_kernel_sigma = train_params.get('gaussian_reg_sigma', 0)
    use_gaussian_reg_on_lateral_kernels = False
    if lambda1 is not 0 and gaussian_kernel_sigma is not 0:
        use_gaussian_reg_on_lateral_kernels = True

    # -----------------------------------------------------------------------------------
    # Model
    # -----------------------------------------------------------------------------------
    print("====> Loading Model ")
    print("Name: {}".format(model.__class__.__name__))
    print(model)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    # Get name of contour integration layer
    temp = vars(model)  # Returns a dictionary.
    layers = temp['_modules']  # Returns all top level modules (layers)
    cont_int_layer_type = ''
    if 'contour_integration_layer' in layers:
        cont_int_layer_type = model.contour_integration_layer.__class__.__name__

    # Actual Results store directory
    results_store_dir = os.path.join(
        base_results_store_dir, model.__class__.__name__ + '_' +
        cont_int_layer_type + datetime.now().strftime("_%Y%m%d_%H%M%S"))

    if not os.path.exists(results_store_dir):
        os.makedirs(results_store_dir)

    # -----------------------------------------------------------------------------------
    # Data Loaders
    # -----------------------------------------------------------------------------------
    print("====> Setting up data loaders ")
    data_load_start_time = datetime.now()

    # Pathfinder
    # --------------------------------------
    print("Setting up Pathfinder Data loaders... ")
    print("Data Source: {}".format(pathfinder_data_set_dir))

    # Pre-processing
    # Imagenet Mean and STD
    ch_mean = [0.485, 0.456, 0.406]
    ch_std = [0.229, 0.224, 0.225]

    pathfinder_transforms_list = [
        transforms.Normalize(mean=ch_mean, std=ch_std),
        utils.PunctureImage(n_bubbles=100, fwhm=np.array([7, 9, 11, 13, 15]))
    ]
    pathfinder_pre_process_transforms = transforms.Compose(
        pathfinder_transforms_list)

    pathfinder_train_set = dataset_pathfinder.PathfinderNaturalImages(
        data_dir=os.path.join(pathfinder_data_set_dir, 'train'),
        transform=pathfinder_pre_process_transforms,
        subset_size=pathfinder_train_subset_size,
    )

    pathfinder_train_data_loader = DataLoader(dataset=pathfinder_train_set,
                                              num_workers=4,
                                              batch_size=train_batch_size,
                                              shuffle=True,
                                              pin_memory=True)

    pathfinder_val_set = dataset_pathfinder.PathfinderNaturalImages(
        data_dir=os.path.join(pathfinder_data_set_dir, 'test'),
        transform=pathfinder_pre_process_transforms,
        subset_size=pathfinder_test_subset_size)

    pathfinder_val_data_loader = DataLoader(dataset=pathfinder_val_set,
                                            num_workers=4,
                                            batch_size=test_batch_size,
                                            shuffle=True,
                                            pin_memory=True)
    print("Pathfinder Data loading Took {}. # Train {}, # Test {}".format(
        datetime.now() - data_load_start_time,
        len(pathfinder_train_data_loader) * train_batch_size,
        len(pathfinder_val_data_loader) * test_batch_size))

    # Contour Dataset
    # ---------------
    print("Setting up Contour Data loaders... ")
    data_load_start_time = datetime.now()
    print("Data Source: {}".format(contour_data_set_dir))
    print(
        "\tRestrictions:\n\t clen={},\n\t beta={},\n\t alpha={},\n\t gabor_sets={},\n\t "
        "train_subset={},\n\ttest subset={}\n".format(
            c_len_arr, beta_arr, alpha_arr, gabor_set_arr,
            contour_train_subset_size, contour_test_subset_size))

    # Get mean/std of dataset
    meta_data_file = os.path.join(contour_data_set_dir,
                                  'dataset_metadata.pickle')
    with open(meta_data_file, 'rb') as file_handle:
        meta_data = pickle.load(file_handle)

    # Pre-processing
    contour_transforms_list = [
        transforms.Normalize(mean=meta_data['channel_mean'],
                             std=meta_data['channel_std']),
        utils.PunctureImage(n_bubbles=100, fwhm=np.array([7, 9, 11, 13, 15]))
    ]

    contour_pre_process_transforms = transforms.Compose(
        contour_transforms_list)

    contour_train_set = dataset_contour.Fields1993(
        data_dir=os.path.join(contour_data_set_dir, 'train'),
        bg_tile_size=meta_data["full_tile_size"],
        transform=contour_pre_process_transforms,
        subset_size=contour_train_subset_size,
        c_len_arr=c_len_arr,
        beta_arr=beta_arr,
        alpha_arr=alpha_arr,
        gabor_set_arr=gabor_set_arr)

    contour_train_data_loader = DataLoader(dataset=contour_train_set,
                                           num_workers=4,
                                           batch_size=train_batch_size,
                                           shuffle=True,
                                           pin_memory=True)

    contour_val_set = dataset_contour.Fields1993(
        data_dir=os.path.join(contour_data_set_dir, 'val'),
        bg_tile_size=meta_data["full_tile_size"],
        transform=contour_pre_process_transforms,
        subset_size=contour_test_subset_size,
        c_len_arr=c_len_arr,
        beta_arr=beta_arr,
        alpha_arr=alpha_arr,
        gabor_set_arr=gabor_set_arr)

    contour_val_data_loader = DataLoader(dataset=contour_val_set,
                                         num_workers=4,
                                         batch_size=test_batch_size,
                                         shuffle=True,
                                         pin_memory=True)

    print("Contour Data loading Took {}. # Train {}, # Test {}".format(
        datetime.now() - data_load_start_time,
        len(contour_train_data_loader) * train_batch_size,
        len(contour_val_data_loader) * test_batch_size))

    # -----------------------------------------------------------------------------------
    # Loss / optimizer
    # -----------------------------------------------------------------------------------
    detect_thres = 0.5

    optimizer = optim.Adam(filter(lambda p1: p1.requires_grad,
                                  model.parameters()),
                           lr=learning_rate)

    lr_scheduler = optim.lr_scheduler.StepLR(optimizer,
                                             step_size=30,
                                             gamma=0.1)

    criterion = nn.BCEWithLogitsLoss().to(device)

    if use_gaussian_reg_on_lateral_kernels:
        gaussian_mask_e = 1 - utils.get_2d_gaussian_kernel(
            model.contour_integration_layer.lateral_e.weight.shape[2:],
            sigma=gaussian_kernel_sigma)
        gaussian_mask_i = 1 - utils.get_2d_gaussian_kernel(
            model.contour_integration_layer.lateral_i.weight.shape[2:],
            sigma=gaussian_kernel_sigma)

        gaussian_mask_e = torch.from_numpy(gaussian_mask_e).float().to(device)
        gaussian_mask_i = torch.from_numpy(gaussian_mask_i).float().to(device)

        def inverse_gaussian_regularization(weight_e, weight_i):
            loss1 = (gaussian_mask_e * weight_e).abs().sum() + \
                    (gaussian_mask_i * weight_i).abs().sum()
            # print("Loss1: {:0.4f}".format(loss1))

            return loss1

    # -----------------------------------------------------------------------------------
    #  Training Validation Routines
    # -----------------------------------------------------------------------------------
    def train_pathfinder():
        """ Train for one Epoch over the train data set """
        model.train()
        e_loss = 0
        e_acc = 0

        for iteration, data_loader_out in enumerate(
                pathfinder_train_data_loader, 1):
            optimizer.zero_grad()  # zero the parameter gradients

            img, label, _, _, _, = data_loader_out

            img = img.to(device)
            label = label.to(device)

            # Second part is pathfinder out
            _, label_out = model(img)

            bce_loss = criterion(label_out, label.float())
            reg_loss = 0

            if use_gaussian_reg_on_lateral_kernels:
                reg_loss = \
                    inverse_gaussian_regularization(
                        model.contour_integration_layer.lateral_e.weight,
                        model.contour_integration_layer.lateral_i.weight
                    )

            total_loss = bce_loss + lambda1 * reg_loss
            acc = binary_acc(label_out, label)

            # print("Loss: {:0.4f}, bce_loss {:0.4f}, lateral kernels reg_loss {:0.4f}, "
            #       "acc {:0.4f}".format(total_loss, bce_loss,  lambda1 * reg_loss, acc))

            total_loss.backward()
            optimizer.step()

            e_loss += total_loss.item()
            e_acc += acc.item()

        e_loss = e_loss / len(pathfinder_train_data_loader)
        e_acc = e_acc / len(pathfinder_train_data_loader)

        # iou_arr = ["{:0.2f}".format(item) for item in e_iou]
        # print("Train Epoch {} Loss = {:0.4f}, Acc = {}".format(epoch, e_loss, e_acc))

        return e_loss, e_acc

    def validate_pathfinder():
        """ Get loss over validation set """
        model.eval()
        e_loss = 0
        e_acc = 0

        with torch.no_grad():
            for iteration, data_loader_out in enumerate(
                    pathfinder_val_data_loader, 1):

                img, label, _, _, _ = data_loader_out

                img = img.to(device)
                label = label.to(device)

                _, label_out = model(img)

                bce_loss = criterion(label_out, label.float())
                reg_loss = 0

                if use_gaussian_reg_on_lateral_kernels:
                    reg_loss = \
                        inverse_gaussian_regularization(
                            model.contour_integration_layer.lateral_e.weight,
                            model.contour_integration_layer.lateral_i.weight
                        )

                total_loss = bce_loss + lambda1 * reg_loss
                acc = binary_acc(label_out, label)

                e_loss += total_loss.item()
                e_acc += acc.item()

        e_loss = e_loss / len(pathfinder_val_data_loader)
        e_acc = e_acc / len(pathfinder_val_data_loader)

        # print("Val Loss = {:0.4f}, Accuracy={}".format(e_loss, e_acc))

        return e_loss, e_acc

    def train_contour():
        """ Train for one Epoch over the train data set """
        model.train()
        e_loss = 0
        e_iou = 0

        for iteration, (img, label) in enumerate(contour_train_data_loader, 1):
            optimizer.zero_grad()  # zero the parameter gradients

            img = img.to(device)
            label = label.to(device)

            label_out, _ = model(img)
            batch_loss = criterion(label_out, label.float())

            kernel_loss = \
                inverse_gaussian_regularization(
                    model.contour_integration_layer.lateral_e.weight,
                    model.contour_integration_layer.lateral_i.weight
                )

            total_loss = batch_loss + lambda1 * kernel_loss

            # print("Total Loss: {:0.4f}, cross_entropy_loss {:0.4f}, kernel_loss {:0.4f}".format(
            #     total_loss, batch_loss,  lambda1 * kernel_loss))
            #
            # import pdb
            # pdb.set_trace()

            total_loss.backward()
            optimizer.step()

            e_loss += total_loss.item()

            preds = (torch.sigmoid(label_out) > detect_thres)
            e_iou += utils.intersection_over_union(
                preds.float(), label.float()).cpu().detach().numpy()

        e_loss = e_loss / len(contour_train_data_loader)
        e_iou = e_iou / len(contour_train_data_loader)

        # print("Train Epoch {} Loss = {:0.4f}, IoU={:0.4f}".format(epoch, e_loss, e_iou))

        return e_loss, e_iou

    def validate_contour():
        """ Get loss over validation set """
        model.eval()
        e_loss = 0
        e_iou = 0

        with torch.no_grad():
            for iteration, (img, label) in enumerate(contour_val_data_loader,
                                                     1):
                img = img.to(device)
                label = label.to(device)

                label_out, _ = model(img)
                batch_loss = criterion(label_out, label.float())

                kernel_loss = \
                    inverse_gaussian_regularization(
                        model.contour_integration_layer.lateral_e.weight,
                        model.contour_integration_layer.lateral_i.weight
                    )

                total_loss = batch_loss + lambda1 * kernel_loss

                e_loss += total_loss.item()
                preds = (torch.sigmoid(label_out) > detect_thres)
                e_iou += utils.intersection_over_union(
                    preds.float(), label.float()).cpu().detach().numpy()

        e_loss = e_loss / len(contour_val_data_loader)
        e_iou = e_iou / len(contour_val_data_loader)

        # print("Val Loss = {:0.4f}, IoU={:0.4f}".format(e_loss, e_iou))

        return e_loss, e_iou

    def write_training_and_model_details(f_handle):
        # Dataset Parameters:
        f_handle.write("Data Set Parameters {}\n".format('-' * 60))
        f_handle.write("CONTOUR DATASET \n")
        f_handle.write(
            "Source                   : {}\n".format(contour_data_set_dir))
        f_handle.write("Restrictions             :\n")
        f_handle.write("  Lengths                : {}\n".format(c_len_arr))
        f_handle.write("  Beta                   : {}\n".format(beta_arr))
        f_handle.write("  Alpha                  : {}\n".format(alpha_arr))
        f_handle.write("  Gabor Sets             : {}\n".format(gabor_set_arr))
        f_handle.write("  Train subset size      : {}\n".format(
            contour_train_subset_size))
        f_handle.write(
            "  Test subset size       : {}\n".format(contour_test_subset_size))
        f_handle.write("Number of Images         : Train {}, Test {}\n".format(
            len(contour_train_set.images), len(contour_val_set.images)))
        f_handle.write("Train Set Mean {}, std {}\n".format(
            contour_train_set.data_set_mean, contour_train_set.data_set_std))
        f_handle.write("Val   Set Mean {}, std {}\n".format(
            contour_val_set.data_set_mean, contour_val_set.data_set_std))

        f_handle.write("PATHFINDER  DATASET\n")
        f_handle.write(
            "Source                   : {}\n".format(pathfinder_data_set_dir))
        f_handle.write("Restrictions             :\n")
        f_handle.write("  Train subset size      : {}\n".format(
            pathfinder_train_subset_size))
        f_handle.write("  Test subset size       : {}\n".format(
            pathfinder_test_subset_size))
        f_handle.write("Number of Images         : Train {}, Test {}\n".format(
            len(pathfinder_train_set.images), len(pathfinder_val_set.images)))

        # Training Parameters:
        f_handle.write("Training Parameters {}\n".format('-' * 60))
        f_handle.write(
            "Train batch size         : {}\n".format(train_batch_size))
        f_handle.write(
            "Val batch size           : {}\n".format(test_batch_size))
        f_handle.write("Epochs                   : {}\n".format(num_epochs))
        f_handle.write("Optimizer                : {}\n".format(
            optimizer.__class__.__name__))
        f_handle.write("learning rate            : {}\n".format(learning_rate))
        f_handle.write("Loss Fcn                 : {}\n".format(
            criterion.__class__.__name__))
        f_handle.write(
            "Gaussian Regularization on lateral kernels: {}\n".format(
                use_gaussian_reg_on_lateral_kernels))
        if use_gaussian_reg_on_lateral_kernels:
            f_handle.write("  Gaussian Reg. sigma    : {}\n".format(
                gaussian_kernel_sigma))
            f_handle.write("  Gaussian Reg. weight   : {}\n".format(lambda1))
        f_handle.write("IoU Detection Threshold  : {}\n".format(detect_thres))

        f_handle.write("Image pre-processing :\n")
        f_handle.write("Contour Dataset:\n")
        print(contour_pre_process_transforms, file=f_handle)
        f_handle.write("Pathfinder Dataset:\n")
        print(pathfinder_pre_process_transforms, file=f_handle)

        # Model Details
        f_handle.write("Model Parameters {}\n".format('-' * 63))
        f_handle.write("Model Name       : {}\n".format(
            model.__class__.__name__))
        f_handle.write("\n")
        print(model, file=file_handle)

        tmp = vars(model)  # Returns a dictionary.
        p = [item for item in tmp if not item.startswith('_')]
        for var in sorted(p):
            f_handle.write("{}: {}\n".format(var, getattr(model, var)))

        layers1 = tmp['_modules']  # Returns all top level modules (layers)
        if 'contour_integration_layer' in layers1:

            f_handle.write("Contour Integration Layer: {}\n".format(
                model.contour_integration_layer.__class__.__name__))

            # print fixed hyper parameters
            f_handle.write("Hyper parameters\n")

            cont_int_layer_vars = \
                [item for item in vars(model.contour_integration_layer) if not item.startswith('_')]
            for var in sorted(cont_int_layer_vars):
                f_handle.write("\t{}: {}\n".format(
                    var, getattr(model.contour_integration_layer, var)))

            # print parameter names and whether they are trainable
            f_handle.write("Contour Integration Layer Parameters\n")
            layer_params = vars(model.contour_integration_layer)['_parameters']
            for k, v in sorted(layer_params.items()):
                f_handle.write("\t{}: requires_grad {}\n".format(
                    k, v.requires_grad))

        # Headers for columns in training details in summary file
        f_handle.write("{}\n".format('-' * 80))
        f_handle.write("Training details\n")
        f_handle.write(
            "[Epoch,\ncontour train_loss, train_iou, val_loss, val_iou\n")
        f_handle.write(
            "pathfinder train_loss, train_acc, val_loss, val_acc]\n")

    # -----------------------------------------------------------------------------------
    # Main Loop
    # -----------------------------------------------------------------------------------
    print("====> Starting Training ")
    training_start_time = datetime.now()

    pathfinder_train_history = []
    pathfinder_val_history = []
    contour_train_history = []
    contour_val_history = []
    lr_history = []

    # Summary File
    # ------------
    summary_file = os.path.join(results_store_dir, 'summary.txt')
    file_handle = open(summary_file, 'w+')

    write_training_and_model_details(file_handle)

    # Actual main loop start
    # ----------------------
    print("train_batch_size={}, test_batch_size= {}, lr={}, epochs={}".format(
        train_batch_size, test_batch_size, learning_rate, num_epochs))

    for epoch in range(0, num_epochs):

        # Contour Dataset First
        epoch_start_time = datetime.now()
        contour_train_history.append(train_contour())
        contour_val_history.append(validate_contour())

        print(
            "Epoch [{}/{}], Contour    Train: loss={:0.4f}, IoU={:0.4f}. Val: "
            "loss={:0.4f}, IoU={:0.4f}. Time {}".format(
                epoch + 1, num_epochs, contour_train_history[epoch][0],
                contour_train_history[epoch][1], contour_val_history[epoch][0],
                contour_val_history[epoch][1],
                datetime.now() - epoch_start_time))

        # Pathfinder Dataset
        epoch_start_time = datetime.now()
        pathfinder_train_history.append(train_pathfinder())
        pathfinder_val_history.append(validate_pathfinder())

        print(
            "Epoch [{}/{}], Pathfinder Train: loss={:0.4f}, Acc={:0.3f}. Val: "
            "loss={:0.4f}, Acc={:0.3f}. Time {}".format(
                epoch + 1, num_epochs, pathfinder_train_history[epoch][0],
                pathfinder_train_history[epoch][1],
                pathfinder_val_history[epoch][0],
                pathfinder_val_history[epoch][1],
                datetime.now() - epoch_start_time))

        lr_history.append(get_lr(optimizer))
        lr_scheduler.step(epoch)

        # Save Last epoch weights
        torch.save(model.state_dict(),
                   os.path.join(results_store_dir, 'last_epoch.pth'))

        file_handle.write("[{}, {:0.4f}, {:0.4f}, {:0.4f}, {:0.4f}, "
                          "{:0.4f}, {:0.3f}, {:0.4f}, {:0.3f}],\n".format(
                              epoch + 1, contour_train_history[epoch][0],
                              contour_train_history[epoch][1],
                              contour_val_history[epoch][0],
                              contour_val_history[epoch][1],
                              pathfinder_train_history[epoch][0],
                              pathfinder_train_history[epoch][1],
                              pathfinder_val_history[epoch][0],
                              pathfinder_val_history[epoch][1]))

    training_time = datetime.now() - training_start_time
    print('Finished Training. Training took {}'.format(training_time))

    file_handle.write("{}\n".format('-' * 80))
    file_handle.write("Train Duration       : {}\n".format(training_time))
    file_handle.close()

    # -----------------------------------------------------------------------------------
    # Plots
    # -----------------------------------------------------------------------------------
    plot_pathfinder_results(pathfinder_train_history, pathfinder_val_history,
                            results_store_dir)
    plot_contour_results(contour_train_history, contour_val_history,
                         results_store_dir)

    # -----------------------------------------------------------------------------------
    # Run Li 2006 experiments
    # -----------------------------------------------------------------------------------
    print("====> Running Experiments")
    experiment_gain_vs_len.main(model,
                                base_results_dir=results_store_dir,
                                iou_results=False)
    experiment_gain_vs_len.main(model,
                                results_store_dir,
                                iou_results=False,
                                frag_size=np.array([14, 14]))

    experiment_gain_vs_spacing.main(model, base_results_dir=results_store_dir)
    experiment_gain_vs_spacing.main(model,
                                    results_store_dir,
                                    frag_size=np.array([14, 14]))
def main(model,
         train_params,
         data_set_params,
         base_results_store_dir='./results'):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    # -----------------------------------------------------------------------------------
    # Sanity Checks
    # -----------------------------------------------------------------------------------
    # Validate Data set parameters
    # ----------------------------
    required_data_set_params = ['data_set_dir']
    for key in required_data_set_params:
        assert key in data_set_params, 'data_set_params does not have required key {}'.format(
            key)
    data_set_dir = data_set_params['data_set_dir']

    # Optional
    train_subset_size = data_set_params.get('train_subset_size', None)
    test_subset_size = data_set_params.get('test_subset_size', None)
    resize_size = data_set_params.get('resize_size', None)

    # Validate training parameters
    # ----------------------------
    required_training_params = \
        ['train_batch_size', 'test_batch_size', 'learning_rate', 'num_epochs']
    for key in required_training_params:
        assert key in train_params, 'training_params does not have required key {}'.format(
            key)
    train_batch_size = train_params['train_batch_size']
    test_batch_size = train_params['test_batch_size']
    learning_rate = train_params['learning_rate']
    num_epochs = train_params['num_epochs']

    # Optional
    lambda1 = train_params.get('gaussian_reg_weight', 0)
    gaussian_kernel_sigma = train_params.get('gaussian_reg_sigma', 0)
    use_gaussian_reg_on_lateral_kernels = False
    if lambda1 is not 0 and gaussian_kernel_sigma is not 0:
        use_gaussian_reg_on_lateral_kernels = True

    # -----------------------------------------------------------------------------------
    # Model
    # -----------------------------------------------------------------------------------
    print("====> Loading Model ")
    print("Name: {}".format(model.__class__.__name__))
    print(model)

    # Get name of contour integration layer
    temp = vars(model)  # Returns a dictionary.
    layers = temp['_modules']  # Returns all top level modules (layers)
    cont_int_layer_type = ''
    if 'contour_integration_layer' in layers:
        cont_int_layer_type = model.contour_integration_layer.__class__.__name__

    results_store_dir = os.path.join(
        base_results_store_dir, model.__class__.__name__ + '_' +
        cont_int_layer_type + datetime.now().strftime("_%Y%m%d_%H%M%S"))

    if not os.path.exists(results_store_dir):
        os.makedirs(results_store_dir)

    # -----------------------------------------------------------------------------------
    # Data Loader
    # -----------------------------------------------------------------------------------
    print("====> Setting up data loaders ")
    data_load_start_time = datetime.now()

    print("Data Source: {}".format(data_set_dir))

    # TODO: get mean/std of dataset
    # Imagenet Mean and STD
    ch_mean = [0.485, 0.456, 0.406]
    ch_std = [0.229, 0.224, 0.225]
    # print("Channel mean {}, std {}".format(meta_data['channel_mean'], meta_data['channel_std']))

    # Pre-processing
    transforms_list = [
        transforms.Normalize(mean=ch_mean, std=ch_std),
        # utils.PunctureImage(n_bubbles=100, fwhm=20, peak_bubble_transparency=0)
    ]
    pre_process_transforms = transforms.Compose(transforms_list)

    train_set = dataset_biped.BipedDataSet(
        data_dir=data_set_dir,
        dataset_type='train',
        transform=pre_process_transforms,
        subset_size=train_subset_size,
        resize_size=resize_size,
    )
    train_batch_size = min(train_batch_size, len(train_set))

    train_data_loader = DataLoader(dataset=train_set,
                                   num_workers=4,
                                   batch_size=train_batch_size,
                                   shuffle=True,
                                   pin_memory=True)

    val_set = dataset_biped.BipedDataSet(
        data_dir=data_set_dir,
        dataset_type='test',
        transform=pre_process_transforms,
        subset_size=test_subset_size,
        resize_size=resize_size,
    )
    test_batch_size = min(test_batch_size, len(val_set))

    val_data_loader = DataLoader(dataset=val_set,
                                 num_workers=4,
                                 batch_size=test_batch_size,
                                 shuffle=False,
                                 pin_memory=True)

    print("Data loading Took {}. # Train {}, # Test {}".format(
        datetime.now() - data_load_start_time,
        len(train_data_loader) * train_batch_size,
        len(val_data_loader) * test_batch_size))

    # -----------------------------------------------------------------------------------
    # Loss / optimizer
    # -----------------------------------------------------------------------------------
    optimizer = optim.Adam(filter(lambda params: params.requires_grad,
                                  model.parameters()),
                           lr=learning_rate)

    lr_scheduler = optim.lr_scheduler.StepLR(optimizer,
                                             step_size=30,
                                             gamma=0.1)

    criterion = nn.BCEWithLogitsLoss().to(device)

    # detect_thres = 0.5
    detect_thres = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8])

    if use_gaussian_reg_on_lateral_kernels:
        gaussian_mask_e = 1 - utils.get_2d_gaussian_kernel(
            model.contour_integration_layer.lateral_e.weight.shape[2:],
            sigma=gaussian_kernel_sigma)
        gaussian_mask_i = 1 - utils.get_2d_gaussian_kernel(
            model.contour_integration_layer.lateral_i.weight.shape[2:],
            sigma=gaussian_kernel_sigma)

        gaussian_mask_e = torch.from_numpy(gaussian_mask_e).float().to(device)
        gaussian_mask_i = torch.from_numpy(gaussian_mask_i).float().to(device)

        def inverse_gaussian_regularization(weight_e, weight_i):
            loss1 = (gaussian_mask_e * weight_e).abs().sum() + \
                    (gaussian_mask_i * weight_i).abs().sum()
            # print("Loss1: {:0.4f}".format(loss1))

            return loss1

    # -----------------------------------------------------------------------------------
    #  Training Validation Routines
    # -----------------------------------------------------------------------------------
    def train():
        """ Train for one Epoch over the train data set """
        model.train()
        e_loss = 0
        e_iou = np.zeros_like(detect_thres)

        for iteration, (img, label) in enumerate(train_data_loader, 1):
            optimizer.zero_grad()  # zero the parameter gradients

            img = img.to(device)
            label = label.to(device)

            label_out = model(img)

            bce_loss = criterion(label_out, label.float())
            reg_loss = 0

            if use_gaussian_reg_on_lateral_kernels:
                reg_loss = \
                    inverse_gaussian_regularization(
                        model.contour_integration_layer.lateral_e.weight,
                        model.contour_integration_layer.lateral_i.weight
                    )

            total_loss = bce_loss + lambda1 * reg_loss

            # print("Loss: {:0.4f}, bce_loss {:0.4f}, lateral kernels reg_loss {:0.4f}".format(
            #     total_loss, bce_loss,  lambda1 * reg_loss))

            total_loss.backward()
            optimizer.step()

            e_loss += total_loss.item()

            sigmoid_label_out = torch.sigmoid(label_out)

            for th_idx, thresh in enumerate(detect_thres):
                preds = sigmoid_label_out > thresh
                e_iou[th_idx] += utils.intersection_over_union(
                    preds.float(), label.float()).cpu().detach().numpy()

        e_loss = e_loss / len(train_data_loader)
        e_iou = e_iou / len(train_data_loader)

        # iou_arr = ["{:0.2f}".format(item) for item in e_iou]
        # print("Train Epoch {} Loss = {:0.4f}, IoU={}".format(epoch, e_loss, iou_arr))

        return e_loss, e_iou

    def validate():
        """ Get loss over validation set """
        model.eval()
        e_loss = 0
        e_iou = np.zeros_like(detect_thres)

        with torch.no_grad():
            for iteration, (img, label) in enumerate(val_data_loader, 1):
                img = img.to(device)
                label = label.to(device)

                label_out = model(img)
                bce_loss = criterion(label_out, label.float())
                reg_loss = 0

                if use_gaussian_reg_on_lateral_kernels:
                    reg_loss = \
                        inverse_gaussian_regularization(
                            model.contour_integration_layer.lateral_e.weight,
                            model.contour_integration_layer.lateral_i.weight
                        )

                total_loss = bce_loss + lambda1 * reg_loss

                e_loss += total_loss.item()

                sigmoid_label_out = torch.sigmoid(label_out)

                for th_idx, thresh in enumerate(detect_thres):
                    preds = sigmoid_label_out > thresh
                    e_iou[th_idx] += utils.intersection_over_union(
                        preds.float(), label.float()).cpu().detach().numpy()

        e_loss = e_loss / len(val_data_loader)
        e_iou = e_iou / len(val_data_loader)

        # print("Val Loss = {:0.4f}, IoU={}".format(e_loss, e_iou))

        return e_loss, e_iou

    # -----------------------------------------------------------------------------------
    # Main Loop
    # -----------------------------------------------------------------------------------
    print("====> Starting Training ")
    training_start_time = datetime.now()

    train_history = []
    val_history = []
    lr_history = []

    best_iou = 0

    # Summary file
    summary_file = os.path.join(results_store_dir, 'summary.txt')
    file_handle = open(summary_file, 'w+')

    file_handle.write("Data Set Parameters {}\n".format('-' * 60))
    file_handle.write("Source           : {}\n".format(data_set_dir))
    file_handle.write("Train Set Mean {}, std {}\n".format(
        train_set.data_set_mean, train_set.data_set_std))
    file_handle.write("Validation Set Mean {}, std {}\n".format(
        val_set.data_set_mean, train_set.data_set_std))

    file_handle.write("Training Parameters {}\n".format('-' * 60))
    file_handle.write("Train images     : {}\n".format(len(train_set.images)))
    file_handle.write("Val images       : {}\n".format(len(val_set.images)))
    file_handle.write("Train batch size : {}\n".format(train_batch_size))
    file_handle.write("Val batch size   : {}\n".format(test_batch_size))
    file_handle.write("Epochs           : {}\n".format(num_epochs))
    file_handle.write("Optimizer        : {}\n".format(
        optimizer.__class__.__name__))
    file_handle.write("learning rate    : {}\n".format(learning_rate))
    file_handle.write("Loss Fcn         : {}\n".format(
        criterion.__class__.__name__))

    file_handle.write(
        "Use Gaussian Regularization on lateral kernels: {}\n".format(
            use_gaussian_reg_on_lateral_kernels))
    if use_gaussian_reg_on_lateral_kernels:
        file_handle.write("Gaussian Regularization sigma        : {}\n".format(
            gaussian_kernel_sigma))
        file_handle.write(
            "Gaussian Regularization weight        : {}\n".format(lambda1))

    file_handle.write("IoU Threshold    : {}\n".format(detect_thres))
    file_handle.write("Image pre-processing :\n")
    print(pre_process_transforms, file=file_handle)

    file_handle.write("Model Parameters {}\n".format('-' * 63))
    file_handle.write("Model Name       : {}\n".format(
        model.__class__.__name__))
    file_handle.write("\n")
    print(model, file=file_handle)

    temp = vars(model)  # Returns a dictionary.
    file_handle.write("Model Parameters:\n")
    p = [item for item in temp if not item.startswith('_')]
    for var in sorted(p):
        file_handle.write("{}: {}\n".format(var, getattr(model, var)))

    layers = temp['_modules']  # Returns all top level modules (layers)
    if 'contour_integration_layer' in layers:

        file_handle.write("Contour Integration Layer: {}\n".format(
            model.contour_integration_layer.__class__.__name__))

        # print fixed hyper parameters
        file_handle.write("Hyper parameters\n")

        cont_int_layer_vars = \
            [item for item in vars(model.contour_integration_layer) if not item.startswith('_')]
        for var in sorted(cont_int_layer_vars):
            file_handle.write("\t{}: {}\n".format(
                var, getattr(model.contour_integration_layer, var)))

        # print parameter names and whether they are trainable
        file_handle.write("Contour Integration Layer Parameters\n")
        layer_params = vars(model.contour_integration_layer)['_parameters']
        for k, v in sorted(layer_params.items()):
            file_handle.write("\t{}: requires_grad {}\n".format(
                k, v.requires_grad))

    file_handle.write("{}\n".format('-' * 80))
    file_handle.write("Training details\n")
    # file_handle.write("Epoch, train_loss, train_iou, val_loss, val_iou, lr\n")
    file_handle.write("Epoch, train_loss, ")
    for thres in detect_thres:
        file_handle.write("train_iou_{}, ".format(thres))
    file_handle.write("val_loss, ")
    for thres in detect_thres:
        file_handle.write("val_iou_{}, ".format(thres))
    file_handle.write("lr\n")

    print("train_batch_size={}, test_batch_size={}, lr={}, epochs={}".format(
        train_batch_size, test_batch_size, learning_rate, num_epochs))

    for epoch in range(0, num_epochs):

        epoch_start_time = datetime.now()

        train_history.append(train())
        val_history.append(validate())

        lr_history.append(get_lr(optimizer))
        lr_scheduler.step(epoch)

        train_iou_arr = [
            "{:0.2f}".format(item) for item in train_history[epoch][1]
        ]
        val_iou_arr = [
            "{:0.2f}".format(item) for item in val_history[epoch][1]
        ]

        print(
            "Epoch [{}/{}], Train: loss={:0.4f}, IoU={}. Val: loss={:0.4f}, IoU={}."
            "Time {}".format(epoch + 1, num_epochs, train_history[epoch][0],
                             train_iou_arr, val_history[epoch][0], val_iou_arr,
                             datetime.now() - epoch_start_time))

        # Save best val accuracy weights
        max_val_iou = max(val_history[epoch][1]) > best_iou
        if max_val_iou > best_iou:
            best_iou = max_val_iou
            torch.save(model.state_dict(),
                       os.path.join(results_store_dir, 'best_accuracy.pth'))

        # Save Last epoch weights
        torch.save(model.state_dict(),
                   os.path.join(results_store_dir, 'last_epoch.pth'))

        file_handle.write("[{}, {:0.4f}, {}, {:0.4f}, {}, {}],\n".format(
            epoch + 1, train_history[epoch][0], train_iou_arr,
            val_history[epoch][0], val_iou_arr, lr_history[epoch]))

    training_time = datetime.now() - training_start_time
    print('Finished Training. Training took {}'.format(training_time))

    file_handle.write("{}\n".format('-' * 80))
    file_handle.write("Train Duration       : {}\n".format(training_time))
    file_handle.close()

    # -----------------------------------------------------------------------------------
    # Plots
    # -----------------------------------------------------------------------------------
    train_history = np.array(train_history)
    val_history = np.array(val_history)

    train_iou_mat = np.zeros((num_epochs, len(detect_thres)))
    val_iou_mat = np.zeros_like(train_iou_mat)
    for e_idx in range(num_epochs):
        train_iou_mat[e_idx, ] = train_history[e_idx, 1]
        val_iou_mat[e_idx, ] = val_history[e_idx, 1]

    f = plt.figure()
    plt.title("Loss")
    plt.plot(train_history[:, 0], label='train')
    plt.plot(val_history[:, 0], label='validation')
    plt.xlabel('Epoch')
    plt.grid(True)
    plt.legend()
    f.savefig(os.path.join(results_store_dir, 'loss.jpg'), format='jpg')

    f1, ax_arr1 = plt.subplots(1, figsize=(12.8, 9.6))
    f2, ax_arr2 = plt.subplots(1, figsize=(12.8, 9.6))

    for thres_idx, thres in enumerate(detect_thres):
        ax_arr1.plot(train_iou_mat[:, thres_idx],
                     label='train_th_{}'.format(thres))
        ax_arr2.plot(val_iou_mat[:, thres_idx],
                     label='val_th_{}'.format(thres))

        ax_arr1.set_title('Train IoU - various thresholds')
        ax_arr2.set_title('Validation IoU - various thresholds')

        ax_arr1.set_xlabel('Epoch')
        ax_arr2.set_xlabel('Epoch')

        ax_arr1.set_ylabel('IoU')
        ax_arr2.set_ylabel('IoU')

        ax_arr1.legend()
        ax_arr2.legend()

        ax_arr1.grid(True)
        ax_arr2.grid(True)

        f1.savefig(os.path.join(results_store_dir, 'iou_train.jpg'),
                   format='jpg')
        f2.savefig(os.path.join(results_store_dir, 'iou_val.jpg'),
                   format='jpg')

    # -----------------------------------------------------------------------------------
    # Run Li 2006 experiments
    # -----------------------------------------------------------------------------------
    print("====> Running Experiments")
    experiment_gain_vs_len.main(model,
                                base_results_dir=results_store_dir,
                                iou_results=False)
    experiment_gain_vs_len.main(model,
                                base_results_dir=results_store_dir,
                                iou_results=False,
                                frag_size=np.array([11, 11]))

    experiment_gain_vs_spacing.main(model, base_results_dir=results_store_dir)
def main(model,
         train_params,
         data_set_params,
         base_results_store_dir='./results'):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    # -----------------------------------------------------------------------------------
    # Sanity Checks
    # -----------------------------------------------------------------------------------
    # Validate Data set parameters
    # ----------------------------
    required_data_set_params = ['data_set_dir']
    for key in required_data_set_params:
        assert key in data_set_params, 'data_set_params does not have required key {}'.format(
            key)
    data_set_dir = data_set_params['data_set_dir']

    # Optional
    train_subset_size = data_set_params.get('train_subset_size', None)
    test_subset_size = data_set_params.get('test_subset_size', None)
    c_len_arr = data_set_params.get('c_len_arr', None)
    beta_arr = data_set_params.get('beta_arr', None)
    alpha_arr = data_set_params.get('alpha_arr', None)
    gabor_set_arr = data_set_params.get('gabor_set_arr', None)

    # Validate training parameters
    # ----------------------------
    required_training_params = \
        ['train_batch_size', 'test_batch_size', 'learning_rate', 'num_epochs']
    for key in required_training_params:
        assert key in train_params, 'training_params does not have required key {}'.format(
            key)
    train_batch_size = train_params['train_batch_size']
    test_batch_size = train_params['test_batch_size']
    learning_rate = train_params['learning_rate']
    num_epochs = train_params['num_epochs']

    clip_negative_lateral_weights = train_params.get(
        'clip_negative_lateral_weights', False)

    if 'lr_sched_step_size' not in train_params:
        train_params['lr_sched_step_size'] = 30
    if 'lr_sched_gamma' not in train_params:
        train_params['lr_sched_gamma'] = 0.1
    if 'random_seed' not in train_params:
        train_params['random_seed'] = 1

    torch.manual_seed(train_params['random_seed'])
    np.random.seed(train_params['random_seed'])
    # -----------------------------------------------------------------------------------
    # Model
    # -----------------------------------------------------------------------------------
    print("====> Loading Model ")
    print("Name: {}".format(model.__class__.__name__))
    print(model)

    # Get name of contour integration layer
    temp = vars(model)  # Returns a dictionary.
    layers = temp['_modules']  # Returns all top level modules (layers)
    cont_int_layer_type = ''
    if 'contour_integration_layer' in layers:
        cont_int_layer_type = model.contour_integration_layer.__class__.__name__

    results_store_dir = os.path.join(
        base_results_store_dir, model.__class__.__name__ + '_' +
        cont_int_layer_type + datetime.now().strftime("_%Y%m%d_%H%M%S"))
    if not os.path.exists(results_store_dir):
        os.makedirs(results_store_dir)

    # -----------------------------------------------------------------------------------
    # Data Loader
    # -----------------------------------------------------------------------------------
    print("====> Setting up data loaders ")
    data_load_start_time = datetime.now()

    print("Data Source: {}".format(data_set_dir))
    print(
        "Restrictions:\n clen={},\n beta={},\n alpha={},\n gabor_sets={},\n train_subset={},\n "
        "test subset={}\n".format(c_len_arr, beta_arr, alpha_arr,
                                  gabor_set_arr, train_subset_size,
                                  test_subset_size))

    # get mean/std of dataset
    meta_data_file = os.path.join(data_set_dir, 'dataset_metadata.pickle')
    with open(meta_data_file, 'rb') as file_handle:
        meta_data = pickle.load(file_handle)
    # print("Channel mean {}, std {}".format(meta_data['channel_mean'], meta_data['channel_std']))

    # Pre-processing
    normalize = transforms.Normalize(mean=meta_data['channel_mean'],
                                     std=meta_data['channel_std'])

    train_set = dataset.Fields1993(data_dir=os.path.join(
        data_set_dir, 'train'),
                                   bg_tile_size=meta_data["full_tile_size"],
                                   transform=normalize,
                                   subset_size=train_subset_size,
                                   c_len_arr=c_len_arr,
                                   beta_arr=beta_arr,
                                   alpha_arr=alpha_arr,
                                   gabor_set_arr=gabor_set_arr)

    train_data_loader = DataLoader(dataset=train_set,
                                   num_workers=4,
                                   batch_size=train_batch_size,
                                   shuffle=True,
                                   pin_memory=True)

    val_set = dataset.Fields1993(data_dir=os.path.join(data_set_dir, 'val'),
                                 bg_tile_size=meta_data["full_tile_size"],
                                 transform=normalize,
                                 subset_size=test_subset_size,
                                 c_len_arr=c_len_arr,
                                 beta_arr=beta_arr,
                                 alpha_arr=alpha_arr,
                                 gabor_set_arr=gabor_set_arr)

    val_data_loader = DataLoader(dataset=val_set,
                                 num_workers=4,
                                 batch_size=test_batch_size,
                                 shuffle=True,
                                 pin_memory=True)

    print("Data loading Took {}. # Train {}, # Test {}".format(
        datetime.now() - data_load_start_time,
        len(train_data_loader) * train_batch_size,
        len(val_data_loader) * test_batch_size))

    # -----------------------------------------------------------------------------------
    # Optimizer
    # -----------------------------------------------------------------------------------
    optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=learning_rate)

    lr_scheduler = optim.lr_scheduler.StepLR(
        optimizer,
        step_size=train_params['lr_sched_step_size'],
        gamma=train_params['lr_sched_gamma'])

    detect_thres = 0.5

    # -----------------------------------------------------------------------------------
    # Loss Functions
    # -----------------------------------------------------------------------------------
    criterion = torch.nn.BCEWithLogitsLoss()
    # criterion = train_utils.ClassBalancedCrossEntropy()
    # criterion = train_utils.ClassBalancedCrossEntropyAttentionLoss()
    criterion_loss_sigmoid_outputs = False

    # Lateral Weights sparsity constraint
    lateral_sparsity_loss = train_utils.InvertedGaussianL1Loss(
        model.contour_integration_layer.lateral_e.weight.shape[2:],
        model.contour_integration_layer.lateral_i.weight.shape[2:],
        train_params['lateral_w_reg_gaussian_sigma'])
    # lateral_sparsity_loss = train_utils.WeightNormLoss(norm=1) # vanilla L1 Loss

    # # Penalize Negative Lateral Weights
    # negative_lateral_weights_penalty = train_utils.NegativeWeightsNormLoss()
    # negative_lateral_weights_penalty_weight = 0.05

    loss_function = train_utils.CombinedLoss(
        criterion=criterion,
        sigmoid_predictions=criterion_loss_sigmoid_outputs,
        sparsity_loss_fcn=lateral_sparsity_loss,
        sparsity_loss_weight=train_params['lateral_w_reg_weight'],
        # negative_weights_loss_fcn=negative_lateral_weights_penalty,
        # negative_weights_loss_weight=negative_lateral_weights_penalty_weight
    ).to(device)

    # -----------------------------------------------------------------------------------
    # Main Loop
    # -----------------------------------------------------------------------------------
    print("====> Starting Training ")
    training_start_time = datetime.now()

    train_history = []
    val_history = []
    lr_history = []

    best_iou = 0

    # Summary file
    summary_file = os.path.join(results_store_dir, 'summary.txt')
    file_handle = open(summary_file, 'w+')

    file_handle.write("Data Set Parameters {}\n".format('-' * 60))
    file_handle.write("Source           : {}\n".format(data_set_dir))
    file_handle.write("Restrictions     :\n")
    file_handle.write("  Lengths        : {}\n".format(c_len_arr))
    file_handle.write("  Beta           : {}\n".format(beta_arr))
    file_handle.write("  Alpha          : {}\n".format(alpha_arr))
    file_handle.write("  Gabor Sets     : {}\n".format(gabor_set_arr))
    file_handle.write("  Train Set Size : {}\n".format(train_subset_size))
    file_handle.write("  Test Set Size  : {}\n".format(test_subset_size))
    file_handle.write("Train Set Mean {}, std {}\n".format(
        train_set.data_set_mean, train_set.data_set_std))
    file_handle.write("Validation Set Mean {}, std {}\n".format(
        val_set.data_set_mean, train_set.data_set_std))

    file_handle.write("Training Parameters {}\n".format('-' * 60))
    file_handle.write("Random Seed      : {}\n".format(
        train_params['random_seed']))
    file_handle.write("Train images     : {}\n".format(len(train_set.images)))
    file_handle.write("Val images       : {}\n".format(len(val_set.images)))
    file_handle.write("Train batch size : {}\n".format(train_batch_size))
    file_handle.write("Val batch size   : {}\n".format(test_batch_size))
    file_handle.write("Epochs           : {}\n".format(num_epochs))
    file_handle.write("Optimizer        : {}\n".format(
        optimizer.__class__.__name__))
    file_handle.write("learning rate    : {}\n".format(learning_rate))
    for key in train_params.keys():
        if 'lr_sched' in key:
            print("  {}: {}".format(key, train_params[key]), file=file_handle)

    file_handle.write("Loss Fcn         : {}\n".format(
        loss_function.__class__.__name__))
    print(loss_function, file=file_handle)
    file_handle.write("IoU Threshold    : {}\n".format(detect_thres))
    file_handle.write("clip negative lateral weights: {}\n".format(
        clip_negative_lateral_weights))

    file_handle.write("Model Parameters {}\n".format('-' * 63))
    file_handle.write("Model Name       : {}\n".format(
        model.__class__.__name__))
    p1 = [item for item in temp if not item.startswith('_')]
    for var in sorted(p1):
        file_handle.write("{}: {}\n".format(var, getattr(model, var)))
    file_handle.write("\n")
    print(model, file=file_handle)

    temp = vars(model)  # Returns a dictionary.
    layers = temp['_modules']  # Returns all top level modules (layers)
    if 'contour_integration_layer' in layers:

        # print fixed hyper parameters
        file_handle.write("Contour Integration Layer:\n")
        file_handle.write("Type : {}\n".format(
            model.contour_integration_layer.__class__.__name__))
        cont_int_layer_vars = \
            [item for item in vars(model.contour_integration_layer) if not item.startswith('_')]
        for var in sorted(cont_int_layer_vars):
            file_handle.write("\t{}: {}\n".format(
                var, getattr(model.contour_integration_layer, var)))

        # print parameter names and whether they are trainable
        file_handle.write("Contour Integration Layer Parameters\n")
        layer_params = vars(model.contour_integration_layer)['_parameters']
        for k, v in sorted(layer_params.items()):
            file_handle.write("\t{}: requires_grad {}\n".format(
                k, v.requires_grad))

    file_handle.write("{}\n".format('-' * 80))
    file_handle.write("Training details\n")
    file_handle.write("Epoch, train_loss, train_iou, val_loss, val_iou, lr\n")

    print("train_batch_size={}, test_batch_size={}, lr={}, epochs={}".format(
        train_batch_size, test_batch_size, learning_rate, num_epochs))

    # Track evolution of these variables during training
    # (Must be a parameter of the contour integration layer)
    track_var_dict = {
        'a': [],
        'b': [],
        'j_xy': [],
        'j_yx': [],
        'i_bias': [],
        'e_bias': []
    }

    for epoch in range(0, num_epochs):
        epoch_start_time = datetime.now()

        train_history.append(
            iterate_epoch(
                model=model,
                data_loader=train_data_loader,
                loss_fcn=loss_function,
                optimizer1=optimizer,
                device=device,
                detect_th=detect_thres,
                is_train=True,
                clip_negative_lateral_weights=clip_negative_lateral_weights))

        val_history.append(
            iterate_epoch(model=model,
                          data_loader=val_data_loader,
                          loss_fcn=loss_function,
                          optimizer1=optimizer,
                          device=device,
                          detect_th=detect_thres,
                          is_train=False))

        lr_history.append(train_utils.get_lr(optimizer))
        lr_scheduler.step()

        # Track parameters
        cont_int_layer_params = model.contour_integration_layer.state_dict()
        for param in track_var_dict:
            if param in cont_int_layer_params:
                track_var_dict[param].append(
                    cont_int_layer_params[param].cpu().detach().numpy())

        print(
            "Epoch [{}/{}], Train: loss={:0.4f}, IoU={:0.4f}. Val: loss={:0.4f}, IoU={:0.4f}. "
            "Time {}".format(epoch + 1, num_epochs, train_history[epoch][0],
                             train_history[epoch][1], val_history[epoch][0],
                             val_history[epoch][1],
                             datetime.now() - epoch_start_time))

        if val_history[epoch][1] > best_iou:
            best_iou = val_history[epoch][1]
            torch.save(model.state_dict(),
                       os.path.join(results_store_dir, 'best_accuracy.pth'))

        file_handle.write(
            "[{}, {:0.4f}, {:0.4f}, {:0.4f}, {:0.4f}, {}],\n".format(
                epoch + 1, train_history[epoch][0], train_history[epoch][1],
                val_history[epoch][0], val_history[epoch][1],
                lr_history[epoch]))

    #  Store results & plots
    # -----------------------------------------------------------------------------------
    np.set_printoptions(precision=3,
                        linewidth=120,
                        suppress=True,
                        threshold=np.inf)
    file_handle.write("{}\n".format('-' * 80))

    training_time = datetime.now() - training_start_time
    print('Finished Training. Training took {}'.format(training_time))
    file_handle.write("Train Duration       : {}\n".format(training_time))

    train_utils.store_tracked_variables(
        track_var_dict,
        results_store_dir,
        n_ch=model.contour_integration_layer.edge_out_ch)

    train_history = np.array(train_history)
    val_history = np.array(val_history)
    train_utils.plot_training_history(train_history, val_history,
                                      results_store_dir)

    # Reload the model parameters that resulted in the best accuracy
    # --------------------------------------------------------------
    best_val_model_params = os.path.join(results_store_dir,
                                         'best_accuracy.pth')
    model.load_state_dict(
        torch.load(best_val_model_params, map_location=device))

    # Straight contour performance over validation dataset
    # ---------------------------------------------------------------------------------
    print(
        "====> Getting validation set straight contour performance per length")
    # Note: Different from experiments, contour are not centrally located
    c_len_arr = [1, 3, 5, 7, 9]
    c_len_iou_arr, c_len_loss_arr = validate_contour_data_set.get_performance_per_len(
        model, data_set_dir, device, beta_arr=[0], c_len_arr=c_len_arr)
    validate_contour_data_set.plot_iou_per_contour_length(
        c_len_arr,
        c_len_iou_arr,
        f_title='Val Dataset: straight contours',
        file_name=os.path.join(results_store_dir, 'iou_vs_len.png'))

    file_handle.write("{}\n".format('-' * 80))
    file_handle.write(
        "Contour Length vs IoU (Straight Contours in val. dataset) : {}\n".
        format(repr(c_len_iou_arr)))

    # Run Li 2006 experiments
    # -----------------------------------------------------------------------------------
    print("====> Running Experiments")
    optim_stim_dict = experiment_gain_vs_len.main(
        model, base_results_dir=results_store_dir, n_images=100)
    experiment_gain_vs_spacing.main(model,
                                    base_results_dir=results_store_dir,
                                    optimal_stim_dict=optim_stim_dict,
                                    n_images=100)

    file_handle.close()

    # View trained kernels
    # ------------------------------------------------------------------------------------
    trained_kernels_store_dir = os.path.join(results_store_dir,
                                             'trained_kernels')
    if not os.path.exists(trained_kernels_store_dir):
        os.makedirs(trained_kernels_store_dir)

    utils.view_ff_kernels(
        model.edge_extract.weight.data.cpu().detach().numpy(),
        results_store_dir=trained_kernels_store_dir)

    utils.view_spatial_lateral_kernels(
        model.contour_integration_layer.lateral_e.weight.data.cpu().detach(
        ).numpy(),
        model.contour_integration_layer.lateral_i.weight.data.cpu().detach(
        ).numpy(),
        results_store_dir=trained_kernels_store_dir,
        spatial_func=np.mean)
def main_worker(model, gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    results_store_dir = os.path.join(
        './results/imagenet_classification/',
        model.__class__.__name__ + datetime.now().strftime("_%Y%m%d_%H%M%S"))
    if not os.path.exists(results_store_dir):
        os.makedirs(results_store_dir)

    # -----------------------------------------------------------------------------------
    # Write Summary File
    # -----------------------------------------------------------------------------------
    summary_file = os.path.join(results_store_dir, 'summary.txt')
    f = open(summary_file, 'w+')

    # Write Training Setting
    f.write("Input Arguments:\n")
    for a_idx, arg in enumerate(vars(args)):
        f.write("\t[{}] {}: {}\n".format(a_idx, arg, getattr(args, arg)))
    f.write("{}\n".format('-' * 80))

    # Write the model architecture
    f.write("Model Name       : {}\n".format(model.__class__.__name__))
    f.write("\n")
    print(model, file=f)
    f.write("{}\n".format('-' * 80))

    # Hyper Parameters of Contour Integration Layer
    embedded_cont_int_model = get_embedded_cont_int_layer(model)

    if embedded_cont_int_model is not None:
        # print fixed hyper parameters
        f.write("Contour Integration Layer Hyper parameters\n")
        cont_int_layer_vars = \
            [item for item in vars(embedded_cont_int_model.contour_integration_layer) if not item.startswith('_')]
        for var in sorted(cont_int_layer_vars):
            f.write("\t{}: {}\n".format(
                var,
                getattr(embedded_cont_int_model.contour_integration_layer,
                        var)))

        # print parameter names and whether they are trainable
        f.write("Contour Integration Layer Parameters\n")
        layer_params = vars(
            embedded_cont_int_model.contour_integration_layer)['_parameters']
        for k, v in sorted(layer_params.items()):
            f.write("\t{}: requires_grad {}\n".format(k, v.requires_grad))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    f.write("{}\n".format('-' * 80))

    # -----------------------------------------------------------------------------------
    # create model
    # if args.pretrained:
    #     print("=> using pre-trained model '{}'".format(args.arch))
    #     model = models.__dict__[args.arch](pretrained=True)
    # else:
    #     print("=> creating model '{}'".format(args.arch))
    #     model = models.__dict__[args.arch]()

    # model.cuda()
    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(args.workers / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # # DataParallel will divide and allocate batch_size to all available GPUs
        # if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
        #     model.features = torch.nn.DataParallel(model.features)
        #     model.cuda()
        # else:
        model = torch.nn.DataParallel(model).cuda()

    # -----------------------------------------------------------------------------------
    # Define loss function (criterion) and optimizer
    # -----------------------------------------------------------------------------------
    f.write("Loss Functions and Optimizers\n")

    criterion1 = nn.CrossEntropyLoss().cuda(args.gpu)

    # ***********************************************************************************
    # Additional loss on contour integration lateral connections
    gaussian_kernel_sigma = 10
    reg_loss_weight = 0.0001

    lateral_sparsity_loss = None
    if embedded_cont_int_model is not None:
        lateral_sparsity_loss = train_utils.InvertedGaussianL1Loss(
            embedded_cont_int_model.contour_integration_layer.lateral_e.weight.
            shape[2:], embedded_cont_int_model.contour_integration_layer.
            lateral_i.weight.shape[2:], gaussian_kernel_sigma)

    loss_function = train_utils.CombinedLoss(
        criterion=criterion1,
        sigmoid_predictions=False,
        sparsity_loss_fcn=lateral_sparsity_loss,
        sparsity_loss_weight=reg_loss_weight,
        # negative_weights_loss_fcn=negative_lateral_weights_penalty,
        # negative_weights_loss_weight=negative_lateral_weights_penalty_weight
    ).cuda(args.gpu)

    f.write("Loss Fcn         : {}\n".format(loss_function.__class__.__name__))
    print(loss_function, file=f)
    # ***********************************************************************************

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    f.write("Optimizer        : {}\n".format(optimizer.__class__.__name__))

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True
    # -----------------------------------------------------------------------------------

    # Data loading code
    print(">>> Setting up Data loaders {}".format('.' * 80))
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion1, args)
        return

    f.write("Training\n")
    f.write(
        "Epoch, train_loss, train_accTop1, train_accTop5, val_loss val_accTop1, val_accTop5\n"
    )

    # Evaluate performance on Validation set before Training - This for for models that start
    # with pre-trained models
    val_loss, val_acc1, val_acc5 = validate(val_loader, model, loss_function,
                                            args)
    f.write(
        "[{}, np.nan, np.nan, np.nan, {:0.4f}, {:0.4f}, {:0.4f}],\n".format(
            0, val_loss, val_acc1, val_acc5))

    # -----------------------------------------------------------------------------------
    # Main Loop
    # -----------------------------------------------------------------------------------
    print(">>> Start Training {} ".format('.' * 80))
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train_loss, train_acc1, train_acc5 = \
            train(train_loader, model, loss_function, optimizer, epoch, args)

        # evaluate on validation set
        val_loss, val_acc1, val_acc5 = validate(val_loader, model,
                                                loss_function, args)

        f.write(
            "[{}, {:0.4f}, {:0.4f}, {:0.4f}, {:0.4f}, {:0.4f}, {:0.4f}],\n".
            format(epoch, train_loss, train_acc1, train_acc5, val_loss,
                   val_acc1, val_acc5))

        # remember best acc@1 and save checkpoint
        is_best = val_acc1 > best_acc1
        best_acc1 = max(val_acc1, best_acc1)

        # if not args.multiprocessing_distributed or \
        #         (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):

        if is_best:
            save_checkpoint(
                state={
                    'epoch': epoch + 1,
                    # 'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': optimizer.state_dict(),
                },
                is_best=is_best,
                filename=os.path.join(results_store_dir, 'best_accuracy.pth'))
    f.close()

    # Run Li 2006 experiments
    # -----------------------------------------------------------------------------------
    print("====> Running Experiments")
    optim_stim_dict = experiment_gain_vs_len.main(
        model,
        base_results_dir=results_store_dir,
        n_images=100,
        embedded_layer_identifier=model.conv1,
        iou_results=False)

    experiment_gain_vs_spacing.main(model,
                                    base_results_dir=results_store_dir,
                                    optimal_stim_dict=optim_stim_dict,
                                    n_images=100,
                                    embedded_layer_identifier=model.conv1)
    # -----------------------------------------------------------------------------------
    # Main Loop
    # -----------------------------------------------------------------------------------
    frag_size_list = [np.array([7, 7])]
    # frag_size_list = [(7, 7), (9, 9), (11, 11), (13, 13)]

    for frag_size in frag_size_list:
        print("Processing Fragment Size {} {}".format(frag_size, '-' * 50))
        frag_size = np.array(frag_size)

        print("Getting contour length results")
        optim_stim_dict = experiment_gain_vs_len.main(
            net,
            results_dir,
            iou_results=get_iou_results,
            embedded_layer_identifier=replacement_layer,
            frag_size=frag_size,
            n_images=100)

        print("Getting fragment spacing results")
        experiment_gain_vs_spacing.main(
            net,
            results_dir,
            embedded_layer_identifier=replacement_layer,
            frag_size=frag_size,
            optimal_stim_dict=optim_stim_dict,
            n_images=100)

    # -----------------------------------------------------------------------------------
    print("Running script took {}".format(datetime.now() - start_time))