Ejemplo n.º 1
0
    transforms.Compose([
        #transforms.Resize(256),
        #transforms.CenterCrop(224),
        transforms.Resize((900, 600)),
        transforms.CenterCrop((850, 550)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

test_transforms = transforms.Compose([
    #transforms.Resize(256),
    transforms.RandomRotation(15),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(),
    transformations.CropField(),
    #transformations.TileCircle(),
    transforms.CenterCrop((224, 224)),
    #p.torch_transform(),
    transforms.ToTensor(),
    #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    #transforms.Normalize([0.7410, 0.7141, 0.6500], [0.0808, 0.0895, 0.1141])
    transforms.Normalize([0.2974, 0.3233, 0.2370], [0.1399, 0.1464, 0.1392])
])
test_data_augmentation = 5
test_transforms_grande = transforms.Compose([
    transforms.Resize((900, 600)),
    transforms.CenterCrop((850, 550)),
    #transforms.ColorJitter(),
    #transforms.RandomRotation(30),
    #p.torch_transform(),
Ejemplo n.º 2
0
def main(args):
    # Training settings

    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)


    config = wandb.config
    config.lr = args.lr
    config.batch_size = args.batch_size
    config.gamma = args.gamma
    config.epochs = args.epochs
    config.test_batch_size = args.test_batch_size
    config.log_interval = args.log_interval
    config.image_size = args.image_size
    config.dry_run = args.dry_run
    config.num_workers = args.num_workers
    config.stage = args.stage
    config.run_name = args.run_name
    config.data_dir = args.data_dir if args.data_dir is not None else '/dev/shm/dataset'


    tags = ['baseline', f'stage {config.stage}']
    wandb.init(project='FADA', config=config, tags=tags)

    if config.run_name is not None:
        wandb.run.name = config.run_name

    result_path = os.path.join(args.result_path, wandb.run.name)
    config.result_path = result_path
    if not os.path.exists(result_path):
        os.makedirs(result_path)


    device = torch.device("cuda" if use_cuda else "cpu")

    train_kwargs = {'batch_size': config.batch_size }
    test_kwargs = {'batch_size': config.test_batch_size }
    if use_cuda:
        cuda_kwargs = {'num_workers': config.num_workers,
                       'pin_memory': True,
                       'shuffle': True}
        train_kwargs.update(cuda_kwargs)
        test_kwargs.update(cuda_kwargs)


    data_transforms = {
    'train': transforms.Compose([
        transforms.RandomRotation(15),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(),
        transformations.TileHerb(),
        transforms.CenterCrop((config.image_size, config.image_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.2974, 0.3233, 0.2370], [0.1399, 0.1464, 0.1392])
    ]),
    'val': transforms.Compose([
        transformations.CropField(),
        transforms.CenterCrop((config.image_size, config.image_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.2974, 0.3233, 0.2370], [0.1399, 0.1464, 0.1392])
    ]),
    'val_photo': transforms.Compose([
        transformations.CropField(),
        transforms.CenterCrop((config.image_size, config.image_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.2974, 0.3233, 0.2370], [0.1399, 0.1464, 0.1392])
    ])
}
    
    data_dir = config.data_dir 
    herbarium = os.path.join(data_dir, 'herbarium')
    photo = os.path.join(data_dir, 'photo')

    classifier = models.ClassifierPro().to(device)
    encoder = tmodels.resnet50(pretrained=True).to(device)
    encoder.fc = nn.Sequential()
    ssnet = models.TaxonNet(64).to(device)
    genusnet = models.TaxonNet(510).to(device)
    familynet = models.TaxonNet(151).to(device)

    discriminator = models.DCDPro().to(device)
    discriminator_genus = models.DCDPro().to(device)
    discriminator_family = models.DCDPro().to(device)

    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                        data_transforms[x])
                for x in ['train', 'val']}
    
    base_mapping = image_datasets['train'].class_to_idx
    class_name_to_id = image_datasets['train'].class_to_idx
    id_to_class_name = {v: k for k, v in class_name_to_id.items()}
    
    siamese_dataset = data_loader.FADADatasetSS(data_dir,
                                    photo,
                                    'train',
                                    image_datasets['train'].class_to_idx,
                                    class_name_to_id,
                                    config.image_size
                                    )

    if(config.stage == 1):
        stage_1(config, device, image_datasets, classifier, encoder, ssnet, genusnet, familynet, train_kwargs, test_kwargs)
    elif(config.stage == 2):
        encoder.load_state_dict(torch.load('../best/encoder_fada_extra.pth'))
        classifier.load_state_dict(torch.load('../best/classifier_fada_extra.pth'))
        stage_2(config, device, siamese_dataset, discriminator, encoder, discriminator_genus, discriminator_family, train_kwargs, test_kwargs)