示例#1
0
 def train_g(self,
             epoch_num,
             mode='Adam',
             dataname='MNIST',
             logname='MNIST'):
     print(mode)
     if mode == 'SGD':
         g_optimizer = optim.SGD(self.G.parameters(),
                                 lr=self.lr,
                                 weight_decay=self.weight_decay)
         self.writer_init(logname=logname,
                          comments='SGD-%.3f_%.5f' %
                          (self.lr, self.weight_decay))
     elif mode == 'Adam':
         g_optimizer = optim.Adam(self.G.parameters(),
                                  lr=self.lr,
                                  weight_decay=self.weight_decay,
                                  betas=(0.5, 0.999))
         self.writer_init(logname=logname,
                          comments='ADAM-%.3f_%.5f' %
                          (self.lr, self.weight_decay))
     elif mode == 'RMSProp':
         g_optimizer = RMSprop(self.G.parameters(),
                               lr=self.lr,
                               weight_decay=self.weight_decay)
         self.writer_init(logname=logname,
                          comments='RMSProp-%.3f_%.5f' %
                          (self.lr, self.weight_decay))
     timer = time.time()
     for e in range(epoch_num):
         z = torch.randn((self.batchsize, self.z_dim),
                         device=self.device)  ## changed
         fake_x = self.G(z)
         d_fake = self.D(fake_x)
         # G_loss = g_loss(d_fake)
         G_loss = self.criterion(
             d_fake, torch.ones(d_fake.shape, device=self.device))
         g_optimizer.zero_grad()
         zero_grad(self.D.parameters())
         G_loss.backward()
         g_optimizer.step()
         gd = torch.norm(torch.cat(
             [p.grad.contiguous().view(-1) for p in self.D.parameters()]),
                         p=2)
         gg = torch.norm(torch.cat(
             [p.grad.contiguous().view(-1) for p in self.G.parameters()]),
                         p=2)
         self.plot_param(G_loss=G_loss)
         self.plot_grad(gd=gd, gg=gg)
         if self.count % self.show_iter == 0:
             self.show_info(timer=time.time() - timer, D_loss=G_loss)
             timer = time.time()
         self.count += 1
         if self.count % 5000 == 0:
             self.save_checkpoint('fixD_%s-%.5f_%d.pth' %
                                  (mode, self.lr, self.count),
                                  dataset=dataname)
     self.writer.close()
示例#2
0
def train():
  """ train model
  """
  try:
    os.makedirs(opt.checkpoints_dir)
  except OSError:
    pass
  ################################################
  #               load train dataset
  ################################################
  dataset = dset.CIFAR10(root=opt.dataroot,
                         download=True,
                         transform=transforms.Compose([
                           transforms.Resize(opt.image_size),
                           transforms.ToTensor(),
                           transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                         ]))

  assert dataset
  dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size,
                                           shuffle=True, num_workers=int(opt.workers))

  ################################################
  #               load model
  ################################################
  if torch.cuda.device_count() > 1:
    netG = torch.nn.DataParallel(Generator(ngpu))
  else:
    netG = Generator(ngpu)
  if opt.netG != '':
    netG.load_state_dict(torch.load(opt.netG, map_location=lambda storage, loc: storage))

  if torch.cuda.device_count() > 1:
    netD = torch.nn.DataParallel(Discriminator(ngpu))
  else:
    netD = Discriminator(ngpu)
  if opt.netD != '':
    netD.load_state_dict(torch.load(opt.netD, map_location=lambda storage, loc: storage))

  # set train mode
  netG.train()
  netG = netG.to(device)
  netD.train()
  netD = netD.to(device)
  print(netG)
  print(netD)

  ################################################
  #            Use RMSprop optimizer
  ################################################
  optimizerD = RMSprop(netD.parameters(), lr=opt.lr)
  optimizerG = RMSprop(netG.parameters(), lr=opt.lr)

  ################################################
  #               print args
  ################################################
  print("########################################")
  print(f"train dataset path: {opt.dataroot}")
  print(f"work thread: {opt.workers}")
  print(f"batch size: {opt.batch_size}")
  print(f"image size: {opt.image_size}")
  print(f"Epochs: {opt.n_epochs}")
  print(f"Noise size: {opt.nz}")
  print("########################################")
  print("Starting trainning!")
  for epoch in range(opt.n_epochs):
    for i, data in enumerate(dataloader):
      # get data
      real_imgs = data[0].to(device)
      batch_size = real_imgs.size(0)

      # Sample noise as generator input
      noise = torch.randn(batch_size, nz, 1, 1, device=device)

      ##############################################
      # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
      ##############################################

      optimizerD.zero_grad()

      # Generate a batch of images
      fake_imgs = netG(noise).detach()

      # Adversarial loss
      real_output = netD(real_imgs)
      fake_output = netD(fake_imgs)
      loss_D = -torch.mean(real_output) + torch.mean(fake_output)

      loss_D.backward()
      optimizerD.step()

      # Clip weights of discriminator
      for p in netD.parameters():
        p.data.clamp_(-opt.clip_value, opt.clip_value)

      ##############################################
      # (2) Update G network: maximize log(D(G(z)))
      ##############################################
      if i % opt.n_critic == 0:
        optimizerG.zero_grad()

        # Generate a batch of images
        fake_imgs = netG(noise)

        # Adversarial loss
        loss_G = -torch.mean(netD(fake_imgs))

        loss_G.backward()
        optimizerG.step()

        print(f"Epoch->[{epoch + 1:03d}/{opt.n_epochs:03d}] "
              f"Progress->{i / len(dataloader) * 100:4.2f}% "
              f"Loss_D: {loss_D.item():.4f} "
              f"Loss_G: {loss_G.item():.4f} ", end="\r")

      if i % 100 == 0:
        vutils.save_image(real_imgs, f"{opt.out_images}/real_samples.png", normalize=True)
        with torch.no_grad():
          fake = netG(fixed_noise).detach().cpu()
        vutils.save_image(fake, f"{opt.out_images}/fake_samples_epoch_{epoch + 1:03d}.png", normalize=True)

    # do checkpointing
    torch.save(netG.state_dict(), f"{opt.checkpoints_dir}/netG_epoch_{epoch + 1:03d}.pth")
    torch.save(netD.state_dict(), f"{opt.checkpoints_dir}/netD_epoch_{epoch + 1:03d}.pth")
示例#3
0
    def train_gd(self,
                 epoch_num,
                 mode='Adam',
                 dataname='MNIST',
                 logname='MNIST',
                 loss_type='JSD'):
        print(mode)
        if mode == 'SGD':
            d_optimizer = optim.SGD(self.D.parameters(),
                                    lr=self.lr,
                                    weight_decay=self.weight_decay)
            g_optimizer = optim.SGD(self.G.parameters(),
                                    lr=self.lr,
                                    weight_decay=self.weight_decay)
        elif mode == 'Adam':
            d_optimizer = optim.Adam(self.D.parameters(),
                                     lr=self.lr,
                                     weight_decay=self.weight_decay,
                                     betas=(0.5, 0.999))
            g_optimizer = optim.Adam(self.G.parameters(),
                                     lr=self.lr,
                                     weight_decay=self.weight_decay,
                                     betas=(0.5, 0.999))
        elif mode == 'RMSProp':
            d_optimizer = RMSprop(self.D.parameters(),
                                  lr=self.lr,
                                  weight_decay=self.weight_decay)
            g_optimizer = RMSprop(self.G.parameters(),
                                  lr=self.lr,
                                  weight_decay=self.weight_decay)
        self.writer_init(logname=logname,
                         comments='%s-%.3f_%.5f' %
                         (mode, self.lr, self.weight_decay))
        self.iswriter.writeheader()
        timer = time.time()
        start = time.time()
        for e in range(epoch_num):
            for real_x in self.dataloader:
                real_x = real_x[0].to(self.device)
                d_real = self.D(real_x)

                z = torch.randn((self.batchsize, self.z_dim),
                                device=self.device)  ## changed (shape)
                fake_x = self.G(z)
                d_fake = self.D(fake_x.detach())
                if loss_type == 'JSD':
                    loss = self.criterion(d_real, torch.ones(d_real.shape, device=self.device)) + \
                           self.criterion(d_fake, torch.zeros(d_fake.shape, device=self.device))
                else:
                    loss = d_fake.mean() - d_real.mean()

                # D_loss = gan_loss(d_real, d_fake)
                # D_loss = self.criterion(d_real, torch.ones(d_real.shape, device=self.device)) + \
                #          self.criterion(d_fake, torch.zeros(d_fake.shape, device=self.device))
                D_loss = loss + self.l2penalty()
                d_optimizer.zero_grad()
                D_loss.backward()
                d_optimizer.step()

                z = torch.randn((self.batchsize, self.z_dim),
                                device=self.device)  ## changed
                fake_x = self.G(z)
                d_fake = self.D(fake_x)
                # G_loss = g_loss(d_fake)
                if loss_type == 'JSD':
                    G_loss = self.criterion(
                        d_fake, torch.ones(d_fake.shape, device=self.device))
                else:
                    G_loss = -d_fake.mean()
                g_optimizer.zero_grad()
                G_loss.backward()
                g_optimizer.step()
                gd = torch.norm(torch.cat([
                    p.grad.contiguous().view(-1) for p in self.D.parameters()
                ]),
                                p=2)
                gg = torch.norm(torch.cat([
                    p.grad.contiguous().view(-1) for p in self.G.parameters()
                ]),
                                p=2)

                self.plot_param(D_loss=D_loss, G_loss=G_loss)
                self.plot_grad(gd=gd, gg=gg)
                self.plot_d(d_real, d_fake)

                if self.count % self.show_iter == 0:
                    self.show_info(timer=time.time() - timer,
                                   D_loss=D_loss,
                                   G_loss=G_loss)
                    timer = time.time()
                self.count += 1
                if self.count % 2000 == 0:
                    is_mean, is_std = self.get_inception_score(batch_num=500)
                    print(is_mean, is_std)
                    self.iswriter.writerow({
                        'iter': self.count,
                        'is_mean': is_mean,
                        'is_std': is_std,
                        'time': time.time() - start
                    })
                    self.save_checkpoint('%s-%.5f_%d.pth' %
                                         (mode, self.lr, self.count),
                                         dataset=dataname)
        self.writer.close()
        self.save_checkpoint('DIM64%s-%.5f_%d.pth' %
                             (mode, self.lr, self.count),
                             dataset=dataname)
示例#4
0
    with tqdm(postfix=[{"running_reward": 0.0}]) as pbar:

        for it in range(train_steps):
            with torch.no_grad():
                agent.eval()
                agent.exploration_rate *= exploration_decay
                exp = gather_experience(exp_it, batch_size=batch_size)
                estimated_return = calc_estimated_return(agent, exp)

            agent.train()
            loss_value = calc_loss(agent, estimated_return, exp["obs"],
                                   exp["action"])
            optimizer.zero_grad()
            loss_value.backward()
            optimizer.step()
            update_progess_bar(
                pbar, {"running_reward": float(np.mean(exp["next_reward"]))})


if __name__ == "__main__":

    def get_params(params_json_file="constants.json"):
        with open(params_json_file) as f:
            constants = json.load(f)
        return constants

    params = get_params()
    file_path_dict = params["db_file_paths"]
    DATABASE_FILE_PATH = file_path_dict["database"]
    DICT_FILE_PATH = file_path_dict["dict"]
示例#5
0
    def model_train(self, epoch_offset=0):
        create_dir(MODEL_SAVE_PATH)
        loss_for_regression = MSELoss()
        img_coors_json = read_json_file(BBOX_XYWH_JSON_PATH)

        optimizer = RMSprop(self.parameters(),
                            lr=LEARNING_RATE,
                            momentum=MOMENTUM)
        # optimizer = Adam(self.parameters(), lr=LEARNING_RATE)
        #         optimizer = SGD(self.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)

        scheduler = StepLR(optimizer,
                           step_size=SCHEDULER_STEP,
                           gamma=SCHEDULER_GAMMA)

        for epoch in range(EPOCHS):
            epoch_loss = 0.0
            scheduler.step(epoch)
            LOGGER.debug('Epoch: %s, Current Learning Rate: %s',
                         str(epoch + epoch_offset), str(scheduler.get_lr()))
            for image, coors in img_coors_json.items():
                path_of_image = NORMALISED_IMAGES_PATH + image
                path_of_image = path_of_image.replace('%', '_')
                img = cv2.imread(path_of_image)
                img = torch.tensor(img).float().permute(2, 0, 1).unsqueeze(0)
                img = img.to(self.device)
                predicted_width, predicted_height, predicted_midpoint = self.forward(
                    img)

                #all are scaled
                mp_x = coors[0][0]
                mp_y = coors[0][1]
                mp = torch.cat((torch.tensor([[mp_x]]).to(
                    self.device), torch.tensor([[mp_y]]).to(self.device)),
                               dim=1).float()

                w = coors[0][2]
                h = coors[0][3]
                loss1 = loss_for_regression(
                    predicted_height,
                    torch.tensor([[h]]).float().to(self.device))
                loss2 = loss_for_regression(
                    predicted_width,
                    torch.tensor([[w]]).float().to(self.device))
                loss3 = loss_for_regression(predicted_midpoint,
                                            mp.to(self.device))
                loss = loss1 + loss2 + loss3 / 2
                optimizer.zero_grad()
                loss.backward()
                clip_grad_norm(self.parameters(), 0.5)
                optimizer.step()
                epoch_loss = epoch_loss + loss.item()

            if epoch % 5 == 0:
                print('epoch: ' + str(epoch) + ' ' + 'loss: ' +
                      str(epoch_loss))
            if epoch % EPOCH_SAVE_INTERVAL == 0:
                print('saving')
                torch.save(
                    self.state_dict(), MODEL_SAVE_PATH + 'model_epc_' +
                    str(epoch + epoch_offset) + '.pt')
        torch.save(
            self.state_dict(),
            MODEL_SAVE_PATH + 'model_epc_' + str(epoch + epoch_offset) + '.pt')
示例#6
0
def train():
    """ train model
  """
    dataset = dset.ImageFolder(root=opt.dataroot,
                               transform=transforms.Compose([
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5),
                                                        (0.5, 0.5, 0.5)),
                               ]))

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=opt.batch_size,
                                             num_workers=int(opt.workers))

    if torch.cuda.device_count() > 1:
        netG = torch.nn.DataParallel(Generator(ngpu))
        netD = torch.nn.DataParallel(Discriminator(ngpu))
    else:
        netG = Generator(ngpu)
        netD = Discriminator(ngpu)
    netD.apply(weights_init)
    netG.apply(weights_init)
    netD.to(device)
    netG.to(device)
    if opt.netG != "":
        netG.load_state_dict(
            torch.load(opt.netG, map_location=lambda storage, loc: storage))
    if opt.netD != "":
        netD.load_state_dict(
            torch.load(opt.netD, map_location=lambda storage, loc: storage))
    print(netG)
    print(netD)

    optimizerD = RMSprop(netD.parameters(), lr=opt.lr)
    optimizerG = RMSprop(netG.parameters(), lr=opt.lr)

    fixed_noise = torch.randn(opt.batch_size, opt.nz, 1, 1, device=device)

    print("########################################")
    print(f"Train dataset path: {opt.dataroot}")
    print(f"Batch size: {opt.batch_size}")
    print(f"Image size: {opt.img_size}")
    print(f"Epochs: {opt.epochs}")
    print("########################################")
    print("Starting trainning!")
    for epoch in range(opt.epochs):
        for i, data in enumerate(dataloader):
            # get data
            real_imgs = data[0].to(device)
            batch_size = real_imgs.size(0)

            # Sample noise as generator input
            noise = torch.randn(batch_size, 100, 1, 1, device=device)

            ##############################################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ##############################################

            optimizerD.zero_grad()

            # Generate a batch of images
            fake_imgs = netG(noise).detach()

            # Adversarial loss
            errD = -torch.mean(netD(real_imgs)) + torch.mean(netD(fake_imgs))

            errD.backward()
            optimizerD.step()

            # Clip weights of discriminator
            for p in netD.parameters():
                p.data.clamp_(-opt.clip_value, opt.clip_value)

            ##############################################
            # (2) Update G network: maximize log(D(G(z)))
            ##############################################
            if i % opt.n_critic == 0:
                optimizerG.zero_grad()

                # Generate a batch of images
                fake_imgs = netG(noise)

                # Adversarial loss
                errG = -torch.mean(netD(fake_imgs))

                errG.backward()
                optimizerG.step()
                print(
                    f"Epoch->[{epoch + 1:3d}/{opt.epochs}] "
                    f"Progress->[{i}/{len(dataloader)}] "
                    f"Loss_D: {errD.item():.4f} "
                    f"Loss_G: {errG.item():.4f} ",
                    end="\r")

            if i % 100 == 0:
                vutils.save_image(real_imgs,
                                  f"{opt.outf}/simpson_real_samples.png",
                                  normalize=True)
                fake = netG(fixed_noise)
                vutils.save_image(
                    fake.detach(),
                    f"{opt.outf}/simpson_fake_samples_epoch_{epoch + 1}.png",
                    normalize=True)

        # do checkpointing
        torch.save(netG.state_dict(), f"{opt.checkpoint_dir}/simpson_G.pth")
        torch.save(netD.state_dict(), f"{opt.checkpoint_dir}/simpson_D.pth")
示例#7
0
class DQN(Agent):
    def __init__(self,
                 algo_params,
                 env,
                 transition_tuple=None,
                 path=None,
                 seed=-1):
        # environment
        self.env = env
        self.env.seed(seed)
        obs = self.env.reset()
        self.frame_skip = algo_params['frame_skip']
        self.original_image_shape = obs.shape
        self.image_size = algo_params['image_size']
        algo_params.update({
            'state_shape': (self.frame_skip, self.image_size, self.image_size),
            'action_dim':
            self.env.action_space.n,
            'init_input_means':
            None,
            'init_input_vars':
            None
        })
        # training args
        self.training_epoch = algo_params['training_epoch']
        self.training_frame_per_epoch = algo_params['training_frame_per_epoch']
        self.printing_gap = algo_params['printing_gap']
        self.testing_gap = algo_params['testing_gap']
        self.testing_frame_per_epoch = algo_params['testing_frame_per_epoch']
        self.saving_gap = algo_params['saving_gap']

        # args for compatibility and are NOT to be used
        algo_params['actor_learning_rate'] = 0.0
        algo_params['observation_normalization'] = False
        algo_params['tau'] = 1.0
        super(DQN, self).__init__(algo_params,
                                  transition_tuple=transition_tuple,
                                  image_obs=True,
                                  action_type='discrete',
                                  path=path,
                                  seed=seed)
        # torch
        self.network_dict.update({
            'Q':
            DQNetwork(self.state_shape, self.action_dim).to(self.device),
            'Q_target':
            DQNetwork(self.state_shape, self.action_dim).to(self.device)
        })
        self.network_keys_to_save = ['Q', 'Q_target']
        self.Q_optimizer = RMSprop(self.network_dict['Q'].parameters(),
                                   lr=self.critic_learning_rate,
                                   eps=algo_params['RMSprop_epsilon'],
                                   weight_decay=algo_params['Q_weight_decay'],
                                   centered=True)
        self._soft_update(self.network_dict['Q'],
                          self.network_dict['Q_target'],
                          tau=1)
        # behavioural policy args (exploration)
        epsilong_decay_frame = algo_params[
            'epsilon_decay_fraction'] * self.training_epoch * self.training_frame_per_epoch
        self.exploration_strategy = LinearDecayGreedy(
            decay=epsilong_decay_frame, rng=self.rng)
        # training args
        self.warmup_step = algo_params['warmup_step']
        self.Q_target_update_interval = algo_params['Q_target_update_interval']
        self.last_frame = None
        self.frame_buffer = [None, None, None, None]
        self.frame_count = 0
        self.reward_clip = algo_params['reward_clip']
        # statistic dict
        self.statistic_dict.update({
            'epoch_return': [],
            'epoch_test_return': []
        })

    def run(self, test=False, render=False, load_network_ep=None, sleep=0):
        if test:
            num_frames = self.testing_frame_per_epoch
            if load_network_ep is not None:
                print("Loading network parameters...")
                self._load_network(ep=load_network_ep)
            print("Start testing...")
        else:
            num_frames = self.training_frame_per_epoch
            print("Start training...")

        for epo in range(self.training_epoch):
            ep_return = self._interact(render,
                                       test,
                                       epo=epo,
                                       num_frames=num_frames,
                                       sleep=sleep)
            self.statistic_dict['epoch_return'].append(ep_return)
            print("Finished training epoch %i, " % epo,
                  "full return %0.1f" % ep_return)

            if (epo % self.testing_gap == 0) and (epo != 0) and (not test):
                print("Evaluate agent at epoch %i..." % epo)
                ep_test_return = self._interact(
                    render,
                    test=True,
                    epo=epo,
                    num_frames=self.testing_frame_per_epoch)
                self.statistic_dict['epoch_test_return'].append(ep_test_return)
                print("Finished testing epoch %i, " % epo,
                      "test return %0.1f" % ep_test_return)

            if (epo % self.saving_gap == 0) and (epo != 0) and (not test):
                self._save_network(ep=epo)

        if not test:
            print("Finished training")
            print("Saving statistics...")
            self._save_statistics()
            self._plot_statistics()
        else:
            print("Finished testing")

    def _interact(self,
                  render=False,
                  test=False,
                  epo=0,
                  num_frames=0,
                  sleep=0):
        ep_return = 0
        self.frame_count = 0
        while self.frame_count < num_frames:
            done = False
            obs = self.env.reset()
            obs = self._pre_process([obs])
            num_lives = self.env.ale.lives()
            # start a new episode
            while not done:
                if render:
                    self.env.render()
                if self.env_step_count < self.warmup_step:
                    action = self.env.action_space.sample()
                else:
                    action = self._select_action(obs, test=test)

                # action repeat, aggregated reward
                frames = []
                added_reward = 0
                for _ in range(self.frame_skip):
                    new_obs, reward, done, info = self.env.step(action)
                    frames.append(new_obs.copy())
                    added_reward += reward
                time.sleep(sleep)
                # frame gray scale, resize, stack
                new_obs = self._pre_process(frames[-2:])
                # reward clipped into [-1, 1]
                reward = max(min(added_reward, self.reward_clip),
                             -self.reward_clip)

                if num_lives > self.env.ale.lives():
                    # treat the episode as terminated when the agent loses a live in the game
                    num_lives = self.env.ale.lives()
                    done_to_save = True
                    # set the reward to be -reward_bound
                    reward = -self.reward_clip
                    # clear frame buffer when the agent starts with a new live
                    self.frame_buffer = [None, None, None, None]
                else:
                    done_to_save = done

                # return to be recorded
                ep_return += reward
                if not test:
                    self._remember(obs, action, new_obs, reward,
                                   1 - int(done_to_save))
                    if (self.env_step_count % self.update_interval
                            == 0) and (self.env_step_count > self.warmup_step):
                        self._learn()
                obs = new_obs
                self.frame_count += 1
                self.env_step_count += 1

                if self.frame_count % self.printing_gap == 0 and self.frame_count != 0:
                    print("Epoch %i" % epo,
                          "passed frames %i" % self.frame_count,
                          "return %0.1f" % ep_return)

            # clear frame buffer at the end of an episode
            self.frame_buffer = [None, None, None, None]
        return ep_return

    def _select_action(self, obs, test=False):
        if test:
            obs = T.tensor([obs], dtype=T.float32).to(self.device)
            with T.no_grad():
                action = self.network_dict['Q_target'].get_action(obs)
            return action
        else:
            if self.exploration_strategy(self.env_step_count):
                action = self.rng.integers(self.action_dim)
            else:
                obs = T.tensor([obs], dtype=T.float32).to(self.device)
                with T.no_grad():
                    action = self.network_dict['Q_target'].get_action(obs)
            return action

    def _learn(self, steps=None):
        if len(self.buffer) < self.batch_size:
            return
        if steps is None:
            steps = self.optimizer_steps

        for i in range(steps):
            if self.prioritised:
                batch, weights, inds = self.buffer.sample(self.batch_size)
                weights = T.tensor(weights).view(self.batch_size,
                                                 1).to(self.device)
            else:
                batch = self.buffer.sample(self.batch_size)
                weights = T.ones(size=(self.batch_size, 1)).to(self.device)
                inds = None

            inputs = T.tensor(batch.state, dtype=T.float32).to(self.device)
            actions = T.tensor(batch.action,
                               dtype=T.long).unsqueeze(1).to(self.device)
            inputs_ = T.tensor(batch.next_state,
                               dtype=T.float32).to(self.device)
            rewards = T.tensor(batch.reward,
                               dtype=T.float32).unsqueeze(1).to(self.device)
            done = T.tensor(batch.done,
                            dtype=T.float32).unsqueeze(1).to(self.device)

            if self.discard_time_limit:
                done = done * 0 + 1

            with T.no_grad():
                maximal_next_values = self.network_dict['Q_target'](
                    inputs_).max(1)[0].view(self.batch_size, 1)
                value_target = rewards + done * self.gamma * maximal_next_values

            self.Q_optimizer.zero_grad()
            value_estimate = self.network_dict['Q'](inputs).gather(1, actions)
            loss = F.smooth_l1_loss(value_estimate,
                                    value_target.detach(),
                                    reduction='none')
            (loss * weights).mean().backward()
            self.Q_optimizer.step()

            if self.prioritised:
                assert inds is not None
                self.buffer.update_priority(
                    inds, np.abs(loss.cpu().detach().numpy()))

            self.statistic_dict['critic_loss'].append(
                loss.detach().mean().cpu().numpy().item())

            if self.optim_step_count % self.Q_target_update_interval == 0:
                self._soft_update(self.network_dict['Q'],
                                  self.network_dict['Q_target'],
                                  tau=1)

            self.optim_step_count += 1

    def _pre_process(self, frames):
        # This method takes 2 frames and does the following things
        # 1. Max-pool two consecutive frames to deal with flickering
        # 2. Convert images to Y channel: Y = 0.299*R + 0.587*G + (1 - (0.299 + 0.587))*B
        # 3. Resize images to 84x84
        # 4. Stack it with previous frames as one observation
        # output: 1000, 1200, 1230, 1234, 2345, 3456...
        if len(frames) == 1:
            frames.insert(0, np.zeros(self.original_image_shape))
        assert len(frames) == 2

        last_img = frames[0].copy()
        img = frames[1].copy()
        img = np.max([last_img, img], axis=0)
        img = img.transpose((-1, 0, 1))
        img_Y = 0.299 * img[0] + 0.587 * img[1] + (1 -
                                                   (0.299 + 0.587)) * img[2]
        img_Y_resized = np.asarray(
            Image.fromarray(img_Y).resize((self.image_size, self.image_size),
                                          Image.BILINEAR))
        for i in range(len(self.frame_buffer)):
            if self.frame_buffer[i] is None:
                self.frame_buffer[i] = img_Y_resized.copy()
                break

            if i == (len(self.frame_buffer) - 1):
                del self.frame_buffer[0]
                self.frame_buffer.append(img_Y_resized.copy())

        obs = []
        for i in range(len(self.frame_buffer)):
            if self.frame_buffer[i] is not None:
                obs.append(self.frame_buffer[i].copy())
            else:
                obs.append(np.zeros((self.image_size, self.image_size)))
        return np.array(obs, dtype=np.uint8)
def train_sim(epoch_num=10,
              optim_type='ACGD',
              startPoint=None,
              start_n=0,
              z_dim=128,
              batchsize=64,
              l2_penalty=0.0,
              momentum=0.0,
              log=False,
              loss_name='WGAN',
              model_name='dc',
              model_config=None,
              data_path='None',
              show_iter=100,
              logdir='test',
              dataname='CIFAR10',
              device='cpu',
              gpu_num=1):
    lr_d = 1e-4
    lr_g = 1e-4
    dataset = get_data(dataname=dataname, path=data_path)
    dataloader = DataLoader(dataset=dataset,
                            batch_size=batchsize,
                            shuffle=True,
                            num_workers=4)
    D, G = get_model(model_name=model_name, z_dim=z_dim, configs=model_config)
    D.apply(weights_init_d).to(device)
    G.apply(weights_init_g).to(device)

    optim_d = RMSprop(D.parameters(), lr=lr_d)
    optim_g = RMSprop(G.parameters(), lr=lr_g)

    if startPoint is not None:
        chk = torch.load(startPoint)
        D.load_state_dict(chk['D'])
        G.load_state_dict(chk['G'])
        optim_d.load_state_dict(chk['d_optim'])
        optim_g.load_state_dict(chk['g_optim'])
        print('Start from %s' % startPoint)
    if gpu_num > 1:
        D = nn.DataParallel(D, list(range(gpu_num)))
        G = nn.DataParallel(G, list(range(gpu_num)))
    timer = time.time()
    count = 0
    if 'DCGAN' in model_name:
        fixed_noise = torch.randn((64, z_dim, 1, 1), device=device)
    else:
        fixed_noise = torch.randn((64, z_dim), device=device)
    for e in range(epoch_num):
        print('======Epoch: %d / %d======' % (e, epoch_num))
        for real_x in dataloader:
            real_x = real_x[0].to(device)
            d_real = D(real_x)
            if 'DCGAN' in model_name:
                z = torch.randn((d_real.shape[0], z_dim, 1, 1), device=device)
            else:
                z = torch.randn((d_real.shape[0], z_dim), device=device)
            fake_x = G(z)
            d_fake = D(fake_x)
            loss = get_loss(name=loss_name,
                            g_loss=False,
                            d_real=d_real,
                            d_fake=d_fake,
                            l2_weight=l2_penalty,
                            D=D)
            D.zero_grad()
            G.zero_grad()
            loss.backward()
            optim_d.step()
            optim_g.step()

            if count % show_iter == 0:
                time_cost = time.time() - timer
                print('Iter :%d , Loss: %.5f, time: %.3fs' %
                      (count, loss.item(), time_cost))
                timer = time.time()
                with torch.no_grad():
                    fake_img = G(fixed_noise).detach()
                    path = 'figs/%s_%s/' % (dataname, logdir)
                    if not os.path.exists(path):
                        os.makedirs(path)
                    vutils.save_image(fake_img,
                                      path + 'iter_%d.png' % (count + start_n),
                                      normalize=True)
                save_checkpoint(
                    path=logdir,
                    name='%s-%s%.3f_%d.pth' %
                    (optim_type, model_name, lr_g, count + start_n),
                    D=D,
                    G=G,
                    optimizer=optim_d,
                    g_optimizer=optim_g)
            if wandb and log:
                wandb.log({
                    'Real score': d_real.mean().item(),
                    'Fake score': d_fake.mean().item(),
                    'Loss': loss.item()
                })
            count += 1
示例#9
0
class DQNAgent(TrainingAgent):
    def __init__(self, input_shape, action_space, seed, device, model, gamma,
                 alpha, tau, batch_size,update, replay, buffer_size, env,
                 decay = 200, path = 'model',num_epochs= 0, max_step = 50000, learn_interval = 20):

        '''Initialise a DQNAgent Object
        buffer_size : size of replay buffer to sample from
        gamma       : discount rate
        alpha       : learn rate
        replay.     : after which replay buffer loading to be started
        update      : update interval of model parameters every x instances of back propagation
        replay.     : after which replay buffer loading to be started
        learn_interval: tick for learning rate
        '''
        super(DQNAgent,self).__init__( input_shape ,action_space ,seed ,device,model,
                                        gamma, alpha, tau, batch_size, max_step, env,num_epochs ,path)
        self.buffer_size = buffer_size
        self.update = update
        self.replay = replay
        self.interval = learn_interval
        # Q-Network
        self.policy_net = self.model(input_shape, action_space).to(self.device)
        self.target_net = self.model(input_shape, action_space).to(self.device)
        self.optimiser = RMSprop(self.policy_net.parameters(), lr=self.alpha)
        # Replay Memory
        self.memory = ReplayMemory(self.buffer_size, self.batch_size, self.seed, self.device)
        # Timestep
        self.t_step = 0
        self.l_step = 0

        self.EPSILON_START = 1.0
        self.EPSILON_FINAL = 0.02
        self.EPS_DECAY = decay
        self.epsilon_delta = lambda frame_idx: self.EPSILON_FINAL + (self.EPSILON_START - self.EPSILON_FINAL) * exp(-1. * frame_idx / self.EPS_DECAY)

    def step(self, state, action, reward, next_state, done):
        '''
        Step of learning and taking environment action.
        '''

        # Save experience into replay buffer
        self.memory.add(state, action, reward, next_state, done)

        # Learn every update % timestep
        self.t_step = (self.t_step + 1) % self.interval

        if self.t_step == 0:
            # if there are enough samples in the memory, get a random subset and learn
            if len(self.memory) > self.replay:
                experience = self.memory.sample()
                print('learning')
                self.learn(experience)


    def action(self, state, eps=0.):
        ''' Returns action for given state as per current policy'''
        #Unpack the state
        state = torch.from_numpy(state).unsqueeze(0).to(self.device)
        if rand.rand() > eps:
            # Eps Greedy action selections
            action_val = self.policy_net(state)
            return np.argmax(action_val.cpu().data.numpy())
        else:
            return random.choice(np.arange(self.action_space))

    def learn(self, exp):
        state, action, reward, next_state, done = exp

        # Get expected Q values from Policy Model
        Q_expt_current = self.policy_net(state)
        Q_expt = Q_expt_current.gather(1, action.unsqueeze(1)).squeeze(1)

        # Get max predicted Q values for next state from target model
        Q_target_next = self.target_net(next_state).detach().max(1)[0]
        # Compute Q targets for current states
        Q_target = reward + (self.gamma * Q_target_next * (1 - done))

        # Compute Loss
        loss = torch.nn.functional.mse_loss(Q_expt, Q_target)

        # Minimize loss
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()
        self.l_step = (self.l_step +1) % self.update
        if self.t_step == 0:
            self.soft_update(self.policy_net, self.target_net, self.tau)

    def model_dict(self)-> dict:
        ''' To save models'''
        return {'policy_net': self.policy_net.state_dict(), 'target_net': self.target_net.state_dict(),
                'optimizer': self.optimiser.state_dict(), 'num_epoch': self.num_epochs,'scores': self.scores}

    def load_model(self, state_dict,eval = True):
        '''Load Parameters and Model Information from prior training for continuation of training'''
        self.policy_net.load_state_dict(state_dict['policy_net'])
        self.target_net.load_state_dict(state_dict['target_net'])
        self.optimiser.load_state_dict(state_dict['optimizer'])
        self.scores = state_dict['scores']
        if eval:
            self.policy_net.eval()
            self.target_net.eval()
        else:
            self.policy_net.train()
            self.target_net.train()
        #Load the model
        self.num_epochs = state_dict['num_epoch']

    # θ'=θ×τ+θ'×(1−τ)
    def soft_update(self, policy_model, target_model, tau):
        for t_param, p_param in zip(target_model.parameters(), policy_model.parameters()):
            t_param.data.copy_(tau * p_param.data + (1.0 - tau) * t_param.data)

    def train(self, n_episodes=1000,render= False):
        """
        n_episodes: maximum number of training episodes
        Saves Model every 100 Epochs
        """
        filename = get_filename()

        self.env.render(render)
        # Toggles the render on
        for i_episode in range(n_episodes):
            self.num_epochs += 1
            state = self.stack_frames(None, self.reset(), True)
            score = 0
            eps = self.epsilon_delta(self.num_epochs)

            while True:
                action = self.action(state, eps)

                next_state, reward, done, info = self.env.step(action)

                score += reward

                next_state = self.stack_frames(state, next_state, False)

                self.step(state, action, reward, next_state, done)
                state = next_state
                if done:
                    break
            self.scores.append(score)  # save most recent score

            # Every 100 training
            if i_episode % 100 == 0:
                self.save_obj(self.model_dict(), os.path.join(self.path, filename))
                print(f"Creating plot")
                # Plot a figure
                fig = plt.figure()

                # Add a subplot
                # ax = fig.add_subplot(111)

                # Plot the graph
                plt.plot(np.arange(len(self.scores)), self.scores)

                # Add labels
                plt.xlabel('Episode #')
                plt.ylabel('Score')

                # Save the plot
                plt.savefig(f'{i_episode} plot.png')
                print(f"Plot saved")

        # Return the scores.
        return self.scores
示例#10
0
    def traing(self,
               epoch_num,
               mode='Adam',
               dataname='MNIST',
               logname='MNIST'):
        print(mode)
        if mode == 'SGD':
            d_optimizer = optim.SGD(self.D.parameters(),
                                    lr=self.lr_d,
                                    weight_decay=self.weight_decay)
            g_optimizer = optim.SGD(self.G.parameters(),
                                    lr=self.lr_d,
                                    weight_decay=self.weight_decay)
            self.writer_init(logname=logname,
                             comments='SGD-%.3f_%.5f' %
                             (self.lr_d, self.weight_decay))
        elif mode == 'Adam':
            d_optimizer = optim.Adam(self.D.parameters(),
                                     lr=self.lr_d,
                                     weight_decay=self.weight_decay,
                                     betas=(0.5, 0.999))
            g_optimizer = optim.Adam(self.G.parameters(),
                                     lr=self.lr_d,
                                     weight_decay=self.weight_decay,
                                     betas=(0.5, 0.999))
            self.writer_init(logname=logname,
                             comments='ADAM-%.3f_%.5f' %
                             (self.lr_d, self.weight_decay))
        elif mode == 'RMSProp':
            d_optimizer = RMSprop(self.D.parameters(),
                                  lr=self.lr_d,
                                  weight_decay=self.weight_decay)
            g_optimizer = RMSprop(self.G.parameters(),
                                  lr=self.lr_d,
                                  weight_decay=self.weight_decay)
            self.writer_init(logname=logname,
                             comments='RMSProp-%.3f_%.5f' %
                             (self.lr_d, self.weight_decay))

        timer = time.time()

        for e in range(epoch_num):
            for real_x in self.dataloader:
                real_x = real_x[0].to(self.device)
                d_real = self.D(real_x)

                z = torch.randn((self.batchsize, self.z_dim),
                                device=self.device)  ## changed (shape)
                fake_x = self.G(z)
                d_fake = self.D(fake_x.detach())

                # D_loss = gan_loss(d_real, d_fake)
                D_loss = self.criterion(d_real, torch.ones(d_real.shape, device=self.device)) + \
                         self.criterion(d_fake, torch.zeros(d_fake.shape, device=self.device))
                # D_loss = d_fake.mean() - d_real.mean()
                d_optimizer.zero_grad()
                D_loss.backward()
                gd = torch.norm(torch.cat([
                    p.grad.contiguous().view(-1) for p in self.D.parameters()
                ]),
                                p=2)

                z = torch.randn((self.batchsize, self.z_dim),
                                device=self.device)  ## changed
                fake_x = self.G(z)
                d_fake = self.D(fake_x)
                # G_loss = g_loss(d_fake)
                G_loss = self.criterion(
                    d_fake, torch.ones(d_fake.shape, device=self.device))
                g_optimizer.zero_grad()
                G_loss.backward()
                g_optimizer.step()
                gg = torch.norm(torch.cat([
                    p.grad.contiguous().view(-1) for p in self.G.parameters()
                ]),
                                p=2)

                self.plot_param(D_loss=D_loss, G_loss=G_loss)
                self.plot_grad(gd=gd, gg=gg)
                self.plot_d(d_real, d_fake)

                if self.count % self.show_iter == 0:
                    self.show_info(timer=time.time() - timer,
                                   D_loss=D_loss,
                                   G_loss=G_loss)
                    timer = time.time()
                    self.save_checkpoint('sfixD%s-%.5f_%d.pth' %
                                         (mode, self.lr_d, self.count),
                                         dataset=dataname)
                self.count += 1
        self.writer.close()
示例#11
0
class DQNAgent():
    def __init__(self, config, n_Feature, n_Action):
        self.device = config["device"]
        self.lr = config["lr"]
        self.lr_step_size = config["lr_decay_step"]
        self.lr_gamma = config["lr_decay_gamma"]
        self.lr_last_epoch = config["lr_last_epoch"]
        self.target_net_update_freq = config["target_net_update_freq"]
        self.experience_replay_size = config["experience_replay_size"]
        self.batch_size = config["batch_size"]
        self.discount = config["discount"]

        self.num_feature = n_Feature
        self.num_action = n_Action

        # initialize the predict net and target net in the agent
        self.predict_net = DNN(input_shape=n_Feature, num_actions=n_Action)
        self.target_net = DNN(input_shape=n_Feature, num_actions=n_Action)
        self.predict_net = self.predict_net.to(self.device)
        self.target_net = self.target_net.to(self.device)
        self.predict_net.apply(weigth_init)
        self.target_net.apply(weigth_init)
        self.target_net.load_state_dict(self.predict_net.state_dict())
        self.optimizer = RMSprop(self.predict_net.parameters(),
                                 lr=self.lr,
                                 momentum=0.95,
                                 eps=0.01)
        self.lr_schduler = StepLR(optimizer=self.optimizer,
                                  step_size=self.lr_step_size,
                                  gamma=self.lr_gamma)
        # self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        # initialize the experience replay buffer
        self.memory = ReplayMemory(entry_size=n_Feature,
                                   memory_size=self.experience_replay_size,
                                   batch_size=self.batch_size)

    def prepare_minibatch(self):
        batch_state, batch_next_state, batch_action, batch_reward = self.memory.sample(
        )
        batch_state = torch.tensor(batch_state,
                                   device=self.device,
                                   dtype=torch.float)
        batch_action = torch.tensor(batch_action,
                                    device=self.device,
                                    dtype=torch.long).view(-1, 1)
        batch_reward = torch.tensor(batch_reward,
                                    device=self.device,
                                    dtype=torch.float)
        batch_next_state = torch.tensor(batch_next_state,
                                        device=self.device,
                                        dtype=torch.float)
        return batch_state, batch_action, batch_reward, batch_next_state

    def update_dqn(self, episode=0):
        self.predict_net.train()
        self.target_net.train()
        batch_state, batch_action, batch_reward, batch_next_state = self.prepare_minibatch(
        )
        current_state_values = self.predict_net(batch_state).gather(
            1, batch_action).squeeze()
        next_state_values = self.target_net(batch_next_state).max(
            1)[0].detach()
        expected_state_action_values = (next_state_values *
                                        self.discount) + batch_reward
        # compute temporal difference as loss
        loss = (expected_state_action_values - current_state_values)**2
        loss = loss.mean()
        # optimize the model
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        # self.lr_schduler.step()
        # update target network
        if episode // self.target_net_update_freq == 0:
            self.target_net.load_state_dict(self.predict_net.state_dict())
        return loss.item()

    def update_double_dqn(self, episode=0):
        # switch to train mode so that BN can work properly
        self.predict_net.train()
        # self.target_net.train()
        batch_state, batch_action, batch_reward, batch_next_state = self.prepare_minibatch(
        )
        current_state_values = self.predict_net(batch_state).gather(
            1, batch_action).squeeze()
        pred_action = self.predict_net(batch_next_state).argmax(1).unsqueeze(1)
        next_state_values = self.target_net(batch_next_state).gather(
            1, pred_action).squeeze()
        expected_state_action_values = (next_state_values *
                                        self.discount) + batch_reward
        # compute temporal difference as loss
        loss = (expected_state_action_values - current_state_values)**2
        loss = loss.mean()
        # optimize the model
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.lr_schduler.step()
        # update target network
        if episode % self.target_net_update_freq == self.target_net_update_freq - 1:
            self.target_net.load_state_dict(self.predict_net.state_dict())
        return loss.item()

    def get_action(self, state, epsilon, learned_policy=True):
        with torch.no_grad():
            rand = random.random()
            if rand < epsilon or learned_policy == False:
                return np.random.randint(0, self.num_action)
            else:
                self.predict_net.eval(
                )  # switch to eval mode so that BN can work properly
                X = torch.tensor([state],
                                 device=self.device,
                                 dtype=torch.float)
                a = self.predict_net.forward(X).squeeze().argmax().item()
                return a
示例#12
0
class OffPGLearner:
    def __init__(self, mac, scheme, logger, args):
        self.args = args
        self.n_agents = args.n_agents
        self.n_actions = args.n_actions
        self.mac = mac
        self.logger = logger

        self.last_target_update_step = 0
        self.critic_training_steps = 0

        self.log_stats_t = -self.args.learner_log_interval - 1

        self.critic = OffPGCritic(scheme, args)
        self.mixer = QMixer(args)
        self.target_critic = copy.deepcopy(self.critic)
        self.target_mixer = copy.deepcopy(self.mixer)

        self.agent_params = list(mac.parameters())
        self.critic_params = list(self.critic.parameters())
        self.mixer_params = list(self.mixer.parameters())
        self.params = self.agent_params + self.critic_params
        self.c_params = self.critic_params + self.mixer_params

        self.agent_optimiser =  RMSprop(params=self.agent_params, lr=args.lr)
        self.critic_optimiser =  RMSprop(params=self.critic_params, lr=args.lr)
        self.mixer_optimiser =  RMSprop(params=self.mixer_params, lr=args.lr)

        print('Mixer Size: ')
        print(get_parameters_num(list(self.c_params)))

    def train(self, batch: EpisodeBatch, t_env: int, log):
        # Get the relevant quantities
        bs = batch.batch_size
        max_t = batch.max_seq_length
        actions = batch["actions"][:, :-1]
        terminated = batch["terminated"][:, :-1].float()
        avail_actions = batch["avail_actions"][:, :-1]
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        mask = mask.repeat(1, 1, self.n_agents).view(-1)
        states = batch["state"][:, :-1]

        #build q
        inputs = self.critic._build_inputs(batch, bs, max_t)
        q_vals = self.critic.forward(inputs).detach()[:, :-1]

        mac_out = []
        self.mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length - 1):
            agent_outs = self.mac.forward(batch, t=t)
            mac_out.append(agent_outs)
        mac_out = th.stack(mac_out, dim=1)  # Concat over time

        # Mask out unavailable actions, renormalise (as in action selection)
        mac_out[avail_actions == 0] = 0
        mac_out = mac_out/mac_out.sum(dim=-1, keepdim=True)
        mac_out[avail_actions == 0] = 0

        # Calculated baseline
        q_taken = th.gather(q_vals, dim=3, index=actions).squeeze(3)
        pi = mac_out.view(-1, self.n_actions)
        baseline = th.sum(mac_out * q_vals, dim=-1).view(-1).detach()

        # Calculate policy grad with mask
        pi_taken = th.gather(pi, dim=1, index=actions.reshape(-1, 1)).squeeze(1)
        pi_taken[mask == 0] = 1.0
        log_pi_taken = th.log(pi_taken)
        coe = self.mixer.k(states).view(-1)

        advantages = (q_taken.view(-1) - baseline)
        # advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        coma_loss = - ((coe * advantages.detach() * log_pi_taken) * mask).sum() / mask.sum()
        
        # dist_entropy = Categorical(pi).entropy().view(-1)
        # dist_entropy[mask == 0] = 0 # fill nan
        # entropy_loss = (dist_entropy * mask).sum() / mask.sum()
 
        # loss = coma_loss - self.args.ent_coef * entropy_loss / entropy_loss.item()
        loss = coma_loss

        # Optimise agents
        self.agent_optimiser.zero_grad()
        loss.backward()
        grad_norm = th.nn.utils.clip_grad_norm_(self.agent_params, self.args.grad_norm_clip)
        self.agent_optimiser.step()

        #compute parameters sum for debugging
        p_sum = 0.
        for p in self.agent_params:
            p_sum += p.data.abs().sum().item() / 100.0


        if t_env - self.log_stats_t >= self.args.learner_log_interval:
            ts_logged = len(log["critic_loss"])
            for key in ["critic_loss", "critic_grad_norm", "td_error_abs", "q_taken_mean", "target_mean", "q_max_mean", "q_min_mean", "q_max_var", "q_min_var"]:
                self.logger.log_stat(key, sum(log[key])/ts_logged, t_env)
            self.logger.log_stat("q_max_first", log["q_max_first"], t_env)
            self.logger.log_stat("q_min_first", log["q_min_first"], t_env)
            #self.logger.log_stat("advantage_mean", (advantages * mask).sum().item() / mask.sum().item(), t_env)
            # self.logger.log_stat("entropy_loss", entropy_loss.item(), t_env)
            self.logger.log_stat("coma_loss", coma_loss.item(), t_env)
            self.logger.log_stat("agent_grad_norm", grad_norm, t_env)
            self.logger.log_stat("pi_max", (pi.max(dim=1)[0] * mask).sum().item() / mask.sum().item(), t_env)
            self.log_stats_t = t_env

    def train_critic(self, on_batch, best_batch=None, log=None):
        bs = on_batch.batch_size
        max_t = on_batch.max_seq_length
        rewards = on_batch["reward"][:, :-1]
        actions = on_batch["actions"][:, :]
        terminated = on_batch["terminated"][:, :-1].float()
        mask = on_batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        avail_actions = on_batch["avail_actions"][:]
        states = on_batch["state"]

        #build_target_q
        target_inputs = self.target_critic._build_inputs(on_batch, bs, max_t)
        target_q_vals = self.target_critic.forward(target_inputs).detach()
        targets_taken = self.target_mixer(th.gather(target_q_vals, dim=3, index=actions).squeeze(3), states)
        target_q = build_td_lambda_targets(rewards, terminated, mask, targets_taken, self.n_agents, self.args.gamma, self.args.td_lambda).detach()

        inputs = self.critic._build_inputs(on_batch, bs, max_t)


        if best_batch is not None:
            best_target_q, best_inputs, best_mask, best_actions, best_mac_out= self.train_critic_best(best_batch)
            log["best_reward"] = th.mean(best_batch["reward"][:, :-1].squeeze(2).sum(-1), dim=0)
            target_q = th.cat((target_q, best_target_q), dim=0)
            inputs = th.cat((inputs, best_inputs), dim=0)
            mask = th.cat((mask, best_mask), dim=0)
            actions = th.cat((actions, best_actions), dim=0)
            states = th.cat((states, best_batch["state"]), dim=0)

        #train critic
        for t in range(max_t - 1):
            mask_t = mask[:, t:t+1]
            if mask_t.sum() < 0.5:
                continue
            q_vals = self.critic.forward(inputs[:, t:t+1])
            q_ori = q_vals
            q_vals = th.gather(q_vals, 3, index=actions[:, t:t+1]).squeeze(3)
            q_vals = self.mixer.forward(q_vals, states[:, t:t+1])
            target_q_t = target_q[:, t:t+1].detach()
            q_err = (q_vals - target_q_t) * mask_t
            critic_loss = (q_err ** 2).sum() / mask_t.sum()

            self.critic_optimiser.zero_grad()
            self.mixer_optimiser.zero_grad()
            critic_loss.backward()
            grad_norm = th.nn.utils.clip_grad_norm_(self.c_params, self.args.grad_norm_clip)
            self.critic_optimiser.step()
            self.mixer_optimiser.step()
            self.critic_training_steps += 1

            log["critic_loss"].append(critic_loss.item())
            log["critic_grad_norm"].append(grad_norm)
            mask_elems = mask_t.sum().item()
            log["td_error_abs"].append((q_err.abs().sum().item() / mask_elems))
            log["target_mean"].append((target_q_t * mask_t).sum().item() / mask_elems)
            log["q_taken_mean"].append((q_vals * mask_t).sum().item() / mask_elems)
            log["q_max_mean"].append((th.mean(q_ori.max(dim=3)[0], dim=2, keepdim=True) * mask_t).sum().item() / mask_elems)
            log["q_min_mean"].append((th.mean(q_ori.min(dim=3)[0], dim=2, keepdim=True) * mask_t).sum().item() / mask_elems)
            log["q_max_var"].append((th.var(q_ori.max(dim=3)[0], dim=2, keepdim=True) * mask_t).sum().item() / mask_elems)
            log["q_min_var"].append((th.var(q_ori.min(dim=3)[0], dim=2, keepdim=True) * mask_t).sum().item() / mask_elems)

            if (t == 0):
                log["q_max_first"] = (th.mean(q_ori.max(dim=3)[0], dim=2, keepdim=True) * mask_t).sum().item() / mask_elems
                log["q_min_first"] = (th.mean(q_ori.min(dim=3)[0], dim=2, keepdim=True) * mask_t).sum().item() / mask_elems

        #update target network
        if (self.critic_training_steps - self.last_target_update_step) / self.args.target_update_interval >= 1.0:
            self._update_targets()
            self.last_target_update_step = self.critic_training_steps



    def train_critic_best(self, batch):
        bs = batch.batch_size
        max_t = batch.max_seq_length
        rewards = batch["reward"][:, :-1]
        actions = batch["actions"][:, :]
        terminated = batch["terminated"][:, :-1].float()
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        avail_actions = batch["avail_actions"][:]
        states = batch["state"]

        with th.no_grad():
            # pr for all actions of the episode
            mac_out = []
            self.mac.init_hidden(bs)
            for i in range(max_t):
                agent_outs = self.mac.forward(batch, t=i)
                mac_out.append(agent_outs)
            mac_out = th.stack(mac_out, dim=1).detach()
            # Mask out unavailable actions, renormalise (as in action selection)
            mac_out[avail_actions == 0] = 0
            mac_out = mac_out / mac_out.sum(dim=-1, keepdim=True)
            mac_out[avail_actions == 0] = 0
            critic_mac = th.gather(mac_out, 3, actions).squeeze(3).prod(dim=2, keepdim=True)

            #target_q take
            target_inputs = self.target_critic._build_inputs(batch, bs, max_t)
            target_q_vals = self.target_critic.forward(target_inputs).detach()
            targets_taken = self.target_mixer(th.gather(target_q_vals, dim=3, index=actions).squeeze(3), states)

            #expected q
            exp_q = self.build_exp_q(target_q_vals, mac_out, states).detach()
            # td-error
            targets_taken[:, -1] = targets_taken[:, -1] * (1 - th.sum(terminated, dim=1))
            exp_q[:, -1] = exp_q[:, -1] * (1 - th.sum(terminated, dim=1))
            targets_taken[:, :-1] = targets_taken[:, :-1] * mask
            exp_q[:, :-1] = exp_q[:, :-1] * mask
            td_q = (rewards + self.args.gamma * exp_q[:, 1:] - targets_taken[:, :-1]) * mask

            #compute target
            target_q =  build_target_q(td_q, targets_taken[:, :-1], critic_mac, mask, self.args.gamma, self.args.tb_lambda, self.args.step).detach()

            inputs = self.critic._build_inputs(batch, bs, max_t)

        return target_q, inputs, mask, actions, mac_out


    def build_exp_q(self, target_q_vals, mac_out, states):
        target_exp_q_vals = th.sum(target_q_vals * mac_out, dim=3)
        target_exp_q_vals = self.target_mixer.forward(target_exp_q_vals, states)
        return target_exp_q_vals

    def _update_targets(self):
        self.target_critic.load_state_dict(self.critic.state_dict())
        self.target_mixer.load_state_dict(self.mixer.state_dict())
        self.logger.console_logger.info("Updated target network")

    def cuda(self):
        self.mac.cuda()
        self.critic.cuda()
        self.mixer.cuda()
        self.target_critic.cuda()
        self.target_mixer.cuda()

    def save_models(self, path):
        self.mac.save_models(path)
        th.save(self.critic.state_dict(), "{}/critic.th".format(path))
        th.save(self.mixer.state_dict(), "{}/mixer.th".format(path))
        th.save(self.agent_optimiser.state_dict(), "{}/agent_opt.th".format(path))
        th.save(self.critic_optimiser.state_dict(), "{}/critic_opt.th".format(path))
        th.save(self.mixer_optimiser.state_dict(), "{}/mixer_opt.th".format(path))

    def load_models(self, path):
        self.mac.load_models(path)
        self.critic.load_state_dict(th.load("{}/critic.th".format(path), map_location=lambda storage, loc: storage))
        self.mixer.load_state_dict(th.load("{}/mixer.th".format(path), map_location=lambda storage, loc: storage))
        # Not quite right but I don't want to save target networks
       # self.target_critic.load_state_dict(self.critic.agent.state_dict())
        self.target_mixer.load_state_dict(self.mixer.state_dict())
        self.agent_optimiser.load_state_dict(th.load("{}/agent_opt.th".format(path), map_location=lambda storage, loc: storage))
        self.critic_optimiser.load_state_dict(th.load("{}/critic_opt.th".format(path), map_location=lambda storage, loc: storage))
        self.mixer_optimiser.load_state_dict(th.load("{}/mixer_opt.th".format(path), map_location=lambda storage, loc: storage))