コード例 #1
0
    def __init__(self, dataset_type, config):
        super(MultiTask, self).__init__()

        self.config = config
        self.dataset_type = dataset_type

        self.task_names = map(lambda x: x.strip(),
                              self.config["tasks"].split(","))

        self.tasks = []
        self.tasks_lens = []

        for task_name in self.task_names:
            task_class = registry.get_task_class(task_name)
            if task_class is None:
                print("[Error] %s not present in our mapping" % task_name)
                return

            if task_name not in self.config["task_attributes"]:
                print("[Error] No attributes present for task %s in config."
                      " Skipping" % task_name)

            task_attributes = self.config["task_attributes"][task_name]
            task_attributes["dataset_type"] = self.dataset_type

            task = task_class()
            task.load(**task_attributes)

            self.tasks.append(task)
            self.tasks_lens.append(len(task))

        self.task_probabilities = [1 for _ in self.tasks]

        self.num_tasks = len(self.tasks)

        training_parameters = self.config["training_parameters"]
        if training_parameters["task_size_proportional_sampling"]:
            self.task_probabilities = self.tasks_lens[:]
            len_sum = sum(self.tasks_lens)
            self.task_probabilities = [
                prob / len_sum for prob in self.task_probabilities
            ]

        self.change_task()
コード例 #2
0
ファイル: flags.py プロジェクト: sameerdharur/sorting-vqa
    def update_task_args(self):
        args = sys.argv
        task_names = None
        for index, item in enumerate(args):
            if item == "--tasks":
                task_names = args[index + 1]

        if task_names is None:
            return

        task_names = map(lambda x: x.strip(), task_names.split(","))

        for task_name in task_names:
            task_class = registry.get_task_class(task_name)
            if task_class is None:
                return

            task_object = task_class()
            task_object.init_args(self.parser)