Exemple #1
0
def cross_session(data, label, session_id, subject_id, category_number, batch_size, iteration, lr, momentum, log_interval):
    ## LOSO
    train_idxs = list(range(3))
    del train_idxs[session_id]
    test_idx = session_id

    target_data, target_label = copy.deepcopy(data[test_idx][subject_id]), copy.deepcopy(label[test_idx][subject_id])
    source_data, source_label = copy.deepcopy(data[train_idxs][:, subject_id]), copy.deepcopy(label[train_idxs][:, subject_id])

    source_data_comb = np.vstack((source_data[0], source_data[1]))
    source_label_comb = np.vstack((source_label[0], source_label[1]))
    for j in range(1, len(source_data)):
        source_data_comb = np.vstack((source_data_comb, source_data[j]))
        source_label_comb = np.vstack((source_label_comb, source_label[j]))
    source_loader = torch.utils.data.DataLoader(dataset=utils.CustomDataset(source_data_comb, source_label_comb),
                                                            batch_size=batch_size,
                                                            shuffle=True,
                                                            drop_last=True)
    target_loader = torch.utils.data.DataLoader(dataset=utils.CustomDataset(target_data, target_label),
                                                            batch_size=batch_size, 
                                                            shuffle=True, 
                                                            drop_last=True)
    model = DCORAL(model=models.DeepCoral(pretrained=False, number_of_category=category_number),
                source_loader=source_loader,
                target_loader=target_loader,
                batch_size=batch_size,
                iteration=iteration,
                lr=lr,
                momentum=momentum,
                log_interval=log_interval)
    # print(model.__getModel__())
    acc = model.train()
    print('Target_session_id: {}, current_subject_id: {}, acc: {}'.format(test_idx, subject_id, acc))
    return acc
Exemple #2
0
def cross_subject(data, label, session_id, category_number, batch_size,
                  iteration, lr, momentum, log_interval):
    one_session_data, one_session_label = copy.deepcopy(
        data_tmp[session_id]), copy.deepcopy(label[session_id])
    target_data, target_label = one_session_data.pop(), one_session_label.pop()
    source_data, source_label = copy.deepcopy(one_session_data), copy.deepcopy(
        one_session_label)
    # print(len(source_data))
    source_loaders = []
    for j in range(len(source_data)):
        source_loaders.append(
            torch.utils.data.DataLoader(dataset=utils.CustomDataset(
                source_data[j], source_label[j]),
                                        batch_size=batch_size,
                                        shuffle=True,
                                        drop_last=True))
    target_loader = torch.utils.data.DataLoader(dataset=utils.CustomDataset(
        target_data, target_label),
                                                batch_size=batch_size,
                                                shuffle=True,
                                                drop_last=True)
    model = MSMDAER(model=models.MSMDAERNet(
        pretrained=False,
        number_of_source=len(source_loaders),
        number_of_category=category_number),
                    source_loaders=source_loaders,
                    target_loader=target_loader,
                    batch_size=batch_size,
                    iteration=iteration,
                    lr=lr,
                    momentum=momentum,
                    log_interval=log_interval)
    # print(model.__getModel__())
    acc = model.train()
    return acc
Exemple #3
0
def cross_session(data, label, subject_id, category_number, batch_size, iteration, lr, momentum, log_interval):
    target_data, target_label = copy.deepcopy(data[2][subject_id]), copy.deepcopy(label[2][subject_id])
    source_data, source_label = [copy.deepcopy(data[0][subject_id]), copy.deepcopy(data[1][subject_id])], [copy.deepcopy(label[0][subject_id]), copy.deepcopy(label[1][subject_id])]
    # one_sub_data, one_sub_label = data[i], label[i]
    # target_data, target_label = one_session_data.pop(), one_session_label.pop()
    # source_data, source_label = one_session_data.copy(), one_session_label.copy()
    # print(len(source_data))
    source_data_comb = np.vstack((source_data[0], source_data[1]))
    source_label_comb = np.vstack((source_label[0], source_label[1]))
    for j in range(1, len(source_data)):
        source_data_comb = np.vstack((source_data_comb, source_data[j]))
        source_label_comb = np.vstack((source_label_comb, source_label[j]))
    source_loader = torch.utils.data.DataLoader(dataset=utils.CustomDataset(source_data_comb, source_label_comb),
                                                            batch_size=batch_size,
                                                            shuffle=True,
                                                            drop_last=True)
    target_loader = torch.utils.data.DataLoader(dataset=utils.CustomDataset(target_data, target_label),
                                                            batch_size=batch_size, 
                                                            shuffle=True, 
                                                            drop_last=True)
    model = DANNet(model=models.DAN(pretrained=False, number_of_category=category_number),
                source_loader=source_loader,
                target_loader=target_loader,
                batch_size=batch_size,
                iteration=iteration,
                lr=lr,
                momentum=momentum,
                log_interval=log_interval)
    # print(model.__getModel__())
    acc = model.train()
    return acc
Exemple #4
0
    def apply_opt(self):
        # dataset
        if self._opt.dataset == "MNIST":
            train_data, test_data = utils.get_mnist()
            self._train_set = torch.utils.data.DataLoader(
                train_data,
                batch_size=self._opt.batch_size,
                shuffle=True,
                num_workers=self._opt.num_workers)
            self._test_set = torch.utils.data.DataLoader(
                test_data,
                batch_size=self._opt.batch_size,
                shuffle=True,
                num_workers=self._opt.num_workers)
            self._initialize_model(dims=self._opt.layer_dims)
            print("MNIST experiment")

        elif self._opt.dataset == "IBNet":
            train_data = utils.CustomDataset('2017_12_21_16_51_3_275766',
                                             train=True)
            test_data = utils.CustomDataset('2017_12_21_16_51_3_275766',
                                            train=False)
            self._train_set = torch.utils.data.DataLoader(
                train_data,
                batch_size=self._opt.batch_size,
                shuffle=True,
                num_workers=self._opt.num_workers)
            self._test_set = torch.utils.data.DataLoader(
                test_data,
                batch_size=self._opt.batch_size,
                shuffle=True,
                num_workers=self._opt.num_workers)
            self._initialize_model(dims=self._opt.layer_dims)
            print("IBnet experiment")
        else:
            raise RuntimeError(
                'Do not have {name} dataset, Please be sure to use the existing dataset'
                .format(name=self._opt.dataset))

        # construct saving directory
        save_root_dir = self._opt.save_root_dir
        dataset = self._opt.dataset
        time = datetime.datetime.today().strftime('%m_%d_%H_%M')
        model = ''.join(
            list(map(lambda x: str(x) + '_', self._model.layer_dims)))
        folder_name = dataset + '_' + self._opt.experiment_name + '_Time_' + time + '_Model_' + model
        self._path_to_dir = save_root_dir + '/' + folder_name + '/'
        print(self._path_to_dir)
        if not os.path.exists(self._path_to_dir):
            os.makedirs(self._path_to_dir)

        self._logger = Logger(opt=self._opt, plot_name=folder_name)
        self._json = JsonParser()
Exemple #5
0
def cross_subject(data, label, session_id, category_number, batch_size, iteration, lr, momentum, log_interval):
    # cross-subject, for 3 sessions, 1-14 as sources, 15 as target
    one_session_data, one_session_label = copy.deepcopy(data[session_id]), copy.deepcopy(label[session_id])
    target_data, target_label = one_session_data.pop(), one_session_label.pop()
    source_data, source_label = copy.deepcopy(one_session_data), copy.deepcopy(one_session_label.copy())
    # print(len(source_data))
    source_data_comb = source_data[0]
    source_label_comb = source_label[0]
    for j in range(1, len(source_data)):
        source_data_comb = np.vstack((source_data_comb, source_data[j]))
        source_label_comb = np.vstack((source_label_comb, source_label[j]))
    if bn == 'ele':
        source_data_comb = utils.norminy(source_data_comb)
        target_data = utils.norminy(target_data)
    elif bn == 'sample':
        source_data_comb = utils.norminx(source_data_comb)
        target_data = utils.norminx(target_data)
    elif bn == 'global':
        source_data_comb = utils.normalization(source_data_comb)
        target_data = utils.normalization(target_data)
    elif bn == 'none':
        pass
    else:
        pass
    # source_data_comb = utils.norminy(source_data_comb)
    # target_data = utils.norminy(target_data)
    source_loader = torch.utils.data.DataLoader(dataset=utils.CustomDataset(source_data_comb, source_label_comb),
                                                            batch_size=batch_size,
                                                            shuffle=True,
                                                            drop_last=True)
    # source_loaders = []
    # for j in range(len(source_data)):
    #     source_loaders.append(torch.utils.data.DataLoader(dataset=utils.CustomDataset(source_data[j], source_label[j]),
    #                                                         batch_size=batch_size,
    #                                                         shuffle=True,
    #                                                         drop_last=True))
    target_loader = torch.utils.data.DataLoader(dataset=utils.CustomDataset(target_data, target_label),
                                                            batch_size=batch_size, 
                                                            shuffle=True, 
                                                            drop_last=True)
    model = DANNet(model=models.DAN(pretrained=False, number_of_category=category_number),
                source_loader=source_loader,
                target_loader=target_loader,
                batch_size=batch_size,
                iteration=iteration,
                lr=lr,
                momentum=momentum,
                log_interval=log_interval)
    # print(model.__getModel__())
    acc = model.train()
    return acc
Exemple #6
0
    def __init__(self):
        self._name = 'train_testIBnet'
        self._fdir = "./results/testIBnet"
        self._save_step = 100

        if not os.path.exists(self._fdir):
            os.mkdir(self._fdir)

        self._layer_dims = [12, 12, 10, 7, 5, 4, 3, 2, 2]
        self._acttype = 'tanh'
        self._is_train = True
        self._num_epoch = 1000
        self._batch_size = 256
        self._lr = 0.0004

        self._build_model()
        self._initialize_model()

        # set training dataset
        train_data = utils.CustomDataset('2017_12_21_16_51_3_275766', train=True)
        self._train_set = torch.utils.data.DataLoader(train_data, 
                                                      batch_size=self._batch_size, 
                                                      shuffle=True, 
                                                      num_workers=1)
        print("\nIBnet experiment:\n")
Exemple #7
0
    def __init__(self):
        self.progress_bar = 0
        self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # device setup
        load_config = JsonParser() # training args
        self.model_name = 'IBNet_test_save_Time_05_27_20_09_Model_12_12_10_7_5_4_3_2_2_'
        self.path =os.path.join('./results', self.model_name)# info plane dir
        self._opt = load_config.read_json_as_argparse(self.path) # load training args

        # force the batch size to 1 for calculation convinience
        self._opt.batch_size = 1
        # dataset
        if self._opt.dataset == "MNIST":
            train_data, test_data = utils.get_mnist()

            if not self._opt.full_mi:
                # self._train_set = torch.utils.data.DataLoader(train_data, batch_size=1, shuffle=False, num_workers=0)
                self._test_set = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False, num_workers=0)
            else:
                dataset = torch.utils.data.ConcatDataset([train_data, test_data])
                self._test_set = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
            print("MNIST experiment")

        elif self._opt.dataset == "IBNet":
            train_data = utils.CustomDataset('2017_12_21_16_51_3_275766', train=True)
            test_data = utils.CustomDataset('2017_12_21_16_51_3_275766', train=False)
            # self._train_set = torch.utils.data.DataLoader(train_data, batch_size=1, shuffle=False, num_workers=0)
            if not self._opt.full_mi:
                self._test_set = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False, num_workers=0)
            else:
                dataset = torch.utils.data.ConcatDataset([train_data, test_data])
                self._test_set = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
            print("IBnet experiment")
        else:
            raise RuntimeError('Do not have {name} dataset, Please be sure to use the existing dataset'.format(name = self._opt.dataset))

        # get model
        self._model = Model(activation = self._opt.activation ,dims = self._opt.layer_dims, train = False)
        
        # get measure
        # self.measure = measure.kde()
        self.measure = measure.EVKL() # our new measure
Exemple #8
0
    def __init__(self):
        self._name = 'testDebug_IBnet'
        self._fdir = "./results/testIBnet"

        self._layer_dims = [12, 12, 10, 7, 5, 4, 3, 2, 2]
        self._acttype = 'tanh'
        self._is_train = False
        self._batch_size = 256

        self._model = IBnetModel(dims = self._layer_dims, 
                                 acttype = self._acttype, 
                                 is_train = self._is_train)

        test_data = utils.CustomDataset('2017_12_21_16_51_3_275766', train=False)
        self._test_set  = torch.utils.data.DataLoader(test_data,
                                                      batch_size=self._batch_size, 
                                                      shuffle=True, 
                                                      num_workers=1)
        print("\nIBnet debug:\n")

        self.measure = measure.kde()
Exemple #9
0
 #     source_data_comb = utils.norminy(source_data_comb)
 #     target_data = utils.norminy(target_data)
 # elif bn == 'sample':
 #     source_data_comb = utils.norminx(source_data_comb)
 #     target_data = utils.norminx(target_data)
 # elif bn == 'global':
 #     source_data_comb = utils.normalization(source_data_comb)
 #     target_data = utils.normalization(target_data)
 # elif bn == 'none':
 #     pass
 # else:
 #     pass
 # source_data_comb = utils.norminy(source_data_comb)
 # target_data = utils.norminy(target_data)
 source_loader = torch.utils.data.DataLoader(
     dataset=utils.CustomDataset(source_data_comb,
                                 source_label_comb),
     batch_size=batch_size,
     shuffle=True,
     drop_last=True)
 # source_loaders = []
 # for j in range(len(source_data)):
 #     source_loaders.append(torch.utils.data.DataLoader(dataset=utils.CustomDataset(source_data[j], source_label[j]),
 #                                                         batch_size=batch_size,
 #                                                         shuffle=True,
 #                                                         drop_last=True))
 target_loader = torch.utils.data.DataLoader(
     dataset=utils.CustomDataset(target_data, target_label),
     batch_size=batch_size,
     shuffle=True,
     drop_last=True)
 model = DANNet(model=models.DAN(
Exemple #10
0
    one_session_data, one_session_label = copy.deepcopy(data[session_id]), copy.deepcopy(label[session_id])
    train_idxs = list(range(15))
    del train_idxs[subject_id]
    test_idx = subject_id
    target_data, target_label = copy.deepcopy(one_session_data[test_idx]), copy.deepcopy(one_session_label[test_idx])
    source_data, source_label = copy.deepcopy(one_session_data[train_idxs]), copy.deepcopy(one_session_label[train_idxs])
    # print('Target_subject_id: ', test_idx)
    # print('Source_subject_id: ', train_idxs)
>>>>>>> Stashed changes

    del one_session_label
    del one_session_data

    source_loaders = []
    for j in range(len(source_data)):
        source_loaders.append(torch.utils.data.DataLoader(dataset=utils.CustomDataset(source_data[j], source_label[j]),
                                                          batch_size=batch_size,
                                                          shuffle=True,
                                                          drop_last=True))
    target_loader = torch.utils.data.DataLoader(dataset=utils.CustomDataset(target_data, target_label),
                                                batch_size=batch_size,
                                                shuffle=True,
                                                drop_last=True)
    model = MSMDAER(model=models.MSMDAERNet(pretrained=False, number_of_source=len(source_loaders), number_of_category=category_number),
                    source_loaders=source_loaders,
                    target_loader=target_loader,
                    batch_size=batch_size,
                    iteration=iteration,
                    lr=lr,
                    momentum=momentum,
                    log_interval=log_interval)