示例#1
0
文件: main.py 项目: thudzj/ar
    # define optimizer
    optimizer_alpha = torch.optim.SGD([alpha], lr=.3)

    # initialize the statistics of gradients
    grad_stats = GradStats(w, beta2_rmsprop=beta2_rmsprop)

    o_loss = torch.zeros_like(alpha)
    o_grad = torch.zeros_like(alpha)

    for outer_ite in range(outer_T):

        grad_stats.detach()
        ws = [w.detach_().requires_grad_()]
        for inner_ite in range(inner_T):
            i_loss = task.inner_loss(ws[-1],
                                     alpha) + weight_decay_w / 2. * (ws[-1])**2
            i_grad = torch.autograd.grad(outputs=i_loss,
                                         inputs=ws[-1],
                                         grad_outputs=torch.ones_like(i_loss),
                                         create_graph=True,
                                         retain_graph=True,
                                         only_inputs=True)[0]
            grad_stats.update(i_grad)

            # choose optimizer and lr with a policy net
            # policy_o = policy_net(torch.stack([o_loss, o_grad, i_loss, i_grad], 1))
            # opt, lr = policy_o.sample()
            # logp = logp(opt | policy_o) + logp(lr | policy_o)
            opt, lr = 5, 0
            ws.append(OPTS[opt](LRS[lr], ws[-1], i_grad, grad_stats))
示例#2
0
文件: simple_env.py 项目: thudzj/ar
class SimpleEnv:
    def __init__(self,
                 p=None,
                 q=None,
                 num_tasks_per_batch=None,
                 testing=False,
                 seed=1000):
        # hyper-parameters
        if seed:
            np.random.seed(seed)
            # torch.cuda.set_device(args.gpu)
            # cudnn.benchmark = True
            torch.manual_seed(seed)
            # cudnn.enabled=True
            # torch.cuda.manual_seed(args.seed)

        self.weight_decay_w = 0
        self.beta2_rmsprop = 0.9
        self.outer_T = 40
        self.inner_T = 6
        self.testing = testing
        self.num_tasks_per_batch = 1
        self.n_opts = len(OPTS)

        # initialize
        # self.reset(p, q)

    def step(self, action):
        # action: opt, lr
        opt = action[0]
        lr = action[1]
        episode_over = False
        if opt == self.n_opts:
            self.outer_count = self.outer_count + 1
            self.w = self.ws[-1]

            self.optimizer_alpha.zero_grad()
            self.task.outer_loss(self.w, self.alpha).sum().backward()
            self.optimizer_alpha.step()
            self.lr_scheduler.step()

            self.points_x.append(self.alpha[0].item())
            self.points_y.append(self.w[0].item())
            self.grad_stats.detach()
            self.ws = [self.w.detach_().requires_grad_()]

            self.o_loss = self.task.outer_loss(self.ws[-1], self.alpha)
            self.o_grad = torch.autograd.grad(outputs=self.o_loss,
                                              inputs=self.alpha,
                                              grad_outputs=torch.ones_like(
                                                  self.o_loss))[0]
            if self.outer_count == 1:
                self.o_loss_mv = self.o_loss.data.clone()
                self.o_grad_norm_mv = self.o_grad.data.norm()
            else:
                self.o_loss_mv = self.o_loss_mv * 0.9 + self.o_loss.data * 0.1
                self.o_grad_norm_mv = self.o_grad_norm_mv * 0.9 + self.o_grad.data.norm(
                ) * 0.1

            self.i_grad_change = -self.i_grad.data
            self.i_loss = self.task.inner_loss(
                self.ws[-1],
                self.alpha) + self.weight_decay_w / 2. * (self.ws[-1])**2
            self.i_grad = torch.autograd.grad(outputs=self.i_loss,
                                              inputs=self.ws[-1],
                                              grad_outputs=torch.ones_like(
                                                  self.i_loss),
                                              create_graph=True,
                                              retain_graph=True,
                                              only_inputs=True)[0]
            self.i_grad_change += self.i_grad.data

            ob = torch.stack([
                self.i_loss, self.i_loss - self.i_loss_mv,
                self.i_grad.norm().view(1),
                (self.i_grad.norm().add(1e-16).log() -
                 self.i_grad_norm_mv.add(1e-16).log()).view(1),
                self.i_grad_change.norm().view(1), self.o_loss,
                self.o_loss - self.o_loss_mv,
                torch.tensor([0.]),
                torch.tensor([float(self.outer_count) / self.outer_T])
            ], 1).detach()

            if self.inner_count > 0 and self.i_loss.item() <= 100:
                tmp = self.task.get_reward(self.alpha, self.ws[-1]).item()
                reward = tmp  # - self.last_reward
                self.last_reward = tmp
                self.inner_count = 0
            else:
                reward = -1
                episode_over = True
            if self.outer_count == self.outer_T:
                episode_over = True
            self.points.append(reward)
        else:
            self.grad_stats.update(self.i_grad)
            self.ws.append(OPTS[opt](LRS[lr], self.ws[-1], self.i_grad,
                                     self.grad_stats))
            self.i_loss_mv = self.i_loss_mv * 0.9 + self.i_loss.data * 0.1
            self.i_grad_norm_mv = self.i_grad_norm_mv * 0.9 + self.i_grad.data.norm(
            ) * 0.1
            self.inner_count += 1

            self.i_grad_change = -self.i_grad.data
            self.i_loss = self.task.inner_loss(
                self.ws[-1],
                self.alpha) + self.weight_decay_w / 2. * (self.ws[-1])**2
            self.i_grad = torch.autograd.grad(outputs=self.i_loss,
                                              inputs=self.ws[-1],
                                              grad_outputs=torch.ones_like(
                                                  self.i_loss),
                                              create_graph=True,
                                              retain_graph=True,
                                              only_inputs=True)[0]
            self.i_grad_change += self.i_grad.data

            ob = torch.stack([
                self.i_loss, self.i_loss - self.i_loss_mv,
                self.i_grad.norm().view(1),
                (self.i_grad.norm().add(1e-16).log() -
                 self.i_grad_norm_mv.add(1e-16).log()).view(1),
                self.i_grad_change.norm().view(1), self.o_loss,
                self.o_loss - self.o_loss_mv,
                torch.tensor([float(self.inner_count) / self.inner_T]),
                torch.tensor([float(self.outer_count) / self.outer_T])
            ], 1).detach()

            if self.inner_count < self.inner_T and self.i_loss.item() <= 100:
                reward = 0
            else:
                reward = -1
                episode_over = True

        return ob, reward, episode_over, {}

    def reset(self, p=None, q=None):
        # initialize a task
        if p and q:
            self.task = Task(p=np.array([p]),
                             q=np.array([q]),
                             num_tasks_per_batch=1)
        else:
            self.task = random_task(self.num_tasks_per_batch)

        # define variables
        self.w = torch.tensor(np.random.uniform(
            -2, 0, self.num_tasks_per_batch).astype(np.float32),
                              requires_grad=True)
        self.alpha = torch.tensor(np.random.uniform(
            0, 4, self.num_tasks_per_batch).astype(np.float32),
                                  requires_grad=True)
        self.ws = [self.w]

        self.start = (self.alpha.data.clone(), self.w.data.clone())

        # initialize the statistics of gradients for w
        self.grad_stats = GradStats(self.w, beta2_rmsprop=self.beta2_rmsprop)

        # define optimizer for alpha
        self.optimizer_alpha = torch.optim.SGD([self.alpha], lr=.1)
        self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
            self.optimizer_alpha, 0.95)

        self.o_loss = torch.zeros_like(self.alpha)
        self.o_grad = torch.zeros_like(self.alpha)

        self.points = []
        self.points_x = [self.alpha[0].item()]
        self.points_y = [self.w[0].item()]

        self.i_loss = self.task.inner_loss(
            self.ws[-1],
            self.alpha) + self.weight_decay_w / 2. * (self.ws[-1])**2
        self.i_grad = torch.autograd.grad(outputs=self.i_loss,
                                          inputs=self.ws[-1],
                                          grad_outputs=torch.ones_like(
                                              self.i_loss),
                                          create_graph=True,
                                          retain_graph=True,
                                          only_inputs=True)[0]

        self.i_loss_mv = torch.zeros_like(
            self.i_loss)  #self.i_loss.clone().detach_()
        self.i_grad_norm_mv = torch.zeros_like(
            self.i_grad)  #self.i_grad.data.norm()
        self.o_loss_mv = self.o_loss.clone().detach_()
        self.o_grad_norm_mv = self.o_grad.data.norm()
        self.i_grad_change = self.i_grad.clone().detach_()

        #whether in inner optimization
        self.inner_count = 0
        self.outer_count = 0
        self.last_reward = 0

        #print('reset')
        return torch.stack([
            self.i_loss, self.i_loss - self.i_loss_mv,
            self.i_grad.norm().view(1),
            (self.i_grad.norm().add(1e-16).log() -
             self.i_grad_norm_mv.add(1e-16).log()).view(1),
            self.i_grad_change.norm().view(1), self.o_loss,
            self.o_loss - self.o_loss_mv,
            torch.tensor([0.]),
            torch.tensor([0.])
        ], 1).detach()

    def render(self):
        plt.subplot(1, 2, 1)
        plt.plot(self.points)

        plt.subplot(1, 2, 2)
        plt.plot(self.points_x, self.points_y)
        # tmp = np.max(self.points_x)
        # plt.plot([0, tmp+1], [0, (tmp+1)*self.task.p[0]/2.], 'r--')
        plt.plot([self.task.optimal_alpha[0]], [self.task.optimal_w[0]], 'g^')
        plt.plot([self.start[0]], [self.start[1]], 'r*')
        plt.show()
        #plt.savefig("reward.pdf")

    def close(self):
        print('close')