def train_net(train, val, model, name):
    transformations_train = transforms.apply_chain([
        transforms.random_fliplr(),
        transforms.random_flipud(),
        transforms.augment(),
        torchvision.transforms.ToTensor()
    ])

    transformations_val = transforms.apply_chain([
        torchvision.transforms.ToTensor(),
    ])

    dset_train = KaggleAmazonJPGDataset(train, paths.train_jpg, transformations_train, divide=False)
    train_loader = DataLoader(dset_train,
                              batch_size=64,
                              shuffle=True,
                              num_workers=10,
                              pin_memory=True)

    dset_val = KaggleAmazonJPGDataset(val, paths.train_jpg, transformations_val, divide=False)
    val_loader = DataLoader(dset_val,
                            batch_size=64,
                            num_workers=10,
                            pin_memory=True)

    ignored_params = list(map(id, chain(
        model.classifier.parameters(),
        model.layer1.parameters(),
        model.layer2.parameters(),
        model.layer3.parameters(),
        model.layer4.parameters()
    )))
    base_params = filter(lambda p: id(p) not in ignored_params,
                         model.parameters())

    optimizer = optim.Adam([
        {'params': base_params},
        {'params': model.layer1.parameters()},
        {'params': model.layer2.parameters()},
        {'params': model.layer3.parameters()},
        {'params': model.layer4.parameters()},
        {'params': model.classifier.parameters()}
    ], lr=0, weight_decay=0.0005)

    trainer = ModuleTrainer(model)

    def schedule(current_epoch, current_lrs, **logs):
        lrs = [1e-3, 1e-4, 1e-5]
        epochs = [0, 2, 10]

        for lr, epoch in zip(lrs, epochs):
            if current_epoch >= epoch:
                current_lrs[5] = lr
                if current_epoch >= 1:
                    current_lrs[4] = lr * 0.4
                    current_lrs[3] = lr * 0.2
                    current_lrs[2] = lr * 0.1
                    current_lrs[1] = lr * 0.05
                    current_lrs[0] = lr * 0.01

        return current_lrs

    trainer.set_callbacks([
        callbacks.ModelCheckpoint(
            paths.models,
            name,
            save_best_only=False,
            saving_strategy=lambda epoch: True
        ),
        CSVLogger('./logs/' + name),
        LearningRateScheduler(schedule)
    ])

    trainer.compile(loss=nn.BCELoss(),
                    optimizer=optimizer)

    trainer.fit_loader(train_loader,
                       val_loader,
                       nb_epoch=35,
                       verbose=1,
                       cuda_device=0)
Ejemplo n.º 2
0
def train_net(train, val, unsupervised, model, name):
    unsupervised_initialization = mlb.transform(unsupervised['tags'].str.split()).astype(np.float32)
    unsupervised_samples = unsupervised['image_name'].as_matrix()

    unsupervised_initialization = unsupervised_initialization[:len(unsupervised_initialization)//2*3]
    unsupervised_samples = unsupervised_samples[:len(unsupervised_samples)//2*3]

    transformations_train = transforms.apply_chain([
        transforms.random_fliplr(),
        transforms.random_flipud(),
        transforms.augment(),
        torchvision.transforms.ToTensor()
    ])

    transformations_val = transforms.apply_chain([
        torchvision.transforms.ToTensor()
    ])

    dset_train_unsupervised = KaggleAmazonUnsupervisedDataset(
        unsupervised_samples,
        paths.test_jpg,
        '.jpg',
        transformations_train,
        transformations_val,
        unsupervised_initialization
    )

    dset_train_supervised = KaggleAmazonJPGDataset(train, paths.train_jpg, transformations_train, divide=False)
    dset_train = KaggleAmazonSemiSupervisedDataset(dset_train_supervised, dset_train_unsupervised, None, indices=False)

    train_loader = DataLoader(dset_train,
                              batch_size=64,
                              shuffle=True,
                              num_workers=10,
                              pin_memory=True)

    dset_val = KaggleAmazonJPGDataset(val, paths.train_jpg, transformations_val, divide=False)
    val_loader = DataLoader(dset_val,
                            batch_size=64,
                            num_workers=10,
                            pin_memory=True)

    ignored_params = list(map(id, chain(
        model.classifier.parameters(),
        model.layer1.parameters(),
        model.layer2.parameters(),
        model.layer3.parameters(),
        model.layer4.parameters()
    )))
    base_params = filter(lambda p: id(p) not in ignored_params,
                         model.parameters())

    optimizer = optim.Adam([
        {'params': base_params},
        {'params': model.layer1.parameters()},
        {'params': model.layer2.parameters()},
        {'params': model.layer3.parameters()},
        {'params': model.layer4.parameters()},
        {'params': model.classifier.parameters()}
    ], lr=0, weight_decay=0.0001)

    trainer = ModuleTrainer(model)

    def schedule(current_epoch, current_lrs, **logs):
        lrs = [1e-3, 1e-4, 0.5e-4, 1e-5, 0.5e-5]
        epochs = [0, 1, 6, 8, 12]

        for lr, epoch in zip(lrs, epochs):
            if current_epoch >= epoch:
                current_lrs[5] = lr
                if current_epoch >= 2:
                    current_lrs[4] = lr * 1
                    current_lrs[3] = lr * 1
                    current_lrs[2] = lr * 1
                    current_lrs[1] = lr * 1
                    current_lrs[0] = lr * 0.1

        return current_lrs

    trainer.set_callbacks([
        ModelCheckpoint(
            paths.models,
            name,
            save_best_only=False,
            saving_strategy=lambda epoch: True
        ),
        CSVLogger(paths.logs + name),
        LearningRateScheduler(schedule),
        SemiSupervisedUpdater(trainer, dset_train_unsupervised, start_epoch=6, momentum=0.25)
    ])

    trainer.compile(loss=nn.BCELoss(),
                    optimizer=optimizer)

    trainer.fit_loader(train_loader,
                       val_loader,
                       nb_epoch=16,
                       verbose=1,
                       cuda_device=0)