Exemplo n.º 1
0
def load_checkpoint(file_dir, i_epoch, layer_sizes, input_size, device='cuda'):
    checkpoint = torch.load(os.path.join(file_dir, "ckpt_eps%d.pt" % i_epoch),
                            map_location=device)

    policy_net = PolicyNet(layer_sizes).to(device)
    policy_net.load_state_dict(checkpoint["policy_net"])
    policy_net.train()

    value_net_in = ValueNet(input_size).to(device)
    value_net_in.load_state_dict(checkpoint["value_net_in"])
    value_net_in.train()

    value_net_ex = ValueNet(input_size).to(device)
    value_net_ex.load_state_dict(checkpoint["value_net_ex"])
    value_net_ex.train()

    valuenet_in_optim = optim.Adam(value_net_in.parameters())
    valuenet_in_optim.load_state_dict(checkpoint["valuenet_in_optim"])

    valuenet_ex_optim = optim.Adam(value_net_ex.parameters())
    valuenet_ex_optim.load_state_dict(checkpoint["valuenet_ex_optim"])

    # lpl_graph = checkpoint["lpl_graph"]
    simhash = checkpoint["simhash"]

    checkpoint.pop("policy_net")
    checkpoint.pop("value_net_in")
    checkpoint.pop("value_net_ex")
    checkpoint.pop("valuenet_in_optim")
    checkpoint.pop("valuenet_ex_optim")
    checkpoint.pop("i_epoch")

    return policy_net, value_net_in, value_net_ex, valuenet_in_optim, valuenet_ex_optim,\
            simhash, checkpoint
Exemplo n.º 2
0
def load_checkpoint(file_dir, i_epoch, layer_sizes, input_size, device='cuda'):
    checkpoint = torch.load(os.path.join(file_dir, "ckpt_eps%d.pt" % i_epoch),
                            map_location=device)

    policy_net = PolicyNet(layer_sizes).to(device)
    value_net = ValueNet(input_size).to(device)
    policy_net.load_state_dict(checkpoint["policy_net"])
    policy_net.train()
    value_net.load_state_dict(checkpoint["value_net"])
    value_net.train()

    policy_lr = checkpoint["policy_lr"]
    valuenet_lr = checkpoint["valuenet_lr"]

    valuenet_optim = optim.Adam(value_net.parameters(), lr=valuenet_lr)
    valuenet_optim.load_state_dict(checkpoint["valuenet_optim"])

    checkpoint.pop("policy_net")
    checkpoint.pop("value_net")
    checkpoint.pop("valuenet_optim")
    checkpoint.pop("i_epoch")
    checkpoint.pop("policy_lr")
    checkpoint.pop("valuenet_lr")

    return policy_net, value_net, valuenet_optim, checkpoint
Exemplo n.º 3
0
# Turn on pyplot's interactive mode
# VERY IMPORTANT because otherwise training stats plot will hault
plt.ion()

# Create OpenAI gym environment
env = gym.make(env_name)
if is_unwrapped:
    env = env.unwrapped

# Get device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Current usable device is: ", device)

# Create the model. Two value net
policy_net = PolicyNet(layer_sizes).to(device)  # Policy network
value_net_ex = ValueNet(input_size).to(
    device)  # Value network for extrinsic reward
value_net_in = ValueNet(input_size + 1 + output_size).to(
    device)  # One additional input unit to indicate trajectory number

# Set up optimizer
valuenet_in_optimizer = optim.Adam(value_net_in.parameters())
valuenet_ex_optimizer = optim.Adam(value_net_ex.parameters())

# Set up memory
memory = Memory(capacity, GAMMA, LAMBDA, device=device)


# Define observation normalization function. Normalize state vector values to range [-1., 1.]
def state_nomalize(s):
Exemplo n.º 4
0
# Turn on pyplot's interactive mode
# VERY IMPORTANT because otherwise training stats plot will hault
plt.ion()

# Create OpenAI gym environment
env = gym.make(env_name)
if is_unwrapped:
    env = env.unwrapped

# Get device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Current usable device is: ", device)

# Create the model
policy_net = PolicyNet(layer_sizes).to(device)  # Policy network
value_net = ValueNet(input_size).to(device)  # Value network

# Set up memory
memory = Memory(capacity, GAMMA, LAMBDA, device)

# Set up optimizer
policynet_optimizer = optim.Adam(policy_net.parameters())
valuenet_optimizer = optim.Adam(value_net.parameters())

###################################################################
# Start training

# Dictionary for extra training information to save to checkpoints
training_info = {
    "epoch mean durations": [],