def __init__( self, *, model, bits_per_symbol: Union[float, int], optimizer: Optional[str] = 'adam', stepsize_mu: float = 0.0, stepsize_sigma: float = 0.0, initial_std: float = 0.1, min_std: float = 1e-5, max_std: float = 1e2, lambda_baseline: float = 0.0, #used in update method lambda_center: float = 0.0, #used in update method lambda_l1: float = 0.0, #used in update method lambda_l2: float = 0.0, #used in update method **kwargs): self.model = model(bits_per_symbol=bits_per_symbol, **kwargs) self.name = self.model.name self.bits_per_symbol = bits_per_symbol self.log_std_min = np.log(min_std) * torch.ones(2) self.log_std_max = np.log(max_std) * torch.ones(2) self.lambda_l1 = torch.tensor(lambda_l1).float() self.lambda_l2 = torch.tensor(lambda_l2).float() self.lambda_center = torch.tensor(lambda_center).float() self.lambda_baseline = torch.tensor(lambda_baseline).float() self.all_symbols = integers_to_symbols( np.arange(0, 2**bits_per_symbol), bits_per_symbol) self.log_std = nn.Parameter(np.log(initial_std) * torch.ones(2)) # self.std.register_hook(lambda grad: print(grad)) optimizers = { 'adam': torch.optim.Adam, 'sgd': torch.optim.SGD, } if optimizer and hasattr(self.model, "mu_parameters") and not hasattr( self.model, "update"): assert optimizer.lower() in optimizers.keys( ), "modulator optimizer=%s not supported" % optimizer optimizer = optimizers[optimizer.lower()] if kwargs['verbose']: print("Modulator %s initialized with %s optimizer." % (self.model.name, optimizer.__name__)) self.param_dicts = [\ {'params': self.model.mu_parameters(), 'lr':stepsize_mu}, {'params': self.log_std, 'lr':stepsize_sigma}] self.optimizer = optimizer(self.param_dicts) else: if kwargs['verbose']: print("Modulator %s initialized WITHOUT an optimizer" % (self.model.name)) self.optimizer = None if hasattr(self.model, "mu_parameters"): self.param_dicts = [\ {'params': self.model.mu_parameters(), 'lr':stepsize_mu}, {'params': self.log_std, 'lr':stepsize_sigma}] else: self.param_dicts = []
def trainer(*, agents, bits_per_symbol: int, batch_size: int, train_SNR_db: float, signal_power: float = 1.0, backwards_only: bool = False, **kwargs): integers_to_symbols_map = integers_to_symbols( np.arange(0, 2**bits_per_symbol), bits_per_symbol) A = agents[0] B = agents[1] batches_sent = 0 integers = np.random.randint(low=0, high=2**bits_per_symbol, size=[batch_size]) preamble = integers_to_symbols_map[integers] # new shared preamble # A if A.to_echo is not None: c_signal_backward = A.mod.modulate(A.to_echo, mode='explore', dtype='cartesian') c_signal_forward = A.mod.modulate(preamble, mode='explore', dtype='cartesian') A.preamble = preamble A.actions = c_signal_forward # Channel if A.to_echo is not None: c_signal_backward_noisy = add_awgn(c_signal_backward, SNR_db=train_SNR_db, signal_power=signal_power) c_signal_forward_noisy = add_awgn(c_signal_forward, SNR_db=train_SNR_db, signal_power=signal_power) # B if A.to_echo is not None: preamble_roundtrip = B.demod.demodulate(c_signal_backward_noisy) # Update mod after a roundtrip pass B.mod.update(B.preamble, B.actions, preamble_roundtrip) batches_sent += 1 # guess of new preamble B.to_echo = B.demod.demodulate(c_signal_forward_noisy) # Update demod after a oneway pass if not backwards_only: B.demod.update(c_signal_forward_noisy, preamble) batches_sent += 1 return A, B, batches_sent
def __init__( self, *, model, bits_per_symbol, optimizer: Optional[str] = 'adam', stepsize_cross_entropy: float = 1e-3, # cross_entropy_weight: float = 1.0, # epochs: int = 5, lambda_l1: float = 0.0, lambda_l2: float = 0.0, **kwargs): self.epochs = epochs self.bits_per_symbol = bits_per_symbol self.model = model(bits_per_symbol=bits_per_symbol, **kwargs) self.name = self.model.name self.lambda_l1 = torch.tensor(lambda_l1).float() self.lambda_l2 = torch.tensor(lambda_l2).float() self.integers_to_symbols_map = integers_to_symbols( np.arange(0, 2**bits_per_symbol), bits_per_symbol) optimizers = { 'adam': torch.optim.Adam, 'sgd': torch.optim.SGD, } if optimizer and hasattr(self.model, 'parameters') and not hasattr( self.model, "update"): assert optimizer.lower() in optimizers.keys( ), "demodulator optimizer=%s not supported" % optimizer optimizer = optimizers[optimizer.lower()] if kwargs['verbose']: print("Demodulator %s initialized with %s optimizer." % (self.model.name, optimizer.__name__)) self.cross_entropy_weight = torch.tensor( cross_entropy_weight).float() self.param_dicts = [ \ {'params': self.model.parameters(), 'lr': stepsize_cross_entropy}, ] self.optimizer = optimizer(self.param_dicts, lr=stepsize_cross_entropy) else: if kwargs['verbose']: print("Demodulator %s initialized WITHOUT an optimizer" % (self.model.name)) self.optimizer = None if hasattr(self.model, 'parameters'): self.param_dicts = [ \ {'params': self.model.parameters(), 'lr': stepsize_cross_entropy}, ] else: self.param_dicts = []
def test( *, agent1, agent2, bits_per_symbol, test_SNR_db, signal_power=1.0, test_batch_size=100000, ): integers = np.random.randint(low=0, high=2**bits_per_symbol, size=[test_batch_size]) preamble = integers_to_symbols(integers, bits_per_symbol=bits_per_symbol) A = agent1 B = agent2 c_signal_forward = A.mod.modulate(preamble, mode='exploit', dtype='cartesian') _c_signal_forward = B.mod.modulate(preamble, mode='exploit', dtype='cartesian') c_signal_forward_noisy = add_awgn(c_signal_forward, SNR_db=test_SNR_db, signal_power=signal_power) preamble_halftrip = B.demod.demodulate(c_signal_forward_noisy) # c_signal_backward = B.mod.modulate(preamble_halftrip, mode='exploit', dtype='cartesian') c_signal_backward_noisy = add_awgn(c_signal_backward, SNR_db=test_SNR_db, signal_power=signal_power) preamble_roundtrip = A.demod.demodulate(c_signal_backward_noisy) _c_signal_forward_noisy = add_awgn(_c_signal_forward, SNR_db=test_SNR_db, signal_power=signal_power) _preamble_halftrip = A.demod.demodulate(_c_signal_forward_noisy) # _c_signal_backward = A.mod.modulate(_preamble_halftrip, mode='exploit', dtype='cartesian') _c_signal_backward_noisy = add_awgn(_c_signal_backward, SNR_db=test_SNR_db, signal_power=signal_power) _preamble_roundtrip = B.demod.demodulate(_c_signal_backward_noisy) return (float(get_ber(preamble, preamble_roundtrip)) + float(get_ber(preamble, _preamble_roundtrip))) / 2.0
def test( *, agent1, agent2, bits_per_symbol, test_SNR_dbs, signal_power=1.0, test_batch_size=10000, ): integers = np.random.randint(low=0, high=2**bits_per_symbol, size=[test_batch_size]) preamble = integers_to_symbols(integers, bits_per_symbol=bits_per_symbol) A = agent1 B = agent2 if agent2 is not None else agent1 c_signal_forward = A.mod.modulate(preamble, mode='exploit', dtype='cartesian') _c_signal_forward = B.mod.modulate(preamble, mode='exploit', dtype='cartesian') for test_SNR_db in test_SNR_dbs: c_signal_forward_noisy = add_awgn(c_signal_forward, SNR_db=test_SNR_db, signal_power=signal_power) preamble_halftrip = B.demod.demodulate(c_signal_forward_noisy) # c_signal_backward = B.mod.modulate(preamble_halftrip, mode='exploit', dtype='cartesian') c_signal_backward_noisy = add_awgn(c_signal_backward, SNR_db=test_SNR_db, signal_power=signal_power) preamble_roundtrip = A.demod.demodulate(c_signal_backward_noisy) if agent2 is not None: _c_signal_forward_noisy = add_awgn(_c_signal_forward, SNR_db=test_SNR_db, signal_power=signal_power) _preamble_halftrip = A.demod.demodulate(_c_signal_forward_noisy) # _c_signal_backward = A.mod.modulate(_preamble_halftrip, mode='exploit', dtype='cartesian') _c_signal_backward_noisy = add_awgn(_c_signal_backward, SNR_db=test_SNR_db, signal_power=signal_power) _preamble_roundtrip = B.demod.demodulate(_c_signal_backward_noisy) return "TODOTODOTODO"
def get_random_preamble(n, bits_per_symbol): integers = np.random.randint(low=0, high=2 ** bits_per_symbol, size=[n]) return integers_to_symbols(integers, bits_per_symbol)
def animated_plot(results_file="/Users/caryn/Desktop/echo/experiments/private_" "preamble/QPSK_poly_hyperparam_search/results/0.npy", results=None): # (result): plt.ion() # plt.show() if results is None: results = np.load(open(results_file, 'rb'), allow_pickle=True) meta, results = results[0], results[1:] num_agents = meta['num_agents'] bits_per_symbol = meta['bits_per_symbol'] num_colors = 2**bits_per_symbol # cm = plt.get_cmap('tab20') cm = plt.cm.viridis colors = [cm(int(x * cm.N / num_colors)) for x in range(num_colors)] grid = get_grid_2d() if num_agents == 2: pairs = [(1, 2), (2, 1)] fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(8, 8)) axes_list = [item for sublist in axes for item in sublist] elif num_agents == 1: pairs = [(1, 1)] fig, axes = plt.subplots(nrows=1, ncols=2, sharex=True, sharey=True, figsize=(8, 4)) axes_list = [item for item in axes] for result in results: _a_ = 0 for a, b in pairs: m, c = (result['mod_std_%i' % a], result['constellation_%i' % a]) d = result['demod_grid_%i' % b] i_means = c[:, 0] q_means = c[:, 1] cov = np.eye(2, 2) * m**2 i_std, q_std = m minor = np.sqrt(9.210) * q_std major = np.sqrt(9.210) * i_std ax = axes_list[_a_] _a_ += 1 ax.scatter(i_means, q_means, c=colors) ells = [ Ellipse(xy=c[i], width=major, height=minor, color=colors[i], alpha=.2) for i in range(len(c)) ] # ells = [confidence_ellipse(cov, c[i], ax, facecolor=colors[i], alpha=.2) for i in range(len(c))] for i, e in enumerate(ells): ax.add_artist(e) # e.set_clip_box(ax.bbox) # e.set_alpha(.2) # e.set_facecolor(colors[i]) for label, (x, y) in \ zip(integers_to_symbols(np.array([i for i in range(2 ** bits_per_symbol)]), bits_per_symbol), c): ax.annotate(label, xy=(x, y), xytext=(0, 0), textcoords='offset points') ax.set_xlim(-1.5, 1.5) ax.set_ylim(-1.5, 1.5) ax1 = axes_list[_a_] _a_ += 1 ax1.scatter(grid[:, 0], grid[:, 1], c=[colors[i] for i in d]) ax1.set_xlim(-1.5, 1.5) ax1.set_ylim(-1.5, 1.5) plt.pause(0.0000000001) for ax in axes_list: ax.clear() fig.canvas.draw() plt.ioff() plt.close(fig) # plt.savefig("%s/%s_demod-%d.png" % (plots_dir, "_".join(agent.name.lower().split(" ")), plot_count)) return
def train(*, agents, bits_per_symbol: int, batch_size: int, num_iterations: int, results_every: int, train_SNR_db: float, signal_power: float, early_stopping: bool = False, early_stopping_db_off: float = 1, verbose: bool = False, **kwargs ): br = BER_lookup_table() early_stop = False integers_to_symbols_map = integers_to_symbols(np.arange(0, 2 ** bits_per_symbol), bits_per_symbol) if verbose: print("shared_preamble train.py") Amod = agents[0].mod Ademod = agents[0].demod Bmod = agents[1].mod Bdemod = agents[1].demod prev_preamble = None prev_actions = None batches_sent = 0 results = [] for i in range(num_iterations + 1): # A.mod(preamble) | | B.demod(signal forward) |==> B update demod | # |--> channel -->| | |--> switch (A,B = B,A) # A.mod(pre-half) | | A.demod(signal backward) |==> B update mod | integers = np.random.randint(low=0, high=2 ** bits_per_symbol, size=[batch_size]) preamble = integers_to_symbols_map[integers] # new shared preamble # A if prev_preamble is not None: c_signal_backward = Amod.modulate(preamble_halftrip, mode='explore', dtype='cartesian') c_signal_forward = Amod.modulate(preamble, mode='explore', dtype='cartesian') # Channel if prev_preamble is not None: c_signal_backward_noisy = add_awgn(c_signal_backward, SNR_db=train_SNR_db, signal_power=signal_power) c_signal_forward_noisy = add_awgn(c_signal_forward, SNR_db=train_SNR_db, signal_power=signal_power) if prev_preamble is not None: preamble_roundtrip = Bdemod.demodulate(c_signal_backward_noisy) # Update mod after a roundtrip pass Bmod.update(prev_preamble, prev_actions, preamble_roundtrip) batches_sent += 2 # guess of new preamble preamble_halftrip = Bdemod.demodulate(c_signal_forward_noisy) # Update demod after a oneway pass if i < num_iterations: # the last iteration is just to complete the roundtrip, do not update halftrip Bdemod.update(c_signal_forward_noisy, preamble) prev_preamble, prev_actions = preamble, c_signal_forward # SWITCH Amod, Ademod, Bmod, Bdemod = Bmod, Bdemod, Amod, Ademod ############### STATS ########################## if i % results_every == 0 or i == num_iterations: if verbose: print("ITER %i: Train SNR_db:% 5.1f" % (i, train_SNR_db)) result = evaluate(agent1=agents[0], agent2=agents[1], bits_per_symbol=bits_per_symbol, signal_power=signal_power, verbose=verbose or i == num_iterations, total_iterations=num_iterations // results_every, completed_iterations=i//results_every, **kwargs) test_SNR_dbs = result['test_SNR_dbs'] test_bers = result['test_bers'] db_off_for_test_snr = [testSNR - br.get_optimal_SNR_for_BER_roundtrip(testBER, bits_per_symbol) for testSNR, testBER in zip(test_SNR_dbs, test_bers)] ###ADD TO RESULT result['batches_sent'] = batches_sent result['db_off'] = db_off_for_test_snr results += [result] if early_stopping and all(np.array(db_off_for_test_snr) <= early_stopping_db_off): print("STOPPED AT ITERATION: %i" % i) print(['0 BER', '1e-5 BER', '1e-4 BER', '1e-3 BER', '1e-2 BER', '1e-1 BER']) print("TEST SNR dBs : ", test_SNR_dbs) print("dB off Optimal : ", db_off_for_test_snr) print("Early Stopping dBs off: %d" % early_stopping_db_off) early_stop = True break info = { 'bits_per_symbol': bits_per_symbol, 'train_SNR_db': train_SNR_db, 'num_results': len(results), 'test_SNR_dbs': test_SNR_dbs, 'early_stop': early_stop, 'early_stop_threshold_db_off': early_stopping_db_off, 'batch_size': batch_size, 'num_agents': 2, } return info, results
def train(*, agents, optimizer, bits_per_symbol: int, batch_size: int, num_iterations: int, results_every: int, train_SNR_db: float, signal_power: float, early_stopping: bool = False, early_stopping_db_off: float = .1, verbose: bool = False, **kwargs ): br = BER_lookup_table() early_stop = False if verbose: print("gradient_passing train.py") A = agents[0] optimizers = { 'adam': torch.optim.Adam, 'sgd': torch.optim.SGD, } if optimizer: assert optimizer.lower() in optimizers.keys(), "modulator optimizer=%s not supported" % optimizer optimizer = optimizers[optimizer.lower()] print("gradient_passing initialized with %s optimizer." % optimizer.__name__) optimizer = optimizer(A.mod.get_param_dicts() + A.demod.get_param_dicts()) else: print("gradient_passing initialized WITHOUT an optimizer") optimizer = None batches_sent = 0 results = [] loss_criterion = torch.nn.CrossEntropyLoss() for i in range(num_iterations): preamble_labels = np.random.randint(low=0, high=2 ** bits_per_symbol, size=[batch_size]) preamble = integers_to_symbols(preamble_labels, bits_per_symbol) preamble = torch.from_numpy(preamble).float() ##MODULATE/action c_signal_forward = A.mod.modulate_tensor(preamble) ##CHANNEL N0 = get_N0(SNR_db=train_SNR_db, signal_power=signal_power) noise = torch.from_numpy(get_awgn(N0=N0, n=c_signal_forward.shape[0])).float() c_signal_forward_noisy = c_signal_forward + noise ##DEMODULATE/update and pass loss to mod preamble_halftrip_logits = A.demod.demodulate_tensor(c_signal_forward_noisy) loss = loss_criterion(input=preamble_halftrip_logits.float(), target=torch.from_numpy(preamble_labels)) loss += A.mod.get_regularization_loss() loss += A.demod.get_regularization_loss() if optimizer: optimizer.zero_grad() loss.backward() optimizer.step() batches_sent += 1 ############### STATS ########################## if i % results_every == 0 or i == num_iterations-1: if verbose: print("ITER %i: Train SNR_db:% 5.1f" % (i, train_SNR_db)) result = evaluate(agent1=agents[0], bits_per_symbol=bits_per_symbol, signal_power=signal_power, verbose=verbose or i == num_iterations, total_iterations=num_iterations // results_every, completed_iterations=i // results_every, **kwargs) test_SNR_dbs = result['test_SNR_dbs'] test_bers = result['test_bers'] db_off_for_test_snr = [testSNR - br.get_optimal_SNR_for_BER_roundtrip(testBER, bits_per_symbol) for testSNR, testBER in zip(test_SNR_dbs, test_bers)] ###ADD TO RESULT result['batches_sent'] = batches_sent result['db_off'] = db_off_for_test_snr results += [result] if early_stopping and all(np.array(db_off_for_test_snr) <= early_stopping_db_off): print("STOPPED AT ITERATION: %i" % i) print(['0 BER', '1e-5 BER', '1e-4 BER', '1e-3 BER', '1e-2 BER', '1e-1 BER']) print("TEST SNR dBs : ", test_SNR_dbs) print("dB off Optimal : ", db_off_for_test_snr) print("Early Stopping dBs off: %d" % early_stopping_db_off) early_stop = True break info = { 'bits_per_symbol': bits_per_symbol, 'train_SNR_db': train_SNR_db, 'num_results': len(results), 'test_SNR_dbs': test_SNR_dbs, 'early_stop': early_stop, 'early_stop_threshold_db_off': early_stopping_db_off, 'batch_size': batch_size, 'num_agents': 1, } return info, results