示例#1
0
    def __init__(self,
                 config: BasicConfigSAC,
                 environment: Env,
                 log_path: str = None,
                 logging: bool = True):
        """
        TODO: Write docstring
        """
        log_path = generate_experiment_signature(
            environment) if log_path is None else log_path
        self.config = config
        self.monitor = Monitor(log_path, config, logging=logging)
        self.env = environment

        self.batch_size = config.learner.batch_size
        self.episode_horizon = config.episode_horizon
        self.steps_before_learn = config.steps_before_learn

        self.memory_buffer = MemoryBuffer(max_memory_size=config.memory_size)

        self.agent = AgentSAC(environment, config.policy)

        self.learner = LearnerSAC(config=config.learner,
                                  agent=self.agent,
                                  enviroment=self.env,
                                  monitor=self.monitor)
示例#2
0
    def __init__(self, action_dim, state_dim, params):
        """ Initialization
        """
        # session = K.get_session()
        # Environment and DDQN parameters
        self.with_per = params["with_per"]
        self.action_dim = action_dim
        self.state_dim = state_dim

        self.lr = 2.5e-4
        self.gamma = 0.95
        self.epsilon = params["epsilon"]
        self.epsilon_decay = params["epsilon_decay"]
        self.epsilon_minimum = 0.05
        self.buffer_size = 10000
        self.tau = 1.0
        self.agent = Agent(self.state_dim, action_dim, self.lr, self.tau,
                           params["dueling"])
        # Memory Buffer for Experience Replay
        self.buffer = MemoryBuffer(self.buffer_size, self.with_per)

        exp_dir = 'test/models/'
        if not os.path.exists(exp_dir):
            os.makedirs(exp_dir)
        self.export_path = exp_dir + '/lala.h5'
        self.save_interval = params["save_interval"]
    def __init__(self, memory, start_address):
        self.memory = memory
        self.start_address = start_address
        self.register_file = RegisterFile()
        self.data_memory_key_fn = lambda: -777
        self.data_memory = defaultdict(self.data_memory_key_fn)

        self.cycle_count = 0
        self.instr_count = 0
        self.PC = 0

        self.fetch_input_buffer = FetchInputBuffer({
            'PC':
            self.start_address,
            'instr_count':
            self.instr_count,
        })
        self.fetcher_buffer = FetcherBuffer()
        self.fetch_stage = FetchStage(self.memory, self.fetch_input_buffer,
                                      self.fetcher_buffer)

        self.decoder_buffer = DecoderBuffer()
        self.decode_stage = DecodeStage(self.fetcher_buffer,
                                        self.decoder_buffer,
                                        self.register_file)

        self.executer_buffer = ExecuterBuffer()
        self.execute_stage = ExecuteStage(self.decoder_buffer,
                                          self.executer_buffer)
        self.memory_buffer = MemoryBuffer()
        self.memory_stage = MemoryStage(self.executer_buffer,
                                        self.memory_buffer, self.data_memory)
        self.write_back_stage = WriteBackStage(self.memory_buffer,
                                               self.register_file)
    def __init__(self,
                 load_policy=False,
                 learning_rate=0.001,
                 dim_a=3,
                 fc_layers_neurons=100,
                 loss_function_type='mean_squared',
                 policy_loc='./racing_car_m2/network',
                 image_size=64,
                 action_upper_limits='1,1',
                 action_lower_limits='-1,-1',
                 e='1',
                 show_ae_output=True,
                 show_state=True,
                 resize_observation=True,
                 ae_training_threshold=0.0011,
                 ae_evaluation_frequency=40):

        self.image_size = image_size

        super(Agent, self).__init__(dim_a=dim_a,
                                    policy_loc=policy_loc,
                                    action_upper_limits=action_upper_limits,
                                    action_lower_limits=action_lower_limits,
                                    e=e,
                                    load_policy=load_policy,
                                    loss_function_type=loss_function_type,
                                    learning_rate=learning_rate,
                                    fc_layers_neurons=fc_layers_neurons)

        # High-dimensional state initialization
        self.resize_observation = resize_observation
        self.show_state = show_state
        self.show_ae_output = show_ae_output

        # Autoencoder training control variables
        self.ae_training = True
        self.ae_loss_history = MemoryBuffer(
            min_size=50,
            max_size=50)  # reuse memory buffer for the ae loss history
        self.ae_trainig_threshold = ae_training_threshold
        self.ae_evaluation_frequency = ae_evaluation_frequency
        self.mean_ae_loss = 1e7

        if self.show_state:
            self.state_plot = FastImagePlot(1,
                                            np.zeros([image_size, image_size]),
                                            image_size,
                                            'Image State',
                                            vmax=0.5)

        if self.show_ae_output:
            self.ae_output_plot = FastImagePlot(2,
                                                np.zeros(
                                                    [image_size, image_size]),
                                                image_size,
                                                'Autoencoder Output',
                                                vmax=0.5)
示例#5
0
 def test_memory_buffer_size(self):
     info_set_size = 1 + 2 + 5 + 24
     item_size = 64
     max_size = int(1e6)
     mb = MemoryBuffer(info_set_size, item_size, max_size=max_size)
     print(mb._infosets.dtype)
     print(mb._items.dtype)
     print(mb._weights.dtype)
     print("Memory buffer size (max_size={}): {} mb".format(
         max_size, mb.size_mb()))
示例#6
0
    def __init__(self,
                 n_state,
                 n_action,
                 a_bound,
                 discount=0.99,
                 tau=0.05,
                 actor_lr=0.001,
                 critic_lr=0.001,
                 policy_freq=2,
                 exp_noise_std=0.1,
                 noise_decay=0.9995,
                 noise_decay_steps=1000,
                 smooth_noise_std=0.1,
                 clip=0.2,
                 buffer_size=20000,
                 save_interval=5000,
                 assess_interval=20,
                 logger=None,
                 checkpoint_queen=None):
        #self.__dict__.update(locals())
        self.logger = logger
        self.logger.save_config(locals())
        self.n_action = n_action
        self.n_state = n_state
        self.a_bound = a_bound
        self.noise_std = exp_noise_std
        self.noise_decay = noise_decay
        self.noise_decay_steps = noise_decay_steps
        self.policy_freq = policy_freq
        self.smooth_noise_std = smooth_noise_std
        self.clip = clip
        self.discount = discount

        self.pointer = 0
        self.buffer = MemoryBuffer(buffer_size, with_per=True)
        self.save_interval = save_interval
        self.assess_interval = assess_interval
        self.actor = Actor(self.n_state,
                           self.n_action,
                           gamma=discount,
                           lr=actor_lr,
                           tau=tau)
        self.critic1 = Critic(self.n_state,
                              self.n_action,
                              gamma=discount,
                              lr=critic_lr,
                              tau=tau)
        self.critic2 = Critic(self.n_state,
                              self.n_action,
                              gamma=discount,
                              lr=critic_lr,
                              tau=tau)
        self.merge = self._merge_summary()
        self.ckpt_queen = checkpoint_queen
        self.prefix = self.__class__.__name__
    def __init__(self,
                 state_dim,
                 action_dim,
                 batchSize=64,
                 lr=.0001,
                 tau=.05,
                 gamma=.95,
                 epsilon=1,
                 eps_dec=.99,
                 learnInterval=1,
                 isDual=False,
                 isDueling=False,
                 isPER=False,
                 filename='model',
                 mem_size=1000000,
                 layerCount=2,
                 layerUnits=64,
                 usePruning=False):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.isDueling = isDueling
        self.isDual = isDual
        self.isPER = isPER
        self.lr = lr
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = eps_dec
        self.batchSize = batchSize
        self.filename = filename
        self.learnInterval = learnInterval
        # Initialize Deep Q-Network
        self.model = generateDQN(action_dim, lr, state_dim, isDueling,
                                 layerCount, layerUnits, usePruning)
        # Build target Q-Network
        self.target_model = generateDQN(action_dim, lr, state_dim, isDueling,
                                        layerCount, layerUnits, usePruning)
        self.layerCount = layerCount
        self.layerUnits = layerUnits
        self.target_model.set_weights(self.model.get_weights())
        self.memory = MemoryBuffer(mem_size, isPER)
        self.epsilonInitial = epsilon
        self.minEpsilon = .1
        self.usePruning = usePruning

        if isDual:
            self.tau = tau
        else:
            self.tau = 1.0

        # load memory data from disk if needed
        self.lastLearnIndex = self.memory.totalMemCount
示例#8
0
    def __init__(self, in_features, n_hidden, out_features):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features

        self.f_internal = torch.relu
        self.f_output = lambda x: x

        self.layers = []
        for i in range(len(n_hidden) + 1):
            if i == 0:
                inf = in_features
            else:
                inf = n_hidden[i - 1]
            if i == len(n_hidden):
                outf = out_features
            else:
                outf = n_hidden[i]

            self.layers.append(torch.nn.Linear(inf, outf))
            self.add_module(f'layer{i}', self.layers[i])

        self.memory_buffer = MemoryBuffer()
        self.use_memory_buffer = False
    def __init__ (self, memory, start_address):
        self.memory = memory
        self.start_address = start_address
        self.register_file = RegisterFile ()
        self.data_memory_key_fn = lambda: -777
        self.data_memory = defaultdict (self.data_memory_key_fn)

        self.cycle_count = 0
        self.instr_count = 0
        self.PC = 0

        self.fetch_input_buffer = FetchInputBuffer({
            'PC': self.start_address,
            'instr_count': self.instr_count,
            })
        self.fetcher_buffer = FetcherBuffer()
        self.fetch_stage = FetchStage(self.memory, 
                                      self.fetch_input_buffer, 
                                      self.fetcher_buffer)
        
        self.decoder_buffer = DecoderBuffer()
        self.decode_stage = DecodeStage(self.fetcher_buffer,
                                        self.decoder_buffer,
                                        self.register_file)
        
        self.executer_buffer = ExecuterBuffer()
        self.execute_stage = ExecuteStage(self.decoder_buffer,
                                          self.executer_buffer)
        self.memory_buffer = MemoryBuffer()
        self.memory_stage = MemoryStage(self.executer_buffer,
                                        self.memory_buffer,
                                        self.data_memory)
        self.write_back_stage = WriteBackStage(
            self.memory_buffer,
            self.register_file)
 def test_write_back_R(self):
     self.set_up_write_back_stage('R ADD  R1 R2 R3')
     expected_reg_value = self.memory_buffer.rd[1]
     self.write_back_stage.write_back()
     self.assertEqual(self.write_back_stage.memory_buffer, MemoryBuffer())
     self.assertEqual(self.write_back_stage.register_file[self.instr.rd],
                      expected_reg_value)
     self.assertTrue(self.register_file.isClean(self.instr.rd))
    def test_do_operand_forwarding_MEM(self):
        self.processor.executer_buffer = ExecuterBuffer({'rt': [1, None]})
        self.processor.memory_buffer = MemoryBuffer({'rt': [1, 3]})
        self.processor.do_operand_forwarding()
        self.assertEqual(self.processor.executer_buffer.rt, [1, 3])

        self.processor.executer_buffer = ExecuterBuffer({'rt': [2, None]})
        self.processor.do_operand_forwarding()
        self.assertEqual(self.processor.executer_buffer.rt, [2, None])
    def test_do_operand_forwarding(self):
        self.processor.decoder_buffer = DecoderBuffer({'rs': [2, None]})
        self.processor.executer_buffer = ExecuterBuffer({'rt': [2, 7]})
        self.processor.do_operand_forwarding()
        self.assertEqual(self.processor.decoder_buffer.rs, [2, 7])

        self.processor.decoder_buffer = DecoderBuffer({'rs': [2, None]})
        self.processor.executer_buffer = ExecuterBuffer()
        self.processor.memory_buffer = MemoryBuffer({'rd': [2, 9]})
        self.processor.do_operand_forwarding()
        self.assertEqual(self.processor.decoder_buffer.rs, [2, 9])
示例#13
0
    def get_stage_output(memory, register_file, pc, instr_count, stage_name):
        """Return the output buffer of stage given the initial conditions.
        
        All the stages before stage_name will be executed.
        
        Arguments:
        - `memory`:
        - `register_file`:
        - `pc`:
        - `stage_name`:

        TODO: Maybe just take the stages as input later.
        """
        fetch_input_buffer = FetchInputBuffer({
            'PC': pc,
            'instr_count': instr_count,
        })
        fetcher_buffer = FetcherBuffer()
        fetch_stage = FetchStage(memory, fetch_input_buffer, fetcher_buffer)
        fetch_stage.fetch_instruction()

        if stage_name == 'fetch':
            return fetch_stage.fetcher_buffer

        decode_stage = DecodeStage(fetch_stage.fetcher_buffer, DecoderBuffer(),
                                   register_file)
        decode_stage.decode_instruction()

        if stage_name == 'decode':
            return decode_stage.decoder_buffer

        execute_stage = ExecuteStage(decode_stage.decoder_buffer,
                                     ExecuterBuffer())
        execute_stage.execute()
        if stage_name == 'execute':
            return execute_stage.executer_buffer

        data_memory_key_fn = lambda: -1
        data_memory = defaultdict(data_memory_key_fn)

        memory_stage = MemoryStage(execute_stage.executer_buffer,
                                   MemoryBuffer(), data_memory)
        memory_stage.do_memory_operation()

        if stage_name == 'memory':
            return memory_stage.memory_buffer
示例#14
0
    def test_memory_buffer_autosave(self):
        print("\n ================= AUTOSAVE TEST ====================")
        # Make sure the folder doesn't exist so the manifest has to be created.
        if os.path.exists("./memory/memory_buffer_test/"):
            shutil.rmtree("./memory/memory_buffer_test/")
        info_set_size = 1 + 1 + 24
        item_size = 64
        max_size = int(1e3)

        # Add autosave params.
        mb = MemoryBuffer(info_set_size,
                          item_size,
                          max_size=max_size,
                          autosave_params=("./memory/memory_buffer_test/",
                                           "test_buffer"))

        for _ in range(max_size):
            mb.add(make_dummy_ev_infoset(), torch.zeros(item_size), 1234)
        self.assertTrue(mb.full())

        # This should trigger the save and reset.
        mb.add(make_dummy_ev_infoset(), torch.zeros(item_size), 1234)
示例#15
0
    def test_memory_buffer_save(self):
        # Make sure the folder doesn't exist so the manifest has to be created.
        if os.path.exists("./memory/memory_buffer_test/"):
            shutil.rmtree("./memory/memory_buffer_test/")
        info_set_size = 1 + 2 + 5 + 24
        item_size = 64
        max_size = int(1e6)
        mb = MemoryBuffer(info_set_size, item_size, max_size=max_size)
        mb.save("./memory/memory_buffer_test/", "test_buffer")

        self.assertTrue(
            os.path.exists(
                "./memory/memory_buffer_test/manifest_test_buffer.csv"))
        self.assertTrue(
            os.path.exists(
                "./memory/memory_buffer_test/test_buffer_00000.pth"))

        # Now save again.
        mb.save("./memory/memory_buffer_test/", "test_buffer")
        self.assertTrue(
            os.path.exists(
                "./memory/memory_buffer_test/test_buffer_00001.pth"))
示例#16
0
class TD3(object):
    """deep deterministic policy gradient
    """
    def __init__(self,
                 n_state,
                 n_action,
                 a_bound,
                 discount=0.99,
                 tau=0.05,
                 actor_lr=0.001,
                 critic_lr=0.001,
                 policy_freq=2,
                 exp_noise_std=0.1,
                 noise_decay=0.9995,
                 noise_decay_steps=1000,
                 smooth_noise_std=0.1,
                 clip=0.2,
                 buffer_size=20000,
                 save_interval=5000,
                 assess_interval=20,
                 logger=None,
                 checkpoint_queen=None):
        #self.__dict__.update(locals())
        self.logger = logger
        self.logger.save_config(locals())
        self.n_action = n_action
        self.n_state = n_state
        self.a_bound = a_bound
        self.noise_std = exp_noise_std
        self.noise_decay = noise_decay
        self.noise_decay_steps = noise_decay_steps
        self.policy_freq = policy_freq
        self.smooth_noise_std = smooth_noise_std
        self.clip = clip
        self.discount = discount

        self.pointer = 0
        self.buffer = MemoryBuffer(buffer_size, with_per=True)
        self.save_interval = save_interval
        self.assess_interval = assess_interval
        self.actor = Actor(self.n_state,
                           self.n_action,
                           gamma=discount,
                           lr=actor_lr,
                           tau=tau)
        self.critic1 = Critic(self.n_state,
                              self.n_action,
                              gamma=discount,
                              lr=critic_lr,
                              tau=tau)
        self.critic2 = Critic(self.n_state,
                              self.n_action,
                              gamma=discount,
                              lr=critic_lr,
                              tau=tau)
        self.merge = self._merge_summary()
        self.ckpt_queen = checkpoint_queen
        self.prefix = self.__class__.__name__

    def _merge_summary(self):
        tf.summary.histogram('critic_output', self.critic1.model.output)
        tf.summary.histogram('actor_output', self.actor.model.output)
        tf.summary.histogram('critic_dense1',
                             self.critic1.model.get_layer('l1').weights[0])
        tf.summary.histogram('actor_dense1',
                             self.actor.model.get_layer('l1').weights[0])
        tf.summary.histogram('critic_dense2',
                             self.critic1.model.get_layer('l2').weights[0])
        tf.summary.histogram('actor_dense2',
                             self.actor.model.get_layer('l2').weights[0])
        return tf.summary.merge_all()

    def select_action(self, state):
        return self.actor.predict(state)

    def bellman_q_value(self, rewards, q_nexts, dones):
        """ Use the Bellman Equation to compute the critic target
        """
        q_target = np.zeros_like(
            rewards)  #asarry( copy = False), array(cope=True)
        for i in range(rewards.shape[0]):
            if dones[i]:
                q_target[i] = rewards[i]
            else:
                q_target[i] = rewards[i] + self.discount * q_nexts[i]
        return q_target

    def memorize(self, state, action, reward, done, new_state):
        """ Store experience in memory buffer
        """
        if (self.buffer.with_per):
            q_val = reward
            q_val_t = self.critic1.target_predict(state, action)
            td_error = abs(q_val_t - q_val)[0]
            # print(td_error)
        else:
            td_error = 0
        state = state.reshape(-1)
        action = action.reshape(-1)
        self.buffer.memorize(state, action, reward, done, new_state, td_error)

    def sample_batch(self, batch_size):
        return self.buffer.sample_batch(batch_size)

    def update_actor(self, states):
        actions = self.actor.predict(states)
        grad_ys = self.critic1.gradients(states, actions)
        actor_output = self.actor.train(states, actions, grad_ys)
        self.actor.copy_weights()
        self.critic1.copy_weights()
        self.critic2.copy_weights()
        return grad_ys, actor_output

    def update_critic(self, states, actions, q_values):
        loss_names, loss_values = self.critic1.train_on_batch(
            states, actions, q_values)
        self.critic2.train_on_batch(states, actions, q_values)
        return loss_names, loss_values

    def save_weights(self, path):
        self.actor.save(path)
        self.critic1.save(path)
        self.critic2.save(path)

    def save_model(self, path, file):
        self.actor.model.save(
            os.path.join(path, self.prefix + '_actor_' + file + '.h5'))
        self.critic1.model.save(
            os.path.join(path, self.prefix + '_critic1_' + file + '.h5'))
        self.critic2.model.save(
            os.path.join(path, self.prefix + '_critic2_' + file + '.h5'))

    def checkpoint(self, path, step, metric_value):
        signature = str(step) + '_' + '{:.4}'.format(metric_value)
        to_delete, need_save = self.ckpt_queen.add((metric_value, signature))
        if to_delete:
            delete_actor = os.path.join(
                path, self.prefix + '_actor_' + to_delete[1] + '.h5')
            delete_critic1 = os.path.join(
                path, self.prefix + '_critic1_' + to_delete[1] + '.h5')
            delete_critic2 = os.path.join(
                path, self.prefix + '_critic2_' + to_delete[1] + '.h5')
            os.remove(delete_actor)
            os.remove(delete_critic1)
            os.remove(delete_critic2)
        if need_save:
            self.save_model(path, signature)

    def train(self,
              args,
              summary_writer,
              train_data=None,
              val_data=None,
              test_data=None):
        results = []
        max_val_rate = 0
        val_data = np.asarray(val_data)  # none will be array(None)
        # First, gather experience
        tqdm_e = tqdm(range(args.batchs),
                      desc='score',
                      leave=True,
                      unit="epoch")
        if train_data is None:
            dataset = CsvBuffer(args.file_dir,
                                args.reg_pattern,
                                chunksize=args.batch_size)  # 100*(20+1)
            assert dataset.is_buffer_available, 'neither train_data nor csv buffer is available'
        # noise = OrnsteinUhlenbeckProcess(size=self.n_action)
        else:
            dataset = Dataset(train_data, 1, shuffle=True)

        warm_up = 20 * args.batch_size
        for e in tqdm_e:
            batch_data = next(dataset)
            states, labels = batch_data[:, :-1], batch_data[:, -1].astype(int)

            a = self.select_action(states)  #(batch, n_action)
            a = np.clip(a + np.random.normal(0, self.noise_std, size=a.shape),
                        self.a_bound[0], self.a_bound[1])
            llr = np.clip(np.log(a / (1 - a) + 1e-6), -5, 5)
            # rewards = np.where(labels==1, llr.ravel(), -llr.ravel())  #(batch,)
            rewards = np.where(labels == 1,
                               np.where(llr > 0, llr.ravel(), 2 * llr.ravel()),
                               np.where(llr < 0, -llr.ravel(),
                                        -2 * llr.ravel()))  #(batch,)
            # print(rewards)

            # a_ = self.actor.target_predict(next_states)
            # noise = np.clip(np.random.normal(0, self.smooth_noise_std), 0, self.clip)
            # a_ = a_ + noise
            # q_next1 = self.critic1.target_predict(new_states, a_)
            # q_next2 = self.critic2.target_predict(new_states,a_)
            # q_nexts = np.where(q_next1<q_next2, q_next1, q_next2)
            self.memorize(states, a, rewards, True, None)
            if e < warm_up:
                continue

            states, a, rewards, _, _, _ = self.sample_batch(args.batch_size)
            # print(states.shape, a.shape, rewards.shape)

            q_ = self.bellman_q_value(rewards=rewards,
                                      q_nexts=0,
                                      dones=[True] *
                                      rewards.shape[0])  #(batch,)

            loss_names, loss_values = self.update_critic(
                states, a, q_.reshape(-1, 1))

            if e % self.policy_freq == 0 or e == warm_up:
                grad_ys, actor_output = self.update_actor(states)

            if ((e + 1) % self.noise_decay_steps - 1) == 0 or e == warm_up:
                self.noise_std *= self.noise_decay
                self.logger.log_tabular('noise', self.noise_std)
            if e % self.assess_interval == 0 or e == args.batchs - 1 or e == warm_up:
                if val_data is not None:
                    val_pred = self.actor.predict(val_data[:, :-1])
                    val_y = val_data[:, -1]
                    # print(val_pred.shape,val_pred[:10])
                    # print(val_y.shape, val_y[:10])
                    val_rate, top_k = top_ratio_hit_rate(
                        val_y.ravel(), val_pred.ravel())
                    self.logger.log_tabular('val_rate', val_rate)
                    self.logger.log_tabular('val_k', int(top_k))
                    self.checkpoint(args.model_path, e, val_rate)
                    max_val_rate = val_rate if val_rate > max_val_rate else max_val_rate
                if test_data is not None:
                    test_pred = self.actor.predict(test_data[:, :-1])
                    test_y = test_data[:, -1]
                    test_rate, top_k = top_ratio_hit_rate(
                        test_y, test_pred.ravel())
                    self.logger.log_tabular('test_rate', test_rate)
                    self.logger.log_tabular('test_k', int(top_k))

            score = rewards.mean()
            summary_writer.add_summary(tf_summary(['mean-reward'], [score]),
                                       global_step=e)
            summary_writer.add_summary(tf_summary(loss_names, [loss_values]),
                                       global_step=e)
            merge = keras.backend.get_session().run(
                self.merge,
                feed_dict={
                    self.critic1.model.input[0]: states,
                    self.critic1.model.input[1]: a,
                    self.actor.model.input: states
                })
            summary_writer.add_summary(merge, global_step=e)

            for name, val in zip(loss_names, [loss_values]):
                self.logger.log_tabular(name, val)

            self.logger.log_tabular(
                'dQ/da', '%.4f+%.4f' %
                (grad_ys.mean(), grad_ys.std()))  # grad_ys (batch,act_dim)
            self.logger.log_tabular(
                'aout',
                '%.4f+%.4f' % (actor_output[0].mean(), actor_output[0].std()))
            self.logger.log_tabular('aloss', '%.4f' % (actor_output[1]))
            self.logger.log_tabular('reward',
                                    '%.4f+%.4f' % (score, rewards.std()))
            self.logger.dump_tabular()
            tqdm_e.set_description("score: " + '{:.4f}'.format(score))
            tqdm_e.set_postfix(noise_std='{:.4}'.format(self.noise_std),
                               max_val_rate='{:.4}'.format(max_val_rate),
                               val_rate='{:.4}'.format(val_rate),
                               top_k=top_k)
            tqdm_e.refresh()

        return results
class Agent(AgentBase):
    def __init__(self,
                 load_policy=False,
                 learning_rate=0.001,
                 dim_a=3,
                 fc_layers_neurons=100,
                 loss_function_type='mean_squared',
                 policy_loc='./racing_car_m2/network',
                 image_size=64,
                 action_upper_limits='1,1',
                 action_lower_limits='-1,-1',
                 e='1',
                 show_ae_output=True,
                 show_state=True,
                 resize_observation=True,
                 ae_training_threshold=0.0011,
                 ae_evaluation_frequency=40):

        self.image_size = image_size

        super(Agent, self).__init__(dim_a=dim_a,
                                    policy_loc=policy_loc,
                                    action_upper_limits=action_upper_limits,
                                    action_lower_limits=action_lower_limits,
                                    e=e,
                                    load_policy=load_policy,
                                    loss_function_type=loss_function_type,
                                    learning_rate=learning_rate,
                                    fc_layers_neurons=fc_layers_neurons)

        # High-dimensional state initialization
        self.resize_observation = resize_observation
        self.show_state = show_state
        self.show_ae_output = show_ae_output

        # Autoencoder training control variables
        self.ae_training = True
        self.ae_loss_history = MemoryBuffer(
            min_size=50,
            max_size=50)  # reuse memory buffer for the ae loss history
        self.ae_trainig_threshold = ae_training_threshold
        self.ae_evaluation_frequency = ae_evaluation_frequency
        self.mean_ae_loss = 1e7

        if self.show_state:
            self.state_plot = FastImagePlot(1,
                                            np.zeros([image_size, image_size]),
                                            image_size,
                                            'Image State',
                                            vmax=0.5)

        if self.show_ae_output:
            self.ae_output_plot = FastImagePlot(2,
                                                np.zeros(
                                                    [image_size, image_size]),
                                                image_size,
                                                'Autoencoder Output',
                                                vmax=0.5)

    def _build_network(self, dim_a, params):
        # Initialize graph
        with tf.variable_scope('base'):
            # Build autoencoder
            ae_inputs = tf.placeholder(
                tf.float32, (None, self.image_size, self.image_size, 1),
                name='input')
            self.loss_ae, latent_space, self.ae_output = autoencoder(ae_inputs)

            # Build fully connected layers
            self.y, loss_policy = fully_connected_layers(
                tf.contrib.layers.flatten(latent_space), dim_a,
                params['fc_layers_neurons'], params['loss_function_type'])

        variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'base')
        self.train_policy = tf.train.GradientDescentOptimizer(
            learning_rate=params['learning_rate']).minimize(loss_policy,
                                                            var_list=variables)

        self.train_ae = tf.train.AdamOptimizer(
            learning_rate=params['learning_rate']).minimize(self.loss_ae)

        # Initialize tensorflow
        init = tf.global_variables_initializer()
        self.sess = tf.Session()
        self.sess.run(init)
        self.saver = tf.train.Saver()

    def _preprocess_observation(self, observation):
        if self.resize_observation:
            observation = cv2.resize(observation,
                                     (self.image_size, self.image_size))
        self.high_dim_observation = observation_to_gray(
            observation, self.image_size)
        self.network_input = self.high_dim_observation

    def _batch_update_extra(self, state_batch, y_label_batch):
        # Calculate autoencoder loss and train if necessary
        if self.ae_training:
            _, loss_ae = self.sess.run([self.train_ae, self.loss_ae],
                                       feed_dict={'base/input:0': state_batch})

        else:
            loss_ae = self.sess.run(self.loss_ae,
                                    feed_dict={'base/input:0': state_batch})

        # Append loss to loss buffer
        self.ae_loss_history.add(loss_ae)

    def _evaluate_ae(self, t):
        # Check autoencoder mean loss in history and update ae_training flag
        if t % self.ae_evaluation_frequency == 0:
            self.mean_ae_loss = np.array(self.ae_loss_history.buffer).mean()
            last_ae_training_state = self.ae_training

            if self.ae_loss_history.initialized(
            ) and self.mean_ae_loss < self.ae_trainig_threshold:
                self.ae_training = False
            else:
                self.ae_training = True

            # If flag changed, print
            if last_ae_training_state is not self.ae_training:
                print('\nTraining autoencoder:', self.ae_training, '\n')

    def _refresh_image_plots(self, t):
        if t % 4 == 0 and self.show_state:
            self.state_plot.refresh(self.high_dim_observation)

        if (t + 2) % 4 == 0 and self.show_ae_output:
            self.ae_output_plot.refresh(
                self.ae_output.eval(
                    session=self.sess,
                    feed_dict={'base/input:0': self.high_dim_observation})[0])

    def time_step(self, t):
        self._evaluate_ae(t)
        self._refresh_image_plots(t)

    def new_episode(self):
        print('\nTraining autoencoder:', self.ae_training)
        print('Last autoencoder mean loss:', self.mean_ae_loss, '\n')
示例#18
0
class DDQN:
    """ Deep Q-Learning Main Algorithm
    """
    def __init__(self, action_dim, state_dim, params):
        """ Initialization
        """
        # session = K.get_session()
        # Environment and DDQN parameters
        self.with_per = params["with_per"]
        self.action_dim = action_dim
        self.state_dim = state_dim

        self.lr = 2.5e-4
        self.gamma = 0.95
        self.epsilon = params["epsilon"]
        self.epsilon_decay = params["epsilon_decay"]
        self.epsilon_minimum = 0.05
        self.buffer_size = 10000
        self.tau = 1.0
        self.agent = Agent(self.state_dim, action_dim, self.lr, self.tau,
                           params["dueling"])
        # Memory Buffer for Experience Replay
        self.buffer = MemoryBuffer(self.buffer_size, self.with_per)

        exp_dir = 'test/models/'
        if not os.path.exists(exp_dir):
            os.makedirs(exp_dir)
        self.export_path = exp_dir + '/lala.h5'
        self.save_interval = params["save_interval"]

    def policy_action(self, s):
        """ Apply an espilon-greedy policy to pick next action
        """
        if random() <= self.epsilon:
            return randrange(self.action_dim)
        else:
            return np.argmax(self.agent.predict(s)[0])

    def train_agent(self, batch_size):
        """ Train Q-network on batch sampled from the buffer
        """
        # Sample experience from memory buffer (optionally with PER)
        s, a, r, d, new_s, idx = self.buffer.sample_batch(batch_size)

        # Apply Bellman Equation on batch samples to train our DDQN
        q = self.agent.predict(s)
        next_q = self.agent.predict(new_s)
        q_targ = self.agent.target_predict(new_s)

        for i in range(s.shape[0]):
            old_q = q[i, a[i]]
            if d[i]:
                q[i, a[i]] = r[i]
            else:
                next_best_action = np.argmax(next_q[i, :])
                q[i, a[i]] = r[i] + self.gamma * q_targ[i, next_best_action]
            if (self.with_per):
                # Update PER Sum Tree
                self.buffer.update(idx[i], abs(old_q - q[i, a[i]]))
        # Train on batch
        self.agent.fit(s, q)
        # Decay epsilon
        if self.epsilon_decay > self.epsilon_minimum:
            self.epsilon *= self.epsilon_decay
        else:
            self.epsilon = self.epsilon_minimum

    def train(self, env, nb_episodes, batch_size, writer):
        """ Main DDQN Training Algorithm
        """

        results = []
        tqdm_e = tqdm(range(nb_episodes),
                      desc='Score',
                      leave=True,
                      unit=" episodes")

        for e in tqdm_e:
            # Reset episode
            t, cumul_reward, done = 0, 0, False
            old_state = env.reset()

            t0 = time.time()
            while not done:
                # Actor picks an action (following the policy)
                a = self.policy_action(old_state)
                # Retrieve new state, reward, and whether the state is terminal
                new_state, r, done, _ = env.step(a)
                print("Step %s in episode %s, cumul_reward: %s reward: %s" %
                      (t, e, cumul_reward, r))
                # Memorize for experience replay
                self.memorize(old_state, a, r, done, new_state)
                # Update current state
                old_state = new_state
                cumul_reward += r
                t += 1
                # Train DDQN and transfer weights to target network
                if (self.buffer.size() > batch_size):
                    self.train_agent(batch_size)
                    self.agent.transfer_weights()
            print("it took % s at episode %s" % (time.time() - t0, e))

            if (e % 10 == 0) & (e != 0):
                # Gather stats every episode for plotting
                mean, stdev, n = self.gather_stats(env)
                results.append([e, mean, stdev, n])

            with writer.as_default():
                tf.summary.scalar('score', cumul_reward, step=e)
            writer.flush()

            # Display score
            tqdm_e.set_description("Score: " + str(cumul_reward))
            tqdm_e.refresh()
            if (e % self.save_interval == 0) & (e != 0):
                self.save_weights(self.export_path, e)
            t0 = time.time()
        return results

    def memorize(self, state, action, reward, done, new_state):
        """ Store experience in memory buffer
        """
        if (self.with_per):
            q_val = self.agent.predict(state)
            q_val_t = self.agent.target_predict(new_state)
            next_best_action = np.argmax(q_val)
            new_val = reward + self.gamma * q_val_t[0, next_best_action]
            td_error = abs(new_val - q_val)[0]
        else:
            td_error = 0
        self.buffer.memorize(state, action, reward, done, new_state, td_error)

    def save_weights(self, path, ep=10000):
        self.agent.save(path)

    def load_weights(self, path):
        self.agent.load_weights(path)

    def gather_stats(self, env):
        score = []
        n_steps = []
        for k in range(10):
            old_state = env.reset()
            cumul_r, t, done = 0, 0, False
            while not done:
                a = self.policy_action(old_state)
                old_state, r, done, _ = env.step(a)
                cumul_r += r
                t += 1
            score.append(cumul_r)
            n_steps.append(t)
        return np.mean(np.array(score)), np.std(
            np.array(score)), np.mean(n_steps)
示例#19
0
    def test_resample(self):
        if os.path.exists("./memory/memory_buffer_test/"):
            shutil.rmtree("./memory/memory_buffer_test/")

        # Make a few saved memory buffers.
        info_set_size = 1 + 1 + 16
        item_size = 6
        max_size = int(1e4)
        mb = MemoryBuffer(info_set_size, item_size, max_size=max_size)

        buf1_size = 100
        for i in range(buf1_size):
            mb.add(make_dummy_ev_infoset(), torch.zeros(item_size), 0)
        mb.save("./memory/memory_buffer_test/", "advt_mem_0")
        mb.clear()

        buf2_size = 200
        for i in range(buf2_size):
            mb.add(make_dummy_ev_infoset(), torch.zeros(item_size), 1)
        mb.save("./memory/memory_buffer_test/", "advt_mem_0")
        mb.clear()

        buf3_size = 300
        for i in range(buf3_size):
            mb.add(make_dummy_ev_infoset(), torch.zeros(item_size), 2)
        mb.save("./memory/memory_buffer_test/", "advt_mem_0")
        mb.clear()

        # Make a dataset using the saved buffers.
        # n = (buf1_size + buf2_size) // 10
        n = 1000
        dataset = MemoryBufferDataset("./memory/memory_buffer_test/",
                                      "advt_mem_0", n)
        # min_size = min(n, buf1_size + buf2_size + buf3_size)
        # print(min_size)

        for _ in range(1):
            dataset.resample()
            self.assertEqual(len(dataset), n)
            self.assertEqual(len(dataset._infosets), n)
            self.assertEqual(len(dataset._items), n)
            self.assertEqual(len(dataset._weights), n)
            # print(dataset._weights)

        # Test iteration over the dataset.
        for inputs in dataset:
            print(inputs.keys())

        print(dataset._weights)
示例#20
0
class Processor (object):

    def __init__ (self, memory, start_address):
        self.memory = memory
        self.start_address = start_address
        self.register_file = RegisterFile ()
        self.data_memory_key_fn = lambda: -777
        self.data_memory = defaultdict (self.data_memory_key_fn)

        self.cycle_count = 0
        self.instr_count = 0
        self.PC = 0

        self.fetch_input_buffer = FetchInputBuffer({
            'PC': self.start_address,
            'instr_count': self.instr_count,
            })
        self.fetcher_buffer = FetcherBuffer()
        self.fetch_stage = FetchStage(self.memory, 
                                      self.fetch_input_buffer, 
                                      self.fetcher_buffer)
        
        self.decoder_buffer = DecoderBuffer()
        self.decode_stage = DecodeStage(self.fetcher_buffer,
                                        self.decoder_buffer,
                                        self.register_file)
        
        self.executer_buffer = ExecuterBuffer()
        self.execute_stage = ExecuteStage(self.decoder_buffer,
                                          self.executer_buffer)
        self.memory_buffer = MemoryBuffer()
        self.memory_stage = MemoryStage(self.executer_buffer,
                                        self.memory_buffer,
                                        self.data_memory)
        self.write_back_stage = WriteBackStage(
            self.memory_buffer,
            self.register_file)

    def print_buffers (self):
        print "PC:", self.fetch_stage.fetch_input_buffer

        print 'fetch_stage.fetch_input_buffer:'
        print self.fetch_stage.fetch_input_buffer
        print 'fetch_stage.fetcher_buffer:'
        print self.fetch_stage.fetcher_buffer
        print
        print 'decode_stage.fetcher_buffer:'
        print self.decode_stage.fetcher_buffer
        print 'decode_stage.decoder_buffer:'
        print self.decode_stage.decoder_buffer
        print
        print 'execute_stage.decoder_buffer:'
        print self.execute_stage.decoder_buffer
        print 'execute_stage.executer_buffer:'
        print self.execute_stage.executer_buffer
        print
        print 'memory_stage.executer_buffer:'
        print self.memory_stage.executer_buffer
        print 'memory_stage.memory_buffer:'
        print self.memory_stage.memory_buffer
        print
        print 'write_back_stage.memory_buffer:'
        print self.write_back_stage.memory_buffer


    # def get_all_curr_data(self):
    #     """Return dict of all data in the Processor at the moment.
    #     """

    #     # TODO: It gives 'Can't pickle instancemethod object' error
    #     # when I have self.data_memory too.

    #     curr_data_dict = {
    #         'fetcher_buffer': self.fetcher_buffer,
    #         'decoder_buffer': self.decoder_buffer,
    #         'executer_buffer': self.executer_buffer,
    #         'memory_buffer': self.memory_buffer,
    #         'decoder_stalled': self.decoder_stalled,
    #         'executer_stalled': self.executer_stalled,
    #         'mem_stalled': self.mem_stalled,
    #         'reg_writer_stalled': self.reg_writer_stalled,
    #         'memory': self.memory,
    #         'start_address': self.start_address,
    #         'register_file': self.register_file,
    #         'PC': self.PC,
    #         'IR': self.IR,
    #         'NPC': self.NPC,
    #         'cycle_count': self.cycle_count,
    #         'instr_count': self.instr_count,
    #         }
    #     return curr_data_dict

    # @staticmethod
    # def save_cycle_data(cycle_data_list, cycle_data_file_name = default_data_file_name):
    #     """Pickle and save cycle_data_list.
        
    #     Arguments:
    #     - `cycle_data_list`:
    #     """

    #     with open(cycle_data_file_name, 'w') as f:
    #         pickle.dump(cycle_data_list, f)

    #     print 'Wrote cycle_data_list to {0}'.format(cycle_data_file_name)

    # @staticmethod
    # def read_saved_data(cycle_data_file_name = default_data_file_name):
    #     """Return cycle data list saved in cycle_data_file_name.
        
    #     Arguments:
    #     - `cycle_data_file_name`:
    #     """
    #     cycle_data_list = []
    #     with open(cycle_data_file_name, 'rb') as f:
    #         cycle_data_list = pickle.load(f)
    #         print 'Read cycle_data_list from {0}'.format(cycle_data_file_name)
    #     return cycle_data_list
    
    # TODO: Be careful. In reality, the stages are executed in reverse
    # order.
    @staticmethod
    def get_stage_output(memory, register_file, pc, instr_count, 
                         stage_name):
        """Return the output buffer of stage given the initial conditions.
        
        All the stages before stage_name will be executed.
        
        Arguments:
        - `memory`:
        - `register_file`:
        - `pc`:
        - `stage_name`:

        TODO: Maybe just take the stages as input later.
        """
        fetch_input_buffer = FetchInputBuffer({
            'PC': pc,
            'instr_count': instr_count,
            })
        fetcher_buffer = FetcherBuffer()
        fetch_stage = FetchStage(memory, fetch_input_buffer, fetcher_buffer)
        fetch_stage.fetch_instruction()

        if stage_name == 'fetch':
            return fetch_stage.fetcher_buffer

        decode_stage = DecodeStage(fetch_stage.fetcher_buffer, 
                                   DecoderBuffer(), 
                                   register_file)
        decode_stage.decode_instruction()

        if stage_name == 'decode':
            return decode_stage.decoder_buffer

        execute_stage = ExecuteStage(decode_stage.decoder_buffer,
                                     ExecuterBuffer())
        execute_stage.execute()
        if stage_name == 'execute':
            return execute_stage.executer_buffer

        data_memory_key_fn = lambda: -1
        data_memory = defaultdict (data_memory_key_fn)

        memory_stage = MemoryStage(execute_stage.executer_buffer,
                                   MemoryBuffer(),
                                   data_memory)
        memory_stage.do_memory_operation()

        if stage_name == 'memory':
            return memory_stage.memory_buffer
    
    def do_operand_forwarding(self, ):
        """Forward operands if possible.
        """
        # TODO: Be careful about this... check if it is an output reg
        # value or an input reg value... (Does that make sense?)

        # MEM stage
        for operand_label in ['rt']:
            if (self.executer_buffer[operand_label] is not None 
                and self.executer_buffer[operand_label][1] is None):

                reg_label = self.executer_buffer[operand_label][0]
                mem_forward_val = self.memory_buffer.get_reg_val(reg_label)
                print 'mem_forward_val: ', mem_forward_val
                self.executer_buffer[operand_label][1] = mem_forward_val

        # ALU
        for operand_label in ['rs', 'rt']:
            if (self.decoder_buffer[operand_label] is not None 
                and self.decoder_buffer[operand_label][1] is None):

                reg_label = self.decoder_buffer[operand_label][0]
                exec_forward_val = self.executer_buffer.get_reg_val(reg_label)
                mem_forward_val = self.memory_buffer.get_reg_val(reg_label)

                # Note: exec_forward_val or mem_forward_val won't work
                # cos 0 or None => None
                forward_val = exec_forward_val if exec_forward_val is not None \
                  else mem_forward_val

                self.decoder_buffer[operand_label][1] = forward_val

    def are_instructions_in_flight(self, ):
        """Return True iff there exist instructions in-flight.

        TODO: Check if any registers are dirty.
        """
        any_non_empty_buffers = not all(buff.is_empty() for buff in 
                                        [self.memory_stage.memory_buffer,
                                         self.execute_stage.executer_buffer,
                                         self.decode_stage.decoder_buffer,
                                         self.fetch_stage.fetcher_buffer])
        any_stalls = any(stage.is_stalled for stage in [self.decode_stage,
                                                        self.execute_stage,
                                                        self.memory_stage])
        valid_PC_coming_up = self.fetch_stage.is_valid_PC()
        # valid_PC_coming_up = False

        return any_non_empty_buffers or any_stalls or valid_PC_coming_up
        
    def execute_one_cycle(self, ):
        """Execute one cycle of the Processor.


        TODO: Make it such that the stages SHARE the buffer (ie. have
        a reference to the same buffer) instead of having copies named
        FetcherBuffer, etc.
        """

        self.do_operand_forwarding()

        self.write_back_stage.write_back()
        self.memory_stage.do_memory_operation()
        self.execute_stage.execute(self.memory_stage.is_stalled)
        self.decode_stage.decode_instruction(self.execute_stage.is_stalled)
        self.fetch_stage.fetch_instruction(self.decode_stage.is_stalled)

        # self.write_back_stage.memory_buffer = self.memory_stage.memory_buffer

        # if not self.memory_stage.is_stalled:
        #     self.memory_stage.executer_buffer = self.execute_stage.executer_buffer

        # if (not self.execute_stage.is_stalled 
        #     and not self.decode_stage.has_jumped):
        #     self.execute_stage.decoder_buffer = self.decode_stage.decoder_buffer

        # # TODO: What is this for?
        # if not self.decode_stage.is_stalled and not self.execute_stage.branch_pc is not None:
        #     self.decode_stage.fetcher_buffer = self.fetch_stage.fetcher_buffer

        if (self.memory_stage.is_stalled or 
            self.execute_stage.is_stalled or 
            self.decode_stage.is_stalled):
            print 'STALL'
            print 'self.memory_stage.is_stalled, self.execute_stage.is_stalled, self.decode_stage.is_stalled: \n', [self.memory_stage.is_stalled, 
                                              self.execute_stage.is_stalled, 
                                              self.decode_stage.is_stalled]


        if self.execute_stage.branch_pc is not None:
            self.decode_stage.undo_dirties(self.execute_stage.is_stalled)
            print 'self.register_file: ', self.register_file
            self.decoder_buffer.clear()
            self.fetcher_buffer.clear()
            
        # TODO: Can there be a jump PC from Decode and a branch PC
        # from Execute at the end of the same cycle?
        if self.decode_stage.has_jumped:
            # Pass on PC value from decoder_buffer to fetcher_buffer in
            # case of a jump.
            self.fetch_stage.fetch_input_buffer.PC = self.decode_stage.jump_pc
        elif self.execute_stage.branch_pc is not None:
            self.fetch_stage.fetch_input_buffer.PC = self.execute_stage.branch_pc
    
    def execute_cycles(self, num_cycles = None):
        """Execute num_cycles cycles of the Processor (if possible).

        Else, execute till the program terminates.
        """
        self.cycle_count = 0
        print 'self.memory: ', self.memory

        while True:
            self.cycle_count += 1
            print '\n'
            print 'Beginning of Cycle #' + str(self.cycle_count)
            print '=' * 12

            print '[self.decode_stage, self.execute_stage, self.memory_stage]: ', [
                stage.is_stalled 
                for stage in [self.decode_stage, self.memory_stage, self.execute_stage]]

            self.print_buffers ()
            print self.register_file

            self.execute_one_cycle()

            if not self.are_instructions_in_flight() or (
                    num_cycles is not None and self.cycle_count == num_cycles):
                break

        print '\nAt the end'
        print '=' * 12
        self.print_buffers ()
        print self.register_file

    def start(self, cycle_data_file_name = default_data_file_name):
        """Start execution of instructions from the start_address.
        """
        self.instruction_address = self.start_address
        self.execute_cycles()

    def getCPI (self):
        return (1.0 * self.cycle_count) / self.fetch_stage.fetch_input_buffer.instr_count
示例#21
0
    tf.compat.v1.keras.backend.set_session(sess)

    # 设置gym有关参数
    env = make_atari('PongNoFrameskip-v4')
    env = wrap_deepmind(env, scale=False, frame_stack=True)
    num_actions = env.action_space.n

    dqn = DeepQNetwork(input_shape=(WIDTH, HEIGHT, NUM_FRAMES),
                       num_actions=num_actions,
                       name='dqn',
                       learning_rate=LR)
    target_dqn = DeepQNetwork(input_shape=(WIDTH, HEIGHT, NUM_FRAMES),
                              num_actions=num_actions,
                              name='target_dqn',
                              learning_rate=LR)
    buf = MemoryBuffer(memory_size=BUFFER_SIZE)

    total_episode_rewards = []
    step = 0
    for episode in range(MAX_EPISODE + 1):
        frame = env.reset()  # LazyFrames
        state = np.array(frame)  # narray (84, 84, 4)
        done = False
        cur_episode_reward = 0
        while not done:  # 如果done则结束episode
            if step % C == 0:
                target_dqn.copy_from(dqn)  # 复制参数
            if epsilon_greedy(step):
                action = env.action_space.sample()
            else:
                action = dqn.get_action(state / 255.0)
示例#22
0
class TrainingLoopSAC:
    def __init__(self,
                 config: BasicConfigSAC,
                 environment: Env,
                 log_path: str = None,
                 logging: bool = True):
        """
        TODO: Write docstring
        """
        log_path = generate_experiment_signature(
            environment) if log_path is None else log_path
        self.config = config
        self.monitor = Monitor(log_path, config, logging=logging)
        self.env = environment

        self.batch_size = config.learner.batch_size
        self.episode_horizon = config.episode_horizon
        self.steps_before_learn = config.steps_before_learn

        self.memory_buffer = MemoryBuffer(max_memory_size=config.memory_size)

        self.agent = AgentSAC(environment, config.policy)

        self.learner = LearnerSAC(config=config.learner,
                                  agent=self.agent,
                                  enviroment=self.env,
                                  monitor=self.monitor)

    def train(self,
              num_epochs: int = 30,
              steps_per_epoch: int = 1000,
              step_per_learning_step: int = 1,
              grad_updates_per_learning_step: int = 1):

        # assert self.agent.is_learning
        self.total_steps, self.learning_steps, self.episodes = 0, 0, 0

        observation, trajectory = self.reset_environment()

        with self.monitor.summary_writter.as_default():
            for epoch in range(num_epochs):
                self.monitor.epoch_start_callback()
                for step in range(steps_per_epoch):
                    observation, trajectory = self.register_agent_step(
                        observation, trajectory)
                    if step % step_per_learning_step == 0 and self.agent.is_learning:
                        self.learning_step(grad_updates_per_learning_step)
                        self.learning_steps += 1
                self.monitor.epoch_end_callback(steps_per_epoch)

    def register_agent_step(
            self, observation: Observation,
            trajectory: Trajectory) -> Tuple[Observation, Trajectory]:
        action, action_metadata = self.agent.act(observation)
        next_observation, reward, done, info = self.env.step(action)
        # Append step to trajectory & add step to memory buffer
        step = trajectory.register_step(observation=observation,
                                        action=action,
                                        reward=reward,
                                        next_observation=next_observation,
                                        action_metadata=action_metadata,
                                        done=done)
        if self.agent.is_learning:
            self.memory_buffer.add_step(step)
        start_new_traj = self.episode_horizon == len(
            trajectory) or done  # Handle terminal steps
        if start_new_traj:
            observation, trajectory = self.handle_completed_trajectory(
                trajectory)
        else:
            observation = next_observation
        self.total_steps += 1
        return observation, trajectory

    def handle_completed_trajectory(
            self, trajectory: Trajectory) -> Tuple[Observation, Trajectory]:
        if len(trajectory) > 0:
            self.monitor.trajectory_completed_callback(trajectory)
            self.episodes += 1
        return self.reset_environment()

    def reset_environment(self) -> Tuple[Observation, Trajectory]:
        return self.env.reset(), Trajectory()

    def learning_step(self, grad_updates_per_learning_step: int):
        ready_to_learn = (
            (self.memory_buffer.current_size >= self.batch_size
             and self.memory_buffer.current_size > self.steps_before_learn)
            and self.agent.is_learning)
        if ready_to_learn:
            for _ in range(grad_updates_per_learning_step):
                batch = self.memory_buffer.sample_batch_transitions(
                    self.batch_size)
                self.learner.learn_from_batch(batch)
示例#23
0
comarg.add_argument("output_folder", help="Where to write results to.")
comarg.add_argument("--num_episodes",
                    type=int,
                    default=10,
                    help="Number of episodes to test.")
comarg.add_argument("--random_seed",
                    type=int,
                    help="Random seed for repeatable experiments.")
args = parser.parse_args()

if args.random_seed:
    random.seed(args.random_seed)

env = GymEnvironment(args.env_id, args)
net = DeepQNetwork(env.numActions(), args)
buf = MemoryBuffer(args)

if args.load_weights:
    print "Loading weights from %s" % args.load_weights
    net.load_weights(args.load_weights)

env.gym.monitor.start(args.output_folder, force=True)
avg_reward = 0
num_episodes = args.num_episodes
for i_episode in xrange(num_episodes):
    env.restart()
    observation = env.getScreen()
    buf.reset()
    i_total_reward = 0
    for t in xrange(10000):
        buf.add(observation)
示例#24
0
    version,
    train_ae=config_graph.getboolean('train_autoencoder'),
    load_policy=config_exp_setup.getboolean('load_graph'),
    learning_rate=float(config_graph['learning_rate']),
    dim_a=config_graph.getint('dim_a'),
    fc_layers_neurons=config_graph.getint('fc_layers_neurons'),
    loss_function_type=config_graph['loss_function_type'],
    policy_loc=config_graph['policy_loc'] + exp_num + '_',
    action_upper_limits=config_graph['action_upper_limits'],
    action_lower_limits=config_graph['action_lower_limits'],
    e=config_graph['e'],
    config_graph=config_graph,
    config_general=config_general)

# Create memory buffer
buffer = MemoryBuffer(min_size=config_buffer.getint('min_size'),
                      max_size=config_buffer.getint('max_size'))

# Create feedback object
env.render()
human_feedback = Feedback(env,
                          key_type=config_feedback['key_type'],
                          h_up=config_feedback['h_up'],
                          h_down=config_feedback['h_down'],
                          h_right=config_feedback['h_right'],
                          h_left=config_feedback['h_left'],
                          h_null=config_feedback['h_null'])

# Create saving directory if it does no exist
if save_results:
    if not os.path.exists(eval_save_path + eval_save_folder):
        os.makedirs(eval_save_path + eval_save_folder)
示例#25
0
class Processor(object):
    def __init__(self, memory, start_address):
        self.memory = memory
        self.start_address = start_address
        self.register_file = RegisterFile()
        self.data_memory_key_fn = lambda: -777
        self.data_memory = defaultdict(self.data_memory_key_fn)

        self.cycle_count = 0
        self.instr_count = 0
        self.PC = 0

        self.fetch_input_buffer = FetchInputBuffer({
            'PC':
            self.start_address,
            'instr_count':
            self.instr_count,
        })
        self.fetcher_buffer = FetcherBuffer()
        self.fetch_stage = FetchStage(self.memory, self.fetch_input_buffer,
                                      self.fetcher_buffer)

        self.decoder_buffer = DecoderBuffer()
        self.decode_stage = DecodeStage(self.fetcher_buffer,
                                        self.decoder_buffer,
                                        self.register_file)

        self.executer_buffer = ExecuterBuffer()
        self.execute_stage = ExecuteStage(self.decoder_buffer,
                                          self.executer_buffer)
        self.memory_buffer = MemoryBuffer()
        self.memory_stage = MemoryStage(self.executer_buffer,
                                        self.memory_buffer, self.data_memory)
        self.write_back_stage = WriteBackStage(self.memory_buffer,
                                               self.register_file)

    def print_buffers(self):
        print "PC:", self.fetch_stage.fetch_input_buffer

        print 'fetch_stage.fetch_input_buffer:'
        print self.fetch_stage.fetch_input_buffer
        print 'fetch_stage.fetcher_buffer:'
        print self.fetch_stage.fetcher_buffer
        print
        print 'decode_stage.fetcher_buffer:'
        print self.decode_stage.fetcher_buffer
        print 'decode_stage.decoder_buffer:'
        print self.decode_stage.decoder_buffer
        print
        print 'execute_stage.decoder_buffer:'
        print self.execute_stage.decoder_buffer
        print 'execute_stage.executer_buffer:'
        print self.execute_stage.executer_buffer
        print
        print 'memory_stage.executer_buffer:'
        print self.memory_stage.executer_buffer
        print 'memory_stage.memory_buffer:'
        print self.memory_stage.memory_buffer
        print
        print 'write_back_stage.memory_buffer:'
        print self.write_back_stage.memory_buffer

    # def get_all_curr_data(self):
    #     """Return dict of all data in the Processor at the moment.
    #     """

    #     # TODO: It gives 'Can't pickle instancemethod object' error
    #     # when I have self.data_memory too.

    #     curr_data_dict = {
    #         'fetcher_buffer': self.fetcher_buffer,
    #         'decoder_buffer': self.decoder_buffer,
    #         'executer_buffer': self.executer_buffer,
    #         'memory_buffer': self.memory_buffer,
    #         'decoder_stalled': self.decoder_stalled,
    #         'executer_stalled': self.executer_stalled,
    #         'mem_stalled': self.mem_stalled,
    #         'reg_writer_stalled': self.reg_writer_stalled,
    #         'memory': self.memory,
    #         'start_address': self.start_address,
    #         'register_file': self.register_file,
    #         'PC': self.PC,
    #         'IR': self.IR,
    #         'NPC': self.NPC,
    #         'cycle_count': self.cycle_count,
    #         'instr_count': self.instr_count,
    #         }
    #     return curr_data_dict

    # @staticmethod
    # def save_cycle_data(cycle_data_list, cycle_data_file_name = default_data_file_name):
    #     """Pickle and save cycle_data_list.

    #     Arguments:
    #     - `cycle_data_list`:
    #     """

    #     with open(cycle_data_file_name, 'w') as f:
    #         pickle.dump(cycle_data_list, f)

    #     print 'Wrote cycle_data_list to {0}'.format(cycle_data_file_name)

    # @staticmethod
    # def read_saved_data(cycle_data_file_name = default_data_file_name):
    #     """Return cycle data list saved in cycle_data_file_name.

    #     Arguments:
    #     - `cycle_data_file_name`:
    #     """
    #     cycle_data_list = []
    #     with open(cycle_data_file_name, 'rb') as f:
    #         cycle_data_list = pickle.load(f)
    #         print 'Read cycle_data_list from {0}'.format(cycle_data_file_name)
    #     return cycle_data_list

    # TODO: Be careful. In reality, the stages are executed in reverse
    # order.
    @staticmethod
    def get_stage_output(memory, register_file, pc, instr_count, stage_name):
        """Return the output buffer of stage given the initial conditions.
        
        All the stages before stage_name will be executed.
        
        Arguments:
        - `memory`:
        - `register_file`:
        - `pc`:
        - `stage_name`:

        TODO: Maybe just take the stages as input later.
        """
        fetch_input_buffer = FetchInputBuffer({
            'PC': pc,
            'instr_count': instr_count,
        })
        fetcher_buffer = FetcherBuffer()
        fetch_stage = FetchStage(memory, fetch_input_buffer, fetcher_buffer)
        fetch_stage.fetch_instruction()

        if stage_name == 'fetch':
            return fetch_stage.fetcher_buffer

        decode_stage = DecodeStage(fetch_stage.fetcher_buffer, DecoderBuffer(),
                                   register_file)
        decode_stage.decode_instruction()

        if stage_name == 'decode':
            return decode_stage.decoder_buffer

        execute_stage = ExecuteStage(decode_stage.decoder_buffer,
                                     ExecuterBuffer())
        execute_stage.execute()
        if stage_name == 'execute':
            return execute_stage.executer_buffer

        data_memory_key_fn = lambda: -1
        data_memory = defaultdict(data_memory_key_fn)

        memory_stage = MemoryStage(execute_stage.executer_buffer,
                                   MemoryBuffer(), data_memory)
        memory_stage.do_memory_operation()

        if stage_name == 'memory':
            return memory_stage.memory_buffer

    def do_operand_forwarding(self, ):
        """Forward operands if possible.
        """
        # TODO: Be careful about this... check if it is an output reg
        # value or an input reg value... (Does that make sense?)

        # MEM stage
        for operand_label in ['rt']:
            if (self.executer_buffer[operand_label] is not None
                    and self.executer_buffer[operand_label][1] is None):

                reg_label = self.executer_buffer[operand_label][0]
                mem_forward_val = self.memory_buffer.get_reg_val(reg_label)
                print 'mem_forward_val: ', mem_forward_val
                self.executer_buffer[operand_label][1] = mem_forward_val

        # ALU
        for operand_label in ['rs', 'rt']:
            if (self.decoder_buffer[operand_label] is not None
                    and self.decoder_buffer[operand_label][1] is None):

                reg_label = self.decoder_buffer[operand_label][0]
                exec_forward_val = self.executer_buffer.get_reg_val(reg_label)
                mem_forward_val = self.memory_buffer.get_reg_val(reg_label)

                # Note: exec_forward_val or mem_forward_val won't work
                # cos 0 or None => None
                forward_val = exec_forward_val if exec_forward_val is not None \
                  else mem_forward_val

                self.decoder_buffer[operand_label][1] = forward_val

    def are_instructions_in_flight(self, ):
        """Return True iff there exist instructions in-flight.

        TODO: Check if any registers are dirty.
        """
        any_non_empty_buffers = not all(buff.is_empty() for buff in [
            self.memory_stage.memory_buffer,
            self.execute_stage.executer_buffer,
            self.decode_stage.decoder_buffer, self.fetch_stage.fetcher_buffer
        ])
        any_stalls = any(
            stage.is_stalled for stage in
            [self.decode_stage, self.execute_stage, self.memory_stage])
        valid_PC_coming_up = self.fetch_stage.is_valid_PC()
        # valid_PC_coming_up = False

        return any_non_empty_buffers or any_stalls or valid_PC_coming_up

    def execute_one_cycle(self, ):
        """Execute one cycle of the Processor.


        TODO: Make it such that the stages SHARE the buffer (ie. have
        a reference to the same buffer) instead of having copies named
        FetcherBuffer, etc.
        """

        self.do_operand_forwarding()

        self.write_back_stage.write_back()
        self.memory_stage.do_memory_operation()
        self.execute_stage.execute(self.memory_stage.is_stalled)
        self.decode_stage.decode_instruction(self.execute_stage.is_stalled)
        self.fetch_stage.fetch_instruction(self.decode_stage.is_stalled)

        # self.write_back_stage.memory_buffer = self.memory_stage.memory_buffer

        # if not self.memory_stage.is_stalled:
        #     self.memory_stage.executer_buffer = self.execute_stage.executer_buffer

        # if (not self.execute_stage.is_stalled
        #     and not self.decode_stage.has_jumped):
        #     self.execute_stage.decoder_buffer = self.decode_stage.decoder_buffer

        # # TODO: What is this for?
        # if not self.decode_stage.is_stalled and not self.execute_stage.branch_pc is not None:
        #     self.decode_stage.fetcher_buffer = self.fetch_stage.fetcher_buffer

        if (self.memory_stage.is_stalled or self.execute_stage.is_stalled
                or self.decode_stage.is_stalled):
            print 'STALL'
            print 'self.memory_stage.is_stalled, self.execute_stage.is_stalled, self.decode_stage.is_stalled: \n', [
                self.memory_stage.is_stalled, self.execute_stage.is_stalled,
                self.decode_stage.is_stalled
            ]

        if self.execute_stage.branch_pc is not None:
            self.decode_stage.undo_dirties(self.execute_stage.is_stalled)
            print 'self.register_file: ', self.register_file
            self.decoder_buffer.clear()
            self.fetcher_buffer.clear()

        # TODO: Can there be a jump PC from Decode and a branch PC
        # from Execute at the end of the same cycle?
        if self.decode_stage.has_jumped:
            # Pass on PC value from decoder_buffer to fetcher_buffer in
            # case of a jump.
            self.fetch_stage.fetch_input_buffer.PC = self.decode_stage.jump_pc
        elif self.execute_stage.branch_pc is not None:
            self.fetch_stage.fetch_input_buffer.PC = self.execute_stage.branch_pc

    def execute_cycles(self, num_cycles=None):
        """Execute num_cycles cycles of the Processor (if possible).

        Else, execute till the program terminates.
        """
        self.cycle_count = 0
        print 'self.memory: ', self.memory

        while True:
            self.cycle_count += 1
            print '\n'
            print 'Beginning of Cycle #' + str(self.cycle_count)
            print '=' * 12

            print '[self.decode_stage, self.execute_stage, self.memory_stage]: ', [
                stage.is_stalled for stage in
                [self.decode_stage, self.memory_stage, self.execute_stage]
            ]

            self.print_buffers()
            print self.register_file

            self.execute_one_cycle()

            if not self.are_instructions_in_flight() or (
                    num_cycles is not None and self.cycle_count == num_cycles):
                break

        print '\nAt the end'
        print '=' * 12
        self.print_buffers()
        print self.register_file

    def start(self, cycle_data_file_name=default_data_file_name):
        """Start execution of instructions from the start_address.
        """
        self.instruction_address = self.start_address
        self.execute_cycles()

    def getCPI(self):
        return (1.0 * self.cycle_count
                ) / self.fetch_stage.fetch_input_buffer.instr_count
 def test_write_back_I(self):
     self.set_up_write_back_stage('I LW  R2 R5 4')
     self.write_back_stage.write_back()
     self.assertEqual(self.write_back_stage.memory_buffer, MemoryBuffer())
     self.assertTrue(self.register_file.isClean(self.instr.rt))
示例#27
0
mainarg = parser.add_argument_group('Main loop')
mainarg.add_argument("--load_weights", help="Load network from file.")
mainarg.add_argument("--save_weights_prefix", help="Save network to given file. Epoch and extension will be appended.")

comarg = parser.add_argument_group('Common')
comarg.add_argument("output_folder", help="Where to write results to.")
comarg.add_argument("--num_episodes", type=int, default=10, help="Number of episodes to test.")
comarg.add_argument("--random_seed", type=int, help="Random seed for repeatable experiments.")
args = parser.parse_args()

if args.random_seed:
  random.seed(args.random_seed)

env = GymEnvironment(args.env_id, args)
net = DeepQNetwork(env.numActions(), args)
buf = MemoryBuffer(args)

if args.load_weights:
  print "Loading weights from %s" % args.load_weights
  net.load_weights(args.load_weights)

env.gym.monitor.start(args.output_folder, force=True)
avg_reward = 0
num_episodes = args.num_episodes
for i_episode in xrange(num_episodes):
    env.restart()
    observation = env.getScreen()
    buf.reset()
    i_total_reward = 0
    for t in xrange(10000):
        buf.add(observation)
示例#28
0
def traverse_worker(worker_id, traverse_player_idx, strategies, save_lock, opt,
                    t, eval_mode, info_queue):
    """
  A worker that traverses the game tree K times, saving things to memory buffers. Each worker
  maintains its own memory buffers and saves them after finishing.

  If eval_mode is set to True, no memory buffers are created.
  """
    # assert(strategies[0]._network.device == torch.device("cpu"))
    # assert(strategies[1]._network.device == torch.device("cpu"))

    advt_mem = MemoryBuffer(
        Constants.INFO_SET_SIZE,
        Constants.NUM_ACTIONS,
        max_size=opt.SINGLE_PROC_MEM_BUFFER_MAX_SIZE,
        autosave_params=(opt.MEMORY_FOLDER,
                         opt.ADVT_BUFFER_FMT.format(traverse_player_idx)),
        save_lock=save_lock) if eval_mode == False else None

    strt_mem = MemoryBuffer(
        Constants.INFO_SET_SIZE,
        Constants.NUM_ACTIONS,
        max_size=opt.SINGLE_PROC_MEM_BUFFER_MAX_SIZE,
        autosave_params=(opt.MEMORY_FOLDER, opt.STRT_BUFFER_FMT),
        save_lock=save_lock) if eval_mode == False else None

    if eval_mode:
        num_traversals_per_worker = int(opt.NUM_TRAVERSALS_EVAL /
                                        opt.NUM_TRAVERSE_WORKERS)
    else:
        num_traversals_per_worker = int(opt.NUM_TRAVERSALS_PER_ITER /
                                        opt.NUM_TRAVERSE_WORKERS)

    t0 = time.time()
    for k in range(num_traversals_per_worker):
        ctr = [0]

        # Generate a random initialization, alternating the SB player each time.
        sb_player_idx = k % 2
        round_state = create_new_round(sb_player_idx)

        precomputed_ev = make_precomputed_ev(round_state)
        info = traverse(round_state,
                        make_actions,
                        make_infoset,
                        traverse_player_idx,
                        sb_player_idx,
                        strategies,
                        advt_mem,
                        strt_mem,
                        t,
                        precomputed_ev,
                        recursion_ctr=ctr)

        if (k % opt.TRAVERSE_DEBUG_PRINT_HZ) == 0 and eval_mode == False:
            elapsed = time.time() - t0
            print(
                "[WORKER #{}] done with {}/{} traversals | recursion depth={} | advt={} strt={} | elapsed={} sec"
                .format(worker_id, k, num_traversals_per_worker, ctr[0],
                        advt_mem.size(), strt_mem.size(), elapsed))

    # Save all the buffers one last time.
    print("[WORKER #{}] Final autosave ...".format(worker_id))
    if advt_mem is not None: advt_mem.autosave()
    if strt_mem is not None: strt_mem.autosave()
class AgentDQN_family:
    """ Agent Class (Network) for DDQN
	"""
    def __init__(self,
                 state_dim,
                 action_dim,
                 batchSize=64,
                 lr=.0001,
                 tau=.05,
                 gamma=.95,
                 epsilon=1,
                 eps_dec=.99,
                 learnInterval=1,
                 isDual=False,
                 isDueling=False,
                 isPER=False,
                 filename='model',
                 mem_size=1000000,
                 layerCount=2,
                 layerUnits=64,
                 usePruning=False):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.isDueling = isDueling
        self.isDual = isDual
        self.isPER = isPER
        self.lr = lr
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = eps_dec
        self.batchSize = batchSize
        self.filename = filename
        self.learnInterval = learnInterval
        # Initialize Deep Q-Network
        self.model = generateDQN(action_dim, lr, state_dim, isDueling,
                                 layerCount, layerUnits, usePruning)
        # Build target Q-Network
        self.target_model = generateDQN(action_dim, lr, state_dim, isDueling,
                                        layerCount, layerUnits, usePruning)
        self.layerCount = layerCount
        self.layerUnits = layerUnits
        self.target_model.set_weights(self.model.get_weights())
        self.memory = MemoryBuffer(mem_size, isPER)
        self.epsilonInitial = epsilon
        self.minEpsilon = .1
        self.usePruning = usePruning

        if isDual:
            self.tau = tau
        else:
            self.tau = 1.0

        # load memory data from disk if needed
        self.lastLearnIndex = self.memory.totalMemCount

    def chooseAction(self, state):
        """ Apply an espilon-greedy policy to pick next action
		"""
        if random() <= self.epsilon:
            return randrange(self.action_dim)
        else:
            return np.argmax(
                self.model.predict(self.reshape(
                    state, debug=False)))  # state[0] just means first batch

    def learn(self, numLearns=1):
        """ Train Q-network on batch sampled from the buffer
		"""
        if (
                self.memory.getSize() < self.batchSize
        ):  # we'd get strange errors if we tried to train before we had enough entries in
            # our memory to fill a batch
            return
        if self.memory.totalMemCount - self.lastLearnIndex < self.learnInterval:
            # print("THIS SHOULD NEVER HAPPEN ON NORMAL RUNS, UNLESS WE TERMINATED DUE TO COMPLETEING MISSION")
            return
        self.lastLearnIndex = self.memory.totalMemCount

        for localCounter in range(numLearns):
            # Sample experience from memory buffer (optionally with PER)
            s, a, r, d, new_s, idx = self.memory.sample_batch(self.batchSize)

            # Apply Bellman Equation on batch samples to train our DDQN
            #q = self.model.predict(self.reshape(s,debug=True))
            #next_q = self.model.predict(self.reshape(new_s))
            #q_targ = self.target_model.predict(self.reshape(new_s))
            q = self.model.predict(s)
            next_q = self.model.predict(new_s)
            q_targ = self.target_model.predict(new_s)

            batch_index = np.arange(
                self.batchSize, dtype=np.int32
            )  # This creates a properly formated array of the indexes of the samples in the batch, ie,
            # [1,2,3,4,...]
            #q_target[batch_index, sampleActions] = sampleRewards + self.gamma*np.max(q_next_predictions, axis=1)*sampleTerminations

            #old_q = q[i, a[i]]
            #if d[i]:
            #	q[i, a[i]] = r[i]
            #else:
            #	next_best_action = np.argmax(next_q[i,:])
            #	q[i, a[i]] = r[i] + self.gamma * q_targ[i, next_best_action]
            #if(self.isPER):
            # Update PER Sum Tree
            #	self.buffer.update(idx[i], abs(old_q - q[i, a[i]]))

            q[batch_index,
              a] = r + self.gamma * d * q_targ[batch_index,
                                               np.argmax(next_q, axis=1)]
            """
			for i in range(s.shape[0]):
				old_q = q[i, a[i]]
				if d[i]:
					q[i, a[i]] = r[i]
				else:
					next_best_action = np.argmax(next_q[i,:])
					q[i, a[i]] = r[i] + self.gamma * q_targ[i, next_best_action]
				if(self.isPER):
					# Update PER Sum Tree
					self.buffer.update(idx[i], abs(old_q - q[i, a[i]]))
			"""
            # Train on batch
            self.model.fit(s, q, epochs=1,
                           verbose=0)  # do we really reshape S here???
            # Decay epsilon
            self.updateWeights()
        if self.epsilon > self.minEpsilon:
            self.epsilon *= self.epsilon_decay

    def huber_loss(self, y_true, y_pred):
        return K.mean(K.sqrt(1 + K.square(y_pred - y_true)) - 1, axis=-1)

    def updateWeights(self):
        """ Transfer Weights from Model to Target at rate Tau
		"""
        W = self.model.get_weights()
        tgt_W = self.target_model.get_weights()
        for i in range(len(W)):
            tgt_W[i] = self.tau * W[i] + (1 - self.tau) * tgt_W[i]
        self.target_model.set_weights(tgt_W)

    def saveMemory(self, state, action, reward, done, new_state):
        if (self.isPER):
            q_val = self.self.model.predict(self.reshape(state))
            q_val_t = self.target_model.predict(self.reshape(new_state))
            next_best_action = np.argmax(
                self.model.predict(self.reshape(new_state)))
            new_val = reward + self.gamma * q_val_t[0, next_best_action]
            td_error = abs(new_val - q_val)[0]
        else:
            td_error = 0
        self.memory.saveMemory(state, action, reward, done, new_state,
                               td_error)

    def reshape(self, x, debug=False):
        if debug:
            print("RESHAPING: x was: " + str(x.shape) +
                  ", and has len(shape): " + str(len(x.shape)))
        if len(x.shape) < 4 and len(self.state_dim) > 2:
            return np.expand_dims(x, axis=0)
        elif len(x.shape) < 3:
            if debug: y = np.expand_dims(x, axis=0)
            if debug:
                print("A: Now x is: " + str(y.shape) +
                      ", and has len(shape): " + str(len(y.shape)))
            if debug: breakthis = idontexist
            return np.expand_dims(x, axis=0)
        else:
            if debug:
                print("B: Now x is: " + str(x.shape) +
                      ", and has len(shape): " + str(len(x.shape)))
            if debug: breakthis = idontexist
            return x

    def load_model(self):
        self.model = load_model(self.filename + '.h5')
        self.target_model = load_model(self.filename + '.h5')

    def getFilename(self,
                    filename=None,
                    filenameAppendage=None,
                    intelligentFilename=True,
                    directory=None):
        if directory != None:
            filename = directory + "/"
        else:
            filename = ""
        if intelligentFilename == True:
            if self.isDueling and self.isDual:
                filename += "D3QN"
            elif self.isDueling:
                filename += "DuelingDQN"
            elif self.isDual:
                filename += "DDQN"
            else:
                filename += "DQN"
            filename += ("_lr" + str(self.lr) + "_LI" +
                         str(self.learnInterval) + '_bs' +
                         str(self.batchSize) + '_g' + str(self.gamma) + '_e' +
                         str(self.epsilonInitial) + '_t' + str(self.tau) +
                         "_network" + str(self.layerCount) + "x" +
                         str(self.layerUnits) + "_")
            filename += self.filename
            if self.usePruning:
                filename += "PRUNED_"
            if filenameAppendage != None:
                filename += filenameAppendage

        else:
            if filename == None:
                if filenameAppendage == None:
                    filename += self.filename
                else:
                    filename += self.filename + filenameAppendage
            if self.isDueling:
                filename += "_dueling"
        return (filename)

    # are we supposed to save the target model or the training model?
    def save_model(self,
                   filename=None,
                   filenameAppendage=None,
                   intelligentFilename=True,
                   directory=None):
        filename = self.getFilename(filename=filename,
                                    filenameAppendage=filenameAppendage,
                                    intelligentFilename=intelligentFilename,
                                    directory=directory)
        self.model.save(filename + '.h5')