def __init__(self, n_states=4, encoding_size=10, path='simulation', cv=0, hidden_size=100, in_channel=3, window_size=50):
        # Load or train a TNC encoder
        if not os.path.exists("./ckpt/%s/checkpoint_%d.pth.tar"%(path,cv)):
            raise ValueError("No checkpoint for an encoder")
        checkpoint = torch.load('./ckpt/%s/checkpoint_%d.pth.tar'%(path, cv))
        self.encoder = RnnEncoder(hidden_size=hidden_size, in_channel=in_channel, encoding_size=encoding_size)
        self.encoder.load_state_dict(checkpoint['encoder_state_dict'])
        self.classifier = StateClassifier(input_size=encoding_size, output_size=n_states)

        self.n_states = n_states
        self.cv = cv
        # Build a new encoder to train end-to-end with a classifier
        self.e2e_model = E2EStateClassifier(hidden_size=hidden_size, in_channel=in_channel, encoding_size=encoding_size, output_size=n_states)

        data_path = './data/HAR_data/' if 'har' in path else './data/simulated_data/'
        self.train_loader, self.valid_loader, self.test_loader = create_simulated_dataset\
            (window_size=window_size, path=data_path, batch_size=100)
Exemplo n.º 2
0
def main(is_train, data_type, cv, w, cont):
    if not os.path.exists("./plots"):
        os.mkdir("./plots")
    if not os.path.exists("./ckpt/"):
        os.mkdir("./ckpt/")

    if data_type == 'simulation':
        window_size = 50
        encoder = RnnEncoder(hidden_size=100,
                             in_channel=3,
                             encoding_size=10,
                             device=device)
        path = './data/simulated_data/'

        if is_train:
            with open(os.path.join(path, 'x_train.pkl'), 'rb') as f:
                x = pickle.load(f)
            learn_encoder(x,
                          encoder,
                          w=w,
                          lr=1e-3,
                          decay=1e-5,
                          window_size=window_size,
                          n_epochs=100,
                          mc_sample_size=40,
                          path='simulation',
                          device=device,
                          augmentation=5,
                          n_cross_val=cv)
        else:
            # Plot the distribution of the encodings and use the learnt encoders to train a downstream classifier
            with open(os.path.join(path, 'x_test.pkl'), 'rb') as f:
                x_test = pickle.load(f)
            with open(os.path.join(path, 'state_test.pkl'), 'rb') as f:
                y_test = pickle.load(f)
            checkpoint = torch.load('./ckpt/%s/checkpoint_0.pth.tar' %
                                    (data_type))
            encoder.load_state_dict(checkpoint['encoder_state_dict'])
            encoder = encoder.to(device)
            track_encoding(x_test[10, :, 50:650], y_test[10, 50:650], encoder,
                           window_size, 'simulation')
            for cv_ind in range(cv):
                plot_distribution(x_test,
                                  y_test,
                                  encoder,
                                  window_size=window_size,
                                  path='simulation',
                                  title='TNC',
                                  device=device,
                                  cv=cv_ind)
                exp = ClassificationPerformanceExperiment(cv=cv_ind)
                # Run cross validation for classification
                for lr in [0.001, 0.01, 0.1]:
                    print('===> lr: ', lr)
                    tnc_acc, tnc_auc, e2e_acc, e2e_auc = exp.run(
                        data='simulation', n_epochs=150, lr_e2e=lr, lr_cls=lr)
                    print(
                        'TNC acc: %.2f \t TNC auc: %.2f \t E2E acc: %.2f \t E2E auc: %.2f'
                        % (tnc_acc, tnc_auc, e2e_acc, e2e_auc))

    if data_type == 'waveform':
        window_size = 2500
        path = './data/waveform_data/processed'
        encoder = WFEncoder(encoding_size=64).to(device)

        if is_train:
            with open(os.path.join(path, 'x_train.pkl'), 'rb') as f:
                x = pickle.load(f)
            T = x.shape[-1]
            x_window = np.concatenate(np.split(x[:, :, :T // 5 * 5], 5, -1), 0)
            learn_encoder(torch.Tensor(x_window),
                          encoder,
                          w=w,
                          lr=1e-5,
                          decay=1e-4,
                          n_epochs=150,
                          window_size=window_size,
                          path='waveform',
                          mc_sample_size=10,
                          device=device,
                          augmentation=7,
                          n_cross_val=cv,
                          cont=cont)

        else:
            with open(os.path.join(path, 'x_test.pkl'), 'rb') as f:
                x_test = pickle.load(f)
            with open(os.path.join(path, 'state_test.pkl'), 'rb') as f:
                y_test = pickle.load(f)
            checkpoint = torch.load('./ckpt/%s/checkpoint_0.pth.tar' %
                                    (data_type))
            encoder.load_state_dict(checkpoint['encoder_state_dict'])
            encoder = encoder.to(device)
            track_encoding(x_test[0, :, 80000:130000],
                           y_test[0, 80000:130000],
                           encoder,
                           window_size,
                           'waveform',
                           sliding_gap=1000)
            for cv_ind in range(cv):
                plot_distribution(x_test,
                                  y_test,
                                  encoder,
                                  window_size=window_size,
                                  path='waveform',
                                  device=device,
                                  augment=100,
                                  cv=cv_ind,
                                  title='TNC')
            exp = WFClassificationExperiment(window_size=window_size,
                                             cv=cv_ind)
            exp.run(data='waveform', n_epochs=10, lr_e2e=0.0001, lr_cls=0.01)

    if data_type == 'har':
        window_size = 4
        path = './data/HAR_data/'
        encoder = RnnEncoder(hidden_size=100,
                             in_channel=561,
                             encoding_size=10,
                             device=device)

        if is_train:
            with open(os.path.join(path, 'x_train.pkl'), 'rb') as f:
                x = pickle.load(f)
            learn_encoder(torch.Tensor(x),
                          encoder,
                          w=w,
                          lr=1e-3,
                          decay=1e-5,
                          n_epochs=150,
                          window_size=window_size,
                          path='har',
                          mc_sample_size=20,
                          device=device,
                          augmentation=5,
                          n_cross_val=cv)

        else:
            with open(os.path.join(path, 'x_test.pkl'), 'rb') as f:
                x_test = pickle.load(f)
            with open(os.path.join(path, 'state_test.pkl'), 'rb') as f:
                y_test = pickle.load(f)
            checkpoint = torch.load('./ckpt/%s/checkpoint_0.pth.tar' %
                                    (data_type))
            encoder.load_state_dict(checkpoint['encoder_state_dict'])
            encoder = encoder.to(device)
            track_encoding(x_test[0, :, :], y_test[0, :], encoder, window_size,
                           'har')
            for cv_ind in range(cv):
                plot_distribution(x_test,
                                  y_test,
                                  encoder,
                                  window_size=window_size,
                                  path='har',
                                  device=device,
                                  augment=100,
                                  cv=cv_ind,
                                  title='TNC')
                exp = ClassificationPerformanceExperiment(n_states=6,
                                                          encoding_size=10,
                                                          path='har',
                                                          hidden_size=100,
                                                          in_channel=561,
                                                          window_size=4,
                                                          cv=cv_ind)
                # Run cross validation for classification
                for lr in [0.001, 0.01, 0.1]:
                    print('===> lr: ', lr)
                    tnc_acc, tnc_auc, e2e_acc, e2e_auc = exp.run(data='har',
                                                                 n_epochs=50,
                                                                 lr_e2e=lr,
                                                                 lr_cls=lr)
                    print(
                        'TNC acc: %.2f \t TNC auc: %.2f \t E2E acc: %.2f \t E2E auc: %.2f'
                        % (tnc_acc, tnc_auc, e2e_acc, e2e_auc))
Exemplo n.º 3
0
def learn_encoder(x,
                  encoder,
                  window_size,
                  w,
                  lr=0.001,
                  decay=0.005,
                  mc_sample_size=20,
                  n_epochs=100,
                  path='simulation',
                  device='cpu',
                  augmentation=1,
                  n_cross_val=1,
                  cont=False):
    accuracies, losses = [], []
    for cv in range(n_cross_val):
        if 'waveform' in path:
            encoder = WFEncoder(encoding_size=64).to(device)
            batch_size = 5
        elif 'simulation' in path:
            encoder = RnnEncoder(hidden_size=100,
                                 in_channel=3,
                                 encoding_size=10,
                                 device=device)
            batch_size = 10
        elif 'har' in path:
            encoder = RnnEncoder(hidden_size=100,
                                 in_channel=561,
                                 encoding_size=10,
                                 device=device)
            batch_size = 10
        if not os.path.exists('./ckpt/%s' % path):
            os.mkdir('./ckpt/%s' % path)
        if cont:
            checkpoint = torch.load('./ckpt/%s/checkpoint_%d.pth.tar' %
                                    (path, cv))
            encoder.load_state_dict(checkpoint['encoder_state_dict'])

        disc_model = Discriminator(encoder.encoding_size, device)
        params = list(disc_model.parameters()) + list(encoder.parameters())
        optimizer = torch.optim.Adam(params, lr=lr, weight_decay=decay)
        inds = list(range(len(x)))
        random.shuffle(inds)
        x = x[inds]
        n_train = int(0.8 * len(x))
        performance = []
        best_acc = 0
        best_loss = np.inf

        for epoch in range(n_epochs + 1):
            trainset = TNCDataset(x=torch.Tensor(x[:n_train]),
                                  mc_sample_size=mc_sample_size,
                                  window_size=window_size,
                                  augmentation=augmentation,
                                  adf=True)
            train_loader = data.DataLoader(trainset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=3)
            validset = TNCDataset(x=torch.Tensor(x[n_train:]),
                                  mc_sample_size=mc_sample_size,
                                  window_size=window_size,
                                  augmentation=augmentation,
                                  adf=True)
            valid_loader = data.DataLoader(validset,
                                           batch_size=batch_size,
                                           shuffle=True)

            epoch_loss, epoch_acc = epoch_run(train_loader,
                                              disc_model,
                                              encoder,
                                              optimizer=optimizer,
                                              w=w,
                                              train=True,
                                              device=device)
            test_loss, test_acc = epoch_run(valid_loader,
                                            disc_model,
                                            encoder,
                                            train=False,
                                            w=w,
                                            device=device)
            performance.append((epoch_loss, test_loss, epoch_acc, test_acc))
            if epoch % 10 == 0:
                print(
                    '(cv:%s)Epoch %d Loss =====> Training Loss: %.5f \t Training Accuracy: %.5f \t Test Loss: %.5f \t Test Accuracy: %.5f'
                    % (cv, epoch, epoch_loss, epoch_acc, test_loss, test_acc))
            if best_loss > test_loss or path == 'har':
                best_acc = test_acc
                best_loss = test_loss
                state = {
                    'epoch': epoch,
                    'encoder_state_dict': encoder.state_dict(),
                    'discriminator_state_dict': disc_model.state_dict(),
                    'best_accuracy': test_acc
                }
                torch.save(state,
                           './ckpt/%s/checkpoint_%d.pth.tar' % (path, cv))
        accuracies.append(best_acc)
        losses.append(best_loss)
        # Save performance plots
        if not os.path.exists('./plots/%s' % path):
            os.mkdir('./plots/%s' % path)
        train_loss = [t[0] for t in performance]
        test_loss = [t[1] for t in performance]
        train_acc = [t[2] for t in performance]
        test_acc = [t[3] for t in performance]
        plt.figure()
        plt.plot(np.arange(n_epochs + 1), train_loss, label="Train")
        plt.plot(np.arange(n_epochs + 1), test_loss, label="Test")
        plt.title("Loss")
        plt.legend()
        plt.savefig(os.path.join("./plots/%s" % path, "loss_%d.pdf" % cv))
        plt.figure()
        plt.plot(np.arange(n_epochs + 1), train_acc, label="Train")
        plt.plot(np.arange(n_epochs + 1), test_acc, label="Test")
        plt.title("Accuracy")
        plt.legend()
        plt.savefig(os.path.join("./plots/%s" % path, "accuracy_%d.pdf" % cv))

    print('=======> Performance Summary:')
    print('Accuracy: %.2f +- %.2f' %
          (100 * np.mean(accuracies), 100 * np.std(accuracies)))
    print('Loss: %.4f +- %.4f' % (np.mean(losses), np.std(losses)))
    return encoder
def main(is_train, data, cv):
    if not os.path.exists("./plots"):
        os.mkdir("./plots")
    if not os.path.exists("./ckpt/"):
        os.mkdir("./ckpt/")

    if data == 'waveform':
        path = './data/waveform_data/processed'
        window_size = 2500
        encoder = WFEncoder(encoding_size=64).to(device)
        if is_train:
            with open(os.path.join(path, 'x_train.pkl'), 'rb') as f:
                x = pickle.load(f)
            T = x.shape[-1]
            x_window = np.concatenate(np.split(x[:, :, :T // 5 * 5], 5, -1), 0)
            learn_encoder(x_window,
                          window_size,
                          n_epochs=150,
                          lr=1e-4,
                          decay=1e-4,
                          data='waveform',
                          n_cross_val=cv)
        else:
            with open(os.path.join(path, 'x_test.pkl'), 'rb') as f:
                x_test = pickle.load(f)
            with open(os.path.join(path, 'state_test.pkl'), 'rb') as f:
                y_test = pickle.load(f)
            for cv_ind in range(cv):
                plot_distribution(x_test,
                                  y_test,
                                  encoder,
                                  window_size=window_size,
                                  path='%s_trip' % data,
                                  device=device,
                                  augment=100,
                                  cv=cv_ind,
                                  title='Triplet Loss')
            # exp = WFClassificationExperiment(window_size=window_size, data='waveform_trip')
            # exp.run(data='waveform_trip', n_epochs=15, lr_e2e=0.001, lr_cls=0.001)

    elif data == 'simulation':
        path = './data/simulated_data/'
        window_size = 50
        encoder = RnnEncoder(hidden_size=100,
                             in_channel=3,
                             encoding_size=10,
                             device=device).to(device)
        if is_train:
            with open(os.path.join(path, 'x_train.pkl'), 'rb') as f:
                x = pickle.load(f)
            learn_encoder(x,
                          window_size,
                          lr=1e-3,
                          decay=1e-5,
                          data=data,
                          n_epochs=150,
                          device=device,
                          n_cross_val=cv)
        else:
            with open(os.path.join(path, 'x_test.pkl'), 'rb') as f:
                x_test = pickle.load(f)
            with open(os.path.join(path, 'state_test.pkl'), 'rb') as f:
                y_test = pickle.load(f)
            for cv_ind in range(cv):
                plot_distribution(x_test,
                                  y_test,
                                  encoder,
                                  window_size=window_size,
                                  path='%s_trip' % data,
                                  title='Triplet Loss',
                                  device=device,
                                  cv=cv_ind)
                exp = ClassificationPerformanceExperiment(
                    path='simulation_trip', cv=cv_ind)
                # Run cross validation for classification
                for lr in [0.001, 0.01, 0.1]:
                    print('===> lr: ', lr)
                    tnc_acc, tnc_auc, e2e_acc, e2e_auc = exp.run(
                        data='simulation_trip',
                        n_epochs=50,
                        lr_e2e=lr,
                        lr_cls=lr)
                    print(
                        'TNC acc: %.2f \t TNC auc: %.2f \t E2E acc: %.2f \t E2E auc: %.2f'
                        % (tnc_acc, tnc_auc, e2e_acc, e2e_auc))

    elif data == 'har':
        window_size = 4
        path = './data/HAR_data/'
        encoder = RnnEncoder(hidden_size=100,
                             in_channel=561,
                             encoding_size=10,
                             device=device)

        if is_train:
            with open(os.path.join(path, 'x_train.pkl'), 'rb') as f:
                x = pickle.load(f)
            learn_encoder(x,
                          window_size,
                          lr=1e-5,
                          decay=0.001,
                          data=data,
                          n_epochs=300,
                          device=device,
                          n_cross_val=cv)
        else:
            with open(os.path.join(path, 'x_test.pkl'), 'rb') as f:
                x_test = pickle.load(f)
            with open(os.path.join(path, 'state_test.pkl'), 'rb') as f:
                y_test = pickle.load(f)
            for cv_ind in range(cv):
                plot_distribution(x_test,
                                  y_test,
                                  encoder,
                                  window_size=window_size,
                                  path='har_trip',
                                  device=device,
                                  augment=100,
                                  cv=cv_ind,
                                  title='Triplet Loss')
                exp = ClassificationPerformanceExperiment(n_states=6,
                                                          encoding_size=10,
                                                          path='har_trip',
                                                          hidden_size=100,
                                                          in_channel=561,
                                                          window_size=5,
                                                          cv=cv_ind)
                # Run cross validation for classification
                for lr in [0.001, 0.01, 0.1]:
                    print('===> lr: ', lr)
                    tnc_acc, tnc_auc, e2e_acc, e2e_auc = exp.run(
                        data='har_trip', n_epochs=100, lr_e2e=lr, lr_cls=lr)
                    print(
                        'TNC acc: %.2f \t TNC auc: %.2f \t E2E acc: %.2f \t E2E auc: %.2f'
                        % (tnc_acc, tnc_auc, e2e_acc, e2e_auc))
def learn_encoder(x,
                  window_size,
                  data,
                  lr=0.001,
                  decay=0,
                  n_epochs=100,
                  device='cpu',
                  n_cross_val=1):
    if not os.path.exists("./plots/%s_trip/" % data):
        os.mkdir("./plots/%s_trip/" % data)
    if not os.path.exists("./ckpt/%s_trip/" % data):
        os.mkdir("./ckpt/%s_trip/" % data)
    for cv in range(n_cross_val):
        if 'waveform' in data:
            encoder = WFEncoder(encoding_size=64).to(device)
        elif 'simulation' in data:
            encoder = RnnEncoder(hidden_size=100,
                                 in_channel=3,
                                 encoding_size=10,
                                 device=device).to(device)
        elif 'har' in data:
            encoder = RnnEncoder(hidden_size=100,
                                 in_channel=561,
                                 encoding_size=10,
                                 device=device).to(device)

        params = encoder.parameters()
        optimizer = torch.optim.Adam(params, lr=lr, weight_decay=decay)
        inds = list(range(len(x)))
        random.shuffle(inds)
        x = x[inds]
        n_train = int(0.8 * len(x))
        train_loss, test_loss = [], []
        best_loss = np.inf
        for epoch in range(n_epochs):
            epoch_loss, acc = epoch_run(x[:n_train],
                                        encoder,
                                        device,
                                        window_size,
                                        optimizer=optimizer,
                                        train=True)
            epoch_loss_test, acc_test = epoch_run(x[n_train:],
                                                  encoder,
                                                  device,
                                                  window_size,
                                                  optimizer=optimizer,
                                                  train=False)
            print('\nEpoch ', epoch)
            print('Train ===> Loss: ', epoch_loss)
            print('Test ===> Loss: ', epoch_loss_test)
            train_loss.append(epoch_loss)
            test_loss.append(epoch_loss_test)
            if epoch_loss_test < best_loss:
                print('Save new ckpt')
                state = {
                    'epoch': epoch,
                    'encoder_state_dict': encoder.state_dict()
                }
                best_loss = epoch_loss_test
                torch.save(state,
                           './ckpt/%s_trip/checkpoint_%d.pth.tar' % (data, cv))
        plt.figure()
        plt.plot(np.arange(n_epochs), train_loss, label="Train")
        plt.plot(np.arange(n_epochs), test_loss, label="Test")
        plt.title("Loss")
        plt.legend()
        plt.savefig(os.path.join("./plots/%s_trip/loss_%d.pdf" % (data, cv)))
Exemplo n.º 6
0
def learn_encoder(x,
                  window_size,
                  lr=0.001,
                  decay=0,
                  n_size=5,
                  n_epochs=50,
                  data='simulation',
                  device='cpu',
                  n_cross_val=1):
    if not os.path.exists("./plots/%s_cpc/" % data):
        os.mkdir("./plots/%s_cpc/" % data)
    if not os.path.exists("./ckpt/%s_cpc/" % data):
        os.mkdir("./ckpt/%s_cpc/" % data)
    accuracies = []
    for cv in range(n_cross_val):
        if 'waveform' in data:
            encoding_size = 64
            encoder = WFEncoder(encoding_size=64).to(device)
        elif 'simulation' in data:
            encoding_size = 10
            encoder = RnnEncoder(hidden_size=100,
                                 in_channel=3,
                                 encoding_size=10,
                                 device=device)
        elif 'har' in data:
            encoding_size = 10
            encoder = RnnEncoder(hidden_size=100,
                                 in_channel=561,
                                 encoding_size=10,
                                 device=device)
        ds_estimator = torch.nn.Linear(encoder.encoding_size,
                                       encoder.encoding_size)
        auto_regressor = torch.nn.GRU(input_size=encoding_size,
                                      hidden_size=encoding_size,
                                      batch_first=True)
        params = list(ds_estimator.parameters()) + list(
            encoder.parameters()) + list(auto_regressor.parameters())
        optimizer = torch.optim.Adam(params, lr=lr, weight_decay=decay)
        inds = list(range(len(x)))
        random.shuffle(inds)
        x = x[inds]
        n_train = int(0.8 * len(x))
        best_acc = 0
        best_loss = np.inf
        train_loss, test_loss = [], []
        for epoch in range(n_epochs):
            epoch_loss, acc = epoch_run(x[:n_train],
                                        ds_estimator,
                                        auto_regressor,
                                        encoder,
                                        device,
                                        window_size,
                                        optimizer=optimizer,
                                        n_size=n_size,
                                        train=True)
            epoch_loss_test, acc_test = epoch_run(x[n_train:],
                                                  ds_estimator,
                                                  auto_regressor,
                                                  encoder,
                                                  device,
                                                  window_size,
                                                  n_size=n_size,
                                                  train=False)
            print('\nEpoch ', epoch)
            print('Train ===> Loss: ', epoch_loss, '\t Accuracy: ', acc)
            print('Test ===> Loss: ', epoch_loss_test, '\t Accuracy: ',
                  acc_test)
            train_loss.append(epoch_loss)
            test_loss.append(epoch_loss_test)
            if epoch_loss_test < best_loss:
                print('Save new ckpt')
                state = {
                    'epoch': epoch,
                    'encoder_state_dict': encoder.state_dict()
                }
                best_loss = epoch_loss_test
                best_acc = acc_test
                torch.save(state,
                           './ckpt/%s_cpc/checkpoint_%d.pth.tar' % (data, cv))
        accuracies.append(best_acc)
        plt.figure()
        plt.plot(np.arange(n_epochs), train_loss, label="Train")
        plt.plot(np.arange(n_epochs), test_loss, label="Test")
        plt.title("CPC Loss")
        plt.legend()
        plt.savefig(os.path.join("./plots/%s_cpc/loss_%d.pdf" % (data, cv)))
    print('=======> Performance Summary:')
    print('Accuracy: %.2f +- %.2f' %
          (100 * np.mean(accuracies), 100 * np.std(accuracies)))
class ClassificationPerformanceExperiment():
    def __init__(self, n_states=4, encoding_size=10, path='simulation', cv=0, hidden_size=100, in_channel=3, window_size=50):
        # Load or train a TNC encoder
        if not os.path.exists("./ckpt/%s/checkpoint_%d.pth.tar"%(path,cv)):
            raise ValueError("No checkpoint for an encoder")
        checkpoint = torch.load('./ckpt/%s/checkpoint_%d.pth.tar'%(path, cv))
        self.encoder = RnnEncoder(hidden_size=hidden_size, in_channel=in_channel, encoding_size=encoding_size)
        self.encoder.load_state_dict(checkpoint['encoder_state_dict'])
        self.classifier = StateClassifier(input_size=encoding_size, output_size=n_states)

        self.n_states = n_states
        self.cv = cv
        # Build a new encoder to train end-to-end with a classifier
        self.e2e_model = E2EStateClassifier(hidden_size=hidden_size, in_channel=in_channel, encoding_size=encoding_size, output_size=n_states)

        data_path = './data/HAR_data/' if 'har' in path else './data/simulated_data/'
        self.train_loader, self.valid_loader, self.test_loader = create_simulated_dataset\
            (window_size=window_size, path=data_path, batch_size=100)

    def _train_end_to_end(self, lr):
        self.e2e_model.train()
        loss_fn = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.e2e_model.parameters(), lr=lr)

        epoch_loss, epoch_auc = 0, 0
        epoch_acc = 0
        batch_count = 0
        y_all, prediction_all = [], []
        for i, (x, y) in enumerate(self.train_loader):
            # if i>5:
            #     break
            optimizer.zero_grad()
            prediction = self.e2e_model(x)
            state_prediction = torch.argmax(prediction, dim=1)
            loss = loss_fn(prediction, y.long())
            loss.backward()
            optimizer.step()
            y_all.append(y)
            prediction_all.append(prediction.detach().cpu().numpy())
            epoch_acc += torch.eq(state_prediction, y).sum().item()/len(x)
            epoch_loss += loss.item()
            batch_count += 1
        y_all = np.concatenate(y_all, 0)
        prediction_all = np.concatenate(prediction_all, 0)
        prediction_class_all = np.argmax(prediction_all, -1)
        y_onehot_all = np.zeros(prediction_all.shape)
        y_onehot_all[np.arange(len(y_onehot_all)), y_all.astype(int)] = 1
        epoch_auc = roc_auc_score(y_onehot_all, prediction_all)
        c = confusion_matrix(y_all.astype(int), prediction_class_all)
        return epoch_loss / batch_count, epoch_acc / batch_count, epoch_auc, c

    def _train_tnc_classifier(self, lr):
        self.classifier.train()
        self.encoder.eval()
        loss_fn = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.classifier.parameters(), lr=lr)

        epoch_loss, epoch_auc = 0, 0
        epoch_acc = 0
        batch_count = 0
        y_all, prediction_all = [], []
        for i, (x, y) in enumerate(self.train_loader):
            if i > 30:
                break
            optimizer.zero_grad()
            encodings = self.encoder(x)
            prediction = self.classifier(encodings)
            state_prediction = torch.argmax(prediction, dim=1)
            loss = loss_fn(prediction, y.long())
            loss.backward()
            optimizer.step()
            y_all.append(y)
            prediction_all.append(prediction.detach().cpu().numpy())

            epoch_acc += torch.eq(state_prediction, y).sum().item()/len(x)
            epoch_loss += loss.item()
            batch_count += 1
        y_all = np.concatenate(y_all, 0)
        prediction_all = np.concatenate(prediction_all, 0)
        prediction_class_all = np.argmax(prediction_all, -1)
        y_onehot_all = np.zeros(prediction_all.shape)
        y_onehot_all[np.arange(len(y_onehot_all)), y_all.astype(int)] = 1
        epoch_auc = roc_auc_score(y_onehot_all, prediction_all)
        c = confusion_matrix(y_all.astype(int), prediction_class_all)
        return epoch_loss / batch_count, epoch_acc / batch_count, epoch_auc, c

    def _test(self, model):
        model.eval()
        loss_fn = torch.nn.CrossEntropyLoss()
        data_loader = self.valid_loader

        epoch_loss, epoch_auc = 0, 0
        epoch_acc = 0
        batch_count = 0
        y_all, prediction_all = [], []
        for x, y in data_loader:
            prediction = model(x)
            state_prediction = torch.argmax(prediction, -1)
            loss = loss_fn(prediction, y.long())
            y_all.append(y)
            prediction_all.append(prediction.detach().cpu().numpy())

            epoch_acc += torch.eq(state_prediction, y).sum().item()/len(x)
            epoch_loss += loss.item()
            batch_count += 1
        y_all = np.concatenate(y_all, 0)
        prediction_all = np.concatenate(prediction_all, 0)
        y_onehot_all = np.zeros(prediction_all.shape)
        prediction_class_all = np.argmax(prediction_all, -1)
        y_onehot_all[np.arange(len(y_onehot_all)), y_all.astype(int)] = 1
        epoch_auc = roc_auc_score(y_onehot_all, prediction_all)
        c = confusion_matrix(y_all.astype(int), prediction_class_all)
        return epoch_loss / batch_count, epoch_acc / batch_count, epoch_auc, c

    def run(self, data, n_epochs, lr_e2e, lr_cls=0.01):
        tnc_acc, tnc_loss, tnc_auc = [], [], []
        etoe_acc, etoe_loss, etoe_auc = [], [], []
        tnc_acc_test, tnc_loss_test, tnc_auc_test = [], [], []
        etoe_acc_test, etoe_loss_test, etoe_auc_test = [], [], []
        for epoch in range(n_epochs):
            loss, acc, auc, _ = self._train_tnc_classifier(lr_cls)
            tnc_acc.append(acc)
            tnc_loss.append(loss)
            tnc_auc.append(auc)
            # loss, acc, auc, _ = self._train_end_to_end(lr_e2e)
            loss, acc, auc = 0, 0, 0
            etoe_acc.append(acc)
            etoe_loss.append(loss)
            etoe_auc.append(auc)
            # Test
            loss, acc, auc, c_mtx_enc = self._test(model=torch.nn.Sequential(self.encoder, self.classifier))
            tnc_acc_test.append(acc)
            tnc_loss_test.append(loss)
            tnc_auc_test.append(auc)
            # loss, acc, auc, c_mtx_e2e = self._test(model=self.e2e_model) #torch.nn.Sequential(self.encoder, self.classifier))
            loss, acc, auc, c_mtx_e2e = 0, 0, 0, 0
            etoe_acc_test.append(acc)
            etoe_loss_test.append(loss)
            etoe_auc_test.append(auc)

            if epoch%5 ==0:
                print('***** Epoch %d *****'%epoch)
                print('TNC =====> Training Loss: %.3f \t Training Acc: %.3f \t Training AUC: %.3f '
                      '\t Test Loss: %.3f \t Test Acc: %.3f \t Test AUC: %.3f'
                      % (tnc_loss[-1], tnc_acc[-1], tnc_auc[-1], tnc_loss_test[-1], tnc_acc_test[-1], tnc_auc_test[-1]))
                print('End-to-End =====> Training Loss: %.3f \t Training Acc: %.3f \t Training AUC: %.3f'
                      ' \t Test Loss: %.3f \t Test Acc: %.3f \t Test AUC: %.3f'
                      % (etoe_loss[-1], etoe_acc[-1], etoe_auc[-1], etoe_loss_test[-1], etoe_acc_test[-1], etoe_auc_test[-1]))

        # Save performance plots
        plt.figure()
        plt.plot(np.arange(n_epochs), tnc_loss_test, label="tnc test")
        plt.plot(np.arange(n_epochs), etoe_loss_test, label="e2e test")
        plt.title("Loss trend for the e2e and tnc framework")
        plt.legend()
        plt.savefig(os.path.join("./plots/%s"%data, "classification_loss_comparison.pdf"))

        plt.figure()
        plt.plot(np.arange(n_epochs), tnc_acc, label="tnc train")
        plt.plot(np.arange(n_epochs), etoe_acc, label="e2e train")
        plt.plot(np.arange(n_epochs), tnc_acc_test, label="tnc test")
        plt.plot(np.arange(n_epochs), etoe_acc_test, label="e2e test")
        plt.title("Accuracy trend for the e2e and tnc model")
        plt.legend()
        plt.savefig(os.path.join("./plots/%s"%data, "classification_accuracy_comparison_%d.pdf"%self.cv))

        plt.figure()
        plt.plot(np.arange(n_epochs), tnc_auc, label="tnc train")
        plt.plot(np.arange(n_epochs), etoe_auc, label="e2e train")
        plt.plot(np.arange(n_epochs), tnc_auc_test, label="tnc test")
        plt.plot(np.arange(n_epochs), etoe_auc_test, label="e2e test")
        plt.title("AUC trend for the e2e and TNC model")
        plt.legend()
        plt.savefig(os.path.join("./plots/%s" % data, "classification_auc_comparison_%d.pdf"%self.cv))

        df_cm = pd.DataFrame(c_mtx_enc, index=[i for i in ['']*self.n_states],
                             columns=[i for i in ['']*self.n_states])
        plt.figure(figsize=(10, 10))
        sns.heatmap(df_cm, annot=True)
        plt.savefig(os.path.join("./plots/%s"%data, "encoder_cf_matrix.pdf"))
        return tnc_acc_test[-1], tnc_auc_test[-1], etoe_acc_test[-1], etoe_auc_test[-1]
Exemplo n.º 8
0
def run_test(data, e2e_lr, tnc_lr, cpc_lr, trip_lr, data_path, window_size,
             n_cross_val):
    # Load data
    with open(os.path.join(data_path, 'x_train.pkl'), 'rb') as f:
        x = pickle.load(f)
    with open(os.path.join(data_path, 'state_train.pkl'), 'rb') as f:
        y = pickle.load(f)
    with open(os.path.join(data_path, 'x_test.pkl'), 'rb') as f:
        x_test = pickle.load(f)
    with open(os.path.join(data_path, 'state_test.pkl'), 'rb') as f:
        y_test = pickle.load(f)
    T = x.shape[-1]
    x_window = np.split(x[:, :, :window_size * (T // window_size)],
                        (T // window_size), -1)
    y_window = np.concatenate(
        np.split(y[:, :window_size * (T // window_size)], (T // window_size),
                 -1), 0).astype(int)
    x_window = torch.Tensor(np.concatenate(x_window, 0))
    y_window = torch.Tensor(
        np.array([np.bincount(yy).argmax() for yy in y_window]))

    x_window_test = np.split(x_test[:, :, :window_size * (T // window_size)],
                             (T // window_size), -1)
    y_window_test = np.concatenate(
        np.split(y_test[:, :window_size * (T // window_size)],
                 (T // window_size), -1), 0).astype(int)
    x_window_test = torch.Tensor(np.concatenate(x_window_test, 0))
    y_window_test = torch.Tensor(
        np.array([np.bincount(yy).argmax() for yy in y_window_test]))

    testset = torch.utils.data.TensorDataset(x_window_test, y_window_test)
    test_loader = torch.utils.data.DataLoader(testset,
                                              batch_size=100,
                                              shuffle=True)

    del x, y, x_test, y_test
    e2e_accs, e2e_aucs, e2e_auprcs = [], [], []
    tnc_accs, tnc_aucs, tnc_auprcs = [], [], []
    cpc_accs, cpc_aucs, cpc_auprcs = [], [], []
    trip_accs, trip_aucs, trip_auprcs = [], [], []
    for cv in range(n_cross_val):
        shuffled_inds = list(range(len(x_window)))
        random.shuffle(shuffled_inds)
        x_window = x_window[shuffled_inds]
        y_window = y_window[shuffled_inds]
        n_train = int(0.7 * len(x_window))
        X_train, X_test = x_window[:n_train], x_window[n_train:]
        y_train, y_test = y_window[:n_train], y_window[n_train:]

        trainset = torch.utils.data.TensorDataset(X_train, y_train)
        validset = torch.utils.data.TensorDataset(X_test, y_test)

        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=200,
                                                   shuffle=False)
        valid_loader = torch.utils.data.DataLoader(validset,
                                                   batch_size=200,
                                                   shuffle=False)

        # Define baseline models
        if data == 'waveform':
            encoding_size = 64
            n_classes = 4

            e2e_model = WFEncoder(encoding_size=encoding_size,
                                  classify=True,
                                  n_classes=n_classes).to(device)

            tnc_encoder = WFEncoder(encoding_size=encoding_size).to(device)
            if not os.path.exists(
                    './ckpt/waveform/checkpoint_%d.pth.tar' % cv):
                RuntimeError('Checkpoint for TNC encoder does not exist!')
            tnc_checkpoint = torch.load(
                './ckpt/waveform/checkpoint_%d.pth.tar' % cv)
            tnc_encoder.load_state_dict(tnc_checkpoint['encoder_state_dict'])
            tnc_classifier = WFClassifier(encoding_size=encoding_size,
                                          output_size=4)
            tnc_model = torch.nn.Sequential(tnc_encoder,
                                            tnc_classifier).to(device)

            cpc_encoder = WFEncoder(encoding_size=encoding_size).to(device)
            if not os.path.exists(
                    './ckpt/waveform_cpc/checkpoint_%d.pth.tar' % cv):
                RuntimeError('Checkpoint for CPC encoder does not exist!')
            cpc_checkpoint = torch.load(
                './ckpt/waveform_cpc/checkpoint_%d.pth.tar' % cv)
            cpc_encoder.load_state_dict(cpc_checkpoint['encoder_state_dict'])
            cpc_classifier = WFClassifier(encoding_size=encoding_size,
                                          output_size=4)
            cpc_model = torch.nn.Sequential(cpc_encoder,
                                            cpc_classifier).to(device)

            trip_encoder = WFEncoder(encoding_size=encoding_size).to(device)
            if not os.path.exists(
                    './ckpt/waveform_trip/checkpoint_%d.pth.tar' % cv):
                RuntimeError(
                    'Checkpoint for Triplet Loss encoder does not exist!')
            trip_checkpoint = torch.load(
                './ckpt/waveform_trip/checkpoint_%d.pth.tar' % cv)
            trip_encoder.load_state_dict(trip_checkpoint['encoder_state_dict'])
            trip_classifier = WFClassifier(encoding_size=encoding_size,
                                           output_size=4)
            trip_model = torch.nn.Sequential(trip_encoder,
                                             trip_classifier).to(device)
            n_epochs = 8
            n_epoch_e2e = 8

        elif data == 'simulation':
            encoding_size = 10
            e2e_model = E2EStateClassifier(hidden_size=100,
                                           in_channel=3,
                                           encoding_size=encoding_size,
                                           output_size=4,
                                           device=device)

            tnc_encoder = RnnEncoder(hidden_size=100,
                                     in_channel=3,
                                     encoding_size=encoding_size,
                                     device=device)
            tnc_checkpoint = torch.load(
                './ckpt/simulation/checkpoint_%d.pth.tar' % cv)
            tnc_encoder.load_state_dict(tnc_checkpoint['encoder_state_dict'])
            tnc_classifier = StateClassifier(input_size=encoding_size,
                                             output_size=4).to(device)
            tnc_model = torch.nn.Sequential(tnc_encoder,
                                            tnc_classifier).to(device)

            cpc_encoder = RnnEncoder(hidden_size=100,
                                     in_channel=3,
                                     encoding_size=encoding_size,
                                     device=device)
            cpc_checkpoint = torch.load(
                './ckpt/simulation_cpc/checkpoint_%d.pth.tar' % cv)
            cpc_encoder.load_state_dict(cpc_checkpoint['encoder_state_dict'])
            cpc_classifier = StateClassifier(input_size=encoding_size,
                                             output_size=4).to(device)
            cpc_model = torch.nn.Sequential(cpc_encoder,
                                            cpc_classifier).to(device)

            trip_encoder = RnnEncoder(hidden_size=100,
                                      in_channel=3,
                                      encoding_size=encoding_size,
                                      device=device)
            trip_checkpoint = torch.load(
                './ckpt/simulation_trip/checkpoint_%d.pth.tar' % cv)
            trip_encoder.load_state_dict(trip_checkpoint['encoder_state_dict'])
            trip_classifier = StateClassifier(input_size=encoding_size,
                                              output_size=4).to(device)
            trip_model = torch.nn.Sequential(trip_encoder,
                                             trip_classifier).to(device)
            n_epochs = 30
            n_epoch_e2e = 100

        elif data == 'har':
            encoding_size = 10
            e2e_model = E2EStateClassifier(hidden_size=100,
                                           in_channel=561,
                                           encoding_size=encoding_size,
                                           output_size=6,
                                           device=device)

            tnc_encoder = RnnEncoder(hidden_size=100,
                                     in_channel=561,
                                     encoding_size=encoding_size,
                                     device=device)
            tnc_checkpoint = torch.load('./ckpt/har/checkpoint_%d.pth.tar' %
                                        cv)
            tnc_encoder.load_state_dict(tnc_checkpoint['encoder_state_dict'])
            tnc_classifier = StateClassifier(input_size=encoding_size,
                                             output_size=6).to(device)
            tnc_model = torch.nn.Sequential(tnc_encoder,
                                            tnc_classifier).to(device)

            cpc_encoder = RnnEncoder(hidden_size=100,
                                     in_channel=561,
                                     encoding_size=encoding_size,
                                     device=device)
            cpc_checkpoint = torch.load(
                './ckpt/har_cpc/checkpoint_%d.pth.tar' % cv)
            cpc_encoder.load_state_dict(cpc_checkpoint['encoder_state_dict'])
            cpc_classifier = StateClassifier(input_size=encoding_size,
                                             output_size=6).to(device)
            cpc_model = torch.nn.Sequential(cpc_encoder,
                                            cpc_classifier).to(device)

            trip_encoder = RnnEncoder(hidden_size=100,
                                      in_channel=561,
                                      encoding_size=encoding_size,
                                      device=device)
            trip_checkpoint = torch.load(
                './ckpt/har_trip/checkpoint_%d.pth.tar' % cv)
            trip_encoder.load_state_dict(trip_checkpoint['encoder_state_dict'])
            trip_classifier = StateClassifier(input_size=encoding_size,
                                              output_size=6).to(device)
            trip_model = torch.nn.Sequential(trip_encoder,
                                             trip_classifier).to(device)
            n_epochs = 50
            n_epoch_e2e = 100

        # Train the model
        # ***** E2E *****
        best_acc_e2e, best_auc_e2e, best_auprc_e2e = train(
            train_loader,
            valid_loader,
            e2e_model,
            e2e_lr,
            data_type=data,
            n_epochs=n_epoch_e2e,
            type='e2e',
            cv=cv)
        print('E2E: ', best_acc_e2e * 100, best_auc_e2e, best_auprc_e2e)
        # ***** TNC *****
        best_acc_tnc, best_auc_tnc, best_auprc_tnc = train(train_loader,
                                                           valid_loader,
                                                           tnc_classifier,
                                                           tnc_lr,
                                                           encoder=tnc_encoder,
                                                           data_type=data,
                                                           n_epochs=n_epochs,
                                                           type='tnc',
                                                           cv=cv)
        print('TNC: ', best_acc_tnc * 100, best_auc_tnc, best_auprc_tnc)
        # ***** CPC *****
        best_acc_cpc, best_auc_cpc, best_auprc_cpc = train(train_loader,
                                                           valid_loader,
                                                           cpc_classifier,
                                                           cpc_lr,
                                                           encoder=cpc_encoder,
                                                           data_type=data,
                                                           n_epochs=n_epochs,
                                                           type='cpc',
                                                           cv=cv)
        print('CPC: ', best_acc_cpc * 100, best_auc_cpc, best_auprc_cpc)
        # ***** Trip *****
        best_acc_trip, best_auc_trip, best_auprc_trip = train(
            train_loader,
            valid_loader,
            trip_classifier,
            trip_lr,
            encoder=trip_encoder,
            data_type=data,
            n_epochs=n_epochs,
            type='trip',
            cv=cv)
        print('TRIP: ', best_acc_trip * 100, best_auc_trip, best_auprc_trip)

        if data == 'waveform':
            # The waveform dataset is very small and sparse. If due to class imbalance there are no samples of a
            # particular class in the test set, report the validation performance
            _, test_acc_e2e, test_auc_e2e, test_auprc_e2e, _ = epoch_run(
                e2e_model, dataloader=valid_loader, train=False)
            _, test_acc_tnc, test_auc_tnc, test_auprc_tnc, _ = epoch_run_encoder(
                tnc_encoder,
                tnc_classifier,
                dataloader=valid_loader,
                train=False)
            _, test_acc_cpc, test_auc_cpc, test_auprc_cpc, _ = epoch_run_encoder(
                cpc_encoder,
                cpc_classifier,
                dataloader=valid_loader,
                train=False)
            _, test_acc_trip, test_auc_trip, test_auprc_trip, _ = epoch_run_encoder(
                trip_encoder,
                trip_classifier,
                dataloader=valid_loader,
                train=False)
        else:
            _, test_acc_e2e, test_auc_e2e, test_auprc_e2e, _ = epoch_run(
                e2e_model, dataloader=test_loader, train=False)
            _, test_acc_tnc, test_auc_tnc, test_auprc_tnc, _ = epoch_run_encoder(
                tnc_encoder,
                tnc_classifier,
                dataloader=test_loader,
                train=False)
            _, test_acc_cpc, test_auc_cpc, test_auprc_cpc, _ = epoch_run_encoder(
                cpc_encoder,
                cpc_classifier,
                dataloader=test_loader,
                train=False)
            _, test_acc_trip, test_auc_trip, test_auprc_trip, _ = epoch_run_encoder(
                trip_encoder,
                trip_classifier,
                dataloader=test_loader,
                train=False)

        e2e_accs.append(test_acc_e2e)
        e2e_aucs.append(test_auc_e2e)
        e2e_auprcs.append(test_auprc_e2e)
        tnc_accs.append(test_acc_tnc)
        tnc_aucs.append(test_auc_tnc)
        tnc_auprcs.append(test_auprc_tnc)
        cpc_accs.append(test_acc_cpc)
        cpc_aucs.append(test_auc_cpc)
        cpc_auprcs.append(test_auprc_cpc)
        trip_accs.append(test_acc_trip)
        trip_aucs.append(test_auc_trip)
        trip_auprcs.append(test_auprc_trip)

        with open("./outputs/%s_classifiers.txt" % data, "a") as f:
            f.write("\n\nPerformance result for a fold")
            f.write("End-to-End model: \t AUC: %s\t Accuracy: %s \n\n" %
                    (str(best_auc_e2e), str(100 * best_acc_e2e)))
            f.write("TNC model: \t AUC: %s\t Accuracy: %s \n\n" %
                    (str(best_auc_tnc), str(100 * best_acc_tnc)))
            f.write("CPC model: \t AUC: %s\t Accuracy: %s \n\n" %
                    (str(best_auc_cpc), str(100 * best_acc_cpc)))
            f.write("Triplet Loss model: \t AUC: %s\t Accuracy: %s \n\n" %
                    (str(best_auc_trip), str(100 * best_acc_trip)))

        torch.cuda.empty_cache()

    print('=======> Performance Summary:')
    print(
        'E2E model: \t Accuracy: %.2f +- %.2f \t AUC: %.3f +- %.3f \t AUPRC: %.3f +- %.3f'
        % (100 * np.mean(e2e_accs), 100 * np.std(e2e_accs), np.mean(e2e_aucs),
           np.std(e2e_aucs), np.mean(e2e_auprcs), np.std(e2e_auprcs)))
    print(
        'TNC model: \t Accuracy: %.2f +- %.2f \t AUC: %.3f +- %.3f \t AUPRC: %.3f +- %.3f'
        % (100 * np.mean(tnc_accs), 100 * np.std(tnc_accs), np.mean(tnc_aucs),
           np.std(tnc_aucs), np.mean(tnc_auprcs), np.std(tnc_auprcs)))
    print(
        'CPC model: \t Accuracy: %.2f +- %.2f \t AUC: %.3f +- %.3f \t AUPRC: %.3f +- %.3f'
        % (100 * np.mean(cpc_accs), 100 * np.std(cpc_accs), np.mean(cpc_aucs),
           np.std(cpc_aucs), np.mean(cpc_auprcs), np.std(cpc_auprcs)))
    print(
        'Trip model: \t Accuracy: %.2f +- %.2f \t AUC: %.3f +- %.3f \t AUPRC: %.3f +- %.3f'
        %
        (100 * np.mean(trip_accs), 100 * np.std(trip_accs), np.mean(trip_aucs),
         np.std(trip_aucs), np.mean(trip_auprcs), np.std(trip_auprcs)))