Exemple #1
0
def main(cf):
    print(f"device [{cf.device}]")
    print("loading MNIST data...")
    train_set = mnist_utils.get_mnist_train_set()
    test_set = mnist_utils.get_mnist_test_set()

    img_train = mnist_utils.get_imgs(train_set)
    img_test = mnist_utils.get_imgs(test_set)
    label_train = mnist_utils.get_labels(train_set)
    label_test = mnist_utils.get_labels(test_set)

    if cf.data_size is not None:
        test_size = cf.data_size // 5
        img_train = img_train[:, 0:cf.data_size]
        label_train = label_train[:, 0:cf.data_size]
        img_test = img_test[:, 0:test_size]
        label_test = label_test[:, 0:test_size]

    msg = "img_train {} img_test {} label_train {} label_test {}"
    print(
        msg.format(img_train.shape, img_test.shape, label_train.shape,
                   label_test.shape))

    print("performing preprocessing...")
    if cf.apply_scaling:
        img_train = mnist_utils.scale_imgs(img_train, cf.img_scale)
        img_test = mnist_utils.scale_imgs(img_test, cf.img_scale)
        label_train = mnist_utils.scale_labels(label_train, cf.label_scale)
        label_test = mnist_utils.scale_labels(label_test, cf.label_scale)

    if cf.apply_inv:
        img_train = F.f_inv(img_train, cf.act_fn)
        img_test = F.f_inv(img_test, cf.act_fn)

    model = PredictiveCodingNetwork(cf)

    with torch.no_grad():
        for epoch in range(cf.n_epochs):
            print(f"\nepoch {epoch}")

            img_batches, label_batches = mnist_utils.get_batches(
                img_train, label_train, cf.batch_size)
            print(
                f"training on {len(img_batches)} batches of size {cf.batch_size}"
            )
            model.train_epoch(label_batches, img_batches, epoch_num=epoch)

            img_batches, label_batches = mnist_utils.get_batches(
                img_test, label_test, cf.batch_size)
            print("generating images...")
            pred_imgs = model.generate_data(label_batches[0])
            mnist_utils.plot_imgs(pred_imgs, cf.img_path.format(epoch))

            perm = np.random.permutation(img_train.shape[1])
            img_train = img_train[:, perm]
            label_train = label_train[:, perm]
Exemple #2
0
def main(cf):
    print(f"device [{cf.device}]")
    print("loading MNIST data...")
    train_set = mnist_utils.get_mnist_train_set()
    test_set = mnist_utils.get_mnist_test_set()

    img_train = mnist_utils.get_imgs(train_set)
    img_test = mnist_utils.get_imgs(test_set)
    label_train = mnist_utils.get_labels(train_set)
    label_test = mnist_utils.get_labels(test_set)

    if cf.data_size is not None:
        test_size = cf.data_size // 5
        img_train = img_train[:, 0:cf.data_size]
        label_train = label_train[:, 0:cf.data_size]
        img_test = img_test[:, 0:test_size]
        label_test = label_test[:, 0:test_size]

    msg = "img_train {} img_test {} label_train {} label_test {}"
    print(
        msg.format(img_train.shape, img_test.shape, label_train.shape,
                   label_test.shape))

    print("performing preprocessing...")
    if cf.apply_scaling:
        img_train = mnist_utils.scale_imgs(img_train, cf.img_scale)
        img_test = mnist_utils.scale_imgs(img_test, cf.img_scale)
        label_train = mnist_utils.scale_labels(label_train, cf.label_scale)
        label_test = mnist_utils.scale_labels(label_test, cf.label_scale)

    if cf.apply_inv:
        img_train = F.f_inv(img_train, cf.act_fn)
        img_test = F.f_inv(img_test, cf.act_fn)

    model = QCodingNetwork(cf)

    q_accs = []
    h_accs = []
    p_accs = []

    with torch.no_grad():
        for epoch in range(cf.n_epochs):
            print(f"\nepoch {epoch}")

            img_batches, label_batches = mnist_utils.get_batches(
                img_train, label_train, cf.batch_size)
            print(
                f"> training on {len(img_batches)} batches of size {cf.batch_size}"
            )
            end_err, init_err, its = model.train_epoch(img_batches,
                                                       label_batches,
                                                       epoch_num=epoch)
            print("end_err {} / init_err {} / its {}".format(
                end_err, init_err, its))

            if epoch % cf.test_every == 0:
                img_batches, label_batches = mnist_utils.get_batches(
                    img_test, label_test, cf.batch_size)
                print("> generating images...")
                pred_imgs = model.generate_data(label_batches[0])
                mnist_utils.plot_imgs(pred_imgs, cf.img_path.format(epoch))

                if cf.amortised:
                    img_batches, label_batches = mnist_utils.get_batches(
                        img_test, label_test, cf.batch_size)
                    print(
                        f"> testing amortised acc {len(img_batches)} batches of size {cf.batch_size}"
                    )
                    accs = model.test_amortised_epoch(img_batches,
                                                      label_batches)
                    mean_q_acc = np.mean(np.array(accs))
                    q_accs.append(mean_q_acc)
                    print(f"average amortised accuracy {mean_q_acc}")

                img_batches, label_batches = mnist_utils.get_batches(
                    img_test, label_test, cf.batch_size)
                print(
                    f"> testing hybrid acc on {len(img_batches)} batches of size {cf.batch_size}"
                )
                accs, its = model.test_epoch(img_batches,
                                             label_batches,
                                             itr_max=cf.test_itr_max)
                mean_h_acc = np.mean(np.array(accs))
                h_accs.append(mean_h_acc)
                print(f"average hybrid accuracy {mean_h_acc} / its {its}")

                img_batches, label_batches = mnist_utils.get_batches(
                    img_test, label_test, cf.batch_size)
                print(
                    f"> testing PC acc on {len(img_batches)} batches of size {cf.batch_size}"
                )
                accs, its = model.test_pc_epoch(img_batches,
                                                label_batches,
                                                itr_max=cf.test_itr_max)
                mean_p_acc = np.mean(np.array(accs))
                p_accs.append(mean_p_acc)
                print(f"average PC accuracy {mean_p_acc} / its {its}")

                np.save(cf.hybird_path, h_accs)
                np.save(cf.amortised_path, q_accs)
                np.save(cf.pc_path, p_accs)

                perm = np.random.permutation(img_train.shape[1])
                img_train = img_train[:, perm]
                label_train = label_train[:, perm]