Example #1
0
    while all_train_data[num_train].get_scene_name().split(
            "_")[1] == all_train_data[num_train -
                                      1].get_scene_name().split("_")[1]:
        num_train += 1
    train_split = all_train_data[:num_train]
    tune_split = all_train_data[num_train:]

    logging.info("Created train dataset of size %d ", len(train_split))
    logging.info("Created tuning dataset of size %d ", len(tune_split))

    # Train on this dataset
    # print("Training Agent...")
    # if supervised:
    #     logging.info("Running supervised")
    #     model.do_supervised_train(agent, train_split, tune_split, experiment)
    # else:
    #     logging.info("Running RL/CB")
    #     model.do_train(agent, train_split, tune_split, experiment)

    # Test agent
    print("Testing Agent...")
    agent.test(tune_split, tensorboard)

    server.kill()

except Exception:
    server.kill()
    exc_info = sys.exc_info()
    traceback.print_exception(*exc_info)
    # raise e
Example #2
0
    logging.log(logging.DEBUG, "MODEL CREATED")

    # Create the agent
    logging.log(logging.DEBUG, "STARTING AGENT")
    agent = Agent(server=server,
                  model=model,
                  test_policy=test_policy,
                  action_space=action_space,
                  meta_data_util=meta_data_util,
                  config=config,
                  constants=constants)

    # create tensorboard
    tensorboard = Tensorboard("dummy")

    # Launch Unity Build
    launch_k_unity_builds([config["port"]],
                          "./simulators/NavDroneLinuxBuild.x86_64")

    test_data = DatasetParser.parse("data/nav_drone/dev_annotations_6000.json",
                                    config)
    agent.test(test_data, tensorboard)

    server.kill()

except Exception:
    server.kill()
    exc_info = sys.exc_info()
    traceback.print_exception(*exc_info)
    # raise e
Example #3
0
    def do_train_(simulator_file,
                  shared_model,
                  config,
                  action_space,
                  meta_data_util,
                  constants,
                  train_dataset,
                  tune_dataset,
                  experiment,
                  experiment_name,
                  rank,
                  server,
                  logger,
                  model_type,
                  use_pushover=False):

        # Launch unity
        launch_k_unity_builds([config["port"]], simulator_file)
        server.initialize_server()

        # Test policy
        test_policy = gp.get_argmax_action

        # torch.manual_seed(args.seed + rank)

        if rank == 0:  # client 0 creates a tensorboard server
            tensorboard = Tensorboard(experiment_name)
        else:
            tensorboard = None

        if use_pushover:
            pushover_logger = PushoverLogger(experiment_name)
        else:
            pushover_logger = None

        # Create a local model for rollouts
        local_model = model_type(config, constants)
        # local_model.train()

        # Create the Agent
        logger.log("STARTING AGENT")
        agent = Agent(server=server,
                      model=local_model,
                      test_policy=test_policy,
                      action_space=action_space,
                      meta_data_util=meta_data_util,
                      config=config,
                      constants=constants)
        logger.log("Created Agent...")

        action_counts = [0] * action_space.num_actions()
        max_epochs = constants["max_epochs"]
        dataset_size = len(train_dataset)
        tune_dataset_size = len(tune_dataset)

        # Create the learner to compute the loss
        learner = AsynchronousContextualBandit(shared_model, local_model,
                                               action_space, meta_data_util,
                                               config, constants, tensorboard)

        for epoch in range(1, max_epochs + 1):

            for data_point_ix, data_point in enumerate(train_dataset):

                # Sync with the shared model
                local_model.load_from_state_dict(shared_model.get_state_dict())

                if (data_point_ix + 1) % 100 == 0:
                    logger.log("Done %d out of %d" %
                               (data_point_ix, dataset_size))
                    logger.log("Training data action counts %r" %
                               action_counts)

                num_actions = 0
                max_num_actions = constants["horizon"] + constants[
                    "max_extra_horizon"]

                image, metadata = agent.server.reset_receive_feedback(
                    data_point)

                state = AgentObservedState(instruction=data_point.instruction,
                                           config=config,
                                           constants=constants,
                                           start_image=image,
                                           previous_action=None,
                                           data_point=data_point)
                meta_data_util.start_state_update_metadata(state, metadata)

                model_state = None
                batch_replay_items = []
                total_reward = 0
                forced_stop = True

                while num_actions < max_num_actions:

                    # Sample action using the policy
                    log_probabilities, model_state, image_emb_seq, volatile = \
                        local_model.get_probs(state, model_state)
                    probabilities = list(torch.exp(log_probabilities.data))[0]

                    # Sample action from the probability
                    action = gp.sample_action_from_prob(probabilities)
                    action_counts[action] += 1

                    if action == action_space.get_stop_action_index():
                        forced_stop = False
                        break

                    # Send the action and get feedback
                    image, reward, metadata = agent.server.send_action_receive_feedback(
                        action)

                    # Store it in the replay memory list
                    replay_item = ReplayMemoryItem(state,
                                                   action,
                                                   reward,
                                                   log_prob=log_probabilities,
                                                   volatile=volatile)
                    batch_replay_items.append(replay_item)

                    # Update the agent state
                    state = state.update(image, action, data_point=data_point)
                    meta_data_util.state_update_metadata(state, metadata)

                    num_actions += 1
                    total_reward += reward

                # Send final STOP action and get feedback
                image, reward, metadata = agent.server.halt_and_receive_feedback(
                )
                total_reward += reward

                if tensorboard is not None:
                    meta_data_util.state_update_metadata(tensorboard, metadata)

                # Store it in the replay memory list
                if not forced_stop:
                    replay_item = ReplayMemoryItem(
                        state,
                        action_space.get_stop_action_index(),
                        reward,
                        log_prob=log_probabilities,
                        volatile=volatile)
                    batch_replay_items.append(replay_item)

                # Perform update
                if len(batch_replay_items) > 0:
                    loss_val = learner.do_update(batch_replay_items)

                    if tensorboard is not None:
                        entropy = float(
                            learner.entropy.data[0]) / float(num_actions + 1)
                        tensorboard.log_scalar("loss", loss_val)
                        tensorboard.log_scalar("entropy", entropy)
                        tensorboard.log_scalar("total_reward", total_reward)

            # Save the model
            local_model.save_model(experiment + "/contextual_bandit_" +
                                   str(rank) + "_epoch_" + str(epoch))
            logger.log("Training data action counts %r" % action_counts)

            if tune_dataset_size > 0:
                # Test on tuning data
                agent.test(tune_dataset,
                           tensorboard=tensorboard,
                           logger=logger,
                           pushover_logger=pushover_logger)
Example #4
0
    def do_train_(shared_model,
                  config,
                  action_space,
                  meta_data_util,
                  constants,
                  train_dataset,
                  tune_dataset,
                  experiment,
                  experiment_name,
                  rank,
                  server,
                  logger,
                  model_type,
                  use_pushover=False):

        server.initialize_server()

        # Test policy
        test_policy = gp.get_argmax_action

        # torch.manual_seed(args.seed + rank)

        if rank == 0:  # client 0 creates a tensorboard server
            tensorboard = Tensorboard(experiment_name)
        else:
            tensorboard = None

        if use_pushover:
            pushover_logger = PushoverLogger(experiment_name)
        else:
            pushover_logger = None

        # Create a local model for rollouts
        local_model = model_type(config, constants)
        # local_model.train()

        # Create the Agent
        logger.log("STARTING AGENT")
        agent = Agent(server=server,
                      model=local_model,
                      test_policy=test_policy,
                      action_space=action_space,
                      meta_data_util=meta_data_util,
                      config=config,
                      constants=constants)
        logger.log("Created Agent...")

        action_counts = [0] * action_space.num_actions()
        max_epochs = constants["max_epochs"]
        dataset_size = len(train_dataset)
        tune_dataset_size = len(tune_dataset)

        # Create the learner to compute the loss
        learner = AsynchronousAdvantageActorGAECritic(shared_model,
                                                      local_model,
                                                      action_space,
                                                      meta_data_util, config,
                                                      constants, tensorboard)

        # Launch unity
        launch_k_unity_builds([config["port"]],
                              "./simulators/NavDroneLinuxBuild.x86_64")

        for epoch in range(1, max_epochs + 1):

            learner.epoch = epoch
            task_completion_accuracy = 0
            mean_stop_dist_error = 0
            stop_dist_errors = []
            for data_point_ix, data_point in enumerate(train_dataset):

                # Sync with the shared model
                # local_model.load_state_dict(shared_model.state_dict())
                local_model.load_from_state_dict(shared_model.get_state_dict())

                if (data_point_ix + 1) % 100 == 0:
                    logger.log("Done %d out of %d" %
                               (data_point_ix, dataset_size))
                    logger.log("Training data action counts %r" %
                               action_counts)

                num_actions = 0
                max_num_actions = constants["horizon"] + constants[
                    "max_extra_horizon"]

                image, metadata = agent.server.reset_receive_feedback(
                    data_point)

                pose = int(metadata["y_angle"] / 15.0)
                position_orientation = (metadata["x_pos"], metadata["z_pos"],
                                        metadata["y_angle"])
                state = AgentObservedState(
                    instruction=data_point.instruction,
                    config=config,
                    constants=constants,
                    start_image=image,
                    previous_action=None,
                    pose=pose,
                    position_orientation=position_orientation,
                    data_point=data_point)
                state.goal = GoalPrediction.get_goal_location(
                    metadata, data_point, learner.image_height,
                    learner.image_width)

                model_state = None
                batch_replay_items = []
                total_reward = 0
                forced_stop = True

                while num_actions < max_num_actions:

                    # Sample action using the policy
                    log_probabilities, model_state, image_emb_seq, volatile = \
                        local_model.get_probs(state, model_state)
                    probabilities = list(torch.exp(log_probabilities.data))[0]

                    # Sample action from the probability
                    action = gp.sample_action_from_prob(probabilities)
                    action_counts[action] += 1

                    # Generate goal
                    if config["do_goal_prediction"]:
                        goal = learner.goal_prediction_calculator.get_goal_location(
                            metadata, data_point, learner.image_height,
                            learner.image_width)
                    else:
                        goal = None

                    if action == action_space.get_stop_action_index():
                        forced_stop = False
                        break

                    # Send the action and get feedback
                    image, reward, metadata = agent.server.send_action_receive_feedback(
                        action)

                    # Store it in the replay memory list
                    replay_item = ReplayMemoryItem(state,
                                                   action,
                                                   reward,
                                                   log_prob=log_probabilities,
                                                   volatile=volatile,
                                                   goal=goal)
                    batch_replay_items.append(replay_item)

                    # Update the agent state
                    pose = int(metadata["y_angle"] / 15.0)
                    position_orientation = (metadata["x_pos"],
                                            metadata["z_pos"],
                                            metadata["y_angle"])
                    state = state.update(
                        image,
                        action,
                        pose=pose,
                        position_orientation=position_orientation,
                        data_point=data_point)
                    state.goal = GoalPrediction.get_goal_location(
                        metadata, data_point, learner.image_height,
                        learner.image_width)

                    num_actions += 1
                    total_reward += reward

                # Send final STOP action and get feedback
                image, reward, metadata = agent.server.halt_and_receive_feedback(
                )
                total_reward += reward

                if metadata["stop_dist_error"] < 5.0:
                    task_completion_accuracy += 1
                mean_stop_dist_error += metadata["stop_dist_error"]
                stop_dist_errors.append(metadata["stop_dist_error"])

                if tensorboard is not None:
                    tensorboard.log_all_train_errors(
                        metadata["edit_dist_error"],
                        metadata["closest_dist_error"],
                        metadata["stop_dist_error"])

                # Store it in the replay memory list
                if not forced_stop:
                    replay_item = ReplayMemoryItem(
                        state,
                        action_space.get_stop_action_index(),
                        reward,
                        log_prob=log_probabilities,
                        volatile=volatile,
                        goal=goal)
                    batch_replay_items.append(replay_item)

                # Update the scores based on meta_data
                # self.meta_data_util.log_results(metadata)

                # Perform update
                if len(batch_replay_items) > 0:  # 32:
                    loss_val = learner.do_update(batch_replay_items)
                    # self.action_prediction_loss_calculator.predict_action(batch_replay_items)
                    # del batch_replay_items[:]  # in place list clear

                    if tensorboard is not None:
                        cross_entropy = float(learner.cross_entropy.data[0])
                        tensorboard.log(cross_entropy, loss_val, 0)
                        entropy = float(
                            learner.entropy.data[0]) / float(num_actions + 1)
                        v_value_loss_per_step = float(
                            learner.value_loss.data[0]) / float(num_actions +
                                                                1)
                        tensorboard.log_scalar("entropy", entropy)
                        tensorboard.log_scalar("total_reward", total_reward)
                        tensorboard.log_scalar("v_value_loss_per_step",
                                               v_value_loss_per_step)
                        ratio = float(learner.ratio.data[0])
                        tensorboard.log_scalar(
                            "Abs_objective_to_entropy_ratio", ratio)

                        if learner.action_prediction_loss is not None:
                            action_prediction_loss = float(
                                learner.action_prediction_loss.data[0])
                            learner.tensorboard.log_action_prediction_loss(
                                action_prediction_loss)
                        if learner.temporal_autoencoder_loss is not None:
                            temporal_autoencoder_loss = float(
                                learner.temporal_autoencoder_loss.data[0])
                            tensorboard.log_temporal_autoencoder_loss(
                                temporal_autoencoder_loss)
                        if learner.object_detection_loss is not None:
                            object_detection_loss = float(
                                learner.object_detection_loss.data[0])
                            tensorboard.log_object_detection_loss(
                                object_detection_loss)
                        if learner.symbolic_language_prediction_loss is not None:
                            symbolic_language_prediction_loss = float(
                                learner.symbolic_language_prediction_loss.
                                data[0])
                            tensorboard.log_scalar(
                                "sym_language_prediction_loss",
                                symbolic_language_prediction_loss)
                        if learner.goal_prediction_loss is not None:
                            goal_prediction_loss = float(
                                learner.goal_prediction_loss.data[0])
                            tensorboard.log_scalar("goal_prediction_loss",
                                                   goal_prediction_loss)

            # Save the model
            local_model.save_model(experiment + "/contextual_bandit_" +
                                   str(rank) + "_epoch_" + str(epoch))
            logger.log("Training data action counts %r" % action_counts)
            mean_stop_dist_error = mean_stop_dist_error / float(
                len(train_dataset))
            task_completion_accuracy = (task_completion_accuracy *
                                        100.0) / float(len(train_dataset))
            logger.log("Training: Mean stop distance error %r" %
                       mean_stop_dist_error)
            logger.log("Training: Task completion accuracy %r " %
                       task_completion_accuracy)
            bins = range(0, 80, 3)  # range of distance
            histogram, _ = np.histogram(stop_dist_errors, bins)
            logger.log("Histogram of train errors %r " % histogram)

            if tune_dataset_size > 0:
                # Test on tuning data
                agent.test(tune_dataset,
                           tensorboard=tensorboard,
                           logger=logger,
                           pushover_logger=pushover_logger)
Example #5
0
def main():

    experiment_name = "blocks_experiments"
    experiment = "./results/" + experiment_name
    print("EXPERIMENT NAME: ", experiment_name)

    # Create the experiment folder
    if not os.path.exists(experiment):
        os.makedirs(experiment)

    # Define log settings
    log_path = experiment + '/test_baseline.log'
    multiprocess_logging_manager = MultiprocessingLoggerManager(
        file_path=log_path, logging_level=logging.INFO)
    master_logger = multiprocess_logging_manager.get_logger("Master")
    master_logger.log(
        "----------------------------------------------------------------")
    master_logger.log(
        "                    STARING NEW EXPERIMENT                      ")
    master_logger.log(
        "----------------------------------------------------------------")

    with open("data/blocks/config.json") as f:
        config = json.load(f)
    with open("data/shared/contextual_bandit_constants.json") as f:
        constants = json.load(f)
    print(json.dumps(config, indent=2))
    setup_validator = BlocksSetupValidator()
    setup_validator.validate(config, constants)

    # log core experiment details
    master_logger.log("CONFIG DETAILS")
    for k, v in sorted(config.items()):
        master_logger.log("    %s --- %r" % (k, v))
    master_logger.log("CONSTANTS DETAILS")
    for k, v in sorted(constants.items()):
        master_logger.log("    %s --- %r" % (k, v))
    master_logger.log("START SCRIPT CONTENTS")
    with open(__file__) as f:
        for line in f.readlines():
            master_logger.log(">>> " + line.strip())
    master_logger.log("END SCRIPT CONTENTS")

    action_space = ActionSpace(config)
    meta_data_util = MetaDataUtil()

    # Create vocabulary
    vocab = dict()
    vocab_list = open("./Assets/vocab_both").readlines()
    for i, tk in enumerate(vocab_list):
        token = tk.strip().lower()
        vocab[token] = i
    vocab["$UNK$"] = len(vocab_list)
    config["vocab_size"] = len(vocab_list) + 1

    # Test policy
    test_policy = gp.get_argmax_action

    # Create tensorboard
    tensorboard = Tensorboard("Agent Test")

    try:
        # Create the model
        master_logger.log("CREATING MODEL")
        model_type = IncrementalModelEmnlp
        shared_model = model_type(config, constants)
        shared_model.load_saved_model(
            "./results/model-folder-name/model-file-name")

        # Read the dataset
        test_data = DatasetParser.parse("devset.json", config)
        master_logger.log("Created test dataset of size %d " % len(test_data))

        # Create server and launch a client
        simulator_file = "./simulators/blocks/retro_linux_build.x86_64"
        config["port"] = find_k_ports(1)[0]
        server = BlocksServer(config, action_space, vocab=vocab)

        # Launch unity
        launch_k_unity_builds([config["port"]], simulator_file)
        server.initialize_server()

        # Create the agent
        master_logger.log("CREATING AGENT")
        agent = Agent(server=server,
                      model=shared_model,
                      test_policy=test_policy,
                      action_space=action_space,
                      meta_data_util=meta_data_util,
                      config=config,
                      constants=constants)

        agent.test(test_data, tensorboard)

    except Exception:
        exc_info = sys.exc_info()
        traceback.print_exception(*exc_info)
    def do_train_(shared_model,
                  config,
                  action_space,
                  meta_data_util,
                  args,
                  constants,
                  train_dataset,
                  tune_dataset,
                  experiment,
                  experiment_name,
                  rank,
                  server,
                  logger,
                  model_type,
                  use_pushover=False):

        server.initialize_server()

        # Test policy
        test_policy = gp.get_argmax_action

        # torch.manual_seed(args.seed + rank)

        if rank == 0:  # client 0 creates a tensorboard server
            tensorboard = Tensorboard(experiment_name)
        else:
            tensorboard = None

        if use_pushover:
            pushover_logger = PushoverLogger(experiment_name)
        else:
            pushover_logger = None

        # Create a local model for rollouts
        local_model = model_type(args, config=config)
        if torch.cuda.is_available():
            local_model.cuda()
        local_model.train()

        # Create the Agent
        logger.log("STARTING AGENT")
        agent = Agent(server=server,
                      model=local_model,
                      test_policy=test_policy,
                      action_space=action_space,
                      meta_data_util=meta_data_util,
                      config=config,
                      constants=constants)
        logger.log("Created Agent...")

        action_counts = [0] * action_space.num_actions()
        max_epochs = constants["max_epochs"]
        dataset_size = len(train_dataset)
        tune_dataset_size = len(tune_dataset)

        # Create the learner to compute the loss
        learner = AsynchronousContextualBandit(shared_model, local_model,
                                               action_space, meta_data_util,
                                               config, constants, tensorboard)

        # Launch unity
        launch_k_unity_builds([
            config["port"]
        ], "/home/dipendra/Downloads/NavDroneLinuxBuild/NavDroneLinuxBuild.x86_64"
                              )

        for epoch in range(1, max_epochs + 1):

            if tune_dataset_size > 0:
                # Test on tuning data
                agent.test(tune_dataset,
                           tensorboard=tensorboard,
                           logger=logger,
                           pushover_logger=pushover_logger)

            for data_point_ix, data_point in enumerate(train_dataset):

                # Sync with the shared model
                # local_model.load_state_dict(shared_model.state_dict())
                local_model.load_from_state_dict(shared_model.get_state_dict())

                if (data_point_ix + 1) % 100 == 0:
                    logging.info("Done %d out of %d", data_point_ix,
                                 dataset_size)
                    logging.info("Training data action counts %r",
                                 action_counts)

                num_actions = 0
                # max_num_actions = len(data_point.get_trajectory())
                # max_num_actions += self.constants["max_extra_horizon"]
                max_num_actions = constants["horizon"]

                image, metadata = agent.server.reset_receive_feedback(
                    data_point)

                pose = int(metadata["y_angle"] / 15.0)
                position_orientation = (metadata["x_pos"], metadata["z_pos"],
                                        metadata["y_angle"])
                state = AgentObservedState(
                    instruction=data_point.instruction,
                    config=config,
                    constants=constants,
                    start_image=image,
                    previous_action=None,
                    pose=pose,
                    position_orientation=position_orientation,
                    data_point=data_point)

                model_state = None
                batch_replay_items = []
                total_reward = 0
                forced_stop = True

                while num_actions < max_num_actions:

                    # Sample action using the policy
                    log_probabilities, model_state, image_emb_seq, state_feature = \
                        local_model.get_probs(state, model_state)
                    probabilities = list(torch.exp(log_probabilities.data))[0]

                    # Sample action from the probability
                    action = gp.sample_action_from_prob(probabilities)
                    action_counts[action] += 1

                    if action == action_space.get_stop_action_index():
                        forced_stop = False
                        break

                    # Send the action and get feedback
                    image, reward, metadata = agent.server.send_action_receive_feedback(
                        action)

                    # Store it in the replay memory list
                    rewards = learner.get_all_rewards(metadata)
                    replay_item = ReplayMemoryItem(state,
                                                   action,
                                                   reward,
                                                   log_prob=log_probabilities,
                                                   all_rewards=rewards)
                    batch_replay_items.append(replay_item)

                    # Update the agent state
                    pose = int(metadata["y_angle"] / 15.0)
                    position_orientation = (metadata["x_pos"],
                                            metadata["z_pos"],
                                            metadata["y_angle"])
                    state = state.update(
                        image,
                        action,
                        pose=pose,
                        position_orientation=position_orientation,
                        data_point=data_point)

                    num_actions += 1
                    total_reward += reward

                # Send final STOP action and get feedback
                image, reward, metadata = agent.server.halt_and_receive_feedback(
                )
                rewards = learner.get_all_rewards(metadata)
                total_reward += reward

                if tensorboard is not None:
                    tensorboard.log_all_train_errors(
                        metadata["edit_dist_error"],
                        metadata["closest_dist_error"],
                        metadata["stop_dist_error"])

                # Store it in the replay memory list
                if not forced_stop:
                    replay_item = ReplayMemoryItem(
                        state,
                        action_space.get_stop_action_index(),
                        reward,
                        log_prob=log_probabilities,
                        all_rewards=rewards)
                    batch_replay_items.append(replay_item)

                # Update the scores based on meta_data
                # self.meta_data_util.log_results(metadata)

                # Perform update
                if len(batch_replay_items) > 0:
                    loss_val = learner.do_update(batch_replay_items)
                    # self.action_prediction_loss_calculator.predict_action(batch_replay_items)
                    del batch_replay_items[:]  # in place list clear

                    if tensorboard is not None:
                        cross_entropy = float(learner.cross_entropy.data[0])
                        tensorboard.log(cross_entropy, loss_val, 0)
                        entropy = float(learner.entropy.data[0])
                        tensorboard.log_scalar("entropy", entropy)

                        ratio = float(learner.ratio.data[0])
                        tensorboard.log_scalar(
                            "Abs_objective_to_entropy_ratio", ratio)

                        if learner.action_prediction_loss is not None:
                            action_prediction_loss = float(
                                learner.action_prediction_loss.data[0])
                            learner.tensorboard.log_action_prediction_loss(
                                action_prediction_loss)
                        if learner.temporal_autoencoder_loss is not None:
                            temporal_autoencoder_loss = float(
                                learner.temporal_autoencoder_loss.data[0])
                            tensorboard.log_temporal_autoencoder_loss(
                                temporal_autoencoder_loss)
                        if learner.object_detection_loss is not None:
                            object_detection_loss = float(
                                learner.object_detection_loss.data[0])
                            tensorboard.log_object_detection_loss(
                                object_detection_loss)
                        if learner.symbolic_language_prediction_loss is not None:
                            symbolic_language_prediction_loss = float(
                                learner.symbolic_language_prediction_loss.
                                data[0])
                            tensorboard.log_scalar(
                                "sym_language_prediction_loss",
                                symbolic_language_prediction_loss)
                        if learner.goal_prediction_loss is not None:
                            goal_prediction_loss = float(
                                learner.goal_prediction_loss.data[0])
                            tensorboard.log_scalar("goal_prediction_loss",
                                                   goal_prediction_loss)
                        if learner.mean_factor_entropy is not None:
                            mean_factor_entropy = float(
                                learner.mean_factor_entropy.data[0])
                            tensorboard.log_factor_entropy_loss(
                                mean_factor_entropy)

            # Save the model
            local_model.save_model(experiment + "/contextual_bandit_" +
                                   str(rank) + "_epoch_" + str(epoch))

            logging.info("Training data action counts %r", action_counts)