async def get_rules(self, msg, arg=None): """ Get rules of a server. :param msg: discord.Message :param arg: string :return: """ key = 'rules' full_embeds = [ Embed(self.data, msg, js={ 'title': 'Got server rules:', 'description': 'The following are the rules.' }) ] config = self.data.servers[msg.server.id].config if key in config and config[key] is not None: full_embeds = full_embeds + [ Embed(self.data, msg, js=rule) for rule in config[key] ] if arg is not None: if arg.lower() in ['p', 'perm', 'permanent']: for embed in full_embeds[1:]: await self.data.client.send_message(msg.channel, '', embed=embed.embed) return [ Embed(self.data, msg, js={'title': 'Got server rules:'}) ] return full_embeds
async def set_rules(self, msg, arg): """ Set rules for server. :param msg: discord.Message :param arg: :return: """ key = 'rules' config = self.data.servers[msg.server.id].config if key not in config or config[key] is None: self.data.servers[msg.server.id].config[key] = [] js = Utils.js_decoder(arg) if not isinstance(js, list): js = [js] self.data.servers[msg.server.id].config[key] = js self.data.servers[msg.server.id].write_config() full_embeds = [ Embed(self.data, msg, js={ 'title': 'Set server rules:', 'description': 'The following are the rules.' }) ] full_embeds = full_embeds + [ Embed(self.data, msg, js=embed) for embed in js ] return full_embeds
def embed_splitter(self, embed): if embed.embed is not discord.Embed.Empty: js = embed.embed.to_dict() else: return [embed] limit = 1950 if not len(str(js)) > limit: return [embed] if 'fields' in js: fields = copy.deepcopy(js['fields']) else: fields = [] if 'description' in js: description = copy.deepcopy(js['description']) else: description = '__ __' js['description'] = '__ __' js['fields'] = [] embeds = [] if len(js['description']) > limit: descriptions = self.text_splitter(description) else: descriptions = [copy.deepcopy(description)] for description in descriptions: njs = copy.deepcopy(js) njs['description'] = description embeds.append( Embed(self.data, embed.msg, js=njs, error=embed.error, developer=embed.developer, help=embed.help) ) nfields = [] skip = 0 for i, field in enumerate(fields): if skip: skip -= 1 continue if len(str(field)) < limit: while len(str(nfields)) < limit - 500: if len(fields) > i + skip: if len(str(nfields) + str(fields[i+skip])) < limit-500: nfields.append(fields[i+skip]) skip += 1 else: skip -= 1 break else: skip -= 1 break else: field['value'] = 'Field value too long.' nfields.append(field) if len(nfields) > 0: njs = copy.deepcopy(js) njs['fields'] = nfields embeds.append( Embed(self.data, embed.msg, js=njs, error=embed.error, developer=embed.developer, help=embed.help) ) nfields = [] return embeds
async def remove_join_role(self, msg, arg): """ Set a role that is given for every user when they join a server. :param msg: :param arg: role id or name :return: """ role = (await Utils(self.data).finder(msg.server, arg, role=True))[2] if role is None: return Utils(self.data).error_embed(msg, 'Could not find specified role: ({})'.format(arg)) if 'join_roles' not in self.data.servers[msg.server.id].config: self.data.servers[msg.server.id].config['join_roles'] = [] f = list(filter(lambda r: r['id'] == role.id, self.data.servers[msg.server.id].config['join_roles'])) if len(f) > 0: self.data.servers[msg.server.id].config['join_roles'].remove(f[0]) self.data.servers[msg.server.id].write_config() return Embed( self.data, msg, js={ 'title': 'Set bot join role:', 'description': 'Set bot join role to: ``{}``'.format( self.data.servers[msg.server.id].config['join_roles'] ) } )
async def get_custom_permission(self, msg, arg): """ Get current custom permissions for a server. :param msg: discord.Message :param arg: string - user or role :return: """ user, channel, role = await Utils(self.data).finder(msg.server, arg, role=True, member=True) if role is None and user is None: return Utils(self.data).error_embed(msg, 'Could not find specified user or role: ({})'.format(arg)) if role is not None: who = role if 'custom_permissions' not in self.data.servers[msg.server.id].config: self.data.servers[msg.server.id].config['custom_permissions'] = {} if role.id not in self.data.servers[msg.server.id].config['custom_permissions']: self.data.servers[msg.server.id].config['custom_permissions'][role.id] = [] current_permissions = self.data.servers[msg.server.id].config['custom_permissions'][role.id] else: who = user user_class = self.data.servers[msg.server.id].load_user(user) if 'custom_permissions' not in user_class.config: user_class.config['custom_permissions'] = [] current_permissions = user_class.config['custom_permissions'] return Embed(self.data, msg, js={ 'title': 'Get custom permission:', 'description': '**From:** {}:{} - {}\n\n'.format(who.name, who.id, who.mention) + '**Current permissions:** {}'.format(', '.join(current_permissions)) })
async def member_remove(self, member): """ Discord.Client.on_member_remove(). :param member: discord.Member :return: """ member, = member key = 'member_leave_logging' go, channel = self.check_logging(key, member) if channel is None or not go: return self.data.servers[member.server.id].config[key]['channel'] = [channel.id, channel.name] self.data.servers[member.server.id].write_config() embed = Embed(self.data, member, js={ 'title': 'Detected member leave:', 'description': 'user: {}:<@{}>\ndate: {}; {}'.format( member.id, member.id, Utils.get_full_time_string(datetime.datetime.now()), time.tzname[time.daylight] ) }) await self.data.client.send_message(channel, '', embed=embed.embed)
async def message(self, msg): """ Discord.Client.on_message(). :param msg: discord.Message :return: """ msg, = msg # command logging if not msg.content.startswith(self.data.servers[msg.server.id].config['prefix']): return key = 'command_use_logging' go, channel = self.check_logging(key, msg.author) if channel is None or not go: return self.data.servers[msg.server.id].config[key]['channel'] = [channel.id, channel.name] embed = Embed(self.data, msg, js={ 'title': 'Detected command use by: {}:<@{}>.'.format(msg.author.id, msg.author.id), 'description': 'user: {}:<@{}>\nchannel: {}:<#{}>\ndate: {}; {}\n\nCommand: {}'.format( msg.author.id, msg.author.id, msg.channel.id, msg.channel.id, Utils.get_full_time_string(datetime.datetime.now()), time.tzname[time.daylight], '``{}``'.format(msg.content) ) }) # send alert embed await self.data.client.send_message(channel, '', embed=embed.embed)
def __init__(self, input_size, hidden_size): super(EncoderRNN, self).__init__() self.hidden_size = hidden_size #self.embedding = nn.Embedding(input_size, hidden_size) self.embedding = Embed(input_size, hidden_size) self.lstm = nn.LSTM(hidden_size, hidden_size)
def error_embed(self, msg, error): return Embed(self.data, msg, js={ 'title': 'Error:', 'description': error }, error=True)
async def message_edit(self, tup): """ Discord.Client.on_message_edit(). :param tup: (discord.Message, discord.Message) :return: """ before, after = tup if before.content == after.content and before.attachments == after.attachments and \ before.embeds == after.embeds: return if before.content.startswith(self.data.servers[before.server.id].config['prefix']): return key = 'message_edit_logging' go, channel = self.check_logging(key, before.author) if channel is None or not go: return self.data.servers[before.server.id].config[key]['channel'] = [channel.id, channel.name] self.data.servers[before.server.id].write_config() # prepare first embed embed = Embed(self.data, before, js={ 'title': 'Detected message edit by: {}:<@{}>.'.format(before.author.id, before.author.id), 'description': 'user: {}:<@{}>\nchannel: {}:<#{}>\ndate: {}; {}'.format( before.author.id, before.author.id, before.channel.id, before.channel.id, Utils.get_full_time_string(datetime.datetime.now()), time.tzname[time.daylight] ) }) # # download attachments before alerting before_attachments = [] if len(before.attachments) > 0: for attachment in before.attachments: attachment = Utils.download_file(attachment) before_attachments.append(attachment) after_attachments = [] if len(after.attachments) > 0: for attachment in after.attachments: attachment = Utils.download_file(attachment) after_attachments.append(attachment) # # send alert embed await self.data.client.send_message(channel, '', embed=embed.embed) # await self.send_message_data(before, before_attachments, channel) # send pause await self.data.client.send_message(channel, '----====----') await self.send_message_data(after, after_attachments, channel)
def reassign_ingredient(): jsonData = simplejson.loads(request.body.read()) if request.body else {} layoutName = jsonData.get("layoutName") oldID = jsonData.get("oldID") newID = jsonData.get("newID") removeOld = jsonData.get("removeOld") updated = False if layoutName and oldID and newID: from Embed import Embed embedDao = Embed(db) updated = embedDao.edit_ingredient_mapping(layoutName, oldID, newID) if updated and removeOld: query = db.ingredients.id == oldID db(query).delete() if updated: return api_response(update="success") else: return api_error('Input01', "I'm not angry, just disappointed")
def empty_embed(self, msg): """ Create an empty embed. :param msg: discord.Message :return: """ return (msg.channel, '', Embed(self.data, msg), Utils.add_seconds(datetime.datetime.now(), 60))
def __init__(self, NO, num_of_agents, emb_dim, hiden_dim,action_dim=5): self.NO = NO self.epsilon = 0.9 self.alpha = 0.1 self.gamma = 0.9 self.lr = 0.01 self.action_dim = action_dim self.iter_step = 0 self.TARGET_UPDATE_STEP = 100 self.num_of_agents = num_of_agents self.loss_list = deque(maxlen=5000) # show the loss self.acc_list = deque(maxlen=5000) # show the acc # self.main_enet = Main_Enet(input_dim,emb_dim,hiden_dim) # self.sub_enet = Sub_Enet(input_dim,emb_dim,hiden_dim) self.emb_net = Embed(self.num_of_agents, self.NO) self.eval_net = RL_net(emb_dim, action_dim,hiden_dim) self.target_net = RL_net(emb_dim,action_dim,hiden_dim) self.optimizer = torch.optim.Adam(self.eval_net.parameters(),lr=self.lr) self.loss_func = nn.MSELoss(reduction='none')
async def show_access(self, msg): if 'access' not in self.data.servers[msg.server.id].config: self.data.servers[msg.server.id].config['access'] = {} self.data.servers[msg.server.id].write_config() if len(self.data.servers[msg.server.id].config['access']) == 0: return ( msg.channel, '', Embed( self.data, msg, js={'title': 'Access data:', 'description': 'Currently there is no access data.'} ), -1 ) fields = Utils.convert_json_to_fields(self.data.servers[msg.server.id].config['access']) return ( msg.channel, '', Embed(self.data, msg, js={'title': 'Access data:', 'fields': fields}), -1 )
async def roles(self, msg): fields = [{ 'header': role.name, 'text': 'ID: {}'.format(role.id) } for role in msg.server.roles] return (msg.channel, '', Embed(self.data, msg, js={ 'title': 'Loaded roles:', 'fields': fields }), -1)
async def toggle(self, msg, args): if 'access' not in self.data.servers[msg.server.id].config: self.data.servers[msg.server.id].config['access'] = {} self.data.servers[msg.server.id].write_config() conf = self.data.servers[msg.server.id].config['access'] done = [] for arg in args: if arg in conf: member, channel, role = await Utils(self.data).finder( msg.server, conf[arg]['id'], channel=True, role=True ) if conf[arg]['type'] == 'channel': if channel is None: return ( msg.channel, '', Utils(self.data).error_embed( msg, '**{}**\nInvalid configuration for ``{}`` - could not recognize channel or role.' .format(msg.content, arg) ), Utils.add_seconds(datetime.datetime.now(), 60) ) overwrites = channel.overwrites_for(msg.author) overwrites.update(read_messages=(not overwrites.read_messages)) await self.data.client.edit_channel_permissions(channel, msg.author, overwrites) if conf[arg]['type'] == 'role': if role is None: if channel is None: return ( msg.channel, '', Utils(self.data).error_embed( msg, '**{}**\nInvalid configuration for ``{}`` - could not recognize channel or role.' .format(msg.content, arg) ), Utils.add_seconds(datetime.datetime.now(), 60) ) if role in msg.author.roles: await self.data.client.remove_roles(msg.author, role) else: await self.data.client.add_roles(msg.author, role) done.append(arg) return ( msg.channel, '', Embed(self.data, msg, js={'title': 'Toggled access:', 'description': ', '.join(done)}), -1 )
async def channels(self, msg): fields = [{ 'header': channel.name, 'text': 'ID: {}'.format(channel.id) } for channel in msg.server.channels] return (msg.channel, '', Embed(self.data, msg, js={ 'title': 'Loaded channels:', 'fields': fields }), -1)
async def command_recognizer(self, msg, args): """ Command recognizing loop. :param msg: discord.Message :param args: list - list of strings :return: """ commands = [] skip = 0 for i, arg in enumerate(args): if skip: skip -= 1 continue found, arg = Utils.find_subcommand(self, arg) print(arg) if found: if arg == 'help': if len(args) > i + 1: commands.append( (msg.channel, '', (await Utils(self.data ).default_help(self, msg, arg=args[i + 1])), -1)) skip += 1 else: commands.append( (msg.channel, '', (await Utils(self.data).default_help(self, msg)), -1)) else: loads = await eval('self.{}(msg)'.format(arg)) embed = Embed(self.data, msg, js={ 'title': 'Loaded Data:', 'description': loads }) commands.append((msg.channel, '', embed, -1)) else: # Command not found or not available for use. commands.append(self.arg_not_found(msg, arg)) return commands
def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGTH): super(AttnDecoderRNN, self).__init__() self.hidden_size = hidden_size self.output_size = output_size self.dropout_p = dropout_p self.max_length = max_length #self.embedding = nn.Embedding(self.output_size, self.hidden_size) self.embedding = Embed(self.output_size, self.hidden_size) self.attn = nn.Linear(self.hidden_size * 2, self.max_length) self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size) self.dropout = nn.Dropout(self.dropout_p) self.lstm = nn.LSTM(self.hidden_size, self.hidden_size) self.out = nn.Linear(self.hidden_size, self.output_size)
async def member_join(self, member): """ Discord.Client.on_member_join(). :param member: discord.Member :return: """ member, = member print("member join detected") key = 'member_join_logging' go, channel = self.check_logging(key, member) if 'join_roles' not in self.data.servers[member.server.id].config: print('no join roles') self.data.servers[member.server.id].config['join_roles'] = [] self.data.servers[member.server.id].write_config() roles = [] for role in self.data.servers[member.server.id].config['join_roles']: # [{'name': role.name, 'id': role.id}] role = (await Utils(self.data).finder(member.server, role["id"], role=True))[2] if role is None: continue roles.append(role) if len(roles) > 0: await self.data.client.add_roles(member, *(tuple(roles))) if channel is None or not go: return self.data.servers[member.server.id].config[key]['channel'] = [channel.id, channel.name] self.data.servers[member.server.id].write_config() embed = Embed(self.data, member, js={ 'title': 'Detected member join:', 'description': 'user: {}:<@{}>\ndate: {}; {}'.format( member.id, member.id, Utils.get_full_time_string(datetime.datetime.now()), time.tzname[time.daylight] ) }) await self.data.client.send_message(channel, '', embed=embed.embed)
async def remove_access(self, msg, args): if 'access' not in self.data.servers[msg.server.id].config: self.data.server[msg.server.id].config['access'] = {} values = {} for arg in args: if arg in self.data.servers[msg.server.id].config['access']: values.update({arg: str(self.data.servers[msg.server.id].config['access'][arg])}) self.data.servers[msg.server.id].config['access'].pop(arg) fields = Utils.convert_json_to_fields(values) self.data.servers[msg.server.id].write_config() return ( msg.channel, '', Embed( self.data, msg, js={'title': 'Removed access values:', 'fields': fields} ), -1 )
async def set_bot_prefix(self, msg, arg): """ Set a new prefix for the bot on a server. :param msg: discord.Message :param arg: string - new prefix :return: """ self.data.servers[msg.server.id].config['prefix'] = arg self.data.servers[msg.server.id].write_config() self.data.load(mode=1, _reload=True) self.data.loaded = True return Embed( self.data, msg, js={ 'title': 'Set bot prefix:', 'description': 'Set bot prefix to: ``{}``'.format( self.data.servers[msg.server.id].config['prefix'] ) } )
async def command_handle(self, content, msg): args = [] if ' ' in content: args = content.split(' ') command = args[0] args = ' '.join(args[1:]) else: command = content if len(args): args = re.findall(r'\[.*\]|{.*}|"[^"]*"|\'[^\']*\'|[^ ]*', args) args = list(filter(None, args)) for i, arg in enumerate(args): if arg.startswith('"') or arg.startswith("'"): arg = arg[1:] if arg.endswith('"') or arg.endswith("'"): arg = arg[:-1] args[i] = arg # make sure no aliases overlap aliases, alias_command, disabled_aliases, disabled_command = self.get_aliases() if command in aliases: i = aliases.index(command) responses = await self.data.command_calls[alias_command[i]](self.data).execute(msg, args=args) for response in responses: channel, message, embed, time = response for line in list(filter(None, self.text_splitter(message))): await self.send_message(channel, msg.author, line, Embed(self.data, msg), time) for embed in self.embed_splitter(embed): await self.send_message(channel, msg.author, '', embed, time) await self.data.client.delete_message(msg) return elif command in disabled_aliases: embed = Utils(self.data).error_embed(msg, "That alias ({}) is currently disabled.".format(command)) else: embed = Utils(self.data).error_embed(msg, "Unrecognized command ({}).".format(command)) await self.send_message(msg.channel, msg.author, '', embed, Utils.add_seconds(datetime.datetime.now(), 60)) await self.data.client.delete_message(msg)
async def add_role_to_everyone(self, msg, arg): role = (await Utils(self.data).finder(msg.server, arg, role=True))[2] if role is None: return (msg.channel, '', Utils(self.data).error_embed( msg, 'Could not find specified role: ``{}``'.format(arg)), Utils.add_seconds(datetime.datetime.now(), 60)) if msg.server.large: await self.data.client.request_offline_members() for member in msg.server.members: if role not in member.roles: await self.data.client.add_roles(member, role) return (msg.channel, '', Embed(self.data, msg, js={ 'title': 'Added roles:', 'description': 'Added <@{}>'.format(role.id) }), -1)
async def message_delete(self, msg): """ Discord.Client.on_message_delete(). :param msg: discord.Message :return: """ msg, = msg if msg.content.startswith(self.data.servers[msg.server.id].config['prefix']): return key = 'message_delete_logging' go, channel = self.check_logging(key, msg.author) if channel is None or not go: return self.data.servers[msg.server.id].config[key]['channel'] = [channel.id, channel.name] self.data.servers[msg.server.id].write_config() embed = Embed(self.data, msg, js={ 'title': 'Detected message deletion:', 'description': 'user: {}:<@{}>\nchannel: {}:<#{}>\ndate: {}; {}'.format( msg.author.id, msg.author.id, msg.channel.id, msg.channel.id, Utils.get_full_time_string(datetime.datetime.now()), time.tzname[time.daylight] ) }) attachments = [] if len(msg.attachments) > 0: for attachment in msg.attachments: attachment = Utils.download_file(attachment) attachments.append(attachment) await self.data.client.send_message(channel, '', embed=embed.embed) await self.send_message_data(msg, attachments, channel)
async def add_rules(self, msg, arg): """ Add rules to server. :param msg: discord.Message :param arg: string - json :return: """ key = 'rules' config = self.data.servers[msg.server.id].config if key not in config or config[key] is None: self.data.servers[msg.server.id].config[key] = [] config = self.data.servers[msg.server.id].config js = Utils.js_decoder(arg) if not isinstance(js, list): js = [js] full_embeds = [ Embed(self.data, msg, js={ 'title': 'Added server rules:', 'description': 'The following are the rules.' }) ] data = config[key] n_data = [embed for embed in js] self.data.servers[msg.server.id].config[key] = data + n_data full_embeds = full_embeds + n_data self.data.servers[msg.server.id].write_config() return full_embeds
async def default_help(self, msg, arg=None): """ Provide help for this command. :param self: command where this was called (in place of self) :param msg: discord.message :param arg: string or None :return: Help embed """ js = {'title': 'Help: ', 'description': ''} if arg is None: js['title'] = js['title'] + '--{}--'.format(self.aliases[0]) js['description'] = '**[Aliases]: {}**\n\n**Descriptions:**\n'.format( ', '.join(sorted(self.aliases))) command_descriptions = [] for command in sorted(self.command_descriptions): if Utils(self.data).permissions(self.user, self.permits[command], self.aliases[0], command): command_descriptions.append({ 'header': '**{}:**\n'.format(command), 'text': 'Permission: {}\n'.format(self.permits[command]) + 'Function: {}'.format( self.command_descriptions[command][0]) }) if len(command_descriptions) == 0: command_descriptions = [{ 'header': "This command doesn't take any arguments or you can not access them.", 'text': '__ __' }] js['fields'] = command_descriptions + [{ 'header': 'Syntax:', 'text': '<arg> - non-optional argument\n' '<,arg> - optional argument\n' '[arg] - non-optional arguments\n' '[,arg] - optional arguments' }] else: # arg is not None js['title'] = js['title'] + '--{}[{}]--'.format( self.aliases[0], arg) found = True if arg not in self.command_descriptions: found = False for cmd in self.command_descriptions: if arg in self.command_descriptions[cmd][2]: found = True arg = cmd if found and not Utils(self.data).permissions( self.user, self.permits[arg], self.aliases[0], arg): found = False if not found: js['description'] = '**{}**\nArgument was not found: ({}).'.format( msg.content, arg) else: js['fields'] = [{ 'header': '**{}:**'.format(arg), 'text': '**[Aliases]:** {}\n\n'.format(', '.join( sorted(self.command_descriptions[arg][2]))) + "**[Function]:** {}\n\n".format( self.command_descriptions[arg][0]) + "**[Usage]:** {}". format(self.command_descriptions[arg][1].format( p=self.data.servers[msg.server.id].config['prefix'])) }] return Embed(self.data, msg, js=js, help=True)
RE = ENV(env_params=env_params, speed_ration_map=speed_ration_map, obs_builder="local") env = RE.env() env_renderer = RE.env_renderer(env) obs, info = env.reset() env_renderer.reset() agents = env.agents agent = Agent(NO=1, num_of_agents=1, emb_dim=16, hiden_dim=10) # emb_dim = 16 + N - 1 = N + 15 memory = Memory(1000, 16) if __name__ == "__main__": emb_net = Embed(env_params["number_of_agents"], 0) obs_emb = emb_net.local_obs_to_emb(obs) print("agents.position: ", agents[0].position) print("obs_emb:", obs_emb) obs_emb = torch.Tensor(obs_emb) action = agent.act(obs_emb) actions = {0: action} # only one agent print("action:", actions) obs_next, rewards, done, _ = env.step(actions) obs_next_emb = emb_net.local_obs_to_emb(obs_next) print("obs_next_emb:", obs_next_emb) obs_next_emb = torch.Tensor(obs_next_emb) memory.store_memory(obs_emb, actions, rewards, obs_next_emb, done) print("memory:", memory) # agent.step(obs_emb,actions,rewards,obs_next_emb,done)
class Agent_gym(): ''' Double-DQN @params: NO: the No. of agents num_of_agents: the number of agents emb_dim: Agent net input dim = 15 + N hiden_dim: Agent net hiden dim action_dim: Agent net output dim = number of actions ''' def __init__(self, NO, num_of_agents, emb_dim, hiden_dim, action_dim=5): self.NO = NO self.epsilon = 0.9 self.alpha = 0.1 self.gamma = 0.9 self.lr = 0.01 self.action_dim = action_dim self.iter_step = 0 self.TARGET_UPDATE_STEP = 100 self.num_of_agents = num_of_agents self.loss_list = deque(maxlen=20000) # show the loss self.acc_list = deque(maxlen=20000) # show the acc # self.main_enet = Main_Enet(input_dim,emb_dim,hiden_dim) # self.sub_enet = Sub_Enet(input_dim,emb_dim,hiden_dim) self.emb_net = Embed(self.num_of_agents, self.NO) self.eval_net = RL_net_test(emb_dim, action_dim, hiden_dim) self.target_net = RL_net_test(emb_dim, action_dim, hiden_dim) self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=self.lr) self.loss_func = nn.MSELoss(reduction='none') def act(self, obs): if random.random() < self.epsilon: actions_value = self.eval_net(obs) action = torch.max(actions_value, -1)[1].data.numpy() # max actions value # action = check_actions(action) else: action = random.choice(range(self.action_dim)) return action def step(self, obs_batch, action_batch, reward_batch, obs_next_batch, done): ''' iterate and train by Q-Learining ''' self.iter_step = self.iter_step % 5000 + 1 if self.iter_step % self.TARGET_UPDATE_STEP == 0: self.target_net.load_state_dict( self.eval_net.state_dict()) # update target net print("---------------- Update target net") # print("target net parms:",self.target_net.state_dict()) reward = 0 for i in range(len(obs_batch)): # for i in range(batch_size) # print("obs_batch:",obs_batch) obs = torch.Tensor(obs_batch[i]) action = action_batch[i] # 怎么强化对应的action呢 reward += reward_batch[i] # reward r = reward_batch[i] obs_next = torch.Tensor(obs_next_batch[i]) done_i = done[i] # print("obs:",obs) # print("action:",action) # TODO 不能只更新action对应的loss,应该更新每个q值的loss # action对应的q值,用q_target来更新,其他的用它自己或加上-0.1来更新 q_pre = self.eval_net(obs) q_tar = q_pre.clone().detach() q_eval = q_pre[action] # q_target = q_eval.clone() # print("q_eval:",q_eval) q_next = self.target_net(obs_next).detach() # print("q_next:",q_next) # q_target = r + self.gamma * q_next.max() q_target = q_eval + self.alpha * (r + self.gamma * q_next.max() - q_eval) q_tar[action] = q_target # q_target[action] = q_ # print("q_next, r:", q_next, r) # print("action, q_", action, q_) # print("q_eval, q_target:",q_eval, q_target) loss = self.loss_func(q_pre, q_tar) # print("****loss:", loss) # acc计算不对 acc = np.abs(q_target.detach().numpy() - \ q_eval.detach().numpy()) / np.abs(q_target.detach().numpy()) avr = (q_target + q_eval) / 2 loss_q = ((q_target - avr)**2 + (q_eval - avr)**2) / 2 self.loss_list.append(loss_q * 10) self.acc_list.append(acc) self.optimizer.zero_grad() loss.backward(torch.ones_like(q_pre)) self.optimizer.step() # print("step successd") return reward def emb(self, obs): obs_emb = self.emb_net.local_obs_to_emb(obs) obs_emb = torch.Tensor(obs_emb) return obs_emb def save_model(self, save_path): ''' save model to pickle ''' with open(save_path + 'RL_gym_agent_' + str(self.NO) + ".pkl", 'wb') as f: torch.save(self.target_net, f) with open(save_path + 'loss_20000_2.pkl', 'wb') as fl: pickle.dump(self.loss_list, fl) print("Model %d Saved!" % self.NO) def load_model(self, load_path): ''' load model from pickle file ''' with open(load_path + 'RL_gym_agent_stable' + ".pkl", 'rb') as f: self.target_net = torch.load(f) self.eval_net = self.target_net def show(self, show_loss=True, show_acc=False): x_loss = range(len(self.loss_list)) x_acc = range(len(self.acc_list)) y_loss = self.loss_list y_acc = self.acc_list if show_loss: plt.subplot(2, 1, 1) plt.plot(x_loss, y_loss, '-') plt.title('Train loss vs. epoches') plt.ylabel('Train loss') ''' plt.subplot(3,1,2) y_loss_last = self.loss_list[len(self.loss_list)-1000 : len(self.loss_list)] x_loss_last = range(len(y_loss_last)) plt.plot(x_loss_last, y_loss_last,'-') plt.xlabel('Last 1000 loss vs. epochs') plt.ylabel('last 1000 loss') ''' if show_acc: plt.subplot(2, 1, 2) plt.plot(x_acc, y_acc, '.-') plt.xlabel('Test accuracy vs. epoches') plt.ylabel('Test accuracy') plt.show() plt.savefig("accuracy_loss.jpg")
class Agent(): ''' Double-DQN @params: NO: the No. of agents num_of_agents: the number of agents emb_dim: Agent net input dim = 15 + N hiden_dim: Agent net hiden dim action_dim: Agent net output dim = number of actions ''' def __init__(self, NO, num_of_agents, emb_dim, hiden_dim,action_dim=5): self.NO = NO self.epsilon = 0.9 self.alpha = 0.1 self.gamma = 0.9 self.lr = 0.01 self.action_dim = action_dim self.iter_step = 0 self.TARGET_UPDATE_STEP = 50 self.num_of_agents = num_of_agents self.loss_list = deque(maxlen=5000) # show the loss self.acc_list = deque(maxlen=5000) # show the acc # self.main_enet = Main_Enet(input_dim,emb_dim,hiden_dim) # self.sub_enet = Sub_Enet(input_dim,emb_dim,hiden_dim) self.emb_net = Embed(self.num_of_agents, self.NO) self.eval_net = RL_net(emb_dim, action_dim,hiden_dim) self.target_net = RL_net(emb_dim,action_dim,hiden_dim) print("T parms:",self.target_net.state_dict()) self.optimizer = torch.optim.Adam(self.eval_net.parameters(),lr=self.lr) self.loss_func = nn.MSELoss(reduction='none') def act(self, obs, actions_valid,d): ''' 这里的obs只包括[x,y,s]三个值,不包括方向,所以需要单独传入方向 parms:obs, actions_valid, dir ''' if random.random() < self.epsilon: # agent给出5个动作,分别是[0,1,2,3,4] = [N,E,S,W,Stop] # 将action_valid中的动作根据方向转化为上面对应的5个动作 actions_value = self.eval_net(obs) # [0,1,2,3]=[N,E,S,W] # 不能直接用预测出来的Q值来选择动作,应当剔除掉做不到的动作,在可以做到的动作中选择最优的 # 否则它会一直选Q值最大但是做不到的动作 # 1.判断当前状态下哪个动作可行,哪些动作不可行 # 2.将不可行的动作的Q值设为负值,比如-10 # 3.然后依据Q值选择最佳动作 # 4.即使随机选择也要把不可行的排除在外 for i in range(len(actions_valid)): # 把不可行的动作对应的Q值设为-10 if actions_valid[i] == 0: actions_value[i] = -10 elif d == 0: if actions_valid[i] == 4: break a = (a+2)%4 # TODO action是取Q值最大的序号,能不能直接取Q值,然后取整?? action = torch.max(actions_value,-1)[1].data.numpy() # max actions value ''' if 0==((obs!=torch.tensor([25,8,3,1])).sum()) or 0==((obs!=torch.tensor([25,7,3,1])).sum()) \ or 0==((obs!=torch.tensor([25,8,0,1])).sum()) or 0==((obs!=torch.tensor([24,9,3,1])).sum()) or 0==((obs!=torch.tensor([24,9,2,1])).sum()): # print("########") print("obs:",obs) print("actions_value:",actions_value) print("action:", action) print(" ") ''' # action = check_actions(action) else: # 随机选择时也从可行的动作中选 actions_v = [] for i in range(len(actions_valid)): if actions_valid[i] == 1: actions_v.append(i) action = random.choice(actions_v) # print("random") # print("chosen aciton:",action) return action def check_actions(self, action): ''' TODO make sure the actions are leagel ''' return action def step(self,obs_batch,action_batch,reward_batch,obs_next_batch,done): ''' iterate and train by Q-Learning ''' # print("eval parms:",self.eval_net.state_dict()) self.iter_step = (self.iter_step % 3000) + 1 if self.iter_step % self.TARGET_UPDATE_STEP == 0: self.target_net.load_state_dict(self.eval_net.state_dict()) # update target net if self.epsilon < 0.9: self.epsilon += 0.01 # print("epsilon:",self.epsilon) print(" - Update t net") print("T parms:",self.target_net.state_dict()) reward = 0 for i in range(len(obs_batch)): # for i in range(batch_size) # print("obs_batch:",obs_batch) obs = torch.Tensor(obs_batch[i][self.NO]) action = action_batch[i][self.NO] # 怎么强化对应的action呢 obs_next = torch.Tensor(obs_next_batch[i][self.NO]) done_i = done[i][self.NO] # 直接用环境返回的reward相加来计算reward r = reward_batch[i][self.NO] ''' r = 0 # 根据done简单计算reward if done_i: r += 10 # agent_i到达终点时,r=5 if done[i]["__all__"]: r += 10 # 所有agent到达终点时,再加10 ''' reward += r # reward # TODO 不能只更新action对应的loss,应该更新每个q值的loss # action对应的q值,用q_target来更新,其他的用它自己或加上-0.1来更新 q_pre = self.eval_net(obs) q_tar = q_pre.clone().detach() q_eval = q_pre[action] # q_target = q_eval.clone() q_next = self.target_net(obs_next).detach() # print("q_next:",q_next) # q_target = r + self.gamma * q_next.max() # TODO 下面这个公式才是对的 q_target = q_eval + self.alpha * (r + self.gamma * q_next.max() - q_eval) q_tar[action] = q_target # 将q_target取非负 ''' if q_target > q_eval and q_target < 0: q_target = torch.Tensor([0.0]) ''' # q_target[action] = q_ # print("q_next, r:", q_next, r) # print("action, q_", action, q_) # print("q_eval, q_target:",q_eval, q_target) # TODO loss要与输出同维度 loss = self.loss_func(q_pre,q_tar) ''' if 0==((obs!=torch.tensor([25,8,3,1])).sum()) or 0==((obs!=torch.tensor([25,7,3,1])).sum()) \ or 0==((obs!=torch.tensor([25,8,0,1])).sum()) or 0==((obs!=torch.tensor([24,9,3,1])).sum()) or 0==((obs!=torch.tensor([24,9,2,1])).sum()): # for t in range(10): # q_pre_t = self.eval_net(obs) # print("q_pre_t:",q_pre_t) print("obs:",obs) print("action:",action) print("obs_next:",obs_next) print("r:",r) print("q_pre:",q_pre) print("q_tar:",q_tar) print("q_eval:",q_eval) print("q_target: ", q_target) print("loss:", loss) ''' # acc计算不对 acc = np.abs(q_target.detach().numpy() - \ q_eval.detach().numpy()) / np.abs(q_target.detach().numpy()) avr = (q_target+q_eval)/2 loss_q = ((q_target - avr)**2 + (q_eval - avr)**2) / 2 self.loss_list.append(loss_q) self.acc_list.append(acc) self.optimizer.zero_grad() loss.backward(torch.ones_like(q_pre)) self.optimizer.step() # print("step successd") return reward def emb(self,obs): obs_emb = self.emb_net.local_obs_to_emb(obs) obs_emb = torch.Tensor(obs_emb) return obs_emb def save_model(self,save_path): ''' save model to pickle ''' with open(save_path+'PSD_RLagent_'+str(self.NO)+".pkl", 'wb') as f: torch.save(self.target_net, f) print("Model %d Saved!" % self.NO) def show(self,show_loss=True,show_acc=False): x_loss = range(len(self.loss_list)) x_acc = range(len(self.acc_list)) y_loss = self.loss_list y_acc = self.acc_list if show_loss: plt.subplot(3, 1, 1) plt.plot(x_loss, y_loss, '-') plt.title('Train loss vs. epoches') plt.ylabel('Train loss') # show last 1000 loss ''' if len(self.loss_list) > 1000: plt.subplot(3,1,2) y_loss_last = self.loss_list[len(self.loss_list)-1000 : len(self.loss_list)] x_loss_last = range(len(y_loss_last)) plt.plot(x_loss_last, y_loss_last,'-') plt.xlabel('Last 1000 loss vs. epochs') plt.ylabel('last 1000 loss') else: print("length of loss_list is less than 1000!") ''' if show_acc: plt.subplot(3, 1, 3) plt.plot(x_acc, y_acc, '.-') plt.xlabel('Test accuracy vs. epoches') plt.ylabel('Test accuracy') plt.show() plt.savefig("PSD_loss.jpg")