예제 #1
0
 def evaluate(self, parameters, config):
     loss = 0.0
     acc = 0.0
     result = (0.0, 1, {})
     if self.step == 'k-means' and self.y_test is not None:
         # predicting labels
         y_pred_kmeans = self.kmeans.predict(self.x_test)
         # computing metrics
         acc = my_metrics.acc(self.y_test, y_pred_kmeans)
         nmi = my_metrics.nmi(self.y_test, y_pred_kmeans)
         ami = my_metrics.ami(self.y_test, y_pred_kmeans)
         ari = my_metrics.ari(self.y_test, y_pred_kmeans)
         ran = my_metrics.ran(self.y_test, y_pred_kmeans)
         h**o = my_metrics.h**o(self.y_test, y_pred_kmeans)
         print(
             out_1 %
             (self.client_id, self.f_round, acc, nmi, ami, ari, ran, h**o))
         if self.f_round % 10 == 0:  # print confusion matrix
             my_fn.print_confusion_matrix(self.y_test,
                                          y_pred_kmeans,
                                          client_id=self.client_id,
                                          path_to_out=self.out_dir)
         # plotting outcomes on the labels
         if self.outcomes_test is not None:
             times = self.outcomes_test[:, 0]
             events = self.outcomes_test[:, 1]
             my_fn.plot_lifelines_pred(times,
                                       events,
                                       y_pred_kmeans,
                                       client_id=self.client_id,
                                       path_to_out=self.out_dir)
         if self.ids_test is not None:
             pred = {'ID': self.ids_test, 'label': y_pred_kmeans}
             my_fn.dump_pred_dict('pred_client_' + str(self.client_id),
                                  pred,
                                  path_to_out=self.out_dir)
         # dumping and retrieving the results
         metrics = {
             "accuracy": acc,
             "normalized_mutual_info_score": nmi,
             "adjusted_mutual_info_score": ami,
             "adjusted_rand_score": ari,
             "rand_score": ran,
             "homogeneity_score": h**o
         }
         result = metrics.copy()
         result['client'] = self.client_id
         result['round'] = self.f_round
         my_fn.dump_result_dict('client_' + str(self.client_id),
                                result,
                                path_to_out=self.out_dir)
         result = (loss, len(self.x_test), metrics)
     return result
예제 #2
0
 def evaluate(self, parameters, config):
     loss = 0.0
     acc = 0.0
     result = ()
     metrics = {}
     if self.step == 'pretrain_ae':
         # evaluation
         loss = self.autoencoder.evaluate(self.x_test,
                                          self.x_test,
                                          verbose=0)
         metrics = {"loss": loss}
         result = metrics.copy()
         result['client'] = self.client_id
         result['round'] = self.f_round
         my_fn.dump_result_dict('client_' + str(self.client_id) + '_ae',
                                result,
                                path_to_out=self.out_dir)
         print(out_2 % (self.client_id, self.f_round, loss))
         result = (loss, len(self.x_test), {})
     elif self.step == 'k-means':
         # predicting labels
         y_pred_kmeans = self.kmeans.predict(
             self.encoder.predict(self.x_test))
         # computing metrics
         acc = my_metrics.acc(self.y_test, y_pred_kmeans)
         nmi = my_metrics.nmi(self.y_test, y_pred_kmeans)
         ami = my_metrics.ami(self.y_test, y_pred_kmeans)
         ari = my_metrics.ari(self.y_test, y_pred_kmeans)
         ran = my_metrics.ran(self.y_test, y_pred_kmeans)
         h**o = my_metrics.h**o(self.y_test, y_pred_kmeans)
         print(
             out_1 %
             (self.client_id, self.f_round, acc, nmi, ami, ari, ran, h**o))
         if self.f_round % 10 == 0:  # print confusion matrix
             my_fn.print_confusion_matrix(self.y_test,
                                          y_pred_kmeans,
                                          client_id=self.client_id,
                                          path_to_out=self.out_dir)
         # retrieving the results
         result = (loss, len(self.x_test), metrics)
     elif self.step == 'clustering':
         # evaluation
         q = self.clustering_model.predict(self.x_test, verbose=0)
         # update the auxiliary target distribution p
         p = target_distribution(q)
         # retrieving loss
         loss = self.clustering_model.evaluate(self.x_test, p, verbose=0)
         # evaluate the clustering performance using some metrics
         y_pred = q.argmax(1)
         # plotting outcomes on the labels
         if self.outcomes_test is not None:
             times = self.outcomes_test[:, 0]
             events = self.outcomes_test[:, 1]
             my_fn.plot_lifelines_pred(times,
                                       events,
                                       y_pred,
                                       client_id=self.client_id,
                                       path_to_out=self.out_dir)
         # evaluating metrics
         if self.y_test is not None:
             acc = my_metrics.acc(self.y_test, y_pred)
             nmi = my_metrics.nmi(self.y_test, y_pred)
             ami = my_metrics.ami(self.y_test, y_pred)
             ari = my_metrics.ari(self.y_test, y_pred)
             ran = my_metrics.ran(self.y_test, y_pred)
             h**o = my_metrics.h**o(self.y_test, y_pred)
             if self.f_round % 10 == 0:  # print confusion matrix
                 my_fn.print_confusion_matrix(self.y_test,
                                              y_pred,
                                              client_id=self.client_id,
                                              path_to_out=self.out_dir)
             print(out_1 % (self.client_id, self.f_round, acc, nmi, ami,
                            ari, ran, h**o))
             # dumping and retrieving the results
             metrics = {
                 "accuracy": acc,
                 "normalized_mutual_info_score": nmi,
                 "adjusted_mutual_info_score": ami,
                 "adjusted_rand_score": ari,
                 "rand_score": ran,
                 "homogeneity_score": h**o
             }
             result = metrics.copy()
             result['loss'] = loss
             result['client'] = self.client_id
             result['round'] = self.local_iter
             my_fn.dump_result_dict('client_' + str(self.client_id),
                                    result,
                                    path_to_out=self.out_dir)
         if self.id_test is not None:
             pred = {'ID': self.id_test, 'label': y_pred}
             my_fn.dump_pred_dict('pred_client_' + str(self.client_id),
                                  pred,
                                  path_to_out=self.out_dir)
         result = (loss, len(self.x_test), metrics)
     return result
예제 #3
0
    def test(self, config):
        print('Begin evaluation session...\n')
        # Generator in eval mode
        self.generator.eval()
        self.encoder.eval()

        # Set number of examples for cycle calcs
        n_sqrt_samp = 5
        n_samp = n_sqrt_samp * n_sqrt_samp

        test_imgs, test_labels, test_ids, test_outcomes = next(
            iter(self.testloader))
        times = test_outcomes[:, 0]
        events = test_outcomes[:, 1]
        test_imgs = Variable(test_imgs.type(self.TENSOR))

        # Cycle through test real -> enc -> gen
        t_imgs, t_label = test_imgs.data, test_labels
        # Encode sample real instances
        e_tzn, e_tzc, e_tzc_logits = self.encoder(t_imgs)

        computed_labels = []
        for pred in e_tzc.detach().cpu().numpy():
            computed_labels.append(pred.argmax())
        computed_labels = np.array(computed_labels)

        # computing metrics
        acc = my_metrics.acc(t_label.detach().cpu().numpy(), computed_labels)
        nmi = my_metrics.nmi(t_label.detach().cpu().numpy(), computed_labels)
        ami = my_metrics.ami(t_label.detach().cpu().numpy(), computed_labels)
        ari = my_metrics.ari(t_label.detach().cpu().numpy(), computed_labels)
        ran = my_metrics.ran(t_label.detach().cpu().numpy(), computed_labels)
        h**o = my_metrics.h**o(t_label.detach().cpu().numpy(), computed_labels)
        print(out_1 %
              (self.client_id, self.f_epoch, acc, nmi, ami, ari, ran, h**o))
        # plotting outcomes on the labels
        # if self.outcomes_loader is not None:
        my_fn.plot_lifelines_pred(times,
                                  events,
                                  computed_labels,
                                  client_id=self.client_id,
                                  path_to_out=self.out_dir)
        if self.f_epoch % 10 == 0:  # print confusion matrix
            my_fn.print_confusion_matrix(t_label.detach().cpu().numpy(),
                                         computed_labels,
                                         client_id=self.client_id,
                                         path_to_out=self.out_dir)
        # dumping and retrieving the results
        metrics = {
            "accuracy": acc,
            "normalized_mutual_info_score": nmi,
            "adjusted_mutual_info_score": ami,
            "adjusted_rand_score": ari,
            "rand_score": ran,
            "homogeneity_score": h**o
        }
        result = metrics.copy()

        # Generate sample instances from encoding
        teg_imgs = self.generator(e_tzn, e_tzc)
        # Calculate cycle reconstruction loss
        self.img_mse_loss = self.mse_loss(t_imgs, teg_imgs)
        # Save img reco cycle loss
        self.c_i.append(self.img_mse_loss.item())

        # Cycle through randomly sampled encoding -> generator -> encoder
        zn_samp, zc_samp, zc_samp_idx = sample_z(shape=n_samp,
                                                 latent_dim=self.latent_dim,
                                                 n_c=self.n_c,
                                                 cuda=self.cuda)
        # Generate sample instances
        gen_imgs_samp = self.generator(zn_samp, zc_samp)

        # Encode sample instances
        zn_e, zc_e, zc_e_logits = self.encoder(gen_imgs_samp)

        # Calculate cycle latent losses
        self.lat_mse_loss = self.mse_loss(zn_e, zn_samp)
        self.lat_xe_loss = self.xe_loss(zc_e_logits, zc_samp_idx)

        # Save latent space cycle losses
        self.c_zn.append(self.lat_mse_loss.item())
        self.c_zc.append(self.lat_xe_loss.item())

        # Save cycled and generated examples!
        if self.save_images:
            r_imgs, i_label = self.real_imgs.data[:
                                                  n_samp], self.itruth_label[:
                                                                             n_samp]
            e_zn, e_zc, e_zc_logits = self.encoder(r_imgs)
            reg_imgs = self.generator(e_zn, e_zc)
            save_image(reg_imgs.data[:n_samp],
                       self.img_dir / 'cycle_reg_%06i.png' % (self.f_epoch),
                       nrow=n_sqrt_samp,
                       normalize=True)
            save_image(gen_imgs_samp.data[:n_samp],
                       self.img_dir / 'gen_%06i.png' % (self.f_epoch),
                       nrow=n_sqrt_samp,
                       normalize=True)
            # Generate samples for specified classes
            stack_imgs = []
            for idx in range(self.n_c):
                # Sample specific class
                zn_samp, zc_samp, zc_samp_idx = sample_z(
                    shape=self.n_c,
                    latent_dim=self.latent_dim,
                    n_c=self.n_c,
                    fix_class=idx,
                    cuda=self.cuda)
                # Generate sample instances
                gen_imgs_samp = self.generator(zn_samp, zc_samp)

                if (len(stack_imgs) == 0):
                    stack_imgs = gen_imgs_samp
                else:
                    stack_imgs = torch.cat((stack_imgs, gen_imgs_samp), 0)
            # Save class-specified generated examples!
            save_image(stack_imgs,
                       self.img_dir / 'gen_classes_%06i.png' % (self.f_epoch),
                       nrow=self.n_c,
                       normalize=True)

        print("[Federated Epoch %d/%d] [Client ID %d] \n"
              "\tCycle Losses: [x: %f] [z_n: %f] [z_c: %f]\n" %
              (self.f_epoch, config['total_epochs'], self.client_id,
               self.img_mse_loss.item(), self.lat_mse_loss.item(),
               self.lat_xe_loss.item()))

        result['img_mse_loss'] = self.img_mse_loss.item()
        result['lat_mse_loss'] = self.lat_mse_loss.item()
        result['lat_xe_loss'] = self.lat_xe_loss.item()
        result['client'] = self.client_id
        result['round'] = self.f_epoch
        my_fn.dump_result_dict('client_' + str(self.client_id),
                               result,
                               path_to_out=self.out_dir)
        pred = {'ID': test_ids, 'label': computed_labels}
        my_fn.dump_pred_dict('pred_client_' + str(self.client_id),
                             pred,
                             path_to_out=self.out_dir)
예제 #4
0
        t_imgs, t_label = test_imgs.data, test_labels
        # Encode sample real instances
        e_tzn, e_tzc, e_tzc_logits = encoder(t_imgs)

        computed_labels = []
        for pred in e_tzc.detach().cpu().numpy():
            computed_labels.append(pred.argmax())
        computed_labels = np.array(computed_labels)

        # computing metrics
        acc = my_metrics.acc(t_label.detach().cpu().numpy(), computed_labels)
        nmi = my_metrics.nmi(t_label.detach().cpu().numpy(), computed_labels)
        ami = my_metrics.ami(t_label.detach().cpu().numpy(), computed_labels)
        ari = my_metrics.ari(t_label.detach().cpu().numpy(), computed_labels)
        ran = my_metrics.ran(t_label.detach().cpu().numpy(), computed_labels)
        h**o = my_metrics.h**o(t_label.detach().cpu().numpy(), computed_labels)
        if args.dataset == 'euromds':
            # plotting outcomes on the labels
            my_fn.plot_lifelines_pred(time=times,
                                      event=events,
                                      labels=computed_labels,
                                      path_to_out=path_to_out)
        if epoch % 10 == 0:  # print confusion matrix
            my_fn.print_confusion_matrix(y=t_label.detach().cpu().numpy(),
                                         y_pred=computed_labels,
                                         path_to_out=path_to_out)
        # dumping and retrieving the results
        metrics = {
            "accuracy": acc,
            "normalized_mutual_info_score": nmi,
            "adjusted_mutual_info_score": ami,