Пример #1
0
def do_marking_run(overall_marking_percentage,
                   experiment_directory,
                   tensorboard_log_directory,
                   augment=True):

    # Setup experiment directory
    if os.path.isdir(experiment_directory):
        error_message = f"Directory {experiment_directory} already exists. By default we assume you don't want to "\
                        "repeat the marking stage."
        logger.info(error_message)
        return

    os.makedirs(experiment_directory)

    logfile_path = os.path.join(experiment_directory, 'marking.log')
    setup_logger_tqdm(filepath=logfile_path)

    training_set = torchvision.datasets.CIFAR10(root="experiments/datasets",
                                                download=True)

    # Marking network is the resnet18 we trained on CIFAR10
    marking_network = torchvision.models.resnet18(pretrained=False,
                                                  num_classes=10)
    checkpoint_path = "experiments/table2/step1/checkpoint.pth"
    marking_network_checkpoint = torch.load(checkpoint_path)
    marking_network.load_state_dict(
        marking_network_checkpoint["model_state_dict"])

    # Carriers
    marking_network_fc_feature_size = 512
    carriers = torch.randn(len(training_set.classes),
                           marking_network_fc_feature_size)
    carriers /= torch.norm(carriers, dim=1, keepdim=True)
    torch.save(carriers, os.path.join(experiment_directory, "carriers.pth"))

    # Load randomly sampled images from random class along with list of original indexes
    # Assume each class has equal number of images, adjust class_marking_percentage to
    # fit overall marking_percentage
    class_marking_percentage = overall_marking_percentage * len(
        training_set.classes)
    class_id, images, original_indexes = get_images_for_marking_cifar10(
        training_set, tensorboard_log_directory, class_marking_percentage)

    optimizer = lambda x: torch.optim.AdamW(x, lr=0.1)
    epochs = 100
    batch_size = 32
    output_directory = os.path.join(experiment_directory, "marked_images")
    if not augment:
        augmentation = None
    marked_images = do_marking(output_directory,
                               marking_network,
                               images,
                               original_indexes,
                               carriers,
                               class_id,
                               NORMALIZE_CIFAR10,
                               optimizer,
                               tensorboard_log_directory,
                               epochs=epochs,
                               batch_size=batch_size,
                               overwrite=True,
                               augmentation=augmentation)

    # Show marked images in Tensorboard
    tensorboard_summary_writer = SummaryWriter(
        log_dir=tensorboard_log_directory)
    images_for_tensorboard = [transforms.ToTensor()(x) for x in marked_images]
    img_grid = torchvision.utils.make_grid(images_for_tensorboard, nrow=16)
    tensorboard_summary_writer.add_image('marked_images', img_grid)

    # Record marking completion
    with open(os.path.join(experiment_directory, "marking.complete"),
              "w") as fh:
        fh.write("1")
def do_marking_run_multiclass(overall_marking_percentage, experiment_directory,
                              tensorboard_log_directory, marking_network,
                              training_set):

    # Setup experiment directory
    if os.path.isdir(experiment_directory):
        error_message = f"Directory {experiment_directory} already exists. By default we assume you don't want to "\
                        "repeat the marking stage."
        logger.info(error_message)
        return

    os.makedirs(experiment_directory)

    logfile_path = os.path.join(experiment_directory, 'marking.log')
    setup_logger_tqdm(filepath=logfile_path)

    # Carriers
    marking_network_fc_feature_size = 512
    carriers = torch.randn(len(training_set.classes),
                           marking_network_fc_feature_size)
    carriers /= torch.norm(carriers, dim=1, keepdim=True)
    torch.save(carriers, os.path.join(experiment_directory, "carriers.pth"))

    # { 0 : [(image1, original_index1),(image2, original_index2)...], 1 : [....] }
    image_data = get_images_for_marking_multiclass(training_set,
                                                   tensorboard_log_directory,
                                                   overall_marking_percentage)

    marked_images = []
    for class_id, image_list in image_data.items():
        if image_list:
            images, original_indexes = map(list, zip(*image_list))
            optimizer = lambda x: torch.optim.AdamW(x)
            epochs = 250
            batch_size = 8
            output_directory = os.path.join(experiment_directory,
                                            "marked_images")
            augmentation = differentiable_augmentations.CenterCrop(256, 224)
            tensorboard_class_log = os.path.join(tensorboard_log_directory,
                                                 f"class_{class_id}")
            marked_images_temp = do_marking(output_directory,
                                            marking_network,
                                            images,
                                            original_indexes,
                                            carriers,
                                            class_id,
                                            NORMALIZE_IMAGENETTE,
                                            optimizer,
                                            tensorboard_class_log,
                                            epochs=epochs,
                                            batch_size=batch_size,
                                            overwrite=False,
                                            augmentation=augmentation)

            marked_images = marked_images + marked_images_temp

    # Show marked images in Tensorboard - centercrop for grid
    from PIL import Image as im
    tensorboard_summary_writer = SummaryWriter(
        log_dir=tensorboard_log_directory)
    transform = transforms.Compose(
        [transforms.CenterCrop(256),
         transforms.ToTensor()])
    images_for_tensorboard = [
        transform(im.fromarray(x)) for x in marked_images
    ]
    img_grid = torchvision.utils.make_grid(images_for_tensorboard, nrow=3)
    tensorboard_summary_writer.add_image('marked_images', img_grid)

    # Record marking completion
    with open(os.path.join(experiment_directory, "marking.complete"),
              "w") as fh:
        fh.write("1")
Пример #3
0
def do_marking_run_multiclass(overall_marking_percentage,
                              experiment_directory,
                              tensorboard_log_directory,
                              augment=True):
    # Setup experiment directory
    if os.path.isdir(experiment_directory):
        error_message = f"Directory {experiment_directory} already exists. By default we assume you don't want to "\
                        "repeat the marking stage."
        logger.info(error_message)
        return

    os.makedirs(experiment_directory)

    logfile_path = os.path.join(experiment_directory, 'marking.log')
    setup_logger_tqdm(filepath=logfile_path)

    training_set = torchvision.datasets.CIFAR100(root="experiments/datasets",
                                                 download=True)

    # Marking network is the resnet18 we trained on CIFAR10
    # marking_network = torchvision.models.resnet18(pretrained=False, num_classes=10)
    marking_network = resnet(num_classes=100,
                             depth=164,
                             block_name='bottleneck')
    checkpoint_path = "experiments/cifar100/table1/step1/checkpoint.pth"
    marking_network_checkpoint = torch.load(checkpoint_path)
    marking_network.load_state_dict({
        k.replace("module.", ""): v
        for k, v in marking_network_checkpoint["model_state_dict"].items()
    })

    # Carriers
    # marking_network_fc_feature_size = 512
    marking_network_fc_feature_size = 256
    carriers = torch.randn(len(training_set.classes),
                           marking_network_fc_feature_size)
    carriers /= torch.norm(carriers, dim=1, keepdim=True)
    torch.save(carriers, os.path.join(experiment_directory, "carriers.pth"))

    # { 0 : [(image1, original_index1),(image2, original_index2)...], 1 : [....] }
    image_data = get_images_for_marking_multiclass_cifar10(
        training_set, tensorboard_log_directory, overall_marking_percentage)

    marked_images = []
    for class_id, image_list in image_data.items():
        if image_list:
            images, original_indexes = map(list, zip(*image_list))
            optimizer = lambda x: torch.optim.AdamW(x, lr=0.1)
            epochs = 100
            batch_size = 32
            output_directory = os.path.join(experiment_directory,
                                            "marked_images")
            if not augment:
                augmentation = None

            tensorboard_class_log = os.path.join(tensorboard_log_directory,
                                                 f"class_{class_id}")
            marked_images_temp = do_marking(output_directory,
                                            marking_network,
                                            images,
                                            original_indexes,
                                            carriers,
                                            class_id,
                                            NORMALIZE_CIFAR10,
                                            optimizer,
                                            tensorboard_class_log,
                                            epochs=epochs,
                                            batch_size=batch_size,
                                            overwrite=False,
                                            augmentation=augmentation)

            marked_images = marked_images + marked_images_temp

    # Show marked images in Tensorboard
    tensorboard_summary_writer = SummaryWriter(
        log_dir=tensorboard_log_directory)
    images_for_tensorboard = [transforms.ToTensor()(x) for x in marked_images]
    img_grid = torchvision.utils.make_grid(images_for_tensorboard, nrow=16)
    tensorboard_summary_writer.add_image('marked_images', img_grid)

    # Record marking completion
    with open(os.path.join(experiment_directory, "marking.complete"),
              "w") as fh:
        fh.write("1")