def forward(self, inputs, tformat, loss_fn=None, hidden_states=None, **kwargs): test_mode = kwargs["test_mode"] avail_actions, params_aa, tformat_aa = _to_batch(inputs["avail_actions"], tformat) x, params, tformat = _to_batch(inputs["main"], tformat) x = F.relu(self.fc1(x)) x = self.fc2(x) # mask policy elements corresponding to unavailable actions n_available_actions = avail_actions.sum(dim=1, keepdim=True) x = th.exp(x) x = x.masked_fill(avail_actions == 0, np.sqrt(float(np.finfo(np.float32).tiny))) x_sum = x.sum(dim=1, keepdim=True) second_mask = (x_sum <= np.sqrt(float(np.finfo(np.float32).tiny)) * avail_actions.shape[1]) x_sum = x_sum.masked_fill(second_mask, 1.0) x = th.div(x, x_sum) # throw debug warning if second masking was necessary if th.sum(second_mask.data) > 0: if self.args.debug_verbose: print('Warning in MACKRLNonRecurrentAgentLevel3.forward(): some sum during the softmax has been 0!') # add softmax exploration (if switched on) if self.args.coma_exploration_mode in ["softmax"] and not test_mode: epsilons = inputs["epsilons"].unsqueeze(_tdim(tformat)) epsilons, _, _ = _to_batch(epsilons, tformat) x = avail_actions * epsilons / n_available_actions + x * (1 - epsilons) x = _from_batch(x, params, tformat) if loss_fn is not None: losses, _ = loss_fn(x, tformat=tformat) return x, hidden_states, losses, tformat
def forward(self, inputs, tformat): # _check_inputs_validity(inputs, self.input_shapes, tformat, allow_nonseq=True) main, params, m_tformat = _to_batch(inputs.get("main"), tformat) x = F.relu(self.fc1(main)) vvalue = self.fc2(x) return _from_batch(vvalue, params, m_tformat), m_tformat
def forward(self, inputs, hidden_states, tformat, loss_fn=None, **kwargs): """ If data contains whole sequences, can pass loss_fn to forward pass in order to generate all losses automatically. Can either be operated in sequence mode, or operated step-by-step """ _check_inputs_validity(inputs, self.input_shapes, tformat) _inputs = inputs["main"] loss = None t_dim = _tdim(tformat) assert t_dim == 2, "t_dim along unsupported axis" t_len = _inputs.shape[t_dim] loss_x = [] output_x = [] h_list = [hidden_states] for t in range(t_len): x = _inputs[:, :, slice(t, t + 1), :].contiguous() x, tformat = self.encoder({"main":x}, tformat) x, params_x, tformat_x = _to_batch(x, tformat) h, params_h, tformat_h = _to_batch(h_list[-1], tformat) h = self.gru(x, h) x = self.output(h) h = _from_batch(h, params_h, tformat_h) x = _from_batch(x, params_x, tformat_x) h_list.append(h) loss_x.append(x) # we will not branch the variables if loss_fn is set - instead return only tensor values for x in that case output_x.append(x) if loss_fn is None else output_x.append(x.clone()) if loss_fn is not None: _x = th.cat(loss_x, dim=_tdim(tformat)) loss = loss_fn(_x, tformat=tformat)[0] return th.cat(output_x, t_dim), \ th.cat(h_list[1:], t_dim), \ loss, \ tformat
def forward(self, policies, tformat): _policies, policies_params, policies_tformat = _to_batch( policies, tformat) # need batch scalar product as in COMA!!! entropy = th.bmm( th.log(_policies).unsqueeze(1), _policies.unsqueeze(2)).squeeze(2) ret = _from_batch(entropy, policies_params, policies_tformat) return ret
def forward(self, inputs, tformat, **kwargs): # _check_inputs_validity(inputs, self.input_shapes, tformat, allow_nonseq=True) if getattr(self.args, "critic_is_recurrent", False): _inputs = inputs.get("main") t_dim = _tdim(tformat) assert t_dim == 1, "t_dim along unsupported axis" t_len = _inputs.shape[t_dim] try: hidden_states = kwargs["hidden_states"] except: pass x_list = [] h_list = [hidden_states] for t in range(t_len): x = _inputs[:, slice(t, t + 1), :].contiguous() x, params_x, tformat_x = _to_batch(x, tformat) h, params_h, tformat_h = _to_batch(h_list[-1], tformat) x = F.relu(self.fc1(x)) h = self.gru(x, h) x = self.fc2(x) h = _from_batch(h, params_h, tformat_h) x = _from_batch(x, params_x, tformat_x) h_list.append(h) x_list.append(x) return th.cat(x_list, t_dim), \ tformat else: main, params, m_tformat = _to_batch(inputs.get("main"), tformat) x = F.relu(self.fc1(main)) vvalue = self.fc2(x) return _from_batch(vvalue, params, m_tformat), m_tformat
def forward(self, inputs, tformat, loss_fn=None, hidden_states=None, **kwargs): _check_inputs_validity(inputs, self.input_shapes, tformat) # Execute model branch "main" x, params, tformat = _to_batch(inputs["main"], tformat) x = F.relu(self.fc1(x)) x = self.fc2(x) x = _from_batch(x, params, tformat) losses = None if self.output_type in ["policies"]: log_softmax = kwargs.get("log_softmax", False) if log_softmax: x = F.log_softmax(x, dim=_vdim(tformat)) else: x = F.softmax(x, dim=_vdim(tformat)) if loss_fn is not None: losses, _ = loss_fn(x, tformat=tformat) return x, hidden_states, losses, tformat # output, hidden state, losses
def forward(self, inputs, tformat): x, n_seq, tformat = _to_batch(inputs["main"], tformat) x = F.relu(self.fc1(x)) return _from_batch(x, n_seq, tformat), tformat
def forward(self, inputs, actions, hidden_states, tformat, loss_fn=None, **kwargs): seq_lens = kwargs["seq_lens"] # TODO: How do we handle the loss propagation for recurrent layers?? # generate level 1-3 outputs out_level1, hidden_states_level1, losses_level1, tformat_level1 = self.model_level1(inputs=inputs["level1"]["agent_input_level1"], hidden_states=hidden_states["level1"], loss_fn=None, #loss_fn, tformat=tformat["level1"], #n_agents=self.n_agents, **kwargs) pairwise_avail_actions = inputs["level2"]["agent_input_level2"]["avail_actions_pair"] ttype = th.cuda.FloatTensor if pairwise_avail_actions.is_cuda else th.FloatTensor delegation_avails = Variable(ttype(pairwise_avail_actions.shape[0], pairwise_avail_actions.shape[1], pairwise_avail_actions.shape[2], 1).fill_(1.0), requires_grad=False) pairwise_avail_actions = th.cat([delegation_avails, pairwise_avail_actions], dim=_vdim(tformat["level2"])) out_level2, hidden_states_level2, losses_level2, tformat_level2 = self.models["level2_{}".format(0)](inputs=inputs["level2"]["agent_input_level2"], hidden_states=hidden_states["level2"], loss_fn=None, #loss_fn, tformat=tformat["level2"], pairwise_avail_actions=pairwise_avail_actions, **kwargs) out_level3, hidden_states_level3, losses_level3, tformat_level3 = self.models["level3_{}".format(0)](inputs["level3"]["agent_input_level3"], hidden_states=hidden_states["level3"], loss_fn=None, tformat=tformat["level3"], #seq_lens=seq_lens, **kwargs) # for each agent pair (a,b), calculate p_a_b = p(u_a, u_b|tau_a tau_b CK_ab) = p(u_d|CK_ab)*pi(u_a|tau_a)*pi(u_b|tau_b) + p(u_ab|CK_ab) # output dim of p_a_b is (n_agents choose 2) x bs x t x n_actions**2 # Bulk NaN masking # out_level1[out_level1 != out_level1] = 0.0 # out_level2[out_level2 != out_level2] = 0.0 # out_level3[out_level3 != out_level3] = 0.0 # actions = actions.detach() # actions[actions!=actions] = 0.0 p_d = out_level2[:, :, :, 0:1] p_ab = out_level2[:, :, :, 1:-1] _actions = actions.clone() _actions[actions != actions] = 0.0 pi = out_level3 pi_actions_selected = out_level3.gather(_vdim(tformat_level3), _actions.long()).clone() # pi_actions_selected[pi_actions_selected != pi_actions_selected ] = 0.0 #float("nan") pi_actions_selected_mask = (actions != actions) avail_actions_level3 = inputs["level3"]["agent_input_level3"]["avail_actions"].clone().data avail_actions_selected = avail_actions_level3.gather(_vdim(tformat_level3), _actions.long()).clone() # RENORMALIZE THIS STUFF pi_a_cross_pi_b_list = [] pi_ab_list = [] pi_corr_list = [] # DEBUG: if prob is 0 throw a debug! # debug # 0 (0, 1) # 1 (0, 2) # 2 (1, 2) for _i, (_a, _b) in enumerate(self.ordered_agent_pairings): #_ordered_agent_pairings(self.n_agents)[:self.args.n_pair_samples] if hasattr("n_pair_samples", self.args) else _ordered_agent_pairings(self.n_agents)): #(_ordered_agent_pairings(self.n_agents)): # calculate pi_a_cross_pi_b # TODO: Set disallowed joint actions to NaN! pi_a = pi[_a:_a+1].clone() pi_b = pi[_b:_b+1].clone() x, params_x, tformat_x = _to_batch(out_level3[_a:_a+1], tformat["level3"]) y, params_y, tformat_y = _to_batch(out_level3[_b:_b+1], tformat["level3"]) actions_masked = actions.clone() actions_masked[actions!=actions] = 0.0 _actions_x, _actions_params, _actions_tformat = _to_batch(actions_masked[_a:_a+1].clone(), tformat["level3"]) _actions_y, _actions_params, _actions_tformat = _to_batch(actions_masked[_b:_b+1].clone(), tformat["level3"]) _x = x.gather(1, _actions_x.long()) _y = y.gather(1, _actions_y.long()) z = _x * _y u = _from_batch(z, params_x, tformat_x) pi_a_cross_pi_b_list.append(u) # calculate p_ab_selected _p_ab = p_ab[_i:_i+1].clone() joint_actions = _action_pair_2_joint_actions((actions_masked[_a:_a+1], actions_masked[_b:_b+1]), self.n_actions) _z = _p_ab.gather(_vdim(tformat_level2), joint_actions.long()) # Set probabilities corresponding to jointly-disallowed actions to 0.0 avail_flags = pairwise_avail_actions[_i:_i+1].gather(_vdim(tformat_level2), joint_actions.long() + 1).clone() _z[avail_flags==0.0] = 0.0 # TODO: RENORMALIZE? pi_ab_list.append(_z) if not ((hasattr(self.args, "mackrl_always_delegate") and self.args.mackrl_always_delegate) or \ (hasattr(self.args, "mackrl_delegate_if_zero_ck") and self.args.mackrl_delegate_if_zero_ck)): # Calculate corrective delegation (when unavailable actions are selected) aa_a, _, _ = _to_batch(avail_actions_level3[_a:_a+1].clone(), tformat_level3) aa_b, _, _ = _to_batch(avail_actions_level3[_b:_b+1].clone(), tformat_level3) paa, params_paa, tformat_paa = _to_batch(pairwise_avail_actions[_i:_i+1].clone(), tformat_level3) _pi_a, _, _ = _to_batch(pi_a, tformat_level3) _pi_b, _, _ = _to_batch(pi_b, tformat_level3) # x = th.bmm(th.unsqueeze(aa_a, 2), th.unsqueeze(aa_b, 1)) # At least one action unavailable. diff_matrix = th.relu(paa[:, 1:-1].view(-1, self.n_actions, self.n_actions) - th.bmm(th.unsqueeze(aa_a[:, :-1], 2), th.unsqueeze(aa_b[:, :-1], 1))) diff_matrix = diff_matrix * _p_ab.view(-1, self.n_actions, self.n_actions) both_unavailable = th.relu(paa[:, 1:-1].view(-1, self.n_actions, self.n_actions) - th.add(th.unsqueeze(aa_a[:, :-1], 2), th.unsqueeze(aa_b[:, :-1], 1))) both_unavailable = both_unavailable * _p_ab.view(-1, self.n_actions, self.n_actions) both_unavailable_weight = th.sum(both_unavailable.view(-1, self.n_actions*self.n_actions), -1, keepdim=True) # If neither component of the joint action is available both get resampled, with the probability of the independent actors. correction = both_unavailable_weight * pi_actions_selected[_a:_a + 1].view(-1,1) * pi_actions_selected[_b:_b + 1].view(-1,1) act_a = actions_masked[_a:_a + 1].clone().view(-1, 1,1).long() act_b = actions_masked[_b:_b + 1].clone().view(-1, 1,1).long() b_resamples = th.sum(th.gather(diff_matrix, 1, act_a.repeat([1,1, self.n_actions])),-1) b_resamples = b_resamples * pi_actions_selected[_b:_b + 1].view(-1,1) a_resamples = th.sum(th.gather(diff_matrix, 2, act_b.repeat([1, self.n_actions,1])), 1) a_resamples = a_resamples * pi_actions_selected[_a:_a + 1].view(-1,1) correction = correction + b_resamples + a_resamples pi_corr_list.append(_from_batch(correction, params_paa, tformat_paa)) pi_a_cross_pi_b = th.cat(pi_a_cross_pi_b_list, dim=0) pi_ab_selected = th.cat(pi_ab_list, dim=0) p_a_b = p_d * pi_a_cross_pi_b + pi_ab_selected agent_pairs = self.ordered_agent_pairings if not ( (hasattr(self.args, "mackrl_always_delegate") and self.args.mackrl_always_delegate) or \ (hasattr(self.args, "mackrl_delegate_if_zero_ck") and self.args.mackrl_delegate_if_zero_ck) ): pi_corr = th.cat(pi_corr_list, dim=0) p_a_b += pi_corr # DEBUG pi_c_prod_list = [] p_a_b_pair_list = [] for pairs in self.ordered_2_agent_pairings: p_a_b_pair_list.append(reduce(operator.mul, [p_a_b[pair] for pair in pairs])) agents_in_pairs = list(itertools.chain.from_iterable([agent_pairs[pair] for pair in pairs])) # calculate pi_c_prod _pi_actions_selected = pi_actions_selected.clone() _pi_actions_selected[pi_actions_selected_mask] = 0.0 for _agent in agents_in_pairs: _pi_actions_selected[_agent:_agent + 1] = 1.0 _k = th.prod(_pi_actions_selected, dim=_adim(tformat_level3), keepdim=True) pi_c_prod_list.append(_k) pi_ab_pair = th.stack(p_a_b_pair_list, dim=0) pi_c_prod = th.cat(pi_c_prod_list, dim=0) # next, calculate p_a_b * prod(p, -a-b) p_prod = pi_ab_pair * pi_c_prod # now, calculate p_a_b_c _tmp = out_level1.transpose(_adim(tformat_level1), _vdim(tformat_level1)) _tmp[_tmp!=_tmp] = 0.0 p_a_b_c = (p_prod * _tmp).sum(dim=_adim(tformat_level1), keepdim=True) if self.args.debug_mode in ["check_probs"]: if not hasattr(self, "action_table"): self.action_table = {} if not hasattr(self, "actions_sampled"): self.actions_sampled = 0 actions_flat = actions.view(self.n_agents, -1) for id in range(actions_flat.shape[1]): act = tuple(actions_flat[:, id].tolist()) if act in self.action_table: self.action_table[act] += 1 else: self.action_table[act] = 0 self.actions_sampled += actions_flat.shape[0] if self.args.debug_mode in ["check_probs"]: actions_flat = actions.view(self.n_agents, -1) for id in range(actions_flat.shape[1]): print("sampled: ", self.action_table[tuple(actions_flat[:, id].tolist())] / self.actions_sampled, " pred: ", p_a_b_c.view(-1)[id]) # DEBUG MODE HERE! # _check_nan(pi_c_prod) # _check_nan(p_d) # _check_nan(pi_ab_selected) # _check_nan(pi_a_cross_pi_b) # _check_nan(p_a_b) # _check_nan(p_a_b_c) # agent_parameters = list(self.parameters()) # pi_c_prod.sum().backward(retain_graph=True) # p_prod throws NaN # _check_nan(agent_parameters) # p_a_b.sum().backward(retain_graph=True) # _check_nan(agent_parameters) # p_prod.sum().backward(retain_graph=True) # _check_nan(agent_parameters) hidden_states = {"level1": hidden_states_level1, "level2": hidden_states_level2, "level3": hidden_states_level3 } loss = loss_fn(policies=p_a_b_c, tformat=tformat_level3) # loss = p_a_b_c.sum(), "a*bs*t*v" #loss[0].sum().backward(retain_graph=True) # loss[0].backward(retain_graph=True) # try: # _check_nan(agent_parameters) # except Exception as e: # for a, b in self.named_parameters(): # print("{}:{}".format(a, b.grad)) # a = 5 # pass return p_a_b_c, hidden_states, loss, tformat_level3 # note: policies can have NaNs in it!!
def forward(self, inputs, hidden_states, tformat, loss_fn=None, **kwargs): seq_lens = kwargs["seq_lens"] try: _check_inputs_validity(inputs, self.input_shapes, tformat) except Exception as e: print("Exception {} - have replaced NaNs with zeros".format(e)) for _k, _v in inputs.items(): inputs[_k][inputs[_k]!=inputs[_k]] = 0.0 pass test_mode = kwargs["test_mode"] _inputs = inputs["main"] _inputs_aa = _pad(inputs["avail_actions"], tformat, seq_lens, 1.0) loss = None t_dim = _tdim(tformat) assert t_dim == 2, "t_dim along unsupported axis" t_len = _inputs.shape[t_dim] x_list = [] h_list = [hidden_states] for t in range(t_len): x = _inputs[:, :, slice(t, t + 1), :].contiguous() avail_actions = _inputs_aa[:, :, slice(t, t + 1), :].contiguous() x, tformat = self.encoder({"main":x}, tformat) x, params_x, tformat_x = _to_batch(x, tformat) avail_actions, params_aa, tformat_aa = _to_batch(avail_actions, tformat) h, params_h, tformat_h = _to_batch(h_list[-1], tformat) h = self.gru(x, h) x = self.output(h) # # mask policy elements corresponding to unavailable actions # n_available_actions = avail_actions.sum(dim=1, keepdim=True) # x = th.exp(x) # x = x.masked_fill(avail_actions == 0, np.sqrt(float(np.finfo(np.float32).tiny))) # x_sum = x.sum(dim=1, keepdim=True) # second_mask = (x_sum <= np.sqrt(float(np.finfo(np.float32).tiny))*avail_actions.shape[1]) # x_sum = x_sum.masked_fill(second_mask, 1.0) # x = th.div(x, x_sum) n_available_actions = avail_actions.sum(dim=1, keepdim=True) x = x - (1 - avail_actions) * 1e30 x = F.softmax(x, 1) # throw debug warning if second masking was necessary # if th.sum(second_mask.data) > 0: # if self.args.debug_verbose: # print('Warning in MACKRLRecurrentAgentLevel3.forward(): some sum during the softmax has been 0!') # add softmax exploration (if switched on) if self.args.mackrl_exploration_mode_level3 in ["softmax"] and not test_mode: epsilons = inputs["epsilons_central_level3"].unsqueeze(_tdim(tformat)).detach() epsilons, _, _ = _to_batch(epsilons, tformat) # n_available_actions[n_available_actions==0.0] = 1 #np.sqrt(float(np.finfo(np.float32).tiny)) #if th.sum(th.sum(n_available_actions, dim=-1) > 0): # a = 6 # pass #if th.sum(n_available_actions == 0.0) > 0: # a = 5 # pass x = avail_actions.detach() * epsilons / n_available_actions + x * (1 - epsilons) # avail_actions * epsilons / n_available_actions + x * (1 - epsilons) h = _from_batch(h, params_h, tformat_h) x = _from_batch(x, params_x, tformat_x) h_list.append(h) x_list.append(x) if loss_fn is not None: _x = th.cat(x_list, dim=_tdim(tformat)) loss = loss_fn(_x, tformat=tformat)[0] return th.cat(x_list, t_dim), \ th.cat(h_list[1:], t_dim), \ loss, \ tformat
def forward(self, inputs, hidden_states, tformat, loss_fn=None, **kwargs): seq_lens = kwargs["seq_lens"] try: _check_inputs_validity(inputs, self.input_shapes, tformat) except Exception as e: print("Exception {} - have replaced NaNs with zeros".format(e)) for _k, _v in inputs.items(): inputs[_k][inputs[_k]!=inputs[_k]] = 0.0 pass test_mode = kwargs["test_mode"] pairwise_avail_actions = _pad(kwargs["pairwise_avail_actions"].detach(), tformat, seq_lens, 1.0) pairwise_avail_actions.requires_grad = False _inputs = inputs["main"] loss = None t_dim = _tdim(tformat) assert t_dim == 2, "t_dim along unsupported axis" t_len = _inputs.shape[t_dim] x_list = [] h_list = [hidden_states] for t in range(t_len): x = _inputs[:, :, slice(t, t + 1), :].contiguous() avail_actions = pairwise_avail_actions[:, :, slice(t, t + 1), :].contiguous().detach() x, tformat = self.encoder({"main":x}, tformat) x, params_x, tformat_x = _to_batch(x, tformat) avail_actions, params_aa, tformat_aa = _to_batch(avail_actions, tformat) h, params_h, tformat_h = _to_batch(h_list[-1], tformat) h = self.gru(x, h) x = self.output(h) if getattr(self.args, "mackrl_logit_bias", 0.0) != 0.0: x = th.cat([x[:, 0:1] + self.args.mackrl_logit_bias, x[:, 1:]], dim=1) n_available_actions = avail_actions.sum(dim=1, keepdim=True) x = x - (1 - avail_actions) * 1e30 x = F.softmax(x, 1) # add softmax exploration (if switched on) if self.args.mackrl_exploration_mode_level2 in ["softmax"] and not test_mode: epsilons = inputs["epsilons_central_level2"].unsqueeze(_tdim(tformat)).detach() epsilons, _, _ = _to_batch(epsilons, tformat) n_available_actions[n_available_actions==0.0] = 1.0 x = avail_actions.detach() * epsilons / n_available_actions + x * (1 - epsilons) h = _from_batch(h, params_h, tformat_h) x = _from_batch(x, params_x, tformat_x) h_list.append(h) x_list.append(x) x_cat = th.cat(x_list, t_dim) if hasattr(self.args, "mackrl_always_delegate") and self.args.mackrl_always_delegate: x_cat[:, :, :, 0] = 1.0 x_cat[:, :, :, 1:] = 0.0 if loss_fn is not None: loss = loss_fn(x_cat, tformat=tformat)[0] return x_cat, \ th.cat(h_list[1:], t_dim), \ loss, \ tformat
def forward(self, inputs, hidden_states, tformat, loss_fn=None, **kwargs): try: _check_inputs_validity(inputs, self.input_shapes, tformat) except Exception as e: print("Exception {} - have replaced NaNs with zeros".format(e)) for _k, _v in inputs.items(): inputs[_k][inputs[_k]!=inputs[_k]] = 0.0 pass test_mode = kwargs["test_mode"] n_agents = kwargs["n_agents"] if len(inputs["main"].shape) == 3: _inputs = inputs["main"].unsqueeze(0) # as agent dimension is lacking else: _inputs = inputs["main"] #_inputs_aa = inputs["avail_actions"] loss = None t_dim = _tdim(tformat) assert t_dim == 2, "t_dim along unsupported axis" t_len = _inputs.shape[t_dim] x_list = [] h_list = [hidden_states] for t in range(t_len): x = _inputs[:, :, slice(t, t + 1), :].contiguous() #avail_actions = _inputs_aa[:, :, slice(t, t + 1), :].contiguous() x, tformat = self.encoder({"main":x}, tformat) x, params_x, tformat_x = _to_batch(x, tformat) #avail_actions, params_aa, tformat_aa = _to_batch(avail_actions, tformat) h, params_h, tformat_h = _to_batch(h_list[-1], tformat) h = self.gru(x, h) x = self.output(h) # mask policy elements corresponding to unavailable actions #n_available_actions = avail_actions.detach().sum(dim=1, keepdim=True) # DEBUG x = F.softmax(x, 1) # Alternative variant #x = th.nn.functional.softmax(x).clone() #x.masked_fill_(avail_actions.long() == 0, float(np.finfo(np.float32).tiny)) #x = th.div(x, x.sum(dim=1, keepdim=True)) if self.args.mackrl_exploration_mode_level1 in ["softmax"] and not test_mode: epsilons = inputs["epsilons_central_level1"].unsqueeze(_tdim("bs*t*v")).detach() epsilons, _, _ = _to_batch(epsilons, "bs*t*v") x = epsilons / x.shape[-1] + x * (1 - epsilons) if hasattr(self.args, "mackrl_fix_level1_pair") and self.args.mackrl_fix_level1_pair: #x.fill_(0.0) x[:,1:] = 0.0 x[:, 0] = 1.0 h = _from_batch(h, params_h, tformat_h) x = _from_batch(x, params_x, tformat_x) h_list.append(h) x_list.append(x) if loss_fn is not None: _x = th.cat(x_list, dim=_tdim(tformat)) loss = loss_fn(_x, tformat=tformat)[0] return th.cat(x_list, t_dim), \ th.cat(h_list[1:], t_dim), \ loss, \ tformat