Exemplo n.º 1
0
    def test_fusion_parameters(self):
        """
        TEST : Fusion.dempster_shafer_fusion_parameters
        """
        from iota2.Classification import Fusion
        from iota2.Common import IOTA2Directory
        from iota2.Common import ServiceConfigFile as SCF

        # define inputs
        cfg = SCF.serviceConfigFile(self.config_test)
        iota2_dir = os.path.join(self.test_working_directory, "fusionTest")
        cfg.setParam('chain', 'outputPath', iota2_dir)

        IOTA2Directory.generate_directories(iota2_dir, check_inputs=False)
        iota2_ds_confusions_dir = os.path.join(iota2_dir, "dataAppVal",
                                               "bymodels")
        fut.ensure_dir(iota2_ds_confusions_dir)
        # generate some fake data
        nb_seed = 2
        for i in range(nb_seed):
            fake_classif_opt = os.path.join(
                iota2_dir, "classif",
                "Classif_T31TCJ_model_1_seed_{}.tif".format(i))
            fake_classif_sar = os.path.join(
                iota2_dir, "classif",
                "Classif_T31TCJ_model_1_seed_{}_SAR.tif".format(i))
            fake_confidence_opt = os.path.join(
                iota2_dir, "classif",
                "T31TCJ_model_1_confidence_seed_{}.tif".format(i))
            fake_confidence_sar = os.path.join(
                iota2_dir, "classif",
                "T31TCJ_model_1_confidence_seed_{}_SAR.tif".format(i))
            fake_model_confusion_sar = os.path.join(
                iota2_ds_confusions_dir, "model_1_seed_{}_SAR.csv".format(i))
            fake_model_confusion_opt = os.path.join(
                iota2_ds_confusions_dir, "model_1_seed_{}.csv".format(i))

            with open(fake_classif_opt, "w") as new_file:
                new_file.write("TEST")
            with open(fake_classif_sar, "w") as new_file:
                new_file.write("TEST")
            with open(fake_confidence_opt, "w") as new_file:
                new_file.write("TEST")
            with open(fake_confidence_sar, "w") as new_file:
                new_file.write("TEST")
            with open(fake_model_confusion_sar, "w") as new_file:
                new_file.write("TEST")
            with open(fake_model_confusion_opt, "w") as new_file:
                new_file.write("TEST")

        parameters_test = Fusion.dempster_shafer_fusion_parameters(iota2_dir)
        # parameters_test depend of execution environement, remove local path is necessary
        for param_group in parameters_test:
            for key, value in list(param_group.items()):
                param_group[key] = value.replace(iota2_dir, "")
        #reverse parameters_test if necessary
        if "seed_1" in parameters_test[0]["sar_model"]:
            parameters_test = parameters_test[::-1]

        # assert
        self.assertTrue(all(param_group_test == param_group_ref
                            for param_group_test, param_group_ref in zip(
                                parameters_test, self.parameter_ref)),
                        msg="input parameters generation failed")
Exemplo n.º 2
0
def mpi_schedule(iota2_step,
                 param_array_origin,
                 mpi_service=MPIService(),
                 logPath=None,
                 logger_lvl="INFO",
                 enable_console=False):
    """
    A simple MPI scheduler to execute jobs in parallel.
    """

    if mpi_service.rank != 0:
        return None

    job = iota2_step.step_execute()

    returned_data_list = []
    parameters_success = []

    if not param_array_origin:
        raise Exception(
            "JobArray must contain a list of parameter as argument")
        sys.exit(1)
    try:
        if os.path.exists(logPath):
            os.remove(logPath)
        if callable(param_array_origin):
            param_array = param_array_origin()
        else:
            #shallowCopy
            param_array = [param for param in param_array_origin]
        if mpi_service.size > 1:
            # master
            nb_completed_tasks = 0
            nb_tasks = len(param_array)
            for i in range(1, mpi_service.size):
                if len(param_array) > 0:
                    task_param = param_array.pop(0)
                    mpi_service.comm.send(
                        [job, task_param, logger_lvl, enable_console],
                        dest=i,
                        tag=0)
            while nb_completed_tasks < nb_tasks:
                [
                    worker_rank,
                    [start, end, worker_complete_log, returned_data, success]
                ] = mpi_service.comm.recv(source=MPI.ANY_SOURCE, tag=0)
                returned_data_list.append(returned_data)
                parameters_success.append(success)
                #Write worker log
                fut.ensure_dir(os.path.split(logPath)[0])
                with open(logPath, "a+") as log_f:
                    log_f.write(worker_complete_log)
                nb_completed_tasks += 1
                if len(param_array) > 0:
                    task_param = param_array.pop(0)
                    mpi_service.comm.send(
                        [job, task_param, logger_lvl, enable_console],
                        dest=worker_rank,
                        tag=0)
        else:
            #if not lanch thanks to mpirun, launch each parameters one by one
            for param in param_array:
                worker_log = sLog.Log_task(logger_lvl, enable_console)
                worker_complete_log, start_date, end_date, returned_data, success = launchTask(
                    job, param, worker_log)
                fut.ensure_dir(os.path.split(logPath)[0])
                with open(logPath, "a+") as log_f:
                    log_f.write(worker_complete_log)
                returned_data_list.append(returned_data)
                parameters_success.append(success)
    except KeyboardInterrupt:
        raise
    except Exception as e:
        print(e)
        parameters_success.append(False)
        if mpi_service.rank == 0 and mpi_service.size > 1:
            print("Something went wrong, we should log errors.")
            traceback.print_exc()
            stop_workers(mpi_service)
            sys.exit(1)

    step_completed = all(parameters_success)
    if step_completed:
        iota2_step.step_clean()
    return returned_data_list, step_completed