Ejemplo n.º 1
0
def load_agent(path):
    # possible actions are
    # move[-1, 1],
    # strafe[-1, 1]
    # pitch[-1, 1]
    # turn[-1, 1]
    # jump 0/1

    # discreet actions
    # "move -0.5" "jump_forward",
    action_names = ["turn 0.15", "turn -0.15", "turn 0.01",
                    "turn 0.01", 'pitch 0.1', 'pitch -0.1',
                    'pitch 0.01', 'pitch -0.01']
    actionSet = [network.CategoricalAction(action_names)]

    policy_net = network.QVisualNetwork(actionSet, 2, 20, n_channels=3, activation=nn.ReLU(), batchnorm=True)
    target_net = network.QVisualNetwork(actionSet, 2, 20, n_channels=3, activation=nn.ReLU(), batchnorm=True)
    batch_size = 20
    my_simple_agent = network.DQN(policy_net, target_net, 0.9, batch_size, 450, capacity=2000)
    location = 'cuda' if torch.cuda.is_available() else 'cpu'
    if os.path.exists(path):
        logging.info('loading model from %s', path)
        data = torch.load(path, map_location=location)
        my_simple_agent.load_state_dict(data, strict=False)

    return my_simple_agent.to(location)
Ejemplo n.º 2
0
def load_agent(path):
    # possible actions are
    # move[-1, 1],
    # strafe[-1, 1]
    # pitch[-1, 1]
    # turn[-1, 1]
    # jump 0/1

    # for example:
    # actionSet = [network.ContiniousAction('move', -1, 1),
    #              network.ContiniousAction('strafe', -1, 1),
    #              network.ContiniousAction('pitch', -1, 1),
    #              network.ContiniousAction('turn', -1, 1),
    #              network.BinaryAction('jump')]

    # discreet actions
    action_names = ["turn 0.1", "turn -0.1", "move 0.9", "jump_forward" ]
    actionSet = [network.CategoricalAction(action_names)]

    policy_net = QVisualNetworkV2(3, actionSet, 0, 41,  n_channels=3, activation=nn.LeakyReLU(), batchnorm=False, num=256)
    target_net = QVisualNetworkV2(3, actionSet, 0, 41,  n_channels=3, activation=nn.LeakyReLU(), batchnorm=False, num=256)
    batch_size = 18

    transformer = common.make_noisy_transformers()
    my_simple_agent = network.DQN(policy_net, target_net, 0.99, batch_size, 450, capacity=7000, transform=transformer)
    location = 'cuda' if torch.cuda.is_available() else 'cpu'
    if os.path.exists(path):
        logging.info('loading model from %s', path)
        data = torch.load(path, map_location=location)
        my_simple_agent.load_state_dict(data, strict=False)

    return my_simple_agent.to(location)
Ejemplo n.º 3
0
def load_agent(path):
    # possible actions are
    # move[-1, 1],
    # strafe[-1, 1]
    # pitch[-1, 1]
    # turn[-1, 1]
    # jump 0/1

    # discreet actions
    # "move -0.5" "jump_forward",
    action_names = ["turn 0.15", "turn -0.15", "turn 0.01",
                    "turn 0.01", 'pitch 0.1', 'pitch -0.1',
                    'pitch 0.01', 'pitch -0.01']
    actionSet = [network.CategoricalAction(action_names)]

    transformer = common.make_noisy_transformers()
    policy_net = QVisualNetworkTree(1, actionSet, 0, 34,  n_channels=3, activation=nn.LeakyReLU(), batchnorm=False, num=256)
    target_net = QVisualNetworkTree(1, actionSet, 0, 34,  n_channels=3, activation=nn.LeakyReLU(), batchnorm=False, num=256)

    batch_size = 20
    my_simple_agent = network.DQN(policy_net, target_net, 0.9,
                                  batch_size, 450, capacity=2000,
                                  transform=transformer)

    if os.path.exists('agent_tree.pth'):
        location = 'cuda' if torch.cuda.is_available() else 'cpu'
        logging.info('loading model from agent_tree.pth')
        data = torch.load('agent_tree.pth', map_location=location)
        my_simple_agent.load_state_dict(data, strict=False)

    return my_simple_agent
Ejemplo n.º 4
0
def load_agent(path):
    # possible actions are
    # move[-1, 1],
    # strafe[-1, 1]
    # pitch[-1, 1]
    # turn[-1, 1]
    # jump 0/1

    # discreet actions
    # "move -0.5" "jump_forward",
    action_names = [
        "turn 0.20", "turn -0.20", "turn 0.01", "turn 0.01", 'pitch 0.1',
        'pitch -0.1', 'pitch 0.01', 'pitch -0.01'
    ]
    actionSet = [network.CategoricalAction(action_names)]
    n_out = len(common.visible_blocks) + 1

    location = 'cuda' if torch.cuda.is_available() else 'cpu'
    net = GoodPoint(8, n_out, n_channels=3, depth=False)
    model_weights = torch.load('goodpoint.pt', map_location=location)['model']
    net.load_checkpoint(model_weights)
    net.to(location)

    policy_net = SearchTree(actionSet,
                            2,
                            n_channels=3,
                            activation=nn.LeakyReLU(),
                            block_net=net)
    target_net = SearchTree(actionSet,
                            2,
                            n_channels=3,
                            activation=nn.LeakyReLU(),
                            block_net=net)

    batch_size = 20
    my_simple_agent = network.DQN(policy_net,
                                  target_net,
                                  0.9,
                                  batch_size,
                                  450,
                                  capacity=2000)

    if os.path.exists(path):
        logging.info('loading model from ' + path)
        data = torch.load(path, map_location=location)
        my_simple_agent.load_state_dict(data, strict=False)

    return my_simple_agent
Ejemplo n.º 5
0
def load_agent(path):
    # possible actions are
    # move[-1, 1],
    # strafe[-1, 1]
    # pitch[-1, 1]
    # turn[-1, 1]
    # jump 0/1

    # for example:
    # actionSet = [network.ContiniousAction('move', -1, 1),
    #              network.ContiniousAction('strafe', -1, 1),
    #              network.ContiniousAction('pitch', -1, 1),
    #              network.ContiniousAction('turn', -1, 1),
    #              network.BinaryAction('jump')]

    # discreet actions
    action_names = ["turn 0.15", "turn -0.15", "move 0.5", "jump_forward"]
    actionSet = [network.CategoricalAction(action_names)]

    policy_net = network.QNetwork(actionSet,
                                  grid_len=27,
                                  grid_w=5,
                                  target_enc_len=3,
                                  pos_enc_len=5)
    target_net = network.QNetwork(actionSet,
                                  grid_len=27,
                                  grid_w=5,
                                  target_enc_len=3,
                                  pos_enc_len=5)

    my_simple_agent = network.DQN(policy_net,
                                  target_net,
                                  0.9,
                                  70,
                                  450,
                                  capacity=2000)
    if os.path.exists(path):
        data = torch.load(path)
        my_simple_agent.load_state_dict(data, strict=False)

    return my_simple_agent