Ejemplo n.º 1
0
    def __init__(self, args, exp_model, logging_func):
        self.args = args

        # Exploration Model
        self.exp_model = exp_model

        self.log = logging_func["log"]
        self.log_image = logging_func["image"]
        os.makedirs("{}/transition_model".format(args.log_path))

        # Experience Replay
        self.replay = ExpReplay(args.exp_replay_size, args.stale_limit, exp_model, args, priority=self.args.prioritized)

        # DQN and Target DQN
        model = get_models(args.model)
        print("\n\nDQN")
        self.dqn = model(actions=args.actions)
        print("Target DQN")
        self.target_dqn = model(actions=args.actions)

        dqn_params = 0
        for weight in self.dqn.parameters():
            weight_params = 1
            for s in weight.size():
                weight_params *= s
            dqn_params += weight_params
        print("Model DQN has {:,} parameters.".format(dqn_params))

        self.target_dqn.eval()

        if args.gpu:
            print("Moving models to GPU.")
            self.dqn.cuda()
            self.target_dqn.cuda()

        # Optimizer
        # self.optimizer = Adam(self.dqn.parameters(), lr=args.lr)
        self.optimizer = RMSprop(self.dqn.parameters(), lr=args.lr)

        self.T = 0
        self.target_sync_T = -self.args.t_max

        # Action sequences
        self.actions_to_take = []
Ejemplo n.º 2
0
    def __init__(self, args, exp_model, logging_func):
        self.args = args

        # Exploration Model
        self.exp_model = exp_model

        self.log = logging_func["log"]

        # Experience Replay
        if self.args.set_replay:
            self.replay = ExpReplaySet(10, 10, exp_model, args, priority=False)
        else:
            self.replay = ExpReplay(args.exp_replay_size, args.stale_limit, exp_model, args, priority=self.args.prioritized)

        # DQN and Target DQN
        model = get_models(args.model)
        self.dqn = model(actions=args.actions)
        self.target_dqn = model(actions=args.actions)

        dqn_params = 0
        for weight in self.dqn.parameters():
            weight_params = 1
            for s in weight.size():
                weight_params *= s
            dqn_params += weight_params
        print("DQN has {:,} parameters.".format(dqn_params))

        self.target_dqn.eval()

        if args.gpu:
            print("Moving models to GPU.")
            self.dqn.cuda()
            self.target_dqn.cuda()

        # Optimizer
        # self.optimizer = Adam(self.dqn.parameters(), lr=args.lr)
        self.optimizer = RMSprop(self.dqn.parameters(), lr=args.lr)

        self.T = 0
        self.target_sync_T = -self.args.t_max
Ejemplo n.º 3
0
    def __init__(self, args, exp_model, logging_func):
        self.args = args

        # Exploration Model
        self.exp_model = exp_model

        self.log = logging_func["log"]

        # Experience Replay
        self.replay = ExpReplay(args.exp_replay_size, args)
        self.dnds = [DND(kernel=kernel, num_neighbors=args.nec_neighbours, max_memory=args.dnd_size, embedding_size=args.nec_embedding) for _ in range(self.args.actions)]

        # DQN and Target DQN
        model = get_models(args.model)
        self.embedding = model(embedding=args.nec_embedding)

        embedding_params = 0
        for weight in self.embedding.parameters():
            weight_params = 1
            for s in weight.size():
                weight_params *= s
            embedding_params += weight_params
        print("Embedding Network has {:,} parameters.".format(embedding_params))

        if args.gpu:
            print("Moving models to GPU.")
            self.embedding.cuda()

        # Optimizer
        self.optimizer = RMSprop(self.embedding.parameters(), lr=args.lr)
        # self.optimizer = Adam(self.embedding.parameters(), lr=args.lr)

        self.T = 0
        self.target_sync_T = -self.args.t_max

        self.experiences = []
        self.keys = []
        self.q_val_estimates = []

        self.table_updates = 0
Ejemplo n.º 4
0
class QMixPolicyGraph(PolicyGraph):
    """QMix impl. Assumes homogeneous agents for now.

    You must use MultiAgentEnv.with_agent_groups() to group agents
    together for QMix. This creates the proper Tuple obs/action spaces and
    populates the '_group_rewards' info field.

    Action masking: to specify an action mask for individual agents, use a
    dict space with an action_mask key, e.g. {"obs": ob, "action_mask": mask}.
    The mask space must be `Box(0, 1, (n_actions,))`.
    """

    def __init__(self, obs_space, action_space, config):
        _validate(obs_space, action_space)
        config = dict(ray.rllib.agents.qmix.qmix.DEFAULT_CONFIG, **config)
        self.config = config
        self.observation_space = obs_space
        self.action_space = action_space
        self.n_agents = len(obs_space.original_space.spaces)
        self.n_actions = action_space.spaces[0].n
        self.h_size = config["model"]["lstm_cell_size"]

        agent_obs_space = obs_space.original_space.spaces[0]
        if isinstance(agent_obs_space, Dict):
            space_keys = set(agent_obs_space.spaces.keys())
            if space_keys != {"obs", "action_mask"}:
                raise ValueError(
                    "Dict obs space for agent must have keyset "
                    "['obs', 'action_mask'], got {}".format(space_keys))
            mask_shape = tuple(agent_obs_space.spaces["action_mask"].shape)
            if mask_shape != (self.n_actions, ):
                raise ValueError("Action mask shape must be {}, got {}".format(
                    (self.n_actions, ), mask_shape))
            self.has_action_mask = True
            self.obs_size = _get_size(agent_obs_space.spaces["obs"])
            # The real agent obs space is nested inside the dict
            agent_obs_space = agent_obs_space.spaces["obs"]
        else:
            self.has_action_mask = False
            self.obs_size = _get_size(agent_obs_space)

        self.model = ModelCatalog.get_torch_model(
            agent_obs_space,
            self.n_actions,
            config["model"],
            default_model_cls=RNNModel)
        self.target_model = ModelCatalog.get_torch_model(
            agent_obs_space,
            self.n_actions,
            config["model"],
            default_model_cls=RNNModel)

        # Setup the mixer network.
        # The global state is just the stacked agent observations for now.
        self.state_shape = [self.obs_size, self.n_agents]
        if config["mixer"] is None:
            self.mixer = None
            self.target_mixer = None
        elif config["mixer"] == "qmix":
            self.mixer = QMixer(self.n_agents, self.state_shape,
                                config["mixing_embed_dim"])
            self.target_mixer = QMixer(self.n_agents, self.state_shape,
                                       config["mixing_embed_dim"])
        elif config["mixer"] == "vdn":
            self.mixer = VDNMixer()
            self.target_mixer = VDNMixer()
        else:
            raise ValueError("Unknown mixer type {}".format(config["mixer"]))

        self.cur_epsilon = 1.0
        self.update_target()  # initial sync

        # Setup optimizer
        self.params = list(self.model.parameters())
        self.loss = QMixLoss(self.model, self.target_model, self.mixer,
                             self.target_mixer, self.n_agents, self.n_actions,
                             self.config["double_q"], self.config["gamma"])
        self.optimiser = RMSprop(
            params=self.params,
            lr=config["lr"],
            alpha=config["optim_alpha"],
            eps=config["optim_eps"])

    @override(PolicyGraph)
    def compute_actions(self,
                        obs_batch,
                        state_batches=None,
                        prev_action_batch=None,
                        prev_reward_batch=None,
                        info_batch=None,
                        episodes=None,
                        **kwargs):
        obs_batch, action_mask = self._unpack_observation(obs_batch)

        # Compute actions
        with th.no_grad():
            q_values, hiddens = _mac(
                self.model, th.from_numpy(obs_batch),
                [th.from_numpy(np.array(s)) for s in state_batches])
            avail = th.from_numpy(action_mask).float()
            masked_q_values = q_values.clone()
            masked_q_values[avail == 0.0] = -float("inf")
            # epsilon-greedy action selector
            random_numbers = th.rand_like(q_values[:, :, 0])
            pick_random = (random_numbers < self.cur_epsilon).long()
            random_actions = Categorical(avail).sample().long()
            actions = (pick_random * random_actions +
                       (1 - pick_random) * masked_q_values.max(dim=2)[1])
            actions = actions.numpy()
            hiddens = [s.numpy() for s in hiddens]

        return TupleActions(list(actions.transpose([1, 0]))), hiddens, {}

    @override(PolicyGraph)
    def learn_on_batch(self, samples):
        obs_batch, action_mask = self._unpack_observation(samples["obs"])
        group_rewards = self._get_group_rewards(samples["infos"])

        # These will be padded to shape [B * T, ...]
        [rew, action_mask, act, dones, obs], initial_states, seq_lens = \
            chop_into_sequences(
                samples["eps_id"],
                samples["agent_index"], [
                    group_rewards, action_mask, samples["actions"],
                    samples["dones"], obs_batch
                ],
                [samples["state_in_{}".format(k)]
                 for k in range(len(self.get_initial_state()))],
                max_seq_len=self.config["model"]["max_seq_len"],
                dynamic_max=True,
                _extra_padding=1)
        # TODO(ekl) adding 1 extra unit of padding here, since otherwise we
        # lose the terminating reward and the Q-values will be unanchored!
        B, T = len(seq_lens), max(seq_lens) + 1

        def to_batches(arr):
            new_shape = [B, T] + list(arr.shape[1:])
            return th.from_numpy(np.reshape(arr, new_shape))

        rewards = to_batches(rew)[:, :-1].float()
        actions = to_batches(act)[:, :-1].long()
        obs = to_batches(obs).reshape([B, T, self.n_agents,
                                       self.obs_size]).float()
        action_mask = to_batches(action_mask)

        # TODO(ekl) this treats group termination as individual termination
        terminated = to_batches(dones.astype(np.float32)).unsqueeze(2).expand(
            B, T, self.n_agents)[:, :-1]
        filled = (np.reshape(np.tile(np.arange(T), B), [B, T]) <
                  np.expand_dims(seq_lens, 1)).astype(np.float32)
        mask = th.from_numpy(filled).unsqueeze(2).expand(B, T,
                                                         self.n_agents)[:, :-1]
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])

        # Compute loss
        loss_out, mask, masked_td_error, chosen_action_qvals, targets = \
            self.loss(rewards, actions, terminated, mask, obs, action_mask)

        # Optimise
        self.optimiser.zero_grad()
        loss_out.backward()
        grad_norm = th.nn.utils.clip_grad_norm_(
            self.params, self.config["grad_norm_clipping"])
        self.optimiser.step()

        mask_elems = mask.sum().item()
        stats = {
            "loss": loss_out.item(),
            "grad_norm": grad_norm
            if isinstance(grad_norm, float) else grad_norm.item(),
            "td_error_abs": masked_td_error.abs().sum().item() / mask_elems,
            "q_taken_mean": (chosen_action_qvals * mask).sum().item() /
            mask_elems,
            "target_mean": (targets * mask).sum().item() / mask_elems,
        }
        return {"stats": stats}, {}

    @override(PolicyGraph)
    def get_initial_state(self):
        return [
            s.expand([self.n_agents, -1]).numpy()
            for s in self.model.state_init()
        ]

    @override(PolicyGraph)
    def get_weights(self):
        return {"model": self.model.state_dict()}

    @override(PolicyGraph)
    def set_weights(self, weights):
        self.model.load_state_dict(weights["model"])

    @override(PolicyGraph)
    def get_state(self):
        return {
            "model": self.model.state_dict(),
            "target_model": self.target_model.state_dict(),
            "mixer": self.mixer.state_dict() if self.mixer else None,
            "target_mixer": self.target_mixer.state_dict()
            if self.mixer else None,
            "cur_epsilon": self.cur_epsilon,
        }

    @override(PolicyGraph)
    def set_state(self, state):
        self.model.load_state_dict(state["model"])
        self.target_model.load_state_dict(state["target_model"])
        if state["mixer"] is not None:
            self.mixer.load_state_dict(state["mixer"])
            self.target_mixer.load_state_dict(state["target_mixer"])
        self.set_epsilon(state["cur_epsilon"])
        self.update_target()

    def update_target(self):
        self.target_model.load_state_dict(self.model.state_dict())
        if self.mixer is not None:
            self.target_mixer.load_state_dict(self.mixer.state_dict())
        logger.debug("Updated target networks")

    def set_epsilon(self, epsilon):
        self.cur_epsilon = epsilon

    def _get_group_rewards(self, info_batch):
        group_rewards = np.array([
            info.get(GROUP_REWARDS, [0.0] * self.n_agents)
            for info in info_batch
        ])
        return group_rewards

    def _unpack_observation(self, obs_batch):
        """Unpacks the action mask / tuple obs from agent grouping.

        Returns:
            obs (Tensor): flattened obs tensor of shape [B, n_agents, obs_size]
            mask (Tensor): action mask, if any
        """
        unpacked = _unpack_obs(
            np.array(obs_batch),
            self.observation_space.original_space,
            tensorlib=np)
        if self.has_action_mask:
            obs = np.concatenate(
                [o["obs"] for o in unpacked],
                axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size])
            action_mask = np.concatenate(
                [o["action_mask"] for o in unpacked], axis=1).reshape(
                    [len(obs_batch), self.n_agents, self.n_actions])
        else:
            obs = np.concatenate(
                unpacked,
                axis=1).reshape([len(obs_batch), self.n_agents, self.obs_size])
            action_mask = np.ones(
                [len(obs_batch), self.n_agents, self.n_actions])
        return obs, action_mask
Ejemplo n.º 5
0
def main():
    parser = argparse.ArgumentParser(
        description='Tuning with Multitask bi-directional RNN-CNN-CRF')
    parser.add_argument('--config',
                        help='Config file (Python file format)',
                        default="config_multitask.py")
    parser.add_argument('--grid', help='Grid Search Options', default="{}")
    args = parser.parse_args()
    logger = get_logger("Multi-Task")
    use_gpu = torch.cuda.is_available()

    # Config Tensorboard Writer
    log_writer = SummaryWriter()

    # Load from config file
    spec = importlib.util.spec_from_file_location("config", args.config)
    config_module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(config_module)
    config = config_module.entries

    # Load options from grid search
    options = eval(args.grid)
    for k, v in options.items():
        if isinstance(v, six.string_types):
            cmd = "%s = \"%s\"" % (k, v)
        else:
            cmd = "%s = %s" % (k, v)
            log_writer.add_scalar(k, v, 1)
        exec(cmd)

    # Load embedding dict
    embedding = config.embedding.embedding_type
    embedding_path = config.embedding.embedding_dict
    embedd_dict, embedd_dim = utils.load_embedding_dict(
        embedding, embedding_path)

    # Collect data path
    data_dir = config.data.data_dir
    data_names = config.data.data_names
    train_paths = [
        os.path.join(data_dir, data_name, "train.tsv")
        for data_name in data_names
    ]
    dev_paths = [
        os.path.join(data_dir, data_name, "devel.tsv")
        for data_name in data_names
    ]
    test_paths = [
        os.path.join(data_dir, data_name, "test.tsv")
        for data_name in data_names
    ]

    # Create alphabets
    logger.info("Creating Alphabets")
    if not os.path.exists('tmp'):
        os.mkdir('tmp')
    word_alphabet, char_alphabet, pos_alphabet, chunk_alphabet, ner_alphabet, ner_alphabet_task, label_reflect  = \
            bionlp_data.create_alphabets(os.path.join(Path(data_dir).abspath(), "alphabets", "_".join(data_names)), train_paths,
                    data_paths=dev_paths + test_paths, use_cache=True,
                    embedd_dict=embedd_dict, max_vocabulary_size=50000)

    logger.info("Word Alphabet Size: %d" % word_alphabet.size())
    logger.info("Character Alphabet Size: %d" % char_alphabet.size())
    logger.info("POS Alphabet Size: %d" % pos_alphabet.size())
    logger.info("Chunk Alphabet Size: %d" % chunk_alphabet.size())
    logger.info("NER Alphabet Size: %d" % ner_alphabet.size())
    logger.info(
        "NER Alphabet Size per Task: %s",
        str([task_alphabet.size() for task_alphabet in ner_alphabet_task]))

    #task_reflects = torch.LongTensor(reverse_reflect(label_reflect, ner_alphabet.size()))
    #if use_gpu:
    #    task_reflects = task_reflects.cuda()

    if embedding == 'elmo':
        logger.info("Loading ELMo Embedder")
        ee = ElmoEmbedder(options_file=config.embedding.elmo_option,
                          weight_file=config.embedding.elmo_weight,
                          cuda_device=config.embedding.elmo_cuda)
    else:
        ee = None

    logger.info("Reading Data")

    # Prepare dataset
    data_trains = [
        bionlp_data.read_data_to_variable(train_path,
                                          word_alphabet,
                                          char_alphabet,
                                          pos_alphabet,
                                          chunk_alphabet,
                                          ner_alphabet_task[task_id],
                                          use_gpu=use_gpu,
                                          elmo_ee=ee)
        for task_id, train_path in enumerate(train_paths)
    ]
    num_data = [sum(data_train[1]) for data_train in data_trains]
    num_labels = ner_alphabet.size()
    num_labels_task = [task_item.size() for task_item in ner_alphabet_task]

    data_devs = [
        bionlp_data.read_data_to_variable(dev_path,
                                          word_alphabet,
                                          char_alphabet,
                                          pos_alphabet,
                                          chunk_alphabet,
                                          ner_alphabet_task[task_id],
                                          use_gpu=use_gpu,
                                          volatile=True,
                                          elmo_ee=ee)
        for task_id, dev_path in enumerate(dev_paths)
    ]

    data_tests = [
        bionlp_data.read_data_to_variable(test_path,
                                          word_alphabet,
                                          char_alphabet,
                                          pos_alphabet,
                                          chunk_alphabet,
                                          ner_alphabet_task[task_id],
                                          use_gpu=use_gpu,
                                          volatile=True,
                                          elmo_ee=ee)
        for task_id, test_path in enumerate(test_paths)
    ]

    writer = BioNLPWriter(word_alphabet, char_alphabet, pos_alphabet,
                          chunk_alphabet, ner_alphabet)

    def construct_word_embedding_table():
        scale = np.sqrt(3.0 / embedd_dim)
        table = np.empty([word_alphabet.size(), embedd_dim], dtype=np.float32)
        table[bionlp_data.UNK_ID, :] = np.random.uniform(
            -scale, scale, [1, embedd_dim]).astype(np.float32)
        oov = 0
        for word, index in word_alphabet.items():
            if not embedd_dict == None and word in embedd_dict:
                embedding = embedd_dict[word]
            elif not embedd_dict == None and word.lower() in embedd_dict:
                embedding = embedd_dict[word.lower()]
            else:
                embedding = np.random.uniform(
                    -scale, scale, [1, embedd_dim]).astype(np.float32)
                oov += 1
            table[index, :] = embedding
        print('oov: %d' % oov)
        return torch.from_numpy(table)

    word_table = construct_word_embedding_table()
    logger.info("constructing network...")

    # Construct network
    window = 3
    num_layers = 1
    mode = config.rnn.mode
    hidden_size = config.rnn.hidden_size
    char_dim = config.rnn.char_dim
    num_filters = config.rnn.num_filters
    tag_space = config.rnn.tag_space
    bigram = config.rnn.bigram
    attention_mode = config.rnn.attention
    if config.rnn.dropout == 'std':
        network = FullySharedBiRecurrentCRF(
            len(data_trains),
            embedd_dim,
            word_alphabet.size(),
            char_dim,
            char_alphabet.size(),
            num_filters,
            window,
            mode,
            hidden_size,
            num_layers,
            num_labels,
            num_labels_task=num_labels_task,
            tag_space=tag_space,
            embedd_word=word_table,
            p_in=config.rnn.p,
            p_rnn=config.rnn.p,
            bigram=bigram,
            elmo=(embedding == 'elmo'),
            attention_mode=attention_mode,
            adv_loss_coef=config.multitask.adv_loss_coef,
            diff_loss_coef=config.multitask.diff_loss_coef,
            char_level_rnn=config.rnn.char_level_rnn)
    else:
        raise NotImplementedError

    if use_gpu:
        network.cuda()

    # Prepare training
    unk_replace = config.embedding.unk_replace
    num_epochs = config.training.num_epochs
    batch_size = config.training.batch_size
    lr = config.training.learning_rate
    momentum = config.training.momentum
    alpha = config.training.alpha
    lr_decay = config.training.lr_decay
    schedule = config.training.schedule
    gamma = config.training.gamma

    # optim = SGD(network.parameters(), lr=lr, momentum=momentum, weight_decay=gamma, nesterov=True)
    optim = RMSprop(network.parameters(),
                    lr=lr,
                    alpha=alpha,
                    momentum=momentum,
                    weight_decay=gamma)
    logger.info(
        "Network: %s, num_layer=%d, hidden=%d, filter=%d, tag_space=%d, crf=%s"
        % (mode, num_layers, hidden_size, num_filters, tag_space,
           'bigram' if bigram else 'unigram'))
    logger.info(
        "training: l2: %f, (#training data: %s, batch: %d, dropout: %.2f, unk replace: %.2f)"
        % (gamma, num_data, batch_size, config.rnn.p, unk_replace))

    num_batches = [x // batch_size + 1 for x in num_data]
    dev_f1 = [0.0 for x in num_data]
    dev_acc = [0.0 for x in num_data]
    dev_precision = [0.0 for x in num_data]
    dev_recall = [0.0 for x in num_data]
    test_f1 = [0.0 for x in num_data]
    test_acc = [0.0 for x in num_data]
    test_precision = [0.0 for x in num_data]
    test_recall = [0.0 for x in num_data]
    best_epoch = [0 for x in num_data]

    # Training procedure
    for epoch in range(1, num_epochs + 1):
        print(
            'Epoch %d (%s(%s), learning rate=%.4f, decay rate=%.4f (schedule=%d)): '
            % (epoch, mode, config.rnn.dropout, lr, lr_decay, schedule))
        train_err = 0.
        train_total = 0.

        # Gradient decent on training data
        start_time = time.time()
        num_back = 0
        network.train()
        batch_count = 0
        for batch in range(1, 2 * num_batches[0] + 1):
            r = random.random()
            task_id = 0 if r <= 0.5 else random.randint(1, len(num_data) - 1)
            batch_count += 1
            word, char, _, _, labels, masks, lengths, elmo_embedding = bionlp_data.get_batch_variable(
                data_trains[task_id], batch_size, unk_replace=unk_replace)

            optim.zero_grad()
            loss, task_loss, adv_loss, diff_loss = network.loss(
                task_id,
                word,
                char,
                labels,
                mask=masks,
                elmo_word=elmo_embedding)
            #log_writer.add_scalars(
            #        'train_loss_task' + str(task_id),
            #        {'all_loss': loss, 'task_loss': task_loss, 'adv_loss': adv_loss, 'diff_loss': diff_loss},
            #        (epoch - 1) * (num_batches[task_id] + 1) + batch
            #)
            #log_writer.add_scalars(
            #        'train_loss_overview',
            #        {'all_loss': loss, 'task_loss': task_loss, 'adv_loss': adv_loss, 'diff_loss': diff_loss},
            #        (epoch - 1) * (sum(num_batches) + 1) + batch_count
            #)
            loss.backward()
            clip_grad_norm(network.parameters(), 5.0)
            optim.step()

            num_inst = word.size(0)
            train_err += loss.data[0] * num_inst
            train_total += num_inst

            time_ave = (time.time() - start_time) / batch
            time_left = (2 * num_batches[0] - batch) * time_ave

            # update log
            if batch % 100 == 0:
                sys.stdout.write("\b" * num_back)
                sys.stdout.write(" " * num_back)
                sys.stdout.write("\b" * num_back)
                log_info = 'train: %d/%d loss: %.4f, time left (estimated): %.2fs' % (
                    batch, 2 * num_batches[0], train_err / train_total,
                    time_left)
                sys.stdout.write(log_info)
                sys.stdout.flush()
                num_back = len(log_info)

        sys.stdout.write("\b" * num_back)
        sys.stdout.write(" " * num_back)
        sys.stdout.write("\b" * num_back)
        print('train: %d loss: %.4f, time: %.2fs' %
              (2 * num_batches[0], train_err / train_total,
               time.time() - start_time))

        # Evaluate performance on dev data
        network.eval()
        for task_id in range(len(num_batches)):
            tmp_filename = 'tmp/%s_dev%d%d' % (str(uid), epoch, task_id)
            writer.start(tmp_filename)

            for batch in bionlp_data.iterate_batch_variable(
                    data_devs[task_id], batch_size):
                word, char, pos, chunk, labels, masks, lengths, elmo_embedding = batch
                preds, _ = network.decode(
                    task_id,
                    word,
                    char,
                    target=labels,
                    mask=masks,
                    leading_symbolic=bionlp_data.NUM_SYMBOLIC_TAGS,
                    elmo_word=elmo_embedding)
                writer.write(word.data.cpu().numpy(),
                             pos.data.cpu().numpy(),
                             chunk.data.cpu().numpy(),
                             preds.cpu().numpy(),
                             labels.data.cpu().numpy(),
                             lengths.cpu().numpy())
            writer.close()
            acc, precision, recall, f1 = evaluate(tmp_filename)
            log_writer.add_scalars(
                'dev_task' + str(task_id), {
                    'accuracy': acc,
                    'precision': precision,
                    'recall': recall,
                    'f1': f1
                }, epoch)
            print(
                'dev acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%%'
                % (acc, precision, recall, f1))

            if dev_f1[task_id] < f1:
                dev_f1[task_id] = f1
                dev_acc[task_id] = acc
                dev_precision[task_id] = precision
                dev_recall[task_id] = recall
                best_epoch[task_id] = epoch

                # Evaluate on test data when better performance detected
                tmp_filename = 'tmp/%s_test%d%d' % (str(uid), epoch, task_id)
                writer.start(tmp_filename)

                for batch in bionlp_data.iterate_batch_variable(
                        data_tests[task_id], batch_size):
                    word, char, pos, chunk, labels, masks, lengths, elmo_embedding = batch
                    preds, _ = network.decode(
                        task_id,
                        word,
                        char,
                        target=labels,
                        mask=masks,
                        leading_symbolic=bionlp_data.NUM_SYMBOLIC_TAGS,
                        elmo_word=elmo_embedding)
                    writer.write(word.data.cpu().numpy(),
                                 pos.data.cpu().numpy(),
                                 chunk.data.cpu().numpy(),
                                 preds.cpu().numpy(),
                                 labels.data.cpu().numpy(),
                                 lengths.cpu().numpy())
                writer.close()
                test_acc[task_id], test_precision[task_id], test_recall[
                    task_id], test_f1[task_id] = evaluate(tmp_filename)
                log_writer.add_scalars(
                    'test_task' + str(task_id), {
                        'accuracy': test_acc[task_id],
                        'precision': test_precision[task_id],
                        'recall': test_recall[task_id],
                        'f1': test_f1[task_id]
                    }, epoch)

            print(
                "================================================================================"
            )
            print("dataset: %s" % data_names[task_id])
            print(
                "best dev  acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)"
                % (dev_acc[task_id], dev_precision[task_id],
                   dev_recall[task_id], dev_f1[task_id], best_epoch[task_id]))
            print(
                "best test acc: %.2f%%, precision: %.2f%%, recall: %.2f%%, F1: %.2f%% (epoch: %d)"
                %
                (test_acc[task_id], test_precision[task_id],
                 test_recall[task_id], test_f1[task_id], best_epoch[task_id]))
            print(
                "================================================================================\n"
            )

            if epoch % schedule == 0:
                # lr = learning_rate / (1.0 + epoch * lr_decay)
                # optim = SGD(network.parameters(), lr=lr, momentum=momentum, weight_decay=gamma, nesterov=True)
                lr = lr * lr_decay
                optim.param_groups[0]['lr'] = lr

    # writer.export_scalars_to_json("./all_scalars.json")
    writer.close()
Ejemplo n.º 6
0
class DMAQ_qattenLearner:
    def __init__(self, mac, scheme, logger, args):
        self.args = args
        self.mac = mac
        self.logger = logger

        self.params = list(mac.parameters())

        self.last_target_update_episode = 0

        self.mixer = None
        if args.mixer is not None:
            if args.mixer == "dmaq":
                self.mixer = DMAQer(args)
            elif args.mixer == 'dmaq_qatten':
                self.mixer = DMAQ_QattenMixer(args)
            else:
                raise ValueError("Mixer {} not recognised.".format(args.mixer))
            self.params += list(self.mixer.parameters())
            self.target_mixer = copy.deepcopy(self.mixer)

        self.optimiser = RMSprop(params=self.params,
                                 lr=args.lr,
                                 alpha=args.optim_alpha,
                                 eps=args.optim_eps)

        # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC
        self.target_mac = copy.deepcopy(mac)

        self.log_stats_t = -self.args.learner_log_interval - 1

        self.n_actions = self.args.n_actions

    def sub_train(self,
                  batch: EpisodeBatch,
                  t_env: int,
                  episode_num: int,
                  mac,
                  mixer,
                  optimiser,
                  params,
                  show_demo=False,
                  save_data=None):
        # Get the relevant quantities
        rewards = batch["reward"][:, :-1]
        actions = batch["actions"][:, :-1]
        terminated = batch["terminated"][:, :-1].float()
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        avail_actions = batch["avail_actions"]
        actions_onehot = batch["actions_onehot"][:, :-1]

        # Calculate estimated Q-Values
        mac_out = []
        mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            agent_outs = mac.forward(batch, t=t)
            mac_out.append(agent_outs)
        mac_out = th.stack(mac_out, dim=1)  # Concat over time

        # Pick the Q-Values for the actions taken by each agent
        chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3,
                                        index=actions).squeeze(
                                            3)  # Remove the last dim

        x_mac_out = mac_out.clone().detach()
        x_mac_out[avail_actions == 0] = -9999999
        max_action_qvals, max_action_index = x_mac_out[:, :-1].max(dim=3)

        max_action_index = max_action_index.detach().unsqueeze(3)
        is_max_action = (max_action_index == actions).int().float()

        if show_demo:
            q_i_data = chosen_action_qvals.detach().cpu().numpy()
            q_data = (max_action_qvals -
                      chosen_action_qvals).detach().cpu().numpy()
            # self.logger.log_stat('agent_1_%d_q_1' % save_data[0], np.squeeze(q_data)[0], t_env)
            # self.logger.log_stat('agent_2_%d_q_2' % save_data[1], np.squeeze(q_data)[1], t_env)

        # Calculate the Q-Values necessary for the target
        target_mac_out = []
        self.target_mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            target_agent_outs = self.target_mac.forward(batch, t=t)
            target_mac_out.append(target_agent_outs)

        # We don't need the first timesteps Q-Value estimate for calculating targets
        target_mac_out = th.stack(target_mac_out[1:],
                                  dim=1)  # Concat across time

        # Mask out unavailable actions
        target_mac_out[avail_actions[:, 1:] == 0] = -9999999

        # Max over target Q-Values
        if self.args.double_q:
            # Get actions that maximise live Q (for double q-learning)
            mac_out_detach = mac_out.clone().detach()
            mac_out_detach[avail_actions == 0] = -9999999
            cur_max_actions = mac_out_detach[:, 1:].max(dim=3, keepdim=True)[1]
            target_chosen_qvals = th.gather(target_mac_out, 3,
                                            cur_max_actions).squeeze(3)
            target_max_qvals = target_mac_out.max(dim=3)[0]
            target_next_actions = cur_max_actions.detach()

            cur_max_actions_onehot = th.zeros(
                cur_max_actions.squeeze(3).shape + (self.n_actions, )).to(
                    self.args.device)  # .cuda() #TODO
            cur_max_actions_onehot = cur_max_actions_onehot.scatter_(
                3, cur_max_actions, 1)
        else:
            # Calculate the Q-Values necessary for the target
            target_mac_out = []
            self.target_mac.init_hidden(batch.batch_size)
            for t in range(batch.max_seq_length):
                target_agent_outs = self.target_mac.forward(batch, t=t)
                target_mac_out.append(target_agent_outs)
            # We don't need the first timesteps Q-Value estimate for calculating targets
            target_mac_out = th.stack(target_mac_out[1:],
                                      dim=1)  # Concat across time
            target_max_qvals = target_mac_out.max(dim=3)[0]

        # Mix
        if mixer is not None:
            if self.args.mixer == "dmaq_qatten":
                ans_chosen, q_attend_regs, head_entropies = \
                    mixer(chosen_action_qvals, batch["state"][:, :-1], is_v=True)
                ans_adv, _, _ = mixer(chosen_action_qvals,
                                      batch["state"][:, :-1],
                                      actions=actions_onehot,
                                      max_q_i=max_action_qvals,
                                      is_v=False)
                chosen_action_qvals = ans_chosen + ans_adv
            else:
                ans_chosen = mixer(chosen_action_qvals,
                                   batch["state"][:, :-1],
                                   is_v=True)
                ans_adv = mixer(chosen_action_qvals,
                                batch["state"][:, :-1],
                                actions=actions_onehot,
                                max_q_i=max_action_qvals,
                                is_v=False)
                chosen_action_qvals = ans_chosen + ans_adv

            if self.args.double_q:
                if self.args.mixer == "dmaq_qatten":
                    target_chosen, _, _ = self.target_mixer(
                        target_chosen_qvals, batch["state"][:, 1:], is_v=True)
                    target_adv, _, _ = self.target_mixer(
                        target_chosen_qvals,
                        batch["state"][:, 1:],
                        actions=cur_max_actions_onehot,
                        max_q_i=target_max_qvals,
                        is_v=False)
                    target_max_qvals = target_chosen + target_adv
                else:
                    target_chosen = self.target_mixer(target_chosen_qvals,
                                                      batch["state"][:, 1:],
                                                      is_v=True)
                    target_adv = self.target_mixer(
                        target_chosen_qvals,
                        batch["state"][:, 1:],
                        actions=cur_max_actions_onehot,
                        max_q_i=target_max_qvals,
                        is_v=False)
                    target_max_qvals = target_chosen + target_adv
            else:
                target_max_qvals = self.target_mixer(target_max_qvals,
                                                     batch["state"][:, 1:],
                                                     is_v=True)

        # Calculate 1-step Q-Learning targets
        targets = rewards + self.args.gamma * (1 -
                                               terminated) * target_max_qvals

        if show_demo:
            tot_q_data = chosen_action_qvals.detach().cpu().numpy()
            tot_target = targets.detach().cpu().numpy()
            print('action_pair_%d_%d' % (save_data[0], save_data[1]),
                  np.squeeze(q_data[:, 0]), np.squeeze(q_i_data[:, 0]),
                  np.squeeze(tot_q_data[:, 0]), np.squeeze(tot_target[:, 0]))
            self.logger.log_stat(
                'action_pair_%d_%d' % (save_data[0], save_data[1]),
                np.squeeze(tot_q_data[:, 0]), t_env)
            return

        # Td-error
        td_error = (chosen_action_qvals - targets.detach())

        mask = mask.expand_as(td_error)

        # 0-out the targets that came from padded data
        masked_td_error = td_error * mask

        # Normal L2 loss, take mean over actual data
        if self.args.mixer == "dmaq_qatten":
            loss = (masked_td_error**2).sum() / mask.sum() + q_attend_regs
        else:
            loss = (masked_td_error**2).sum() / mask.sum()

        masked_hit_prob = th.mean(is_max_action, dim=2) * mask
        hit_prob = masked_hit_prob.sum() / mask.sum()

        # Optimise
        optimiser.zero_grad()
        loss.backward()
        grad_norm = th.nn.utils.clip_grad_norm_(params,
                                                self.args.grad_norm_clip)
        optimiser.step()

        if t_env - self.log_stats_t >= self.args.learner_log_interval:
            self.logger.log_stat("loss", loss.item(), t_env)
            self.logger.log_stat("hit_prob", hit_prob.item(), t_env)
            self.logger.log_stat("grad_norm", grad_norm, t_env)
            mask_elems = mask.sum().item()
            self.logger.log_stat(
                "td_error_abs",
                (masked_td_error.abs().sum().item() / mask_elems), t_env)
            self.logger.log_stat("q_taken_mean",
                                 (chosen_action_qvals * mask).sum().item() /
                                 (mask_elems * self.args.n_agents), t_env)
            self.logger.log_stat("target_mean", (targets * mask).sum().item() /
                                 (mask_elems * self.args.n_agents), t_env)
            self.log_stats_t = t_env

    def train(self,
              batch: EpisodeBatch,
              t_env: int,
              episode_num: int,
              show_demo=False,
              save_data=None):
        self.sub_train(batch,
                       t_env,
                       episode_num,
                       self.mac,
                       self.mixer,
                       self.optimiser,
                       self.params,
                       show_demo=show_demo,
                       save_data=save_data)
        if (episode_num - self.last_target_update_episode
            ) / self.args.target_update_interval >= 1.0:
            self._update_targets()
            self.last_target_update_episode = episode_num

    def _update_targets(self):
        self.target_mac.load_state(self.mac)
        if self.mixer is not None:
            self.target_mixer.load_state_dict(self.mixer.state_dict())
        self.logger.console_logger.info("Updated target network")

    def cuda(self):
        self.mac.cuda()
        self.target_mac.cuda()
        if self.mixer is not None:
            self.mixer.cuda()
            self.target_mixer.cuda()

    def save_models(self, path):
        self.mac.save_models(path)
        if self.mixer is not None:
            th.save(self.mixer.state_dict(), "{}/mixer.th".format(path))
        th.save(self.optimiser.state_dict(), "{}/opt.th".format(path))

    def load_models(self, path):
        self.mac.load_models(path)
        # Not quite right but I don't want to save target networks
        self.target_mac.load_models(path)
        if self.mixer is not None:
            self.mixer.load_state_dict(
                th.load("{}/mixer.th".format(path),
                        map_location=lambda storage, loc: storage))
            self.target_mixer.load_state_dict(
                th.load("{}/mixer.th".format(path),
                        map_location=lambda storage, loc: storage))
        self.optimiser.load_state_dict(
            th.load("{}/opt.th".format(path),
                    map_location=lambda storage, loc: storage))
Ejemplo n.º 7
0
    max_epoch = args.max_epoch
    n_update_d = args.n_update_d
    n_update_g = args.n_update_g
    batch_size = args.batch_size
    dim_noise = args.dim_noise
    n_eval_epoch = args.n_eval_epoch

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    D = Discriminator(args.loss)
    G = Generator()

    D.to(device)
    G.to(device)

    # train
    optimizer_D = RMSprop(D.parameters(), lr=args.lr)
    optimizer_G = RMSprop(G.parameters(), lr=args.lr)

    face_dataset = FolderDataset(device, face_folder, conditions=False)
    dataloader = DataLoader(face_dataset, batch_size=batch_size, shuffle=True)

    if args.tb:
        from tensorboardX import SummaryWriter
        writer = SummaryWriter(logdir='./log/' + args.model_name)

    for epoch in trange(max_epoch):
        loss_epoch_d, loss_epoch_g = 0, 0
        for i, batch_image in enumerate(dataloader):
            if i % n_update_d == 0:
                loss_d, grad_d = update_discriminator(batch_image[0], G, D,
                                                      optimizer_D, args.clip,
Ejemplo n.º 8
0
class inception_v3_agent(BaseAgent):
    def __init__(self, cfg):
        super().__init__(cfg)
        self.cfg = cfg
        self.device = get_device()

        # define models
        # inception_v3 input size = ( N x 3 x 299 x 299 )
        self.model = inception_v3(pretrained=True, num_classes=cfg.num_classes)

        # define data_loader
        self.dataset = inception_data(cfg).get_dataset()
        tr_size = int(cfg.train_test_ratio * len(self.dataset))
        te_size = len(self.dataset) - tr_size
        tr_dataset, te_dataset = random_split(self.dataset, [tr_size, te_size])
        self.tr_loader = DataLoader(tr_dataset,
                                    batch_size=cfg.bs,
                                    shuffle=cfg.data_shuffle,
                                    num_workers=cfg.num_workers)
        self.te_loader = DataLoader(te_dataset,
                                    batch_size=cfg.bs,
                                    shuffle=cfg.data_shuffle,
                                    num_workers=cfg.num_workers)

        # define loss
        self.loss = torch.tensor(0)
        self.criterion = CrossEntropyLoss()

        # define optimizers for both generator and discriminator
        self.optimizer = RMSprop(self.model.parameters(), lr=cfg.lr)

        # initialize counter
        self.current_epoch = 0
        self.current_iteration = 0
        self.best_metric = 0
        self.best_info = ""

        # set cuda flag
        self.is_cuda = torch.cuda.is_available()
        if self.is_cuda and not self.cfg.cuda:
            self.logger.info(
                "WARNING: You have a CUDA device, so you should probably enable CUDA"
            )

        self.cuda = self.is_cuda & self.cfg.cuda

        # set the manual seed for torch
        self.manual_seed = self.cfg.seed
        if self.cuda:
            torch.cuda.manual_seed(self.manual_seed)
            self.model = self.model.to(self.device)
            if self.cfg.data_parallel:
                self.model = nn.DataParallel(self.model)
            self.logger.info("Program will run on *****GPU-CUDA***** ")
        else:
            self.model = self.model.to(self.device)
            torch.manual_seed(self.manual_seed)
            self.logger.info("Program will run on *****CPU*****\n")

        # Model Loading from cfg if not found start from scratch.
        self.exp_dir = os.path.join('./experiments', cfg.exp_name)
        self.load_checkpoint(self.cfg.checkpoint_filename)
        # Summary Writer
        self.summary_writer = SummaryWriter(
            log_dir=os.path.join(self.exp_dir, 'summaries'))

    def load_checkpoint(self, file_name):
        """
        Latest checkpoint loader
        :param file_name: name of the checkpoint file
        :return:
        """
        try:
            self.logger.info("Loading checkpoint '{}'".format(file_name))
            checkpoint = torch.load(file_name, map_location=self.device)

            self.current_epoch = checkpoint['epoch']
            self.current_iteration = checkpoint['iteration']
            self.model.load_state_dict(checkpoint['model'], strict=False)
            self.optimizer.load_state_dict(checkpoint['optimizer'])

            info = "Checkpoint loaded successfully from "
            self.logger.info(
                info + "'{}' at (epoch {}) at (iteration {})\n".format(
                    file_name, checkpoint['epoch'], checkpoint['iteration']))

        except OSError as e:
            self.logger.info("Checkpoint not found in '{}'".format(file_name))
            self.logger.info("**First time to train**")

    def save_checkpoint(self, file_name="checkpoint.pth.tar", is_best=False):
        """
        Checkpoint saver
        :param file_name: name of the checkpoint file
        :param is_best: boolean flag to indicate whether current 
                        checkpoint's accuracy is the best so far
        :return:
        """
        state = {
            'epoch': self.current_epoch,
            'iteration': self.current_iteration,
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict()
        }

        # save the state
        checkpoint_dir = os.path.join(self.exp_dir, 'checkpoints')
        torch.save(state, os.path.join(checkpoint_dir, file_name))

        if is_best:
            shutil.copyfile(os.path.join(checkpoint_dir, file_name),
                            os.path.join(checkpoint_dir, 'best.pt'))

    def run(self):
        """
        The main operator
        :return:
        """
        try:
            if self.cfg.mode == 'train':
                self.train()
            elif self.cfg.mode == 'predict':
                self.predict()
            else:
                self.logger.info("\'mode\' value of cfg file is wrong")
                raise ValueError

        except KeyboardInterrupt:
            self.logger.info("You have entered CTRL+C.. Wait to finalize")

    def train(self):
        """
        Main training loop
        :return:
        """
        self.validate()
        for e in range(1, self.cfg.epochs + 1):
            self.current_epoch = e
            self.train_one_epoch()
            self.validate()
        print(self.best_info)

    def train_one_epoch(self):
        """
        One epoch of training
        :return:
        """
        self.model.train()
        for batch_idx, (imgs, labels) in enumerate(self.tr_loader):
            imgs, labels = imgs.to(self.device), labels.to(self.device)
            self.optimizer.zero_grad()

            outputs, aux_outputs = self.model(imgs).values()
            loss1 = self.criterion(outputs, labels)
            loss2 = self.criterion(aux_outputs, labels)
            self.loss = loss1 + 0.3 * loss2

            _, preds = torch.max(outputs, 1)
            acc = preds.eq(labels.view_as(preds)).sum().item() / self.cfg.bs

            self.loss.backward()
            self.optimizer.step()

            self.summary_writer.add_scalars(
                'scalar_group', {
                    'loss_end': loss1.item(),
                    'loss_aux': loss2.item(),
                    'loss_total': self.loss.item(),
                    'accuracy': acc
                }, self.current_iteration)

            if batch_idx % self.cfg.log_interval == 0:
                info_1 = 'Epochs {} [{}/{} ({:.0f}%)] | Loss: {:.6f}'.format(
                    self.current_epoch, batch_idx * len(imgs),
                    len(self.tr_loader.dataset),
                    100. * batch_idx / len(self.tr_loader), self.loss.item())
                info_2 = 'Batch Accuracy : {:.2f}'.format(acc)
                self.logger.info('{} | {}'.format(info_1, info_2))
                self.save_checkpoint('{}_epoch{}_iter{}.pt'.format(
                    self.cfg.exp_name, self.current_epoch,
                    self.current_iteration))
            self.current_iteration += 1

    def validate(self):
        """
        One cycle of model validation
        :return:
        """
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for batch_idx, (imgs, labels) in enumerate(self.te_loader):
                imgs, labels = imgs.to(self.device), labels.to(self.device)

                outputs, aux_outputs = self.model(imgs).values()
                loss1 = self.criterion(outputs, labels)
                loss2 = self.criterion(aux_outputs, labels)
                test_loss += loss1 + 0.3 * loss2

                # get the index of the max log-probability
                _, preds = torch.max(outputs, 1)
                correct += preds.eq(labels.view_as(preds)).sum().item()

        test_loss /= len(self.te_loader)
        acc = correct / (len(self.te_loader) * self.cfg.bs)
        self.logger.info(
            'Test: Avg loss:{:.4f}, Accuracy:{}/{} ({:.2f}%)\n'.format(
                test_loss, correct,
                len(self.te_loader) * self.cfg.bs, 100 * acc))
        if self.best_metric <= acc:
            self.best_metric = acc
            self.best_info = 'Best: {}_epoch{}_iter{}.pt'.format(
                self.cfg.exp_name, self.current_epoch,
                self.current_iteration - 1)

    def predict(self):
        try:
            from tkinter.filedialog import askdirectory
            from glob import glob
            directory = askdirectory(title="select a directory")
            fn_list = glob(directory + '/*.jpg')
        except ImportError:
            from glob import glob
            fn_list = glob(self.cfg.test_img_path + '/*.jpg')

        result = {k: 0 for k in self.dataset.classes}
        t = time.time()
        for fn in fn_list:
            img = read_image(fn, size=(299, 299))
            img = img.to(self.device)
            self.model.eval()
            output = self.model(img)

            _, pred = torch.max(output, 1)
            res = self.dataset.idx_to_class[pred.item()]
            print("{} : {}".format(fn, res))
            result[res] += 1
        print("result : ", result)
        print('process spend : {} sec for {} images'.format(
            time.time() - t, len(fn_list)))

    def finalize(self):
        """
        Finalizes all the operations of the 2 Main classes of the process, 
        the operator and the data loader
        :return:
        """
        pass
Ejemplo n.º 9
0
class MAXQLearner:
    def __init__(self, mac, scheme, logger, args):
        self.args = args
        self.mac = mac
        self.logger = logger

        self.mac_params = list(mac.parameters())
        self.params = list(self.mac.parameters())

        self.last_target_update_episode = 0

        self.mixer = None
        assert args.mixer is not None
        if args.mixer is not None:
            if args.mixer == "vdn":
                self.mixer = VDNMixer()
            elif args.mixer == "qmix":
                self.mixer = QMixer(args)
            elif args.mixer == "qmix_cnn":
                self.mixer = QMixer_CNN(args)
            else:
                raise ValueError("Mixer {} not recognised.".format(args.mixer))
            self.mixer_params = list(self.mixer.parameters())
            self.params += list(self.mixer.parameters())
            self.target_mixer = copy.deepcopy(self.mixer)

        # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC
        self.target_mac = copy.deepcopy(mac)

        # Central Q
        # TODO: Clean this mess up!
        self.central_mac = None
        if self.args.central_mixer in ["ff", "atten"]:
            if self.args.central_loss == 0:
                self.central_mixer = self.mixer
                self.central_mac = self.mac
                self.target_central_mac = self.target_mac
            else:
                if self.args.central_mixer == "ff":
                    self.central_mixer = QMixerCentralFF(
                        args
                    )  # Feedforward network that takes state and agent utils as input
                elif self.args.central_mixer == "atten":
                    self.central_mixer = QMixerCentralAtten(args)
                else:
                    raise Exception("Error with central_mixer")

                assert args.central_mac == "basic_central_mac"
                self.central_mac = mac_REGISTRY[args.central_mac](
                    scheme, args
                )  # Groups aren't used in the CentralBasicController. Little hacky
                self.target_central_mac = copy.deepcopy(self.central_mac)
                self.params += list(self.central_mac.parameters())
        else:
            raise Exception("Error with qCentral")
        self.params += list(self.central_mixer.parameters())
        self.target_central_mixer = copy.deepcopy(self.central_mixer)

        self.optimiser = RMSprop(params=self.params,
                                 lr=args.lr,
                                 alpha=args.optim_alpha,
                                 eps=args.optim_eps)

        self.log_stats_t = -self.args.learner_log_interval - 1

        self.grad_norm = 1
        self.mixer_norm = 1
        self.mixer_norms = deque([1], maxlen=100)

    def train(self, batch: EpisodeBatch, t_env: int, episode_num: int):
        # Get the relevant quantities
        rewards = batch["reward"][:, :-1]
        actions = batch["actions"][:, :-1]
        terminated = batch["terminated"][:, :-1].float()
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        avail_actions = batch["avail_actions"]

        # Calculate estimated Q-Values
        mac_out = []
        self.mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            agent_outs = self.mac.forward(batch, t=t)
            mac_out.append(agent_outs)
        mac_out = th.stack(mac_out, dim=1)  # Concat over time

        # Pick the Q-Values for the actions taken by each agent
        chosen_action_qvals_agents = th.gather(mac_out[:, :-1],
                                               dim=3,
                                               index=actions).squeeze(
                                                   3)  # Remove the last dim
        chosen_action_qvals = chosen_action_qvals_agents

        # Calculate the Q-Values necessary for the target
        target_mac_out = []
        self.target_mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            target_agent_outs = self.target_mac.forward(batch, t=t)
            target_mac_out.append(target_agent_outs)

        # We don't need the first timesteps Q-Value estimate for calculating targets
        target_mac_out = th.stack(target_mac_out[:],
                                  dim=1)  # Concat across time

        # Mask out unavailable actions
        target_mac_out[avail_actions[:, :] == 0] = -9999999  # From OG deepmarl

        # Max over target Q-Values
        if self.args.double_q:
            # Get actions that maximise live Q (for double q-learning)
            mac_out_detach = mac_out.clone().detach()
            mac_out_detach[avail_actions == 0] = -9999999
            cur_max_action_targets, cur_max_actions = mac_out_detach[:, :].max(
                dim=3, keepdim=True)
            target_max_agent_qvals = th.gather(
                target_mac_out[:, :], 3, cur_max_actions[:, :]).squeeze(3)
        else:
            raise Exception("Use double q")

        # Central MAC stuff
        central_mac_out = []
        self.central_mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            agent_outs = self.central_mac.forward(batch, t=t)
            central_mac_out.append(agent_outs)
        central_mac_out = th.stack(central_mac_out, dim=1)  # Concat over time
        central_chosen_action_qvals_agents = th.gather(
            central_mac_out[:, :-1],
            dim=3,
            index=actions.unsqueeze(4).repeat(
                1, 1, 1, 1, self.args.central_action_embed)).squeeze(
                    3)  # Remove the last dim

        central_target_mac_out = []
        self.target_central_mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            target_agent_outs = self.target_central_mac.forward(batch, t=t)
            central_target_mac_out.append(target_agent_outs)
        central_target_mac_out = th.stack(central_target_mac_out[:],
                                          dim=1)  # Concat across time
        # Mask out unavailable actions
        central_target_mac_out[avail_actions[:, :] ==
                               0] = -9999999  # From OG deepmarl
        # Use the Qmix max actions
        central_target_max_agent_qvals = th.gather(
            central_target_mac_out[:, :], 3,
            cur_max_actions[:, :].unsqueeze(4).repeat(
                1, 1, 1, 1, self.args.central_action_embed)).squeeze(3)
        # ---

        # Mix
        chosen_action_qvals = self.mixer(chosen_action_qvals,
                                         batch["state"][:, :-1])
        target_max_qvals = self.target_central_mixer(
            central_target_max_agent_qvals[:, 1:], batch["state"][:, 1:])

        # Calculate 1-step Q-Learning targets
        targets = rewards + self.args.gamma * (1 -
                                               terminated) * target_max_qvals

        # Td-error
        td_error = (chosen_action_qvals - (targets.detach()))

        mask = mask.expand_as(td_error)

        # 0-out the targets that came from padded data
        masked_td_error = td_error * mask

        # Training central Q
        central_chosen_action_qvals = self.central_mixer(
            central_chosen_action_qvals_agents, batch["state"][:, :-1])
        central_td_error = (central_chosen_action_qvals - targets.detach())
        central_mask = mask.expand_as(central_td_error)
        central_masked_td_error = central_td_error * central_mask
        central_loss = (central_masked_td_error**2).sum() / mask.sum()

        # QMIX loss with weighting
        ws = th.ones_like(td_error) * self.args.w
        if self.args.hysteretic_qmix:  # OW-QMIX
            ws = th.where(td_error < 0,
                          th.ones_like(td_error) * 1,
                          ws)  # Target is greater than current max
            w_to_use = ws.mean().item()  # For logging
        else:  # CW-QMIX
            is_max_action = (actions == cur_max_actions[:, :-1]).min(dim=2)[0]
            max_action_qtot = self.target_central_mixer(
                central_target_max_agent_qvals[:, :-1], batch["state"][:, :-1])
            qtot_larger = targets > max_action_qtot
            ws = th.where(is_max_action | qtot_larger,
                          th.ones_like(td_error) * 1,
                          ws)  # Target is greater than current max
            w_to_use = ws.mean().item()  # Average of ws for logging

        qmix_loss = (ws.detach() * (masked_td_error**2)).sum() / mask.sum()

        # The weightings for the different losses aren't used (they are always set to 1)
        loss = self.args.qmix_loss * qmix_loss + self.args.central_loss * central_loss

        # Optimise
        self.optimiser.zero_grad()
        loss.backward()

        # Logging
        agent_norm = 0
        for p in self.mac_params:
            param_norm = p.grad.data.norm(2)
            agent_norm += param_norm.item()**2
        agent_norm = agent_norm**(1. / 2)

        mixer_norm = 0
        for p in self.mixer_params:
            param_norm = p.grad.data.norm(2)
            mixer_norm += param_norm.item()**2
        mixer_norm = mixer_norm**(1. / 2)
        self.mixer_norm = mixer_norm
        self.mixer_norms.append(mixer_norm)

        grad_norm = th.nn.utils.clip_grad_norm_(self.params,
                                                self.args.grad_norm_clip)
        self.grad_norm = grad_norm

        self.optimiser.step()

        if (episode_num - self.last_target_update_episode
            ) / self.args.target_update_interval >= 1.0:
            self._update_targets()
            self.last_target_update_episode = episode_num

        if t_env - self.log_stats_t >= self.args.learner_log_interval:
            self.logger.log_stat("loss", loss.item(), t_env)
            self.logger.log_stat("qmix_loss", qmix_loss.item(), t_env)
            self.logger.log_stat("grad_norm", grad_norm, t_env)
            self.logger.log_stat("mixer_norm", mixer_norm, t_env)
            self.logger.log_stat("agent_norm", agent_norm, t_env)
            mask_elems = mask.sum().item()
            self.logger.log_stat(
                "td_error_abs",
                (masked_td_error.abs().sum().item() / mask_elems), t_env)
            self.logger.log_stat("q_taken_mean",
                                 (chosen_action_qvals * mask).sum().item() /
                                 (mask_elems * self.args.n_agents), t_env)
            self.logger.log_stat("target_mean", (targets * mask).sum().item() /
                                 (mask_elems * self.args.n_agents), t_env)
            self.logger.log_stat("central_loss", central_loss.item(), t_env)
            self.logger.log_stat("w_to_use", w_to_use, t_env)
            self.log_stats_t = t_env

    def _update_targets(self):
        self.target_mac.load_state(self.mac)
        if self.mixer is not None:
            self.target_mixer.load_state_dict(self.mixer.state_dict())
        if self.central_mac is not None:
            self.target_central_mac.load_state(self.central_mac)
        self.target_central_mixer.load_state_dict(
            self.central_mixer.state_dict())
        self.logger.console_logger.info("Updated target network")

    def cuda(self):
        self.mac.cuda()
        self.target_mac.cuda()
        if self.mixer is not None:
            self.mixer.cuda()
            self.target_mixer.cuda()
        if self.central_mac is not None:
            self.central_mac.cuda()
            self.target_central_mac.cuda()
        self.central_mixer.cuda()
        self.target_central_mixer.cuda()

    # TODO: Model saving/loading is out of date!
    def save_models(self, path):
        self.mac.save_models(path)
        if self.mixer is not None:
            th.save(self.mixer.state_dict(), "{}/mixer.th".format(path))
        th.save(self.optimiser.state_dict(), "{}/opt.th".format(path))

    def load_models(self, path):
        self.mac.load_models(path)
        # Not quite right but I don't want to save target networks
        self.target_mac.load_models(path)
        if self.mixer is not None:
            self.mixer.load_state_dict(
                th.load("{}/mixer.th".format(path),
                        map_location=lambda storage, loc: storage))
        self.optimiser.load_state_dict(
            th.load("{}/opt.th".format(path),
                    map_location=lambda storage, loc: storage))
Ejemplo n.º 10
0
class QLearner:
    def __init__(self, mac, scheme, logger, args):
        self.args = args
        self.mac = mac
        self.logger = logger

        self.params = list(mac.parameters())

        self.last_target_update_episode = 0

        self.mixer = None
        if args.mixer is not None:
            if args.mixer == "vdn":
                self.mixer = VDNMixer()
            elif args.mixer == "qmix":
                self.mixer = QMixer(args)
            elif args.mixer == "flex_qmix":
                assert args.entity_scheme, "FlexQMixer only available with entity scheme"
                self.mixer = FlexQMixer(args)
            else:
                raise ValueError("Mixer {} not recognised.".format(args.mixer))
            self.params += list(self.mixer.parameters())
            self.target_mixer = copy.deepcopy(self.mixer)

        self.optimiser = RMSprop(params=self.params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps,
                                 weight_decay=args.weight_decay)

        # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC
        self.target_mac = copy.deepcopy(mac)

        self.log_stats_t = -self.args.learner_log_interval - 1

    def _get_mixer_ins(self, batch, repeat_batch=1):
        if not self.args.entity_scheme:
            return (batch["state"][:, :-1].repeat(repeat_batch, 1, 1, 1),
                    batch["state"][:, 1:])
        else:
            entities = []
            bs, max_t, ne, ed = batch["entities"].shape
            entities.append(batch["entities"])
            if self.args.entity_last_action:
                last_actions = th.zeros(bs, max_t, ne, self.args.n_actions,
                                        device=batch.device,
                                        dtype=batch["entities"].dtype)
                last_actions[:, 1:, :self.args.n_agents] = batch["actions_onehot"][:, :-1]
                entities.append(last_actions)

            entities = th.cat(entities, dim=3)
            return ((entities[:, :-1].repeat(repeat_batch, 1, 1, 1),
                     batch["entity_mask"][:, :-1].repeat(repeat_batch, 1, 1)),
                    (entities[:, 1:],
                     batch["entity_mask"][:, 1:]))

    def train(self, batch: EpisodeBatch, t_env: int, episode_num: int):
        # Get the relevant quantities
        rewards = batch["reward"][:, :-1]
        actions = batch["actions"][:, :-1]
        terminated = batch["terminated"][:, :-1].float()
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        avail_actions = batch["avail_actions"]

        # # Calculate estimated Q-Values
        # mac_out = []
        self.mac.init_hidden(batch.batch_size)
        # enable things like dropout on mac and mixer, but not target_mac and target_mixer
        self.mac.train()
        self.mixer.train()
        self.target_mac.eval()
        self.target_mixer.eval()
        # for t in range(batch.max_seq_length):
        #     agent_outs = self.mac.forward(batch, t=t)
        #     mac_out.append(agent_outs)
        # mac_out = th.stack(mac_out, dim=1)  # Concat over time
        if 'imagine' in self.args.agent:
            all_mac_out, groups = self.mac.forward(batch, t=None, imagine=True)
            # Pick the Q-Values for the actions taken by each agent
            rep_actions = actions.repeat(3, 1, 1, 1)
            all_chosen_action_qvals = th.gather(all_mac_out[:, :-1], dim=3, index=rep_actions).squeeze(3)  # Remove the last dim

            mac_out, moW, moI = all_mac_out.chunk(3, dim=0)
            chosen_action_qvals, caqW, caqI = all_chosen_action_qvals.chunk(3, dim=0)
            if not self.args.mix_imagined:
                caq_imagine = caqW + caqI
            else:
                caq_imagine = th.cat([caqW, caqI], dim=2)
        else:
            mac_out = self.mac.forward(batch, t=None)
            # Pick the Q-Values for the actions taken by each agent
            chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze(3)  # Remove the last dim

        self.target_mac.init_hidden(batch.batch_size)

        target_mac_out = self.target_mac.forward(batch, t=None)
        avail_actions_targ = avail_actions
        target_mac_out = target_mac_out[:, 1:]

        # Mask out unavailable actions
        target_mac_out[avail_actions_targ[:, 1:] == 0] = -9999999  # From OG deepmarl

        # Max over target Q-Values
        if self.args.double_q:
            # Get actions that maximise live Q (for double q-learning)
            mac_out_detach = mac_out.clone().detach()
            mac_out_detach[avail_actions_targ == 0] = -9999999
            cur_max_actions = mac_out_detach[:, 1:].max(dim=3, keepdim=True)[1]
            target_max_qvals = th.gather(target_mac_out, 3, cur_max_actions).squeeze(3)
        else:
            target_max_qvals = target_mac_out.max(dim=3)[0]

        # Mix
        if self.mixer is not None:
            if 'imagine' in self.args.agent:
                if not self.args.mix_imagined:
                    mix_ins, targ_mix_ins = self._get_mixer_ins(batch,
                                                                repeat_batch=2)
                    mixer_qvals = th.cat([chosen_action_qvals, caq_imagine], dim=0)
                    chosen_action_qvals = self.mixer(mixer_qvals, mix_ins)
                    chosen_action_qvals, caq_imagine = chosen_action_qvals.chunk(2, dim=0)
                else:
                    mix_ins, targ_mix_ins = self._get_mixer_ins(batch)
                    chosen_action_qvals = self.mixer(chosen_action_qvals,
                                                     mix_ins)
                    # don't need last timestep
                    groups = [gr[:, :-1] for gr in groups]
                    caq_imagine = self.mixer(caq_imagine, mix_ins,
                                             imagine_groups=groups)
            else:
                mix_ins, targ_mix_ins = self._get_mixer_ins(batch)
                chosen_action_qvals = self.mixer(chosen_action_qvals, mix_ins)
            target_max_qvals = self.target_mixer(target_max_qvals, targ_mix_ins)

        # Calculate 1-step Q-Learning targets
        targets = rewards + self.args.gamma * (1 - terminated) * target_max_qvals

        # Td-error
        td_error = (chosen_action_qvals - targets.detach())
        mask = mask.expand_as(td_error)
        # 0-out the targets that came from padded data
        masked_td_error = td_error * mask
        # Normal L2 loss, take mean over actual data
        loss = (masked_td_error ** 2).sum() / mask.sum()

        if 'imagine' in self.args.agent:
            im_prop = self.args.lmbda
            im_td_error = (caq_imagine - targets.detach())
            im_masked_td_error = im_td_error * mask
            im_loss = (im_masked_td_error ** 2).sum() / mask.sum()
            loss = (1 - im_prop) * loss + im_prop * im_loss

        # Optimise
        self.optimiser.zero_grad()
        loss.backward()
        grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip)
        self.optimiser.step()

        if (episode_num - self.last_target_update_episode) / self.args.target_update_interval >= 1.0:
            self._update_targets()
            self.last_target_update_episode = episode_num

        if t_env - self.log_stats_t >= self.args.learner_log_interval:
            self.logger.log_stat("loss", loss.item(), t_env)
            if 'imagine' in self.args.agent:
                self.logger.log_stat("im_loss", im_loss.item(), t_env)
            self.logger.log_stat("grad_norm", grad_norm, t_env)
            mask_elems = mask.sum().item()
            self.logger.log_stat("td_error_abs", (masked_td_error.abs().sum().item()/mask_elems), t_env)
            self.logger.log_stat("q_taken_mean", (chosen_action_qvals * mask).sum().item()/(mask_elems * self.args.n_agents), t_env)
            self.logger.log_stat("target_mean", (targets * mask).sum().item()/(mask_elems * self.args.n_agents), t_env)
            if self.args.gated:
                self.logger.log_stat("gate", self.mixer.gate.cpu().item(), t_env)
            if batch.max_seq_length == 2:
                # We are in a 1-step env. Calculate the max Q-Value for logging
                max_agent_qvals = mac_out_detach[:,0].max(dim=2, keepdim=True)[0]
                max_qtots = self.mixer(max_agent_qvals, batch["state"][:,0])
                self.logger.log_stat("max_qtot", max_qtots.mean().item(), t_env)
            self.log_stats_t = t_env

    def _update_targets(self):
        self.target_mac.load_state(self.mac)
        if self.mixer is not None:
            self.target_mixer.load_state_dict(self.mixer.state_dict())
        self.logger.console_logger.info("Updated target network")

    def cuda(self):
        self.mac.cuda()
        self.target_mac.cuda()
        if self.mixer is not None:
            self.mixer.cuda()
            self.target_mixer.cuda()
        if hasattr(self, "imagine_mixer"):
            self.imagine_mixer.cuda()

    def save_models(self, path):
        self.mac.save_models(path)
        if self.mixer is not None:
            th.save(self.mixer.state_dict(), "{}/mixer.th".format(path))
        if hasattr(self, "imagine_mixer"):
            th.save(self.imagine_mixer.state_dict(), "{}/imagine_mixer.th".format(path))
        th.save(self.optimiser.state_dict(), "{}/opt.th".format(path))

    def load_models(self, path, evaluate=False):
        self.mac.load_models(path)
        # Not quite right but I don't want to save target networks
        self.target_mac.load_models(path)
        if not evaluate:
            if self.mixer is not None:
                self.mixer.load_state_dict(th.load("{}/mixer.th".format(path), map_location=lambda storage, loc: storage))
            if hasattr(self, "imagine_mixer"):
                self.imagine_mixer.load_state_dict(th.load("{}/imagine_mixer.th".format(path), map_location=lambda storage, loc: storage))
            self.optimiser.load_state_dict(th.load("{}/opt.th".format(path), map_location=lambda storage, loc: storage))
Ejemplo n.º 11
0
class RODELearner:
    def __init__(self, mac, scheme, logger, args):
        self.args = args
        self.mac = mac
        self.logger = logger
        self.n_agents = args.n_agents

        self.params = list(mac.parameters())

        self.last_target_update_episode = 0

        self.mixer = None
        if args.mixer is not None:
            if args.mixer == "vdn":
                self.mixer = VDNMixer()
            elif args.mixer == "qmix":
                self.mixer = QMixer(args)
            else:
                raise ValueError("Mixer {} not recognised.".format(args.mixer))
            self.params += list(self.mixer.parameters())
            self.target_mixer = copy.deepcopy(self.mixer)

        self.role_mixer = None
        if args.role_mixer is not None:
            if args.role_mixer == "vdn":
                self.role_mixer = VDNMixer()
            elif args.role_mixer == "qmix":
                self.role_mixer = QMixer(args)
            else:
                raise ValueError("Role Mixer {} not recognised.".format(
                    args.role_mixer))
            self.params += list(self.role_mixer.parameters())
            self.target_role_mixer = copy.deepcopy(self.role_mixer)

        self.optimiser = RMSprop(params=self.params,
                                 lr=args.lr,
                                 alpha=args.optim_alpha,
                                 eps=args.optim_eps)

        # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC
        self.target_mac = copy.deepcopy(mac)

        self.log_stats_t = -self.args.learner_log_interval - 1

        self.role_interval = args.role_interval
        self.device = self.args.device

        self.role_action_spaces_updated = True

        # action encoder
        self.action_encoder_params = list(self.mac.action_encoder_params())
        self.action_encoder_optimiser = RMSprop(
            params=self.action_encoder_params,
            lr=args.lr,
            alpha=args.optim_alpha,
            eps=args.optim_eps)

    def train(self, batch: EpisodeBatch, t_env: int, episode_num: int):
        # Get the relevant quantities
        rewards = batch["reward"][:, :-1]
        actions = batch["actions"][:, :-1]
        terminated = batch["terminated"][:, :-1].float()
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        avail_actions = batch["avail_actions"]
        # role_avail_actions = batch["role_avail_actions"]
        roles_shape_o = batch["roles"][:, :-1].shape
        role_at = int(np.ceil(roles_shape_o[1] / self.role_interval))
        role_t = role_at * self.role_interval

        roles_shape = list(roles_shape_o)
        roles_shape[1] = role_t
        roles = th.zeros(roles_shape).to(self.device)
        roles[:, :roles_shape_o[1]] = batch["roles"][:, :-1]
        roles = roles.view(batch.batch_size, role_at, self.role_interval,
                           self.n_agents, -1)[:, :, 0]

        # Calculate estimated Q-Values
        mac_out = []
        role_out = []
        self.mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            agent_outs, role_outs = self.mac.forward(batch, t=t)
            mac_out.append(agent_outs)
            if t % self.role_interval == 0 and t < batch.max_seq_length - 1:
                role_out.append(role_outs)
        mac_out = th.stack(mac_out, dim=1)  # Concat over time
        role_out = th.stack(role_out, dim=1)  # Concat over time

        # Pick the Q-Values for the actions taken by each agent
        chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3,
                                        index=actions).squeeze(
                                            3)  # Remove the last dim
        chosen_role_qvals = th.gather(role_out, dim=3,
                                      index=roles.long()).squeeze(3)

        # Calculate the Q-Values necessary for the target
        target_mac_out = []
        target_role_out = []
        self.target_mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            target_agent_outs, target_role_outs = self.target_mac.forward(
                batch, t=t)
            target_mac_out.append(target_agent_outs)
            if t % self.role_interval == 0 and t < batch.max_seq_length - 1:
                target_role_out.append(target_role_outs)

        target_role_out.append(
            th.zeros(batch.batch_size, self.n_agents,
                     self.mac.n_roles).to(self.device))
        # We don't need the first timesteps Q-Value estimate for calculating targets
        target_mac_out = th.stack(target_mac_out[1:],
                                  dim=1)  # Concat across time
        target_role_out = th.stack(target_role_out[1:], dim=1)

        # Mask out unavailable actions
        target_mac_out[avail_actions[:, 1:] == 0] = -9999999
        # target_mac_out[role_avail_actions[:, 1:] == 0] = -9999999

        # Max over target Q-Values
        if self.args.double_q:
            # Get actions that maximise live Q (for double q-learning)
            mac_out_detach = mac_out.clone().detach()
            mac_out_detach[avail_actions == 0] = -9999999
            # mac_out_detach[role_avail_actions == 0] = -9999999
            cur_max_actions = mac_out_detach[:, 1:].max(dim=3, keepdim=True)[1]
            target_max_qvals = th.gather(target_mac_out, 3,
                                         cur_max_actions).squeeze(3)

            role_out_detach = role_out.clone().detach()
            role_out_detach = th.cat(
                [role_out_detach[:, 1:], role_out_detach[:, 0:1]], dim=1)
            cur_max_roles = role_out_detach.max(dim=3, keepdim=True)[1]
            target_role_max_qvals = th.gather(target_role_out, 3,
                                              cur_max_roles).squeeze(3)
        else:
            target_max_qvals = target_mac_out.max(dim=3)[0]
            target_role_max_qvals = target_role_out.max(dim=3)[0]

        # Mix
        if self.mixer is not None:
            chosen_action_qvals = self.mixer(chosen_action_qvals,
                                             batch["state"][:, :-1])
            target_max_qvals = self.target_mixer(target_max_qvals,
                                                 batch["state"][:, 1:])
        if self.role_mixer is not None:
            state_shape_o = batch["state"][:, :-1].shape
            state_shape = list(state_shape_o)
            state_shape[1] = role_t
            role_states = th.zeros(state_shape).to(self.device)
            role_states[:, :state_shape_o[1]] = batch["state"][:, :-1].detach(
            ).clone()
            role_states = role_states.view(batch.batch_size, role_at,
                                           self.role_interval, -1)[:, :, 0]
            chosen_role_qvals = self.role_mixer(chosen_role_qvals, role_states)
            role_states = th.cat([role_states[:, 1:], role_states[:, 0:1]],
                                 dim=1)
            target_role_max_qvals = self.target_role_mixer(
                target_role_max_qvals, role_states)

        # Calculate 1-step Q-Learning targets
        targets = rewards + self.args.gamma * (1 -
                                               terminated) * target_max_qvals
        rewards_shape = list(rewards.shape)
        rewards_shape[1] = role_t
        role_rewards = th.zeros(rewards_shape).to(self.device)
        role_rewards[:, :rewards.shape[1]] = rewards.detach().clone()
        role_rewards = role_rewards.view(batch.batch_size, role_at,
                                         self.role_interval).sum(dim=-1,
                                                                 keepdim=True)
        # role_terminated
        terminated_shape_o = terminated.shape
        terminated_shape = list(terminated_shape_o)
        terminated_shape[1] = role_t
        role_terminated = th.zeros(terminated_shape).to(self.device)
        role_terminated[:, :terminated_shape_o[1]] = terminated.detach().clone(
        )
        role_terminated = role_terminated.view(
            batch.batch_size, role_at, self.role_interval).sum(dim=-1,
                                                               keepdim=True)
        # role_terminated
        role_targets = role_rewards + self.args.gamma * (
            1 - role_terminated) * target_role_max_qvals

        # Td-error
        td_error = (chosen_action_qvals - targets.detach())
        role_td_error = (chosen_role_qvals - role_targets.detach())

        mask = mask.expand_as(td_error)
        mask_shape = list(mask.shape)
        mask_shape[1] = role_t
        role_mask = th.zeros(mask_shape).to(self.device)
        role_mask[:, :mask.shape[1]] = mask.detach().clone()
        role_mask = role_mask.view(batch.batch_size, role_at,
                                   self.role_interval, -1)[:, :, 0]

        # 0-out the targets that came from padded data
        masked_td_error = td_error * mask
        masked_role_td_error = role_td_error * role_mask

        # Normal L2 loss, take mean over actual data
        loss = (masked_td_error**2).sum() / mask.sum()
        role_loss = (masked_role_td_error**2).sum() / role_mask.sum()
        loss += role_loss

        # Optimise
        self.optimiser.zero_grad()
        loss.backward()
        grad_norm = th.nn.utils.clip_grad_norm_(self.params,
                                                self.args.grad_norm_clip)
        self.optimiser.step()

        pred_obs_loss = None
        pred_r_loss = None
        pred_grad_norm = None
        if self.role_action_spaces_updated:
            # train action encoder
            no_pred = []
            r_pred = []
            for t in range(batch.max_seq_length):
                no_preds, r_preds = self.mac.action_repr_forward(batch, t=t)
                no_pred.append(no_preds)
                r_pred.append(r_preds)
            no_pred = th.stack(no_pred, dim=1)[:, :-1]  # Concat over time
            r_pred = th.stack(r_pred, dim=1)[:, :-1]
            no = batch["obs"][:, 1:].detach().clone()
            repeated_rewards = batch["reward"][:, :-1].detach().clone(
            ).unsqueeze(2).repeat(1, 1, self.n_agents, 1)

            pred_obs_loss = th.sqrt(((no_pred - no)**2).sum(dim=-1)).mean()
            pred_r_loss = ((r_pred - repeated_rewards)**2).mean()

            pred_loss = pred_obs_loss + 10 * pred_r_loss
            self.action_encoder_optimiser.zero_grad()
            pred_loss.backward()
            pred_grad_norm = th.nn.utils.clip_grad_norm_(
                self.action_encoder_params, self.args.grad_norm_clip)
            self.action_encoder_optimiser.step()

            if t_env > self.args.role_action_spaces_update_start:
                self.mac.update_role_action_spaces()
                if 'noar' in self.args.mac:
                    self.target_mac.role_selector.update_roles(
                        self.mac.n_roles)
                self.role_action_spaces_updated = False
                self._update_targets()
                self.last_target_update_episode = episode_num

        if (episode_num - self.last_target_update_episode
            ) / self.args.target_update_interval >= 1.0:
            self._update_targets()
            self.last_target_update_episode = episode_num

        if t_env - self.log_stats_t >= self.args.learner_log_interval:
            self.logger.log_stat("loss", (loss - role_loss).item(), t_env)
            self.logger.log_stat("role_loss", role_loss.item(), t_env)
            self.logger.log_stat("grad_norm", grad_norm, t_env)
            if pred_obs_loss is not None:
                self.logger.log_stat("pred_obs_loss", pred_obs_loss.item(),
                                     t_env)
                self.logger.log_stat("pred_r_loss", pred_r_loss.item(), t_env)
                self.logger.log_stat("action_encoder_grad_norm",
                                     pred_grad_norm, t_env)
            mask_elems = mask.sum().item()
            self.logger.log_stat(
                "td_error_abs",
                (masked_td_error.abs().sum().item() / mask_elems), t_env)
            self.logger.log_stat("q_taken_mean",
                                 (chosen_action_qvals * mask).sum().item() /
                                 (mask_elems * self.args.n_agents), t_env)
            self.logger.log_stat("role_q_taken_mean",
                                 (chosen_role_qvals * role_mask).sum().item() /
                                 (role_mask.sum().item() * self.args.n_agents),
                                 t_env)
            self.logger.log_stat("target_mean", (targets * mask).sum().item() /
                                 (mask_elems * self.args.n_agents), t_env)
            self.log_stats_t = t_env

    def _update_targets(self):
        self.target_mac.load_state(self.mac)
        if self.mixer is not None:
            self.target_mixer.load_state_dict(self.mixer.state_dict())
        if self.role_mixer is not None:
            self.target_role_mixer.load_state_dict(
                self.role_mixer.state_dict())
        self.target_mac.role_action_spaces_updated = self.role_action_spaces_updated
        self.logger.console_logger.info("Updated target network")

    def cuda(self):
        to_cuda(self.mac, self.args.device)
        to_cuda(self.target_mac, self.args.device)
        if self.mixer is not None:
            to_cuda(self.mixer, self.args.device)
            to_cuda(self.target_mixer, self.args.device)
        if self.role_mixer is not None:
            to_cuda(self.role_mixer, self.args.device)
            to_cuda(self.target_role_mixer, self.args.device)

    def save_models(self, path):
        self.mac.save_models(path)
        if self.mixer is not None:
            th.save(self.mixer.state_dict(), "{}/mixer.th".format(path))
        if self.role_mixer is not None:
            th.save(self.role_mixer.state_dict(),
                    "{}/role_mixer.th".format(path))
        th.save(self.optimiser.state_dict(), "{}/opt.th".format(path))
        th.save(self.action_encoder_optimiser.state_dict(),
                "{}/action_repr_opt.th".format(path))

    def load_models(self, path):
        self.mac.load_models(path)
        # Not quite right but I don't want to save target networks
        self.target_mac.load_models(path)
        if self.mixer is not None:
            self.mixer.load_state_dict(
                th.load("{}/mixer.th".format(path),
                        map_location=lambda storage, loc: storage))
        if self.role_mixer is not None:
            self.role_mixer.load_state_dict(
                th.load("{}/role_mixer.th".format(path),
                        map_location=lambda storage, loc: storage))
        self.optimiser.load_state_dict(
            th.load("{}/opt.th".format(path),
                    map_location=lambda storage, loc: storage))
        self.action_encoder_optimiser.load_state_dict(
            th.load("{}/action_repr_opt.th".format(path),
                    map_location=lambda storage, loc: storage))
Ejemplo n.º 12
0
    def __init__(self, mac, scheme, logger, args):
        self.args = args
        self.mac = mac
        self.logger = logger
        self.scheme = scheme
        self.n_agents = args.n_agents

        self.params = list(mac.parameters())

        self.last_target_update_episode = 0

        self.mixer = None
        if args.mixer is not None:
            if args.mixer == "vdn":
                self.mixer = VDNMixer()
            elif args.mixer == "qmix":
                self.mixer = QMixer(args)
            elif args.mixer == "hqmix":
                self.mixer = HQMixer(args)
            else:
                raise ValueError("Mixer {} not recognised.".format(args.mixer))
            self.params += list(self.mixer.parameters())
            self.target_mixer = copy.deepcopy(self.mixer)
            self.meta_params = list(self.mac.parameters())

        self.role_mixer = None
        if args.role_mixer is not None:
            if args.role_mixer == "vdn":
                self.role_mixer = VDNMixer()
            elif args.role_mixer == "qmix" or (args.role_mixer == "hqmix"
                                               and not args.meta_role):
                self.role_mixer = QMixer(args)
            elif args.role_mixer == "hqmix":
                self.role_mixer = HQMixer(args)
            else:
                raise ValueError("Role Mixer {} not recognised.".format(
                    args.role_mixer))
            self.params += list(self.role_mixer.parameters())
            self.target_role_mixer = copy.deepcopy(self.role_mixer)

        self.optimiser = RMSprop(params=self.params,
                                 lr=args.lr,
                                 alpha=args.optim_alpha,
                                 eps=args.optim_eps)
        self.meta_optimiser = RMSprop(params=self.meta_params,
                                      lr=args.meta_lr,
                                      alpha=args.optim_alpha,
                                      eps=args.optim_eps)

        # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC
        self.target_mac = copy.deepcopy(mac)

        self.log_stats_t = -self.args.learner_log_interval - 1

        self.role_interval = args.role_interval
        self.device = self.args.device

        self.role_action_spaces_updated = True

        # action encoder
        self.action_encoder_params = list(self.mac.action_encoder_params())
        self.action_encoder_optimiser = RMSprop(
            params=self.action_encoder_params,
            lr=args.lr,
            alpha=args.optim_alpha,
            eps=args.optim_eps)
Ejemplo n.º 13
0
class COMALearner:
    def __init__(self, mac, scheme, logger, args):
        self.args = args
        self.n_agents = args.n_agents
        self.n_actions = args.n_actions
        self.mac = mac
        self.logger = logger

        self.last_target_update_step = 0
        self.critic_training_steps = 0

        self.log_stats_t = -self.args.learner_log_interval - 1

        self.critic = COMACritic(scheme, args)
        self.target_critic = copy.deepcopy(self.critic)

        self.agent_params = list(mac.parameters())
        self.critic_params = list(self.critic.parameters())
        self.params = self.agent_params + self.critic_params

        self.agent_optimiser = RMSprop(params=self.agent_params,
                                       lr=args.lr,
                                       alpha=args.optim_alpha,
                                       eps=args.optim_eps)
        self.critic_optimiser = RMSprop(params=self.critic_params,
                                        lr=args.critic_lr,
                                        alpha=args.optim_alpha,
                                        eps=args.optim_eps)

    def train(self, batch: EpisodeBatch, t_env: int, episode_num: int):
        # Get the relevant quantities
        bs = batch.batch_size
        max_t = batch.max_seq_length
        rewards = batch["reward"][:, :-1]
        # print('rewards shape: ', rewards.shape)
        actions = batch["actions"][:, :]
        # print('actions shape: ', actions.shape)
        terminated = batch["terminated"][:, :-1].float()
        # print('terminated shape: ', terminated.shape)
        mask = batch["filled"][:, :-1].float()
        # mask_bak = copy.deepcopy(mask)
        # print('mask shape: ', mask.shape)
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        # if th.sum(mask_bak - mask) != 0:
        # for i in range(bs):
        # print('------------')
        # print(batch["terminated"][i].flatten())
        # print(batch["filled"][i].flatten())
        # print(mask_bak[i].flatten())
        # print(terminated[i].flatten())
        # print(mask[i].flatten())
        # assert False
        avail_actions = batch["avail_actions"][:, :-1]

        critic_mask = mask.clone()

        mask = mask.repeat(1, 1, self.n_agents).view(-1)

        q_vals, critic_train_stats = self._train_critic(
            batch, rewards, terminated, actions, avail_actions, critic_mask,
            bs, max_t)

        actions = actions[:, :-1]

        mac_out = []
        self.mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length - 1):
            agent_outs = self.mac.forward(batch, t=t)
            mac_out.append(agent_outs)
        mac_out = th.stack(mac_out, dim=1)  # Concat over time

        # Mask out unavailable actions, renormalise (as in action selection)
        mac_out[avail_actions == 0] = 0
        mac_out = mac_out / mac_out.sum(dim=-1, keepdim=True)
        mac_out[avail_actions == 0] = 0

        # Calculated baseline
        q_vals = q_vals.reshape(-1, self.n_actions)
        pi = mac_out.view(-1, self.n_actions)
        baseline = (pi * q_vals).sum(-1).detach()

        # Calculate policy grad with mask
        q_taken = th.gather(q_vals, dim=1, index=actions.reshape(-1,
                                                                 1)).squeeze(1)
        pi_taken = th.gather(pi, dim=1, index=actions.reshape(-1,
                                                              1)).squeeze(1)
        pi_taken[mask == 0] = 1.0
        log_pi_taken = th.log(pi_taken)

        advantages = (q_taken - baseline).detach()

        coma_loss = -((advantages * log_pi_taken) * mask).sum() / mask.sum()

        # Optimise agents
        self.agent_optimiser.zero_grad()
        coma_loss.backward()
        grad_norm = th.nn.utils.clip_grad_norm_(self.agent_params,
                                                self.args.grad_norm_clip)
        self.agent_optimiser.step()

        if (self.critic_training_steps - self.last_target_update_step
            ) / self.args.target_update_interval >= 1.0:
            self._update_targets()
            self.last_target_update_step = self.critic_training_steps

        if t_env - self.log_stats_t >= self.args.learner_log_interval:
            ts_logged = len(critic_train_stats["critic_loss"])
            for key in [
                    "critic_loss", "critic_grad_norm", "td_error_abs",
                    "q_taken_mean", "target_mean"
            ]:
                self.logger.log_stat(key,
                                     sum(critic_train_stats[key]) / ts_logged,
                                     t_env)

            self.logger.log_stat("advantage_mean",
                                 (advantages * mask).sum().item() /
                                 mask.sum().item(), t_env)
            self.logger.log_stat("coma_loss", coma_loss.item(), t_env)
            self.logger.log_stat("agent_grad_norm", grad_norm, t_env)
            self.logger.log_stat("pi_max",
                                 (pi.max(dim=1)[0] * mask).sum().item() /
                                 mask.sum().item(), t_env)
            self.log_stats_t = t_env

    def _train_critic(self, batch, rewards, terminated, actions, avail_actions,
                      mask, bs, max_t):
        # Optimise critic
        target_q_vals = self.target_critic(batch)[:, :]
        # print('target_q_vals shape: ', target_q_vals.shape)
        targets_taken = th.gather(target_q_vals, dim=3,
                                  index=actions).squeeze(3)
        # print('targets_taken shape ', targets_taken.shape)

        # Calculate td-lambda targets
        targets = build_td_lambda_targets(rewards, terminated, mask,
                                          targets_taken, self.n_agents,
                                          self.args.gamma, self.args.td_lambda)
        # print('targets shape: ', targets.shape)

        q_vals = th.zeros_like(target_q_vals)[:, :-1]
        # print('q_vals shape: ', q_vals.shape)

        running_log = {
            "critic_loss": [],
            "critic_grad_norm": [],
            "td_error_abs": [],
            "target_mean": [],
            "q_taken_mean": [],
        }

        for t in reversed(range(rewards.size(1))):
            mask_t = mask[:, t].expand(-1, self.n_agents)
            # print('mask_t shape: ', mask_t.shape)
            # assert False
            if mask_t.sum() == 0:
                continue

            q_t = self.critic(batch, t)
            q_vals[:, t] = q_t.view(bs, self.n_agents, self.n_actions)
            q_taken = th.gather(q_t, dim=3,
                                index=actions[:,
                                              t:t + 1]).squeeze(3).squeeze(1)
            targets_t = targets[:, t]

            td_error = (q_taken - targets_t.detach())

            # 0-out the targets that came from padded data
            masked_td_error = td_error * mask_t

            # Normal L2 loss, take mean over actual data
            loss = (masked_td_error**2).sum() / mask_t.sum()
            self.critic_optimiser.zero_grad()
            loss.backward()
            grad_norm = th.nn.utils.clip_grad_norm_(self.critic_params,
                                                    self.args.grad_norm_clip)
            self.critic_optimiser.step()
            self.critic_training_steps += 1

            running_log["critic_loss"].append(loss.item())
            running_log["critic_grad_norm"].append(grad_norm)
            mask_elems = mask_t.sum().item()
            running_log["td_error_abs"].append(
                (masked_td_error.abs().sum().item() / mask_elems))
            running_log["q_taken_mean"].append(
                (q_taken * mask_t).sum().item() / mask_elems)
            running_log["target_mean"].append(
                (targets_t * mask_t).sum().item() / mask_elems)

        return q_vals, running_log

    def _update_targets(self):
        self.target_critic.load_state_dict(self.critic.state_dict())
        self.logger.console_logger.info("Updated target network")

    def cuda(self):
        self.mac.cuda()
        self.critic.cuda()
        self.target_critic.cuda()

    def save_models(self, path):
        self.mac.save_models(path)
        th.save(self.critic.state_dict(), "{}/critic.th".format(path))
        th.save(self.agent_optimiser.state_dict(),
                "{}/agent_opt.th".format(path))
        th.save(self.critic_optimiser.state_dict(),
                "{}/critic_opt.th".format(path))

    def load_models(self, path):
        self.mac.load_models(path)
        self.critic.load_state_dict(
            th.load("{}/critic.th".format(path),
                    map_location=lambda storage, loc: storage))
        # Not quite right but I don't want to save target networks
        self.target_critic.load_state_dict(self.critic.state_dict())
        self.agent_optimiser.load_state_dict(
            th.load("{}/agent_opt.th".format(path),
                    map_location=lambda storage, loc: storage))
        self.critic_optimiser.load_state_dict(
            th.load("{}/critic_opt.th".format(path),
                    map_location=lambda storage, loc: storage))
Ejemplo n.º 14
0
def train_deepq(name,
                env,
                nb_actions,
                Q_network,
                preprocess_fn=None,
                batch_size=32,
                replay_start_size=50000,
                replay_memory_size=50000,
                agent_history_length=4,
                target_network_update_frequency=10000,
                discount_factor=0.99,
                learning_rate=1e-5,
                update_frequency=4,
                inital_exploration=1,
                final_exploration=0.1,
                final_exploration_step=int(1e6),
                nb_timesteps=int(1e7),
                tensorboard_freq=50,
                demo_tensorboard=False):

    #SAVE/LOAD MODEL
    DIRECTORY_MODELS = './models/'
    if not os.path.exists(DIRECTORY_MODELS):
        os.makedirs(DIRECTORY_MODELS)
    PATH_SAVE = DIRECTORY_MODELS + name + '_' + time.strftime('%Y%m%d-%H%M')

    #GPU/CPU
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    print('RUNNING ON', device)

    #TENSORBOARDX
    writer = SummaryWriter(comment=name)

    replay_memory = init_replay_memory(env, replay_memory_size,
                                       replay_start_size, preprocess_fn)

    print('#### TRAINING ####')
    print('see more details on tensorboard')

    done = True  #reset environment
    eps_schedule = ScheduleExploration(inital_exploration, final_exploration,
                                       final_exploration_step)
    Q_network = Q_network.to(device)
    Q_hat = copy.deepcopy(Q_network).to(device)
    loss = SmoothL1Loss()
    optimizer = RMSprop(Q_network.parameters(),
                        lr=learning_rate,
                        alpha=0.95,
                        eps=0.01,
                        centered=True)

    episode = 1
    rewards_episode, total_reward_per_episode = list(), list()
    for timestep in tqdm(range(nb_timesteps)):  #tqdm
        #if an episode is ended
        if done:
            total_reward_per_episode.append(np.sum(rewards_episode))
            rewards_episode = list()
            phi_t = env.reset()
            if preprocess_fn:
                phi_t = preprocess_fn(phi_t)

            if (episode % tensorboard_freq == 0):
                assert len(total_reward_per_episode) == tensorboard_freq
                #tensorboard
                writer.add_scalar('rewards/train_reward',
                                  np.mean(total_reward_per_episode), episode)
                total_reward_per_episode = list()
                writer.add_scalar('other/replay_memory_size',
                                  len(replay_memory), episode)
                writer.add_scalar('other/eps_exploration',
                                  eps_schedule.get_eps(), episode)
                if demo_tensorboard:
                    demos, demo_rewards = play(env,
                                               Q_network,
                                               preprocess_fn,
                                               nb_episodes=1,
                                               eps=eps_schedule.get_eps())
                    writer.add_scalar('rewards/demo_reward',
                                      np.mean(demo_rewards), episode)
                    for demo in demos:
                        demo = demo.permute([3, 0, 1, 2]).unsqueeze(0)
                        writer.add_video(name, demo, episode, fps=25)

                #save model
                torch.save(Q_network.state_dict(), PATH_SAVE)

            episode += 1

        a_t = get_action(phi_t, env, Q_network, eps_schedule)

        phi_t_1, r_t, done, info = env.step(a_t)
        rewards_episode.append(r_t)
        if preprocess_fn:
            phi_t_1 = preprocess_fn(phi_t_1)
        replay_memory.push([phi_t, a_t, r_t, phi_t_1, done])
        phi_t = phi_t_1

        #training
        if timestep % update_frequency == 0:
            #get training data
            phi_t_training, actions_training, y = get_training_data(
                Q_hat, replay_memory, batch_size, discount_factor)

            #forward
            phi_t_training = phi_t_training.to(device)
            Q_values = Q_network(phi_t_training)
            mask = torch.zeros([batch_size, nb_actions]).to(device)
            for j in range(len(actions_training)):
                mask[j, actions_training[j]] = 1
            Q_values = Q_values * mask
            Q_values = torch.sum(Q_values, dim=1)
            output = loss(Q_values, y)

            #backward and gradient descent
            optimizer.zero_grad()
            output.backward()
            optimizer.step()

        if timestep % target_network_update_frequency == 0:
            Q_hat = copy.deepcopy(Q_network).to(device)
Ejemplo n.º 15
0
class VisualWeight():
    def __init__(self,
                 model,
                 input_dim,
                 layer_name='all',
                 filter_id=None,
                 epoch=30,
                 ImageNet=False,
                 reshape=None):
        self.model = model
        self.model.eval()
        self.model.to('cpu')
        self.input_dim = input_dim
        self.layer_name = layer_name
        self.filter_id = filter_id
        self.epoch = epoch
        self.ImageNet = ImageNet
        if type(reshape) == tuple:
            self.reshape = reshape
        else:
            self.reshape = None

    def hook_layer(self):
        def hook_function(module, _in, _out):
            if isinstance(self._layer, torch.nn.Linear):
                '''
                    in: (N, x_in)
                    out: (N, x_out)
                    weight: (x_out, x_in)
                '''
                self._output = _out
            elif isinstance(self._layer, torch.nn.Conv2d):
                '''
                    in: (N, C_in, H_in, W_in)
                    out: (N, C_out, H_out, W_out)
                    weight: (C_out, C_in, H_kernel, W_kernel)
                '''
                self._output = _out[0, self._filter_id]

        return self._layer.register_forward_hook(hook_function)

    def _weight(self):
        def _train():
            print('{}: {}'.format(self._name, self._layer))
            if isinstance(self._layer, torch.nn.Conv2d)\
            and self.filter_id is None:
                x = []
                _loss = 0
                self._n = self._layer.weight.size(0)
                for k in range(self._n):
                    self._filter_id = k
                    x.append(self._get_input_for_weight())
                    _loss += self._loss
                self._loss = _loss / self._n
                self._save(x)
            else:
                self._filter_id = self.filter_id
                x = self._get_input_for_weight()
                self._save(x)

        if self.layer_name == 'all':
            for (name, layer) in self.model.named_modules():
                if hasattr(layer, 'weight') and isinstance(
                        layer, torch.nn.BatchNorm2d) == False:
                    self._name, self._layer = name, layer
                    _train()
        else:
            self._name = self.layer_name
            self._layer = self.model.named_modules()[self.layer_name]
            _train()

    def _get_input_for_weight(self):
        handle = self.hook_layer()
        _msg = "Visual '{}'".format(self._name)
        if isinstance(self._layer, torch.nn.Conv2d):
            _msg += " , filter = {}/{}".format(self._filter_id + 1,
                                               self._layer.weight.size(0))
        # Generate a random image
        if type(self.input_dim) == int:
            random_image = np.uint8(np.random.uniform(0, 1,
                                                      (self.input_dim, )))
        else:
            random_image = np.uint8(
                np.random.uniform(
                    0, 1,
                    (self.input_dim[1], self.input_dim[2], self.input_dim[0])))
        # Process image and return variable
        processed_image = preprocess_image(random_image, self.ImageNet)
        # Define optimizer for the image
        #        optimizer = Adam([processed_image], lr=1e-1)
        self.optimizer = RMSprop([processed_image],
                                 lr=1e-2,
                                 alpha=0.9,
                                 eps=1e-10)
        for i in range(self.epoch):
            self.optimizer.zero_grad()
            # Assign create image to a variable to move forward in the model
            self.model.forward(processed_image)
            # Loss function is the mean of the output of the selected layer/filter
            # We try to minimize the mean of the output of that specific filter
            loss = -torch.mean(self._output)
            # Backward
            loss.backward()
            self._loss = loss.item()
            # Update image
            self.optimizer.step()
            # Recreate image
            _str = _msg + " | Epoch: {}/{}, Loss = {:.4f}".format(
                i + 1, self.epoch, self._loss)
            sys.stdout.write('\r' + _str)
            sys.stdout.flush()

        handle.remove()
        self.model.zero_grad()
        # Save image
        processed_image.requires_grad_(False)
        created_image = recreate_image(processed_image, self.ImageNet,
                                       self.reshape)
        return created_image

    def _get_input_for_category(self):

        self._n = self.model.train_Y.shape[1]
        images = []
        self._loss = 0
        for c in range(self._n):
            if type(self.input_dim) == int:
                random_image = np.random.uniform(0, 1, (self.input_dim, ))
            else:
                random_image = np.random.uniform(
                    0, 1,
                    (self.input_dim[1], self.input_dim[2], self.input_dim[0]))
            processed_image = preprocess_image(random_image, self.ImageNet)
            optimizer = RMSprop([processed_image],
                                lr=1e-2,
                                alpha=0.9,
                                eps=1e-10)
            label = np.zeros((self._n, ), dtype=np.float32)
            label[c] = 1
            label = torch.from_numpy(label)
            for i in range(self.epoch):
                optimizer.zero_grad()
                output = self.model.forward(processed_image)
                loss = self.model.L(output, label)
                loss.backward()
                _loss = loss.item()
                optimizer.step()
                _msg = "Visual feature for Category {}/{}".format(
                    c + 1, self._n)
                _str = _msg + " | Epoch: {}/{}, Loss = {:.4f}".format(
                    i + 1, self.epoch, _loss)
                sys.stdout.write('\r' + _str)
                sys.stdout.flush()
            self._loss += _loss
            self.model.zero_grad()
            processed_image.requires_grad_(False)
            created_image = recreate_image(processed_image, self.ImageNet,
                                           self.reshape)
            images.append(created_image)
        self._loss /= self._n
        self._layer_name = 'output'
        self._save(images)

    def _save(self, x):
        path = '../save/para/[' + self.model.name + ']/'
        if not os.path.exists(path): os.makedirs(path)
        if type(x) == list:
            file = self._layer_name + ' (vis), loss = {:.4f}'.format(
                self._loss)
            _save_multi_img(x, int(np.sqrt(self._n)), path=path + file)
        else:
            file = self._layer_name + ' (vis), loss = {:.4f}'.format(
                self._loss)
            _save_img(x, path=path + file)
        sys.stdout.write('\r')
        sys.stdout.flush()
        print("Visual saved in: " + path + file + " " * 25)
Ejemplo n.º 16
0
class Trainer(object):
    def __init__(self, config_path=None, **kwargs):

        # general
        self.run_name = None

        # code parameters
        self.code_length = None
        self.info_length = None
        self.crc_length = None  # only supports 16 bits in this setup
        self.clipping_val = None  # initialization absolute LLR value
        self.n_states = None

        # training hyperparameters
        self.num_of_minibatches = None
        self.train_minibatch_size = None
        self.train_SNR_start = None
        self.train_SNR_end = None
        self.train_num_SNR = None  # how many equally spaced values, including edges
        self.training_words_factor = None
        self.lr = None  # learning rate
        self.load_from_checkpoint = None  # loads last checkpoint, if exists in the run_name folder
        self.validation_minibatches_frequency = None  # validate every number of minibatches
        self.save_checkpoint_minibatches = None  # save checkpoint every

        # validation hyperparameters
        self.val_minibatch_size = None  # the more the merrier :)
        self.val_SNR_start = None
        self.val_SNR_end = None
        self.val_num_SNR = None  # how many equally spaced values
        self.thresh_errors = None  # monte-carlo error threshold per point

        # seed
        self.noise_seed = None
        self.word_seed = None

        # if any kwargs are passed, initialize the dict with them
        self.initialize_by_kwargs(**kwargs)

        # initializes all none parameters above from config
        self.param_parser(config_path)

        # initializes word and noise generator from seed
        self.rand_gen = np.random.RandomState(self.noise_seed)
        self.word_rand_gen = np.random.RandomState(self.word_seed)

        # initialize matrices, datasets and decoder
        self.start_minibatch = 0
        self.code_pcm, self.code_gm_inner, self.code_gm_outer, self.code_h_outer = load_code_parameters(
            self.code_length, self.crc_length, self.info_length)
        self.det_length = self.info_length + self.crc_length
        self.load_decoder()
        self.initialize_dataloaders()

    def initialize_by_kwargs(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)

    def param_parser(self, config_path: str):
        """
        Parse the config, load all attributes into the trainer
        :param config_path: path to config
        """
        if config_path is None:
            config_path = CONFIG_PATH

        with open(config_path) as f:
            self.config = yaml.load(f, Loader=yaml.FullLoader)

        # set attribute of Trainer with every config item
        for k, v in self.config.items():
            try:
                if getattr(self, k) is None:
                    setattr(self, k, v)
            except AttributeError:
                pass

        self.weights_dir = os.path.join(WEIGHTS_DIR, self.run_name)
        if not os.path.exists(self.weights_dir) and len(self.weights_dir):
            os.makedirs(self.weights_dir)
            # save config in output dir
            copyfile(config_path, os.path.join(self.weights_dir,
                                               "config.yaml"))

    def get_name(self):
        return self.__name__()

    def load_decoder(self):
        """
        Every trainer must have some base decoder model
        """
        self.decoder = None
        pass

    # calculate train loss
    def calc_loss(self, soft_estimation: torch.Tensor,
                  labels: torch.Tensor) -> torch.Tensor:
        """
         Every trainer must have some loss calculation
        """
        pass

    # setup the optimization algorithm
    def deep_learning_setup(self):
        """
        Sets up the optimizer and loss criterion
        """
        self.optimizer = RMSprop(filter(lambda p: p.requires_grad,
                                        self.decoder.parameters()),
                                 lr=self.lr)
        self.criterion = CrossEntropyLoss().to(device=device)

    def initialize_dataloaders(self):
        """
        Sets up the data loader - a generator from which we draw batches, in iterations
        """
        self.snr_range = {
            'train':
            np.linspace(self.train_SNR_start,
                        self.train_SNR_end,
                        num=self.train_num_SNR),
            'val':
            np.linspace(self.val_SNR_start,
                        self.val_SNR_end,
                        num=self.val_num_SNR)
        }
        self.batches_size = {
            'train': self.train_minibatch_size,
            'val': self.val_minibatch_size
        }
        self.channel_dataset = {
            phase: ChannelModelDataset(
                code_length=self.code_length,
                det_length=self.det_length,
                info_length=self.info_length,
                size_per_snr=self.batches_size[phase],
                snr_range=self.snr_range[phase],
                random=self.rand_gen,
                word_rand_gen=self.word_rand_gen,
                code_gm_inner=self.code_gm_inner,
                code_gm_outer=self.code_gm_outer,
                phase=phase,
                training_words_factor=self.training_words_factor,
                n_states=self.n_states)
            for phase in ['train', 'val']
        }
        self.dataloaders = {
            phase: torch.utils.data.DataLoader(self.channel_dataset[phase])
            for phase in ['train', 'val']
        }

    def load_last_checkpoint(self):
        """
        Loads decoder's weights from highest checkpoint in run_name
        """
        print(self.run_name)
        folder = os.path.join(os.path.join(WEIGHTS_DIR, self.run_name))
        names = []
        for file in os.listdir(folder):
            if file.startswith("checkpoint_"):
                names.append(int(file.split('.')[0].split('_')[1]))
        names.sort()
        if len(names) == 0:
            print("No checkpoints in run dir!!!")
            return

        self.start_minibatch = int(names[-1])
        if os.path.isfile(
                os.path.join(WEIGHTS_DIR, self.run_name,
                             f'checkpoint_{self.start_minibatch}.pt')):
            print(f'loading model from minibatch {self.start_minibatch}')
            checkpoint = torch.load(
                os.path.join(WEIGHTS_DIR, self.run_name,
                             f'checkpoint_{self.start_minibatch}.pt'))
            try:
                self.decoder.load_state_dict(checkpoint['model_state_dict'])
            except Exception:
                raise ValueError("Wrong run directory!!!")
        else:
            print(
                f'There is no checkpoint {self.start_minibatch} in run "{self.run_name}", starting from scratch'
            )

    def evaluate(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        Monte-Carlo simulation over validation SNRs range
        :return: ber, fer, iterations vectors
        """
        ber_total, fer_total = np.zeros(len(self.snr_range['val'])), np.zeros(
            len(self.snr_range['val']))
        iterations_total = np.zeros(len(self.snr_range['val']))
        with torch.no_grad():
            for snr_ind, snr in enumerate(self.snr_range['val']):
                err_count = 0
                runs_num = 0
                print(f'Starts evaluation at snr {snr}')
                start = time()
                # either stop when simulated enough errors, or reached a maximum number of runs
                while err_count < self.thresh_errors and runs_num < MAX_RUNS:
                    ber, fer, iterations, current_err_count = self.single_eval(
                        snr_ind)
                    ber_total[snr_ind] += ber
                    fer_total[snr_ind] += fer
                    iterations_total[snr_ind] += iterations
                    err_count += current_err_count
                    runs_num += 1.0

                ber_total[snr_ind] /= runs_num
                fer_total[snr_ind] /= runs_num
                iterations_total[snr_ind] /= runs_num
                print(
                    f'Done. time: {time() - start}, ber: {ber_total[snr_ind]}, fer: {fer_total[snr_ind]}, iterations: {iterations_total[snr_ind]}'
                )

        return ber_total, fer_total, iterations_total

    def single_eval(self, snr_ind: int) -> Tuple[float, float, float, int]:
        """
        Evaluation at a single snr.
        :param snr_ind: indice of snr in the snrs vector
        :return: ber and fer for batch, average iterations per word and number of errors in current batch
        """
        # create state_estimator_morning data
        transmitted_words, received_words = iter(
            self.channel_dataset['val'][snr_ind])
        transmitted_words = transmitted_words.to(device=device)
        received_words = received_words.to(device=device)

        # decode and calculate accuracy
        decoded_words = self.decoder(received_words, 'val')

        ber, fer, err_indices = calculate_error_rates(decoded_words,
                                                      transmitted_words)
        current_err_count = err_indices.shape[0]
        iterations = self.decoder.get_iterations()

        return ber, fer, iterations, current_err_count

    def train(self):
        """
        Main training loop. Runs in minibatches.
        Evaluates performance over validation SNRs.
        Saves weights every so and so iterations.
        """
        self.deep_learning_setup()
        self.evaluate()

        # batches loop
        for minibatch in range(self.start_minibatch,
                               self.num_of_minibatches + 1):
            print(f"Minibatch number - {str(minibatch)}")

            current_loss = 0

            # run single train loop
            current_loss += self.run_single_train_loop()

            print(f"Loss {current_loss}")

            # save weights
            if self.save_checkpoint_minibatches and minibatch % self.save_checkpoint_minibatches == 0:
                self.save_checkpoint(current_loss, minibatch)

            # evaluate performance
            if (minibatch + 1) % self.validation_minibatches_frequency == 0:
                self.evaluate()

    def run_single_train_loop(self) -> float:
        # draw words
        transmitted_words, received_words = iter(
            self.channel_dataset['train'][:len(self.snr_range['train'])])
        transmitted_words = transmitted_words.to(device=device)
        received_words = received_words.to(device=device)

        # pass through decoder
        soft_estimation = self.decoder(received_words, 'train')

        # calculate loss
        loss = self.calc_loss(soft_estimation=soft_estimation,
                              labels=transmitted_words)
        loss_val = loss.item()

        # if loss is Nan inform the user
        if torch.sum(torch.isnan(loss)):
            print('Nan value')

        # backpropagation
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss_val

    def save_checkpoint(self, current_loss: float, minibatch: int):
        torch.save(
            {
                'minibatch': minibatch,
                'model_state_dict': self.decoder.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'loss': current_loss,
                'lr': self.lr
            },
            os.path.join(self.weights_dir,
                         'checkpoint_' + str(minibatch) + '.pt'))
Ejemplo n.º 17
0
class DQNTrainer:
    def __init__(self, params, model_path):
        self.params = params
        self.model_path = model_path
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.current_q_net = DQN(input_shape=1,
                                 num_of_actions=get_action_space())
        self.current_q_net.to(self.device)
        self.target_q_net = DQN(input_shape=1,
                                num_of_actions=get_action_space())
        self.target_q_net.to(self.device)
        self.optimizer = RMSprop(self.current_q_net.parameters(),
                                 lr=self.params.lr)
        self.replay_memory = ReplayMemory(self.params.memory_capacity)
        game = "Breakout-ram-v0"
        env = gym.make(game)
        self.environment = EnvironmentWrapper(env, self.params.skip_steps)

    def run(self):
        state = torch.tensor(self.environment.reset(),
                             device=self.device,
                             dtype=torch.float32)
        self._update_target_q_net()
        total_reward = 0
        for step in range(int(self.params.num_of_steps)):
            q_value = self.current_q_net(torch.stack([state]))
            action_index, action = get_action(q_value,
                                              train=True,
                                              step=step,
                                              params=self.params,
                                              device=self.device)
            next_state, reward, done = self.environment.step(action)
            next_state = torch.tensor(next_state,
                                      device=self.device,
                                      dtype=torch.float32)
            self.replay_memory.add(state, action_index, reward, next_state,
                                   done)
            state = next_state
            total_reward += reward
            if done:
                state = torch.tensor(self.environment.reset(),
                                     device=self.device,
                                     dtype=torch.float32)
            if len(self.replay_memory.memory) > self.params.batch_size:
                loss = self._update_current_q_net()
                print('Update: {}. Loss: {}. Score: {}'.format(
                    step, loss, total_reward))
            if step % self.params.target_update_freq == 0:
                self._update_target_q_net()
        torch.save(self.target_q_net.state_dict(), self.model_path)

    def _update_current_q_net(self):
        batch = self.replay_memory.sample(self.params.batch_size)
        states, actions, rewards, next_states, dones = batch

        states = torch.stack(states)
        next_states = torch.stack(next_states)
        actions = torch.stack(actions).view(-1, 1)
        rewards = torch.tensor(rewards, device=self.device)
        dones = torch.tensor(dones, device=self.device, dtype=torch.float32)

        q_values = self.current_q_net(states).gather(1, actions)
        next_q_values = self.target_q_net(next_states).max(1)[0]

        expected_q_values = rewards + self.params.discount_factor * next_q_values * (
            1 - dones)
        loss = F.smooth_l1_loss(q_values, expected_q_values.unsqueeze(1))
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss

    def _update_target_q_net(self):
        self.target_q_net.load_state_dict(self.current_q_net.state_dict())
Ejemplo n.º 18
0
class DDQN_Agent:

    def __init__(self, args, exp_model, logging_func):
        self.args = args

        # Exploration Model
        self.exp_model = exp_model

        self.log = logging_func["log"]

        # Experience Replay
        if self.args.set_replay:
            self.replay = ExpReplaySet(10, 10, exp_model, args, priority=False)
        else:
            self.replay = ExpReplay(args.exp_replay_size, args.stale_limit, exp_model, args, priority=self.args.prioritized)

        # DQN and Target DQN
        model = get_models(args.model)
        self.dqn = model(actions=args.actions)
        self.target_dqn = model(actions=args.actions)

        dqn_params = 0
        for weight in self.dqn.parameters():
            weight_params = 1
            for s in weight.size():
                weight_params *= s
            dqn_params += weight_params
        print("DQN has {:,} parameters.".format(dqn_params))

        self.target_dqn.eval()

        if args.gpu:
            print("Moving models to GPU.")
            self.dqn.cuda()
            self.target_dqn.cuda()

        # Optimizer
        # self.optimizer = Adam(self.dqn.parameters(), lr=args.lr)
        self.optimizer = RMSprop(self.dqn.parameters(), lr=args.lr)

        self.T = 0
        self.target_sync_T = -self.args.t_max

    def sync_target_network(self):
        for target, source in zip(self.target_dqn.parameters(), self.dqn.parameters()):
            target.data = source.data

    def act(self, state, epsilon, exp_model, evaluation=False):
        # self.T += 1
        self.dqn.eval()
        orig_state = state[:, :, -1:]
        state = torch.from_numpy(state).float().transpose_(0, 2).unsqueeze(0)
        q_values = self.dqn(Variable(state, volatile=True)).cpu().data[0]
        q_values_numpy = q_values.numpy()

        extra_info = {}

        if self.args.optimistic_init and not evaluation:
            q_values_pre_bonus = np.copy(q_values_numpy)
            if not self.args.ucb:
                for a in range(self.args.actions):
                    _, info = exp_model.bonus(orig_state, a, dont_remember=True)
                    action_pseudo_count = info["Pseudo_Count"]
                    # TODO: Log the optimism bonuses
                    optimism_bonus = self.args.optimistic_scaler / np.power(action_pseudo_count + 0.01, self.args.bandit_p)
                    if self.args.tb and self.T % self.args.tb_interval == 0:
                        self.log("Bandit/Action_{}".format(a), optimism_bonus, step=self.T)
                    q_values[a] += optimism_bonus
            else:
                action_counts = []
                for a in range(self.args.actions):
                    _, info = exp_model.bonus(orig_state, a, dont_remember=True)
                    action_pseudo_count = info["Pseudo_Count"]
                    action_counts.append(action_pseudo_count)
                total_count = sum(action_counts)
                for ai, a in enumerate(action_counts):
                    # TODO: Log the optimism bonuses
                    optimisim_bonus = self.args.optimistic_scaler * np.sqrt(2 * np.log(max(1, total_count)) / (a + 0.01))
                    self.log("Bandit/UCB/Action_{}".format(ai), optimisim_bonus, step=self.T)
                    q_values[ai] += optimisim_bonus

            extra_info["Action_Bonus"] = q_values_numpy - q_values_pre_bonus

        extra_info["Q_Values"] = q_values_numpy

        if np.random.random() < epsilon:
            action = np.random.randint(low=0, high=self.args.actions)
        else:
            action = q_values.max(0)[1][0]  # Torch...

        extra_info["Action"] = action

        return action, extra_info

    def experience(self, state, action, reward, state_next, steps, terminated, pseudo_reward=0, density=1, exploring=False):
        if not exploring:
            self.T += 1
        self.replay.Add_Exp(state, action, reward, state_next, steps, terminated, pseudo_reward, density)

    def end_of_trajectory(self):
        self.replay.end_of_trajectory()

    def train(self):

        if self.T - self.target_sync_T > self.args.target:
            self.sync_target_network()
            self.target_sync_T = self.T

        info = {}

        for _ in range(self.args.iters):
            self.dqn.eval()

            # TODO: Use a named tuple for experience replay
            n_step_sample = 1
            if np.random.random() < self.args.n_step_mixing:
                n_step_sample = self.args.n_step
            batch, indices, is_weights = self.replay.Sample_N(self.args.batch_size, n_step_sample, self.args.gamma)
            columns = list(zip(*batch))

            states = Variable(torch.from_numpy(np.array(columns[0])).float().transpose_(1, 3))
            actions = Variable(torch.LongTensor(columns[1]))
            terminal_states = Variable(torch.FloatTensor(columns[5]))
            rewards = Variable(torch.FloatTensor(columns[2]))
            # Have to clip rewards for DQN
            rewards = torch.clamp(rewards, -1, 1)
            steps = Variable(torch.FloatTensor(columns[4]))
            new_states = Variable(torch.from_numpy(np.array(columns[3])).float().transpose_(1, 3))

            target_dqn_qvals = self.target_dqn(new_states).cpu()
            # Make a new variable with those values so that these are treated as constants
            target_dqn_qvals_data = Variable(target_dqn_qvals.data)

            q_value_targets = (Variable(torch.ones(terminal_states.size()[0])) - terminal_states)
            inter = Variable(torch.ones(terminal_states.size()[0]) * self.args.gamma)
            # print(steps)
            q_value_targets = q_value_targets * torch.pow(inter, steps)
            if self.args.double:
                # Double Q Learning
                new_states_qvals = self.dqn(new_states).cpu()
                new_states_qvals_data = Variable(new_states_qvals.data)
                q_value_targets = q_value_targets * target_dqn_qvals_data.gather(1, new_states_qvals_data.max(1)[1])
            else:
                q_value_targets = q_value_targets * target_dqn_qvals_data.max(1)[0]
            q_value_targets = q_value_targets + rewards

            self.dqn.train()
            if self.args.gpu:
                actions = actions.cuda()
                q_value_targets = q_value_targets.cuda()
            model_predictions = self.dqn(states).gather(1, actions.view(-1, 1))

            # info = {}

            td_error = model_predictions - q_value_targets
            info["TD_Error"] = td_error.mean().data[0]

            # Update the priorities
            if not self.args.density_priority:
                self.replay.Update_Indices(indices, td_error.cpu().data.numpy(), no_pseudo_in_priority=self.args.count_td_priority)

            # If using prioritised we need to weight the td_error
            if self.args.prioritized and self.args.prioritized_is:
                # print(td_error)
                weights_tensor = torch.from_numpy(is_weights).float()
                weights_tensor = Variable(weights_tensor)
                if self.args.gpu:
                    weights_tensor = weights_tensor.cuda()
                # print(weights_tensor)
                td_error = td_error * weights_tensor
            l2_loss = (td_error).pow(2).mean()
            info["Loss"] = l2_loss.data[0]

            # Update
            self.optimizer.zero_grad()
            l2_loss.backward()

            # Taken from pytorch clip_grad_norm
            # Remove once the pip version it up to date with source
            gradient_norm = clip_grad_norm(self.dqn.parameters(), self.args.clip_value)
            if gradient_norm is not None:
                info["Norm"] = gradient_norm

            self.optimizer.step()

            if "States" in info:
                states_trained = info["States"]
                info["States"] = states_trained + columns[0]
            else:
                info["States"] = columns[0]

        # Pad out the states to be of size batch_size
        if len(info["States"]) < self.args.batch_size:
            old_states = info["States"]
            new_states = old_states[0] * (self.args.batch_size - len(old_states))
            info["States"] = new_states

        return info
Ejemplo n.º 19
0
def main(args):
    # TODO checks
    if args.pretrained and args.dataset != 'ImageNet':
        raise NotImplementedError('Pretrained only implemented for ImageNet')
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if device == 'cpu' and args.verbose >= 1:
        print("[WARNING] No GPU available")
    # fix random seed
    if args.seed:
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)
    SAMPLING = args.optimizer in MCMC_OPTIMIZERS  # True: MCMC, False: DNN
    # create stats df
    df_metrics = pd.DataFrame(columns=['model_id', 'epoch', 'loss', 'acc_test', 'time'])
    # load data
    if args.dataset == 'MNIST':
        data = MNIST(batch_size=args.batch_size, num_workers=args.num_workers)
    elif args.dataset == 'CIFAR10':
        data = CIFAR10(batch_size=args.batch_size, num_workers=args.num_workers)
    elif args.dataset == 'CIFAR100':
        data = CIFAR100(batch_size=args.batch_size, num_workers=args.num_workers)
    elif args.dataset == 'ImageNet':
        data = ImageNet(batch_size=args.batch_size, num_workers=args.num_workers)
    else:
        raise NotImplementedError('Not supported dataset')
    # compute variables
    if SAMPLING:
        # convert burnin and sampling interval in number of epochs to numbers of mini-batch if specified
        if args.burnin_epochs:
            args.burnin = args.burnin_epochs * len(data.trainloader)
        if args.sampling_interval_epochs:
            args.sampling_interval = args.sampling_interval_epochs * len(data.trainloader)
        # compute number of epochs needed
        num_epochs = math.ceil((args.burnin + args.samples * args.sampling_interval) / len(data.trainloader))
        num_restarts = 1
        if args.verbose >= 2:
            print(f'Training for {num_epochs} epochs...')
    else:
        num_epochs = args.epochs
        num_restarts = args.models
    for i_model in range(num_restarts):
        if num_restarts > 1 and args.verbose >= 1:
            print(f'Training model #{i_model}...')
        # load model
        if args.dataset == 'MNIST' and args.architecture == 'CNN':
            model_class = MnistCnn
        elif args.dataset == 'MNIST' and args.architecture == 'FC':
            model_class = MnistFc
        elif args.dataset in ['CIFAR10', 'CIFAR100'] and args.architecture == 'LeNet':
            model_class = CifarLeNet
        elif args.dataset in ['CIFAR10', 'CIFAR100', 'ImageNet'] and args.architecture == 'resnet50':
            model_class = resnet50
        elif args.dataset in ['CIFAR10'] and args.architecture in PEMODELS_NAMES:
            model_class = getattr(pemodels, args.architecture)
        else:
            raise NotImplementedError('Unsupported architecture for this dataset.')
        if args.architecture in PEMODELS_NAMES:
            model = model_class.base(num_classes=data.num_classes, **model_class.kwargs)
        else:
            model = model_class(num_classes=data.num_classes, pretrained=args.pretrained)
        model.train()
        model.to(device)
        criterion = torch.nn.CrossEntropyLoss(reduction='sum')
        if USE_CUDA:
            criterion.cuda()
        # load optimizer
        weight_decay = 1 / (args.prior_sigma ** 2)
        if args.optimizer == 'SGD':
            optimizer = SGD(params=model.parameters(), lr=args.lr, weight_decay=weight_decay, momentum=0.9)
        elif args.optimizer == 'RMSprop':
            optimizer = RMSprop(params=model.parameters(), lr=args.lr, weight_decay=weight_decay, alpha=0.99)
        elif args.optimizer == 'Adam':
            optimizer = Adam(params=model.parameters(), lr=args.lr, weight_decay=weight_decay)
        elif args.optimizer == 'SGLD':
            optimizer = SGLD(params=model.parameters(), lr=args.lr, prior_sigma=args.prior_sigma)
        elif args.optimizer == 'pSGLD':
            optimizer = pSGLD(params=model.parameters(), lr=args.lr, prior_sigma=args.prior_sigma, alpha=0.99)
        else:
            raise NotImplementedError('Not supported optimizer')
        # set lr decay
        if args.lr_decay_on_plateau:
            scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=args.lr_decay_gamma, verbose=True)
        elif args.lr_decay:
            scheduler = StepLR(optimizer, step_size=args.lr_decay, gamma=args.lr_decay_gamma)
        # print stats if pretrained
        if args.pretrained:
            acc_test = data.compute_accuracy(model=model, train=False)
            if args.verbose >= 2:
                print(
                    f"Epoch 0/{num_epochs}, Loss: None, Accuracy: {acc_test * 100:.3f}, "
                    f"Time: 0 min")
            df_metrics = df_metrics.append(
                {'model_id': i_model, 'epoch': -1, 'loss': None, 'acc_test': acc_test,
                 'time': 0},
                ignore_index=True)
        i = 0  # index current mini-batch
        i_sample = 0  # index sample
        # time code
        if USE_CUDA:
            torch.cuda.synchronize()
        start_time = time.perf_counter()
        for epoch in range(num_epochs):
            running_loss = 0.0
            for j, (inputs, labels) in enumerate(data.trainloader):
                model.train()
                inputs, labels = inputs.to(device), labels.to(device)

                # forward + backward + optimize
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
                if list(model.parameters())[0].grad is None:
                    print(f'! mini-batch {j} none grad')
                if (list(model.parameters())[0].grad == 0).all():
                    print(f'! mini-batch {j} grad are all 0')

                # verbose
                if args.verbose >= 3 and (j % 1000) == 0:
                    print(f'[ {j} mini-batch / {len(data.trainloader)} ] partial loss: {loss.item():.3f}')

                # save sample
                model.eval()
                if SAMPLING:
                    if i >= args.burnin and (i - args.burnin) % args.sampling_interval == 0:
                        path_model, path_metrics = args2paths(args=args, index_model=i_sample)
                        torch.save(model.state_dict(), path_model)
                        i_sample += 1
                    if i_sample >= args.samples:
                        break  # stop current epochs if all samples are collected
                i += 1

            # print statistics
            running_loss /= len(data.trainloader)
            acc_train = data.compute_accuracy(model=model, train=True)
            acc_test = data.compute_accuracy(model=model, train=False)
            # time code
            if USE_CUDA:
                torch.cuda.synchronize()
            end_time = time.perf_counter()
            if args.verbose >= 2 or (epoch + 1 == num_epochs):
                print(
                    f"Epoch {epoch + 1}/{num_epochs}, Loss: {running_loss:.3f}, Accuracy train: {acc_train * 100:.3f}, Accuracy test: {acc_test * 100:.3f}, "
                    f"Time: {(end_time - start_time) / 60:.3f} min")
            df_metrics = df_metrics.append(
                {'model_id': i_model, 'epoch': epoch, 'loss': running_loss, 'acc_test': acc_test,
                 'time': end_time - start_time},
                ignore_index=True)

            # decay LR if set
            if args.lr_decay_on_plateau:
                scheduler.step(acc_test)
            elif args.lr_decay:
                scheduler.step()
        # save final model and metrics
        if not SAMPLING:
            if num_restarts == 1:
                path_model, path_metrics = args2paths(args=args, index_model=None)
            else:
                path_model, path_metrics = args2paths(args=args, index_model=i_model)
            torch.save(model.state_dict(), path_model)
    df_metrics.to_csv(path_metrics)
    # compute ensemble accuracy
    if SAMPLING or (num_restarts > 1):
        models_dir = os.path.split(path_model)[0]
        list_models = load_list_models(models_dir=models_dir, class_model=model_class, device=device)
        model_ens = TorchEnsemble(models=list_models, ensemble_logits=False)
        model_ens.to(device)
        acc_ens = data.compute_accuracy(model=model_ens, train=False)
        del model_ens
        if USE_CUDA:
            torch.cuda.empty_cache()
        model_ens2 = TorchEnsemble(models=list_models, ensemble_logits=True)
        model_ens2.to(device)
        acc_ens2 = data.compute_accuracy(model=model_ens2, train=False)
        if args.verbose >= 1:
            print(f'Accuracy ensemble probs : {acc_ens * 100:.3f} \n'
                  f'Accuracy ensemble logits: {acc_ens2 * 100:.3f}')
Ejemplo n.º 20
0
class QMixTorchPolicy(Policy):
    """QMix impl. Assumes homogeneous agents for now.

    You must use MultiAgentEnv.with_agent_groups() to group agents
    together for QMix. This creates the proper Tuple obs/action spaces and
    populates the '_group_rewards' info field.

    Action masking: to specify an action mask for individual agents, use a
    dict space with an action_mask key, e.g. {"obs": ob, "action_mask": mask}.
    The mask space must be `Box(0, 1, (n_actions,))`.
    """
    def __init__(self, obs_space, action_space, config):
        _validate(obs_space, action_space)
        config = dict(ray.rllib.agents.qmix.qmix.DEFAULT_CONFIG, **config)
        self.framework = "torch"
        super().__init__(obs_space, action_space, config)
        self.n_agents = len(obs_space.original_space.spaces)
        config["model"]["n_agents"] = self.n_agents
        self.n_actions = action_space.spaces[0].n
        self.h_size = config["model"]["lstm_cell_size"]
        self.has_env_global_state = False
        self.has_action_mask = False
        self.device = (torch.device("cuda")
                       if torch.cuda.is_available() else torch.device("cpu"))

        agent_obs_space = obs_space.original_space.spaces[0]
        if isinstance(agent_obs_space, Dict):
            space_keys = set(agent_obs_space.spaces.keys())
            if "obs" not in space_keys:
                raise ValueError(
                    "Dict obs space must have subspace labeled `obs`")
            self.obs_size = _get_size(agent_obs_space.spaces["obs"])
            if "action_mask" in space_keys:
                mask_shape = tuple(agent_obs_space.spaces["action_mask"].shape)
                if mask_shape != (self.n_actions, ):
                    raise ValueError(
                        "Action mask shape must be {}, got {}".format(
                            (self.n_actions, ), mask_shape))
                self.has_action_mask = True
            if ENV_STATE in space_keys:
                self.env_global_state_shape = _get_size(
                    agent_obs_space.spaces[ENV_STATE])
                self.has_env_global_state = True
            else:
                self.env_global_state_shape = (self.obs_size, self.n_agents)
            # The real agent obs space is nested inside the dict
            config["model"]["full_obs_space"] = agent_obs_space
            agent_obs_space = agent_obs_space.spaces["obs"]
        else:
            self.obs_size = _get_size(agent_obs_space)
            self.env_global_state_shape = (self.obs_size, self.n_agents)

        self.model = ModelCatalog.get_model_v2(
            agent_obs_space,
            action_space.spaces[0],
            self.n_actions,
            config["model"],
            framework="torch",
            name="model",
            default_model=RNNModel,
        ).to(self.device)

        self.target_model = ModelCatalog.get_model_v2(
            agent_obs_space,
            action_space.spaces[0],
            self.n_actions,
            config["model"],
            framework="torch",
            name="target_model",
            default_model=RNNModel,
        ).to(self.device)

        self.exploration = self._create_exploration()

        # Setup the mixer network.
        if config["mixer"] is None:
            self.mixer = None
            self.target_mixer = None
        elif config["mixer"] == "qmix":
            self.mixer = QMixer(self.n_agents, self.env_global_state_shape,
                                config["mixing_embed_dim"]).to(self.device)
            self.target_mixer = QMixer(
                self.n_agents, self.env_global_state_shape,
                config["mixing_embed_dim"]).to(self.device)
        elif config["mixer"] == "vdn":
            self.mixer = VDNMixer().to(self.device)
            self.target_mixer = VDNMixer().to(self.device)
        else:
            raise ValueError("Unknown mixer type {}".format(config["mixer"]))

        self.cur_epsilon = 1.0
        self.update_target()  # initial sync

        # Setup optimizer
        self.params = list(self.model.parameters())
        if self.mixer:
            self.params += list(self.mixer.parameters())
        self.loss = QMixLoss(
            self.model,
            self.target_model,
            self.mixer,
            self.target_mixer,
            self.n_agents,
            self.n_actions,
            self.config["double_q"],
            self.config["gamma"],
        )
        from torch.optim import RMSprop

        self.optimiser = RMSprop(
            params=self.params,
            lr=config["lr"],
            alpha=config["optim_alpha"],
            eps=config["optim_eps"],
        )

    @override(Policy)
    def compute_actions(self,
                        obs_batch,
                        state_batches=None,
                        prev_action_batch=None,
                        prev_reward_batch=None,
                        info_batch=None,
                        episodes=None,
                        explore=None,
                        timestep=None,
                        **kwargs):
        explore = explore if explore is not None else self.config["explore"]
        obs_batch, action_mask, _ = self._unpack_observation(obs_batch)
        # We need to ensure we do not use the env global state
        # to compute actions

        # Compute actions
        with torch.no_grad():
            q_values, hiddens = _mac(
                self.model,
                torch.as_tensor(obs_batch,
                                dtype=torch.float,
                                device=self.device),
                [
                    torch.as_tensor(
                        np.array(s), dtype=torch.float, device=self.device)
                    for s in state_batches
                ],
            )
            avail = torch.as_tensor(action_mask,
                                    dtype=torch.float,
                                    device=self.device)
            masked_q_values = q_values.clone()
            masked_q_values[avail == 0.0] = -float("inf")
            masked_q_values_folded = torch.reshape(
                masked_q_values, [-1] + list(masked_q_values.shape)[2:])
            actions, _ = self.exploration.get_exploration_action(
                action_distribution=TorchCategorical(masked_q_values_folded),
                timestep=timestep,
                explore=explore,
            )
            actions = (torch.reshape(
                actions,
                list(masked_q_values.shape)[:-1]).cpu().numpy())
            hiddens = [s.cpu().numpy() for s in hiddens]

        return tuple(actions.transpose([1, 0])), hiddens, {}

    @override(Policy)
    def compute_log_likelihoods(
        self,
        actions,
        obs_batch,
        state_batches=None,
        prev_action_batch=None,
        prev_reward_batch=None,
    ):
        obs_batch, action_mask, _ = self._unpack_observation(obs_batch)
        return np.zeros(obs_batch.size()[0])

    @override(Policy)
    def learn_on_batch(self, samples):
        obs_batch, action_mask, env_global_state = self._unpack_observation(
            samples[SampleBatch.CUR_OBS])
        (
            next_obs_batch,
            next_action_mask,
            next_env_global_state,
        ) = self._unpack_observation(samples[SampleBatch.NEXT_OBS])
        group_rewards = self._get_group_rewards(samples[SampleBatch.INFOS])

        input_list = [
            group_rewards,
            action_mask,
            next_action_mask,
            samples[SampleBatch.ACTIONS],
            samples[SampleBatch.DONES],
            obs_batch,
            next_obs_batch,
        ]
        if self.has_env_global_state:
            input_list.extend([env_global_state, next_env_global_state])

        output_list, _, seq_lens = chop_into_sequences(
            episode_ids=samples[SampleBatch.EPS_ID],
            unroll_ids=samples[SampleBatch.UNROLL_ID],
            agent_indices=samples[SampleBatch.AGENT_INDEX],
            feature_columns=input_list,
            state_columns=[],  # RNN states not used here
            max_seq_len=self.config["model"]["max_seq_len"],
            dynamic_max=True,
        )
        # These will be padded to shape [B * T, ...]
        if self.has_env_global_state:
            (
                rew,
                action_mask,
                next_action_mask,
                act,
                dones,
                obs,
                next_obs,
                env_global_state,
                next_env_global_state,
            ) = output_list
        else:
            (
                rew,
                action_mask,
                next_action_mask,
                act,
                dones,
                obs,
                next_obs,
            ) = output_list
        B, T = len(seq_lens), max(seq_lens)

        def to_batches(arr, dtype):
            new_shape = [B, T] + list(arr.shape[1:])
            return torch.as_tensor(np.reshape(arr, new_shape),
                                   dtype=dtype,
                                   device=self.device)

        rewards = to_batches(rew, torch.float)
        actions = to_batches(act, torch.long)
        obs = to_batches(obs, torch.float).reshape(
            [B, T, self.n_agents, self.obs_size])
        action_mask = to_batches(action_mask, torch.float)
        next_obs = to_batches(next_obs, torch.float).reshape(
            [B, T, self.n_agents, self.obs_size])
        next_action_mask = to_batches(next_action_mask, torch.float)
        if self.has_env_global_state:
            env_global_state = to_batches(env_global_state, torch.float)
            next_env_global_state = to_batches(next_env_global_state,
                                               torch.float)

        # TODO(ekl) this treats group termination as individual termination
        terminated = (to_batches(dones, torch.float).unsqueeze(2).expand(
            B, T, self.n_agents))

        # Create mask for where index is < unpadded sequence length
        filled = np.reshape(np.tile(np.arange(T, dtype=np.float32), B),
                            [B, T]) < np.expand_dims(seq_lens, 1)
        mask = (torch.as_tensor(filled, dtype=torch.float,
                                device=self.device).unsqueeze(2).expand(
                                    B, T, self.n_agents))

        # Compute loss
        loss_out, mask, masked_td_error, chosen_action_qvals, targets = self.loss(
            rewards,
            actions,
            terminated,
            mask,
            obs,
            next_obs,
            action_mask,
            next_action_mask,
            env_global_state,
            next_env_global_state,
        )

        # Optimise
        self.optimiser.zero_grad()
        loss_out.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(
            self.params, self.config["grad_norm_clipping"])
        self.optimiser.step()

        mask_elems = mask.sum().item()
        stats = {
            "loss":
            loss_out.item(),
            "grad_norm":
            grad_norm if isinstance(grad_norm, float) else grad_norm.item(),
            "td_error_abs":
            masked_td_error.abs().sum().item() / mask_elems,
            "q_taken_mean":
            (chosen_action_qvals * mask).sum().item() / mask_elems,
            "target_mean": (targets * mask).sum().item() / mask_elems,
        }
        return {LEARNER_STATS_KEY: stats}

    @override(Policy)
    def get_initial_state(self):  # initial RNN state
        return [
            s.expand([self.n_agents, -1]).cpu().numpy()
            for s in self.model.get_initial_state()
        ]

    @override(Policy)
    def get_weights(self):
        return {
            "model":
            self._cpu_dict(self.model.state_dict()),
            "target_model":
            self._cpu_dict(self.target_model.state_dict()),
            "mixer":
            self._cpu_dict(self.mixer.state_dict()) if self.mixer else None,
            "target_mixer":
            self._cpu_dict(self.target_mixer.state_dict())
            if self.mixer else None,
        }

    @override(Policy)
    def set_weights(self, weights):
        self.model.load_state_dict(self._device_dict(weights["model"]))
        self.target_model.load_state_dict(
            self._device_dict(weights["target_model"]))
        if weights["mixer"] is not None:
            self.mixer.load_state_dict(self._device_dict(weights["mixer"]))
            self.target_mixer.load_state_dict(
                self._device_dict(weights["target_mixer"]))

    @override(Policy)
    def get_state(self):
        state = self.get_weights()
        state["cur_epsilon"] = self.cur_epsilon
        return state

    @override(Policy)
    def set_state(self, state):
        self.set_weights(state)
        self.set_epsilon(state["cur_epsilon"])

    def update_target(self):
        self.target_model.load_state_dict(self.model.state_dict())
        if self.mixer is not None:
            self.target_mixer.load_state_dict(self.mixer.state_dict())
        logger.debug("Updated target networks")

    def set_epsilon(self, epsilon):
        self.cur_epsilon = epsilon

    def _get_group_rewards(self, info_batch):
        group_rewards = np.array([
            info.get(GROUP_REWARDS, [0.0] * self.n_agents)
            for info in info_batch
        ])
        return group_rewards

    def _device_dict(self, state_dict):
        return {
            k: torch.as_tensor(v, device=self.device)
            for k, v in state_dict.items()
        }

    @staticmethod
    def _cpu_dict(state_dict):
        return {k: v.cpu().detach().numpy() for k, v in state_dict.items()}

    def _unpack_observation(self, obs_batch):
        """Unpacks the observation, action mask, and state (if present)
        from agent grouping.

        Returns:
            obs (np.ndarray): obs tensor of shape [B, n_agents, obs_size]
            mask (np.ndarray): action mask, if any
            state (np.ndarray or None): state tensor of shape [B, state_size]
                or None if it is not in the batch
        """

        unpacked = _unpack_obs(
            np.array(obs_batch, dtype=np.float32),
            self.observation_space.original_space,
            tensorlib=np,
        )

        if isinstance(unpacked[0], dict):
            assert "obs" in unpacked[0]
            unpacked_obs = [
                np.concatenate(tree.flatten(u["obs"]), 1) for u in unpacked
            ]
        else:
            unpacked_obs = unpacked

        obs = np.concatenate(unpacked_obs, axis=1).reshape(
            [len(obs_batch), self.n_agents, self.obs_size])

        if self.has_action_mask:
            action_mask = np.concatenate([o["action_mask"] for o in unpacked],
                                         axis=1).reshape([
                                             len(obs_batch), self.n_agents,
                                             self.n_actions
                                         ])
        else:
            action_mask = np.ones(
                [len(obs_batch), self.n_agents, self.n_actions],
                dtype=np.float32)

        if self.has_env_global_state:
            state = np.concatenate(tree.flatten(unpacked[0][ENV_STATE]), 1)
        else:
            state = None
        return obs, action_mask, state
Ejemplo n.º 21
0
                                                      shuffle=False, num_workers=1)
    else:
        dataloader = torch.utils.data.DataLoader(CELEBA_SLURM(train_folder), batch_size=64,
                                                 shuffle=True, num_workers=1)
        # DATASET for test
        # if you want to split train from test just move some files in another dir
        dataloader_test = torch.utils.data.DataLoader(CELEBA_SLURM(test_folder), batch_size=100,
                                                      shuffle=False, num_workers=1)
    #margin and equilibirum
    margin = 0.35
    equilibrium = 0.68
    #mse_lambda = 1.0
    # OPTIM-LOSS
    # an optimizer for each of the sub-networks, so we can selectively backprop
    #optimizer_encoder = Adam(params=net.encoder.parameters(),lr = lr,betas=(0.9,0.999))
    optimizer_encoder = RMSprop(params=net.encoder.parameters(),lr=lr,alpha=0.9,eps=1e-8,weight_decay=0,momentum=0,centered=False)
    #lr_encoder = MultiStepLR(optimizer_encoder,milestones=[2],gamma=1)
    lr_encoder = ExponentialLR(optimizer_encoder, gamma=decay_lr)
    #optimizer_decoder = Adam(params=net.decoder.parameters(),lr = lr,betas=(0.9,0.999))
    optimizer_decoder = RMSprop(params=net.decoder.parameters(),lr=lr,alpha=0.9,eps=1e-8,weight_decay=0,momentum=0,centered=False)
    lr_decoder = ExponentialLR(optimizer_decoder, gamma=decay_lr)
    #lr_decoder = MultiStepLR(optimizer_decoder,milestones=[2],gamma=1)
    #optimizer_discriminator = Adam(params=net.discriminator.parameters(),lr = lr,betas=(0.9,0.999))
    optimizer_discriminator = RMSprop(params=net.discriminator.parameters(),lr=lr,alpha=0.9,eps=1e-8,weight_decay=0,momentum=0,centered=False)
    lr_discriminator = ExponentialLR(optimizer_discriminator, gamma=decay_lr)
    #lr_discriminator = MultiStepLR(optimizer_discriminator,milestones=[2],gamma=1)

    batch_number = len(dataloader)
    step_index = 0
    widgets = [
Ejemplo n.º 22
0
    def __init__(self, obs_space, action_space, config):
        _validate(obs_space, action_space)
        config = dict(ray.rllib.agents.qmix.qmix.DEFAULT_CONFIG, **config)
        self.framework = "torch"
        super().__init__(obs_space, action_space, config)
        self.n_agents = len(obs_space.original_space.spaces)
        config["model"]["n_agents"] = self.n_agents
        self.n_actions = action_space.spaces[0].n
        self.h_size = config["model"]["lstm_cell_size"]
        self.has_env_global_state = False
        self.has_action_mask = False
        self.device = (torch.device("cuda")
                       if torch.cuda.is_available() else torch.device("cpu"))

        agent_obs_space = obs_space.original_space.spaces[0]
        if isinstance(agent_obs_space, Dict):
            space_keys = set(agent_obs_space.spaces.keys())
            if "obs" not in space_keys:
                raise ValueError(
                    "Dict obs space must have subspace labeled `obs`")
            self.obs_size = _get_size(agent_obs_space.spaces["obs"])
            if "action_mask" in space_keys:
                mask_shape = tuple(agent_obs_space.spaces["action_mask"].shape)
                if mask_shape != (self.n_actions, ):
                    raise ValueError(
                        "Action mask shape must be {}, got {}".format(
                            (self.n_actions, ), mask_shape))
                self.has_action_mask = True
            if ENV_STATE in space_keys:
                self.env_global_state_shape = _get_size(
                    agent_obs_space.spaces[ENV_STATE])
                self.has_env_global_state = True
            else:
                self.env_global_state_shape = (self.obs_size, self.n_agents)
            # The real agent obs space is nested inside the dict
            config["model"]["full_obs_space"] = agent_obs_space
            agent_obs_space = agent_obs_space.spaces["obs"]
        else:
            self.obs_size = _get_size(agent_obs_space)
            self.env_global_state_shape = (self.obs_size, self.n_agents)

        self.model = ModelCatalog.get_model_v2(
            agent_obs_space,
            action_space.spaces[0],
            self.n_actions,
            config["model"],
            framework="torch",
            name="model",
            default_model=RNNModel,
        ).to(self.device)

        self.target_model = ModelCatalog.get_model_v2(
            agent_obs_space,
            action_space.spaces[0],
            self.n_actions,
            config["model"],
            framework="torch",
            name="target_model",
            default_model=RNNModel,
        ).to(self.device)

        self.exploration = self._create_exploration()

        # Setup the mixer network.
        if config["mixer"] is None:
            self.mixer = None
            self.target_mixer = None
        elif config["mixer"] == "qmix":
            self.mixer = QMixer(self.n_agents, self.env_global_state_shape,
                                config["mixing_embed_dim"]).to(self.device)
            self.target_mixer = QMixer(
                self.n_agents, self.env_global_state_shape,
                config["mixing_embed_dim"]).to(self.device)
        elif config["mixer"] == "vdn":
            self.mixer = VDNMixer().to(self.device)
            self.target_mixer = VDNMixer().to(self.device)
        else:
            raise ValueError("Unknown mixer type {}".format(config["mixer"]))

        self.cur_epsilon = 1.0
        self.update_target()  # initial sync

        # Setup optimizer
        self.params = list(self.model.parameters())
        if self.mixer:
            self.params += list(self.mixer.parameters())
        self.loss = QMixLoss(
            self.model,
            self.target_model,
            self.mixer,
            self.target_mixer,
            self.n_agents,
            self.n_actions,
            self.config["double_q"],
            self.config["gamma"],
        )
        from torch.optim import RMSprop

        self.optimiser = RMSprop(
            params=self.params,
            lr=config["lr"],
            alpha=config["optim_alpha"],
            eps=config["optim_eps"],
        )
Ejemplo n.º 23
0
    def __init__(self, mac, scheme, logger, args):
        self.args = args
        self.mac = mac
        self.logger = logger

        self.mac_params = list(mac.parameters())
        self.params = list(self.mac.parameters())

        self.last_target_update_episode = 0

        self.mixer = None
        assert args.mixer is not None
        if args.mixer is not None:
            if args.mixer == "vdn":
                self.mixer = VDNMixer()
            elif args.mixer == "qmix":
                self.mixer = QMixer(args)
            elif args.mixer == "qmix_cnn":
                self.mixer = QMixer_CNN(args)
            else:
                raise ValueError("Mixer {} not recognised.".format(args.mixer))
            self.mixer_params = list(self.mixer.parameters())
            self.params += list(self.mixer.parameters())
            self.target_mixer = copy.deepcopy(self.mixer)

        # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC
        self.target_mac = copy.deepcopy(mac)

        # Central Q
        # TODO: Clean this mess up!
        self.central_mac = None
        if self.args.central_mixer in ["ff", "atten"]:
            if self.args.central_loss == 0:
                self.central_mixer = self.mixer
                self.central_mac = self.mac
                self.target_central_mac = self.target_mac
            else:
                if self.args.central_mixer == "ff":
                    self.central_mixer = QMixerCentralFF(
                        args
                    )  # Feedforward network that takes state and agent utils as input
                elif self.args.central_mixer == "atten":
                    self.central_mixer = QMixerCentralAtten(args)
                else:
                    raise Exception("Error with central_mixer")

                assert args.central_mac == "basic_central_mac"
                self.central_mac = mac_REGISTRY[args.central_mac](
                    scheme, args
                )  # Groups aren't used in the CentralBasicController. Little hacky
                self.target_central_mac = copy.deepcopy(self.central_mac)
                self.params += list(self.central_mac.parameters())
        else:
            raise Exception("Error with qCentral")
        self.params += list(self.central_mixer.parameters())
        self.target_central_mixer = copy.deepcopy(self.central_mixer)

        self.optimiser = RMSprop(params=self.params,
                                 lr=args.lr,
                                 alpha=args.optim_alpha,
                                 eps=args.optim_eps)

        self.log_stats_t = -self.args.learner_log_interval - 1

        self.grad_norm = 1
        self.mixer_norm = 1
        self.mixer_norms = deque([1], maxlen=100)
Ejemplo n.º 24
0
def main():

    # load data
    data_path = '../data/output/records_final.pkl'
    voc_path = '../data/output/voc_final.pkl'

    ddi_adj_path = '../data/output/ddi_A_final.pkl'
    device = torch.device('cuda:{}'.format(args.cuda))

    ddi_adj = dill.load(open(ddi_adj_path, 'rb'))
    data = dill.load(open(data_path, 'rb'))

    voc = dill.load(open(voc_path, 'rb'))
    diag_voc, pro_voc, med_voc = voc['diag_voc'], voc['pro_voc'], voc[
        'med_voc']

    np.random.seed(1203)
    np.random.shuffle(data)

    split_point = int(len(data) * 3 / 5)
    data_train = data[:split_point]
    eval_len = int(len(data[split_point:]) / 2)
    data_test = data[split_point:split_point + eval_len]
    data_eval = data[split_point + eval_len:]

    voc_size = (len(diag_voc.idx2word), len(pro_voc.idx2word),
                len(med_voc.idx2word))

    model = MICRON(voc_size, ddi_adj, emb_dim=args.dim, device=device)
    # model.load_state_dict(torch.load(open(args.resume_path, 'rb')))

    if args.Test:
        model.load_state_dict(torch.load(open(args.resume_path, 'rb')))
        model.to(device=device)
        tic = time.time()
        label_list, prob_list = eval(model, data_eval, voc_size, 0, 1)

        threshold1, threshold2 = [], []
        for i in range(label_list.shape[1]):
            _, _, boundary = roc_curve(label_list[:, i],
                                       prob_list[:, i],
                                       pos_label=1)
            # boundary1 should be in [0.5, 0.9], boundary2 should be in [0.1, 0.5]
            threshold1.append(
                min(
                    0.9,
                    max(0.5, boundary[max(0,
                                          round(len(boundary) * 0.05) - 1)])))
            threshold2.append(
                max(
                    0.1,
                    min(
                        0.5, boundary[min(round(len(boundary) * 0.95),
                                          len(boundary) - 1)])))
        print(np.mean(threshold1), np.mean(threshold2))
        threshold1 = np.ones(voc_size[2]) * np.mean(threshold1)
        threshold2 = np.ones(voc_size[2]) * np.mean(threshold2)
        eval(model, data_test, voc_size, 0, 0, threshold1, threshold2)
        print('test time: {}'.format(time.time() - tic))

        return

    model.to(device=device)
    print('parameters', get_n_params(model))
    # exit()
    optimizer = RMSprop(list(model.parameters()),
                        lr=args.lr,
                        weight_decay=args.weight_decay)

    # start iterations
    history = defaultdict(list)
    best_epoch, best_ja = 0, 0

    weight_list = [[0.25, 0.25, 0.25, 0.25]]

    EPOCH = 40
    for epoch in range(EPOCH):
        t = 0
        tic = time.time()
        print('\nepoch {} --------------------------'.format(epoch + 1))

        sample_counter = 0
        mean_loss = np.array([0, 0, 0, 0])

        model.train()
        for step, input in enumerate(data_train):
            loss = 0
            if len(input) < 2: continue
            for adm_idx, adm in enumerate(input):
                if adm_idx == 0: continue
                # sample_counter += 1
                seq_input = input[:adm_idx + 1]

                loss_bce_target = np.zeros((1, voc_size[2]))
                loss_bce_target[:, adm[2]] = 1

                loss_bce_target_last = np.zeros((1, voc_size[2]))
                loss_bce_target_last[:, input[adm_idx - 1][2]] = 1

                loss_multi_target = np.full((1, voc_size[2]), -1)
                for idx, item in enumerate(adm[2]):
                    loss_multi_target[0][idx] = item

                loss_multi_target_last = np.full((1, voc_size[2]), -1)
                for idx, item in enumerate(input[adm_idx - 1][2]):
                    loss_multi_target_last[0][idx] = item

                result, result_last, _, loss_ddi, loss_rec = model(seq_input)

                loss_bce = 0.75 * F.binary_cross_entropy_with_logits(result, torch.FloatTensor(loss_bce_target).to(device)) + \
                    (1 - 0.75) * F.binary_cross_entropy_with_logits(result_last, torch.FloatTensor(loss_bce_target_last).to(device))
                loss_multi = 5e-2 * (0.75 * F.multilabel_margin_loss(F.sigmoid(result), torch.LongTensor(loss_multi_target).to(device)) + \
                    (1 - 0.75) * F.multilabel_margin_loss(F.sigmoid(result_last), torch.LongTensor(loss_multi_target_last).to(device)))

                y_pred_tmp = F.sigmoid(result).detach().cpu().numpy()[0]
                y_pred_tmp[y_pred_tmp >= 0.5] = 1
                y_pred_tmp[y_pred_tmp < 0.5] = 0
                y_label = np.where(y_pred_tmp == 1)[0]
                current_ddi_rate = ddi_rate_score(
                    [[y_label]], path='../data/output/ddi_A_final.pkl')

                # l2 = 0
                # for p in model.parameters():
                #     l2 = l2 + (p ** 2).sum()

                if sample_counter == 0:
                    lambda1, lambda2, lambda3, lambda4 = weight_list[-1]
                else:
                    current_loss = np.array([
                        loss_bce.detach().cpu().numpy(),
                        loss_multi.detach().cpu().numpy(),
                        loss_ddi.detach().cpu().numpy(),
                        loss_rec.detach().cpu().numpy()
                    ])
                    current_ratio = (current_loss -
                                     np.array(mean_loss)) / np.array(mean_loss)
                    instant_weight = np.exp(current_ratio) / sum(
                        np.exp(current_ratio))
                    lambda1, lambda2, lambda3, lambda4 = instant_weight * 0.75 + np.array(
                        weight_list[-1]) * 0.25
                    # update weight_list
                    weight_list.append([lambda1, lambda2, lambda3, lambda4])
                # update mean_loss
                mean_loss = (mean_loss * (sample_counter - 1) + np.array([loss_bce.detach().cpu().numpy(), \
                    loss_multi.detach().cpu().numpy(), loss_ddi.detach().cpu().numpy(), loss_rec.detach().cpu().numpy()])) / sample_counter
                # lambda1, lambda2, lambda3, lambda4 = weight_list[-1]
                if current_ddi_rate > 0.08:
                    loss += lambda1 * loss_bce + lambda2 * loss_multi + \
                                 lambda3 * loss_ddi +  lambda4 * loss_rec
                else:
                    loss += lambda1 * loss_bce + lambda2 * loss_multi + \
                                lambda4 * loss_rec

            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()

            llprint('\rtraining step: {} / {}'.format(step, len(data_train)))

        tic2 = time.time()
        ddi_rate, ja, prauc, avg_p, avg_r, avg_f1, add, delete, avg_med = eval(
            model, data_eval, voc_size, epoch)
        print('training time: {}, test time: {}'.format(
            time.time() - tic,
            time.time() - tic2))

        history['ja'].append(ja)
        history['ddi_rate'].append(ddi_rate)
        history['avg_p'].append(avg_p)
        history['avg_r'].append(avg_r)
        history['avg_f1'].append(avg_f1)
        history['prauc'].append(prauc)
        history['add'].append(add)
        history['delete'].append(delete)
        history['med'].append(avg_med)

        if epoch >= 5:
            print(
                'ddi: {}, Med: {}, Ja: {}, F1: {}, Add: {}, Delete: {}'.format(
                    np.mean(history['ddi_rate'][-5:]),
                    np.mean(history['med'][-5:]), np.mean(history['ja'][-5:]),
                    np.mean(history['avg_f1'][-5:]),
                    np.mean(history['add'][-5:]),
                    np.mean(history['delete'][-5:])))

        torch.save(model.state_dict(), open(os.path.join('saved', args.model_name, \
            'Epoch_{}_JA_{:.4}_DDI_{:.4}.model'.format(epoch, ja, ddi_rate)), 'wb'))

        if epoch != 0 and best_ja < ja:
            best_epoch = epoch
            best_ja = ja

        print('best_epoch: {}'.format(best_epoch))

    dill.dump(
        history,
        open(
            os.path.join('saved', args.model_name,
                         'history_{}.pkl'.format(args.model_name)), 'wb'))
Ejemplo n.º 25
0
    def __init__(self, cfg):
        super().__init__(cfg)
        self.cfg = cfg
        self.device = get_device()

        # define models
        # inception_v3 input size = ( N x 3 x 299 x 299 )
        self.model = inception_v3(pretrained=True, num_classes=cfg.num_classes)

        # define data_loader
        self.dataset = inception_data(cfg).get_dataset()
        tr_size = int(cfg.train_test_ratio * len(self.dataset))
        te_size = len(self.dataset) - tr_size
        tr_dataset, te_dataset = random_split(self.dataset, [tr_size, te_size])
        self.tr_loader = DataLoader(tr_dataset,
                                    batch_size=cfg.bs,
                                    shuffle=cfg.data_shuffle,
                                    num_workers=cfg.num_workers)
        self.te_loader = DataLoader(te_dataset,
                                    batch_size=cfg.bs,
                                    shuffle=cfg.data_shuffle,
                                    num_workers=cfg.num_workers)

        # define loss
        self.loss = torch.tensor(0)
        self.criterion = CrossEntropyLoss()

        # define optimizers for both generator and discriminator
        self.optimizer = RMSprop(self.model.parameters(), lr=cfg.lr)

        # initialize counter
        self.current_epoch = 0
        self.current_iteration = 0
        self.best_metric = 0
        self.best_info = ""

        # set cuda flag
        self.is_cuda = torch.cuda.is_available()
        if self.is_cuda and not self.cfg.cuda:
            self.logger.info(
                "WARNING: You have a CUDA device, so you should probably enable CUDA"
            )

        self.cuda = self.is_cuda & self.cfg.cuda

        # set the manual seed for torch
        self.manual_seed = self.cfg.seed
        if self.cuda:
            torch.cuda.manual_seed(self.manual_seed)
            self.model = self.model.to(self.device)
            if self.cfg.data_parallel:
                self.model = nn.DataParallel(self.model)
            self.logger.info("Program will run on *****GPU-CUDA***** ")
        else:
            self.model = self.model.to(self.device)
            torch.manual_seed(self.manual_seed)
            self.logger.info("Program will run on *****CPU*****\n")

        # Model Loading from cfg if not found start from scratch.
        self.exp_dir = os.path.join('./experiments', cfg.exp_name)
        self.load_checkpoint(self.cfg.checkpoint_filename)
        # Summary Writer
        self.summary_writer = SummaryWriter(
            log_dir=os.path.join(self.exp_dir, 'summaries'))
Ejemplo n.º 26
0
class A2C(Agent):
    """
    An agent learned with Advantage Actor-Critic
    - Actor takes state as input
    - Critic takes both state and action as input
    - agent interact with environment to collect experience
    - agent training with experience to update policy
    """
    def __init__(self, env, state_dim, action_dim,
                 memory_capacity=10000, max_steps=None,
                 roll_out_n_steps=10,
                 reward_gamma=0.99, reward_scale=1., done_penalty=None,
                 actor_hidden_size=32, critic_hidden_size=32,
                 actor_output_act=nn.functional.log_softmax, critic_loss="mse",
                 actor_lr=0.001, critic_lr=0.001,
                 optimizer_type="rmsprop", entropy_reg=0.01,
                 max_grad_norm=0.5, batch_size=100, episodes_before_train=100,
                 epsilon_start=0.9, epsilon_end=0.01, epsilon_decay=200,
                 use_cuda=True):
        super(A2C, self).__init__(env, state_dim, action_dim,
                 memory_capacity, max_steps,
                 reward_gamma, reward_scale, done_penalty,
                 actor_hidden_size, critic_hidden_size,
                 actor_output_act, critic_loss,
                 actor_lr, critic_lr,
                 optimizer_type, entropy_reg,
                 max_grad_norm, batch_size, episodes_before_train,
                 epsilon_start, epsilon_end, epsilon_decay,
                 use_cuda)

        self.roll_out_n_steps = roll_out_n_steps

        self.actor = ActorNetwork(self.state_dim, self.actor_hidden_size,
                                  self.action_dim, self.actor_output_act)
        self.critic = CriticNetwork(self.state_dim, self.action_dim,
                                    self.critic_hidden_size, 1)
        if self.optimizer_type == "adam":
            self.actor_optimizer = Adam(self.actor.parameters(), lr=self.actor_lr)
            self.critic_optimizer = Adam(self.critic.parameters(), lr=self.critic_lr)
        elif self.optimizer_type == "rmsprop":
            self.actor_optimizer = RMSprop(self.actor.parameters(), lr=self.actor_lr)
            self.critic_optimizer = RMSprop(self.critic.parameters(), lr=self.critic_lr)
        if self.use_cuda:
            self.actor.cuda()

    # agent interact with the environment to collect experience
    def interact(self):
        super(A2C, self)._take_n_steps()

    # train on a roll out batch
    def train(self):
        if self.n_episodes <= self.episodes_before_train:
            pass

        batch = self.memory.sample(self.batch_size)
        states_var = to_tensor_var(batch.states, self.use_cuda).view(-1, self.state_dim)
        one_hot_actions = index_to_one_hot(batch.actions, self.action_dim)
        actions_var = to_tensor_var(one_hot_actions, self.use_cuda).view(-1, self.action_dim)
        rewards_var = to_tensor_var(batch.rewards, self.use_cuda).view(-1, 1)

        # update actor network
        self.actor_optimizer.zero_grad()
        action_log_probs = self.actor(states_var)
        entropy_loss = th.mean(entropy(th.exp(action_log_probs)))
        action_log_probs = th.sum(action_log_probs * actions_var, 1)
        values = self.critic(states_var, actions_var).detach()
        advantages = rewards_var - values
        pg_loss = -th.mean(action_log_probs * advantages)
        actor_loss = pg_loss - entropy_loss * self.entropy_reg
        actor_loss.backward()
        if self.max_grad_norm is not None:
            nn.utils.clip_grad_norm(self.actor.parameters(), self.max_grad_norm)
        self.actor_optimizer.step()

        # update critic network
        self.critic_optimizer.zero_grad()
        target_values = rewards_var
        values = self.critic(states_var, actions_var)
        if self.critic_loss == "huber":
            critic_loss = nn.functional.smooth_l1_loss(values, target_values)
        else:
            critic_loss = nn.MSELoss()(values, target_values)
        critic_loss.backward()
        if self.max_grad_norm is not None:
            nn.utils.clip_grad_norm(self.critic.parameters(), self.max_grad_norm)
        self.critic_optimizer.step()

    # predict softmax action based on state
    def _softmax_action(self, state):
        state_var = to_tensor_var([state], self.use_cuda)
        softmax_action_var = th.exp(self.actor(state_var))
        if self.use_cuda:
            softmax_action = softmax_action_var.data.cpu().numpy()[0]
        else:
            softmax_action = softmax_action_var.data.numpy()[0]
        return softmax_action

    # choice an action based on state with random noise added for exploration in training
    def exploration_action(self, state):
        softmax_action = self._softmax_action(state)
        epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
                                  np.exp(-1. * self.n_steps / self.epsilon_decay)
        if np.random.rand() < epsilon:
            action = np.random.choice(self.action_dim)
        else:
            action = np.argmax(softmax_action)
        return action

    # choice an action based on state for execution
    def action(self, state):
        softmax_action = self._softmax_action(state)
        action = np.argmax(softmax_action)
        return action

    # evaluate value for a state-action pair
    def value(self, state, action):
        state_var = to_tensor_var([state], self.use_cuda)
        action = index_to_one_hot([action], self.action_dim).flatten()
        action_var = to_tensor_var([action], self.use_cuda)
        value_var = self.critic(state_var, action_var)
        if self.use_cuda:
            value = value_var.data.cpu().numpy()[0]
        else:
            value = value_var.data.numpy()[0]
        return value
Ejemplo n.º 27
0
    dis.load_state_dict(cp_data['dis_state_dict'])
    start_epoch = cp_data['epoch'] + 1
    print("Loaded saved models")

#Tensors for gradients computation
# one = torch.FloatTensor([1])
one = torch.tensor(1, dtype=torch.float)
mone = one * -1
if opt['use_cuda']:
    one = one.cuda()
    mone = mone.cuda()

#Default optimizers
# optimizerD = Adam(dis.parameters(), lr=float(opt['lr']), betas=(0.5, 0.9))
# optimizerG = Adam(gen.parameters(), lr=float(opt['lr']), betas=(0.5, 0.9))
optimizerD = RMSprop(dis.parameters(), lr=float(opt['lr']))
optimizerG = RMSprop(gen.parameters(), lr=float(opt['lr']))

gen.train()
dis.train()

dis_loss_history = []
gen_loss_history = []
MSE_history = []
correlation_hostory = []
real_loss_history = []

for epoch in range(start_epoch, opt['epochs']):
    start_time = time.time()
    # print("| epoch {} |".format(epoch))
    for p in dis.parameters():
Ejemplo n.º 28
0
class CentralV_Learner:
    def __init__(self, mac, scheme, logger, args):
        self.n_actions = args.n_actions
        self.n_actions = args.n_actions
        self.n_agents = args.n_agents
        self.state_shape = args.state_shape
        self.obs_shape = args.obs_shape
        self.mac = mac
        self.logger = logger
        self.args = args

        self.critic_training_steps = 0
        self.last_target_update_step = 0

        self.log_stats_t = -self.args.learner_log_interval - 1

        self.critic = CentralV_Critic(self.state_shape, args)
        self.target_critic = CentralV_Critic(self.state_shape, args)

        self.target_critic.load_state_dict(self.critic.state_dict())
        self.rnn_parameters = list(self.mac.parameters())
        self.critic_parameters = list(self.critic.parameters())

        self.critic_optimizer = RMSprop(self.critic_parameters,
                                        lr=args.critic_lr)
        self.agent_optimiser = RMSprop(self.rnn_parameters, lr=args.lr)

    def train(self, batch: EpisodeBatch, t_env: int, episode_num: int):
        rewards = batch["reward"][:, :-1]
        actions = batch["actions"][:, :]
        terminated = batch["terminated"][:, :-1].float()
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        avail_actions = batch["avail_actions"][:, :-1]

        mask = mask.repeat(1, 1, self.n_agents)

        td_error, critic_train_stats = self._train_critic(batch, rewards, mask)

        actions = actions[:, :-1]

        mac_out = []
        self.mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length - 1):
            agent_outs = self.mac.forward(batch, t=t)
            mac_out.append(agent_outs)
        mac_out = torch.stack(mac_out, dim=1)  # Concat over time

        # Mask out unavailable actions, renormalise (as in action selection)
        mac_out[avail_actions == 0] = 0
        mac_out = mac_out / mac_out.sum(dim=-1, keepdim=True)
        mac_out[avail_actions == 0] = 0

        pi_taken = torch.gather(mac_out, dim=3, index=actions).squeeze(3)
        pi_taken[mask == 0] = 1.0
        log_pi_taken = torch.log(pi_taken)

        centralV_loss = -(
            (td_error.detach() * log_pi_taken) * mask).sum() / mask.sum()

        # Optimise agents
        self.agent_optimiser.zero_grad()
        centralV_loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(self.rnn_parameters,
                                                   self.args.grad_norm_clip)
        self.agent_optimiser.step()

        if (self.critic_training_steps - self.last_target_update_step
            ) / self.args.target_update_interval >= 1.0:
            self._update_targets()
            self.last_target_update_step = self.critic_training_steps

        if t_env - self.log_stats_t >= self.args.learner_log_interval:
            ts_logged = len(critic_train_stats["critic_loss"])
            for key in ["critic_loss", "critic_grad_norm"]:
                self.logger.log_stat(key,
                                     sum(critic_train_stats[key]) / ts_logged,
                                     t_env)

            self.logger.log_stat("coma_loss", centralV_loss.item(), t_env)
            self.logger.log_stat("agent_grad_norm", grad_norm, t_env)
            self.log_stats_t = t_env

    def _train_critic(self, batch, rewards, mask):
        r, terminated = batch['reward'][:, :-1], batch[
            'terminated'][:, :-1].type(torch.FloatTensor)
        v_evals, v_targets = [], []
        for t in reversed(range(rewards.size(1))):
            state = batch["state"][:, t]
            nextState = batch["state"][:, t + 1]
            val = self.critic(state)
            target_val = self.target_critic(nextState)

            v_evals.append(val)
            v_targets.append(target_val)

        v_evals = torch.stack(v_evals,
                              dim=1)  # (episode_num, max_episode_len, 1)
        v_targets = torch.stack(v_targets, dim=1)

        re_mask = mask.repeat(1, 1, self.n_agents)
        targets = r + self.args.gamma * v_targets * (1 - terminated)
        td_error = targets.detach() - v_evals
        masked_td_error = re_mask * td_error

        loss = (masked_td_error**2).sum() / mask.sum()

        self.critic_optimizer.zero_grad()
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(self.critic_parameters,
                                                   self.args.grad_norm_clip)

        running_log = {"critic_loss": [], "critic_grad_norm": []}

        running_log["critic_loss"].append(loss.item())
        running_log["critic_grad_norm"].append(grad_norm)

        self.critic_training_steps += 1

        return td_error, running_log

    def _update_targets(self):
        self.target_critic.load_state_dict(self.critic.state_dict())
        self.logger.console_logger.info("Updated target network")

    def cuda(self):
        self.mac.cuda()
        self.critic.cuda()
        self.target_critic.cuda()

    def save_models(self, path):
        self.mac.save_models(path)
        torch.save(self.critic.state_dict(), "{}/critic.th".format(path))
        torch.save(self.agent_optimiser.state_dict(),
                   "{}/agent_opt.th".format(path))
        torch.save(self.critic_optimizer.state_dict(),
                   "{}/critic_opt.th".format(path))

    def load_models(self, path):
        self.mac.load_models(path)
        self.critic.load_state_dict(
            torch.load("{}/critic.th".format(path),
                       map_location=lambda storage, loc: storage))
        # Not quite right but I don't want to save target networks
        self.target_critic.load_state_dict(self.critic.state_dict())
        self.agent_optimiser.load_state_dict(
            torch.load("{}/agent_opt.th".format(path),
                       map_location=lambda storage, loc: storage))
        self.critic_optimizer.load_state_dict(
            torch.load("{}/critic_opt.th".format(path),
                       map_location=lambda storage, loc: storage))
Ejemplo n.º 29
0
def run(data_dir: str = './env/data',
        vae_dir: str = './vae/model',
        mdnrnn_dir: str = './mdnrnn/model',
        epochs: int = 20) -> None:
    """
    Train mdnrnn using saved environment rollouts.

    Parameters
    ----------
    data_dir
        Directory with train and test data.
    vae_dir
        Directory to load VAE model from.
    mdnrnn_dir
        Directory to optionally load MDNRNN model from and save trained model to.
    epochs
        Number of training epochs.
    """
    # set random seed and deterministic backend
    SEED = 123
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True

    # use GPU if available
    cuda = torch.cuda.is_available()
    device = torch.device("cuda" if cuda else "cpu")

    # define input transformations
    transform_train = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((H, W)),
        transforms.ToTensor(),
    ])

    transform_test = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((H, W)),
        transforms.ToTensor(),
    ])

    # define train and test datasets
    dir_train = os.path.join(data_dir, 'train/')
    dir_test = os.path.join(data_dir, 'test/')
    dataset_train = GymDataset(dir_train,
                               seq_len=SEQ_LEN,
                               transform=transform_train)
    dataset_test = GymDataset(dir_test,
                              seq_len=SEQ_LEN,
                              transform=transform_test)
    dataset_test.load_batch(0)  # 1 batch of data used for test set
    dataloader_test = torch.utils.data.DataLoader(dataset_test,
                                                  batch_size=BATCH_SIZE,
                                                  shuffle=False,
                                                  collate_fn=collate_fn)

    # define and load VAE model
    vae = VAE(CHANNELS, LATENT_SIZE)
    load_vae_file = os.path.join(vae_dir, 'best.tar')
    state_vae = torch.load(load_vae_file)
    vae.load_state_dict(state_vae['state_dict'])
    vae.to(device)

    # set save and optional load directories for the MDNRNN model
    load_mdnrnn_file = os.path.join(mdnrnn_dir, 'best.tar')
    try:
        state_mdnrnn = torch.load(load_mdnrnn_file)
    except FileNotFoundError:
        state_mdnrnn = None

    # define and load MDNRNN model
    mdnrnn = MDNRNN(LATENT_SIZE,
                    ACTION_SIZE,
                    HIDDEN_SIZE,
                    N_GAUSS,
                    rewards_terminal=False)
    if state_mdnrnn is not None:
        mdnrnn.load_state_dict(state_mdnrnn['state_dict'])
    mdnrnn.zero_grad()
    mdnrnn.to(device)

    # optimizer
    params = [p for p in mdnrnn.parameters() if p.requires_grad]
    optimizer = RMSprop(params, lr=LR, alpha=.9)
    if state_mdnrnn is not None:
        optimizer.load_state_dict(state_mdnrnn['optimizer'])

    # learning rate scheduling
    lr_scheduler = StepLR(optimizer, step_size=3, gamma=0.1)
    if state_mdnrnn is not None:
        lr_scheduler.load_state_dict(state_mdnrnn['scheduler'])

    # helper function
    def img2latent(obs, batch_size):
        """ Function to go from image to latent space. """
        with torch.no_grad():
            obs = obs.view(-1, CHANNELS, H, W)
            _, mu, logsigma = vae(obs)
            latent = (mu + logsigma.exp() * torch.randn_like(mu)).view(
                batch_size, SEQ_LEN, LATENT_SIZE)
        return latent

    # define test fn
    def test():
        """ One test epoch """
        mdnrnn.eval()
        test_loss = 0
        n_test = len(dataloader_test.dataset)
        with torch.no_grad():
            for (obs, action, next_obs) in generate_obs(dataloader_test):

                batch_size = len(obs)

                # place on device
                try:
                    obs = torch.stack(obs).to(device)
                    next_obs = torch.stack(next_obs).to(device)
                    action = torch.stack(action).to(device)
                except:
                    print(
                        'Did not manage to stack test observations and actions.'
                    )
                    n_test -= batch_size
                    continue

                # convert to latent space
                latent_obs = img2latent(obs, batch_size)
                next_latent_obs = img2latent(next_obs, batch_size)

                # need to flip dims to feed into LSTM from [batch, seq_len, dim] to [seq_len, batch, dim]
                latent_obs, action, next_latent_obs = [
                    arr.transpose(1, 0)
                    for arr in [latent_obs, action, next_latent_obs]
                ]

                # forward pass model
                mus, sigmas, logpi = mdnrnn(action, latent_obs)

                # compute loss
                loss = gmm_loss(next_latent_obs, mus, sigmas, logpi)
                test_loss += loss.item()

        test_loss /= n_test
        return test_loss

    # train
    n_batch_train = len(dataset_train.batch_list)
    optimizer.zero_grad()

    cur_best = None

    tq_episode = tqdm_notebook(range(epochs))
    for epoch in tq_episode:

        mdnrnn.train()
        loss_train = 0
        n_batch = 0

        tq_batch = tqdm_notebook(range(n_batch_train))
        for i in tq_batch:  # loop over training data for each epoch

            dataset_train.load_batch(i)
            dataloader_train = torch.utils.data.DataLoader(
                dataset_train,
                batch_size=BATCH_SIZE,
                shuffle=True,
                collate_fn=collate_fn)

            tq_minibatch = tqdm_notebook(generate_obs(dataloader_train),
                                         total=len(dataloader_train),
                                         leave=False)
            for j, (obs, action, next_obs) in enumerate(tq_minibatch):

                n_batch += 1

                # place on device
                batch_size = len(obs)
                try:
                    obs = torch.stack(obs).to(device)
                    next_obs = torch.stack(next_obs).to(device)
                    action = torch.stack(action).to(device)
                except:
                    print('Did not manage to stack observations and actions.')
                    continue

                # convert to latent space
                latent_obs = img2latent(obs, batch_size)
                next_latent_obs = img2latent(next_obs, batch_size)

                # need to flip dims to feed into LSTM from [batch, seq_len, dim] to [seq_len, batch, dim]
                latent_obs, action, next_latent_obs = [
                    arr.transpose(1, 0)
                    for arr in [latent_obs, action, next_latent_obs]
                ]

                # forward pass model
                mus, sigmas, logpi = mdnrnn(action, latent_obs)

                # compute loss
                loss = gmm_loss(next_latent_obs, mus, sigmas, logpi)

                # backward pass
                loss.backward()

                # store loss value
                loss_train += loss.item()
                loss_train_avg = loss_train / (n_batch * BATCH_SIZE)

                # apply gradients and learning rate scheduling with optional gradient accumulation
                if (j + 1) % GRAD_ACCUMULATION_STEPS == 0:
                    optimizer.step()
                    optimizer.zero_grad()

                tq_minibatch.set_postfix(loss_train=loss_train_avg)

            tq_batch.set_postfix(loss_train=loss_train_avg)

        lr_scheduler.step()

        # evaluate on test set
        loss_test_avg = test()

        # checkpointing
        best_filename = os.path.join(mdnrnn_dir, 'best.tar')
        filename = os.path.join(mdnrnn_dir, 'checkpoint.tar')
        is_best = not cur_best or loss_test_avg < cur_best
        if is_best:
            cur_best = loss_test_avg

        save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': mdnrnn.state_dict(),
                'precision': loss_test_avg,
                'optimizer': optimizer.state_dict(),
                'scheduler': lr_scheduler.state_dict()
            }, is_best, filename, best_filename)

        tq_episode.set_postfix(loss_train=loss_train_avg,
                               loss_test=loss_test_avg)
Ejemplo n.º 30
0
class NQLearner:
    def __init__(self, mac, scheme, logger, args):
        self.args = args
        self.mac = mac
        self.logger = logger

        self.last_target_update_episode = 0
        self.device = th.device('cuda' if args.use_cuda else 'cpu')
        self.params = list(mac.parameters())

        if args.mixer == "qatten":
            self.mixer = QattenMixer(args)
        elif args.mixer == "vdn":
            self.mixer = VDNMixer()
        elif args.mixer == "qmix":
            self.mixer = Mixer(args)
        else:
            raise "mixer error"
        self.target_mixer = copy.deepcopy(self.mixer)
        self.params += list(self.mixer.parameters())

        print('Mixer Size: ')
        print(get_parameters_num(self.mixer.parameters()))

        if self.args.optimizer == 'adam':
            self.optimiser = Adam(params=self.params, lr=args.lr)
        else:
            self.optimiser = RMSprop(params=self.params,
                                     lr=args.lr,
                                     alpha=args.optim_alpha,
                                     eps=args.optim_eps)

        # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC
        self.target_mac = copy.deepcopy(mac)

        self.log_stats_t = -self.args.learner_log_interval - 1

        self.train_t = 0

        # th.autograd.set_detect_anomaly(True)

    def train(self, batch: EpisodeBatch, t_env: int, episode_num: int):
        # Get the relevant quantities
        rewards = batch["reward"][:, :-1]
        actions = batch["actions"][:, :-1]
        terminated = batch["terminated"][:, :-1].float()
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        avail_actions = batch["avail_actions"]

        # Calculate estimated Q-Values
        mac_out = []
        self.mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            agent_outs = self.mac.forward(batch, t=t)
            mac_out.append(agent_outs)
        mac_out = th.stack(mac_out, dim=1)  # Concat over time

        # Pick the Q-Values for the actions taken by each agent
        chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3,
                                        index=actions).squeeze(
                                            3)  # Remove the last dim
        chosen_action_qvals_ = chosen_action_qvals

        # Calculate the Q-Values necessary for the target
        with th.no_grad():
            target_mac_out = []
            self.target_mac.init_hidden(batch.batch_size)
            for t in range(batch.max_seq_length):
                target_agent_outs = self.target_mac.forward(batch, t=t)
                target_mac_out.append(target_agent_outs)

            # We don't need the first timesteps Q-Value estimate for calculating targets
            target_mac_out = th.stack(target_mac_out,
                                      dim=1)  # Concat across time

            # Max over target Q-Values/ Double q learning
            mac_out_detach = mac_out.clone().detach()
            mac_out_detach[avail_actions == 0] = -9999999
            cur_max_actions = mac_out_detach.max(dim=3, keepdim=True)[1]
            target_max_qvals = th.gather(target_mac_out, 3,
                                         cur_max_actions).squeeze(3)

            # Calculate n-step Q-Learning targets
            target_max_qvals = self.target_mixer(target_max_qvals,
                                                 batch["state"])

            if getattr(self.args, 'q_lambda', False):
                qvals = th.gather(target_mac_out, 3,
                                  batch["actions"]).squeeze(3)
                qvals = self.target_mixer(qvals, batch["state"])

                targets = build_q_lambda_targets(rewards, terminated, mask,
                                                 target_max_qvals, qvals,
                                                 self.args.gamma,
                                                 self.args.td_lambda)
            else:
                targets = build_td_lambda_targets(rewards, terminated, mask,
                                                  target_max_qvals,
                                                  self.args.n_agents,
                                                  self.args.gamma,
                                                  self.args.td_lambda)

        # Mixer
        chosen_action_qvals = self.mixer(chosen_action_qvals,
                                         batch["state"][:, :-1])

        td_error = (chosen_action_qvals - targets.detach())
        td_error = 0.5 * td_error.pow(2)

        mask = mask.expand_as(td_error)
        masked_td_error = td_error * mask
        loss = L_td = masked_td_error.sum() / mask.sum()

        # Optimise
        self.optimiser.zero_grad()
        loss.backward()
        grad_norm = th.nn.utils.clip_grad_norm_(self.params,
                                                self.args.grad_norm_clip)
        self.optimiser.step()

        if (episode_num - self.last_target_update_episode
            ) / self.args.target_update_interval >= 1.0:
            self._update_targets()
            self.last_target_update_episode = episode_num

        if t_env - self.log_stats_t >= self.args.learner_log_interval:
            self.logger.log_stat("loss_td", L_td.item(), t_env)
            self.logger.log_stat("grad_norm", grad_norm, t_env)
            mask_elems = mask.sum().item()
            self.logger.log_stat(
                "td_error_abs",
                (masked_td_error.abs().sum().item() / mask_elems), t_env)
            self.logger.log_stat("q_taken_mean",
                                 (chosen_action_qvals * mask).sum().item() /
                                 (mask_elems * self.args.n_agents), t_env)
            self.logger.log_stat("target_mean", (targets * mask).sum().item() /
                                 (mask_elems * self.args.n_agents), t_env)
            self.log_stats_t = t_env

            # print estimated matrix
            if self.args.env == "one_step_matrix_game":
                print_matrix_status(batch, self.mixer, mac_out)

    def _update_targets(self):
        self.target_mac.load_state(self.mac)
        if self.mixer is not None:
            self.target_mixer.load_state_dict(self.mixer.state_dict())
        self.logger.console_logger.info("Updated target network")

    def cuda(self):
        self.mac.cuda()
        self.target_mac.cuda()
        if self.mixer is not None:
            self.mixer.cuda()
            self.target_mixer.cuda()

    def save_models(self, path):
        self.mac.save_models(path)
        if self.mixer is not None:
            th.save(self.mixer.state_dict(), "{}/mixer.th".format(path))
        th.save(self.optimiser.state_dict(), "{}/opt.th".format(path))

    def load_models(self, path):
        self.mac.load_models(path)
        # Not quite right but I don't want to save target networks
        self.target_mac.load_models(path)
        if self.mixer is not None:
            self.mixer.load_state_dict(
                th.load("{}/mixer.th".format(path),
                        map_location=lambda storage, loc: storage))
        self.optimiser.load_state_dict(
            th.load("{}/opt.th".format(path),
                    map_location=lambda storage, loc: storage))
Ejemplo n.º 31
0
class NoiseQLearner:
    def __init__(self, mac, scheme, logger, args):
        self.args = args
        self.mac = mac
        self.logger = logger

        self.params = list(mac.parameters())

        self.last_target_update_episode = 0

        self.mixer = None
        if args.mixer is not None:
            if args.mixer == "vdn":
                self.mixer = VDNMixer()
            elif args.mixer == "qmix":
                self.mixer = NoiseQMixer(args)
            else:
                raise ValueError("Mixer {} not recognised.".format(args.mixer))
            self.params += list(self.mixer.parameters())
            self.target_mixer = copy.deepcopy(self.mixer)

        discrim_input = np.prod(
            self.args.state_shape) + self.args.n_agents * self.args.n_actions

        if self.args.rnn_discrim:
            self.rnn_agg = RNNAggregator(discrim_input, args)
            self.discrim = Discrim(args.rnn_agg_size, self.args.noise_dim,
                                   args)
            self.params += list(self.discrim.parameters())
            self.params += list(self.rnn_agg.parameters())
        else:
            self.discrim = Discrim(discrim_input, self.args.noise_dim, args)
            self.params += list(self.discrim.parameters())
        self.discrim_loss = th.nn.CrossEntropyLoss(reduction="none")

        self.optimiser = RMSprop(params=self.params,
                                 lr=args.lr,
                                 alpha=args.optim_alpha,
                                 eps=args.optim_eps)

        self.target_mac = copy.deepcopy(mac)

        self.log_stats_t = -self.args.learner_log_interval - 1

    def train(self, batch: EpisodeBatch, t_env: int, episode_num: int):
        # Get the relevant quantities
        rewards = batch["reward"][:, :-1]
        actions = batch["actions"][:, :-1]
        terminated = batch["terminated"][:, :-1].float()
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        avail_actions = batch["avail_actions"]
        noise = batch["noise"][:,
                               0].unsqueeze(1).repeat(1, rewards.shape[1], 1)

        # Calculate estimated Q-Values
        mac_out = []
        self.mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            agent_outs = self.mac.forward(batch, t=t)
            mac_out.append(agent_outs)
        mac_out = th.stack(mac_out, dim=1)  # Concat over time

        # Pick the Q-Values for the actions taken by each agent
        chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3,
                                        index=actions).squeeze(
                                            3)  # Remove the last dim

        # Calculate the Q-Values necessary for the target
        target_mac_out = []
        self.target_mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            target_agent_outs = self.target_mac.forward(batch, t=t)
            target_mac_out.append(target_agent_outs)

        # We don't need the first timesteps Q-Value estimate for calculating targets
        target_mac_out = th.stack(target_mac_out[1:],
                                  dim=1)  # Concat across time

        # Mask out unavailable actions
        target_mac_out[avail_actions[:,
                                     1:] == 0] = -9999999  # From OG deepmarl

        # Max over target Q-Values
        if self.args.double_q:
            # Get actions that maximise live Q (for double q-learning)
            mac_out[avail_actions == 0] = -9999999
            cur_max_actions = mac_out[:, 1:].max(dim=3, keepdim=True)[1]
            target_max_qvals = th.gather(target_mac_out, 3,
                                         cur_max_actions).squeeze(3)
        else:
            target_max_qvals = target_mac_out.max(dim=3)[0]

        # Mix
        if self.mixer is not None:
            chosen_action_qvals = self.mixer(chosen_action_qvals,
                                             batch["state"][:, :-1], noise)
            target_max_qvals = self.target_mixer(target_max_qvals,
                                                 batch["state"][:, 1:], noise)

        # Discriminator
        mac_out[avail_actions == 0] = -9999999
        q_softmax_actions = th.nn.functional.softmax(mac_out[:, :-1], dim=3)

        if self.args.hard_qs:
            maxs = th.max(mac_out[:, :-1], dim=3, keepdim=True)[1]
            zeros = th.zeros_like(q_softmax_actions)
            zeros.scatter_(dim=3, index=maxs, value=1)
            q_softmax_actions = zeros

        q_softmax_agents = q_softmax_actions.reshape(
            q_softmax_actions.shape[0], q_softmax_actions.shape[1], -1)

        states = batch["state"][:, :-1]
        state_and_softactions = th.cat([q_softmax_agents, states], dim=2)

        if self.args.rnn_discrim:
            h_to_use = th.zeros(size=(batch.batch_size,
                                      self.args.rnn_agg_size)).to(
                                          states.device)
            hs = th.ones_like(h_to_use)
            for t in range(batch.max_seq_length - 1):
                hs = self.rnn_agg(state_and_softactions[:, t], hs)
                for b in range(batch.batch_size):
                    if t == batch.max_seq_length - 2 or (mask[b, t] == 1 and
                                                         mask[b, t + 1] == 0):
                        # This is the last timestep of the sequence
                        h_to_use[b] = hs[b]
            s_and_softa_reshaped = h_to_use
        else:
            s_and_softa_reshaped = state_and_softactions.reshape(
                -1, state_and_softactions.shape[-1])

        if self.args.mi_intrinsic:
            s_and_softa_reshaped = s_and_softa_reshaped.detach()

        discrim_prediction = self.discrim(s_and_softa_reshaped)

        # Cross-Entropy
        target_repeats = 1
        if not self.args.rnn_discrim:
            target_repeats = q_softmax_actions.shape[1]
        discrim_target = batch["noise"][:, 0].long().detach().max(
            dim=1)[1].unsqueeze(1).repeat(1, target_repeats).reshape(-1)
        discrim_loss = self.discrim_loss(discrim_prediction, discrim_target)

        if self.args.rnn_discrim:
            averaged_discrim_loss = discrim_loss.mean()
        else:
            masked_discrim_loss = discrim_loss * mask.reshape(-1)
            averaged_discrim_loss = masked_discrim_loss.sum() / mask.sum()
        self.logger.log_stat("discrim_loss", averaged_discrim_loss.item(),
                             t_env)

        # Calculate 1-step Q-Learning targets
        targets = rewards + self.args.gamma * (1 -
                                               terminated) * target_max_qvals
        if self.args.mi_intrinsic:
            assert self.args.rnn_discrim is False
            targets = targets + self.args.mi_scaler * discrim_loss.view_as(
                rewards)

        # Td-error
        td_error = (chosen_action_qvals - targets.detach())

        mask = mask.expand_as(td_error)

        # 0-out the targets that came from padded data
        masked_td_error = td_error * mask

        # Normal L2 loss, take mean over actual data
        loss = (masked_td_error**2).sum() / mask.sum()

        loss = loss + self.args.mi_loss * averaged_discrim_loss

        # Optimise
        self.optimiser.zero_grad()
        loss.backward()
        grad_norm = th.nn.utils.clip_grad_norm_(self.params,
                                                self.args.grad_norm_clip)
        self.optimiser.step()

        if (episode_num - self.last_target_update_episode
            ) / self.args.target_update_interval >= 1.0:
            self._update_targets()
            self.last_target_update_episode = episode_num

        if t_env - self.log_stats_t >= self.args.learner_log_interval:
            self.logger.log_stat("loss", loss.item(), t_env)
            self.logger.log_stat("grad_norm", grad_norm, t_env)
            mask_elems = mask.sum().item()
            self.logger.log_stat(
                "td_error_abs",
                (masked_td_error.abs().sum().item() / mask_elems), t_env)
            self.logger.log_stat("q_taken_mean",
                                 (chosen_action_qvals * mask).sum().item() /
                                 (mask_elems * self.args.n_agents), t_env)
            self.logger.log_stat("target_mean", (targets * mask).sum().item() /
                                 (mask_elems * self.args.n_agents), t_env)
            self.log_stats_t = t_env

    def _update_targets(self):
        self.target_mac.load_state(self.mac)
        if self.mixer is not None:
            self.target_mixer.load_state_dict(self.mixer.state_dict())
        self.logger.console_logger.info("Updated target network")

    def cuda(self):
        self.mac.cuda()
        self.target_mac.cuda()
        self.discrim.cuda()
        if self.args.rnn_discrim:
            self.rnn_agg.cuda()
        if self.mixer is not None:
            self.mixer.cuda()
            self.target_mixer.cuda()

    def save_models(self, path):
        self.mac.save_models(path)
        if self.mixer is not None:
            th.save(self.mixer.state_dict(), "{}/mixer.th".format(path))
        th.save(self.optimiser.state_dict(), "{}/opt.th".format(path))

    def load_models(self, path):
        self.mac.load_models(path)
        self.target_mac.load_models(path)
        if self.mixer is not None:
            self.mixer.load_state_dict(
                th.load("{}/mixer.th".format(path),
                        map_location=lambda storage, loc: storage))
        self.optimiser.load_state_dict(
            th.load("{}/opt.th".format(path),
                    map_location=lambda storage, loc: storage))
Ejemplo n.º 32
0
class QLearner:
    def __init__(self, mac, scheme, logger, args):
        self.args = args
        self.mac = mac
        self.logger = logger

        self.params = list(mac.parameters())

        self.last_target_update_episode = 0
        self.device = th.device('cuda' if args.use_cuda else 'cpu')

        self.mixer = None
        if args.mixer is not None:
            if args.mixer == "vdn":
                self.mixer = VDNMixer()
            elif args.mixer == "qmix":
                self.mixer = QMixer(args)
            else:
                raise ValueError("Mixer {} not recognised.".format(args.mixer))
            self.params += list(self.mixer.parameters())
            self.target_mixer = copy.deepcopy(self.mixer)

        if self.args.optimizer == 'adam':
            self.optimiser = Adam(params=self.params, lr=args.lr)
        else:
            self.optimiser = RMSprop(params=self.params,
                                     lr=args.lr,
                                     alpha=args.optim_alpha,
                                     eps=args.optim_eps)

        # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC
        self.target_mac = copy.deepcopy(mac)

        self.log_stats_t = -self.args.learner_log_interval - 1

        self.train_t = 0

    def train(self, batch: EpisodeBatch, t_env: int, episode_num: int):
        # Get the relevant quantities
        rewards = batch["reward"][:, :-1]
        actions = batch["actions"][:, :-1]
        terminated = batch["terminated"][:, :-1].float()
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        avail_actions = batch["avail_actions"]

        # Calculate estimated Q-Values
        mac_out = []
        self.mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            agent_outs = self.mac.forward(batch, t=t)
            mac_out.append(agent_outs)
        mac_out = th.stack(mac_out, dim=1)  # Concat over time

        # Pick the Q-Values for the actions taken by each agent
        chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3,
                                        index=actions).squeeze(
                                            3)  # Remove the last dim
        chosen_action_qvals_back = chosen_action_qvals

        # Calculate the Q-Values necessary for the target
        target_mac_out = []
        self.target_mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            target_agent_outs = self.target_mac.forward(batch, t=t)
            target_mac_out.append(target_agent_outs)

        # We don't need the first timesteps Q-Value estimate for calculating targets
        target_mac_out = th.stack(target_mac_out[1:],
                                  dim=1)  # Concat across time

        # Mask out unavailable actions
        target_mac_out[avail_actions[:, 1:] == 0] = -9999999

        # Max over target Q-Values
        if self.args.double_q:
            # Get actions that maximise live Q (for double q-learning)
            mac_out_detach = mac_out.clone().detach()
            mac_out_detach[avail_actions == 0] = -9999999
            cur_max_actions = mac_out_detach[:, 1:].max(dim=3, keepdim=True)[1]
            target_max_qvals = th.gather(target_mac_out, 3,
                                         cur_max_actions).squeeze(3)
        else:
            target_max_qvals = target_mac_out.max(dim=3)[0]

        # Mix
        if self.mixer is not None:
            chosen_action_qvals = self.mixer(chosen_action_qvals,
                                             batch["state"][:, :-1])
            target_max_qvals = self.target_mixer(target_max_qvals,
                                                 batch["state"][:, 1:])

        # Calculate 1-step Q-Learning targets
        targets = rewards + self.args.gamma * (1 -
                                               terminated) * target_max_qvals

        # Td-error
        td_error = (chosen_action_qvals - targets.detach())

        mask = mask.expand_as(td_error)

        # 0-out the targets that came from padded data
        masked_td_error = td_error * mask

        # Normal L2 loss, take mean over actual data
        loss = 0.5 * (masked_td_error**2).sum() / mask.sum()

        # Optimise
        self.optimiser.zero_grad()
        loss.backward()
        grad_norm = th.nn.utils.clip_grad_norm_(self.params,
                                                self.args.grad_norm_clip)
        self.optimiser.step()

        if (episode_num - self.last_target_update_episode
            ) / self.args.target_update_interval >= 1.0:
            self._update_targets()
            self.last_target_update_episode = episode_num

        if t_env - self.log_stats_t >= self.args.learner_log_interval:
            self.logger.log_stat("loss_td", loss.item(), t_env)
            self.logger.log_stat("grad_norm", grad_norm, t_env)
            mask_elems = mask.sum().item()
            self.logger.log_stat(
                "td_error_abs",
                (masked_td_error.abs().sum().item() / mask_elems), t_env)
            self.logger.log_stat("q_taken_mean",
                                 (chosen_action_qvals * mask).sum().item() /
                                 (mask_elems * self.args.n_agents), t_env)
            self.logger.log_stat("target_mean", (targets * mask).sum().item() /
                                 (mask_elems * self.args.n_agents), t_env)
            self.log_stats_t = t_env

            # print estimated matrix
            if self.args.env == "one_step_matrix_game":
                print_matrix_status(batch, self.mixer, mac_out)

    def _update_targets(self):
        self.target_mac.load_state(self.mac)
        if self.mixer is not None:
            self.target_mixer.load_state_dict(self.mixer.state_dict())
        self.logger.console_logger.info("Updated target network")

    def cuda(self):
        self.mac.cuda()
        self.target_mac.cuda()
        if self.mixer is not None:
            self.mixer.cuda()
            self.target_mixer.cuda()

    def save_models(self, path):
        self.mac.save_models(path)
        if self.mixer is not None:
            th.save(self.mixer.state_dict(), "{}/mixer.th".format(path))
        th.save(self.optimiser.state_dict(), "{}/opt.th".format(path))

    def load_models(self, path):
        self.mac.load_models(path)
        # Not quite right but I don't want to save target networks
        self.target_mac.load_models(path)
        if self.mixer is not None:
            self.mixer.load_state_dict(
                th.load("{}/mixer.th".format(path),
                        map_location=lambda storage, loc: storage))
        self.optimiser.load_state_dict(
            th.load("{}/opt.th".format(path),
                    map_location=lambda storage, loc: storage))
Ejemplo n.º 33
0
    def __init__(self, obs_space, action_space, config):
        _validate(obs_space, action_space)
        config = dict(ray.rllib.agents.qmix.qmix.DEFAULT_CONFIG, **config)
        self.config = config
        self.observation_space = obs_space
        self.action_space = action_space
        self.n_agents = len(obs_space.original_space.spaces)
        self.n_actions = action_space.spaces[0].n
        self.h_size = config["model"]["lstm_cell_size"]

        agent_obs_space = obs_space.original_space.spaces[0]
        if isinstance(agent_obs_space, Dict):
            space_keys = set(agent_obs_space.spaces.keys())
            if space_keys != {"obs", "action_mask"}:
                raise ValueError(
                    "Dict obs space for agent must have keyset "
                    "['obs', 'action_mask'], got {}".format(space_keys))
            mask_shape = tuple(agent_obs_space.spaces["action_mask"].shape)
            if mask_shape != (self.n_actions, ):
                raise ValueError("Action mask shape must be {}, got {}".format(
                    (self.n_actions, ), mask_shape))
            self.has_action_mask = True
            self.obs_size = _get_size(agent_obs_space.spaces["obs"])
            # The real agent obs space is nested inside the dict
            agent_obs_space = agent_obs_space.spaces["obs"]
        else:
            self.has_action_mask = False
            self.obs_size = _get_size(agent_obs_space)

        self.model = ModelCatalog.get_torch_model(
            agent_obs_space,
            self.n_actions,
            config["model"],
            default_model_cls=RNNModel)
        self.target_model = ModelCatalog.get_torch_model(
            agent_obs_space,
            self.n_actions,
            config["model"],
            default_model_cls=RNNModel)

        # Setup the mixer network.
        # The global state is just the stacked agent observations for now.
        self.state_shape = [self.obs_size, self.n_agents]
        if config["mixer"] is None:
            self.mixer = None
            self.target_mixer = None
        elif config["mixer"] == "qmix":
            self.mixer = QMixer(self.n_agents, self.state_shape,
                                config["mixing_embed_dim"])
            self.target_mixer = QMixer(self.n_agents, self.state_shape,
                                       config["mixing_embed_dim"])
        elif config["mixer"] == "vdn":
            self.mixer = VDNMixer()
            self.target_mixer = VDNMixer()
        else:
            raise ValueError("Unknown mixer type {}".format(config["mixer"]))

        self.cur_epsilon = 1.0
        self.update_target()  # initial sync

        # Setup optimizer
        self.params = list(self.model.parameters())
        self.loss = QMixLoss(self.model, self.target_model, self.mixer,
                             self.target_mixer, self.n_agents, self.n_actions,
                             self.config["double_q"], self.config["gamma"])
        self.optimiser = RMSprop(
            params=self.params,
            lr=config["lr"],
            alpha=config["optim_alpha"],
            eps=config["optim_eps"])
Ejemplo n.º 34
0
class DQN_Model_Agent:

    def __init__(self, args, exp_model, logging_func):
        self.args = args

        # Exploration Model
        self.exp_model = exp_model

        self.log = logging_func["log"]
        self.log_image = logging_func["image"]
        os.makedirs("{}/transition_model".format(args.log_path))

        # Experience Replay
        self.replay = ExpReplay(args.exp_replay_size, args.stale_limit, exp_model, args, priority=self.args.prioritized)

        # DQN and Target DQN
        model = get_models(args.model)
        print("\n\nDQN")
        self.dqn = model(actions=args.actions)
        print("Target DQN")
        self.target_dqn = model(actions=args.actions)

        dqn_params = 0
        for weight in self.dqn.parameters():
            weight_params = 1
            for s in weight.size():
                weight_params *= s
            dqn_params += weight_params
        print("Model DQN has {:,} parameters.".format(dqn_params))

        self.target_dqn.eval()

        if args.gpu:
            print("Moving models to GPU.")
            self.dqn.cuda()
            self.target_dqn.cuda()

        # Optimizer
        # self.optimizer = Adam(self.dqn.parameters(), lr=args.lr)
        self.optimizer = RMSprop(self.dqn.parameters(), lr=args.lr)

        self.T = 0
        self.target_sync_T = -self.args.t_max

        # Action sequences
        self.actions_to_take = []

    def sync_target_network(self):
        for target, source in zip(self.target_dqn.parameters(), self.dqn.parameters()):
            target.data = source.data


    def get_pc_estimates(self, root_state, depth=0, starts=None):
        state = root_state
        bonuses = []
        for action in range(self.args.actions):

            # Current pc estimates
            if depth == 0 or not self.args.only_leaf:
                numpy_state = state[0].numpy().swapaxes(0, 2)
                _, info = self.exp_model.bonus(numpy_state, action, dont_remember=True)
                action_pseudo_count = info["Pseudo_Count"]
                action_bonus = self.args.optimistic_scaler / np.power(action_pseudo_count + 0.01, self.args.bandit_p)
                if starts is not None:
                    action_bonus += starts[action]

            # If the depth is 0 we don't want to look any further ahead
            if depth == 0:
                bonuses.append(action_bonus)
                continue

            one_hot_action = torch.zeros(1, self.args.actions)
            one_hot_action[0, action] = 1
            _, next_state_prediction = self.dqn(Variable(state, volatile=True), Variable(one_hot_action, volatile=True))
            next_state_prediction = next_state_prediction.cpu().data

            next_state_pc_estimates = self.get_pc_estimates(next_state_prediction, depth=depth - 1)

            if self.args.only_leaf:
                bonuses += next_state_pc_estimates
            else:
                ahead_pc_estimates = [action_bonus + self.args.gamma * n for n in next_state_pc_estimates]
                bonuses += ahead_pc_estimates

        return bonuses

    def act(self, state, epsilon, exp_model, evaluation=False):
        # self.T += 1
        if not evaluation:
            if len(self.actions_to_take) > 0:
                action_to_take = self.actions_to_take[0]
                self.actions_to_take = self.actions_to_take[1:]
                return action_to_take, {"Action": action_to_take, "Q_Values": self.prev_q_vals}

        self.dqn.eval()
        # orig_state = state[:, :, -1:]
        state = torch.from_numpy(state).float().transpose_(0, 2).unsqueeze(0)
        q_values = self.dqn(Variable(state, volatile=True)).cpu().data[0]
        q_values_numpy = q_values.numpy()
        self.prev_q_vals = q_values_numpy

        extra_info = {}

        if self.args.optimistic_init and not evaluation and len(self.actions_to_take) == 0:

            # 2 action lookahead
            action_bonuses = self.get_pc_estimates(state, depth=self.args.lookahead_depth, starts=q_values_numpy)

            # Find the maximum sequence 
            max_so_far = -100000
            best_index = 0
            best_seq = []

            for ii, bonus in enumerate(action_bonuses):
                if bonus > max_so_far:
                    best_index = ii
                    max_so_far = bonus

            for depth in range(self.args.lookahead_depth):
                last_action = best_index % self.args.actions
                best_index = best_index // self.args.actions
                best_seq = best_seq + [last_action]

            # print(best_seq)
            self.actions_to_take = best_seq

        extra_info["Q_Values"] = q_values_numpy

        if np.random.random() < epsilon:
            action = np.random.randint(low=0, high=self.args.actions)
        else:
            action = q_values.max(0)[1][0]  # Torch...

        extra_info["Action"] = action

        return action, extra_info

    def experience(self, state, action, reward, state_next, steps, terminated, pseudo_reward=0, density=1, exploring=False):
        if not exploring:
            self.T += 1
        self.replay.Add_Exp(state, action, reward, state_next, steps, terminated, pseudo_reward, density)

    def end_of_trajectory(self):
        self.replay.end_of_trajectory()

    def train(self):

        if self.T - self.target_sync_T > self.args.target:
            self.sync_target_network()
            self.target_sync_T = self.T

        info = {}

        for _ in range(self.args.iters):
            self.dqn.eval()

            # TODO: Use a named tuple for experience replay
            n_step_sample = self.args.n_step
            batch, indices, is_weights = self.replay.Sample_N(self.args.batch_size, n_step_sample, self.args.gamma)
            columns = list(zip(*batch))

            states = Variable(torch.from_numpy(np.array(columns[0])).float().transpose_(1, 3))
            actions = Variable(torch.LongTensor(columns[1]))
            terminal_states = Variable(torch.FloatTensor(columns[5]))
            rewards = Variable(torch.FloatTensor(columns[2]))
            # Have to clip rewards for DQN
            rewards = torch.clamp(rewards, -1, 1)
            steps = Variable(torch.FloatTensor(columns[4]))
            new_states = Variable(torch.from_numpy(np.array(columns[3])).float().transpose_(1, 3))

            target_dqn_qvals = self.target_dqn(new_states).cpu()
            # Make a new variable with those values so that these are treated as constants
            target_dqn_qvals_data = Variable(target_dqn_qvals.data)

            q_value_targets = (Variable(torch.ones(terminal_states.size()[0])) - terminal_states)
            inter = Variable(torch.ones(terminal_states.size()[0]) * self.args.gamma)
            # print(steps)
            q_value_targets = q_value_targets * torch.pow(inter, steps)
            if self.args.double:
                # Double Q Learning
                new_states_qvals = self.dqn(new_states).cpu()
                new_states_qvals_data = Variable(new_states_qvals.data)
                q_value_targets = q_value_targets * target_dqn_qvals_data.gather(1, new_states_qvals_data.max(1)[1])
            else:
                q_value_targets = q_value_targets * target_dqn_qvals_data.max(1)[0]
            q_value_targets = q_value_targets + rewards

            self.dqn.train()

            one_hot_actions = torch.zeros(self.args.batch_size, self.args.actions)

            for i in range(self.args.batch_size):
                one_hot_actions[i][actions[i].data] = 1

            if self.args.gpu:
                actions = actions.cuda()
                one_hot_actions = one_hot_actions.cuda()
                q_value_targets = q_value_targets.cuda()
                new_states = new_states.cuda()
            model_predictions_q_vals, model_predictions_state = self.dqn(states, Variable(one_hot_actions))
            model_predictions = model_predictions_q_vals.gather(1, actions.view(-1, 1))

            # info = {}

            td_error = model_predictions - q_value_targets
            info["TD_Error"] = td_error.mean().data[0]

            # Update the priorities
            if not self.args.density_priority:
                self.replay.Update_Indices(indices, td_error.cpu().data.numpy(), no_pseudo_in_priority=self.args.count_td_priority)

            # If using prioritised we need to weight the td_error
            if self.args.prioritized and self.args.prioritized_is:
                # print(td_error)
                weights_tensor = torch.from_numpy(is_weights).float()
                weights_tensor = Variable(weights_tensor)
                if self.args.gpu:
                    weights_tensor = weights_tensor.cuda()
                # print(weights_tensor)
                td_error = td_error * weights_tensor

            # Model 1 step state transition error

            # Save them every x steps
            if self.T % self.args.model_save_image == 0:
                os.makedirs("{}/transition_model/{}".format(self.args.log_path, self.T))
                for ii, image, action, next_state, current_state in zip(range(self.args.batch_size), model_predictions_state.cpu().data, actions.data, new_states.cpu().data, states.cpu().data):
                    image = image.numpy()[0]
                    image = np.clip(image, 0, 1)
                    # print(next_state)
                    next_state = next_state.numpy()[0]
                    current_state = current_state.numpy()[0]

                    black_bars = np.zeros_like(next_state[:1, :])
                    # print(black_bars.shape)

                    joined_image = np.concatenate((current_state, black_bars, image, black_bars, next_state), axis=0)
                    joined_image = np.transpose(joined_image)
                    self.log_image("{}/transition_model/{}/{}_____Action_{}".format(self.args.log_path, self.T, ii + 1, action), joined_image * 255)

                    # self.log_image("{}/transition_model/{}/{}_____Action_{}".format(self.args.log_path, self.T, ii + 1, action), image * 255)
                    # self.log_image("{}/transition_model/{}/{}_____Correct".format(self.args.log_path, self.T, ii + 1), next_state * 255)

            # print(model_predictions_state)

            # Cross Entropy Loss
            # TODO

            # Regresssion loss
            state_error = model_predictions_state - new_states
            # state_error_val = state_error.mean().data[0]

            info["State_Error"] = state_error.mean().data[0]
            self.log("DQN/State_Loss", state_error.mean().data[0], step=self.T)
            self.log("DQN/State_Loss_Squared", state_error.pow(2).mean().data[0], step=self.T)
            self.log("DQN/State_Loss_Max", state_error.abs().max().data[0], step=self.T)
            # self.log("DQN/Action_Matrix_Norm", self.dqn.action_matrix.weight.norm().cpu().data[0], step=self.T)

            combined_loss = (1 - self.args.model_loss) * td_error.pow(2).mean() + (self.args.model_loss) * state_error.pow(2).mean()
            l2_loss = combined_loss
            # l2_loss = (combined_loss).pow(2).mean()
            info["Loss"] = l2_loss.data[0]

            # Update
            self.optimizer.zero_grad()
            l2_loss.backward()

            # Taken from pytorch clip_grad_norm
            # Remove once the pip version it up to date with source
            gradient_norm = clip_grad_norm(self.dqn.parameters(), self.args.clip_value)
            if gradient_norm is not None:
                info["Norm"] = gradient_norm

            self.optimizer.step()

            if "States" in info:
                states_trained = info["States"]
                info["States"] = states_trained + columns[0]
            else:
                info["States"] = columns[0]

        # Pad out the states to be of size batch_size
        if len(info["States"]) < self.args.batch_size:
            old_states = info["States"]
            new_states = old_states[0] * (self.args.batch_size - len(old_states))
            info["States"] = new_states

        return info
Ejemplo n.º 35
0
class LIIRLearner:
    def __init__(self, mac, scheme, logger, args):
        self.args = args
        self.n_agents = args.n_agents
        self.n_actions = args.n_actions
        self.mac = mac
        self.logger = logger

        self.last_target_update_step = 0
        self.critic_training_steps = 0

        self.log_stats_t = -self.args.learner_log_interval - 1

        self.critic = LIIRCritic(scheme, args)
        self.target_critic = copy.deepcopy(self.critic)

        self.policy_new = copy.deepcopy(self.mac)
        self.policy_old = copy.deepcopy(self.mac)

        if self.args.use_cuda:
            # following two lines should be used when use GPU 
            self.policy_old.agent = self.policy_old.agent.to("cuda")
            self.policy_new.agent = self.policy_new.agent.to("cuda")
        else:
            # following lines should be used when use CPU, 
            self.policy_old.agent = self.policy_old.agent.to("cpu")
            self.policy_new.agent = self.policy_new.agent.to("cpu")

        self.agent_params = list(mac.parameters())
        self.critic_params = list(self.critic.fc1.parameters()) + list(self.critic.fc2.parameters()) + list(
            self.critic.fc3_v_mix.parameters())
        self.intrinsic_params = list(self.critic.fc3_r_in.parameters()) + list(self.critic.fc4.parameters())  # to do
        self.params = self.agent_params + self.critic_params + self.intrinsic_params

        self.agent_optimiser = RMSprop(params=self.agent_params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps)
        self.critic_optimiser = RMSprop(params=self.critic_params, lr=args.critic_lr, alpha=args.optim_alpha,
                                        eps=args.optim_eps)
        self.intrinsic_optimiser = RMSprop(params=self.intrinsic_params, lr=args.critic_lr, alpha=args.optim_alpha,
                                           eps=args.optim_eps)  # should distinguish them
        self.update = 0
        self.count = 0

    def train(self, batch: EpisodeBatch, t_env: int, episode_num: int, nupdate: int):
        # Get the relevant quantities
        bs = batch.batch_size
        max_t = batch.max_seq_length
        rewards = batch["reward"][:, :-1]
        actions = batch["actions"][:, :]
        terminated = batch["terminated"][:, :-1].float()
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        avail_actions = batch["avail_actions"][:, :-1]

        critic_mask = mask.clone()

        mask_long = mask.repeat(1, 1, self.n_agents).view(-1, 1)
        mask = mask.view(-1, 1)

        avail_actions1 = avail_actions.reshape(-1, self.n_agents, self.n_actions)  # [maskxx,:]
        mask_alive = 1.0 - avail_actions1[:, :, 0]
        mask_alive = mask_alive.float()

        q_vals, critic_train_stats, target_mix, target_ex, v_ex, r_in = self._train_critic(batch, rewards, terminated,                                                                                        actions, avail_actions,
                                                                                           critic_mask, bs, max_t)

        actions = actions[:, :-1]

        mac_out = []
        self.mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length - 1):
            agent_outs = self.mac.forward(batch, t=t)
            mac_out.append(agent_outs)
        mac_out = th.stack(mac_out, dim=1)  # Concat over time
        # Mask out unavailable actions, renormalise (as in action selection)
        mac_out[avail_actions == 0] = 0
        mac_out = mac_out / mac_out.sum(dim=-1, keepdim=True)
        mac_out[avail_actions == 0] = 0

        # Calculated baseline
        q_vals = q_vals.reshape(-1, 1)
        pi = mac_out.view(-1, self.n_actions)
        # Calculate policy grad with mask
        pi_taken = th.gather(pi, dim=1, index=actions.reshape(-1, 1)).squeeze(1)
        pi_taken[mask_long.squeeze(-1) == 0] = 1.0
        log_pi_taken = th.log(pi_taken)

        advantages = (target_mix.reshape(-1, 1) - q_vals).detach()
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        log_pi_taken = log_pi_taken.reshape(-1, self.n_agents)
        log_pi_taken = log_pi_taken * mask_alive
        log_pi_taken = log_pi_taken.reshape(-1, 1)

        liir_loss = - ((advantages * log_pi_taken) * mask_long).sum() / mask_long.sum()

        # Optimise agents
        self.agent_optimiser.zero_grad()
        liir_loss.backward()
        grad_norm_policy = th.nn.utils.clip_grad_norm_(self.agent_params, self.args.grad_norm_clip)
        self.agent_optimiser.step()

        # _________Intrinsic loss optimizer --------------------
        # ____value loss
        v_ex_loss = (((v_ex - target_ex.detach()) ** 2).view(-1, 1) * mask).sum() / mask.sum()

        # _____pg1____
        mac_out_old = []
        self.policy_old.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length - 1):
            agent_outs_tmp = self.policy_old.forward(batch, t=t, test_mode=True)
            mac_out_old.append(agent_outs_tmp)
        mac_out_old = th.stack(mac_out_old, dim=1)  # Concat over time

        # Mask out unavailable actions, renormalise (as in action selection)
        mac_out_old[avail_actions == 0] = 0
        mac_out_old = mac_out_old / mac_out.sum(dim=-1, keepdim=True)
        mac_out_old[avail_actions == 0] = 0
        pi_old = mac_out_old.view(-1, self.n_actions)

        # Calculate policy grad with mask
        pi_taken_old = th.gather(pi_old, dim=1, index=actions.reshape(-1, 1)).squeeze(1)
        pi_taken_old[mask_long.squeeze(-1) == 0] = 1.0
        log_pi_taken_old = th.log(pi_taken_old)

        log_pi_taken_old = log_pi_taken_old.reshape(-1, self.n_agents)
        log_pi_taken_old = log_pi_taken_old * mask_alive

        # ______pg2___new pi theta
        self._update_policy()  # update policy_new to new params

        mac_out_new = []
        self.policy_new.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length - 1):
            agent_outs_tmp = self.policy_new.forward(batch, t=t, test_mode=True)
            mac_out_new.append(agent_outs_tmp)
        mac_out_new = th.stack(mac_out_new, dim=1)  # Concat over time

        # Mask out unavailable actions, renormalise (as in action selection)
        mac_out_new[avail_actions == 0] = 0
        mac_out_new = mac_out_new / mac_out.sum(dim=-1, keepdim=True)
        mac_out_new[avail_actions == 0] = 0

        pi_new = mac_out_new.view(-1, self.n_actions)
        # Calculate policy grad with mask
        pi_taken_new = th.gather(pi_new, dim=1, index=actions.reshape(-1, 1)).squeeze(1)
        pi_taken_new[mask_long.squeeze(-1) == 0] = 1.0
        log_pi_taken_new = th.log(pi_taken_new)

        log_pi_taken_new = log_pi_taken_new.reshape(-1, self.n_agents)
        log_pi_taken_new = log_pi_taken_new * mask_alive
        neglogpac_new = - log_pi_taken_new.sum(-1)

        pi2 = log_pi_taken.reshape(-1, self.n_agents).sum(-1).clone()
        ratio_new = th.exp(- pi2 - neglogpac_new)  
        
        adv_ex = (target_ex - v_ex.detach()).detach()
        adv_ex = (adv_ex - adv_ex.mean()) / (adv_ex.std() + 1e-8)

        # _______ gadient for pg 1 and 2---
        mask_tnagt = critic_mask.repeat(1, 1, self.n_agents)

        pg_loss1 = (log_pi_taken_old.view(-1, 1) * mask_long).sum() / mask_long.sum()
        pg_loss2 = ((adv_ex.view(-1) * ratio_new) * mask.squeeze(-1)).sum() / mask.sum()
        self.policy_old.agent.zero_grad()
        pg_loss1_grad = th.autograd.grad(pg_loss1, self.policy_old.parameters())

        self.policy_new.agent.zero_grad()
        pg_loss2_grad = th.autograd.grad(pg_loss2, self.policy_new.parameters())

        grad_total = 0
        for grad1, grad2 in zip(pg_loss1_grad, pg_loss2_grad):
            grad_total += (grad1 * grad2).sum()

        target_mix = target_mix.reshape(-1, max_t - 1, self.n_agents)
        pg_ex_loss = ((grad_total.detach() * target_mix) * mask_tnagt).sum() / mask_tnagt.sum()  

        intrinsic_loss = pg_ex_loss + vf_coef * v_ex_loss
        self.intrinsic_optimiser.zero_grad()
        intrinsic_loss.backward()

        self.intrinsic_optimiser.step()
            
        self._update_policy_piold()

        # ______config tensorboard
        if (self.critic_training_steps - self.last_target_update_step) / self.args.target_update_interval >= 1.0:
            self._update_targets()
            self.last_target_update_step = self.critic_training_steps

        if t_env - self.log_stats_t >= self.args.learner_log_interval:
            ts_logged = len(critic_train_stats["critic_loss"])
            for key in ["critic_loss", "critic_grad_norm", "td_error_abs", "value_mean", "target_mean"]:
                self.logger.log_stat(key, sum(critic_train_stats[key]) / ts_logged, t_env)

            self.logger.log_stat("advantage_mean", (advantages * mask_long).sum().item() / mask_long.sum().item(),
                                 t_env)
            self.logger.log_stat("liir_loss", liir_loss.item(), t_env)
            self.logger.log_stat("agent_grad_norm", grad_norm_policy, t_env)
            self.logger.log_stat("pi_max", (pi.max(dim=1)[0] * mask_long.squeeze(-1)).sum().item() / mask_long.sum().item(),
                                 t_env)

            reward1 = rewards.reshape(-1,1)
            self.logger.log_stat('rewards_mean', (reward1 * mask).sum().item() / mask.sum().item(), t_env)
            self.log_stats_t = t_env

    def _train_critic(self, batch, rewards, terminated, actions, avail_actions, mask, bs, max_t):
        # Optimise critic
        r_in, target_vals, target_val_ex = self.target_critic(batch)

        r_in, _, target_val_ex_opt = self.critic(batch)
        r_in_taken = th.gather(r_in, dim=3, index=actions)
        r_in = r_in_taken.squeeze(-1)

        target_vals = target_vals.squeeze(-1)

        targets_mix, targets_ex = build_td_lambda_targets(rewards, terminated, mask, target_vals, self.n_agents,
                                                          self.args.gamma, self.args.td_lambda, r_in, target_val_ex)

        vals_mix = th.zeros_like(target_vals)[:, :-1]
        vals_ex = target_val_ex_opt[:, :-1]

        running_log = {
            "critic_loss": [],
            "critic_grad_norm": [],
            "td_error_abs": [],
            "target_mean": [],
            "value_mean": [],
        }

        for t in reversed(range(rewards.size(1))):
            mask_t = mask[:, t].expand(-1, self.n_agents)
            if mask_t.sum() == 0:
                continue

            _, q_t, _ = self.critic(batch, t)  # 8,1,3,1,
            vals_mix[:, t] = q_t.view(bs, self.n_agents)
            targets_t = targets_mix[:, t]

            td_error = (q_t.view(bs, self.n_agents) - targets_t.detach())

            # 0-out the targets that came from padded data
            masked_td_error = td_error * mask_t

            # Normal L2 loss, take mean over actual data
            loss = (masked_td_error ** 2).sum() / mask_t.sum()
            self.critic_optimiser.zero_grad()
            loss.backward()
            grad_norm = th.nn.utils.clip_grad_norm_(self.critic_params, self.args.grad_norm_clip)
            self.critic_optimiser.step()
            self.critic_training_steps += 1

            running_log["critic_loss"].append(loss.item())
            running_log["critic_grad_norm"].append(grad_norm)
            mask_elems = mask_t.sum().item()
            running_log["td_error_abs"].append((masked_td_error.abs().sum().item() / mask_elems))
            running_log["value_mean"].append((q_t.view(bs, self.n_agents) * mask_t).sum().item() / mask_elems)
            running_log["target_mean"].append((targets_t * mask_t).sum().item() / mask_elems)

        return vals_mix, running_log, targets_mix, targets_ex, vals_ex, r_in

    def _update_targets(self):
        self.target_critic.load_state_dict(self.critic.state_dict())
        self.logger.console_logger.info("Updated target network")

    def _update_policy(self):
        self.policy_new.load_state(self.mac)

    def _update_policy_piold(self):
        self.policy_old.load_state(self.mac)

    def cuda(self):
        self.mac.cuda()
        self.critic.cuda()
        self.target_critic.cuda()

    def save_models(self, path):
        self.mac.save_models(path)
        th.save(self.critic.state_dict(), "{}/critic.th".format(path))
        th.save(self.agent_optimiser.state_dict(), "{}/agent_opt.th".format(path))
        th.save(self.critic_optimiser.state_dict(), "{}/critic_opt.th".format(path))

    def load_models(self, path):
        self.mac.load_models(path)
        self.critic.load_state_dict(th.load("{}/critic.th".format(path), map_location=lambda storage, loc: storage))
        self.target_critic.load_state_dict(self.critic.state_dict())
        self.agent_optimiser.load_state_dict(
            th.load("{}/agent_opt.th".format(path), map_location=lambda storage, loc: storage))
        self.critic_optimiser.load_state_dict(
            th.load("{}/critic_opt.th".format(path), map_location=lambda storage, loc: storage))
Ejemplo n.º 36
0
def main():
    gpu_num = int(sys.argv[1])
    random_seed = (int(time.time()) * (gpu_num + 1)) % (2 ** 31 - 1)

    np.random.seed(random_seed)
    random.seed(random_seed)
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)

    HYPERPARAMETERS = {
        'batch_size': choice([4096, 8192, 16384]),  # [8192//2, 8192*2]
        'nn_encoder_out': choice(list(range(10, 100))),  # [20,40]
        'enc_hidden_layer_k': choice(np.linspace(0.5, 4.0, 8)),
    # [0.5,4] denominator of nn 'nn_encoder_out' E.G. if 'nn_encoder_out' = 30 so every feature encoder hidden layer size will be 15
        'n_splits': 10,  # n folds
        'optimizer': 'adam',  # ['RMSprop','adam']
        'lr': choice(np.linspace(0.001, 0.01, 10)),  # [0.01,0.001]
        'use_dropout': choice([True,False]),
        'use_bn': choice([True,False]),
        'lr_sheduler_factor': choice(np.linspace(0.1, 0.9, 9)),  # [0.1,0.9]
        'lr_sheduler_patience': choice(list(range(3, 15))),  # [3,15]
        'lr_sheduler_min_lr': 0.0001,  # not so important but non't have to be too small
        'max_epoch': 9999,  # we want to use early_stop so just need to be big
        'early_stop_wait': 20,  # bigger - better but slower, but i guess 20 is okay
        'upsampling_times': choice(list(range(3, 20))),  # [3,20] more = slower
        'upsampling_class_balancer': choice(list(range(2, 10)))  # [0.1,7]
    }

    ans = {}
    for i in range(6):
        with open(f"../output/hpo_logs_{i}.json" ,"r") as f:
            for item in f.readlines():
                d = eval(item)
                ans[d["target"]] = d["params"]

    score = sorted(ans)[-gpu_num - 1]
    params = ans[score]

    params['batch_size'] = int(params['batch_size'])
    params['nn_encoder_out'] = int(params['nn_encoder_out'])
    params['lr_sheduler_patience'] = int(params['lr_sheduler_patience'])
    params['upsampling_times'] = int(params['upsampling_times'])
    params['upsampling_class_balancer'] = int(params['upsampling_class_balancer'])
    params['upsampling_class_balancer'] = min(params['upsampling_class_balancer'],
                                              params['upsampling_times'])
    params['use_bn'] = params['use_bn'] > 0.5
    params['use_dropout'] = params['use_dropout'] > 0.5

    for key in params:
        HYPERPARAMETERS[key] = params[key]

    with open(f"log_{gpu_num}.txt", "a") as f:
        for key in HYPERPARAMETERS:
            f.write(key + " " + str(HYPERPARAMETERS[key]) + "\n")
            print(key, HYPERPARAMETERS[key])

    print(score)
    print("\nSEED:", random_seed)
    print("GPU:", gpu_num, "\n")

    input_path = "../input/"
    output_path = "../output/"

    print("torch:", torch.__version__)
    print("loading data...")

    train_df = pd.read_csv(input_path + 'train.csv.zip')

    label = train_df.target
    train = train_df.drop(['ID_code', 'target'], axis=1)
    cols = train.columns

    test = pd.read_csv(input_path + 'test.csv.zip')
    test = test.drop(['ID_code'], axis=1)

    test_filtered = pd.read_pickle(input_path + 'test_filtered.pkl')
    test_filtered = test_filtered.loc[:, train.columns]

    train_test = pd.concat([train, test_filtered]).reset_index(drop=True)

    vcs_train_test = {}

    for col in tqdm(train.columns):
        vcs_train_test[col] = train_test.loc[:, col].value_counts()

    generate_features(test, vcs_train_test, cols)

    ups = UpsamplingPreprocessor(HYPERPARAMETERS['upsampling_times'], HYPERPARAMETERS['upsampling_class_balancer'])

    loss_f = BCEWithLogitsLoss()

    batch_size = HYPERPARAMETERS['batch_size']
    N_IN = 2

    gpu = torch.device(f'cuda:{gpu_num % 4}')
    cpu = torch.device('cpu')

    folds = StratifiedKFold(n_splits=HYPERPARAMETERS['n_splits'], shuffle=True, random_state=42)
    oof = np.zeros(len(train))
    predictions = np.zeros(len(test))
    for fold_, (trn_idx, val_idx) in enumerate(folds.split(train.values, label)):
        print("Fold {}".format(fold_))

        X_train, Train_label = ups.fit_transform(train.loc[trn_idx], label.loc[trn_idx])
        X_val, Val_label = train.loc[val_idx], label.loc[val_idx]
        generate_features(X_train, vcs_train_test, cols)
        generate_features(X_val, vcs_train_test, cols)
        cols_new = X_train.columns
        scaler = StandardScaler()
        X_train = pd.DataFrame(scaler.fit_transform(X_train), columns=cols_new)
        X_val = pd.DataFrame(scaler.transform(X_val), columns=cols_new)
        test_new = pd.DataFrame(scaler.transform(test), columns=cols_new)

        train_tensors = []
        val_tensors = []
        test_tensors = []

        for fff in range(200):
            cols_to_use = [f'var_{fff}', f'var_{fff}_1_flag']
            train_t = X_train.loc[:, cols_to_use].values
            val_t = X_val.loc[:, cols_to_use].values
            test_t = test_new.loc[:, cols_to_use].values
            train_tensors.append(torch.tensor(train_t, requires_grad=False, device=cpu, dtype=torch.float32))
            val_tensors.append(torch.tensor(val_t, requires_grad=False, device=cpu, dtype=torch.float32))

            test_tensors.append(torch.tensor(test_t, requires_grad=False, device=gpu, dtype=torch.float32))

        train_tensors = torch.cat(train_tensors, 1).view((-1, 200, N_IN))
        val_tensors = torch.cat(val_tensors, 1).view((-1, 200, N_IN))
        test_tensors = torch.cat(test_tensors, 1).view((-1, 200, N_IN))
        try:
            y_train_t = torch.tensor(Train_label, requires_grad=False, device=cpu, dtype=torch.float32)
        except:
            y_train_t = torch.tensor(Train_label.values, requires_grad=False, device=cpu, dtype=torch.float32)

        try:
            y_val_t = torch.tensor(Val_label, requires_grad=False, device=cpu, dtype=torch.float32)
        except:
            y_val_t = torch.tensor(Val_label.values, requires_grad=False, device=cpu, dtype=torch.float32)

        nn = NN(D_in=N_IN,
                enc_out=HYPERPARAMETERS['nn_encoder_out'],
                enc_hidden_layer_k=HYPERPARAMETERS['enc_hidden_layer_k'],
                use_dropout=HYPERPARAMETERS['use_dropout'],
                use_BN=HYPERPARAMETERS['use_bn']).to(gpu)

        if HYPERPARAMETERS['optimizer'] == 'adam':
            optimizer = Adam(params=nn.parameters(), lr=HYPERPARAMETERS['lr'])
        elif HYPERPARAMETERS['optimizer'] == 'RMSprop':
            optimizer = RMSprop(params=nn.parameters(), lr=HYPERPARAMETERS['lr'])

        scheduler = ReduceLROnPlateau(optimizer, 'max', factor=HYPERPARAMETERS['lr_sheduler_factor'],
                                      patience=HYPERPARAMETERS['lr_sheduler_patience'],
                                      min_lr=HYPERPARAMETERS['lr_sheduler_min_lr'], verbose=True)

        best_AUC = 0
        early_stop = 0

        for epoch in tqdm(range(HYPERPARAMETERS['max_epoch'])):
            nn.train()
            dl = batch_iter(train_tensors, y_train_t, batch_size=batch_size)
            for data, label_t in dl:
                pred = nn(data.to(gpu))
                loss = loss_f(pred, torch.unsqueeze(label_t.to(gpu), -1))
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            with torch.no_grad():
                nn.eval()
                blobs = []
                for batch in torch.split(val_tensors, batch_size):
                    blob = nn(batch.to(gpu)).data.cpu().numpy().flatten()
                    blobs.append(blob)
                val_pred = np.concatenate(blobs)
                AUC = roc_auc_score(label[val_idx], val_pred)
                print('EPOCH {}'.format(epoch))
                print('LOSS: ', loss_f(torch.tensor(val_pred), y_val_t))
                print('AUC: ', AUC)
                scheduler.step(AUC)

                if AUC > best_AUC:
                    early_stop = 0
                    best_AUC = AUC
                    torch.save(nn, output_path + f'best_auc_nn_{gpu_num}.pkl')
                else:
                    early_stop += 1
                    print('SCORE IS NOT THE BEST. Early stop counter: {}'.format(early_stop))

                if early_stop == HYPERPARAMETERS['early_stop_wait']:
                    print(f'EARLY_STOPPING NOW, BEST AUC = {best_AUC}')
                    break
                print('=' * 50)

            best_model = torch.load(output_path + f'best_auc_nn_{gpu_num}.pkl')

        with torch.no_grad():
            best_model.eval()
            blobs = []
            for batch in torch.split(val_tensors, batch_size):
                blob = best_model(batch.to(gpu)).data.cpu().numpy().flatten()
                blobs.append(blob)

            oof[val_idx] = np.concatenate(blobs)

            auc = round(roc_auc_score(Val_label, oof[val_idx]), 5)
            with open(f"log_{gpu_num}.txt", "a") as f:
                f.write(str(fold_) + " " + str(auc) + "\n")

            blobs = []
            for batch in torch.split(test_tensors, batch_size):
                blob = best_model(batch).data.cpu().numpy().flatten()
                blobs.append(blob)
        predictions_test = np.concatenate(blobs)

        predictions += predictions_test / folds.n_splits

    auc = round(roc_auc_score(label, oof), 5)
    print("CV score: {:<8.5f}".format(auc))

    with open(f"log_{gpu_num}.txt", "a") as f:
        f.write("OOF " + str(auc) + "\n")

    np.save(output_path + f"nn_{gpu_num}_{auc}_oof.npy", oof)
    np.save(output_path + f"nn_{gpu_num}_{auc}_test.npy", predictions)
Ejemplo n.º 37
0
class QLearner:
    def __init__(self, mac, scheme, logger, args):
        self.args = args
        self.mac = mac
        self.logger = logger

        self.params = list(mac.parameters())

        self.last_target_update_episode = 0

        self.mixer = None
        if args.mixer == "qtran_base":
            self.mixer = QTranBase(args)
        elif args.mixer == "qtran_alt":
            self.mixer = QTranAlt(args)

        self.params += list(self.mixer.parameters())
        self.target_mixer = copy.deepcopy(self.mixer)

        self.optimiser = RMSprop(params=self.params,
                                 lr=args.lr,
                                 alpha=args.optim_alpha,
                                 eps=args.optim_eps)

        # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC
        self.target_mac = copy.deepcopy(mac)

        self.log_stats_t = -self.args.learner_log_interval - 1

    def train(self, batch: EpisodeBatch, t_env: int, episode_num: int):
        # Get the relevant quantities
        rewards = batch["reward"][:, :-1]
        actions = batch["actions"][:, :-1]
        terminated = batch["terminated"][:, :-1].float()
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        avail_actions = batch["avail_actions"]

        # Calculate estimated Q-Values
        mac_out = []
        mac_hidden_states = []
        self.mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            agent_outs = self.mac.forward(batch, t=t)
            mac_out.append(agent_outs)
            mac_hidden_states.append(self.mac.hidden_states)
        mac_out = th.stack(mac_out, dim=1)  # Concat over time
        mac_hidden_states = th.stack(mac_hidden_states, dim=1)
        mac_hidden_states = mac_hidden_states.reshape(batch.batch_size,
                                                      self.args.n_agents,
                                                      batch.max_seq_length,
                                                      -1).transpose(1,
                                                                    2)  #btav

        # Pick the Q-Values for the actions taken by each agent
        chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3,
                                        index=actions).squeeze(
                                            3)  # Remove the last dim

        # Calculate the Q-Values necessary for the target
        target_mac_out = []
        target_mac_hidden_states = []
        self.target_mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            target_agent_outs = self.target_mac.forward(batch, t=t)
            target_mac_out.append(target_agent_outs)
            target_mac_hidden_states.append(self.target_mac.hidden_states)

        # We don't need the first timesteps Q-Value estimate for calculating targets
        target_mac_out = th.stack(target_mac_out[:],
                                  dim=1)  # Concat across time
        target_mac_hidden_states = th.stack(target_mac_hidden_states, dim=1)
        target_mac_hidden_states = target_mac_hidden_states.reshape(
            batch.batch_size, self.args.n_agents, batch.max_seq_length,
            -1).transpose(1, 2)  #btav

        # Mask out unavailable actions
        target_mac_out[avail_actions[:, :] == 0] = -9999999  # From OG deepmarl
        mac_out_maxs = mac_out.clone()
        mac_out_maxs[avail_actions == 0] = -9999999

        # Best joint action computed by target agents
        target_max_actions = target_mac_out.max(dim=3, keepdim=True)[1]
        # Best joint-action computed by regular agents
        max_actions_qvals, max_actions_current = mac_out_maxs[:, :].max(
            dim=3, keepdim=True)

        if self.args.mixer == "qtran_base":
            # -- TD Loss --
            # Joint-action Q-Value estimates
            joint_qs, vs = self.mixer(batch[:, :-1], mac_hidden_states[:, :-1])

            # Need to argmax across the target agents' actions to compute target joint-action Q-Values
            if self.args.double_q:
                max_actions_current_ = th.zeros(
                    size=(batch.batch_size, batch.max_seq_length,
                          self.args.n_agents, self.args.n_actions),
                    device=batch.device)
                max_actions_current_onehot = max_actions_current_.scatter(
                    3, max_actions_current[:, :], 1)
                max_actions_onehot = max_actions_current_onehot
            else:
                max_actions = th.zeros(
                    size=(batch.batch_size, batch.max_seq_length,
                          self.args.n_agents, self.args.n_actions),
                    device=batch.device)
                max_actions_onehot = max_actions.scatter(
                    3, target_max_actions[:, :], 1)
            target_joint_qs, target_vs = self.target_mixer(
                batch[:, 1:],
                hidden_states=target_mac_hidden_states[:, 1:],
                actions=max_actions_onehot[:, 1:])

            # Td loss targets
            td_targets = rewards.reshape(-1, 1) + self.args.gamma * (
                1 - terminated.reshape(-1, 1)) * target_joint_qs
            td_error = (joint_qs - td_targets.detach())
            masked_td_error = td_error * mask.reshape(-1, 1)
            td_loss = (masked_td_error**2).sum() / mask.sum()
            # -- TD Loss --

            # -- Opt Loss --
            # Argmax across the current agents' actions
            if not self.args.double_q:  # Already computed if we're doing double Q-Learning
                max_actions_current_ = th.zeros(
                    size=(batch.batch_size, batch.max_seq_length,
                          self.args.n_agents, self.args.n_actions),
                    device=batch.device)
                max_actions_current_onehot = max_actions_current_.scatter(
                    3, max_actions_current[:, :], 1)
            max_joint_qs, _ = self.mixer(
                batch[:, :-1],
                mac_hidden_states[:, :-1],
                actions=max_actions_current_onehot[:, :-1]
            )  # Don't use the target network and target agent max actions as per author's email

            # max_actions_qvals = th.gather(mac_out[:, :-1], dim=3, index=max_actions_current[:,:-1])
            opt_error = max_actions_qvals[:, :-1].sum(dim=2).reshape(
                -1, 1) - max_joint_qs.detach() + vs
            masked_opt_error = opt_error * mask.reshape(-1, 1)
            opt_loss = (masked_opt_error**2).sum() / mask.sum()
            # -- Opt Loss --

            # -- Nopt Loss --
            # target_joint_qs, _ = self.target_mixer(batch[:, :-1])
            nopt_values = chosen_action_qvals.sum(dim=2).reshape(
                -1, 1) - joint_qs.detach(
                ) + vs  # Don't use target networks here either
            nopt_error = nopt_values.clamp(max=0)
            masked_nopt_error = nopt_error * mask.reshape(-1, 1)
            nopt_loss = (masked_nopt_error**2).sum() / mask.sum()
            # -- Nopt loss --

        elif self.args.mixer == "qtran_alt":
            raise Exception("Not supported yet.")
            counter_qs, vs = self.mixer(batch[:, :-1])

            # Need to argmax across the target agents' actions
            # Convert cur_max_actions to one hot
            max_actions = th.zeros(
                size=(batch.batch_size, batch.max_seq_length - 1,
                      self.args.n_agents, self.args.n_actions),
                device=batch.device)
            max_actions_onehot = max_actions.scatter(3, target_max_actions, 1)
            max_actions_onehot_repeat = max_actions_onehot.repeat(
                1, 1, self.args.n_agents, 1)
            agent_mask = (1 - th.eye(self.args.n_agents, device=batch.device))
            agent_mask = agent_mask.view(-1, 1).repeat(
                1, self.args.n_actions)  #.view(self.n_agents, -1)
            masked_actions = max_actions_onehot_repeat * agent_mask.unsqueeze(
                0).unsqueeze(0)
            masked_actions = masked_actions.view(
                -1, self.args.n_agents * self.args.n_actions)
            target_counter_qs, target_vs = self.target_mixer(
                batch[:, 1:], masked_actions)

            # Td loss
            td_target_qs = target_counter_qs.gather(
                1, target_max_actions.view(-1, 1))
            td_chosen_qs = counter_qs.gather(1,
                                             actions.contiguous().view(-1, 1))
            td_targets = rewards.repeat(1, 1, self.args.n_agents).view(
                -1, 1) + self.args.gamma * (1 - terminated.repeat(
                    1, 1, self.args.n_agents).view(-1, 1)) * td_target_qs
            td_error = (td_chosen_qs - td_targets.detach())

            td_mask = mask.repeat(1, 1, self.args.n_agents).view(-1, 1)
            masked_td_error = td_error * td_mask

            td_loss = (masked_td_error**2).sum() / td_mask.sum()

            # Opt loss
            # Computing the targets
            opt_max_actions = th.zeros(
                size=(batch.batch_size, batch.max_seq_length - 1,
                      self.args.n_agents, self.args.n_actions),
                device=batch.device)
            opt_max_actions_onehot = opt_max_actions.scatter(
                3, max_actions_current, 1)
            opt_max_actions_onehot_repeat = opt_max_actions_onehot.repeat(
                1, 1, self.args.n_agents, 1)
            agent_mask = (1 - th.eye(self.args.n_agents, device=batch.device))
            agent_mask = agent_mask.view(-1, 1).repeat(1, self.args.n_actions)
            opt_masked_actions = opt_max_actions_onehot_repeat * agent_mask.unsqueeze(
                0).unsqueeze(0)
            opt_masked_actions = opt_masked_actions.view(
                -1, self.args.n_agents * self.args.n_actions)

            opt_target_qs, opt_vs = self.mixer(batch[:, :-1],
                                               opt_masked_actions)

            opt_error = max_actions_qvals.squeeze(3).sum(
                dim=2, keepdim=True).repeat(1, 1, self.args.n_agents).view(
                    -1, 1) - opt_target_qs.gather(
                        1, max_actions_current.view(-1, 1)).detach() + opt_vs
            opt_loss = ((opt_error * td_mask)**2).sum() / td_mask.sum()

            # NOpt loss
            qsums = chosen_action_qvals.clone().unsqueeze(2).repeat(
                1, 1, self.args.n_agents, 1).view(-1, self.args.n_agents)
            ids_to_zero = th.tensor([i for i in range(self.args.n_agents)],
                                    device=batch.device).repeat(
                                        batch.batch_size *
                                        (batch.max_seq_length - 1))
            qsums.scatter(1, ids_to_zero.unsqueeze(1), 0)
            nopt_error = mac_out[:, :-1].contiguous().view(
                -1, self.args.n_actions) + qsums.sum(
                    dim=1, keepdim=True) - counter_qs.detach() + opt_vs
            min_nopt_error = th.min(nopt_error, dim=1, keepdim=True)[0]
            nopt_loss = ((min_nopt_error * td_mask)**2).sum() / td_mask.sum()

        loss = td_loss + self.args.opt_loss * opt_loss + self.args.nopt_min_loss * nopt_loss

        # Optimise
        self.optimiser.zero_grad()
        loss.backward()
        grad_norm = th.nn.utils.clip_grad_norm_(self.params,
                                                self.args.grad_norm_clip)
        self.optimiser.step()

        if (episode_num - self.last_target_update_episode
            ) / self.args.target_update_interval >= 1.0:
            self._update_targets()
            self.last_target_update_episode = episode_num

        if t_env - self.log_stats_t >= self.args.learner_log_interval:
            self.logger.log_stat("loss", loss.item(), t_env)
            self.logger.log_stat("td_loss", td_loss.item(), t_env)
            self.logger.log_stat("opt_loss", opt_loss.item(), t_env)
            self.logger.log_stat("nopt_loss", nopt_loss.item(), t_env)
            self.logger.log_stat("grad_norm", grad_norm, t_env)
            if self.args.mixer == "qtran_base":
                mask_elems = mask.sum().item()
                self.logger.log_stat(
                    "td_error_abs",
                    (masked_td_error.abs().sum().item() / mask_elems), t_env)
                self.logger.log_stat(
                    "td_targets",
                    ((masked_td_error).sum().item() / mask_elems), t_env)
                self.logger.log_stat("td_chosen_qs",
                                     (joint_qs.sum().item() / mask_elems),
                                     t_env)
                self.logger.log_stat("v_mean", (vs.sum().item() / mask_elems),
                                     t_env)
                self.logger.log_stat(
                    "agent_indiv_qs",
                    ((chosen_action_qvals * mask).sum().item() /
                     (mask_elems * self.args.n_agents)), t_env)
            elif self.args.mixer == "qtran_alt":
                mask_elems = mask.sum().item()
                mask_td_elems = td_mask.sum().item()
                self.logger.log_stat(
                    "td_error_abs",
                    (masked_td_error.abs().sum().item() / mask_td_elems),
                    t_env)
                self.logger.log_stat("q_taken_mean",
                                     (td_chosen_qs * td_mask).sum().item() /
                                     (mask_td_elems), t_env)
                self.logger.log_stat("target_mean",
                                     (td_targets * td_mask).sum().item() /
                                     mask_td_elems, t_env)
                self.logger.log_stat(
                    "agent_qs_mean",
                    (chosen_action_qvals * mask).sum().item() /
                    (mask_elems * self.args.n_agents), t_env)
                self.logger.log_stat("v_mean", (vs * td_mask).sum().item() /
                                     (mask_td_elems), t_env)
            self.log_stats_t = t_env

    def _update_targets(self):
        self.target_mac.load_state(self.mac)
        if self.mixer is not None:
            self.target_mixer.load_state_dict(self.mixer.state_dict())
        self.logger.console_logger.info("Updated target network")

    def cuda(self):
        self.mac.cuda()
        self.target_mac.cuda()
        if self.mixer is not None:
            self.mixer.cuda()
            self.target_mixer.cuda()

    def save_models(self, path):
        self.mac.save_models(path)
        if self.mixer is not None:
            th.save(self.mixer.state_dict(), "{}/mixer.th".format(path))
        th.save(self.optimiser.state_dict(), "{}/opt.th".format(path))

    def load_models(self, path):
        self.mac.load_models(path)
        # Not quite right but I don't want to save target networks
        self.target_mac.load_models(path)
        if self.mixer is not None:
            self.mixer.load_state_dict(
                th.load("{}/mixer.th".format(path),
                        map_location=lambda storage, loc: storage))
        self.optimiser.load_state_dict(
            th.load("{}/opt.th".format(path),
                    map_location=lambda storage, loc: storage))
Ejemplo n.º 38
0
class NEC_Agent:

    def __init__(self, args, exp_model, logging_func):
        self.args = args

        # Exploration Model
        self.exp_model = exp_model

        self.log = logging_func["log"]

        # Experience Replay
        self.replay = ExpReplay(args.exp_replay_size, args)
        self.dnds = [DND(kernel=kernel, num_neighbors=args.nec_neighbours, max_memory=args.dnd_size, embedding_size=args.nec_embedding) for _ in range(self.args.actions)]

        # DQN and Target DQN
        model = get_models(args.model)
        self.embedding = model(embedding=args.nec_embedding)

        embedding_params = 0
        for weight in self.embedding.parameters():
            weight_params = 1
            for s in weight.size():
                weight_params *= s
            embedding_params += weight_params
        print("Embedding Network has {:,} parameters.".format(embedding_params))

        if args.gpu:
            print("Moving models to GPU.")
            self.embedding.cuda()

        # Optimizer
        self.optimizer = RMSprop(self.embedding.parameters(), lr=args.lr)
        # self.optimizer = Adam(self.embedding.parameters(), lr=args.lr)

        self.T = 0
        self.target_sync_T = -self.args.t_max

        self.experiences = []
        self.keys = []
        self.q_val_estimates = []

        self.table_updates = 0

    def Q_Value_Estimates(self, state):
        # Get state embedding
        state = torch.from_numpy(state).float().transpose_(0, 2).unsqueeze(0)
        key = self.embedding(Variable(state, volatile=True)).cpu()

        if (key != key).sum().data[0] > 0:
            pass
            # print(key)
            # for param in self.embedding.parameters():
                # print(param)
            # print(key != key)
            # print((key != key).sum().data[0])
            # print("Nan key")

        estimate_from_dnds = torch.cat([dnd.lookup(key) for dnd in self.dnds])
        # print(estimate_from_dnds)

        self.keys.append(key.data[0].numpy())
        self.q_val_estimates.append(estimate_from_dnds.data.numpy())

        return estimate_from_dnds, key
        # return np.array(estimate_from_dnds), key

    def act(self, state, epsilon, exp_model):

        q_values, key = self.Q_Value_Estimates(state)
        q_values_numpy = q_values.data.numpy()

        extra_info = {}
        extra_info["Q_Values"] = q_values_numpy

        if np.random.random() < epsilon:
            action = np.random.randint(low=0, high=self.args.actions)
        else:
            action = np.argmax(q_values_numpy)

        extra_info["Action"] = action

        return action, extra_info

    def experience(self, state, action, reward, state_next, steps, terminated, pseudo_reward=0, density=1, exploring=False):

        experience = (state, action, reward, pseudo_reward, state_next, terminated)
        self.experiences.append(experience)

        if len(self.experiences) >= self.args.n_step:
            self.add_experience()

        if not exploring:
            self.T += 1

    def end_of_trajectory(self):
        self.replay.end_of_trajectory()

        # Go through the experiences and add them to the replay using a less than N-step Q-Val estimate
        while len(self.experiences) > 0:
            self.add_experience()

    def add_experience(self):

        # Match the key and q val estimates size to the number of experieneces
        N = len(self.experiences)
        self.keys = self.keys[-N:]
        self.q_val_estimates = self.q_val_estimates[-N:]

        first_state = self.experiences[0][0]
        first_action = self.experiences[0][1]
        last_state = self.experiences[-1][4]
        terminated_last_state = self.experiences[-1][5]
        accum_reward = 0
        for ex in reversed(self.experiences):
            r = ex[2]
            pr = ex[3]
            accum_reward = (r + pr) + self.args.gamma * accum_reward
        # if accum_reward > 1000:
            # print(accum_reward)
        if terminated_last_state:
            last_state_max_q_val = 0
        else:
            # last_state_q_val_estimates, last_state_key = self.Q_Value_Estimates(last_state)
            # last_state_max_q_val = last_state_q_val_estimates.data.max(0)[0][0]
            last_state_max_q_val = np.max(self.q_val_estimates[-1])
            # print(last_state_max_q_val)

        # first_state_q_val_estimates, first_state_key = self.Q_Value_Estimates(first_state)
        # first_state_key = first_state_key.data[0].numpy()
        first_state_key = self.keys[0]

        n_step_q_val_estimate = accum_reward + (self.args.gamma ** len(self.experiences)) * last_state_max_q_val
        n_step_q_val_estimate = n_step_q_val_estimate
        # print(n_step_q_val_estimate)

        # Add to dnd
        # print(first_state_key)
        # print(tuple(first_state_key.data[0]))
        # if any(np.isnan(first_state_key)):
            # print("NAN")
        if self.dnds[first_action].is_present(key=first_state_key):
            current_q_val = self.dnds[first_action].get_value(key=first_state_key)
            new_q_val = current_q_val + self.args.nec_alpha * (n_step_q_val_estimate - current_q_val)
            self.dnds[first_action].upsert(key=first_state_key, value=new_q_val)
            self.table_updates += 1
            self.log("NEC/Table_Updates", self.table_updates, step=self.T)
        else:
            self.dnds[first_action].upsert(key=first_state_key, value=n_step_q_val_estimate)

        # Add to replay
        self.replay.Add_Exp(first_state, first_action, n_step_q_val_estimate)

        # Remove first experience
        self.experiences = self.experiences[1:]

    def train(self):

        info = {}
        if self.T % self.args.nec_update != 0:
            return info

        # print("Training")

        for _ in range(self.args.iters):

            # TODO: Use a named tuple for experience replay
            batch = self.replay.Sample(self.args.batch_size)
            columns = list(zip(*batch))

            states = Variable(torch.from_numpy(np.array(columns[0])).float().transpose_(1, 3))
            # print(states)
            actions = columns[1]
            # print(actions)
            targets = Variable(torch.FloatTensor(columns[2]))
            # print(targets)
            keys = self.embedding(states).cpu()
            # print(keys)
            # print("Keys", keys.requires_grad)
            # for action in actions:
                # print(action)
            # for action, key in zip(actions, keys):
                # print(action, key)
                # kk = key.unsqueeze(0)
                # print("kk", kk.requires_grad)
                # k = self.dnds[action].lookup(key.unsqueeze(0))
                # print("key", key.requires_grad, key.volatile)
            model_predictions = torch.cat([self.dnds[action].lookup(key.unsqueeze(0)) for action, key in zip(actions, keys)])
            # print(model_predictions)
            # print(targets)

            td_error = model_predictions - targets
            # print(td_error)
            info["TD_Error"] = td_error.mean().data[0]

            l2_loss = (td_error).pow(2).mean()
            info["Loss"] = l2_loss.data[0]

            # Update
            self.optimizer.zero_grad()

            l2_loss.backward()

            # Taken from pytorch clip_grad_norm
            # Remove once the pip version it up to date with source
            gradient_norm = clip_grad_norm(self.embedding.parameters(), self.args.clip_value)
            if gradient_norm is not None:
                info["Norm"] = gradient_norm

            self.optimizer.step()

            if "States" in info:
                states_trained = info["States"]
                info["States"] = states_trained + columns[0]
            else:
                info["States"] = columns[0]

        return info