コード例 #1
0
def save(frame_num, img):

	global current_frame_num

	'''
	1. generate path for an input sample
	2. resize the image
	3. store the down-sampled image to the path
	'''

	'''

	file_path = generate_path(frame_num)

	dst = cv.resize(img, dsize=(84,84))
	cv.imwrite(file_path, dst)
	'''

	start_time = time.clock()
	dst = cv.resize(img, dsize=(84,84))


	if len(memory) == replaymem_size:
		memory.pop()

	memory.append(dst)

	logger.log_write_info('store one sample needs time ' + str(time.clock() - start_time))
	current_frame_num = frame_num
コード例 #2
0
def fill_memory_buffer():

		global memory_buffer, current_frame_num, minibatch_actions, minibatch_rewards, minibatch_s0, minibatch_s1

		memory_buffer = []
		minibatch_s0 = []
		minibatch_s1 = []

		for i in range(minibatch_length):

			action = -1
			while(action == -1):
				state_index = np.random.randint(0, current_frame_num / 4)

				if state_index >= len(actions) or state_index >= len(rewards):
					logger.log_write_info('state_index = ' + str(state_index))
					logger.log_write_info('length of actions = ' + str(len(actions)))

				action = actions[state_index]
				minibatch_actions[i, 0] = actions[state_index]
				minibatch_rewards[i, 0] = rewards[state_index]

			stacked_data_St0 = get_stacked_input(state_index)
			stacked_data_St1 = get_stacked_input(state_index+1)

			minibatch_s0.append(stacked_data_St0)
			minibatch_s1.append(stacked_data_St1)
コード例 #3
0
def save(frame_num, img):

    global current_frame_num
    '''
	1. generate path for an input sample
	2. resize the image
	3. store the down-sampled image to the path
	'''
    '''

	file_path = generate_path(frame_num)

	dst = cv.resize(img, dsize=(84,84))
	cv.imwrite(file_path, dst)
	'''

    start_time = time.clock()
    dst = cv.resize(img, dsize=(84, 84))

    if len(memory) == replaymem_size:
        memory.pop()

    memory.append(dst)

    logger.log_write_info('store one sample needs time ' +
                          str(time.clock() - start_time))
    current_frame_num = frame_num
コード例 #4
0
def fill_memory_buffer():

    global memory_buffer, current_frame_num, minibatch_actions, minibatch_rewards, minibatch_s0, minibatch_s1

    memory_buffer = []
    minibatch_s0 = []
    minibatch_s1 = []

    for i in range(minibatch_length):

        action = -1
        while (action == -1):
            state_index = np.random.randint(0, current_frame_num / 4)

            if state_index >= len(actions) or state_index >= len(rewards):
                logger.log_write_info('state_index = ' + str(state_index))
                logger.log_write_info('length of actions = ' +
                                      str(len(actions)))

            action = actions[state_index]
            minibatch_actions[i, 0] = actions[state_index]
            minibatch_rewards[i, 0] = rewards[state_index]

        stacked_data_St0 = get_stacked_input(state_index)
        stacked_data_St1 = get_stacked_input(state_index + 1)

        minibatch_s0.append(stacked_data_St0)
        minibatch_s1.append(stacked_data_St1)
コード例 #5
0
def save(frame_num, img, action, reward):

	global current_frame_num
	start_time = time.clock()

	if len(currnet_ovservation) <= state_dimension:
		dst = cv.resize(img, dsize=(84,84))
		currnet_ovservation.append((dst))



	logger.log_write_info('store one sample needs time ', time.clock() - start_time)
	current_frame_num = frame_num
コード例 #6
0
def get_minibach_smaple():

	'''
	if memory_buffer == [] or memory_buffer_index == memory_buffer_size:

		print 'current frame num = ', current_frame_num

		fill_memory_buffer(current_frame_num)
	'''

	start_time = time.clock()

	fill_memory_buffer()
	logger.log_write_info('get one minibatch states needs time ' + str(time.clock() - start_time))
	#minibach_data = memory_buffer[memory_buffer_index : memory_buffer_index + minibatch_length]

	#memory_buffer_index = memory_buffer_index + minibatch_length

	return minibatch_s0, minibatch_actions, minibatch_rewards, minibatch_s1
コード例 #7
0
	def get_minibach_smaple(self):

		minibatch_s0 = []
		minibatch_s1 = []
		for i in range(self.minibatch_length):
			action = -1
			while(action == -1):
				state_index = np.random.randint(0, len(self.memory))
				action = self.memory[state_index].action

			minibatch_actions[i, 0] = action
			if self.memory[state_index].reward != 0:
				logger.log_write_info(' picked a none-zero reward state !')
			minibatch_rewards[i, 0] = self.memory[state_index].reward

			minibatch_s0.append(self.memory[state_index].observation_0)
			minibatch_s1.append(self.memory[state_index].observation_1)

		return minibatch_s0, minibatch_actions, minibatch_rewards, minibatch_s1
コード例 #8
0
def get_minibach_smaple():
    '''
	if memory_buffer == [] or memory_buffer_index == memory_buffer_size:

		print 'current frame num = ', current_frame_num

		fill_memory_buffer(current_frame_num)
	'''

    start_time = time.clock()

    fill_memory_buffer()
    logger.log_write_info('get one minibatch states needs time ' +
                          str(time.clock() - start_time))
    #minibach_data = memory_buffer[memory_buffer_index : memory_buffer_index + minibatch_length]

    #memory_buffer_index = memory_buffer_index + minibatch_length

    return minibatch_s0, minibatch_actions, minibatch_rewards, minibatch_s1
コード例 #9
0
ファイル: dqn.py プロジェクト: MingyanZhao/dqn_battlecity
	def choose_action(self, current_screen, total_frame):

		start_time = time.clock()

		#p_initial - (n * (p_initial - p_final)) / (total)
		if self.random_action_porb > self.random_pick_p_end:
			self.random_action_porb = self.random_action_porb - \
		                          (self.random_pick_p_start - self.random_pick_p_end)/self.random_pick_peiriod

		logger.log_write_info('random_action_porb = ' + str(self.random_action_porb))

		if np.random.rand() < self.random_action_porb:
			nextaction = np.random.randint(0, 5)
			logger.log_write_debug(' dqn, choose action rondomly, need time ' + str(time.clock() - start_time))
			logger.log_write_info('random action ' + str(nextaction))
			return [nextaction]
		else:
			nextaction = tf.argmax(self.model(self.current_screen), 1)
			logger.log_write_info('dqn select action ' + str(nextaction))
			logger.log_write_debug(' dqn, choose action by DQN, need time ' + str(time.clock() - start_time))
			return self.s.run([nextaction], {self.current_screen : current_screen})
コード例 #10
0
def add_reward(r):
	logger.log_write_info('frame =' + str(current_frame_num) + ', record reward ' + str(r) + ', rewards recorded = ' + str(len(rewards)))
	rewards.append(r)
コード例 #11
0
def add_action(a):
	logger.log_write_info('frame =' + str(current_frame_num) + ', record action ' + str(a) + ', actions recorded = ' + str(len(actions)))
	actions.append(a)
コード例 #12
0
	def save(self,frame_num, img, action, reward):
		resized_img = cv.resize(img, dsize=(84,84))

		if self.temp_state == None:

			#if frame_num % self.state_length == self.state_length - 1 and self.current_init == 4:
			if self.current_init == 4:

				self.following_observation[:,:,frame_num % self.state_length] = resized_img
				if self.current_reward == 0:
					self.current_reward = reward
				self.current_action = action

				self.temp_state = State(observation_0=self.current_observation,
				                        frame=frame_num)
				logger.log_write_info('frame =' + str(frame_num)
				                      + 'current_observation done, NOT record action ' + str(action)
				                      + ', reward = ' + str(reward))
				return
			else:
				#self.current_observation.append(resized_img)
				logger.log_write_info('frame =' + str(frame_num)
				                      + ' recording current_observation no.'
				                      + str(frame_num % self.state_length))
				self.current_observation[:,:,frame_num % self.state_length] = resized_img
				self.current_init += 1
				return
		else:
			if frame_num % self.state_length == self.state_length - 1:
				self.following_observation[:,:,frame_num % self.state_length] = resized_img

				if self.current_reward == 0:
					self.current_reward = reward

				self.current_action = action
				self.temp_state.observation_1 = self.following_observation
				self.temp_state.action = self.current_action
				self.temp_state.reward = self.current_reward


				if self.temp_state.reward == 0 and np.random.rand() < 0.9:
					#if no reward, only recode 20% of the states.
					return
				else:
					if len(self.memory) == self.memory_size:
						self.memory.pop()

					self.memory.append(self.temp_state)

				logger.log_write_info('frame = ' + str(frame_num)
				                      + ' State into memory, numbers recorded ' + str(len(self.memory))
				                      + ' action = ' + str(self.temp_state.action)
				                      + ', reward = ' + str(self.temp_state.reward))


				self.temp_state = None
				self.current_observation = self.following_observation
				self.current_reward = 0
				self.current_action = -1
				self.following_observation = np.array(np.zeros((84,84,4)))
			else:
				logger.log_write_info('frame =' + str(frame_num)
				                      + ' recording following_observation no.' + str(frame_num % self.state_length))
				self.following_observation[:,:,frame_num % self.state_length] = resized_img
				if self.current_reward == 0:
					self.current_reward = reward
				self.current_action = action
				return
コード例 #13
0
ファイル: dqn.py プロジェクト: MingyanZhao/dqn_battlecity
	def dqn_training(self, db_manager, frame_num):

		s0, actions, rewards, s1 = db_manager.get_minibach_smaple()

		s1_array = np.array(s1)

		if os.listdir('./ckp') != []:
			if frame_num == 3200:
				self.saver.restore(sess=self.s, save_path='./ckp/dqn')
				print('***************************check point loaded*******************************************')
				logger.log_write_info('***************************check point loaded*******************************************')

		q = self.s.run(self.q,{self.s0 : s0})

		for i in range(len(actions)):
			l = q[i]
			a = actions[i][0]
			if a == -1 or a == None:
				a = np.random.randint(low=0,high=4)
			self.q_result[i][0] = l[a]

		error, optimizer, l3_w, l4_w, q_max, y = self.s.run(
				[
					self.error,
					self.optimizer_op,
					self.layer3_weights,
					self.layer4_weights,
					self.q_max,
					self.y
				],
				{
					self.q_real : self.q_result,
					self.s1     : s1,
					self.actions: actions,
					self.rewards: rewards
				})

		if frame_num > 4800 and frame_num % 20000 == 0:
			print 'check point saved '
			logger.log_write_info('check point saved ')
			self.saver.save(sess=self.s,
			                save_path='./ckp/dqn',
			                #global_step=frame_num,
			                latest_filename='latest_ckp')

		logger.log_write_info('q_result %f' +str(self.q_result))
		logger.log_write_info('q_max %f' +str(q_max))
		logger.log_write_info('y %f' +str(y))
		logger.log_write_info('training error  = ' + str(error))
コード例 #14
0
def add_reward(r):
    logger.log_write_info('frame =' + str(current_frame_num) +
                          ', record reward ' + str(r) +
                          ', rewards recorded = ' + str(len(rewards)))
    rewards.append(r)
コード例 #15
0
def add_action(a):
    logger.log_write_info('frame =' + str(current_frame_num) +
                          ', record action ' + str(a) +
                          ', actions recorded = ' + str(len(actions)))
    actions.append(a)
コード例 #16
0
	def save(self,frame_num, img, action, reward, terminal):

		if self.load_memory == True:
			with open(self.pickle_name, 'rb') as f:
				m = pickle.load(f)
				self.memory_nonzero = m['nz']
				self.memory_zero = m['zz']
			self.load_memory = False

		if terminal == 1:
			if frame_num % 8 != 0:
				return

		'''

		for p in range(img.shape[0]):
			print '****************************'
			print p[:20,:20, r]
			print p[21:41,21:41, r]
			print p[42:62,42:62, r]
			print p[63:84,42:84, r]
			print '****************************'
		'''

		resized_img = cv.resize(img, dsize=(84,84))

		resized_img = np.reshape(resized_img, (84,84,1))

		'''
		if frame_num % 10000 == 0:
			for i in range(4):
				cv.imwrite('./samples/c_' + str(frame_num) + '_' + str(i) + '_c.png', self.current_observation[:,:,i ])
		'''
		self.following_observation = np.append(resized_img, self.current_observation[:, :, :3], axis=2)

		'''
		if frame_num % 10000 == 0:
			for i in range(4):
				cv.imwrite('./samples/f_' + str(frame_num) + '_' + str(i) + '.png', self.following_observation[:,:,i ])
		'''

		if terminal == 1:
			self.memory_zero.append((self.current_observation, action, reward, self.following_observation, terminal, frame_num))
		else:
			logger.log_write_info(' get a non zero state ' + str(terminal))
			self.memory_nonzero.append((self.current_observation, action, reward, self.following_observation, terminal, frame_num))

		self.current_observation = self.following_observation

		if len(self.memory_nonzero) + len(self.memory_zero) == self.memory_size:
			if random.random < 0.05:
				self.memory_nonzero.popleft()
			else:
				self.memory_zero.popleft()
			if self.pickle_saved == False:
				with open(self.pickle_name, 'wb') as f:
					pickle.dump({'nz':self.memory_nonzero, 'z':self.memory_zero}, f, protocol=2)
				print 'pickle saved'
				self.pickle_saved = True

		if terminal == 0:
			img_init_array = np.array(np.zeros((84,84)),dtype=float)
			self.current_observation = np.stack((img_init_array, img_init_array, img_init_array, img_init_array), axis=2)