class BiomeAE(): def __init__(self, args): if args.model in ["BiomeAEL0"]: self.mlp_type = "L0" else: self.mlp_type = None self.model_alias = args.model_alias self.model= args.model self.snap_loc = os.path.join(args.vis_dir, "snap.pt") #tl.configure("runs/ds.{}".format(model_alias)) #tl.log_value(model_alias, 0) """ no stat file needed for now stat_alias = 'obj_DataStat+%s_%s' % (args.dataset_name, args.dataset_subset) stat_path = os.path.join( output_dir, '%s.pkl' % (stat_alias) ) with open(stat_path,'rb') as sf: data_stats = pickle.load(sf) """ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num self.predictor = None torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed(args.seed) self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def get_transformation(self): return None def loss_fn(self, recon_x, x, mean, log_var): if self.model in ["BiomeAE","BiomeAESnip"]: mseloss = torch.nn.MSELoss() return torch.sqrt(mseloss(recon_x, x)) if self.model in ["BiomeAEL0"]: mseloss = torch.nn.MSELoss() return torch.sqrt(mseloss(recon_x, x))+self.predictor.regularization() elif self.model =="BiomeVAE": BCE = torch.nn.functional.binary_cross_entropy( recon_x.view(-1, 28*28), x.view(-1, 28*28), reduction='sum') KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp()) return (BCE + KLD) / x.size(0) def param_l0(self): return self.predictor.param_l0() def init_fit(self, X1_train, X2_train, y_train, X1_val, X2_val, y_val, args, ): self.train_loader = get_dataloader (X1_train, X2_train, y_train, args.batch_size) self.test_loader = get_dataloader(X1_val, X2_val, y_val, args.batch_size) self.predictor = AE( encoder_layer_sizes=[X1_train.shape[1]], latent_size=args.latent_size, decoder_layer_sizes=[X2_train.shape[1]], activation=args.activation, batch_norm= args.batch_norm, dropout=args.dropout, mlp_type=self.mlp_type, conditional=args.conditional, num_labels=10 if args.conditional else 0).to(self.device) self.optimizer = torch.optim.Adam(self.predictor.parameters(), lr=args.learning_rate) self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.8) def train(self, args): if args.contr: print("Loading from ", self.snap_loc) loaded_model_para = torch.load(self.snap_loc) self.predictor.load_state_dict(loaded_model_para) t = 0 logs = defaultdict(list) iterations_per_epoch = len(self.train_loader.dataset) / args.batch_size num_iterations = int(iterations_per_epoch * args.epochs) for epoch in range(args.epochs): tracker_epoch = defaultdict(lambda: defaultdict(dict)) for iteration, (x1, x2, y) in enumerate(self.train_loader): t+=1 x1, x2, y = x1.to(self.device), x2.to(self.device), y.to(self.device) if args.conditional: x2_hat, z, mean, log_var = self.predictor(x1, y) else: x2_hat, z, mean, log_var = self.predictor(x1) for i, yi in enumerate(y): id = len(tracker_epoch) tracker_epoch[id]['x'] = z[i, 0].item() tracker_epoch[id]['y'] = z[i, 1].item() tracker_epoch[id]['label'] = yi.item() loss = self.loss_fn(x2_hat, x2, mean, log_var) self.optimizer.zero_grad() loss.backward() if (t + 1) % int(num_iterations / 10) == 0: self.scheduler.step() self.optimizer.step() #enforce non-negative if args.nonneg_weight: for layer in self.predictor.modules(): if isinstance(layer, nn.Linear): layer.weight.data.clamp_(0.0) logs['loss'].append(loss.item()) if iteration % args.print_every == 0 or iteration == len(self.train_loader)-1: print("Epoch {:02d}/{:02d} Batch {:04d}/{:d}, Loss {:9.4f}".format( epoch, args.epochs, iteration, len(self.train_loader)-1, loss.item())) if args.model =="VAE": if args.conditional: c = torch.arange(0, 10).long().unsqueeze(1) x = self.predictor.inference(n=c.size(0), c=c) else: x = self.predictor.inference(n=10) if not args.contr: print("Saving to ", self.snap_loc) torch.save(self.predictor.state_dict(), self.snap_loc) def fit(self,X1_train, X2_train, y_train, X1_val, X2_val, y_val, args,): self.init_fit(X1_train, X2_train, y_train, X1_val, X2_val, y_val, args) self.train(args) def get_graph(self): """ return nodes and weights :return: """ nodes = [] weights = [] for l, layer in enumerate(self.predictor.modules()): if isinstance(layer, nn.Linear): lin_layer =layer nodes.append(["%s"%(x) for x in list(range(lin_layer.in_features))]) weights.append(lin_layer.weight.detach().cpu().numpy().T) nodes.append(["%s"%(x) for x in list(range(lin_layer.out_features))]) #last linear layer return (nodes, weights) def predict(self,X1_val, X2_val, y_val, args): #Batch test x1, x2, y = torch.FloatTensor(X1_val).to(self.device), torch.FloatTensor(X2_val).to(self.device), torch.FloatTensor(y_val).to(self.device) if args.conditional: x2_hat, z, mean, log_var = self.predictor(x1, y) else: x2_hat, z, mean, log_var = self.predictor(x1) val_loss = self.loss_fn( x2_hat, x2, mean, log_var) print("val_loss: {:9.4f}", val_loss.item()) return x2_hat.detach().cpu().numpy() def transform(self,X1_val, X2_val, y_val, args): x1, x2, y = torch.FloatTensor(X1_val).to(self.device), torch.FloatTensor(X2_val).to( self.device), torch.FloatTensor(y_val).to(self.device) if args.conditional: x2_hat, z, mean, log_var = self.predictor(x1, y) else: x2_hat, z, mean, log_var = self.predictor(x1) return z.detach().cpu().numpy() def get_influence_matrix(self): return self.predictor.get_influence_matrix()
def main(args): model_alias = 'DeepBiome_%s+%s_%s+fea1_%s+fea2_%s+bs_%s+%s' % ( args.model, args.dataset_name, args.data_type, args.fea1,args.fea2, args.batch_size, args.extra) tl.configure("runs/ds.{}".format(model_alias)) tl.log_value(model_alias, 0) """ no stat file needed for now stat_alias = 'obj_DataStat+%s_%s' % (args.dataset_name, args.dataset_subset) stat_path = os.path.join( output_dir, '%s.pkl' % (stat_alias) ) with open(stat_path,'rb') as sf: data_stats = pickle.load(sf) """ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_num data_path = os.path.join(DATA_ROOT, "ibd_{}.pkl".format(args.data_type)) logger.info("Initializing train dataset") # load data print('==> loading data'); print() (X1_train, X2_train, y_train), (X1_val, X2_val, y_val) = load_data(data_path) train_loader = get_dataloader (X1_train, X2_train, y_train, args.batch_size) test_loader = get_dataloader(X1_val, X2_val, y_val, args.batch_size) torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed(args.seed) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') ts = time.time() def loss_fn(model, recon_x, x, mean, log_var): if model =="AE": mseloss = torch.nn.MSELoss() return torch.sqrt(mseloss(recon_x, x)) elif model =="VAE": BCE = torch.nn.functional.binary_cross_entropy( recon_x.view(-1, 28*28), x.view(-1, 28*28), reduction='sum') KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp()) return (BCE + KLD) / x.size(0) if args.model == "AE": predictor = AE( encoder_layer_sizes=[X1_train.shape[1]], latent_size=args.latent_size, decoder_layer_sizes=[X2_train.shape[1]], activation=args.activation, batch_norm= args.batch_norm, dropout=args.dropout, conditional=args.conditional, num_labels=10 if args.conditional else 0).to(device) else: predictor = VAE( encoder_layer_sizes=args.encoder_layer_sizes, latent_size=args.latent_size, decoder_layer_sizes=args.decoder_layer_sizes, activation=args.activation, batch_norm=args.batch_norm, dropout=args.dropout, conditional=args.conditional, num_labels=10 if args.conditional else 0).to(device) optimizer = torch.optim.Adam(predictor.parameters(), lr=args.learning_rate) logs = defaultdict(list) for epoch in range(args.epochs): tracker_epoch = defaultdict(lambda: defaultdict(dict)) for iteration, (x1, x2, y) in enumerate(train_loader): x1, x2, y = x1.to(device), x2.to(device), y.to(device) if args.conditional: x2_hat, z, mean, log_var = predictor(x1, y) else: x2_hat, z, mean, log_var = predictor(x1) for i, yi in enumerate(y): id = len(tracker_epoch) tracker_epoch[id]['x'] = z[i, 0].item() tracker_epoch[id]['y'] = z[i, 1].item() tracker_epoch[id]['label'] = yi.item() loss = loss_fn(args.model, x2_hat, x2, mean, log_var) optimizer.zero_grad() loss.backward() optimizer.step() logs['loss'].append(loss.item()) if iteration % args.print_every == 0 or iteration == len(train_loader)-1: print("Epoch {:02d}/{:02d} Batch {:04d}/{:d}, Loss {:9.4f}".format( epoch, args.epochs, iteration, len(train_loader)-1, loss.item())) if args.model =="VAE": if args.conditional: c = torch.arange(0, 10).long().unsqueeze(1) x = predictor.inference(n=c.size(0), c=c) else: x = predictor.inference(n=10) """ plt.figure() plt.figure(figsize=(5, 10)) for p in range(10): plt.subplot(5, 2, p+1) if args.conditional: plt.text( 0, 0, "c={:d}".format(c[p].item()), color='black', backgroundcolor='white', fontsize=8) plt.imshow(x[p].view(28, 28).cpu().data.numpy()) plt.axis('off') if not os.path.exists(os.path.join(args.fig_root, str(ts))): if not(os.path.exists(os.path.join(args.fig_root))): os.mkdir(os.path.join(args.fig_root)) os.mkdir(os.path.join(args.fig_root, str(ts))) plt.savefig( os.path.join(args.fig_root, str(ts), "E{:d}I{:d}.png".format(epoch, iteration)), dpi=300) plt.clf() plt.close('all') """ #Batch test x1, x2, y = torch.FloatTensor(X1_val).to(device), torch.FloatTensor(X2_val).to(device), torch.FloatTensor(y_val).to(device) if args.conditional: x2_hat, z, mean, log_var = predictor(x1, y) else: x2_hat, z, mean, log_var = predictor(x1) val_loss = loss_fn(args.model, x2_hat, x2, mean, log_var) print("val_loss: {:9.4f}", val_loss.item()) """