Exemplo n.º 1
0
    def __init__(
            self,
            *,
            model,
            bits_per_symbol: Union[float, int],
            optimizer: Optional[str] = 'adam',
            stepsize_mu: float = 0.0,
            stepsize_sigma: float = 0.0,
            initial_std: float = 0.1,
            min_std: float = 1e-5,
            max_std: float = 1e2,
            lambda_baseline: float = 0.0,  #used in update method
            lambda_center: float = 0.0,  #used in update method
            lambda_l1: float = 0.0,  #used in update method
            lambda_l2: float = 0.0,  #used in update method
            **kwargs):

        self.model = model(bits_per_symbol=bits_per_symbol, **kwargs)
        self.name = self.model.name
        self.bits_per_symbol = bits_per_symbol
        self.log_std_min = np.log(min_std) * torch.ones(2)
        self.log_std_max = np.log(max_std) * torch.ones(2)
        self.lambda_l1 = torch.tensor(lambda_l1).float()
        self.lambda_l2 = torch.tensor(lambda_l2).float()
        self.lambda_center = torch.tensor(lambda_center).float()
        self.lambda_baseline = torch.tensor(lambda_baseline).float()
        self.all_symbols = integers_to_symbols(
            np.arange(0, 2**bits_per_symbol), bits_per_symbol)
        self.log_std = nn.Parameter(np.log(initial_std) * torch.ones(2))
        # self.std.register_hook(lambda grad: print(grad))

        optimizers = {
            'adam': torch.optim.Adam,
            'sgd': torch.optim.SGD,
        }

        if optimizer and hasattr(self.model, "mu_parameters") and not hasattr(
                self.model, "update"):
            assert optimizer.lower() in optimizers.keys(
            ), "modulator optimizer=%s not supported" % optimizer
            optimizer = optimizers[optimizer.lower()]
            if kwargs['verbose']:
                print("Modulator %s initialized with %s optimizer." %
                      (self.model.name, optimizer.__name__))
            self.param_dicts = [\
                    {'params': self.model.mu_parameters(), 'lr':stepsize_mu},
                    {'params': self.log_std, 'lr':stepsize_sigma}]
            self.optimizer = optimizer(self.param_dicts)
        else:
            if kwargs['verbose']:
                print("Modulator %s initialized WITHOUT an optimizer" %
                      (self.model.name))
            self.optimizer = None
            if hasattr(self.model, "mu_parameters"):
                self.param_dicts = [\
                    {'params': self.model.mu_parameters(), 'lr':stepsize_mu},
                    {'params': self.log_std, 'lr':stepsize_sigma}]
            else:
                self.param_dicts = []
Exemplo n.º 2
0
def trainer(*,
            agents,
            bits_per_symbol: int,
            batch_size: int,
            train_SNR_db: float,
            signal_power: float = 1.0,
            backwards_only: bool = False,
            **kwargs):
    integers_to_symbols_map = integers_to_symbols(
        np.arange(0, 2**bits_per_symbol), bits_per_symbol)
    A = agents[0]
    B = agents[1]

    batches_sent = 0
    integers = np.random.randint(low=0,
                                 high=2**bits_per_symbol,
                                 size=[batch_size])
    preamble = integers_to_symbols_map[integers]  # new shared preamble
    # A
    if A.to_echo is not None:
        c_signal_backward = A.mod.modulate(A.to_echo,
                                           mode='explore',
                                           dtype='cartesian')
    c_signal_forward = A.mod.modulate(preamble,
                                      mode='explore',
                                      dtype='cartesian')
    A.preamble = preamble
    A.actions = c_signal_forward

    # Channel
    if A.to_echo is not None:
        c_signal_backward_noisy = add_awgn(c_signal_backward,
                                           SNR_db=train_SNR_db,
                                           signal_power=signal_power)
    c_signal_forward_noisy = add_awgn(c_signal_forward,
                                      SNR_db=train_SNR_db,
                                      signal_power=signal_power)

    # B
    if A.to_echo is not None:
        preamble_roundtrip = B.demod.demodulate(c_signal_backward_noisy)
        # Update mod after a roundtrip pass
        B.mod.update(B.preamble, B.actions, preamble_roundtrip)
        batches_sent += 1

    # guess of new preamble
    B.to_echo = B.demod.demodulate(c_signal_forward_noisy)
    # Update demod after a oneway pass
    if not backwards_only:
        B.demod.update(c_signal_forward_noisy, preamble)
        batches_sent += 1

    return A, B, batches_sent
Exemplo n.º 3
0
    def __init__(
            self,
            *,
            model,
            bits_per_symbol,
            optimizer: Optional[str] = 'adam',
            stepsize_cross_entropy: float = 1e-3,  #
            cross_entropy_weight: float = 1.0,  #
            epochs: int = 5,
            lambda_l1: float = 0.0,
            lambda_l2: float = 0.0,
            **kwargs):
        self.epochs = epochs
        self.bits_per_symbol = bits_per_symbol
        self.model = model(bits_per_symbol=bits_per_symbol, **kwargs)
        self.name = self.model.name
        self.lambda_l1 = torch.tensor(lambda_l1).float()
        self.lambda_l2 = torch.tensor(lambda_l2).float()
        self.integers_to_symbols_map = integers_to_symbols(
            np.arange(0, 2**bits_per_symbol), bits_per_symbol)

        optimizers = {
            'adam': torch.optim.Adam,
            'sgd': torch.optim.SGD,
        }

        if optimizer and hasattr(self.model, 'parameters') and not hasattr(
                self.model, "update"):
            assert optimizer.lower() in optimizers.keys(
            ), "demodulator optimizer=%s not supported" % optimizer
            optimizer = optimizers[optimizer.lower()]
            if kwargs['verbose']:
                print("Demodulator %s initialized with %s optimizer." %
                      (self.model.name, optimizer.__name__))
            self.cross_entropy_weight = torch.tensor(
                cross_entropy_weight).float()
            self.param_dicts = [ \
                {'params': self.model.parameters(), 'lr': stepsize_cross_entropy},
            ]
            self.optimizer = optimizer(self.param_dicts,
                                       lr=stepsize_cross_entropy)
        else:
            if kwargs['verbose']:
                print("Demodulator %s initialized WITHOUT an optimizer" %
                      (self.model.name))
            self.optimizer = None
            if hasattr(self.model, 'parameters'):
                self.param_dicts = [ \
                    {'params': self.model.parameters(), 'lr': stepsize_cross_entropy},
                ]
            else:
                self.param_dicts = []
Exemplo n.º 4
0
def test(
    *,
    agent1,
    agent2,
    bits_per_symbol,
    test_SNR_db,
    signal_power=1.0,
    test_batch_size=100000,
):
    integers = np.random.randint(low=0,
                                 high=2**bits_per_symbol,
                                 size=[test_batch_size])
    preamble = integers_to_symbols(integers, bits_per_symbol=bits_per_symbol)
    A = agent1
    B = agent2
    c_signal_forward = A.mod.modulate(preamble,
                                      mode='exploit',
                                      dtype='cartesian')
    _c_signal_forward = B.mod.modulate(preamble,
                                       mode='exploit',
                                       dtype='cartesian')
    c_signal_forward_noisy = add_awgn(c_signal_forward,
                                      SNR_db=test_SNR_db,
                                      signal_power=signal_power)
    preamble_halftrip = B.demod.demodulate(c_signal_forward_noisy)  #
    c_signal_backward = B.mod.modulate(preamble_halftrip,
                                       mode='exploit',
                                       dtype='cartesian')
    c_signal_backward_noisy = add_awgn(c_signal_backward,
                                       SNR_db=test_SNR_db,
                                       signal_power=signal_power)
    preamble_roundtrip = A.demod.demodulate(c_signal_backward_noisy)

    _c_signal_forward_noisy = add_awgn(_c_signal_forward,
                                       SNR_db=test_SNR_db,
                                       signal_power=signal_power)
    _preamble_halftrip = A.demod.demodulate(_c_signal_forward_noisy)  #
    _c_signal_backward = A.mod.modulate(_preamble_halftrip,
                                        mode='exploit',
                                        dtype='cartesian')
    _c_signal_backward_noisy = add_awgn(_c_signal_backward,
                                        SNR_db=test_SNR_db,
                                        signal_power=signal_power)
    _preamble_roundtrip = B.demod.demodulate(_c_signal_backward_noisy)

    return (float(get_ber(preamble, preamble_roundtrip)) +
            float(get_ber(preamble, _preamble_roundtrip))) / 2.0
Exemplo n.º 5
0
def test(
    *,
    agent1,
    agent2,
    bits_per_symbol,
    test_SNR_dbs,
    signal_power=1.0,
    test_batch_size=10000,
):
    integers = np.random.randint(low=0,
                                 high=2**bits_per_symbol,
                                 size=[test_batch_size])
    preamble = integers_to_symbols(integers, bits_per_symbol=bits_per_symbol)
    A = agent1
    B = agent2 if agent2 is not None else agent1
    c_signal_forward = A.mod.modulate(preamble,
                                      mode='exploit',
                                      dtype='cartesian')
    _c_signal_forward = B.mod.modulate(preamble,
                                       mode='exploit',
                                       dtype='cartesian')
    for test_SNR_db in test_SNR_dbs:
        c_signal_forward_noisy = add_awgn(c_signal_forward,
                                          SNR_db=test_SNR_db,
                                          signal_power=signal_power)
        preamble_halftrip = B.demod.demodulate(c_signal_forward_noisy)  #
        c_signal_backward = B.mod.modulate(preamble_halftrip,
                                           mode='exploit',
                                           dtype='cartesian')
        c_signal_backward_noisy = add_awgn(c_signal_backward,
                                           SNR_db=test_SNR_db,
                                           signal_power=signal_power)
        preamble_roundtrip = A.demod.demodulate(c_signal_backward_noisy)
        if agent2 is not None:
            _c_signal_forward_noisy = add_awgn(_c_signal_forward,
                                               SNR_db=test_SNR_db,
                                               signal_power=signal_power)
            _preamble_halftrip = A.demod.demodulate(_c_signal_forward_noisy)  #
            _c_signal_backward = A.mod.modulate(_preamble_halftrip,
                                                mode='exploit',
                                                dtype='cartesian')
            _c_signal_backward_noisy = add_awgn(_c_signal_backward,
                                                SNR_db=test_SNR_db,
                                                signal_power=signal_power)
            _preamble_roundtrip = B.demod.demodulate(_c_signal_backward_noisy)
    return "TODOTODOTODO"
Exemplo n.º 6
0
def get_random_preamble(n, bits_per_symbol):
    integers = np.random.randint(low=0, high=2 ** bits_per_symbol, size=[n])
    return integers_to_symbols(integers, bits_per_symbol)
Exemplo n.º 7
0
def animated_plot(results_file="/Users/caryn/Desktop/echo/experiments/private_"
                  "preamble/QPSK_poly_hyperparam_search/results/0.npy",
                  results=None):  # (result):
    plt.ion()
    # plt.show()
    if results is None:
        results = np.load(open(results_file, 'rb'), allow_pickle=True)
    meta, results = results[0], results[1:]
    num_agents = meta['num_agents']
    bits_per_symbol = meta['bits_per_symbol']
    num_colors = 2**bits_per_symbol
    # cm = plt.get_cmap('tab20')
    cm = plt.cm.viridis
    colors = [cm(int(x * cm.N / num_colors)) for x in range(num_colors)]
    grid = get_grid_2d()
    if num_agents == 2:
        pairs = [(1, 2), (2, 1)]
        fig, axes = plt.subplots(nrows=2,
                                 ncols=2,
                                 sharex=True,
                                 sharey=True,
                                 figsize=(8, 8))
        axes_list = [item for sublist in axes for item in sublist]
    elif num_agents == 1:
        pairs = [(1, 1)]
        fig, axes = plt.subplots(nrows=1,
                                 ncols=2,
                                 sharex=True,
                                 sharey=True,
                                 figsize=(8, 4))
        axes_list = [item for item in axes]
    for result in results:
        _a_ = 0
        for a, b in pairs:
            m, c = (result['mod_std_%i' % a], result['constellation_%i' % a])
            d = result['demod_grid_%i' % b]
            i_means = c[:, 0]
            q_means = c[:, 1]
            cov = np.eye(2, 2) * m**2
            i_std, q_std = m
            minor = np.sqrt(9.210) * q_std
            major = np.sqrt(9.210) * i_std
            ax = axes_list[_a_]
            _a_ += 1
            ax.scatter(i_means, q_means, c=colors)
            ells = [
                Ellipse(xy=c[i],
                        width=major,
                        height=minor,
                        color=colors[i],
                        alpha=.2) for i in range(len(c))
            ]
            # ells = [confidence_ellipse(cov, c[i], ax, facecolor=colors[i], alpha=.2) for i in range(len(c))]
            for i, e in enumerate(ells):
                ax.add_artist(e)
            #     e.set_clip_box(ax.bbox)
            #     e.set_alpha(.2)
            #     e.set_facecolor(colors[i])
            for label, (x, y) in \
                    zip(integers_to_symbols(np.array([i for i in range(2 ** bits_per_symbol)]), bits_per_symbol), c):
                ax.annotate(label,
                            xy=(x, y),
                            xytext=(0, 0),
                            textcoords='offset points')
            ax.set_xlim(-1.5, 1.5)
            ax.set_ylim(-1.5, 1.5)
            ax1 = axes_list[_a_]
            _a_ += 1
            ax1.scatter(grid[:, 0], grid[:, 1], c=[colors[i] for i in d])
            ax1.set_xlim(-1.5, 1.5)
            ax1.set_ylim(-1.5, 1.5)
        plt.pause(0.0000000001)
        for ax in axes_list:
            ax.clear()
    fig.canvas.draw()

    plt.ioff()
    plt.close(fig)

    # plt.savefig("%s/%s_demod-%d.png" % (plots_dir, "_".join(agent.name.lower().split(" ")), plot_count))
    return
Exemplo n.º 8
0
def train(*,
          agents,
          bits_per_symbol: int,
          batch_size: int,
          num_iterations: int,
          results_every: int,
          train_SNR_db: float,
          signal_power: float,
          early_stopping: bool = False,
          early_stopping_db_off: float = 1,
          verbose: bool = False,
          **kwargs
          ):
    br = BER_lookup_table()
    early_stop = False
    integers_to_symbols_map = integers_to_symbols(np.arange(0, 2 ** bits_per_symbol), bits_per_symbol)
    if verbose:
        print("shared_preamble train.py")

    Amod = agents[0].mod
    Ademod = agents[0].demod
    Bmod = agents[1].mod
    Bdemod = agents[1].demod
    prev_preamble = None
    prev_actions = None
    batches_sent = 0
    results = []
    for i in range(num_iterations + 1):
        # A.mod(preamble) |               | B.demod(signal forward)      |==> B update demod     |
        #                 |--> channel -->|                              |                       |--> switch (A,B = B,A)
        # A.mod(pre-half) |               | A.demod(signal backward)     |==> B update mod       |
        integers = np.random.randint(low=0, high=2 ** bits_per_symbol, size=[batch_size])
        preamble = integers_to_symbols_map[integers]  # new shared preamble
        # A
        if prev_preamble is not None:
            c_signal_backward = Amod.modulate(preamble_halftrip, mode='explore', dtype='cartesian')
        c_signal_forward = Amod.modulate(preamble, mode='explore', dtype='cartesian')

        # Channel
        if prev_preamble is not None:
            c_signal_backward_noisy = add_awgn(c_signal_backward, SNR_db=train_SNR_db, signal_power=signal_power)
        c_signal_forward_noisy = add_awgn(c_signal_forward, SNR_db=train_SNR_db, signal_power=signal_power)

        if prev_preamble is not None:
            preamble_roundtrip = Bdemod.demodulate(c_signal_backward_noisy)
            # Update mod after a roundtrip pass
            Bmod.update(prev_preamble, prev_actions, preamble_roundtrip)
            batches_sent += 2

        # guess of new preamble
        preamble_halftrip = Bdemod.demodulate(c_signal_forward_noisy)
        # Update demod after a oneway pass
        if i < num_iterations:
            # the last iteration is just to complete the roundtrip, do not update halftrip
            Bdemod.update(c_signal_forward_noisy, preamble)

        prev_preamble, prev_actions = preamble, c_signal_forward

        # SWITCH
        Amod, Ademod, Bmod, Bdemod = Bmod, Bdemod, Amod, Ademod

        ############### STATS ##########################
        if i % results_every == 0 or i == num_iterations:
            if verbose:
                print("ITER %i: Train SNR_db:% 5.1f" % (i, train_SNR_db))

            result = evaluate(agent1=agents[0],
                              agent2=agents[1],
                              bits_per_symbol=bits_per_symbol,
                              signal_power=signal_power,
                              verbose=verbose or i == num_iterations,
                              total_iterations=num_iterations // results_every,
                              completed_iterations=i//results_every,
                              **kwargs)

            test_SNR_dbs = result['test_SNR_dbs']
            test_bers = result['test_bers']
            db_off_for_test_snr = [testSNR - br.get_optimal_SNR_for_BER_roundtrip(testBER, bits_per_symbol)
                                   for testSNR, testBER in zip(test_SNR_dbs, test_bers)]
            ###ADD TO RESULT
            result['batches_sent'] = batches_sent
            result['db_off'] = db_off_for_test_snr
            results += [result]
            if early_stopping and all(np.array(db_off_for_test_snr) <= early_stopping_db_off):
                print("STOPPED AT ITERATION: %i" % i)
                print(['0 BER', '1e-5 BER', '1e-4 BER', '1e-3 BER', '1e-2 BER', '1e-1 BER'])
                print("TEST SNR dBs : ", test_SNR_dbs)
                print("dB off Optimal : ", db_off_for_test_snr)
                print("Early Stopping dBs off: %d" % early_stopping_db_off)
                early_stop = True
                break
    info = {
        'bits_per_symbol': bits_per_symbol,
        'train_SNR_db': train_SNR_db,
        'num_results': len(results),
        'test_SNR_dbs': test_SNR_dbs,
        'early_stop': early_stop,
        'early_stop_threshold_db_off': early_stopping_db_off,
        'batch_size': batch_size,
        'num_agents': 2,
    }
    return info, results
Exemplo n.º 9
0
def train(*,
          agents,
          optimizer,
          bits_per_symbol: int,
          batch_size: int,
          num_iterations: int,
          results_every: int,
          train_SNR_db: float,
          signal_power: float,
          early_stopping: bool = False,
          early_stopping_db_off: float = .1,
          verbose: bool = False,
          **kwargs
          ):
    br = BER_lookup_table()
    early_stop = False
    if verbose:
        print("gradient_passing train.py")

    A = agents[0]
    optimizers = {
        'adam': torch.optim.Adam,
        'sgd': torch.optim.SGD,
    }
    if optimizer:
        assert optimizer.lower() in optimizers.keys(), "modulator optimizer=%s not supported" % optimizer
        optimizer = optimizers[optimizer.lower()]
        print("gradient_passing initialized with %s optimizer." % optimizer.__name__)
        optimizer = optimizer(A.mod.get_param_dicts() + A.demod.get_param_dicts())
    else:
        print("gradient_passing initialized WITHOUT an optimizer")
        optimizer = None

    batches_sent = 0
    results = []
    loss_criterion = torch.nn.CrossEntropyLoss()
    for i in range(num_iterations):
        preamble_labels = np.random.randint(low=0, high=2 ** bits_per_symbol, size=[batch_size])
        preamble = integers_to_symbols(preamble_labels, bits_per_symbol)
        preamble = torch.from_numpy(preamble).float()
        ##MODULATE/action
        c_signal_forward = A.mod.modulate_tensor(preamble)
        ##CHANNEL
        N0 = get_N0(SNR_db=train_SNR_db, signal_power=signal_power)
        noise = torch.from_numpy(get_awgn(N0=N0, n=c_signal_forward.shape[0])).float()
        c_signal_forward_noisy = c_signal_forward + noise
        ##DEMODULATE/update and pass loss to mod
        preamble_halftrip_logits = A.demod.demodulate_tensor(c_signal_forward_noisy)
        loss = loss_criterion(input=preamble_halftrip_logits.float(), target=torch.from_numpy(preamble_labels))
        loss += A.mod.get_regularization_loss()
        loss += A.demod.get_regularization_loss()

        if optimizer:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        batches_sent += 1

        ############### STATS ##########################
        if i % results_every == 0 or i == num_iterations-1:
            if verbose:
                print("ITER %i: Train SNR_db:% 5.1f" % (i, train_SNR_db))

            result = evaluate(agent1=agents[0],
                              bits_per_symbol=bits_per_symbol,
                              signal_power=signal_power,
                              verbose=verbose or i == num_iterations,
                              total_iterations=num_iterations // results_every,
                              completed_iterations=i // results_every,
                              **kwargs)

            test_SNR_dbs = result['test_SNR_dbs']
            test_bers = result['test_bers']
            db_off_for_test_snr = [testSNR - br.get_optimal_SNR_for_BER_roundtrip(testBER, bits_per_symbol)
                                   for testSNR, testBER in zip(test_SNR_dbs, test_bers)]
            ###ADD TO RESULT
            result['batches_sent'] = batches_sent
            result['db_off'] = db_off_for_test_snr
            results += [result]
            if early_stopping and all(np.array(db_off_for_test_snr) <= early_stopping_db_off):
                print("STOPPED AT ITERATION: %i" % i)
                print(['0 BER', '1e-5 BER', '1e-4 BER', '1e-3 BER', '1e-2 BER', '1e-1 BER'])
                print("TEST SNR dBs : ", test_SNR_dbs)
                print("dB off Optimal : ", db_off_for_test_snr)
                print("Early Stopping dBs off: %d" % early_stopping_db_off)
                early_stop = True
                break
    info = {
        'bits_per_symbol': bits_per_symbol,
        'train_SNR_db': train_SNR_db,
        'num_results': len(results),
        'test_SNR_dbs': test_SNR_dbs,
        'early_stop': early_stop,
        'early_stop_threshold_db_off': early_stopping_db_off,
        'batch_size': batch_size,
        'num_agents': 1,
    }
    return info, results