def save_new_parses(parses_j, log_probs_j, save_dir, K_per_img, test_id, reverse=False): """ i : train image index k : parse index """ appendix_i = 'test' if reverse else 'train' appendix_j = 'train' if reverse else 'test' curr = 0 for train_id, K in K_per_img.items(): # get savedir paths save_dir_i = os.path.join(save_dir, appendix_i + '_%0.2i' % train_id) mkdir(save_dir_i) save_dir_ij = os.path.join(save_dir_i, appendix_j + '_%0.2i' % test_id) mkdir(save_dir_ij) # get data subset parses_ij = parses_j[curr:curr + K] log_probs_ij = log_probs_j[curr:curr + K] curr += K # save log-probs lp_file = os.path.join(save_dir_ij, 'log_probs.pt') torch.save(log_probs_ij, lp_file) # save parses for k in range(K): parse = parses_ij[k] parse_file = os.path.join(save_dir_ij, 'parse_%i.pt' % k) torch.save(parse.state_dict(), parse_file)
def save_img_results(save_dir, img_id, parses, log_probs, reverse): appendix = 'test' if reverse else 'train' save_dir_i = os.path.join(save_dir, appendix+'_%0.2i' % img_id) mkdir(save_dir_i) # save log_probs lp_file = os.path.join(save_dir_i, 'log_probs.pt') torch.save(log_probs, lp_file) # save parses K = len(parses) for k in range(K): parse = parses[k] parse_file = os.path.join(save_dir_i, 'parse_%i.pt' % k) torch.save(parse, parse_file)
def main(batch_ix, configs_per=1, trials_per=800, iterations=1500, dry_run=False): assert 0 <= batch_ix < 5 idx = np.arange(batch_ix*10, (batch_ix+1)*10) target_dir = './targets' save_dir = os.path.join('./parses') mkdir(save_dir) # load images images = np.zeros((10,105,105), dtype=bool) for i, ix in enumerate(idx): images[i] = load_image(os.path.join(target_dir, 'handwritten%i.png' % (ix+1))) # ------------------ # Select Parses # ------------------ # load type model print('loading model...') type_model = TypeModel().eval() if torch.cuda.is_available(): type_model = type_model.cuda() score_fn = lambda parses : parse_score_fn(type_model, parses) # get base parses print('Collecting top-K parses for each train image...') base_parses = select_parses(score_fn, images, configs_per, trials_per) parse_list, target_imgs, K_per_img = process_for_opt(base_parses, images) # -------------------- # Optimize Parses # -------------------- # load full model token_model = TokenModel() renderer = Renderer() if torch.cuda.is_available(): torch.backends.cudnn.enabled = False token_model = token_model.cuda() renderer = renderer.cuda() model = opt.FullModel( renderer=renderer, type_model=type_model, token_model=token_model, denormalize=True ) print('Optimizing top-K parses...') parse_list, parse_scores = optimize_parses( model, parse_list, target_imgs, iterations) if not dry_run: save_parses(parse_list, parse_scores, save_dir, K_per_img, idx)
def save_parses(parse_list, log_probs, save_dir, K_per_img, idx): curr = 0 for i, K in K_per_img.items(): parse_list_img = parse_list[curr:curr+K] log_probs_img = log_probs[curr:curr+K] curr += K save_dir_i = os.path.join(save_dir, 'img_%0.2i' % (idx[i]+1)) mkdir(save_dir_i) # save log_probs lp_file = os.path.join(save_dir_i, 'log_probs.pt') torch.save(log_probs_img, lp_file) # save parses for k in range(K): parse = parse_list_img[k] parse_file = os.path.join(save_dir_i, 'parse_%i.pt' % k) torch.save(parse.state_dict(), parse_file)
def save_new_parses(parse_list, log_probs, save_dir, K_per_img, reverse=False): appendix = 'test' if reverse else 'train' curr = 0 for i, K in K_per_img.items(): parse_list_img = parse_list[curr:curr+K] log_probs_img = log_probs[curr:curr+K] curr += K save_dir_i = os.path.join(save_dir, appendix+'_%0.2i' % i) mkdir(save_dir_i) # save log_probs lp_file = os.path.join(save_dir_i, 'log_probs.pt') torch.save(log_probs_img, lp_file) # save parses for k in range(K): parse = parse_list_img[k] parse_file = os.path.join(save_dir_i, 'parse_%i.pt' % k) torch.save(parse.state_dict(), parse_file)
def get_base_parses(run_id, trials_per=800, reverse=False, dry_run=False): print('run_id: %i' % run_id) print('Loading model...') type_model = TypeModel().eval() if torch.cuda.is_available(): type_model = type_model.cuda() score_fn = lambda parses : model_score_fn(type_model, parses) print('Loading classification dataset...') dataset = ClassificationDataset(osc_folder='./one-shot-classification') run = dataset.runs[run_id] if reverse: imgs = run.test_imgs else: imgs = run.train_imgs results_dir = './results' run_dir = os.path.join(results_dir, 'run%0.2i' % (run_id+1)) save_dir = os.path.join(run_dir, 'base_parses') if not dry_run: mkdir(results_dir) mkdir(run_dir) mkdir(save_dir) print('Collecting top-K parses for each train image...') nimg = len(imgs) for i in range(nimg): start_time = time.time() parses, log_probs = get_topK_parses( imgs[i], k=5, score_fn=score_fn, configs_per=1, trials_per=trials_per) total_time = time.time() - start_time print('image %i/%i took %s' % (i+1, nimg, time_string(total_time))) if dry_run: continue save_img_results(save_dir, i, parses, log_probs, reverse)
def optimize_parses(run_id, iterations=1500, reverse=False, dry_run=False): run_dir = './results/run%0.2i' % (run_id+1) load_dir = os.path.join(run_dir, 'base_parses') save_dir = os.path.join(run_dir, 'tuned_parses') assert os.path.exists(run_dir) assert os.path.exists(load_dir) if not dry_run: mkdir(save_dir) print('Loading model and data...') type_model = TypeModel().eval() token_model = TokenModel() renderer = Renderer() # move to GPU if available if torch.cuda.is_available(): torch.backends.cudnn.enabled = False type_model = type_model.cuda() token_model = token_model.cuda() renderer = renderer.cuda() # build full model model = opt.FullModel( renderer=renderer, type_model=type_model, token_model=token_model, denormalize=True) print('Loading data...') # load classification dataset and select run dataset = ClassificationDataset(osc_folder='./one-shot-classification') run = dataset.runs[run_id] # load images and base parses for this run base_parses, images, K_per_img = load_parses(run, load_dir, reverse) assert len(base_parses) == len(images) print('total # parses: %i' % len(images)) print('Optimizing parses...') # initialize Parse modules and optimizer parse_list = [opt.ParseWithToken(p) for p in base_parses] render_params = [p for parse in parse_list for p in parse.render_params] stroke_params = [p for parse in parse_list for p in parse.stroke_params] param_groups = [ {'params': render_params, 'lr': 0.306983}, {'params': stroke_params, 'lr': 0.044114} ] optimizer = torch.optim.Adam(param_groups) # optimize start_time = time.time() losses, states = opt.optimize_parselist( parse_list, images, loss_fn=model.losses_fn, iterations=iterations, optimizer=optimizer, tune_blur=True, tune_fn=model.likelihood_losses_fn ) total_time = time.time() - start_time time.sleep(0.5) print('Took %s' % time_string(total_time)) if dry_run: return parse_scores = -losses[-1] save_new_parses(parse_list, parse_scores, save_dir, K_per_img, reverse)
def refit_parses_single(run_id, test_id, iterations=1500, reverse=False, run=None, dry_run=False): run_dir = './results/run%0.2i' % (run_id + 1) load_dir = os.path.join(run_dir, 'tuned_parses') save_dir = os.path.join(run_dir, 'refitted_parses') assert os.path.exists(run_dir) assert os.path.exists(load_dir) if not dry_run: mkdir(save_dir) print('Loading model...') token_model = TokenModel() renderer = Renderer(blur_fsize=21) if torch.cuda.is_available(): token_model = token_model.cuda() renderer = renderer.cuda() model = opt.FullModel(renderer=renderer, token_model=token_model) print('Loading parses...') # load classification dataset and select run if run is None: dataset = ClassificationDataset(osc_folder='./one-shot-classification') run = dataset.runs[run_id] if reverse: ntrain = len(run.test_imgs) test_img = torch.from_numpy(run.train_imgs[test_id]).float() else: ntrain = len(run.train_imgs) test_img = torch.from_numpy(run.test_imgs[test_id]).float() if torch.cuda.is_available(): test_img = test_img.cuda() # load tuned parses parse_list, K_per_img = load_tuned_parses(load_dir, ntrain, reverse) images = test_img.expand(len(parse_list), 105, 105) print('total # parses: %i' % len(images)) print('Optimizing parses...') # initialize Parse modules and optimizer render_params = [p for parse in parse_list for p in parse.render_params] stroke_params = [ p for parse in parse_list for p in parse.stroke_params if p.requires_grad ] param_groups = [{ 'params': render_params, 'lr': 0.087992 }, { 'params': stroke_params, 'lr': 0.166810 }] optimizer = torch.optim.Adam(param_groups) # optimize start_time = time.time() losses, states = opt.optimize_parselist(parse_list, images, loss_fn=model.losses_fn, iterations=iterations, optimizer=optimizer, tune_blur=True, tune_fn=model.likelihood_losses_fn) total_time = time.time() - start_time time.sleep(0.5) print('Took %s' % time_string(total_time)) if dry_run: return parse_scores = -losses[-1] save_new_parses(parse_list, parse_scores, save_dir, K_per_img, test_id, reverse)