Пример #1
0
def process_for_opt(base_parses, images):
    nimg = len(images)
    parse_list = []
    target_imgs = []
    K_per_img = {}
    for i in range(nimg):
        img = torch.from_numpy(images[i]).float()
        if torch.cuda.is_available():
            img = img.cuda()
        K = len(base_parses[i])
        parse_list.extend(base_parses[i])
        target_imgs.append(img[None].expand(K, 105, 105))
        K_per_img[i] = K
    parse_list = [opt.ParseWithToken(p) for p in parse_list]
    target_imgs = torch.cat(target_imgs)
    return parse_list, target_imgs, K_per_img
def load_tuned_parses(load_dir, ntrain, reverse=False):
    appendix = 'test' if reverse else 'train'
    tuned_parses = []
    K_per_img = {}
    for img_id in range(ntrain):
        savedir = os.path.join(load_dir, appendix + '_%0.2i' % img_id)
        parse_files = [f for f in os.listdir(savedir) if f.startswith('parse')]
        K_per_img[img_id] = len(parse_files)
        for f in sorted(parse_files):
            state_dict = torch.load(os.path.join(savedir, f), map_location='cpu')
            init_parse = [val for key,val in state_dict.items() if key.startswith('x')]
            parse = opt.ParseWithToken(init_parse)
            parse.load_state_dict(state_dict)
            config_for_refit(parse)
            tuned_parses.append(parse)

    return tuned_parses, K_per_img
Пример #3
0
def load_parses(img_ids):
    parses = {}
    log_probs = {}
    for img_id in img_ids:
        savedir = os.path.join('./parses/img_%0.2i' % (img_id + 1))
        # parses
        parse_files = [f for f in os.listdir(savedir) if f.startswith('parse')]
        parses[img_id] = []
        for f in sorted(parse_files):
            state_dict = torch.load(os.path.join(savedir, f),
                                    map_location='cpu')
            init_parse = [
                val for key, val in state_dict.items() if key.startswith('x')
            ]
            parse = opt.ParseWithToken(init_parse)
            parse.load_state_dict(state_dict)
            parses[img_id].append(parse)
        # log probs
        log_probs[img_id] = torch.load(os.path.join(savedir, 'log_probs.pt'))

    return parses, log_probs
Пример #4
0
def optimize_parses(run_id, iterations=1500, reverse=False, dry_run=False):
    run_dir = './results/run%0.2i' % (run_id+1)
    load_dir = os.path.join(run_dir, 'base_parses')
    save_dir = os.path.join(run_dir, 'tuned_parses')
    assert os.path.exists(run_dir)
    assert os.path.exists(load_dir)
    if not dry_run:
        mkdir(save_dir)

    print('Loading model and data...')
    type_model = TypeModel().eval()
    token_model = TokenModel()
    renderer = Renderer()
    # move to GPU if available
    if torch.cuda.is_available():
        torch.backends.cudnn.enabled = False
        type_model = type_model.cuda()
        token_model = token_model.cuda()
        renderer = renderer.cuda()

    # build full model
    model = opt.FullModel(
        renderer=renderer, type_model=type_model, token_model=token_model,
        denormalize=True)

    print('Loading data...')
    # load classification dataset and select run
    dataset = ClassificationDataset(osc_folder='./one-shot-classification')
    run = dataset.runs[run_id]

    # load images and base parses for this run
    base_parses, images, K_per_img = load_parses(run, load_dir, reverse)
    assert len(base_parses) == len(images)
    print('total # parses: %i' % len(images))


    print('Optimizing parses...')
    # initialize Parse modules and optimizer
    parse_list = [opt.ParseWithToken(p) for p in base_parses]
    render_params = [p for parse in parse_list for p in parse.render_params]
    stroke_params = [p for parse in parse_list for p in parse.stroke_params]
    param_groups = [
        {'params': render_params, 'lr': 0.306983},
        {'params': stroke_params, 'lr': 0.044114}
    ]
    optimizer = torch.optim.Adam(param_groups)

    # optimize
    start_time = time.time()
    losses, states = opt.optimize_parselist(
        parse_list, images,
        loss_fn=model.losses_fn,
        iterations=iterations,
        optimizer=optimizer,
        tune_blur=True,
        tune_fn=model.likelihood_losses_fn
    )
    total_time = time.time() - start_time
    time.sleep(0.5)
    print('Took %s' % time_string(total_time))
    if dry_run:
        return

    parse_scores = -losses[-1]
    save_new_parses(parse_list, parse_scores, save_dir, K_per_img, reverse)