Exemple #1
0
def get_csgnet():
    config = read_config.Config("config_synthetic.yml")

    # Encoder
    encoder_net = Encoder(config.encoder_drop)
    encoder_net = encoder_net.to(device)

    imitate_net = ImitateJoint(hd_sz=config.hidden_size,
                               input_size=config.input_size,
                               encoder=encoder_net,
                               mode=config.mode,
                               num_draws=400,
                               canvas_shape=config.canvas_shape)
    imitate_net = imitate_net.to(device)

    print("pre loading model")
    pretrained_dict = torch.load(config.pretrain_modelpath,
                                 map_location=device)
    imitate_net_dict = imitate_net.state_dict()
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in imitate_net_dict
    }
    imitate_net_dict.update(pretrained_dict)
    imitate_net.load_state_dict(imitate_net_dict)

    return imitate_net.encoder
Exemple #2
0
def get_blank_csgnet():
    config = read_config.Config("config_synthetic.yml")

    # Encoder
    encoder_net = Encoder(config.encoder_drop)
    encoder_net = encoder_net.to(device)

    imitate_net = ImitateJoint(hd_sz=config.hidden_size,
                               input_size=config.input_size,
                               encoder=encoder_net,
                               mode=config.mode,
                               num_draws=400,
                               canvas_shape=config.canvas_shape)
    imitate_net = imitate_net.to(device)

    return imitate_net
Exemple #3
0
def get_csgnet():
    config = read_config.Config("config_synthetic.yml")

    # Encoder
    encoder_net = Encoder(config.encoder_drop)
    encoder_net = encoder_net.to(device)

    # Load the terminals symbols of the grammar
    with open("terminals.txt", "r") as file:
        unique_draw = file.readlines()
    for index, e in enumerate(unique_draw):
        unique_draw[index] = e[0:-1]

    imitate_net = ImitateJoint(hd_sz=config.hidden_size,
                               input_size=config.input_size,
                               encoder=encoder_net,
                               mode=config.mode,
                               num_draws=len(unique_draw),
                               canvas_shape=config.canvas_shape)
    imitate_net = imitate_net.to(device)

    print("pre loading model")
    pretrained_dict = torch.load(config.pretrain_modelpath,
                                 map_location=device)
    imitate_net_dict = imitate_net.state_dict()
    imitate_pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in imitate_net_dict
    }
    imitate_net_dict.update(imitate_pretrained_dict)
    imitate_net.load_state_dict(imitate_net_dict)

    for param in imitate_net.parameters():
        param.requires_grad = True

    for param in encoder_net.parameters():
        param.requires_grad = True

    return (encoder_net, imitate_net)
Exemple #4
0
import os

import numpy as np
import torch
from torch.autograd.variable import Variable

from src.Models.models2 import Encoder
from src.Models.models2 import ImitateJoint, validity
# from src.Models.models_cont2 import ParseModelOutput
from src.Models.models2 import ParseModelOutput
from src.utils import read_config
from src.utils.generators.mixed_len_generator import MixedGenerateData
from src.utils.train_utils import prepare_input_op, chamfer
from cont import labels_to_cont

config = read_config.Config("config_synthetic.yml")
model_name = config.pretrain_modelpath.split("/")[-1][0:-4]
encoder_net = Encoder()
encoder_net.cuda()

data_labels_paths = {
    3: "data/synthetic/one_op/expressions.txt",
    5: "data/synthetic/two_ops/expressions.txt",
    7: "data/synthetic/three_ops/expressions.txt",
    9: "data/synthetic/four_ops/expressions.txt",
    11: "data/synthetic/five_ops/expressions.txt",
    13: "data/synthetic/six_ops/expressions.txt"
}
# first element of list is num of training examples, and second is number of
# testing examples.
proportion = config.proportion  # proportion is in percentage. vary from [1, 100].
Exemple #5
0
def train_inference(imitate_net,
                    path,
                    max_epochs=None,
                    self_training=False,
                    ab=None):
    if max_epochs is None:
        epochs = 1000
    else:
        epochs = max_epochs

    config = read_config.Config("config_synthetic.yml")

    if ab is not None:
        train_size = inference_train_size * ab
    else:
        train_size = inference_train_size

    generator = WakeSleepGen(f"{path}/",
                             batch_size=config.batch_size,
                             train_size=train_size,
                             canvas_shape=config.canvas_shape,
                             max_len=max_len,
                             self_training=True)

    train_gen = generator.get_train_data()

    cad_generator = Generator()
    val_gen = cad_generator.val_gen(batch_size=config.batch_size,
                                    path="data/cad/cad.h5",
                                    if_augment=False)

    for parameter in imitate_net.encoder.parameters():
        parameter.requires_grad = False

    optimizer = optim.Adam(
        [para for para in imitate_net.parameters() if para.requires_grad],
        weight_decay=config.weight_decay,
        lr=config.lr)

    reduce_plat = LearningRate(optimizer,
                               init_lr=config.lr,
                               lr_dacay_fact=0.2,
                               patience=config.patience)

    best_test_loss = 1e20
    torch.save(imitate_net.state_dict(), f"{path}/best_dict.pth")

    best_test_cd = 1e20

    patience = 20
    num_worse = 0

    for epoch in range(epochs):
        start = time.time()
        train_loss = 0
        imitate_net.train()
        for batch_idx in range(train_size //
                               (config.batch_size * config.num_traj)):
            optimizer.zero_grad()
            loss = 0
            # acc = 0
            for _ in range(config.num_traj):
                data, labels = next(train_gen)
                # data = data[:, :, 0:1, :, :]
                one_hot_labels = prepare_input_op(labels,
                                                  len(generator.unique_draw))
                one_hot_labels = torch.from_numpy(one_hot_labels).to(device)
                data = data.to(device)
                labels = labels.to(device)
                outputs = imitate_net([data, one_hot_labels, max_len])
                # acc += float((torch.argmax(outputs, dim=2).permute(1, 0) == labels).float().sum()) \
                #        / (labels.shape[0] * labels.shape[1]) / config.num_traj
                loss_k = (
                    (losses_joint(outputs, labels, time_steps=max_len + 1) /
                     (max_len + 1)) / config.num_traj)
                loss_k.backward()
                loss += float(loss_k)
                del loss_k

            optimizer.step()
            train_loss += loss
            print(f"batch {batch_idx} train loss: {loss}")
            # print(f"acc: {acc}")

        mean_train_loss = train_loss / (train_size // (config.batch_size))
        print(f"epoch {epoch} mean train loss: {mean_train_loss}")
        imitate_net.eval()
        loss = 0
        # acc = 0
        metrics = {"cos": 0, "iou": 0, "cd": 0}
        # IOU = 0
        # COS = 0
        CD = 0
        # correct_programs = 0
        # pred_programs = 0
        for batch_idx in range(inference_test_size // config.batch_size):
            parser = ParseModelOutput(generator.unique_draw, max_len // 2 + 1,
                                      max_len, config.canvas_shape)
            with torch.no_grad():
                labels = np.zeros((config.batch_size, max_len), dtype=np.int32)
                data_ = next(val_gen)
                one_hot_labels = prepare_input_op(labels,
                                                  len(generator.unique_draw))
                one_hot_labels = torch.from_numpy(one_hot_labels).cuda()
                data = torch.from_numpy(data_).cuda()
                # outputs = imitate_net([data, one_hot_labels, max_len])
                # loss_k = (losses_joint(outputs, labels, time_steps=max_len + 1) /
                #          (max_len + 1))
                # loss += float(loss_k)
                test_outputs = imitate_net.test(
                    [data[-1, :, 0, :, :], one_hot_labels, max_len])
                # acc += float((torch.argmax(torch.stack(test_outputs), dim=2).permute(1, 0) == labels[:, :-1]).float().sum()) \
                #         / (len(labels) * (max_len+1)) / (inference_test_size // config.batch_size)
                pred_images, correct_prog, pred_prog = parser.get_final_canvas(
                    test_outputs,
                    if_just_expressions=False,
                    if_pred_images=True)
                # correct_programs += len(correct_prog)
                # pred_programs += len(pred_prog)
                target_images = data_[-1, :, 0, :, :].astype(dtype=bool)
                # iou = np.sum(np.logical_and(target_images, pred_images),
                #              (1, 2)) / \
                #       np.sum(np.logical_or(target_images, pred_images),
                #              (1, 2))
                # cos = cosine_similarity(target_images, pred_images)
                CD += np.sum(chamfer(target_images, pred_images))
                # IOU += np.sum(iou)
                # COS += np.sum(cos)

        # metrics["iou"] = IOU / inference_test_size
        # metrics["cos"] = COS / inference_test_size
        metrics["cd"] = CD / inference_test_size

        test_losses = loss
        test_loss = test_losses / (inference_test_size // (config.batch_size))

        if metrics["cd"] >= best_test_cd:
            num_worse += 1
        else:
            num_worse = 0
            best_test_cd = metrics["cd"]
            torch.save(imitate_net.state_dict(), f"{path}/best_dict.pth")
        if num_worse >= patience:
            # load the best model and stop training
            imitate_net.load_state_dict(torch.load(f"{path}/best_dict.pth"))
            return epoch + 1

        # reduce_plat.reduce_on_plateu(metrics["cd"])
        print(
            f"Epoch {epoch}/100 =>  train_loss: {mean_train_loss}, iou: {0}, cd: {metrics['cd']}, test_mse: {test_loss}, test_acc: {0}"
        )
        # print(f"CORRECT PROGRAMS: {correct_programs}")
        # print(f"PREDICTED PROGRAMS: {pred_programs}")
        # print(f"RATIO: {correct_programs/pred_programs}")

        end = time.time()
        print(f"Inference train time {end-start}")

        del test_losses, outputs, test_outputs

    return epochs
Exemple #6
0
import numpy as np
import torch
import torch.optim as optim
import sys
from src.utils import read_config
from tensorboard_logger import configure, log_value
from torch.autograd.variable import Variable
from src.Models.models import ImitateJoint
from src.Models.models import Encoder
from src.utils.generators.shapenet_generater import Generator
from src.utils.learn_utils import LearningRate
from src.utils.reinforce import Reinforce
from src.utils.train_utils import prepare_input_op

if len(sys.argv) > 1:
    config = read_config.Config(sys.argv[1])
else:
    config = read_config.Config("config_cad.yml")

max_len = 15
reward = "chamfer"
power = 20
DATA_PATH = "data/cad/cad.h5"
model_name = config.model_path.format(config.mode)
config.write_config("log/configs/{}_config.json".format(model_name))
config.train_size = 10000
config.test_size = 3000
print(config.config)

# Setup Tensorboard logger
configure("log/tensorboard/{}".format(model_name), flush_secs=5)
Exemple #7
0
def infer_programs(imitate_net, self_training=False, ab=None):
    config = read_config.Config("config_cad.yml")

    # Load the terminals symbols of the grammar
    with open("terminals.txt", "r") as file:
        unique_draw = file.readlines()
    for index, e in enumerate(unique_draw):
        unique_draw[index] = e[0:-1]

    config.train_size = 10000
    config.test_size = 3000
    imitate_net.eval()
    imitate_net.epsilon = 0
    parser = ParseModelOutput(unique_draw, max_len // 2 + 1, max_len,
                              config.canvas_shape)

    generator = Generator()
    test_gen = generator.test_gen(batch_size=config.batch_size,
                                  path="data/cad/cad.h5",
                                  if_augment=False)

    pred_expressions = []
    Rs = 0
    CDs = 0
    Target_images = []
    for batch_idx in range(config.test_size // config.batch_size):
        with torch.no_grad():
            print(f"Inferring test cad batch: {batch_idx}")
            data_ = next(test_gen)
            labels = np.zeros((config.batch_size, max_len), dtype=np.int32)
            one_hot_labels = prepare_input_op(labels, len(unique_draw))
            one_hot_labels = torch.from_numpy(one_hot_labels).to(device)
            data = torch.from_numpy(data_).to(device)

            all_beams, next_beams_prob, all_inputs = imitate_net.beam_search(
                [data[-1, :, 0, :, :], one_hot_labels], beam_width, max_len)

            beam_labels = beams_parser(all_beams,
                                       data_.shape[1],
                                       beam_width=beam_width)

            beam_labels_numpy = np.zeros(
                (config.batch_size * beam_width, max_len), dtype=np.int32)
            Target_images.append(data_[-1, :, 0, :, :])
            for i in range(data_.shape[1]):
                beam_labels_numpy[i * beam_width:(i + 1) *
                                  beam_width, :] = beam_labels[i]

            # find expression from these predicted beam labels
            expressions = [""] * config.batch_size * beam_width
            for i in range(config.batch_size * beam_width):
                for j in range(max_len):
                    expressions[i] += unique_draw[beam_labels_numpy[i, j]]
            for index, prog in enumerate(expressions):
                expressions[index] = prog.split("$")[0]

            pred_expressions += expressions
            predicted_images = image_from_expressions(parser, expressions)
            target_images = data_[-1, :, 0, :, :].astype(dtype=bool)
            target_images_new = np.repeat(target_images,
                                          axis=0,
                                          repeats=beam_width)

            beam_CD = chamfer(target_images_new, predicted_images)

            CD = np.zeros((config.batch_size, 1))
            for r in range(config.batch_size):
                CD[r, 0] = min(beam_CD[r * beam_width:(r + 1) * beam_width])

            CDs += np.mean(CD)

            for j in range(0, config.batch_size):
                f, a = plt.subplots(1, beam_width + 1, figsize=(30, 3))
                a[0].imshow(data_[-1, j, 0, :, :], cmap="Greys_r")
                a[0].axis("off")
                a[0].set_title("target")
                for i in range(1, beam_width + 1):
                    a[i].imshow(predicted_images[j * beam_width + i - 1],
                                cmap="Greys_r")
                    a[i].set_title("{}".format(i))
                    a[i].axis("off")
                plt.savefig("best_lest/" +
                            "{}.png".format(batch_idx * config.batch_size + j),
                            transparent=0)
                plt.close("all")
            # with open("best_st_expressions.txt", "w") as file:
            #     for e in pred_expressions:
            #         file.write(f"{e}\n")
            # break

    return CDs / (config.test_size // config.batch_size)
Exemple #8
0
import numpy as np
import torch
from torch.autograd.variable import Variable
import sys
from src.utils import read_config
from src.Models.models import ImitateJoint
from src.Models.models import Encoder
from src.utils.generators.shapenet_generater import Generator
from src.utils.reinforce import Reinforce
from src.utils.train_utils import prepare_input_op

max_len = 13
power = 20
reward = "chamfer"
if len(sys.argv) > 1:
    config = read_config.Config(sys.argv[1])
else:
    config = read_config.Config("config_synthetic.yml")

encoder_net = Encoder()
encoder_net.cuda()

# CNN encoder
encoder_net = Encoder(config.encoder_drop)
encoder_net.cuda()

# Load the terminals symbols of the grammar
with open("terminals.txt", "r") as file:
    unique_draw = file.readlines()
for index, e in enumerate(unique_draw):
    unique_draw[index] = e[0:-1]
Exemple #9
0
def train_inference(inference_net, iter):
    config = read_config.Config("config_synthetic.yml")

    generator = WakeSleepGen(
        f"wake_sleep_data/inference/{iter}/labels/labels.pt",
        f"wake_sleep_data/inference/{iter}/labels/val/labels.pt",
        batch_size=config.batch_size,
        train_size=inference_train_size,
        test_size=inference_test_size,
        canvas_shape=config.canvas_shape,
        max_len=max_len)

    train_gen = generator.get_train_data()
    test_gen = generator.get_test_data()

    encoder_net, imitate_net = inference_net

    optimizer = optim.Adam(
        [para for para in imitate_net.parameters() if para.requires_grad],
        weight_decay=config.weight_decay,
        lr=config.lr)

    reduce_plat = LearningRate(optimizer,
                               init_lr=config.lr,
                               lr_dacay_fact=0.2,
                               patience=config.patience)

    best_test_loss = 1e20
    best_imitate_dict = imitate_net.state_dict()

    prev_test_cd = 1e20
    prev_test_iou = 0

    patience = 5
    num_worse = 0

    for epoch in range(50):
        train_loss = 0
        Accuracies = []
        imitate_net.train()
        for batch_idx in range(inference_train_size //
                               (config.batch_size * config.num_traj)):
            optimizer.zero_grad()
            loss = Variable(torch.zeros(1)).to(device).data
            for _ in range(config.num_traj):
                batch_data, batch_labels = next(train_gen)
                batch_data = batch_data.to(device)
                batch_labels = batch_labels.to(device)
                batch_data = batch_data[:, :, 0:1, :, :]
                one_hot_labels = prepare_input_op(batch_labels, vocab_size)
                one_hot_labels = Variable(
                    torch.from_numpy(one_hot_labels)).to(device)
                outputs = imitate_net([batch_data, one_hot_labels, max_len])

                loss_k = (losses_joint(
                    outputs, batch_labels, time_steps=max_len + 1) /
                          (max_len + 1)) / config.num_traj
                loss_k.backward()
                loss += loss_k.data
                del loss_k

            optimizer.step()
            train_loss += loss
            print(f"batch {batch_idx} train loss: {loss.cpu().numpy()}")

        mean_train_loss = train_loss / (inference_train_size //
                                        (config.batch_size))
        print(
            f"epoch {epoch} mean train loss: {mean_train_loss.cpu().numpy()}")
        imitate_net.eval()
        loss = Variable(torch.zeros(1)).to(device)
        metrics = {"cos": 0, "iou": 0, "cd": 0}
        IOU = 0
        COS = 0
        CD = 0
        for batch_idx in range(inference_test_size // config.batch_size):
            with torch.no_grad():
                batch_data, batch_labels = next(test_gen)
                batch_data = batch_data.to(device)
                batch_labels = batch_labels.to(device)
                one_hot_labels = prepare_input_op(batch_labels, vocab_size)
                one_hot_labels = Variable(
                    torch.from_numpy(one_hot_labels)).to(device)
                test_outputs = imitate_net(
                    [batch_data, one_hot_labels, max_len])
                loss += (losses_joint(
                    test_outputs, batch_labels, time_steps=max_len + 1) /
                         (max_len + 1))
                test_output = imitate_net.test(
                    [batch_data, one_hot_labels, max_len])
                pred_images, correct_prog, pred_prog = generator.parser.get_final_canvas(
                    test_output,
                    if_just_expressions=False,
                    if_pred_images=True)
                target_images = batch_data.cpu().numpy()[-1, :,
                                                         0, :, :].astype(
                                                             dtype=bool)
                iou = np.sum(np.logical_and(target_images, pred_images),
                             (1, 2)) / \
                      np.sum(np.logical_or(target_images, pred_images),
                             (1, 2))
                cos = cosine_similarity(target_images, pred_images)
                CD += np.sum(chamfer(target_images, pred_images))
                IOU += np.sum(iou)
                COS += np.sum(cos)

        metrics["iou"] = IOU / inference_test_size
        metrics["cos"] = COS / inference_test_size
        metrics["cd"] = CD / inference_test_size

        test_losses = loss.data
        test_loss = test_losses.cpu().numpy() / (inference_test_size //
                                                 (config.batch_size))

        if test_loss >= best_test_loss:
            num_worse += 1
        else:
            num_worse = 0
            best_test_loss = test_loss
            best_imitate_dict = imitate_net.state_dict()
        if num_worse >= patience:
            # load the best model and stop training
            imitate_net.load_state_dict(best_imitate_dict)
            break

        reduce_plat.reduce_on_plateu(metrics["cd"])
        print("Epoch {}/{}=>  train_loss: {}, iou: {}, cd: {}, test_mse: {}".
              format(
                  epoch,
                  config.epochs,
                  mean_train_loss.cpu().numpy(),
                  metrics["iou"],
                  metrics["cd"],
                  test_loss,
              ))

        print(f"CORRECT PROGRAMS: {len(generator.correct_programs)}")

        del test_losses, test_outputs
Exemple #10
0
def infer_programs(imitate_net, path, self_training=False, ab=None):
    save_viz = False

    config = read_config.Config("config_cad.yml")

    # Load the terminals symbols of the grammar
    with open("terminals.txt", "r") as file:
        unique_draw = file.readlines()
    for index, e in enumerate(unique_draw):
        unique_draw[index] = e[0:-1]

    config.train_size = 10000
    config.test_size = 3000
    imitate_net.eval()
    imitate_net.epsilon = 0
    parser = ParseModelOutput(unique_draw, max_len // 2 + 1, max_len,
                              config.canvas_shape)
    pred_expressions = []
    if ab is not None:
        pred_labels = np.zeros((config.train_size * ab, max_len))
    else:
        pred_labels = np.zeros((config.train_size, max_len))
    image_path = f"{path}/images/"
    results_path = f"{path}/results/"
    labels_path = f"{path}/labels/"

    os.makedirs(os.path.dirname(image_path), exist_ok=True)
    os.makedirs(os.path.dirname(results_path), exist_ok=True)
    os.makedirs(os.path.dirname(labels_path), exist_ok=True)
    os.makedirs(os.path.dirname(labels_path + "val/"), exist_ok=True)

    generator = Generator()

    train_gen = generator.train_gen(batch_size=config.batch_size,
                                    path="data/cad/cad.h5",
                                    if_augment=False)

    val_gen = generator.val_gen(batch_size=config.batch_size,
                                path="data/cad/cad.h5",
                                if_augment=False)

    Rs = 0
    CDs = 0
    Target_images = []

    start = time.time()
    pred_images = np.zeros((config.train_size, 64, 64))
    for batch_idx in range(config.train_size // config.batch_size):
        with torch.no_grad():
            print(f"Inferring cad batch: {batch_idx}")
            data_ = next(train_gen)
            labels = np.zeros((config.batch_size, max_len), dtype=np.int32)
            one_hot_labels = prepare_input_op(labels, len(unique_draw))
            one_hot_labels = torch.from_numpy(one_hot_labels).to(device)
            data = torch.from_numpy(data_).to(device)

            all_beams, next_beams_prob, all_inputs = imitate_net.beam_search(
                [data[-1, :, 0, :, :], one_hot_labels], beam_width, max_len)

            beam_labels = beams_parser(all_beams,
                                       data_.shape[1],
                                       beam_width=beam_width)

            beam_labels_numpy = np.zeros(
                (config.batch_size * beam_width, max_len), dtype=np.int32)
            Target_images.append(data_[-1, :, 0, :, :])
            for i in range(data_.shape[1]):
                beam_labels_numpy[i * beam_width:(i + 1) *
                                  beam_width, :] = beam_labels[i]

            # find expression from these predicted beam labels
            expressions = [""] * config.batch_size * beam_width
            for i in range(config.batch_size * beam_width):
                for j in range(max_len):
                    expressions[i] += unique_draw[beam_labels_numpy[i, j]]
            for index, prog in enumerate(expressions):
                expressions[index] = prog.split("$")[0]

            pred_expressions += expressions
            predicted_images = image_from_expressions(parser, expressions)
            target_images = data_[-1, :, 0, :, :].astype(dtype=bool)
            target_images_new = np.repeat(target_images,
                                          axis=0,
                                          repeats=beam_width)

            # beam_R = np.sum(np.logical_and(target_images_new, predicted_images),
            #                 (1, 2)) / np.sum(np.logical_or(target_images_new, predicted_images), (1, 2))
            #
            # R = np.zeros((config.batch_size, 1))
            # for r in range(config.batch_size):
            #     R[r, 0] = max(beam_R[r * beam_width:(r + 1) * beam_width])
            #
            # Rs += np.mean(R)

            beam_CD = chamfer(target_images_new, predicted_images)

            # select best expression by chamfer distance
            if ab is None:
                best_labels = np.zeros((config.batch_size, max_len))
                for r in range(config.batch_size):
                    idx = np.argmin(beam_CD[r * beam_width:(r + 1) *
                                            beam_width])
                    best_labels[r] = beam_labels[r][idx]
                pred_labels[batch_idx *
                            config.batch_size:batch_idx * config.batch_size +
                            config.batch_size] = best_labels
            else:
                best_labels = np.zeros((config.batch_size * ab, max_len))
                for r in range(config.batch_size):
                    sorted_idx = np.argsort(beam_CD[r * beam_width:(r + 1) *
                                                    beam_width])[:ab]
                    best_labels[r * ab:r * ab +
                                ab] = beam_labels[r][sorted_idx]
                pred_labels[batch_idx * config.batch_size *
                            ab:batch_idx * config.batch_size * ab +
                            config.batch_size * ab] = best_labels

            CD = np.zeros((config.batch_size, 1))
            for r in range(config.batch_size):
                CD[r, 0] = min(beam_CD[r * beam_width:(r + 1) * beam_width])
                pred_images[batch_idx * config.batch_size +
                            r] = predicted_images[r * beam_width + np.argmin(
                                beam_CD[r * beam_width:(r + 1) * beam_width])]

            CDs += np.mean(CD)

            if save_viz:
                for j in range(0, config.batch_size):
                    f, a = plt.subplots(1, beam_width + 1, figsize=(30, 3))
                    a[0].imshow(data_[-1, j, 0, :, :], cmap="Greys_r")
                    a[0].axis("off")
                    a[0].set_title("target")
                    for i in range(1, beam_width + 1):
                        a[i].imshow(predicted_images[j * beam_width + i - 1],
                                    cmap="Greys_r")
                        a[i].set_title("{}".format(i))
                        a[i].axis("off")
                    plt.savefig(
                        image_path +
                        "{}.png".format(batch_idx * config.batch_size + j),
                        transparent=0)
                    plt.close("all")

                    save_viz = False

    print("Inferring cad average chamfer distance: {}".format(
        CDs / (config.train_size // config.batch_size)),
          flush=True)

    Rs = Rs / (config.train_size // config.batch_size)
    CDs = CDs / (config.train_size // config.batch_size)
    print(Rs, CDs)
    results = {"iou": Rs, "chamferdistance": CDs}

    with open(results_path + "results_beam_width_{}.org".format(beam_width),
              'w') as outfile:
        json.dump(results, outfile)

    torch.save(pred_labels, labels_path + "labels.pt")
    # torch.save(pred_images, labels_path + "images.pt")
    if self_training:
        if ab is None:
            torch.save(np.concatenate(Target_images, axis=0),
                       labels_path + "images.pt")
        else:
            torch.save(
                np.repeat(np.concatenate(Target_images, axis=0), ab, axis=0),
                labels_path + "images.pt")

    test_gen = generator.test_gen(batch_size=config.batch_size,
                                  path="data/cad/cad.h5",
                                  if_augment=False)

    pred_expressions = []
    Rs = 0
    CDs = 0
    Target_images = []
    for batch_idx in range(config.test_size // config.batch_size):
        with torch.no_grad():
            print(f"Inferring test cad batch: {batch_idx}")
            data_ = next(test_gen)
            labels = np.zeros((config.batch_size, max_len), dtype=np.int32)
            one_hot_labels = prepare_input_op(labels, len(unique_draw))
            one_hot_labels = torch.from_numpy(one_hot_labels).to(device)
            data = torch.from_numpy(data_).to(device)

            all_beams, next_beams_prob, all_inputs = imitate_net.beam_search(
                [data[-1, :, 0, :, :], one_hot_labels], beam_width, max_len)

            beam_labels = beams_parser(all_beams,
                                       data_.shape[1],
                                       beam_width=beam_width)

            beam_labels_numpy = np.zeros(
                (config.batch_size * beam_width, max_len), dtype=np.int32)
            Target_images.append(data_[-1, :, 0, :, :])
            for i in range(data_.shape[1]):
                beam_labels_numpy[i * beam_width:(i + 1) *
                                  beam_width, :] = beam_labels[i]

            # find expression from these predicted beam labels
            expressions = [""] * config.batch_size * beam_width
            for i in range(config.batch_size * beam_width):
                for j in range(max_len):
                    expressions[i] += unique_draw[beam_labels_numpy[i, j]]
            for index, prog in enumerate(expressions):
                expressions[index] = prog.split("$")[0]

            pred_expressions += expressions
            predicted_images = image_from_expressions(parser, expressions)
            target_images = data_[-1, :, 0, :, :].astype(dtype=bool)
            target_images_new = np.repeat(target_images,
                                          axis=0,
                                          repeats=beam_width)

            beam_CD = chamfer(target_images_new, predicted_images)

            CD = np.zeros((config.batch_size, 1))
            for r in range(config.batch_size):
                CD[r, 0] = min(beam_CD[r * beam_width:(r + 1) * beam_width])

            CDs += np.mean(CD)

    print(f"TEST CD: {CDs / (config.test_size // config.batch_size)}")

    end = time.time()
    print(f"Inference time: {end-start}")