def forward(self, inputs, tformat, loss_fn=None, hidden_states=None, **kwargs): test_mode = kwargs["test_mode"] avail_actions, params_aa, tformat_aa = _to_batch(inputs["avail_actions"], tformat) x, params, tformat = _to_batch(inputs["main"], tformat) x = F.relu(self.fc1(x)) x = self.fc2(x) # mask policy elements corresponding to unavailable actions n_available_actions = avail_actions.sum(dim=1, keepdim=True) x = th.exp(x) x = x.masked_fill(avail_actions == 0, np.sqrt(float(np.finfo(np.float32).tiny))) x_sum = x.sum(dim=1, keepdim=True) second_mask = (x_sum <= np.sqrt(float(np.finfo(np.float32).tiny)) * avail_actions.shape[1]) x_sum = x_sum.masked_fill(second_mask, 1.0) x = th.div(x, x_sum) # throw debug warning if second masking was necessary if th.sum(second_mask.data) > 0: if self.args.debug_verbose: print('Warning in MACKRLNonRecurrentAgentLevel3.forward(): some sum during the softmax has been 0!') # add softmax exploration (if switched on) if self.args.coma_exploration_mode in ["softmax"] and not test_mode: epsilons = inputs["epsilons"].unsqueeze(_tdim(tformat)) epsilons, _, _ = _to_batch(epsilons, tformat) x = avail_actions * epsilons / n_available_actions + x * (1 - epsilons) x = _from_batch(x, params, tformat) if loss_fn is not None: losses, _ = loss_fn(x, tformat=tformat) return x, hidden_states, losses, tformat
def forward(self, inputs, hidden_states, tformat, loss_fn=None, **kwargs): """ If data contains whole sequences, can pass loss_fn to forward pass in order to generate all losses automatically. Can either be operated in sequence mode, or operated step-by-step """ _check_inputs_validity(inputs, self.input_shapes, tformat) _inputs = inputs["main"] loss = None t_dim = _tdim(tformat) assert t_dim == 2, "t_dim along unsupported axis" t_len = _inputs.shape[t_dim] loss_x = [] output_x = [] h_list = [hidden_states] for t in range(t_len): x = _inputs[:, :, slice(t, t + 1), :].contiguous() x, tformat = self.encoder({"main":x}, tformat) x, params_x, tformat_x = _to_batch(x, tformat) h, params_h, tformat_h = _to_batch(h_list[-1], tformat) h = self.gru(x, h) x = self.output(h) h = _from_batch(h, params_h, tformat_h) x = _from_batch(x, params_x, tformat_x) h_list.append(h) loss_x.append(x) # we will not branch the variables if loss_fn is set - instead return only tensor values for x in that case output_x.append(x) if loss_fn is None else output_x.append(x.clone()) if loss_fn is not None: _x = th.cat(loss_x, dim=_tdim(tformat)) loss = loss_fn(_x, tformat=tformat)[0] return th.cat(output_x, t_dim), \ th.cat(h_list[1:], t_dim), \ loss, \ tformat
def forward(self, inputs, tformat, **kwargs): # _check_inputs_validity(inputs, self.input_shapes, tformat, allow_nonseq=True) if getattr(self.args, "critic_is_recurrent", False): _inputs = inputs.get("main") t_dim = _tdim(tformat) assert t_dim == 1, "t_dim along unsupported axis" t_len = _inputs.shape[t_dim] try: hidden_states = kwargs["hidden_states"] except: pass x_list = [] h_list = [hidden_states] for t in range(t_len): x = _inputs[:, slice(t, t + 1), :].contiguous() x, params_x, tformat_x = _to_batch(x, tformat) h, params_h, tformat_h = _to_batch(h_list[-1], tformat) x = F.relu(self.fc1(x)) h = self.gru(x, h) x = self.fc2(x) h = _from_batch(h, params_h, tformat_h) x = _from_batch(x, params_x, tformat_x) h_list.append(h) x_list.append(x) return th.cat(x_list, t_dim), \ tformat else: main, params, m_tformat = _to_batch(inputs.get("main"), tformat) x = F.relu(self.fc1(main)) vvalue = self.fc2(x) return _from_batch(vvalue, params, m_tformat), m_tformat
def forward(self, inputs, hidden_states, tformat, loss_fn=None, **kwargs): seq_lens = kwargs["seq_lens"] try: _check_inputs_validity(inputs, self.input_shapes, tformat) except Exception as e: print("Exception {} - have replaced NaNs with zeros".format(e)) for _k, _v in inputs.items(): inputs[_k][inputs[_k]!=inputs[_k]] = 0.0 pass test_mode = kwargs["test_mode"] _inputs = inputs["main"] _inputs_aa = _pad(inputs["avail_actions"], tformat, seq_lens, 1.0) loss = None t_dim = _tdim(tformat) assert t_dim == 2, "t_dim along unsupported axis" t_len = _inputs.shape[t_dim] x_list = [] h_list = [hidden_states] for t in range(t_len): x = _inputs[:, :, slice(t, t + 1), :].contiguous() avail_actions = _inputs_aa[:, :, slice(t, t + 1), :].contiguous() x, tformat = self.encoder({"main":x}, tformat) x, params_x, tformat_x = _to_batch(x, tformat) avail_actions, params_aa, tformat_aa = _to_batch(avail_actions, tformat) h, params_h, tformat_h = _to_batch(h_list[-1], tformat) h = self.gru(x, h) x = self.output(h) # # mask policy elements corresponding to unavailable actions # n_available_actions = avail_actions.sum(dim=1, keepdim=True) # x = th.exp(x) # x = x.masked_fill(avail_actions == 0, np.sqrt(float(np.finfo(np.float32).tiny))) # x_sum = x.sum(dim=1, keepdim=True) # second_mask = (x_sum <= np.sqrt(float(np.finfo(np.float32).tiny))*avail_actions.shape[1]) # x_sum = x_sum.masked_fill(second_mask, 1.0) # x = th.div(x, x_sum) n_available_actions = avail_actions.sum(dim=1, keepdim=True) x = x - (1 - avail_actions) * 1e30 x = F.softmax(x, 1) # throw debug warning if second masking was necessary # if th.sum(second_mask.data) > 0: # if self.args.debug_verbose: # print('Warning in MACKRLRecurrentAgentLevel3.forward(): some sum during the softmax has been 0!') # add softmax exploration (if switched on) if self.args.mackrl_exploration_mode_level3 in ["softmax"] and not test_mode: epsilons = inputs["epsilons_central_level3"].unsqueeze(_tdim(tformat)).detach() epsilons, _, _ = _to_batch(epsilons, tformat) # n_available_actions[n_available_actions==0.0] = 1 #np.sqrt(float(np.finfo(np.float32).tiny)) #if th.sum(th.sum(n_available_actions, dim=-1) > 0): # a = 6 # pass #if th.sum(n_available_actions == 0.0) > 0: # a = 5 # pass x = avail_actions.detach() * epsilons / n_available_actions + x * (1 - epsilons) # avail_actions * epsilons / n_available_actions + x * (1 - epsilons) h = _from_batch(h, params_h, tformat_h) x = _from_batch(x, params_x, tformat_x) h_list.append(h) x_list.append(x) if loss_fn is not None: _x = th.cat(x_list, dim=_tdim(tformat)) loss = loss_fn(_x, tformat=tformat)[0] return th.cat(x_list, t_dim), \ th.cat(h_list[1:], t_dim), \ loss, \ tformat
def forward(self, inputs, hidden_states, tformat, loss_fn=None, **kwargs): seq_lens = kwargs["seq_lens"] try: _check_inputs_validity(inputs, self.input_shapes, tformat) except Exception as e: print("Exception {} - have replaced NaNs with zeros".format(e)) for _k, _v in inputs.items(): inputs[_k][inputs[_k]!=inputs[_k]] = 0.0 pass test_mode = kwargs["test_mode"] pairwise_avail_actions = _pad(kwargs["pairwise_avail_actions"].detach(), tformat, seq_lens, 1.0) pairwise_avail_actions.requires_grad = False _inputs = inputs["main"] loss = None t_dim = _tdim(tformat) assert t_dim == 2, "t_dim along unsupported axis" t_len = _inputs.shape[t_dim] x_list = [] h_list = [hidden_states] for t in range(t_len): x = _inputs[:, :, slice(t, t + 1), :].contiguous() avail_actions = pairwise_avail_actions[:, :, slice(t, t + 1), :].contiguous().detach() x, tformat = self.encoder({"main":x}, tformat) x, params_x, tformat_x = _to_batch(x, tformat) avail_actions, params_aa, tformat_aa = _to_batch(avail_actions, tformat) h, params_h, tformat_h = _to_batch(h_list[-1], tformat) h = self.gru(x, h) x = self.output(h) if getattr(self.args, "mackrl_logit_bias", 0.0) != 0.0: x = th.cat([x[:, 0:1] + self.args.mackrl_logit_bias, x[:, 1:]], dim=1) n_available_actions = avail_actions.sum(dim=1, keepdim=True) x = x - (1 - avail_actions) * 1e30 x = F.softmax(x, 1) # add softmax exploration (if switched on) if self.args.mackrl_exploration_mode_level2 in ["softmax"] and not test_mode: epsilons = inputs["epsilons_central_level2"].unsqueeze(_tdim(tformat)).detach() epsilons, _, _ = _to_batch(epsilons, tformat) n_available_actions[n_available_actions==0.0] = 1.0 x = avail_actions.detach() * epsilons / n_available_actions + x * (1 - epsilons) h = _from_batch(h, params_h, tformat_h) x = _from_batch(x, params_x, tformat_x) h_list.append(h) x_list.append(x) x_cat = th.cat(x_list, t_dim) if hasattr(self.args, "mackrl_always_delegate") and self.args.mackrl_always_delegate: x_cat[:, :, :, 0] = 1.0 x_cat[:, :, :, 1:] = 0.0 if loss_fn is not None: loss = loss_fn(x_cat, tformat=tformat)[0] return x_cat, \ th.cat(h_list[1:], t_dim), \ loss, \ tformat
def forward(self, inputs, hidden_states, tformat, loss_fn=None, **kwargs): try: _check_inputs_validity(inputs, self.input_shapes, tformat) except Exception as e: print("Exception {} - have replaced NaNs with zeros".format(e)) for _k, _v in inputs.items(): inputs[_k][inputs[_k]!=inputs[_k]] = 0.0 pass test_mode = kwargs["test_mode"] n_agents = kwargs["n_agents"] if len(inputs["main"].shape) == 3: _inputs = inputs["main"].unsqueeze(0) # as agent dimension is lacking else: _inputs = inputs["main"] #_inputs_aa = inputs["avail_actions"] loss = None t_dim = _tdim(tformat) assert t_dim == 2, "t_dim along unsupported axis" t_len = _inputs.shape[t_dim] x_list = [] h_list = [hidden_states] for t in range(t_len): x = _inputs[:, :, slice(t, t + 1), :].contiguous() #avail_actions = _inputs_aa[:, :, slice(t, t + 1), :].contiguous() x, tformat = self.encoder({"main":x}, tformat) x, params_x, tformat_x = _to_batch(x, tformat) #avail_actions, params_aa, tformat_aa = _to_batch(avail_actions, tformat) h, params_h, tformat_h = _to_batch(h_list[-1], tformat) h = self.gru(x, h) x = self.output(h) # mask policy elements corresponding to unavailable actions #n_available_actions = avail_actions.detach().sum(dim=1, keepdim=True) # DEBUG x = F.softmax(x, 1) # Alternative variant #x = th.nn.functional.softmax(x).clone() #x.masked_fill_(avail_actions.long() == 0, float(np.finfo(np.float32).tiny)) #x = th.div(x, x.sum(dim=1, keepdim=True)) if self.args.mackrl_exploration_mode_level1 in ["softmax"] and not test_mode: epsilons = inputs["epsilons_central_level1"].unsqueeze(_tdim("bs*t*v")).detach() epsilons, _, _ = _to_batch(epsilons, "bs*t*v") x = epsilons / x.shape[-1] + x * (1 - epsilons) if hasattr(self.args, "mackrl_fix_level1_pair") and self.args.mackrl_fix_level1_pair: #x.fill_(0.0) x[:,1:] = 0.0 x[:, 0] = 1.0 h = _from_batch(h, params_h, tformat_h) x = _from_batch(x, params_x, tformat_x) h_list.append(h) x_list.append(x) if loss_fn is not None: _x = th.cat(x_list, dim=_tdim(tformat)) loss = loss_fn(_x, tformat=tformat)[0] return th.cat(x_list, t_dim), \ th.cat(h_list[1:], t_dim), \ loss, \ tformat