예제 #1
0
 def act(self, stochastic, input_):
     value, logits = self.forward(input_)
     if stochastic:
         action = nd.sample_multinomial(nd.softmax(logits))
     else:
         action = nd.argmax(logits, axis=-1).astype('int32')
     return action, value
    def store_samples(self, data, y, query_network, store_prob, context):
        if not (self.memory_replacement_strategy == "no_replacement" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples):
            num_pus = len(data)
            sub_batch_sizes = [data[i][0][0].shape[0] for i in range(num_pus)]
            num_inputs = len(data[0][0])
            num_outputs = len(y)
            mx_context = context[0]

            if len(self.key_memory) == 0:
                self.key_memory = nd.empty(0, ctx=mx.cpu())
                self.value_memory = []
                self.label_memory = []#nd.empty((num_outputs, 0), ctx=mx.cpu())

            ind = [nd.sample_multinomial(store_prob, sub_batch_sizes[i]).as_in_context(mx_context) for i in range(num_pus)]

            max_inds = [nd.max(ind[i]) for i in range(num_pus)]
            if any(max_inds):
                to_store_values = []
                for i in range(num_inputs):
                    tmp_values = []
                    for j in range(0, num_pus):
                        if max_inds[j]:
                            if isinstance(tmp_values, list):
                                tmp_values = nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j])
                            else:
                                tmp_values = nd.concat(tmp_values, nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]), dim=0)
                    to_store_values.append(tmp_values)

                to_store_labels = []
                for i in range(num_outputs):
                    tmp_labels = []
                    for j in range(0, num_pus):
                        if max_inds[j]:
                            if isinstance(tmp_labels, list):
                                tmp_labels = nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j])
                            else:
                                tmp_labels = nd.concat(tmp_labels, nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]), dim=0)
                    to_store_labels.append(tmp_labels)

                to_store_keys = query_network(*to_store_values[0:self.query_net_num_inputs])

                if self.key_memory.shape[0] == 0:
                    self.key_memory = to_store_keys.as_in_context(mx.cpu())
                    for i in range(num_inputs):
                        self.value_memory.append(to_store_values[i].as_in_context(mx.cpu()))
                    for i in range(num_outputs):
                        self.label_memory.append(to_store_labels[i].as_in_context(mx.cpu()))
                elif self.memory_replacement_strategy == "replace_oldest" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples:
                    num_to_store = to_store_keys.shape[0]
                    self.key_memory = nd.concat(self.key_memory[num_to_store:], to_store_keys.as_in_context(mx.cpu()), dim=0)
                    for i in range(num_inputs):
                        self.value_memory[i] = nd.concat(self.value_memory[i][num_to_store:], to_store_values[i].as_in_context(mx.cpu()), dim=0)
                    for i in range(num_outputs):
                        self.label_memory[i] = nd.concat(self.label_memory[i][num_to_store:], to_store_labels[i].as_in_context(mx.cpu()), dim=0)
                else:
                    self.key_memory = nd.concat(self.key_memory, to_store_keys.as_in_context(mx.cpu()), dim=0)
                    for i in range(num_inputs):
                        self.value_memory[i] = nd.concat(self.value_memory[i], to_store_values[i].as_in_context(mx.cpu()), dim=0)
                    for i in range(num_outputs):
                        self.label_memory[i] = nd.concat(self.label_memory[i], to_store_labels[i].as_in_context(mx.cpu()), dim=0)
예제 #3
0
 def choose_action(self, state):
     state = nd.array([state], ctx=self.ctx)
     all_action_prob = self.actor_network(state)
     action = nd.sample_multinomial(all_action_prob)
     action_prob = nd.pick(all_action_prob, action, axis=1).asnumpy()
     action = int(action.asnumpy())
     return action, action_prob
예제 #4
0
파일: dqn.py 프로젝트: JincanDeng/RL_mx
 def select_action(self, state, is_train=True):
     if is_train:
         # epsilon greedy
         with autograd.record():
             Q_value = self.model(state.as_in_context(self.ctx))
             action = nd.argmax(Q_value, axis=1)
             if nd.random.uniform(0, 1)[0] < self.epsilon:
                 # select other action
                 action = nd.sample_multinomial(nd.ones_like(Q_value)/self.action_space.n)
         self.epsilon -= (self.init_epsilon - self.final_epsilon) / self.replay_size
     else:
         # no epsilon greedy
         Q_value = self.model(state)
         action = nd.argmax(Q_value)
     return action
 def choose_action(self, state):
     state = nd.array([state], ctx=self.ctx)
     all_action_prob = self.network(state)
     action = int(nd.sample_multinomial(all_action_prob).asnumpy())
     return action
예제 #6
0
# Deals with only one random variable

import mxnet as mx
from mxnet import nd
import matplotlib
from matplotlib import pyplot as plt

num = 3000

probabilities = nd.ones(6) / 6
rolls = nd.sample_multinomial(probabilities, shape=(num))

counts = nd.zeros((6,num))
totals = nd.zeros(6)

# Counting the number of trials at each step and the total number of rolls
for i, roll in enumerate(rolls):
	totals[ int(roll.asscalar())] += 1
	counts[:, i] = totals

# Generating the probability at each instant by creating an array of 1-n

x = nd.arange(num).reshape((1,num)) + 1
estimates = counts / x
# print(estimates[:, 0])
# print(estimates[:, 1])
# print(estimates[:, num - 1])

# Plotting all of the choices and their probability
plt.plot(estimates[0, :].asnumpy(), label="Estimated P(die=1)")
plt.plot(estimates[1, :].asnumpy(), label="Estimated P(die=2)")
예제 #7
0
# -*- coding: utf-8 -*-
# !pip install mxnet gluoncv

import mxnet as mx
from mxnet import nd
from matplotlib import pyplot as plt

# simple die probabilities
# we will try to draw samples to assign probs to each side of the dice.
# we assume that each has and equal 1/6.
probabilities = nd.ones(6) / 6
print(probabilities)
# draw samples from the distribution.
nd.sample_multinomial(probabilities)

# multiple draws at one time
print(nd.sample_multinomial(probabilities, shape=(10)))
print(nd.sample_multinomial(probabilities, shape=(5, 10)))

# create a thousand samples.
rolls = nd.sample_multinomial(probabilities, shape=(1000))

# show how many time each side was repeated over the whole course of sampling.
counts = nd.zeros((6, 1000))
# total count for each side.
totals = nd.zeros(6)
for i, roll in enumerate(rolls):
    totals[int(roll.asscalar())] += 1
    counts[:, i] = totals

# total probability for each.
예제 #8
0
 def sample(self, logits):
     # u = nd.random.uniform(shape=logits.shape)
     # return nd.argmax(logits - nd.log(-nd.log(u)), axis=-1)
     return nd.sample_multinomial(logits)