示例#1
0
def load_checkpoint(file_dir, i_epoch, layer_sizes, input_size, device='cuda'):
    checkpoint = torch.load(os.path.join(file_dir, "ckpt_eps%d.pt" % i_epoch),
                            map_location=device)

    policy_net = PolicyNet(layer_sizes).to(device)
    policy_net.load_state_dict(checkpoint["policy_net"])
    policy_net.train()

    value_net_in = ValueNet(input_size).to(device)
    value_net_in.load_state_dict(checkpoint["value_net_in"])
    value_net_in.train()

    value_net_ex = ValueNet(input_size).to(device)
    value_net_ex.load_state_dict(checkpoint["value_net_ex"])
    value_net_ex.train()

    valuenet_in_optim = optim.Adam(value_net_in.parameters())
    valuenet_in_optim.load_state_dict(checkpoint["valuenet_in_optim"])

    valuenet_ex_optim = optim.Adam(value_net_ex.parameters())
    valuenet_ex_optim.load_state_dict(checkpoint["valuenet_ex_optim"])

    # lpl_graph = checkpoint["lpl_graph"]
    simhash = checkpoint["simhash"]

    checkpoint.pop("policy_net")
    checkpoint.pop("value_net_in")
    checkpoint.pop("value_net_ex")
    checkpoint.pop("valuenet_in_optim")
    checkpoint.pop("valuenet_ex_optim")
    checkpoint.pop("i_epoch")

    return policy_net, value_net_in, value_net_ex, valuenet_in_optim, valuenet_ex_optim,\
            simhash, checkpoint
示例#2
0
def load_checkpoint(file_dir, i_epoch, layer_sizes, input_size, device='cuda'):
    checkpoint = torch.load(os.path.join(file_dir, "ckpt_eps%d.pt" % i_epoch),
                            map_location=device)

    policy_net = PolicyNet(layer_sizes).to(device)
    value_net = ValueNet(input_size).to(device)
    policy_net.load_state_dict(checkpoint["policy_net"])
    policy_net.train()
    value_net.load_state_dict(checkpoint["value_net"])
    value_net.train()

    policy_lr = checkpoint["policy_lr"]
    valuenet_lr = checkpoint["valuenet_lr"]

    valuenet_optim = optim.Adam(value_net.parameters(), lr=valuenet_lr)
    valuenet_optim.load_state_dict(checkpoint["valuenet_optim"])

    checkpoint.pop("policy_net")
    checkpoint.pop("value_net")
    checkpoint.pop("valuenet_optim")
    checkpoint.pop("i_epoch")
    checkpoint.pop("policy_lr")
    checkpoint.pop("valuenet_lr")

    return policy_net, value_net, valuenet_optim, checkpoint
示例#3
0
            loss += -torch.sum(torch.min(surr1, surr2))
            num += ratio.shape[0]

        loss /= torch.tensor(num, device=device, dtype=torch.float32)

        policy_candidate_optimizer.zero_grad()
        loss.backward(retain_graph=True)
        nn.utils.clip_grad_norm(policy_candidate.parameters(),
                                1.)  # Clip gradients
        policy_candidate_optimizer.step()

    policy_net = copy.deepcopy(policy_candidate).to(device)

    # Optimize value net for a given number of steps
    # Set value net in training mode
    value_net_in.train()
    value_net_ex.train()
    ex_rtg = memory.extrinsic_discounted_rtg(
        batch_size)  # Use undiscounted reward-to-go to fit the value net
    in_rtg = memory.intrinsic_rtg(batch_size)
    ex_val_est = []
    in_val_est = []

    print("\n\n\tUpdate Value Net for %d steps" % (num_vn_iter))

    for i in tqdm(range(num_vn_iter)):  # Use tqdm to show progress bar
        for j in range(batch_size):
            in_val_traj = value_net_in(
                torch.cat([
                    states[j],
                    torch.ones((states[j].shape[0], 1),
示例#4
0
    # policy_net.load_state_dict(policy_candiate.state_dict())
    policy_net = copy.deepcopy(policy_candidate).to(device)

    # # Vanilla Policy Gradient
    # for gae, act_log_prob in zip(ex_gae, old_act_log_prob):
    #     loss += - torch.sum(gae * act_log_prob)
    # loss /= torch.tensor(batch_size, device=device, dtype=torch.float32)
    #
    # policynet_optimizer.zero_grad()
    # loss.backward()
    # policynet_optimizer.step()

    # Optimize value net for a given number of steps
    # Set value net in training mode
    value_net.train()
    ex_rtg = memory.extrinsic_discounted_rtg(
        batch_size)  # Use undiscounted reward-to-go to fit the value net
    val_est = []

    print("\n\n\tUpdate Value Net for %d steps" % (num_vn_iter))

    for i in tqdm(range(num_vn_iter)):  # Use tqdm to show progress bar
        for j in range(batch_size):
            val_est_traj = value_net(states[j]).squeeze()
            val_est.append(val_est_traj)
        value_net_mse = value_net.optimize_model(val_est, ex_rtg,
                                                 valuenet_optimizer)

    # Reset Flags
    if not (render_each_episode):