def test(args): print(args.save_dir) model_args, data_args = load_args(Path(args.save_dir)) assert not model_args.modelfree, "Code only evaluates on model based models" saver = ModelSaver(Path(args.save_dir), None) power_constraint = PowerConstraint() possible_inputs = get_md_set(model_args.md_len) # TODO: change to batch size and batch per epoch to 1000 data_args.batch_size = 5000 data_args.batches_per_epoch = 1000 dataset_size = data_args.batch_size * data_args.batches_per_epoch loader = InputDataloader(data_args.batch_size, data_args.block_length, dataset_size) loader = loader.example_generator() SNRs = [.5, 1, 2, 3, 4] BER = [] loss = [] for SNR in SNRs: print(f"Testing {SNR} SNR level") data_args.SNR = SNR accuracy = [] losses = [] print(data_args.channel) print(model_args.modelfree) channel = get_channel(data_args.channel, model_args.modelfree, data_args) model = AutoEncoder(model_args, data_args, power_constraint, channel, possible_inputs) saver.load(model) for step in tqdm(range(data_args.batches_per_epoch)): msg = next(loader) metrics = model.trainable_encoder.test_on_batch(msg, msg) losses.append(metrics[0]) accuracy.append(metrics[1]) mean_loss = sum(losses) / len(losses) mean_BER = 1 - sum(accuracy) / len(accuracy) loss.append(mean_loss) BER.append(mean_BER) print(f"mean BER: {mean_BER}") print(f"mean loss: {mean_loss}") # create plots for results plt.plot(SNRs, BER) plt.ylabel("BER") plt.xlabel("SNR") plt.yscale('log') plt.savefig('figures/AWGN_modelaware.png') plt.show()
def test(args): model_dict = load_model_dict(Path(args.model_dict_path)) BER = [] loss = [] noises = [] for noise, save_dir in model_dict.items(): model_args, data_args = load_args(Path(save_dir)) assert model_args.modelfree, "Code only evaluates on model free" saver = ModelSaver(Path(save_dir), None) power_constraint = PowerConstraint() possible_inputs = get_md_set(model_args.md_len) # TODO: change to batch size and batch per epoch to 1000 data_args.batch_size = 100 data_args.batches_per_epoch = 100 dataset_size = data_args.batch_size * data_args.batches_per_epoch loader = InputDataloader(data_args.batch_size, data_args.block_length, dataset_size) loader = loader.example_generator() if data_args.channel == "AWGN": assert float(noise) == data_args.SNR else: assert float(noise) == data_args.epsilon print(f"Testing {noise} noise level") accuracy = [] losses = [] channel = get_channel(data_args.channel, model_args.modelfree, data_args) model = AutoEncoder(model_args, data_args, power_constraint, channel, possible_inputs) channel = get_channel(data_args.channel, model_args.modelfree, data_args) model = AutoEncoder(model_args, data_args, power_constraint, channel, possible_inputs) saver.load(model) for step in tqdm(range(data_args.batches_per_epoch)): msg = next(loader) metrics = model.trainable_encoder.test_on_batch(msg, msg) losses.append(metrics[0]) accuracy.append(metrics[1]) mean_loss = sum(losses) / len(losses) mean_BER = 1 - sum(accuracy) / len(accuracy) loss.append(mean_loss) BER.append(mean_BER) noises.append(noise) print(f"mean BER: {mean_BER}") print(f"mean loss: {mean_loss}") # create plots for results plt.plot(noises, BER, 'b--') plt.plot(noises, BER, 'bx') plt.ylabel("BER") plt.xlabel("noise") plt.yscale('log') plt.ylim([1e-6, 1.0]) plt.savefig("figures/figure.png")