def write_dataset(self, dataset_type): dataset = (tf.data.TextLineDataset( transformers.cached_path(DATA_SOURCES[dataset_type]), compression_type="GZIP", ).map( lambda x: tf.py_function(self.process, [x], [tf.int32, tf.float32] ), num_parallel_calls=tf.data.experimental.AUTOTUNE, ).interleave( lambda x, y: tf.data.Dataset.from_tensor_slices((x, y)), cycle_length=1, num_parallel_calls=tf.data.experimental.AUTOTUNE, )) def serialize_example(input_ids, label): def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList( value=[value.numpy()])) feature = { "input_ids": _bytes_feature(tf.io.serialize_tensor(input_ids)), "label": _bytes_feature(tf.io.serialize_tensor(label)), } example_proto = tf.train.Example(features=tf.train.Features( feature=feature)) return example_proto.SerializeToString() dataset = dataset.map( lambda x, y: tf.py_function(serialize_example, [x, y], [tf.string]) [0], num_parallel_calls=tf.data.experimental.AUTOTUNE, ) file_path = f"./{self.get_file_name(dataset_type)}" writer = tf.data.experimental.TFRecordWriter(file_path, compression_type="GZIP") writer.write(dataset)
def get_dataset_personalities(tokenizer, dataset_path, dataset_cache=None): """ Get personalities from PERSONACHAT """ dataset_path = dataset_path or PERSONACHAT_URL dataset_cache = dataset_cache + '_' + type( tokenizer ).__name__ # Do avoid using GPT cache for GPT-2 and vice-versa if os.path.isfile(dataset_cache): logger.info("Load tokenized dataset from cache at %s", dataset_cache) personachat = torch.load(dataset_cache) else: logger.info("Download PERSONACHAT dataset from %s", dataset_path) personachat_file = cached_path(dataset_path) with open(personachat_file, "r", encoding="utf-8") as f: personachat = json.loads(f.read()) logger.info("Tokenize and encode the dataset") def tokenize(obj): if isinstance(obj, str): return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj)) if isinstance(obj, dict): return dict((n, tokenize(o)) for n, o in obj.items()) return list(tokenize(o) for o in obj) personachat = tokenize(personachat) torch.save(personachat, dataset_cache) logger.info("Filter personalities") personality1 = [] personality2 = [] for dataset in personachat.values(): for dialog in dataset: personality1.append(dialog["persona_info"]) personality2.append(dialog["persona_info2"]) logger.info("Gathered {} personality 1".format(len(personality1))) logger.info("Gathered {} personality 2".format(len(personality2))) return personality1, personality2
def get_dataset(tokenizer, dataset_path, dataset_cache): """ Get tokenized PERSONACHAT dataset from S3 or cache.""" dataset_path = dataset_path or PERSONACHAT_URL dataset_cache = dataset_cache + '_' + type(tokenizer).__name__ # To avoid using GPT cache for GPT-2 and vice-versa if dataset_cache and os.path.isfile(dataset_cache): logger.info("Load tokenized dataset from cache at %s", dataset_cache) dataset = torch.load(dataset_cache) else: logger.info("Download dataset from %s", dataset_path) personachat_file = cached_path(dataset_path) with open(personachat_file, "r", encoding="utf-8") as f: dataset = json.loads(f.read()) logger.info("Tokenize and encode the dataset") def tokenize(obj): if isinstance(obj, str): return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj)) if isinstance(obj, dict): return dict((n, tokenize(o)) for n, o in obj.items()) return list(tokenize(o) for o in obj) dataset = tokenize(dataset) torch.save(dataset, dataset_cache) return dataset
def get_dataset(tokenizer, dataset_path, dataset_cache=None, as_strings=False): """ Get PERSONACHAT from S3 """ dataset_path = dataset_path or PERSONACHAT_URL os.makedirs(dataset_cache, exist_ok=True) dataset_cache = dataset_cache + '/' + dataset_path.split('/')[-1].replace('.json', '') + '_' + type(tokenizer).__name__ # Do avoid using GPT cache for GPT-2 and vice-versa if as_strings: dataset_cache += '_STRINGS' if dataset_cache and os.path.isfile(dataset_cache): logger.info("Load tokenized dataset from cache at %s", dataset_cache) dataset = torch.load(dataset_cache) else: logger.info("Download dataset from %s", dataset_path) personachat_file = cached_path(dataset_path) with open(personachat_file, "r", encoding="utf-8") as f: dataset = json.loads(f.read()) logger.info("Tokenize and encode the dataset") def tokenize(obj): if isinstance(obj, str): # remove space before sentence marker and tokenize tokens = tokenizer.tokenize(obj.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',')) if as_strings: return tokens return tokenizer.convert_tokens_to_ids(tokens) if isinstance(obj, tuple) and len(obj) == 2 and isinstance(obj[0], str): speaker, obj = obj if as_strings: return speaker, tokenize(obj) return tokenizer.special_tokens[speaker], tokenize(obj) if isinstance(obj, dict): return dict((n, tokenize(' '.join(o))) if n == 'personality' and isinstance(o, list) else (n, tokenize(o)) for n, o in obj.items()) return list(tokenize(o) for o in obj) dataset = tokenize(dataset) if dataset_cache: torch.save(dataset, dataset_cache) return dataset
def __init__(self, args, task): super().__init__(args, task) self.eps = args.label_smoothing from fairseq.sequence_generator import SequenceGenerator self.gen = SequenceGenerator(task.target_dictionary, beam_size=args.beam_size) if args.reward == "bleurt": from fairseq.distributed_utils import get_rank sys.argv = sys.argv[:1] my_rank = 0 if torch.cuda.device_count() <= 1 else get_rank() os.environ["CUDA_VISIBLE_DEVICES"] = str(my_rank % 4) from bleurt import score from transformers import cached_path import tensorflow as tf gpus = tf.config.experimental.list_physical_devices('GPU') if gpus: this_gpu = gpus[my_rank % 4] tf.config.set_visible_devices([this_gpu], 'GPU') try: tf.config.experimental.set_memory_growth(this_gpu, True) tf.config.experimental.set_virtual_device_configuration( this_gpu, [ tf.config.experimental.VirtualDeviceConfiguration( memory_limit=2048) ]) logical_devices = tf.config.list_logical_devices('GPU') self.logical_device = tf.device(logical_devices[0].name) print("num of logical gpus", len(logical_devices)) except RuntimeError as e: print(e) with self.logical_device: self.bleurt_scorer = score.BleurtScorer( os.path.join( cached_path( "https://storage.googleapis.com/bleurt-oss/bleurt-base-128.zip", extract_compressed_file=True), "bleurt-base-128"))
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): """ Instantiate a PreTrainedBertModel from a pre-trained model file. Download and cache the pre-trained model file if needed. """ if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] else: vocab_file = pretrained_model_name_or_path if os.path.isdir(vocab_file): vocab_file = os.path.join(vocab_file, VOCAB_NAME) # redirect to the cache, if necessary try: resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) except EnvironmentError: logger.error( "Model name '{}' was not found in model name list ({}). " "We assumed '{}' was a path or url but couldn't find any file " "associated to this path or url.".format( pretrained_model_name_or_path, ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), vocab_file)) return None if resolved_vocab_file == vocab_file: logger.info("loading vocabulary file {}".format(vocab_file)) else: logger.info("loading vocabulary file {} from cache at {}".format( vocab_file, resolved_vocab_file)) if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: # if we're using a pretrained model, ensure the tokenizer wont index sequences longer # than the number of positional embeddings max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) # Instantiate tokenizer. tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) return tokenizer
def convert_all_pt_checkpoints_to_tf( args_model_type, tf_dump_path, model_shortcut_names_or_path=None, config_shortcut_names_or_path=None, compare_with_pt_model=False, use_cached_models=False, remove_cached_files=False, only_convert_finetuned_models=False, ): if args_model_type is None: model_types = list(MODEL_CLASSES.keys()) else: model_types = [args_model_type] for j, model_type in enumerate(model_types, start=1): print("=" * 100) print(" Converting model type {}/{}: {}".format( j, len(model_types), model_type)) print("=" * 100) if model_type not in MODEL_CLASSES: raise ValueError( "Unrecognized model type {}, should be one of {}.".format( model_type, list(MODEL_CLASSES.keys()))) config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[ model_type] if model_shortcut_names_or_path is None: model_shortcut_names_or_path = list(aws_model_maps.keys()) if config_shortcut_names_or_path is None: config_shortcut_names_or_path = model_shortcut_names_or_path for i, (model_shortcut_name, config_shortcut_name) in enumerate( zip(model_shortcut_names_or_path, config_shortcut_names_or_path), start=1): print("-" * 100) if "-squad" in model_shortcut_name or "-mrpc" in model_shortcut_name or "-mnli" in model_shortcut_name: if not only_convert_finetuned_models: print(" Skipping finetuned checkpoint {}".format( model_shortcut_name)) continue model_type = model_shortcut_name elif only_convert_finetuned_models: print(" Skipping not finetuned checkpoint {}".format( model_shortcut_name)) continue print(" Converting checkpoint {}/{}: {} - model_type {}".format( i, len(aws_config_map), model_shortcut_name, model_type)) print("-" * 100) if config_shortcut_name in aws_config_map: config_file = cached_path(aws_config_map[config_shortcut_name], force_download=not use_cached_models) else: config_file = cached_path(config_shortcut_name, force_download=not use_cached_models) if model_shortcut_name in aws_model_maps: model_file = cached_path(aws_model_maps[model_shortcut_name], force_download=not use_cached_models) else: model_file = cached_path(model_shortcut_name, force_download=not use_cached_models) if os.path.isfile(model_shortcut_name): model_shortcut_name = "converted_model" convert_pt_checkpoint_to_tf( model_type=model_type, pytorch_checkpoint_path=model_file, config_file=config_file, tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + "-tf_model.h5"), compare_with_pt_model=compare_with_pt_model, ) if remove_cached_files: os.remove(config_file) os.remove(model_file)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--model_name', type=str, default='openai-gpt', help='pretrained model name') parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument( "--output_dir", default=None, type=str, required=True, help= "The output directory where the model predictions and checkpoints will be written." ) parser.add_argument('--train_dataset', type=str, default='') parser.add_argument('--eval_dataset', type=str, default='') parser.add_argument('--seed', type=int, default=42) parser.add_argument('--num_train_epochs', type=int, default=3) parser.add_argument('--train_batch_size', type=int, default=8) parser.add_argument('--eval_batch_size', type=int, default=16) parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") parser.add_argument('--max_grad_norm', type=int, default=1) parser.add_argument("--max_steps", default=-1, type=int, help="If > 0: set total number of training \ steps to perform. Override num_train_epochs.") parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help="Number of updates steps to accumulate before\ performing a backward/update pass.") parser.add_argument('--learning_rate', type=float, default=6.25e-5) parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") parser.add_argument('--lr_schedule', type=str, default='warmup_linear') parser.add_argument('--weight_decay', type=float, default=0.01) parser.add_argument('--lm_coef', type=float, default=0.9) parser.add_argument('--n_valid', type=int, default=374) parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") args = parser.parse_args() print(args) if args.server_ip and args.server_port: # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script import ptvsd print("Waiting for debugger attach") ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) ptvsd.wait_for_attach() random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() logger.info("device: {}, n_gpu {}".format(device, n_gpu)) if not args.do_train and not args.do_eval: raise ValueError( "At least one of `do_train` or `do_eval` must be True.") if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) # Load tokenizer and model # This loading functions also add new tokens and embeddings called `special tokens` # These new embeddings will be fine-tuned on the RocStories dataset special_tokens = ['_start_', '_delimiter_', '_classify_'] tokenizer = OpenAIGPTTokenizer.from_pretrained(args.model_name) tokenizer.add_tokens(special_tokens) special_tokens_ids = tokenizer.convert_tokens_to_ids(special_tokens) model = OpenAIGPTDoubleHeadsModel.from_pretrained(args.model_name) model.resize_token_embeddings(len(tokenizer)) model.to(device) # Load and encode the datasets if not args.train_dataset and not args.eval_dataset: roc_stories = cached_path(ROCSTORIES_URL) def tokenize_and_encode(obj): """ Tokenize and encode a nested object """ if isinstance(obj, str): return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj)) elif isinstance(obj, int): return obj return list(tokenize_and_encode(o) for o in obj) logger.info("Encoding dataset...") train_dataset = load_rocstories_dataset(args.train_dataset) eval_dataset = load_rocstories_dataset(args.eval_dataset) datasets = (train_dataset, eval_dataset) encoded_datasets = tokenize_and_encode(datasets) # Compute the max input length for the Transformer max_length = model.config.n_positions // 2 - 2 input_length = max(len(story[:max_length]) + max(len(cont1[:max_length]), len(cont2[:max_length])) + 3 \ for dataset in encoded_datasets for story, cont1, cont2, _ in dataset) input_length = min(input_length, model.config.n_positions ) # Max size of input for the pre-trained model # Prepare inputs tensors and dataloaders tensor_datasets = pre_process_datasets(encoded_datasets, input_length, max_length, *special_tokens_ids) train_tensor_dataset, eval_tensor_dataset = tensor_datasets[ 0], tensor_datasets[1] train_data = TensorDataset(*train_tensor_dataset) train_sampler = RandomSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) eval_data = TensorDataset(*eval_tensor_dataset) eval_sampler = SequentialSampler(eval_data) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) # Prepare optimizer if args.do_train: if args.max_steps > 0: t_total = args.max_steps args.num_train_epochs = args.max_steps //\ (len(train_dataloader) // args.gradient_accumulation_steps) + 1 else: t_total = len(train_dataloader)\ // args.gradient_accumulation_steps * args.num_train_epochs param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], 'weight_decay': args.weight_decay }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) if args.do_train: nb_tr_steps, tr_loss, exp_average_loss = 0, 0, None model.train() for _ in trange(int(args.num_train_epochs), desc="Epoch"): tr_loss = 0 nb_tr_steps = 0 tqdm_bar = tqdm(train_dataloader, desc="Training") for step, batch in enumerate(tqdm_bar): batch = tuple(t.to(device) for t in batch) input_ids, mc_token_ids, lm_labels, mc_labels = batch losses = model(input_ids, mc_token_ids=mc_token_ids, lm_labels=lm_labels, mc_labels=mc_labels) loss = args.lm_coef * losses[0] + losses[1] loss.backward() scheduler.step() optimizer.step() optimizer.zero_grad() tr_loss += loss.item() exp_average_loss = loss.item( ) if exp_average_loss is None else 0.7 * exp_average_loss + 0.3 * loss.item( ) nb_tr_steps += 1 tqdm_bar.desc = "Training loss: {:.2e} lr: {:.2e}".format( exp_average_loss, scheduler.get_lr()[0]) # Save a trained model if args.do_train: # Save a trained model, configuration and tokenizer model_to_save = model.module if hasattr( model, 'module') else model # Only save the model itself # If we save using the predefined names, we can load using `from_pretrained` output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) output_config_file = os.path.join(args.output_dir, CONFIG_NAME) torch.save(model_to_save.state_dict(), output_model_file) model_to_save.config.to_json_file(output_config_file) tokenizer.save_vocabulary(args.output_dir) # Load a trained model and vocabulary that you have fine-tuned model = OpenAIGPTDoubleHeadsModel.from_pretrained(args.output_dir) tokenizer = OpenAIGPTTokenizer.from_pretrained(args.output_dir) model.to(device) if args.do_eval: model.eval() eval_loss, eval_accuracy = 0, 0 nb_eval_steps, nb_eval_examples = 0, 0 for batch in tqdm(eval_dataloader, desc="Evaluating"): batch = tuple(t.to(device) for t in batch) input_ids, mc_token_ids, lm_labels, mc_labels = batch with torch.no_grad(): _, mc_loss, _, mc_logits = model(input_ids, mc_token_ids=mc_token_ids, lm_labels=lm_labels, mc_labels=mc_labels) mc_logits = mc_logits.detach().cpu().numpy() mc_labels = mc_labels.to('cpu').numpy() tmp_eval_accuracy = accuracy(mc_logits, mc_labels) eval_loss += mc_loss.mean().item() eval_accuracy += tmp_eval_accuracy nb_eval_examples += input_ids.size(0) nb_eval_steps += 1 eval_loss = eval_loss / nb_eval_steps eval_accuracy = eval_accuracy / nb_eval_examples train_loss = tr_loss / nb_tr_steps if args.do_train else None result = { 'eval_loss': eval_loss, 'eval_accuracy': eval_accuracy, 'train_loss': train_loss } output_eval_file = os.path.join(args.output_dir, "eval_results.txt") with open(output_eval_file, "w") as writer: logger.info("***** Eval results *****") for key in sorted(result.keys()): logger.info(" %s = %s", key, str(result[key])) writer.write("%s = %s\n" % (key, str(result[key])))
def get_dataset( tokenizer, dataset_path, dataset_cache, process_count, proxies, evaluate=False, interact=False, no_cache=False, args=None, ): """Get tokenized PERSONACHAT dataset from S3 or cache.""" dataset_path = dataset_path or PERSONACHAT_URL mode = "eval" if evaluate else "train" if interact: mode = "interact" dataset_cache = (dataset_cache + "_" + type(tokenizer).__name__ + "_" + mode) # To avoid using GPT cache for GPT-2 and vice-versa if dataset_cache and os.path.isfile(dataset_cache) and not no_cache: logger.info("Load tokenized dataset from cache at %s", dataset_cache) dataset = torch.load(dataset_cache) else: logger.info("Download dataset from %s", dataset_path) personachat_file = cached_path(dataset_path, proxies=proxies) with open(personachat_file, "r", encoding="utf-8") as f: dataset = json.loads(f.read()) logger.info("Tokenize and encode the dataset") def tokenize(obj): if isinstance(obj, str): return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj)) if isinstance(obj, dict): return dict((n, tokenize(o)) for n, o in obj.items()) data = [(d, tokenizer) for d in obj] if args.multiprocessing_chunksize == -1: chunksize = max(len(data) // (args.process_count * 2), 500) else: chunksize = args.multiprocessing_chunksize with Pool(process_count) as p: tokenized_data = list( tqdm( p.imap(tokenize_multi, data, chunksize=chunksize), total=len(data), )) return tokenized_data if not interact and dataset_path == PERSONACHAT_URL: if not evaluate: dataset = dataset["train"] else: dataset = dataset["valid"] dataset = tokenize(dataset) torch.save(dataset, dataset_cache) return dataset
def __init__(self, path: str = 'small', device=None, **kwargs): if device is not None: if isinstance(device, torch.device): self.device = device elif isinstance(device, str): self.device = torch.device(device) elif torch.cuda.is_available(): self.device = torch.device('cuda') else: self.device = torch.device('cpu') if path in model_map or is_remote_url(path) or os.path.isfile(path): proxies = kwargs.pop("proxies", None) cache_dir = kwargs.pop("cache_dir", LTP_CACHE) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) local_files_only = kwargs.pop("local_files_only", False) path = cached_path(model_map.get(path, path), cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, extract_compressed_file=True) elif not os.path.isdir(path): raise FileNotFoundError() try: ckpt = torch.load(os.path.join(path, "ltp.model"), map_location=self.device) except Exception as e: fake_import_pytorch_lightning() ckpt = torch.load(os.path.join(path, "ltp.model"), map_location=self.device) patch_4_1_3(ckpt) self.cache_dir = path transformer_config = ckpt['transformer_config'] transformer_config['torchscript'] = True config = AutoConfig.for_model(**transformer_config) self.model = Model(ckpt['model_config'], config=config).to(self.device) self.model.load_state_dict(ckpt['model'], strict=False) self.model.eval() self.seg_vocab = ckpt.get('seg', [WORD_MIDDLE, WORD_START]) self.seg_vocab_dict = { tag: idx for idx, tag in enumerate(self.seg_vocab) } self.pos_vocab = ckpt.get('pos', []) self.ner_vocab = ckpt.get('ner', []) self.dep_vocab = ckpt.get('dep', []) self.sdp_vocab = ckpt.get('sdp', []) self.srl_vocab = [ re.sub(r'ARG(\d)', r'A\1', tag.lstrip('ARGM-')) for tag in ckpt.get('srl', []) ] self.tokenizer = AutoTokenizer.from_pretrained( path, config=self.model.transformer.config, use_fast=True) self.trie = Trie() self._model_version = ckpt.get('version', None)
parser.add_argument('--use_cuda', type=int, default=1, help='True to use cuda') parser.add_argument('--n_epochs', type=int, default=10, help='no. of epochs to run') args = parser.parse_args() # download pretrained model HF_FINETUNED_MODEL = ( "https://s3.amazonaws.com/models.huggingface.co/transfer-learning-chatbot/gpt_personachat_cache.tar.gz" # noqa ) """ Download and extract finetuned model from S3 """ resolved_archive_file = cached_path(HF_FINETUNED_MODEL) tempdir = tempfile.mkdtemp() print("extracting archive file {} to temp dir {}".format( resolved_archive_file, tempdir)) with tarfile.open(resolved_archive_file, "r:gz") as archive: archive.extractall(tempdir) # get directories output_dir = './saved_model' # save models best_model_dir = './saved_model/best' if not os.path.exists(output_dir): os.makedirs(output_dir) if not os.path.exists(best_model_dir): os.makedirs(best_model_dir) TRAIN_FILE = args.train_file
def __init__(self, path: str = 'small', device=None, **kwargs): if device is not None: if isinstance(device, torch.device): self.device = device elif isinstance(device, str): self.device = torch.device(device) elif torch.cuda.is_available(): self.device = torch.device('cuda') else: self.device = torch.device('cpu') if path in model_map or is_remote_url(path) or os.path.isfile(path): proxies = kwargs.pop("proxies", None) cache_dir = kwargs.pop("cache_dir", LTP_CACHE) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) local_files_only = kwargs.pop("local_files_only", False) path = cached_path(model_map.get(path, path), cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, extract_compressed_file=True) elif not os.path.isdir(path): raise FileNotFoundError() ckpt = torch.load(os.path.join(path, "ltp.model"), map_location=self.device) ckpt['model_config']['init'].pop('pretrained') self.cache_dir = path self.model = Model.from_params(ckpt['model_config'], config=ckpt['pretrained_config']).to( self.device) self.model.load_state_dict( ckpt['model'], strict=transformers_version < version.parse("3.1.0")) self.model.eval() # todo fp16 self.max_length = self.model.pretrained.config.max_position_embeddings self.seg_vocab = [WORD_START, WORD_MIDDLE] self.pos_vocab = ckpt['pos'] self.ner_vocab = ckpt['ner'] self.dep_vocab = ckpt['dep'] self.sdp_vocab = ckpt['sdp'] self.srl_vocab = [ re.sub(r'ARG(\d)', r'A\1', tag.lstrip('ARGM-')) for tag in ckpt['srl'] ] self.tokenizer = AutoTokenizer.from_pretrained( path, config=self.model.pretrained.config, use_fast=True) self.trie = Trie() if kwargs.pop("need_config", False): config = ckpt['model_config'] config['init']['seg']['vocab'] = self.seg_vocab config['init']['pos']['vocab'] = self.pos_vocab config['init']['ner']['vocab'] = self.ner_vocab config['init']['dep']['vocab'] = self.dep_vocab config['init']['sdp']['vocab'] = self.sdp_vocab config['init']['srl']['vocab'] = self.srl_vocab config['pretrained_config'] = ckpt['pretrained_config'] self.config = config
def main(args, checkpoint_name="best"): assert args.path is not None, '--path required for generation!' assert not args.sampling or args.nbest == args.beam, \ '--sampling requires --nbest to be equal to --beam' assert args.replace_unk is None or args.raw_text, \ '--replace-unk requires a raw text dataset (--raw-text)' if args.max_tokens is None and args.max_sentences is None: args.max_tokens = 12000 print(args) use_cuda = torch.cuda.is_available() and not args.cpu torch.manual_seed(args.seed) # Load dataset splits task = tasks.setup_task(args) task.load_dataset(args.gen_subset) print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset)))) args.taskobj = task sys.argv = sys.argv[:1] import tensorflow as tf gpus = tf.config.experimental.list_physical_devices('GPU') if gpus: try: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) except RuntimeError as e: print(e) bleurt_scorer = score.BleurtScorer(os.path.join( cached_path( "https://storage.googleapis.com/bleurt-oss/bleurt-base-128.zip", extract_compressed_file=True ), "bleurt-base-128" )) # Set dictionaries #src_dict = task.source_dictionary tgt_dict = task.target_dictionary dict = tgt_dict # Load decoding strategy strategy = strategies.setup_strategy(args) # Load ensemble if args.path.startswith("nsml://"): print("| loading nsml checkpoint", args.path) import nsml session = args.path.replace("nsml://", "") model = task.build_model(args) def load(dir_path): state = torch.load(os.path.join(dir_path, 'best.pt')) state_dict = state["model"] model.load_state_dict(state_dict) print("loaded") nsml.load(args.checkpoint_name, load_fn=load, session=session) models = [model.cuda()] elif args.path == "pretrain": from nsml import DATASET_PATH from fairseq import checkpoint_utils data_token = "en-de" pretrained_path = "{}/train/pretrained_models/maskPredict_{}/checkpoint_best.pt".format(DATASET_PATH, data_token.split(".")[-1].replace("-", "_")) print("| loading", pretrained_path) model = task.build_model(args) state = checkpoint_utils.load_checkpoint_to_cpu(pretrained_path) model.load_state_dict(state["model"], strict=True) models = [model.cuda()] elif args.path.startswith("wb://"): print("| loading wb checkpoint", args.path) import wandb wandb.restore("best.pt", args.path.replace("wb://", ""), root="/tmp/") assert os.path.exists("/tmp/best.pt") state = torch.load("/tmp/best.pt") model = task.build_model(args) model.load_state_dict(state["model"]) models = [model.cuda()] elif args.path.startswith("http://"): print("| loading http checkpoint", args.path) url = "http://trains.deeplearn.org:8081/{}".format(args.path.replace("http://", "")) os.system("curl -o /tmp/model.pt {}".format(url)) state = torch.load("/tmp/model.pt") model = task.build_model(args) model.load_state_dict(state["model"]) models = [model.cuda()] else: print('| loading model(s) from {}'.format(args.path)) models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides)) models = [model.cuda() for model in models] # Optimize ensemble for generation for model in models: model.make_generation_fast_( beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, need_attn=args.print_alignment, ) if args.fp16: model.half() # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(args.replace_unk) # Load dataset (possibly sharded) itr = task.get_batch_iterator( dataset=task.dataset(args.gen_subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models] ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=8, num_shards=args.num_shards, shard_id=args.shard_id, ).next_epoch_itr(shuffle=False) results = [] scorer = pybleu.PyBleuScorer() num_sentences = 0 has_target = True timer = TimeMeter() with progress_bar.build_progress_bar(args, itr) as t: translations = generate_batched_itr(t, strategy, models, tgt_dict, length_beam_size=args.length_beam, use_gold_target_len=args.gold_target_len) for sample_id, src_tokens, target_tokens, hypos in translations: has_target = target_tokens is not None target_tokens = target_tokens.int().cpu() if has_target else None # Either retrieve the original sentences or regenerate them from tokens. if align_dict is not None: src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id) target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id) else: src_str = dict.string(src_tokens, args.remove_bpe) if args.dehyphenate: src_str = dehyphenate(src_str) if has_target: target_str = dict.string(target_tokens, args.remove_bpe, escape_unk=True) if args.dehyphenate: target_str = dehyphenate(target_str) if not args.quiet or True: # print('S-{}\t{}'.format(sample_id, src_str)) if has_target: # print('T-{}\t{}'.format(sample_id, target_str)) hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypos.int().cpu(), src_str=src_str, alignment= None, align_dict=align_dict, tgt_dict=dict, remove_bpe=args.remove_bpe, ) if args.dehyphenate: hypo_str = dehyphenate(hypo_str) if not args.quiet: print('H-{}\t{}'.format(sample_id, hypo_str)) if args.print_alignment: print('A-{}\t{}'.format( sample_id, ' '.join(map(lambda x: str(utils.item(x)), alignment)) )) # print() # Score only the top hypothesis if has_target: if align_dict is not None or args.remove_bpe is not None: # Convert back to tokens for evaluation with unk replacement and/or without BPE target_tokens = tgt_dict.encode_line(target_str, add_if_not_exist=True) results.append((target_str, hypo_str)) num_sentences += 1 if has_target: print('Time = {}'.format(timer.elapsed_time)) ref, out = zip(*results) from fairseq.criterions.lib_sbleu import smoothed_bleu sbleu = np.mean([smoothed_bleu(p[0].split(), p[1].split()) for p in results]) print("| SBLEU = {:.2f}".format(sbleu)) bleurt_scores = bleurt_scorer.score([p[0] for p in results], [p[1] for p in results]) print("| BLEURT = {:.4f}".format(np.mean((np.array(bleurt_scores))))) print('| Generate {} with beam={}: BLEU4 = {:2.2f}, '.format(args.gen_subset, args.length_beam, scorer.score(ref, out)))
def main(args): global score assert args.path is not None, '--path required for generation!' assert not args.sampling or args.nbest == args.beam, \ '--sampling requires --nbest to be equal to --beam' assert args.replace_unk is None or args.raw_text, \ '--replace-unk requires a raw text dataset (--raw-text)' utils.import_user_module(args) if args.max_tokens is None and args.max_sentences is None: args.max_tokens = 12000 print(args) use_cuda = torch.cuda.is_available() and not args.cpu if args.reward == "bleurt" or args.eval_bleurt: sys.argv = sys.argv[:1] import tensorflow as tf gpus = tf.config.experimental.list_physical_devices('GPU') if gpus: try: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) except RuntimeError as e: print(e) bleurt_scorer = score.BleurtScorer( os.path.join( cached_path( "https://storage.googleapis.com/bleurt-oss/bleurt-base-128.zip", extract_compressed_file=True), "bleurt-base-128")) # Load dataset splits task = tasks.setup_task(args) task.load_dataset(args.gen_subset) # Set dictionaries try: src_dict = getattr(task, 'source_dictionary', None) except NotImplementedError: src_dict = None tgt_dict = task.target_dictionary # Load ensemble print('| loading model(s) from {}'.format(args.path)) if args.path.startswith("nsml://"): # NSML session = args.path.replace("nsml://", "") model = task.build_model(args) if ".pt" in session: session = session.replace(".pt", "") session, checkpoint_name = session.rsplit("/", 1) else: checkpoint_name = "best" if "-" in checkpoint_name: start, end = checkpoint_name.replace("epoch", "").split("-") checkpoints = [ "epoch{}".format(i) for i in range(int(start), int(end) + 1) ] print("| checkpoint average:", checkpoints) state_dict = None def load(dir_path): nonlocal state_dict, checkpoints state = torch.load(os.path.join(dir_path, 'best.pt')) model_state = state["model"] for k in model_state: model_state[k] = model_state[k] / float(len(checkpoints)) if state_dict is None: state_dict = model_state else: for k in state_dict: state_dict[k] += model_state[k] print("checkpoint loaded") for checkpoint_name in checkpoints: nsml.load(checkpoint_name, load_fn=load, session=session) model.load_state_dict(state_dict) else: def load(dir_path): state = torch.load(os.path.join(dir_path, 'best.pt')) state_dict = state["model"] model.load_state_dict(state_dict) print("loaded") nsml.load(checkpoint_name, load_fn=load, session=session) models = [model.cuda()] elif "-" in args.path: model = task.build_model(args) print("loading model from", args.path) state_dict = None dir_path = os.path.dirname(args.path) fn = os.path.basename(args.path) if "-" in fn: start, end = fn.replace("epoch", "").replace(".pt", "").split("-") checkpoint_fns = [ "epoch{}.pt".format(i) for i in range(int(start), int(end) + 1) ] else: checkpoint_fns = [fn] for fn in checkpoint_fns: state = torch.load(os.path.join(dir_path, fn)) model_state = state["model"] for k in model_state: model_state[k] = model_state[k] / float(len(checkpoint_fns)) if state_dict is None: state_dict = model_state else: for k in state_dict: state_dict[k] += model_state[k] print("checkpoint loaded") model.load_state_dict(state_dict) models = [model.cuda()] else: model = task.build_model(args) state = torch.load(args.path) model_state = state["model"] model.load_state_dict(model_state) models = [model.cuda()] # Optimize ensemble for generation for model in models: model.make_generation_fast_( beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, need_attn=args.print_alignment, ) if args.fp16: model.half() if use_cuda: model.cuda() # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(args.replace_unk) # Load dataset (possibly sharded) itr = task.get_batch_iterator( dataset=task.dataset(args.gen_subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models]), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) # Initialize generator gen_timer = StopwatchMeter() generator = task.build_generator(args) # Generate and compute BLEU score # if args.sacrebleu: # scorer = bleu.SacrebleuScorer() # else: # scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) scorer = pybleu.PyBleuScorer() num_sentences = 0 has_target = True results = [] best_rank_list = [] if args.save_path: outf = open(args.save_path, "w") total_n = 0 with progress_bar.build_progress_bar(args, itr) as t: wps_meter = TimeMeter() for sample in t: sample = utils.move_to_cuda(sample) if use_cuda else sample if 'net_input' not in sample: continue prefix_tokens = None if args.prefix_size > 0: prefix_tokens = sample['target'][:, :args.prefix_size] gen_timer.start() hypos = task.inference_step(generator, models, sample, prefix_tokens) num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos) gen_timer.stop(num_generated_tokens) hypo_target_pairs = [] for i, sample_id in enumerate(sample['id'].tolist()): total_n += 1 has_target = sample['target'] is not None # Remove padding src_tokens = utils.strip_pad( sample['net_input']['src_tokens'][i, :], tgt_dict.pad()) target_tokens = None if has_target: target_tokens = utils.strip_pad( sample['target'][i, :], tgt_dict.pad()).int().cpu() # Either retrieve the original sentences or regenerate them from tokens. if align_dict is not None: src_str = task.dataset( args.gen_subset).src.get_original_text(sample_id) target_str = task.dataset( args.gen_subset).tgt.get_original_text(sample_id) else: if src_dict is not None: src_str = src_dict.string(src_tokens, args.remove_bpe) else: src_str = "" if has_target: target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True) if not args.quiet: if src_dict is not None: print('S-{}\t{}'.format(sample_id, src_str)) if has_target: print('T-{}\t{}'.format(sample_id, target_str)) if args.reward_sample or args.reward_check: # Get sample hypo_strs = [] rewards = [] for j, hypo in enumerate(hypos[i]): hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo['tokens'].int().cpu(), src_str=src_str, alignment=None, align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=None, ) hypo_strs.append(hypo_str) if args.reward == "sbleu": for hypo_str in hypo_strs: hypo_str_nobpe = hypo_str.replace("@@ ", "") rewards.append( compute_reward(hypo_str_nobpe, target_str)) best_idx = np.array(rewards).argmax() if args.reward_check: best_rank_list.append(best_idx) if args.save_path: if args.output_all: for hypo_i in range(len(hypo_strs)): outf.write("{} | {:.4f} | {}\n".format( sample_id, rewards[hypo_i], hypo_strs[hypo_i])) else: outf.write("{} | {}\n".format( sample_id, hypo_strs[best_idx])) else: if args.output_all: for hypo_i in range(len(hypo_strs)): print("{} | {:.4f} | {}".format( sample_id, rewards[hypo_i], hypo_strs[hypo_i])) else: print("{} | {}".format(sample_id, hypo_strs[best_idx])) sys.stdout.flush() elif args.reward == "bleurt": hypo_target_pairs.append( (sample_id, target_str, hypo_strs)) else: # Normal translation # Process top predictions for j, hypo in enumerate(hypos[i][:args.nbest]): hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo['tokens'].int().cpu(), src_str=src_str, alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None, align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=args.remove_bpe, ) if not args.quiet: print('H-{}\t{}\t{}'.format( sample_id, hypo['score'], hypo_str)) print('P-{}\t{}'.format( sample_id, ' '.join( map( lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist(), )))) if args.print_alignment: print('A-{}\t{}'.format( sample_id, ' '.join( map(lambda x: str(utils.item(x)), alignment)))) # Score only the top hypothesis results.append( (sample_id, target_str, hypo_str, float(hypo["positional_scores"].mean()))) if has_target and j == 0 and not args.reward_sample: pass # if align_dict is not None or args.remove_bpe is not None: # Convert back to tokens for evaluation with unk replacement and/or without BPE # target_tokens = tgt_dict.encode_line(target_str, add_if_not_exist=True) # if args.save_path: # outf.write("{} | {}\n".format(sample_id, hypo_str)) # if j == 0 and not args.no_eval: # results.append((sample_id, target_str, hypo_str)) # if hasattr(scorer, 'add_string'): # scorer.add_string(target_str, hypo_str) # else: # scorer.add(target_tokens, hypo_tokens) if args.save_amount > 0 and total_n > args.save_amount: break if args.reward_sample and bool(hypo_target_pairs): hypo_batch = [] target_batch = [] for _, target, hypo_strs in hypo_target_pairs: hypo_batch.extend( [h.replace("@@ ", "") for h in hypo_strs]) target_batch.extend([target_str] * len(hypo_strs)) rewards = np.array( bleurt_scorer.score(target_batch, hypo_batch)) base_i = 0 for sample_id, _, hypo_strs in hypo_target_pairs: start = base_i end = base_i + len(hypo_strs) best_idx = rewards[start:end].argmax() if args.save_path: if args.output_all: for idx in range(start, end): outf.write("{} | {:.4f} | {}\n".format( sample_id, float(rewards[idx]), hypo_strs[idx - start])) else: outf.write("{} | {}\n".format( sample_id, hypo_strs[best_idx])) else: if args.output_all: for idx in range(start, end): print("{} | {:.4f} | {}".format( sample_id, float(rewards[idx]), hypo_strs[idx - start])) else: print("{} | {}".format(sample_id, hypo_strs[best_idx])) sys.stdout.flush() base_i += len(hypo_strs) wps_meter.update(num_generated_tokens) t.log({'wps': round(wps_meter.avg)}) num_sentences += sample['nsentences'] print( '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)' .format(num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) if args.save_path and not args.reward_check and not args.reward_sample: results.sort() for sample_id, tgt, hyp, score in results: outf.write("{}\t{}\t{}\n".format(sample_id, score, hyp)) print("results saved to", args.save_path) if args.reward_check: print("avg ranking of the best sample:", np.array(best_rank_list).mean()) print("ratio of best sample ranked in the top:", (np.array(best_rank_list) == 0).mean()) if has_target and not args.reward_sample and not args.reward_check and not args.no_eval: _, ref, out, _ = zip(*results) from fairseq.criterions.lib_sbleu import smoothed_bleu sbleu = np.mean( [smoothed_bleu(p[1].split(), p[2].split()) for p in results]) print("| SBLEU = {:.2f}".format(sbleu)) if args.eval_bleurt: bleurt_scores = bleurt_scorer.score( references=[p[1] for p in results], candidates=[p[2] for p in results]) print("| BLEURT = {:.4f}".format(np.mean( (np.array(bleurt_scores))))) print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.score(ref, out))) return scorer
def download_pretrained_model(): resolved_archive_file = cached_path(HUGGINGFACE_MODEL) tempdir = tempfile.mkdtemp() with tarfile.open(resolved_archive_file, 'r:gz') as archive: archive.extractall(tempdir) return tempdir
def train(args): # Loading tokenizer, pretrained model and optimizer logger.info("Prepare tokenizer, model and optimizer") tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False) # Let's use a pre-defined tokenizer logger.info("Create model from class %s and configuration %s", args.finetuning_model_class, os.path.join(args.model_checkpoint, CONFIG_NAME)) ModelClass = getattr(importlib.import_module("finetuning_model"), args.finetuning_model_class) pretraining_args = torch.load(cached_path(os.path.join(args.model_checkpoint, CONFIG_NAME))) model = ModelClass(config=pretraining_args, fine_tuning_config=args).to(args.device) logger.info("Load pretrained weigths from %s", os.path.join(args.model_checkpoint, WEIGHTS_NAME)) state_dict = torch.load(cached_path(os.path.join(args.model_checkpoint, WEIGHTS_NAME)), map_location='cpu') incompatible_keys = model.load_state_dict(state_dict, strict=False) logger.info("Parameters discarded from the pretrained model: %s", incompatible_keys.unexpected_keys) logger.info("Parameters added in the adaptation model: %s", incompatible_keys.missing_keys) model.tie_weights() optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) logger.info("Model has %s parameters", sum(p.numel() for p in model.parameters() if p.requires_grad)) logger.info("Prepare datasets") loaders = get_data_loaders(args, tokenizer, pretraining_args.num_max_positions, clf_token=tokenizer.vocab['[CLS]']) train_loader, val_loader, train_sampler, valid_sampler = loaders # Training function and trainer def update(engine, batch): model.train() batch, labels = (t.to(args.device) for t in batch) inputs = batch.transpose(0, 1).contiguous() # to shape [seq length, batch] _, (clf_loss, lm_loss) = model(inputs, clf_tokens_mask=(inputs == tokenizer.vocab['[CLS]']), clf_labels=labels, lm_labels=inputs, padding_mask=(batch == tokenizer.vocab['[PAD]'])) loss = (max(0, args.clf_loss_coef) * clf_loss + max(0, args.lm_loss_coef) * lm_loss) / args.gradient_accumulation_steps loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) if engine.state.iteration % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() return loss.item() trainer = Engine(update) # Evaluation function and evaluator (evaluator output is the input of the metrics) def inference(engine, batch): model.eval() with torch.no_grad(): batch, labels = (t.to(args.device) for t in batch) inputs = batch.transpose(0, 1).contiguous() # to shape [seq length, batch] _, clf_logits = model(inputs, clf_tokens_mask=(inputs == tokenizer.vocab['[CLS]']), padding_mask=(batch == tokenizer.vocab['[PAD]'])) return clf_logits, labels evaluator = Engine(inference) # Attach evaluation to trainer: we evaluate at the end of each epoch and every 'eval_every' iterations if needed trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader)) if args.eval_every > 0: trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: evaluator.run( val_loader) if engine.state.iteration % args.eval_every == 0 else None) if args.n_epochs < 1: trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader)) # Learning rate schedule: linearly warm-up to lr and then to zero scheduler = PiecewiseLinear(optimizer, 'lr', [(0, 0.0), (args.n_warmup, args.lr), (len(train_loader) * args.n_epochs, 0.0)]) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # Prepare metrics - note how we average distributed metrics using average_distributed_scalar metrics = {"accuracy": Accuracy()} metrics.update({"average_accuracy": MetricsLambda(average_distributed_scalar, metrics["accuracy"], args)}) for name, metric in metrics.items(): metric.attach(evaluator, name) # On the main process: add progress bar, tensorboard, checkpoints and save model and configuration before we start to train if args.local_rank in [-1, 0]: checkpoint_handler, tb_logger = add_logging_and_checkpoint_saving(trainer, evaluator, metrics, model, optimizer, args, prefix="finetune_") # Run the training trainer.run(train_loader, max_epochs=args.n_epochs) # On the main process: close tensorboard logger and rename the last checkpoint for easy re-loading if args.local_rank in [-1, 0] and args.n_epochs > 0: os.rename(checkpoint_handler._saved[-1][1][-1], os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME)) tb_logger.close()