description='Script to evaluate single checkpoint on TFD.')
    parser.add_argument("-s", "--split", default='0',
                        help='Training split of TFD to use. (0-4)')
    parser.add_argument("checkpoint_file",
                        help='Path to single model checkpoint (.pkl) file.')
    args = parser.parse_args()

    checkpoint_file = args.checkpoint_file
    fold = int(args.split)
    dataset_path = os.path.join(data_paths.tfd_data_path, 'npy_files/TFD_96/split_'+str(fold))

    print 'Checkpoint: %s' % checkpoint_file
    print 'Testing on split %d\n' % fold

    # Load model
    model = SupervisedModel('evaluation', './')

    # Load dataset
    supervised_data_loader = SupervisedDataLoader(dataset_path)
    test_data_container = supervised_data_loader.load(2)
    test_data_container.X = numpy.float32(test_data_container.X)
    test_data_container.X /= 255.0
    test_data_container.X *= 2.0

    # Construct evaluator
    preprocessor = [util.Normer3(filter_size=5, num_channels=1)]

    evaluator = util.Evaluator(model, test_data_container,
                               checkpoint_file, preprocessor)

    # For the inputted checkpoint, compute the overall test accuracy
if test_split < 0 or test_split > 9:
    raise Exception("Testing Split must be in range 0-9.")
print('Using CK+ testing split: {}'.format(test_split))

checkpoint_dir = os.path.join(args.checkpoint_dir,
                              'checkpoints_48_' + str(test_split))
print 'Checkpoint dir: ', checkpoint_dir

pid = os.getpid()
print('PID: {}'.format(pid))
f = open('pid_' + str(test_split), 'wb')
f.write(str(pid) + '\n')
f.close()

# Load model
model = SupervisedModel('experiment', './', learning_rate=1e-2)
#util.load_checkpoint(model, "./checkpoints_5/experiment-07m-20d-16h-24m-52s.pkl")
monitor = util.Monitor(model,
                       checkpoint_directory=checkpoint_dir,
                       save_steps=1000)

# Add dropout to fully-connected layer
model.fc4.dropout = 0.5
model._compile()

# Loading CK+ dataset
print('Loading Data')
#supervised_data_loader = SupervisedDataLoaderCrossVal(
#    data_paths.ck_plus_data_path)
#train_data_container = supervised_data_loader.load('train', train_split)
#test_data_container = supervised_data_loader.load('test', train_split)
Exemplo n.º 3
0
def main(cfg: OmegaConf):
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    stream_handler = logging.StreamHandler()
    stream_handler.setLevel(logging.INFO)
    stream_handler.terminator = ""
    logger.addHandler(stream_handler)

    check_hydra_conf(cfg)
    init_ddp(cfg)

    # fix seed
    seed = cfg["parameter"]["seed"]
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    rank = cfg["distributed"]["local_rank"]
    logger.info("Using {}".format(rank))

    root = "~/pytorch_datasets"
    if cfg["experiment"]["name"].lower() == "cifar10":
        transform = create_simclr_data_augmentation(
            cfg["experiment"]["strength"], size=32)
        training_dataset = torchvision.datasets.CIFAR10(root=root,
                                                        train=True,
                                                        download=True,
                                                        transform=transform)
        validation_dataset = torchvision.datasets.CIFAR10(
            root=root,
            train=False,
            download=True,
            transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
            ]))
        num_classes = 10
    elif cfg["experiment"]["name"].lower() == "cifar100":
        transform = create_simclr_data_augmentation(
            cfg["experiment"]["strength"], size=32)
        training_dataset = torchvision.datasets.CIFAR100(root=root,
                                                         train=True,
                                                         download=True,
                                                         transform=transform)
        validation_dataset = torchvision.datasets.CIFAR100(
            root=root,
            train=False,
            download=True,
            transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
            ]))
        num_classes = 100
    else:
        assert cfg["experiment"]["name"].lower() in {"cifar10", "cifar100"}

    sampler = torch.utils.data.distributed.DistributedSampler(training_dataset,
                                                              shuffle=True)
    training_data_loader = DataLoader(
        dataset=training_dataset,
        sampler=sampler,
        num_workers=cfg["parameter"]["num_workers"],
        batch_size=cfg["experiment"]["batches"],
        pin_memory=True,
        drop_last=True,
    )

    validation_sampler = torch.utils.data.distributed.DistributedSampler(
        validation_dataset, shuffle=False)
    validation_data_loader = DataLoader(
        dataset=validation_dataset,
        sampler=validation_sampler,
        num_workers=cfg["parameter"]["num_workers"],
        batch_size=cfg["experiment"]["batches"],
        pin_memory=True,
        drop_last=False,
    )

    model = SupervisedModel(base_cnn=cfg["experiment"]["base_cnn"],
                            num_classes=num_classes)
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = model.to(rank)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])

    learning(cfg, training_data_loader, validation_data_loader, model)