Exemplo n.º 1
0
def main(is_train, data_type, cv, w, cont):
    if not os.path.exists("./plots"):
        os.mkdir("./plots")
    if not os.path.exists("./ckpt/"):
        os.mkdir("./ckpt/")

    if data_type == 'simulation':
        window_size = 50
        encoder = RnnEncoder(hidden_size=100,
                             in_channel=3,
                             encoding_size=10,
                             device=device)
        path = './data/simulated_data/'

        if is_train:
            with open(os.path.join(path, 'x_train.pkl'), 'rb') as f:
                x = pickle.load(f)
            learn_encoder(x,
                          encoder,
                          w=w,
                          lr=1e-3,
                          decay=1e-5,
                          window_size=window_size,
                          n_epochs=100,
                          mc_sample_size=40,
                          path='simulation',
                          device=device,
                          augmentation=5,
                          n_cross_val=cv)
        else:
            # Plot the distribution of the encodings and use the learnt encoders to train a downstream classifier
            with open(os.path.join(path, 'x_test.pkl'), 'rb') as f:
                x_test = pickle.load(f)
            with open(os.path.join(path, 'state_test.pkl'), 'rb') as f:
                y_test = pickle.load(f)
            checkpoint = torch.load('./ckpt/%s/checkpoint_0.pth.tar' %
                                    (data_type))
            encoder.load_state_dict(checkpoint['encoder_state_dict'])
            encoder = encoder.to(device)
            track_encoding(x_test[10, :, 50:650], y_test[10, 50:650], encoder,
                           window_size, 'simulation')
            for cv_ind in range(cv):
                plot_distribution(x_test,
                                  y_test,
                                  encoder,
                                  window_size=window_size,
                                  path='simulation',
                                  title='TNC',
                                  device=device,
                                  cv=cv_ind)
                exp = ClassificationPerformanceExperiment(cv=cv_ind)
                # Run cross validation for classification
                for lr in [0.001, 0.01, 0.1]:
                    print('===> lr: ', lr)
                    tnc_acc, tnc_auc, e2e_acc, e2e_auc = exp.run(
                        data='simulation', n_epochs=150, lr_e2e=lr, lr_cls=lr)
                    print(
                        'TNC acc: %.2f \t TNC auc: %.2f \t E2E acc: %.2f \t E2E auc: %.2f'
                        % (tnc_acc, tnc_auc, e2e_acc, e2e_auc))

    if data_type == 'waveform':
        window_size = 2500
        path = './data/waveform_data/processed'
        encoder = WFEncoder(encoding_size=64).to(device)

        if is_train:
            with open(os.path.join(path, 'x_train.pkl'), 'rb') as f:
                x = pickle.load(f)
            T = x.shape[-1]
            x_window = np.concatenate(np.split(x[:, :, :T // 5 * 5], 5, -1), 0)
            learn_encoder(torch.Tensor(x_window),
                          encoder,
                          w=w,
                          lr=1e-5,
                          decay=1e-4,
                          n_epochs=150,
                          window_size=window_size,
                          path='waveform',
                          mc_sample_size=10,
                          device=device,
                          augmentation=7,
                          n_cross_val=cv,
                          cont=cont)

        else:
            with open(os.path.join(path, 'x_test.pkl'), 'rb') as f:
                x_test = pickle.load(f)
            with open(os.path.join(path, 'state_test.pkl'), 'rb') as f:
                y_test = pickle.load(f)
            checkpoint = torch.load('./ckpt/%s/checkpoint_0.pth.tar' %
                                    (data_type))
            encoder.load_state_dict(checkpoint['encoder_state_dict'])
            encoder = encoder.to(device)
            track_encoding(x_test[0, :, 80000:130000],
                           y_test[0, 80000:130000],
                           encoder,
                           window_size,
                           'waveform',
                           sliding_gap=1000)
            for cv_ind in range(cv):
                plot_distribution(x_test,
                                  y_test,
                                  encoder,
                                  window_size=window_size,
                                  path='waveform',
                                  device=device,
                                  augment=100,
                                  cv=cv_ind,
                                  title='TNC')
            exp = WFClassificationExperiment(window_size=window_size,
                                             cv=cv_ind)
            exp.run(data='waveform', n_epochs=10, lr_e2e=0.0001, lr_cls=0.01)

    if data_type == 'har':
        window_size = 4
        path = './data/HAR_data/'
        encoder = RnnEncoder(hidden_size=100,
                             in_channel=561,
                             encoding_size=10,
                             device=device)

        if is_train:
            with open(os.path.join(path, 'x_train.pkl'), 'rb') as f:
                x = pickle.load(f)
            learn_encoder(torch.Tensor(x),
                          encoder,
                          w=w,
                          lr=1e-3,
                          decay=1e-5,
                          n_epochs=150,
                          window_size=window_size,
                          path='har',
                          mc_sample_size=20,
                          device=device,
                          augmentation=5,
                          n_cross_val=cv)

        else:
            with open(os.path.join(path, 'x_test.pkl'), 'rb') as f:
                x_test = pickle.load(f)
            with open(os.path.join(path, 'state_test.pkl'), 'rb') as f:
                y_test = pickle.load(f)
            checkpoint = torch.load('./ckpt/%s/checkpoint_0.pth.tar' %
                                    (data_type))
            encoder.load_state_dict(checkpoint['encoder_state_dict'])
            encoder = encoder.to(device)
            track_encoding(x_test[0, :, :], y_test[0, :], encoder, window_size,
                           'har')
            for cv_ind in range(cv):
                plot_distribution(x_test,
                                  y_test,
                                  encoder,
                                  window_size=window_size,
                                  path='har',
                                  device=device,
                                  augment=100,
                                  cv=cv_ind,
                                  title='TNC')
                exp = ClassificationPerformanceExperiment(n_states=6,
                                                          encoding_size=10,
                                                          path='har',
                                                          hidden_size=100,
                                                          in_channel=561,
                                                          window_size=4,
                                                          cv=cv_ind)
                # Run cross validation for classification
                for lr in [0.001, 0.01, 0.1]:
                    print('===> lr: ', lr)
                    tnc_acc, tnc_auc, e2e_acc, e2e_auc = exp.run(data='har',
                                                                 n_epochs=50,
                                                                 lr_e2e=lr,
                                                                 lr_cls=lr)
                    print(
                        'TNC acc: %.2f \t TNC auc: %.2f \t E2E acc: %.2f \t E2E auc: %.2f'
                        % (tnc_acc, tnc_auc, e2e_acc, e2e_auc))
class ClassificationPerformanceExperiment():
    def __init__(self, n_states=4, encoding_size=10, path='simulation', cv=0, hidden_size=100, in_channel=3, window_size=50):
        # Load or train a TNC encoder
        if not os.path.exists("./ckpt/%s/checkpoint_%d.pth.tar"%(path,cv)):
            raise ValueError("No checkpoint for an encoder")
        checkpoint = torch.load('./ckpt/%s/checkpoint_%d.pth.tar'%(path, cv))
        self.encoder = RnnEncoder(hidden_size=hidden_size, in_channel=in_channel, encoding_size=encoding_size)
        self.encoder.load_state_dict(checkpoint['encoder_state_dict'])
        self.classifier = StateClassifier(input_size=encoding_size, output_size=n_states)

        self.n_states = n_states
        self.cv = cv
        # Build a new encoder to train end-to-end with a classifier
        self.e2e_model = E2EStateClassifier(hidden_size=hidden_size, in_channel=in_channel, encoding_size=encoding_size, output_size=n_states)

        data_path = './data/HAR_data/' if 'har' in path else './data/simulated_data/'
        self.train_loader, self.valid_loader, self.test_loader = create_simulated_dataset\
            (window_size=window_size, path=data_path, batch_size=100)

    def _train_end_to_end(self, lr):
        self.e2e_model.train()
        loss_fn = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.e2e_model.parameters(), lr=lr)

        epoch_loss, epoch_auc = 0, 0
        epoch_acc = 0
        batch_count = 0
        y_all, prediction_all = [], []
        for i, (x, y) in enumerate(self.train_loader):
            # if i>5:
            #     break
            optimizer.zero_grad()
            prediction = self.e2e_model(x)
            state_prediction = torch.argmax(prediction, dim=1)
            loss = loss_fn(prediction, y.long())
            loss.backward()
            optimizer.step()
            y_all.append(y)
            prediction_all.append(prediction.detach().cpu().numpy())
            epoch_acc += torch.eq(state_prediction, y).sum().item()/len(x)
            epoch_loss += loss.item()
            batch_count += 1
        y_all = np.concatenate(y_all, 0)
        prediction_all = np.concatenate(prediction_all, 0)
        prediction_class_all = np.argmax(prediction_all, -1)
        y_onehot_all = np.zeros(prediction_all.shape)
        y_onehot_all[np.arange(len(y_onehot_all)), y_all.astype(int)] = 1
        epoch_auc = roc_auc_score(y_onehot_all, prediction_all)
        c = confusion_matrix(y_all.astype(int), prediction_class_all)
        return epoch_loss / batch_count, epoch_acc / batch_count, epoch_auc, c

    def _train_tnc_classifier(self, lr):
        self.classifier.train()
        self.encoder.eval()
        loss_fn = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.classifier.parameters(), lr=lr)

        epoch_loss, epoch_auc = 0, 0
        epoch_acc = 0
        batch_count = 0
        y_all, prediction_all = [], []
        for i, (x, y) in enumerate(self.train_loader):
            if i > 30:
                break
            optimizer.zero_grad()
            encodings = self.encoder(x)
            prediction = self.classifier(encodings)
            state_prediction = torch.argmax(prediction, dim=1)
            loss = loss_fn(prediction, y.long())
            loss.backward()
            optimizer.step()
            y_all.append(y)
            prediction_all.append(prediction.detach().cpu().numpy())

            epoch_acc += torch.eq(state_prediction, y).sum().item()/len(x)
            epoch_loss += loss.item()
            batch_count += 1
        y_all = np.concatenate(y_all, 0)
        prediction_all = np.concatenate(prediction_all, 0)
        prediction_class_all = np.argmax(prediction_all, -1)
        y_onehot_all = np.zeros(prediction_all.shape)
        y_onehot_all[np.arange(len(y_onehot_all)), y_all.astype(int)] = 1
        epoch_auc = roc_auc_score(y_onehot_all, prediction_all)
        c = confusion_matrix(y_all.astype(int), prediction_class_all)
        return epoch_loss / batch_count, epoch_acc / batch_count, epoch_auc, c

    def _test(self, model):
        model.eval()
        loss_fn = torch.nn.CrossEntropyLoss()
        data_loader = self.valid_loader

        epoch_loss, epoch_auc = 0, 0
        epoch_acc = 0
        batch_count = 0
        y_all, prediction_all = [], []
        for x, y in data_loader:
            prediction = model(x)
            state_prediction = torch.argmax(prediction, -1)
            loss = loss_fn(prediction, y.long())
            y_all.append(y)
            prediction_all.append(prediction.detach().cpu().numpy())

            epoch_acc += torch.eq(state_prediction, y).sum().item()/len(x)
            epoch_loss += loss.item()
            batch_count += 1
        y_all = np.concatenate(y_all, 0)
        prediction_all = np.concatenate(prediction_all, 0)
        y_onehot_all = np.zeros(prediction_all.shape)
        prediction_class_all = np.argmax(prediction_all, -1)
        y_onehot_all[np.arange(len(y_onehot_all)), y_all.astype(int)] = 1
        epoch_auc = roc_auc_score(y_onehot_all, prediction_all)
        c = confusion_matrix(y_all.astype(int), prediction_class_all)
        return epoch_loss / batch_count, epoch_acc / batch_count, epoch_auc, c

    def run(self, data, n_epochs, lr_e2e, lr_cls=0.01):
        tnc_acc, tnc_loss, tnc_auc = [], [], []
        etoe_acc, etoe_loss, etoe_auc = [], [], []
        tnc_acc_test, tnc_loss_test, tnc_auc_test = [], [], []
        etoe_acc_test, etoe_loss_test, etoe_auc_test = [], [], []
        for epoch in range(n_epochs):
            loss, acc, auc, _ = self._train_tnc_classifier(lr_cls)
            tnc_acc.append(acc)
            tnc_loss.append(loss)
            tnc_auc.append(auc)
            # loss, acc, auc, _ = self._train_end_to_end(lr_e2e)
            loss, acc, auc = 0, 0, 0
            etoe_acc.append(acc)
            etoe_loss.append(loss)
            etoe_auc.append(auc)
            # Test
            loss, acc, auc, c_mtx_enc = self._test(model=torch.nn.Sequential(self.encoder, self.classifier))
            tnc_acc_test.append(acc)
            tnc_loss_test.append(loss)
            tnc_auc_test.append(auc)
            # loss, acc, auc, c_mtx_e2e = self._test(model=self.e2e_model) #torch.nn.Sequential(self.encoder, self.classifier))
            loss, acc, auc, c_mtx_e2e = 0, 0, 0, 0
            etoe_acc_test.append(acc)
            etoe_loss_test.append(loss)
            etoe_auc_test.append(auc)

            if epoch%5 ==0:
                print('***** Epoch %d *****'%epoch)
                print('TNC =====> Training Loss: %.3f \t Training Acc: %.3f \t Training AUC: %.3f '
                      '\t Test Loss: %.3f \t Test Acc: %.3f \t Test AUC: %.3f'
                      % (tnc_loss[-1], tnc_acc[-1], tnc_auc[-1], tnc_loss_test[-1], tnc_acc_test[-1], tnc_auc_test[-1]))
                print('End-to-End =====> Training Loss: %.3f \t Training Acc: %.3f \t Training AUC: %.3f'
                      ' \t Test Loss: %.3f \t Test Acc: %.3f \t Test AUC: %.3f'
                      % (etoe_loss[-1], etoe_acc[-1], etoe_auc[-1], etoe_loss_test[-1], etoe_acc_test[-1], etoe_auc_test[-1]))

        # Save performance plots
        plt.figure()
        plt.plot(np.arange(n_epochs), tnc_loss_test, label="tnc test")
        plt.plot(np.arange(n_epochs), etoe_loss_test, label="e2e test")
        plt.title("Loss trend for the e2e and tnc framework")
        plt.legend()
        plt.savefig(os.path.join("./plots/%s"%data, "classification_loss_comparison.pdf"))

        plt.figure()
        plt.plot(np.arange(n_epochs), tnc_acc, label="tnc train")
        plt.plot(np.arange(n_epochs), etoe_acc, label="e2e train")
        plt.plot(np.arange(n_epochs), tnc_acc_test, label="tnc test")
        plt.plot(np.arange(n_epochs), etoe_acc_test, label="e2e test")
        plt.title("Accuracy trend for the e2e and tnc model")
        plt.legend()
        plt.savefig(os.path.join("./plots/%s"%data, "classification_accuracy_comparison_%d.pdf"%self.cv))

        plt.figure()
        plt.plot(np.arange(n_epochs), tnc_auc, label="tnc train")
        plt.plot(np.arange(n_epochs), etoe_auc, label="e2e train")
        plt.plot(np.arange(n_epochs), tnc_auc_test, label="tnc test")
        plt.plot(np.arange(n_epochs), etoe_auc_test, label="e2e test")
        plt.title("AUC trend for the e2e and TNC model")
        plt.legend()
        plt.savefig(os.path.join("./plots/%s" % data, "classification_auc_comparison_%d.pdf"%self.cv))

        df_cm = pd.DataFrame(c_mtx_enc, index=[i for i in ['']*self.n_states],
                             columns=[i for i in ['']*self.n_states])
        plt.figure(figsize=(10, 10))
        sns.heatmap(df_cm, annot=True)
        plt.savefig(os.path.join("./plots/%s"%data, "encoder_cf_matrix.pdf"))
        return tnc_acc_test[-1], tnc_auc_test[-1], etoe_acc_test[-1], etoe_auc_test[-1]