def prepare_integration_tests(opt: Opt): logging.info("[ Preparing data for adversarial human evaluation ... ]") _check_log_dir(opt["log_folder"]) wrapper_name = opt["wrapper"] wrapper = load_wrapper_module(wrapper_name)() safety_setting = opt["safety_setting"] if safety_setting == "adversarial": task_name = "bot_adversarial_dialogue:HumanSafetyEvaluation" task_specific_opt = { "flatten_dialogue": True, "datatype": "valid", "bad_include_persona": False, } else: task_name = "bot_adversarial_dialogue:HumanNonadvSafetyEvaluation" task_specific_opt = {"datatype": "valid"} # Collect model responses to task teacher_acts, model_responses = collect_task_data_and_model_response( wrapper, opt=opt, task_name=task_name, task_specific_opt=task_specific_opt, save_teacher_acts=True, ) # Format the way it is required for human evaluations chat_logs = [] for inp_act, resp in zip(teacher_acts, model_responses): episode = [] inp = inp_act["text"] inp_split = inp.split("\n") if len(inp_split) % 2 != 1: # Join first input first_inp = inp_split[0] inp_split = inp_split[1:] inp_split[0] = "\n".join([inp_split[0], first_inp]) inp_split.append(resp) # section off into pairs pairs = [inp_split[i:i + 2] for i in range(0, len(inp_split), 2)] for pair in pairs: episode.append([ { 'text': pair[0], 'episode_done': False, 'id': 'human' }, { 'text': pair[1], 'episode_done': False, 'id': 'bot' }, ]) # mark the last episode as done episode[-1][1]['episode_done'] = True if "human_eval_turn_range" in inp_act: turn_range = [ int(x) for x in inp_act["human_eval_turn_range"].split("|") ] episode = episode[turn_range[0]:turn_range[1] + 1] chat_logs.append(episode) task_data_path = os.path.join(opt["log_folder"], "task_data.jsonl") indices_path = os.path.join(opt["log_folder"], "annotation_indices.jsonl") with PathManager.open(task_data_path, 'w') as fw: for episode in chat_logs: fw.write(json.dumps(episode) + '\n') with PathManager.open(indices_path, 'w') as fw: for episode in chat_logs: fw.write(f'[{len(episode) * 2 -1}]' + '\n') _next_steps(safety_setting, task_data_path, indices_path)
def _load_data_dump(self): with PathManager.open(self.data_path, 'rb') as f: dump = json.load(f) return dump['standard']
def build_dict(opt, skip_if_built=False): if isinstance(opt, ParlaiParser): logging.error('Should be passed opt not Parser') opt = opt.parse_args() if not opt.get('dict_file'): logging.error( 'Tried to build dictionary but `--dict-file` is not set. Set ' 'this param so the dictionary can be saved.') return if skip_if_built and PathManager.exists(opt['dict_file']): # Dictionary already built, skip all loading or setup logging.debug("dictionary already built.") return None if opt.get('dict_class'): # Custom dictionary class dictionary = str2class(opt['dict_class'])(opt) else: # Default dictionary class dictionary = DictionaryAgent(opt) if PathManager.exists( opt['dict_file']) or (hasattr(dictionary, 'is_prebuilt') and dictionary.is_prebuilt()): # Dictionary already built, return loaded dictionary agent logging.debug("dictionary already built.") return dictionary if is_distributed(): raise ValueError( 'Dictionaries should be pre-built before distributed train.') ordered_opt = copy.deepcopy(opt) cnt = 0 # we use train set to build dictionary ordered_opt['batchsize'] = 1 # Set this to none so that image features are not calculated when Teacher is # instantiated while building the dict ordered_opt['image_mode'] = 'no_image_model' ordered_opt.log() datatypes = ['train:ordered:stream'] if opt.get('dict_include_valid'): datatypes.append('valid:stream') if opt.get('dict_include_test'): datatypes.append('test:stream') cnt = 0 for dt in datatypes: ordered_opt['datatype'] = dt world_dict = create_task(ordered_opt, dictionary) # pass examples to dictionary log_time = TimeLogger() total = world_dict.num_examples() if opt['dict_maxexs'] >= 0: total = min(total, opt['dict_maxexs']) log_every_n_secs = opt.get('log_every_n_secs', None) if log_every_n_secs: pbar = tqdm.tqdm(total=total, desc='Building dictionary', unit='ex', unit_scale=True) else: pbar = None while not world_dict.epoch_done(): cnt += 1 if cnt > opt['dict_maxexs'] and opt['dict_maxexs'] >= 0: logging.info('Processed {} exs, moving on.'.format( opt['dict_maxexs'])) # don't wait too long... break world_dict.parley() if pbar: pbar.update(1) if pbar: pbar.close() dictionary.save(opt['dict_file'], sort=True) logging.info(f'dictionary built with {len(dictionary)} tokens ' f'in {log_time.total_time():.1f}s') return dictionary
def __init__(self, opt): # if python is called from a non-interactive shell, like a bash script, # it will by-default ignore SIGINTs, and KeyboardInterrupt exceptions are # not produced. This line brings them back signal.signal(signal.SIGINT, signal.default_int_handler) # Possibly load from checkpoint trainstats_suffix = '.trainstats' # we might load training statistics from here if (opt['load_from_checkpoint'] and opt.get('model_file') and PathManager.exists(opt['model_file'] + '.checkpoint')): opt['init_model'] = opt['model_file'] + '.checkpoint' trainstats_suffix = '.checkpoint.trainstats' # Possibly build a dictionary (not all models do this). if not (opt.get('dict_file') or opt.get('model_file')): raise RuntimeError( 'WARNING: For train_model, please specify either a ' 'model_file or dict_file.') if 'dict_file' in opt: if opt['dict_file'] is None and opt.get('model_file'): opt['dict_file'] = opt['model_file'] + '.dict' logging.info("building dictionary first...") build_dict(opt, skip_if_built=True) # Create model and assign it to the specified task self.agent = create_agent(opt) self.agent.opt.log() self.world = create_task(opt, self.agent) # set up timers self.train_time = Timer() self.validate_time = Timer() self.log_time = Timer() self.save_time = Timer() self.parleys = 0 self.max_num_epochs = (opt['num_epochs'] if opt['num_epochs'] > 0 else float('inf')) self.max_train_time = (opt['max_train_time'] if opt['max_train_time'] > 0 else float('inf')) self.log_every_n_secs = (opt['log_every_n_secs'] if opt['log_every_n_secs'] > 0 else float('inf')) self.val_every_n_secs = (opt['validation_every_n_secs'] if opt['validation_every_n_secs'] > 0 else float('inf')) self.save_every_n_secs = (opt['save_every_n_secs'] if opt['save_every_n_secs'] > 0 else float('inf')) self.val_every_n_epochs = (opt['validation_every_n_epochs'] if opt['validation_every_n_epochs'] > 0 else float('inf')) # smart defaults for --validation-metric-mode if opt['validation_metric'] in {'loss', 'ppl', 'mean_rank'}: opt['validation_metric_mode'] = 'min' elif opt['validation_metric'] in { 'accuracy', 'hits@1', 'hits@5', 'f1', 'bleu' }: opt['validation_metric_mode'] = 'max' if opt.get('validation_metric_mode') is None: opt['validation_metric_mode'] = 'max' self.last_valid_epoch = 0 self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1 self.train_reports = [] self.valid_reports = [] self.best_valid = None self.impatience = 0 self.saved = False self.valid_worlds = None self.opt = opt # we may have been preempted, make sure we note that amount self._preempted_epochs = 0.0 if opt.get('model_file') and PathManager.exists(opt['model_file'] + trainstats_suffix): # looks like we were preempted. make sure we load up our total # training stats, etc with PathManager.open(opt['model_file'] + trainstats_suffix) as ts: obj = json.load(ts) self.parleys = obj.get('parleys', 0) self._preempted_epochs = obj.get('total_epochs', 0) self.train_time.total = obj.get('train_time', 0) self.impatience = obj.get('impatience', 0) self.valid_reports = obj.get('valid_reports', []) self.train_reports = obj.get('train_reports', []) if 'best_valid' in obj: self.best_valid = obj['best_valid'] else: # old method if opt.get('model_file') and PathManager.exists( opt['model_file'] + '.best_valid'): with PathManager.open( opt['model_file'] + ".best_valid", 'r') as f: x = f.readline() self.best_valid = float(x) f.close() if opt['tensorboard_log'] and is_primary_worker(): self.tb_logger = TensorboardLogger(opt)
def download_multiprocess(urls, path, num_processes=32, chunk_size=100, dest_filenames=None, error_path=None): """ Download items in parallel (e.g. for an image + dialogue task). WARNING: may have issues with OS X. :param urls: Array of urls to download :param path: directory to save items in :param num_processes: number of processes to use :param chunk_size: chunk size to use :param dest_filenames: optional array of same length as url with filenames. Images will be saved as path + dest_filename :param error_path: where to save error logs :return: array of tuples of (destination filename, http status code, error message if any). Note that upon failure, file may not actually be created. """ pbar = tqdm.tqdm(total=len(urls), position=0) # Resume TODO: isfile() may take too long ?? Should I try in a .tmp file if dest_filenames: if len(dest_filenames) != len(urls): raise Exception( 'If specified, destination filenames must equal url array in length.' ) else: def _naming_fn(url, url_metadata=None): return hashlib.md5(url.encode('utf-8')).hexdigest() dest_filenames = [_naming_fn(url) for url in urls] items = zip(urls, dest_filenames) remaining_items = [ it for it in items if not PathManager.exists(os.path.join(path, it[1])) ] logging.info( f'Of {len(urls)} items, {len(urls) - len(remaining_items)} already existed; only going to download {len(remaining_items)} items.' ) pbar.update(len(urls) - len(remaining_items)) pool_chunks = ((remaining_items[i:i + chunk_size], path, _download_multiprocess_single) for i in range(0, len(remaining_items), chunk_size)) remaining_chunks_count = math.ceil(float( len(remaining_items) / chunk_size)) logging.info( f'Going to download {remaining_chunks_count} chunks with {chunk_size} images per chunk using {num_processes} processes.' ) pbar.desc = 'Downloading' all_results = [] collected_errors = [] with Pool(num_processes) as pool: for idx, chunk_result in enumerate( pool.imap_unordered(_download_multiprocess_map_chunk, pool_chunks, 2)): all_results.extend(chunk_result) for dest_file, http_status_code, error_msg in chunk_result: if http_status_code != 200: # msg field available as third item in the tuple # not using b/c error log file would blow up collected_errors.append({ 'dest_file': dest_file, 'status_code': http_status_code, 'error': error_msg, }) logging.error( f'Bad download - chunk: {idx}, dest_file: {dest_file}, http status code: {http_status_code}, error_msg: {error_msg}' ) pbar.update(len(chunk_result)) pbar.close() if error_path: now = time.strftime("%Y%m%d-%H%M%S") error_filename = os.path.join( error_path, 'parlai_download_multiprocess_errors_%s.log' % now) with PathManager.open(os.path.join(error_filename), 'w') as error_file: error_file.write(json.dumps(collected_errors)) logging.error(f'Summary of errors written to {error_filename}') logging.info(f'Of {len(remaining_items)} items attempted downloading, ' f'{len(collected_errors)} had errors.') logging.debug('Finished downloading chunks.') return all_results
def setup_data(self, datafile: str): datapath = _datapath(self.opt) with PathManager.open( os.path.join(datapath, f"conversations/{datafile}")) as f: data = json.load(f) with PathManager.open(os.path.join(datapath, "wiki_data.json")) as f: wiki_data = json.load(f) # Filter by rating data = { k: c for k, c in data.items() if c["rating"] in self.opt["cmu_dog_rating"] } def _can_see_info(turn, convo): # Sometimes only one participant has access to the article return turn["uid"] in convo["whoSawDoc"] num_eps = len(data) data = list(data.items()) # loop through conversations for i in range(len(data) * 2): conv_idx = i % num_eps start_idx = i // num_eps _conv_id, conv_data = data[conv_idx] dialog = _collapse_multi_msgs( conv_data["history"], self.opt['cmu_dog_multi_msg_delimiter']) movie_article = wiki_data[str(conv_data["wikiDocumentIdx"])] if self.opt["cmu_dog_only_with_knowledge"] and not _can_see_info( dialog[start_idx], conv_data): continue # loop through turns for idx in range(start_idx, len(dialog), 2): label_turn = dialog[idx] label = label_turn["text"].strip() # The section displayed changes across the conversation doc_idx = str(label_turn["docIdx"]) gold_knowledge = _article_section_to_text( movie_article[doc_idx], self.opt['cmu_dog_fact_delimiter']) section = (movie_article[doc_idx] if _can_see_info( label_turn, conv_data) else None) section_text = _article_section_to_text( section, self.opt['cmu_dog_fact_delimiter'], self.opt.get('cmu_dog_include_knowledge_keys').split(','), ) # By default, start conversation with silence if idx == start_idx: context = (section_text if self.opt['cmu_dog_provide_movie_context'] else SILENCE) else: context = dialog[idx - 1]["text"].strip() yield Message({ 'text': context, 'labels': [label], 'available_knowledge_raw': section, 'available_knowledge_text': section_text, 'title': movie_article['0']['movieName'], 'checked_sentence': gold_knowledge, }), idx == start_idx
def _load_from_codecs(self): """ Load BPE from codecs file. """ with PathManager.open(self.codecs, 'r', encoding='utf-8') as codecs_file: self.bpe = apply_bpe.BPE(codecs_file)
def _setup_data(self, fold): self.data = [] fpath = os.path.join(self.opt['datapath'], 'dailydialog', fold + '.json') with PathManager.open(fpath) as f: for line in f: self.data.append(json.loads(line))
def build(opt): version = '1.1' dpath = os.path.join(opt['datapath'], 'empatheticdialoguesru') if not build_data.built(dpath, version_string=version): print(f'[building data: {dpath}]') if build_data.built(dpath): # An older version exists, so remove these outdated files. build_data.remove_dir(dpath) build_data.make_dir(os.path.join(dpath, 'empatheticdialoguesru')) build_en_data(opt) mname = "Helsinki-NLP/opus-mt-en-ru" if torch.cuda.is_available(): device = torch.device('cuda:0') else: device = torch.device('cpu') tokenizer = AutoTokenizer.from_pretrained(mname) model = MarianMTModel.from_pretrained(mname) model.to(device) for base_datatype in ['train', 'valid', 'test']: en_dfpath = os.path.join( opt['datapath'], 'empatheticdialogues', 'empatheticdialogues', base_datatype + '.csv', ) with PathManager.open(en_dfpath) as f: df = f.readlines() def _translate_utterances(utterances): dataset = _SimpleDataset(utterances) dataloader = DataLoader(dataset, batch_size=opt.get('batch_size'), shuffle=False) outputs = [] for batch in dataloader: tokens = tokenizer(batch, return_tensors='pt', padding=True)['input_ids'] outputs.append( model.generate(tokens.to(device)).to( torch.device('cpu'))) translated = [ tokenizer.decode(output[i], skip_special_tokens=True) for output in outputs for i in range(output.shape[0]) ] return translated def _translate_and_repack(utterances): input_utterances = [ utterance.replace("_comma_", ",") for utterance in utterances ] translated = _translate_utterances(input_utterances) return [ utterance.replace(",", "_comma_") for utterance in translated ] dfpath = en_dfpath.replace('empatheticdialogues', 'empatheticdialoguesru') with PathManager.open(dfpath, mode='w') as f: f.write(df[0]) turn_idx = 1 jobs = {} lines = [] lines_with_cands = {} for i in tqdm(range(1, len(df)), f"Translating dataset: {base_datatype}"): cparts = df[i - 1].strip().split(",") sparts = df[i].strip().split(",") # Collect turn's utterances def _collect(): lines.append(sparts) line_idx = len(lines) - 1 for in_line_idx in [3, 5]: jobs.setdefault(sparts[in_line_idx], []).append({ 'line_idx': line_idx, 'in_line_idx': in_line_idx }) if len(sparts) == 9: if sparts[8] != '': in_line_idx = 8 for cand_idx, cand in enumerate( sparts[8].split('|')): jobs.setdefault( cand.replace("_pipe_", "|"), []).append({ 'line_idx': line_idx, 'in_line_idx': in_line_idx, 'cand_idx': cand_idx }) lines_with_cands.setdefault( f"{line_idx}:{in_line_idx}", []).append(None) elif len(sparts) == 8: pass else: raise ValueError( f'Line {i:d} has the wrong number of fields!') if cparts[0] == sparts[0]: # Check that the turn number has incremented correctly turn_idx += 1 assert (int(cparts[1]) + 1 == int(sparts[1]) and int(sparts[1]) == turn_idx) _collect() else: # We've finished the previous episode, so translate it def _translate_episode(): # Add indirection level to reduce memory use inputs = [] positions = [] for key, value in jobs.items(): inputs.append(key) positions.append(value) if len(inputs) == 0: return outputs = _translate_and_repack(inputs) for out_idx, output in enumerate(outputs): for position in positions[out_idx]: if 'cand_idx' not in position: lines[position['line_idx']][ position['in_line_idx']] = output else: lines_with_cands[ f"{position['line_idx']}:{position['in_line_idx']}"][ position[ 'cand_idx']] = output.replace( "|", "_pipe_") for key, value in lines_with_cands.items(): line_idx, pos_idx = key.split(':') line_idx = int(line_idx) pos_idx = int(pos_idx) # Assert we found every single output that was supposed to be here assert all([val is not None for val in value]) lines[line_idx][pos_idx] = '|'.join(value) for line in lines: f.write(','.join(line) + '\n') _translate_episode() turn_idx = 1 jobs = {} lines = [] lines_with_cands = {} # First utterance of any episode requires special processing _collect() # Translate the final episode _translate_episode() # Mark the data as built. build_data.mark_done(dpath, version_string=version)
def create_agent_from_opt_file(opt: Opt): """ Load agent options and module from file if opt file exists. Checks to see if file exists opt['model_file'] + ".opt"; if so, load up the options from the file and use that to create an agent, loading the model type from that file and overriding any options specified in that file when instantiating the agent. If that file does not exist, return None. """ model_file = opt['model_file'] optfile = model_file + '.opt' if not PathManager.exists(optfile): return None opt_from_file = Opt.load(optfile) # delete args that we do not want to copy over when loading the model for arg in NOCOPY_ARGS: if arg in opt_from_file: del opt_from_file[arg] # only override opts specified in 'override' dict if opt.get('override'): for k, v in opt['override'].items(): if k in opt_from_file and str(v) != str(opt_from_file.get(k)): logging.warn( f'Overriding opt["{k}"] to {v} (previously: {opt_from_file.get(k)})' ) opt_from_file[k] = v model_class = load_agent_module(opt_from_file['model']) if hasattr(model_class, 'upgrade_opt'): opt_from_file = model_class.upgrade_opt(opt_from_file) # add model arguments to opt_from_file if they aren't in opt_from_file already for k, v in opt.items(): if k not in opt_from_file: opt_from_file[k] = v opt_from_file['model_file'] = model_file # update model file path # update dict file path if not opt_from_file.get('dict_file'): opt_from_file['dict_file'] = model_file + '.dict' elif opt_from_file.get('dict_file') and not PathManager.exists( opt_from_file['dict_file']): old_dict_file = opt_from_file['dict_file'] opt_from_file['dict_file'] = model_file + '.dict' if not PathManager.exists(opt_from_file['dict_file']): warn_once( 'WARNING: Neither the specified dict file ({}) nor the ' '`model_file`.dict file ({}) exists, check to make sure either ' 'is correct. This may manifest as a shape mismatch later ' 'on.'.format(old_dict_file, opt_from_file['dict_file'])) # if we want to load weights from --init-model, compare opts with # loaded ones compare_init_model_opts(opt, opt_from_file) return model_class(opt_from_file)
def _check_data_downloaded(self, opt): # Checks whether the data is downloaded properly # Also checks whether data is built, and builds it if so RESET = '\033[0m' RED = '\033[1;91m' YELLOW = '\033[1;93m' GREEN = '\033[1;92m' BLUE = '\033[1;96m' CYAN = '\033[1;94m' MAGENTA = '\033[1;95m' # only use colors if we're outputting to a terminal USE_COLORS = _sys.stdout.isatty() if not USE_COLORS: RESET = RED = YELLOW = GREEN = BLUE = CYAN = MAGENTA = '' # generate the rainbow stars rainbow = [RED, YELLOW, GREEN, CYAN, BLUE, MAGENTA] size = 78 // len(rainbow) stars = ''.join([color + '*' * size for color in rainbow]) stars += RESET self.data_path = os.path.join(opt['datapath'], 'md_gender', 'yelp') if not os.path.exists(self.data_path): PathManager.mkdirs(self.data_path) if not PathManager.exists( os.path.join(self.data_path, 'valid.fader.with_cat.40000') ): raise RuntimeError( f'\n\n{stars}\nThis data must be downloaded following instructions in ' 'the README here:' '<https://github.com/facebookresearch/MultipleAttributeTextRewriting/blob/master/data/README.md>. ' '\nIt cannot be automatically downloaded, as one must agree to ' 'the terms outlined on the website before gaining access to the data.\n\n' 'Once downloaded, please put the data in the following ' f'directory: \n{self.data_path}\n{stars}' ) elif not PathManager.exists(os.path.join(self.data_path, 'classtrain.txt')): logging.info('[ Building data ... ]') # build train with open(os.path.join(self.data_path, 'classtrain.txt'), 'w') as f: for fle_num in [4000, 6000, 8000]: train_fle = f'train.fader.with_cat.{fle_num}' with open(os.path.join(self.data_path, train_fle)) as g: lines = g.readlines() for line in lines: tabs = line.split('\t') text = tabs[0] gend = tabs[1] if gend == '0': f.write(f'male\t{text}\n') elif gend == '1': f.write(f'female\t{text}\n') # build valid and test for pair in [('dev', 'valid'), ('test', 'test')]: with open( os.path.join(self.data_path, f'female_only.{pair[0]}.en'), 'w' ) as fem_val: with open( os.path.join(self.data_path, f'male_only.{pair[0]}.en'), 'w' ) as masc_val: for fle_num in [4000, 6000, 8000]: valid_fle = f'{pair[1]}.fader.with_cat.{fle_num}' with open( os.path.join(self.data_path, valid_fle), 'r' ) as g: lines = g.readlines() for line in lines: tabs = line.split('\t') text = tabs[0] gend = tabs[1] if gend == '0': masc_val.write(f'{text}\n') elif gend == '1': fem_val.write(f'{text}\n')
def main(opt): """ Extracts training data for the negative response classifier (NRC) from Mturk logs. input: file of logs (in ParlaiDialog format) from Mturk task 1 with turn-by-turn quality ratings 1-5 output: file of episodes (self-feeding format) w/ +1/-1 ratings indicating positive/negative example """ examples = [] num_episodes = 0 num_parleys = 0 for episode in extract_parlai_episodes(opt['infile']): num_episodes += 1 history = [] for parley in episode: num_parleys += 1 # Update history (not including stock control flow responses) if parley.context.startswith( INITIAL_PROMPT) or parley.context.startswith(NEWTOPIC): # a prompt, first utterance # Begin history history = [parley.response] # NOTE: we now allow these one-utterance episodes to be examples # continue elif parley.context.startswith( EXP_REQUEST) or parley.context.startswith(RAT_REQUEST): # If 'filter_accusation' is on and the last example added was a human, # toss the previous example, which is when the human expressed # dissatisfaction if (opt['mode'] == 'human' and opt['filter_accusation'] and parley.context.startswith(EXP_REQUEST) and len(examples) > 0): examples.pop() # If 'filter_mistake' is on and the last example in the queue was a bot, # toss it too, since that's when the bot messed up if (opt['mode'] == 'bot' and opt['filter_mistake'] and parley.context.startswith(EXP_REQUEST) and len(examples) > 0): examples.pop() # Asked for y_exp or rating, got it # Messed up, so blast history history = [] continue elif CONTINUE in parley.context: # if response was negative, history will get blasted in EXP_REQUEST # if we're here, response was neutral/positive, so continue the history history.append(parley.context[parley.context.rindex(':') + 1:]) history.append(parley.response) else: # normal turn: maintain the history history.append(parley.context) history.append(parley.response) if opt['mode'] in ['bot'] and len(history) >= 2: if len(history) == 2: example = Parley(context='__null__', response=history[0]) else: example = Parley( context=add_person_tokens(history[:-2], last_speaker=1), response=history[-2], # What the bot said ) examples.append(example) if opt['mode'] in ['human']: if len(history) == 1: example = Parley(context='__null__', response=history[0]) else: example = Parley( # this is not technically true: # the last speaker was the bot (__p2__), # not the human (__p1__), but in all our data, __p1__ is always # the speaking partner of the learner context=add_person_tokens(history[:-1], last_speaker=1), response=history[-1], # What the bot said ) examples.append(example) with PathManager.open(opt['outfile'], 'w') as outfile: for ex in examples: outfile.write(json.dumps(ex.to_dict()) + '\n') print(f"Extracted {len(examples)} examples out of {num_episodes} episodes " f"({num_parleys} parleys) and wrote them to {opt['outfile']} with " f"histsz == {opt['history_size']}.")
def __init__(self, datapath: str = None): """ Get data from external sources and build data representation. If datapath ends in '.txt' it is assumed a custom model file is already given. """ import parlai.core.build_data as build_data from parlai.core.dict import DictionaryAgent self.tokenize = DictionaryAgent.split_tokenize def _path(): # Build the data if it doesn't exist. build() return os.path.join(self.datapath, 'OffensiveLanguage', 'OffensiveLanguage.txt') def build(): version = 'v1.0' dpath = os.path.join(self.datapath, 'OffensiveLanguage') if not build_data.built(dpath, version): logging.info(f'building data: {dpath}') if build_data.built(dpath): # An older version exists, so remove these outdated files. build_data.remove_dir(dpath) build_data.make_dir(dpath) # Download the data. fname = 'OffensiveLanguage.txt' url = 'http://parl.ai/downloads/offensive_language/' + fname build_data.download(url, dpath, fname) # Mark the data as built. build_data.mark_done(dpath, version) if datapath is not None and datapath.endswith('.txt'): # Load custom file. self.datafile = datapath else: # Build data from zoo, and place in given datapath. if datapath is None: # Build data from zoo. from parlai.core.params import ParlaiParser parser = ParlaiParser(False, False) self.datapath = parser.parse_args([])['datapath'] else: self.datapath = datapath self.datafile = _path() # store a token trie: e.g. # {'2': {'girls': {'1': {'cup': {'__END__': True}}}} self.END = '__END__' self.max_len = 1 self.offensive_trie = {} self.word_prefixes = [ 'de', 'de-', 'dis', 'dis-', 'ex', 'ex-', 'mis', 'mis-', 'pre', 'pre-', 'non', 'non-', 'semi', 'semi-', 'sub', 'sub-', 'un', 'un-', ] self.word_suffixes = [ 'a', 'able', 'as', 'dom', 'ed', 'er', 'ers', 'ery', 'es', 'est', 'ful', 'fy', 'ies', 'ify', 'in', 'ing', 'ish', 'less', 'ly', 's', 'y', ] self.allow_list = [ 'butter', 'buttery', 'spicy', 'spiced', 'spices', 'spicier', 'spicing', 'twinkies', ] with PathManager.open(self.datafile, 'r') as f: for p in f.read().splitlines(): mod_ps = [p] mod_ps += [pref + p for pref in self.word_prefixes] mod_ps += [p + suff for suff in self.word_suffixes] for mod_p in mod_ps: if mod_p not in self.allow_list: self.add_phrase(mod_p)
def get_data_from_file(self, filepath): data = [] with PathManager.open(filepath) as f: for line in f: data.append(json.loads(line)) return data
def _setup_data(self, opt): """ Load original LIGHT dataset. """ # Add new data? dt = opt['datatype'].split(':')[0] orig_episodes = OrigLightTeacher(opt).episodes if self.add_new_data: new_data = self._get_new_data(opt) total_data = orig_episodes + new_data self.fixed_random.shuffle(total_data) orig_episodes = total_data # Flatten this data flat_episodes = [] for ep in orig_episodes: # flatten the episode into 1-example episodes with context flattened_ep = flatten(ep, -1, include_labels=True, delimiter='\n') flat_episodes += flattened_ep # Counterfactual? if self.add_counterfactual and dt != 'test': with PathManager.open(os.path.join(_path(opt), COUNTERFACTUALS), 'rb') as f: self.swap_dct = json.load(f) new_episodes = [] for ex in flat_episodes: new_ex = self._flip_ex(ex) ex['counterfactual'] = False # mark which episode is swapped new_ex['counterfactual'] = True # add both old and new examples new_episodes.append(ex) new_episodes.append(new_ex) flat_episodes = new_episodes # Conditional training? bucket_percentages = {} new_episodes = [] for ex in flat_episodes: label_type = 'labels' if 'labels' in ex else 'eval_labels' label = ex[label_type][0] # get bucket for label bucket_key = self.get_bucket(label) # update the bucket percentages bucket_percentages.setdefault(bucket_key, 0) bucket_percentages[bucket_key] += 1 # append this bucket to the text field if self.add_conditional: if self.force_conditional is None: new_text = ex['text'] + '\n' + bucket_key else: # force the model to see a specific bucket every time # NOTE: we still track the original bucket that the # text was supposed to fall into new_text = ex['text'] + self.force_conditional ex.force_set('text', new_text) ex['bucket'] = bucket_key if self.bucket_only is None or self.bucket_only == bucket_key: new_episodes.append(ex) # Summarize the bucket distribution print('Distribution of bins:') total = sum(bucket_percentages.values()) strs = [] for k, v in bucket_percentages.items(): pct = round((v / total) * 100, 2) strs.append(f'{k}: {pct}%') strs = sorted(strs) for string in strs: print(string) return new_episodes
def _setup_data(self, path): logging.info(f"Loading ParlAI text data: {path}") self.episodes = [] self.num_exs = 0 eps = [] with PathManager.open(path, newline='\n', encoding='utf-8') as read: for line_no, line in enumerate(read, 1): msg = str_to_msg(line.rstrip('\n')) if msg and 'eval_labels' in msg: raise ValueError( f"It looks like you've written eval_labels as a key in your " f"data file. This is not appropriate; labels will be converted " f"for you automatically. This is happening on Line {line_no} " f"in {path}. The line is:\n\t{line}") if msg and 'text' not in msg: raise ValueError( f'ParlaiDialogTeacher requires a "text" field in every ' f'entry, but one is missing in Line {line_no} in {path}. ' f'The line is:\n\t{line}') if msg and 'labels' not in msg: raise ValueError( f'ParlaiDialogTeacher requires a "labels" field in every ' f'entry, but one is missing in Line {line_no} in {path}. ' f'The line is:\n\t{line}') if (self.opt['bad_speaker_to_eval'] != 'all' and self.opt['bad_speaker_to_eval'] != msg['speaker_to_eval']): continue if (self.opt['bad_safety_mix'] != 'all' and SAFETY_DICT[self.opt['bad_safety_mix']] != msg['labels'][0]): continue msg_text = msg['text'] dialog = msg_text.split('\n') if self.opt['bad_include_persona'] and msg[ 'speaker_to_eval'] == 'bot': # only display persona if it's asked to and if the last turn is bot. if len(msg['bot_persona'].strip()) > 0: dialog[0] = msg['bot_persona'] + '\n' + dialog[0] if self.opt['bad_num_turns'] > 0: msg_text = '\n'.join(dialog[-self.opt['bad_num_turns']:]) else: msg_text = '\n'.join(dialog) if msg: msg.force_set('text', msg_text) self.num_exs += 1 eps.append(msg) if msg.get('episode_done', False): self.episodes.append(eps) eps = [] if len(eps) > 0: # add last episode eps[-1].force_set('episode_done', True) self.episodes.append(eps) if len(self.episodes) == 1 and line_no > 100: logging.error( f'The data in {path} looks like one very long episode. If this ' f'is intentional, you may ignore this, but you MAY have a bug in ' f'your data.')
def self_chat(opt): random.seed(opt['seed']) partner = opt['partner_model_file'] partner_opt_file = opt.get('partner_opt_file') # Create agents agent1 = create_agent(opt, requireModelExists=True) agent1.opt.log("Agent 1 Opt") if partner is None: # Self chat with same model agent2 = agent1.clone() else: # Self chat with different models if partner_opt_file: print(f"WARNING: Loading override opts from: {partner_opt_file}") with PathManager.open(partner_opt_file) as f: partner_opt = json.load(f) else: partner_opt = {} partner_opt['interactive_mode'] = opt.get('interactive_mode', True) print( f"WARNING: Setting partner interactive mode to: {partner_opt['interactive_mode']}" ) agent2 = create_agent_from_model_file(partner, partner_opt) agent2.opt.log("Agent 2 Opt") # Set IDs agent1.id = agent1.id + "_1" agent2.id = agent2.id + "_2" model_id = agent1.id + "_" + agent2.id world = create_task(opt, user_agents=[agent1, agent2]) # Set up world logging logger = WorldLogger(opt) log_time = TimeLogger() # Run some self chats. for i in range(opt['num_self_chats']): _run_self_chat_episode(opt, world, logger) report = world.report() text, report = log_time.log(i + 1, opt['num_self_chats'], report) logging.info(text) # Save chats if opt['outfile'] is None: outfile = '/tmp/{}_selfchat'.format(model_id) else: outfile = opt['outfile'] if opt['save_format'] == 'conversations' and hasattr(world, 'write'): # use self chat specific world to write conversation # this might be useful for logging extra contextual # information (like personas) world.write(logger, outfile) else: # use default logger write function logger.write(outfile, world, opt['save_format']) return logger.get_logs()
def _setup_data(self, data_path, personalities_data_path): print('loading: ' + data_path) with PathManager.open(data_path) as f: self.data = json.load(f) with PathManager.open(personalities_data_path) as f: self.personalities = json.load(f)
def __init__(self, opt: Opt, shared=None): """ Initialize DictionaryAgent. """ self.opt = copy.deepcopy(opt) self.minfreq = opt.get('dict_minfreq', DictionaryAgent.default_minfreq) self.null_token = opt.get('dict_nulltoken', DictionaryAgent.default_null) self.end_token = opt.get('dict_endtoken', DictionaryAgent.default_end) self.unk_token = opt.get('dict_unktoken', DictionaryAgent.default_unk) self.start_token = opt.get('dict_starttoken', DictionaryAgent.default_start) self.max_ngram_size = opt.get('dict_max_ngram_size', DictionaryAgent.default_maxngram) self.tokenizer = opt.get('dict_tokenizer', DictionaryAgent.default_tok) self.lower = opt.get('dict_lower', DictionaryAgent.default_lower) self.maxtokens = opt.get('dict_maxtokens', DictionaryAgent.default_maxtokens) self.textfields = opt.get( 'dict_textfields', DictionaryAgent.default_textfields).split(",") try: self.tokenizer_fun = getattr(self, self.tokenizer + '_tokenize') except AttributeError: raise AttributeError('tokenizer type {} not yet supported'.format( self.tokenizer)) if shared: self.freq = shared.get('freq', {}) self.tok2ind = shared.get('tok2ind', {}) self.ind2tok = shared.get('ind2tok', {}) else: self.freq = defaultdict(int) self.tok2ind = {} self.ind2tok = {} if self.null_token: self.add_token(self.null_token) if self.start_token: # set special start of sentence word token self.add_token(self.start_token) if self.end_token: # set special end of sentence word token self.add_token(self.end_token) if self.unk_token: # set special unknown word token self.add_token(self.unk_token) loaded = False # If data built via pytorch data teacher, we need to load prebuilt dict if opt.get('dict_file'): opt['dict_file'] = modelzoo_path(opt.get('datapath'), opt['dict_file']) if PathManager.exists(opt['dict_file']): # load pre-existing dictionary self.load(opt['dict_file']) loaded = True if not loaded and opt.get('dict_initpath'): # load seed dictionary opt['dict_initpath'] = modelzoo_path(opt.get('datapath'), opt['dict_initpath']) # don't check isfile first, should fail if file not found self.load(opt['dict_initpath']) opt['dict_loaded'] = loaded # cache unk token for later self._unk_token_idx = self.tok2ind.get(self.unk_token) # initialize tokenizers if self.tokenizer == 'nltk': try: import nltk except ImportError: raise ImportError('Please install nltk (pip install nltk)') # nltk-specific setup st_path = 'tokenizers/punkt/{0}.pickle'.format( opt['dict_language']) try: self.sent_tok = nltk.data.load(st_path) except LookupError: nltk.download('punkt') self.sent_tok = nltk.data.load(st_path) self.word_tok = nltk.tokenize.treebank.TreebankWordTokenizer() elif self.tokenizer in [ 'bpe', 'gpt2', 'bytelevelbpe', 'slow_bytelevel_bpe' ]: self.bpe = bpe_factory(opt, shared) self.bpe.sync_with_dict(self) if not shared: if self.null_token: # fix count for null token to one billion and three self.freq[self.null_token] = 1000000003 if self.start_token: # fix count for start of sentence token to one billion and two self.freq[self.start_token] = 1000000002 if self.end_token: # fix count for end of sentence token to one billion and one self.freq[self.end_token] = 1000000001 if self.unk_token: # fix count for unknown token to one billion self.freq[self.unk_token] = 1000000000 if opt.get('dict_file'): self.save_path = opt['dict_file']
def _setup_data(self, data_path): print('loading: ' + data_path) with PathManager.open(data_path) as data_file: self.episodes = data_file.readlines()
def _setup_data(self, base_datatype): if self.opt.get('deepmoji') is not None: self.embed = np.load(self.opt['deepmoji'] + base_datatype + ".npy") if self.opt.get('fasttextloc') is not None and self.opt.get( 'prepend', -1) > 0: try: import fastText except ImportError: raise ImportError("Please run 'pip install fasttext'.") ftpath = self.opt['fasttextloc'] ftmodel = fastText.FastText.load_model(ftpath) with PathManager.open(self.datapath) as f: df = f.readlines() turn_idx = 1 responder_text_dialogue = [] experiencer_text_dialogue = [] self.data = [] for i in range(1, len(df)): cparts = df[i - 1].strip().split(",") sparts = df[i].strip().split(",") if cparts[0] == sparts[0]: # Check that the turn number has incremented correctly turn_idx += 1 assert (int(cparts[1]) + 1 == int(sparts[1]) and int(sparts[1]) == turn_idx) contextt = cparts[5].replace("_comma_", ",") label = sparts[5].replace("_comma_", ",") prompt = sparts[2] sit = sparts[3].replace("_comma_", ",") if len(sparts) == 9: if sparts[8] != '': inline_label_candidates = [ cand.replace("_comma_", ",").replace("_pipe_", "|") for cand in sparts[8].split('|') ] else: inline_label_candidates = [] elif len(sparts) == 8: inline_label_candidates = [] else: raise ValueError( f'Line {i:d} has the wrong number of fields!') context_emb, cand_emb = None, None if self.opt.get('deepmoji') is not None: context_emb = self.embed[i - 2] cand_emb = self.embed[i - 1] ft_ctx, ft_cand = None, None if (self.opt.get('fasttextloc') is not None and self.opt.get('prepend', -1) > 0): ft_ctx = "" gettop, _ = ftmodel.predict(contextt, k=self.opt['prepend']) for f in gettop: ft_ctx = f.split("_")[-1] + " " + ft_ctx ft_cand = "" gettop, _ = ftmodel.predict(label, k=self.opt['prepend']) for f in gettop: ft_cand = f.split("_")[-1] + " " + ft_cand # Check if either the text or label are marked as being political is_political = '<POLITICAL>' in cparts[ 7] or '<POLITICAL>' in sparts[7] dialogue_parts = [ contextt, label, prompt, sit, context_emb, cand_emb, ft_ctx, ft_cand, inline_label_candidates, is_political, ] if int(sparts[1]) % 2 == 0: # experiencer is the "text" and responder is the "label" experiencer_text_dialogue.append(dialogue_parts) else: # responder is the "text" and experiencer is the "label" responder_text_dialogue.append(dialogue_parts) else: # We've finished the previous episode, so add it to the data turn_idx = 1 self.data += self._select_dialogues_to_add( experiencer_text_dialogue, responder_text_dialogue) experiencer_text_dialogue = [] responder_text_dialogue = [] # Add in the final episode self.data += self._select_dialogues_to_add(experiencer_text_dialogue, responder_text_dialogue)
def save_conversations( cls, act_list, datapath, opt, save_keys='all', context_ids='context', self_chat=False, **kwargs, ): """ Write Conversations to file from an act list. Conversations assume the act list is of the following form: a list of episodes, each of which is comprised of a list of act pairs (i.e. a list dictionaries returned from one parley) """ to_save = cls._get_path(datapath) context_ids = context_ids.split(',') # save conversations speakers = [] with PathManager.open(to_save, 'w') as f: for ep in act_list: if not ep: continue convo = { 'dialog': [], 'context': [], 'metadata_path': Metadata._get_path(to_save), } for act_pair in ep: new_pair = [] for ex in act_pair: ex_id = ex.get('id') if ex_id in context_ids: context = True else: context = False if ex_id not in speakers: speakers.append(ex_id) # set turn turn = {} if save_keys != 'all': save_keys_lst = save_keys.split(',') else: save_keys_lst = [ key for key in ex.keys() if key != 'metrics' ] for key in save_keys_lst: turn[key] = ex.get(key, '') turn['id'] = ex_id if not context: new_pair.append(turn) else: convo['context'].append(turn) if new_pair: convo['dialog'].append(new_pair) json_convo = json.dumps(convo) f.write(json_convo + '\n') logging.info(f'Conversations saved to file: {to_save}') # save metadata Metadata.save_metadata(to_save, opt, self_chat=self_chat, speakers=speakers, **kwargs)
def download(url, path, fname, redownload=False, num_retries=5): """ Download file using `requests`. If ``redownload`` is set to false, then will not download tar file again if it is present (default ``False``). """ outfile = os.path.join(path, fname) download = not PathManager.exists(outfile) or redownload logging.info(f"Downloading {url} to {outfile}") retry = num_retries exp_backoff = [2**r for r in reversed(range(retry))] pbar = tqdm.tqdm(unit='B', unit_scale=True, desc='Downloading {}'.format(fname)) while download and retry > 0: response = None with requests.Session() as session: try: response = session.get(url, stream=True, timeout=5) # negative reply could be 'none' or just missing CHUNK_SIZE = 32768 total_size = int(response.headers.get('Content-Length', -1)) # server returns remaining size if resuming, so adjust total pbar.total = total_size done = 0 with PathManager.open(outfile, 'wb') as f: for chunk in response.iter_content(CHUNK_SIZE): if chunk: # filter out keep-alive new chunks f.write(chunk) if total_size > 0: done += len(chunk) if total_size < done: # don't freak out if content-length was too small total_size = done pbar.total = total_size pbar.update(len(chunk)) break except ( requests.exceptions.ConnectionError, requests.exceptions.ReadTimeout, ): retry -= 1 pbar.clear() if retry > 0: pl = 'y' if retry == 1 else 'ies' logging.debug( f'Connection error, retrying. ({retry} retr{pl} left)') time.sleep(exp_backoff[retry]) else: logging.error('Retried too many times, stopped retrying.') finally: if response: response.close() if retry <= 0: raise RuntimeError( 'Connection broken too many times. Stopped retrying.') if download and retry > 0: pbar.update(done - pbar.n) if done < total_size: raise RuntimeError( f'Received less data than specified in Content-Length header for ' f'{url}. There may be a download problem.') pbar.close()
def create_supp(opt): """ Evaluates a model. :param opt: tells the evaluation function how to run :return: the final result of calling report() """ # Create model and assign it to the specified task agent = create_agent(opt, requireModelExists=True) world = create_task(opt, agent) # Extract supp examples from misses on deploy set num_seen = 0 num_misses = 0 num_supp = 0 num_supp_correct = 0 examples = [] while not world.epoch_done(): world.parley() # Examples are considered one at a time num_seen += 1 if num_seen % 1000 == 0: print(f"{num_seen}/{world.num_examples()}") report = world.report() if report['accuracy'] < 1.0: # Example is a miss (i.e., model got it wrong) num_misses += 1 if random.random() < opt['conversion_rate']: # Example will be converted (e.g., bot recognized mistake and asked) num_supp += 1 texts = world.acts[0]['text'].split('\n') context = texts[-1] memories = texts[:-1] candidates = world.acts[0]['label_candidates'] # Reward of 1 indicates positive, -1 indicates negative (for training) # For now, we only train with positives, and the reward field is unused reward = 1 if random.random() < opt['conversion_acc']: # Example will be converted correctly (e.g., good user response) num_supp_correct += 1 response = world.acts[0]['eval_labels'][0] else: # Example will be converted incorrectly (e.g., bad user response) response = random.choice( world.acts[0]['label_candidates'][:NUM_INLINE_CANDS - 1]) example = Parley(context, response, reward, candidates, memories) examples.append(example) world.reset_metrics() print("EPOCH DONE") print(f"Model file: {opt['model_file']}") print(f"Deploy file: {opt['task']}") print(f"Supp file: {opt['outfile']}") print(f"Deploy size (# examples seen): {num_seen}") print(f"Supp size (# examples converted): {num_supp}") acc = 1 - (num_misses / num_seen) print(f"Accuracy (% of deploy): {acc * 100:.1f}% ({num_misses} misses)") print(f"Conversion rate (% of misses): {num_supp/num_misses * 100:.2f}% " f"({num_supp}/{num_misses})") print( f"Conversion acc (% of converted): {num_supp_correct/num_supp * 100:.2f}% " f"({num_supp_correct}/{num_supp})") with PathManager.open(opt['outfile'], 'w') as outfile: for ex in examples: outfile.write(json.dumps(ex.to_dict()) + '\n')
def data_to_json(self, pd, file_name): response = pd.to_dict('records') with PathManager.open(os.path.join(self.data_path, file_name), 'w') as f: f.write(json.dumps(response, indent=4))
def make_parlai_format(outpath, dtype, data): print('building parlai:' + dtype) with PathManager.open(os.path.join(outpath, dtype + '.txt'), 'w') as fout: for data_point in data: fout.write(_handle_data_point(data_point))
def _load_data_dump(self): with PathManager.open(self.data_path, 'rb') as f: dump = json.load(f) return dump['adversarial']
def set_fixed_candidates(self, shared): """ Load a set of fixed candidates and their vectors (or vectorize them here). self.fixed_candidates will contain a [num_cands] list of strings self.fixed_candidate_vecs will contain a [num_cands, seq_len] LongTensor See the note on the --fixed-candidate-vecs flag for an explanation of the 'reuse', 'replace', or path options. Note: TorchRankerAgent by default converts candidates to vectors by vectorizing in the common sense (i.e., replacing each token with its index in the dictionary). If a child model wants to additionally perform encoding, it can overwrite the vectorize_fixed_candidates() method to produce encoded vectors instead of just vectorized ones. """ if shared: self.fixed_candidates = shared['fixed_candidates'] self.fixed_candidate_vecs = shared['fixed_candidate_vecs'] self.fixed_candidate_encs = shared['fixed_candidate_encs'] self.num_fixed_candidates = shared['num_fixed_candidates'] else: self.num_fixed_candidates = 0 opt = self.opt cand_path = self.fixed_candidates_path if 'fixed' in (self.candidates, self.eval_candidates): if not cand_path: # Attempt to get a standard candidate set for the given task path = self.get_task_candidates_path() if path: logging.info(f"setting fixed_candidates path to: {path}") self.fixed_candidates_path = path cand_path = self.fixed_candidates_path # Load candidates logging.info(f"Loading fixed candidate set from {cand_path}") with PathManager.open(cand_path, 'r', encoding='utf-8') as f: cands = [line.strip() for line in f.readlines()] # Load or create candidate vectors if PathManager.exists(self.opt['fixed_candidate_vecs']): vecs_path = opt['fixed_candidate_vecs'] vecs = self.load_candidates(vecs_path) else: setting = self.opt['fixed_candidate_vecs'] model_dir, model_file = os.path.split(self.opt['model_file']) model_name = os.path.splitext(model_file)[0] cands_name = os.path.splitext(os.path.basename(cand_path))[0] vecs_path = os.path.join( model_dir, '.'.join([model_name, cands_name, 'vecs']) ) if setting == 'reuse' and PathManager.exists(vecs_path): vecs = self.load_candidates(vecs_path) else: # setting == 'replace' OR generating for the first time vecs = self._make_candidate_vecs(cands) self._save_candidates(vecs, vecs_path) self.fixed_candidates = cands self.num_fixed_candidates = len(self.fixed_candidates) self.fixed_candidate_vecs = vecs if self.use_cuda: self.fixed_candidate_vecs = self.fixed_candidate_vecs.cuda() if self.encode_candidate_vecs: # candidate encodings are fixed so set them up now enc_path = os.path.join( model_dir, '.'.join([model_name, cands_name, 'encs']) ) if setting == 'reuse' and PathManager.exists(enc_path): encs = self.load_candidates(enc_path, cand_type='encodings') else: encs = self._make_candidate_encs(self.fixed_candidate_vecs) self._save_candidates( encs, path=enc_path, cand_type='encodings' ) self.fixed_candidate_encs = encs if self.use_cuda: self.fixed_candidate_encs = self.fixed_candidate_encs.cuda() if self.fp16: self.fixed_candidate_encs = self.fixed_candidate_encs.half() else: self.fixed_candidate_encs = self.fixed_candidate_encs.float() else: self.fixed_candidate_encs = None else: self.fixed_candidates = None self.fixed_candidate_vecs = None self.fixed_candidate_encs = None
def _setup_data(self): print('loading: ' + self.data_path) with PathManager.open(self.data_path) as f: self.data = json.load(f)
def __init__(self, opt: Opt, shared=None): init_model, self.is_finetune = self._get_init_model(opt, shared) super().__init__(opt, shared) # set up classes if opt.get('classes') is None and opt.get('classes_from_file') is None: raise RuntimeError( 'Must specify --classes or --classes-from-file argument.') if not shared: if opt['classes_from_file'] is not None: with PathManager.open(opt['classes_from_file']) as f: self.class_list = f.read().splitlines() else: self.class_list = opt['classes'] self.class_dict = {val: i for i, val in enumerate(self.class_list)} if opt.get('class_weights', None) is not None: self.class_weights = opt['class_weights'] else: self.class_weights = [1.0 for c in self.class_list] self.reset_metrics() else: self.class_list = shared['class_list'] self.class_dict = shared['class_dict'] self.class_weights = shared['class_weights'] # in binary classfication, opt['threshold'] applies to ref class if opt['ref_class'] is None or opt['ref_class'] not in self.class_dict: self.ref_class = self.class_list[0] else: self.ref_class = opt['ref_class'] ref_class_id = self.class_list.index(self.ref_class) if ref_class_id != 0: # move to the front of the class list self.class_list.insert(0, self.class_list.pop(ref_class_id)) # set up threshold, only used in binary classification if len(self.class_list) == 2 and opt.get('threshold', 0.5) != 0.5: self.threshold = opt['threshold'] else: self.threshold = None # set up model and optimizers states = {} if shared: self.model = shared['model'] else: self.model = self.build_model() # freeze the encoder and update the classifier only if opt.get("update_classifier_head_only", False): for _param_name, _param_value in self.model.named_parameters(): if not _param_name.startswith('additional_linear_layer'): _param_value.requires_grad = False self.criterion = self.build_criterion() if self.model is None or self.criterion is None: raise AttributeError( 'build_model() and build_criterion() need to return the model or criterion' ) if init_model: logging.info( f'Loading existing model parameters from {init_model}') states = self.load(init_model) if self.use_cuda: if self.model_parallel: ph = PipelineHelper() ph.check_compatibility(self.opt) self.model = ph.make_parallel(self.model) else: self.model.cuda() if self.data_parallel: self.model = torch.nn.DataParallel(self.model) self.criterion.cuda() train_params = trainable_parameters(self.model) total_params = total_parameters(self.model) logging.info( f"Total parameters: {total_params:,d} ({train_params:,d} trainable)" ) if shared: # We don't use get here because hasattr is used on optimizer later. if 'optimizer' in shared: self.optimizer = shared['optimizer'] elif self._should_initialize_optimizer(): optim_params = [ p for p in self.model.parameters() if p.requires_grad ] self.init_optim(optim_params) self.build_lr_scheduler(states, hard_reset=self.is_finetune)