示例#1
0
    def train(self):
        """
        Method that trains a model using tf.Estimator using parameters defined in the constructor.

        """
        if self.log_time is True:
            self.time_logger = TimeLogger([
                "Fetch Data    ", "Parse Data    ", "To Dataset    ",
                "Build Model   ", "Compute Loss  ", "Estimator     ",
                "Save Model    "
            ])
        print("-> Starting training...")
        for epoch in range(self.train_epochs):
            self.estimator.train(input_fn=self.train_input_fn,
                                 steps=self.train_steps)

            if self.log_time is True: self.time_logger.log(5)

            if epoch > 0 and epoch % self.period == 0:
                self.export_model()

            if self.log_time is True: self.time_logger.log(6)

            if self.log_time is True: self.time_logger.print_logs()
        print("-> Done!")
示例#2
0
    #LOAD MODEL
    model = DQN_Model(input_shape=env.observation_space.shape,
                      num_actions=env.action_space.n,
                      conv_layer_params=CONV_LAYER_PARAMS,
                      fc_layer_params=FC_LAYER_PARAMS,
                      learning_rate=LEARNING_RATE)

    #GLOBAL ITERATOR
    global_i = cbt_global_iterator(cbt_table)
    print("global_i = {}".format(global_i))

    if args.log_time is True:
        time_logger = TimeLogger([
            "0Collect Data", "1Take Action / conv b",
            "2Append new obs / conv b", "3Generate Visual obs keys",
            "4Build pb2 objects traj", "5Write Cells visual", "6Batch visual",
            "7Write cells traj", "8Write cells traj"
        ],
                                 num_cycles=args.num_episodes)

#COLLECT DATA FOR CBT
    print("-> Starting data collection...")
    rows, visual_obs_rows = [], []
    for cycle in range(args.num_cycles):
        gcs_load_weights(model, gcs_bucket, args.prefix,
                         args.tmp_weights_filepath)
        for i in tqdm(range(args.num_episodes), "Cycle {}".format(cycle)):
            if args.log_time is True: time_logger.reset()

            #CREATE ROW_KEY
            row_key_i = i + global_i + (cycle * args.num_episodes)
    #LOAD MODEL
    model = DQN_Model(input_shape=env.observation_space.shape,
                      num_actions=env.action_space.n,
                      conv_layer_params=CONV_LAYER_PARAMS,
                      fc_layer_params=FC_LAYER_PARAMS,
                      learning_rate=LEARNING_RATE)

    #GLOBAL ITERATOR
    global_i = cbt_global_iterator(cbt_table)
    print("global_i = {}".format(global_i))

    #INITIALIZE EXECUTION TIME LOGGER
    if args.log_time is True:
        time_logger = TimeLogger([
            "Load Weights     ", "Run Environment  ", "Data To Bytes    ",
            "Write Cells      ", "Mutate Rows      "
        ])

    #COLLECT DATA FOR CBT
    print("-> Starting data collection...")
    rows = []
    for cycle in range(args.num_cycles):
        if args.log_time is True: time_logger.reset()

        gcs_load_weights(model, gcs_bucket, args.prefix,
                         args.tmp_weights_filepath)

        if args.log_time is True: time_logger.log("Load Weights     ")

        for i in tqdm(range(args.num_episodes), "Cycle {}".format(cycle)):
示例#4
0
                    }
                }
                data = json.dumps(items_to_json,
                                  separators=(',', ':'),
                                  cls=NumpyEncoder)
                json.dump(data, fd)
                fd.write("\n")
                i = +1

    #GLOBAL ITERATOR
    global_i = cbt_global_iterator(cbt_table)
    print("global_i = {}".format(global_i))

    if args.log_time is True:
        time_logger = TimeLogger(
            ["Collect Data", "Serialize Data", "Write Cells", "Mutate Rows"],
            num_cycles=args.num_episodes)

    #COLLECT DATA FOR CBT
    print("-> Starting data collection...")
    rows = []
    for cycle in range(args.num_cycles):
        #gcs_load_weights(model, gcs_bucket, args.prefix, args.tmp_weights_filepath)
        for i in tqdm(range(args.num_episodes), "Cycle {}".format(cycle)):
            if args.log_time is True: time_logger.reset()

            #RL LOOP GENERATES A TRAJECTORY
            data = collect_data(env, random_policy, steps=args.max_steps)
            print("data: ", data[1].action_step.action.numpy())
            write_data(data, "Rab-Agent_: " + (str(cycle)) + "_" + (str(i)))
    env.close()
    #LOAD MODEL
    model = DQN_Model(input_shape=VISUAL_OBS_SPEC,
                      num_actions=NUM_ACTIONS,
                      conv_layer_params=CONV_LAYER_PARAMS,
                      fc_layer_params=FC_LAYER_PARAMS,
                      learning_rate=LEARNING_RATE)
    gcs_load_weights(model, gcs_bucket, args.prefix, args.tmp_weights_filepath)

    #SETUP TENSORBOARD/LOGGING
    train_log_dir = os.path.join(args.output_dir, 'logs/')
    os.makedirs(os.path.dirname(train_log_dir), exist_ok=True)
    loss_metrics = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
    train_summary_writer = tf.summary.create_file_writer(train_log_dir)
    if args.log_time is True:
        time_logger = TimeLogger(
            ["Fetch Data", "Parse Data", "Compute Loss", "Generate Grads"],
            num_cycles=args.train_steps)

    #TRAINING LOOP
    train_step = 0
    exp_buff = ExperienceBuffer(args.buffer_size)
    print("-> Starting training...")
    for epoch in range(args.train_epochs):
        if args.log_time is True: time_logger.reset()

        #FETCH DATA
        global_i = cbt_global_iterator(cbt_table)
        rows = cbt_read_rows(cbt_table, args.prefix, args.train_steps,
                             global_i)

        if args.log_time is True: time_logger.log(0)
    #LOAD MODEL
    model = DQN_Model(input_shape=VECTOR_OBS_SPEC,
                      num_actions=NUM_ACTIONS,
                      fc_layer_params=FC_LAYER_PARAMS,
                      learning_rate=LEARNING_RATE)
    gcs_load_weights(model, gcs_bucket, args.prefix, args.tmp_weights_filepath)

    #SETUP TENSORBOARD/LOGGING
    train_log_dir = os.path.join(args.output_dir, 'logs/')
    os.makedirs(os.path.dirname(train_log_dir), exist_ok=True)
    loss_metrics = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
    train_summary_writer = tf.summary.create_file_writer(train_log_dir)
    if args.log_time is True:
        time_logger = TimeLogger(
            ["Parse Data", "Format Data", "Compute Loss", "Generate Grads"],
            num_cycles=args.train_steps)

    #TRAINING LOOP
    train_step = 0
    exp_buff = ExperienceBuffer(args.buffer_size)
    print("-> Starting training...")
    for epoch in range(args.train_epochs):
        if args.log_time is True:
            time_logger.reset()

        #FETCH DATA
        global_i = cbt_global_iterator(cbt_table)
        rows = cbt_read_rows(cbt_table, args.prefix, args.train_steps,
                             global_i)
        for row in tqdm(
示例#7
0
class DQN_Agent():
    """
    Class for controlling and managing training from a bigtable database.
      
    Attributes:
        cbt_table (google.cloud.bigtable.Table): Bigtable table object returned from [util.gcp_io.cbt_load_table].
        gcs_bucket (google.cloud.storage.Bucket): GCS bucket object returned from [util.gcp_io.gcs_load_bucket].
        gcs_bucket_id (str): Global name of the GCS bucket where the model will be saved/loaded.
        prefix (str): Prefix used for model and trajectory names.
        tmp_weights_filepath (str): Temporary local path for saving model before copying to GCS.
        buffer_size (int): Max size of the experience buffer.
        batch_size (int): Batch size for estimator.
        train_epochs (int): Number of cycles of querying bigtable and training.
        train_steps (int): Number of train steps per epoch.
        period (int): Interval for saving models.
        output_dir (str): Output directory for logs and models.
        log_time (bool): Flag for time logging.
        num_gpus (int): Number of gpu devices for estimator.
    """
    def __init__(self,
                 cbt_table,
                 gcs_bucket,
                 gcs_bucket_id,
                 prefix,
                 tmp_weights_filepath,
                 buffer_size,
                 batch_size,
                 train_epochs,
                 train_steps,
                 period,
                 output_dir=None,
                 log_time=False,
                 num_gpus=0):
        """
        The constructor for DQN_Agent class.

        """
        self.cbt_table = cbt_table
        self.gcs_bucket = gcs_bucket
        self.gcs_bucket_id = gcs_bucket_id
        self.prefix = prefix
        self.tmp_weights_filepath = tmp_weights_filepath
        self.exp_buff = ExperienceBuffer(buffer_size)
        self.batch_size = batch_size
        self.train_epochs = train_epochs
        self.train_steps = train_steps
        self.period = period
        self.output_dir = output_dir
        self.log_time = log_time

        distribution_strategy = get_distribution_strategy(
            distribution_strategy="default", num_gpus=num_gpus)
        run_config = tf.estimator.RunConfig(
            train_distribute=distribution_strategy)
        data_format = ('channels_first'
                       if tf.test.is_built_with_cuda() else 'channels_last')
        model_dir = os.path.join(self.output_dir, 'models/')
        self.estimator = tf.estimator.Estimator(
            model_fn=self.model_fn,
            model_dir=model_dir,
            config=run_config,
            params={'data_format': data_format})

    def model_fn(self, features, labels, mode, params):
        """
        Function to be passed as argument to tf.Estimator.
  
        Parameters: 
           features (tuple): (obs, next_obs) S and S' of a (S,A,R,S') transition.
           labels (tuple): (actions, rewards, next_mask) A and R of a transition, plus a mask for bootstrapping.
           mode (tf.estimator.ModeKeys): Estimator object that defines which op is called. (currently always .TRAIN)
           params (dict): Optional dictionary of parameters for building custom models. (not currently implemented)
        """
        #BUILD MODEL
        model = DQN_Model(input_shape=VISUAL_OBS_SPEC,
                          num_actions=NUM_ACTIONS,
                          conv_layer_params=CONV_LAYER_PARAMS,
                          fc_layer_params=FC_LAYER_PARAMS,
                          learning_rate=LEARNING_RATE)

        ckpt = tf.train.Checkpoint(step=tf.compat.v1.train.get_global_step(),
                                   optimizer=model.opt,
                                   net=model)

        if self.log_time is True: self.time_logger.log(3)

        (obs, next_obs) = features
        (actions, rewards, next_mask) = labels

        #COMPUTE LOSS
        with tf.GradientTape() as tape:
            q_pred, q_next = model(obs), model(next_obs)
            one_hot_actions = tf.one_hot(actions, NUM_ACTIONS)
            q_pred = tf.reduce_sum(q_pred * one_hot_actions, axis=-1)
            q_next = tf.reduce_max(q_next, axis=-1)
            q_next = tf.cast(q_next, dtype=tf.float64) * next_mask
            q_target = rewards + tf.multiply(
                tf.constant(GAMMA, dtype=tf.float64), q_next)
            loss = tf.reduce_sum(model.loss(q_target, q_pred))

        #GENERATE GRADIENTS
        total_grads = tape.gradient(loss, model.trainable_variables)
        grads_op = model.opt.apply_gradients(
            zip(total_grads, model.trainable_variables),
            tf.compat.v1.train.get_global_step())
        train_op = grads_op

        if self.log_time is True: self.time_logger.log(4)

        #RUN ESTIMATOR IN TRAIN MODE
        return tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.TRAIN,
            predictions=q_pred,
            loss=loss,
            train_op=train_op,
            scaffold=tf.compat.v1.train.Scaffold(saver=ckpt))

    def train_input_fn(self):
        """
        Input function to be passed to estimator.train().

        Reads a single row from bigtable at a time until an experience buffer is filled to a specified buffer_size.

        """
        if self.log_time is True: self.time_logger.reset()

        global_i = cbt_global_iterator(self.cbt_table) - 1
        i = 0
        self.exp_buff.reset()
        while True:
            #FETCH ROW FROM CBT
            row_i = global_i - i
            row = cbt_read_row(self.cbt_table, self.prefix, row_i)

            if self.log_time is True: self.time_logger.log(0)

            #DESERIALIZE DATA
            bytes_traj = row.cells['trajectory']['traj'.encode()][0].value
            bytes_info = row.cells['trajectory']['info'.encode()][0].value
            traj, info = Trajectory(), Info()
            traj.ParseFromString(bytes_traj)
            info.ParseFromString(bytes_info)

            #FORMAT DATA
            obs_shape = np.append(info.num_steps,
                                  info.visual_obs_spec).astype(int)
            obs = np.asarray(traj.visual_obs).reshape(obs_shape)

            self.exp_buff.add_trajectory(obs, traj.actions, traj.rewards,
                                         info.num_steps)

            if self.log_time is True: self.time_logger.log(1)

            if self.exp_buff.size >= self.exp_buff.max_size: break
            i += 1
        print("-> Fetched trajectories {} - {}".format(global_i - i, global_i))

        dataset = tf.data.Dataset.from_tensor_slices(
            ((self.exp_buff.obs, self.exp_buff.next_obs),
             (self.exp_buff.actions, self.exp_buff.rewards,
              self.exp_buff.next_mask)))
        dataset = dataset.shuffle(self.exp_buff.max_size).repeat().batch(
            self.batch_size)

        if self.log_time is True: self.time_logger.log(2)

        return dataset

    def export_model(self):
        """
        Method that saves the latest checkpoint to gcs_bucket.

        """
        model_path = 'gs://' + self.gcs_bucket_id + '/' + self.prefix + '_model'
        latest_checkpoint = self.estimator.latest_checkpoint()
        print(latest_checkpoint)
        all_checkpoint_files = tf.io.gfile.glob(latest_checkpoint + '*')
        for filename in all_checkpoint_files:
            suffix = filename.partition(latest_checkpoint)[2]
            destination_path = model_path + suffix
            print('Copying {} to {}'.format(filename, destination_path))
            tf.io.gfile.copy(filename, destination_path, overwrite=True)

    def train(self):
        """
        Method that trains a model using tf.Estimator using parameters defined in the constructor.

        """
        if self.log_time is True:
            self.time_logger = TimeLogger([
                "Fetch Data    ", "Parse Data    ", "To Dataset    ",
                "Build Model   ", "Compute Loss  ", "Estimator     ",
                "Save Model    "
            ])
        print("-> Starting training...")
        for epoch in range(self.train_epochs):
            self.estimator.train(input_fn=self.train_input_fn,
                                 steps=self.train_steps)

            if self.log_time is True: self.time_logger.log(5)

            if epoch > 0 and epoch % self.period == 0:
                self.export_model()

            if self.log_time is True: self.time_logger.log(6)

            if self.log_time is True: self.time_logger.print_logs()
        print("-> Done!")
示例#8
0
class DQN_Agent():
    """
    Class for controlling and managing training from a bigtable database.
      
    Attributes:
        cbt_table (google.cloud.bigtable.Table): Bigtable table object returned from [util.gcp_io.cbt_load_table].
        gcs_bucket (google.cloud.storage.Bucket): GCS bucket object returned from [util.gcp_io.gcs_load_bucket].
        gcs_bucket_id (str): Global name of the GCS bucket where the model will be saved/loaded.
        prefix (str): Prefix used for model and trajectory names.
        tmp_weights_filepath (str): Temporary local path for saving model before copying to GCS.
        buffer_size (int): Max size of the experience buffer.
        batch_size (int): Batch size for estimator.
        train_epochs (int): Number of cycles of querying bigtable and training.
        train_steps (int): Number of train steps per epoch.
        period (int): Interval for saving models.
        output_dir (str): Output directory for logs and models.
        log_time (bool): Flag for time logging.
        num_gpus (int): Number of gpu devices for estimator.
    """
    def __init__(self, **kwargs):
        """
        The constructor for DQN_Agent class.

        """
        hyperparams = kwargs['hyperparams']
        self.input_shape = hyperparams['input_shape']
        self.num_actions = hyperparams['num_actions']
        self.gamma = hyperparams['gamma']
        self.cbt_table = kwargs['cbt_table']
        self.gcs_bucket = kwargs['gcs_bucket']
        self.gcs_bucket_id = kwargs['gcs_bucket_id']
        self.prefix = kwargs['prefix']
        self.tmp_weights_filepath = kwargs['tmp_weights_filepath']
        self.batch_size = kwargs['batch_size']
        self.num_trajectories = kwargs['num_trajectories']
        self.train_epochs = kwargs['train_epochs']
        self.train_steps = kwargs['train_steps']
        self.period = kwargs['period']
        self.output_dir = kwargs['output_dir']
        self.log_time = kwargs['log_time']
        self.num_gpus = kwargs['num_gpus']
        self.tpu_name = kwargs['tpu_name']
        self.wandb = kwargs['wandb']
        self.exp_buff = ExperienceBuffer(kwargs['buffer_size'])

        if self.tpu_name is not None:
            self.distribution_strategy = get_distribution_strategy(
                distribution_strategy='tpu', tpu_address=self.tpu_name)
            self.device = '/job:worker'
        else:
            self.distribution_strategy = get_distribution_strategy(
                distribution_strategy='default', num_gpus=self.num_gpus)
            self.device = None
        with tf.device(self.device), self.distribution_strategy.scope():
            self.model = DQN_Model(
                input_shape=self.input_shape,
                num_actions=self.num_actions,
                conv_layer_params=hyperparams['conv_layer_params'],
                fc_layer_params=hyperparams['fc_layer_params'],
                learning_rate=hyperparams['learning_rate'])
        gcs_load_weights(self.model, self.gcs_bucket, self.prefix,
                         self.tmp_weights_filepath)

    def fill_experience_buffer(self):
        """
        Method that fills the experience buffer object from CBT.

        Reads a batch of rows and parses through them until experience buffer reaches buffer_size.

        """
        self.exp_buff.reset()

        if self.log_time is True: self.time_logger.reset()

        #FETCH DATA
        global_i = cbt_global_iterator(self.cbt_table)
        rows = cbt_read_rows(self.cbt_table, self.prefix,
                             self.num_trajectories, global_i)

        if self.log_time is True: self.time_logger.log("Fetch Data      ")

        for row in tqdm(
                rows, "Parsing trajectories {} - {}".format(
                    global_i - self.num_trajectories, global_i - 1)):
            #DESERIALIZE DATA
            bytes_obs = row.cells['trajectory']['obs'.encode()][0].value
            bytes_actions = row.cells['trajectory'][
                'actions'.encode()][0].value
            bytes_rewards = row.cells['trajectory'][
                'rewards'.encode()][0].value

            if self.log_time is True: self.time_logger.log("Parse Bytes     ")

            #FORMAT DATA
            actions = np.frombuffer(bytes_actions,
                                    dtype=np.uint8).astype(np.int32)
            rewards = np.frombuffer(bytes_rewards, dtype=np.float32)
            num_steps = actions.size
            obs_shape = np.append(num_steps, self.input_shape).astype(np.int32)
            obs = np.frombuffer(bytes_obs, dtype=np.float32).reshape(obs_shape)

            if self.log_time is True: self.time_logger.log("Format Data     ")

            self.exp_buff.add_trajectory(obs, actions, rewards, num_steps)

            if self.log_time is True: self.time_logger.log("Add To Exp_Buff ")
        self.exp_buff.preprocess()

        dataset = tf.data.Dataset.from_tensor_slices(
            ((self.exp_buff.obs, self.exp_buff.next_obs),
             (self.exp_buff.actions, self.exp_buff.rewards,
              self.exp_buff.next_mask)))
        dataset = dataset.shuffle(self.exp_buff.max_size).repeat().batch(
            self.batch_size)

        dist_dataset = self.distribution_strategy.experimental_distribute_dataset(
            dataset)

        if self.log_time is True: self.time_logger.log("To Dataset      ")

        return dist_dataset

    def train(self):
        """
        Method that trains a model using using parameters defined in the constructor.

        """
        @tf.function
        def train_step(dist_inputs):
            def step_fn(inputs):
                ((b_obs, b_next_obs), (b_actions, b_rewards,
                                       b_next_mask)) = inputs

                with tf.GradientTape() as tape:
                    q_pred, q_next = self.model(b_obs), self.model(b_next_obs)
                    one_hot_actions = tf.one_hot(b_actions, self.num_actions)
                    q_pred = tf.reduce_sum(q_pred * one_hot_actions, axis=-1)
                    q_next = tf.reduce_max(q_next, axis=-1)
                    q_target = b_rewards + (
                        tf.constant(self.gamma, dtype=tf.float32) * q_next)
                    mse = self.model.loss(q_target, q_pred)
                    loss = tf.reduce_sum(mse)

                total_grads = tape.gradient(loss, self.model.trainable_weights)
                self.model.opt.apply_gradients(
                    list(zip(total_grads, self.model.trainable_weights)))
                return mse

            per_example_losses = self.distribution_strategy.experimental_run_v2(
                step_fn, args=(dist_inputs, ))
            mean_loss = self.distribution_strategy.reduce(
                tf.distribute.ReduceOp.MEAN, per_example_losses, axis=None)
            return mean_loss

        if self.log_time is True:
            self.time_logger = TimeLogger([
                "Fetch Data      ", "Parse Bytes     ", "Format Data     ",
                "Add To Exp_Buff ", "To Dataset      ", "Train Step      ",
                "Save Model      "
            ])
        print("-> Starting training...")
        for epoch in range(self.train_epochs):
            with tf.device(self.device), self.distribution_strategy.scope():
                dataset = self.fill_experience_buffer()
                exp_buff = iter(dataset)

                losses = []
                for step in tqdm(range(self.train_steps),
                                 "Training epoch {}".format(epoch)):
                    loss = train_step(next(exp_buff))
                    losses.append(loss)

                    if self.log_time is True:
                        self.time_logger.log("Train Step      ")

                if self.wandb is not None:
                    mean_loss = np.mean(losses)
                    tf.summary.scalar("Mean Loss", mean_loss, epoch)
                    self.wandb.log({"Epoch": epoch, "Mean Loss": mean_loss})

            if epoch > 0 and epoch % self.period == 0:
                model_filename = self.prefix + '_model.h5'
                gcs_save_weights(self.model, self.gcs_bucket,
                                 self.tmp_weights_filepath, model_filename)

            if self.log_time is True: self.time_logger.log("Save Model      ")

            if self.log_time is True: self.time_logger.print_totaltime_logs()
        print("-> Done!")
示例#9
0
    def train(self):
        """
        Method that trains a model using using parameters defined in the constructor.

        """
        @tf.function
        def train_step(dist_inputs):
            def step_fn(inputs):
                ((b_obs, b_next_obs), (b_actions, b_rewards,
                                       b_next_mask)) = inputs

                with tf.GradientTape() as tape:
                    q_pred, q_next = self.model(b_obs), self.model(b_next_obs)
                    one_hot_actions = tf.one_hot(b_actions, self.num_actions)
                    q_pred = tf.reduce_sum(q_pred * one_hot_actions, axis=-1)
                    q_next = tf.reduce_max(q_next, axis=-1)
                    q_target = b_rewards + (
                        tf.constant(self.gamma, dtype=tf.float32) * q_next)
                    mse = self.model.loss(q_target, q_pred)
                    loss = tf.reduce_sum(mse)

                total_grads = tape.gradient(loss, self.model.trainable_weights)
                self.model.opt.apply_gradients(
                    list(zip(total_grads, self.model.trainable_weights)))
                return mse

            per_example_losses = self.distribution_strategy.experimental_run_v2(
                step_fn, args=(dist_inputs, ))
            mean_loss = self.distribution_strategy.reduce(
                tf.distribute.ReduceOp.MEAN, per_example_losses, axis=None)
            return mean_loss

        if self.log_time is True:
            self.time_logger = TimeLogger([
                "Fetch Data      ", "Parse Bytes     ", "Format Data     ",
                "Add To Exp_Buff ", "To Dataset      ", "Train Step      ",
                "Save Model      "
            ])
        print("-> Starting training...")
        for epoch in range(self.train_epochs):
            with tf.device(self.device), self.distribution_strategy.scope():
                dataset = self.fill_experience_buffer()
                exp_buff = iter(dataset)

                losses = []
                for step in tqdm(range(self.train_steps),
                                 "Training epoch {}".format(epoch)):
                    loss = train_step(next(exp_buff))
                    losses.append(loss)

                    if self.log_time is True:
                        self.time_logger.log("Train Step      ")

                if self.wandb is not None:
                    mean_loss = np.mean(losses)
                    tf.summary.scalar("Mean Loss", mean_loss, epoch)
                    self.wandb.log({"Epoch": epoch, "Mean Loss": mean_loss})

            if epoch > 0 and epoch % self.period == 0:
                model_filename = self.prefix + '_model.h5'
                gcs_save_weights(self.model, self.gcs_bucket,
                                 self.tmp_weights_filepath, model_filename)

            if self.log_time is True: self.time_logger.log("Save Model      ")

            if self.log_time is True: self.time_logger.print_totaltime_logs()
        print("-> Done!")
示例#10
0
                                              args.bucket_id, credentials)

    #LOAD MODEL
    model = DQN_Model(input_shape=VECTOR_OBS_SPEC,
                      num_actions=NUM_ACTIONS,
                      fc_layer_params=FC_LAYER_PARAMS,
                      learning_rate=LEARNING_RATE)
    # gcs_load_weights(model, gcs_bucket, args.prefix, args.tmp_weights_filepath)

    #SETUP TENSORBOARD/LOGGING
    train_log_dir = os.path.join(args.output_dir, 'logs/')
    os.makedirs(os.path.dirname(train_log_dir), exist_ok=True)
    loss_metrics = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
    train_summary_writer = tf.summary.create_file_writer(train_log_dir)
    if args.log_time is True:
        time_logger = TimeLogger(
            ['Fetch Data', 'Parse Data', 'Compute Loss', 'Generate Grads'])

    #TRAINING LOOP
    train_step = 0
    print("-> Starting training...")
    for epoch in range(args.train_epochs):
        if args.log_time is True:
            time_logger.reset()

        #FETCH DATA
        global_i = cbt_global_iterator(cbt_table)
        rows = cbt_read_rows(cbt_table, args.prefix, args.train_steps,
                             global_i)

        if args.log_time is True:
            time_logger.log('Fetch Data')