Example #1
0
def data_efficient_imitation_across_multiple_tasks(domain_name, window_size, number_demonstrations, adapt_detector_threshold, start_monitoring_at_time_step, detector_c, detector_m, initial_detector_threshold, epochs, number_mini_batches, activation_unit, learning_rate, hidden_units, number_samples_variance_reduction, precision_alpha, weights_prior_mean_1, weights_prior_mean_2, weights_prior_deviation_1, weights_prior_deviation_2, mixture_pie, rho_mean, extra_likelihood_emphasis):
    COPY_OF_ALL_MUJOCO_TASK_IDENTITIES = copy.deepcopy(ALL_MUJOCO_TASK_IDENTITIES)
    for simulation_iterator in range(TOTAL_SIMULATION_ITERATIONS):
        random.seed(simulation_iterator)
        random.shuffle(COPY_OF_ALL_MUJOCO_TASK_IDENTITIES)
        
        ###### Naive Controller ######
        '''
        all_gathered_x, all_gathered_y = None, None
        tasks_trained_on, tasks_encountered = [], []
        
        print(GREEN('Starting runs for the naive controller'))
        for task_iterator, current_task_identity in enumerate(COPY_OF_ALL_MUJOCO_TASK_IDENTITIES):
            print(RED('Simulation iteration is ' + str(simulation_iterator) + ' and task iterator is ' + str(task_iterator)))

            tasks_trained_on.append(current_task_identity)
            tasks_encountered.append(current_task_identity)
            moving_windows_x, moving_windows_y, drift_per_time_step, moving_windows_x_size = getDemonstrationsFromTask(domain_name=domain_name, task_identity=current_task_identity, window_size=window_size, number_demonstrations=number_demonstrations)
            
            if all_gathered_x is None:
                all_gathered_x, all_gathered_y = copy.deepcopy(moving_windows_x), copy.deepcopy(moving_windows_y)
            else:
                all_gathered_x, all_gathered_y = np.append(all_gathered_x, moving_windows_x, axis=0), np.append(all_gathered_y, moving_windows_y, axis=0)

            disposible_training_x, disposible_training_y = copy.deepcopy(all_gathered_x), copy.deepcopy(all_gathered_y)
            mean_x, deviation_x = get_mean_and_deviation(data = disposible_training_x)
            disposible_training_x = NORMALIZE(disposible_training_x, mean_x, deviation_x)
            mean_y, deviation_y = get_mean_and_deviation(data = disposible_training_y)
            disposible_training_y = NORMALIZE(disposible_training_y, mean_y, deviation_y)

            configuration_identity = 'logs/' + domain_name + '/naive_controller/' + str(simulation_iterator) + '/' + str(task_iterator) + '/'
            training_logs_directory = configuration_identity + 'training/'
            if not os.path.exists(training_logs_directory):
                os.makedirs(training_logs_directory)

            file_name_to_save_meta_data = training_logs_directory + 'training_meta_data.pkl'
            meta_data_to_store = {MEAN_KEY_X: mean_x, DEVIATION_KEY_X: deviation_x, MEAN_KEY_Y:mean_y, DEVIATION_KEY_Y:deviation_y,
                                  DRIFT_PER_TIME_STEP_KEY: drift_per_time_step, MOVING_WINDOWS_X_SIZE_KEY: moving_windows_x_size,
                                  TASKS_TRAINED_ON_KEY: tasks_trained_on, TASKS_ENCOUNTERED_KEY: tasks_encountered,
                                  WINDOW_SIZE_KEY: window_size}
            with open(file_name_to_save_meta_data, 'wb') as f:
                pickle.dump(meta_data_to_store, f)

            print(BLUE('Training phase'))
            train_BBB(data_x=copy.deepcopy(disposible_training_x), data_y=copy.deepcopy(disposible_training_y), configuration_identity=configuration_identity, epochs=epochs, number_mini_batches=number_mini_batches, activation_unit=activation_unit,
             learning_rate=learning_rate, hidden_units=hidden_units, number_samples_variance_reduction=number_samples_variance_reduction, precision_alpha=precision_alpha, weights_prior_mean_1=weights_prior_mean_1,
              weights_prior_mean_2=weights_prior_mean_2, weights_prior_deviation_1=weights_prior_deviation_1, weights_prior_deviation_2=weights_prior_deviation_2, mixture_pie=mixture_pie, rho_mean=rho_mean, extra_likelihood_emphasis=extra_likelihood_emphasis)

            print(BLUE('Validation phase'))
            validate_BBB(domain_name=domain_name, task_identity=current_task_identity, configuration_identity=configuration_identity)
        
        '''
        ###### BBB Controller ######
        did_succeed = False
        all_gathered_x, all_gathered_y = None, None
        tasks_trained_on, tasks_encountered, task_iterator_trained_on = [], [], []
        current_task_identity = COPY_OF_ALL_MUJOCO_TASK_IDENTITIES[0]
        detector = Detector(domain_name=domain_name, start_monitoring_at_time_step=start_monitoring_at_time_step, initial_threshold=initial_detector_threshold, detector_m=detector_m, detector_c=detector_c)

        print(GREEN('Starting runs for the BBB controller'))
        for task_iterator in range(len(COPY_OF_ALL_MUJOCO_TASK_IDENTITIES)):
            print(RED('Simulation iteration is ' + str(simulation_iterator) + ', task iterator is ' + str(task_iterator) + ', and current task is ' + str(current_task_identity)))
            tasks_encountered.append(current_task_identity)
            detector.reset()

            configuration_identity = 'logs/' + domain_name + '/bbb_controller/detector_c_' + str(detector_c) + '_detector_m_' + str(detector_m) + '/' + str(simulation_iterator) + '/' + str(task_iterator) + '/'
            training_logs_directory = configuration_identity + 'training/'
            if not os.path.exists(training_logs_directory):
                os.makedirs(training_logs_directory)

            if not did_succeed:
                tasks_trained_on.append(current_task_identity)
                task_iterator_trained_on.append(task_iterator)

                moving_windows_x, moving_windows_y, drift_per_time_step, moving_windows_x_size = getDemonstrationsFromTask(domain_name=domain_name, task_identity=current_task_identity, window_size=window_size, number_demonstrations=number_demonstrations)              
                if all_gathered_x is None:
                    all_gathered_x, all_gathered_y = copy.deepcopy(moving_windows_x), copy.deepcopy(moving_windows_y)
                else:
                    all_gathered_x, all_gathered_y = np.append(all_gathered_x, moving_windows_x, axis=0), np.append(all_gathered_y, moving_windows_y, axis=0)
                disposible_training_x, disposible_training_y = copy.deepcopy(all_gathered_x), copy.deepcopy(all_gathered_y)
                mean_x, deviation_x = get_mean_and_deviation(data = disposible_training_x)
                disposible_training_x = NORMALIZE(disposible_training_x, mean_x, deviation_x)
                mean_y, deviation_y = get_mean_and_deviation(data = disposible_training_y)
                disposible_training_y = NORMALIZE(disposible_training_y, mean_y, deviation_y)

                file_name_to_save_meta_data = training_logs_directory + 'training_meta_data.pkl'
                meta_data_to_store = {MEAN_KEY_X: mean_x, DEVIATION_KEY_X: deviation_x, MEAN_KEY_Y:mean_y, DEVIATION_KEY_Y:deviation_y,
                                      DRIFT_PER_TIME_STEP_KEY: drift_per_time_step, MOVING_WINDOWS_X_SIZE_KEY: moving_windows_x_size,
                                      TASKS_TRAINED_ON_KEY: tasks_trained_on, TASKS_ENCOUNTERED_KEY: tasks_encountered,
                                      WINDOW_SIZE_KEY: window_size}
                with open(file_name_to_save_meta_data, 'wb') as f:
                    pickle.dump(meta_data_to_store, f)

                print(BLUE('Training phase'))
                train_BBB(data_x=copy.deepcopy(disposible_training_x), data_y=copy.deepcopy(disposible_training_y), configuration_identity=configuration_identity, epochs=epochs, number_mini_batches=number_mini_batches,
                 activation_unit=activation_unit, learning_rate=learning_rate, hidden_units=hidden_units, number_samples_variance_reduction=number_samples_variance_reduction, precision_alpha=precision_alpha,
                  weights_prior_mean_1=weights_prior_mean_1, weights_prior_mean_2=weights_prior_mean_2, weights_prior_deviation_1=weights_prior_deviation_1, weights_prior_deviation_2=weights_prior_deviation_2,
                   mixture_pie=mixture_pie, rho_mean=rho_mean, extra_likelihood_emphasis=extra_likelihood_emphasis)
                
                _, average_uncertainty = run_on_itself(domain_name=domain_name, task_identity=current_task_identity, configuration_identity=configuration_identity)
                #### Ground the threshold according to the quantitative value of uncertainty on the current task ####
                if adapt_detector_threshold:
                    detector.threshold = average_uncertainty
                
                meta_data_file_for_this_run = 'logs/' + domain_name + '/bbb_controller/detector_c_' + str(detector_c) + '_detector_m_' + str(detector_m) + '/' + str(simulation_iterator) + '/meta_data.pkl'
                meta_data_for_this_run = {TRAINING_TASK_ITERATION_KEY: task_iterator_trained_on}
                with open(meta_data_file_for_this_run, 'wb') as f:
                    pickle.dump(meta_data_for_this_run, f)
                #need_training = False

            print(BLUE('Validation phase'))
            validate_BBB(domain_name=domain_name, task_identity=current_task_identity, configuration_identity=configuration_identity)

            if task_iterator == (len(COPY_OF_ALL_MUJOCO_TASK_IDENTITIES) - 1):
                break

            current_task_identity = COPY_OF_ALL_MUJOCO_TASK_IDENTITIES[task_iterator + 1]
            tasks_encountered.append(current_task_identity)
            did_succeed, average_uncertainty = run_on_itself(domain_name=domain_name, task_identity=current_task_identity, configuration_identity=configuration_identity, detector=detector)
            did_succeed = str_to_bool(did_succeed)