示例#1
0
def load_FashionMNIST(dir,
                      use_validation=True,
                      val_ratio=0.2,
                      train_transforms=DEFAULT_TRANSFORM,
                      test_transforms=DEFAULT_TRANSFORM,
                      pin_memory=True,
                      batch_size=128,
                      num_workers=1):
    path = dir + '/data/FashionMNIST'
    train_set = FashionMNIST(path,
                             train=True,
                             download=True,
                             transform=train_transforms)
    test_set = FashionMNIST(path,
                            train=False,
                            download=True,
                            transform=test_transforms)

    if use_validation:
        val_size = int(val_ratio * len(train_set))
        train_set.data = train_set.data[:-val_size]
        train_set.targets = train_set.targets[:-val_size]
        val_set = FashionMNIST(path,
                               train=True,
                               download=True,
                               transform=test_transforms)
        val_set.train = False
        val_set.data = val_set.data[-val_size:]
        val_set.targets = val_set.targets[-val_size:]

    train_loader = data.DataLoader(train_set,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   num_workers=num_workers,
                                   pin_memory=pin_memory)

    valid_loader = None
    if use_validation:
        valid_loader = data.DataLoader(val_set,
                                       batch_size=batch_size,
                                       shuffle=False,
                                       num_workers=num_workers,
                                       pin_memory=pin_memory)

    test_loader = data.DataLoader(test_set,
                                  batch_size=batch_size,
                                  shuffle=False,
                                  num_workers=num_workers)

    return {'train': train_loader, 'valid': valid_loader, 'test': test_loader}
示例#2
0
    def setup(self, stage: str = None):
        # Transforms
        mnist_transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([self.x_mean], [self.x_std])
        ])

        # Train
        train_val = FashionMNIST(
            os.path.dirname(__file__),
            download=True,
            train=True,
            transform=mnist_transforms,
        )

        idx = torch.cat(
            [train_val.targets[:, None] == digit for digit in self.digits],
            dim=1).any(dim=1)
        train_val.targets = train_val.targets[idx]
        train_val.data = train_val.data[idx]

        train_length = int(len(train_val) * self.train_val_split)
        val_length = len(train_val) - train_length
        self.train_dataset, self.val_dataset = random_split(
            train_val, [train_length, val_length])

        # Test
        self.test_dataset = FashionMNIST(
            os.path.dirname(__file__),
            download=True,
            train=False,
            transform=mnist_transforms,
        )
        idx = torch.cat(
            [
                self.test_dataset.targets[:, None] == digit
                for digit in self.digits
            ],
            dim=1,
        ).any(dim=1)
        self.test_dataset.targets = self.test_dataset.targets[idx]
        self.test_dataset.data = self.test_dataset.data[idx]
from source.fashion_mnist_cnn import LeNet5
from torchvision.datasets import FashionMNIST
from torchvision.transforms import transforms
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
import argparse

parser = argparse.ArgumentParser(description="Test LeNet5")
parser.add_argument('--ckpt_path')
parser.add_argument('--data_dir')
parser.add_argument('--gpu')
args = parser.parse_args()

classifier = LeNet5.load_from_checkpoint(args.ckpt_path, None, None, 5, 32, 0)

test_data = FashionMNIST(args.data_dir,
                         train=False,
                         download=False,
                         transform=transforms.Compose([
                             transforms.Resize((32, 32)),
                             transforms.ToTensor()
                         ]))
test_data.data = test_data.data[1000:]
test_data.targets = test_data.targets[1000:]

test_loader = DataLoader(test_data, batch_size=8)

trainer = Trainer(gpus=args.gpu)
trainer.test(classifier, test_dataloaders=test_loader)
                          download=False,
                          transform=transforms.Compose([
                              transforms.Resize((32, 32)),
                              transforms.ToTensor()
                          ]))

data_val = FashionMNIST(args.data_dir,
                        train=False,
                        download=False,
                        transform=transforms.Compose([
                            transforms.Resize((32, 32)),
                            transforms.ToTensor()
                        ]))

data_val.data = data_val.data[:1000]
data_val.targets = data_val.targets[:1000]


# define the optimizing process
def objective(trial: optuna.Trial):
    # Filenames for each trial must be made unique in order to access each checkpoint.
    checkpoint_callback = pl.callbacks.ModelCheckpoint(os.path.join(
        MODEL_DIR, "trial_{}".format(trial.number), "{epoch}"),
                                                       monitor="val_acc")

    # The default logger in PyTorch Lightning writes to event files to be consumed by
    # TensorBoard. We don't use any logger here as it requires us to implement several abstract
    # methods. Instead we setup a simple callback, that saves metrics from each validation step.
    metrics_callback = MetricsCallback()
    trainer = pl.Trainer(logger=True,
                         checkpoint_callback=checkpoint_callback,
示例#5
0
    avg_loss /= len(data_test)
    acc = float(correct_num) / len(data_test)
    global best_acc
    if acc > best_acc:
        best_acc = acc
    print("Mean validate loss:{:.4f}, Acc: {:.4f} (max {:.4f})".format(avg_loss, acc, best_acc))
    writer.add_scalar('acc', acc, cur_epoch)
    writer.flush()


def main(use_super_loss):
    for it in range(100):
        train(it, use_super_loss)
        validate(it)


if __name__ == '__main__':
    if len(sys.argv) < 3:
        raise ValueError("Bad usage")
    sl = int(sys.argv[1]) > 0
    noise_rate = float(sys.argv[2])
    if noise_rate > 0.0:
        print("add noise to labels")
        data_train.targets = torch.tensor(data_train.targets)
        for i in range(len(data_train)):
            if torch.rand(()) < noise_rate:
                data_train.targets[i] = torch.randint(0, 10, ())
    print("Use SuperLoss: {}".format(sl))
    print("Noise rate: {}".format(noise_rate))
    main(sl)