Esempio n. 1
0
 def _find_baseline(self):
     if hasattr(self, 'dynamic_size'):
         x_bl = torch.FloatTensor(
             1, self.dynamic_size[0] * self.dynamic_size[1]).zero_()
     else:
         x_bl = torch.FloatTensor(1, self.model.struct[0]).zero_()
     if self.basaline_goal == 'input':
         self.x_bl = x_bl
         return
     # x_bl = torch.randn(1, self.dynamic_size[0]*self.dynamic_size[1])
     zero = torch.FloatTensor(1, self.model.n_category).zero_()
     x_bl = Variable(x_bl, requires_grad=True).to(self.model.dvc)
     optimizer = RMSprop([x_bl], lr=1e-2, alpha=0.9, eps=1e-10)
     loss_data = 1.0
     epoch = 0
     print('\nFind baseline ...')
     while loss_data > 1e-6 and epoch <= 5e3:
         epoch += 1
         optimizer.zero_grad()
         output = self.model.forward(x_bl)
         loss = torch.sqrt(torch.mean((output)**2))
         loss_data = loss.data.cpu().numpy()
         loss.backward()
         optimizer.step()
         msg = " | Epoch: {}, Loss = {:.4f}".format(epoch, loss_data)
         sys.stdout.write('\r' + msg)
         sys.stdout.flush()
     view_info('baseline', x_bl)
     print('{:.4f}'.format(output))
     self.x_bl = x_bl.data
Esempio n. 2
0
class MDNSineTrainer:
    def __init__(self, dataset, model):
        self.dataset = dataset
        self.model = model
        self.args = get_args()
        self.loss_func = mdn_loss_func
        self.optimizer = RMSprop(self.model.parameters())

    def train(self):
        dataloader = DataLoader(self.dataset, batch_size=self.args.batch_size)

        for ep in range(self.args.epochs):
            for it, data in enumerate(dataloader):
                x_data, y_data = data
                res = self.model(x_data)

                loss = self.loss_func(*res, y_data)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                if it % self.args.log_every == 0:
                    print("ep: %d, it: %d, loss: %.4f" % (ep, it, loss))

                if it % self.args.save_every == 0:
                    torch.save(self.model.state_dict(),
                               '../checkpoints/mdn_model_checkpoint.pt')
Esempio n. 3
0
class GRUTask(Task):
    def __init__(self, params):
        super().__init__(params)
        self.cel = nn.CrossEntropyLoss(reduction="sum")
        self.optim = RMSprop(self.model.parameters(), lr=self.lr)

    def init(self):
        self.model.init(self.batch_size)

    def perbatch(self, xs, ys, bn=-1, istraining=True):
        batch_loss = 0
        total = 0
        correct = 0
        steps = xs.shape[1]
        batch_size = xs.shape[0]
        self.model.init()
        yp = self.model(xs)
        _, yp_index = torch.topk(yp, 1, dim=2)
        total = batch_size * steps
        yp_index = yp_index.view(yp_index.shape[0], yp_index.shape[1])
        correct = torch.sum(yp_index == ys).item()
        yp = yp.view(-1, 2)
        ys = ys.view(-1)

        batch_loss = self.cel(yp, ys)

        if istraining:
            self.optim.zero_grad()
            batch_loss.backward()
            self.optim.step()
        if self.verbose:
            print("Train batch %d Loss: %f Accuracy: %f" %
                  (bn, batch_loss / total, correct / total))
        return batch_loss, correct, total
Esempio n. 4
0
class Agent:
    def __init__(
        self,
        state_dim,
        action_dim,
        n_agents,
        batch_size=BATCH_SIZE,
        grad_clip=GRADIENT_CLIP,
    ):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.n_agents = n_agents
        self.batch_size = batch_size
        self.model = A2CGaussian(state_dim, action_dim)
        self.optimizer = RMSprop(self.model.parameters(), lr=LR)
        self.last_out = None
        self.state_normalizer = Normalizer()
        self.minibatch = MiniBatch()
        self.grad_clip = grad_clip

    def act(self, states):
        """
        Retrieve the actions from the network and save output.
        """
        self.last_out = self.model(self.state_normalizer(states))
        return self.last_out.actions.numpy()

    def step(self, rewards, dones):
        """
        Step through the environment, doing an update if any of the agent reaches a terminal
        state or if we accumulate `batch_size` steps.
        Returns the loss if an update was made, otherwise None.
        """
        _, l, v, e = self.last_out
        if len(self.minibatch) == self.batch_size or (any(dones) and
                                                      len(self.minibatch) > 0):
            # We use v as v_next, and ignore other variables
            return self._learn(v)
        else:
            sample = Sample(rewards=rewards,
                            dones=dones,
                            log_probs=l,
                            v=v,
                            entropy=e)
            self.minibatch.append(sample)
            return None

    def _learn(self, v_next):
        """
        Do a network update, clipping gradients for stability.
        """
        loss = self.minibatch.compute_loss(v_next)
        self.optimizer.zero_grad()
        loss.backward()
        clip_grad_norm_(self.model.parameters(), max_norm=self.grad_clip)
        self.optimizer.step()
        return loss.item()
 def get_reward(self, action, model_name):
     model = DenseModel(self.num_input, self.num_classes,
                        action).to(self.device)
     if len(self.reg_param) == 3:
         loss_func = TiltedLoss(self.reg_param[0])
     else:
         loss_func = nn.MSELoss()
     optimizer = RMSprop(model.parameters(), lr=self.learning_rate)
     all_mean_val_loss = []
     all_mean_tr_loss = []
     for epoch in range(200):
         model.train()
         model.is_training = True
         mean_tr_loss = []
         for step, (tx, ty) in enumerate(self.train_loader):
             tx = tx.view(-1, self.num_input).to(self.device)
             optimizer.zero_grad()
             to = model(tx)
             loss = loss_func(to, ty.to(self.device))
             loss.backward()
             optimizer.step()
             mean_tr_loss.append(loss.cpu().detach().numpy())
         all_mean_tr_loss.append(np.mean(mean_tr_loss))
     print('train_loss:', -np.mean(all_mean_tr_loss))
     dummy_input = torch.randn(1, 22)
     torch.onnx.export(model, dummy_input,
                       '../saved_model/' + model_name + '.onnx')
     model.eval()
     model.is_training = False
     mean_val_loss = []
     for step, (tx, ty) in enumerate(self.test_loader):
         tx = tx.view(-1, self.num_input).to(self.device)
         to = model(tx)
         l = loss_func(to, ty.to(self.device)).cpu().detach().numpy()
         mean_val_loss.append(l)
     all_mean_val_loss.append(mean_val_loss)
     to = model(self.X_test)
     if len(self.reg_param) == 3:
         q_score, truth_label, colors = self.reg_param
         for idx, (q, t,
                   c) in enumerate(list(zip(q_score, truth_label, colors))):
             plt.plot(self.xvals,
                      to.detach().numpy()[:, idx],
                      label=q,
                      color=c)
             plt.plot(self.xvals, t, label=q, color=c)
     else:
         plt.plot(
             self.xvals,
             to.detach().numpy(),
         )
         plt.plot(self.xvals, self.reg_param)
     plt.legend()
     plt.savefig('../saved_model_graph/' + model_name + '.png')
     plt.show()
     return -np.mean(all_mean_val_loss)
Esempio n. 6
0
def main():
    args = vars(parser.parse_args())
    agent_config = configs.get_agent_config(args)
    game_config = configs.get_game_config(args)
    training_config = configs.get_training_config(args)
    print("Training with config:")
    print(training_config)
    print(game_config)
    print(agent_config)
    agent = AgentModule(agent_config)
    if training_config.use_cuda:
        agent.cuda()
    optimizer = RMSprop(agent.parameters(), lr=training_config.learning_rate)
    scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, cooldown=5)
    losses = defaultdict(lambda: defaultdict(list))
    dists = defaultdict(lambda: defaultdict(list))
    for epoch in range(training_config.num_epochs):
        num_agents = np.random.randint(game_config.min_agents,
                                       game_config.max_agents + 1)
        num_landmarks = np.random.randint(game_config.min_landmarks,
                                          game_config.max_landmarks + 1)
        agent.reset()
        game = GameModule(game_config, num_agents, num_landmarks)
        if training_config.use_cuda:
            game.cuda()
        optimizer.zero_grad()

        total_loss, _ = agent(game)
        per_agent_loss = total_loss.data[
            0] / num_agents / game_config.batch_size
        losses[num_agents][num_landmarks].append(per_agent_loss)

        dist = game.get_avg_agent_to_goal_distance()
        avg_dist = dist.data[0] / num_agents / game_config.batch_size
        dists[num_agents][num_landmarks].append(avg_dist)

        print_losses(epoch, losses, dists, game_config)

        total_loss.backward()
        optimizer.step()

        if num_agents == game_config.max_agents and num_landmarks == game_config.max_landmarks:
            scheduler.step(
                losses[game_config.max_agents][game_config.max_landmarks][-1])

    if training_config.save_model:
        torch.save(agent, training_config.save_model_file)
        print("Saved agent model weights at %s" %
              training_config.save_model_file)
    """
Esempio n. 7
0
class LSGAN(object):
    def __init__(self, batch_size, adopt_gas=False):
        self.batch_size = batch_size
        self.generator = Generator(batch_size=self.batch_size, base_filter=32)
        self.discriminator = Discriminator(batch_size=self.batch_size,
                                           base_filter=32,
                                           adopt_gas=adopt_gas)
        self.generator.cuda()
        self.discriminator.cuda()
        self.gen_optimizer = RMSprop(self.generator.parameters())
        self.dis_optimizer = RMSprop(self.discriminator.parameters())

    def train(self, epoch, loader):
        self.generator.train()
        self.discriminator.train()
        self.gen_loss_sum = 0.0
        self.dis_loss_sum = 0.0
        for i, (batch_img, batch_tag) in enumerate(loader):
            # Get logits
            batch_img = Variable(batch_img.cuda())
            batch_z = Variable(torch.randn(self.batch_size, 100).cuda())
            self.gen_image = self.generator(batch_z)
            true_logits = self.discriminator(batch_img)
            fake_logits = self.discriminator(self.gen_image)

            # Get loss
            self.dis_loss = torch.sum((true_logits - 1)**2 + (fake_logits)) / 2
            self.gen_loss = torch.sum((fake_logits - 1)**2) / 2

            # Update
            self.dis_optimizer.zero_grad()
            self.dis_loss.backward(retain_graph=True)
            self.dis_loss_sum += self.dis_loss.data.cpu().numpy()[0]
            self.dis_optimizer.step()
            if i % 5 == 0:
                self.gen_optimizer.zero_grad()
                self.gen_loss.backward()
                self.gen_loss_sum += self.gen_loss.data.cpu().numpy()[0]
                self.gen_optimizer.step()

            if i > 300:
                break

    def eval(self):
        self.generator.eval()
        batch_z = Variable(torch.randn(32, 100).cuda())
        return self.generator(batch_z)
    def _get_input_for_category(self):

        self._n = self.model.train_Y.shape[1]
        images = []
        self._loss = 0
        for c in range(self._n):
            if type(self.input_dim) == int:
                random_image = np.random.uniform(0, 1, (self.input_dim, ))
            else:
                random_image = np.random.uniform(
                    0, 1,
                    (self.input_dim[1], self.input_dim[2], self.input_dim[0]))
            processed_image = preprocess_image(random_image, self.ImageNet)
            optimizer = RMSprop([processed_image],
                                lr=1e-2,
                                alpha=0.9,
                                eps=1e-10)
            label = np.zeros((self._n, ), dtype=np.float32)
            label[c] = 1
            label = torch.from_numpy(label)
            for i in range(self.epoch):
                optimizer.zero_grad()
                output = self.model.forward(processed_image)
                loss = self.model.L(output, label)
                loss.backward()
                _loss = loss.item()
                optimizer.step()
                _msg = "Visual feature for Category {}/{}".format(
                    c + 1, self._n)
                _str = _msg + " | Epoch: {}/{}, Loss = {:.4f}".format(
                    i + 1, self.epoch, _loss)
                sys.stdout.write('\r' + _str)
                sys.stdout.flush()
            self._loss += _loss
            self.model.zero_grad()
            processed_image.requires_grad_(False)
            created_image = recreate_image(processed_image, self.ImageNet,
                                           self.reshape)
            images.append(created_image)
        self._loss /= self._n
        self._layer_name = 'output'
        self._save(images)
Esempio n. 9
0
def main(args):
    trainloader, testloader = get_sequential_mnist(args.dataroot,
                                                   args.batch_size)
    model = RnnModel1().to(args.device)
    opt = RMSprop(model.parameters(), lr=args.lr)

    print("Model parameters:", count_parameters(model))
    for i, (x, y) in tqdm(zip(range(args.iterations), loop_iter(trainloader)),
                          total=args.iterations):
        if i % args.test_interval == 0:
            test_acc = evaluate(model, testloader, args.device)
            print(f"\niter={i:5d} test_acc={test_acc:.4f}")

        model.train()

        x, y = x.to(args.device), y.to(args.device)
        pred = model(x)
        loss = F.cross_entropy(pred, y)

        opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
        opt.step()
Esempio n. 10
0
class DQNModel(Model):
    def __init__(self,
                 name,
                 network_config,
                 restore=True,
                 learning_rate=0.001):
        logger.info("Building network for %s" % name)
        self.network_config = network_config
        model = _DQNModel(network_config)
        Model.__init__(self, model, name, network_config, restore)
        logger.info("Created network for %s " % self.name)
        self.optimizer = RMSprop(self.model.parameters(), lr=learning_rate)
        self.loss_fn = nn.MSELoss()

    def fit(self, states, target, steps):
        self.optimizer.zero_grad()
        predict = self.model(states)
        loss = self.loss_fn(predict, target)
        loss.backward()
        self.optimizer.step()

    def predict(self, input):
        return self.model(input)
Esempio n. 11
0
    def train(self, model, train_loader, valid_loader):
        criterion = BCELoss()  # binary cross-entropy
        optimizer = RMSprop(model.parameters(), lr=self.config.learning_rate)
        early_stopping = EarlyStopping(
            patience=self.config.early_stopping_patience)

        epochs_finished = 0
        for _ in range(self.config.epochs):

            model.train()
            for data, target in train_loader:
                optimizer.zero_grad()

                output = model(data)

                loss = criterion(output, target)
                loss.backward()

                optimizer.step()

            model.eval()
            valid_losses = []
            for data, target in valid_loader:
                output = model(data)
                loss = criterion(output, target)
                valid_losses.append(loss.item())
            valid_loss = np.average(valid_losses)

            epochs_finished += 1

            if early_stopping.should_early_stop(valid_loss, model):
                break

        model.load_state_dict(early_stopping.best_model_state)

        return model, epochs_finished
Esempio n. 12
0
class QLearner:
    def __init__(self, mac, scheme, logger, args):
        self.args = args
        self.mac = mac
        self.logger = logger

        self.params = list(mac.parameters())

        self.last_target_update_episode = 0

        self.mixer = None
        if args.mixer is not None:
            if args.mixer == "vdn":
                self.mixer = VDNMixer()
            elif args.mixer == "qmix":
                self.mixer = NoiseQMixer(args)
            else:
                raise ValueError("Mixer {} not recognised.".format(args.mixer))
            self.params += list(self.mixer.parameters())
            self.target_mixer = copy.deepcopy(self.mixer)

        discrim_input = np.prod(
            self.args.state_shape) + self.args.n_agents * self.args.n_actions

        if self.args.rnn_discrim:
            self.rnn_agg = RNNAggregator(discrim_input, args)
            self.discrim = Discrim(args.rnn_agg_size, self.args.noise_dim,
                                   args)
            self.params += list(self.discrim.parameters())
            self.params += list(self.rnn_agg.parameters())
        else:
            self.discrim = Discrim(discrim_input, self.args.noise_dim, args)
            self.params += list(self.discrim.parameters())
        self.discrim_loss = th.nn.CrossEntropyLoss(reduction="none")

        self.optimiser = RMSprop(params=self.params,
                                 lr=args.lr,
                                 alpha=args.optim_alpha,
                                 eps=args.optim_eps)

        self.target_mac = copy.deepcopy(mac)

        self.log_stats_t = -self.args.learner_log_interval - 1

    def train(self, batch: EpisodeBatch, t_env: int, episode_num: int):
        # Get the relevant quantities
        rewards = batch["reward"][:, :-1]
        actions = batch["actions"][:, :-1]
        terminated = batch["terminated"][:, :-1].float()
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        avail_actions = batch["avail_actions"]
        noise = batch["noise"][:,
                               0].unsqueeze(1).repeat(1, rewards.shape[1], 1)

        # Calculate estimated Q-Values
        mac_out = []
        self.mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            agent_outs = self.mac.forward(batch, t=t)
            mac_out.append(agent_outs)
        mac_out = th.stack(mac_out, dim=1)  # Concat over time

        # Pick the Q-Values for the actions taken by each agent
        chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3,
                                        index=actions).squeeze(
                                            3)  # Remove the last dim

        # Calculate the Q-Values necessary for the target
        target_mac_out = []
        self.target_mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            target_agent_outs = self.target_mac.forward(batch, t=t)
            target_mac_out.append(target_agent_outs)

        # We don't need the first timesteps Q-Value estimate for calculating targets
        target_mac_out = th.stack(target_mac_out[1:],
                                  dim=1)  # Concat across time

        # Mask out unavailable actions
        #target_mac_out[avail_actions[:, 1:] == 0] = -9999999  # From OG deepmarl

        # Max over target Q-Values
        if self.args.double_q:
            # Get actions that maximise live Q (for double q-learning)
            #mac_out[avail_actions == 0] = -9999999
            cur_max_actions = mac_out[:, 1:].max(dim=3, keepdim=True)[1]
            target_max_qvals = th.gather(target_mac_out, 3,
                                         cur_max_actions).squeeze(3)
        else:
            target_max_qvals = target_mac_out.max(dim=3)[0]

        # Mix
        if self.mixer is not None:
            chosen_action_qvals = self.mixer(chosen_action_qvals,
                                             batch["state"][:, :-1], noise)
            target_max_qvals = self.target_mixer(target_max_qvals,
                                                 batch["state"][:, 1:], noise)

        # Discriminator
        #mac_out[avail_actions == 0] = -9999999
        q_softmax_actions = th.nn.functional.softmax(mac_out[:, :-1], dim=3)

        if self.args.hard_qs:
            maxs = th.max(mac_out[:, :-1], dim=3, keepdim=True)[1]
            zeros = th.zeros_like(q_softmax_actions)
            zeros.scatter_(dim=3, index=maxs, value=1)
            q_softmax_actions = zeros

        q_softmax_agents = q_softmax_actions.reshape(
            q_softmax_actions.shape[0], q_softmax_actions.shape[1], -1)

        states = batch["state"][:, :-1]
        state_and_softactions = th.cat([q_softmax_agents, states], dim=2)

        if self.args.rnn_discrim:
            h_to_use = th.zeros(size=(batch.batch_size,
                                      self.args.rnn_agg_size)).to(
                                          states.device)
            hs = th.ones_like(h_to_use)
            for t in range(batch.max_seq_length - 1):
                hs = self.rnn_agg(state_and_softactions[:, t], hs)
                for b in range(batch.batch_size):
                    if t == batch.max_seq_length - 2 or (mask[b, t] == 1 and
                                                         mask[b, t + 1] == 0):
                        # This is the last timestep of the sequence
                        h_to_use[b] = hs[b]
            s_and_softa_reshaped = h_to_use
        else:
            s_and_softa_reshaped = state_and_softactions.reshape(
                -1, state_and_softactions.shape[-1])

        if self.args.mi_intrinsic:
            s_and_softa_reshaped = s_and_softa_reshaped.detach()

        discrim_prediction = self.discrim(s_and_softa_reshaped)

        # Cross-Entropy
        target_repeats = 1
        if not self.args.rnn_discrim:
            target_repeats = q_softmax_actions.shape[1]
        discrim_target = batch["noise"][:, 0].long().detach().max(
            dim=1)[1].unsqueeze(1).repeat(1, target_repeats).reshape(-1)
        discrim_loss = self.discrim_loss(discrim_prediction, discrim_target)

        if self.args.rnn_discrim:
            averaged_discrim_loss = discrim_loss.mean()
        else:
            masked_discrim_loss = discrim_loss * mask.reshape(-1)
            averaged_discrim_loss = masked_discrim_loss.sum() / mask.sum()
        self.logger.log_stat("discrim_loss", averaged_discrim_loss.item(),
                             t_env)

        # Calculate 1-step Q-Learning targets
        targets = rewards + self.args.gamma * (1 -
                                               terminated) * target_max_qvals
        if self.args.mi_intrinsic:
            assert self.args.rnn_discrim is False
            targets = targets + self.args.mi_scaler * discrim_loss.view_as(
                rewards)

        # Td-error
        td_error = (chosen_action_qvals - targets.detach())

        mask = mask.expand_as(td_error)

        # 0-out the targets that came from padded data
        masked_td_error = td_error * mask

        # Normal L2 loss, take mean over actual data
        loss = (masked_td_error**2).sum() / mask.sum()

        loss = loss + self.args.mi_loss * averaged_discrim_loss

        # Optimise
        self.optimiser.zero_grad()
        loss.backward()
        grad_norm = th.nn.utils.clip_grad_norm_(self.params,
                                                self.args.grad_norm_clip)
        self.optimiser.step()

        if (episode_num - self.last_target_update_episode
            ) / self.args.target_update_interval >= 1.0:
            self._update_targets()
            self.last_target_update_episode = episode_num

        if t_env - self.log_stats_t >= self.args.learner_log_interval:
            self.logger.log_stat("loss", loss.item(), t_env)
            self.logger.log_stat("grad_norm", grad_norm, t_env)
            mask_elems = mask.sum().item()
            self.logger.log_stat(
                "td_error_abs",
                (masked_td_error.abs().sum().item() / mask_elems), t_env)
            self.logger.log_stat("q_taken_mean",
                                 (chosen_action_qvals * mask).sum().item() /
                                 (mask_elems * self.args.n_agents), t_env)
            self.logger.log_stat("target_mean", (targets * mask).sum().item() /
                                 (mask_elems * self.args.n_agents), t_env)
            self.log_stats_t = t_env

    def _update_targets(self):
        self.target_mac.load_state(self.mac)
        if self.mixer is not None:
            self.target_mixer.load_state_dict(self.mixer.state_dict())
        self.logger.console_logger.info("Updated target network")

    def cuda(self):
        self.mac.cuda()
        self.target_mac.cuda()
        self.discrim.cuda()
        if self.args.rnn_discrim:
            self.rnn_agg.cuda()
        if self.mixer is not None:
            self.mixer.cuda()
            self.target_mixer.cuda()

    def save_models(self, path):
        self.mac.save_models(path)
        if self.mixer is not None:
            th.save(self.mixer.state_dict(), "{}/mixer.th".format(path))
        th.save(self.optimiser.state_dict(), "{}/opt.th".format(path))

    def load_models(self, path):
        self.mac.load_models(path)
        self.target_mac.load_models(path)
        if self.mixer is not None:
            self.mixer.load_state_dict(
                th.load("{}/mixer.th".format(path),
                        map_location=lambda storage, loc: storage))
        self.optimiser.load_state_dict(
            th.load("{}/opt.th".format(path),
                    map_location=lambda storage, loc: storage))
Esempio n. 13
0
class QLearner:
    def __init__(self, mac, scheme, logger, args):
        self.args = args
        self.mac = mac
        self.logger = logger

        self.params = list(mac.parameters())

        self.last_target_update_episode = 0

        self.mixer = None
        if args.mixer == "qtran_base":
            self.mixer = QTranBase(args)
        elif args.mixer == "qtran_alt":
            raise Exception("Not implemented here!")

        self.params += list(self.mixer.parameters())
        self.target_mixer = copy.deepcopy(self.mixer)

        self.optimiser = RMSprop(params=self.params,
                                 lr=args.lr,
                                 alpha=args.optim_alpha,
                                 eps=args.optim_eps)

        # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC
        self.target_mac = copy.deepcopy(mac)

        self.log_stats_t = -self.args.learner_log_interval - 1

    def train(self,
              batch: EpisodeBatch,
              t_env: int,
              episode_num: int,
              show_demo=False,
              save_data=None):
        # Get the relevant quantities
        rewards = batch["reward"][:, :-1]
        actions = batch["actions"][:, :-1]
        terminated = batch["terminated"][:, :-1].float()
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        avail_actions = batch["avail_actions"]

        # Calculate estimated Q-Values
        mac_out = []
        mac_hidden_states = []
        self.mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            agent_outs = self.mac.forward(batch, t=t)
            mac_out.append(agent_outs)
            mac_hidden_states.append(self.mac.hidden_states)
        mac_out = torch.stack(mac_out, dim=1)  # Concat over time
        mac_hidden_states = torch.stack(mac_hidden_states, dim=1)
        mac_hidden_states = mac_hidden_states.reshape(batch.batch_size,
                                                      self.args.n_agents,
                                                      batch.max_seq_length,
                                                      -1).transpose(1,
                                                                    2)  #btav

        # Pick the Q-Values for the actions taken by each agent
        chosen_action_qvals = torch.gather(mac_out[:, :-1],
                                           dim=3,
                                           index=actions).squeeze(
                                               3)  # Remove the last dim

        x_mac_out = mac_out.clone().detach()
        x_mac_out[avail_actions == 0] = -9999999
        max_action_qvals, max_action_index = x_mac_out[:, :-1].max(dim=3)

        max_action_index = max_action_index.detach().unsqueeze(3)
        is_max_action = (max_action_index == actions).int().float()

        if show_demo:
            q_i_data = chosen_action_qvals.detach().cpu().numpy()
            q_data = (max_action_qvals -
                      chosen_action_qvals).detach().cpu().numpy()

        # Calculate the Q-Values necessary for the target
        target_mac_out = []
        target_mac_hidden_states = []
        self.target_mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            target_agent_outs = self.target_mac.forward(batch, t=t)
            target_mac_out.append(target_agent_outs)
            target_mac_hidden_states.append(self.target_mac.hidden_states)

        # We don't need the first timesteps Q-Value estimate for calculating targets
        target_mac_out = torch.stack(target_mac_out[:],
                                     dim=1)  # Concat across time
        target_mac_hidden_states = torch.stack(target_mac_hidden_states, dim=1)
        target_mac_hidden_states = target_mac_hidden_states.reshape(
            batch.batch_size, self.args.n_agents, batch.max_seq_length,
            -1).transpose(1, 2)  #btav

        # Mask out unavailable actions
        target_mac_out[avail_actions[:, :] == 0] = -9999999  # From OG deepmarl
        mac_out_maxs = mac_out.clone()
        mac_out_maxs[avail_actions == 0] = -9999999

        # Best joint action computed by target agents
        target_max_actions = target_mac_out.max(dim=3, keepdim=True)[1]
        # Best joint-action computed by regular agents
        max_actions_qvals, max_actions_current = mac_out_maxs[:, :].max(
            dim=3, keepdim=True)

        if self.args.mixer == "qtran_base":
            # -- TD Loss --
            # Joint-action Q-Value estimates
            joint_qs, vs = self.mixer(batch[:, :-1], mac_hidden_states[:, :-1])

            # Need to argmax across the target agents' actions to compute target joint-action Q-Values
            if self.args.double_q:
                max_actions_current_ = torch.zeros(
                    size=(batch.batch_size, batch.max_seq_length,
                          self.args.n_agents, self.args.n_actions),
                    device=batch.device)
                max_actions_current_onehot = max_actions_current_.scatter(
                    3, max_actions_current[:, :], 1)
                max_actions_onehot = max_actions_current_onehot
            else:
                max_actions = torch.zeros(
                    size=(batch.batch_size, batch.max_seq_length,
                          self.args.n_agents, self.args.n_actions),
                    device=batch.device)
                max_actions_onehot = max_actions.scatter(
                    3, target_max_actions[:, :], 1)
            target_joint_qs, target_vs = self.target_mixer(
                batch[:, 1:],
                hidden_states=target_mac_hidden_states[:, 1:],
                actions=max_actions_onehot[:, 1:])

            # Td loss targets
            td_targets = rewards.reshape(-1, 1) + self.args.gamma * (
                1 - terminated.reshape(-1, 1)) * target_joint_qs
            td_error = (joint_qs - td_targets.detach())
            masked_td_error = td_error * mask.reshape(-1, 1)
            td_loss = (masked_td_error**2).sum() / mask.sum()
            # -- TD Loss --

            # -- Opt Loss --
            # Argmax across the current agents' actions
            if not self.args.double_q:  # Already computed if we're doing double Q-Learning
                max_actions_current_ = torch.zeros(
                    size=(batch.batch_size, batch.max_seq_length,
                          self.args.n_agents, self.args.n_actions),
                    device=batch.device)
                max_actions_current_onehot = max_actions_current_.scatter(
                    3, max_actions_current[:, :], 1)
            max_joint_qs, _ = self.mixer(
                batch[:, :-1],
                mac_hidden_states[:, :-1],
                actions=max_actions_current_onehot[:, :-1]
            )  # Don't use the target network and target agent max actions as per author's email

            # max_actions_qvals = torch.gather(mac_out[:, :-1], dim=3, index=max_actions_current[:,:-1])
            opt_error = max_actions_qvals[:, :-1].sum(dim=2).reshape(
                -1, 1) - max_joint_qs.detach() + vs
            masked_opt_error = opt_error * mask.reshape(-1, 1)
            opt_loss = (masked_opt_error**2).sum() / mask.sum()
            # -- Opt Loss --

            # -- Nopt Loss --
            # target_joint_qs, _ = self.target_mixer(batch[:, :-1])
            nopt_values = chosen_action_qvals.sum(dim=2).reshape(
                -1, 1) - joint_qs.detach(
                ) + vs  # Don't use target networks here either
            nopt_error = nopt_values.clamp(max=0)
            masked_nopt_error = nopt_error * mask.reshape(-1, 1)
            nopt_loss = (masked_nopt_error**2).sum() / mask.sum()
            # -- Nopt loss --

        elif self.args.mixer == "qtran_alt":
            raise Exception("Not supported yet.")

        if show_demo:
            tot_q_data = joint_qs.detach().cpu().numpy()
            tot_target = td_targets.detach().cpu().numpy()
            bs = q_data.shape[0]
            tot_q_data = tot_q_data.reshape(bs, -1)
            tot_target = tot_target.reshape(bs, -1)
            print('action_pair_%d_%d' % (save_data[0], save_data[1]),
                  np.squeeze(q_data[:, 0]), np.squeeze(q_i_data[:, 0]),
                  np.squeeze(tot_q_data[:, 0]), np.squeeze(tot_target[:, 0]))
            self.logger.log_stat(
                'action_pair_%d_%d' % (save_data[0], save_data[1]),
                np.squeeze(tot_q_data[:, 0]), t_env)
            return

        loss = td_loss + self.args.opt_loss * opt_loss + self.args.nopt_min_loss * nopt_loss

        masked_hit_prob = torch.mean(is_max_action, dim=2) * mask
        hit_prob = masked_hit_prob.sum() / mask.sum()

        # Optimise
        self.optimiser.zero_grad()
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(self.params,
                                                   self.args.grad_norm_clip)
        self.optimiser.step()

        if (episode_num - self.last_target_update_episode
            ) / self.args.target_update_interval >= 1.0:
            self._update_targets()
            self.last_target_update_episode = episode_num

        if t_env - self.log_stats_t >= self.args.learner_log_interval:
            self.logger.log_stat("loss", loss.item(), t_env)
            self.logger.log_stat("hit_prob", hit_prob.item(), t_env)
            self.logger.log_stat("td_loss", td_loss.item(), t_env)
            self.logger.log_stat("opt_loss", opt_loss.item(), t_env)
            self.logger.log_stat("nopt_loss", nopt_loss.item(), t_env)
            self.logger.log_stat("grad_norm", grad_norm, t_env)
            if self.args.mixer == "qtran_base":
                mask_elems = mask.sum().item()
                self.logger.log_stat(
                    "td_error_abs",
                    (masked_td_error.abs().sum().item() / mask_elems), t_env)
                self.logger.log_stat(
                    "td_targets",
                    ((masked_td_error).sum().item() / mask_elems), t_env)
                self.logger.log_stat("td_chosen_qs",
                                     (joint_qs.sum().item() / mask_elems),
                                     t_env)
                self.logger.log_stat("v_mean", (vs.sum().item() / mask_elems),
                                     t_env)
                self.logger.log_stat(
                    "agent_indiv_qs",
                    ((chosen_action_qvals * mask).sum().item() /
                     (mask_elems * self.args.n_agents)), t_env)
            self.log_stats_t = t_env

    def _update_targets(self):
        self.target_mac.load_state(self.mac)
        if self.mixer is not None:
            self.target_mixer.load_state_dict(self.mixer.state_dict())
        self.logger.console_logger.info("Updated target network")

    def cuda(self):
        self.mac.cuda()
        self.target_mac.cuda()
        if self.mixer is not None:
            self.mixer.cuda()
            self.target_mixer.cuda()

    def save_models(self, path):
        self.mac.save_models(path)
        if self.mixer is not None:
            torch.save(self.mixer.state_dict(), "{}/mixer.torch".format(path))
        torch.save(self.optimiser.state_dict(), "{}/opt.torch".format(path))

    def load_models(self, path):
        self.mac.load_models(path)
        # Not quite right but I don't want to save target networks
        self.target_mac.load_models(path)
        if self.mixer is not None:
            self.mixer.load_state_dict(
                torch.load("{}/mixer.torch".format(path),
                           map_location=lambda storage, loc: storage))
        self.optimiser.load_state_dict(
            torch.load("{}/opt.torch".format(path),
                       map_location=lambda storage, loc: storage))
Esempio n. 14
0
class QMixTorchPolicy(Policy):
    """QMix impl. Assumes homogeneous agents for now.

    You must use MultiAgentEnv.with_agent_groups() to group agents
    together for QMix. This creates the proper Tuple obs/action spaces and
    populates the '_group_rewards' info field.

    Action masking: to specify an action mask for individual agents, use a
    dict space with an action_mask key, e.g. {"obs": ob, "action_mask": mask}.
    The mask space must be `Box(0, 1, (n_actions,))`.
    """

    def __init__(self, obs_space, action_space, config):
        _validate(obs_space, action_space)
        config = dict(ray.rllib.agents.qmix.qmix.DEFAULT_CONFIG, **config)
        self.framework = "torch"
        super().__init__(obs_space, action_space, config)
        self.n_agents = len(obs_space.original_space.spaces)
        self.n_actions = action_space.spaces[0].n
        self.h_size = config["model"]["lstm_cell_size"]
        self.has_env_global_state = False
        self.has_action_mask = False
        self.device = (torch.device("cuda")
                       if torch.cuda.is_available() else torch.device("cpu"))

        agent_obs_space = obs_space.original_space.spaces[0]
        if isinstance(agent_obs_space, Dict):
            space_keys = set(agent_obs_space.spaces.keys())
            if "obs" not in space_keys:
                raise ValueError(
                    "Dict obs space must have subspace labeled `obs`")
            self.obs_size = _get_size(agent_obs_space.spaces["obs"])
            if "action_mask" in space_keys:
                mask_shape = tuple(agent_obs_space.spaces["action_mask"].shape)
                if mask_shape != (self.n_actions, ):
                    raise ValueError(
                        "Action mask shape must be {}, got {}".format(
                            (self.n_actions, ), mask_shape))
                self.has_action_mask = True
            if ENV_STATE in space_keys:
                self.env_global_state_shape = _get_size(
                    agent_obs_space.spaces[ENV_STATE])
                self.has_env_global_state = True
            else:
                self.env_global_state_shape = (self.obs_size, self.n_agents)
            # The real agent obs space is nested inside the dict
            config["model"]["full_obs_space"] = agent_obs_space
            agent_obs_space = agent_obs_space.spaces["obs"]
        else:
            self.obs_size = _get_size(agent_obs_space)

        self.model = ModelCatalog.get_model_v2(
            agent_obs_space,
            action_space.spaces[0],
            self.n_actions,
            config["model"],
            framework="torch",
            name="model",
            default_model=RNNModel).to(self.device)

        self.target_model = ModelCatalog.get_model_v2(
            agent_obs_space,
            action_space.spaces[0],
            self.n_actions,
            config["model"],
            framework="torch",
            name="target_model",
            default_model=RNNModel).to(self.device)

        self.exploration = self._create_exploration()

        # Setup the mixer network.
        if config["mixer"] is None:
            self.mixer = None
            self.target_mixer = None
        elif config["mixer"] == "qmix":
            self.mixer = QMixer(self.n_agents, self.env_global_state_shape,
                                config["mixing_embed_dim"]).to(self.device)
            self.target_mixer = QMixer(
                self.n_agents, self.env_global_state_shape,
                config["mixing_embed_dim"]).to(self.device)
        elif config["mixer"] == "vdn":
            self.mixer = VDNMixer().to(self.device)
            self.target_mixer = VDNMixer().to(self.device)
        else:
            raise ValueError("Unknown mixer type {}".format(config["mixer"]))

        self.cur_epsilon = 1.0
        self.update_target()  # initial sync

        # Setup optimizer
        self.params = list(self.model.parameters())
        if self.mixer:
            self.params += list(self.mixer.parameters())
        self.loss = QMixLoss(self.model, self.target_model, self.mixer,
                             self.target_mixer, self.n_agents, self.n_actions,
                             self.config["double_q"], self.config["gamma"])
        self.optimiser = RMSprop(
            params=self.params,
            lr=config["lr"],
            alpha=config["optim_alpha"],
            eps=config["optim_eps"])

    @override(Policy)
    def compute_actions(self,
                        obs_batch,
                        state_batches=None,
                        prev_action_batch=None,
                        prev_reward_batch=None,
                        info_batch=None,
                        episodes=None,
                        explore=None,
                        **kwargs):
        explore = explore if explore is not None else self.config["explore"]
        obs_batch, action_mask, _ = self._unpack_observation(obs_batch)
        # We need to ensure we do not use the env global state
        # to compute actions

        # Compute actions
        with torch.no_grad():
            q_values, hiddens = _mac(
                self.model,
                torch.as_tensor(
                    obs_batch, dtype=torch.float, device=self.device), [
                        torch.as_tensor(
                            np.array(s), dtype=torch.float, device=self.device)
                        for s in state_batches
                    ])
            avail = torch.as_tensor(
                action_mask, dtype=torch.float, device=self.device)
            masked_q_values = q_values.clone()
            masked_q_values[avail == 0.0] = -float("inf")
            # epsilon-greedy action selector
            random_numbers = torch.rand_like(q_values[:, :, 0])
            pick_random = (random_numbers < (self.cur_epsilon
                                             if explore else 0.0)).long()
            random_actions = Categorical(avail).sample().long()
            actions = (pick_random * random_actions +
                       (1 - pick_random) * masked_q_values.argmax(dim=2))
            actions = actions.cpu().numpy()
            hiddens = [s.cpu().numpy() for s in hiddens]

        return TupleActions(list(actions.transpose([1, 0]))), hiddens, {}

    @override(Policy)
    def compute_log_likelihoods(self,
                                actions,
                                obs_batch,
                                state_batches=None,
                                prev_action_batch=None,
                                prev_reward_batch=None):
        obs_batch, action_mask, _ = self._unpack_observation(obs_batch)
        return np.zeros(obs_batch.size()[0])

    @override(Policy)
    def learn_on_batch(self, samples):
        obs_batch, action_mask, env_global_state = self._unpack_observation(
            samples[SampleBatch.CUR_OBS])
        (next_obs_batch, next_action_mask,
         next_env_global_state) = self._unpack_observation(
             samples[SampleBatch.NEXT_OBS])
        group_rewards = self._get_group_rewards(samples[SampleBatch.INFOS])

        input_list = [
            group_rewards, action_mask, next_action_mask,
            samples[SampleBatch.ACTIONS], samples[SampleBatch.DONES],
            obs_batch, next_obs_batch
        ]
        if self.has_env_global_state:
            input_list.extend([env_global_state, next_env_global_state])

        output_list, _, seq_lens = \
            chop_into_sequences(
                samples[SampleBatch.EPS_ID],
                samples[SampleBatch.UNROLL_ID],
                samples[SampleBatch.AGENT_INDEX],
                input_list,
                [],  # RNN states not used here
                max_seq_len=self.config["model"]["max_seq_len"],
                dynamic_max=True)
        # These will be padded to shape [B * T, ...]
        if self.has_env_global_state:
            (rew, action_mask, next_action_mask, act, dones, obs, next_obs,
             env_global_state, next_env_global_state) = output_list
        else:
            (rew, action_mask, next_action_mask, act, dones, obs,
             next_obs) = output_list
        B, T = len(seq_lens), max(seq_lens)

        def to_batches(arr, dtype):
            new_shape = [B, T] + list(arr.shape[1:])
            return torch.as_tensor(
                np.reshape(arr, new_shape), dtype=dtype, device=self.device)

        rewards = to_batches(rew, torch.float)
        actions = to_batches(act, torch.long)
        obs = to_batches(obs, torch.float).reshape(
            [B, T, self.n_agents, self.obs_size])
        action_mask = to_batches(action_mask, torch.float)
        next_obs = to_batches(next_obs, torch.float).reshape(
            [B, T, self.n_agents, self.obs_size])
        next_action_mask = to_batches(next_action_mask, torch.float)
        if self.has_env_global_state:
            env_global_state = to_batches(env_global_state, torch.float)
            next_env_global_state = to_batches(next_env_global_state,
                                               torch.float)

        # TODO(ekl) this treats group termination as individual termination
        terminated = to_batches(dones, torch.float).unsqueeze(2).expand(
            B, T, self.n_agents)

        # Create mask for where index is < unpadded sequence length
        filled = np.reshape(
            np.tile(np.arange(T, dtype=np.float32), B),
            [B, T]) < np.expand_dims(seq_lens, 1)
        mask = torch.as_tensor(
            filled, dtype=torch.float, device=self.device).unsqueeze(2).expand(
                B, T, self.n_agents)

        # Compute loss
        loss_out, mask, masked_td_error, chosen_action_qvals, targets = (
            self.loss(rewards, actions, terminated, mask, obs, next_obs,
                      action_mask, next_action_mask, env_global_state,
                      next_env_global_state))

        # Optimise
        self.optimiser.zero_grad()
        loss_out.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(
            self.params, self.config["grad_norm_clipping"])
        self.optimiser.step()

        mask_elems = mask.sum().item()
        stats = {
            "loss": loss_out.item(),
            "grad_norm": grad_norm
            if isinstance(grad_norm, float) else grad_norm.item(),
            "td_error_abs": masked_td_error.abs().sum().item() / mask_elems,
            "q_taken_mean": (chosen_action_qvals * mask).sum().item() /
            mask_elems,
            "target_mean": (targets * mask).sum().item() / mask_elems,
        }
        return {LEARNER_STATS_KEY: stats}

    @override(Policy)
    def get_initial_state(self):  # initial RNN state
        return [
            s.expand([self.n_agents, -1]).cpu().numpy()
            for s in self.model.get_initial_state()
        ]

    @override(Policy)
    def get_weights(self):
        return {
            "model": self._cpu_dict(self.model.state_dict()),
            "target_model": self._cpu_dict(self.target_model.state_dict()),
            "mixer": self._cpu_dict(self.mixer.state_dict())
            if self.mixer else None,
            "target_mixer": self._cpu_dict(self.target_mixer.state_dict())
            if self.mixer else None,
        }

    @override(Policy)
    def set_weights(self, weights):
        self.model.load_state_dict(self._device_dict(weights["model"]))
        self.target_model.load_state_dict(
            self._device_dict(weights["target_model"]))
        if weights["mixer"] is not None:
            self.mixer.load_state_dict(self._device_dict(weights["mixer"]))
            self.target_mixer.load_state_dict(
                self._device_dict(weights["target_mixer"]))

    @override(Policy)
    def get_state(self):
        state = self.get_weights()
        state["cur_epsilon"] = self.cur_epsilon
        return state

    @override(Policy)
    def set_state(self, state):
        self.set_weights(state)
        self.set_epsilon(state["cur_epsilon"])

    def update_target(self):
        self.target_model.load_state_dict(self.model.state_dict())
        if self.mixer is not None:
            self.target_mixer.load_state_dict(self.mixer.state_dict())
        logger.debug("Updated target networks")

    def set_epsilon(self, epsilon):
        self.cur_epsilon = epsilon

    def _get_group_rewards(self, info_batch):
        group_rewards = np.array([
            info.get(GROUP_REWARDS, [0.0] * self.n_agents)
            for info in info_batch
        ])
        return group_rewards

    def _device_dict(self, state_dict):
        return {
            k: torch.as_tensor(v, device=self.device)
            for k, v in state_dict.items()
        }

    @staticmethod
    def _cpu_dict(state_dict):
        return {k: v.cpu().detach().numpy() for k, v in state_dict.items()}

    def _unpack_observation(self, obs_batch):
        """Unpacks the observation, action mask, and state (if present)
        from agent grouping.

        Returns:
            obs (np.ndarray): obs tensor of shape [B, n_agents, obs_size]
            mask (np.ndarray): action mask, if any
            state (np.ndarray or None): state tensor of shape [B, state_size]
                or None if it is not in the batch
        """
        unpacked = _unpack_obs(
            np.array(obs_batch, dtype=np.float32),
            self.observation_space.original_space,
            tensorlib=np)
        if self.has_action_mask:
            obs = np.concatenate(
                [o["obs"] for o in unpacked],
                axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size])
            action_mask = np.concatenate(
                [o["action_mask"] for o in unpacked], axis=1).reshape(
                    [len(obs_batch), self.n_agents, self.n_actions])
        else:
            if isinstance(unpacked[0], dict):
                unpacked_obs = [u["obs"] for u in unpacked]
            else:
                unpacked_obs = unpacked
            obs = np.concatenate(
                unpacked_obs,
                axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size])
            action_mask = np.ones(
                [len(obs_batch), self.n_agents, self.n_actions],
                dtype=np.float32)

        if self.has_env_global_state:
            state = unpacked[0][ENV_STATE]
        else:
            state = None
        return obs, action_mask, state
Esempio n. 15
0
class Trainer(object):
    def __init__(self,
                 netG,
                 netD,
                 train_loader,
                 test_loader=None,
                 cv_loader=None,
                 gpus=()):
        self.use_gpus = torch.cuda.is_available() & (len(gpus) > 0)
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.cv_loader = cv_loader
        self.netG = netG
        self.netD = netD
        self.count = 0
        self.steps = len(train_loader)
        self.losses = {
            'G': [],
            'D': [],
            'GP': [],
            "SGP": [],
            'WD': [],
            'GN': [],
            'JC': []
        }
        self.valid_losses = {'G': [], 'D': [], 'WD': [], 'JC': []}
        self.writer = SummaryWriter(log_dir="log")
        self.lr = 2e-3
        self.lr_decay = 0.94
        self.weight_decay = 2e-5
        self.nepochs = 100
        self.opt_g = Adam(self.netG.parameters(),
                          lr=self.lr,
                          betas=(0.9, 0.99),
                          weight_decay=self.weight_decay)
        self.opt_d = RMSprop(self.netD.parameters(),
                             lr=self.lr,
                             weight_decay=self.weight_decay)

        input = torch.autograd.Variable(torch.Tensor(4, 1, 256, 256),
                                        requires_grad=True)
        net_cpu = copy.deepcopy(self.netG)
        # out = net_cpu.cpu()(input)
        self.writer.add_graph(model=net_cpu.cpu(), input_to_model=input)
        # self.writer.close()
        # exit(1)

    def _dis_train_iteration(self, input, reals):
        """ train one step D model.

        :param input: [?, 1, 256, 256]
        :param reals: [?, 3, 256, 256]
        :return:
        """
        self.opt_d.zero_grad()
        fakes = self.netG(input)  # [?, 3, 256, 256]
        d_fakes = self.comput_d_input(input, fakes, result="batch")
        d_reals = self.comput_d_input(input, reals, result="batch")

        gp = Branch_gradPenalty(self.netD, reals, fakes, input=input)
        sgp = spgradPenalty(self.netD, input, reals, fakes, type="G") * 0.5

        w_distance = -(d_fakes - d_reals).mean()

        loss_d = (-w_distance + gp.mean() + sgp.mean()) / 3
        # w_distance = (d_real.mean() - d_fake.mean()).detach()
        loss_d.backward()
        self.opt_d.step()

        self.losses["D"].append(loss_d.detach())
        self.losses["WD"].append(w_distance.detach())
        self.losses["GP"].append(gp.detach())
        self.losses["SGP"].append(sgp.detach())
        self._watchLoss(["D", "GP", "WD", "SGP"],
                        loss_dic=self.losses,
                        type="Train")
        d_log = "Loss_D: {:.4f}".format(loss_d.detach())
        return d_log

    def _gen_train_iteration(self, input):
        self.opt_g.zero_grad()
        fakes = self.netG(input)
        fake_NN, fake_NN_NBG_SR, fake_GAUSSIAN = self.split_dic(fakes)
        d_fake_NN, d_fake_NN_NBG_SR, d_fake_GAUSSIAN = self.comput_d_input(
            input, fake_NN, fake_NN_NBG_SR, fake_GAUSSIAN)

        loss_g = (-d_fake_NN.mean() - d_fake_NN_NBG_SR.mean() -
                  d_fake_GAUSSIAN.mean()) / 3

        self.losses["JC"].append(0)
        self.losses["G"].append(loss_g.detach())
        self._watchLoss(["JC"], loss_dic=self.losses, type="Train")
        self._watchLoss(["G"], loss_dic=self.losses, type="Train")
        g_log = "Loss_G: {:.4f}".format(loss_g.detach())

        loss_g.backward()
        self.opt_g.step()
        return g_log

    def _train_epoch(self, input, reals):
        epoch = self.epoch
        for iteration, batch in enumerate(self.train_loader, 1):
            timer = ElapsedTimer()
            self.count += 1

            real_a_cpu, real_b_cpu = batch[0], batch[1]
            input.data.resize_(real_a_cpu.size()).copy_(
                real_a_cpu)  # input data
            reals.data.resize_(real_b_cpu.size()).copy_(
                real_b_cpu
            )  # reals data list[torch(?,1,256,256),torch(?,1,256,256),torch(?,1,256,256)]

            d_log = self._dis_train_iteration(input, reals)
            g_log = self._gen_train_iteration(input)

            print("===> Epoch[{}]({}/{}): {}\t{} ".format(
                epoch, iteration, self.steps, d_log, g_log))
            one_step_cost = time.time() - timer.start_time
            left_time_one_epoch = timer.elapsed(
                (self.steps - iteration) * one_step_cost)
            print("leftTime: %s" % left_time_one_epoch)

            if iteration == 1:
                fake_NN, fake_NN_NBG_SR, fake_GAUSSIAN = self.split_dic(
                    self.netG(input))
                self._watchImg(input,
                               fake_NN,
                               reals,
                               type="Train",
                               name="in-NN-real")
                self._watchImg(input,
                               fake_NN_NBG_SR,
                               reals,
                               type="Train",
                               name="in-NN_NBG_SR-real")
                self._watchImg(input,
                               fake_GAUSSIAN,
                               reals,
                               type="Train",
                               name="in-GAUSSIAN-real")

    def train(self):
        input = Variable()
        real = Variable()
        if self.use_gpus:
            input = input.cuda()
            real = real.cuda()
        startEpoch = 1
        # netG, netD = loadCheckPoint(netG, netD, startEpoch)
        for epoch in range(startEpoch, self.nepochs + 1):
            self.epoch = epoch
            timer = ElapsedTimer()
            self._train_epoch(input, real)
            # self.valid()
            self._watchNetParams(self.netG, epoch)
            left_time = timer.elapsed(
                (self.nepochs - epoch) * (time.time() - timer.start_time))
            print("leftTime: %s" % left_time)
            if epoch == 10:
                self.lr = self.lr / 10
                self.opt_g = Adam(net_G.parameters(),
                                  lr=self.lr,
                                  betas=(0.9, 0.99),
                                  weight_decay=self.weight_decay)
                self.opt_d = RMSprop(net_D.parameters(),
                                     lr=self.lr,
                                     weight_decay=self.weight_decay)
                print("change learning rate to %s" % self.lr)
            elif epoch == 20:
                self.lr = self.lr / 10
                self.opt_g = Adam(net_G.parameters(),
                                  lr=self.lr,
                                  betas=(0.9, 0.99),
                                  weight_decay=self.weight_decay)
                self.opt_d = RMSprop(net_D.parameters(),
                                     lr=self.lr,
                                     weight_decay=self.weight_decay)
                print("change learning rate to %s" % self.lr)
            elif epoch == 40:
                self.lr = self.lr / 10
                self.opt_g = Adam(net_G.parameters(),
                                  lr=self.lr,
                                  betas=(0.9, 0.99),
                                  weight_decay=self.weight_decay)
                self.opt_d = RMSprop(net_D.parameters(),
                                     lr=self.lr,
                                     weight_decay=self.weight_decay)
                print("change learning rate to %s" % self.lr)
            if epoch % 10 == 0:
                self.predict()
                checkPoint(net_G, net_D, epoch)
        self.writer.close()

    def _watchNetParams(self, net, count):
        for name, param in net.named_parameters():
            if "bias" in name:
                continue
            self.writer.add_histogram(name,
                                      param.clone().cpu().data.numpy(),
                                      count,
                                      bins="auto")

    def _watchLoss(self, loss_keys, loss_dic, type="Train"):
        for key in loss_keys:
            self.writer.add_scalars(key, {type: loss_dic[key][-1]}, self.count)

    def _watchImg(self,
                  input,
                  fake,
                  real,
                  type="Train",
                  show_imgs_num=3,
                  name="in-pred-real"):
        out = None
        input_torch = None
        prediction_torch = None
        real_torch = None
        batchSize = input.shape[0]
        show_nums = min(show_imgs_num, batchSize)
        randindex_list = random.sample(list(range(batchSize)), show_nums)
        for randindex in randindex_list:
            input_torch = input[randindex].cpu().detach()
            input_torch = transforms.Normalize([-1], [2])(input_torch)

            prediction_torch = fake[randindex].cpu().detach()
            prediction_torch = transforms.Normalize([-1],
                                                    [2])(prediction_torch)

            real_torch = real[randindex].cpu().detach()
            real_torch = transforms.Normalize([-1], [2])(real_torch)
            out_1 = torch.stack((input_torch, prediction_torch, real_torch))
            if out is None:
                out = out_1
            else:
                out = torch.cat((out_1, out))
        out = make_grid(out, nrow=3)
        self.writer.add_image('%s-%s' % (type, name), out, self.epoch)

        input = transforms.ToPILImage()(input_torch).convert("L")
        prediction = transforms.ToPILImage()(prediction_torch).convert("L")
        real = transforms.ToPILImage()(real_torch).convert("L")

        in_filename = "plots/%s/E%03d_in_.png" % (type, self.epoch)
        real_filename = "plots/%s/E%03d_real_.png" % (type, self.epoch)
        out_filename = "plots/%s/E%03d_out_.png" % (type, self.epoch)
        input.save(in_filename)
        prediction.save(out_filename)
        real.save(real_filename)

    def predict(self):
        for input, real in self.test_loader:
            input = Variable(input)
            real = Variable(real)
            if self.use_gpus:
                input = input.cuda()
                real = real.cuda()
            self.netG.eval()
            fake = self.netG(input).detach().detach()
            self.netG.zero_grad()
            self._watchImg(input, fake, real, type="Test", show_imgs_num=8)

        self.netG.train()

    def comput_d_input(self, input, *fakes, result="split"):
        shape = (input.shape[0], 1, 16, 16)
        if len(fakes) == 1:
            fake_NN, fake_NN_NBG_SR, fake_GAUSSIAN = self.split_dic(fakes[0])
            d_fake_NN = self.netD(fake_NN, input)[:, NN, :, :].view(
                shape[0], 1, shape[2], shape[3])
            d_fake_NN_NBG_SR = self.netD(fake_NN_NBG_SR,
                                         input)[:, NN_NBG_SR, :, :].view(
                                             shape[0], 1, shape[2], shape[3])
            d_fake_GAUSSIAN = self.netD(fake_GAUSSIAN,
                                        input)[:, GAUSSIAN, :, :].view(
                                            shape[0], 1, shape[2], shape[3])
        elif len(fakes) == 3:
            d_fake_NN = self.netD(fakes[NN], input)[:, NN, :, :].view(
                shape[0], 1, shape[2], shape[3])
            d_fake_NN_NBG_SR = self.netD(fakes[NN_NBG_SR],
                                         input)[:, NN_NBG_SR, :, :].view(
                                             shape[0], 1, shape[2], shape[3])
            d_fake_GAUSSIAN = self.netD(fakes[GAUSSIAN],
                                        input)[:, GAUSSIAN, :, :].view(
                                            shape[0], 1, shape[2], shape[3])
        else:
            d_fake_NN, d_fake_NN_NBG_SR, d_fake_GAUSSIAN = None, None, None

        if result == "split":
            return d_fake_NN, d_fake_NN_NBG_SR, d_fake_GAUSSIAN
        elif result == "batch":
            return torch.cat([d_fake_NN, d_fake_NN_NBG_SR, d_fake_GAUSSIAN], 1)

    def split_dic(self, fakes):
        """[?,3,256,256] => [?,1,256,256],[?,1,256,256],[?,1,256,256]

        :param fakes:
        :return:
        """
        shape = fakes.shape
        fake_NN = fakes[:, NN, :, :].view(shape[0], 1, shape[2], shape[3])
        fake_NN_NBG_SR = fakes[:,
                               NN_NBG_SR, :, :].view(shape[0], 1, shape[2],
                                                     shape[3])
        fake_GAUSSIAN = fakes[:, GAUSSIAN, :, :].view(shape[0], 1, shape[2],
                                                      shape[3])
        return fake_NN, fake_NN_NBG_SR, fake_GAUSSIAN

    def valid(self):
        avg_loss_g = 0
        avg_loss_d = 0
        avg_w_distance = 0
        # netG = netG._d
        self.netG.eval()
        self.netD.eval()
        input = Variable()
        real = Variable()
        if self.use_gpus:
            input = input.cuda()
            real = real.cuda()
        len_test_data = len(self.cv_loader)
        for iteration, batch in enumerate(self.cv_loader, 1):
            input.data.resize_(batch[0].size()).copy_(batch[0])  # input data
            real.data.resize_(batch[1].size()).copy_(batch[1])  # real data
            ## 计算G的LOSS
            fake = self.netG(input).detach()
            d_fake = self.netD(fake, input).detach()
            loss_g = -d_fake.mean()

            # 计算D的LOSS
            d_real = self.netD(real, input).detach()
            gp = gradPenalty(self.netD, real, fake, input=input)
            loss_d = d_fake.mean() - d_real.mean() + gp
            w_distance = d_real.mean() - d_fake.mean()
            # 求和
            avg_w_distance += w_distance.detach()
            avg_loss_d += loss_d.detach()
            avg_loss_g += loss_g.detach()
        avg_w_distance = avg_w_distance / len_test_data
        avg_loss_d = avg_loss_d / len_test_data
        avg_loss_g = avg_loss_g / len_test_data
        self.valid_losses["WD"].append(avg_w_distance)
        self.valid_losses["D"].append(avg_loss_d)
        self.valid_losses["G"].append(avg_loss_g)
        # print("===> CV_Loss_D: {:.4f} CV_WD:{:.4f} CV_Loss_G: {:.4f}".format(avg_loss_d, avg_w_distance, avg_loss_g))
        self._watchLoss(["D", "G", "WD"],
                        loss_dic=self.valid_losses,
                        type="Valid")
        self._watchImg(input, fake, real, type="Valid")
        self.netG.train()
        self.netD.train()
        self.netG.zero_grad()
        self.netD.zero_grad()

        return avg_w_distance
class ActorCriticLearner:
    def __init__(self, mac, scheme, logger, args):
        self.args = args
        self.n_agents = args.n_agents
        self.n_actions = args.n_actions
        self.logger = logger

        self.mac = mac
        self.agent_params = list(mac.parameters())
        self.agent_optimiser = RMSprop(params=self.agent_params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps)
        self.params = self.agent_params


        if args.critic_q_fn == "coma":
            self.critic = COMACritic(scheme, args)
        elif args.critic_q_fn == "centralV":
            self.critic = CentralVCritic(scheme, args)
        self.target_critic = copy.deepcopy(self.critic)

        self.critic_params = list(self.critic.parameters())
        self.params += self.critic_params
        self.critic_optimiser = RMSprop(params=self.critic_params, lr=args.critic_lr, alpha=args.optim_alpha,
                                        eps=args.optim_eps)

        self.separate_baseline_critic = False
        if args.critic_q_fn != args.critic_baseline_fn:
            self.separate_baseline_critic = True
            if args.critic_baseline_fn == "coma":
                self.baseline_critic = COMACritic(scheme, args)
            elif args.critic_baseline_fn == "centralV":
                self.baseline_critic = CentralVCritic(scheme, args)
            self.target_baseline_critic = copy.deepcopy(self.baseline_critic)

            self.baseline_critic_params = list(self.baseline_critic.parameters())
            self.params += self.baseline_critic_params
            self.baseline_critic_optimiser = RMSprop(params=self.baseline_critic_params, lr=args.critic_lr,
                                                     alpha=args.optim_alpha,
                                                     eps=args.optim_eps)

        if args.critic_train_mode == "seq":
            self.critic_train_fn = self.train_critic_sequential
        elif args.critic_train_mode == "batch":
            self.critic_train_fn = self.train_critic_batched
        else:
            raise ValueError

        self.last_target_update_step = 0
        self.critic_training_steps = 0
        self.log_stats_t = -self.args.learner_log_interval - 1

    def train(self, batch: EpisodeBatch, t_env: int, episode_num: int):
        # Get the relevant quantities

        rewards = batch["reward"][:, :-1]
        actions = batch["actions"][:, :]
        terminated = batch["terminated"][:, :-1].float()
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        avail_actions = batch["avail_actions"][:, :-1]

        # No experiences to train on in this minibatch
        if mask.sum() == 0:
            self.logger.log_stat("Mask_Sum_Zero", 1, t_env)
            self.logger.console_logger.error("Actor Critic Learner: mask.sum() == 0 at t_env {}".format(t_env))
            return

        mask = mask.repeat(1, 1, self.n_agents)

        critic_mask = mask.clone()
        baseline_critic_mask = mask.clone()

        mac_out = []
        self.mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length - 1):
            agent_outs = self.mac.forward(batch, t=t)
            mac_out.append(agent_outs)
        mac_out = th.stack(mac_out, dim=1)  # Concat over time

        # Mask out unavailable actions, renormalise (as in action selection)
        mac_out[avail_actions == 0] = 0
        mac_out = mac_out/mac_out.sum(dim=-1, keepdim=True)
        mac_out[avail_actions == 0] = 0

        pi = mac_out

        for _ in range(self.args.critic_train_reps):
            q_sa, v_s, critic_train_stats = self.critic_train_fn(self.critic, self.target_critic, self.critic_optimiser, batch,
                                                                 rewards, terminated, actions, avail_actions, critic_mask)

        if self.separate_baseline_critic:
            for _ in range(self.args.critic_train_reps):
                q_sa_baseline, v_s_baseline, critic_train_stats_baseline = \
                    self.critic_train_fn(self.baseline_critic, self.target_baseline_critic, self.baseline_critic_optimiser,
                                         batch, rewards, terminated, actions, avail_actions, baseline_critic_mask)
            if self.args.critic_baseline_fn == "coma":
                baseline = (q_sa_baseline * pi).sum(-1).detach()
            else:
                baseline = v_s_baseline
        else:
            if self.args.critic_baseline_fn == "coma":
                baseline = (q_sa * pi).sum(-1).detach()
            else:
                baseline = v_s

        actions = actions[:,:-1]

        if self.critic.output_type == "q":
            q_sa = th.gather(q_sa, dim=3, index=actions).squeeze(3)
            if self.args.critic_q_fn == "coma" and self.args.coma_mean_q:
                q_sa = q_sa.mean(2, keepdim=True).expand(-1, -1, self.n_agents)
        q_sa = self.nstep_returns(rewards, mask, q_sa, self.args.q_nstep)

        advantages = (q_sa - baseline).detach().squeeze()

        # Calculate policy grad with mask

        pi_taken = th.gather(pi, dim=3, index=actions).squeeze(3)
        pi_taken[mask == 0] = 1.0
        log_pi_taken = th.log(pi_taken)

        pg_loss = - ((advantages * log_pi_taken) * mask).sum() / mask.sum()

        # Optimise agents
        self.agent_optimiser.zero_grad()
        pg_loss.backward()
        grad_norm = th.nn.utils.clip_grad_norm_(self.agent_params, self.args.grad_norm_clip)
        self.agent_optimiser.step()

        if (self.critic_training_steps - self.last_target_update_step) / self.args.target_update_interval >= 1.0:
            self._update_targets()
            self.last_target_update_step = self.critic_training_steps

        if t_env - self.log_stats_t >= self.args.learner_log_interval:
            ts_logged = len(critic_train_stats["critic_loss"])
            for key in ["critic_loss", "critic_grad_norm", "td_error_abs", "q_taken_mean", "target_mean"]:
                self.logger.log_stat(key, sum(critic_train_stats[key])/ts_logged, t_env)

            self.logger.log_stat("advantage_mean", (advantages * mask).sum().item() / mask.sum().item(), t_env)
            self.logger.log_stat("pg_loss", pg_loss.item(), t_env)
            self.logger.log_stat("agent_grad_norm", grad_norm, t_env)
            self.logger.log_stat("pi_max", (pi.max(dim=-1)[0] * mask).sum().item() / mask.sum().item(), t_env)
            self.log_stats_t = t_env

    def train_critic_sequential(self, critic, target_critic, optimiser, batch, rewards, terminated, actions,
                                avail_actions, mask):
        # Optimise critic
        target_vals = target_critic(batch)

        all_vals = th.zeros_like(target_vals)

        if critic.output_type == 'q':
            target_vals = th.gather(target_vals, dim=3, index=actions)
            # target_vals = th.cat([target_vals[:, 1:], th.zeros_like(target_vals[:, 0:1])], dim=1)
        target_vals = target_vals.squeeze(3)

        # Calculate td-lambda targets
        targets = build_td_lambda_targets(rewards, terminated, mask, target_vals, self.n_agents,
                                          self.args.gamma, self.args.td_lambda)

        running_log = {
            "critic_loss": [],
            "critic_grad_norm": [],
            "td_error_abs": [],
            "target_mean": [],
            "q_taken_mean": [],
        }

        for t in reversed(range(rewards.size(1) + 1)):
            vals_t = critic(batch, t)
            all_vals[:, t] = vals_t.squeeze(1)

            if t == rewards.size(1):
                continue

            mask_t = mask[:, t]
            if mask_t.sum() == 0:
                continue

            if critic.output_type == "q":
                vals_t = th.gather(vals_t, dim=3, index=actions[:, t:t+1]).squeeze(3).squeeze(1)
            else:
                vals_t = vals_t.squeeze(3).squeeze(1)
            targets_t = targets[:, t]

            td_error = (vals_t - targets_t.detach())

            # 0-out the targets that came from padded data
            masked_td_error = td_error * mask_t

            # Normal L2 loss, take mean over actual data
            loss = (masked_td_error ** 2).sum() / mask_t.sum()  # Not dividing by number of agents, only # valid timesteps
            optimiser.zero_grad()
            loss.backward()
            grad_norm = th.nn.utils.clip_grad_norm_(optimiser.param_groups[0]["params"], self.args.grad_norm_clip)
            optimiser.step()
            self.critic_training_steps += 1

            running_log["critic_loss"].append(loss.item())
            running_log["critic_grad_norm"].append(grad_norm)
            mask_elems = mask_t.sum().item()
            running_log["td_error_abs"].append((masked_td_error.abs().sum().item() / mask_elems))
            running_log["q_taken_mean"].append((vals_t * mask_t).sum().item() / mask_elems)
            running_log["target_mean"].append((targets_t * mask_t).sum().item() / mask_elems)

        if critic.output_type == 'q':
            q_vals = all_vals[:, :-1]
            v_s = None
        else:
            q_vals = all_vals[:, :-1].squeeze(3)
            v_s = all_vals[:, :-1].squeeze(3)

        return q_vals, v_s, running_log

    def nstep_returns(self, rewards, mask, values, nsteps):
        nstep_values = th.zeros_like(values)
        for t_start in range(rewards.size(1)):
            nstep_return_t = th.zeros_like(values[:, 0])
            for step in range(nsteps + 1):
                t = t_start + step
                if t >= rewards.size(1):
                    break
                elif step == nsteps:
                    nstep_return_t += self.args.gamma ** (step) * values[:, t] * mask[:, t]
                elif t == rewards.size(1) - 1:
                    nstep_return_t += self.args.gamma ** (step) * values[:, t] * mask[:, t]
                else:
                    nstep_return_t += self.args.gamma ** (step) * rewards[:, t] * mask[:, t]
            nstep_values[:, t_start, :] = nstep_return_t
        return nstep_values

    def train_critic_batched(self, critic, target_critic, optimiser, batch, rewards, terminated, actions,
                             avail_actions, mask):
        # Optimise critic
        target_vals = target_critic(batch)

        target_vals = target_vals[:, :-1]

        if critic.output_type == 'q':
            target_vals = th.gather(target_vals, dim=3, index=actions)
            target_vals = th.cat([target_vals[:, 1:], th.zeros_like(target_vals[:, 0:1])], dim=1)
        target_vals = target_vals.squeeze(3)

        # Calculate td-lambda targets
        targets = build_td_lambda_targets(rewards, terminated, mask, target_vals, self.n_agents,
                                         self.args.gamma, self.args.td_lambda)

        running_log = {
            "critic_loss": [],
            "critic_grad_norm": [],
            "td_error_abs": [],
            "target_mean": [],
            "q_taken_mean": [],
        }

        all_vals = critic(batch)
        vals = all_vals.clone()[:, :-1]

        if critic.output_type == "q":
            vals = th.gather(vals, dim=3, index=actions)
        vals = vals.squeeze(3)

        td_error = (vals - targets.detach())

        # 0-out the targets that came from padded data
        masked_td_error = td_error * mask

        # Normal L2 loss, take mean over actual data
        loss = (masked_td_error ** 2).sum() / mask.sum()
        optimiser.zero_grad()
        loss.backward()
        grad_norm = th.nn.utils.clip_grad_norm_(optimiser.param_groups[0]["params"], self.args.grad_norm_clip)
        optimiser.step()
        self.critic_training_steps += 1

        running_log["critic_loss"].append(loss.item())
        running_log["critic_grad_norm"].append(grad_norm)
        mask_elems = mask.sum().item()
        running_log["td_error_abs"].append((masked_td_error.abs().sum().item() / mask_elems))
        running_log["q_taken_mean"].append((vals * mask).sum().item() / mask_elems)
        running_log["target_mean"].append((targets * mask).sum().item() / mask_elems)

        if critic.output_type == 'q':
            q_vals = all_vals[:, :-1]
            v_s = None
        else:
            q_vals = build_td_lambda_targets(rewards, terminated, mask, all_vals.squeeze(3)[:, 1:], self.n_agents,
                                             self.args.gamma, self.args.td_lambda)
            v_s = vals

        return q_vals, v_s, running_log

    def _update_targets(self):
        self.target_critic.load_state_dict(self.critic.state_dict())
        if self.separate_baseline_critic:
            self.target_baseline_critic.load_state_dict(self.baseline_critic.state_dict())
        self.logger.console_logger.info("Updated target network")

    def cuda(self):
        self.mac.cuda()
        self.critic.cuda()
        self.target_critic.cuda()
        if self.separate_baseline_critic:
            self.baseline_critic.cuda()
            self.target_baseline_critic.cuda()
Esempio n. 17
0
class DDQN_Agent:

    def __init__(self, args, exp_model, logging_func):
        self.args = args

        # Exploration Model
        self.exp_model = exp_model

        self.log = logging_func["log"]

        # Experience Replay
        if self.args.set_replay:
            self.replay = ExpReplaySet(10, 10, exp_model, args, priority=False)
        else:
            self.replay = ExpReplay(args.exp_replay_size, args.stale_limit, exp_model, args, priority=self.args.prioritized)

        # DQN and Target DQN
        model = get_models(args.model)
        self.dqn = model(actions=args.actions)
        self.target_dqn = model(actions=args.actions)

        dqn_params = 0
        for weight in self.dqn.parameters():
            weight_params = 1
            for s in weight.size():
                weight_params *= s
            dqn_params += weight_params
        print("DQN has {:,} parameters.".format(dqn_params))

        self.target_dqn.eval()

        if args.gpu:
            print("Moving models to GPU.")
            self.dqn.cuda()
            self.target_dqn.cuda()

        # Optimizer
        # self.optimizer = Adam(self.dqn.parameters(), lr=args.lr)
        self.optimizer = RMSprop(self.dqn.parameters(), lr=args.lr)

        self.T = 0
        self.target_sync_T = -self.args.t_max

    def sync_target_network(self):
        for target, source in zip(self.target_dqn.parameters(), self.dqn.parameters()):
            target.data = source.data

    def act(self, state, epsilon, exp_model, evaluation=False):
        # self.T += 1
        self.dqn.eval()
        orig_state = state[:, :, -1:]
        state = torch.from_numpy(state).float().transpose_(0, 2).unsqueeze(0)
        q_values = self.dqn(Variable(state, volatile=True)).cpu().data[0]
        q_values_numpy = q_values.numpy()

        extra_info = {}

        if self.args.optimistic_init and not evaluation:
            q_values_pre_bonus = np.copy(q_values_numpy)
            if not self.args.ucb:
                for a in range(self.args.actions):
                    _, info = exp_model.bonus(orig_state, a, dont_remember=True)
                    action_pseudo_count = info["Pseudo_Count"]
                    # TODO: Log the optimism bonuses
                    optimism_bonus = self.args.optimistic_scaler / np.power(action_pseudo_count + 0.01, self.args.bandit_p)
                    if self.args.tb and self.T % self.args.tb_interval == 0:
                        self.log("Bandit/Action_{}".format(a), optimism_bonus, step=self.T)
                    q_values[a] += optimism_bonus
            else:
                action_counts = []
                for a in range(self.args.actions):
                    _, info = exp_model.bonus(orig_state, a, dont_remember=True)
                    action_pseudo_count = info["Pseudo_Count"]
                    action_counts.append(action_pseudo_count)
                total_count = sum(action_counts)
                for ai, a in enumerate(action_counts):
                    # TODO: Log the optimism bonuses
                    optimisim_bonus = self.args.optimistic_scaler * np.sqrt(2 * np.log(max(1, total_count)) / (a + 0.01))
                    self.log("Bandit/UCB/Action_{}".format(ai), optimisim_bonus, step=self.T)
                    q_values[ai] += optimisim_bonus

            extra_info["Action_Bonus"] = q_values_numpy - q_values_pre_bonus

        extra_info["Q_Values"] = q_values_numpy

        if np.random.random() < epsilon:
            action = np.random.randint(low=0, high=self.args.actions)
        else:
            action = q_values.max(0)[1][0]  # Torch...

        extra_info["Action"] = action

        return action, extra_info

    def experience(self, state, action, reward, state_next, steps, terminated, pseudo_reward=0, density=1, exploring=False):
        if not exploring:
            self.T += 1
        self.replay.Add_Exp(state, action, reward, state_next, steps, terminated, pseudo_reward, density)

    def end_of_trajectory(self):
        self.replay.end_of_trajectory()

    def train(self):

        if self.T - self.target_sync_T > self.args.target:
            self.sync_target_network()
            self.target_sync_T = self.T

        info = {}

        for _ in range(self.args.iters):
            self.dqn.eval()

            # TODO: Use a named tuple for experience replay
            n_step_sample = 1
            if np.random.random() < self.args.n_step_mixing:
                n_step_sample = self.args.n_step
            batch, indices, is_weights = self.replay.Sample_N(self.args.batch_size, n_step_sample, self.args.gamma)
            columns = list(zip(*batch))

            states = Variable(torch.from_numpy(np.array(columns[0])).float().transpose_(1, 3))
            actions = Variable(torch.LongTensor(columns[1]))
            terminal_states = Variable(torch.FloatTensor(columns[5]))
            rewards = Variable(torch.FloatTensor(columns[2]))
            # Have to clip rewards for DQN
            rewards = torch.clamp(rewards, -1, 1)
            steps = Variable(torch.FloatTensor(columns[4]))
            new_states = Variable(torch.from_numpy(np.array(columns[3])).float().transpose_(1, 3))

            target_dqn_qvals = self.target_dqn(new_states).cpu()
            # Make a new variable with those values so that these are treated as constants
            target_dqn_qvals_data = Variable(target_dqn_qvals.data)

            q_value_targets = (Variable(torch.ones(terminal_states.size()[0])) - terminal_states)
            inter = Variable(torch.ones(terminal_states.size()[0]) * self.args.gamma)
            # print(steps)
            q_value_targets = q_value_targets * torch.pow(inter, steps)
            if self.args.double:
                # Double Q Learning
                new_states_qvals = self.dqn(new_states).cpu()
                new_states_qvals_data = Variable(new_states_qvals.data)
                q_value_targets = q_value_targets * target_dqn_qvals_data.gather(1, new_states_qvals_data.max(1)[1])
            else:
                q_value_targets = q_value_targets * target_dqn_qvals_data.max(1)[0]
            q_value_targets = q_value_targets + rewards

            self.dqn.train()
            if self.args.gpu:
                actions = actions.cuda()
                q_value_targets = q_value_targets.cuda()
            model_predictions = self.dqn(states).gather(1, actions.view(-1, 1))

            # info = {}

            td_error = model_predictions - q_value_targets
            info["TD_Error"] = td_error.mean().data[0]

            # Update the priorities
            if not self.args.density_priority:
                self.replay.Update_Indices(indices, td_error.cpu().data.numpy(), no_pseudo_in_priority=self.args.count_td_priority)

            # If using prioritised we need to weight the td_error
            if self.args.prioritized and self.args.prioritized_is:
                # print(td_error)
                weights_tensor = torch.from_numpy(is_weights).float()
                weights_tensor = Variable(weights_tensor)
                if self.args.gpu:
                    weights_tensor = weights_tensor.cuda()
                # print(weights_tensor)
                td_error = td_error * weights_tensor
            l2_loss = (td_error).pow(2).mean()
            info["Loss"] = l2_loss.data[0]

            # Update
            self.optimizer.zero_grad()
            l2_loss.backward()

            # Taken from pytorch clip_grad_norm
            # Remove once the pip version it up to date with source
            gradient_norm = clip_grad_norm(self.dqn.parameters(), self.args.clip_value)
            if gradient_norm is not None:
                info["Norm"] = gradient_norm

            self.optimizer.step()

            if "States" in info:
                states_trained = info["States"]
                info["States"] = states_trained + columns[0]
            else:
                info["States"] = columns[0]

        # Pad out the states to be of size batch_size
        if len(info["States"]) < self.args.batch_size:
            old_states = info["States"]
            new_states = old_states[0] * (self.args.batch_size - len(old_states))
            info["States"] = new_states

        return info
Esempio n. 18
0
class QMixPolicyGraph(PolicyGraph):
    """QMix impl. Assumes homogeneous agents for now.

    You must use MultiAgentEnv.with_agent_groups() to group agents
    together for QMix. This creates the proper Tuple obs/action spaces and
    populates the '_group_rewards' info field.

    Action masking: to specify an action mask for individual agents, use a
    dict space with an action_mask key, e.g. {"obs": ob, "action_mask": mask}.
    The mask space must be `Box(0, 1, (n_actions,))`.
    """

    def __init__(self, obs_space, action_space, config):
        _validate(obs_space, action_space)
        config = dict(ray.rllib.agents.qmix.qmix.DEFAULT_CONFIG, **config)
        self.config = config
        self.observation_space = obs_space
        self.action_space = action_space
        self.n_agents = len(obs_space.original_space.spaces)
        self.n_actions = action_space.spaces[0].n
        self.h_size = config["model"]["lstm_cell_size"]

        agent_obs_space = obs_space.original_space.spaces[0]
        if isinstance(agent_obs_space, Dict):
            space_keys = set(agent_obs_space.spaces.keys())
            if space_keys != {"obs", "action_mask"}:
                raise ValueError(
                    "Dict obs space for agent must have keyset "
                    "['obs', 'action_mask'], got {}".format(space_keys))
            mask_shape = tuple(agent_obs_space.spaces["action_mask"].shape)
            if mask_shape != (self.n_actions, ):
                raise ValueError("Action mask shape must be {}, got {}".format(
                    (self.n_actions, ), mask_shape))
            self.has_action_mask = True
            self.obs_size = _get_size(agent_obs_space.spaces["obs"])
            # The real agent obs space is nested inside the dict
            agent_obs_space = agent_obs_space.spaces["obs"]
        else:
            self.has_action_mask = False
            self.obs_size = _get_size(agent_obs_space)

        self.model = ModelCatalog.get_torch_model(
            agent_obs_space,
            self.n_actions,
            config["model"],
            default_model_cls=RNNModel)
        self.target_model = ModelCatalog.get_torch_model(
            agent_obs_space,
            self.n_actions,
            config["model"],
            default_model_cls=RNNModel)

        # Setup the mixer network.
        # The global state is just the stacked agent observations for now.
        self.state_shape = [self.obs_size, self.n_agents]
        if config["mixer"] is None:
            self.mixer = None
            self.target_mixer = None
        elif config["mixer"] == "qmix":
            self.mixer = QMixer(self.n_agents, self.state_shape,
                                config["mixing_embed_dim"])
            self.target_mixer = QMixer(self.n_agents, self.state_shape,
                                       config["mixing_embed_dim"])
        elif config["mixer"] == "vdn":
            self.mixer = VDNMixer()
            self.target_mixer = VDNMixer()
        else:
            raise ValueError("Unknown mixer type {}".format(config["mixer"]))

        self.cur_epsilon = 1.0
        self.update_target()  # initial sync

        # Setup optimizer
        self.params = list(self.model.parameters())
        self.loss = QMixLoss(self.model, self.target_model, self.mixer,
                             self.target_mixer, self.n_agents, self.n_actions,
                             self.config["double_q"], self.config["gamma"])
        self.optimiser = RMSprop(
            params=self.params,
            lr=config["lr"],
            alpha=config["optim_alpha"],
            eps=config["optim_eps"])

    @override(PolicyGraph)
    def compute_actions(self,
                        obs_batch,
                        state_batches=None,
                        prev_action_batch=None,
                        prev_reward_batch=None,
                        info_batch=None,
                        episodes=None,
                        **kwargs):
        obs_batch, action_mask = self._unpack_observation(obs_batch)

        # Compute actions
        with th.no_grad():
            q_values, hiddens = _mac(
                self.model, th.from_numpy(obs_batch),
                [th.from_numpy(np.array(s)) for s in state_batches])
            avail = th.from_numpy(action_mask).float()
            masked_q_values = q_values.clone()
            masked_q_values[avail == 0.0] = -float("inf")
            # epsilon-greedy action selector
            random_numbers = th.rand_like(q_values[:, :, 0])
            pick_random = (random_numbers < self.cur_epsilon).long()
            random_actions = Categorical(avail).sample().long()
            actions = (pick_random * random_actions +
                       (1 - pick_random) * masked_q_values.max(dim=2)[1])
            actions = actions.numpy()
            hiddens = [s.numpy() for s in hiddens]

        return TupleActions(list(actions.transpose([1, 0]))), hiddens, {}

    @override(PolicyGraph)
    def learn_on_batch(self, samples):
        obs_batch, action_mask = self._unpack_observation(samples["obs"])
        group_rewards = self._get_group_rewards(samples["infos"])

        # These will be padded to shape [B * T, ...]
        [rew, action_mask, act, dones, obs], initial_states, seq_lens = \
            chop_into_sequences(
                samples["eps_id"],
                samples["agent_index"], [
                    group_rewards, action_mask, samples["actions"],
                    samples["dones"], obs_batch
                ],
                [samples["state_in_{}".format(k)]
                 for k in range(len(self.get_initial_state()))],
                max_seq_len=self.config["model"]["max_seq_len"],
                dynamic_max=True,
                _extra_padding=1)
        # TODO(ekl) adding 1 extra unit of padding here, since otherwise we
        # lose the terminating reward and the Q-values will be unanchored!
        B, T = len(seq_lens), max(seq_lens) + 1

        def to_batches(arr):
            new_shape = [B, T] + list(arr.shape[1:])
            return th.from_numpy(np.reshape(arr, new_shape))

        rewards = to_batches(rew)[:, :-1].float()
        actions = to_batches(act)[:, :-1].long()
        obs = to_batches(obs).reshape([B, T, self.n_agents,
                                       self.obs_size]).float()
        action_mask = to_batches(action_mask)

        # TODO(ekl) this treats group termination as individual termination
        terminated = to_batches(dones.astype(np.float32)).unsqueeze(2).expand(
            B, T, self.n_agents)[:, :-1]
        filled = (np.reshape(np.tile(np.arange(T), B), [B, T]) <
                  np.expand_dims(seq_lens, 1)).astype(np.float32)
        mask = th.from_numpy(filled).unsqueeze(2).expand(B, T,
                                                         self.n_agents)[:, :-1]
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])

        # Compute loss
        loss_out, mask, masked_td_error, chosen_action_qvals, targets = \
            self.loss(rewards, actions, terminated, mask, obs, action_mask)

        # Optimise
        self.optimiser.zero_grad()
        loss_out.backward()
        grad_norm = th.nn.utils.clip_grad_norm_(
            self.params, self.config["grad_norm_clipping"])
        self.optimiser.step()

        mask_elems = mask.sum().item()
        stats = {
            "loss": loss_out.item(),
            "grad_norm": grad_norm
            if isinstance(grad_norm, float) else grad_norm.item(),
            "td_error_abs": masked_td_error.abs().sum().item() / mask_elems,
            "q_taken_mean": (chosen_action_qvals * mask).sum().item() /
            mask_elems,
            "target_mean": (targets * mask).sum().item() / mask_elems,
        }
        return {"stats": stats}, {}

    @override(PolicyGraph)
    def get_initial_state(self):
        return [
            s.expand([self.n_agents, -1]).numpy()
            for s in self.model.state_init()
        ]

    @override(PolicyGraph)
    def get_weights(self):
        return {"model": self.model.state_dict()}

    @override(PolicyGraph)
    def set_weights(self, weights):
        self.model.load_state_dict(weights["model"])

    @override(PolicyGraph)
    def get_state(self):
        return {
            "model": self.model.state_dict(),
            "target_model": self.target_model.state_dict(),
            "mixer": self.mixer.state_dict() if self.mixer else None,
            "target_mixer": self.target_mixer.state_dict()
            if self.mixer else None,
            "cur_epsilon": self.cur_epsilon,
        }

    @override(PolicyGraph)
    def set_state(self, state):
        self.model.load_state_dict(state["model"])
        self.target_model.load_state_dict(state["target_model"])
        if state["mixer"] is not None:
            self.mixer.load_state_dict(state["mixer"])
            self.target_mixer.load_state_dict(state["target_mixer"])
        self.set_epsilon(state["cur_epsilon"])
        self.update_target()

    def update_target(self):
        self.target_model.load_state_dict(self.model.state_dict())
        if self.mixer is not None:
            self.target_mixer.load_state_dict(self.mixer.state_dict())
        logger.debug("Updated target networks")

    def set_epsilon(self, epsilon):
        self.cur_epsilon = epsilon

    def _get_group_rewards(self, info_batch):
        group_rewards = np.array([
            info.get(GROUP_REWARDS, [0.0] * self.n_agents)
            for info in info_batch
        ])
        return group_rewards

    def _unpack_observation(self, obs_batch):
        """Unpacks the action mask / tuple obs from agent grouping.

        Returns:
            obs (Tensor): flattened obs tensor of shape [B, n_agents, obs_size]
            mask (Tensor): action mask, if any
        """
        unpacked = _unpack_obs(
            np.array(obs_batch),
            self.observation_space.original_space,
            tensorlib=np)
        if self.has_action_mask:
            obs = np.concatenate(
                [o["obs"] for o in unpacked],
                axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size])
            action_mask = np.concatenate(
                [o["action_mask"] for o in unpacked], axis=1).reshape(
                    [len(obs_batch), self.n_agents, self.n_actions])
        else:
            obs = np.concatenate(
                unpacked,
                axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size])
            action_mask = np.ones(
                [len(obs_batch), self.n_agents, self.n_actions])
        return obs, action_mask
Esempio n. 19
0
class CentralVLearner(BasicLearner):

    def __init__(self, multiagent_controller, logging_struct=None, args=None):
        self.args = args
        self.multiagent_controller = multiagent_controller
        self.n_agents = multiagent_controller.n_agents
        self.n_actions = self.multiagent_controller.n_actions
        for _i in range(1, 4):
            setattr(self, "T_policy_level{}".format(_i), 0)
            setattr(self, "T_critic_level{}".format(_i), 0)

        self.stats = {}
        self.logging_struct = logging_struct

        self.critic_class = CentralVCritic

        self.critic_scheme = Scheme([dict(name="actions_onehot",
                                          rename="past_actions",
                                          select_agent_ids=list(range(self.n_agents)),
                                          transforms=[("shift", dict(steps=1)),
                                                     ],
                                          switch=self.args.critic_use_past_actions),
                                     dict(name="state")
                                   ])
        self.target_critic_scheme = self.critic_scheme

        # Set up schemes
        self.scheme = {}
        # level 1

        self.scheme["critic"] = self.critic_scheme
        self.scheme["target_critic"] = self.critic_scheme

        # create joint scheme from the critic scheme
        self.joint_scheme_dict = _join_dicts(self.scheme,
                                             self.multiagent_controller.joint_scheme_dict)

        # construct model-specific input regions
        self.input_columns = {}
        self.input_columns["critic"] = {"vfunction":Scheme([{"name":"state"},
                                                           ])}
        self.input_columns["target_critic"] = self.input_columns["critic"]

        # for _i in range(self.n_agents):
        #     self.input_columns["critic__agent{}".format(_i)] = {"vfunction":Scheme([{"name":"state"},
        #                                                                             #{"name":"past_actions",
        #                                                                             # "select_agent_ids":list(range(self.n_agents))},
        #                                                                             #{"name": "actions",
        #                                                                             # "select_agent_ids": list(
        #                                                                             #     range(self.n_agents))}
        #                                                                             ])}
        #
        # for _i in range(self.n_agents):
        #     self.input_columns["target_critic__agent{}".format(_i)] = self.input_columns["critic__agent{}".format(_i)]


        self.last_target_update_T_critic = 0
        self.T_critic = 0
        self.T_policy = 0

        self.policy_loss_class = CentralVPolicyLoss
        pass


    def create_models(self, transition_scheme):

        self.scheme_shapes = _generate_scheme_shapes(transition_scheme=transition_scheme,
                                                     dict_of_schemes=self.scheme)

        self.input_shapes = _generate_input_shapes(input_columns=self.input_columns,
                                                   scheme_shapes=self.scheme_shapes)

        # set up critic model
        self.critic_model = self.critic_class(input_shapes=self.input_shapes["critic"],
                                              n_agents=self.n_agents,
                                              n_actions=self.n_actions,
                                              args=self.args)
        if self.args.use_cuda:
            self.critic_model = self.critic_model.cuda()
        self.target_critic_model = deepcopy(self.critic_model)


        # set up optimizers
        if self.args.share_agent_params:
            self.agent_parameters = self.multiagent_controller.get_parameters()
        else:
            assert False, "TODO"
        self.agent_optimiser = RMSprop(self.agent_parameters, lr=self.args.lr_agent)

        self.critic_parameters = []
        if not (hasattr(self.args, "critic_share_params") and not self.args.critic_share_params):
            self.critic_parameters.extend(self.critic_model.parameters())
        else:
            assert False, "TODO"
        self.critic_optimiser = RMSprop(self.critic_parameters, lr=self.args.lr_critic)

        # this is used for joint retrieval of data from all schemes
        self.joint_scheme_dict = _join_dicts(self.scheme, self.multiagent_controller.joint_scheme_dict)

        self.args_sanity_check() # conduct MACKRL sanity check on arg parameters
        pass

    def args_sanity_check(self):
        """
        :return:
        """
        pass

    def train(self,
              batch_history,
              T_env=None):


        # Update target if necessary
        if (self.T_critic - self.last_target_update_T_critic) / self.args.target_critic_update_interval > 1.0:
            self.update_target_nets()
            self.last_target_update_T_critic = self.T_critic
            print("updating target net!")

        # Retrieve and view all data that can be retrieved from batch_history in a single step (caching efficient)

        # create one single batch_history view suitable for all
        data_inputs, data_inputs_tformat = batch_history.view(dict_of_schemes=self.joint_scheme_dict,
                                                              to_cuda=self.args.use_cuda,
                                                              to_variable=True,
                                                              bs_ids=None,
                                                              fill_zero=True) # DEBUG: Should be True

        actions, actions_tformat = batch_history.get_col(bs=None,
                                                         col="actions",
                                                         agent_ids=list(range(0, self.n_agents)),
                                                         stack=True)

        # do single forward pass in critic
        coma_model_inputs, coma_model_inputs_tformat = _build_model_inputs(column_dict=self.input_columns,
                                                                           inputs=data_inputs,
                                                                           inputs_tformat=data_inputs_tformat,
                                                                           to_variable=True)

        critic_loss_arr = []
        critic_mean_arr = []
        target_critic_mean_arr = []
        critic_grad_norm_arr = []



        # construct target-critic targets and carry out necessary forward passes
        # same input scheme for both target critic and critic!
        inputs_target_critic = coma_model_inputs["target_critic"]
        hidden_states = None
        if getattr(self.args, "critic_is_recurrent", False):
            hidden_states = Variable(
                th.zeros(inputs_target_critic["vfunction"].shape[0], 1, self.args.agents_hidden_state_size))
            if self.args.use_cuda:
                hidden_states = hidden_states.cuda()
        output_target_critic, output_target_critic_tformat = self.target_critic_model.forward(inputs_target_critic,
                                                                                              tformat="bs*t*v",
                                                                                              hidden_states=hidden_states)



        target_critic_td_targets, \
        target_critic_td_targets_tformat = batch_history.get_stat("td_lambda_targets",
                                                                  bs_ids=None,
                                                                  td_lambda=self.args.td_lambda,
                                                                  gamma=self.args.gamma,
                                                                  value_function_values=output_target_critic[
                                                                      "vvalue"].unsqueeze(0).detach(),
                                                                  to_variable=True,
                                                                  n_agents=1,
                                                                  to_cuda=self.args.use_cuda)

        # targets for terminal state are always NaNs, so mask these out of loss as well!
        mask = _pad_zero(inputs_target_critic["vfunction"][:,:,-1:].clone().fill_(1.0),
                         "bs*t*v",
                         batch_history.seq_lens).byte()
        mask[:, :-1, :] = mask[:, 1:, :] # account for terminal NaNs of targets
        mask[:, -1, :] = 0.0  # handles case of seq_len=limit_len

        output_critic_list = []
        def _optimize_critic(**kwargs):
            inputs_critic= kwargs["coma_model_inputs"]["critic"]
            inputs_target_critic=kwargs["coma_model_inputs"]["target_critic"]
            inputs_critic_tformat=kwargs["tformat"]
            inputs_target_critic_tformat = kwargs["tformat"]
            t = kwargs["t"]
            do_train = kwargs["do_train"]
            _inputs_critic = inputs_critic
            vtargets = target_critic_td_targets.squeeze(0)

            hidden_states = None
            if getattr(self.args, "critic_is_recurrent", False):
                hidden_states = Variable(th.zeros(output_target_critic["vvalue"].shape[0], 1, self.args.agents_hidden_state_size))
                if self.args.use_cuda:
                    hidden_states = hidden_states.cuda()

            output_critic, output_critic_tformat = self.critic_model.forward({_k:_v[:, t:t+1] for _k, _v in _inputs_critic.items()},
                                                                             tformat="bs*t*v",
                                                                             hidden_states=hidden_states)
            output_critic_list.append({_k:_v.clone() for _k, _v in output_critic.items()})

            if not do_train:
                return output_critic
            critic_loss, \
            critic_loss_tformat = CentralVCriticLoss()(inputs=output_critic["vvalue"],
                                                       target=Variable(vtargets[:, t:t+1], requires_grad=False),
                                                       tformat="bs*t*v",
                                                       mask=mask[:, t:t+1])
                                                       # seq_lens=batch_history.seq_lens)

            # optimize critic loss
            self.critic_optimiser.zero_grad()
            critic_loss.backward()

            critic_grad_norm = th.nn.utils.clip_grad_norm_(self.critic_parameters,
                                                           10)
            self.critic_optimiser.step()

            # Calculate critic statistics and update
            target_critic_mean = _naninfmean(output_target_critic["vvalue"])

            critic_mean = _naninfmean(output_critic["vvalue"])

            critic_loss_arr.append(np.asscalar(critic_loss.data.cpu().numpy()))
            critic_mean_arr.append(critic_mean)
            target_critic_mean_arr.append(target_critic_mean)
            critic_grad_norm_arr.append(critic_grad_norm)

            self.T_critic += 1
            return output_critic



        output_critic = None
        # optimize the critic as often as necessary to get the critic loss down reliably
        for _i in reversed(range(batch_history._n_t)): #range(self.args.n_critic_learner_reps):
            _ = _optimize_critic(coma_model_inputs=coma_model_inputs,
                                 tformat=coma_model_inputs_tformat,
                                 actions=actions,
                                 t=_i,
                                 do_train=(_i < max(batch_history.seq_lens) - 1))


        hidden_states = None
        if getattr(self.args, "critic_is_recurrent", False):
            hidden_states = Variable(th.zeros(coma_model_inputs["critic"]["vfunction"].shape[0], 1, self.args.agents_hidden_state_size))
            if self.args.use_cuda:
                hidden_states = hidden_states.cuda()

        # get advantages
        # output_critic, output_critic_tformat = self.critic_model.forward(coma_model_inputs["critic"],
        #                                                                   tformat="bs*t*v",
        #                                                                   hidden_states=hidden_states)

        values = th.cat([ x["vvalue"] for x in reversed(output_critic_list)], dim=1)

        # advantages = output_critic["advantage"]
        advantages = _n_step_return(values=values.unsqueeze(0), #output_critic["vvalue"].unsqueeze(0),
                                    rewards=batch_history["reward"][0],
                                    terminated=batch_history["terminated"][0],
                                    truncated=batch_history["truncated"][0],
                                    seq_lens=batch_history.seq_lens,
                                    horizon=batch_history._n_t-1,
                                    n=1 if not hasattr(self.args, "n_step_return_n") else self.args.n_step_return_n,
                                    gamma=self.args.gamma) - values #output_critic["vvalue"]

        advantages = advantages.squeeze(0)

        # only train the policy once in order to stay on-policy!
        policy_loss_function = partial(self.policy_loss_class(),
                                       actions = actions,
                                       advantages=advantages.detach(),
                                       seq_lens=batch_history.seq_lens,
                                       n_agents=self.n_agents)

        hidden_states, hidden_states_tformat = self.multiagent_controller.generate_initial_hidden_states(
            len(batch_history), caller="learner")

        agent_controller_output, \
        agent_controller_output_tformat = self.multiagent_controller.get_outputs(data_inputs,
                                                                                 hidden_states=hidden_states,
                                                                                 loss_fn=policy_loss_function,
                                                                                 tformat=data_inputs_tformat,
                                                                                 test_mode=False,
                                                                                 actions=actions)
        CentralV_loss = agent_controller_output["losses"]
        CentralV_loss = CentralV_loss.mean()

        if hasattr(self.args, "coma_use_entropy_regularizer") and self.args.coma_use_entropy_regularizer:
            CentralV_loss += self.args.coma_entropy_loss_regularization_factor * \
                         EntropyRegularisationLoss()(policies=agent_controller_output["policies"],
                                                     tformat="a*bs*t*v").sum()

        # carry out optimization for agents
        self.agent_optimiser.zero_grad()
        CentralV_loss.backward()

        policy_grad_norm = th.nn.utils.clip_grad_norm_(self.agent_parameters, 10)
        try:
            _check_nan(self.agent_parameters, silent_fail=False)
            self.agent_optimiser.step()  # DEBUG
            self._add_stat("Agent NaN gradient", 0.0, T_env=T_env)
        except Exception as e:
            self.logging_struct.py_logger.warning("NaN in agent gradients! Gradient not taken. ERROR: {}".format(e))
            self._add_stat("Agent NaN gradient", 1.0, T_env=T_env)

        # increase episode counter (the fastest one is always)
        self.T_policy += len(batch_history) * batch_history._n_t

        # Calculate policy statistics
        advantage_mean = _naninfmean(advantages)
        self._add_stat("advantage_mean", advantage_mean, T_env=T_env)
        self._add_stat("policy_grad_norm", policy_grad_norm, T_env=T_env)
        self._add_stat("policy_loss", CentralV_loss.data.cpu().numpy(), T_env=T_env)
        self._add_stat("critic_loss", np.mean(critic_loss_arr), T_env=T_env)
        self._add_stat("critic_mean", np.mean(critic_mean_arr), T_env=T_env)
        self._add_stat("target_critic_mean", np.mean(target_critic_mean_arr), T_env=T_env)
        self._add_stat("critic_grad_norm", np.mean(critic_grad_norm_arr), T_env=T_env)
        self._add_stat("T_policy", self.T_policy, T_env=T_env)
        self._add_stat("T_critic", self.T_critic, T_env=T_env)

        pass

    def update_target_nets(self):
        self.target_critic_model.load_state_dict(self.critic_model.state_dict())

    def get_stats(self):
        if hasattr(self, "_stats"):
            return self._stats
        else:
            return []

    def log(self, log_directly = True):
        """
        Each learner has it's own logging routine, which logs directly to the python-wide logger if log_directly==True,
        and returns a logging string otherwise

        Logging is triggered in run.py
        """
        stats = self.get_stats()
        logging_dict =  dict(advantage_mean = _seq_mean(stats["advantage_mean"]),
                             critic_grad_norm = _seq_mean(stats["critic_grad_norm"]),
                             critic_loss =_seq_mean(stats["critic_loss"]),
                             policy_grad_norm = _seq_mean(stats["policy_grad_norm"]),
                             policy_loss = _seq_mean(stats["policy_loss"]),
                             target_critic_mean = _seq_mean(stats["target_critic_mean"]),
                             T_critic=self.T_critic,
                             T_policy=self.T_policy
                            )
        logging_str = "T_policy={:g}, T_critic={:g}, ".format(logging_dict["T_policy"], logging_dict["T_critic"])
        logging_str += _make_logging_str(_copy_remove_keys(logging_dict, ["T_policy", "T_critic"]))

        if log_directly:
            self.logging_struct.py_logger.info("{} LEARNER INFO: {}".format(self.args.learner.upper(), logging_str))

        return logging_str, logging_dict

    def save_models(self, path, token, T):
        import os
        if not os.path.exists("results/models/{}".format(self.args.learner)):
            os.makedirs("results/models/{}".format(self.args.learner))

        self.multiagent_controller.save_models(path=path, token=token, T=T)
        th.save(self.critic_model.state_dict(),"results/models/{}/{}_critic__{}_T.weights".format(self.args.learner,
                                                                                            token,
                                                                                            T))
        th.save(self.target_critic_model.state_dict(), "results/models/{}/{}_target_critic__{}_T.weights".format(self.args.learner,
                                                                                                           token,
                                                                                                           T))
        pass
Esempio n. 20
0
class RODELearner:
    def __init__(self, mac, scheme, logger, args):
        self.args = args
        self.mac = mac
        self.logger = logger
        self.n_agents = args.n_agents

        self.params = list(mac.parameters())

        self.last_target_update_episode = 0

        self.mixer = None
        if args.mixer is not None:
            if args.mixer == "vdn":
                self.mixer = VDNMixer()
            elif args.mixer == "qmix":
                self.mixer = QMixer(args)
            else:
                raise ValueError("Mixer {} not recognised.".format(args.mixer))
            self.params += list(self.mixer.parameters())
            self.target_mixer = copy.deepcopy(self.mixer)

        self.role_mixer = None
        if args.role_mixer is not None:
            if args.role_mixer == "vdn":
                self.role_mixer = VDNMixer()
            elif args.role_mixer == "qmix":
                self.role_mixer = QMixer(args)
            else:
                raise ValueError("Role Mixer {} not recognised.".format(
                    args.role_mixer))
            self.params += list(self.role_mixer.parameters())
            self.target_role_mixer = copy.deepcopy(self.role_mixer)

        self.optimiser = RMSprop(params=self.params,
                                 lr=args.lr,
                                 alpha=args.optim_alpha,
                                 eps=args.optim_eps)

        # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC
        self.target_mac = copy.deepcopy(mac)

        self.log_stats_t = -self.args.learner_log_interval - 1

        self.role_interval = args.role_interval
        self.device = self.args.device

        self.role_action_spaces_updated = True

        # action encoder
        self.action_encoder_params = list(self.mac.action_encoder_params())
        self.action_encoder_optimiser = RMSprop(
            params=self.action_encoder_params,
            lr=args.lr,
            alpha=args.optim_alpha,
            eps=args.optim_eps)

    def train(self, batch: EpisodeBatch, t_env: int, episode_num: int):
        # Get the relevant quantities
        rewards = batch["reward"][:, :-1]
        actions = batch["actions"][:, :-1]
        terminated = batch["terminated"][:, :-1].float()
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        avail_actions = batch["avail_actions"]
        # role_avail_actions = batch["role_avail_actions"]
        roles_shape_o = batch["roles"][:, :-1].shape
        role_at = int(np.ceil(roles_shape_o[1] / self.role_interval))
        role_t = role_at * self.role_interval

        roles_shape = list(roles_shape_o)
        roles_shape[1] = role_t
        roles = th.zeros(roles_shape).to(self.device)
        roles[:, :roles_shape_o[1]] = batch["roles"][:, :-1]
        roles = roles.view(batch.batch_size, role_at, self.role_interval,
                           self.n_agents, -1)[:, :, 0]

        # Calculate estimated Q-Values
        mac_out = []
        role_out = []
        self.mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            agent_outs, role_outs = self.mac.forward(batch, t=t)
            mac_out.append(agent_outs)
            if t % self.role_interval == 0 and t < batch.max_seq_length - 1:
                role_out.append(role_outs)
        mac_out = th.stack(mac_out, dim=1)  # Concat over time
        role_out = th.stack(role_out, dim=1)  # Concat over time

        # Pick the Q-Values for the actions taken by each agent
        chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3,
                                        index=actions).squeeze(
                                            3)  # Remove the last dim
        chosen_role_qvals = th.gather(role_out, dim=3,
                                      index=roles.long()).squeeze(3)

        # Calculate the Q-Values necessary for the target
        target_mac_out = []
        target_role_out = []
        self.target_mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            target_agent_outs, target_role_outs = self.target_mac.forward(
                batch, t=t)
            target_mac_out.append(target_agent_outs)
            if t % self.role_interval == 0 and t < batch.max_seq_length - 1:
                target_role_out.append(target_role_outs)

        target_role_out.append(
            th.zeros(batch.batch_size, self.n_agents,
                     self.mac.n_roles).to(self.device))
        # We don't need the first timesteps Q-Value estimate for calculating targets
        target_mac_out = th.stack(target_mac_out[1:],
                                  dim=1)  # Concat across time
        target_role_out = th.stack(target_role_out[1:], dim=1)

        # Mask out unavailable actions
        target_mac_out[avail_actions[:, 1:] == 0] = -9999999
        # target_mac_out[role_avail_actions[:, 1:] == 0] = -9999999

        # Max over target Q-Values
        if self.args.double_q:
            # Get actions that maximise live Q (for double q-learning)
            mac_out_detach = mac_out.clone().detach()
            mac_out_detach[avail_actions == 0] = -9999999
            # mac_out_detach[role_avail_actions == 0] = -9999999
            cur_max_actions = mac_out_detach[:, 1:].max(dim=3, keepdim=True)[1]
            target_max_qvals = th.gather(target_mac_out, 3,
                                         cur_max_actions).squeeze(3)

            role_out_detach = role_out.clone().detach()
            role_out_detach = th.cat(
                [role_out_detach[:, 1:], role_out_detach[:, 0:1]], dim=1)
            cur_max_roles = role_out_detach.max(dim=3, keepdim=True)[1]
            target_role_max_qvals = th.gather(target_role_out, 3,
                                              cur_max_roles).squeeze(3)
        else:
            target_max_qvals = target_mac_out.max(dim=3)[0]
            target_role_max_qvals = target_role_out.max(dim=3)[0]

        # Mix
        if self.mixer is not None:
            chosen_action_qvals = self.mixer(chosen_action_qvals,
                                             batch["state"][:, :-1])
            target_max_qvals = self.target_mixer(target_max_qvals,
                                                 batch["state"][:, 1:])
        if self.role_mixer is not None:
            state_shape_o = batch["state"][:, :-1].shape
            state_shape = list(state_shape_o)
            state_shape[1] = role_t
            role_states = th.zeros(state_shape).to(self.device)
            role_states[:, :state_shape_o[1]] = batch["state"][:, :-1].detach(
            ).clone()
            role_states = role_states.view(batch.batch_size, role_at,
                                           self.role_interval, -1)[:, :, 0]
            chosen_role_qvals = self.role_mixer(chosen_role_qvals, role_states)
            role_states = th.cat([role_states[:, 1:], role_states[:, 0:1]],
                                 dim=1)
            target_role_max_qvals = self.target_role_mixer(
                target_role_max_qvals, role_states)

        # Calculate 1-step Q-Learning targets
        targets = rewards + self.args.gamma * (1 -
                                               terminated) * target_max_qvals
        rewards_shape = list(rewards.shape)
        rewards_shape[1] = role_t
        role_rewards = th.zeros(rewards_shape).to(self.device)
        role_rewards[:, :rewards.shape[1]] = rewards.detach().clone()
        role_rewards = role_rewards.view(batch.batch_size, role_at,
                                         self.role_interval).sum(dim=-1,
                                                                 keepdim=True)
        # role_terminated
        terminated_shape_o = terminated.shape
        terminated_shape = list(terminated_shape_o)
        terminated_shape[1] = role_t
        role_terminated = th.zeros(terminated_shape).to(self.device)
        role_terminated[:, :terminated_shape_o[1]] = terminated.detach().clone(
        )
        role_terminated = role_terminated.view(
            batch.batch_size, role_at, self.role_interval).sum(dim=-1,
                                                               keepdim=True)
        # role_terminated
        role_targets = role_rewards + self.args.gamma * (
            1 - role_terminated) * target_role_max_qvals

        # Td-error
        td_error = (chosen_action_qvals - targets.detach())
        role_td_error = (chosen_role_qvals - role_targets.detach())

        mask = mask.expand_as(td_error)
        mask_shape = list(mask.shape)
        mask_shape[1] = role_t
        role_mask = th.zeros(mask_shape).to(self.device)
        role_mask[:, :mask.shape[1]] = mask.detach().clone()
        role_mask = role_mask.view(batch.batch_size, role_at,
                                   self.role_interval, -1)[:, :, 0]

        # 0-out the targets that came from padded data
        masked_td_error = td_error * mask
        masked_role_td_error = role_td_error * role_mask

        # Normal L2 loss, take mean over actual data
        loss = (masked_td_error**2).sum() / mask.sum()
        role_loss = (masked_role_td_error**2).sum() / role_mask.sum()
        loss += role_loss

        # Optimise
        self.optimiser.zero_grad()
        loss.backward()
        grad_norm = th.nn.utils.clip_grad_norm_(self.params,
                                                self.args.grad_norm_clip)
        self.optimiser.step()

        pred_obs_loss = None
        pred_r_loss = None
        pred_grad_norm = None

        if self.role_action_spaces_updated:
            # train action encoder

            no_pred = []
            r_pred = []
            for t in range(batch.max_seq_length):
                no_preds, r_preds = self.mac.action_repr_forward(batch, t=t)
                no_pred.append(no_preds)
                r_pred.append(r_preds)
            no_pred = th.stack(no_pred, dim=1)[:, :-1]  # Concat over time
            r_pred = th.stack(r_pred, dim=1)[:, :-1]
            no = batch["obs"][:, 1:].detach().clone()
            repeated_rewards = batch["reward"][:, :-1].detach().clone(
            ).unsqueeze(2).repeat(1, 1, self.n_agents, 1)

            pred_obs_loss = th.sqrt(((no_pred - no)**2).sum(dim=-1)).mean()
            pred_r_loss = ((r_pred - repeated_rewards)**2).mean()

            pred_loss = pred_obs_loss + 10 * pred_r_loss
            self.action_encoder_optimiser.zero_grad()
            pred_loss.backward()
            pred_grad_norm = th.nn.utils.clip_grad_norm_(
                self.action_encoder_params, self.args.grad_norm_clip)
            self.action_encoder_optimiser.step()

            if t_env > self.args.role_action_spaces_update_start:
                self.mac.update_role_action_spaces()
                if 'noar' in self.args.mac:
                    self.target_mac.role_selector.update_roles(
                        self.mac.n_roles)
                self.role_action_spaces_updated = False
                self._update_targets()
                self.last_target_update_episode = episode_num

        if (episode_num - self.last_target_update_episode
            ) / self.args.target_update_interval >= 1.0:
            self._update_targets()
            self.last_target_update_episode = episode_num

        if t_env - self.log_stats_t >= self.args.learner_log_interval:
            self.logger.log_stat("loss", (loss - role_loss).item(), t_env)
            self.logger.log_stat("role_loss", role_loss.item(), t_env)
            self.logger.log_stat("grad_norm", grad_norm, t_env)
            if pred_obs_loss is not None:
                self.logger.log_stat("pred_obs_loss", pred_obs_loss.item(),
                                     t_env)
                self.logger.log_stat("pred_r_loss", pred_r_loss.item(), t_env)
                self.logger.log_stat("action_encoder_grad_norm",
                                     pred_grad_norm, t_env)
            mask_elems = mask.sum().item()
            self.logger.log_stat(
                "td_error_abs",
                (masked_td_error.abs().sum().item() / mask_elems), t_env)
            self.logger.log_stat("q_taken_mean",
                                 (chosen_action_qvals * mask).sum().item() /
                                 (mask_elems * self.args.n_agents), t_env)
            self.logger.log_stat("role_q_taken_mean",
                                 (chosen_role_qvals * role_mask).sum().item() /
                                 (role_mask.sum().item() * self.args.n_agents),
                                 t_env)
            self.logger.log_stat("target_mean", (targets * mask).sum().item() /
                                 (mask_elems * self.args.n_agents), t_env)
            self.log_stats_t = t_env

    def _update_targets(self):
        self.target_mac.load_state(self.mac)
        if self.mixer is not None:
            self.target_mixer.load_state_dict(self.mixer.state_dict())
        if self.role_mixer is not None:
            self.target_role_mixer.load_state_dict(
                self.role_mixer.state_dict())
        self.target_mac.role_action_spaces_updated = self.role_action_spaces_updated
        self.logger.console_logger.info("Updated target network")

    def cuda(self):
        self.mac.cuda()
        self.target_mac.cuda()
        if self.mixer is not None:
            self.mixer.cuda()
            self.target_mixer.cuda()
        if self.role_mixer is not None:
            self.role_mixer.cuda()
            self.target_role_mixer.cuda()

    def save_models(self, path):
        self.mac.save_models(path)
        if self.mixer is not None:
            th.save(self.mixer.state_dict(), "{}/mixer.th".format(path))
        if self.role_mixer is not None:
            th.save(self.role_mixer.state_dict(),
                    "{}/role_mixer.th".format(path))
        th.save(self.optimiser.state_dict(), "{}/opt.th".format(path))
        th.save(self.action_encoder_optimiser.state_dict(),
                "{}/action_repr_opt.th".format(path))

    def load_models(self, path):
        self.mac.load_models(path)
        # Not quite right but I don't want to save target networks
        self.target_mac.load_models(path)
        if self.mixer is not None:
            self.mixer.load_state_dict(
                th.load("{}/mixer.th".format(path),
                        map_location=lambda storage, loc: storage))
        if self.role_mixer is not None:
            self.role_mixer.load_state_dict(
                th.load("{}/role_mixer.th".format(path),
                        map_location=lambda storage, loc: storage))
        self.optimiser.load_state_dict(
            th.load("{}/opt.th".format(path),
                    map_location=lambda storage, loc: storage))
        self.action_encoder_optimiser.load_state_dict(
            th.load("{}/action_repr_opt.th".format(path),
                    map_location=lambda storage, loc: storage))
Esempio n. 21
0
    loss_tracker = []
    for t in range(opt.n_disc):
        latent_variable = get_latent_variable(opt.batch_size, opt.latent_dim,
                                              device)

        fake_images = generator(latent_variable)
        real_images = data_loader.load_images()

        discriminator_optimizer.zero_grad()
        discriminator_loss = WGANGP_loss(discriminator=discriminator,
                                         from_real=real_images,
                                         from_fake=fake_images,
                                         lamda=opt.lamda)
        discriminator_loss_tracker.append(discriminator_loss.tolist())
        discriminator_loss.backward()
        discriminator_optimizer.step()

    # update the generator
    latent_variable = get_latent_variable(opt.batch_size, opt.latent_dim,
                                          device)

    fake_images = generator(latent_variable)
    generator_optimizer.zero_grad()
    generator_loss = -discriminator(fake_images).mean()
    generator_loss_tracker.append(generator_loss.tolist())
    generator_loss.backward()
    generator_optimizer.step()

    # check the convergence
    gen_conv = generator_loss_tracker.converged()
    dis_conv = discriminator_loss_tracker.converged()
Esempio n. 22
0
File: train.py Progetto: nik-sm/dqn
class Agent:
    def __init__(self,
                 game: str,
                 replay_buffer_capacity: int,
                 replay_start_size: int,
                 batch_size: int,
                 discount_factor: float,
                 lr: float,
                 device: str = 'cuda:0',
                 env_seed: int = 0,
                 frame_buffer_size: int = 4,
                 print_self=True):

        self.device = device
        self.discount_factor = discount_factor
        self.game = game
        self.batch_size = batch_size

        self.replay_buf = ReplayBuffer(capacity=replay_buffer_capacity)

        self.env = FrameStack(
            AtariPreprocessing(
                gym.make(self.game),
                # noop_max=0,
                # terminal_on_life_loss=True,
                scale_obs=False),
            num_stack=frame_buffer_size)
        self.env.seed(env_seed)
        self.reset()

        self.n_action = self.env.action_space.n
        self.policy_net = DQN(self.n_action).to(self.device)
        self.target_net = DQN(self.n_action).to(self.device).eval()
        self.optimizer = RMSprop(
            self.policy_net.parameters(),
            alpha=0.95,
            # momentum=0.95,
            eps=0.01)

        if print_self:
            print(self)
        self._fill_replay_buf(replay_start_size)

    def __repr__(self):
        return '\n'.join([
            'Agent:', f'Game: {self.game}', f'Device: {self.device}',
            f'Policy net: {self.policy_net}', f'Target net: {self.target_net}',
            f'Replay buf: {self.replay_buf}'
        ])

    def _fill_replay_buf(self, replay_start_size):
        for _ in trange(replay_start_size,
                        desc='Fill replay_buf randomly',
                        leave=True):
            self.step(1.0)

    def reset(self):
        """Reset the end, pre-populate self.frame_buf and self.state"""
        self.state = self.env.reset()

    @torch.no_grad()
    def step(self, epsilon, clip_reward=True):
        """
        Choose an action based on current state and epsilon-greedy policy
        """
        # Choose action
        if random.random() <= epsilon:
            q_values = None
            action = self.env.action_space.sample()
        else:
            torch_state = torch.tensor(self.state,
                                       dtype=torch.float32,
                                       device=self.device).unsqueeze(0) / 255.0
            q_values = self.policy_net(torch_state)
            action = int(q_values.argmax(dim=1).item())

        # Apply action
        next_state, reward, done, _ = self.env.step(action)
        if clip_reward:
            reward = max(-1.0, min(reward, 1.0))

        # Store into replay buffer
        self.replay_buf.append(
            (torch.tensor(
                np.array(self.state), dtype=torch.float32, device="cpu") /
             255., action, reward,
             torch.tensor(
                 np.array(next_state), dtype=torch.float32, device="cpu") /
             255., done))

        # Advance to next state
        self.state = next_state
        if done:
            self.reset()

        return reward, q_values, done

    def q_update(self):
        self.optimizer.zero_grad()
        states, actions, rewards, next_states, dones = [
            x.to(self.device) for x in self.replay_buf.sample(self.batch_size)
        ]

        with torch.no_grad():
            y = torch.where(
                dones, rewards, rewards +
                self.discount_factor * self.target_net(next_states).max(1)[0])

        predicted_values = self.policy_net(states).gather(
            1, actions.unsqueeze(-1)).squeeze(-1)
        loss = huber(y, predicted_values, 2.)
        loss.backward()
        self.optimizer.step()
        return (y - predicted_values).abs().mean()
Esempio n. 23
0
                gan_dis_eq_mean(0.0)

            if train_dec:
                gan_gen_eq_mean(1.0)
            else:
                gan_gen_eq_mean(0.0)

            # BACKPROP
            # clean grads
            net.zero_grad()
            # encoder
            loss_encoder.backward(retain_graph=True)
            # someone likes to clamp the grad here
            #[p.grad.data.clamp_(-1,1) for p in net.encoder.parameters()]
            # update parameters
            optimizer_encoder.step()
            # clean others, so they are not afflicted by encoder loss
            net.zero_grad()
            #decoder
            if train_dec:
                loss_decoder.backward(retain_graph=True)
                #[p.grad.data.clamp_(-1,1) for p in net.decoder.parameters()]
                optimizer_decoder.step()
                #clean the discriminator
                net.discriminator.zero_grad()
            #discriminator
            if train_dis:
                loss_discriminator.backward()
                #[p.grad.data.clamp_(-1,1) for p in net.discriminator.parameters()]
                optimizer_discriminator.step()
Esempio n. 24
0
class NEC_Agent:

    def __init__(self, args, exp_model, logging_func):
        self.args = args

        # Exploration Model
        self.exp_model = exp_model

        self.log = logging_func["log"]

        # Experience Replay
        self.replay = ExpReplay(args.exp_replay_size, args)
        self.dnds = [DND(kernel=kernel, num_neighbors=args.nec_neighbours, max_memory=args.dnd_size, embedding_size=args.nec_embedding) for _ in range(self.args.actions)]

        # DQN and Target DQN
        model = get_models(args.model)
        self.embedding = model(embedding=args.nec_embedding)

        embedding_params = 0
        for weight in self.embedding.parameters():
            weight_params = 1
            for s in weight.size():
                weight_params *= s
            embedding_params += weight_params
        print("Embedding Network has {:,} parameters.".format(embedding_params))

        if args.gpu:
            print("Moving models to GPU.")
            self.embedding.cuda()

        # Optimizer
        self.optimizer = RMSprop(self.embedding.parameters(), lr=args.lr)
        # self.optimizer = Adam(self.embedding.parameters(), lr=args.lr)

        self.T = 0
        self.target_sync_T = -self.args.t_max

        self.experiences = []
        self.keys = []
        self.q_val_estimates = []

        self.table_updates = 0

    def Q_Value_Estimates(self, state):
        # Get state embedding
        state = torch.from_numpy(state).float().transpose_(0, 2).unsqueeze(0)
        key = self.embedding(Variable(state, volatile=True)).cpu()

        if (key != key).sum().data[0] > 0:
            pass
            # print(key)
            # for param in self.embedding.parameters():
                # print(param)
            # print(key != key)
            # print((key != key).sum().data[0])
            # print("Nan key")

        estimate_from_dnds = torch.cat([dnd.lookup(key) for dnd in self.dnds])
        # print(estimate_from_dnds)

        self.keys.append(key.data[0].numpy())
        self.q_val_estimates.append(estimate_from_dnds.data.numpy())

        return estimate_from_dnds, key
        # return np.array(estimate_from_dnds), key

    def act(self, state, epsilon, exp_model):

        q_values, key = self.Q_Value_Estimates(state)
        q_values_numpy = q_values.data.numpy()

        extra_info = {}
        extra_info["Q_Values"] = q_values_numpy

        if np.random.random() < epsilon:
            action = np.random.randint(low=0, high=self.args.actions)
        else:
            action = np.argmax(q_values_numpy)

        extra_info["Action"] = action

        return action, extra_info

    def experience(self, state, action, reward, state_next, steps, terminated, pseudo_reward=0, density=1, exploring=False):

        experience = (state, action, reward, pseudo_reward, state_next, terminated)
        self.experiences.append(experience)

        if len(self.experiences) >= self.args.n_step:
            self.add_experience()

        if not exploring:
            self.T += 1

    def end_of_trajectory(self):
        self.replay.end_of_trajectory()

        # Go through the experiences and add them to the replay using a less than N-step Q-Val estimate
        while len(self.experiences) > 0:
            self.add_experience()

    def add_experience(self):

        # Match the key and q val estimates size to the number of experieneces
        N = len(self.experiences)
        self.keys = self.keys[-N:]
        self.q_val_estimates = self.q_val_estimates[-N:]

        first_state = self.experiences[0][0]
        first_action = self.experiences[0][1]
        last_state = self.experiences[-1][4]
        terminated_last_state = self.experiences[-1][5]
        accum_reward = 0
        for ex in reversed(self.experiences):
            r = ex[2]
            pr = ex[3]
            accum_reward = (r + pr) + self.args.gamma * accum_reward
        # if accum_reward > 1000:
            # print(accum_reward)
        if terminated_last_state:
            last_state_max_q_val = 0
        else:
            # last_state_q_val_estimates, last_state_key = self.Q_Value_Estimates(last_state)
            # last_state_max_q_val = last_state_q_val_estimates.data.max(0)[0][0]
            last_state_max_q_val = np.max(self.q_val_estimates[-1])
            # print(last_state_max_q_val)

        # first_state_q_val_estimates, first_state_key = self.Q_Value_Estimates(first_state)
        # first_state_key = first_state_key.data[0].numpy()
        first_state_key = self.keys[0]

        n_step_q_val_estimate = accum_reward + (self.args.gamma ** len(self.experiences)) * last_state_max_q_val
        n_step_q_val_estimate = n_step_q_val_estimate
        # print(n_step_q_val_estimate)

        # Add to dnd
        # print(first_state_key)
        # print(tuple(first_state_key.data[0]))
        # if any(np.isnan(first_state_key)):
            # print("NAN")
        if self.dnds[first_action].is_present(key=first_state_key):
            current_q_val = self.dnds[first_action].get_value(key=first_state_key)
            new_q_val = current_q_val + self.args.nec_alpha * (n_step_q_val_estimate - current_q_val)
            self.dnds[first_action].upsert(key=first_state_key, value=new_q_val)
            self.table_updates += 1
            self.log("NEC/Table_Updates", self.table_updates, step=self.T)
        else:
            self.dnds[first_action].upsert(key=first_state_key, value=n_step_q_val_estimate)

        # Add to replay
        self.replay.Add_Exp(first_state, first_action, n_step_q_val_estimate)

        # Remove first experience
        self.experiences = self.experiences[1:]

    def train(self):

        info = {}
        if self.T % self.args.nec_update != 0:
            return info

        # print("Training")

        for _ in range(self.args.iters):

            # TODO: Use a named tuple for experience replay
            batch = self.replay.Sample(self.args.batch_size)
            columns = list(zip(*batch))

            states = Variable(torch.from_numpy(np.array(columns[0])).float().transpose_(1, 3))
            # print(states)
            actions = columns[1]
            # print(actions)
            targets = Variable(torch.FloatTensor(columns[2]))
            # print(targets)
            keys = self.embedding(states).cpu()
            # print(keys)
            # print("Keys", keys.requires_grad)
            # for action in actions:
                # print(action)
            # for action, key in zip(actions, keys):
                # print(action, key)
                # kk = key.unsqueeze(0)
                # print("kk", kk.requires_grad)
                # k = self.dnds[action].lookup(key.unsqueeze(0))
                # print("key", key.requires_grad, key.volatile)
            model_predictions = torch.cat([self.dnds[action].lookup(key.unsqueeze(0)) for action, key in zip(actions, keys)])
            # print(model_predictions)
            # print(targets)

            td_error = model_predictions - targets
            # print(td_error)
            info["TD_Error"] = td_error.mean().data[0]

            l2_loss = (td_error).pow(2).mean()
            info["Loss"] = l2_loss.data[0]

            # Update
            self.optimizer.zero_grad()

            l2_loss.backward()

            # Taken from pytorch clip_grad_norm
            # Remove once the pip version it up to date with source
            gradient_norm = clip_grad_norm(self.embedding.parameters(), self.args.clip_value)
            if gradient_norm is not None:
                info["Norm"] = gradient_norm

            self.optimizer.step()

            if "States" in info:
                states_trained = info["States"]
                info["States"] = states_trained + columns[0]
            else:
                info["States"] = columns[0]

        return info
Esempio n. 25
0
class PPO(Agent):
    """
    An agent learned with PPO using Advantage Actor-Critic framework
    - Actor takes state as input
    - Critic takes both state and action as input
    - agent interact with environment to collect experience
    - agent training with experience to update policy
    - adam seems better than rmsprop for ppo
    """
    def __init__(self,
                 env,
                 state_dim,
                 action_dim,
                 memory_capacity=10000,
                 max_steps=None,
                 roll_out_n_steps=1,
                 target_tau=1.,
                 target_update_steps=5,
                 clip_param=0.2,
                 reward_gamma=0.99,
                 reward_scale=1.,
                 done_penalty=None,
                 actor_hidden_size=32,
                 critic_hidden_size=32,
                 actor_output_act=nn.functional.log_softmax,
                 critic_loss="mse",
                 actor_lr=0.001,
                 critic_lr=0.001,
                 optimizer_type="adam",
                 entropy_reg=0.01,
                 max_grad_norm=0.5,
                 batch_size=100,
                 episodes_before_train=100,
                 epsilon_start=0.9,
                 epsilon_end=0.01,
                 epsilon_decay=200,
                 use_cuda=True):
        super(PPO,
              self).__init__(env, state_dim, action_dim, memory_capacity,
                             max_steps, reward_gamma, reward_scale,
                             done_penalty, actor_hidden_size,
                             critic_hidden_size, actor_output_act, critic_loss,
                             actor_lr, critic_lr, optimizer_type, entropy_reg,
                             max_grad_norm, batch_size, episodes_before_train,
                             epsilon_start, epsilon_end, epsilon_decay,
                             use_cuda)

        self.roll_out_n_steps = roll_out_n_steps
        self.target_tau = target_tau
        self.target_update_steps = target_update_steps
        self.clip_param = clip_param

        self.actor = ActorNetwork(self.state_dim, self.actor_hidden_size,
                                  self.action_dim, self.actor_output_act)
        self.critic = CriticNetwork(self.state_dim, self.action_dim,
                                    self.critic_hidden_size, 1)
        # to ensure target network and learning network has the same weights
        self.actor_target = deepcopy(self.actor)
        self.critic_target = deepcopy(self.critic)

        if self.optimizer_type == "adam":
            self.actor_optimizer = Adam(self.actor.parameters(),
                                        lr=self.actor_lr)
            self.critic_optimizer = Adam(self.critic.parameters(),
                                         lr=self.critic_lr)
        elif self.optimizer_type == "rmsprop":
            self.actor_optimizer = RMSprop(self.actor.parameters(),
                                           lr=self.actor_lr)
            self.critic_optimizer = RMSprop(self.critic.parameters(),
                                            lr=self.critic_lr)

        if self.use_cuda:
            self.actor.cuda()
            self.critic.cuda()
            self.actor_target.cuda()
            self.critic_target.cuda()

    # agent interact with the environment to collect experience
    def interact(self):
        super(PPO, self)._take_n_steps()

    # train on a roll out batch
    def train(self):
        if self.n_episodes <= self.episodes_before_train:
            pass

        batch = self.memory.sample(self.batch_size)
        states_var = to_tensor_var(batch.states,
                                   self.use_cuda).view(-1, self.state_dim)
        one_hot_actions = index_to_one_hot(batch.actions, self.action_dim)
        actions_var = to_tensor_var(one_hot_actions,
                                    self.use_cuda).view(-1, self.action_dim)
        rewards_var = to_tensor_var(batch.rewards, self.use_cuda).view(-1, 1)

        # update actor network
        self.actor_optimizer.zero_grad()
        values = self.critic_target(states_var, actions_var).detach()
        advantages = rewards_var - values
        # # normalizing advantages seems not working correctly here
        # advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-5)
        action_log_probs = self.actor(states_var)
        action_log_probs = th.sum(action_log_probs * actions_var, 1)
        old_action_log_probs = self.actor_target(states_var).detach()
        old_action_log_probs = th.sum(old_action_log_probs * actions_var, 1)
        ratio = th.exp(action_log_probs - old_action_log_probs)
        surr1 = ratio * advantages
        surr2 = th.clamp(ratio, 1.0 - self.clip_param,
                         1.0 + self.clip_param) * advantages
        # PPO's pessimistic surrogate (L^CLIP)
        actor_loss = -th.mean(th.min(surr1, surr2))
        actor_loss.backward()
        if self.max_grad_norm is not None:
            nn.utils.clip_grad_norm(self.actor.parameters(),
                                    self.max_grad_norm)
        self.actor_optimizer.step()

        # update critic network
        self.critic_optimizer.zero_grad()
        target_values = rewards_var
        values = self.critic(states_var, actions_var)
        if self.critic_loss == "huber":
            critic_loss = nn.functional.smooth_l1_loss(values, target_values)
        else:
            critic_loss = nn.MSELoss()(values, target_values)
        critic_loss.backward()
        if self.max_grad_norm is not None:
            nn.utils.clip_grad_norm(self.critic.parameters(),
                                    self.max_grad_norm)
        self.critic_optimizer.step()

        # update actor target network and critic target network
        if self.n_steps % self.target_update_steps == 0 and self.n_steps > 0:
            super(PPO, self)._soft_update_target(self.actor_target, self.actor)
            super(PPO, self)._soft_update_target(self.critic_target,
                                                 self.critic)

    # predict softmax action based on state
    def _softmax_action(self, state):
        state_var = to_tensor_var([state], self.use_cuda)
        softmax_action_var = th.exp(self.actor(state_var))
        if self.use_cuda:
            softmax_action = softmax_action_var.data.cpu().numpy()[0]
        else:
            softmax_action = softmax_action_var.data.numpy()[0]
        return softmax_action

    # choose an action based on state with random noise added for exploration in training
    def exploration_action(self, state):
        softmax_action = self._softmax_action(state)
        epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
                                  np.exp(-1. * self.n_steps / self.epsilon_decay)
        if np.random.rand() < epsilon:
            action = np.random.choice(self.action_dim)
        else:
            action = np.argmax(softmax_action)
        return action

    # choose an action based on state for execution
    def action(self, state):
        softmax_action = self._softmax_action(state)
        action = np.argmax(softmax_action)
        return action

    # evaluate value for a state-action pair
    def value(self, state, action):
        state_var = to_tensor_var([state], self.use_cuda)
        action = index_to_one_hot(action, self.action_dim)
        action_var = to_tensor_var([action], self.use_cuda)
        value_var = self.critic(state_var, action_var)
        if self.use_cuda:
            value = value_var.data.cpu().numpy()[0]
        else:
            value = value_var.data.numpy()[0]
        return value
Esempio n. 26
0
class DQN_Model_Agent:

    def __init__(self, args, exp_model, logging_func):
        self.args = args

        # Exploration Model
        self.exp_model = exp_model

        self.log = logging_func["log"]
        self.log_image = logging_func["image"]
        os.makedirs("{}/transition_model".format(args.log_path))

        # Experience Replay
        self.replay = ExpReplay(args.exp_replay_size, args.stale_limit, exp_model, args, priority=self.args.prioritized)

        # DQN and Target DQN
        model = get_models(args.model)
        print("\n\nDQN")
        self.dqn = model(actions=args.actions)
        print("Target DQN")
        self.target_dqn = model(actions=args.actions)

        dqn_params = 0
        for weight in self.dqn.parameters():
            weight_params = 1
            for s in weight.size():
                weight_params *= s
            dqn_params += weight_params
        print("Model DQN has {:,} parameters.".format(dqn_params))

        self.target_dqn.eval()

        if args.gpu:
            print("Moving models to GPU.")
            self.dqn.cuda()
            self.target_dqn.cuda()

        # Optimizer
        # self.optimizer = Adam(self.dqn.parameters(), lr=args.lr)
        self.optimizer = RMSprop(self.dqn.parameters(), lr=args.lr)

        self.T = 0
        self.target_sync_T = -self.args.t_max

        # Action sequences
        self.actions_to_take = []

    def sync_target_network(self):
        for target, source in zip(self.target_dqn.parameters(), self.dqn.parameters()):
            target.data = source.data


    def get_pc_estimates(self, root_state, depth=0, starts=None):
        state = root_state
        bonuses = []
        for action in range(self.args.actions):

            # Current pc estimates
            if depth == 0 or not self.args.only_leaf:
                numpy_state = state[0].numpy().swapaxes(0, 2)
                _, info = self.exp_model.bonus(numpy_state, action, dont_remember=True)
                action_pseudo_count = info["Pseudo_Count"]
                action_bonus = self.args.optimistic_scaler / np.power(action_pseudo_count + 0.01, self.args.bandit_p)
                if starts is not None:
                    action_bonus += starts[action]

            # If the depth is 0 we don't want to look any further ahead
            if depth == 0:
                bonuses.append(action_bonus)
                continue

            one_hot_action = torch.zeros(1, self.args.actions)
            one_hot_action[0, action] = 1
            _, next_state_prediction = self.dqn(Variable(state, volatile=True), Variable(one_hot_action, volatile=True))
            next_state_prediction = next_state_prediction.cpu().data

            next_state_pc_estimates = self.get_pc_estimates(next_state_prediction, depth=depth - 1)

            if self.args.only_leaf:
                bonuses += next_state_pc_estimates
            else:
                ahead_pc_estimates = [action_bonus + self.args.gamma * n for n in next_state_pc_estimates]
                bonuses += ahead_pc_estimates

        return bonuses

    def act(self, state, epsilon, exp_model, evaluation=False):
        # self.T += 1
        if not evaluation:
            if len(self.actions_to_take) > 0:
                action_to_take = self.actions_to_take[0]
                self.actions_to_take = self.actions_to_take[1:]
                return action_to_take, {"Action": action_to_take, "Q_Values": self.prev_q_vals}

        self.dqn.eval()
        # orig_state = state[:, :, -1:]
        state = torch.from_numpy(state).float().transpose_(0, 2).unsqueeze(0)
        q_values = self.dqn(Variable(state, volatile=True)).cpu().data[0]
        q_values_numpy = q_values.numpy()
        self.prev_q_vals = q_values_numpy

        extra_info = {}

        if self.args.optimistic_init and not evaluation and len(self.actions_to_take) == 0:

            # 2 action lookahead
            action_bonuses = self.get_pc_estimates(state, depth=self.args.lookahead_depth, starts=q_values_numpy)

            # Find the maximum sequence 
            max_so_far = -100000
            best_index = 0
            best_seq = []

            for ii, bonus in enumerate(action_bonuses):
                if bonus > max_so_far:
                    best_index = ii
                    max_so_far = bonus

            for depth in range(self.args.lookahead_depth):
                last_action = best_index % self.args.actions
                best_index = best_index // self.args.actions
                best_seq = best_seq + [last_action]

            # print(best_seq)
            self.actions_to_take = best_seq

        extra_info["Q_Values"] = q_values_numpy

        if np.random.random() < epsilon:
            action = np.random.randint(low=0, high=self.args.actions)
        else:
            action = q_values.max(0)[1][0]  # Torch...

        extra_info["Action"] = action

        return action, extra_info

    def experience(self, state, action, reward, state_next, steps, terminated, pseudo_reward=0, density=1, exploring=False):
        if not exploring:
            self.T += 1
        self.replay.Add_Exp(state, action, reward, state_next, steps, terminated, pseudo_reward, density)

    def end_of_trajectory(self):
        self.replay.end_of_trajectory()

    def train(self):

        if self.T - self.target_sync_T > self.args.target:
            self.sync_target_network()
            self.target_sync_T = self.T

        info = {}

        for _ in range(self.args.iters):
            self.dqn.eval()

            # TODO: Use a named tuple for experience replay
            n_step_sample = self.args.n_step
            batch, indices, is_weights = self.replay.Sample_N(self.args.batch_size, n_step_sample, self.args.gamma)
            columns = list(zip(*batch))

            states = Variable(torch.from_numpy(np.array(columns[0])).float().transpose_(1, 3))
            actions = Variable(torch.LongTensor(columns[1]))
            terminal_states = Variable(torch.FloatTensor(columns[5]))
            rewards = Variable(torch.FloatTensor(columns[2]))
            # Have to clip rewards for DQN
            rewards = torch.clamp(rewards, -1, 1)
            steps = Variable(torch.FloatTensor(columns[4]))
            new_states = Variable(torch.from_numpy(np.array(columns[3])).float().transpose_(1, 3))

            target_dqn_qvals = self.target_dqn(new_states).cpu()
            # Make a new variable with those values so that these are treated as constants
            target_dqn_qvals_data = Variable(target_dqn_qvals.data)

            q_value_targets = (Variable(torch.ones(terminal_states.size()[0])) - terminal_states)
            inter = Variable(torch.ones(terminal_states.size()[0]) * self.args.gamma)
            # print(steps)
            q_value_targets = q_value_targets * torch.pow(inter, steps)
            if self.args.double:
                # Double Q Learning
                new_states_qvals = self.dqn(new_states).cpu()
                new_states_qvals_data = Variable(new_states_qvals.data)
                q_value_targets = q_value_targets * target_dqn_qvals_data.gather(1, new_states_qvals_data.max(1)[1])
            else:
                q_value_targets = q_value_targets * target_dqn_qvals_data.max(1)[0]
            q_value_targets = q_value_targets + rewards

            self.dqn.train()

            one_hot_actions = torch.zeros(self.args.batch_size, self.args.actions)

            for i in range(self.args.batch_size):
                one_hot_actions[i][actions[i].data] = 1

            if self.args.gpu:
                actions = actions.cuda()
                one_hot_actions = one_hot_actions.cuda()
                q_value_targets = q_value_targets.cuda()
                new_states = new_states.cuda()
            model_predictions_q_vals, model_predictions_state = self.dqn(states, Variable(one_hot_actions))
            model_predictions = model_predictions_q_vals.gather(1, actions.view(-1, 1))

            # info = {}

            td_error = model_predictions - q_value_targets
            info["TD_Error"] = td_error.mean().data[0]

            # Update the priorities
            if not self.args.density_priority:
                self.replay.Update_Indices(indices, td_error.cpu().data.numpy(), no_pseudo_in_priority=self.args.count_td_priority)

            # If using prioritised we need to weight the td_error
            if self.args.prioritized and self.args.prioritized_is:
                # print(td_error)
                weights_tensor = torch.from_numpy(is_weights).float()
                weights_tensor = Variable(weights_tensor)
                if self.args.gpu:
                    weights_tensor = weights_tensor.cuda()
                # print(weights_tensor)
                td_error = td_error * weights_tensor

            # Model 1 step state transition error

            # Save them every x steps
            if self.T % self.args.model_save_image == 0:
                os.makedirs("{}/transition_model/{}".format(self.args.log_path, self.T))
                for ii, image, action, next_state, current_state in zip(range(self.args.batch_size), model_predictions_state.cpu().data, actions.data, new_states.cpu().data, states.cpu().data):
                    image = image.numpy()[0]
                    image = np.clip(image, 0, 1)
                    # print(next_state)
                    next_state = next_state.numpy()[0]
                    current_state = current_state.numpy()[0]

                    black_bars = np.zeros_like(next_state[:1, :])
                    # print(black_bars.shape)

                    joined_image = np.concatenate((current_state, black_bars, image, black_bars, next_state), axis=0)
                    joined_image = np.transpose(joined_image)
                    self.log_image("{}/transition_model/{}/{}_____Action_{}".format(self.args.log_path, self.T, ii + 1, action), joined_image * 255)

                    # self.log_image("{}/transition_model/{}/{}_____Action_{}".format(self.args.log_path, self.T, ii + 1, action), image * 255)
                    # self.log_image("{}/transition_model/{}/{}_____Correct".format(self.args.log_path, self.T, ii + 1), next_state * 255)

            # print(model_predictions_state)

            # Cross Entropy Loss
            # TODO

            # Regresssion loss
            state_error = model_predictions_state - new_states
            # state_error_val = state_error.mean().data[0]

            info["State_Error"] = state_error.mean().data[0]
            self.log("DQN/State_Loss", state_error.mean().data[0], step=self.T)
            self.log("DQN/State_Loss_Squared", state_error.pow(2).mean().data[0], step=self.T)
            self.log("DQN/State_Loss_Max", state_error.abs().max().data[0], step=self.T)
            # self.log("DQN/Action_Matrix_Norm", self.dqn.action_matrix.weight.norm().cpu().data[0], step=self.T)

            combined_loss = (1 - self.args.model_loss) * td_error.pow(2).mean() + (self.args.model_loss) * state_error.pow(2).mean()
            l2_loss = combined_loss
            # l2_loss = (combined_loss).pow(2).mean()
            info["Loss"] = l2_loss.data[0]

            # Update
            self.optimizer.zero_grad()
            l2_loss.backward()

            # Taken from pytorch clip_grad_norm
            # Remove once the pip version it up to date with source
            gradient_norm = clip_grad_norm(self.dqn.parameters(), self.args.clip_value)
            if gradient_norm is not None:
                info["Norm"] = gradient_norm

            self.optimizer.step()

            if "States" in info:
                states_trained = info["States"]
                info["States"] = states_trained + columns[0]
            else:
                info["States"] = columns[0]

        # Pad out the states to be of size batch_size
        if len(info["States"]) < self.args.batch_size:
            old_states = info["States"]
            new_states = old_states[0] * (self.args.batch_size - len(old_states))
            info["States"] = new_states

        return info
Esempio n. 27
0
class SLearner:
    def __init__(self, mac, scheme, logger, args):
        self.args = args
        self.mac = mac
        self.logger = logger

        self.n_actions_levin = args.n_actions

        self.params = list(mac.parameters())

        self.last_target_update_episode = 0

        self.mixer = None
        if args.mixer is not None:
            if args.mixer == "vdn":
                self.mixer = VDNMixer()
            elif args.mixer == "qmix":
                self.mixer = QMixer(args)
            else:
                raise ValueError("Mixer {} not recognised.".format(args.mixer))
            self.params += list(self.mixer.parameters())

            if not args.SubAVG_Mixer_flag:
                self.target_mixer = copy.deepcopy(self.mixer)

            elif args.mixer == "qmix":
                self.target_mixer_list = []
                for i in range(self.args.SubAVG_Mixer_K):
                    self.target_mixer_list.append(copy.deepcopy(self.mixer))
                self.levin_iter_target_mixer_update = 0

        self.optimiser = RMSprop(params=self.params,
                                 lr=args.lr,
                                 alpha=args.optim_alpha,
                                 eps=args.optim_eps)

        # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC
        if not self.args.SubAVG_Agent_flag:
            self.target_mac = copy.deepcopy(mac)
        else:
            self.target_mac_list = []
            for i in range(self.args.SubAVG_Agent_K):
                self.target_mac_list.append(copy.deepcopy(mac))
            self.levin_iter_target_update = 0

        self.log_stats_t = -self.args.learner_log_interval - 1

        # ====== levin =====
        self.number = 0

    def train(self,
              batch: EpisodeBatch,
              t_env: int,
              episode_num: int,
              epsilon_levin=None):
        # Get the relevant quantities
        rewards = batch["reward"][:, :-1]
        actions = batch["actions"][:, :-1]
        terminated = batch["terminated"][:, :-1].float()
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        avail_actions = batch["avail_actions"]

        # Calculate estimated Q-Values
        mac_out = []
        self.mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            agent_outs = self.mac.forward(batch, t=t)
            mac_out.append(agent_outs)
        mac_out = th.stack(mac_out, dim=1)  # Concat over time

        # Pick the Q-Values for the actions taken by each agent
        chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3,
                                        index=actions).squeeze(3)

        # Calculate the Q-Values necessary for the target
        target_mac_out = []
        if not self.args.SubAVG_Agent_flag:
            self.target_mac.init_hidden(batch.batch_size)
        else:
            for i in range(self.args.SubAVG_Agent_K):
                self.target_mac_list[i].init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            if not self.args.SubAVG_Agent_flag:
                target_agent_outs = self.target_mac.forward(batch, t=t)
            # exp:使用 average DQN的target_mac
            else:
                target_agent_outs = 0

                self.target_agent_out_list = []
                for i in range(self.args.SubAVG_Agent_K):
                    target_agent_out = self.target_mac_list[i].forward(batch,
                                                                       t=t)
                    target_agent_outs = target_agent_outs + target_agent_out
                    if self.args.SubAVG_Agent_flag_select:
                        self.target_agent_out_list.append(target_agent_out)
                target_agent_outs = target_agent_outs / self.args.SubAVG_Agent_K
                if self.args.SubAVG_Agent_flag_select:
                    if self.args.SubAVG_Agent_name_select_replacement == 'mean':
                        target_out_select_sum = 0
                        for i in range(self.args.SubAVG_Agent_K):
                            if self.args.SubAVG_Agent_flag_select > 0:
                                target_out_select = th.where(
                                    self.target_agent_out_list[i] <
                                    target_agent_outs, target_agent_outs,
                                    self.target_agent_out_list[i])
                            else:
                                target_out_select = th.where(
                                    self.target_agent_out_list[i] >
                                    target_agent_outs, target_agent_outs,
                                    self.target_agent_out_list[i])
                            target_out_select_sum = target_out_select_sum + target_out_select
                        target_agent_outs = target_out_select_sum / self.args.SubAVG_Agent_K
                    elif self.args.SubAVG_Agent_name_select_replacement == 'zero':
                        target_out_select_sum = 0
                        target_select_bool_sum = 0
                        for i in range(self.args.SubAVG_Agent_K):
                            if self.args.SubAVG_Agent_flag_select > 0:
                                target_select_bool = (
                                    self.target_agent_out_list[i] >
                                    target_agent_outs).float()
                                target_out_select = th.where(
                                    self.target_agent_out_list[i] >
                                    target_agent_outs,
                                    self.target_agent_out_list[i],
                                    th.full_like(target_agent_outs, 0))
                            else:
                                target_select_bool = (
                                    self.target_agent_out_list[i] <
                                    target_agent_outs).float()
                                target_out_select = th.where(
                                    self.target_agent_out_list[i] <
                                    target_agent_outs,
                                    self.target_agent_out_list[i],
                                    th.full_like(target_agent_outs, 0))
                            target_select_bool_sum = target_select_bool_sum + target_select_bool
                            target_out_select_sum = target_out_select_sum + target_out_select
                        if self.levin_iter_target_update < 2:
                            pass  # print("using average directly")
                        else:
                            target_agent_outs = target_out_select_sum / target_select_bool_sum
            target_mac_out.append(target_agent_outs)

        # We don't need the first timesteps Q-Value estimate for calculating targets
        target_mac_out = th.stack(target_mac_out, dim=1)  # Concat across time

        # Mask out unavailable actions
        target_chosen_action_qvals = th.gather(target_mac_out, 3,
                                               batch['actions']).squeeze(-1)

        # Mix
        if self.mixer is None:
            target_qvals = target_chosen_action_qvals
        else:
            chosen_action_qvals = self.mixer(chosen_action_qvals,
                                             batch["state"][:, :-1])
            if not self.args.SubAVG_Mixer_flag:
                target_qvals = self.target_mixer(target_chosen_action_qvals,
                                                 batch['state'])
            elif self.args.mixer == "qmix":
                target_max_qvals_sum = 0
                self.target_mixer_out_list = []
                for i in range(self.args.SubAVG_Mixer_K):
                    targe_mixer_out = self.target_mixer_list[i](
                        target_chosen_action_qvals, batch['state'])
                    target_max_qvals_sum = target_max_qvals_sum + targe_mixer_out
                    if self.args.SubAVG_Mixer_flag_select:
                        self.target_mixer_out_list.append(targe_mixer_out)
                target_max_qvals = target_max_qvals_sum / self.args.SubAVG_Mixer_K

                # levin: mixer select
                if self.args.SubAVG_Mixer_flag_select:
                    if self.args.SubAVG_Mixer_name_select_replacement == 'mean':
                        target_mixer_select_sum = 0
                        for i in range(self.args.SubAVG_Mixer_K):
                            if self.args.SubAVG_Mixer_flag_select > 0:
                                target_mixer_select = th.where(
                                    self.target_mixer_out_list[i] <
                                    target_max_qvals, target_max_qvals,
                                    self.target_mixer_out_list[i])
                            else:
                                target_mixer_select = th.where(
                                    self.target_mixer_out_list[i] >
                                    target_max_qvals, target_max_qvals,
                                    self.target_mixer_out_list[i])
                            target_mixer_select_sum = target_mixer_select_sum + target_mixer_select
                        target_max_qvals = target_mixer_select_sum / self.args.SubAVG_Mixer_K
                    elif self.args.SubAVG_Mixer_name_select_replacement == 'zero':
                        target_mixer_select_sum = 0
                        target_mixer_select_bool_sum = 0
                        for i in range(self.args.SubAVG_Mixer_K):
                            if self.args.SubAVG_Mixer_flag_select > 0:
                                target_mixer_select_bool = (
                                    self.target_mixer_out_list[i] >
                                    target_max_qvals).float()
                                target_mixer_select = th.where(
                                    self.target_mixer_out_list[i] >
                                    target_max_qvals,
                                    self.target_mixer_out_list[i],
                                    th.full_like(target_max_qvals, 0))
                            else:
                                target_mixer_select_bool = (
                                    self.target_mixer_out_list[i] <
                                    target_max_qvals).float()
                                target_mixer_select = th.where(
                                    self.target_mixer_out_list[i] <
                                    target_max_qvals,
                                    self.target_mixer_out_list[i],
                                    th.full_like(target_max_qvals, 0))
                            target_mixer_select_bool_sum = target_mixer_select_bool_sum + target_mixer_select_bool
                            target_mixer_select_sum = target_mixer_select_sum + target_mixer_select
                        if self.levin_iter_target_mixer_update < 2:
                            pass  # print("using average-mix directly")
                        else:
                            target_max_qvals = target_mixer_select_sum / target_mixer_select_bool_sum
                target_qvals = target_max_qvals

        if self.args.td_lambda <= 1 and self.args.td_lambda > 0:
            targets = build_td_lambda_targets(rewards, terminated, mask,
                                              target_qvals, self.args.n_agents,
                                              self.args.gamma,
                                              self.args.td_lambda)
        else:
            if self.args.td_lambda == 0:
                n = 1  # 1-step TD
            else:
                n = self.args.td_lambda

            targets = th.zeros_like(batch['reward'])
            targets += batch['reward']

            for i in range(1, n):
                targets[:, :-i] += (self.args.gamma**i) * (
                    1 - terminated[:, i - 1:]) * batch['reward'][:, i:]
            targets[:, :-n] += (self.args.gamma**n) * (
                1 - terminated[:, n - 1:]) * target_qvals[:, n:]

            targets = targets[:, :-1]

        # Td-error
        td_error = (chosen_action_qvals - targets.detach())

        mask = mask.expand_as(td_error)

        # 0-out the targets that came from padded data
        masked_td_error = td_error * mask
        # Normal L2 loss, take mean over actual data
        loss = (masked_td_error**2).sum() / mask.sum() * 2

        # Optimise
        self.optimiser.zero_grad()
        loss.backward()
        grad_norm = th.nn.utils.clip_grad_norm_(self.params,
                                                self.args.grad_norm_clip)
        self.optimiser.step()

        if (episode_num - self.last_target_update_episode
            ) / self.args.target_update_interval >= 1.0:
            self._update_targets()
            self.last_target_update_episode = episode_num

        if t_env - self.log_stats_t >= self.args.learner_log_interval:
            self.logger.log_stat("loss", loss.item(), t_env)
            # self.logger.log_stat("loss_levin", loss_levin.item(), t_env)
            self.logger.log_stat("grad_norm", grad_norm, t_env)
            mask_elems = mask.sum().item()
            self.logger.log_stat(
                "td_error_abs",
                (masked_td_error.abs().sum().item() / mask_elems), t_env)
            self.logger.log_stat("q_taken_mean",
                                 (chosen_action_qvals * mask).sum().item() /
                                 (mask_elems * self.args.n_agents), t_env)
            self.logger.log_stat("target_mean", (targets * mask).sum().item() /
                                 (mask_elems * self.args.n_agents), t_env)
            self.log_stats_t = t_env

    def _update_targets(self):
        if not self.args.SubAVG_Agent_flag:
            self.target_mac.load_state(self.mac)
        else:
            self.number = self.levin_iter_target_update % self.args.SubAVG_Agent_K
            self.target_mac_list[self.number].load_state(self.mac)
            self.levin_iter_target_update = self.levin_iter_target_update + 1

        if self.mixer is not None:
            if not self.args.SubAVG_Mixer_flag:
                self.target_mixer.load_state_dict(self.mixer.state_dict())
            elif self.args.mixer == "qmix":
                mixer_number = self.levin_iter_target_mixer_update % self.args.SubAVG_Mixer_K
                self.target_mixer_list[mixer_number].load_state_dict(
                    self.mixer.state_dict())
                self.levin_iter_target_mixer_update = self.levin_iter_target_mixer_update + 1
        self.logger.console_logger.info("Updated target network")

    def cuda(self):
        self.mac.cuda()
        if not self.args.SubAVG_Agent_flag:
            self.target_mac.cuda()
        else:
            for i in range(self.args.SubAVG_Agent_K):
                self.target_mac_list[i].cuda()
        if self.mixer is not None:
            self.mixer.cuda()
            if not self.args.SubAVG_Mixer_flag:
                self.target_mixer.cuda()
            elif self.args.mixer == "qmix":
                for i in range(self.args.SubAVG_Mixer_K):
                    self.target_mixer_list[i].cuda()

    def save_models(self, path):
        self.mac.save_models(path)
        if self.mixer is not None:
            th.save(self.mixer.state_dict(), "{}/mixer.th".format(path))
        th.save(self.optimiser.state_dict(), "{}/opt.th".format(path))

    def load_models(self, path):
        self.mac.load_models(path)
        # Not quite right but I don't want to save target networks
        if not self.args.SubAVG_Agent_flag:
            self.target_mac.load_models(path)
        else:
            for i in range(self.args.SubAVG_Agent_K):
                self.target_mac_list[i].load_models(path)

        if self.mixer is not None:
            self.mixer.load_state_dict(
                th.load("{}/mixer.th".format(path),
                        map_location=lambda storage, loc: storage))
        self.optimiser.load_state_dict(
            th.load("{}/opt.th".format(path),
                    map_location=lambda storage, loc: storage))
Esempio n. 28
0
class A2C(Agent):
    """
    An agent learned with Advantage Actor-Critic
    - Actor takes state as input
    - Critic takes both state and action as input
    - agent interact with environment to collect experience
    - agent training with experience to update policy
    """
    def __init__(self,
                 env,
                 state_dim,
                 action_dim,
                 memory_capacity=10000,
                 max_steps=None,
                 roll_out_n_steps=10,
                 reward_gamma=0.99,
                 reward_scale=1.,
                 done_penalty=None,
                 actor_hidden_size=32,
                 critic_hidden_size=32,
                 actor_output_act=nn.functional.log_softmax,
                 critic_loss="mse",
                 actor_lr=0.001,
                 critic_lr=0.001,
                 optimizer_type="rmsprop",
                 entropy_reg=0.01,
                 max_grad_norm=0.5,
                 batch_size=100,
                 episodes_before_train=100,
                 epsilon_start=0.9,
                 epsilon_end=0.01,
                 epsilon_decay=200,
                 use_cuda=True):
        super(A2C,
              self).__init__(env, state_dim, action_dim, memory_capacity,
                             max_steps, reward_gamma, reward_scale,
                             done_penalty, actor_hidden_size,
                             critic_hidden_size, actor_output_act, critic_loss,
                             actor_lr, critic_lr, optimizer_type, entropy_reg,
                             max_grad_norm, batch_size, episodes_before_train,
                             epsilon_start, epsilon_end, epsilon_decay,
                             use_cuda)

        self.roll_out_n_steps = roll_out_n_steps

        self.actor = ActorNetwork(self.state_dim, self.actor_hidden_size,
                                  self.action_dim, self.actor_output_act)
        self.critic = CriticNetwork(self.state_dim, self.action_dim,
                                    self.critic_hidden_size, 1)
        if self.optimizer_type == "adam":
            self.actor_optimizer = Adam(self.actor.parameters(),
                                        lr=self.actor_lr)
            self.critic_optimizer = Adam(self.critic.parameters(),
                                         lr=self.critic_lr)
        elif self.optimizer_type == "rmsprop":
            self.actor_optimizer = RMSprop(self.actor.parameters(),
                                           lr=self.actor_lr)
            self.critic_optimizer = RMSprop(self.critic.parameters(),
                                            lr=self.critic_lr)
        if self.use_cuda:
            self.actor.cuda()

    # agent interact with the environment to collect experience
    def interact(self):
        super(A2C, self)._take_n_steps()

    # train on a roll out batch
    def train(self):
        if self.n_episodes <= self.episodes_before_train:
            pass

        batch = self.memory.sample(self.batch_size)
        states_var = to_tensor_var(batch.states,
                                   self.use_cuda).view(-1, self.state_dim)
        one_hot_actions = index_to_one_hot(batch.actions, self.action_dim)
        actions_var = to_tensor_var(one_hot_actions,
                                    self.use_cuda).view(-1, self.action_dim)
        rewards_var = to_tensor_var(batch.rewards, self.use_cuda).view(-1, 1)

        # update actor network
        self.actor_optimizer.zero_grad()
        action_log_probs = self.actor(states_var)
        entropy_loss = th.mean(entropy(th.exp(action_log_probs)))
        action_log_probs = th.sum(action_log_probs * actions_var, 1)
        values = self.critic(states_var, actions_var)
        advantages = rewards_var - values.detach()
        pg_loss = -th.mean(action_log_probs * advantages)
        actor_loss = pg_loss - entropy_loss * self.entropy_reg
        actor_loss.backward()
        if self.max_grad_norm is not None:
            nn.utils.clip_grad_norm(self.actor.parameters(),
                                    self.max_grad_norm)
        self.actor_optimizer.step()

        # update critic network
        self.critic_optimizer.zero_grad()
        target_values = rewards_var
        if self.critic_loss == "huber":
            critic_loss = nn.functional.smooth_l1_loss(values, target_values)
        else:
            critic_loss = nn.MSELoss()(values, target_values)
        critic_loss.backward()
        if self.max_grad_norm is not None:
            nn.utils.clip_grad_norm(self.critic.parameters(),
                                    self.max_grad_norm)
        self.critic_optimizer.step()

    # predict softmax action based on state
    def _softmax_action(self, state):
        state_var = to_tensor_var([state], self.use_cuda)
        softmax_action_var = th.exp(self.actor(state_var))
        if self.use_cuda:
            softmax_action = softmax_action_var.data.cpu().numpy()[0]
        else:
            softmax_action = softmax_action_var.data.numpy()[0]
        return softmax_action

    # choose an action based on state with random noise added for exploration in training
    def exploration_action(self, state):
        softmax_action = self._softmax_action(state)
        epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
                                  np.exp(-1. * self.n_steps / self.epsilon_decay)
        if np.random.rand() < epsilon:
            action = np.random.choice(self.action_dim)
        else:
            action = np.argmax(softmax_action)
        return action

    # choose an action based on state for execution
    def action(self, state):
        softmax_action = self._softmax_action(state)
        action = np.argmax(softmax_action)
        return action

    # evaluate value for a state-action pair
    def value(self, state, action):
        state_var = to_tensor_var([state], self.use_cuda)
        action = index_to_one_hot(action, self.action_dim)
        action_var = to_tensor_var([action], self.use_cuda)
        value_var = self.critic(state_var, action_var)
        if self.use_cuda:
            value = value_var.data.cpu().numpy()[0]
        else:
            value = value_var.data.numpy()[0]
        return value
Esempio n. 29
0
class OffPGLearner:
    def __init__(self, mac, scheme, logger, args):
        self.args = args
        self.n_agents = args.n_agents
        self.n_actions = args.n_actions
        self.mac = mac
        self.logger = logger

        self.last_target_update_step = 0
        self.critic_training_steps = 0

        self.log_stats_t = -self.args.learner_log_interval - 1

        self.critic = OffPGCritic(scheme, args)
        self.mixer = QMixer(args)
        self.target_critic = copy.deepcopy(self.critic)
        self.target_mixer = copy.deepcopy(self.mixer)

        self.agent_params = list(mac.parameters())
        self.critic_params = list(self.critic.parameters())
        self.mixer_params = list(self.mixer.parameters())
        self.params = self.agent_params + self.critic_params
        self.c_params = self.critic_params + self.mixer_params

        self.agent_optimiser = RMSprop(params=self.agent_params,
                                       lr=args.lr,
                                       alpha=args.optim_alpha,
                                       eps=args.optim_eps)
        self.critic_optimiser = RMSprop(params=self.critic_params,
                                        lr=args.critic_lr,
                                        alpha=args.optim_alpha,
                                        eps=args.optim_eps)
        self.mixer_optimiser = RMSprop(params=self.mixer_params,
                                       lr=args.critic_lr,
                                       alpha=args.optim_alpha,
                                       eps=args.optim_eps)

    def train(self, batch: EpisodeBatch, t_env: int, log):
        # Get the relevant quantities
        bs = batch.batch_size
        max_t = batch.max_seq_length
        actions = batch["actions"][:, :-1]
        terminated = batch["terminated"][:, :-1].float()
        avail_actions = batch["avail_actions"][:, :-1]
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        mask = mask.repeat(1, 1, self.n_agents).view(-1)
        states = batch["state"][:, :-1]

        #build q
        inputs = self.critic._build_inputs(batch, bs, max_t)
        q_vals = self.critic.forward(inputs).detach()[:, :-1]

        mac_out = []
        self.mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length - 1):
            agent_outs = self.mac.forward(batch, t=t)
            mac_out.append(agent_outs)
        mac_out = th.stack(mac_out, dim=1)  # Concat over time

        # Mask out unavailable actions, renormalise (as in action selection)
        mac_out[avail_actions == 0] = 0
        mac_out = mac_out / mac_out.sum(dim=-1, keepdim=True)
        mac_out[avail_actions == 0] = 0

        # Calculated baseline
        q_taken = th.gather(q_vals, dim=3, index=actions).squeeze(3)
        pi = mac_out.view(-1, self.n_actions)
        baseline = th.sum(mac_out * q_vals, dim=-1).view(-1).detach()

        # Calculate policy grad with mask
        pi_taken = th.gather(pi, dim=1, index=actions.reshape(-1,
                                                              1)).squeeze(1)
        pi_taken[mask == 0] = 1.0
        log_pi_taken = th.log(pi_taken)
        coe = self.mixer.k(states).view(-1)

        advantages = (q_taken.view(-1) - baseline).detach()

        coma_loss = -(
            (coe * advantages * log_pi_taken) * mask).sum() / mask.sum()

        # Optimise agents
        self.agent_optimiser.zero_grad()
        coma_loss.backward()
        grad_norm = th.nn.utils.clip_grad_norm_(self.agent_params,
                                                self.args.grad_norm_clip)
        self.agent_optimiser.step()

        #compute parameters sum for debugging
        p_sum = 0.
        for p in self.agent_params:
            p_sum += p.data.abs().sum().item() / 100.0

        if t_env - self.log_stats_t >= self.args.learner_log_interval:
            ts_logged = len(log["critic_loss"])
            for key in [
                    "critic_loss", "critic_grad_norm", "td_error_abs",
                    "q_taken_mean", "target_mean", "q_max_mean", "q_min_mean",
                    "q_max_var", "q_min_var"
            ]:
                self.logger.log_stat(key, sum(log[key]) / ts_logged, t_env)
            self.logger.log_stat("q_max_first", log["q_max_first"], t_env)
            self.logger.log_stat("q_min_first", log["q_min_first"], t_env)
            #self.logger.log_stat("advantage_mean", (advantages * mask).sum().item() / mask.sum().item(), t_env)
            self.logger.log_stat("coma_loss", coma_loss.item(), t_env)
            self.logger.log_stat("agent_grad_norm", grad_norm, t_env)
            self.logger.log_stat("pi_max",
                                 (pi.max(dim=1)[0] * mask).sum().item() /
                                 mask.sum().item(), t_env)
            self.log_stats_t = t_env

    def train_critic(self, on_batch, best_batch=None, log=None):
        bs = on_batch.batch_size
        max_t = on_batch.max_seq_length
        rewards = on_batch["reward"][:, :-1]
        actions = on_batch["actions"][:, :]
        terminated = on_batch["terminated"][:, :-1].float()
        mask = on_batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        avail_actions = on_batch["avail_actions"][:]
        states = on_batch["state"]

        #build_target_q
        target_inputs = self.target_critic._build_inputs(on_batch, bs, max_t)
        target_q_vals = self.target_critic.forward(target_inputs).detach()
        targets_taken = self.target_mixer(
            th.gather(target_q_vals, dim=3, index=actions).squeeze(3), states)
        target_q = build_td_lambda_targets(rewards, terminated, mask,
                                           targets_taken, self.n_agents,
                                           self.args.gamma,
                                           self.args.td_lambda).detach()

        inputs = self.critic._build_inputs(on_batch, bs, max_t)

        mac_out = []
        self.mac.init_hidden(bs)
        for i in range(max_t):
            agent_outs = self.mac.forward(on_batch, t=i)
            mac_out.append(agent_outs)
        mac_out = th.stack(mac_out, dim=1).detach()
        # Mask out unavailable actions, renormalise (as in action selection)
        mac_out[avail_actions == 0] = 0
        mac_out = mac_out / mac_out.sum(dim=-1, keepdim=True)
        mac_out[avail_actions == 0] = 0

        if best_batch is not None:
            best_target_q, best_inputs, best_mask, best_actions, best_mac_out = self.train_critic_best(
                best_batch)
            log["best_reward"] = th.mean(
                best_batch["reward"][:, :-1].squeeze(2).sum(-1), dim=0)
            target_q = th.cat((target_q, best_target_q), dim=0)
            inputs = th.cat((inputs, best_inputs), dim=0)
            mask = th.cat((mask, best_mask), dim=0)
            actions = th.cat((actions, best_actions), dim=0)
            states = th.cat((states, best_batch["state"]), dim=0)
            mac_out = th.cat((mac_out, best_mac_out), dim=0)

        #train critic
        mac_out = mac_out.detach()
        for t in range(max_t - 1):
            mask_t = mask[:, t:t + 1]
            if mask_t.sum() < 0.5:
                continue
            k = self.mixer.k(states[:, t:t + 1]).unsqueeze(3)
            #b = self.mixer.b(states[:, t:t+1])
            q_vals = self.critic.forward(inputs[:, t:t + 1])
            q_ori = q_vals
            q_vals = th.gather(q_vals, 3, index=actions[:, t:t + 1]).squeeze(3)
            q_vals = self.mixer.forward(q_vals, states[:, t:t + 1])
            target_q_t = target_q[:, t:t + 1].detach()
            q_err = (q_vals - target_q_t) * mask_t
            critic_loss = (q_err**2).sum() / mask_t.sum()
            #Here introduce the loss for Qi
            v_vals = th.sum(q_ori * mac_out[:, t:t + 1], dim=3, keepdim=True)
            ad_vals = q_ori - v_vals
            goal = th.sum(k * v_vals, dim=2, keepdim=True) + k * ad_vals
            goal_err = (goal - q_ori) * mask_t
            goal_loss = 0.1 * (goal_err**
                               2).sum() / mask_t.sum() / self.args.n_actions
            #critic_loss += goal_loss
            self.critic_optimiser.zero_grad()
            self.mixer_optimiser.zero_grad()
            critic_loss.backward()
            grad_norm = th.nn.utils.clip_grad_norm_(self.c_params,
                                                    self.args.grad_norm_clip)
            self.critic_optimiser.step()
            self.mixer_optimiser.step()
            self.critic_training_steps += 1

            log["critic_loss"].append(critic_loss.item())
            log["critic_grad_norm"].append(grad_norm)
            mask_elems = mask_t.sum().item()
            log["td_error_abs"].append((q_err.abs().sum().item() / mask_elems))
            log["target_mean"].append(
                (target_q_t * mask_t).sum().item() / mask_elems)
            log["q_taken_mean"].append(
                (q_vals * mask_t).sum().item() / mask_elems)
            log["q_max_mean"].append(
                (th.mean(q_ori.max(dim=3)[0], dim=2, keepdim=True) *
                 mask_t).sum().item() / mask_elems)
            log["q_min_mean"].append(
                (th.mean(q_ori.min(dim=3)[0], dim=2, keepdim=True) *
                 mask_t).sum().item() / mask_elems)
            log["q_max_var"].append(
                (th.var(q_ori.max(dim=3)[0], dim=2, keepdim=True) *
                 mask_t).sum().item() / mask_elems)
            log["q_min_var"].append(
                (th.var(q_ori.min(dim=3)[0], dim=2, keepdim=True) *
                 mask_t).sum().item() / mask_elems)

            if (t == 0):
                log["q_max_first"] = (
                    th.mean(q_ori.max(dim=3)[0], dim=2, keepdim=True) *
                    mask_t).sum().item() / mask_elems
                log["q_min_first"] = (
                    th.mean(q_ori.min(dim=3)[0], dim=2, keepdim=True) *
                    mask_t).sum().item() / mask_elems

        #update target network
        if (self.critic_training_steps - self.last_target_update_step
            ) / self.args.target_update_interval >= 1.0:
            self._update_targets()
            self.last_target_update_step = self.critic_training_steps

    def train_critic_best(self, batch):
        bs = batch.batch_size
        max_t = batch.max_seq_length
        rewards = batch["reward"][:, :-1]
        actions = batch["actions"][:, :]
        terminated = batch["terminated"][:, :-1].float()
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        avail_actions = batch["avail_actions"][:]
        states = batch["state"]

        # pr for all actions of the episode
        mac_out = []
        self.mac.init_hidden(bs)
        for i in range(max_t):
            agent_outs = self.mac.forward(batch, t=i)
            mac_out.append(agent_outs)
        mac_out = th.stack(mac_out, dim=1).detach()
        # Mask out unavailable actions, renormalise (as in action selection)
        mac_out[avail_actions == 0] = 0
        mac_out = mac_out / mac_out.sum(dim=-1, keepdim=True)
        mac_out[avail_actions == 0] = 0
        critic_mac = th.gather(mac_out, 3,
                               actions).squeeze(3).prod(dim=2, keepdim=True)

        #target_q take
        target_inputs = self.target_critic._build_inputs(batch, bs, max_t)
        target_q_vals = self.target_critic.forward(target_inputs).detach()
        targets_taken = self.target_mixer(
            th.gather(target_q_vals, dim=3, index=actions).squeeze(3), states)

        #expected q
        exp_q = self.build_exp_q(target_q_vals, mac_out, states).detach()
        # td-error
        targets_taken[:,
                      -1] = targets_taken[:,
                                          -1] * (1 - th.sum(terminated, dim=1))
        exp_q[:, -1] = exp_q[:, -1] * (1 - th.sum(terminated, dim=1))
        targets_taken[:, :-1] = targets_taken[:, :-1] * mask
        exp_q[:, :-1] = exp_q[:, :-1] * mask
        td_q = (rewards + self.args.gamma * exp_q[:, 1:] -
                targets_taken[:, :-1]) * mask

        #compute target
        target_q = build_target_q(td_q, targets_taken[:, :-1], critic_mac,
                                  mask, self.args.gamma, self.args.tb_lambda,
                                  self.args.step).detach()

        inputs = self.critic._build_inputs(batch, bs, max_t)

        return target_q, inputs, mask, actions, mac_out

    def build_exp_q(self, target_q_vals, mac_out, states):
        target_exp_q_vals = th.sum(target_q_vals * mac_out, dim=3)
        target_exp_q_vals = self.target_mixer.forward(target_exp_q_vals,
                                                      states)
        return target_exp_q_vals

    def _update_targets(self):
        self.target_critic.load_state_dict(self.critic.state_dict())
        self.target_mixer.load_state_dict(self.mixer.state_dict())
        self.logger.console_logger.info("Updated target network")

    def cuda(self):
        self.mac.cuda()
        self.critic.cuda()
        self.mixer.cuda()
        self.target_critic.cuda()
        self.target_mixer.cuda()

    def save_models(self, path):
        self.mac.save_models(path)
        th.save(self.critic.state_dict(), "{}/critic.th".format(path))
        th.save(self.mixer.state_dict(), "{}/mixer.th".format(path))
        th.save(self.agent_optimiser.state_dict(),
                "{}/agent_opt.th".format(path))
        th.save(self.critic_optimiser.state_dict(),
                "{}/critic_opt.th".format(path))
        th.save(self.mixer_optimiser.state_dict(),
                "{}/mixer_opt.th".format(path))

    def load_models(self, path):
        self.mac.load_models(path)
        self.critic.load_state_dict(
            th.load("{}/critic.th".format(path),
                    map_location=lambda storage, loc: storage))
        self.mixer.load_state_dict(
            th.load("{}/mixer.th".format(path),
                    map_location=lambda storage, loc: storage))
        # Not quite right but I don't want to save target networks
        # self.target_critic.load_state_dict(self.critic.agent.state_dict())
        self.target_mixer.load_state_dict(self.mixer.state_dict())
        self.agent_optimiser.load_state_dict(
            th.load("{}/agent_opt.th".format(path),
                    map_location=lambda storage, loc: storage))
        self.critic_optimiser.load_state_dict(
            th.load("{}/critic_opt.th".format(path),
                    map_location=lambda storage, loc: storage))
        self.mixer_optimiser.load_state_dict(
            th.load("{}/mixer_opt.th".format(path),
                    map_location=lambda storage, loc: storage))
Esempio n. 30
0
def main():
    args = vars(parser.parse_args())
    agent_config = configs.get_agent_config(args)
    game_config = configs.get_game_config(args)
    training_config = configs.get_training_config(args)
    print("Training with config:")
    print(training_config)
    print(game_config)
    print(agent_config)
    agent = AgentModule(agent_config)
    if training_config.use_cuda:
        agent.cuda()
    optimizer = RMSprop(agent.parameters(), lr=training_config.learning_rate)
    scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, cooldown=5)
    losses = defaultdict(lambda: defaultdict(list))
    dists = defaultdict(lambda: defaultdict(list))
    for epoch in range(training_config.num_epochs):
        num_agents = np.random.randint(game_config.min_agents,
                                       game_config.max_agents + 1)
        num_landmarks = np.random.randint(game_config.min_landmarks,
                                          game_config.max_landmarks + 1)
        agent.reset()
        game = GameModule(game_config, num_agents, num_landmarks)
        if training_config.use_cuda:
            game.cuda()
        optimizer.zero_grad()

        total_loss, timesteps = agent(game)
        per_agent_loss = total_loss.data[
            0] / num_agents / game_config.batch_size
        losses[num_agents][num_landmarks].append(per_agent_loss)

        dist = game.get_avg_agent_to_goal_distance()
        avg_dist = dist.data / num_agents / game_config.batch_size
        dists[num_agents][num_landmarks].append(avg_dist)

        print_losses(epoch, losses, dists, game_config)
        # print("total loss:", total_loss.detach().numpy()[0])

        total_loss.backward()
        optimizer.step()

        if num_agents == game_config.max_agents and num_landmarks == game_config.max_landmarks:
            scheduler.step(
                losses[game_config.max_agents][game_config.max_landmarks][-1])
        '''
        This visualizes the trajectories of agents (circles) and target locations (crosses).
        It also displays the communication symbol usage. Basically, alpha channel of a letter represents
        how much the the agent was using the i-th symbol during the epoch (on each step
        communication is done by a [1, 20] float vector). I sum all these vectors through all steps.
        '''
        if epoch < 3 or epoch > training_config.num_epochs - 3:
            import matplotlib.pyplot as plt
            fig, ax = plt.subplots()
            ax.set_xticks([])
            ax.set_yticks([])
            colors = ['red', 'green', 'blue']
            agent_markers = ['o', '^']
            landmark_markers = ['P', '*']
            utterances = np.zeros_like(timesteps[0]['utterances'][0].detach())
            for time, timestep in enumerate(timesteps):
                agent_legends = []
                for idx, point in enumerate(
                        timestep['locations'][0][:num_agents]):
                    agent_legends.append(
                        plt.scatter(
                            *list(point.detach().numpy()),
                            color=colors[int(game.physical[0, idx, 0].item())],
                            marker=agent_markers[int(game.physical[0, idx,
                                                                   1].item())],
                            s=20,
                            alpha=0.75))
                for idx, point in enumerate(
                        timestep['locations'][0][-num_landmarks:]):
                    if time == 0:
                        plt.scatter(
                            *list(point.detach().numpy()),
                            color='dark' +
                            colors[int(game.physical[0, idx, 0].item())],
                            marker=landmark_markers[int(
                                game.physical[0, idx, 1].item())],
                            s=300,
                            alpha=0.75)
                utterances += timestep['utterances'][0].detach().numpy()
            # this controls how much we highlight or supress non-freqent symbol when displaying
            # pow < 1 helps to bring in the low freqent symbols that were emitted once and lost in sum
            # pow >=1 can highlight some important symbols through the epoch if it is too noisy
            utterances = np.power(
                utterances / utterances.max(axis=1)[..., np.newaxis], 2)
            for agent_idx in range(utterances.shape[0]):
                for symbol_idx in range(utterances.shape[1]):
                    plt.text(0,
                             1 + 0.01 + 0.05 * agent_idx,
                             str(agent_idx + 1) + ': ',
                             color=colors[int(game.physical[0, agent_idx,
                                                            0].item())],
                             transform=ax.transAxes)
                    plt.text(0.05 + 0.03 * symbol_idx,
                             1 + 0.01 + 0.05 * agent_idx,
                             'ABCDEFGHIJKLMNOPQRSTUVXYZ1234567890'[symbol_idx],
                             alpha=utterances[agent_idx, symbol_idx],
                             color=colors[int(game.physical[0, agent_idx,
                                                            0].item())],
                             transform=ax.transAxes)
            plt.legend(reversed(agent_legends),
                       reversed(
                           [str(i + 1) for i in range(len(agent_legends))]),
                       bbox_to_anchor=(0, 1.15))
            for a in range(game_config.min_agents, game_config.max_agents + 1):
                for l in range(game_config.min_landmarks,
                               game_config.max_landmarks + 1):
                    loss = losses[a][l][-1] if len(losses[a][l]) > 0 else 0
                    min_loss = min(
                        losses[a][l]) if len(losses[a][l]) > 0 else 0
                    plt.text(
                        0,
                        -0.05 - 0.05 * ((a - game_config.min_agents) +
                                        (l - game_config.min_landmarks)),
                        "[epoch %d][%d as, %d ls][last loss: %s][min loss: %s]"
                        % (epoch, a, l, ("%.7f" % loss)[:7],
                           ("%.7f" % min_loss)[:7]),
                        transform=ax.transAxes)
            plt.show()

    # if training_config.save_model:
    #     torch.save(agent, training_config.save_model_file)
    #     print("Saved agent model weights at %s" % training_config.save_model_file)
    """
Esempio n. 31
0
class QMixTorchPolicy(Policy):
    """QMix impl. Assumes homogeneous agents for now.

    You must use MultiAgentEnv.with_agent_groups() to group agents
    together for QMix. This creates the proper Tuple obs/action spaces and
    populates the '_group_rewards' info field.

    Action masking: to specify an action mask for individual agents, use a
    dict space with an action_mask key, e.g. {"obs": ob, "action_mask": mask}.
    The mask space must be `Box(0, 1, (n_actions,))`.
    """
    def __init__(self, obs_space, action_space, config):
        _validate(obs_space, action_space)
        config = dict(ray.rllib.agents.qmix.qmix.DEFAULT_CONFIG, **config)
        self.config = config
        self.observation_space = obs_space
        self.action_space = action_space
        self.n_agents = len(obs_space.original_space.spaces)
        self.n_actions = action_space.spaces[0].n
        self.h_size = config["model"]["lstm_cell_size"]

        agent_obs_space = obs_space.original_space.spaces[0]
        if isinstance(agent_obs_space, Dict):
            space_keys = set(agent_obs_space.spaces.keys())
            if space_keys != {"obs", "action_mask"}:
                raise ValueError(
                    "Dict obs space for agent must have keyset "
                    "['obs', 'action_mask'], got {}".format(space_keys))
            mask_shape = tuple(agent_obs_space.spaces["action_mask"].shape)
            if mask_shape != (self.n_actions, ):
                raise ValueError("Action mask shape must be {}, got {}".format(
                    (self.n_actions, ), mask_shape))
            self.has_action_mask = True
            self.obs_size = _get_size(agent_obs_space.spaces["obs"])
            # The real agent obs space is nested inside the dict
            agent_obs_space = agent_obs_space.spaces["obs"]
        else:
            self.has_action_mask = False
            self.obs_size = _get_size(agent_obs_space)

        self.model = ModelCatalog.get_torch_model(agent_obs_space,
                                                  self.n_actions,
                                                  config["model"],
                                                  default_model_cls=RNNModel)
        self.target_model = ModelCatalog.get_torch_model(
            agent_obs_space,
            self.n_actions,
            config["model"],
            default_model_cls=RNNModel)

        # Setup the mixer network.
        # The global state is just the stacked agent observations for now.
        self.state_shape = [self.obs_size, self.n_agents]
        if config["mixer"] is None:
            self.mixer = None
            self.target_mixer = None
        elif config["mixer"] == "qmix":
            self.mixer = QMixer(self.n_agents, self.state_shape,
                                config["mixing_embed_dim"])
            self.target_mixer = QMixer(self.n_agents, self.state_shape,
                                       config["mixing_embed_dim"])
        elif config["mixer"] == "vdn":
            self.mixer = VDNMixer()
            self.target_mixer = VDNMixer()
        else:
            raise ValueError("Unknown mixer type {}".format(config["mixer"]))

        self.cur_epsilon = 1.0
        self.update_target()  # initial sync

        # Setup optimizer
        self.params = list(self.model.parameters())
        if self.mixer:
            self.params += list(self.mixer.parameters())
        self.loss = QMixLoss(self.model, self.target_model, self.mixer,
                             self.target_mixer, self.n_agents, self.n_actions,
                             self.config["double_q"], self.config["gamma"])
        self.optimiser = RMSprop(params=self.params,
                                 lr=config["lr"],
                                 alpha=config["optim_alpha"],
                                 eps=config["optim_eps"])

    @override(Policy)
    def compute_actions(self,
                        obs_batch,
                        state_batches=None,
                        prev_action_batch=None,
                        prev_reward_batch=None,
                        info_batch=None,
                        episodes=None,
                        **kwargs):
        obs_batch, action_mask = self._unpack_observation(obs_batch)

        # Compute actions
        with th.no_grad():
            q_values, hiddens = _mac(
                self.model, th.from_numpy(obs_batch),
                [th.from_numpy(np.array(s)) for s in state_batches])
            avail = th.from_numpy(action_mask).float()
            masked_q_values = q_values.clone()
            masked_q_values[avail == 0.0] = -float("inf")
            # epsilon-greedy action selector
            random_numbers = th.rand_like(q_values[:, :, 0])
            pick_random = (random_numbers < self.cur_epsilon).long()
            random_actions = Categorical(avail).sample().long()
            actions = (pick_random * random_actions +
                       (1 - pick_random) * masked_q_values.max(dim=2)[1])
            actions = actions.numpy()
            hiddens = [s.numpy() for s in hiddens]

        return TupleActions(list(actions.transpose([1, 0]))), hiddens, {}

    @override(Policy)
    def learn_on_batch(self, samples):
        obs_batch, action_mask = self._unpack_observation(
            samples[SampleBatch.CUR_OBS])
        next_obs_batch, next_action_mask = self._unpack_observation(
            samples[SampleBatch.NEXT_OBS])
        group_rewards = self._get_group_rewards(samples[SampleBatch.INFOS])

        # These will be padded to shape [B * T, ...]
        [rew, action_mask, next_action_mask, act, dones, obs, next_obs], \
            initial_states, seq_lens = \
            chop_into_sequences(
                samples[SampleBatch.EPS_ID],
                samples[SampleBatch.UNROLL_ID],
                samples[SampleBatch.AGENT_INDEX], [
                    group_rewards, action_mask, next_action_mask,
                    samples[SampleBatch.ACTIONS], samples[SampleBatch.DONES],
                    obs_batch, next_obs_batch
                ],
                [samples["state_in_{}".format(k)]
                 for k in range(len(self.get_initial_state()))],
                max_seq_len=self.config["model"]["max_seq_len"],
                dynamic_max=True)
        B, T = len(seq_lens), max(seq_lens)

        def to_batches(arr):
            new_shape = [B, T] + list(arr.shape[1:])
            return th.from_numpy(np.reshape(arr, new_shape))

        rewards = to_batches(rew).float()
        actions = to_batches(act).long()
        obs = to_batches(obs).reshape([B, T, self.n_agents,
                                       self.obs_size]).float()
        action_mask = to_batches(action_mask)
        next_obs = to_batches(next_obs).reshape(
            [B, T, self.n_agents, self.obs_size]).float()
        next_action_mask = to_batches(next_action_mask)

        # TODO(ekl) this treats group termination as individual termination
        terminated = to_batches(dones.astype(np.float32)).unsqueeze(2).expand(
            B, T, self.n_agents)

        # Create mask for where index is < unpadded sequence length
        filled = (np.reshape(np.tile(np.arange(T), B), [B, T]) <
                  np.expand_dims(seq_lens, 1)).astype(np.float32)
        mask = th.from_numpy(filled).unsqueeze(2).expand(B, T, self.n_agents)

        # Compute loss
        loss_out, mask, masked_td_error, chosen_action_qvals, targets = \
            self.loss(rewards, actions, terminated, mask, obs,
                      next_obs, action_mask, next_action_mask)

        # Optimise
        self.optimiser.zero_grad()
        loss_out.backward()
        grad_norm = th.nn.utils.clip_grad_norm_(
            self.params, self.config["grad_norm_clipping"])
        self.optimiser.step()

        mask_elems = mask.sum().item()
        stats = {
            "loss":
            loss_out.item(),
            "grad_norm":
            grad_norm if isinstance(grad_norm, float) else grad_norm.item(),
            "td_error_abs":
            masked_td_error.abs().sum().item() / mask_elems,
            "q_taken_mean":
            (chosen_action_qvals * mask).sum().item() / mask_elems,
            "target_mean": (targets * mask).sum().item() / mask_elems,
        }
        return {LEARNER_STATS_KEY: stats}

    @override(Policy)
    def get_initial_state(self):
        return [
            s.expand([self.n_agents, -1]).numpy()
            for s in self.model.state_init()
        ]

    @override(Policy)
    def get_weights(self):
        return {"model": self.model.state_dict()}

    @override(Policy)
    def set_weights(self, weights):
        self.model.load_state_dict(weights["model"])

    @override(Policy)
    def get_state(self):
        return {
            "model": self.model.state_dict(),
            "target_model": self.target_model.state_dict(),
            "mixer": self.mixer.state_dict() if self.mixer else None,
            "target_mixer":
            self.target_mixer.state_dict() if self.mixer else None,
            "cur_epsilon": self.cur_epsilon,
        }

    @override(Policy)
    def set_state(self, state):
        self.model.load_state_dict(state["model"])
        self.target_model.load_state_dict(state["target_model"])
        if state["mixer"] is not None:
            self.mixer.load_state_dict(state["mixer"])
            self.target_mixer.load_state_dict(state["target_mixer"])
        self.set_epsilon(state["cur_epsilon"])
        self.update_target()

    def update_target(self):
        self.target_model.load_state_dict(self.model.state_dict())
        if self.mixer is not None:
            self.target_mixer.load_state_dict(self.mixer.state_dict())
        logger.debug("Updated target networks")

    def set_epsilon(self, epsilon):
        self.cur_epsilon = epsilon

    def _get_group_rewards(self, info_batch):
        group_rewards = np.array([
            info.get(GROUP_REWARDS, [0.0] * self.n_agents)
            for info in info_batch
        ])
        return group_rewards

    def _unpack_observation(self, obs_batch):
        """Unpacks the action mask / tuple obs from agent grouping.

        Returns:
            obs (Tensor): flattened obs tensor of shape [B, n_agents, obs_size]
            mask (Tensor): action mask, if any
        """
        unpacked = _unpack_obs(np.array(obs_batch),
                               self.observation_space.original_space,
                               tensorlib=np)
        if self.has_action_mask:
            obs = np.concatenate([o["obs"] for o in unpacked], axis=1).reshape(
                [len(obs_batch), self.n_agents, self.obs_size])
            action_mask = np.concatenate([o["action_mask"] for o in unpacked],
                                         axis=1).reshape([
                                             len(obs_batch), self.n_agents,
                                             self.n_actions
                                         ])
        else:
            obs = np.concatenate(unpacked, axis=1).reshape(
                [len(obs_batch), self.n_agents, self.obs_size])
            action_mask = np.ones(
                [len(obs_batch), self.n_agents, self.n_actions])
        return obs, action_mask
def main():

    parse = argparse.ArgumentParser()

    parse.add_argument("--lr", type=float, default=0.00005, 
                        help="learning rate of generate and discriminator")
    parse.add_argument("--clamp", type=float, default=0.01, 
                        help="clamp discriminator parameters")
    parse.add_argument("--batch_size", type=int, default=10,
                        help="number of dataset in every train or test iteration")
    parse.add_argument("--dataset", type=str, default="faces",
                        help="base path for dataset")
    parse.add_argument("--epochs", type=int, default=500,
                        help="number of training epochs")
    parse.add_argument("--loaders", type=int, default=4,
                        help="number of parallel data loading processing")
    parse.add_argument("--size_per_dataset", type=int, default=30000,
                        help="number of training data")

    args = parse.parse_args()

    if not os.path.exists("saved"):
        os.mkdir("saved")
    if not os.path.exists("saved/img"):
        os.mkdir("saved/img")

    if os.path.exists("faces"):
        pass
    else:
        print("Don't find the dataset directory, please copy the link in website ,download and extract faces.tar.gz .\n \
        https://drive.google.com/drive/folders/1mCsY5LEsgCnc0Txv0rpAUhKVPWVkbw5I \n ")
        exit()

    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    generate = Generate().to(device)
    discriminator = Discriminator().to(device)

    generate.apply(weight_init)
    discriminator.apply(weight_init)

    dataset = AnimeDataset(os.getcwd(), args.dataset, args.size_per_dataset)
    dataload = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.loaders)

    optimizer_G = RMSprop(generate.parameters(), lr=args.lr)
    optimizer_D = RMSprop(discriminator.parameters(), lr=args.lr)

    fixed_noise = torch.randn(64, 100, 1, 1).to(device)
    step = 0
    for epoch in range(args.epochs):

        print("Main epoch{}:".format(epoch))
        progress = tqdm(total=len(dataload.dataset))
        
        for i, inp in enumerate(dataload):
            step += 1
            # train discriminator   
            real_data = inp.float().to(device)
            noise = torch.randn(inp.size()[0], 100, 1, 1).to(device)
            fake_data = generate(noise)
            optimizer_D.zero_grad()
            real_output = torch.mean(discriminator(real_data).squeeze())
            fake_output = torch.mean(discriminator(fake_data).squeeze())
            output = (real_output - fake_output)* -1
            output.backward()
            optimizer_D.step()
            
            for param in discriminator.parameters():
                param.data.clamp_(-args.clamp, args.clamp)

            #train generate
            if step%5 == 0:
                optimizer_G.zero_grad()
                fake_data = generate(noise)
                fake_output = -torch.mean(discriminator(fake_data).squeeze())
                fake_output.backward()
                optimizer_G.step()
            
            progress.update(dataload.batch_size)

        if epoch % 20 == 0:

            torch.save(generate, os.path.join(os.getcwd(), "saved/generate.t7"))
            torch.save(discriminator, os.path.join(os.getcwd(), "saved/discriminator.t7"))

            img = generate(fixed_noise).to("cpu").detach().numpy()

            display_grid = np.zeros((8*96,8*96,3))
            
            for j in range(int(64/8)):
                for k in range(int(64/8)):
                    display_grid[j*96:(j+1)*96,k*96:(k+1)*96,:] = (img[k+8*j].transpose(1, 2, 0)+1)/2

            img_save_path = os.path.join(os.getcwd(),"saved/img/{}.png".format(epoch))
            scipy.misc.imsave(img_save_path, display_grid)

    creat_gif("evolution.gif", os.path.join(os.getcwd(),"saved/img"))
Esempio n. 33
0
class QLearner:
    def __init__(self, mac, scheme, logger, args):
        self.args = args
        self.mac = mac
        self.logger = logger

        self.params = list(mac.parameters())

        self.last_target_update_episode = 0

        self.mixer = None
        if args.mixer is not None:
            if args.mixer == "vdn":
                self.mixer = VDNMixer()
            elif args.mixer == "qmix":
                self.mixer = QMixer(args)
            else:
                raise ValueError("Mixer {} not recognised.".format(args.mixer))
            self.params += list(self.mixer.parameters())
            self.target_mixer = copy.deepcopy(self.mixer)

        self.optimiser = RMSprop(params=self.params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps)

        # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC
        self.target_mac = copy.deepcopy(mac)

        self.log_stats_t = -self.args.learner_log_interval - 1

    def train(self, batch: EpisodeBatch, t_env: int, episode_num: int):
        # Get the relevant quantities
        rewards = batch["reward"][:, :-1]
        actions = batch["actions"][:, :-1]
        terminated = batch["terminated"][:, :-1].float()
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        avail_actions = batch["avail_actions"]

        # Calculate estimated Q-Values
        mac_out = []
        self.mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            agent_outs = self.mac.forward(batch, t=t)
            mac_out.append(agent_outs)
        mac_out = th.stack(mac_out, dim=1)  # Concat over time

        # Pick the Q-Values for the actions taken by each agent
        chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze(3)  # Remove the last dim

        # Calculate the Q-Values necessary for the target
        target_mac_out = []
        self.target_mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            target_agent_outs = self.target_mac.forward(batch, t=t)
            target_mac_out.append(target_agent_outs)

        # We don't need the first timesteps Q-Value estimate for calculating targets
        target_mac_out = th.stack(target_mac_out[1:], dim=1)  # Concat across time

        # Mask out unavailable actions
        target_mac_out[avail_actions[:, 1:] == 0] = -9999999

        # Max over target Q-Values
        if self.args.double_q:
            # Get actions that maximise live Q (for double q-learning)
            mac_out_detach = mac_out.clone().detach()
            mac_out_detach[avail_actions == 0] = -9999999
            cur_max_actions = mac_out_detach[:, 1:].max(dim=3, keepdim=True)[1]
            target_max_qvals = th.gather(target_mac_out, 3, cur_max_actions).squeeze(3)
        else:
            target_max_qvals = target_mac_out.max(dim=3)[0]

        # Mix
        if self.mixer is not None:
            chosen_action_qvals = self.mixer(chosen_action_qvals, batch["state"][:, :-1])
            target_max_qvals = self.target_mixer(target_max_qvals, batch["state"][:, 1:])

        # Calculate 1-step Q-Learning targets
        targets = rewards + self.args.gamma * (1 - terminated) * target_max_qvals

        # Td-error
        td_error = (chosen_action_qvals - targets.detach())

        mask = mask.expand_as(td_error)

        # 0-out the targets that came from padded data
        masked_td_error = td_error * mask

        # Normal L2 loss, take mean over actual data
        loss = (masked_td_error ** 2).sum() / mask.sum()

        # Optimise
        self.optimiser.zero_grad()
        loss.backward()
        grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip)
        self.optimiser.step()

        if (episode_num - self.last_target_update_episode) / self.args.target_update_interval >= 1.0:
            self._update_targets()
            self.last_target_update_episode = episode_num

        if t_env - self.log_stats_t >= self.args.learner_log_interval:
            self.logger.log_stat("loss", loss.item(), t_env)
            self.logger.log_stat("grad_norm", grad_norm, t_env)
            mask_elems = mask.sum().item()
            self.logger.log_stat("td_error_abs", (masked_td_error.abs().sum().item()/mask_elems), t_env)
            self.logger.log_stat("q_taken_mean", (chosen_action_qvals * mask).sum().item()/(mask_elems * self.args.n_agents), t_env)
            self.logger.log_stat("target_mean", (targets * mask).sum().item()/(mask_elems * self.args.n_agents), t_env)
            self.log_stats_t = t_env

    def _update_targets(self):
        self.target_mac.load_state(self.mac)
        if self.mixer is not None:
            self.target_mixer.load_state_dict(self.mixer.state_dict())
        self.logger.console_logger.info("Updated target network")

    def cuda(self):
        self.mac.cuda()
        self.target_mac.cuda()
        if self.mixer is not None:
            self.mixer.cuda()
            self.target_mixer.cuda()

    def save_models(self, path):
        self.mac.save_models(path)
        if self.mixer is not None:
            th.save(self.mixer.state_dict(), "{}/mixer.th".format(path))
        th.save(self.optimiser.state_dict(), "{}/opt.th".format(path))

    def load_models(self, path):
        self.mac.load_models(path)
        # Not quite right but I don't want to save target networks
        self.target_mac.load_models(path)
        if self.mixer is not None:
            self.mixer.load_state_dict(th.load("{}/mixer.th".format(path), map_location=lambda storage, loc: storage))
        self.optimiser.load_state_dict(th.load("{}/opt.th".format(path), map_location=lambda storage, loc: storage))
Esempio n. 34
0
class Model(object):
    def __init__(self):
        self.data_path = config['data_dir']
        self.model_path = os.path.join(
            config['model_par_dir'],
            str(config['model_name_newest']) + '.pkl')
        self.build_net()
        self.load_par()
        self.model_name_newest = config['model_name_newest']

    def build_net(self):
        self.net = GCNet()
        # print(len(list(self.net.parameters())))
        if config['if_GPU']:
            self.net = self.net.cuda()
        self.dataset = {}
        self.dataloader = {}
        self.dataset['train'] = Kitty2015DataSet(root_dir=config['root_dir'])
        self.dataset['val'] = Kitty2015DataSet(root_dir=config['root_dir'],
                                               is_validation=True)
        self.dataloader['train'] = iter(
            DataLoader(dataset=self.dataset['train'],
                       batch_size=config['batch_size'],
                       shuffle=True))
        self.dataloader['val'] = iter(
            DataLoader(dataset=self.dataset['val'],
                       batch_size=config['batch_size'],
                       shuffle=True))
        self.criterion = RegLoss()
        self.val = Validation()
        self.opti = RMSprop(self.net.parameters(), lr=config['learning_rate'])

    def load_par(self):
        if os.path.exists(self.model_path):
            self.net.load_state_dict(torch.load(self.model_path))

    def save_par(self):
        model_name = str(self.model_name_newest) + '.pkl'
        model_path = os.path.join(config['model_par_dir'], model_name)
        torch.save(self.net.state_dict(), model_path)
        config['model_name_newest'] = self.model_name_newest
        self.model_name_newest += 1
        self.model_name_newest %= 5

    def train(self):
        epoches = config['epoches']
        start_time = time.time()
        for epoch in range(epoches):
            total_loss = 0
            for batch_index in range(config['batches_per_epoch']):
                st = time.time()
                sample = iter(self.dataloader['train']).next()
                loss = self.train_batch(sample, epoch, batch_index)
                et = time.time()
                print('cost {0} seconds'.format(int(et - st + 0.5)))
                total_loss += loss

                if (batch_index + 1) % config['batches_per_validation'] == 0:
                    end_time = time.time()
                    print('......time :{0}'.format(end_time - start_time))
                    start_time = end_time
                    print('...epoch  :  {0:2d}'.format(epoch))
                    print('....validation')
                    v0_t, v1_t, v2_t = 0, 0, 0
                    num = config['number_validation']
                    for i in range(num):
                        sample = iter(self.dataloader['val']).next()
                        v0, v1, v2 = self.validation(sample)
                        v0_t += v0
                        v1_t += v1
                        v2_t += v2
                    print(
                        '.....>2 px : {0}%  >3 px : {1}%  >5 px : {2}%'.format(
                            v0_t / num, v1_t / num, v2_t / num))
                    print('....save parameters')
                    self.save_par()
            print('...average loss : {0}'.format(total_loss /
                                                 config['batches_per_epoch']))

    def train_batch(self, batch_sample, epoch, batch_index):
        left_image = batch_sample['left_image'].float()
        right_image = batch_sample['right_image'].float()
        disp = batch_sample['disp'].float()
        if config['if_GPU']:
            left_image = left_image.cuda()
            right_image = right_image.cuda()
            disp = disp.cuda()
        left_image, right_image, disp = Variable(left_image), Variable(
            right_image), Variable(disp)
        disp_prediction = self.net((left_image, right_image))
        if config['if_GPU']:
            disp_prediction = disp_prediction.cuda()
        loss = self.criterion(disp_prediction, disp)
        if config['if_GPU']:
            loss = loss.cuda()
        print('.....epoch : {1} batch_index : {2} loss : {0}'.format(
            loss.data[0], epoch, batch_index))
        self.opti.zero_grad()
        loss.backward()
        self.opti.step()
        return loss.data[0]

    def validation(self, batch_sample):
        left_image = batch_sample['left_image'].float()
        right_image = batch_sample['right_image'].float()
        disp = batch_sample['disp'].float()
        if config['if_GPU']:
            left_image = left_image.cuda()
            right_image = right_image.cuda()
            disp = disp.cuda()
        left_image, right_image, disp = Variable(left_image), Variable(
            right_image), Variable(disp)
        disp_prediction = self.net((left_image, right_image))
        if config['if_GPU']:
            disp_prediction = disp_prediction.cuda()
        val = self.val(disp_prediction, disp)
        if config['if_GPU']:
            val = [val[0].cuda(), val[1].cuda(), val[2].cuda()]
        # print('.....>2 px : {0:2f}%  >3 px : {1:2f}%  >5 px : {2:2f}%'.format(val[0].data[0]*100,val[1].data[0]*100,val[2].data[0]*100))
        return val[0].data[0] * 100, val[1].data[0] * 100, val[2].data[0] * 100

    def predict_batch(self, batch_sample):
        left_image = batch_sample['left_image'].float()
        right_image = batch_sample['right_image'].float()
        disp = batch_sample['disp'].float()
        if config['if_GPU']:
            left_image = left_image.cuda()
            right_image = right_image.cuda()
            disp = disp.cuda()
        left_image, right_image, disp = Variable(left_image), Variable(
            right_image), Variable(disp)
        disp_prediction = self.net((left_image, right_image))
        loss = self.criterion(disp_prediction, disp)
        if config['if_GPU']:
            loss = loss.cuda()
        print('loss : {0}'.format(loss.data[0]))
        return disp, disp_prediction

    def predict(self):
        for i in range(10):
            print(i)
            dir = os.path.join('./Data', str(i))
            if not os.path.exists(dir):
                os.makedirs(dir)
            batch_sample = self.dataloader['val'].next()
            left_image_pre = batch_sample['pre_left']
            image_pre = left_image_pre[0].numpy()
            disp, disp_prediction = self.predict_batch(batch_sample)
            disp_prediction = disp_prediction.cpu().data.numpy() * 256
            disp_gt = disp.cpu().data.numpy()[0] * 256
            print(disp_prediction.shape)
            np.save(os.path.join(dir, 'image_pre.npy'), image_pre)
            np.save(os.path.join(dir, 'disp_gt.npy'), disp_gt)
            np.save(os.path.join(dir, 'disp_est.npy'), disp_prediction)