Esempio n. 1
0
class Sampler():

    def __init__(self,device,actionsize):
        self.samplenet = DQN(actionsize).to(device)
        self.targetnet = DQN(actionsize).to(device)
        self.opt = torch.optim.Adam(itertools.chain(self.samplenet.parameters()),lr=0.00001,betas=(0.0,0.9))
        self.device = device
        self.memory = ReplayMemory(1000,device=device)
        self.BATCH_SIZE = 10
        self.GAMMA = 0.99
        self.count = 0

    def select_action(self, model):
        self.samplenet.eval()
        action = self.samplenet(model.conv2.weight.data.view(-1,5,5).unsqueeze(0))
        return torch.max(action,1)[1]

    def step(self,state,action,next_state,reward,done):
        self.memory.push(state,action,next_state,reward,done)

        #don't bother if you don't have enough in memory
        if len(self.memory) >= self.BATCH_SIZE:
            self.optimize()

    def optimize(self):
        self.samplenet.train()
        self.targetnet.eval()
        s1,actions,r1,s2,d = self.memory.sample(self.BATCH_SIZE)

        #get old Q values and new Q values for belmont eq
        qvals = self.samplenet(s1)
        state_action_values = qvals.gather(1,actions[:,0].unsqueeze(1))
        with torch.no_grad():
            qvals_t = self.targetnet(s2)
            q1_t = qvals_t.max(1)[0].unsqueeze(1)

        expected_state_action_values = (q1_t * self.GAMMA) * (1-d) + r1

        #LOSS IS l2 loss of belmont equation
        loss = torch.nn.MSELoss()(state_action_values,expected_state_action_values)

        self.opt.zero_grad()
        loss.backward()
        self.opt.step()

        if self.count % 20 == 0:
            self.targetnet.load_state_dict(self.samplenet.state_dict())

        return loss.item()
Esempio n. 2
0
class Generator():
    def __init__(self, device, data):
        self.data = data
        self.actor = Actor().to(device)
        self.critic = Critic().to(device)
        #self.ctarget = Critic().to(device)
        self.actor_opt = torch.optim.Adam(itertools.chain(
            self.actor.parameters()),
                                          lr=0.0001,
                                          betas=(0.0, 0.9))
        self.critic_opt = torch.optim.Adam(itertools.chain(
            self.critic.parameters()),
                                           lr=0.001,
                                           betas=(0.0, 0.9))

        def init_weights(m):
            if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
                torch.nn.init.xavier_uniform_(m.weight.data)

        self.actor.apply(init_weights)
        self.critic.apply(init_weights)
        #self.ctarget.apply(init_weights)

        self.device = device
        self.memory = ReplayMemory(1000, device=device)
        self.batch_size = 5
        self.GAMMA = 0.99
        self.count = 0

    def select_action(self, imgs):
        with torch.no_grad():
            self.actor.eval()
            action = self.actor(imgs)
            return action

    def step(self, state, action, next_state, reward, done):
        self.memory.push(state, action, next_state, reward, done)

        if len(self.memory) >= self.batch_size:
            self.optimize()

    def optimize(self):
        self.actor.train()
        self.critic.train()
        #self.ctarget.eval()

        s1, a, r, s2, d = self.memory.sample(self.batch_size)

        #train the critic
        for reward, action in zip(r, a):
            qval = self.critic(action)
            avgQ = qval.mean().unsqueeze(0)
            loss = torch.nn.L1Loss()(avgQ, reward)
            self.critic_opt.zero_grad()
            loss.backward()
            self.critic_opt.step()

        #train the actor
        img, target = self.data[random.randint(0, len(self.data) - 1)]
        batch = self.actor(img)
        score = self.critic(batch)
        actor_loss = -score.mean()
        self.actor_opt.zero_grad()
        actor_loss.backward()
        self.actor_opt.step()

        #if self.count % 5 == 0:
        #    self.ctarget.load_state_dict(self.critic.state_dict())
        #self.count += 1

    def save(self):
        torch.save(self.actor.state_dict(), os.path.join('model', 'actor.pth'))
        torch.save(self.critic.state_dict(),
                   os.path.join('model', 'critic.pth'))