-
Notifications
You must be signed in to change notification settings - Fork 0
/
stateless_model.py
77 lines (71 loc) · 2.63 KB
/
stateless_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import numpy as np
import chainer
import chainer.functions as F
import chainer.links as L
import chainer_extensions as E
class StatelessModel(chainer.Chain):
"""static model"""
def __init__(self):
super().__init__(
l1=L.Convolution2D(None, 8, 5, pad=2),
l2=L.Convolution2D(None, 8, 5, pad=2),
l3=L.Convolution2D(None, 8, 5, pad=2),
lout=L.Convolution2D(None, 1, 3, pad=1),
l_str=L.Linear(None, 1),
)
self.pen = np.zeros((3, 3), np.float32) # pen shape
self.pen += 0.1
self.pen[1, 1] = 0.9
self.pen = self.pen.reshape((1, 1, 3, 3))
self.move_cost = np.zeros((5, 5), np.float32)
for i in range(5):
for j in range(5):
self.move_cost[i, j] = -5 + np.sqrt((i-2)**2+(j-2)**2)
self.move_cost[2, 2] = 4.0 # don't stop!
self.move_cost = self.move_cost.reshape((1, 1, 5, 5))
self.tau = 1.0
def calc(self, x):
h = x
h = self.l1(h)
h = F.tanh(h)
h = self.l2(h)
h = F.tanh(h)
h = self.l3(h)
h = F.tanh(h)
self.strength = F.sigmoid(self.l_str(h))
self.strength = F.reshape(self.strength, (-1, 1, 1))
h = self.lout(h)
return h
def __call__(self, x):
# ch: canvas, ref, prev_pen
pred = self.calc(x) # b, 1, w, h
shape = pred.shape
b, ch, h, w = shape
pred = F.reshape(pred, (b, -1))
# pred = F.softmax(pred)
pred = E.gumbel_softmax(pred, tau=self.tau)
pred = F.reshape(pred, (b, 1)+shape[2:])
self.current_pos = pred # pen position
# mx, my = np.meshgrid(np.arange(w), np.arange(h))
# bmx, bmy, pos = F.broadcast(
# mx.reshape((1, 1, h, w)),
# my.reshape((1, 1, h, w)),
# self.current_pos)
# px, py = pos*mx, pos*my
# prex, prey = np.sum(mx*x[:, 2, :, :]), np.sum(my*x[:, 2, :, :])
# dx = F.sqrt((F.sum(px)-prex)**2+(F.sum(py)-prey)**2)
mv_cost = F.sum(
0.5*self.current_pos*(
F.convolution_2d(
x[:, 2:3, :, :], self.move_cost, pad=2)+4))
# mv_cost = 0.3*F.relu(dx-1.5)
# print(mv_cost.data)
draw = F.convolution_2d(pred, self.pen, pad=1) # pen stroke
strength, draw = F.broadcast(self.strength, draw[:, 0, :, :])
self.draw = strength*draw
canvas = x[:, 0, :, :] + self.draw
self.canvas = E.leaky_clip(canvas[0, :, :], 0., 1., leak=0.001)
ref = x[:, 1, :, :]
diff = F.sum((canvas-ref)**2)
self.loss = diff+mv_cost
return self.loss