Пример #1
0

use_cuda=True if torch.cuda.is_available() else False
device=torch.device('cuda:0') if use_cuda else torch.device('cpu')
print(device)

print(os.getcwd())

classifier = models.ClassifierPro()
classifier2 = models.ClassifierPro()
#encoder = models.Encoder()
encoder = tmodels.resnet50(pretrained=True)
#encoder = tmodels.inception_v3(pretrained=True)
#encoder.aux_logits=False
encoder.fc = nn.Sequential()
discriminator = models.DCDPro()
#discriminator = models.DCDPro(input_features=128)
ssnet = models.TaxonNet(64)
discriminator_genus = models.DCDPro()
genusnet = models.TaxonNet(510)
discriminator_family = models.DCDPro()
familynet = models.TaxonNet(151)

classifier.to(device)
encoder.to(device)
#classifier2.to(device)
#encoder2.to(device)
discriminator.to(device)
ssnet.to(device)
discriminator_genus.to(device)
genusnet.to(device)
Пример #2
0
#datadir = '/home/villacis/Desktop/villacis/datasets/plantclef_minida_cropped'

num_epochs1 = 75
num_epochs2 = 200
num_epochs3 = 250

use_cuda = True if torch.cuda.is_available() else False
device = torch.device('cuda:0') if use_cuda else torch.device('cpu')

classifier = models.ClassifierPro()
#encoder = models.Encoder()
encoder = tmodels.resnet18(pretrained=True)
#encoder = tmodels.inception_v3(pretrained=True)
#encoder.aux_logits=False
encoder.fc = nn.Sequential()
discriminator = models.DCDPro()
#discriminator = models.DCDPro(input_features=128)

classifier.to(device)
encoder.to(device)
discriminator.to(device)

loss_fn = torch.nn.CrossEntropyLoss()
loss_fn2 = ContrastiveLoss()  ##quitar
loss_fn3 = SpecLoss()  ##quitar
# -----------------------------------------------------------------------------
## etapa 1: entrenar g y h
print("||||| Stage 1 |||||")
optimizer = torch.optim.Adam(list(encoder.parameters()) +
                             list(classifier.parameters()),
                             lr=0.0001)
Пример #3
0
def main(args):
    # Training settings

    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)


    config = wandb.config
    config.lr = args.lr
    config.batch_size = args.batch_size
    config.gamma = args.gamma
    config.epochs = args.epochs
    config.test_batch_size = args.test_batch_size
    config.log_interval = args.log_interval
    config.image_size = args.image_size
    config.dry_run = args.dry_run
    config.num_workers = args.num_workers
    config.stage = args.stage
    config.run_name = args.run_name
    config.data_dir = args.data_dir if args.data_dir is not None else '/dev/shm/dataset'


    tags = ['baseline', f'stage {config.stage}']
    wandb.init(project='FADA', config=config, tags=tags)

    if config.run_name is not None:
        wandb.run.name = config.run_name

    result_path = os.path.join(args.result_path, wandb.run.name)
    config.result_path = result_path
    if not os.path.exists(result_path):
        os.makedirs(result_path)


    device = torch.device("cuda" if use_cuda else "cpu")

    train_kwargs = {'batch_size': config.batch_size }
    test_kwargs = {'batch_size': config.test_batch_size }
    if use_cuda:
        cuda_kwargs = {'num_workers': config.num_workers,
                       'pin_memory': True,
                       'shuffle': True}
        train_kwargs.update(cuda_kwargs)
        test_kwargs.update(cuda_kwargs)


    data_transforms = {
    'train': transforms.Compose([
        transforms.RandomRotation(15),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(),
        transformations.TileHerb(),
        transforms.CenterCrop((config.image_size, config.image_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.2974, 0.3233, 0.2370], [0.1399, 0.1464, 0.1392])
    ]),
    'val': transforms.Compose([
        transformations.CropField(),
        transforms.CenterCrop((config.image_size, config.image_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.2974, 0.3233, 0.2370], [0.1399, 0.1464, 0.1392])
    ]),
    'val_photo': transforms.Compose([
        transformations.CropField(),
        transforms.CenterCrop((config.image_size, config.image_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.2974, 0.3233, 0.2370], [0.1399, 0.1464, 0.1392])
    ])
}
    
    data_dir = config.data_dir 
    herbarium = os.path.join(data_dir, 'herbarium')
    photo = os.path.join(data_dir, 'photo')

    classifier = models.ClassifierPro().to(device)
    encoder = tmodels.resnet50(pretrained=True).to(device)
    encoder.fc = nn.Sequential()
    ssnet = models.TaxonNet(64).to(device)
    genusnet = models.TaxonNet(510).to(device)
    familynet = models.TaxonNet(151).to(device)

    discriminator = models.DCDPro().to(device)
    discriminator_genus = models.DCDPro().to(device)
    discriminator_family = models.DCDPro().to(device)

    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                        data_transforms[x])
                for x in ['train', 'val']}
    
    base_mapping = image_datasets['train'].class_to_idx
    class_name_to_id = image_datasets['train'].class_to_idx
    id_to_class_name = {v: k for k, v in class_name_to_id.items()}
    
    siamese_dataset = data_loader.FADADatasetSS(data_dir,
                                    photo,
                                    'train',
                                    image_datasets['train'].class_to_idx,
                                    class_name_to_id,
                                    config.image_size
                                    )

    if(config.stage == 1):
        stage_1(config, device, image_datasets, classifier, encoder, ssnet, genusnet, familynet, train_kwargs, test_kwargs)
    elif(config.stage == 2):
        encoder.load_state_dict(torch.load('../best/encoder_fada_extra.pth'))
        classifier.load_state_dict(torch.load('../best/classifier_fada_extra.pth'))
        stage_2(config, device, siamese_dataset, discriminator, encoder, discriminator_genus, discriminator_family, train_kwargs, test_kwargs)