Ejemplo n.º 1
0
def test_crossbar(shape):
    device = torch.device('cpu' if 'cpu' in memtorch.__version__ else 'cuda')
    memristor_model = memtorch.bh.memristor.LinearIonDrift
    memristor_model_params = {'time_series_resolution': 1e-3}
    crossbar = memtorch.bh.crossbar.Crossbar(memristor_model, memristor_model_params, shape)
    conductance_matrix = naive_map(torch.zeros(shape).uniform_(0, 1),
                                   memristor_model().r_on, memristor_model().r_off,
                                   memtorch.bh.crossbar.Scheme.SingleColumn)
    crossbar.write_conductance_matrix(conductance_matrix)
    if sys.version_info > (3, 6):
        crossbar.update(from_devices=False, parallelize=True)
        assert torch.all(torch.isclose(conductance_matrix.T[:, :], crossbar.conductance_matrix.cpu()[:, :], atol=1e-5))
        assert crossbar.devices[0][0].g == crossbar.conductance_matrix[0][0].item()

    crossbar.update(from_devices=False, parallelize=False)
    assert crossbar.devices[0][0].g == crossbar.conductance_matrix[0][0].item()
    inputs = torch.zeros(shape).uniform_(0, 1)
    assert torch.all(torch.isclose(simulate_matmul(inputs, crossbar).float(),
                     torch.matmul(inputs, conductance_matrix.T).float().to(device), rtol=1e-1))
    programming_signal = gen_programming_signal(1, 1e-2, 1e-2, 1, memristor_model_params['time_series_resolution'])
    assert type(programming_signal) == tuple
    with pytest.raises(AssertionError):
        gen_programming_signal(1, 1e-4, 1e-4, 1, memristor_model_params['time_series_resolution'])

    point = (0, 0)
    row, column = point
    conductance_to_program = random.uniform(1 / crossbar.devices[row][column].r_off, 1 / crossbar.devices[row][column].r_on)
    crossbar.devices = naive_program(crossbar, (row, column), conductance_to_program, rel_tol=0.01)
    assert math.isclose(conductance_to_program, crossbar.devices[row][column].g, abs_tol=1e-4)
    with pytest.raises(AssertionError):
        naive_program(crossbar, (row, column), -1)
Ejemplo n.º 2
0
def update_patched_model(patched_model, model):
    for i, (name, m) in enumerate(list(patched_model.named_modules())):
        if isinstance(m, memtorch.mn.Conv2d) or isinstance(m, memtorch.mn.Linear):
            pos_conductance_matrix, neg_conductance_matrix = naive_map(getattr(model, name).weight.data, r_on, r_off,scheme=memtorch.bh.Scheme.DoubleColumn)
            m.crossbars[0].write_conductance_matrix(pos_conductance_matrix, transistor=True, programming_routine=None)
            m.crossbars[1].write_conductance_matrix(neg_conductance_matrix, transistor=True, programming_routine=None)
            m.weight.data = getattr(model, name).weight.data

    return patched_model