示例#1
0
    def train(self):
        """
        Method that trains a model using tf.Estimator using parameters defined in the constructor.

        """
        if self.log_time is True:
            self.time_logger = TimeLogger([
                "Fetch Data    ", "Parse Data    ", "To Dataset    ",
                "Build Model   ", "Compute Loss  ", "Estimator     ",
                "Save Model    "
            ])
        print("-> Starting training...")
        for epoch in range(self.train_epochs):
            self.estimator.train(input_fn=self.train_input_fn,
                                 steps=self.train_steps)

            if self.log_time is True: self.time_logger.log(5)

            if epoch > 0 and epoch % self.period == 0:
                self.export_model()

            if self.log_time is True: self.time_logger.log(6)

            if self.log_time is True: self.time_logger.print_logs()
        print("-> Done!")
示例#2
0
    #LOAD MODEL
    model = DQN_Model(input_shape=env.observation_space.shape,
                      num_actions=env.action_space.n,
                      conv_layer_params=CONV_LAYER_PARAMS,
                      fc_layer_params=FC_LAYER_PARAMS,
                      learning_rate=LEARNING_RATE)

    #GLOBAL ITERATOR
    global_i = cbt_global_iterator(cbt_table)
    print("global_i = {}".format(global_i))

    if args.log_time is True:
        time_logger = TimeLogger([
            "0Collect Data", "1Take Action / conv b",
            "2Append new obs / conv b", "3Generate Visual obs keys",
            "4Build pb2 objects traj", "5Write Cells visual", "6Batch visual",
            "7Write cells traj", "8Write cells traj"
        ],
                                 num_cycles=args.num_episodes)

#COLLECT DATA FOR CBT
    print("-> Starting data collection...")
    rows, visual_obs_rows = [], []
    for cycle in range(args.num_cycles):
        gcs_load_weights(model, gcs_bucket, args.prefix,
                         args.tmp_weights_filepath)
        for i in tqdm(range(args.num_episodes), "Cycle {}".format(cycle)):
            if args.log_time is True: time_logger.reset()

            #CREATE ROW_KEY
            row_key_i = i + global_i + (cycle * args.num_episodes)
    #LOAD MODEL
    model = DQN_Model(input_shape=env.observation_space.shape,
                      num_actions=env.action_space.n,
                      conv_layer_params=CONV_LAYER_PARAMS,
                      fc_layer_params=FC_LAYER_PARAMS,
                      learning_rate=LEARNING_RATE)

    #GLOBAL ITERATOR
    global_i = cbt_global_iterator(cbt_table)
    print("global_i = {}".format(global_i))

    #INITIALIZE EXECUTION TIME LOGGER
    if args.log_time is True:
        time_logger = TimeLogger([
            "Load Weights     ", "Run Environment  ", "Data To Bytes    ",
            "Write Cells      ", "Mutate Rows      "
        ])

    #COLLECT DATA FOR CBT
    print("-> Starting data collection...")
    rows = []
    for cycle in range(args.num_cycles):
        if args.log_time is True: time_logger.reset()

        gcs_load_weights(model, gcs_bucket, args.prefix,
                         args.tmp_weights_filepath)

        if args.log_time is True: time_logger.log("Load Weights     ")

        for i in tqdm(range(args.num_episodes), "Cycle {}".format(cycle)):
示例#4
0
                    }
                }
                data = json.dumps(items_to_json,
                                  separators=(',', ':'),
                                  cls=NumpyEncoder)
                json.dump(data, fd)
                fd.write("\n")
                i = +1

    #GLOBAL ITERATOR
    global_i = cbt_global_iterator(cbt_table)
    print("global_i = {}".format(global_i))

    if args.log_time is True:
        time_logger = TimeLogger(
            ["Collect Data", "Serialize Data", "Write Cells", "Mutate Rows"],
            num_cycles=args.num_episodes)

    #COLLECT DATA FOR CBT
    print("-> Starting data collection...")
    rows = []
    for cycle in range(args.num_cycles):
        #gcs_load_weights(model, gcs_bucket, args.prefix, args.tmp_weights_filepath)
        for i in tqdm(range(args.num_episodes), "Cycle {}".format(cycle)):
            if args.log_time is True: time_logger.reset()

            #RL LOOP GENERATES A TRAJECTORY
            data = collect_data(env, random_policy, steps=args.max_steps)
            print("data: ", data[1].action_step.action.numpy())
            write_data(data, "Rab-Agent_: " + (str(cycle)) + "_" + (str(i)))
    env.close()
    #LOAD MODEL
    model = DQN_Model(input_shape=VISUAL_OBS_SPEC,
                      num_actions=NUM_ACTIONS,
                      conv_layer_params=CONV_LAYER_PARAMS,
                      fc_layer_params=FC_LAYER_PARAMS,
                      learning_rate=LEARNING_RATE)
    gcs_load_weights(model, gcs_bucket, args.prefix, args.tmp_weights_filepath)

    #SETUP TENSORBOARD/LOGGING
    train_log_dir = os.path.join(args.output_dir, 'logs/')
    os.makedirs(os.path.dirname(train_log_dir), exist_ok=True)
    loss_metrics = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
    train_summary_writer = tf.summary.create_file_writer(train_log_dir)
    if args.log_time is True:
        time_logger = TimeLogger(
            ["Fetch Data", "Parse Data", "Compute Loss", "Generate Grads"],
            num_cycles=args.train_steps)

    #TRAINING LOOP
    train_step = 0
    exp_buff = ExperienceBuffer(args.buffer_size)
    print("-> Starting training...")
    for epoch in range(args.train_epochs):
        if args.log_time is True: time_logger.reset()

        #FETCH DATA
        global_i = cbt_global_iterator(cbt_table)
        rows = cbt_read_rows(cbt_table, args.prefix, args.train_steps,
                             global_i)

        if args.log_time is True: time_logger.log(0)
示例#6
0
    def train(self):
        """
        Method that trains a model using using parameters defined in the constructor.

        """
        @tf.function
        def train_step(dist_inputs):
            def step_fn(inputs):
                ((b_obs, b_next_obs), (b_actions, b_rewards,
                                       b_next_mask)) = inputs

                with tf.GradientTape() as tape:
                    q_pred, q_next = self.model(b_obs), self.model(b_next_obs)
                    one_hot_actions = tf.one_hot(b_actions, self.num_actions)
                    q_pred = tf.reduce_sum(q_pred * one_hot_actions, axis=-1)
                    q_next = tf.reduce_max(q_next, axis=-1)
                    q_target = b_rewards + (
                        tf.constant(self.gamma, dtype=tf.float32) * q_next)
                    mse = self.model.loss(q_target, q_pred)
                    loss = tf.reduce_sum(mse)

                total_grads = tape.gradient(loss, self.model.trainable_weights)
                self.model.opt.apply_gradients(
                    list(zip(total_grads, self.model.trainable_weights)))
                return mse

            per_example_losses = self.distribution_strategy.experimental_run_v2(
                step_fn, args=(dist_inputs, ))
            mean_loss = self.distribution_strategy.reduce(
                tf.distribute.ReduceOp.MEAN, per_example_losses, axis=None)
            return mean_loss

        if self.log_time is True:
            self.time_logger = TimeLogger([
                "Fetch Data      ", "Parse Bytes     ", "Format Data     ",
                "Add To Exp_Buff ", "To Dataset      ", "Train Step      ",
                "Save Model      "
            ])
        print("-> Starting training...")
        for epoch in range(self.train_epochs):
            with tf.device(self.device), self.distribution_strategy.scope():
                dataset = self.fill_experience_buffer()
                exp_buff = iter(dataset)

                losses = []
                for step in tqdm(range(self.train_steps),
                                 "Training epoch {}".format(epoch)):
                    loss = train_step(next(exp_buff))
                    losses.append(loss)

                    if self.log_time is True:
                        self.time_logger.log("Train Step      ")

                if self.wandb is not None:
                    mean_loss = np.mean(losses)
                    tf.summary.scalar("Mean Loss", mean_loss, epoch)
                    self.wandb.log({"Epoch": epoch, "Mean Loss": mean_loss})

            if epoch > 0 and epoch % self.period == 0:
                model_filename = self.prefix + '_model.h5'
                gcs_save_weights(self.model, self.gcs_bucket,
                                 self.tmp_weights_filepath, model_filename)

            if self.log_time is True: self.time_logger.log("Save Model      ")

            if self.log_time is True: self.time_logger.print_totaltime_logs()
        print("-> Done!")
示例#7
0
                                              args.bucket_id, credentials)

    #LOAD MODEL
    model = DQN_Model(input_shape=VECTOR_OBS_SPEC,
                      num_actions=NUM_ACTIONS,
                      fc_layer_params=FC_LAYER_PARAMS,
                      learning_rate=LEARNING_RATE)
    # gcs_load_weights(model, gcs_bucket, args.prefix, args.tmp_weights_filepath)

    #SETUP TENSORBOARD/LOGGING
    train_log_dir = os.path.join(args.output_dir, 'logs/')
    os.makedirs(os.path.dirname(train_log_dir), exist_ok=True)
    loss_metrics = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
    train_summary_writer = tf.summary.create_file_writer(train_log_dir)
    if args.log_time is True:
        time_logger = TimeLogger(
            ['Fetch Data', 'Parse Data', 'Compute Loss', 'Generate Grads'])

    #TRAINING LOOP
    train_step = 0
    print("-> Starting training...")
    for epoch in range(args.train_epochs):
        if args.log_time is True:
            time_logger.reset()

        #FETCH DATA
        global_i = cbt_global_iterator(cbt_table)
        rows = cbt_read_rows(cbt_table, args.prefix, args.train_steps,
                             global_i)

        if args.log_time is True:
            time_logger.log('Fetch Data')