def get_loaders(self, stage: str, **kwargs) -> tp.Dict[str, DataLoader]: loaders = dict() data_params = dict(self.stages_config[stage]["data_params"]) data_path = ( Path(os.getenv("DATA_PATH")) / "data_cat_dogs" ).as_posix() + "/*" tag_file_path = ( Path(os.getenv("DATA_PATH")) / "cat_dog_labeling.json" ).as_posix() train_data, valid_data, num_classes = get_cat_dogs_dataset( data_path, tag_file_path=tag_file_path ) open_fn = get_reader(num_classes) data = [("train", train_data), ("valid", valid_data)] for mode, part in data: data_transform = self.get_transforms(stage=stage, dataset=mode) loaders[mode] = utils.get_loader( part, open_fn=open_fn, dict_transform=data_transform, shuffle=(mode == "train"), sampler=None, drop_last=(mode == "train"), **data_params, ) return loaders
def get_loaders(self, stage: str, **kwargs): loaders = dict() data_params = dict(self.stages_config[stage]["data_params"]) data_path = Path(os.environ["DATA_PATH"]) if stage == "stage1": for mode in ["train", "valid"]: dataset = CIFAR10( root=(data_path / "data_cifar").as_posix(), train=(mode == "train"), download=True, transform=self.get_transforms(stage=stage, dataset=mode), ) loaders[mode] = utils.get_loader( dataset, open_fn=lambda x: x, dict_transform=lambda x: x, shuffle=(mode == "train"), sampler=None, drop_last=(mode == "train"), **data_params, ) elif stage == "stage2": data_path = (data_path / "data_cat_dogs").as_posix() + "/*" tag_file_path = (data_path / "cat_dog_labeling.json").as_posix() train_data, valid_data, num_classes = get_cat_dogs_dataset( data_path, tag_file_path=tag_file_path) open_fn = get_reader(num_classes) data = [("train", train_data), ("valid", valid_data)] for mode, part in data: data_transform = self.get_transforms(stage=stage, dataset=mode) loaders[mode] = utils.get_loader( part, open_fn=open_fn, dict_transform=data_transform, shuffle=(mode == "train"), sampler=None, drop_last=(mode == "train"), **data_params, ) return loaders
def main(args, _=None): """Run the ``catalyst-data text2embeddings`` script.""" batch_size = args.batch_size num_workers = args.num_workers max_length = args.max_length pooling_groups = args.pooling.split(",") bert_level = args.bert_level if bert_level is not None: assert (args.output_hidden_states ), "You need hidden states output for level specification" utils.set_global_seed(args.seed) utils.prepare_cudnn(args.deterministic, args.benchmark) if getattr(args, "in_huggingface", False): model_config = BertConfig.from_pretrained(args.in_huggingface) model_config.output_hidden_states = args.output_hidden_states model = BertModel.from_pretrained(args.in_huggingface, config=model_config) tokenizer = BertTokenizer.from_pretrained(args.in_huggingface) else: model_config = BertConfig.from_pretrained(args.in_config) model_config.output_hidden_states = args.output_hidden_states model = BertModel(config=model_config) tokenizer = BertTokenizer.from_pretrained(args.in_vocab) if getattr(args, "in_model", None) is not None: checkpoint = utils.load_checkpoint(args.in_model) checkpoint = {"model_state_dict": checkpoint} utils.unpack_checkpoint(checkpoint=checkpoint, model=model) model = model.eval() model, _, _, _, device = utils.process_components(model=model) df = pd.read_csv(args.in_csv) df = df.dropna(subset=[args.txt_col]) df.to_csv(f"{args.out_prefix}.df.csv", index=False) df = df.reset_index().drop("index", axis=1) df = list(df.to_dict("index").values()) num_samples = len(df) open_fn = LambdaReader( input_key=args.txt_col, output_key=None, lambda_fn=partial( tokenize_text, strip=args.strip, lowercase=args.lowercase, remove_punctuation=args.remove_punctuation, ), tokenizer=tokenizer, max_length=max_length, ) dataloader = utils.get_loader( df, open_fn, batch_size=batch_size, num_workers=num_workers, ) features = {} dataloader = tqdm(dataloader) if args.verbose else dataloader with torch.no_grad(): for idx, batch_input in enumerate(dataloader): batch_input = utils.any2device(batch_input, device) batch_output = model(**batch_input) mask = (batch_input["attention_mask"].unsqueeze(-1) if args.mask_for_max_length else None) if utils.check_ddp_wrapped(model): # using several gpu hidden_size = model.module.config.hidden_size hidden_states = model.module.config.output_hidden_states else: # using cpu or one gpu hidden_size = model.config.hidden_size hidden_states = model.config.output_hidden_states batch_features = process_bert_output( bert_output=batch_output, hidden_size=hidden_size, output_hidden_states=hidden_states, pooling_groups=pooling_groups, mask=mask, ) # create storage based on network output if idx == 0: for layer_name, layer_value in batch_features.items(): if bert_level is not None and bert_level != layer_name: continue layer_name = (layer_name if isinstance(layer_name, str) else f"{layer_name:02d}") _, embedding_size = layer_value.shape features[layer_name] = np.memmap( f"{args.out_prefix}.{layer_name}.npy", dtype=np.float32, mode="w+", shape=(num_samples, embedding_size), ) indices = np.arange(idx * batch_size, min((idx + 1) * batch_size, num_samples)) for layer_name2, layer_value2 in batch_features.items(): if bert_level is not None and bert_level != layer_name2: continue layer_name2 = (layer_name2 if isinstance(layer_name2, str) else f"{layer_name2:02d}") features[layer_name2][indices] = _detach(layer_value2) if args.force_save: for key, mmap in features.items(): mmap.flush() np.save(f"{args.out_prefix}.{key}.force.npy", mmap, allow_pickle=False)