コード例 #1
0
def semValidGen(pc, prog, encoder, decoder, max_lines, input_dim, device, gt_prog, rejection_sample):
    q = []

    prog_out = []
    preds = []
    children = []
    out = torch.zeros((1, 1, input_dim), dtype = torch.float).to(device)
    out[0][0][0] = 1.0
    bb_dims = prog["bb_dims"]
    gt_nc_ind = 0
    gt_children = []

    if gt_prog is not None and len(gt_prog) > 0:
        gt_children = gt_prog["children"]

    P = Program()
    P.mode = 'start'
    P.grounded = set([0])
    P.atts = {}
    c = 0

    stop = False

    loops = 0

    num_rejects = 0

    h = decoder.init_gru
    full_pc = pc[0, :, :(num_samps // 2)]

    while(not stop and loops < max_lines):
        loops += 1

        prev_out = out.clone().detach()
        reencoding = reencode_prog(full_pc, P, num_samps, encoder)

        if loops == 1:
            bb_pred = decoder.bbdimNet(reencoding.squeeze())
            prog["bb_dims"] = bb_pred
            bb_dims = bb_pred

        out, h, valid = clean_forward(
            decoder,
            out,
            reencoding,
            h,
            bb_dims,
            input_dim,
            device,
            P
        )

        if not valid:
            if rejection_sample and DO_REJECT:
                assert False, "Couldn't clean line"
            num_rejects += 1
            if MASK_BAD_OUT and num_rejects < MAX_REJECT:
                out = prev_out
            else:
                num_rejects = 0

            continue

        prog_out.append(out)
        line = out.clone().detach().squeeze()
        preds.append(line)

        command = torch.argmax(line[:7])

        pline = None

        if command == 1:
            P.mode = 'cuboid'
            pline = getCuboidLine(line, c)
            c += 1

        elif command == 2:
            P.mode = 'attach'
            cub1 = torch.argmax(line[7:18]).item()
            cub2 = torch.argmax(line[18:29]).item()

            P.grounded.add(cub1)

            if cub2 in P.atts:
                P.atts[cub2].append(cub1)
            else:
                P.atts[cub2] = [cub1]

            if cub1 in P.atts:
                P.atts[cub1].append(cub2)
            else:
                P.atts[cub1] = [cub2]

            pline = getAttachLines(line)

        elif command == 3:
            P.mode = 'sym'
            pline = getReflectLine(line)

        elif command == 4:
            P.mode = 'sym'
            pline = getTranslateLine(line)

        elif command == 5:
            P.mode = 'attach'
            cub1 = torch.argmax(line[7:18]).item()
            cub2 = torch.argmax(line[18:29]).item()
            cub3 = torch.argmax(line[29:40]).item()

            P.grounded.add(cub1)

            if cub2 in P.atts:
                P.atts[cub2].append(cub1)
            else:
                P.atts[cub2] = [cub1]

            if cub3 in P.atts:
                P.atts[cub3].append(cub1)
            else:
                P.atts[cub3] = [cub1]

            if cub1 in P.atts:
                P.atts[cub1].append(cub2)
                P.atts[cub1].append(cub3)
            else:
                P.atts[cub1] = [cub2, cub3]

            pline = getSqueezeLine(line)

        try:
            if pline is not None:
                P.execute(pline)

        except Exception:
            if VERBOSE:
                print("Unexpectedly, failed to execute line")
            pass

        # Stop at end token or when we have gone past max lines

        if command == 6:
            stop = True

        # If make a new Cuboid, use l to decide if it should have a child program or be a leaf
        if command == 1:
            children.append({})

    fc_preds, fc_prog_out, fchildren = cuboid_line_clean(preds, prog_out, children, P)
    prog["children"] = fchildren
    return fc_preds, fc_prog_out
コード例 #2
0
def semValidGen(prog, rnn, h, hier_ind, max_lines, input_dim, device, gt_prog,
                rejection_sample):
    q = []

    prog_out = []
    preds = []
    children = []
    out = torch.zeros((1, 1, input_dim), dtype=torch.float).to(device)
    out[0][0][0] = 1.0
    h_start = h.clone()
    bb_dims = prog["bb_dims"]
    gt_nc_ind = 0
    gt_children = []

    if gt_prog is not None and len(gt_prog) > 0:
        gt_children = gt_prog["children"]

    P = Program()
    P.mode = 'start'
    P.grounded = set([0])
    P.atts = {}
    c = 0

    stop = False

    loops = 0

    num_rejects = 0

    while (not stop and loops < max_lines):
        loops += 1

        prev_out = out.clone().detach()

        out, pnext, pleaf, h, valid = clean_forward(rnn, out, h, h_start,
                                                    bb_dims, hier_ind,
                                                    input_dim, device, P)

        if not valid:
            if rejection_sample and DO_REJECT:
                assert False, "Couldn't clean line"
            num_rejects += 1
            if MASK_BAD_OUT and num_rejects < MAX_REJECT:
                out = prev_out
            else:
                num_rejects = 0

            continue

        prog_out.append(out)
        line = out.clone().detach().squeeze()
        preds.append(line)

        command = torch.argmax(line[:7])

        pline = None

        if command == 1:
            P.mode = 'cuboid'
            pline = getCuboidLine(line, c)
            c += 1

        elif command == 2:
            P.mode = 'attach'
            cub1 = torch.argmax(line[7:18]).item()
            cub2 = torch.argmax(line[18:29]).item()

            P.grounded.add(cub1)

            if cub2 in P.atts:
                P.atts[cub2].append(cub1)
            else:
                P.atts[cub2] = [cub1]

            if cub1 in P.atts:
                P.atts[cub1].append(cub2)
            else:
                P.atts[cub1] = [cub2]

            pline = getAttachLines(line)

        elif command == 3:
            P.mode = 'sym'
            pline = getReflectLine(line)

        elif command == 4:
            P.mode = 'sym'
            pline = getTranslateLine(line)

        elif command == 5:
            P.mode = 'attach'
            cub1 = torch.argmax(line[7:18]).item()
            cub2 = torch.argmax(line[18:29]).item()
            cub3 = torch.argmax(line[29:40]).item()

            P.grounded.add(cub1)

            if cub2 in P.atts:
                P.atts[cub2].append(cub1)
            else:
                P.atts[cub2] = [cub1]

            if cub3 in P.atts:
                P.atts[cub3].append(cub1)
            else:
                P.atts[cub3] = [cub1]

            if cub1 in P.atts:
                P.atts[cub1].append(cub2)
                P.atts[cub1].append(cub3)
            else:
                P.atts[cub1] = [cub2, cub3]

            pline = getSqueezeLine(line)

        try:
            if pline is not None:
                P.execute(pline)

        except Exception:
            if VERBOSE:
                print("Unexpectedly, failed to execute line")
            pass

        # Stop at end token or when we have gone past max lines

        if command == 6:
            stop = True

        # If make a new Cuboid, use l to decide if it should have a child program or be a leaf
        if command == 1:

            gt_child = None

            if gt_nc_ind < len(gt_children):
                gt_child = gt_children[gt_nc_ind]
                gt_nc_ind += 1

            # Skip BBox line
            if pleaf.squeeze().item() < 0 and len(preds) > 1:
                d = {"children": [], "bb_dims": line[40:43]}
                children.append(d)
                q.append((pnext, d, hier_ind + 1, gt_child))

            else:
                children.append({})

    fc_preds, fc_prog_out, fchildren = cuboid_line_clean(
        preds, prog_out, children, P)
    prog["children"] = fchildren
    return fc_preds, fc_prog_out, q