def _get_fmnist_dataset(): train_set = FashionMNIST(expanduser("~") + "/.avalanche/data/fashionmnist/", train=True, download=True) test_set = FashionMNIST(expanduser("~") + "/.avalanche/data/fashionmnist/", train=False, download=True) return train_set, test_set
import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader import torchvision from torchvision.datasets import FashionMNIST, MNIST from tqdm import tqdm mnist = FashionMNIST('./data/', download=True, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), ])) dataloader = DataLoader(mnist, batch_size=128, shuffle=True) class VAE(nn.Module): def __init__(self): super(VAE, self).__init__() self.conv1 = nn.Conv2d(1, 4, 5, padding=1) self.conv2 = nn.Conv2d(4, 16, 3) self.pool1 = nn.MaxPool2d(2) self.conv3 = nn.Conv2d(16, 32, 3) # 32 x 10 x 10 self.pool2 = nn.MaxPool2d(2) self.lin1 = nn.Linear(800, 400) self.lin2 = nn.Linear(400, 200) self.lin3_1 = nn.Linear(200, 32) self.lin3_2 = nn.Linear(200, 32)
def main( dataset: str = "fashionmnist", initial_batch_size: int = 64, epochs: int = 6, verbose: Union[int, bool] = False, lr: float = 1.0, cuda: bool = False, random_state: Optional[int] = None, # seed to pass to BaseDamper init_seed: Optional[int] = None, # seed for initialization tuning: bool = True, # tuning seed damper: str = "geodamp", batch_growth_rate: float = 0.01, dampingfactor: Number = 5.0, dampingdelay: int = 5, max_batch_size: Optional[int] = None, test_freq: float = 1, approx_loss: bool = False, rho: float = 0.9, dwell: int = 1, approx_rate: bool = False, model: Optional[str] = None, momentum: Optional[Union[float, int]] = 0, nesterov: bool = False, weight_decay: float = 0, ) -> Tuple[List[Dict], List[Dict]]: # Get (tuning, random_state, init_seed) assert int(tuning) or isinstance(tuning, bool) assert isinstance(random_state, int) assert isinstance(init_seed, int) if "NUM_THREADS" in os.environ: v = os.environ["NUM_THREADS"] if v: print(f"NUM_THREADS={v} (int(v)={int(v)})") torch.set_num_threads(int(v)) args: Dict[str, Any] = { "initial_batch_size": initial_batch_size, "max_batch_size": max_batch_size, "batch_growth_rate": batch_growth_rate, "dampingfactor": dampingfactor, "dampingdelay": dampingdelay, "epochs": epochs, "verbose": verbose, "lr": lr, "no_cuda": not cuda, "random_state": random_state, "init_seed": init_seed, "damper": damper, "dataset": dataset, "approx_loss": approx_loss, "test_freq": test_freq, "rho": rho, "dwell": dwell, "approx_rate": approx_rate, "nesterov": nesterov, "momentum": momentum, "weight_decay": weight_decay, } pprint(args) no_cuda = not cuda args["ident"] = ident(args) args["tuning"] = tuning use_cuda = not args["no_cuda"] and torch.cuda.is_available() device = "cuda" if use_cuda else "cpu" _device = torch.device(device) _set_seed(args["init_seed"]) transform_train = [ transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=(0.1307, ), std=(0.3081, )), ] transform_test = [ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ] assert dataset in ["fashionmnist", "cifar10", "synthetic"] if dataset == "fashionmnist": _dir = "_traindata/fashionmnist/" train_set = FashionMNIST( _dir, train=True, transform=Compose(transform_train), download=True, ) test_set = FashionMNIST(_dir, train=False, transform=Compose(transform_test)) model = Net() elif dataset == "cifar10": transform_train = [ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ] transform_test = [ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ] _dir = "_traindata/cifar10/" train_set = CIFAR10( _dir, train=True, transform=Compose(transform_train), download=True, ) test_set = CIFAR10(_dir, train=False, transform=Compose(transform_test)) if model == "wideresnet": model = WideResNet(16, 4, 0.3, 10) else: model = _get_resnet18() elif dataset == "synthetic": data_kwargs = {"n": 10_000, "d": 100} args.update(data_kwargs) train_set, test_set, data_stats = synth_dataset(**data_kwargs) args.update(data_stats) model = LinearNet(data_kwargs["d"]) else: raise ValueError( f"dataset={dataset} not in ['fashionmnist', 'cifar10', 'synth']") if tuning: train_size = int(0.8 * len(train_set)) test_size = len(train_set) - train_size train_set, test_set = random_split( train_set, [train_size, test_size], random_state=int(tuning), ) train_x = [x.abs().sum().item() for x, _ in train_set] train_y = [y for _, y in train_set] test_x = [x.abs().sum().item() for x, _ in test_set] test_y = [y for _, y in test_set] data_stats = { "train_x_sum": sum(train_x), "train_y_sum": sum(train_y), "test_x_sum": sum(test_x), "test_y_sum": sum(test_y), "len_train_x": len(train_x), "len_train_y": len(train_y), "len_test_x": len(test_x), "len_test_y": len(test_y), "tuning": int(tuning), } args.update(data_stats) pprint(data_stats) model = model.to(_device) _set_seed(args["random_state"]) if args["damper"] == "adagrad": optimizer = optim.Adagrad(model.parameters(), lr=args.get("lr", 0.01)) elif args["damper"] == "adadelta": optimizer = optim.Adadelta(model.parameters(), rho=rho) else: if not args["nesterov"]: assert args["momentum"] == 0 optimizer = optim.SGD(model.parameters(), lr=args["lr"], nesterov=args["nesterov"], momentum=args["momentum"], weight_decay=args["weight_decay"]) n_data = len(train_set) opt_args = [model, train_set, optimizer] opt_kwargs = { k: args[k] for k in ["initial_batch_size", "max_batch_size", "random_state"] } opt_kwargs["device"] = device if dataset == "synthetic": opt_kwargs["loss"] = F.mse_loss if dataset == "cifar10": opt_kwargs["loss"] = F.cross_entropy if args["damper"].lower() == "padadamp": if approx_rate: assert isinstance(max_batch_size, int) BM = max_batch_size B0 = initial_batch_size e = epochs n = n_data r_hat = 4 / 3 * (BM - B0) * (B0 + 2 * BM + 3) r_hat /= (2 * BM - 2 * B0 + 3 * e * n) args["batch_growth_rate"] = r_hat opt = PadaDamp( *opt_args, batch_growth_rate=args["batch_growth_rate"], dwell=args["dwell"], **opt_kwargs, ) elif args["damper"].lower() == "geodamp": opt = GeoDamp( *opt_args, dampingdelay=args["dampingdelay"], dampingfactor=args["dampingfactor"], **opt_kwargs, ) elif args["damper"].lower() == "geodamplr": opt = GeoDampLR( *opt_args, dampingdelay=args["dampingdelay"], dampingfactor=args["dampingfactor"], **opt_kwargs, ) elif args["damper"].lower() == "cntsdamplr": opt = CntsDampLR( *opt_args, dampingfactor=args["dampingfactor"], **opt_kwargs, ) elif args["damper"].lower() == "adadamp": opt = AdaDamp(*opt_args, approx_loss=approx_loss, dwell=args["dwell"], **opt_kwargs) elif args["damper"].lower() == "gd": opt = GradientDescent(*opt_args, **opt_kwargs) elif (args["damper"].lower() in ["adagrad", "adadelta", "sgd", "gd"] or args["damper"] is None): opt = BaseDamper(*opt_args, **opt_kwargs) else: raise ValueError("argument damper not recognized") if dataset == "synthetic": pprint(data_stats) opt._meta["best_train_loss"] = data_stats["best_train_loss"] data, train_data = experiment.run( model=model, opt=opt, train_set=train_set, test_set=test_set, args=args, test_freq=test_freq, train_stats=dataset == "synthetic", verbose=verbose, device="cuda" if use_cuda else "cpu", ) return data, train_data
def train_model( epochs, batch_size, use_cuda, dset_folder, disable_tqdm=False, ): print("Reading dataset") dset = FashionMNIST(dset_folder, download=True) imgs = dset.data.unsqueeze(-1).numpy().astype(np.float64) labels = dset.targets.numpy() train_idx, valid_idx = map(np.array, util.split_dataset(labels)) print("Processing images into graphs...", end="") ptime = time.time() with multiprocessing.Pool() as p: graphs = np.array(p.map(util.get_graph_from_image, imgs)) del imgs ptime = time.time() - ptime print(" Took {ptime}s".format(ptime=ptime)) model_args = [] model_kwargs = {} model = GAT_MNIST(num_features=util.NUM_FEATURES, num_classes=util.NUM_CLASSES) if use_cuda: model = model.cuda() opt = torch.optim.Adam(model.parameters()) best_valid_acc = 0. best_model = copy.deepcopy(model) last_epoch_train_loss = 0. last_epoch_train_acc = 0. last_epoch_valid_acc = 0. interrupted = False for e in tqdm( range(epochs), total=epochs, desc="Epoch ", disable=disable_tqdm, ): try: train_losses, train_accs = util.train( model, opt, graphs, labels, train_idx, batch_size=batch_size, use_cuda=use_cuda, disable_tqdm=disable_tqdm, ) last_epoch_train_loss = np.mean(train_losses) last_epoch_train_acc = 100 * np.mean(train_accs) except KeyboardInterrupt: print("Training interrupted!") interrupted = True valid_accs = util.test( model, graphs, labels, valid_idx, use_cuda, desc="Validation ", disable_tqdm=disable_tqdm, ) last_epoch_valid_acc = 100 * np.mean(valid_accs) if last_epoch_valid_acc > best_valid_acc: best_valid_acc = last_epoch_valid_acc best_model = copy.deepcopy(model) tqdm.write("EPOCH SUMMARY {loss:.4f} {t_acc:.2f}% {v_acc:.2f}%".format( loss=last_epoch_train_loss, t_acc=last_epoch_train_acc, v_acc=last_epoch_valid_acc)) if interrupted: break util.save_model("best", best_model) util.save_model("last", model)
def get_data(flag=True): mnist = FashionMNIST('datasets/fashionmnist/', train=flag, transform=transforms.ToTensor(), download=flag) loader = torch.utils.data.DataLoader(mnist, batch_size=config['batch_size'], shuffle=flag, drop_last=False) return loader
# coding: utf-8 # In[69]: import torch import torchvision from torchvision.datasets import FashionMNIST from torch.utils.data import DataLoader from torchvision import transforms # transform torchvision dataset images from PILImage to tensor for input to CNN data_transform = transforms.ToTensor() train_data = FashionMNIST(root='./data', train=True, download=True, transform=data_transform) test_data = FashionMNIST(root='./data', train=False, download=True, transform=data_transform) print('Train data, number of images: ', len(train_data)) print('Test data, number of images: ', len(test_data)) # In[70]: batch_size = 20 train_loader = DataLoader(dataset=train_data,
def get_loader(args): train_data_loader = None test_data_loader = None kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} train_triplets = [] test_triplets = [] dset_obj = None loader = BaseLoader means = (0.485, 0.456, 0.406) stds = (0.229, 0.224, 0.225) if args.dataset == 'vggface2': dset_obj = vggface2.VGGFace2() elif args.dataset == 'custom': dset_obj = custom_dset.Custom() elif (args.dataset == 'mnist') or (args.dataset == 'fmnist'): train_dataset, test_dataset = None, None if args.dataset == 'mnist': train_dataset = MNIST(os.path.join(args.result_dir, "MNIST"), train=True, download=True) test_dataset = MNIST(os.path.join(args.result_dir, "MNIST"), train=False, download=True) if args.dataset == 'fmnist': train_dataset = FashionMNIST(os.path.join(args.result_dir, "FashionMNIST"), train=True, download=True) test_dataset = FashionMNIST(os.path.join(args.result_dir, "FashionMNIST"), train=False, download=True) dset_obj = mnist.MNIST_DS(train_dataset, test_dataset) loader = TripletMNISTLoader means = (0.485, ) stds = (0.229, ) dset_obj.load() for i in range(args.num_train_samples): pos_anchor_img, pos_img, neg_img = dset_obj.getTriplet() train_triplets.append([pos_anchor_img, pos_img, neg_img]) for i in range(args.num_test_samples): pos_anchor_img, pos_img, neg_img = dset_obj.getTriplet(split='test') test_triplets.append([pos_anchor_img, pos_img, neg_img]) train_data_loader = torch.utils.data.DataLoader(loader( train_triplets, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize(means, stds)])), batch_size=args.batch_size, shuffle=True, **kwargs) test_data_loader = torch.utils.data.DataLoader(loader( test_triplets, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize(means, stds)])), batch_size=args.batch_size, shuffle=True, **kwargs) return train_data_loader, test_data_loader
x = F.relu(self.conv5(x)) x = self.pool(F.relu(self.conv6(x))) x = F.relu(self.conv7(x)) x = self.pool(F.relu(self.conv8(x))) x = x.view(-1, 7*7*512) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x batch_size = 32 transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(), transforms.Normalize((0.5),(0.5))]) fashion_mnist_trainval = FashionMNIST("FashionMNIST", train=True, download=True, transform=transform) fashion_mnist_test = FashionMNIST("FashionMNIST", train=False, download=True, transform=transform) n_samples = len(fashion_mnist_trainval) train_size = int(len(fashion_mnist_trainval) * 0.8) val_size = n_samples - train_size train_dataset, val_dataset = torch.utils.data.random_split(fashion_mnist_trainval, [train_size, val_size]) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(fashion_mnist_test, batch_size=batch_size, shuffle=True) net = VGG() net.to(device)
def downloader_construct_datasetsdict(datasets_list: list, grayscale=False) -> dict: """ This function takes in a list of datasets to be used in the experiments """ print(f"INFO ------ List of datasets being loaded are {datasets_list}") datasets_dict = {} if "CIFAR10" in datasets_list: if not grayscale: cifar_train_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, 4), ]) cifar_test_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) else: cifar_train_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.Grayscale(3), transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, 4), ]) cifar_test_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.Grayscale(3), ]) datasets_dict["CIFAR10_train"] = CIFAR10( root=r"./dataset/CHIFAR10/", train=True, download=True, transform=cifar_train_transform, ) datasets_dict["CIFAR10_test"] = CIFAR10( root=r"./dataset/CHIFAR10/", train=False, download=True, transform=cifar_test_transform, ) print("INFO ----- Dataset Loaded : CIFAR10") datasets_list.remove("CIFAR10") if "A_MNIST" in datasets_list: mnist_transforms = transforms.Compose([ transforms.Pad(2), transforms.ToTensor(), transforms.Normalize(mean=[0.1307], std=[0.3015]), transforms.Lambda(tmp_func), transforms.RandomCrop(32, 4), ]) datasets_dict["A_MNIST_train"] = MNIST( root=r"./dataset/MNIST", train=True, download=True, transform=mnist_transforms, ) datasets_dict["A_MNIST_test"] = MNIST( root=r"./dataset/MNIST", train=False, download=True, transform=mnist_transforms, ) print("INFO ----- Dataset Loaded : MNIST") datasets_list.remove("A_MNIST") if "A_FashionMNIST" in datasets_list: fmnist_transforms = transforms.Compose([ transforms.Pad(2), transforms.ToTensor(), transforms.Normalize(mean=[0.2860], std=[0.3205]), transforms.Lambda(tmp_func), transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, 4), ]) datasets_dict["A_FashionMNIST_train"] = FashionMNIST( root="./dataset/FashionMNIST", train=True, download=True, transform=fmnist_transforms, ) datasets_dict["A_FashionMNIST_test"] = FashionMNIST( root="./dataset/FashionMNIST", train=False, download=True, transform=fmnist_transforms, ) print("INFO ----- Dataset Loaded : FashionMNIST") datasets_list.remove("A_FashionMNIST") if "A_SVHN" in datasets_list: SVHN_transforms = transforms.Compose( [transforms.ToTensor(), transforms.Resize(32)]) datasets_dict["A_SVHN_train"] = SVHN( root=r"./dataset/SVHN", split="train", download=True, transform=SVHN_transforms, ) datasets_dict["A_SVHN_train"].targets = datasets_dict[ "A_SVHN_train"].labels datasets_dict["A_SVHN_test"] = SVHN( root=r"./dataset/SVHN", split="test", download=True, transform=SVHN_transforms, ) datasets_dict["A_SVHN_test"].targets = datasets_dict[ "A_SVHN_test"].labels print("INFO ----- Dataset Loaded : SVHN") datasets_list.remove("A_SVHN") if "CIFAR100" in datasets_list: datasets_dict["CIFAR100_train"] = CIFAR100( root=r"./dataset/CIFAR100", train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5071, 0.4865, 0.4409], std=[0.2009, 0.1984, 0.2023]), transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, 4), ]), ) datasets_dict["CIFAR100_test"] = CIFAR100( root=r"./dataset/CIFAR100", train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5071, 0.4865, 0.4409], std=[0.2009, 0.1984, 0.2023]), ]), ) print("INFO ----- Dataset Loaded : CIFAR100") datasets_list.remove("CIFAR100") if "A_CIFAR10_ood" in datasets_list: cifar_train_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, 4), ]) cifar_test_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) datasets_dict["A_CIFAR10_ood_train"] = CIFAR10( root=r"./dataset/CHIFAR10/", train=True, download=True, transform=cifar_train_transform, ) datasets_dict["A_CIFAR10_ood_test"] = CIFAR10( root=r"./dataset/CHIFAR10/", train=False, download=True, transform=cifar_test_transform, ) print("INFO ----- Dataset Loaded : CIFAR10_ood") datasets_list.remove("A_CIFAR10_ood") if "A_CIFAR100_ood" in datasets_list: datasets_dict["A_CIFAR100_ood_train"] = CIFAR100( root=r"./dataset/CIFAR100", train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, 4), ]), ) datasets_dict["A_CIFAR100_ood_test"] = CIFAR100( root=r"./dataset/CIFAR100", train=False, download=True, transform=transforms.ToTensor(), ) print("INFO ----- Dataset Loaded : CIFAR100_ood") datasets_list.remove("A_CIFAR100_ood") assert ( len(datasets_list) == 0 ), f"Not all datasets have been loaded, datasets left : {datasets_list}" return datasets_dict
[transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5, ))]) if args.dataset == 'mnist': train_data = MNIST(root='data/mnist', train=True, download=True, transform=transform) test_data = MNIST(root='data/mnist', train=False, download=True, transform=transform) if args.dataset == 'fashion': train_data = FashionMNIST(root='data/fashion', train=True, download=True, transform=transform) test_data = FashionMNIST(root='data/fashion', train=False, download=True, transform=transform) # BiGAN params z_dim = args.z_dim hid_dim = args.hid_dim # Train params use_cuda = args.use_cuda and torch.cuda.is_available() device = torch.device("cuda") if use_cuda else torch.device('cpu') n_epochs = args.n_epochs batch_size = args.batch_size
if __name__ == '__main__': np.random.seed(1234) torch.manual_seed(1234) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') ''' データの読み込み ''' d_dir = 'gdrive/MyDrive/ProjectExperiment/Datasets/' #そのままだとPIL(Python Imaging Library)の画像形式でDatasetを #を作ってしまうのでtransforms.toTensorでTensorに変換 fashion_mnist_train = FashionMNIST(d_dir, train=True, download=True, transform=transforms.ToTensor()) fashion_mnist_test = FashionMNIST(d_dir, train=False, download=True, transform=transforms.ToTensor()) #バッチサイズが128のDataLoaderを作成 #データローダーはミニバッチを作成するため batch_size = 128 train_dataloader = DataLoader(fashion_mnist_train, batch_size=batch_size, shuffle=True) test_dataloader = DataLoader(fashion_mnist_test, batch_size=batch_size, shuffle=False)
return avg_loss.avg if __name__ == "__main__": number_epochs = 100 device = torch.device( 'cpu' ) # Replace with torch.device("cuda:0") if you want to train on GPU model = MLP(10).to(device) trans_img = transforms.Compose([transforms.ToTensor()]) dataset = FashionMNIST("./data/", train=True, transform=trans_img, download=True) trainloader = DataLoader(dataset, batch_size=1024, shuffle=True) optimizer = optim.Adam(model.parameters(), lr=0.01) track_loss = [] for i in tqdm(range(number_epochs)): loss = train_one_epoch(model, trainloader, optimizer, device) track_loss.append(loss) plt.figure() plt.plot(track_loss) plt.title("training-loss-MLP") plt.savefig("./img/training_mlp.jpg")
def get_data_manager( indistribution=["Cifar10"], ood=["MNIST", "Fashion_MNIST"], ): """get_data_manager [Creates a data_manager instance with the In-/Out-of-Distribution Data] [List based processing of Datasets. Images are resized / croped on 32x32] Args: indistribution (list, optional): [description]. Defaults to ["Cifar10"]. ood (list, optional): [description]. Defaults to ["MNIST", "Fashion_MNIST", "SVHN"]. Returns: [datamager]: [Experiment data_manager for for logging and the active learning cycle] """ # TODO ADD Target transform? base_data = np.empty(shape=(1, 3, 32, 32)) base_data_test = np.empty(shape=(1, 3, 32, 32)) base_labels = np.empty(shape=(1,)) base_labels_test = np.empty(shape=(1,)) OOD_data = np.empty(shape=(1, 3, 32, 32)) OOD_labels = np.empty(shape=(1,)) resize = transforms.Resize(32) random_crop = transforms.RandomCrop(32) standard_transform = transforms.Compose( [ transforms.ToTensor(), resize, transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] ) if debug: tracemalloc.start() snapshot = tracemalloc.take_snapshot() display_top(snapshot) for dataset in indistribution: if dataset == "Cifar10": CIFAR10_train = CIFAR10( root=r"/dataset/CHIFAR10/", train=True, download=True, transform=transforms.ToTensor(), ) CIFAR10_test = CIFAR10( root=r"/dataset/CHIFAR10/", train=False, download=True, transform=transforms.ToTensor(), ) # CIFAR10_train_data = CIFAR10_train.data.permute( # 0, 3, 1, 2 # ) # .reshape(-1, 3, 32, 32) # CIFAR10_test_data = CIFAR10_test.data.permute( # 0, 3, 1, 2 # ) # .reshape(-1, 3, 32, 32) CIFAR10_train_data = np.array([i.numpy() for i, _ in CIFAR10_train]) CIFAR10_test_data = np.array([i.numpy() for i, _ in CIFAR10_test]) CIFAR10_train_labels = np.array(CIFAR10_train.targets) CIFAR10_test_labels = np.array(CIFAR10_test.targets) base_data = np.concatenate( [base_data.copy(), CIFAR10_train_data.copy()], axis=0, ) base_data_test = np.concatenate( [base_data_test.copy(), CIFAR10_test_data.copy()] ) base_labels = np.concatenate( [ base_labels.copy(), CIFAR10_train_labels.copy(), ], axis=0, ) base_labels_test = np.concatenate( [ base_labels_test.copy(), CIFAR10_test_labels.copy(), ] ) del ( CIFAR10_train_data, CIFAR10_test_data, CIFAR10_train_labels, CIFAR10_test_labels, CIFAR10_train, CIFAR10_test, ) gc.collect() elif dataset == "MNIST": MNIST_train = MNIST( root=r"/dataset/MNIST", train=True, download=True, transform=transforms.Compose( [ transforms.Pad(2), transforms.Grayscale(3), transforms.ToTensor(), ] ), ) MNIST_test = MNIST( root=r"/dataset/MNIST", train=False, download=True, transform=transforms.Compose( [ transforms.Pad(2), transforms.ToTensor(), transforms.Lambda(lambda x: x.repeat(3, 1, 1)), ] ), ) MNIST_train_data = np.array([i.numpy() for i, _ in MNIST_train]) MNIST_test_data = np.array([i.numpy() for i, _ in MNIST_test]) if len(dataset) > 1: MNIST_train_labels = MNIST_train.targets + np.max(base_labels) MNIST_test_labels = MNIST_test.targets + np.max(base_labels) else: MNIST_train_labels = MNIST_train.targets MNIST_test_labels = MNIST_test.targets base_data = np.concatenate([base_data.copy(), MNIST_train_data.copy()]) base_data_test = np.concatenate( [base_data_test.copy(), MNIST_test_labels.copy()] ) base_labels = np.concatenate( [ base_labels.copy(), MNIST_train_labels.copy(), ] ) base_labels_test = np.concatenate( [ base_labels_test.copy(), MNIST_test_labels.copy(), ] ) del ( MNIST_train, MNIST_test, MNIST_train_data, MNIST_test_data, MNIST_train_labels, MNIST_test_labels, ) gc.collect() elif dataset == "Fashion_MNIST": Fashion_MNIST_train = FashionMNIST( root="/dataset/FashionMNIST", train=True, download=True, transform=transforms.Compose( [ transforms.Pad(2), transforms.ToTensor(), transforms.Lambda(lambda x: x.repeat(3, 1, 1)), ] ), ) Fashion_MNIST_test = FashionMNIST( root="/dataset/FashionMNIST", train=False, download=True, transform=transforms.Compose( [ transforms.Pad(2), transforms.Grayscale(3), transforms.ToTensor(), ] ), ) Fashion_MNIST_train_data = np.array( [i.numpy() for i, _ in Fashion_MNIST_train] ) Fashion_MNIST_test_data = np.array( [i.numpy() for i, _ in Fashion_MNIST_test] ) if len(dataset) > 1: Fashion_MNIST_train_labels = ( Fashion_MNIST_train.targets.numpy() + np.max(base_labels) ) Fashion_MNIST_test_labels = Fashion_MNIST_test.targets.numpy() + np.max( base_labels ) else: Fashion_MNIST_train_labels = Fashion_MNIST_train.targets.numpy() Fashion_MNIST_test_labels = Fashion_MNIST_test.targets.numpy() base_data = np.concatenate( [base_data.copy(), Fashion_MNIST_train_data.copy()] ) base_data_test = np.concatenate( [base_data_test.copy(), Fashion_MNIST_test_data.copy()] ) base_labels = np.concatenate( [ base_labels.copy(), Fashion_MNIST_train_labels.copy(), ] ) base_labels_test = np.concatenate( [ base_labels_test.copy(), Fashion_MNIST_test_labels.copy(), ] ) del ( Fashion_MNIST_train, Fashion_MNIST_test, Fashion_MNIST_train_data, Fashion_MNIST_test_data, Fashion_MNIST_train_labels, Fashion_MNIST_test_labels, ) gc.collect() for ood_dataset in ood: if ood_dataset == "MNIST": MNIST_train = MNIST( root=r"/dataset/MNIST", train=True, download=True, transform=transforms.Compose( [ transforms.Pad(2), transforms.Grayscale(3), transforms.ToTensor(), ] ), ) MNIST_test = MNIST( root=r"/dataset/MNIST", train=False, download=True, transform=transforms.Compose( [ transforms.Pad(2), transforms.Grayscale(3), transforms.ToTensor(), ] ), ) MNIST_train_data = np.array([i.numpy() for i, _ in MNIST_train]) MNIST_test_data = np.array([i.numpy() for i, _ in MNIST_test]) MNIST_train_labels = MNIST_train.targets.numpy() MNIST_test_labels = MNIST_test.targets.numpy() OOD_data = np.concatenate( [OOD_data.copy(), MNIST_train_data.copy(), MNIST_test_data.copy()], axis=0, ) OOD_labels = np.concatenate( [OOD_labels.copy(), MNIST_train_labels.copy(), MNIST_test_labels.copy()] ) del ( MNIST_train, MNIST_test, MNIST_train_data, MNIST_test_data, MNIST_train_labels, MNIST_test_labels, ) gc.collect() elif ood_dataset == "Fashion_MNIST": Fashion_MNIST_train = FashionMNIST( root="/dataset/FashionMNIST", train=True, download=True, transform=transforms.Compose( [ transforms.Pad(2), transforms.Grayscale(3), transforms.ToTensor(), ] ), ) Fashion_MNIST_test = FashionMNIST( root="/dataset/FashionMNIST", train=False, download=True, transform=transforms.Compose( [ transforms.Pad(2), transforms.Grayscale(3), transforms.ToTensor(), ] ), ) Fashion_MNIST_train_data = np.array( [i.numpy() for i, _ in Fashion_MNIST_train] ) Fashion_MNIST_test_data = np.array( [i.numpy() for i, _ in Fashion_MNIST_test] ) Fashion_MNIST_train_labels = Fashion_MNIST_train.targets.numpy() Fashion_MNIST_test_labels = Fashion_MNIST_test.targets.numpy() OOD_data = np.concatenate( [ OOD_data.copy(), Fashion_MNIST_train_data.copy(), Fashion_MNIST_test_data.copy(), ], axis=0, ) OOD_labels = np.concatenate( [ OOD_labels.copy(), Fashion_MNIST_train_labels.copy(), Fashion_MNIST_test_labels.copy(), ], ) del ( Fashion_MNIST_train, Fashion_MNIST_test, Fashion_MNIST_train_data, Fashion_MNIST_test_data, Fashion_MNIST_train_labels, Fashion_MNIST_test_labels, ) gc.collect() elif ood_dataset == "SVHN": SVHN_train = SVHN( root=r"/dataset/SVHN", split="train", download=True, transform=standard_transform, ) SVHN_test = SVHN( root=r"/dataset/SVHN", split="test", download=True, transform=standard_transform, ) SVHN_train_data = SVHN_train.data SVHN_test_data = SVHN_test.data SVHN_train_labels = SVHN_train.labels SVHN_test_labels = SVHN_test.labels OOD_data = np.concatenate( [OOD_data.copy(), SVHN_train_data.copy(), SVHN_test_data.copy()], axis=0 ) OOD_labels = np.concatenate( [OOD_labels.copy(), SVHN_train_labels.copy(), SVHN_test_labels.copy()] ) del ( SVHN_train, SVHN_test, SVHN_train_data, SVHN_test_data, SVHN_train_labels, SVHN_test_labels, ) gc.collect() # elif ood_dataset == "TinyImageNet": # if not os.listdir(os.path.join(r"./dataset/tiny-imagenet-200")): # download_and_unzip() # id_dict = {} # for i, line in enumerate( # open( # os.path.join( # r"\dataset\tiny-imagenet-200\tiny-imagenet-200\wnids.txt" # ), # "r", # ) # ): # id_dict[line.replace("\n", "")] = i # normalize_imagenet = transforms.Normalize( # (122.4786, 114.2755, 101.3963), (70.4924, 68.5679, 71.8127) # ) # train_t_imagenet = TrainTinyImageNetDataset( # id=id_dict, transform=transforms.Compose([normalize_imagenet, resize]) # ) # test_t_imagenet = TestTinyImageNetDataset( # id=id_dict, transform=transforms.Compose([normalize_imagenet, resize]) # ) if debug: snapshot = tracemalloc.take_snapshot() display_top(snapshot) base_data = np.delete(base_data, 0, axis=0) base_data_test = np.delete(base_data_test, 0, axis=0) base_labels = np.delete(base_labels, 0) base_labels_test = np.delete(base_labels_test, 0) OOD_data = np.delete(OOD_data, 0, axis=0) OOD_labels = np.delete(OOD_labels, 0) print(base_data.shape, base_data_test.shape, OOD_data.shape, OOD_labels.shape) data_manager = Data_manager( base_data=base_data, base_labels=base_labels, base_data_test=base_data_test, base_labels_test=base_labels_test, OOD_data=OOD_data, OOD_labels=OOD_labels, ) # del (base_data, base_labels, OOD_data, OOD_labels) gc.collect() if debug: snapshot = tracemalloc.take_snapshot() display_top(snapshot) return data_manager