def trainer(*, agents, bits_per_symbol: int, batch_size: int, train_SNR_db: float, signal_power: float = 1.0, backwards_only: bool = False, **kwargs): integers_to_symbols_map = integers_to_symbols( np.arange(0, 2**bits_per_symbol), bits_per_symbol) A = agents[0] B = agents[1] batches_sent = 0 integers = np.random.randint(low=0, high=2**bits_per_symbol, size=[batch_size]) preamble = integers_to_symbols_map[integers] # new shared preamble # A if A.to_echo is not None: c_signal_backward = A.mod.modulate(A.to_echo, mode='explore', dtype='cartesian') c_signal_forward = A.mod.modulate(preamble, mode='explore', dtype='cartesian') A.preamble = preamble A.actions = c_signal_forward # Channel if A.to_echo is not None: c_signal_backward_noisy = add_awgn(c_signal_backward, SNR_db=train_SNR_db, signal_power=signal_power) c_signal_forward_noisy = add_awgn(c_signal_forward, SNR_db=train_SNR_db, signal_power=signal_power) # B if A.to_echo is not None: preamble_roundtrip = B.demod.demodulate(c_signal_backward_noisy) # Update mod after a roundtrip pass B.mod.update(B.preamble, B.actions, preamble_roundtrip) batches_sent += 1 # guess of new preamble B.to_echo = B.demod.demodulate(c_signal_forward_noisy) # Update demod after a oneway pass if not backwards_only: B.demod.update(c_signal_forward_noisy, preamble) batches_sent += 1 return A, B, batches_sent
def test( *, agent1, agent2, bits_per_symbol, test_SNR_db, signal_power=1.0, test_batch_size=100000, ): integers = np.random.randint(low=0, high=2**bits_per_symbol, size=[test_batch_size]) preamble = integers_to_symbols(integers, bits_per_symbol=bits_per_symbol) A = agent1 B = agent2 c_signal_forward = A.mod.modulate(preamble, mode='exploit', dtype='cartesian') _c_signal_forward = B.mod.modulate(preamble, mode='exploit', dtype='cartesian') c_signal_forward_noisy = add_awgn(c_signal_forward, SNR_db=test_SNR_db, signal_power=signal_power) preamble_halftrip = B.demod.demodulate(c_signal_forward_noisy) # c_signal_backward = B.mod.modulate(preamble_halftrip, mode='exploit', dtype='cartesian') c_signal_backward_noisy = add_awgn(c_signal_backward, SNR_db=test_SNR_db, signal_power=signal_power) preamble_roundtrip = A.demod.demodulate(c_signal_backward_noisy) _c_signal_forward_noisy = add_awgn(_c_signal_forward, SNR_db=test_SNR_db, signal_power=signal_power) _preamble_halftrip = A.demod.demodulate(_c_signal_forward_noisy) # _c_signal_backward = A.mod.modulate(_preamble_halftrip, mode='exploit', dtype='cartesian') _c_signal_backward_noisy = add_awgn(_c_signal_backward, SNR_db=test_SNR_db, signal_power=signal_power) _preamble_roundtrip = B.demod.demodulate(_c_signal_backward_noisy) return (float(get_ber(preamble, preamble_roundtrip)) + float(get_ber(preamble, _preamble_roundtrip))) / 2.0
def test( *, agent1, agent2, bits_per_symbol, test_SNR_dbs, signal_power=1.0, test_batch_size=10000, ): integers = np.random.randint(low=0, high=2**bits_per_symbol, size=[test_batch_size]) preamble = integers_to_symbols(integers, bits_per_symbol=bits_per_symbol) A = agent1 B = agent2 if agent2 is not None else agent1 c_signal_forward = A.mod.modulate(preamble, mode='exploit', dtype='cartesian') _c_signal_forward = B.mod.modulate(preamble, mode='exploit', dtype='cartesian') for test_SNR_db in test_SNR_dbs: c_signal_forward_noisy = add_awgn(c_signal_forward, SNR_db=test_SNR_db, signal_power=signal_power) preamble_halftrip = B.demod.demodulate(c_signal_forward_noisy) # c_signal_backward = B.mod.modulate(preamble_halftrip, mode='exploit', dtype='cartesian') c_signal_backward_noisy = add_awgn(c_signal_backward, SNR_db=test_SNR_db, signal_power=signal_power) preamble_roundtrip = A.demod.demodulate(c_signal_backward_noisy) if agent2 is not None: _c_signal_forward_noisy = add_awgn(_c_signal_forward, SNR_db=test_SNR_db, signal_power=signal_power) _preamble_halftrip = A.demod.demodulate(_c_signal_forward_noisy) # _c_signal_backward = A.mod.modulate(_preamble_halftrip, mode='exploit', dtype='cartesian') _c_signal_backward_noisy = add_awgn(_c_signal_backward, SNR_db=test_SNR_db, signal_power=signal_power) _preamble_roundtrip = B.demod.demodulate(_c_signal_backward_noisy) return "TODOTODOTODO"
def roundtrip_evaluate(*, agent1, agent2=None, bits_per_symbol, test_batch_size: int, signal_power: float, verbose: bool = False, completed_iterations: int = None, total_iterations: int = None, **kwargs): grid_2d = get_grid_2d(grid=[-1.5, 1.5], points_per_dim=100) # For getting demod boundaries test_SNR_dbs = get_test_SNR_dbs()[bits_per_symbol]['ber_roundtrip'] A = agent1 if agent2 is None: B = agent1 else: B = agent2 # Calculate Roundtrip Testing Accuracy on different SNRs preamble = get_random_preamble(n=test_batch_size, bits_per_symbol=bits_per_symbol) c_signal_forward = A.mod.modulate(preamble, mode='exploit', dtype='cartesian') _c_signal_forward = B.mod.modulate(preamble, mode='exploit', dtype='cartesian') # c_signal_forward = A.mod.modulate(preamble, mode='classic', dtype='cartesian') # _c_signal_forward = B.mod.modulate(preamble, mode='classic', dtype='cartesian') test_bers = [[], []] test_sers = [[], []] test_preamble = symbols_to_integers(preamble) # count , acc = 0, 0; for test_SNR_db in test_SNR_dbs: # count += 1 c_signal_forward_noisy = add_awgn(c_signal_forward, SNR_db=test_SNR_db, signal_power=signal_power) preamble_halftrip = B.demod.demodulate(c_signal_forward_noisy) # # c_signal_backward = B.mod.modulate(preamble_halftrip, mode='classic', dtype='cartesian') c_signal_backward = B.mod.modulate(preamble_halftrip, mode='exploit', dtype='cartesian') c_signal_backward_noisy = add_awgn(c_signal_backward, SNR_db=test_SNR_db, signal_power=signal_power) preamble_roundtrip = A.demod.demodulate(c_signal_backward_noisy) # test accuracy test_roundtrip = symbols_to_integers(preamble_roundtrip) # acc += sum(test_roundtrip[i] == test_preamble[i] for i in range(len(preamble)))/len(preamble) # print('accuracy', sum(test_roundtrip[i] == test_preamble[i] for i in range(len(preamble)))/len(preamble)) test_bers[0].append(float(get_ber(preamble, preamble_roundtrip))) test_sers[0].append(float(get_ser(preamble, preamble_roundtrip))) if not agent2 is None: _c_signal_forward_noisy = add_awgn(_c_signal_forward, SNR_db=test_SNR_db, signal_power=signal_power) _preamble_halftrip = A.demod.demodulate(_c_signal_forward_noisy) # # _c_signal_backward = A.mod.modulate(_preamble_halftrip, mode='classic', dtype='cartesian') _c_signal_backward = A.mod.modulate(_preamble_halftrip, mode='exploit', dtype='cartesian') _c_signal_backward_noisy = add_awgn(_c_signal_backward, SNR_db=test_SNR_db, signal_power=signal_power) _preamble_roundtrip = B.demod.demodulate(_c_signal_backward_noisy) test_bers[1].append(float(get_ber(preamble, _preamble_roundtrip))) test_sers[1].append(float(get_ser(preamble, _preamble_roundtrip))) # print('accuracy', acc/count) if agent2 is not None: avg_test_sers = np.mean([test_sers[0], test_sers[1]], axis=0) avg_test_bers = np.mean([test_bers[0], test_bers[1]], axis=0) else: avg_test_sers = test_sers[0] avg_test_bers = test_bers[0] if verbose is True: # IF you want to manually debug: print(" ") if agent2 is not None: print("\t\t\t(%s --> %s --> %s), \n\t\t\t[%s --> %s --> %s], \n\t\t\t<Means>" % ( A.name, B.name, A.name, B.name, A.name, B.name)) for k in range(len(test_SNR_dbs)): print( "Test SNR_db :% 5.1f | " "(BER: %7.6f) [BER: %7.6f] <BER: %7.6f> | " "(SER: %7.6f) [SER: %7.6f] <SER: %7.6f>" % (test_SNR_dbs[k], test_bers[0][k], test_bers[1][k], avg_test_bers[k], test_sers[0][k], test_sers[1][k], avg_test_sers[k])) else: print("\t\t\t(%s --> %s --> %s)" % (A.name, B.name, A.name)) for k in range(len(test_SNR_dbs)): print( "Test SNR_db :% 5.1f | BER: %7.6f | SER: %7.6f" % (test_SNR_dbs[k], test_bers[0][k], test_sers[0][k],)) print(" ") elif (total_iterations is not None) and (completed_iterations is not None): print("[%s]" % ("." * completed_iterations + " " * (total_iterations - completed_iterations)), end='\r', flush=True) r2 = {} if agent2 is not None: r2 = { 'test_bers_1': test_bers[0], 'test_sers_1': test_sers[0], 'test_bers_2': test_bers[1], 'test_sers_2': test_sers[1], 'mod_std_2': agent2.mod.get_std(), 'constellation_2': agent2.mod.get_constellation(), 'demod_grid_2': agent2.demod.get_demod_grid(grid_2d), } return { 'test_SNR_dbs': test_SNR_dbs, 'test_bers': avg_test_bers, # mean 'test_sers': avg_test_sers, # mean 'mod_std_1': agent1.mod.get_std(), # 'constellation_1': agent1.mod.get_constellation(), 'constellation_1': add_awgn(agent1.mod.get_constellation(), SNR_db=test_SNR_db, signal_power=signal_power), 'demod_grid_1': agent1.demod.get_demod_grid(grid_2d), **r2 }
def train(*, agents, bits_per_symbol: int, batch_size: int, num_iterations: int, results_every: int, train_SNR_db: float, signal_power: float, early_stopping: bool = False, early_stopping_db_off: float = 1, verbose: bool = False, **kwargs): br = BER_lookup_table() early_stop = False if verbose: print("loss_passing train.py") A = agents[0] batches_sent = 0 results = [] for i in range(num_iterations): preamble = get_random_preamble(batch_size, bits_per_symbol) ##MODULATE/action c_signal_forward = A.mod.modulate(preamble, mode='explore', dtype='complex') actions = c_signal_forward ##CHANNEL c_signal_forward_noisy = add_awgn(c_signal_forward, SNR_db=train_SNR_db, signal_power=signal_power) ##DEMODULATE/update and pass loss to mod A.demod.update(c_signal_forward_noisy, preamble) preamble_halftrip = A.demod.demodulate(c_signal_forward_noisy) A.mod.update(preamble, actions, preamble_halftrip) batches_sent += 1 ############### STATS ########################## if i % results_every == 0 or i == num_iterations - 1: if verbose: print("ITER %i: Train SNR_db:% 5.1f" % (i, train_SNR_db)) result = evaluate(agent1=agents[0], bits_per_symbol=bits_per_symbol, signal_power=signal_power, verbose=verbose or i == num_iterations, total_iterations=num_iterations // results_every, completed_iterations=i // results_every, **kwargs) test_SNR_dbs = result['test_SNR_dbs'] test_bers = result['test_bers'] db_off_for_test_snr = [ testSNR - br.get_optimal_SNR_for_BER_roundtrip(testBER, bits_per_symbol) for testSNR, testBER in zip(test_SNR_dbs, test_bers) ] ###ADD TO RESULT result['batches_sent'] = batches_sent result['db_off'] = db_off_for_test_snr results += [result] if early_stopping and all( np.array(db_off_for_test_snr) <= early_stopping_db_off): print("STOPPED AT ITERATION: %i" % i) print([ '0 BER', '1e-5 BER', '1e-4 BER', '1e-3 BER', '1e-2 BER', '1e-1 BER' ]) print("TEST SNR dBs : ", test_SNR_dbs) print("dB off Optimal : ", db_off_for_test_snr) print("Early Stopping dBs off: %d" % early_stopping_db_off) early_stop = True break info = { 'bits_per_symbol': bits_per_symbol, 'train_SNR_db': train_SNR_db, 'num_results': len(results), 'test_SNR_dbs': test_SNR_dbs, 'early_stop': early_stop, 'early_stop_threshold_db_off': early_stopping_db_off, 'batch_size': batch_size, 'num_agents': 1, } return info, results
def train(*, agents, bits_per_symbol: int, batch_size: int, num_iterations: int, results_every: int, train_SNR_db: float, signal_power: float, early_stopping: bool = False, early_stopping_db_off: float = 1, verbose: bool = False, **kwargs ): br = BER_lookup_table() early_stop = False integers_to_symbols_map = integers_to_symbols(np.arange(0, 2 ** bits_per_symbol), bits_per_symbol) if verbose: print("shared_preamble train.py") Amod = agents[0].mod Ademod = agents[0].demod Bmod = agents[1].mod Bdemod = agents[1].demod prev_preamble = None prev_actions = None batches_sent = 0 results = [] for i in range(num_iterations + 1): # A.mod(preamble) | | B.demod(signal forward) |==> B update demod | # |--> channel -->| | |--> switch (A,B = B,A) # A.mod(pre-half) | | A.demod(signal backward) |==> B update mod | integers = np.random.randint(low=0, high=2 ** bits_per_symbol, size=[batch_size]) preamble = integers_to_symbols_map[integers] # new shared preamble # A if prev_preamble is not None: c_signal_backward = Amod.modulate(preamble_halftrip, mode='explore', dtype='cartesian') c_signal_forward = Amod.modulate(preamble, mode='explore', dtype='cartesian') # Channel if prev_preamble is not None: c_signal_backward_noisy = add_awgn(c_signal_backward, SNR_db=train_SNR_db, signal_power=signal_power) c_signal_forward_noisy = add_awgn(c_signal_forward, SNR_db=train_SNR_db, signal_power=signal_power) if prev_preamble is not None: preamble_roundtrip = Bdemod.demodulate(c_signal_backward_noisy) # Update mod after a roundtrip pass Bmod.update(prev_preamble, prev_actions, preamble_roundtrip) batches_sent += 2 # guess of new preamble preamble_halftrip = Bdemod.demodulate(c_signal_forward_noisy) # Update demod after a oneway pass if i < num_iterations: # the last iteration is just to complete the roundtrip, do not update halftrip Bdemod.update(c_signal_forward_noisy, preamble) prev_preamble, prev_actions = preamble, c_signal_forward # SWITCH Amod, Ademod, Bmod, Bdemod = Bmod, Bdemod, Amod, Ademod ############### STATS ########################## if i % results_every == 0 or i == num_iterations: if verbose: print("ITER %i: Train SNR_db:% 5.1f" % (i, train_SNR_db)) result = evaluate(agent1=agents[0], agent2=agents[1], bits_per_symbol=bits_per_symbol, signal_power=signal_power, verbose=verbose or i == num_iterations, total_iterations=num_iterations // results_every, completed_iterations=i//results_every, **kwargs) test_SNR_dbs = result['test_SNR_dbs'] test_bers = result['test_bers'] db_off_for_test_snr = [testSNR - br.get_optimal_SNR_for_BER_roundtrip(testBER, bits_per_symbol) for testSNR, testBER in zip(test_SNR_dbs, test_bers)] ###ADD TO RESULT result['batches_sent'] = batches_sent result['db_off'] = db_off_for_test_snr results += [result] if early_stopping and all(np.array(db_off_for_test_snr) <= early_stopping_db_off): print("STOPPED AT ITERATION: %i" % i) print(['0 BER', '1e-5 BER', '1e-4 BER', '1e-3 BER', '1e-2 BER', '1e-1 BER']) print("TEST SNR dBs : ", test_SNR_dbs) print("dB off Optimal : ", db_off_for_test_snr) print("Early Stopping dBs off: %d" % early_stopping_db_off) early_stop = True break info = { 'bits_per_symbol': bits_per_symbol, 'train_SNR_db': train_SNR_db, 'num_results': len(results), 'test_SNR_dbs': test_SNR_dbs, 'early_stop': early_stop, 'early_stop_threshold_db_off': early_stopping_db_off, 'batch_size': batch_size, 'num_agents': 2, } return info, results