def set_feature_lock(self, event, idx_feature, set_to=None): if set_to is None: self.feature_lock_status[idx_feature] = np.logical_not(self.feature_lock_status[idx_feature]) else: self.feature_lock_status[idx_feature] = set_to self.feature_directoion_disentangled = feature_axis.disentangle_feature_axis_by_idx( self.feature_direction, idx_base=np.flatnonzero(self.feature_lock_status))
def __init__(self): self.latents = z_sample self.feature_direction = feature_direction self.feature_lock_status = np.zeros(num_feature).astype('bool') self.feature_directoion_disentangled = feature_axis.disentangle_feature_axis_by_idx( self.feature_direction, idx_base=np.flatnonzero(self.feature_lock_status))
def __init__(self): self.latents = np.random.randn(1, *Gs.input_shapes[0][1:]) self.feature_direction = feature_direction self.feature_lock_status = np.zeros(num_feature).astype('bool') self.feature_directoion_disentangled = feature_axis.disentangle_feature_axis_by_idx( self.feature_direction, idx_base=np.flatnonzero(self.feature_lock_status)) img_cur = gen_image(self.latents) h_img.set_data(img_cur) plt.draw()
if not os.path.exists(path_feature_direction): os.mkdir(path_feature_direction) pathfile_feature_direction = os.path.join( path_feature_direction, 'feature_direction_{}.pkl'.format(misc.gen_time_str())) dict_to_save = {'direction': feature_direction, 'name': y_name} with open(pathfile_feature_direction, 'wb') as f: pickle.dump(dict_to_save, f) ## """ disentangle correlated feature axis """ pathfile_feature_direction = glob.glob( os.path.join(path_feature_direction, 'feature_direction_*.pkl'))[-1] with open(pathfile_feature_direction, 'rb') as f: feature_direction_name = pickle.load(f) feature_direction = feature_direction_name['direction'] feature_name = np.array(feature_direction_name['name']) len_z, len_y = feature_direction.shape feature_direction_disentangled = feature_axis.disentangle_feature_axis_by_idx( feature_direction, idx_base=range(len_y // 4), idx_target=None) feature_axis.plot_feature_cos_sim(feature_direction_disentangled, feature_name=feature_name) ##
import importlib import numpy as np import src.tf_gan.feature_axis as feature_axis importlib.reload(feature_axis) vectors = np.random.rand(10, 4) print(np.sum(vectors**2, axis=0)) vectors_normalized = feature_axis.normalize_feature_axis(vectors) print(np.sum(vectors_normalized**2, axis=0)) print( feature_axis.orthogonalize_one_vector(np.array([1, 0, 0]), np.array([1, 1, 1]))) print(vectors_normalized) vectors_orthogonal = feature_axis.orthogonalize_vectors(vectors_normalized) vectors_disentangled = feature_axis.disentangle_feature_axis_by_idx( vectors, idx_base=[0], idx_target=[2, 3]) print(np.dot(vectors_normalized[:, -2], vectors_normalized[:, -1])) print(np.dot(vectors_orthogonal[:, -2], vectors_orthogonal[:, -1])) feature_axis.plot_feature_cos_sim(vectors) feature_axis.plot_feature_cos_sim(vectors_orthogonal) feature_axis.plot_feature_cos_sim(vectors_disentangled)