示例#1
0
def build_trainer(args, *rest, **kwargs):
    configuration = Configuration(args.config)

    # Update with the config override if passed
    configuration.override_with_cmd_config(args.config_override)

    # Now, update with opts args that were passed
    configuration.override_with_cmd_opts(args.opts)

    # Finally, update with args that were specifically passed
    # as arguments
    configuration.update_with_args(args)
    configuration.freeze()

    config = configuration.get_config()
    registry.register("config", config)
    registry.register("configuration", configuration)

    trainer_type = config.training_parameters.trainer
    trainer_cls = registry.get_trainer_class(trainer_type)
    trainer_obj = trainer_cls(config)

    # Set args as an attribute for future use
    setattr(trainer_obj, 'args', args)

    return trainer_obj
示例#2
0
    def load(self):
        self._init_process_group()

        self.run_type = self.config.training_parameters.get("run_type", "train")
        self.task_loader = TaskLoader(self.config)

        self.writer = Logger(self.config)
        registry.register("writer", self.writer)

        self.configuration = registry.get("configuration")
        self.configuration.pretty_print()

        self.config_based_setup()

        self.load_task()
        self.load_model()
        self.load_optimizer()
        self.load_extras()

        # a survey on model size
        self.writer.write("----------MODEL SIZE----------")
        total = 0
        for p in self.model.named_parameters():
            self.writer.write(p[0] + str(p[1].shape))
            total += torch.numel(p[1])
        self.writer.write("total parameters to train: {}".format(total))

        # init a TensorBoard writer
        self.tb_writer = SummaryWriter(
            os.path.join("save/tb", getattr(self.config.model_attributes, self.config.model).code_name))
示例#3
0
    def get_data_t(self, t, data, batch_size_t, prev_output):
        if self.teacher_forcing:
            # Modify batch_size for timestep t
            batch_size_t = sum([l > t for l in data["decode_lengths"]])
        elif prev_output is not None and self.config["inference"][
                "type"] == "greedy":
            # Adding t-1 output words to data["text"] for greedy decoding
            output_softmax = torch.log_softmax(prev_output, dim=1)
            _, indices = torch.max(output_softmax, dim=1, keepdim=True)
            data["texts"] = torch.cat(
                (data["texts"], indices.view(batch_size_t, 1)), dim=1)

        # Slice data based on batch_size at timestep t
        data["texts"] = data["texts"][:batch_size_t]
        if "state" in data:
            h1 = data["state"]["td_hidden"][0][:batch_size_t]
            c1 = data["state"]["td_hidden"][1][:batch_size_t]
            h2 = data["state"]["lm_hidden"][0][:batch_size_t]
            c2 = data["state"]["lm_hidden"][1][:batch_size_t]
        else:
            h1, c1 = self.init_hidden_state(data["texts"])
            h2, c2 = self.init_hidden_state(data["texts"])
        data["state"] = {"td_hidden": (h1, c1), "lm_hidden": (h2, c2)}
        registry.register("{}_lstm_state".format(h1.device), data["state"])

        return data, batch_size_t
示例#4
0
    def __call__(self, sample_list, model_output, *args, **kwargs):
        values = {}
        if not hasattr(sample_list, "targets"):
            return values

        dataset_type = sample_list.dataset_type

        with torch.no_grad():
            for metric_name, metric_object in self.metrics.items():
                key = "{}/{}".format(dataset_type, metric_name)
                values[key] = metric_object._calculate_with_checks(
                    sample_list, model_output, *args, **kwargs)

                if not isinstance(values[key], torch.Tensor):
                    values[key] = torch.tensor(values[key], dtype=torch.float)
                else:
                    values[key] = values[key].float()

                if values[key].dim() == 0:
                    values[key] = values[key].view(1)

        registry.register(
            "{}.{}.{}".format("metrics", sample_list.dataset_name,
                              dataset_type), values)

        return values
示例#5
0
def setup_imports():
    # Automatically load all of the modules, so that
    # they register with registry
    root_folder = registry.get("pythia_root", no_warning=True)

    if root_folder is None:
        root_folder = os.path.dirname(os.path.abspath(__file__))
        root_folder = os.path.join(root_folder, "..")

        environment_pythia_path = os.environ.get("PYTHIA_PATH")

        if environment_pythia_path is not None:
            root_folder = environment_pythia_path

        root_folder = os.path.join(root_folder, "pythia")
        registry.register("pythia_path", root_folder)

    trainer_folder = os.path.join(root_folder, "trainers")
    trainer_pattern = os.path.join(trainer_folder, "**", "*.py")
    tasks_folder = os.path.join(root_folder, "tasks")
    tasks_pattern = os.path.join(tasks_folder, "**", "*.py")
    model_folder = os.path.join(root_folder, "models")
    model_pattern = os.path.join(model_folder, "**", "*.py")

    importlib.import_module("pythia.common.meter")

    files = glob.glob(tasks_pattern, recursive=True) + \
            glob.glob(model_pattern, recursive=True) + \
            glob.glob(trainer_pattern, recursive=True)

    for f in files:
        if f.endswith("task.py"):
            splits = f.split(os.sep)
            task_name = splits[-2]
            if task_name == "tasks":
                continue
            file_name = splits[-1]
            module_name = file_name[:file_name.find(".py")]
            importlib.import_module("pythia.tasks." + task_name + "." +
                                    module_name)
        elif f.find("models") != -1:
            splits = f.split(os.sep)
            file_name = splits[-1]
            module_name = file_name[:file_name.find(".py")]
            importlib.import_module("pythia.models." + module_name)
        elif f.find("trainer") != -1:
            splits = f.split(os.sep)
            file_name = splits[-1]
            module_name = file_name[:file_name.find(".py")]
            importlib.import_module("pythia.trainers." + module_name)
        elif f.endswith("builder.py"):
            splits = f.split(os.sep)
            task_name = splits[-3]
            dataset_name = splits[-2]
            if task_name == "tasks" or dataset_name == "tasks":
                continue
            file_name = splits[-1]
            module_name = file_name[:file_name.find(".py")]
            importlib.import_module("pythia.tasks." + task_name + "." +
                                    dataset_name + "." + module_name)
示例#6
0
    def forward(self, sample_list, model_output, *args, **kwargs):
        """Takes in the original ``SampleList`` returned from DataLoader
        and `model_output` returned from the model and returned a Dict containing
        loss for each of the losses in `losses`.

        Args:
            sample_list (SampleList): SampleList given be the dataloader.
            model_output (Dict): Dict returned from model as output.

        Returns:
            Dict: Dictionary containing loss value for each of the loss.

        """
        output = {}
        if not hasattr(sample_list, "targets"):
            if not self._evalai_inference:
                warnings.warn("Sample list has not field 'targets', are you "
                              "sure that your ImDB has labels? you may have "
                              "wanted to run with --evalai_inference 1")
            return output

        for loss in self.losses:
            output.update(loss(sample_list, model_output, *args, **kwargs))

        registry_loss_key = "{}.{}.{}".format("losses",
                                              sample_list.dataset_name,
                                              sample_list.dataset_type)
        # Register the losses to registry
        registry.register(registry_loss_key, output)

        return output
示例#7
0
    def train(self):
        # self.writer.write("===== Model =====")
        # self.writer.write(self.model)
        if self.run_type == "all_in_one":
            self._all_in_one()
        if self.run_type == "train_viz":
            self._inference_run("train")
            return
        if "train" not in self.run_type:
            self.inference()
            return

        should_break = False

        if self.max_epochs is None:
            self.max_epochs = math.inf
        else:
            self.max_iterations = math.inf

        self.model.train()
        self.train_timer = Timer()
        self.snapshot_timer = Timer()

        self.profile("Setup Time")

        torch.autograd.set_detect_anomaly(True)

        self.writer.write("Starting training...")
        while self.current_iteration < self.max_iterations and not should_break:
            self.current_epoch += 1
            registry.register("current_epoch", self.current_epoch)

            # Seed the sampler in case if it is distributed
            self.task_loader.seed_sampler("train", self.current_epoch)

            if self.current_epoch > self.max_epochs:
                break

            for batch in self.train_loader:
                self.profile("Batch load time")
                self.current_iteration += 1
                self.writer.write(self.current_iteration, "debug")

                registry.register("current_iteration", self.current_iteration)

                if self.current_iteration > self.max_iterations:
                    break

                self._run_scheduler()
                report, _ = self._forward_pass(batch)
                self._update_meter(report, self.meter)
                loss = self._extract_loss(report)
                self._backward(loss)
                should_break = self._logistics(report)

                if should_break:
                    break

        self.finalize()
示例#8
0
def get_pythia_root():
    from pythia.common.registry import registry

    pythia_root = registry.get("pythia_root", no_warning=True)
    if pythia_root is None:
        pythia_root = os.path.dirname(os.path.abspath(__file__))
        pythia_root = os.path.abspath(os.path.join(pythia_root, ".."))
        registry.register("pythia_root", pythia_root)
    return pythia_root
示例#9
0
 def update_registry_for_model(self, config):
     registry.register(
         self.dataset_name + "_text_vocab_size",
         self.dataset.text_processor.get_vocab_size(),
     )
     registry.register(
         self.dataset_name + "_num_final_outputs",
         self.dataset.answer_processor.get_vocab_size(),
     )
示例#10
0
	def update_registry_for_model(self, config):
		# Register vocab (question and answer) sizes to registry for easy access to models.
		registry.register(
			self.dataset_name + "_text_vocab_size",
			self.dataset.text_processor.get_vocab_size(),
		)
		
		registry.register(
			self.dataset_name + "_num_final_outputs",
			self.dataset.answer_processor.get_vocab_size()-1,
		)
示例#11
0
    def _load(self, dataset_type, config, *args, **kwargs):
        self.config = config

        image_features = config["image_features"]["train"][0].split(",")
        self.num_image_features = len(image_features)

        registry.register("num_image_features", self.num_image_features)

        self.dataset = self.prepare_data_set(dataset_type, config)

        return self.dataset
示例#12
0
 def setUp(self):
     torch.manual_seed(1234)
     config_path = os.path.join(get_pythia_root(), "..", "configs",
                                "captioning", "coco",
                                "butd_nucleus_sampling.yml")
     config_path = os.path.abspath(config_path)
     configuration = Configuration(config_path)
     configuration.config["datasets"] = "coco"
     configuration.config["model_attributes"]["butd"]["inference"][
         "params"]["sum_threshold"] = 0.5
     configuration.freeze()
     self.config = configuration.config
     registry.register("config", self.config)
示例#13
0
    def init_processors(self):
        if not hasattr(self.config, "processors"):
            return
        extra_params = {"data_root_dir": self.config.data_root_dir}
        for processor_key, processor_params in self.config.processors.items():
            reg_key = "{}_{}".format(self._name, processor_key)
            reg_check = registry.get(reg_key, no_warning=True)

            if reg_check is None:
                processor_object = Processor(processor_params, **extra_params)
                setattr(self, processor_key, processor_object)
                registry.register(reg_key, processor_object)
            else:
                setattr(self, processor_key, reg_check)
示例#14
0
文件: tiki.py 项目: psnonis/TikiAI
    def build_processors(self):

        print('Tiki : Initializing : Building - Text Processors')

        with open('/final/data/pythia.yaml') as f:
            config = yaml.load(f, Loader=yaml.FullLoader)

        config = ConfigNode(config)
        config.training_parameters.evalai_inference = True  # Remove warning
        registry.register('config', config)

        self.config = config
        vqa_config = config.task_attributes.vqa.dataset_attributes.vqa2
        text_processor_config = vqa_config.processors.text_processor
        answer_processor_config = vqa_config.processors.answer_processor

        text_processor_config.params.vocab.vocab_file = '/final/data/vocabulary_100k.txt'
        answer_processor_config.params.vocab_file = '/final/data/answers_vqa.txt'

        self.text_processor = VocabProcessor(text_processor_config.params)
        self.answer_processor = VQAAnswerProcessor(
            answer_processor_config.params)

        registry.register('vqa2_text_processor', self.text_processor)
        registry.register('vqa2_answer_processor', self.answer_processor)
        registry.register('vqa2_num_final_outputs',
                          self.answer_processor.get_vocab_size())
示例#15
0
    def _init_processors(self):
        with open(os.path.join(BASE_VQA_DIR_PATH, "model_data/pythia.yaml")) as f:
            config = yaml.load(f)

        config = ConfigNode(config)
        # Remove warning
        config.training_parameters.evalai_inference = True
        registry.register("config", config)

        self.config = config

        vqa_config = config.task_attributes.vqa.dataset_attributes.vqa2
        text_processor_config = vqa_config.processors.text_processor
        answer_processor_config = vqa_config.processors.answer_processor

        text_processor_config.params.vocab.vocab_file = os.path.join(
            BASE_VQA_DIR_PATH, "model_data/vocabulary_100k.txt"
        )
        answer_processor_config.params.vocab_file = os.path.join(
            BASE_VQA_DIR_PATH, "model_data/answers_vqa.txt"
        )
        # Add preprocessor as that will needed when we are getting questions from user
        self.text_processor = VocabProcessor(text_processor_config.params)
        self.answer_processor = VQAAnswerProcessor(answer_processor_config.params)

        registry.register("vqa2_text_processor", self.text_processor)
        registry.register("vqa2_answer_processor", self.answer_processor)
        registry.register(
            "vqa2_num_final_outputs", self.answer_processor.get_vocab_size()
        )
示例#16
0
    def get_data_t(self, data, batch_size_t):
        data["texts"] = data["texts"][:batch_size_t]
        if "state" in data:
            h1 = data["state"]["td_hidden"][0][:batch_size_t]
            c1 = data["state"]["td_hidden"][1][:batch_size_t]
            h2 = data["state"]["lm_hidden"][0][:batch_size_t]
            c2 = data["state"]["lm_hidden"][1][:batch_size_t]
        else:
            h1, c1 = self.init_hidden_state(data["texts"])
            h2, c2 = self.init_hidden_state(data["texts"])
        data["state"] = {"td_hidden": (h1, c1), "lm_hidden": (h2, c2)}
        registry.register("{}_lstm_state".format(h1.device), data["state"])

        return data, batch_size_t
示例#17
0
    def load_model(self):
        attributes = self.config.model_attributes[self.config.model]
        # Easy way to point to config for other model
        if isinstance(attributes, str):
            attributes = self.config.model_attributes[attributes]

        attributes["model"] = self.config.model

        self.task_loader.update_registry_for_model(attributes)
        self.model = build_model(attributes)
        self.task_loader.clean_config(attributes)
        training_parameters = self.config.training_parameters

        data_parallel = training_parameters.data_parallel
        distributed = training_parameters.distributed

        registry.register("data_parallel", data_parallel)
        registry.register("distributed", distributed)

        if "cuda" in str(self.config.training_parameters.device):
            rank = self.local_rank if self.local_rank is not None else 0
            device_info = "CUDA Device {} is: {}".format(
            rank, torch.cuda.get_device_name(self.local_rank)
            )

            self.writer.write(device_info, log_all=True)

        self.model = self.model.to(self.device)

        self.writer.write("Torch version is: " + torch.__version__)

        if (
            "cuda" in str(self.device)
            and torch.cuda.device_count() > 1
            and data_parallel is True
        ):
            print("parallel!")
            self.model = torch.nn.DataParallel(self.model)

        if (
            "cuda" in str(self.device)
            and self.local_rank is not None
            and distributed is True
        ):
            torch.cuda.set_device(self.local_rank)
            self.model = torch.nn.parallel.DistributedDataParallel(
            self.model, device_ids=[self.local_rank]
            )
示例#18
0
def build_caption_model(caption_config: Dict, cuda_device: torch.device):
    """

    Parameters
    ----------
    caption_config : Dict
        Dict of BUTD and Detectron model configuration.
    cuda_device : torch.device
        Torch device to load the model to.

    Returns
    -------
    (model, caption_processor, text_processor) : List[object]
        Returns the model, caption and text processor


    """
    with open(caption_config["butd_model"]["config_yaml"]) as f:
        butd_config = yaml.load(f, Loader=yaml.FullLoader)
    butd_config = ConfigNode(butd_config)
    butd_config.training_parameters.evalai_inference = True
    registry.register("config", butd_config)

    caption_processor, text_processor = init_processors(
        caption_config, butd_config)

    if cuda_device == torch.device('cpu'):
        state_dict = torch.load(caption_config["butd_model"]["model_pth"],
                                map_location='cpu')
    else:
        state_dict = torch.load(caption_config["butd_model"]["model_pth"])

    model_config = butd_config.model_attributes.butd
    model_config.model_data_dir = caption_config["model_data_dir"]
    model = BUTD(model_config)
    model.build()
    model.init_losses_and_metrics()

    if list(state_dict.keys())[0].startswith('module') and \
            not hasattr(model, 'module'):
        state_dict = multi_gpu_state_to_single(state_dict)

    model.load_state_dict(state_dict)
    model.to(cuda_device)
    model.eval()

    return model, caption_processor, text_processor
示例#19
0
文件: trainer.py 项目: ronghanghu/mmf
    def load(self):
        self.load_config()
        self._init_process_group()

        self.run_type = self.config.training_parameters.get("run_type", "train")
        self.task_loader = TaskLoader(self.config)

        self.writer = Logger(self.config)
        registry.register("writer", self.writer)

        self.configuration.pretty_print()

        self.config_based_setup()

        self.load_task()
        self.load_model()
        self.load_optimizer()
        self.load_extras()
示例#20
0
    def _init_process_group(self):
        training_parameters = self.config.training_parameters
        self.local_rank = training_parameters.local_rank
        self.device = training_parameters.device

        if self.local_rank is not None and training_parameters.distributed:
            if not torch.distributed.is_nccl_available():
                raise RuntimeError(
                    "Unable to initialize process group: NCCL is not available"
                )
            torch.distributed.init_process_group(backend="nccl")
            synchronize()

        if ("cuda" in self.device and training_parameters.distributed
                and self.local_rank is not None):
            self.device = torch.device("cuda", self.local_rank)

        registry.register("current_device", self.device)
示例#21
0
文件: trainer.py 项目: ronghanghu/mmf
    def load_config(self):
        # TODO: Review configuration update once again
        # (remember clip_gradients case)
        self.configuration = Configuration(self.args.config)

        # Update with the config override if passed
        self.configuration.override_with_cmd_config(self.args.config_override)

        # Now, update with opts args that were passed
        self.configuration.override_with_cmd_opts(self.args.opts)

        # Finally, update with args that were specifically passed
        # as arguments
        self.configuration.update_with_args(self.args)
        self.configuration.freeze()

        self.config = self.configuration.get_config()
        registry.register("config", self.config)
示例#22
0
def init_processors(caption_config: Dict, butd_config: Dict):
    """Build the caption and text processors.

    """
    captioning_config = butd_config.task_attributes.captioning \
        .dataset_attributes.coco
    text_processor_config = captioning_config.processors.text_processor
    caption_processor_config = captioning_config.processors \
        .caption_processor
    vocab_file_path = caption_config["text_caption_processor_vocab_txt"]
    text_processor_config.params.vocab.vocab_file = vocab_file_path
    caption_processor_config.params.vocab.vocab_file = vocab_file_path
    text_processor = VocabProcessor(text_processor_config.params)
    caption_processor = CaptionProcessor(caption_processor_config.params)

    registry.register("coco_text_processor", text_processor)
    registry.register("coco_caption_processor", caption_processor)

    return caption_processor, text_processor
    def __init__(self, use_constrained=False):
        super(PythiaCaptioner, self).__init__()
        # load configuration file
        with open(config_file) as f:
            config = yaml.load(f)
        config = ConfigNode(config)

        self.use_constrained = use_constrained

        # the following blocks of code read some configuration
        # parameter in Pythia
        config.training_parameters.evalai_inference = True
        registry.register("config", config)
        self.config = config

        captioning_config = config.task_attributes.captioning.dataset_attributes.coco
        text_processor_config = captioning_config.processors.text_processor
        caption_processor_config = captioning_config.processors.caption_processor
        # text_processor and caption_processor are used to pre-process the text
        text_processor_config.params.vocab.vocab_file = vocab_file
        caption_processor_config.params.vocab.vocab_file = vocab_file
        self.text_processor = VocabProcessor(text_processor_config.params)
        self.caption_processor = CaptionProcessor(
            caption_processor_config.params)

        registry.register("coco_text_processor", self.text_processor)
        registry.register("coco_caption_processor", self.caption_processor)

        self.model = self._build_model()
示例#24
0
文件: main.py 项目: ascott02/pythia
  def _init_processors(self):
    with open(model_yaml) as f:
      config = yaml.load(f)

    config = ConfigNode(config)
    # Remove warning
    config.training_parameters.evalai_inference = True
    registry.register("config", config)

    self.config = config

    captioning_config = config.task_attributes.captioning.dataset_attributes.coco
    # captioning_config = config.task_attributes.captioning.dataset_attributes.youcookII
    text_processor_config = captioning_config.processors.text_processor
    caption_processor_config = captioning_config.processors.caption_processor
    # print("DEBUG captioning_config:", captioning_config)
    # print("DEBUG text_processor_config:", text_processor_config)
    # print("DEBUG caption_processor_config:", caption_processor_config)

    text_processor_config.params.vocab.vocab_file = "content/model_data/vocabulary_captioning_thresh5.txt"
    caption_processor_config.params.vocab.vocab_file = "content/model_data/vocabulary_captioning_thresh5.txt"
    self.text_processor = VocabProcessor(text_processor_config.params)
    self.caption_processor = CaptionProcessor(caption_processor_config.params)
    # print("DEBUG text_processor:", self.text_processor)
    # print("DEBUG caption_processor:", self.caption_processor)

    registry.register("coco_text_processor", self.text_processor)
    registry.register("coco_caption_processor", self.caption_processor)
示例#25
0
    def __init__(self, use_constrained=False):
        super(PythiaCaptioner, self).__init__()
        # load configuration file
        with open(config_file) as f:
            config = yaml.load(f)
        config = ConfigNode(config)

        self.use_constrained = use_constrained

        # TODO: not sure what these two lines really means
        config.training_parameters.evalai_inference = True
        registry.register("config", config)
        self.config = config

        captioning_config = config.task_attributes.captioning.dataset_attributes.coco
        text_processor_config = captioning_config.processors.text_processor
        caption_processor_config = captioning_config.processors.caption_processor

        text_processor_config.params.vocab.vocab_file = vocab_file
        caption_processor_config.params.vocab.vocab_file = vocab_file
        self.text_processor = VocabProcessor(text_processor_config.params)
        self.caption_processor = CaptionProcessor(
            caption_processor_config.params)

        registry.register("coco_text_processor", self.text_processor)
        registry.register("coco_caption_processor", self.caption_processor)

        self.model = self._build_model()
示例#26
0
    def test_caption_bleu4(self):
        path = os.path.join(
            os.path.abspath(__file__),
            "../../../pythia/common/defaults/configs/datasets/captioning/coco.yml",
        )
        with open(os.path.abspath(path)) as f:
            config = yaml.load(f, Loader=yaml.FullLoader)

        config = ConfigNode(config)
        captioning_config = config.dataset_attributes.coco
        caption_processor_config = captioning_config.processors.caption_processor
        vocab_path = os.path.join(os.path.abspath(__file__), "..", "..",
                                  "data", "vocab.txt")
        caption_processor_config.params.vocab.vocab_file = os.path.abspath(
            vocab_path)
        caption_processor = CaptionProcessor(caption_processor_config.params)
        registry.register("coco_caption_processor", caption_processor)

        caption_bleu4 = metrics.CaptionBleu4Metric()
        expected = Sample()
        predicted = dict()

        # Test complete match
        expected.answers = torch.empty((5, 5, 10))
        expected.answers.fill_(4)
        predicted["scores"] = torch.zeros((5, 10, 19))
        predicted["scores"][:, :, 4] = 1.0

        self.assertEqual(
            caption_bleu4.calculate(expected, predicted).item(), 1.0)

        # Test partial match
        expected.answers = torch.empty((5, 5, 10))
        expected.answers.fill_(4)
        predicted["scores"] = torch.zeros((5, 10, 19))
        predicted["scores"][:, 0:5, 4] = 1.0

        self.assertAlmostEqual(
            caption_bleu4.calculate(expected, predicted).item(), 0.3928, 4)
示例#27
0
    def _update_meter(self, report, meter=None, eval_mode=False):
        if meter is None:
            meter = self.meter

        loss_dict = report.losses
        metrics_dict = report.metrics

        reduced_loss_dict = reduce_dict(loss_dict)
        reduced_metrics_dict = reduce_dict(metrics_dict)

        loss_key = report.dataset_type + "/total_loss"

        with torch.no_grad():
            reduced_loss = sum([loss.mean() for loss in reduced_loss_dict.values()])
            if hasattr(reduced_loss, "item"):
                reduced_loss = reduced_loss.item()

            registry.register(loss_key, reduced_loss)

            meter_update_dict = {loss_key: reduced_loss}
            meter_update_dict.update(reduced_loss_dict)
            meter_update_dict.update(reduced_metrics_dict)
            meter.update(meter_update_dict)
示例#28
0
    def load(self):
        self._init_process_group()

        self.run_type = self.config.training_parameters.get(
            "run_type", "train")
        self.dataset_loader = DatasetLoader(self.config)
        self._datasets = self.config.datasets

        self.writer = Logger(self.config)
        registry.register("writer", self.writer)

        self.configuration = registry.get("configuration")
        self.configuration.pretty_print()

        self.config_based_setup()

        self.load_task()
        self.load_model()
        self.load_optimizer()
        self.load_extras()
        if visualization_flag:
            self.generator = GenerateWord(
                'data/m4c_captioner_vocabs/textcaps/vocab_textcap_threshold_10.txt'
            )
示例#29
0
 def setUp(self):
     torch.manual_seed(1234)
     registry.register("clevr_text_vocab_size", 80)
     registry.register("clevr_num_final_outputs", 32)
     config_path = os.path.join(get_pythia_root(), "..", "configs", "vqa",
                                "clevr", "cnn_lstm.yml")
     config_path = os.path.abspath(config_path)
     configuration = Configuration(config_path)
     configuration.config["datasets"] = "clevr"
     configuration.freeze()
     self.config = configuration.config
     registry.register("config", self.config)
示例#30
0
    def _init_processors(self):
        with open("model_data/butd.yaml") as f:
            config = yaml.load(f)

            config = ConfigNode(config)
            config.training_parameters.evalai_inference = True
            registry.register("config", config)

            self.config = config

            captioning_config = config.task_attributes.captioning.dataset_attributes.coco
            text_processor_config = captioning_config.processors.text_processor
            caption_processor_config = captioning_config.processors.caption_processor

            text_processor_config.params.vocab.vocab_file = "model_data/vocabulary_captioning_thresh5.txt"
            caption_processor_config.params.vocab.vocab_file = "model_data/vocabulary_captioning_thresh5.txt"
            self.text_processor = VocabProcessor(text_processor_config.params)
            self.caption_processor = CaptionProcessor(caption_processor_config.params)

            registry.register("coco_text_processor", self.text_processor)
            registry.register("coco_caption_processor", self.caption_processor)