def __init__(self, args, data_id, model_id, optim_id, train_loader, eval_loader, model, optimizer, scheduler_iter, scheduler_epoch): # Edit args if args.eval_every is None: args.eval_every = args.epochs if args.check_every is None: args.check_every = args.epochs if args.name is None: args.name = time.strftime("%Y-%m-%d_%H-%M-%S") if args.project is None: args.project = '_'.join([data_id, model_id]) # Move model model = model.to(args.device) if args.parallel == 'dp': model = DataParallelDistribution(model) # Init parent log_path = os.path.join(self.log_base, data_id, model_id, optim_id, args.name) super(FlowExperiment, self).__init__(model=model, optimizer=optimizer, scheduler_iter=scheduler_iter, scheduler_epoch=scheduler_epoch, log_path=log_path, eval_every=args.eval_every, check_every=args.check_every) # Store args self.create_folders() self.save_args(args) self.args = args # Store IDs self.data_id = data_id self.model_id = model_id self.optim_id = optim_id # Store data loaders self.train_loader = train_loader self.eval_loader = eval_loader # Init logging args_dict = clean_dict(vars(args), keys=self.no_log_keys) if args.log_tb: self.writer = SummaryWriter(os.path.join(self.log_path, 'tb')) self.writer.add_text("args", get_args_table(args_dict).get_html_string(), global_step=0) if args.log_wandb: wandb.init(config=args_dict, project=args.project, id=args.name, dir=self.log_path)
with open(path_args, 'rb') as f: args = pickle.load(f) ################## ## Specify data ## ################## _, _, data_shape = get_data(args) ################### ## Specify model ## ################### model = get_model(args, data_shape=data_shape) if args.parallel == 'dp': model = DataParallelDistribution(model) checkpoint = torch.load(path_check) model.load_state_dict(checkpoint['model']) print('Loaded weights for model at {}/{} epochs'.format( checkpoint['current_epoch'], args.epochs)) ############ ## Sample ## ############ path_samples = '{}/samples/sample_ep{}_s{}.png'.format( eval_args.model, checkpoint['current_epoch'], eval_args.seed) if not os.path.exists(os.path.dirname(path_samples)): os.mkdir(os.path.dirname(path_samples)) device = 'cuda' if torch.cuda.is_available() else 'cpu'
def __init__(self, args, data_id, model_id, optim_id, train_loader, eval_loader, model, optimizer, scheduler_iter, scheduler_epoch): # Edit args if args.eval_every is None or args.eval_every == 0: args.eval_every = args.epochs if args.check_every is None or args.check_every == 0: args.check_every = args.epochs if args.name is None: args.name = time.strftime("%Y-%m-%d_%H-%M-%S") if args.project is None: args.project = '_'.join([data_id, model_id]) if args.name == "debug": log_path = os.path.join(self.log_base, "debug", data_id, model_id, optim_id, f"seed{args.seed}", time.strftime("%Y-%m-%d_%H-%M-%S")) else: log_path = os.path.join(self.log_base, data_id, model_id, optim_id, f"seed{args.seed}", args.name) # Move model model = model.to(args.device) if args.parallel == 'dp': model = DataParallelDistribution(model) # Init parent super(FlowExperiment, self).__init__(model=model, optimizer=optimizer, scheduler_iter=scheduler_iter, scheduler_epoch=scheduler_epoch, log_path=log_path, eval_every=args.eval_every, check_every=args.check_every, save_samples=args.save_samples) # Store args self.create_folders() self.save_args(args) self.args = args # Store IDs self.data_id = data_id self.model_id = model_id self.optim_id = optim_id # Store data loaders self.train_loader = train_loader self.eval_loader = eval_loader # Init logging args_dict = clean_dict(vars(args), keys=self.no_log_keys) if args.log_tb: self.writer = SummaryWriter(os.path.join(self.log_path, 'tb')) self.writer.add_text("args", get_args_table(args_dict).get_html_string(), global_step=0) if args.log_wandb: wandb.init(config=args_dict, project=args.project, id=args.name, dir=self.log_path) # training params self.max_grad_norm = args.max_grad_norm # automatic mixed precision # bigger changes need to make this work with dataparallel though (@autocast() decoration on each forward) pytorch_170 = int(str(torch.__version__)[2]) >= 7 self.amp = args.amp and args.parallel != 'dp' and pytorch_170 if self.amp: # only available in pytorch 1.7.0+ self.scaler = torch.cuda.amp.GradScaler() else: self.scaler = None # save model architecture for reference self.save_architecture()
################## ## Specify data ## ################## eval_loader, data_shape, cond_shape = get_data(args, eval_only=True) ################### ## Specify model ## ################### device = 'cuda' if torch.cuda.is_available() else 'cpu' model = get_model(args, data_shape=data_shape, cond_shape=cond_shape) if args.parallel == 'dp': model = DataParallelDistribution(model) checkpoint = torch.load(path_check, map_location=torch.device(device)) model.load_state_dict(checkpoint['model']) model = model.to(device) model = model.eval() print('Loaded weights for model at {}/{} epochs'.format( checkpoint['current_epoch'], args.epochs)) ############ ## Sample ## ############ def save_images(imgs, file_path, num_bits=args.num_bits, nrow=eval_args.nrow): if not os.path.exists(os.path.dirname(file_path)): os.mkdir(os.path.dirname(file_path))
################## ## Specify data ## ################## eval_loader, data_shape, cond_shape = get_data(args, eval_only=True) ################### ## Specify model ## ################### device = 'cuda' if torch.cuda.is_available() else 'cpu' model = get_model(args, data_shape=data_shape, cond_shape=cond_shape) if args.parallel == 'dp': model = DataParallelDistribution(model) checkpoint = torch.load(path_check, map_location=torch.device(device)) model.load_state_dict(checkpoint['model']) model = model.to(device) model = model.eval() print('Loaded weights for model at {}/{} epochs'.format( checkpoint['current_epoch'], args.epochs)) ############ ## Sample ## ############ base_dir = os.path.join(f"{eval_args.model}", f"likelihoods/") if not os.path.exists(base_dir): os.mkdir(base_dir)
################## ## Specify data ## ################## eval_loader, data_shape, cond_shape = get_data(args, eval_only=True) ################### ## Specify model ## ################### device = 'cuda' if torch.cuda.is_available() else 'cpu' model = get_model(args, data_shape=data_shape, cond_shape=cond_shape) if args.parallel == 'dp': model = DataParallelDistribution(model) checkpoint = torch.load(path_check, map_location=torch.device(device)) model.load_state_dict(checkpoint['model']) model = model.to(device) model = model.eval() print('Loaded weights for model at {}/{} epochs'.format( checkpoint['current_epoch'], args.epochs)) ############## ## Evaluate ## ############## pq = PerceptualQuality(device=device, num_bits=args.num_bits, sr_scale_factor=args.sr_scale_factor)
def __init__(self, args, data_id, model_id, optim_id, model, teacher, optimizer, scheduler_iter, scheduler_epoch): # Edit args if args.eval_every is None or args.eval_every == 0: args.eval_every = args.epochs if args.name is None: args.name = time.strftime("%Y-%m-%d_%H-%M-%S") if args.project is None: args.project = '_'.join([data_id, model_id]) aug_or_abs = 'abs' if args.augment_size == 0 else f"aug{args.augment_size}" hidden = '_'.join([str(u) for u in args.hidden_units]) cond_id = args.cond_trans.lower() arch_id = f"flow_{aug_or_abs}_k{args.num_flows}_h{hidden}_{'affine' if args.affine else 'additive'}{'_actnorm' if args.actnorm else ''}" seed_id = f"seed{args.seed}" if args.name == "debug": log_path = os.path.join(self.log_base, "debug", model_id, data_id, cond_id, arch_id, seed_id, time.strftime("%Y-%m-%d_%H-%M-%S")) else: log_path = os.path.join(self.log_base, model_id, data_id, cond_id, arch_id, seed_id, args.name) # Move models model = model.to(args.device) if args.parallel == 'dp': model = DataParallelDistribution(model) # Init parent super(StudentExperiment, self).__init__(model=model, optimizer=optimizer, scheduler_iter=scheduler_iter, scheduler_epoch=scheduler_epoch, log_path=log_path, eval_every=args.eval_every) # student teacher args teacher = teacher.to(args.device) self.teacher = teacher self.teacher.eval() self.cond_size = 1 if cond_id.startswith( 'split') or cond_id.startswith('multiply') else 2 # Store args self.create_folders() self.save_args(args) self.args = args # Store IDs self.model_id = model_id self.data_id = data_id self.cond_id = cond_id self.optim_id = optim_id self.arch_id = arch_id self.seed_id = seed_id # Init logging args_dict = clean_dict(vars(args), keys=self.no_log_keys) if args.log_tb: self.writer = SummaryWriter(os.path.join(self.log_path, 'tb')) self.writer.add_text("args", get_args_table(args_dict).get_html_string(), global_step=0) if args.log_wandb: wandb.init(config=args_dict, project=args.project, id=args.name, dir=self.log_path) # training params self.max_grad_norm = args.max_grad_norm # automatic mixed precision # bigger changes need to make this work with dataparallel though (@autocast() decoration on each forward) pytorch_170 = int(str(torch.__version__)[2]) >= 7 self.amp = args.amp and args.parallel != 'dp' and pytorch_170 if self.amp: # only available in pytorch 1.7.0+ self.scaler = torch.cuda.amp.GradScaler() else: self.scaler = None # save model architecture for reference self.save_architecture()
################## ## Specify data ## ################## eval_loader, data_shape, cond_shape = get_data(args, eval_only=True) #################### ## Specify models ## #################### device = 'cuda' if torch.cuda.is_available() else 'cpu' # conditional model model = get_model(args, data_shape=data_shape, cond_shape=cond_shape) if args.parallel == 'dp': model = DataParallelDistribution(model) checkpoint = torch.load(path_check, map_location=torch.device(device)) model.load_state_dict(checkpoint['model']) model = model.to(device) model = model.eval() print('Loaded weights for conditional model at {}/{} epochs'.format( checkpoint['current_epoch'], args.epochs)) # prior model prior_model = get_model(prior_args, data_shape=(data_shape[0], data_shape[1] // args.sr_scale_factor, data_shape[2] // args.sr_scale_factor)) if prior_args.parallel == 'dp': prior_model = DataParallelDistribution(prior_model) prior_checkpoint = torch.load(path_prior_check,
args = pickle.load(f) ################## ## Specify data ## ################## eval_loader, data_shape, cond_shape = get_data(args, eval_only=True) # Adjust args args.batch_size = eval_args.batch_size ################### ## Specify model ## ################### model = get_model(args, data_shape=data_shape, cond_shape=cond_shape) if args.parallel == 'dp': model = DataParallelDistribution(model) checkpoint = torch.load(path_check) model.load_state_dict(checkpoint['model']) print('Loaded weights for model at {}/{} epochs'.format(checkpoint['current_epoch'], args.epochs)) # Load checkpoint exp.checkpoint_load('{}/check/'.format(more_args.model), device=more_args.new_device) # modify model if more_args.new_device is not None: exp.model.to(torch.device(more_args.new_device)) exp.eval_fn()
args.device = torch.device(device) ################## ## Specify data ## ################## args.batch_size = 1 eval_loader, data_shape, cond_shape = get_data(args, eval_only=True) ################### ## Specify model ## ################### model = get_model(args, data_shape=data_shape, cond_shape=cond_shape) if args.parallel == 'dp': model = DataParallelDistribution(model) checkpoint = torch.load(path_check, map_location=device) model.load_state_dict(checkpoint['model']) model = model.to(device) model = model.eval() print('Loaded weights for model at {}/{} epochs'.format( checkpoint['current_epoch'], args.epochs)) ############ ## Sample ## ############ def save_images(imgs, file_path, num_bits=args.num_bits, nrow=1): if not os.path.exists(os.path.dirname(file_path)):