-
Notifications
You must be signed in to change notification settings - Fork 0
/
rram.py
122 lines (91 loc) · 3.55 KB
/
rram.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import sys
import numpy as np
from adc import ADC
class RRAM:
def __init__(self, gp):
# Global Parameters
self.gp = gp
# RRAM Parameters
self.ron = gp.rram.ron
self.roff = gp.rram.roff
self.rvar = gp.rram.rvar
self.rp = gp.rram.rp
self.vdiff = gp.rram.von - gp.rram.voff
self.n_bit = gp.rram.n_bit
self.rlsb = (self.roff-self.ron)/(2**self.n_bit-1)
self.glsb = (1/self.ron - 1/self.roff)/(2**self.n_bit-1)
self.x = gp.rram.size_x
self.y = gp.rram.size_y
# Resistance values
self.arr = np.empty([self.y, self.x])
# Digital values
self.dig_arr = np.empty([self.y, self.x])
# ADC
self.adc = ADC(gp, self.n_bit, gp.mvm.active_rows,\
self.ron, self.roff, self.rvar, self.vdiff)
# Energy
self.e_read = 0
def write(self, weights, res):
# Helper variables
n_bit = int(self.n_bit)
n_cell = int(np.ceil(res/n_bit))
if(n_cell > self.x):
raise Exception("No weight splitting allowed")
w = np.array(weights,dtype=int)
# Generate digital represntation in weight arr
self.dig_arr = np.zeros([self.y, self.x])
for i in range(w.shape[0]):
for j in range(w.shape[1]):
try:
num = int(w[i][j])
a = [(num>>(n_bit*i))&(2**n_bit-1) for i in range(n_cell)]
a = np.flip(np.array(a,dtype=int))
self.dig_arr[i][j*n_cell:(j+1)*n_cell] = a
except:
print("except")
pass
# Assign real resistances to r_cell
for i in range(self.y):
for j in range(self.x):
self.arr[i][j] = 1/self.roff + self.glsb*self.dig_arr[i][j]
self.arr[i][j] = 1/(10**(np.log10(1/self.arr[i][j]) + np.random.normal(0,self.rvar)))
#self.arr[i][j] = 1/(1/self.arr[i][j] + np.random.normal(0,self.rvar))
#print(1/self.arr[i][j])
def read(self, ifmap, res):
ifm = np.array(ifmap, dtype=int)
# Bit-serial approach
dout = np.zeros([1,self.x])
for i in range(res):
v = ((ifm>>i)&1)*(self.vdiff)
i_out = np.dot(v, self.arr)
for j in range(self.x):
dout[0,j] = dout[0,j] + (self.adc.convert(i_out[j])<<i)
self.e_read += self.adc.energy
dout = np.array(dout,dtype=int)
# Concatenate columns
n_cell = int(np.ceil(res/self.n_bit))
num_words = int(np.floor(self.x/n_cell))
out = np.zeros([1,num_words])
for i in range(num_words):
for j in range(n_cell):
idx = i*n_cell+j
out[0][i] += (dout[0][idx]<<((n_cell - 1 - j)*self.n_bit))
return out
#def adc_old(self, i_in):
# #bits = np.ceil(self.n_bit + np.log2(gp.mvm.active_rows))
# i_lsb = self.vdiff*self.glsb
# return np.array(np.floor((i_in+i_lsb/2)/i_lsb),dtype=int)
#def graph(self):
# # Plot number line
# fig = plt.figure(1)
# ax = fig.add_subplot(111)
# irange = imax-imin
# #ax.set_xlim(imin-irange*0.1,imax+irange*0.1)
# ax.set_ylim(0, 10)
#
# plt.hlines(2, imin, imax)
# for i in i_list:
# plt.vlines(i, 2, 5)
# #plt.text(i, 3, '{:.2f}'.format(i*1e6), horizontalalignment='center')
# for i in adc.i_ref:
# plt.vlines(i, 2, 8)