-
Notifications
You must be signed in to change notification settings - Fork 0
/
agent.py
71 lines (55 loc) · 2.57 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from numpy import zeros, ones, array
from explauto.sensorimotor_model.nearest_neighbor import NearestNeighbor
from explauto.utils.config import make_configuration
from explauto.models.dmp import DmpPrimitive
from explauto.utils import bounds_min_max
from explauto import InterestModel
from explauto.agent import Agent
class DmpAgent(Agent):
def __init__(self, n_dmps, n_bfs, used, default, conf, sm, im, dmp_type='discrete', ay=None):
Agent.__init__(self, conf, sm, im)
self.n_dmps, self.n_bfs = n_dmps, n_bfs
self.current_m = zeros(self.conf.m_ndims)
if ay is None:
self.dmp = DmpPrimitive(n_dmps, n_bfs, used, default, type=dmp_type)
else:
self.dmp = DmpPrimitive(n_dmps, n_bfs, used, default,
type=dmp_type, ay=ones(n_dmps) * 1.)
@classmethod
def from_settings(cls, n_bfs, starting_position, babbling_name, sm_name, im_name):
params = get_params(n_bfs, starting_position, babbling_name, sm_name, im_name)
return cls(**params)
def motor_primitive(self, m):
self.m = bounds_min_max(m, self.conf.m_mins, self.conf.m_maxs)
y = self.dmp.trajectory(self.m)
self.current_m = y[-1, :]
return y # y[:int(len(y) * ((n_bfs*2. - 1.)/(n_bfs*2.))), :]
def sensory_primitive(self, s):
return s[-1] # array([mean(s)]) #s[[-1]]
sms = {
'knn': (NearestNeighbor, {'sigma_ratio': 1. / 38}),
}
def get_params(n_bfs, starting_position, babbling_name, sm_name, im_name):
n_dmps = len(starting_position)
default = zeros(n_dmps*(n_bfs+2))
default[:n_dmps] = starting_position
default[-n_dmps:] = starting_position
poppy_ag = {'m_mins': list([-600] * (n_dmps * n_bfs)) + list(default[:n_dmps] - 180.),
'm_maxs': list([600] * (n_dmps * n_bfs)) + list(default[:n_dmps] + 180.),
's_mins': [-1., -0.7, -0.1],
's_maxs': [1., 0.7, 0.7]
}
poppy_ag_conf = make_configuration(**poppy_ag)
im_dims = poppy_ag_conf.m_dims if babbling_name == 'motor' else poppy_ag_conf.s_dims
im = InterestModel.from_configuration(poppy_ag_conf, im_dims, im_name)
sm_cls, kwargs = sms[sm_name]
sm = sm_cls(poppy_ag_conf, **kwargs)
used = array([False]*n_dmps + [True]*(n_dmps*n_bfs) + [True]*n_dmps)
return {'n_dmps': n_dmps,
'n_bfs': n_bfs,
'used': used,
'default': default,
'conf': poppy_ag_conf,
'sm': sm,
'im': im
}