Esempio n. 1
0
    def __init__(self,
                 train_percentage,
                 dataset_type='train',
                 random_state=42,
                 noise=0.0):
        self.x = None
        self.y = None
        self.train = False
        if dataset_type == 'train':
            self.train = True

        # TODO: REMOVE THIS SET
        self.train = False

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])
        path = ''.join(
            [os.path.dirname(os.path.realpath(__file__)), '/data/mnist'])
        MNIST.__init__(self,
                       root=path,
                       train=self.train,
                       download=True,
                       transform=self.transform)
        NeatTestingDataset.__init__(self,
                                    train_percentage=train_percentage,
                                    dataset_type=dataset_type,
                                    random_state=random_state,
                                    noise=noise)
Esempio n. 2
0
 def __init__(self, **kwargs):
     """Construct the Mnist class."""
     Dataset.__init__(self, **kwargs)
     MNIST.__init__(self,
                    root=self.args.data_path,
                    train=self.train,
                    transform=Compose(self.transforms.__transform__),
                    download=self.args.download)
Esempio n. 3
0
 def __init__(self, **kwargs):
     """Construct the Mnist class."""
     Dataset.__init__(self, **kwargs)
     self.args.data_path = FileOps.download_dataset(self.args.data_path)
     MNIST.__init__(self,
                    root=self.args.data_path,
                    train=self.train,
                    transform=self.transforms,
                    download=self.args.download)
Esempio n. 4
0
  def _load_dataset(self):
    self.logger.info('Loading the MNIST dataset from root directory {}'.format(self.data_directory))
    MNISTBase.__init__(self, root=self.data_directory, download=True)

    XS = self.train_data.numpy()
    yS = self.train_labels.numpy()

    usps_data_path = 'usps/usps.dat'
    usps_data_path = os.path.join(os.getenv('DATA_DIR'), usps_data_path)
    XT, yT = joblib.load(usps_data_path)

    idx = []
    unique_classes = np.unique(yS)
    for _c in unique_classes:
      idx += np.where(yS == _c)[0][:700].tolist()

    _XS = XS[idx]
    _yS = yS[idx]

    idx = []
    unique_classes = np.unique(yT)
    for _c in unique_classes:
      idx += np.where(yT == _c)[0][:700].tolist()

    _XT = XT[idx]
    _yT = yT[idx]


    if self.source == 'mnist' and self.target == 'usps':
      self.XS = _XS
      self.yS = _yS
      self.XT = _XT
      self.yT = _yT
    elif self.source == 'usps' and self.target == 'mnist':
      self.XS = _XT
      self.yS = _yT
      self.XT = _XS
      self.yT = _yS
    else:
      raise  Exception('Unsupported source and target domains.')