Пример #1
0
    def __init__(self, **kwargs):
        """
        The constructor for DQN_Agent class.

        """
        hyperparams = kwargs['hyperparams']
        self.input_shape = hyperparams['input_shape']
        self.num_actions = hyperparams['num_actions']
        self.gamma = hyperparams['gamma']
        self.update_horizon = hyperparams['update_horizon']
        self.future_discounts = np.power(self.gamma,
                                         range(self.update_horizon))
        self.bootstrap_discount = np.power(self.gamma, self.update_horizon)

        self.cbt_table = kwargs['cbt_table']
        self.gcs_bucket = kwargs['gcs_bucket']
        self.gcs_bucket_id = kwargs['gcs_bucket_id']
        self.prefix = kwargs['prefix']
        self.tmp_weights_filepath = kwargs['tmp_weights_filepath']
        self.output_dir = kwargs['output_dir']

        self.buffer_size = kwargs['buffer_size']
        self.batch_size = kwargs['batch_size']
        self.train_epochs = kwargs['train_epochs']
        self.train_steps = kwargs['train_steps']
        self.period = kwargs['period']

        self.log_time = kwargs['log_time']
        self.num_gpus = kwargs['num_gpus']
        self.tpu_name = kwargs['tpu_name']
        self.wandb = kwargs['wandb']

        self.exp_buff = ExperienceBuffer(self.buffer_size, self.update_horizon)

        if self.tpu_name is not None:
            self.distribution_strategy = get_distribution_strategy(
                distribution_strategy='tpu', tpu_address=self.tpu_name)
            self.device = '/job:worker'
        else:
            self.distribution_strategy = get_distribution_strategy(
                distribution_strategy='default', num_gpus=self.num_gpus)
            self.device = None
        with tf.device(self.device), self.distribution_strategy.scope():
            self.model = DQN_Model(
                input_shape=self.input_shape,
                num_actions=self.num_actions,
                conv_layer_params=hyperparams['conv_layer_params'],
                fc_layer_params=hyperparams['fc_layer_params'],
                learning_rate=hyperparams['learning_rate'])
            self.target_model = DQN_Model(
                input_shape=self.input_shape,
                num_actions=self.num_actions,
                conv_layer_params=hyperparams['conv_layer_params'],
                fc_layer_params=hyperparams['fc_layer_params'],
                learning_rate=hyperparams['learning_rate'])
        gcs_load_weights(self.model, self.gcs_bucket, self.prefix,
                         self.tmp_weights_filepath)
Пример #2
0
    print("global_i = {}".format(global_i))

    if args.log_time is True:
        time_logger = TimeLogger([
            "0Collect Data", "1Take Action / conv b",
            "2Append new obs / conv b", "3Generate Visual obs keys",
            "4Build pb2 objects traj", "5Write Cells visual", "6Batch visual",
            "7Write cells traj", "8Write cells traj"
        ],
                                 num_cycles=args.num_episodes)

#COLLECT DATA FOR CBT
    print("-> Starting data collection...")
    rows, visual_obs_rows = [], []
    for cycle in range(args.num_cycles):
        gcs_load_weights(model, gcs_bucket, args.prefix,
                         args.tmp_weights_filepath)
        for i in tqdm(range(args.num_episodes), "Cycle {}".format(cycle)):
            if args.log_time is True: time_logger.reset()

            #CREATE ROW_KEY
            row_key_i = i + global_i + (cycle * args.num_episodes)
            row_key = '{}_trajectory_{}'.format(args.prefix,
                                                row_key_i).encode()

            #RL LOOP GENERATES A TRAJECTORY
            observations, actions, rewards = [], [], []

            obs = np.asarray(env.reset()).astype(np.dtype('b'))

            reward = 0
            done = False