def train_test(self, x_train, y_train, x_test=None, cuisines=None, k=15): torch.cuda.empty_cache() x_train = np.array(x_train) testing = x_test is not None if testing: x_tr = torch.tensor(x_train) y_tr = torch.tensor(y_train) x_val = torch.tensor(x_test) dist = self.get_dist(x_tr, x_val) y_pred = self.predict(dist, y_tr, k) ids = [cuisine.id for cuisine in cuisines] pred_cuisines = [ self.dataset.id2cuisine[label] for label in y_pred ] self._write2csv(ids, pred_cuisines) else: shuffle_idx = torch.randperm(x_train.shape[0]) x_train = torch.tensor(x_train).float() y_train = torch.tensor(y_train) x_train = x_train[shuffle_idx] y_train = y_train[shuffle_idx] x_val = x_train[35000:] x_tr = x_train[:35000] y_val = y_train[35000:] y_tr = y_train[:35000] use_DML = False if use_DML: x_val = x_train[5000:6000] x_tr = x_train[:5000] y_val = y_train[5000:6000] y_tr = y_train[:20000] x_tr, x_val = self.PCA(x_tr, x_val, 64) lmnn = LMNN(k=15, learn_rate=1e-6, min_iter=50, max_iter=100) lmnn.fit(x_tr.numpy(), y_tr.numpy()) M = lmnn.get_mahalanobis_matrix() M = torch.tensor(M).float() n, d = x_val.shape m = x_tr.shape[0] x0 = x_tr.unsqueeze(1).expand(-1, n, -1).contiguous().view(-1, d) x1 = x_val.unsqueeze(0).expand(m, -1, -1).contiguous().view(-1, d) x = x0 - x1 dist0 = torch.mm(M, x.t().contiguous()) dists = dist0.t().contiguous() * x dist = dists.sum(1).view(m, n) else: x_tr, x_val = self.PCA(x_tr, x_val, 500) dist = self.get_dist(x_tr, x_val).cpu() for k in [1, 3, 5, 8, 10, 15, 20, 25, 30]: y_pred = self.predict(dist, y_tr, k) acc = (y_pred == y_val).sum().float().numpy() / y_val.shape[0] print("K=", k, " acc=", acc) torch.cuda.empty_cache()
def test_lmnn(self): lmnn = LMNN(k=5, learn_rate=1e-6, verbose=False) lmnn.fit(self.X, self.y) L = lmnn.components_ assert_array_almost_equal(L.T.dot(L), lmnn.get_mahalanobis_matrix())
def main(params): initialize_results_dir(params.get('results_dir')) backup_params(params, params.get('results_dir')) print('>>> loading data...') X_train, y_train, X_test, y_test = LoaderFactory().create( name=params.get('dataset'), root=params.get('dataset_dir'), random=True, seed=params.getint('split_seed'))() print('<<< data loaded') print('>>> computing psd matrix...') if params.get('algorithm') == 'identity': psd_matrix = np.identity(X_train.shape[1], dtype=X_train.dtype) elif params.get('algorithm') == 'nca': nca = NCA(init='auto', verbose=True, random_state=params.getint('algorithm_seed')) nca.fit(X_train, y_train) psd_matrix = nca.get_mahalanobis_matrix() elif params.get('algorithm') == 'lmnn': lmnn = LMNN(init='auto', verbose=True, random_state=params.getint('algorithm_seed')) lmnn.fit(X_train, y_train) psd_matrix = lmnn.get_mahalanobis_matrix() elif params.get('algorithm') == 'itml': itml = ITML_Supervised(verbose=True, random_state=params.getint('algorithm_seed')) itml.fit(X_train, y_train) psd_matrix = itml.get_mahalanobis_matrix() elif params.get('algorithm') == 'lfda': lfda = LFDA() lfda.fit(X_train, y_train) psd_matrix = lfda.get_mahalanobis_matrix() elif params.get('algorithm') == 'arml': learner = TripleLearner( optimizer=params.get('optimizer'), optimizer_params={ 'lr': params.getfloat('lr'), 'momentum': params.getfloat('momentum'), 'weight_decay': params.getfloat('weight_decay'), }, criterion=params.get('criterion'), criterion_params={'calibration': params.getfloat('calibration')}, n_epochs=params.getint('n_epochs'), batch_size=params.getint('batch_size'), random_initialization=params.getboolean('random_initialization', fallback=False), update_triple=params.getboolean('update_triple', fallback=False), device=params.get('device'), seed=params.getint('learner_seed')) psd_matrix = learner(X_train, y_train, n_candidate_mins=params.getint('n_candidate_mins', fallback=1)) else: raise Exception('unsupported algorithm') print('<<< psd matrix got') np.savetxt(os.path.join(params.get('results_dir'), 'psd_matrix.txt'), psd_matrix)
def test_lmnn(self): lmnn = LMNN(k=5, learn_rate=1e-6, verbose=False) lmnn.fit(self.X, self.y) L = lmnn.transformer_ assert_array_almost_equal(L.T.dot(L), lmnn.get_mahalanobis_matrix())