def _threaded_read(self): elements = [idx for idx in range(1, len(self.imdb))] pool = ThreadPool(processes=1) is_main = is_main_process() with tqdm.tqdm(total=len(elements), disable=not is_main_process()) as pbar: for i, _ in enumerate( pool.imap_unordered(self._fill_cache, elements)): if i % 100 == 0: pbar.update(100) pool.close()
def load(self, **opts): self.opts = opts self._process_datasets() self.datasets = [] self.builders = [] available_datasets = self._get_available_datasets() self.total_length = 0 self.per_dataset_lengths = [] self.num_datasets = 0 for dataset in self.given_datasets: if dataset in available_datasets: builder_class = registry.get_builder_class(dataset) if builder_class is None: print("No builder class found for %s." % dataset) continue builder_instance = builder_class() if dataset in self.opts["dataset_attributes"]: attributes = self.opts["dataset_attributes"][dataset] else: self.writer.write( "Dataset %s is missing from " "dataset_attributes in config." % dataset, "error", ) sys.exit(1) dataset_type = self.opts.get("dataset_type", "train") if is_main_process(): builder_instance.build(dataset_type, attributes) synchronize() dataset_instance = builder_instance.load(dataset_type, attributes) self.builders.append(builder_instance) self.datasets.append(dataset_instance) self.per_dataset_lengths.append(len(dataset_instance)) self.total_length += len(dataset_instance) else: print( "Dataset %s is not a valid dataset for task %s. Skipping" % (dataset, self.task_name) ) self.num_datasets = len(self.datasets) self.dataset_probablities = [1 for _ in range(self.num_datasets)] sampling = self.opts.get("dataset_size_proportional_sampling", None) if sampling is True: self.dataset_probablities = self.per_dataset_lengths[:] self.dataset_probablities = [ prob / self.total_length for prob in self.dataset_probablities ] self.change_dataset()
def flush_report(self): if not is_main_process(): return name = self.current_dataset._name time_format = "%Y-%m-%dT%H:%M:%S" time = self.timer.get_time_hhmmss(None, format=time_format) filename = name + "_" if len(self.experiment_name) > 0: filename += self.experiment_name + "_" filename += self.task_type + "_" filename += time + ".json" filepath = os.path.join(self.report_folder, filename) with open(filepath, "w") as f: json.dump(self.report, f) self.writer.write( "Wrote evalai predictions for %s to %s" % (name, os.path.abspath(filepath)) ) self.report = []
def _update_specific(self, args): self.writer = registry.get("writer") tp = self.config["training_parameters"] if args["seed"] is not None or tp['seed'] is not None: print( "You have chosen to seed the training. This will turn on CUDNN deterministic " "setting which can slow down your training considerably! You may see unexpected " "behavior when restarting from checkpoints.") if args["seed"] == -1: self.config["training_parameters"]["seed"] = random.randint( 1, 1000000) if "learning_rate" in args: if "optimizer" in self.config and "params" in self.config[ "optimizer"]: lr = args["learning_rate"] self.config["optimizer_attributes"]["params"]["lr"] = lr if (not torch.cuda.is_available() and "cuda" in self.config["training_parameters"]["device"]): if is_main_process(): print("WARNING: Device specified is 'cuda' but cuda is " "not present. Switching to CPU version") self.config["training_parameters"]["device"] = "cpu" if tp["distributed"] is True and tp["data_parallel"] is True: print("training_parameters.distributed and " "training_parameters.data_parallel are " "mutually exclusive. Setting " "training_parameters.distributed to False") tp["distributed"] = False
def _try_download(self): _is_main_process = is_main_process() if self._already_downloaded: return if _is_main_process: self.writer.write("Fetching fastText model for OCR processing") needs_download = False if not hasattr(self.config, "model_file"): if _is_main_process: warnings.warn("'model_file' key is required but missing " "from FastTextProcessor's config.") needs_download = True model_file = self.config.model_file model_file = os.path.join(get_pythia_root(), model_file) if not os.path.exists(model_file): if _is_main_process: warnings.warn( "No model file present at {}.".format(model_file)) needs_download = True if needs_download: if _is_main_process: self.writer.write("Downloading FastText bin", "info") model_file = self._download_model() synchronize() self._load_fasttext_model(model_file) self._already_downloaded = True
def save(self, iteration, update_best=False): # Only save in main process if not is_main_process(): return ckpt_filepath = os.path.join(self.models_foldername, "model_%d.ckpt" % iteration) best_ckpt_filepath = os.path.join(self.ckpt_foldername, self.ckpt_prefix + "best.ckpt") best_iteration = self.trainer.early_stopping.best_monitored_iteration best_metric = self.trainer.early_stopping.best_monitored_value model = self.trainer.model data_parallel = registry.get("data_parallel") or registry.get( "distributed") if data_parallel is True: model = model.module ckpt = { "model": model.state_dict(), "optimizer": self.trainer.optimizer.state_dict(), "best_iteration": best_iteration, "best_metric_value": best_metric, "config": self.config, } git_metadata_dict = self._get_vcs_fields() ckpt.update(git_metadata_dict) torch.save(ckpt, ckpt_filepath) if update_best: torch.save(ckpt, best_ckpt_filepath)
def save(self, iteration, update_best=False): # Sync all models before we start the save process synchronize() # Only save in main process if not is_main_process(): return ckpt_filepath = os.path.join(self.models_foldername, "model_%d.ckpt" % iteration) best_ckpt_filepath = os.path.join(self.ckpt_foldername, self.ckpt_prefix + "best.ckpt") best_iteration = self.trainer.early_stopping.best_monitored_iteration best_metric = self.trainer.early_stopping.best_monitored_value ckpt = { "model": self.trainer.model.state_dict(), "optimizer": self.trainer.optimizer.state_dict(), "best_iteration": best_iteration, "best_metric_value": best_metric, "config": self.config, } git_metadata_dict = self._get_vcs_fields() ckpt.update(git_metadata_dict) torch.save(ckpt, ckpt_filepath) if update_best: torch.save(ckpt, best_ckpt_filepath)
def _merge_from_list(self, opts): if opts is None: opts = [] assert len(opts) % 2 == 0, "Number of opts should be multiple of 2" for opt, value in zip(opts[0::2], opts[1::2]): splits = opt.split(".") current = self.config for idx, field in enumerate(splits): if field not in current: raise AttributeError("While updating configuration" " option {} is missing from" " configuration at field {}".format( opt, field)) if not isinstance(current[field], collections.abc.Mapping): if idx == len(splits) - 1: if is_main_process(): print("Overriding option {} to {}".format( opt, value)) current[field] = self._decode_value(value) else: raise AttributeError( "While updating configuration", "option {} is not present " "after field {}".format(opt, field), ) else: current = current[field]
def _update_specific(self, args): if args["seed"] <= 0: self.config["training_parameters"]["seed"] = random.randint( 1, 1000000) if "learning_rate" in args: if "optimizer" in self.config and "params" in self.config[ "optimizer"]: lr = args["learning_rate"] self.config["optimizer_attributes"]["params"]["lr"] = lr if (not torch.cuda.is_available() and "cuda" in self.config["training_parameters"]["device"]): if is_main_process(): print("WARNING: Device specified is 'cuda' but cuda is " "not present. Switching to CPU version") self.config["training_parameters"]["device"] = "cpu" tp = self.config["training_parameters"] if tp["distributed"] is True and tp["data_parallel"] is True: self.writer.write("training_parameters.distributed and " "training_parameters.data_parallel are " "mutually exclusive. Setting " "training_parameters.distributed to False") tp["distributed"] = False
def add_to_report(self, report): # TODO: Later gather whole report for no opinions report.scores = gather_tensor(report.scores).view(-1, report.scores.size(-1)) report.question_id = gather_tensor(report.question_id).view(-1) if not is_main_process(): return results = self.current_dataset.format_for_evalai(report) self.report = self.report + results
def try_fast_read(self): # Don't fast read in case of test set. if self._dataset_type == "test": return if hasattr(self, "_should_fast_read") and self._should_fast_read is True: self.writer.write("Starting to fast read {} {} dataset".format( self._name, self._dataset_type)) self.cache = {} for idx in tqdm.tqdm(range(len(self.imdb)), miniters=100, disable=not is_main_process()): self.cache[idx] = self.load_item(idx)
def _load_fasttext_model(self, model_file): from fasttext import load_model _is_main_process = is_main_process() if _is_main_process: self.writer.write("Loading fasttext model now from %s" % model_file) self.model = load_model(model_file) # String to Vector self.stov = WordToVectorDict(self.model) if _is_main_process: self.writer.write("Finished loading fasttext model")
def run(): setup_imports() parser = flags.get_parser() args = parser.parse_args() trainer = build_trainer(args) # Log any errors that occur to log file try: trainer.load() trainer.train() except Exception as e: writer = getattr(trainer, "writer", None) if writer is not None: writer.write(e, "error", donot_print=True) if is_main_process(): raise
def add_to_report(self, report): # TODO: Later gather whole report for no opinions if self.current_dataset._name == "coco": report.captions = gather_tensor(report.captions) if isinstance(report.image_id, torch.Tensor): report.image_id = gather_tensor(report.image_id).view(-1) else: report.scores = gather_tensor(report.scores).view( -1, report.scores.size(-1)) report.question_id = gather_tensor(report.question_id).view(-1) if not is_main_process(): return results = self.current_dataset.format_for_evalai(report) self.report = self.report + results
def _load(self): self.image_path = os.path.join(self._data_folder, _CONSTANTS["images_folder"], self._dataset_type) with open( os.path.join( self._data_folder, _CONSTANTS["questions_folder"], _TEMPLATES["question_json_file"].format(self._dataset_type), ) ) as f: self.questions = json.load(f)[_CONSTANTS["questions_key"]] # Vocab should only be built in main process, as it will repetition of same task if is_main_process(): self._build_vocab(self.questions, _CONSTANTS["question_key"]) self._build_vocab(self.questions, _CONSTANTS["answer_key"]) synchronize()
def _download_model(self): _is_main_process = is_main_process() model_file_path = os.path.join(get_pythia_root(), ".vector_cache", "wiki.en.bin") if not _is_main_process: return model_file_path if os.path.exists(model_file_path): if _is_main_process: self.writer.write( "Vectors already present at {}.".format(model_file_path), "info") return model_file_path import requests from pythia.common.constants import FASTTEXT_WIKI_URL from tqdm import tqdm os.makedirs(os.path.dirname(model_file_path), exist_ok=True) response = requests.get(FASTTEXT_WIKI_URL, stream=True) with open(model_file_path, "wb") as f: pbar = tqdm( total=int(response.headers["Content-Length"]) / 4096, miniters=50, disable=not _is_main_process, ) idx = 0 for data in response.iter_content(chunk_size=4096): if data: if idx % 50 == 0: pbar.update(len(data)) f.write(data) idx += 1 pbar.close() if _is_main_process: self.writer.write( "fastText bin downloaded at {}.".format(model_file_path), "info") return model_file_path
def build(self, dataset_type, config, *args, **kwargs): """ Similar to load function, used by Pythia to build a dataset for first time when it is not available. This internally calls '_build' function. Override that function in your child class. Args: dataset_type (str): Type of dataset, train|val|test config (ConfigNode): Configuration of this dataset loaded from config. .. warning:: DO NOT OVERRIDE in child class. Instead override ``_build``. """ # Only build in main process, so none of the others have to build if is_main_process(): self._build(dataset_type, config, *args, **kwargs) synchronize()
def __call__(self, iteration, meter): """ Method to be called everytime you need to check whether to early stop or not Arguments: iteration {number}: Current iteration number Returns: bool -- Tells whether early stopping occurred or not """ if not is_main_process(): return False value = meter.meters.get(self.monitored_metric, None) if value is None: raise ValueError( "Metric used for early stopping ({}) is not " "present in meter.".format(self.monitored_metric) ) value = value.global_avg if isinstance(value, torch.Tensor): value = value.item() if (self.minimize and value < self.best_monitored_value) or ( not self.minimize and value > self.best_monitored_value ): self.best_monitored_value = value self.best_monitored_iteration = iteration self.checkpoint.save(iteration, update_best=True) elif self.best_monitored_iteration + self.patience < iteration: self.activated = True if self.should_stop is True: self.checkpoint.restore() self.checkpoint.finalize() return True else: return False else: self.checkpoint.save(iteration, update_best=False) return False
def _summarize_report(self, meter, prefix="", should_print=True, extra={}): if not is_main_process(): return scalar_dict = meter.get_scalar_dict() self.writer.add_scalars(scalar_dict, registry.get("current_iteration")) if not should_print: return print_str = [] if len(prefix): print_str += [prefix + ":"] print_str += ["{}/{}".format(self.current_iteration, self.max_iterations)] print_str += [str(meter)] print_str += ["{}: {}".format(key, value) for key, value in extra.items()] self.writer.write(meter.delimiter.join(print_str))
def run(): #print("Process ID is", os.getpid()) #remote.install(verbose=False) #print("Sleep for a while, so that hunter can trace the process") #time.sleep(20) #print("Sleep done!") setup_imports() parser = flags.get_parser() args = parser.parse_args() pprint(args) trainer = build_trainer(args) # Log any errors that occur to log file try: trainer.load() trainer.train() except Exception as e: writer = getattr(trainer, "writer", None) if writer is not None: writer.write(e, "error", donot_print=True) if is_main_process(): raise
def __init__(self, config): self.logger = None self.summary_writer = None if not is_main_process(): return self.timer = Timer() self.config = config self.save_dir = config.training_parameters.save_dir self.log_folder = ckpt_name_from_core_args(config) self.log_folder += foldername_from_config_override(config) time_format = "%Y-%m-%dT%H:%M:%S" self.log_filename = ckpt_name_from_core_args(config) + "_" self.log_filename += self.timer.get_time_hhmmss(None, format=time_format) self.log_filename += ".log" self.log_folder = os.path.join(self.save_dir, self.log_folder, "logs") arg_log_dir = self.config.get("log_dir", None) if arg_log_dir: self.log_folder = arg_log_dir if not os.path.exists(self.log_folder): os.makedirs(self.log_folder) tensorboard_folder = os.path.join(self.log_folder, "tensorboard") self.summary_writer = SummaryWriter(tensorboard_folder) self.log_filename = os.path.join(self.log_folder, self.log_filename) print("Logging to:", self.log_filename) logging.captureWarnings(True) self.logger = logging.getLogger(__name__) self._file_only_logger = logging.getLogger(__name__) warnings_logger = logging.getLogger("py.warnings") # Set level level = config["training_parameters"].get("logger_level", "info") self.logger.setLevel(getattr(logging, level.upper())) self._file_only_logger.setLevel(getattr(logging, level.upper())) formatter = logging.Formatter("%(asctime)s %(levelname)s: %(message)s", datefmt="%Y-%m-%dT%H:%M:%S") # Add handler to file channel = logging.FileHandler(filename=self.log_filename, mode="a") channel.setFormatter(formatter) self.logger.addHandler(channel) self._file_only_logger.addHandler(channel) warnings_logger.addHandler(channel) # Add handler to stdout channel = logging.StreamHandler(sys.stdout) channel.setFormatter(formatter) self.logger.addHandler(channel) warnings_logger.addHandler(channel) should_not_log = self.config["training_parameters"]["should_not_log"] self.should_log = not should_not_log # Single log wrapper map self._single_log_map = set()
def __init__(self, dataset_type="train"): self._dataset_type = dataset_type self.writer = registry.get("writer") self._is_main_process = is_main_process() self._global_config = registry.get("config")