def main(): # Parse options args = Options().parse() print('Parameters:\t' + str(args)) if args.filter_sketch: assert args.dataset == 'Sketchy' if args.split_eccv_2018: assert args.dataset == 'Sketchy_extended' or args.dataset == 'Sketchy' if args.gzs_sbir: args.test = True # Read the config file and config = utils.read_config() path_dataset = config['path_dataset'] path_aux = config['path_aux'] # modify the log and check point paths ds_var = None if '_' in args.dataset: token = args.dataset.split('_') args.dataset = token[0] ds_var = token[1] str_aux = '' if args.split_eccv_2018: str_aux = 'split_eccv_2018' if args.gzs_sbir: str_aux = os.path.join(str_aux, 'generalized') args.semantic_models = sorted(args.semantic_models) model_name = '+'.join(args.semantic_models) root_path = os.path.join(path_dataset, args.dataset) path_sketch_model = os.path.join(path_aux, 'CheckPoints', args.dataset, 'sketch') path_image_model = os.path.join(path_aux, 'CheckPoints', args.dataset, 'image') path_cp = os.path.join(path_aux, 'CheckPoints', args.dataset, str_aux, model_name, str(args.dim_out)) path_log = os.path.join(path_aux, 'LogFiles', args.dataset, str_aux, model_name, str(args.dim_out)) path_results = os.path.join(path_aux, 'Results', args.dataset, str_aux, model_name, str(args.dim_out)) files_semantic_labels = [] sem_dim = 0 for f in args.semantic_models: fi = os.path.join(path_aux, 'Semantic', args.dataset, f + '.npy') files_semantic_labels.append(fi) sem_dim += list(np.load(fi, allow_pickle=True).item().values())[0].shape[0] print('Checkpoint path: {}'.format(path_cp)) print('Logger path: {}'.format(path_log)) print('Result path: {}'.format(path_results)) # Parameters for transforming the images transform_image = transforms.Compose( [transforms.Resize((args.im_sz, args.im_sz)), transforms.ToTensor()]) transform_sketch = transforms.Compose( [transforms.Resize((args.sk_sz, args.sk_sz)), transforms.ToTensor()]) # Load the dataset print('Loading data...', end='') if args.dataset == 'Sketchy': if ds_var == 'extended': photo_dir = 'extended_photo' # photo or extended_photo photo_sd = '' else: photo_dir = 'photo' photo_sd = 'tx_000000000000' sketch_dir = 'sketch' sketch_sd = 'tx_000000000000' splits = utils.load_files_sketchy_zeroshot( root_path=root_path, split_eccv_2018=args.split_eccv_2018, photo_dir=photo_dir, sketch_dir=sketch_dir, photo_sd=photo_sd, sketch_sd=sketch_sd) elif args.dataset == 'TU-Berlin': photo_dir = 'images' sketch_dir = 'sketches' photo_sd = '' sketch_sd = '' splits = utils.load_files_tuberlin_zeroshot(root_path=root_path, photo_dir=photo_dir, sketch_dir=sketch_dir, photo_sd=photo_sd, sketch_sd=sketch_sd) else: raise Exception('Wrong dataset.') # Combine the valid and test set into test set if args.gzs_sbir: perc = 0.2 _, idx_sk = np.unique(splits['tr_fls_sk'], return_index=True) tr_fls_sk_ = splits['tr_fls_sk'][idx_sk] tr_clss_sk_ = splits['tr_clss_sk'][idx_sk] _, idx_im = np.unique(splits['tr_fls_im'], return_index=True) tr_fls_im_ = splits['tr_fls_im'][idx_im] tr_clss_im_ = splits['tr_clss_im'][idx_im] if args.dataset == 'Sketchy' and args.filter_sketch: _, idx_sk = np.unique([f.split('-')[0] for f in tr_fls_sk_], return_index=True) tr_fls_sk_ = tr_fls_sk_[idx_sk] tr_clss_sk_ = tr_clss_sk_[idx_sk] idx_sk = np.sort( np.random.choice(tr_fls_sk_.shape[0], int(perc * splits['te_fls_sk'].shape[0]), replace=False)) idx_im = np.sort( np.random.choice(tr_fls_im_.shape[0], int(perc * splits['te_fls_im'].shape[0]), replace=False)) splits['te_fls_sk'] = np.concatenate( (tr_fls_sk_[idx_sk], splits['te_fls_sk']), axis=0) splits['te_clss_sk'] = np.concatenate( (tr_clss_sk_[idx_sk], splits['te_clss_sk']), axis=0) splits['te_fls_im'] = np.concatenate( (tr_fls_im_[idx_im], splits['te_fls_im']), axis=0) splits['te_clss_im'] = np.concatenate( (tr_clss_im_[idx_im], splits['te_clss_im']), axis=0) # class dictionary dict_clss = utils.create_dict_texts(splits['tr_clss_im']) data_train = DataGeneratorPaired(args.dataset, root_path, photo_dir, sketch_dir, photo_sd, sketch_sd, splits['tr_fls_sk'], splits['tr_fls_im'], splits['tr_clss_im'], transforms_sketch=transform_sketch, transforms_image=transform_image) data_test_sketch = DataGeneratorSketch(args.dataset, root_path, sketch_dir, sketch_sd, splits['te_fls_sk'], splits['te_clss_sk'], transforms=transform_sketch) data_test_image = DataGeneratorImage(args.dataset, root_path, photo_dir, photo_sd, splits['te_fls_im'], splits['te_clss_im'], transforms=transform_image) print('Done') train_sampler = WeightedRandomSampler(data_train.get_weights(), num_samples=args.epoch_size * args.batch_size, replacement=True) # PyTorch train loader train_loader = DataLoader(dataset=data_train, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.num_workers, pin_memory=True) # PyTorch test loader for sketch test_loader_sketch = DataLoader(dataset=data_test_sketch, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) # PyTorch test loader for image test_loader_image = DataLoader(dataset=data_test_image, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) # Model parameters params_model = dict() # Paths to pre-trained sketch and image models params_model['path_sketch_model'] = path_sketch_model params_model['path_image_model'] = path_image_model # Dimensions params_model['dim_out'] = args.dim_out params_model['sem_dim'] = sem_dim # Number of classes params_model['num_clss'] = len(dict_clss) # Weight (on losses) parameters params_model['lambda_se'] = args.lambda_se params_model['lambda_im'] = args.lambda_im params_model['lambda_sk'] = args.lambda_sk params_model['lambda_gen_cyc'] = args.lambda_gen_cyc params_model['lambda_gen_adv'] = args.lambda_gen_adv params_model['lambda_gen_cls'] = args.lambda_gen_cls params_model['lambda_gen_reg'] = args.lambda_gen_reg params_model['lambda_disc_se'] = args.lambda_disc_se params_model['lambda_disc_sk'] = args.lambda_disc_sk params_model['lambda_disc_im'] = args.lambda_disc_im params_model['lambda_regular'] = args.lambda_regular # Optimizers' parameters params_model['lr'] = args.lr params_model['momentum'] = args.momentum params_model['milestones'] = args.milestones params_model['gamma'] = args.gamma # Files with semantic labels params_model['files_semantic_labels'] = files_semantic_labels # Class dictionary params_model['dict_clss'] = dict_clss # Model sem_pcyc_model = SEM_PCYC(params_model) cudnn.benchmark = True # Logger print('Setting logger...', end='') logger = Logger(path_log, force=True) print('Done') # Check cuda print('Checking cuda...', end='') # Check if CUDA is enabled if args.ngpu > 0 & torch.cuda.is_available(): print('*Cuda exists*...', end='') sem_pcyc_model = sem_pcyc_model.cuda() print('Done') best_map = 0 early_stop_counter = 0 # Epoch for loop if not args.test: print('***Train***') for epoch in range(args.epochs): sem_pcyc_model.scheduler_gen.step() sem_pcyc_model.scheduler_disc.step() sem_pcyc_model.scheduler_ae.step() # train on training set losses = train(train_loader, sem_pcyc_model, epoch, args) # evaluate on validation set, map_ since map is already there print('***Validation***') valid_data = validate(test_loader_sketch, test_loader_image, sem_pcyc_model, epoch, args) map_ = np.mean(valid_data['aps@all']) print( 'mAP@all on validation set after {0} epochs: {1:.4f} (real), {2:.4f} (binary)' .format(epoch + 1, map_, np.mean(valid_data['aps@all_bin']))) del valid_data if map_ > best_map: best_map = map_ early_stop_counter = 0 utils.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': sem_pcyc_model.state_dict(), 'best_map': best_map }, directory=path_cp) else: if args.early_stop == early_stop_counter: break early_stop_counter += 1 # Logger step logger.add_scalar('semantic autoencoder loss', losses['aut_enc'].avg) logger.add_scalar('generator adversarial loss', losses['gen_adv'].avg) logger.add_scalar('generator cycle consistency loss', losses['gen_cyc'].avg) logger.add_scalar('generator classification loss', losses['gen_cls'].avg) logger.add_scalar('generator regression loss', losses['gen_reg'].avg) logger.add_scalar('generator loss', losses['gen'].avg) logger.add_scalar('semantic discriminator loss', losses['disc_se'].avg) logger.add_scalar('sketch discriminator loss', losses['disc_sk'].avg) logger.add_scalar('image discriminator loss', losses['disc_im'].avg) logger.add_scalar('discriminator loss', losses['disc'].avg) logger.add_scalar('mean average precision', map_) logger.step() # load the best model yet best_model_file = os.path.join(path_cp, 'model_best.pth') if os.path.isfile(best_model_file): print("Loading best model from '{}'".format(best_model_file)) checkpoint = torch.load(best_model_file) epoch = checkpoint['epoch'] best_map = checkpoint['best_map'] sem_pcyc_model.load_state_dict(checkpoint['state_dict']) print("Loaded best model '{0}' (epoch {1}; mAP@all {2:.4f})".format( best_model_file, epoch, best_map)) print('***Test***') valid_data = validate(test_loader_sketch, test_loader_image, sem_pcyc_model, epoch, args) print( 'Results on test set: mAP@all = {1:.4f}, Prec@100 = {0:.4f}, mAP@200 = {3:.4f}, Prec@200 = {2:.4f}, ' 'Time = {4:.6f} || mAP@all (binary) = {6:.4f}, Prec@100 (binary) = {5:.4f}, mAP@200 (binary) = {8:.4f}, ' 'Prec@200 (binary) = {7:.4f}, Time (binary) = {9:.6f} '.format( valid_data['prec@100'], np.mean(valid_data['aps@all']), valid_data['prec@200'], np.mean(valid_data['aps@200']), valid_data['time_euc'], valid_data['prec@100_bin'], np.mean(valid_data['aps@all_bin']), valid_data['prec@200_bin'], np.mean(valid_data['aps@200_bin']), valid_data['time_bin'])) print('Saving qualitative results...', end='') path_qualitative_results = os.path.join(path_results, 'qualitative_results') utils.save_qualitative_results(root_path, sketch_dir, sketch_sd, photo_dir, photo_sd, splits['te_fls_sk'], splits['te_fls_im'], path_qualitative_results, valid_data['aps@all'], valid_data['sim_euc'], valid_data['str_sim'], save_image=args.save_image_results, nq=args.number_qualit_results, best=args.save_best_results) print('Done') else: print( "No best model found at '{}'. Exiting...".format(best_model_file)) exit()
import torch import torch.optim as optim import torch.nn as nn from torch.nn.functional import interpolate from models.build_model import build_netG from data.customdataset import CustomDataset from models.losses import gdloss from options import Options """Pre-Training Generator""" opt = Options().parse() opt.phase = 'train' opt.nEpochs = 10 opt.save_fre = 10 opt.dataset = 'iseg' print(opt) data_set = CustomDataset(opt) print('Image numbers:', data_set.img_size) dataloader = torch.utils.data.DataLoader(data_set, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers)) generator = build_netG(opt) if opt.gpu_ids != '-1': num_gpus = len(opt.gpu_ids.split(',')) else:
normal_dataloader=load_case(normal=True) abnormal_dataloader=load_case(normal=False) opt = Options() opt.nc=1 opt.nz=50 opt.isize=320 opt.ndf=32 opt.ngf=32 opt.batchsize=64 opt.ngpu=1 opt.istest=True opt.lr=0.001 opt.beta1=0.5 opt.niter=None opt.dataset=None opt.model = None opt.outf=None model=BeatGAN(opt,None,device) model.G.load_state_dict(torch.load('model/beatgan_folder_0_G.pkl',map_location='cpu')) model.D.load_state_dict(torch.load('model/beatgan_folder_0_D.pkl',map_location='cpu')) model.G.eval() model.D.eval() with torch.no_grad(): abnormal_input=[]