def __init__(self, n_states=4, encoding_size=10, path='simulation', cv=0, hidden_size=100, in_channel=3, window_size=50): # Load or train a TNC encoder if not os.path.exists("./ckpt/%s/checkpoint_%d.pth.tar"%(path,cv)): raise ValueError("No checkpoint for an encoder") checkpoint = torch.load('./ckpt/%s/checkpoint_%d.pth.tar'%(path, cv)) self.encoder = RnnEncoder(hidden_size=hidden_size, in_channel=in_channel, encoding_size=encoding_size) self.encoder.load_state_dict(checkpoint['encoder_state_dict']) self.classifier = StateClassifier(input_size=encoding_size, output_size=n_states) self.n_states = n_states self.cv = cv # Build a new encoder to train end-to-end with a classifier self.e2e_model = E2EStateClassifier(hidden_size=hidden_size, in_channel=in_channel, encoding_size=encoding_size, output_size=n_states) data_path = './data/HAR_data/' if 'har' in path else './data/simulated_data/' self.train_loader, self.valid_loader, self.test_loader = create_simulated_dataset\ (window_size=window_size, path=data_path, batch_size=100)
def main(is_train, data_type, cv, w, cont): if not os.path.exists("./plots"): os.mkdir("./plots") if not os.path.exists("./ckpt/"): os.mkdir("./ckpt/") if data_type == 'simulation': window_size = 50 encoder = RnnEncoder(hidden_size=100, in_channel=3, encoding_size=10, device=device) path = './data/simulated_data/' if is_train: with open(os.path.join(path, 'x_train.pkl'), 'rb') as f: x = pickle.load(f) learn_encoder(x, encoder, w=w, lr=1e-3, decay=1e-5, window_size=window_size, n_epochs=100, mc_sample_size=40, path='simulation', device=device, augmentation=5, n_cross_val=cv) else: # Plot the distribution of the encodings and use the learnt encoders to train a downstream classifier with open(os.path.join(path, 'x_test.pkl'), 'rb') as f: x_test = pickle.load(f) with open(os.path.join(path, 'state_test.pkl'), 'rb') as f: y_test = pickle.load(f) checkpoint = torch.load('./ckpt/%s/checkpoint_0.pth.tar' % (data_type)) encoder.load_state_dict(checkpoint['encoder_state_dict']) encoder = encoder.to(device) track_encoding(x_test[10, :, 50:650], y_test[10, 50:650], encoder, window_size, 'simulation') for cv_ind in range(cv): plot_distribution(x_test, y_test, encoder, window_size=window_size, path='simulation', title='TNC', device=device, cv=cv_ind) exp = ClassificationPerformanceExperiment(cv=cv_ind) # Run cross validation for classification for lr in [0.001, 0.01, 0.1]: print('===> lr: ', lr) tnc_acc, tnc_auc, e2e_acc, e2e_auc = exp.run( data='simulation', n_epochs=150, lr_e2e=lr, lr_cls=lr) print( 'TNC acc: %.2f \t TNC auc: %.2f \t E2E acc: %.2f \t E2E auc: %.2f' % (tnc_acc, tnc_auc, e2e_acc, e2e_auc)) if data_type == 'waveform': window_size = 2500 path = './data/waveform_data/processed' encoder = WFEncoder(encoding_size=64).to(device) if is_train: with open(os.path.join(path, 'x_train.pkl'), 'rb') as f: x = pickle.load(f) T = x.shape[-1] x_window = np.concatenate(np.split(x[:, :, :T // 5 * 5], 5, -1), 0) learn_encoder(torch.Tensor(x_window), encoder, w=w, lr=1e-5, decay=1e-4, n_epochs=150, window_size=window_size, path='waveform', mc_sample_size=10, device=device, augmentation=7, n_cross_val=cv, cont=cont) else: with open(os.path.join(path, 'x_test.pkl'), 'rb') as f: x_test = pickle.load(f) with open(os.path.join(path, 'state_test.pkl'), 'rb') as f: y_test = pickle.load(f) checkpoint = torch.load('./ckpt/%s/checkpoint_0.pth.tar' % (data_type)) encoder.load_state_dict(checkpoint['encoder_state_dict']) encoder = encoder.to(device) track_encoding(x_test[0, :, 80000:130000], y_test[0, 80000:130000], encoder, window_size, 'waveform', sliding_gap=1000) for cv_ind in range(cv): plot_distribution(x_test, y_test, encoder, window_size=window_size, path='waveform', device=device, augment=100, cv=cv_ind, title='TNC') exp = WFClassificationExperiment(window_size=window_size, cv=cv_ind) exp.run(data='waveform', n_epochs=10, lr_e2e=0.0001, lr_cls=0.01) if data_type == 'har': window_size = 4 path = './data/HAR_data/' encoder = RnnEncoder(hidden_size=100, in_channel=561, encoding_size=10, device=device) if is_train: with open(os.path.join(path, 'x_train.pkl'), 'rb') as f: x = pickle.load(f) learn_encoder(torch.Tensor(x), encoder, w=w, lr=1e-3, decay=1e-5, n_epochs=150, window_size=window_size, path='har', mc_sample_size=20, device=device, augmentation=5, n_cross_val=cv) else: with open(os.path.join(path, 'x_test.pkl'), 'rb') as f: x_test = pickle.load(f) with open(os.path.join(path, 'state_test.pkl'), 'rb') as f: y_test = pickle.load(f) checkpoint = torch.load('./ckpt/%s/checkpoint_0.pth.tar' % (data_type)) encoder.load_state_dict(checkpoint['encoder_state_dict']) encoder = encoder.to(device) track_encoding(x_test[0, :, :], y_test[0, :], encoder, window_size, 'har') for cv_ind in range(cv): plot_distribution(x_test, y_test, encoder, window_size=window_size, path='har', device=device, augment=100, cv=cv_ind, title='TNC') exp = ClassificationPerformanceExperiment(n_states=6, encoding_size=10, path='har', hidden_size=100, in_channel=561, window_size=4, cv=cv_ind) # Run cross validation for classification for lr in [0.001, 0.01, 0.1]: print('===> lr: ', lr) tnc_acc, tnc_auc, e2e_acc, e2e_auc = exp.run(data='har', n_epochs=50, lr_e2e=lr, lr_cls=lr) print( 'TNC acc: %.2f \t TNC auc: %.2f \t E2E acc: %.2f \t E2E auc: %.2f' % (tnc_acc, tnc_auc, e2e_acc, e2e_auc))
def learn_encoder(x, encoder, window_size, w, lr=0.001, decay=0.005, mc_sample_size=20, n_epochs=100, path='simulation', device='cpu', augmentation=1, n_cross_val=1, cont=False): accuracies, losses = [], [] for cv in range(n_cross_val): if 'waveform' in path: encoder = WFEncoder(encoding_size=64).to(device) batch_size = 5 elif 'simulation' in path: encoder = RnnEncoder(hidden_size=100, in_channel=3, encoding_size=10, device=device) batch_size = 10 elif 'har' in path: encoder = RnnEncoder(hidden_size=100, in_channel=561, encoding_size=10, device=device) batch_size = 10 if not os.path.exists('./ckpt/%s' % path): os.mkdir('./ckpt/%s' % path) if cont: checkpoint = torch.load('./ckpt/%s/checkpoint_%d.pth.tar' % (path, cv)) encoder.load_state_dict(checkpoint['encoder_state_dict']) disc_model = Discriminator(encoder.encoding_size, device) params = list(disc_model.parameters()) + list(encoder.parameters()) optimizer = torch.optim.Adam(params, lr=lr, weight_decay=decay) inds = list(range(len(x))) random.shuffle(inds) x = x[inds] n_train = int(0.8 * len(x)) performance = [] best_acc = 0 best_loss = np.inf for epoch in range(n_epochs + 1): trainset = TNCDataset(x=torch.Tensor(x[:n_train]), mc_sample_size=mc_sample_size, window_size=window_size, augmentation=augmentation, adf=True) train_loader = data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=3) validset = TNCDataset(x=torch.Tensor(x[n_train:]), mc_sample_size=mc_sample_size, window_size=window_size, augmentation=augmentation, adf=True) valid_loader = data.DataLoader(validset, batch_size=batch_size, shuffle=True) epoch_loss, epoch_acc = epoch_run(train_loader, disc_model, encoder, optimizer=optimizer, w=w, train=True, device=device) test_loss, test_acc = epoch_run(valid_loader, disc_model, encoder, train=False, w=w, device=device) performance.append((epoch_loss, test_loss, epoch_acc, test_acc)) if epoch % 10 == 0: print( '(cv:%s)Epoch %d Loss =====> Training Loss: %.5f \t Training Accuracy: %.5f \t Test Loss: %.5f \t Test Accuracy: %.5f' % (cv, epoch, epoch_loss, epoch_acc, test_loss, test_acc)) if best_loss > test_loss or path == 'har': best_acc = test_acc best_loss = test_loss state = { 'epoch': epoch, 'encoder_state_dict': encoder.state_dict(), 'discriminator_state_dict': disc_model.state_dict(), 'best_accuracy': test_acc } torch.save(state, './ckpt/%s/checkpoint_%d.pth.tar' % (path, cv)) accuracies.append(best_acc) losses.append(best_loss) # Save performance plots if not os.path.exists('./plots/%s' % path): os.mkdir('./plots/%s' % path) train_loss = [t[0] for t in performance] test_loss = [t[1] for t in performance] train_acc = [t[2] for t in performance] test_acc = [t[3] for t in performance] plt.figure() plt.plot(np.arange(n_epochs + 1), train_loss, label="Train") plt.plot(np.arange(n_epochs + 1), test_loss, label="Test") plt.title("Loss") plt.legend() plt.savefig(os.path.join("./plots/%s" % path, "loss_%d.pdf" % cv)) plt.figure() plt.plot(np.arange(n_epochs + 1), train_acc, label="Train") plt.plot(np.arange(n_epochs + 1), test_acc, label="Test") plt.title("Accuracy") plt.legend() plt.savefig(os.path.join("./plots/%s" % path, "accuracy_%d.pdf" % cv)) print('=======> Performance Summary:') print('Accuracy: %.2f +- %.2f' % (100 * np.mean(accuracies), 100 * np.std(accuracies))) print('Loss: %.4f +- %.4f' % (np.mean(losses), np.std(losses))) return encoder
def main(is_train, data, cv): if not os.path.exists("./plots"): os.mkdir("./plots") if not os.path.exists("./ckpt/"): os.mkdir("./ckpt/") if data == 'waveform': path = './data/waveform_data/processed' window_size = 2500 encoder = WFEncoder(encoding_size=64).to(device) if is_train: with open(os.path.join(path, 'x_train.pkl'), 'rb') as f: x = pickle.load(f) T = x.shape[-1] x_window = np.concatenate(np.split(x[:, :, :T // 5 * 5], 5, -1), 0) learn_encoder(x_window, window_size, n_epochs=150, lr=1e-4, decay=1e-4, data='waveform', n_cross_val=cv) else: with open(os.path.join(path, 'x_test.pkl'), 'rb') as f: x_test = pickle.load(f) with open(os.path.join(path, 'state_test.pkl'), 'rb') as f: y_test = pickle.load(f) for cv_ind in range(cv): plot_distribution(x_test, y_test, encoder, window_size=window_size, path='%s_trip' % data, device=device, augment=100, cv=cv_ind, title='Triplet Loss') # exp = WFClassificationExperiment(window_size=window_size, data='waveform_trip') # exp.run(data='waveform_trip', n_epochs=15, lr_e2e=0.001, lr_cls=0.001) elif data == 'simulation': path = './data/simulated_data/' window_size = 50 encoder = RnnEncoder(hidden_size=100, in_channel=3, encoding_size=10, device=device).to(device) if is_train: with open(os.path.join(path, 'x_train.pkl'), 'rb') as f: x = pickle.load(f) learn_encoder(x, window_size, lr=1e-3, decay=1e-5, data=data, n_epochs=150, device=device, n_cross_val=cv) else: with open(os.path.join(path, 'x_test.pkl'), 'rb') as f: x_test = pickle.load(f) with open(os.path.join(path, 'state_test.pkl'), 'rb') as f: y_test = pickle.load(f) for cv_ind in range(cv): plot_distribution(x_test, y_test, encoder, window_size=window_size, path='%s_trip' % data, title='Triplet Loss', device=device, cv=cv_ind) exp = ClassificationPerformanceExperiment( path='simulation_trip', cv=cv_ind) # Run cross validation for classification for lr in [0.001, 0.01, 0.1]: print('===> lr: ', lr) tnc_acc, tnc_auc, e2e_acc, e2e_auc = exp.run( data='simulation_trip', n_epochs=50, lr_e2e=lr, lr_cls=lr) print( 'TNC acc: %.2f \t TNC auc: %.2f \t E2E acc: %.2f \t E2E auc: %.2f' % (tnc_acc, tnc_auc, e2e_acc, e2e_auc)) elif data == 'har': window_size = 4 path = './data/HAR_data/' encoder = RnnEncoder(hidden_size=100, in_channel=561, encoding_size=10, device=device) if is_train: with open(os.path.join(path, 'x_train.pkl'), 'rb') as f: x = pickle.load(f) learn_encoder(x, window_size, lr=1e-5, decay=0.001, data=data, n_epochs=300, device=device, n_cross_val=cv) else: with open(os.path.join(path, 'x_test.pkl'), 'rb') as f: x_test = pickle.load(f) with open(os.path.join(path, 'state_test.pkl'), 'rb') as f: y_test = pickle.load(f) for cv_ind in range(cv): plot_distribution(x_test, y_test, encoder, window_size=window_size, path='har_trip', device=device, augment=100, cv=cv_ind, title='Triplet Loss') exp = ClassificationPerformanceExperiment(n_states=6, encoding_size=10, path='har_trip', hidden_size=100, in_channel=561, window_size=5, cv=cv_ind) # Run cross validation for classification for lr in [0.001, 0.01, 0.1]: print('===> lr: ', lr) tnc_acc, tnc_auc, e2e_acc, e2e_auc = exp.run( data='har_trip', n_epochs=100, lr_e2e=lr, lr_cls=lr) print( 'TNC acc: %.2f \t TNC auc: %.2f \t E2E acc: %.2f \t E2E auc: %.2f' % (tnc_acc, tnc_auc, e2e_acc, e2e_auc))
def learn_encoder(x, window_size, data, lr=0.001, decay=0, n_epochs=100, device='cpu', n_cross_val=1): if not os.path.exists("./plots/%s_trip/" % data): os.mkdir("./plots/%s_trip/" % data) if not os.path.exists("./ckpt/%s_trip/" % data): os.mkdir("./ckpt/%s_trip/" % data) for cv in range(n_cross_val): if 'waveform' in data: encoder = WFEncoder(encoding_size=64).to(device) elif 'simulation' in data: encoder = RnnEncoder(hidden_size=100, in_channel=3, encoding_size=10, device=device).to(device) elif 'har' in data: encoder = RnnEncoder(hidden_size=100, in_channel=561, encoding_size=10, device=device).to(device) params = encoder.parameters() optimizer = torch.optim.Adam(params, lr=lr, weight_decay=decay) inds = list(range(len(x))) random.shuffle(inds) x = x[inds] n_train = int(0.8 * len(x)) train_loss, test_loss = [], [] best_loss = np.inf for epoch in range(n_epochs): epoch_loss, acc = epoch_run(x[:n_train], encoder, device, window_size, optimizer=optimizer, train=True) epoch_loss_test, acc_test = epoch_run(x[n_train:], encoder, device, window_size, optimizer=optimizer, train=False) print('\nEpoch ', epoch) print('Train ===> Loss: ', epoch_loss) print('Test ===> Loss: ', epoch_loss_test) train_loss.append(epoch_loss) test_loss.append(epoch_loss_test) if epoch_loss_test < best_loss: print('Save new ckpt') state = { 'epoch': epoch, 'encoder_state_dict': encoder.state_dict() } best_loss = epoch_loss_test torch.save(state, './ckpt/%s_trip/checkpoint_%d.pth.tar' % (data, cv)) plt.figure() plt.plot(np.arange(n_epochs), train_loss, label="Train") plt.plot(np.arange(n_epochs), test_loss, label="Test") plt.title("Loss") plt.legend() plt.savefig(os.path.join("./plots/%s_trip/loss_%d.pdf" % (data, cv)))
def learn_encoder(x, window_size, lr=0.001, decay=0, n_size=5, n_epochs=50, data='simulation', device='cpu', n_cross_val=1): if not os.path.exists("./plots/%s_cpc/" % data): os.mkdir("./plots/%s_cpc/" % data) if not os.path.exists("./ckpt/%s_cpc/" % data): os.mkdir("./ckpt/%s_cpc/" % data) accuracies = [] for cv in range(n_cross_val): if 'waveform' in data: encoding_size = 64 encoder = WFEncoder(encoding_size=64).to(device) elif 'simulation' in data: encoding_size = 10 encoder = RnnEncoder(hidden_size=100, in_channel=3, encoding_size=10, device=device) elif 'har' in data: encoding_size = 10 encoder = RnnEncoder(hidden_size=100, in_channel=561, encoding_size=10, device=device) ds_estimator = torch.nn.Linear(encoder.encoding_size, encoder.encoding_size) auto_regressor = torch.nn.GRU(input_size=encoding_size, hidden_size=encoding_size, batch_first=True) params = list(ds_estimator.parameters()) + list( encoder.parameters()) + list(auto_regressor.parameters()) optimizer = torch.optim.Adam(params, lr=lr, weight_decay=decay) inds = list(range(len(x))) random.shuffle(inds) x = x[inds] n_train = int(0.8 * len(x)) best_acc = 0 best_loss = np.inf train_loss, test_loss = [], [] for epoch in range(n_epochs): epoch_loss, acc = epoch_run(x[:n_train], ds_estimator, auto_regressor, encoder, device, window_size, optimizer=optimizer, n_size=n_size, train=True) epoch_loss_test, acc_test = epoch_run(x[n_train:], ds_estimator, auto_regressor, encoder, device, window_size, n_size=n_size, train=False) print('\nEpoch ', epoch) print('Train ===> Loss: ', epoch_loss, '\t Accuracy: ', acc) print('Test ===> Loss: ', epoch_loss_test, '\t Accuracy: ', acc_test) train_loss.append(epoch_loss) test_loss.append(epoch_loss_test) if epoch_loss_test < best_loss: print('Save new ckpt') state = { 'epoch': epoch, 'encoder_state_dict': encoder.state_dict() } best_loss = epoch_loss_test best_acc = acc_test torch.save(state, './ckpt/%s_cpc/checkpoint_%d.pth.tar' % (data, cv)) accuracies.append(best_acc) plt.figure() plt.plot(np.arange(n_epochs), train_loss, label="Train") plt.plot(np.arange(n_epochs), test_loss, label="Test") plt.title("CPC Loss") plt.legend() plt.savefig(os.path.join("./plots/%s_cpc/loss_%d.pdf" % (data, cv))) print('=======> Performance Summary:') print('Accuracy: %.2f +- %.2f' % (100 * np.mean(accuracies), 100 * np.std(accuracies)))
class ClassificationPerformanceExperiment(): def __init__(self, n_states=4, encoding_size=10, path='simulation', cv=0, hidden_size=100, in_channel=3, window_size=50): # Load or train a TNC encoder if not os.path.exists("./ckpt/%s/checkpoint_%d.pth.tar"%(path,cv)): raise ValueError("No checkpoint for an encoder") checkpoint = torch.load('./ckpt/%s/checkpoint_%d.pth.tar'%(path, cv)) self.encoder = RnnEncoder(hidden_size=hidden_size, in_channel=in_channel, encoding_size=encoding_size) self.encoder.load_state_dict(checkpoint['encoder_state_dict']) self.classifier = StateClassifier(input_size=encoding_size, output_size=n_states) self.n_states = n_states self.cv = cv # Build a new encoder to train end-to-end with a classifier self.e2e_model = E2EStateClassifier(hidden_size=hidden_size, in_channel=in_channel, encoding_size=encoding_size, output_size=n_states) data_path = './data/HAR_data/' if 'har' in path else './data/simulated_data/' self.train_loader, self.valid_loader, self.test_loader = create_simulated_dataset\ (window_size=window_size, path=data_path, batch_size=100) def _train_end_to_end(self, lr): self.e2e_model.train() loss_fn = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(self.e2e_model.parameters(), lr=lr) epoch_loss, epoch_auc = 0, 0 epoch_acc = 0 batch_count = 0 y_all, prediction_all = [], [] for i, (x, y) in enumerate(self.train_loader): # if i>5: # break optimizer.zero_grad() prediction = self.e2e_model(x) state_prediction = torch.argmax(prediction, dim=1) loss = loss_fn(prediction, y.long()) loss.backward() optimizer.step() y_all.append(y) prediction_all.append(prediction.detach().cpu().numpy()) epoch_acc += torch.eq(state_prediction, y).sum().item()/len(x) epoch_loss += loss.item() batch_count += 1 y_all = np.concatenate(y_all, 0) prediction_all = np.concatenate(prediction_all, 0) prediction_class_all = np.argmax(prediction_all, -1) y_onehot_all = np.zeros(prediction_all.shape) y_onehot_all[np.arange(len(y_onehot_all)), y_all.astype(int)] = 1 epoch_auc = roc_auc_score(y_onehot_all, prediction_all) c = confusion_matrix(y_all.astype(int), prediction_class_all) return epoch_loss / batch_count, epoch_acc / batch_count, epoch_auc, c def _train_tnc_classifier(self, lr): self.classifier.train() self.encoder.eval() loss_fn = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(self.classifier.parameters(), lr=lr) epoch_loss, epoch_auc = 0, 0 epoch_acc = 0 batch_count = 0 y_all, prediction_all = [], [] for i, (x, y) in enumerate(self.train_loader): if i > 30: break optimizer.zero_grad() encodings = self.encoder(x) prediction = self.classifier(encodings) state_prediction = torch.argmax(prediction, dim=1) loss = loss_fn(prediction, y.long()) loss.backward() optimizer.step() y_all.append(y) prediction_all.append(prediction.detach().cpu().numpy()) epoch_acc += torch.eq(state_prediction, y).sum().item()/len(x) epoch_loss += loss.item() batch_count += 1 y_all = np.concatenate(y_all, 0) prediction_all = np.concatenate(prediction_all, 0) prediction_class_all = np.argmax(prediction_all, -1) y_onehot_all = np.zeros(prediction_all.shape) y_onehot_all[np.arange(len(y_onehot_all)), y_all.astype(int)] = 1 epoch_auc = roc_auc_score(y_onehot_all, prediction_all) c = confusion_matrix(y_all.astype(int), prediction_class_all) return epoch_loss / batch_count, epoch_acc / batch_count, epoch_auc, c def _test(self, model): model.eval() loss_fn = torch.nn.CrossEntropyLoss() data_loader = self.valid_loader epoch_loss, epoch_auc = 0, 0 epoch_acc = 0 batch_count = 0 y_all, prediction_all = [], [] for x, y in data_loader: prediction = model(x) state_prediction = torch.argmax(prediction, -1) loss = loss_fn(prediction, y.long()) y_all.append(y) prediction_all.append(prediction.detach().cpu().numpy()) epoch_acc += torch.eq(state_prediction, y).sum().item()/len(x) epoch_loss += loss.item() batch_count += 1 y_all = np.concatenate(y_all, 0) prediction_all = np.concatenate(prediction_all, 0) y_onehot_all = np.zeros(prediction_all.shape) prediction_class_all = np.argmax(prediction_all, -1) y_onehot_all[np.arange(len(y_onehot_all)), y_all.astype(int)] = 1 epoch_auc = roc_auc_score(y_onehot_all, prediction_all) c = confusion_matrix(y_all.astype(int), prediction_class_all) return epoch_loss / batch_count, epoch_acc / batch_count, epoch_auc, c def run(self, data, n_epochs, lr_e2e, lr_cls=0.01): tnc_acc, tnc_loss, tnc_auc = [], [], [] etoe_acc, etoe_loss, etoe_auc = [], [], [] tnc_acc_test, tnc_loss_test, tnc_auc_test = [], [], [] etoe_acc_test, etoe_loss_test, etoe_auc_test = [], [], [] for epoch in range(n_epochs): loss, acc, auc, _ = self._train_tnc_classifier(lr_cls) tnc_acc.append(acc) tnc_loss.append(loss) tnc_auc.append(auc) # loss, acc, auc, _ = self._train_end_to_end(lr_e2e) loss, acc, auc = 0, 0, 0 etoe_acc.append(acc) etoe_loss.append(loss) etoe_auc.append(auc) # Test loss, acc, auc, c_mtx_enc = self._test(model=torch.nn.Sequential(self.encoder, self.classifier)) tnc_acc_test.append(acc) tnc_loss_test.append(loss) tnc_auc_test.append(auc) # loss, acc, auc, c_mtx_e2e = self._test(model=self.e2e_model) #torch.nn.Sequential(self.encoder, self.classifier)) loss, acc, auc, c_mtx_e2e = 0, 0, 0, 0 etoe_acc_test.append(acc) etoe_loss_test.append(loss) etoe_auc_test.append(auc) if epoch%5 ==0: print('***** Epoch %d *****'%epoch) print('TNC =====> Training Loss: %.3f \t Training Acc: %.3f \t Training AUC: %.3f ' '\t Test Loss: %.3f \t Test Acc: %.3f \t Test AUC: %.3f' % (tnc_loss[-1], tnc_acc[-1], tnc_auc[-1], tnc_loss_test[-1], tnc_acc_test[-1], tnc_auc_test[-1])) print('End-to-End =====> Training Loss: %.3f \t Training Acc: %.3f \t Training AUC: %.3f' ' \t Test Loss: %.3f \t Test Acc: %.3f \t Test AUC: %.3f' % (etoe_loss[-1], etoe_acc[-1], etoe_auc[-1], etoe_loss_test[-1], etoe_acc_test[-1], etoe_auc_test[-1])) # Save performance plots plt.figure() plt.plot(np.arange(n_epochs), tnc_loss_test, label="tnc test") plt.plot(np.arange(n_epochs), etoe_loss_test, label="e2e test") plt.title("Loss trend for the e2e and tnc framework") plt.legend() plt.savefig(os.path.join("./plots/%s"%data, "classification_loss_comparison.pdf")) plt.figure() plt.plot(np.arange(n_epochs), tnc_acc, label="tnc train") plt.plot(np.arange(n_epochs), etoe_acc, label="e2e train") plt.plot(np.arange(n_epochs), tnc_acc_test, label="tnc test") plt.plot(np.arange(n_epochs), etoe_acc_test, label="e2e test") plt.title("Accuracy trend for the e2e and tnc model") plt.legend() plt.savefig(os.path.join("./plots/%s"%data, "classification_accuracy_comparison_%d.pdf"%self.cv)) plt.figure() plt.plot(np.arange(n_epochs), tnc_auc, label="tnc train") plt.plot(np.arange(n_epochs), etoe_auc, label="e2e train") plt.plot(np.arange(n_epochs), tnc_auc_test, label="tnc test") plt.plot(np.arange(n_epochs), etoe_auc_test, label="e2e test") plt.title("AUC trend for the e2e and TNC model") plt.legend() plt.savefig(os.path.join("./plots/%s" % data, "classification_auc_comparison_%d.pdf"%self.cv)) df_cm = pd.DataFrame(c_mtx_enc, index=[i for i in ['']*self.n_states], columns=[i for i in ['']*self.n_states]) plt.figure(figsize=(10, 10)) sns.heatmap(df_cm, annot=True) plt.savefig(os.path.join("./plots/%s"%data, "encoder_cf_matrix.pdf")) return tnc_acc_test[-1], tnc_auc_test[-1], etoe_acc_test[-1], etoe_auc_test[-1]
def run_test(data, e2e_lr, tnc_lr, cpc_lr, trip_lr, data_path, window_size, n_cross_val): # Load data with open(os.path.join(data_path, 'x_train.pkl'), 'rb') as f: x = pickle.load(f) with open(os.path.join(data_path, 'state_train.pkl'), 'rb') as f: y = pickle.load(f) with open(os.path.join(data_path, 'x_test.pkl'), 'rb') as f: x_test = pickle.load(f) with open(os.path.join(data_path, 'state_test.pkl'), 'rb') as f: y_test = pickle.load(f) T = x.shape[-1] x_window = np.split(x[:, :, :window_size * (T // window_size)], (T // window_size), -1) y_window = np.concatenate( np.split(y[:, :window_size * (T // window_size)], (T // window_size), -1), 0).astype(int) x_window = torch.Tensor(np.concatenate(x_window, 0)) y_window = torch.Tensor( np.array([np.bincount(yy).argmax() for yy in y_window])) x_window_test = np.split(x_test[:, :, :window_size * (T // window_size)], (T // window_size), -1) y_window_test = np.concatenate( np.split(y_test[:, :window_size * (T // window_size)], (T // window_size), -1), 0).astype(int) x_window_test = torch.Tensor(np.concatenate(x_window_test, 0)) y_window_test = torch.Tensor( np.array([np.bincount(yy).argmax() for yy in y_window_test])) testset = torch.utils.data.TensorDataset(x_window_test, y_window_test) test_loader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=True) del x, y, x_test, y_test e2e_accs, e2e_aucs, e2e_auprcs = [], [], [] tnc_accs, tnc_aucs, tnc_auprcs = [], [], [] cpc_accs, cpc_aucs, cpc_auprcs = [], [], [] trip_accs, trip_aucs, trip_auprcs = [], [], [] for cv in range(n_cross_val): shuffled_inds = list(range(len(x_window))) random.shuffle(shuffled_inds) x_window = x_window[shuffled_inds] y_window = y_window[shuffled_inds] n_train = int(0.7 * len(x_window)) X_train, X_test = x_window[:n_train], x_window[n_train:] y_train, y_test = y_window[:n_train], y_window[n_train:] trainset = torch.utils.data.TensorDataset(X_train, y_train) validset = torch.utils.data.TensorDataset(X_test, y_test) train_loader = torch.utils.data.DataLoader(trainset, batch_size=200, shuffle=False) valid_loader = torch.utils.data.DataLoader(validset, batch_size=200, shuffle=False) # Define baseline models if data == 'waveform': encoding_size = 64 n_classes = 4 e2e_model = WFEncoder(encoding_size=encoding_size, classify=True, n_classes=n_classes).to(device) tnc_encoder = WFEncoder(encoding_size=encoding_size).to(device) if not os.path.exists( './ckpt/waveform/checkpoint_%d.pth.tar' % cv): RuntimeError('Checkpoint for TNC encoder does not exist!') tnc_checkpoint = torch.load( './ckpt/waveform/checkpoint_%d.pth.tar' % cv) tnc_encoder.load_state_dict(tnc_checkpoint['encoder_state_dict']) tnc_classifier = WFClassifier(encoding_size=encoding_size, output_size=4) tnc_model = torch.nn.Sequential(tnc_encoder, tnc_classifier).to(device) cpc_encoder = WFEncoder(encoding_size=encoding_size).to(device) if not os.path.exists( './ckpt/waveform_cpc/checkpoint_%d.pth.tar' % cv): RuntimeError('Checkpoint for CPC encoder does not exist!') cpc_checkpoint = torch.load( './ckpt/waveform_cpc/checkpoint_%d.pth.tar' % cv) cpc_encoder.load_state_dict(cpc_checkpoint['encoder_state_dict']) cpc_classifier = WFClassifier(encoding_size=encoding_size, output_size=4) cpc_model = torch.nn.Sequential(cpc_encoder, cpc_classifier).to(device) trip_encoder = WFEncoder(encoding_size=encoding_size).to(device) if not os.path.exists( './ckpt/waveform_trip/checkpoint_%d.pth.tar' % cv): RuntimeError( 'Checkpoint for Triplet Loss encoder does not exist!') trip_checkpoint = torch.load( './ckpt/waveform_trip/checkpoint_%d.pth.tar' % cv) trip_encoder.load_state_dict(trip_checkpoint['encoder_state_dict']) trip_classifier = WFClassifier(encoding_size=encoding_size, output_size=4) trip_model = torch.nn.Sequential(trip_encoder, trip_classifier).to(device) n_epochs = 8 n_epoch_e2e = 8 elif data == 'simulation': encoding_size = 10 e2e_model = E2EStateClassifier(hidden_size=100, in_channel=3, encoding_size=encoding_size, output_size=4, device=device) tnc_encoder = RnnEncoder(hidden_size=100, in_channel=3, encoding_size=encoding_size, device=device) tnc_checkpoint = torch.load( './ckpt/simulation/checkpoint_%d.pth.tar' % cv) tnc_encoder.load_state_dict(tnc_checkpoint['encoder_state_dict']) tnc_classifier = StateClassifier(input_size=encoding_size, output_size=4).to(device) tnc_model = torch.nn.Sequential(tnc_encoder, tnc_classifier).to(device) cpc_encoder = RnnEncoder(hidden_size=100, in_channel=3, encoding_size=encoding_size, device=device) cpc_checkpoint = torch.load( './ckpt/simulation_cpc/checkpoint_%d.pth.tar' % cv) cpc_encoder.load_state_dict(cpc_checkpoint['encoder_state_dict']) cpc_classifier = StateClassifier(input_size=encoding_size, output_size=4).to(device) cpc_model = torch.nn.Sequential(cpc_encoder, cpc_classifier).to(device) trip_encoder = RnnEncoder(hidden_size=100, in_channel=3, encoding_size=encoding_size, device=device) trip_checkpoint = torch.load( './ckpt/simulation_trip/checkpoint_%d.pth.tar' % cv) trip_encoder.load_state_dict(trip_checkpoint['encoder_state_dict']) trip_classifier = StateClassifier(input_size=encoding_size, output_size=4).to(device) trip_model = torch.nn.Sequential(trip_encoder, trip_classifier).to(device) n_epochs = 30 n_epoch_e2e = 100 elif data == 'har': encoding_size = 10 e2e_model = E2EStateClassifier(hidden_size=100, in_channel=561, encoding_size=encoding_size, output_size=6, device=device) tnc_encoder = RnnEncoder(hidden_size=100, in_channel=561, encoding_size=encoding_size, device=device) tnc_checkpoint = torch.load('./ckpt/har/checkpoint_%d.pth.tar' % cv) tnc_encoder.load_state_dict(tnc_checkpoint['encoder_state_dict']) tnc_classifier = StateClassifier(input_size=encoding_size, output_size=6).to(device) tnc_model = torch.nn.Sequential(tnc_encoder, tnc_classifier).to(device) cpc_encoder = RnnEncoder(hidden_size=100, in_channel=561, encoding_size=encoding_size, device=device) cpc_checkpoint = torch.load( './ckpt/har_cpc/checkpoint_%d.pth.tar' % cv) cpc_encoder.load_state_dict(cpc_checkpoint['encoder_state_dict']) cpc_classifier = StateClassifier(input_size=encoding_size, output_size=6).to(device) cpc_model = torch.nn.Sequential(cpc_encoder, cpc_classifier).to(device) trip_encoder = RnnEncoder(hidden_size=100, in_channel=561, encoding_size=encoding_size, device=device) trip_checkpoint = torch.load( './ckpt/har_trip/checkpoint_%d.pth.tar' % cv) trip_encoder.load_state_dict(trip_checkpoint['encoder_state_dict']) trip_classifier = StateClassifier(input_size=encoding_size, output_size=6).to(device) trip_model = torch.nn.Sequential(trip_encoder, trip_classifier).to(device) n_epochs = 50 n_epoch_e2e = 100 # Train the model # ***** E2E ***** best_acc_e2e, best_auc_e2e, best_auprc_e2e = train( train_loader, valid_loader, e2e_model, e2e_lr, data_type=data, n_epochs=n_epoch_e2e, type='e2e', cv=cv) print('E2E: ', best_acc_e2e * 100, best_auc_e2e, best_auprc_e2e) # ***** TNC ***** best_acc_tnc, best_auc_tnc, best_auprc_tnc = train(train_loader, valid_loader, tnc_classifier, tnc_lr, encoder=tnc_encoder, data_type=data, n_epochs=n_epochs, type='tnc', cv=cv) print('TNC: ', best_acc_tnc * 100, best_auc_tnc, best_auprc_tnc) # ***** CPC ***** best_acc_cpc, best_auc_cpc, best_auprc_cpc = train(train_loader, valid_loader, cpc_classifier, cpc_lr, encoder=cpc_encoder, data_type=data, n_epochs=n_epochs, type='cpc', cv=cv) print('CPC: ', best_acc_cpc * 100, best_auc_cpc, best_auprc_cpc) # ***** Trip ***** best_acc_trip, best_auc_trip, best_auprc_trip = train( train_loader, valid_loader, trip_classifier, trip_lr, encoder=trip_encoder, data_type=data, n_epochs=n_epochs, type='trip', cv=cv) print('TRIP: ', best_acc_trip * 100, best_auc_trip, best_auprc_trip) if data == 'waveform': # The waveform dataset is very small and sparse. If due to class imbalance there are no samples of a # particular class in the test set, report the validation performance _, test_acc_e2e, test_auc_e2e, test_auprc_e2e, _ = epoch_run( e2e_model, dataloader=valid_loader, train=False) _, test_acc_tnc, test_auc_tnc, test_auprc_tnc, _ = epoch_run_encoder( tnc_encoder, tnc_classifier, dataloader=valid_loader, train=False) _, test_acc_cpc, test_auc_cpc, test_auprc_cpc, _ = epoch_run_encoder( cpc_encoder, cpc_classifier, dataloader=valid_loader, train=False) _, test_acc_trip, test_auc_trip, test_auprc_trip, _ = epoch_run_encoder( trip_encoder, trip_classifier, dataloader=valid_loader, train=False) else: _, test_acc_e2e, test_auc_e2e, test_auprc_e2e, _ = epoch_run( e2e_model, dataloader=test_loader, train=False) _, test_acc_tnc, test_auc_tnc, test_auprc_tnc, _ = epoch_run_encoder( tnc_encoder, tnc_classifier, dataloader=test_loader, train=False) _, test_acc_cpc, test_auc_cpc, test_auprc_cpc, _ = epoch_run_encoder( cpc_encoder, cpc_classifier, dataloader=test_loader, train=False) _, test_acc_trip, test_auc_trip, test_auprc_trip, _ = epoch_run_encoder( trip_encoder, trip_classifier, dataloader=test_loader, train=False) e2e_accs.append(test_acc_e2e) e2e_aucs.append(test_auc_e2e) e2e_auprcs.append(test_auprc_e2e) tnc_accs.append(test_acc_tnc) tnc_aucs.append(test_auc_tnc) tnc_auprcs.append(test_auprc_tnc) cpc_accs.append(test_acc_cpc) cpc_aucs.append(test_auc_cpc) cpc_auprcs.append(test_auprc_cpc) trip_accs.append(test_acc_trip) trip_aucs.append(test_auc_trip) trip_auprcs.append(test_auprc_trip) with open("./outputs/%s_classifiers.txt" % data, "a") as f: f.write("\n\nPerformance result for a fold") f.write("End-to-End model: \t AUC: %s\t Accuracy: %s \n\n" % (str(best_auc_e2e), str(100 * best_acc_e2e))) f.write("TNC model: \t AUC: %s\t Accuracy: %s \n\n" % (str(best_auc_tnc), str(100 * best_acc_tnc))) f.write("CPC model: \t AUC: %s\t Accuracy: %s \n\n" % (str(best_auc_cpc), str(100 * best_acc_cpc))) f.write("Triplet Loss model: \t AUC: %s\t Accuracy: %s \n\n" % (str(best_auc_trip), str(100 * best_acc_trip))) torch.cuda.empty_cache() print('=======> Performance Summary:') print( 'E2E model: \t Accuracy: %.2f +- %.2f \t AUC: %.3f +- %.3f \t AUPRC: %.3f +- %.3f' % (100 * np.mean(e2e_accs), 100 * np.std(e2e_accs), np.mean(e2e_aucs), np.std(e2e_aucs), np.mean(e2e_auprcs), np.std(e2e_auprcs))) print( 'TNC model: \t Accuracy: %.2f +- %.2f \t AUC: %.3f +- %.3f \t AUPRC: %.3f +- %.3f' % (100 * np.mean(tnc_accs), 100 * np.std(tnc_accs), np.mean(tnc_aucs), np.std(tnc_aucs), np.mean(tnc_auprcs), np.std(tnc_auprcs))) print( 'CPC model: \t Accuracy: %.2f +- %.2f \t AUC: %.3f +- %.3f \t AUPRC: %.3f +- %.3f' % (100 * np.mean(cpc_accs), 100 * np.std(cpc_accs), np.mean(cpc_aucs), np.std(cpc_aucs), np.mean(cpc_auprcs), np.std(cpc_auprcs))) print( 'Trip model: \t Accuracy: %.2f +- %.2f \t AUC: %.3f +- %.3f \t AUPRC: %.3f +- %.3f' % (100 * np.mean(trip_accs), 100 * np.std(trip_accs), np.mean(trip_aucs), np.std(trip_aucs), np.mean(trip_auprcs), np.std(trip_auprcs)))