def test_device_faults(debug_patched_networks):
    device = torch.device('cpu' if 'cpu' in memtorch.__version__ else 'cuda')
    patched_networks = debug_patched_networks
    for patched_network in patched_networks:
        patched_network_lrs = apply_nonidealities(
            copy.deepcopy(patched_network),
            non_idealities=[memtorch.bh.nonideality.NonIdeality.DeviceFaults],
            lrs_proportion=0.5,
            hrs_proportion=0,
            electroform_proportion=0)
        patched_tensor_lrs = patched_network_lrs.layer.crossbars[
            0].conductance_matrix.float().to(device)
        lrs = torch.tensor(1 / np.vectorize(lambda x: x.r_on)(
            patched_network_lrs.layer.crossbars[0].devices)).float().to(device)
        lrs_percentage = sum(
            torch.isclose(patched_tensor_lrs,
                          lrs).view(-1)).item() / patched_tensor_lrs.numel()
        patched_network_hrs = apply_nonidealities(
            copy.deepcopy(patched_network),
            non_idealities=[memtorch.bh.nonideality.NonIdeality.DeviceFaults],
            lrs_proportion=0,
            hrs_proportion=0.25,
            electroform_proportion=0.25)
        patched_tensor_hrs = patched_network_hrs.layer.crossbars[
            0].conductance_matrix.float().to(device)
        hrs = torch.tensor(1 / np.vectorize(lambda x: x.r_off)(
            patched_network_hrs.layer.crossbars[0].devices)).float().to(device)
        hrs_percentage = sum(
            torch.isclose(patched_tensor_hrs,
                          hrs).view(-1)).item() / patched_tensor_hrs.numel()
        assert lrs_percentage >= 0.25 and hrs_percentage >= 0.25  # To account for some stochasticity
def test_device_faults(debug_patched_networks, tile_shape, quant_method):
    device = torch.device("cpu" if "cpu" in memtorch.__version__ else "cuda")
    patched_networks = debug_patched_networks(tile_shape, quant_method)
    for patched_network in patched_networks:
        patched_network_lrs = apply_nonidealities(
            copy.deepcopy(patched_network),
            non_idealities=[memtorch.bh.nonideality.NonIdeality.DeviceFaults],
            lrs_proportion=0.5,
            hrs_proportion=0,
            electroform_proportion=0,
        )
        patched_tensor_lrs = (
            patched_network_lrs.layer.crossbars[0].conductance_matrix.float().to(device)
        )
        lrs = (
            torch.tensor(
                1
                / np.vectorize(lambda x: x.r_on)(
                    patched_network_lrs.layer.crossbars[0].devices
                )
            )
            .float()
            .to(device)
        )
        lrs_percentage = (
            sum(torch.isclose(patched_tensor_lrs, lrs).view(-1)).item()
            / patched_tensor_lrs.numel()
        )
        patched_network_hrs = apply_nonidealities(
            copy.deepcopy(patched_network),
            non_idealities=[memtorch.bh.nonideality.NonIdeality.DeviceFaults],
            lrs_proportion=0,
            hrs_proportion=0.25,
            electroform_proportion=0.25,
        )
        patched_tensor_hrs = (
            patched_network_hrs.layer.crossbars[0].conductance_matrix.float().to(device)
        )
        hrs = (
            torch.tensor(
                1
                / np.vectorize(lambda x: x.r_off)(
                    patched_network_hrs.layer.crossbars[0].devices
                )
            )
            .float()
            .to(device)
        )
        hrs_percentage = (
            sum(torch.isclose(patched_tensor_hrs, hrs).view(-1)).item()
            / patched_tensor_hrs.numel()
        )
        assert (
            lrs_percentage >= 0.25 and hrs_percentage >= 0.25
        )  # To account for some degree of stochasticity
def test_non_linear(debug_patched_networks, tile_shape, quant_method):
    patched_networks = debug_patched_networks(tile_shape, quant_method)
    for patched_network in patched_networks:
        patched_network_non_linear = apply_nonidealities(
            copy.deepcopy(patched_network),
            non_idealities=[memtorch.bh.nonideality.NonIdeality.NonLinear],
            sweep_duration=2,
            sweep_voltage_signal_amplitude=1,
            sweep_voltage_signal_frequency=0.5,
        )
        patched_network_non_linear.tune_(
            tune_kwargs={
                "<class 'memtorch.mn.Conv1d.Conv1d'>": {
                    "input_batch_size": 1,
                    "input_shape": 2,
                },
                "<class 'memtorch.mn.Conv2d.Conv2d'>": {
                    "input_batch_size": 1,
                    "input_shape": 2,
                },
                "<class 'memtorch.mn.Conv3d.Conv3d'>": {
                    "input_batch_size": 1,
                    "input_shape": 2,
                },
                "<class 'memtorch.mn.Linear.Linear'>": {"input_shape": 2},
            }
        )
        patched_network_non_linear = apply_nonidealities(
            copy.deepcopy(patched_network),
            non_idealities=[memtorch.bh.nonideality.NonIdeality.NonLinear],
            simulate=True,
        )
        patched_network_non_linear.tune_(
            tune_kwargs={
                "<class 'memtorch.mn.Conv1d.Conv1d'>": {
                    "input_batch_size": 1,
                    "input_shape": 2,
                },
                "<class 'memtorch.mn.Conv2d.Conv2d'>": {
                    "input_batch_size": 1,
                    "input_shape": 2,
                },
                "<class 'memtorch.mn.Conv3d.Conv3d'>": {
                    "input_batch_size": 1,
                    "input_shape": 2,
                },
                "<class 'memtorch.mn.Linear.Linear'>": {"input_shape": 2},
            }
        )
def test_non_linear(debug_patched_networks):
    patched_networks = debug_patched_networks
    for patched_network in patched_networks:
        patched_network_non_linear = apply_nonidealities(
            copy.deepcopy(patched_network),
            non_idealities=[memtorch.bh.nonideality.NonIdeality.NonLinear],
            sweep_duration=2,
            sweep_voltage_signal_amplitude=1,
            sweep_voltage_signal_frequency=0.5)
        patched_network_non_linear.tune_()
        patched_network_non_linear = apply_nonidealities(
            copy.deepcopy(patched_network),
            non_idealities=[memtorch.bh.nonideality.NonIdeality.NonLinear],
            simulate=True)
        patched_network_non_linear.tune_()
def test_finite_conductance_states(
    debug_patched_networks, tile_shape, quant_method, conductance_states=5
):
    device = torch.device("cpu" if "cpu" in memtorch.__version__ else "cuda")
    patched_networks = debug_patched_networks(tile_shape, quant_method)
    for patched_network in patched_networks:
        patched_network_finite_states = apply_nonidealities(
            copy.deepcopy(patched_network),
            non_idealities=[
                memtorch.bh.nonideality.NonIdeality.FiniteConductanceStates
            ],
            conductance_states=5,
        )
        conductance_matrix = patched_network.layer.crossbars[0].conductance_matrix
        quantized_conductance_matrix = patched_network_finite_states.layer.crossbars[
            0
        ].conductance_matrix
        quantized_conductance_matrix_unique = quantized_conductance_matrix.unique()
        valid_values = torch.linspace(
            patched_network.layer.crossbars[0].conductance_matrix.min(),
            patched_network.layer.crossbars[0].conductance_matrix.max(),
            conductance_states,
        ).float()
        assert any(
            [
                bool(val)
                for val in [
                    torch.isclose(
                        quantized_conductance_matrix_unique, valid_value
                    ).any()
                    for valid_value in valid_values
                ]
            ]
        )
        assert conductance_matrix.shape == quantized_conductance_matrix.shape
def test_model_endurance_retention_endurance(
    debug_patched_networks,
    tile_shape,
    operation_mode,
    temperature,
    x=1e4,
    p_lrs=[1, 0, 0, 0],
    stable_resistance_lrs=100,
    p_hrs=[1, 0, 0, 0],
    stable_resistance_hrs=1000,
    cell_size=None,
):
    device = torch.device("cpu" if "cpu" in memtorch.__version__ else "cuda")
    patched_networks = debug_patched_networks(tile_shape, None)
    for patched_network in patched_networks:
        patched_network = apply_nonidealities(
            copy.deepcopy(patched_network),
            non_idealities=[memtorch.bh.nonideality.NonIdeality.Endurance],
            x=float(x),
            endurance_model=memtorch.bh.nonideality.endurance_retention_models.model_endurance_retention,
            endurance_model_kwargs={
                "operation_mode": operation_mode,
                "p_lrs": p_lrs,
                "stable_resistance_lrs": stable_resistance_lrs,
                "p_hrs": p_hrs,
                "stable_resistance_hrs": stable_resistance_hrs,
                "cell_size": cell_size,
                "temperature": temperature,
            },
        )
def test_model_conductance_drift(
    debug_patched_networks,
    time,
    drift_coefficient,
    tile_shape=(128, 128),
    initial_time=1e-12,
):
    device = torch.device("cpu" if "cpu" in memtorch.__version__ else "cuda")
    patched_networks = debug_patched_networks(tile_shape, None)
    for patched_network in patched_networks:
        patched_network = apply_nonidealities(
            copy.deepcopy(patched_network),
            non_idealities=[memtorch.bh.nonideality.NonIdeality.Retention],
            time=float(time),
            retention_model=memtorch.bh.nonideality.endurance_retention_models.model_conductance_drift,
            retention_model_kwargs={
                "initial_time": initial_time,
                "drift_coefficient": drift_coefficient,
            },
        )