def main(batch_ix,
         configs_per=1,
         trials_per=800,
         iterations=1500,
         dry_run=False):
    assert 0 <= batch_ix < 5
    idx = np.arange(batch_ix * 10, (batch_ix + 1) * 10)
    target_dir = './targets'
    save_dir = os.path.join('./parses')
    mkdir(save_dir)

    # load images
    images = np.zeros((10, 105, 105), dtype=bool)
    for i, ix in enumerate(idx):
        images[i] = load_image(
            os.path.join(target_dir, 'handwritten%i.png' % (ix + 1)))

    # ------------------
    #   Select Parses
    # ------------------

    # load type model
    print('loading model...')
    type_model = TypeModel().eval()
    if torch.cuda.is_available():
        type_model = type_model.cuda()
    score_fn = lambda parses: parse_score_fn(type_model, parses)

    # get base parses
    print('Collecting top-K parses for each train image...')
    base_parses = select_parses(score_fn, images, configs_per, trials_per)
    parse_list, target_imgs, K_per_img = process_for_opt(base_parses, images)

    # --------------------
    #   Optimize Parses
    # --------------------

    # load full model
    token_model = TokenModel()
    renderer = Renderer()
    if torch.cuda.is_available():
        torch.backends.cudnn.enabled = False
        token_model = token_model.cuda()
        renderer = renderer.cuda()
    model = opt.FullModel(renderer=renderer,
                          type_model=type_model,
                          token_model=token_model,
                          denormalize=True)

    print('Optimizing top-K parses...')
    parse_list, parse_scores = optimize_parses(model, parse_list, target_imgs,
                                               iterations)

    if not dry_run:
        save_parses(parse_list, parse_scores, save_dir, K_per_img, idx)
Example #2
0
    def __init__(self, save_dir=None, downsize=True):
        super().__init__()

        if save_dir is None:
            save_dir = MODEL_SAVE_PATH
        assert os.path.exists(save_dir)

        self.loc = load_location_model(save_dir, downsize).requires_grad_(False)
        self.stk = load_stroke_model(save_dir, downsize).requires_grad_(False)
        self.term = load_terminate_model(save_dir, downsize).requires_grad_(False)

        self.renderer = Renderer(blur_sigma=0.5)

        self.register_buffer('space_mean', torch.tensor([50., -50.]))
        self.register_buffer('space_scale', torch.tensor([20., 20.]))
Example #3
0
def main():
    print('Loading target images...')
    images = np.zeros((50, 105, 105), dtype=bool)
    for i in range(50):
        images[i] = load_image(
            os.path.join('./targets/handwritten%i.png' % (i + 1)))

    print('Loading model parses...')
    img_ids = np.arange(50)
    fits = ImageFits(images, img_ids)

    print('Generating new exemplars...')
    renderer = Renderer()
    token_model = TokenModel()
    torch.manual_seed(4)

    nrow, ncol = 7, 7
    size = 2
    fig = plt.figure(figsize=(size * ncol * 0.75, size * nrow))
    outer = gridspec.GridSpec(nrow, ncol)
    for i in range(nrow * ncol):
        samples_i = sample(token_model, fits, i, T=8)
        show_subgrid(renderer,
                     fig,
                     outer[i],
                     img=fits.train_imgs[i],
                     samples=samples_i)
    plt.show()
Example #4
0
def optimize_parses(run_id, iterations=1500, reverse=False, dry_run=False):
    run_dir = './results/run%0.2i' % (run_id+1)
    load_dir = os.path.join(run_dir, 'base_parses')
    save_dir = os.path.join(run_dir, 'tuned_parses')
    assert os.path.exists(run_dir)
    assert os.path.exists(load_dir)
    if not dry_run:
        mkdir(save_dir)

    print('Loading model and data...')
    type_model = TypeModel().eval()
    token_model = TokenModel()
    renderer = Renderer()
    # move to GPU if available
    if torch.cuda.is_available():
        torch.backends.cudnn.enabled = False
        type_model = type_model.cuda()
        token_model = token_model.cuda()
        renderer = renderer.cuda()

    # build full model
    model = opt.FullModel(
        renderer=renderer, type_model=type_model, token_model=token_model,
        denormalize=True)

    print('Loading data...')
    # load classification dataset and select run
    dataset = ClassificationDataset(osc_folder='./one-shot-classification')
    run = dataset.runs[run_id]

    # load images and base parses for this run
    base_parses, images, K_per_img = load_parses(run, load_dir, reverse)
    assert len(base_parses) == len(images)
    print('total # parses: %i' % len(images))


    print('Optimizing parses...')
    # initialize Parse modules and optimizer
    parse_list = [opt.ParseWithToken(p) for p in base_parses]
    render_params = [p for parse in parse_list for p in parse.render_params]
    stroke_params = [p for parse in parse_list for p in parse.stroke_params]
    param_groups = [
        {'params': render_params, 'lr': 0.306983},
        {'params': stroke_params, 'lr': 0.044114}
    ]
    optimizer = torch.optim.Adam(param_groups)

    # optimize
    start_time = time.time()
    losses, states = opt.optimize_parselist(
        parse_list, images,
        loss_fn=model.losses_fn,
        iterations=iterations,
        optimizer=optimizer,
        tune_blur=True,
        tune_fn=model.likelihood_losses_fn
    )
    total_time = time.time() - start_time
    time.sleep(0.5)
    print('Took %s' % time_string(total_time))
    if dry_run:
        return

    parse_scores = -losses[-1]
    save_new_parses(parse_list, parse_scores, save_dir, K_per_img, reverse)
def refit_parses_single(run_id, test_id, iterations=1500, reverse=False,
                        run=None, dry_run=False):
    run_dir = './run%0.2i' % (run_id+1)
    load_dir = os.path.join(run_dir, 'tuned_parses')
    save_dir = os.path.join(run_dir, 'refitted_parses')
    assert os.path.exists(run_dir)
    assert os.path.exists(load_dir)
    if not dry_run:
        mkdir(save_dir)

    print('Loading model...')
    token_model = TokenModel()
    renderer = Renderer(blur_fsize=21)
    if torch.cuda.is_available():
        token_model = token_model.cuda()
        renderer = renderer.cuda()
    model = opt.FullModel(renderer=renderer, token_model=token_model)

    print('Loading parses...')
    # load classification dataset and select run
    if run is None:
        dataset = ClassificationDataset(osc_folder='./one-shot-classification')
        run = dataset.runs[run_id]
    if reverse:
        ntrain = len(run.test_imgs)
        test_img = torch.from_numpy(run.train_imgs[test_id]).float()
    else:
        ntrain = len(run.train_imgs)
        test_img = torch.from_numpy(run.test_imgs[test_id]).float()
    if torch.cuda.is_available():
        test_img = test_img.cuda()

    # load tuned parses
    parse_list, K_per_img = load_tuned_parses(load_dir, ntrain, reverse)
    images = test_img.expand(len(parse_list), 105, 105)
    print('total # parses: %i' % len(images))


    print('Optimizing parses...')
    # initialize Parse modules and optimizer
    render_params = [p for parse in parse_list for p in parse.render_params]
    stroke_params = [p for parse in parse_list for p in parse.stroke_params if p.requires_grad]
    param_groups = [
        {'params': render_params, 'lr': 0.087992},
        {'params': stroke_params, 'lr': 0.166810}
    ]
    optimizer = torch.optim.Adam(param_groups)

    # optimize
    start_time = time.time()
    losses, states = opt.optimize_parselist(
        parse_list, images,
        loss_fn=model.losses_fn,
        iterations=iterations,
        optimizer=optimizer,
        tune_blur=True,
        tune_fn=model.likelihood_losses_fn
    )
    total_time = time.time() - start_time
    time.sleep(0.5)
    print('Took %s' % time_string(total_time))
    if dry_run:
        return

    parse_scores = -losses[-1]
    save_new_parses(parse_list, parse_scores, save_dir, K_per_img, test_id, reverse)
Example #6
0
class TypeModel(nn.Module):
    def __init__(self, save_dir=None, downsize=True):
        super().__init__()

        if save_dir is None:
            save_dir = MODEL_SAVE_PATH
        assert os.path.exists(save_dir)

        self.loc = load_location_model(save_dir, downsize).requires_grad_(False)
        self.stk = load_stroke_model(save_dir, downsize).requires_grad_(False)
        self.term = load_terminate_model(save_dir, downsize).requires_grad_(False)

        self.renderer = Renderer(blur_sigma=0.5)

        self.register_buffer('space_mean', torch.tensor([50., -50.]))
        self.register_buffer('space_scale', torch.tensor([20., 20.]))

    def cuda(self, device=None):
        super().cuda(device)
        self.renderer.painter = self.renderer.painter.cpu()
        return self

    def normalize(self, x, inverse=False):
        if inverse:
            return x*self.space_scale + self.space_mean
        else:
            return (x - self.space_mean)/self.space_scale

    def loss_fn(self, splines, drawing=None, filter_small=False,
                denormalize=False):
        if drawing is None: # get strokes
            drawing = splines_to_strokes(splines)
        if filter_small: # filter small strokes
            keep_ix = find_keepers(splines)
            splines = [splines[i] for i in keep_ix]
            drawing = [drawing[i] for i in keep_ix]
        ns = len(splines)
        device = splines[0].device
        pad_val = self.stk.pad_val

        # compute partial canvas renders
        canvases = self.renderer.forward_partial(drawing) # (ns+1,105,105)
        canvases = canvases.unsqueeze(1)

        # normalize spatial coordinates
        splines = [self.normalize(x) for x in splines]
        drawing = [self.normalize(x) for x in drawing]

        # collect model inputs
        prevs = get_input_prev(drawing, device) # (ns+1, 2)
        locs = get_input_loc(drawing, device) # (ns+1, 2)
        trajs = get_input_traj(splines, pad_val=pad_val) # (ns+1, T, 2)
        trajs = rnn.pad_sequence(trajs, batch_first=True, padding_value=pad_val)
        terms = get_input_term(ns, device) # (ns+1,)

        # compute losses
        if denormalize:
            losses_loc = self.loc.losses_fn(
                x_canv=canvases[:-1], prev=prevs[:-1], start=locs[:-1],
                mean_X=self.space_mean, std_X=self.space_scale)
            losses_stk = self.stk.losses_fn(
                x_canv=canvases[:-1], start=locs[:-1], x=trajs[:-1],
                std_X=self.space_scale)
        else:
            losses_loc = self.loc.losses_fn(x_canv=canvases[:-1], prev=prevs[:-1], start=locs[:-1])
            losses_stk = self.stk.losses_fn(x_canv=canvases[:-1], start=locs[:-1], x=trajs[:-1])
        losses_term = self.term.losses_fn(x=canvases, y=terms)
        loss = losses_loc.sum() + losses_stk.sum() + losses_term.sum()
        return loss

    def _losses_fn(self, splines_list, drawing_list=None, filter_small=False,
                   denormalize=False):
        assert isinstance(splines_list[0], list)
        if drawing_list is None:
            drawing_list = [splines_to_strokes(splines) for splines in splines_list]
        if filter_small: # filter small strokes
            keep_ix_list = [find_keepers(splines) for splines in splines_list]
            splines_list = [[splines[i] for i in keep_ix] for
                            splines, keep_ix in zip(splines_list, keep_ix_list)]
            drawing_list = [[drawing[i] for i in keep_ix] for
                            drawing, keep_ix in zip(drawing_list, keep_ix_list)]
        nsamp = len(splines_list)
        ns_list = [len(splines) for splines in splines_list]
        device = splines_list[0][0].device
        pad_val = self.stk.pad_val

        # compute partial canvas renders; returns list of (ns+1,105,105)
        canvases = self.renderer.forward_partial(drawing_list, concat=True)
        canvases = canvases.unsqueeze(1) # (ntot, 1, 105, 105)

        # normalize spatial coordinates
        splines_list = [[self.normalize(x) for x in splines] for splines in splines_list]
        drawing_list = [[self.normalize(x) for x in drawing] for drawing in drawing_list]

        # collect model inputs
        prevs = torch.cat([get_input_prev(drawing, device) for drawing in drawing_list])
        locs = torch.cat([get_input_loc(drawing, device) for drawing in drawing_list])
        trajs = [elt for splines in splines_list for elt in get_input_traj(splines, pad_val)]
        trajs = rnn.pad_sequence(trajs, batch_first=True, padding_value=pad_val)
        terms = torch.cat([get_input_term(ns, device) for ns in ns_list])

        # compute losses
        sizes = [ns+1 for ns in ns_list]
        if denormalize:
            losses_loc = self.loc.losses_fn(
                x_canv=canvases, prev=prevs, start=locs,
                mean_X=self.space_mean, std_X=self.space_scale)
            losses_stk = self.stk.losses_fn(
                x_canv=canvases, start=locs, x=trajs,
                std_X=self.space_scale
            )
        else:
            losses_loc = self.loc.losses_fn(x_canv=canvases, prev=prevs, start=locs)
            losses_stk = self.stk.losses_fn(x_canv=canvases, start=locs, x=trajs)
        losses_term = self.term.losses_fn(x=canvases, y=terms)
        losses_loc = torch.split(losses_loc, sizes, 0)
        losses_stk = torch.split(losses_stk, sizes, 0)
        losses_term = torch.split(losses_term, sizes, 0)
        losses = torch.zeros(nsamp, device=device)
        for i, (ll,ls,lt) in enumerate(zip(losses_loc, losses_stk, losses_term)):
            losses[i] = ll[:-1].sum() + ls[:-1].sum() + lt.sum()
        return losses

    def losses_fn(self, splines_list, drawing_list=None, filter_small=False,
                  denormalize=False, max_size=400):
        n = len(splines_list)
        if n <= max_size:
            return self._losses_fn(
                splines_list, drawing_list, filter_small, denormalize)
        losses = torch.zeros(n, device=splines_list[0][0].device)
        nbatch = math.ceil(n/max_size)
        for i in range(nbatch):
            start = i*max_size
            end = (i+1)*max_size
            splines_batch = splines_list[start:end]
            drawings_batch = None if drawing_list is None else drawing_list[start:end]
            losses[start:end] = self._losses_fn(
                splines_batch, drawings_batch, filter_small, denormalize)
        return losses