def test_inconsistent_rank_inputs_for_importance_weights(self): """Test one of many possible errors in shape of inputs.""" placeholders = { 'log_rhos': tf.placeholder(dtype=tf.float32, shape=[None, None, 1]), 'discounts': tf.placeholder(dtype=tf.float32, shape=[None, None, 1]), 'rewards': tf.placeholder(dtype=tf.float32, shape=[None, None, 42]), 'values': tf.placeholder(dtype=tf.float32, shape=[None, None, 42]), # Should be [None, 42]. 'bootstrap_value': tf.placeholder(dtype=tf.float32, shape=[None]) } with self.assertRaisesRegexp(ValueError, 'must have rank 2'): vtrace.from_importance_weights(**placeholders)
def test_vtrace(self, batch_size): """Tests V-trace against ground truth data calculated in python.""" seq_len = 5 # Create log_rhos such that rho will span from near-zero to above the # clipping thresholds. In particular, calculate log_rhos in # [-2.5, 2.5), # so that rho is in approx [0.08, 12.2). log_rhos = _shaped_arange(seq_len, batch_size) / (batch_size * seq_len) log_rhos = 5 * (log_rhos - 0.5) # [0.0, 1.0) -> [-2.5, 2.5). values = { "log_rhos": log_rhos, # T, B where B_i: [0.9 / (i+1)] * T "discounts": np.array([[0.9 / (b + 1) for b in range(batch_size)] for _ in range(seq_len)]), "rewards": _shaped_arange(seq_len, batch_size), "values": _shaped_arange(seq_len, batch_size) / batch_size, "bootstrap_value": _shaped_arange(batch_size) + 1.0, "clip_rho_threshold": 3.7, "clip_pg_rho_threshold": 2.2, } output = vtrace.from_importance_weights(**values) with self.test_session() as session: output_v = session.run(output) ground_truth_v = _ground_truth_calculation(**values) for a, b in zip(ground_truth_v, output_v): self.assertAllClose(a, b)
def test_vtrace_from_logits(self, batch_size=2): """Tests V-trace calculated from logits.""" seq_len = 5 num_actions = 3 clip_rho_threshold = None # No clipping. clip_pg_rho_threshold = None # No clipping. values = { "behavior_policy_logits": _shaped_arange(seq_len, batch_size, num_actions), "target_policy_logits": _shaped_arange(seq_len, batch_size, num_actions), "actions": np.random.randint(0, num_actions - 1, size=(seq_len, batch_size)), "discounts": np.array( # T, B where B_i: [0.9 / (i+1)] * T [[0.9 / (b + 1) for b in range(batch_size)] for _ in range(seq_len)], dtype=np.float32, ), "rewards": _shaped_arange(seq_len, batch_size), "values": _shaped_arange(seq_len, batch_size) / batch_size, "bootstrap_value": _shaped_arange(batch_size) + 1.0, # B } values = {k: torch.from_numpy(v) for k, v in values.items()} from_logits_output = vtrace.from_logits( clip_rho_threshold=clip_rho_threshold, clip_pg_rho_threshold=clip_pg_rho_threshold, **values, ) target_log_probs = vtrace.action_log_probs( values["target_policy_logits"], values["actions"]) behavior_log_probs = vtrace.action_log_probs( values["behavior_policy_logits"], values["actions"]) log_rhos = target_log_probs - behavior_log_probs # Calculate V-trace using the ground truth logits. from_iw = vtrace.from_importance_weights( log_rhos=log_rhos, discounts=values["discounts"], rewards=values["rewards"], values=values["values"], bootstrap_value=values["bootstrap_value"], clip_rho_threshold=clip_rho_threshold, clip_pg_rho_threshold=clip_pg_rho_threshold, ) assert_allclose(from_iw.vs, from_logits_output.vs) assert_allclose(from_iw.pg_advantages, from_logits_output.pg_advantages) assert_allclose(behavior_log_probs, from_logits_output.behavior_action_log_probs) assert_allclose(target_log_probs, from_logits_output.target_action_log_probs) assert_allclose(log_rhos, from_logits_output.log_rhos)
def test_vtrace_from_iw(): """V-trace 중요도 가중치 테스트.""" batch_size = 1 # 2 seq_len = 5 log_rhos = _shaped_arange(seq_len, batch_size) / (batch_size * seq_len) log_rhos = 5 * (log_rhos - 0.5) # [0.0, 1.0) -> [-2.5, 2.5). values = { 'log_rhos': log_rhos, # T, B where B_i: [0.9 / (i+1)] * T 'discounts': np.array([[0.9 / (b + 1) for b in range(batch_size)] for _ in range(seq_len)]), 'rewards': _shaped_arange(seq_len, batch_size), 'values': _shaped_arange(seq_len, batch_size) / batch_size, 'bootstrap_value': _shaped_arange(batch_size) + 1.0, 'clip_rho_threshold': 3.7, 'clip_pg_rho_threshold': 2.2, } output = vtrace.from_importance_weights(**values) ground_truth = _ground_truth_calculation(**values) for g, o in zip(ground_truth, output): assert np.allclose(g, o.data.tolist())
def test_inconsistent_rank_inputs_for_importance_weights(self): """Test one of many possible errors in shape of inputs.""" T = 3 # pylint: disable=invalid-name B = 2 # pylint: disable=invalid-name values = { 'log_rhos': torch.zeros(T, B, 1), 'discounts': torch.zeros(T, B, 1), 'rewards': torch.zeros(T, B, 42), 'values': torch.zeros(T, B, 42), # Should be [B, 42]. 'bootstrap_value': torch.zeros(B), } with self.assertRaisesRegex(RuntimeError, 'same number of dimensions: got 3 and 2'): vtrace.from_importance_weights(**values)
def test_higher_rank_inputs_for_importance_weights(self): """Checks support for additional dimensions in inputs.""" placeholders = { 'log_rhos': tf.placeholder(dtype=tf.float32, shape=[None, None, 1]), 'discounts': tf.placeholder(dtype=tf.float32, shape=[None, None, 1]), 'rewards': tf.placeholder(dtype=tf.float32, shape=[None, None, 42]), 'values': tf.placeholder(dtype=tf.float32, shape=[None, None, 42]), 'bootstrap_value': tf.placeholder(dtype=tf.float32, shape=[None, 42]) } output = vtrace.from_importance_weights(**placeholders) self.assertEqual(output.vs.shape.as_list()[-1], 42)
def __init__(self, behaviour_actions_log_probs, target_actions_log_probs, policy_entropy, dones, discount, rewards, values, bootstrap_value, entropy_coeff=-0.01, vf_loss_coeff=0.5, clip_rho_threshold=1.0, clip_pg_rho_threshold=1.0): """Policy gradient loss with vtrace importance weighting. VTraceLoss takes tensors of shape [T, B, ...], where `B` is the batch_size. The reason we need to know `B` is for V-trace to properly handle episode cut boundaries. Args: behaviour_actions_log_probs: A float32 tensor of shape [T, B]. target_actions_log_probs: A float32 tensor of shape [T, B]. policy_entropy: A float32 tensor of shape [T, B]. dones: A float32 tensor of shape [T, B]. discount: A float32 scalar. rewards: A float32 tensor of shape [T, B]. values: A float32 tensor of shape [T, B]. bootstrap_value: A float32 tensor of shape [B]. """ self.vtrace_returns = from_importance_weights( behaviour_actions_log_probs=behaviour_actions_log_probs, target_actions_log_probs=target_actions_log_probs, discounts=inverse(dones) * discount, rewards=rewards, values=values, bootstrap_value=bootstrap_value, clip_rho_threshold=clip_rho_threshold, clip_pg_rho_threshold=clip_pg_rho_threshold) # The policy gradients loss self.pi_loss = -1.0 * layers.reduce_sum( target_actions_log_probs * self.vtrace_returns.pg_advantages) # The baseline loss delta = values - self.vtrace_returns.vs self.vf_loss = 0.5 * layers.reduce_sum(layers.square(delta)) # The entropy loss (We want to maximize entropy, so entropy_ceoff < 0) self.entropy = layers.reduce_sum(policy_entropy) # The summed weighted loss self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff + self.entropy * entropy_coeff)
def test_higher_rank_inputs_for_importance_weights(self): """Checks support for additional dimensions in inputs.""" T = 3 # pylint: disable=invalid-name B = 2 # pylint: disable=invalid-name values = { 'log_rhos': torch.zeros(T, B, 1), 'discounts': torch.zeros(T, B, 1), 'rewards': torch.zeros(T, B, 42), 'values': torch.zeros(T, B, 42), 'bootstrap_value': torch.zeros(B, 42), } output = vtrace.from_importance_weights(**values) self.assertSequenceEqual(output.vs.shape, (T, B, 42))
def test_vtrace(self, batch_size=5): """Tests V-trace against ground truth data calculated in python.""" seq_len = 5 # Create log_rhos such that rho will span from near-zero to above the # clipping thresholds. In particular, calculate log_rhos in [-2.5, 2.5), # so that rho is in approx [0.08, 12.2). log_rhos = _shaped_arange(seq_len, batch_size) / (batch_size * seq_len) log_rhos = 5 * (log_rhos - 0.5) # [0.0, 1.0) -> [-2.5, 2.5). values = { 'log_rhos': log_rhos, # T, B where B_i: [0.9 / (i+1)] * T 'discounts': np.array( [[0.9 / (b + 1) for b in range(batch_size)] for _ in range(seq_len)], dtype=np.float32), 'rewards': _shaped_arange(seq_len, batch_size), 'values': _shaped_arange(seq_len, batch_size) / batch_size, 'bootstrap_value': _shaped_arange(batch_size) + 1.0, 'clip_rho_threshold': 3.7, 'clip_pg_rho_threshold': 2.2, } ground_truth = _ground_truth_calculation(**values) values = {key: torch.tensor(value) for key, value in values.items()} output = vtrace.from_importance_weights(**values) for a, b in zip(ground_truth, output): assert_allclose(a, b)
def test_vtrace(self, batch_size): """Tests V-trace against ground truth data calculated in python.""" seq_len = 5 # Create log_rhos such that rho will span from near-zero to above the # clipping thresholds. In particular, calculate log_rhos in [-2.5, 2.5), # so that rho is in approx [0.08, 12.2). log_rhos = _shaped_arange(seq_len, batch_size) / (batch_size * seq_len) log_rhos = 5 * (log_rhos - 0.5) # [0.0, 1.0) -> [-2.5, 2.5). values = { 'log_rhos': log_rhos, # T, B where B_i: [0.9 / (i+1)] * T 'discounts': np.array([[0.9 / (b + 1) for b in range(batch_size)] for _ in range(seq_len)]), 'rewards': _shaped_arange(seq_len, batch_size), 'values': _shaped_arange(seq_len, batch_size) / batch_size, 'bootstrap_value': _shaped_arange(batch_size) + 1.0, 'clip_rho_threshold': 3.7, 'clip_pg_rho_threshold': 2.2, } output = vtrace.from_importance_weights(**values) with self.test_session() as session: output_v = session.run(output) ground_truth_v = _ground_truth_calculation(**values) for a, b in zip(ground_truth_v, output_v): self.assertAllClose(a, b)
def test_vtrace_from_logits(self, batch_size): """Tests V-trace calculated from logits.""" seq_len = 5 num_actions = 3 clip_rho_threshold = None # No clipping. clip_pg_rho_threshold = None # No clipping. dummy_config = {"model": None} # Intentionally leaving shapes unspecified to test if V-trace can # deal with that. placeholders = { # T, B, NUM_ACTIONS "behaviour_policy_logits": tf.placeholder( dtype=tf.float32, shape=[None, None, None]), # T, B, NUM_ACTIONS "target_policy_logits": tf.placeholder( dtype=tf.float32, shape=[None, None, None]), "actions": tf.placeholder(dtype=tf.int32, shape=[None, None]), "discounts": tf.placeholder(dtype=tf.float32, shape=[None, None]), "rewards": tf.placeholder(dtype=tf.float32, shape=[None, None]), "values": tf.placeholder(dtype=tf.float32, shape=[None, None]), "bootstrap_value": tf.placeholder(dtype=tf.float32, shape=[None]), } from_logits_output = vtrace.from_logits( clip_rho_threshold=clip_rho_threshold, clip_pg_rho_threshold=clip_pg_rho_threshold, config=dummy_config, **placeholders) target_log_probs = vtrace.log_probs_from_logits_and_actions( placeholders["target_policy_logits"], placeholders["actions"], dummy_config) behaviour_log_probs = vtrace.log_probs_from_logits_and_actions( placeholders["behaviour_policy_logits"], placeholders["actions"], dummy_config) log_rhos = target_log_probs - behaviour_log_probs ground_truth = (log_rhos, behaviour_log_probs, target_log_probs) values = { "behaviour_policy_logits": _shaped_arange(seq_len, batch_size, num_actions), "target_policy_logits": _shaped_arange(seq_len, batch_size, num_actions), "actions": np.random.randint( 0, num_actions - 1, size=(seq_len, batch_size)), "discounts": np.array( # T, B where B_i: [0.9 / (i+1)] * T [[0.9 / (b + 1) for b in range(batch_size)] for _ in range(seq_len)]), "rewards": _shaped_arange(seq_len, batch_size), "values": _shaped_arange(seq_len, batch_size) / batch_size, "bootstrap_value": _shaped_arange(batch_size) + 1.0, # B } feed_dict = {placeholders[k]: v for k, v in values.items()} with self.test_session() as session: from_logits_output_v = session.run( from_logits_output, feed_dict=feed_dict) (ground_truth_log_rhos, ground_truth_behaviour_action_log_probs, ground_truth_target_action_log_probs) = session.run( ground_truth, feed_dict=feed_dict) # Calculate V-trace using the ground truth logits. from_iw = vtrace.from_importance_weights( log_rhos=ground_truth_log_rhos, discounts=values["discounts"], rewards=values["rewards"], values=values["values"], bootstrap_value=values["bootstrap_value"], clip_rho_threshold=clip_rho_threshold, clip_pg_rho_threshold=clip_pg_rho_threshold) with self.test_session() as session: from_iw_v = session.run(from_iw) self.assertAllClose(from_iw_v.vs, from_logits_output_v.vs) self.assertAllClose(from_iw_v.pg_advantages, from_logits_output_v.pg_advantages) self.assertAllClose(ground_truth_behaviour_action_log_probs, from_logits_output_v.behaviour_action_log_probs) self.assertAllClose(ground_truth_target_action_log_probs, from_logits_output_v.target_action_log_probs) self.assertAllClose(ground_truth_log_rhos, from_logits_output_v.log_rhos)
def test_vtrace_from_logits(self, batch_size): """Tests V-trace calculated from logits.""" seq_len = 5 num_actions = 3 clip_rho_threshold = None # No clipping. clip_pg_rho_threshold = None # No clipping. # Intentionally leaving shapes unspecified to test if V-trace can # deal with that. placeholders = { # T, B, NUM_ACTIONS 'behaviour_policy_logits': tf.placeholder(dtype=tf.float32, shape=[None, None, None]), # T, B, NUM_ACTIONS 'target_policy_logits': tf.placeholder(dtype=tf.float32, shape=[None, None, None]), 'actions': tf.placeholder(dtype=tf.int32, shape=[None, None]), 'discounts': tf.placeholder(dtype=tf.float32, shape=[None, None]), 'rewards': tf.placeholder(dtype=tf.float32, shape=[None, None]), 'values': tf.placeholder(dtype=tf.float32, shape=[None, None]), 'bootstrap_value': tf.placeholder(dtype=tf.float32, shape=[None]), } from_logits_output = vtrace.from_logits( clip_rho_threshold=clip_rho_threshold, clip_pg_rho_threshold=clip_pg_rho_threshold, **placeholders) target_log_probs = vtrace.log_probs_from_logits_and_actions( placeholders['target_policy_logits'], placeholders['actions']) behaviour_log_probs = vtrace.log_probs_from_logits_and_actions( placeholders['behaviour_policy_logits'], placeholders['actions']) log_rhos = target_log_probs - behaviour_log_probs ground_truth = (log_rhos, behaviour_log_probs, target_log_probs) values = { 'behaviour_policy_logits': _shaped_arange(seq_len, batch_size, num_actions), 'target_policy_logits': _shaped_arange(seq_len, batch_size, num_actions), 'actions': np.random.randint(0, num_actions - 1, size=(seq_len, batch_size)), 'discounts': np.array( # T, B where B_i: [0.9 / (i+1)] * T [[0.9 / (b + 1) for b in range(batch_size)] for _ in range(seq_len)]), 'rewards': _shaped_arange(seq_len, batch_size), 'values': _shaped_arange(seq_len, batch_size) / batch_size, 'bootstrap_value': _shaped_arange(batch_size) + 1.0, # B } feed_dict = {placeholders[k]: v for k, v in values.items()} with self.test_session() as session: from_logits_output_v = session.run( from_logits_output, feed_dict=feed_dict) (ground_truth_log_rhos, ground_truth_behaviour_action_log_probs, ground_truth_target_action_log_probs) = session.run( ground_truth, feed_dict=feed_dict) # Calculate V-trace using the ground truth logits. from_iw = vtrace.from_importance_weights( log_rhos=ground_truth_log_rhos, discounts=values['discounts'], rewards=values['rewards'], values=values['values'], bootstrap_value=values['bootstrap_value'], clip_rho_threshold=clip_rho_threshold, clip_pg_rho_threshold=clip_pg_rho_threshold) with self.test_session() as session: from_iw_v = session.run(from_iw) self.assertAllClose(from_iw_v.vs, from_logits_output_v.vs) self.assertAllClose(from_iw_v.pg_advantages, from_logits_output_v.pg_advantages) self.assertAllClose(ground_truth_behaviour_action_log_probs, from_logits_output_v.behaviour_action_log_probs) self.assertAllClose(ground_truth_target_action_log_probs, from_logits_output_v.target_action_log_probs) self.assertAllClose(ground_truth_log_rhos, from_logits_output_v.log_rhos)
def main(): """메인 함수.""" # 환경 생성 env = make_env(ENV_NAME) set_random_seed() device = get_device() net = A2C(env.observation_space.shape, env.action_space.n).to(device) net.apply(weights_init) writer = SummaryWriter(comment="-" + ENV_NAME) log(net) # ZMQ 초기화 context, act_sock, buf_sock = init_zmq() # 입력을 기다린 후 시작 log("Press Enter when the actors are ready: ") input() # 기본 모델을 발행해 액터 시작 log("sending parameters to actors…") publish_model(net, act_sock) optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE) # optimizer = optim.RMSprop(net.parameters(), # lr=RMS_LR, # eps=RMS_EPS, # momentum=RMS_MOMENTUM) fps = 0.0 p_time = None step_idx = 1 max_reward = -1000 # 감쇄 상수 discounts = np.array([pow(GAMMA, i) for i in range(NUM_UNROLL)]) discounts = np.repeat(discounts, NUM_BATCH).reshape(NUM_UNROLL, NUM_BATCH) discounts_v = torch.Tensor(discounts).to(device) while True: # 버퍼에게 학습을 위한 배치를 요청 log("request new batch {}.".format(step_idx)) st = time.time() buf_sock.send(b'') payload = buf_sock.recv() log("receive batch elapse {:.2f}".format(time.time() - st)) if payload == b'not enough': # 아직 배치가 부족 log("not enough data to batch.") time.sleep(1) else: # 배치 학습 st = time.time() step_idx += 1 optimizer.zero_grad() batch, ainfos, binfo = pickle.loads(payload) states, logits, actions, rewards, last_states = batch states_v = torch.Tensor(states).to(device) # 배치 수만큼 logits = [] values = [] bsvalues = [] last_state_idx = [] for bi in range(NUM_BATCH): # 러너의 모델로 예측 logit, value = net(states_v[bi]) logits.append(logit) values.append(value.squeeze(1)) if last_states[bi] is not None: # 부트스트래핑을 위한 마지막 상태 수집 _, bsvalue = net( torch.Tensor([last_states[bi]]).to(device)) bsvalues.append(bsvalue.squeeze(1)) last_state_idx.append(bi) # 러너/액터의 로짓과 동작에서 로그 확률얻어 중요도 샘플링 값 계산 learner_logits = torch.stack(logits).permute(1, 0, 2) learner_values = torch.stack(values).permute(1, 0) actor_logits = torch.stack(logits).permute(1, 0, 2) actor_actions = torch.LongTensor(actions).to(device).permute(1, 0) actor_rewards = torch.Tensor(rewards).to(device).permute(1, 0) bootstrap_value = torch.Tensor(bsvalues).to(device) learner_log_probs =\ log_probs_from_logits_and_actions(learner_logits, actor_actions) actor_log_probs =\ log_probs_from_logits_and_actions(actor_logits, actor_actions) log_rhos = learner_log_probs - actor_log_probs # 중요도 샘플링 값에서 V-trace 결과 얻음 vtrace_ret = from_importance_weights( log_rhos=log_rhos, discounts=discounts_v, rewards=actor_rewards, values=learner_values, bootstrap_value=bootstrap_value, last_state_idx=last_state_idx) # 손실 계산 후 역전파 pg_loss, entropy_loss, baseline_loss, total_loss = \ calc_loss_and_backprop(learner_logits, learner_values, actor_actions, vtrace_ret) grads = np.concatenate([ p.grad.data.cpu().numpy().flatten() for p in net.parameters() if p.grad is not None ]) # 경사 클리핑 nn_utils.clip_grad_norm_(net.parameters(), CLIP_GRAD) optimizer.step() if step_idx % SHOW_FREQ == 0: # 보드 게시 (프레임 단위) # frame_idx = step_idx * NUM_BATCH * NUM_UNROLL write_tb(writer, step_idx, vtrace_ret, learner_values, entropy_loss, pg_loss, baseline_loss, total_loss, grads, ainfos, binfo) # 최고 리워드 모델 저장 _max_reward = np.max([ainfo.reward for ainfo in ainfos.values()]) if _max_reward > max_reward and step_idx % SAVE_FREQ == 0: log("save best model - reward {:.2f}".format(_max_reward)) torch.save(net, ENV_NAME + "-best.dat") max_reward = _max_reward # 모델 발행 if step_idx % PUBLISH_FREQ == 0: publish_model(net, act_sock) if p_time is not None: elapsed = time.time() - p_time fps = 1.0 / elapsed log("train elapsed {:.2f} speed {:.2f} f/s".format(elapsed, fps)) p_time = time.time() writer.close()
def test_vtrace_from_logit(): """V-trace를 로짓에서 계산 테스트.""" seq_len = 5 # n-step num_actions = 3 batch_size = 2 clip_rho_threshold = None # No clipping. clip_pg_rho_threshold = None # No clipping. np.random.seed(0) values = { 'behavior_policy_logits': _shaped_arange(seq_len, batch_size, num_actions), 'target_policy_logits': _shaped_arange(seq_len, batch_size, num_actions), 'actions': np.random.randint(0, num_actions - 1, size=(seq_len, batch_size)), 'discounts': np.array( # T, B where B_i: [0.9 / (i+1)] * T [[0.9 / (b + 1) for b in range(batch_size)] for _ in range(seq_len)]), 'rewards': _shaped_arange(seq_len, batch_size), 'values': _shaped_arange(seq_len, batch_size) / batch_size, 'bootstrap_value': _shaped_arange(batch_size) + 1.0, # B } from_logit_output = vtrace.from_logits( clip_rho_threshold=clip_rho_threshold, clip_pg_rho_threshold=clip_pg_rho_threshold, **values) ground_truth_target_log_probs = vtrace.log_probs_from_logits_and_actions( values['target_policy_logits'], values['actions']) ground_truth_behavior_log_probs = vtrace.log_probs_from_logits_and_actions( values['behavior_policy_logits'], values['actions']) ground_truth_log_rhos = ground_truth_target_log_probs - \ ground_truth_behavior_log_probs from_iw = vtrace.from_importance_weights( log_rhos=ground_truth_log_rhos, discounts=values['discounts'], rewards=values['rewards'], values=values['values'], bootstrap_value=values['bootstrap_value'], clip_rho_threshold=clip_rho_threshold, clip_pg_rho_threshold=clip_pg_rho_threshold) # 중요도 가중치 결과 == 로짓 결과 == ground truth for g, o in zip(from_iw.vs, from_logit_output.vs): assert np.allclose(g, o.data.tolist()) for g, o in zip(from_iw.pg_advantages, from_logit_output.pg_advantages): assert np.allclose(g, o.data.tolist()) for g, o in zip(ground_truth_behavior_log_probs, from_logit_output.behavior_action_log_probs): assert np.allclose(g, o.data.tolist()) for g, o in zip(ground_truth_target_log_probs, from_logit_output.target_action_log_probs): assert np.allclose(g, o.data.tolist()) for g, o in zip(ground_truth_log_rhos, from_logit_output.log_rhos): assert np.allclose(g, o.data.tolist()) logits = torch.Tensor(values['behavior_policy_logits']) actions = torch.LongTensor(values['actions']) advantages = from_iw.pg_advantages import pdb pdb.set_trace() # breakpoint fd504776 // loss = calc_loss(logits, actions, advantages) pass