예제 #1
0
 def _init_data(self):
     # valset picking:
     # for CIFAR, split 5k for val
     # for ImageNet, split 3k for val
     val_size = 5000 if 'cifar' in data_type else 3000
     self.train_loader, self.val_loader, n_class = get_split_dataset(data_type, batch_size,
                                                                     n_data_worker, val_size,
                                                                     data_root,
                                                                     shuffle=False)  # same sampling
예제 #2
0
 def _init_data(self):
     # split the train set into train + val
     # for CIFAR, split 5k for val
     # for ImageNet, split 3k for val
     val_size = 5000 if 'cifar' in self.data_type else 3000
     self.train_loader, self.val_loader, n_class = get_split_dataset(self.data_type, self.batch_size,
                                                                     self.n_data_worker, val_size,
                                                                     data_root=self.data_root,
                                                                     use_real_val=self.use_real_val,
                                                                     shuffle=False)  # same sampling
     if self.use_real_val:  # use the real val set for eval, which is actually wrong
         print('*** USE REAL VALIDATION SET!')