Exemplo n.º 1
0
def test(args):
    print(args.save_dir)

    model_args, data_args = load_args(Path(args.save_dir))
    assert not model_args.modelfree, "Code only evaluates on model based models"

    saver = ModelSaver(Path(args.save_dir), None)

    power_constraint = PowerConstraint()
    possible_inputs = get_md_set(model_args.md_len)

    # TODO: change to batch size and batch per epoch to 1000
    data_args.batch_size = 5000
    data_args.batches_per_epoch = 1000
    dataset_size = data_args.batch_size * data_args.batches_per_epoch
    loader = InputDataloader(data_args.batch_size, data_args.block_length,
                             dataset_size)
    loader = loader.example_generator()

    SNRs = [.5, 1, 2, 3, 4]
    BER = []
    loss = []

    for SNR in SNRs:
        print(f"Testing {SNR} SNR level")
        data_args.SNR = SNR
        accuracy = []
        losses = []
        print(data_args.channel)
        print(model_args.modelfree)
        channel = get_channel(data_args.channel, model_args.modelfree,
                              data_args)
        model = AutoEncoder(model_args, data_args, power_constraint, channel,
                            possible_inputs)
        saver.load(model)
        for step in tqdm(range(data_args.batches_per_epoch)):
            msg = next(loader)
            metrics = model.trainable_encoder.test_on_batch(msg, msg)
            losses.append(metrics[0])
            accuracy.append(metrics[1])
        mean_loss = sum(losses) / len(losses)
        mean_BER = 1 - sum(accuracy) / len(accuracy)
        loss.append(mean_loss)
        BER.append(mean_BER)
        print(f"mean BER: {mean_BER}")
        print(f"mean loss: {mean_loss}")

    # create plots for results
    plt.plot(SNRs, BER)
    plt.ylabel("BER")
    plt.xlabel("SNR")
    plt.yscale('log')
    plt.savefig('figures/AWGN_modelaware.png')
    plt.show()
Exemplo n.º 2
0
def test(args):

    model_dict = load_model_dict(Path(args.model_dict_path))

    BER = []
    loss = []
    noises = []

    for noise, save_dir in model_dict.items():
        model_args, data_args = load_args(Path(save_dir))
        assert model_args.modelfree, "Code only evaluates on model free"

        saver = ModelSaver(Path(save_dir), None)
        power_constraint = PowerConstraint()
        possible_inputs = get_md_set(model_args.md_len)

        # TODO: change to batch size and batch per epoch to 1000
        data_args.batch_size = 100
        data_args.batches_per_epoch = 100
        dataset_size = data_args.batch_size * data_args.batches_per_epoch
        loader = InputDataloader(data_args.batch_size, data_args.block_length,
                                 dataset_size)
        loader = loader.example_generator()

        if data_args.channel == "AWGN":
            assert float(noise) == data_args.SNR
        else:
            assert float(noise) == data_args.epsilon

        print(f"Testing {noise} noise level")

        accuracy = []
        losses = []

        channel = get_channel(data_args.channel, model_args.modelfree,
                              data_args)
        model = AutoEncoder(model_args, data_args, power_constraint, channel,
                            possible_inputs)
        channel = get_channel(data_args.channel, model_args.modelfree,
                              data_args)
        model = AutoEncoder(model_args, data_args, power_constraint, channel,
                            possible_inputs)
        saver.load(model)
        for step in tqdm(range(data_args.batches_per_epoch)):
            msg = next(loader)
            metrics = model.trainable_encoder.test_on_batch(msg, msg)
            losses.append(metrics[0])
            accuracy.append(metrics[1])
        mean_loss = sum(losses) / len(losses)
        mean_BER = 1 - sum(accuracy) / len(accuracy)
        loss.append(mean_loss)
        BER.append(mean_BER)
        noises.append(noise)
        print(f"mean BER: {mean_BER}")
        print(f"mean loss: {mean_loss}")

    # create plots for results
    plt.plot(noises, BER, 'b--')
    plt.plot(noises, BER, 'bx')
    plt.ylabel("BER")
    plt.xlabel("noise")
    plt.yscale('log')
    plt.ylim([1e-6, 1.0])
    plt.savefig("figures/figure.png")