def propose_H(self, dataset): config = self.get_base_config(dataset) from models import get_ref_model_path h_path = get_ref_model_path(self.args, config.model.__class__.__name__, dataset.name) best_h_path = path.join(h_path, 'model.best.pth') if not path.isfile(best_h_path): raise NotImplementedError( "Please use model_setup to pretrain the networks first!") else: print(colored('Loading H1 model from %s' % best_h_path, 'red')) config.model.load_state_dict(torch.load(best_h_path)) # trainer.run_epoch(0, phase='all') # test_average_acc = config.logger.get_measure('all_accuracy').mean_epoch(epoch=0) # print("All average accuracy %s"%colored('%.4f%%'%(test_average_acc*100), 'red')) self.base_model = MahaModelWrapper(config.model, 2, intermediate_nodes=(11, )) loader = DataLoader(dataset, batch_size=self.args.batch_size, shuffle=True, num_workers=self.args.workers, pin_memory=True) self.base_model.collect_states(loader, self.args.device) self.base_model.eval()
def propose_H(self, dataset): config = self.get_base_config(dataset) import models as Models if self.default_model == 0: config.model.netid = "BCE." + config.model.netid else: config.model.netid = "MSE." + config.model.netid home_path = Models.get_ref_model_path(self.args, config.model.__class__.__name__, dataset.name, suffix_str=config.model.netid) hbest_path = path.join(home_path, 'model.best.pth') best_h_path = hbest_path trainer = IterativeTrainer(config, self.args) if not path.isfile(best_h_path): raise NotImplementedError("%s not found!, Please use setup_model to pretrain the networks first!"%best_h_path) else: print(colored('Loading H1 model from %s'%best_h_path, 'red')) config.model.load_state_dict(torch.load(best_h_path)) trainer.run_epoch(0, phase='all') test_loss = config.logger.get_measure('all_loss').mean_epoch(epoch=0) print("All average loss %s"%colored('%.4f'%(test_loss), 'red')) self.base_model = config.model self.base_model.eval()
def propose_H(self, dataset): config = self.get_base_config(dataset) # Wrap the class in KWLWrapper original_class_name = config.model.__class__.__name__ config.model = KWayLogisticWrapper(config.model) config.model = config.model.to(self.args.device) h_path = Models.get_ref_model_path(self.args, original_class_name, dataset.name, suffix_str='KLogistic') best_h_path = path.join(h_path, 'model.best.pth') trainer = IterativeTrainer(config, self.args) if not path.isfile(best_h_path): raise NotImplementedError("Please use setup_model to pretrain the networks first!") else: print(colored('Loading H1 model from %s'%best_h_path, 'red')) config.model.load_state_dict(torch.load(best_h_path)) trainer.run_epoch(0, phase='all') test_average_acc = config.logger.get_measure('all_accuracy').mean_epoch(epoch=0) print("All average accuracy %s"%colored('%.4f%%'%(test_average_acc*100), 'red')) self.base_model = config.model self.base_model.eval()
def train_autoencoder(args, model, dataset, BCE_Loss): if BCE_Loss: model.netid = "BCE." + model.netid else: model.netid = "MSE." + model.netid home_path = Models.get_ref_model_path(args, model.__class__.__name__, dataset.name, model_setup=True, suffix_str=model.netid) hbest_path = os.path.join(home_path, 'model.best.pth') hlast_path = os.path.join(home_path, 'model.last.pth') if not os.path.isdir(home_path): os.makedirs(home_path) if not os.path.isfile(hbest_path+".done"): config = get_ae_config(args, model, dataset, BCE_Loss=BCE_Loss) trainer = IterativeTrainer(config, args) print(colored('Training from scratch', 'green')) best_loss = 999999999 for epoch in range(1, config.max_epoch+1): # Track the learning rates. lrs = [float(param_group['lr']) for param_group in config.optim.param_groups] config.logger.log('LRs', lrs, epoch) config.logger.get_measure('LRs').legend = ['LR%d'%i for i in range(len(lrs))] # One epoch of train and test. trainer.run_epoch(epoch, phase='train') trainer.run_epoch(epoch, phase='test') train_loss = config.logger.get_measure('train_loss').mean_epoch() test_loss = config.logger.get_measure('test_loss').mean_epoch() config.scheduler.step(train_loss) if config.visualize: # Show the average losses for all the phases in one figure. config.logger.visualize_average_keys('.*_loss', 'Average Loss', trainer.visdom) config.logger.visualize_average_keys('.*_accuracy', 'Average Accuracy', trainer.visdom) config.logger.visualize_average('LRs', trainer.visdom) # Save the logger for future reference. torch.save(config.logger.measures, os.path.join(home_path, 'logger.pth')) # Saving a checkpoint. Enable if needed! # if args.save and epoch % 10 == 0: # print('Saving a %s at iter %s'%(colored('snapshot', 'yellow'), colored('%d'%epoch, 'yellow'))) # torch.save(config.model.state_dict(), os.path.join(home_path, 'model.%d.pth'%epoch)) if args.save and test_loss < best_loss: print('Updating the on file model with %s'%(colored('%.4f'%test_loss, 'red'))) best_loss = test_loss torch.save(config.model.state_dict(), hbest_path) torch.save({'finished':True}, hbest_path+".done") torch.save(config.model.state_dict(), hlast_path) if config.visualize: trainer.visdom.save([trainer.visdom.env]) else: print("Skipping %s"%(colored(home_path, 'yellow')))
def propose_H(self, dataset): config = self.get_base_config(dataset) from models import get_ref_model_path h_path = get_ref_model_path(self.args, config.model.__class__.__name__, dataset.name) best_h_path = path.join(h_path, 'model.best.pth') trainer = IterativeTrainer(config, self.args) if not path.isfile(best_h_path): raise NotImplementedError( "Please use model_setup to pretrain the networks first!") else: print(colored('Loading H1 model from %s' % best_h_path, 'red')) config.model.load_state_dict(torch.load(best_h_path)) trainer.run_epoch(0, phase='all') test_average_acc = config.logger.get_measure( 'all_accuracy').mean_epoch(epoch=0) print("All average accuracy %s" % colored('%.4f%%' % (test_average_acc * 100), 'red')) self.base_model = config.model self.base_model.eval()
def get_base_config(self, dataset): print("Preparing training D1 for %s" % (dataset.parent_dataset.__class__.__name__)) all_loader = DataLoader(dataset, batch_size=self.args.batch_size, num_workers=self.args.workers, pin_memory=True) # Set up the criterion criterion = nn.NLLLoss().cuda() # Set up the model model_class = Global.get_ref_classifier( dataset.name)[self.default_model] self.add_identifier = model_class.__name__ # We must create 5 instances of this class. from models import get_ref_model_path all_models = [] for mid in range(5): model = model_class() model = DeepEnsembleWrapper(model) model = model.to(self.args.device) h_path = get_ref_model_path(self.args, model_class.__name__, dataset.name, suffix_str='DE.%d' % mid) best_h_path = path.join(h_path, 'model.best.pth') if not path.isfile(best_h_path): raise NotImplementedError( "Please use setup_model to pretrain the networks first! Can't find %s" % best_h_path) else: print(colored('Loading H1 model from %s' % best_h_path, 'red')) model.load_state_dict(torch.load(best_h_path)) model.eval() all_models.append(model) master_model = DeepEnsembleMasterWrapper(all_models) # Set up the config config = IterativeTrainerConfig() config.name = '%s-CLS' % (self.args.D1) config.phases = { 'all': { 'dataset': all_loader, 'backward': False }, } config.criterion = criterion config.classification = True config.cast_float_label = False config.stochastic_gradient = True config.model = master_model config.optim = None config.autoencoder_target = False config.visualize = False config.logger = Logger() return config
def propose_H(self, dataset): assert self.default_model > 0, 'KNN needs K>0' if self.base_model is not None: self.base_model.base_data = None self.base_model = None # Set up the base0-model base_model = Global.get_ref_classifier(dataset.name)[0]().to( self.args.device) from models import get_ref_model_path home_path = get_ref_model_path(self.args, base_model.__class__.__name__, dataset.name) hbest_path = path.join(home_path, 'model.best.pth') best_h_path = hbest_path print(colored('Loading H1 model from %s' % best_h_path, 'red')) base_model.load_state_dict(torch.load(best_h_path)) base_model.eval() if dataset.name in Global.mirror_augment: print(colored("Mirror augmenting %s" % dataset.name, 'green')) new_train_ds = dataset + MirroredDataset(dataset) dataset = new_train_ds # Initialize the multi-threaded loaders. all_loader = DataLoader(dataset, batch_size=self.args.batch_size, num_workers=1, pin_memory=True) n_data = len(dataset) n_dim = base_model.partial_forward(dataset[0][0].to( self.args.device).unsqueeze(0)).numel() print('nHidden %d' % (n_dim)) self.base_data = torch.zeros(n_data, n_dim, dtype=torch.float32) base_ind = 0 with torch.set_grad_enabled(False): with tqdm(total=len(all_loader), disable=bool(os.environ.get("DISABLE_TQDM", False))) as pbar: pbar.set_description('Caching X_train for %d-nn' % self.default_model) for i, (x, _) in enumerate(all_loader): n_data = x.size(0) output = base_model.partial_forward(x.to( self.args.device)).data self.base_data[base_ind:base_ind + n_data].copy_(output) base_ind = base_ind + n_data pbar.update() # self.base_data = torch.cat([x.view(1, -1) for x,_ in dataset]) self.base_model = AEKNNModel(base_model, self.base_data, k=self.default_model, SV=True).to(self.args.device) self.base_model.eval()
def needs_processing(args, dataset_class, models, suffix): """ This function checks whether this model is already trained and can be skipped. """ for model in models: for suf in suffix: home_path = Models.get_ref_model_path(args, model.__name__, dataset_class.__name__, model_setup=True, suffix_str=suf) hbest_path = os.path.join(home_path, 'model.best.pth.done') if not os.path.isfile(hbest_path): return True return False
def get_base_config(self, dataset): print("Preparing training D1 for %s" % (dataset.parent_dataset.__class__.__name__)) all_loader = DataLoader(dataset, batch_size=self.args.batch_size, num_workers=self.args.workers, pin_memory=True) # Set up the model model = Global.get_ref_pixelcnn(dataset.name)[self.default_model]().to( self.args.device) self.add_identifier = model.__class__.__name__ # Load the snapshot from models import get_ref_model_path h_path = get_ref_model_path(self.args, model.__class__.__name__, dataset.name, suffix_str=model.netid) best_h_path = path.join(h_path, 'model.best.pth') if not path.isfile(best_h_path): raise NotImplementedError( "Please use setup_model to pretrain the networks first! Can't find %s" % best_h_path) else: print(colored('Loading H1 model from %s' % best_h_path, 'red')) model.load_state_dict(torch.load(best_h_path)) model.eval() # Set up the criterion criterion = PCNN_Loss(one_d=(model.input_channels == 1)).to( self.args.device) # Set up the config config = IterativeTrainerConfig() config.name = '%s-pcnn' % (self.args.D1) config.phases = { 'all': { 'dataset': all_loader, 'backward': False }, } config.criterion = criterion config.classification = False config.cast_float_label = False config.autoencoder_target = True config.stochastic_gradient = True config.model = model config.optim = None config.visualize = False config.logger = Logger() return config
def run_epoch(self, epoch, phase='train'): # Retrieve the appropriate config. config = self.config.phases[phase] dataset = config['dataset'] backward = config['backward'] phase_name = phase print("Doing %s" % colored(phase, 'green')) model = self.config.model visualize = self.config.visualize criterion = self.config.criterion optimizer = self.config.optim logger = self.config.logger stochastic = self.config.stochastic_gradient classification = self.config.classification #print("self.config.name:" + self.config.name) home_path = Models.get_ref_model_path(self.args, model.__class__.__name__, self.config.name, model_setup=True, suffix_str="CCC") dump_path = os.path.join(home_path, 'dump') if not os.path.isdir(dump_path): os.makedirs(dump_path) # See the network to the target mode. if backward: model.train() torch.set_grad_enabled(True) else: model.eval() torch.set_grad_enabled(False) start_time = timeit.default_timer() last_viz_update = start_time # For full gradient optimization we need to rescale the loss # to calculate the gradient correctly. loss_scaler = 1 if not stochastic: loss_scaler = 1. / len(dataset.dataset) try: # TQDM sometimes throws IOError exceptions when you # try to close it. We ignore those exceptions. with tqdm(total=len(dataset)) as pbar: if backward and not stochastic: optimizer.zero_grad() for i, (image, label) in enumerate(dataset): pbar.update() if backward and stochastic: optimizer.zero_grad() # Get and prepare data. input, target, data_indices = image, None, None if torch.typename(label) == 'list': assert len( label ) == 2, 'There should be two entries in the label' # Need to unpack the label. This is for when the data provider # has the cached flag enabled, therefore the y is now (y, idx). target, data_indices = label else: target = label if self.config.autoencoder_target: target = input.clone() if self.config.cast_float_label: target = target.float().unsqueeze(1) input, target = input.to(self.device), target.to( model.get_output_device()) # Do a forward propagation and get the loss. prediction = None if data_indices is None: prediction = model(input) else: # Run in the cached mode. This is necessary to speed up # some of the underlying optimization procedures. It is not # always used though. prediction = model(input, indices=data_indices, group=phase_name) loss = criterion(prediction, target) if (self.args.dump_images): # pick one from the batch and output it #filename = phase_name + str(i) +"_epoch" + str(epoch) + ".png" #dump_file = os.path.join(dump_path,filename) #self.dump_image(input[0].cpu(),dump_file,True) if self.config.autoencoder_target: home_path = Models.get_ref_model_path( self.args, model.__class__.__name__, self.config.name, model_setup=True, suffix_str="CCC") dump_path = os.path.join(home_path, 'dump') if not os.path.isdir(dump_path): os.makedirs(dump_path) filename = phase_name + str(i) + "_epoch" + str( epoch) + ".png" dump_file = os.path.join(dump_path, filename) self.dump_image(input[0].cpu(), dump_file, True) # if this is an autoencoder run, also output the recreation for comparison filename = phase_name + str(i) + "_epoch" + str( epoch) + "_target.png" dump_file = os.path.join(dump_path, filename) self.dump_image(prediction[0].cpu(), dump_file, True) if backward: if stochastic: loss.backward() optimizer.step() except IOError, e: if e.errno != errno.EINTR: raise else: print(colored("Problem averted :D", 'green'))
matplotlib.use('Agg') import matplotlib.pyplot as plt if __name__ == "__main__": dataset = PCAM(root_path=os.path.join(args.root_path, "pcam"), extract=True, downsample=64).get_D1_train() dataloader = torch.utils.data.DataLoader(dataset, args.batch_size, True, num_workers=args.workers, pin_memory=True) model = ALIModel(dims=(3, 64, 64)).cuda() home_path = Models.get_ref_model_path(args, model.__class__.__name__, dataset.name, model_setup=True, suffix_str='base0') logger = Logger(home_path) hbest_path = os.path.join(home_path, 'model.best.pth') if not os.path.isdir(home_path): os.makedirs(home_path) best_gen_loss = 9999 if not os.path.isfile(hbest_path + ".done"): print(colored('Training from scratch', 'green')) best_loss = -1 optimizerG = optim.Adam([{ 'params': model.GenX.parameters()
def get_classifier_config(args, model, dataset, balanced=False): print("Preparing training D1 for %s" % (dataset.name)) # 80%, 20% for local train+test train_ds, valid_ds = dataset.split_dataset(0.8) if dataset.name in Global.mirror_augment: print(colored("Mirror augmenting %s" % dataset.name, 'green')) new_train_ds = train_ds + MirroredDataset(train_ds) train_ds = new_train_ds # Initialize the multi-threaded loaders. if balanced: y_train = [] for x, y in train_ds: y_train.append(y.numpy()) y_train = np.array(y_train) class_sample_count = np.array( [len(np.where(y_train == t)[0]) for t in np.unique(y_train)]) print(class_sample_count) weight = 1. / class_sample_count samples_weight = np.array([weight[t] for t in y_train]) samples_weight = torch.from_numpy(samples_weight) sampler = WeightedRandomSampler( samples_weight.type('torch.DoubleTensor'), len(samples_weight)) train_loader = DataLoader(train_ds, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True, sampler=sampler) y_val = [] for x, y in valid_ds: y_val.append(y.numpy()) y_val = np.array(y_val) class_sample_count = np.array( [len(np.where(y_val == t)[0]) for t in np.unique(y_val)]) print(class_sample_count) weight = 1. / class_sample_count samples_weight = np.array([weight[t] for t in y_val]) samples_weight = torch.from_numpy(samples_weight) sampler = WeightedRandomSampler( samples_weight.type('torch.DoubleTensor'), len(samples_weight)) valid_loader = DataLoader(valid_ds, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True, sampler=sampler) else: train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) valid_loader = DataLoader(valid_ds, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True) all_loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True) # Set up the criterion criterion = nn.NLLLoss().to(args.device) # Set up the model model = model.to(args.device) # Set up the config config = IterativeTrainerConfig() config.name = 'classifier_%s_%s' % (dataset.name, model.__class__.__name__) config.train_loader = train_loader config.valid_loader = valid_loader config.phases = { 'train': { 'dataset': train_loader, 'backward': True }, 'test': { 'dataset': valid_loader, 'backward': False }, 'all': { 'dataset': all_loader, 'backward': False }, } config.criterion = criterion config.classification = True config.stochastic_gradient = True config.visualize = not args.no_visualize config.model = model home_path = Models.get_ref_model_path(args, config.model.__class__.__name__, dataset.name, model_setup=True, suffix_str='base0') config.logger = Logger(home_path) config.optim = optim.Adam(model.parameters(), lr=1e-3) config.scheduler = optim.lr_scheduler.ReduceLROnPlateau(config.optim, patience=10, threshold=1e-2, min_lr=1e-6, factor=0.1, verbose=True) config.max_epoch = 120 if hasattr(model, 'train_config'): model_train_config = model.train_config() for key, value in model_train_config.items(): print('Overriding config.%s' % key) config.__setattr__(key, value) return config
D1 = D164.get_D1_train() emb = args.embedding_function.lower() assert emb in ["vae", "ae", "ali"] dummy_args = EasyDict() dummy_args.exp = "foo" dummy_args.experiment_path = args.experiment_path if args.encoder_loss.lower() == "bce": tag = "BCE" else: tag = "MSE" if emb == "vae": model = Global.dataset_reference_vaes[args.dataset][0]() home_path = Models.get_ref_model_path(dummy_args, model.__class__.__name__, D164.name, suffix_str=tag + "." + model.netid) model_path = os.path.join(home_path, 'model.best.pth') elif emb == "ae": model = Global.dataset_reference_autoencoders[args.dataset][0]() home_path = Models.get_ref_model_path(dummy_args, model.__class__.__name__, D164.name, suffix_str=tag + "." + model.netid) model_path = os.path.join(home_path, 'model.best.pth') else: model = Global.dataset_reference_ALI[args.dataset][0]() home_path = Models.get_ref_model_path(dummy_args,
def Train_ALI(args, model, dataset, BCE_Loss=True): dataloader = torch.utils.data.DataLoader(dataset, args.batch_size, True, num_workers=args.workers, pin_memory=True) home_path = Models.get_ref_model_path(args, model.__class__.__name__, dataset.name, model_setup=True, suffix_str='base0') logger = Logger(home_path) hbest_path = os.path.join(home_path, 'model.best.pth') if not os.path.isdir(home_path): os.makedirs(home_path) best_gen_loss = 9999 if not os.path.isfile(hbest_path + ".done"): print(colored('Training from scratch', 'green')) optimizerG = optim.Adam([{ 'params': model.GenX.parameters() }, { 'params': model.GenZ.parameters() }], lr=args.lr, betas=(args.beta1, args.beta2)) optimizerD = optim.Adam([{ 'params': model.DisZ.parameters() }, { 'params': model.DisX.parameters() }, { 'params': model.DisXZ.parameters() }], lr=args.lr, betas=(args.beta1, args.beta2)) if BCE_Loss: criterion = nn.BCELoss() else: criterion = nn.MSELoss() for epoch in range(1, 100 + 1): model.train() with tqdm(total=len(dataloader), disable=bool(os.environ.get("DISABLE_TQDM", False))) as pbar: for i, (x, y) in enumerate(dataloader): pbar.update() batchsize = x.shape[0] fakeZ = torch.randn(batchsize, 512, 1, 1).cuda() pred_real, pred_fake = model.forward(x.cuda(), fakeZ) truelabel = torch.ones(batchsize) - 0.1 fakelabel = torch.zeros(batchsize) if args.random_label == True: truelabel = torch.randint( low=70, high=110, size=(1, batchsize))[0] / 100 fakelabel = torch.randint( low=-10, high=30, size=(1, batchsize))[0] / 100 truelabel = truelabel.cuda() fakelabel = fakelabel.cuda() loss_d = criterion(pred_real.view(-1), truelabel) + criterion( pred_fake.view(-1), fakelabel) loss_g = criterion(pred_fake.view(-1), truelabel) + criterion( pred_real.view(-1), fakelabel) logger.log('Disc_loss', loss_d.item(), epoch, i) logger.log('Gen_loss', loss_g.item(), epoch, i) if loss_g > args.max_loss_g: optimizerG.zero_grad() loss_g.backward() optimizerG.step() pbar.set_description( "Skipped D, Disc_loss %.4f, Gen_loss %.4f" % (loss_d.item(), loss_g.item())) elif loss_g < args.min_loss_g: optimizerD.zero_grad() loss_d.backward() optimizerD.step() pbar.set_description( "Skipped G, Disc_loss %.4f, Gen_loss %.4f" % (loss_d.item(), loss_g.item())) else: optimizerD.zero_grad() loss_d.backward(retain_graph=True) optimizerD.step() optimizerG.zero_grad() loss_g.backward() optimizerG.step() pbar.set_description("Disc_loss %.4f, Gen_loss %.4f" % (loss_d.item(), loss_g.item())) disc_loss = logger.get_measure('Disc_loss').mean_epoch() gen_loss = logger.get_measure('Gen_loss').mean_epoch() print("Discriminator loss %.4f, Generator loss %.4f" % (disc_loss, gen_loss)) logger.writer.add_scalar('disc_loss', disc_loss, epoch) logger.writer.add_scalar('gen_loss', gen_loss, epoch) # vis in tensorboard for (image, label) in dataloader: prediction = model(x=image.cuda()).data.cpu().squeeze().numpy() N = min(prediction.shape[0], 5) fig, ax = plt.subplots(N, 2) image = image.data.squeeze().numpy() for i in range(N): ax[i, 0].imshow(prediction[i]) ax[i, 1].imshow(image[i]) logger.writer.add_figure('Vis', fig, epoch) plt.close(fig) break torch.save(logger.measures, os.path.join(home_path, 'logger.pth')) if args.save and gen_loss < best_gen_loss: print('Updating the on file model with %s' % (colored('%.4f' % gen_loss, 'red'))) best_gen_loss = gen_loss torch.save(model.state_dict(), hbest_path)
def propose_H(self, dataset): assert self.default_model > 0, 'KNN needs K>0' if self.base_model is not None: self.base_model.base_data = None self.base_model = None # Set up the base-model if isinstance(self, BCEKNNSVM) or isinstance(self, MSEKNNSVM): base_model = Global.get_ref_autoencoder(dataset.name)[0]().to( self.args.device) if isinstance(self, BCEKNNSVM): base_model.netid = "BCE." + base_model.netid else: base_model.netid = "MSE." + base_model.netid home_path = Models.get_ref_model_path( self.args, base_model.__class__.__name__, dataset.name, suffix_str=base_model.netid) elif isinstance(self, VAEKNNSVM): base_model = Global.get_ref_vae(dataset.name)[0]().to( self.args.device) home_path = Models.get_ref_model_path( self.args, base_model.__class__.__name__, dataset.name, suffix_str=base_model.netid) else: raise NotImplementedError() hbest_path = path.join(home_path, 'model.best.pth') best_h_path = hbest_path print(colored('Loading H1 model from %s' % best_h_path, 'red')) base_model.load_state_dict(torch.load(best_h_path)) base_model.eval() if dataset.name in Global.mirror_augment: print(colored("Mirror augmenting %s" % dataset.name, 'green')) new_train_ds = dataset + MirroredDataset(dataset) dataset = new_train_ds # Initialize the multi-threaded loaders. all_loader = DataLoader(dataset, batch_size=self.args.batch_size, num_workers=1, pin_memory=True) n_data = len(dataset) n_dim = base_model.encode(dataset[0][0].to( self.args.device).unsqueeze(0)).numel() print('nHidden %d' % (n_dim)) self.base_data = torch.zeros(n_data, n_dim, dtype=torch.float32) base_ind = 0 with torch.set_grad_enabled(False): with tqdm(total=len(all_loader)) as pbar: pbar.set_description('Caching X_train for %d-nn' % self.default_model) for i, (x, _) in enumerate(all_loader): n_data = x.size(0) output = base_model.encode(x.to(self.args.device)).data self.base_data[base_ind:base_ind + n_data].copy_(output) base_ind = base_ind + n_data pbar.update() # self.base_data = torch.cat([x.view(1, -1) for x,_ in dataset]) self.base_model = AEKNNModel(base_model, self.base_data, k=self.default_model).to(self.args.device) self.base_model.eval()
def train_classifier(args, model, dataset): config = None for mid in range(5): home_path = Models.get_ref_model_path(args, model.__class__.__name__, dataset.name, model_setup=True, suffix_str='DE.%d' % mid) hbest_path = os.path.join(home_path, 'model.best.pth') if not os.path.isdir(home_path): os.makedirs(home_path) else: if os.path.isfile(hbest_path + ".done"): print("Skipping %s" % (colored(home_path, 'yellow'))) continue config = get_classifier_config(args, model.__class__(), dataset, mid=mid) trainer = IterativeTrainer(config, args) if not os.path.isfile(hbest_path + ".done"): print(colored('Training from scratch', 'green')) best_accuracy = -1 for epoch in range(1, config.max_epoch + 1): # Track the learning rates. lrs = [ float(param_group['lr']) for param_group in config.optim.param_groups ] config.logger.log('LRs', lrs, epoch) config.logger.get_measure('LRs').legend = [ 'LR%d' % i for i in range(len(lrs)) ] # One epoch of train and test. trainer.run_epoch(epoch, phase='train') trainer.run_epoch(epoch, phase='test') train_loss = config.logger.get_measure( 'train_loss').mean_epoch() config.scheduler.step(train_loss) if config.visualize: # Show the average losses for all the phases in one figure. config.logger.visualize_average_keys( '.*_loss', 'Average Loss', trainer.visdom) config.logger.visualize_average_keys( '.*_accuracy', 'Average Accuracy', trainer.visdom) config.logger.visualize_average('LRs', trainer.visdom) test_average_acc = config.logger.get_measure( 'test_accuracy').mean_epoch() # Save the logger for future reference. torch.save(config.logger.measures, os.path.join(home_path, 'logger.pth')) # Saving a checkpoint. Enable if needed! # if args.save and epoch % 10 == 0: # print('Saving a %s at iter %s'%(colored('snapshot', 'yellow'), colored('%d'%epoch, 'yellow'))) # torch.save(config.model.state_dict(), os.path.join(home_path, 'model.%d.pth'%epoch)) if args.save and best_accuracy < test_average_acc: print('Updating the on file model with %s' % (colored('%.4f' % test_average_acc, 'red'))) best_accuracy = test_average_acc torch.save(config.model.state_dict(), hbest_path) torch.save({'finished': True}, hbest_path + ".done") if config.visualize: trainer.visdom.save([trainer.visdom.env]) else: print("Skipping %s" % (colored(home_path, 'yellow'))) print("Loading the best model.") config.model.load_state_dict(torch.load(hbest_path)) config.model.eval() trainer.run_epoch(0, phase='all') test_average_acc = config.logger.get_measure( 'all_accuracy').mean_epoch(epoch=0) print("All average accuracy %s" % colored('%.4f%%' % (test_average_acc * 100), 'red'))
def train_variational_autoencoder(args, model, dataset, BCE_Loss=True): if BCE_Loss: model.netid = "BCE." + model.netid else: model.netid = "MSE." + model.netid home_path = Models.get_ref_model_path(args, model.__class__.__name__, dataset.name, model_setup=True, suffix_str=model.netid) hbest_path = os.path.join(home_path, 'model.best.pth') hlast_path = os.path.join(home_path, 'model.last.pth') if not os.path.isdir(home_path): os.makedirs(home_path) if not os.path.isfile(hbest_path + ".done"): config = get_vae_config(args, model, dataset, home_path, BCE_Loss) trainer = IterativeTrainer(config, args) print(colored('Training from scratch', 'green')) best_loss = 999999999 for epoch in range(1, config.max_epoch + 1): # Track the learning rates. lrs = [ float(param_group['lr']) for param_group in config.optim.param_groups ] config.logger.log('LRs', lrs, epoch) config.logger.get_measure('LRs').legend = [ 'LR%d' % i for i in range(len(lrs)) ] # One epoch of train and test. trainer.run_epoch(epoch, phase='train') trainer.run_epoch(epoch, phase='test') train_loss = config.logger.get_measure('train_loss').mean_epoch() test_loss = config.logger.get_measure('test_loss').mean_epoch() config.logger.writer.add_scalar('train_loss', train_loss, epoch) config.logger.writer.add_scalar('test_loss', test_loss, epoch) config.scheduler.step(train_loss) # vis in tensorboard for (image, label) in config.valid_loader: prediction = model(image.cuda()).data.cpu().squeeze().numpy() prediction = (prediction - prediction.min()) / ( prediction.max() - prediction.min()) if len(prediction.shape) > 3 and prediction.shape[1] == 3: prediction = prediction.transpose( (0, 2, 3, 1)) # change to N W H C N = min(prediction.shape[0], 5) fig, ax = plt.subplots(N, 2) image = image.data.squeeze().numpy() image = (image - image.min()) / (image.max() - image.min()) if len(image.shape) > 3 and image.shape[1] == 3: image = image.transpose((0, 2, 3, 1)) for i in range(N): ax[i, 0].imshow(prediction[i]) ax[i, 1].imshow(image[i]) config.logger.writer.add_figure('Vis', fig, epoch) plt.close(fig) break if config.visualize: # Show the average losses for all the phases in one figure. config.logger.visualize_average_keys('.*_loss', 'Average Loss', trainer.visdom) config.logger.visualize_average_keys('.*_accuracy', 'Average Accuracy', trainer.visdom) config.logger.visualize_average('LRs', trainer.visdom) # Save the logger for future reference. torch.save(config.logger.measures, os.path.join(home_path, 'logger.pth')) # Saving a checkpoint. Enable if needed! # if args.save and epoch % 10 == 0: # print('Saving a %s at iter %s'%(colored('snapshot', 'yellow'), colored('%d'%epoch, 'yellow'))) # torch.save(config.model.state_dict(), os.path.join(home_path, 'model.%d.pth'%epoch)) if args.save and test_loss < best_loss: print('Updating the on file model with %s' % (colored('%.4f' % test_loss, 'red'))) best_loss = test_loss torch.save(config.model.state_dict(), hbest_path) torch.save({'finished': True}, hbest_path + ".done") torch.save(config.model.state_dict(), hlast_path) if config.visualize: trainer.visdom.save([trainer.visdom.env]) else: print("Skipping %s" % (colored(home_path, 'yellow')))