Esempio n. 1
0
def main():
    size = 224
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    train_list = make_datapath_list(phase="train")
    val_list = make_datapath_list(phase="val")

    train_dataset = HymenopteraDataset(file_list=train_list,
                                       transform=ImageTransform(
                                           size, mean, std),
                                       phase='train')
    val_dataset = HymenopteraDataset(file_list=val_list,
                                     transform=ImageTransform(size, mean, std),
                                     phase='val')

    batch_size = 32

    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=batch_size,
                                                   shuffle=True)

    val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=batch_size,
                                                 shuffle=False)

    dataloaders_dict = {"train": train_dataloader, "val": val_dataloader}

    num_epochs = 2
    train_model(net,
                dataloaders_dict,
                criterion,
                optimizer,
                num_epochs=num_epochs)
Esempio n. 2
0
def main():
    G = Generator(z_dim=20)
    D = Discriminator(z_dim=20)
    E = Encoder(z_dim=20)
    G.apply(weights_init)
    D.apply(weights_init)
    E.apply(weights_init)

    train_img_list=make_datapath_list(num=200)
    mean = (0.5,)
    std = (0.5,)
    train_dataset = GAN_Img_Dataset(file_list=train_img_list, transform=ImageTransform(mean, std))

    batch_size = 64
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    num_epochs = 1500
    G_update, D_update, E_update = train_model(G, D, E, dataloader=train_dataloader, num_epochs=num_epochs, save_model_name='Efficient_GAN')
Esempio n. 3
0
    G = torch.nn.DataParallel(G)
    D = torch.nn.DataParallel(D)
    print("parallel mode")

batch_size = 8
z_dim = 20
fixed_z = torch.randn(batch_size, z_dim)
fixed_z = fixed_z.view(fixed_z.size(0), fixed_z.size(1), 1, 1)

fake_images = G(fixed_z.to(device))

train_img_list = make_datapath_list(num=1000)
mean = (0.5, )
std = (0.5, )
train_dataset = GAN_Img_Dataset(file_list=train_img_list,
                                transform=ImageTransform(mean, std))

train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True)

batch_iterator = iter(train_dataloader)

# fetch first element
images = next(batch_iterator)

fig = plt.figure(figsize=(15, 6))
for i in range(0, 5):
    #train is upside
    plt.subplot(2, 5, i + 1)
    plt.imshow(images[i][0].cpu().detach().numpy(), 'gray')
Esempio n. 4
0
def main():
    train_list = make_datapath_list(phase="train")
    val_list = make_datapath_list(phase="val")

    size, mean, std = 224, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)

    train_dataset = HymenopteraDataset(file_list=train_list,
                                       transform=ImageTransform(
                                           size, mean, std),
                                       phase="train")
    val_dataset = HymenopteraDataset(file_list=val_list,
                                     transform=ImageTransform(size, mean, std),
                                     phase="val")

    batch_size = 32

    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=batch_size,
                                                   shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=batch_size,
                                                 shuffle=False)

    dataloaders_dict = {"train": train_dataloader, "val": val_dataloader}

    use_pretrained = True
    net = models.vgg16(pretrained=use_pretrained)

    net.classifier[6] = nn.Linear(in_features=4096, out_features=2)

    net.train()

    criterion = nn.CrossEntropyLoss()

    #store params in params_to_update for fine tuning

    params_to_update_1 = []
    params_to_update_2 = []
    params_to_update_3 = []

    update_param_names_1 = ["features"]
    update_param_names_2 = [
        "classifier.0.weight", "classifier.0.bias", "classifier.3.weight",
        "classifier.3.bias"
    ]
    update_param_names_3 = ["classifier.6.weight", "classifier.6.bias"]

    for name, param in net.named_parameters():
        if update_param_names_1[0] in name:
            param.requires_grad = True
            params_to_update_1.append(param)

        elif name in update_param_names_2:
            param.requires_grad = True
            params_to_update_2.append(param)

        elif name in update_param_names_3:
            param.requires_grad = True
            params_to_update_3.append(param)

        else:
            param.requires_grad = False

    optimizer = optim.SGD([{
        'params': params_to_update_1,
        'lr': 1e-4
    }, {
        'params': params_to_update_2,
        'lr': 5e-4
    }, {
        'params': params_to_update_3,
        'lr': 1e-3
    }],
                          momentum=0.9)

    num_epochs = 2
    train_model(net,
                dataloaders_dict,
                criterion,
                optimizer,
                num_epochs=num_epochs)
G.to(device)
D.to(device)

"""use GPU in parallel"""
if device == 'cuda':
    G = torch.nn.DataParallel(G)
    D = torch.nn.DataParallel(D)
    print("parallel mode")


batch_size = 5

train_img_list = make_datapath_list(num=1000)
mean = (0.5,)
std = (0.5,)
train_dataset = GAN_Img_Dataset(file_list=train_img_list, transform=ImageTransform(mean, std))

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False)

batch_iterator = iter(train_dataloader)

# fetch first element
images = next(batch_iterator)



x = images[0:5]
x = x.to(device)

z = torch.randn(5, 20).to(device)
z = z.view(z.size(0), z.size(1), 1, 1)