def _get_fmnist_dataset():
    train_set = FashionMNIST(expanduser("~") + "/.avalanche/data/fashionmnist/",
                             train=True, download=True)
    test_set = FashionMNIST(expanduser("~") + "/.avalanche/data/fashionmnist/",
                            train=False, download=True)
    return train_set, test_set
Exemple #2
0
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision.datasets import FashionMNIST, MNIST
from tqdm import tqdm

mnist = FashionMNIST('./data/',
                     download=True,
                     transform=torchvision.transforms.Compose([
                         torchvision.transforms.ToTensor(),
                     ]))

dataloader = DataLoader(mnist, batch_size=128, shuffle=True)


class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.conv1 = nn.Conv2d(1, 4, 5, padding=1)
        self.conv2 = nn.Conv2d(4, 16, 3)
        self.pool1 = nn.MaxPool2d(2)
        self.conv3 = nn.Conv2d(16, 32, 3)  # 32 x 10 x 10
        self.pool2 = nn.MaxPool2d(2)

        self.lin1 = nn.Linear(800, 400)
        self.lin2 = nn.Linear(400, 200)
        self.lin3_1 = nn.Linear(200, 32)
        self.lin3_2 = nn.Linear(200, 32)
Exemple #3
0
def main(
    dataset: str = "fashionmnist",
    initial_batch_size: int = 64,
    epochs: int = 6,
    verbose: Union[int, bool] = False,
    lr: float = 1.0,
    cuda: bool = False,
    random_state: Optional[int] = None,  # seed to pass to BaseDamper
    init_seed: Optional[int] = None,  # seed for initialization
    tuning: bool = True,  # tuning seed
    damper: str = "geodamp",
    batch_growth_rate: float = 0.01,
    dampingfactor: Number = 5.0,
    dampingdelay: int = 5,
    max_batch_size: Optional[int] = None,
    test_freq: float = 1,
    approx_loss: bool = False,
    rho: float = 0.9,
    dwell: int = 1,
    approx_rate: bool = False,
    model: Optional[str] = None,
    momentum: Optional[Union[float, int]] = 0,
    nesterov: bool = False,
    weight_decay: float = 0,
) -> Tuple[List[Dict], List[Dict]]:
    # Get (tuning, random_state, init_seed)
    assert int(tuning) or isinstance(tuning, bool)
    assert isinstance(random_state, int)
    assert isinstance(init_seed, int)

    if "NUM_THREADS" in os.environ:
        v = os.environ["NUM_THREADS"]
        if v:
            print(f"NUM_THREADS={v} (int(v)={int(v)})")
            torch.set_num_threads(int(v))

    args: Dict[str, Any] = {
        "initial_batch_size": initial_batch_size,
        "max_batch_size": max_batch_size,
        "batch_growth_rate": batch_growth_rate,
        "dampingfactor": dampingfactor,
        "dampingdelay": dampingdelay,
        "epochs": epochs,
        "verbose": verbose,
        "lr": lr,
        "no_cuda": not cuda,
        "random_state": random_state,
        "init_seed": init_seed,
        "damper": damper,
        "dataset": dataset,
        "approx_loss": approx_loss,
        "test_freq": test_freq,
        "rho": rho,
        "dwell": dwell,
        "approx_rate": approx_rate,
        "nesterov": nesterov,
        "momentum": momentum,
        "weight_decay": weight_decay,
    }
    pprint(args)

    no_cuda = not cuda
    args["ident"] = ident(args)
    args["tuning"] = tuning

    use_cuda = not args["no_cuda"] and torch.cuda.is_available()
    device = "cuda" if use_cuda else "cpu"
    _device = torch.device(device)
    _set_seed(args["init_seed"])

    transform_train = [
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.1307, ), std=(0.3081, )),
    ]
    transform_test = [
        transforms.ToTensor(),
        transforms.Normalize((0.1307, ), (0.3081, ))
    ]
    assert dataset in ["fashionmnist", "cifar10", "synthetic"]
    if dataset == "fashionmnist":
        _dir = "_traindata/fashionmnist/"
        train_set = FashionMNIST(
            _dir,
            train=True,
            transform=Compose(transform_train),
            download=True,
        )
        test_set = FashionMNIST(_dir,
                                train=False,
                                transform=Compose(transform_test))
        model = Net()
    elif dataset == "cifar10":
        transform_train = [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ]

        transform_test = [
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ]

        _dir = "_traindata/cifar10/"
        train_set = CIFAR10(
            _dir,
            train=True,
            transform=Compose(transform_train),
            download=True,
        )
        test_set = CIFAR10(_dir,
                           train=False,
                           transform=Compose(transform_test))
        if model == "wideresnet":
            model = WideResNet(16, 4, 0.3, 10)
        else:
            model = _get_resnet18()
    elif dataset == "synthetic":
        data_kwargs = {"n": 10_000, "d": 100}
        args.update(data_kwargs)
        train_set, test_set, data_stats = synth_dataset(**data_kwargs)
        args.update(data_stats)
        model = LinearNet(data_kwargs["d"])
    else:
        raise ValueError(
            f"dataset={dataset} not in ['fashionmnist', 'cifar10', 'synth']")
    if tuning:
        train_size = int(0.8 * len(train_set))
        test_size = len(train_set) - train_size

        train_set, test_set = random_split(
            train_set,
            [train_size, test_size],
            random_state=int(tuning),
        )
        train_x = [x.abs().sum().item() for x, _ in train_set]
        train_y = [y for _, y in train_set]
        test_x = [x.abs().sum().item() for x, _ in test_set]
        test_y = [y for _, y in test_set]
        data_stats = {
            "train_x_sum": sum(train_x),
            "train_y_sum": sum(train_y),
            "test_x_sum": sum(test_x),
            "test_y_sum": sum(test_y),
            "len_train_x": len(train_x),
            "len_train_y": len(train_y),
            "len_test_x": len(test_x),
            "len_test_y": len(test_y),
            "tuning": int(tuning),
        }
        args.update(data_stats)
        pprint(data_stats)

    model = model.to(_device)
    _set_seed(args["random_state"])

    if args["damper"] == "adagrad":
        optimizer = optim.Adagrad(model.parameters(), lr=args.get("lr", 0.01))
    elif args["damper"] == "adadelta":
        optimizer = optim.Adadelta(model.parameters(), rho=rho)
    else:
        if not args["nesterov"]:
            assert args["momentum"] == 0
        optimizer = optim.SGD(model.parameters(),
                              lr=args["lr"],
                              nesterov=args["nesterov"],
                              momentum=args["momentum"],
                              weight_decay=args["weight_decay"])
    n_data = len(train_set)

    opt_args = [model, train_set, optimizer]
    opt_kwargs = {
        k: args[k]
        for k in ["initial_batch_size", "max_batch_size", "random_state"]
    }
    opt_kwargs["device"] = device
    if dataset == "synthetic":
        opt_kwargs["loss"] = F.mse_loss
    if dataset == "cifar10":
        opt_kwargs["loss"] = F.cross_entropy
    if args["damper"].lower() == "padadamp":
        if approx_rate:
            assert isinstance(max_batch_size, int)
            BM = max_batch_size
            B0 = initial_batch_size
            e = epochs
            n = n_data
            r_hat = 4 / 3 * (BM - B0) * (B0 + 2 * BM + 3)
            r_hat /= (2 * BM - 2 * B0 + 3 * e * n)
            args["batch_growth_rate"] = r_hat

        opt = PadaDamp(
            *opt_args,
            batch_growth_rate=args["batch_growth_rate"],
            dwell=args["dwell"],
            **opt_kwargs,
        )
    elif args["damper"].lower() == "geodamp":
        opt = GeoDamp(
            *opt_args,
            dampingdelay=args["dampingdelay"],
            dampingfactor=args["dampingfactor"],
            **opt_kwargs,
        )
    elif args["damper"].lower() == "geodamplr":
        opt = GeoDampLR(
            *opt_args,
            dampingdelay=args["dampingdelay"],
            dampingfactor=args["dampingfactor"],
            **opt_kwargs,
        )
    elif args["damper"].lower() == "cntsdamplr":
        opt = CntsDampLR(
            *opt_args,
            dampingfactor=args["dampingfactor"],
            **opt_kwargs,
        )
    elif args["damper"].lower() == "adadamp":
        opt = AdaDamp(*opt_args,
                      approx_loss=approx_loss,
                      dwell=args["dwell"],
                      **opt_kwargs)
    elif args["damper"].lower() == "gd":
        opt = GradientDescent(*opt_args, **opt_kwargs)
    elif (args["damper"].lower() in ["adagrad", "adadelta", "sgd", "gd"]
          or args["damper"] is None):
        opt = BaseDamper(*opt_args, **opt_kwargs)
    else:
        raise ValueError("argument damper not recognized")
    if dataset == "synthetic":
        pprint(data_stats)
        opt._meta["best_train_loss"] = data_stats["best_train_loss"]

    data, train_data = experiment.run(
        model=model,
        opt=opt,
        train_set=train_set,
        test_set=test_set,
        args=args,
        test_freq=test_freq,
        train_stats=dataset == "synthetic",
        verbose=verbose,
        device="cuda" if use_cuda else "cpu",
    )
    return data, train_data
def train_model(
    epochs,
    batch_size,
    use_cuda,
    dset_folder,
    disable_tqdm=False,
):
    print("Reading dataset")
    dset = FashionMNIST(dset_folder, download=True)
    imgs = dset.data.unsqueeze(-1).numpy().astype(np.float64)
    labels = dset.targets.numpy()
    train_idx, valid_idx = map(np.array, util.split_dataset(labels))

    print("Processing images into graphs...", end="")
    ptime = time.time()
    with multiprocessing.Pool() as p:
        graphs = np.array(p.map(util.get_graph_from_image, imgs))
    del imgs
    ptime = time.time() - ptime
    print(" Took {ptime}s".format(ptime=ptime))

    model_args = []
    model_kwargs = {}
    model = GAT_MNIST(num_features=util.NUM_FEATURES,
                      num_classes=util.NUM_CLASSES)
    if use_cuda:
        model = model.cuda()

    opt = torch.optim.Adam(model.parameters())

    best_valid_acc = 0.
    best_model = copy.deepcopy(model)

    last_epoch_train_loss = 0.
    last_epoch_train_acc = 0.
    last_epoch_valid_acc = 0.

    interrupted = False
    for e in tqdm(
            range(epochs),
            total=epochs,
            desc="Epoch ",
            disable=disable_tqdm,
    ):
        try:
            train_losses, train_accs = util.train(
                model,
                opt,
                graphs,
                labels,
                train_idx,
                batch_size=batch_size,
                use_cuda=use_cuda,
                disable_tqdm=disable_tqdm,
            )

            last_epoch_train_loss = np.mean(train_losses)
            last_epoch_train_acc = 100 * np.mean(train_accs)
        except KeyboardInterrupt:
            print("Training interrupted!")
            interrupted = True

        valid_accs = util.test(
            model,
            graphs,
            labels,
            valid_idx,
            use_cuda,
            desc="Validation ",
            disable_tqdm=disable_tqdm,
        )

        last_epoch_valid_acc = 100 * np.mean(valid_accs)

        if last_epoch_valid_acc > best_valid_acc:
            best_valid_acc = last_epoch_valid_acc
            best_model = copy.deepcopy(model)

        tqdm.write("EPOCH SUMMARY {loss:.4f} {t_acc:.2f}% {v_acc:.2f}%".format(
            loss=last_epoch_train_loss,
            t_acc=last_epoch_train_acc,
            v_acc=last_epoch_valid_acc))

        if interrupted:
            break

    util.save_model("best", best_model)
    util.save_model("last", model)
Exemple #5
0
def get_data(flag=True):
    mnist = FashionMNIST('datasets/fashionmnist/', train=flag, transform=transforms.ToTensor(), download=flag)
    loader = torch.utils.data.DataLoader(mnist, batch_size=config['batch_size'], shuffle=flag, drop_last=False)
    return loader
Exemple #6
0
# coding: utf-8

# In[69]:

import torch
import torchvision

from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader
from torchvision import transforms

# transform torchvision dataset images from PILImage to tensor for input to CNN
data_transform = transforms.ToTensor()

train_data = FashionMNIST(root='./data',
                          train=True,
                          download=True,
                          transform=data_transform)

test_data = FashionMNIST(root='./data',
                         train=False,
                         download=True,
                         transform=data_transform)

print('Train data, number of images: ', len(train_data))
print('Test data, number of images: ', len(test_data))

# In[70]:

batch_size = 20

train_loader = DataLoader(dataset=train_data,
def get_loader(args):
    train_data_loader = None
    test_data_loader = None

    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}

    train_triplets = []
    test_triplets = []

    dset_obj = None
    loader = BaseLoader
    means = (0.485, 0.456, 0.406)
    stds = (0.229, 0.224, 0.225)

    if args.dataset == 'vggface2':
        dset_obj = vggface2.VGGFace2()
    elif args.dataset == 'custom':
        dset_obj = custom_dset.Custom()
    elif (args.dataset == 'mnist') or (args.dataset == 'fmnist'):
        train_dataset, test_dataset = None, None
        if args.dataset == 'mnist':
            train_dataset = MNIST(os.path.join(args.result_dir, "MNIST"),
                                  train=True,
                                  download=True)
            test_dataset = MNIST(os.path.join(args.result_dir, "MNIST"),
                                 train=False,
                                 download=True)
        if args.dataset == 'fmnist':
            train_dataset = FashionMNIST(os.path.join(args.result_dir,
                                                      "FashionMNIST"),
                                         train=True,
                                         download=True)
            test_dataset = FashionMNIST(os.path.join(args.result_dir,
                                                     "FashionMNIST"),
                                        train=False,
                                        download=True)
        dset_obj = mnist.MNIST_DS(train_dataset, test_dataset)
        loader = TripletMNISTLoader
        means = (0.485, )
        stds = (0.229, )

    dset_obj.load()
    for i in range(args.num_train_samples):
        pos_anchor_img, pos_img, neg_img = dset_obj.getTriplet()
        train_triplets.append([pos_anchor_img, pos_img, neg_img])
    for i in range(args.num_test_samples):
        pos_anchor_img, pos_img, neg_img = dset_obj.getTriplet(split='test')
        test_triplets.append([pos_anchor_img, pos_img, neg_img])

    train_data_loader = torch.utils.data.DataLoader(loader(
        train_triplets,
        transform=transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize(means, stds)])),
                                                    batch_size=args.batch_size,
                                                    shuffle=True,
                                                    **kwargs)
    test_data_loader = torch.utils.data.DataLoader(loader(
        test_triplets,
        transform=transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize(means, stds)])),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)

    return train_data_loader, test_data_loader
Exemple #8
0
    x = F.relu(self.conv5(x))
    x = self.pool(F.relu(self.conv6(x)))

    x = F.relu(self.conv7(x))
    x = self.pool(F.relu(self.conv8(x)))

    x = x.view(-1, 7*7*512)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x
        
batch_size = 32
transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(), transforms.Normalize((0.5),(0.5))])

fashion_mnist_trainval = FashionMNIST("FashionMNIST", train=True, download=True, transform=transform)
fashion_mnist_test = FashionMNIST("FashionMNIST", train=False, download=True, transform=transform)

n_samples = len(fashion_mnist_trainval) 
train_size = int(len(fashion_mnist_trainval) * 0.8) 
val_size = n_samples - train_size 

train_dataset, val_dataset = torch.utils.data.random_split(fashion_mnist_trainval, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(fashion_mnist_test, batch_size=batch_size, shuffle=True)

net = VGG()
net.to(device)
def downloader_construct_datasetsdict(datasets_list: list,
                                      grayscale=False) -> dict:
    """
    This function takes in a list of datasets to be used in the experiments
    """
    print(f"INFO ------ List of datasets being loaded are {datasets_list}")

    datasets_dict = {}
    if "CIFAR10" in datasets_list:
        if not grayscale:
            cifar_train_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, 4),
            ])
            cifar_test_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
            ])
        else:
            cifar_train_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
                transforms.Grayscale(3),
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, 4),
            ])
            cifar_test_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
                transforms.Grayscale(3),
            ])
        datasets_dict["CIFAR10_train"] = CIFAR10(
            root=r"./dataset/CHIFAR10/",
            train=True,
            download=True,
            transform=cifar_train_transform,
        )
        datasets_dict["CIFAR10_test"] = CIFAR10(
            root=r"./dataset/CHIFAR10/",
            train=False,
            download=True,
            transform=cifar_test_transform,
        )

        print("INFO ----- Dataset Loaded : CIFAR10")
        datasets_list.remove("CIFAR10")

    if "A_MNIST" in datasets_list:
        mnist_transforms = transforms.Compose([
            transforms.Pad(2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.1307], std=[0.3015]),
            transforms.Lambda(tmp_func),
            transforms.RandomCrop(32, 4),
        ])
        datasets_dict["A_MNIST_train"] = MNIST(
            root=r"./dataset/MNIST",
            train=True,
            download=True,
            transform=mnist_transforms,
        )

        datasets_dict["A_MNIST_test"] = MNIST(
            root=r"./dataset/MNIST",
            train=False,
            download=True,
            transform=mnist_transforms,
        )

        print("INFO ----- Dataset Loaded : MNIST")
        datasets_list.remove("A_MNIST")

    if "A_FashionMNIST" in datasets_list:
        fmnist_transforms = transforms.Compose([
            transforms.Pad(2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.2860], std=[0.3205]),
            transforms.Lambda(tmp_func),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
        ])

        datasets_dict["A_FashionMNIST_train"] = FashionMNIST(
            root="./dataset/FashionMNIST",
            train=True,
            download=True,
            transform=fmnist_transforms,
        )

        datasets_dict["A_FashionMNIST_test"] = FashionMNIST(
            root="./dataset/FashionMNIST",
            train=False,
            download=True,
            transform=fmnist_transforms,
        )

        print("INFO ----- Dataset Loaded : FashionMNIST")
        datasets_list.remove("A_FashionMNIST")

    if "A_SVHN" in datasets_list:
        SVHN_transforms = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Resize(32)])

        datasets_dict["A_SVHN_train"] = SVHN(
            root=r"./dataset/SVHN",
            split="train",
            download=True,
            transform=SVHN_transforms,
        )
        datasets_dict["A_SVHN_train"].targets = datasets_dict[
            "A_SVHN_train"].labels
        datasets_dict["A_SVHN_test"] = SVHN(
            root=r"./dataset/SVHN",
            split="test",
            download=True,
            transform=SVHN_transforms,
        )

        datasets_dict["A_SVHN_test"].targets = datasets_dict[
            "A_SVHN_test"].labels
        print("INFO ----- Dataset Loaded : SVHN")
        datasets_list.remove("A_SVHN")

    if "CIFAR100" in datasets_list:
        datasets_dict["CIFAR100_train"] = CIFAR100(
            root=r"./dataset/CIFAR100",
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5071, 0.4865, 0.4409],
                                     std=[0.2009, 0.1984, 0.2023]),
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, 4),
            ]),
        )

        datasets_dict["CIFAR100_test"] = CIFAR100(
            root=r"./dataset/CIFAR100",
            train=False,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5071, 0.4865, 0.4409],
                                     std=[0.2009, 0.1984, 0.2023]),
            ]),
        )
        print("INFO ----- Dataset Loaded : CIFAR100")
        datasets_list.remove("CIFAR100")

    if "A_CIFAR10_ood" in datasets_list:

        cifar_train_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
        ])
        cifar_test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])

        datasets_dict["A_CIFAR10_ood_train"] = CIFAR10(
            root=r"./dataset/CHIFAR10/",
            train=True,
            download=True,
            transform=cifar_train_transform,
        )
        datasets_dict["A_CIFAR10_ood_test"] = CIFAR10(
            root=r"./dataset/CHIFAR10/",
            train=False,
            download=True,
            transform=cifar_test_transform,
        )

        print("INFO ----- Dataset Loaded : CIFAR10_ood")
        datasets_list.remove("A_CIFAR10_ood")

    if "A_CIFAR100_ood" in datasets_list:
        datasets_dict["A_CIFAR100_ood_train"] = CIFAR100(
            root=r"./dataset/CIFAR100",
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(32, 4),
            ]),
        )

        datasets_dict["A_CIFAR100_ood_test"] = CIFAR100(
            root=r"./dataset/CIFAR100",
            train=False,
            download=True,
            transform=transforms.ToTensor(),
        )

        print("INFO ----- Dataset Loaded : CIFAR100_ood")
        datasets_list.remove("A_CIFAR100_ood")

    assert (
        len(datasets_list) == 0
    ), f"Not all datasets have been loaded, datasets left : {datasets_list}"

    return datasets_dict
Exemple #10
0
    [transforms.ToTensor(),
     transforms.Normalize((0.5, ), (0.5, ))])

if args.dataset == 'mnist':
    train_data = MNIST(root='data/mnist',
                       train=True,
                       download=True,
                       transform=transform)
    test_data = MNIST(root='data/mnist',
                      train=False,
                      download=True,
                      transform=transform)

if args.dataset == 'fashion':
    train_data = FashionMNIST(root='data/fashion',
                              train=True,
                              download=True,
                              transform=transform)
    test_data = FashionMNIST(root='data/fashion',
                             train=False,
                             download=True,
                             transform=transform)

# BiGAN params
z_dim = args.z_dim
hid_dim = args.hid_dim

# Train params
use_cuda = args.use_cuda and torch.cuda.is_available()
device = torch.device("cuda") if use_cuda else torch.device('cpu')
n_epochs = args.n_epochs
batch_size = args.batch_size
Exemple #11
0

if __name__ == '__main__':
    np.random.seed(1234)
    torch.manual_seed(1234)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    '''
    データの読み込み
    '''

    d_dir = 'gdrive/MyDrive/ProjectExperiment/Datasets/'

    #そのままだとPIL(Python Imaging Library)の画像形式でDatasetを
    #を作ってしまうのでtransforms.toTensorでTensorに変換
    fashion_mnist_train = FashionMNIST(d_dir,
                                       train=True,
                                       download=True,
                                       transform=transforms.ToTensor())
    fashion_mnist_test = FashionMNIST(d_dir,
                                      train=False,
                                      download=True,
                                      transform=transforms.ToTensor())

    #バッチサイズが128のDataLoaderを作成
    #データローダーはミニバッチを作成するため
    batch_size = 128
    train_dataloader = DataLoader(fashion_mnist_train,
                                  batch_size=batch_size,
                                  shuffle=True)
    test_dataloader = DataLoader(fashion_mnist_test,
                                 batch_size=batch_size,
                                 shuffle=False)
    return avg_loss.avg


if __name__ == "__main__":

    number_epochs = 100

    device = torch.device(
        'cpu'
    )  # Replace with torch.device("cuda:0") if you want to train on GPU

    model = MLP(10).to(device)

    trans_img = transforms.Compose([transforms.ToTensor()])
    dataset = FashionMNIST("./data/",
                           train=True,
                           transform=trans_img,
                           download=True)
    trainloader = DataLoader(dataset, batch_size=1024, shuffle=True)

    optimizer = optim.Adam(model.parameters(), lr=0.01)

    track_loss = []
    for i in tqdm(range(number_epochs)):
        loss = train_one_epoch(model, trainloader, optimizer, device)
        track_loss.append(loss)

    plt.figure()
    plt.plot(track_loss)
    plt.title("training-loss-MLP")
    plt.savefig("./img/training_mlp.jpg")
def get_data_manager(
    indistribution=["Cifar10"],
    ood=["MNIST", "Fashion_MNIST"],
):
    """get_data_manager [Creates a data_manager instance with the In-/Out-of-Distribution Data]

    [List based processing of Datasets. Images are resized / croped on 32x32]

    Args:
        indistribution (list, optional): [description]. Defaults to ["Cifar10"].
        ood (list, optional): [description]. Defaults to ["MNIST", "Fashion_MNIST", "SVHN"].

    Returns:
        [datamager]: [Experiment data_manager for for logging and the active learning cycle]
    """

    # TODO ADD Target transform?
    base_data = np.empty(shape=(1, 3, 32, 32))
    base_data_test = np.empty(shape=(1, 3, 32, 32))
    base_labels = np.empty(shape=(1,))
    base_labels_test = np.empty(shape=(1,))

    OOD_data = np.empty(shape=(1, 3, 32, 32))
    OOD_labels = np.empty(shape=(1,))

    resize = transforms.Resize(32)
    random_crop = transforms.RandomCrop(32)

    standard_transform = transforms.Compose(
        [
            transforms.ToTensor(),
            resize,
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )
    if debug:
        tracemalloc.start()
        snapshot = tracemalloc.take_snapshot()
        display_top(snapshot)

    for dataset in indistribution:
        if dataset == "Cifar10":
            CIFAR10_train = CIFAR10(
                root=r"/dataset/CHIFAR10/",
                train=True,
                download=True,
                transform=transforms.ToTensor(),
            )
            CIFAR10_test = CIFAR10(
                root=r"/dataset/CHIFAR10/",
                train=False,
                download=True,
                transform=transforms.ToTensor(),
            )
            # CIFAR10_train_data = CIFAR10_train.data.permute(
            #     0, 3, 1, 2
            # )  # .reshape(-1, 3, 32, 32)
            # CIFAR10_test_data = CIFAR10_test.data.permute(
            #     0, 3, 1, 2
            # )  # .reshape(-1, 3, 32, 32)

            CIFAR10_train_data = np.array([i.numpy() for i, _ in CIFAR10_train])
            CIFAR10_test_data = np.array([i.numpy() for i, _ in CIFAR10_test])

            CIFAR10_train_labels = np.array(CIFAR10_train.targets)
            CIFAR10_test_labels = np.array(CIFAR10_test.targets)

            base_data = np.concatenate(
                [base_data.copy(), CIFAR10_train_data.copy()],
                axis=0,
            )

            base_data_test = np.concatenate(
                [base_data_test.copy(), CIFAR10_test_data.copy()]
            )

            base_labels = np.concatenate(
                [
                    base_labels.copy(),
                    CIFAR10_train_labels.copy(),
                ],
                axis=0,
            )
            base_labels_test = np.concatenate(
                [
                    base_labels_test.copy(),
                    CIFAR10_test_labels.copy(),
                ]
            )

            del (
                CIFAR10_train_data,
                CIFAR10_test_data,
                CIFAR10_train_labels,
                CIFAR10_test_labels,
                CIFAR10_train,
                CIFAR10_test,
            )
            gc.collect()
        elif dataset == "MNIST":

            MNIST_train = MNIST(
                root=r"/dataset/MNIST",
                train=True,
                download=True,
                transform=transforms.Compose(
                    [
                        transforms.Pad(2),
                        transforms.Grayscale(3),
                        transforms.ToTensor(),
                    ]
                ),
            )
            MNIST_test = MNIST(
                root=r"/dataset/MNIST",
                train=False,
                download=True,
                transform=transforms.Compose(
                    [
                        transforms.Pad(2),
                        transforms.ToTensor(),
                        transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
                    ]
                ),
            )
            MNIST_train_data = np.array([i.numpy() for i, _ in MNIST_train])
            MNIST_test_data = np.array([i.numpy() for i, _ in MNIST_test])
            if len(dataset) > 1:
                MNIST_train_labels = MNIST_train.targets + np.max(base_labels)
                MNIST_test_labels = MNIST_test.targets + np.max(base_labels)
            else:
                MNIST_train_labels = MNIST_train.targets
                MNIST_test_labels = MNIST_test.targets

            base_data = np.concatenate([base_data.copy(), MNIST_train_data.copy()])

            base_data_test = np.concatenate(
                [base_data_test.copy(), MNIST_test_labels.copy()]
            )

            base_labels = np.concatenate(
                [
                    base_labels.copy(),
                    MNIST_train_labels.copy(),
                ]
            )

            base_labels_test = np.concatenate(
                [
                    base_labels_test.copy(),
                    MNIST_test_labels.copy(),
                ]
            )
            del (
                MNIST_train,
                MNIST_test,
                MNIST_train_data,
                MNIST_test_data,
                MNIST_train_labels,
                MNIST_test_labels,
            )
            gc.collect()
        elif dataset == "Fashion_MNIST":

            Fashion_MNIST_train = FashionMNIST(
                root="/dataset/FashionMNIST",
                train=True,
                download=True,
                transform=transforms.Compose(
                    [
                        transforms.Pad(2),
                        transforms.ToTensor(),
                        transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
                    ]
                ),
            )
            Fashion_MNIST_test = FashionMNIST(
                root="/dataset/FashionMNIST",
                train=False,
                download=True,
                transform=transforms.Compose(
                    [
                        transforms.Pad(2),
                        transforms.Grayscale(3),
                        transforms.ToTensor(),
                    ]
                ),
            )
            Fashion_MNIST_train_data = np.array(
                [i.numpy() for i, _ in Fashion_MNIST_train]
            )
            Fashion_MNIST_test_data = np.array(
                [i.numpy() for i, _ in Fashion_MNIST_test]
            )

            if len(dataset) > 1:
                Fashion_MNIST_train_labels = (
                    Fashion_MNIST_train.targets.numpy() + np.max(base_labels)
                )
                Fashion_MNIST_test_labels = Fashion_MNIST_test.targets.numpy() + np.max(
                    base_labels
                )
            else:
                Fashion_MNIST_train_labels = Fashion_MNIST_train.targets.numpy()
                Fashion_MNIST_test_labels = Fashion_MNIST_test.targets.numpy()

            base_data = np.concatenate(
                [base_data.copy(), Fashion_MNIST_train_data.copy()]
            )

            base_data_test = np.concatenate(
                [base_data_test.copy(), Fashion_MNIST_test_data.copy()]
            )

            base_labels = np.concatenate(
                [
                    base_labels.copy(),
                    Fashion_MNIST_train_labels.copy(),
                ]
            )
            base_labels_test = np.concatenate(
                [
                    base_labels_test.copy(),
                    Fashion_MNIST_test_labels.copy(),
                ]
            )
            del (
                Fashion_MNIST_train,
                Fashion_MNIST_test,
                Fashion_MNIST_train_data,
                Fashion_MNIST_test_data,
                Fashion_MNIST_train_labels,
                Fashion_MNIST_test_labels,
            )
            gc.collect()
    for ood_dataset in ood:
        if ood_dataset == "MNIST":

            MNIST_train = MNIST(
                root=r"/dataset/MNIST",
                train=True,
                download=True,
                transform=transforms.Compose(
                    [
                        transforms.Pad(2),
                        transforms.Grayscale(3),
                        transforms.ToTensor(),
                    ]
                ),
            )
            MNIST_test = MNIST(
                root=r"/dataset/MNIST",
                train=False,
                download=True,
                transform=transforms.Compose(
                    [
                        transforms.Pad(2),
                        transforms.Grayscale(3),
                        transforms.ToTensor(),
                    ]
                ),
            )

            MNIST_train_data = np.array([i.numpy() for i, _ in MNIST_train])
            MNIST_test_data = np.array([i.numpy() for i, _ in MNIST_test])

            MNIST_train_labels = MNIST_train.targets.numpy()
            MNIST_test_labels = MNIST_test.targets.numpy()
            OOD_data = np.concatenate(
                [OOD_data.copy(), MNIST_train_data.copy(), MNIST_test_data.copy()],
                axis=0,
            )
            OOD_labels = np.concatenate(
                [OOD_labels.copy(), MNIST_train_labels.copy(), MNIST_test_labels.copy()]
            )

            del (
                MNIST_train,
                MNIST_test,
                MNIST_train_data,
                MNIST_test_data,
                MNIST_train_labels,
                MNIST_test_labels,
            )
            gc.collect()
        elif ood_dataset == "Fashion_MNIST":

            Fashion_MNIST_train = FashionMNIST(
                root="/dataset/FashionMNIST",
                train=True,
                download=True,
                transform=transforms.Compose(
                    [
                        transforms.Pad(2),
                        transforms.Grayscale(3),
                        transforms.ToTensor(),
                    ]
                ),
            )
            Fashion_MNIST_test = FashionMNIST(
                root="/dataset/FashionMNIST",
                train=False,
                download=True,
                transform=transforms.Compose(
                    [
                        transforms.Pad(2),
                        transforms.Grayscale(3),
                        transforms.ToTensor(),
                    ]
                ),
            )
            Fashion_MNIST_train_data = np.array(
                [i.numpy() for i, _ in Fashion_MNIST_train]
            )
            Fashion_MNIST_test_data = np.array(
                [i.numpy() for i, _ in Fashion_MNIST_test]
            )
            Fashion_MNIST_train_labels = Fashion_MNIST_train.targets.numpy()
            Fashion_MNIST_test_labels = Fashion_MNIST_test.targets.numpy()

            OOD_data = np.concatenate(
                [
                    OOD_data.copy(),
                    Fashion_MNIST_train_data.copy(),
                    Fashion_MNIST_test_data.copy(),
                ],
                axis=0,
            )
            OOD_labels = np.concatenate(
                [
                    OOD_labels.copy(),
                    Fashion_MNIST_train_labels.copy(),
                    Fashion_MNIST_test_labels.copy(),
                ],
            )
            del (
                Fashion_MNIST_train,
                Fashion_MNIST_test,
                Fashion_MNIST_train_data,
                Fashion_MNIST_test_data,
                Fashion_MNIST_train_labels,
                Fashion_MNIST_test_labels,
            )
            gc.collect()
        elif ood_dataset == "SVHN":
            SVHN_train = SVHN(
                root=r"/dataset/SVHN",
                split="train",
                download=True,
                transform=standard_transform,
            )
            SVHN_test = SVHN(
                root=r"/dataset/SVHN",
                split="test",
                download=True,
                transform=standard_transform,
            )
            SVHN_train_data = SVHN_train.data
            SVHN_test_data = SVHN_test.data
            SVHN_train_labels = SVHN_train.labels
            SVHN_test_labels = SVHN_test.labels

            OOD_data = np.concatenate(
                [OOD_data.copy(), SVHN_train_data.copy(), SVHN_test_data.copy()], axis=0
            )
            OOD_labels = np.concatenate(
                [OOD_labels.copy(), SVHN_train_labels.copy(), SVHN_test_labels.copy()]
            )

            del (
                SVHN_train,
                SVHN_test,
                SVHN_train_data,
                SVHN_test_data,
                SVHN_train_labels,
                SVHN_test_labels,
            )
            gc.collect()
        # elif ood_dataset == "TinyImageNet":
        #     if not os.listdir(os.path.join(r"./dataset/tiny-imagenet-200")):
        #         download_and_unzip()
        #     id_dict = {}
        #     for i, line in enumerate(
        #         open(
        #             os.path.join(
        #                 r"\dataset\tiny-imagenet-200\tiny-imagenet-200\wnids.txt"
        #             ),
        #             "r",
        #         )
        #     ):
        #         id_dict[line.replace("\n", "")] = i
        #     normalize_imagenet = transforms.Normalize(
        #         (122.4786, 114.2755, 101.3963), (70.4924, 68.5679, 71.8127)
        #     )
        #     train_t_imagenet = TrainTinyImageNetDataset(
        #         id=id_dict, transform=transforms.Compose([normalize_imagenet, resize])
        #     )
        #     test_t_imagenet = TestTinyImageNetDataset(
        #         id=id_dict, transform=transforms.Compose([normalize_imagenet, resize])
        #     )

    if debug:
        snapshot = tracemalloc.take_snapshot()
        display_top(snapshot)

    base_data = np.delete(base_data, 0, axis=0)
    base_data_test = np.delete(base_data_test, 0, axis=0)
    base_labels = np.delete(base_labels, 0)
    base_labels_test = np.delete(base_labels_test, 0)
    OOD_data = np.delete(OOD_data, 0, axis=0)
    OOD_labels = np.delete(OOD_labels, 0)

    print(base_data.shape, base_data_test.shape, OOD_data.shape, OOD_labels.shape)

    data_manager = Data_manager(
        base_data=base_data,
        base_labels=base_labels,
        base_data_test=base_data_test,
        base_labels_test=base_labels_test,
        OOD_data=OOD_data,
        OOD_labels=OOD_labels,
    )
    # del (base_data, base_labels, OOD_data, OOD_labels)

    gc.collect()
    if debug:
        snapshot = tracemalloc.take_snapshot()
        display_top(snapshot)
    return data_manager