Beispiel #1
0
def get_cd(imitate_net, data, one_hot_labels, program_len):
    batch_size = data.shape[1]
    beam_width = 10
    all_beams, next_beams_prob, all_inputs = imitate_net.beam_search(
        [data, one_hot_labels], beam_width, program_len)

    beam_labels = beams_parser(all_beams, batch_size, beam_width=beam_width)

    beam_labels_numpy = np.zeros((batch_size * beam_width, program_len),
                                 dtype=np.int32)
    for i in range(batch_size):
        beam_labels_numpy[i * beam_width:(i + 1) *
                          beam_width, :] = beam_labels[i]

    # find expression from these predicted beam labels
    expressions = [""] * batch_size * beam_width
    for i in range(batch_size * beam_width):
        for j in range(program_len):
            expressions[i] += unique_draw[beam_labels_numpy[i, j]]
    for index, prog in enumerate(expressions):
        expressions[index] = prog.split("$")[0]

    predicted_images = image_from_expressions(parser, expressions)
    target_images = data_[-1, :, 0, :, :].astype(dtype=bool)
    target_images_new = np.repeat(target_images, axis=0, repeats=beam_width)

    beam_CD = chamfer(target_images_new, predicted_images)

    CD = np.zeros((batch_size, 1))
    for r in range(batch_size):
        CD[r, 0] = min(beam_CD[r * beam_width:(r + 1) * beam_width])

    return np.sum(CD)
Beispiel #2
0
def get_nn(images1, images2):
    min_cds = []
    for i in range(len(images1)):
        repeated_i = np.repeat(images1[i:i + 1], axis=0, repeats=len(images2))
        cd = chamfer(repeated_i, images2)
        min_cds.append(np.amin(cd))
    return sum(min_cds) / len(min_cds)
Beispiel #3
0
    def objective(self, x: np.ndarray):
        """
        Objective to minimize.
        :param x: input program parameters in numpy array format
        :return: 
        """
        x = x.astype(np.int)
        x = np.clip(x, 8, 56)

        query_exp = self.make_expression(x)
        query_image = self.parser.expression2stack([query_exp])[-1, 0, 0, :, :]
        if self.metric == "iou":
            error = -np.sum(np.logical_and(
                self.target_image, query_image)) / np.sum(
                    np.logical_or(self.target_image, query_image))
        elif self.metric == "chamfer":
            error = chamfer(np.expand_dims(self.target_image, 0),
                            np.expand_dims(query_image, 0))
        return error
Beispiel #4
0
                test_output = imitate_net.test([data, one_hot_labels, max_len])
                # test_output = torch.stack(test_output).permute(1, 0, 2)
                # acc += float((torch.argmax(test_output[:, :, :8], dim=2)[:, :(k)] == labels_cont[:, 1:-1, 0]).float().sum()) \
                #        / (len(labels_cont) * (k)) / len(dataset_sizes) / (dataset_sizes[k][1] // test_batch_size)
                acc += float((torch.argmax(torch.stack(test_output), dim=2)[:k].permute(1, 0) == labels[:, :-1]).float().sum()) \
                      / (len(labels) * (k)) / len(dataset_sizes) / (dataset_sizes[k][1] // test_batch_size)
                pred_images, correct_prog, pred_prog = parser.get_final_canvas(
                    test_output,
                    if_just_expressions=False,
                    if_pred_images=True)
                target_images = data_[-1, :, 0, :, :].astype(dtype=bool)
                targ_prog = parser.labels2exps(labels, k)

                programs_tar[jit] += targ_prog
                programs_pred[jit] += pred_prog
                distance = chamfer(target_images, pred_images)
                total_CD += np.sum(distance)

    over_all_CD[jit] = total_CD / total_size

metrics["chamfer"] = over_all_CD
print(metrics, model_name)
print(over_all_CD)
print(acc / 2)

results_path = "trained_models/results/{}/".format(model_name)
os.makedirs(os.path.dirname(results_path), exist_ok=True)

with open("trained_models/results/{}/{}".format(model_name, "pred_prog.org"),
          'w') as outfile:
    json.dump(programs_pred, outfile)
Beispiel #5
0
def train_inference(imitate_net,
                    path,
                    max_epochs=None,
                    self_training=False,
                    ab=None):
    if max_epochs is None:
        epochs = 1000
    else:
        epochs = max_epochs

    config = read_config.Config("config_synthetic.yml")

    if ab is not None:
        train_size = inference_train_size * ab
    else:
        train_size = inference_train_size

    generator = WakeSleepGen(f"{path}/",
                             batch_size=config.batch_size,
                             train_size=train_size,
                             canvas_shape=config.canvas_shape,
                             max_len=max_len,
                             self_training=True)

    train_gen = generator.get_train_data()

    cad_generator = Generator()
    val_gen = cad_generator.val_gen(batch_size=config.batch_size,
                                    path="data/cad/cad.h5",
                                    if_augment=False)

    for parameter in imitate_net.encoder.parameters():
        parameter.requires_grad = False

    optimizer = optim.Adam(
        [para for para in imitate_net.parameters() if para.requires_grad],
        weight_decay=config.weight_decay,
        lr=config.lr)

    reduce_plat = LearningRate(optimizer,
                               init_lr=config.lr,
                               lr_dacay_fact=0.2,
                               patience=config.patience)

    best_test_loss = 1e20
    torch.save(imitate_net.state_dict(), f"{path}/best_dict.pth")

    best_test_cd = 1e20

    patience = 20
    num_worse = 0

    for epoch in range(epochs):
        start = time.time()
        train_loss = 0
        imitate_net.train()
        for batch_idx in range(train_size //
                               (config.batch_size * config.num_traj)):
            optimizer.zero_grad()
            loss = 0
            # acc = 0
            for _ in range(config.num_traj):
                data, labels = next(train_gen)
                # data = data[:, :, 0:1, :, :]
                one_hot_labels = prepare_input_op(labels,
                                                  len(generator.unique_draw))
                one_hot_labels = torch.from_numpy(one_hot_labels).to(device)
                data = data.to(device)
                labels = labels.to(device)
                outputs = imitate_net([data, one_hot_labels, max_len])
                # acc += float((torch.argmax(outputs, dim=2).permute(1, 0) == labels).float().sum()) \
                #        / (labels.shape[0] * labels.shape[1]) / config.num_traj
                loss_k = (
                    (losses_joint(outputs, labels, time_steps=max_len + 1) /
                     (max_len + 1)) / config.num_traj)
                loss_k.backward()
                loss += float(loss_k)
                del loss_k

            optimizer.step()
            train_loss += loss
            print(f"batch {batch_idx} train loss: {loss}")
            # print(f"acc: {acc}")

        mean_train_loss = train_loss / (train_size // (config.batch_size))
        print(f"epoch {epoch} mean train loss: {mean_train_loss}")
        imitate_net.eval()
        loss = 0
        # acc = 0
        metrics = {"cos": 0, "iou": 0, "cd": 0}
        # IOU = 0
        # COS = 0
        CD = 0
        # correct_programs = 0
        # pred_programs = 0
        for batch_idx in range(inference_test_size // config.batch_size):
            parser = ParseModelOutput(generator.unique_draw, max_len // 2 + 1,
                                      max_len, config.canvas_shape)
            with torch.no_grad():
                labels = np.zeros((config.batch_size, max_len), dtype=np.int32)
                data_ = next(val_gen)
                one_hot_labels = prepare_input_op(labels,
                                                  len(generator.unique_draw))
                one_hot_labels = torch.from_numpy(one_hot_labels).cuda()
                data = torch.from_numpy(data_).cuda()
                # outputs = imitate_net([data, one_hot_labels, max_len])
                # loss_k = (losses_joint(outputs, labels, time_steps=max_len + 1) /
                #          (max_len + 1))
                # loss += float(loss_k)
                test_outputs = imitate_net.test(
                    [data[-1, :, 0, :, :], one_hot_labels, max_len])
                # acc += float((torch.argmax(torch.stack(test_outputs), dim=2).permute(1, 0) == labels[:, :-1]).float().sum()) \
                #         / (len(labels) * (max_len+1)) / (inference_test_size // config.batch_size)
                pred_images, correct_prog, pred_prog = parser.get_final_canvas(
                    test_outputs,
                    if_just_expressions=False,
                    if_pred_images=True)
                # correct_programs += len(correct_prog)
                # pred_programs += len(pred_prog)
                target_images = data_[-1, :, 0, :, :].astype(dtype=bool)
                # iou = np.sum(np.logical_and(target_images, pred_images),
                #              (1, 2)) / \
                #       np.sum(np.logical_or(target_images, pred_images),
                #              (1, 2))
                # cos = cosine_similarity(target_images, pred_images)
                CD += np.sum(chamfer(target_images, pred_images))
                # IOU += np.sum(iou)
                # COS += np.sum(cos)

        # metrics["iou"] = IOU / inference_test_size
        # metrics["cos"] = COS / inference_test_size
        metrics["cd"] = CD / inference_test_size

        test_losses = loss
        test_loss = test_losses / (inference_test_size // (config.batch_size))

        if metrics["cd"] >= best_test_cd:
            num_worse += 1
        else:
            num_worse = 0
            best_test_cd = metrics["cd"]
            torch.save(imitate_net.state_dict(), f"{path}/best_dict.pth")
        if num_worse >= patience:
            # load the best model and stop training
            imitate_net.load_state_dict(torch.load(f"{path}/best_dict.pth"))
            return epoch + 1

        # reduce_plat.reduce_on_plateu(metrics["cd"])
        print(
            f"Epoch {epoch}/100 =>  train_loss: {mean_train_loss}, iou: {0}, cd: {metrics['cd']}, test_mse: {test_loss}, test_acc: {0}"
        )
        # print(f"CORRECT PROGRAMS: {correct_programs}")
        # print(f"PREDICTED PROGRAMS: {pred_programs}")
        # print(f"RATIO: {correct_programs/pred_programs}")

        end = time.time()
        print(f"Inference train time {end-start}")

        del test_losses, outputs, test_outputs

    return epochs
Beispiel #6
0
                test_output = imitate_net.test([data, one_hot_labels, max_len])
                #acc += float((torch.argmax(torch.stack(test_output), dim=2)[:k].permute(1, 0) == labels[:, :-1]).float().sum()) \
                #        / (len(labels) * (k+1)) / types_prog / (config.test_size // config.batch_size)
                pred_images, correct_prog, pred_prog = parser.get_final_canvas(
                    test_output,
                    if_just_expressions=False,
                    if_pred_images=True)
                correct_programs += len(correct_prog)
                pred_programs += len(pred_prog)
                target_images = data_[-1, :, 0, :, :].astype(dtype=bool)
                #iou = np.sum(np.logical_and(target_images, pred_images),
                #             (1, 2)) / \
                #      np.sum(np.logical_or(target_images, pred_images),
                #             (1, 2))
                #cos = cosine_similarity(target_images, pred_images)
                CD += np.sum(chamfer(target_images, pred_images))
                #beam_CD += get_cd(imitate_net, data, one_hot_labels, k)
                #IOU += np.sum(iou)
                #COS += np.sum(cos)

    #metrics["iou"] = IOU / config.test_size
    #metrics["cos"] = COS / config.test_size
    metrics["cd"] = CD / config.test_size
    #beam_CD = beam_CD / config.test_size

    test_losses = loss.data
    test_loss = test_losses.cpu().numpy() / (config.test_size //
                                             (config.batch_size))

    reduce_plat.reduce_on_plateu(metrics["cd"])
    #reduce_plat.reduce_on_plateu(beam_CD)
Beispiel #7
0
def infer_programs(imitate_net, self_training=False, ab=None):
    config = read_config.Config("config_cad.yml")

    # Load the terminals symbols of the grammar
    with open("terminals.txt", "r") as file:
        unique_draw = file.readlines()
    for index, e in enumerate(unique_draw):
        unique_draw[index] = e[0:-1]

    config.train_size = 10000
    config.test_size = 3000
    imitate_net.eval()
    imitate_net.epsilon = 0
    parser = ParseModelOutput(unique_draw, max_len // 2 + 1, max_len,
                              config.canvas_shape)

    generator = Generator()
    test_gen = generator.test_gen(batch_size=config.batch_size,
                                  path="data/cad/cad.h5",
                                  if_augment=False)

    pred_expressions = []
    Rs = 0
    CDs = 0
    Target_images = []
    for batch_idx in range(config.test_size // config.batch_size):
        with torch.no_grad():
            print(f"Inferring test cad batch: {batch_idx}")
            data_ = next(test_gen)
            labels = np.zeros((config.batch_size, max_len), dtype=np.int32)
            one_hot_labels = prepare_input_op(labels, len(unique_draw))
            one_hot_labels = torch.from_numpy(one_hot_labels).to(device)
            data = torch.from_numpy(data_).to(device)

            all_beams, next_beams_prob, all_inputs = imitate_net.beam_search(
                [data[-1, :, 0, :, :], one_hot_labels], beam_width, max_len)

            beam_labels = beams_parser(all_beams,
                                       data_.shape[1],
                                       beam_width=beam_width)

            beam_labels_numpy = np.zeros(
                (config.batch_size * beam_width, max_len), dtype=np.int32)
            Target_images.append(data_[-1, :, 0, :, :])
            for i in range(data_.shape[1]):
                beam_labels_numpy[i * beam_width:(i + 1) *
                                  beam_width, :] = beam_labels[i]

            # find expression from these predicted beam labels
            expressions = [""] * config.batch_size * beam_width
            for i in range(config.batch_size * beam_width):
                for j in range(max_len):
                    expressions[i] += unique_draw[beam_labels_numpy[i, j]]
            for index, prog in enumerate(expressions):
                expressions[index] = prog.split("$")[0]

            pred_expressions += expressions
            predicted_images = image_from_expressions(parser, expressions)
            target_images = data_[-1, :, 0, :, :].astype(dtype=bool)
            target_images_new = np.repeat(target_images,
                                          axis=0,
                                          repeats=beam_width)

            beam_CD = chamfer(target_images_new, predicted_images)

            CD = np.zeros((config.batch_size, 1))
            for r in range(config.batch_size):
                CD[r, 0] = min(beam_CD[r * beam_width:(r + 1) * beam_width])

            CDs += np.mean(CD)

            for j in range(0, config.batch_size):
                f, a = plt.subplots(1, beam_width + 1, figsize=(30, 3))
                a[0].imshow(data_[-1, j, 0, :, :], cmap="Greys_r")
                a[0].axis("off")
                a[0].set_title("target")
                for i in range(1, beam_width + 1):
                    a[i].imshow(predicted_images[j * beam_width + i - 1],
                                cmap="Greys_r")
                    a[i].set_title("{}".format(i))
                    a[i].axis("off")
                plt.savefig("best_lest/" +
                            "{}.png".format(batch_idx * config.batch_size + j),
                            transparent=0)
                plt.close("all")
            # with open("best_st_expressions.txt", "w") as file:
            #     for e in pred_expressions:
            #         file.write(f"{e}\n")
            # break

    return CDs / (config.test_size // config.batch_size)
Beispiel #8
0
def train_inference(inference_net, iter):
    config = read_config.Config("config_synthetic.yml")

    generator = WakeSleepGen(
        f"wake_sleep_data/inference/{iter}/labels/labels.pt",
        f"wake_sleep_data/inference/{iter}/labels/val/labels.pt",
        batch_size=config.batch_size,
        train_size=inference_train_size,
        test_size=inference_test_size,
        canvas_shape=config.canvas_shape,
        max_len=max_len)

    train_gen = generator.get_train_data()
    test_gen = generator.get_test_data()

    encoder_net, imitate_net = inference_net

    optimizer = optim.Adam(
        [para for para in imitate_net.parameters() if para.requires_grad],
        weight_decay=config.weight_decay,
        lr=config.lr)

    reduce_plat = LearningRate(optimizer,
                               init_lr=config.lr,
                               lr_dacay_fact=0.2,
                               patience=config.patience)

    best_test_loss = 1e20
    best_imitate_dict = imitate_net.state_dict()

    prev_test_cd = 1e20
    prev_test_iou = 0

    patience = 5
    num_worse = 0

    for epoch in range(50):
        train_loss = 0
        Accuracies = []
        imitate_net.train()
        for batch_idx in range(inference_train_size //
                               (config.batch_size * config.num_traj)):
            optimizer.zero_grad()
            loss = Variable(torch.zeros(1)).to(device).data
            for _ in range(config.num_traj):
                batch_data, batch_labels = next(train_gen)
                batch_data = batch_data.to(device)
                batch_labels = batch_labels.to(device)
                batch_data = batch_data[:, :, 0:1, :, :]
                one_hot_labels = prepare_input_op(batch_labels, vocab_size)
                one_hot_labels = Variable(
                    torch.from_numpy(one_hot_labels)).to(device)
                outputs = imitate_net([batch_data, one_hot_labels, max_len])

                loss_k = (losses_joint(
                    outputs, batch_labels, time_steps=max_len + 1) /
                          (max_len + 1)) / config.num_traj
                loss_k.backward()
                loss += loss_k.data
                del loss_k

            optimizer.step()
            train_loss += loss
            print(f"batch {batch_idx} train loss: {loss.cpu().numpy()}")

        mean_train_loss = train_loss / (inference_train_size //
                                        (config.batch_size))
        print(
            f"epoch {epoch} mean train loss: {mean_train_loss.cpu().numpy()}")
        imitate_net.eval()
        loss = Variable(torch.zeros(1)).to(device)
        metrics = {"cos": 0, "iou": 0, "cd": 0}
        IOU = 0
        COS = 0
        CD = 0
        for batch_idx in range(inference_test_size // config.batch_size):
            with torch.no_grad():
                batch_data, batch_labels = next(test_gen)
                batch_data = batch_data.to(device)
                batch_labels = batch_labels.to(device)
                one_hot_labels = prepare_input_op(batch_labels, vocab_size)
                one_hot_labels = Variable(
                    torch.from_numpy(one_hot_labels)).to(device)
                test_outputs = imitate_net(
                    [batch_data, one_hot_labels, max_len])
                loss += (losses_joint(
                    test_outputs, batch_labels, time_steps=max_len + 1) /
                         (max_len + 1))
                test_output = imitate_net.test(
                    [batch_data, one_hot_labels, max_len])
                pred_images, correct_prog, pred_prog = generator.parser.get_final_canvas(
                    test_output,
                    if_just_expressions=False,
                    if_pred_images=True)
                target_images = batch_data.cpu().numpy()[-1, :,
                                                         0, :, :].astype(
                                                             dtype=bool)
                iou = np.sum(np.logical_and(target_images, pred_images),
                             (1, 2)) / \
                      np.sum(np.logical_or(target_images, pred_images),
                             (1, 2))
                cos = cosine_similarity(target_images, pred_images)
                CD += np.sum(chamfer(target_images, pred_images))
                IOU += np.sum(iou)
                COS += np.sum(cos)

        metrics["iou"] = IOU / inference_test_size
        metrics["cos"] = COS / inference_test_size
        metrics["cd"] = CD / inference_test_size

        test_losses = loss.data
        test_loss = test_losses.cpu().numpy() / (inference_test_size //
                                                 (config.batch_size))

        if test_loss >= best_test_loss:
            num_worse += 1
        else:
            num_worse = 0
            best_test_loss = test_loss
            best_imitate_dict = imitate_net.state_dict()
        if num_worse >= patience:
            # load the best model and stop training
            imitate_net.load_state_dict(best_imitate_dict)
            break

        reduce_plat.reduce_on_plateu(metrics["cd"])
        print("Epoch {}/{}=>  train_loss: {}, iou: {}, cd: {}, test_mse: {}".
              format(
                  epoch,
                  config.epochs,
                  mean_train_loss.cpu().numpy(),
                  metrics["iou"],
                  metrics["cd"],
                  test_loss,
              ))

        print(f"CORRECT PROGRAMS: {len(generator.correct_programs)}")

        del test_losses, test_outputs
            pred_images = []
            for index, exp in enumerate(expressions):
                program = parser.Parser.parse(exp)
                if validity(program, len(program), len(program) - 1):
                    stack = parser.expression2stack([exp])
                    pred_images.append(stack[-1, -1, 0, :, :])
                else:
                    pred_images.append(np.zeros(config.canvas_shape))
            pred_images = np.stack(pred_images, 0).astype(dtype=np.bool)
            target_images = data_[-1, :, 0, :, :].astype(dtype=bool)

            # repeat the target_images beamwidth times
            target_images_new = np.repeat(target_images,
                                          axis=0,
                                          repeats=beam_width)
            beam_CD = chamfer(target_images_new, pred_images)

            CD = np.zeros((test_batch_size, 1))
            for r in range(test_batch_size):
                CD[r, 0] = min(beam_CD[r * beam_width:(r + 1) * beam_width])
            total_CD += np.sum(CD)

    over_all_CD[jit] = total_CD / total_size

metrics["chamfer"] = over_all_CD
results_path = "trained_models/results/{}/".format(model_name)
os.makedirs(os.path.dirname(results_path), exist_ok=True)
print(metrics)
print(config.pretrain_modelpath)
with open(
        "trained_models/results/{}/{}".format(
Beispiel #10
0
def infer_programs(imitate_net, path, self_training=False, ab=None):
    save_viz = False

    config = read_config.Config("config_cad.yml")

    # Load the terminals symbols of the grammar
    with open("terminals.txt", "r") as file:
        unique_draw = file.readlines()
    for index, e in enumerate(unique_draw):
        unique_draw[index] = e[0:-1]

    config.train_size = 10000
    config.test_size = 3000
    imitate_net.eval()
    imitate_net.epsilon = 0
    parser = ParseModelOutput(unique_draw, max_len // 2 + 1, max_len,
                              config.canvas_shape)
    pred_expressions = []
    if ab is not None:
        pred_labels = np.zeros((config.train_size * ab, max_len))
    else:
        pred_labels = np.zeros((config.train_size, max_len))
    image_path = f"{path}/images/"
    results_path = f"{path}/results/"
    labels_path = f"{path}/labels/"

    os.makedirs(os.path.dirname(image_path), exist_ok=True)
    os.makedirs(os.path.dirname(results_path), exist_ok=True)
    os.makedirs(os.path.dirname(labels_path), exist_ok=True)
    os.makedirs(os.path.dirname(labels_path + "val/"), exist_ok=True)

    generator = Generator()

    train_gen = generator.train_gen(batch_size=config.batch_size,
                                    path="data/cad/cad.h5",
                                    if_augment=False)

    val_gen = generator.val_gen(batch_size=config.batch_size,
                                path="data/cad/cad.h5",
                                if_augment=False)

    Rs = 0
    CDs = 0
    Target_images = []

    start = time.time()
    pred_images = np.zeros((config.train_size, 64, 64))
    for batch_idx in range(config.train_size // config.batch_size):
        with torch.no_grad():
            print(f"Inferring cad batch: {batch_idx}")
            data_ = next(train_gen)
            labels = np.zeros((config.batch_size, max_len), dtype=np.int32)
            one_hot_labels = prepare_input_op(labels, len(unique_draw))
            one_hot_labels = torch.from_numpy(one_hot_labels).to(device)
            data = torch.from_numpy(data_).to(device)

            all_beams, next_beams_prob, all_inputs = imitate_net.beam_search(
                [data[-1, :, 0, :, :], one_hot_labels], beam_width, max_len)

            beam_labels = beams_parser(all_beams,
                                       data_.shape[1],
                                       beam_width=beam_width)

            beam_labels_numpy = np.zeros(
                (config.batch_size * beam_width, max_len), dtype=np.int32)
            Target_images.append(data_[-1, :, 0, :, :])
            for i in range(data_.shape[1]):
                beam_labels_numpy[i * beam_width:(i + 1) *
                                  beam_width, :] = beam_labels[i]

            # find expression from these predicted beam labels
            expressions = [""] * config.batch_size * beam_width
            for i in range(config.batch_size * beam_width):
                for j in range(max_len):
                    expressions[i] += unique_draw[beam_labels_numpy[i, j]]
            for index, prog in enumerate(expressions):
                expressions[index] = prog.split("$")[0]

            pred_expressions += expressions
            predicted_images = image_from_expressions(parser, expressions)
            target_images = data_[-1, :, 0, :, :].astype(dtype=bool)
            target_images_new = np.repeat(target_images,
                                          axis=0,
                                          repeats=beam_width)

            # beam_R = np.sum(np.logical_and(target_images_new, predicted_images),
            #                 (1, 2)) / np.sum(np.logical_or(target_images_new, predicted_images), (1, 2))
            #
            # R = np.zeros((config.batch_size, 1))
            # for r in range(config.batch_size):
            #     R[r, 0] = max(beam_R[r * beam_width:(r + 1) * beam_width])
            #
            # Rs += np.mean(R)

            beam_CD = chamfer(target_images_new, predicted_images)

            # select best expression by chamfer distance
            if ab is None:
                best_labels = np.zeros((config.batch_size, max_len))
                for r in range(config.batch_size):
                    idx = np.argmin(beam_CD[r * beam_width:(r + 1) *
                                            beam_width])
                    best_labels[r] = beam_labels[r][idx]
                pred_labels[batch_idx *
                            config.batch_size:batch_idx * config.batch_size +
                            config.batch_size] = best_labels
            else:
                best_labels = np.zeros((config.batch_size * ab, max_len))
                for r in range(config.batch_size):
                    sorted_idx = np.argsort(beam_CD[r * beam_width:(r + 1) *
                                                    beam_width])[:ab]
                    best_labels[r * ab:r * ab +
                                ab] = beam_labels[r][sorted_idx]
                pred_labels[batch_idx * config.batch_size *
                            ab:batch_idx * config.batch_size * ab +
                            config.batch_size * ab] = best_labels

            CD = np.zeros((config.batch_size, 1))
            for r in range(config.batch_size):
                CD[r, 0] = min(beam_CD[r * beam_width:(r + 1) * beam_width])
                pred_images[batch_idx * config.batch_size +
                            r] = predicted_images[r * beam_width + np.argmin(
                                beam_CD[r * beam_width:(r + 1) * beam_width])]

            CDs += np.mean(CD)

            if save_viz:
                for j in range(0, config.batch_size):
                    f, a = plt.subplots(1, beam_width + 1, figsize=(30, 3))
                    a[0].imshow(data_[-1, j, 0, :, :], cmap="Greys_r")
                    a[0].axis("off")
                    a[0].set_title("target")
                    for i in range(1, beam_width + 1):
                        a[i].imshow(predicted_images[j * beam_width + i - 1],
                                    cmap="Greys_r")
                        a[i].set_title("{}".format(i))
                        a[i].axis("off")
                    plt.savefig(
                        image_path +
                        "{}.png".format(batch_idx * config.batch_size + j),
                        transparent=0)
                    plt.close("all")

                    save_viz = False

    print("Inferring cad average chamfer distance: {}".format(
        CDs / (config.train_size // config.batch_size)),
          flush=True)

    Rs = Rs / (config.train_size // config.batch_size)
    CDs = CDs / (config.train_size // config.batch_size)
    print(Rs, CDs)
    results = {"iou": Rs, "chamferdistance": CDs}

    with open(results_path + "results_beam_width_{}.org".format(beam_width),
              'w') as outfile:
        json.dump(results, outfile)

    torch.save(pred_labels, labels_path + "labels.pt")
    # torch.save(pred_images, labels_path + "images.pt")
    if self_training:
        if ab is None:
            torch.save(np.concatenate(Target_images, axis=0),
                       labels_path + "images.pt")
        else:
            torch.save(
                np.repeat(np.concatenate(Target_images, axis=0), ab, axis=0),
                labels_path + "images.pt")

    test_gen = generator.test_gen(batch_size=config.batch_size,
                                  path="data/cad/cad.h5",
                                  if_augment=False)

    pred_expressions = []
    Rs = 0
    CDs = 0
    Target_images = []
    for batch_idx in range(config.test_size // config.batch_size):
        with torch.no_grad():
            print(f"Inferring test cad batch: {batch_idx}")
            data_ = next(test_gen)
            labels = np.zeros((config.batch_size, max_len), dtype=np.int32)
            one_hot_labels = prepare_input_op(labels, len(unique_draw))
            one_hot_labels = torch.from_numpy(one_hot_labels).to(device)
            data = torch.from_numpy(data_).to(device)

            all_beams, next_beams_prob, all_inputs = imitate_net.beam_search(
                [data[-1, :, 0, :, :], one_hot_labels], beam_width, max_len)

            beam_labels = beams_parser(all_beams,
                                       data_.shape[1],
                                       beam_width=beam_width)

            beam_labels_numpy = np.zeros(
                (config.batch_size * beam_width, max_len), dtype=np.int32)
            Target_images.append(data_[-1, :, 0, :, :])
            for i in range(data_.shape[1]):
                beam_labels_numpy[i * beam_width:(i + 1) *
                                  beam_width, :] = beam_labels[i]

            # find expression from these predicted beam labels
            expressions = [""] * config.batch_size * beam_width
            for i in range(config.batch_size * beam_width):
                for j in range(max_len):
                    expressions[i] += unique_draw[beam_labels_numpy[i, j]]
            for index, prog in enumerate(expressions):
                expressions[index] = prog.split("$")[0]

            pred_expressions += expressions
            predicted_images = image_from_expressions(parser, expressions)
            target_images = data_[-1, :, 0, :, :].astype(dtype=bool)
            target_images_new = np.repeat(target_images,
                                          axis=0,
                                          repeats=beam_width)

            beam_CD = chamfer(target_images_new, predicted_images)

            CD = np.zeros((config.batch_size, 1))
            for r in range(config.batch_size):
                CD[r, 0] = min(beam_CD[r * beam_width:(r + 1) * beam_width])

            CDs += np.mean(CD)

    print(f"TEST CD: {CDs / (config.test_size // config.batch_size)}")

    end = time.time()
    print(f"Inference time: {end-start}")