Пример #1
0
def save_new_parses(parses_j,
                    log_probs_j,
                    save_dir,
                    K_per_img,
                    test_id,
                    reverse=False):
    """
    i : train image index
    k : parse index
    """
    appendix_i = 'test' if reverse else 'train'
    appendix_j = 'train' if reverse else 'test'
    curr = 0
    for train_id, K in K_per_img.items():
        # get savedir paths
        save_dir_i = os.path.join(save_dir, appendix_i + '_%0.2i' % train_id)
        mkdir(save_dir_i)
        save_dir_ij = os.path.join(save_dir_i, appendix_j + '_%0.2i' % test_id)
        mkdir(save_dir_ij)
        # get data subset
        parses_ij = parses_j[curr:curr + K]
        log_probs_ij = log_probs_j[curr:curr + K]
        curr += K
        # save log-probs
        lp_file = os.path.join(save_dir_ij, 'log_probs.pt')
        torch.save(log_probs_ij, lp_file)
        # save parses
        for k in range(K):
            parse = parses_ij[k]
            parse_file = os.path.join(save_dir_ij, 'parse_%i.pt' % k)
            torch.save(parse.state_dict(), parse_file)
Пример #2
0
def save_img_results(save_dir, img_id, parses, log_probs, reverse):
    appendix = 'test' if reverse else 'train'
    save_dir_i = os.path.join(save_dir, appendix+'_%0.2i' % img_id)
    mkdir(save_dir_i)
    # save log_probs
    lp_file = os.path.join(save_dir_i, 'log_probs.pt')
    torch.save(log_probs, lp_file)
    # save parses
    K = len(parses)
    for k in range(K):
        parse = parses[k]
        parse_file = os.path.join(save_dir_i, 'parse_%i.pt' % k)
        torch.save(parse, parse_file)
Пример #3
0
def main(batch_ix, configs_per=1, trials_per=800, iterations=1500, dry_run=False):
    assert 0 <= batch_ix < 5
    idx = np.arange(batch_ix*10, (batch_ix+1)*10)
    target_dir = './targets'
    save_dir = os.path.join('./parses')
    mkdir(save_dir)

    # load images
    images = np.zeros((10,105,105), dtype=bool)
    for i, ix in enumerate(idx):
        images[i] = load_image(os.path.join(target_dir, 'handwritten%i.png' % (ix+1)))

    # ------------------
    #   Select Parses
    # ------------------

    # load type model
    print('loading model...')
    type_model = TypeModel().eval()
    if torch.cuda.is_available():
        type_model = type_model.cuda()
    score_fn = lambda parses : parse_score_fn(type_model, parses)

    # get base parses
    print('Collecting top-K parses for each train image...')
    base_parses = select_parses(score_fn, images, configs_per, trials_per)
    parse_list, target_imgs, K_per_img = process_for_opt(base_parses, images)


    # --------------------
    #   Optimize Parses
    # --------------------

    # load full model
    token_model = TokenModel()
    renderer = Renderer()
    if torch.cuda.is_available():
        torch.backends.cudnn.enabled = False
        token_model = token_model.cuda()
        renderer = renderer.cuda()
    model = opt.FullModel(
        renderer=renderer, type_model=type_model, token_model=token_model,
        denormalize=True
    )

    print('Optimizing top-K parses...')
    parse_list, parse_scores = optimize_parses(
        model, parse_list, target_imgs, iterations)

    if not dry_run:
        save_parses(parse_list, parse_scores, save_dir, K_per_img, idx)
Пример #4
0
def save_parses(parse_list, log_probs, save_dir, K_per_img, idx):
    curr = 0
    for i, K in K_per_img.items():
        parse_list_img = parse_list[curr:curr+K]
        log_probs_img = log_probs[curr:curr+K]
        curr += K
        save_dir_i = os.path.join(save_dir, 'img_%0.2i' % (idx[i]+1))
        mkdir(save_dir_i)
        # save log_probs
        lp_file = os.path.join(save_dir_i, 'log_probs.pt')
        torch.save(log_probs_img, lp_file)
        # save parses
        for k in range(K):
            parse = parse_list_img[k]
            parse_file = os.path.join(save_dir_i, 'parse_%i.pt' % k)
            torch.save(parse.state_dict(), parse_file)
Пример #5
0
def save_new_parses(parse_list, log_probs, save_dir, K_per_img, reverse=False):
    appendix = 'test' if reverse else 'train'
    curr = 0
    for i, K in K_per_img.items():
        parse_list_img = parse_list[curr:curr+K]
        log_probs_img = log_probs[curr:curr+K]
        curr += K
        save_dir_i = os.path.join(save_dir, appendix+'_%0.2i' % i)
        mkdir(save_dir_i)
        # save log_probs
        lp_file = os.path.join(save_dir_i, 'log_probs.pt')
        torch.save(log_probs_img, lp_file)
        # save parses
        for k in range(K):
            parse = parse_list_img[k]
            parse_file = os.path.join(save_dir_i, 'parse_%i.pt' % k)
            torch.save(parse.state_dict(), parse_file)
Пример #6
0
def get_base_parses(run_id, trials_per=800, reverse=False, dry_run=False):
    print('run_id: %i' % run_id)
    print('Loading model...')
    type_model = TypeModel().eval()
    if torch.cuda.is_available():
        type_model = type_model.cuda()
    score_fn = lambda parses : model_score_fn(type_model, parses)

    print('Loading classification dataset...')
    dataset = ClassificationDataset(osc_folder='./one-shot-classification')
    run = dataset.runs[run_id]
    if reverse:
        imgs = run.test_imgs
    else:
        imgs = run.train_imgs
    results_dir = './results'
    run_dir = os.path.join(results_dir, 'run%0.2i' % (run_id+1))
    save_dir = os.path.join(run_dir, 'base_parses')
    if not dry_run:
        mkdir(results_dir)
        mkdir(run_dir)
        mkdir(save_dir)

    print('Collecting top-K parses for each train image...')
    nimg = len(imgs)
    for i in range(nimg):
        start_time = time.time()
        parses, log_probs = get_topK_parses(
            imgs[i], k=5, score_fn=score_fn, configs_per=1,
            trials_per=trials_per)
        total_time = time.time() - start_time
        print('image %i/%i took %s' % (i+1, nimg, time_string(total_time)))
        if dry_run:
            continue
        save_img_results(save_dir, i, parses, log_probs, reverse)
Пример #7
0
def optimize_parses(run_id, iterations=1500, reverse=False, dry_run=False):
    run_dir = './results/run%0.2i' % (run_id+1)
    load_dir = os.path.join(run_dir, 'base_parses')
    save_dir = os.path.join(run_dir, 'tuned_parses')
    assert os.path.exists(run_dir)
    assert os.path.exists(load_dir)
    if not dry_run:
        mkdir(save_dir)

    print('Loading model and data...')
    type_model = TypeModel().eval()
    token_model = TokenModel()
    renderer = Renderer()
    # move to GPU if available
    if torch.cuda.is_available():
        torch.backends.cudnn.enabled = False
        type_model = type_model.cuda()
        token_model = token_model.cuda()
        renderer = renderer.cuda()

    # build full model
    model = opt.FullModel(
        renderer=renderer, type_model=type_model, token_model=token_model,
        denormalize=True)

    print('Loading data...')
    # load classification dataset and select run
    dataset = ClassificationDataset(osc_folder='./one-shot-classification')
    run = dataset.runs[run_id]

    # load images and base parses for this run
    base_parses, images, K_per_img = load_parses(run, load_dir, reverse)
    assert len(base_parses) == len(images)
    print('total # parses: %i' % len(images))


    print('Optimizing parses...')
    # initialize Parse modules and optimizer
    parse_list = [opt.ParseWithToken(p) for p in base_parses]
    render_params = [p for parse in parse_list for p in parse.render_params]
    stroke_params = [p for parse in parse_list for p in parse.stroke_params]
    param_groups = [
        {'params': render_params, 'lr': 0.306983},
        {'params': stroke_params, 'lr': 0.044114}
    ]
    optimizer = torch.optim.Adam(param_groups)

    # optimize
    start_time = time.time()
    losses, states = opt.optimize_parselist(
        parse_list, images,
        loss_fn=model.losses_fn,
        iterations=iterations,
        optimizer=optimizer,
        tune_blur=True,
        tune_fn=model.likelihood_losses_fn
    )
    total_time = time.time() - start_time
    time.sleep(0.5)
    print('Took %s' % time_string(total_time))
    if dry_run:
        return

    parse_scores = -losses[-1]
    save_new_parses(parse_list, parse_scores, save_dir, K_per_img, reverse)
Пример #8
0
def refit_parses_single(run_id,
                        test_id,
                        iterations=1500,
                        reverse=False,
                        run=None,
                        dry_run=False):
    run_dir = './results/run%0.2i' % (run_id + 1)
    load_dir = os.path.join(run_dir, 'tuned_parses')
    save_dir = os.path.join(run_dir, 'refitted_parses')
    assert os.path.exists(run_dir)
    assert os.path.exists(load_dir)
    if not dry_run:
        mkdir(save_dir)

    print('Loading model...')
    token_model = TokenModel()
    renderer = Renderer(blur_fsize=21)
    if torch.cuda.is_available():
        token_model = token_model.cuda()
        renderer = renderer.cuda()
    model = opt.FullModel(renderer=renderer, token_model=token_model)

    print('Loading parses...')
    # load classification dataset and select run
    if run is None:
        dataset = ClassificationDataset(osc_folder='./one-shot-classification')
        run = dataset.runs[run_id]
    if reverse:
        ntrain = len(run.test_imgs)
        test_img = torch.from_numpy(run.train_imgs[test_id]).float()
    else:
        ntrain = len(run.train_imgs)
        test_img = torch.from_numpy(run.test_imgs[test_id]).float()
    if torch.cuda.is_available():
        test_img = test_img.cuda()

    # load tuned parses
    parse_list, K_per_img = load_tuned_parses(load_dir, ntrain, reverse)
    images = test_img.expand(len(parse_list), 105, 105)
    print('total # parses: %i' % len(images))

    print('Optimizing parses...')
    # initialize Parse modules and optimizer
    render_params = [p for parse in parse_list for p in parse.render_params]
    stroke_params = [
        p for parse in parse_list for p in parse.stroke_params
        if p.requires_grad
    ]
    param_groups = [{
        'params': render_params,
        'lr': 0.087992
    }, {
        'params': stroke_params,
        'lr': 0.166810
    }]
    optimizer = torch.optim.Adam(param_groups)

    # optimize
    start_time = time.time()
    losses, states = opt.optimize_parselist(parse_list,
                                            images,
                                            loss_fn=model.losses_fn,
                                            iterations=iterations,
                                            optimizer=optimizer,
                                            tune_blur=True,
                                            tune_fn=model.likelihood_losses_fn)
    total_time = time.time() - start_time
    time.sleep(0.5)
    print('Took %s' % time_string(total_time))
    if dry_run:
        return

    parse_scores = -losses[-1]
    save_new_parses(parse_list, parse_scores, save_dir, K_per_img, test_id,
                    reverse)