Exemplo n.º 1
0
    def get_statistics(self, epoch, domain, data_loader, col, batch_size=600):
        log_dict, embed = {}, None

        dbis = []
        inp = []
        label = []
        embed2 = []
        for clsid in tqdm(range(data_loader.nclass), desc='Get Statistics {}'.format(domain)):
            x = data_loader.sample_from_class(clsid, batch_size)
            x = np.expand_dims(x, 0)
            z = self.sess.run(self.graph['support_z'], feed_dict={self.ph['support']: x, self.ph['query']: x[:, :2, ], self.ph['is_training']: False})
            z = z.squeeze()
            update_loss(stat.gaussian_test(z), log_dict, False)
            update_loss(stat.correlation(z), log_dict, False)

            nanasa = np.mean(z, 0, keepdims=True)
            dbis.append(np.sqrt(np.mean(np.sum(np.square(z - nanasa), 1), 0)))
            embed2.append(nanasa)
            update_loss({'mean_std': np.mean(np.std(z, 0))}, log_dict, False)
            update_loss(stat.norm(z, 'inclass_'), log_dict, False)
            update_loss(stat.pairwise_distance(z[:100,], 'inclass_'), log_dict, False)
            inp.append(z[:50, ])
            label += [clsid] * 50
        
        embed2 = np.concatenate(embed2, 0)
        update_loss(stat.norm(embed2, 'est_'), log_dict, False)
        update_loss(stat.pairwise_distance(embed2, 'est_'), log_dict, False)
        update_loss(stat.davies_bouldin_index(np.array(dbis), stat.l2_distance(embed2)), log_dict, False)

        inputs = np.concatenate(inp, axis=0)
        labels = np.array(label)
        stat.tsne_visualization(inputs, labels, os.path.join(self.logger.dir,
            'epoch{}_{}.png'.format(epoch, domain)), col)
        self.logger.print(epoch, domain + '-stat', log_dict)
Exemplo n.º 2
0
    def get_statistics(self, epoch, domain, data_loader, col, batch_size=600):
        log_dict, embed = {}, None
        if domain == 'train':
            embed = self.sess.run(self.graph['embed'])
            update_loss(stat.norm(embed, 'embed_'), log_dict, False)
            update_loss(stat.pairwise_distance(embed, 'embed_'), log_dict,
                        False)

        inp = []
        label = []
        embed2 = []
        dbis = []
        for clsid in tqdm(range(data_loader.nclass),
                          desc='Get Statistics {}'.format(domain)):
            x = data_loader.sample_from_class(clsid, batch_size)
            z = self.sess.run(self.graph['z'],
                              feed_dict={
                                  self.ph['data']: x,
                                  self.ph['is_training']: False
                              })
            update_loss(stat.gaussian_test(z), log_dict, False)
            update_loss(stat.correlation(z), log_dict, False)

            nanasa = np.mean(z, 0, keepdims=True)
            embed2.append(nanasa)
            dbis.append(np.sqrt(np.mean(np.sum(np.square(z - nanasa), 1), 0)))
            update_loss({'mean_std': np.mean(np.std(z, 0))}, log_dict, False)
            update_loss(stat.norm(z[:100], 'inclass_'), log_dict, False)
            update_loss(stat.pairwise_distance(z[:100], 'inclass_'), log_dict,
                        False)
            inp.append(z[:50, ])
            label += [clsid] * 50
        embed2 = np.concatenate(embed2, 0)
        if embed is not None:
            update_loss(
                {
                    'mean_div':
                    np.sqrt(np.sum(np.square(embed - embed2))) / embed.shape[0]
                }, log_dict, False)
        update_loss(stat.norm(embed2, 'est_'), log_dict, False)
        update_loss(stat.pairwise_distance(embed2, 'est_'), log_dict, False)
        update_loss(
            stat.davies_bouldin_index(np.array(dbis),
                                      stat.l2_distance(embed2)), log_dict,
            False)

        inputs = np.concatenate(inp, axis=0)
        labels = np.array(label)
        #stat.tsne_visualization(inputs, labels, os.path.join(self.logger.dir,
        #    'epoch{}_{}.png'.format(epoch, domain)), col)
        self.logger.print(epoch, domain + '-stat', log_dict)
        return self.terminate()