def train_mine_policy(scenario: Scenario, horizon: int, batch_size: int, epochs: int, ntrvs: int, mine_class: nn.Module, mine_params, q_net: nn.Module, pi_net: nn.Module, tradeoff: float, lr: float, tag: str = None, save_every: int = 100, log_video_every: Union[int, None] = None, minibatch_size=0, opt_iters=1, lowest_mi=np.inf, cutoff=np.inf, device=pt.device('cpu')): q_net.to(device=device) pi_net.to(device=device) opt = pt.optim.Adam(list(pi_net.parameters()) + list(q_net.parameters()), lr=lr) mine = [mine_class().to(device=device) for t in range(horizon)] last_time = time.time() mi = pt.zeros(horizon).to(device=device) scenario.device = pt.device('cpu') prev_best_value = np.inf current_value = np.inf if minibatch_size == 0: minibatch_size = batch_size if tag is not None: writer = SummaryWriter(f'runs/{tag}', flush_secs=1) for epoch in range(epochs): #if epoch % save_every == 0 or epoch == epochs - 1: start_epoch_event = pt.cuda.Event(enable_timing=True) end_epoch_event = pt.cuda.Event(enable_timing=True) end_rollout_event = pt.cuda.Event(enable_timing=True) start_epoch_event.record() pi_log_probs = pt.zeros((horizon, minibatch_size), device=device) q_log_probs = pt.zeros((horizon, minibatch_size), device=device) q_net.cpu() pi_net.cpu() states, outputs, samples, trvs, inputs, costs = rollout( pi_net, q_net, ntrvs, scenario, horizon, batch_size, pt.device('cpu')) end_rollout_event.record() pt.cuda.synchronize() elapsed_rollout_time = start_epoch_event.elapsed_time( end_rollout_event) / 1000 print(f'Rollout Time: {elapsed_rollout_time:.3f}') print( f'Mean Abs. Displacement: {pt.abs(states[0, -1, :] - states[1, -1, :]).mean().detach().item()}' ) states = states.to(device) outputs = outputs.to(device) samples = samples.to(device) trvs = trvs.to(device) inputs = inputs.to(device) costs = costs.to(device) q_net.to(device) pi_net.to(device) for s in range(batch_size): trv = pt.zeros(ntrvs, device=device) for t in range(horizon): trvs[:, t, s] = q_net(outputs[:, t, s], trv, t, samples[:, t, s])[0] trv = trvs[:, t, s] value = costs.sum(axis=0).mean().item() if tradeoff > -1: states_mi = states.detach().cuda() trvs_mi = trvs.detach().cuda() for t in range(horizon): mine[t].cuda() if epoch == 0: values = train_mine_network( mine[t], (states_mi[:, t, :], trvs_mi[:, t, :]), epochs=100 * mine_params['epochs']) else: train_mine_network(mine[t], (states_mi[:, t, :], trvs_mi[:, t, :]), epochs=mine_params['epochs']) for t in range(horizon): num_datapts = states.shape[2] batch_size = num_datapts joint_batch_idx = np.random.choice(range(num_datapts), size=num_datapts, replace=False) marginal_batch_idx1 = np.random.choice(range(num_datapts), size=num_datapts, replace=False) marginal_batch_idx2 = np.random.choice(range(num_datapts), size=num_datapts, replace=False) joint_batch = pt.cat( (states[:, t, joint_batch_idx], trvs[:, t, joint_batch_idx]), axis=0).t() marginal_batch = pt.cat((states[:, t, marginal_batch_idx1], trvs[:, t, marginal_batch_idx2]), axis=0).t() j_T = mine[t](joint_batch) m_T = mine[t](marginal_batch) mi[t] = j_T.mean() - pt.log(pt.mean(pt.exp(m_T))) mi_sum = mi.sum() baseline = costs.sum(axis=0).mean() current_value = value + tradeoff * mi_sum.detach() if value < cutoff and mi_sum < lowest_mi: print('Saving Model...') lowest_mi = mi_sum.item() pt.save( { 'pi_net_state_dict': pi_net.state_dict(), 'q_net_state_dict': q_net.state_dict() }, f'models/{tag}_epoch_{epoch}_mi_{lowest_mi:.3f}') else: print(f'Current Best: {prev_best_value}') for iter in range(opt_iters): print(f'Computing Iteration {iter}') minibatch_idx = np.random.choice(range(batch_size), size=minibatch_size, replace=False) outputs_minibatch = outputs[:, :, minibatch_idx] trvs_minibatch = trvs[:, :, minibatch_idx] inputs_minibatch = inputs[:, :, minibatch_idx] costs_minibatch = costs[:, minibatch_idx] for s in range(minibatch_size): trv = pt.zeros(ntrvs, device=device) for t in range(horizon): q_log_probs[t, s] = q_net.log_prob(trvs[:, t, s].detach(), outputs_minibatch[:, t, s], trv.detach(), t) pi_log_probs[t, s] = pi_net.log_prob( inputs_minibatch[:, t, s].detach(), trvs_minibatch[:, t, s].detach(), t) trv = trvs_minibatch[:, t, s] opt.zero_grad() loss = pt.mul(pi_log_probs.sum(axis=0), costs_minibatch.sum(axis=0) - baseline).mean() + \ pt.mul(q_log_probs.sum(axis=0), costs_minibatch.sum(axis=0) - baseline).mean() + \ tradeoff * mi_sum loss.backward() opt.step() pi_log_probs = pi_log_probs.detach() q_log_probs = pi_log_probs.detach() if tag is not None: writer.add_scalar('Loss/Total', value + tradeoff * mi.sum().item(), epoch) writer.add_scalar('Loss/MI', mi_sum, epoch) writer.add_scalar('Loss/Cost', value, epoch) writer.add_histogram('Loss/Cost Dist', costs.sum(axis=0), epoch) if log_video_every is not None and epoch % log_video_every == 0: print('Saving Video...') best_traj_idx = pt.argmin(costs.sum(axis=0)) worst_traj_idx = pt.argmax(costs.sum(axis=0)) best_traj_vid = pt.stack([ pt.stack([ outputs[:, t, best_traj_idx].view(3, 64, 64) for t in range(horizon) ]) ]) worst_traj_vid = pt.stack([ pt.stack([ outputs[:, t, worst_traj_idx].view(3, 64, 64) for t in range(horizon) ]) ]) writer.add_video('Loss/Worst Traj', worst_traj_vid, epoch) writer.add_video('Loss/Best Traj', best_traj_vid, epoch) mi = mi.detach() end_epoch_event.record() pt.cuda.synchronize() elapsed_epoch_time = start_epoch_event.elapsed_time( end_epoch_event) / 1000 print( f'[{tradeoff}.{epoch}: {elapsed_epoch_time:.3f}]\t\tAvg. Cost: {value:.3f}\t\tEst. MI: {mi_sum.item():.5f}\t\tTotal: {value + tradeoff * mi_sum.item():.3f}\t\t Lowest MI: {lowest_mi:.3f}' ) if epoch == epochs - 1: return lowest_mi