Esempio n. 1
0
def train_v1(eps_start, eps_end, eps_decay, n_step, mem_capacity, num_episodes,
             embed_dim, iters):
    graph_generator = GraphGenerator(16, 16)
    memory = ReplayBuffer(mem_capacity)
    steps_done = 0
    gnn = Struc2Vec(embed_dim, iters)
    qnet = QNet(embed_dim)
    optimizer = optim.Adam(list(gnn.parameters()) + list(qnet.parameters()),
                           lr=0.0001,
                           weight_decay=1e-4)
    for e in range(num_episodes):
        node_labels, adj, edge_weights = graph_generator.next()
        vtx_feats = gnn(node_labels, adj, edge_weights)
        remaining_vertices = set([i for i in range(len(adj))])
        state = Variable(torch.zeros(embed_dim))
        curr_tour = []
        T = len(adj)
        rewards = []
        states = [state]

        for t in range(T):
            eps_threshold = util.get_eps_threshold(eps_start, eps_end,
                                                   eps_decay, steps_done)
            if random.random() > eps_threshold:
                # arg max action
                curr_vtx = arg_max_action(qnet, vtx_features,
                                          remaining_vertices)
            else:
                # random action
                curr_vtx = random.sample(remaining_vertices, 1)[0]

            action = vtx_feats[curr_vtx]
            # reward maintenance
            est_reward = qnet(state, curr_vtx)
            reward = get_reward(curr_tour, curr_vtx, edge_weights)
            rewards.append(reward)

            # update states
            curr_tour.append(curr_vtx)
            remaining_vertices.remove(curr_vtx)
            states.append(state + action)
            # wait till after doing the memory stuff to add the state

            # we only do these updates after n steps
            if t >= n_step:
                _, next_reward = arg_max_action(qnet, vtx_features,
                                                remaining_vertices)
                state_tminusn = states[-n_step]  # this is a torch tensor
                action_tminusn = vtx_feats[
                    curr_tour[-nstep]]  # this gives the vertex id
                reward_tminusn = sum(reward[-n:])
                memory.push(state_minusn, action_tminusn, reward_tminusn,
                            state, action)

                transitions = memory.sample(batch_size)
                # batch.state, batch.action, batch.reward, etc are now tuples
                # TODO: this looks a bit gross....
                batch = Transition(*zip(*batch))
                state_batch = torch.cat([s.unsqueeze(0) for s in batch.state],
                                        dim=0)
                action_batch = torch.cat(
                    [a.unsqueeze(0) for a in batch.action], dim=0)
                reward_batch = torch.cat(batch.reward)
                newstate_batch = torch.cat(
                    [ns.unsqueeze(0) for ns in batch.new_state], dim=0)
                max_action_batch = torch.cat(
                    [ma.unsqueeze(0) for ma in batch.max_action], dim=0)

                # TODO: make qnet allow batch
                # does the experience replay memory contain state/action/reward/next_state
                # from only the current episode's graph? Or can any graph seen before be
                # in the memory?
                # The argmax action is the thing taken at time t-n_step right?
                oldstate_action_value = qnet(state_batch, action_batch)
                newstate_action_value = qnet(new_state_batch, max_action_batch)
                expected_sa_values = reward_batch + gamma * newstate_action_value
                loss = F.mse_loss(oldstate_action_value, expected_sa_values)

                optimizer.zero_grad()
                loss.backward()
                # clamp grads?

            state += action
            steps_done += 1
Esempio n. 2
0
def train(graph_distr, epochs, batch_size, eps, n_step, discount, capacity,
          gcn_params, opt_params):
    '''
    graph_distr: object that wraps the graph generating distr
    epochs: int
    batch_size: int
    eps: float for exploration probability
    n: int, num steps for n-step Q-learning
    discount: float, how much to discount future state/action value
    capacity: int, number of episodes to keep in memory
    gcn_params: dictionary of graph conv net parameters
    opt_params: dictionary of params for optimizer
    '''
    qnet = QNetwork(gcn_params)
    memory = ReplayBuffer(capacity)
    opt_params['params'] = qnet.parameters()
    optimizer = get_optimizer(opt_params)

    for e in range(epochs):
        node_labels, edge_weights, adj = graph_distr.next()
        embedding = qnet.embed_graph(node_labels, edge_weights, adj)

        state = []  # s_0
        state_vec = Variable(torch.zeros((1, qnet.embed_dim)))
        state_vec_prev = None
        actions = []
        rewards = []
        s_complement = set(range(len(adj)))
        losses = []
        best_actions = []

        for t in range(len(adj)):
            if t > 0:
                v_best_t = qnet.best_action(state, list(s_complement),
                                            embedding)
            if random.random() < eps or t == 0:
                v_t = random.choice(tuple(s_complement))
            else:
                v_t = v_best_t

            action_vec = embedding[v_t].unsqueeze(0)
            vprev = None if t == 0 else state[-1]
            r_t = 0 if t == 0 else -edge_weights.data[vprev, v_t]
            s_complement.remove(v_t)

            # ideally store: s_0 , a_0, r_0, s_1, v_best_1
            # ideally store: s_1 , a_1, r_1, s_2, v_best_2
            if t >= n_step:
                new_state = state[:]
                # the action prev is what action got taken.
                # v_best_t must be the argmax action of the current state
                v_best_embedding = embedding[v_best_t].unsqueeze(0)
                episode = (state_vec_prev, action_vec_prev, rewards[-1],
                           state_vec, v_best_embedding)
                # should try to add v_best_t so we dont recompute later

                memory.push(*episode)
                if len(memory) > batch_size:
                    batch = memory.sample(batch_size)
                    batch_loss = qnet.backprop_batch(batch, optimizer)
                    losses.append(batch_loss)

            state_vec_prev = state_vec
            action_vec_prev = action_vec
            state.append(v_t)
            state_vec = state_vec + action_vec
            rewards.append(r_t)

        epoch_loss = torch.mean(torch.cat(losses))
        print('Epoch {} | avg loss: {:.3f} | Exploration rate: {:.3f}'.format(
            e, float(epoch_loss), eps))
        eps = update_exploration(eps)