def test(data, folder, e): label_col = list(data.columns) result = data model = AE() model.load_state_dict( torch.load(f'{folder}/ep{e}data_aug.pth', map_location='cpu')) model.eval() dataset = AEDataset(data) dataloader = DataLoader(dataset=dataset, batch_size=128, shuffle=False) for inputs in tqdm(dataloader): outputs = model(inputs.float(), 'cpu') for i in range(len(outputs)): tmp = outputs[i].detach().numpy() tmp = pd.DataFrame([tmp], columns=label_col) result = pd.concat([result, tmp], ignore_index=True) result.to_csv(f'{folder}/data_augment.csv', mode='a', header=True, index=False) return result
n_lat=args.latent, dropout=args.dropout) elif args.model == 'GCAE': model = GCAE(n_feat=n_feat, n_hid=args.hidden, n_lat=args.latent, dropout=args.dropout) else: raise ValueError("You choose wrong network model") optimizer = optim.Adam(model.parameters(), lr=args.lr) criterion = nn.MSELoss() if args.checkpoint is not None: checkpoint = torch.load(args.checkpoint) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) n = len(adj) def train(epoch): t = time.time() running_loss = 0 for i in range(n): model.train() optimizer.zero_grad() _, decoded = model(features[i], adj[i], inv_adj[i]) loss = criterion(decoded, features[i]) loss.backward()
class BiomeAE(): def __init__(self, args): if args.model in ["BiomeAEL0"]: self.mlp_type = "L0" else: self.mlp_type = None self.model_alias = args.model_alias self.model= args.model self.snap_loc = os.path.join(args.vis_dir, "snap.pt") #tl.configure("runs/ds.{}".format(model_alias)) #tl.log_value(model_alias, 0) """ no stat file needed for now stat_alias = 'obj_DataStat+%s_%s' % (args.dataset_name, args.dataset_subset) stat_path = os.path.join( output_dir, '%s.pkl' % (stat_alias) ) with open(stat_path,'rb') as sf: data_stats = pickle.load(sf) """ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num self.predictor = None torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed(args.seed) self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def get_transformation(self): return None def loss_fn(self, recon_x, x, mean, log_var): if self.model in ["BiomeAE","BiomeAESnip"]: mseloss = torch.nn.MSELoss() return torch.sqrt(mseloss(recon_x, x)) if self.model in ["BiomeAEL0"]: mseloss = torch.nn.MSELoss() return torch.sqrt(mseloss(recon_x, x))+self.predictor.regularization() elif self.model =="BiomeVAE": BCE = torch.nn.functional.binary_cross_entropy( recon_x.view(-1, 28*28), x.view(-1, 28*28), reduction='sum') KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp()) return (BCE + KLD) / x.size(0) def param_l0(self): return self.predictor.param_l0() def init_fit(self, X1_train, X2_train, y_train, X1_val, X2_val, y_val, args, ): self.train_loader = get_dataloader (X1_train, X2_train, y_train, args.batch_size) self.test_loader = get_dataloader(X1_val, X2_val, y_val, args.batch_size) self.predictor = AE( encoder_layer_sizes=[X1_train.shape[1]], latent_size=args.latent_size, decoder_layer_sizes=[X2_train.shape[1]], activation=args.activation, batch_norm= args.batch_norm, dropout=args.dropout, mlp_type=self.mlp_type, conditional=args.conditional, num_labels=10 if args.conditional else 0).to(self.device) self.optimizer = torch.optim.Adam(self.predictor.parameters(), lr=args.learning_rate) self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.8) def train(self, args): if args.contr: print("Loading from ", self.snap_loc) loaded_model_para = torch.load(self.snap_loc) self.predictor.load_state_dict(loaded_model_para) t = 0 logs = defaultdict(list) iterations_per_epoch = len(self.train_loader.dataset) / args.batch_size num_iterations = int(iterations_per_epoch * args.epochs) for epoch in range(args.epochs): tracker_epoch = defaultdict(lambda: defaultdict(dict)) for iteration, (x1, x2, y) in enumerate(self.train_loader): t+=1 x1, x2, y = x1.to(self.device), x2.to(self.device), y.to(self.device) if args.conditional: x2_hat, z, mean, log_var = self.predictor(x1, y) else: x2_hat, z, mean, log_var = self.predictor(x1) for i, yi in enumerate(y): id = len(tracker_epoch) tracker_epoch[id]['x'] = z[i, 0].item() tracker_epoch[id]['y'] = z[i, 1].item() tracker_epoch[id]['label'] = yi.item() loss = self.loss_fn(x2_hat, x2, mean, log_var) self.optimizer.zero_grad() loss.backward() if (t + 1) % int(num_iterations / 10) == 0: self.scheduler.step() self.optimizer.step() #enforce non-negative if args.nonneg_weight: for layer in self.predictor.modules(): if isinstance(layer, nn.Linear): layer.weight.data.clamp_(0.0) logs['loss'].append(loss.item()) if iteration % args.print_every == 0 or iteration == len(self.train_loader)-1: print("Epoch {:02d}/{:02d} Batch {:04d}/{:d}, Loss {:9.4f}".format( epoch, args.epochs, iteration, len(self.train_loader)-1, loss.item())) if args.model =="VAE": if args.conditional: c = torch.arange(0, 10).long().unsqueeze(1) x = self.predictor.inference(n=c.size(0), c=c) else: x = self.predictor.inference(n=10) if not args.contr: print("Saving to ", self.snap_loc) torch.save(self.predictor.state_dict(), self.snap_loc) def fit(self,X1_train, X2_train, y_train, X1_val, X2_val, y_val, args,): self.init_fit(X1_train, X2_train, y_train, X1_val, X2_val, y_val, args) self.train(args) def get_graph(self): """ return nodes and weights :return: """ nodes = [] weights = [] for l, layer in enumerate(self.predictor.modules()): if isinstance(layer, nn.Linear): lin_layer =layer nodes.append(["%s"%(x) for x in list(range(lin_layer.in_features))]) weights.append(lin_layer.weight.detach().cpu().numpy().T) nodes.append(["%s"%(x) for x in list(range(lin_layer.out_features))]) #last linear layer return (nodes, weights) def predict(self,X1_val, X2_val, y_val, args): #Batch test x1, x2, y = torch.FloatTensor(X1_val).to(self.device), torch.FloatTensor(X2_val).to(self.device), torch.FloatTensor(y_val).to(self.device) if args.conditional: x2_hat, z, mean, log_var = self.predictor(x1, y) else: x2_hat, z, mean, log_var = self.predictor(x1) val_loss = self.loss_fn( x2_hat, x2, mean, log_var) print("val_loss: {:9.4f}", val_loss.item()) return x2_hat.detach().cpu().numpy() def transform(self,X1_val, X2_val, y_val, args): x1, x2, y = torch.FloatTensor(X1_val).to(self.device), torch.FloatTensor(X2_val).to( self.device), torch.FloatTensor(y_val).to(self.device) if args.conditional: x2_hat, z, mean, log_var = self.predictor(x1, y) else: x2_hat, z, mean, log_var = self.predictor(x1) return z.detach().cpu().numpy() def get_influence_matrix(self): return self.predictor.get_influence_matrix()
image_matrix = torch.cat(image_matrix) vutils.save_image(0.5 * (image_matrix + 1), im_name, nrow=BATCH_SIZE + 1) if __name__ == "__main__": device = 'cuda' from models import AE, RefineGenerator_art, RefineGenerator_face net_ae = AE() net_ae.style_encoder.reset_cls() net_ig = RefineGenerator_face() ckpt = torch.load('./models/16.pth') net_ae.load_state_dict(ckpt['ae']) net_ig.load_state_dict(ckpt['ig']) net_ae.to(device) net_ig.to(device) net_ae.eval() #net_ig.eval() data_root_colorful = './data/rgb/' #data_root_colorful = '/media/bingchen/database/images/celebaMask/CelebA_1024' data_root_sketch = './data/skt/' #data_root_sketch = './data/face_skt/' BATCH_SIZE = 3
import tables from tierpsy.helper.params import read_microns_per_pixel from tierpsy.analysis.ske_create.helperIterROI import getROIfromInd #load model model_dir_root = '/data/ajaver/onedrive/classify_strains/logs/worm_autoencoder' dnames = glob.glob(os.path.join(model_dir_root, 'AE_L64*')) d = dnames[0] embedding_size = int(d.split('AE_L')[-1].partition('_')[0]) model_path = os.path.join(d, 'checkpoint.pth.tar') print(embedding_size) model = AE(embedding_size) checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) model.load_state_dict(checkpoint['state_dict']) model.eval() #%% mask_file = '/data/ajaver/onedrive/aggregation/N2_1_Ch1_29062017_182108_comp3.hdf5' feat_file = mask_file.replace('.hdf5', '_featuresN.hdf5') w_ind = 264 ini_f = 1947 microns_per_pixel = read_microns_per_pixel(feat_file) with pd.HDFStore(feat_file, 'r') as fid: trajectories_data = fid['/trajectories_data'] skel_data = trajectories_data[(trajectories_data['skeleton_id'] >= 0)]