def __init__(self, opt, bert_config=None, initial_from_local=False): super(SANBertNetwork, self).__init__() self.dropout_list = nn.ModuleList() if opt['encoder_type'] not in EncoderModelType._value2member_map_: raise ValueError("encoder_type is out of pre-defined types") self.encoder_type = opt['encoder_type'] self.preloaded_config = None literal_encoder_type = EncoderModelType(self.encoder_type).name.lower() config_class, model_class, tokenizer_class = MODEL_CLASSES[ literal_encoder_type] self.preloaded_config = config_class.from_dict( opt) # load config from opt self.preloaded_config.output_hidden_states = True # return all hidden states self.bert = model_class(self.preloaded_config) hidden_size = self.bert.config.hidden_size if opt.get('dump_feature', False): self.opt = opt return if opt['update_bert_opt'] > 0: for p in self.bert.parameters(): p.requires_grad = False task_def_list = opt['task_def_list'] self.task_def_list = task_def_list self.decoder_opt = [] self.task_types = [] for task_id, task_def in enumerate(task_def_list): self.decoder_opt.append( generate_decoder_opt(task_def.enable_san, opt['answer_opt'])) self.task_types.append(task_def.task_type) # create output header self.scoring_list = nn.ModuleList() self.dropout_list = nn.ModuleList() for task_id in range(len(task_def_list)): task_def: TaskDef = task_def_list[task_id] lab = task_def.n_class decoder_opt = self.decoder_opt[task_id] task_type = self.task_types[task_id] task_dropout_p = opt[ 'dropout_p'] if task_def.dropout_p is None else task_def.dropout_p dropout = DropoutWrapper(task_dropout_p, opt['vb_dropout']) self.dropout_list.append(dropout) task_obj = tasks.get_task_obj(task_def) if task_obj is not None: out_proj = task_obj.train_build_task_layer(decoder_opt, hidden_size, lab, opt, prefix='answer', dropout=dropout) elif task_type == TaskType.Span: assert decoder_opt != 1 out_proj = nn.Linear(hidden_size, 2) elif task_type == TaskType.SeqenceLabeling: out_proj = nn.Linear(hidden_size, lab) elif task_type == TaskType.MaskLM: if opt['encoder_type'] == EncoderModelType.ROBERTA: # TODO: xiaodl out_proj = MaskLmHeader( self.bert.embeddings.word_embeddings.weight) else: out_proj = MaskLmHeader( self.bert.embeddings.word_embeddings.weight) else: if decoder_opt == 1: out_proj = SANClassifier(hidden_size, hidden_size, lab, opt, prefix='answer', dropout=dropout) else: out_proj = nn.Linear(hidden_size, lab) self.scoring_list.append(out_proj) self.opt = opt self._my_init() # if not loading from local, loading model weights from pre-trained model, after initialization if not initial_from_local: config_class, model_class, tokenizer_class = MODEL_CLASSES[ literal_encoder_type] self.bert = model_class.from_pretrained( opt['init_checkpoint'], config=self.preloaded_config)
def main(): # set up dist device = torch.device("cuda") if args.local_rank > -1: device = initialize_distributed(args) elif torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") opt = vars(args) # update data dir opt['data_dir'] = data_dir batch_size = args.batch_size print_message(logger, 'Launching the MT-DNN training') #return tasks = {} task_def_list = [] dropout_list = [] printable = args.local_rank in [-1, 0] train_datasets = [] for dataset in args.train_datasets: prefix = dataset.split('_')[0] if prefix in tasks: continue task_id = len(tasks) tasks[prefix] = task_id task_def = task_defs.get_task_def(prefix) task_def_list.append(task_def) train_path = os.path.join(data_dir, '{}_train.json'.format(dataset)) print_message(logger, 'Loading {} as task {}'.format(train_path, task_id)) train_data_set = SingleTaskDataset(train_path, True, maxlen=args.max_seq_len, task_id=task_id, task_def=task_def, printable=printable) train_datasets.append(train_data_set) train_collater = Collater(dropout_w=args.dropout_w, encoder_type=encoder_type, soft_label=args.mkd_opt > 0, max_seq_len=args.max_seq_len, do_padding=args.do_padding) multi_task_train_dataset = MultiTaskDataset(train_datasets) if args.local_rank != -1: multi_task_batch_sampler = DistMultiTaskBatchSampler( train_datasets, args.batch_size, args.mix_opt, args.ratio, rank=args.local_rank, world_size=args.world_size) else: multi_task_batch_sampler = MultiTaskBatchSampler( train_datasets, args.batch_size, args.mix_opt, args.ratio, bin_on=args.bin_on, bin_size=args.bin_size, bin_grow_ratio=args.bin_grow_ratio) multi_task_train_data = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler, collate_fn=train_collater.collate_fn, pin_memory=args.cuda) opt['task_def_list'] = task_def_list dev_data_list = [] test_data_list = [] test_collater = Collater(is_train=False, encoder_type=encoder_type, max_seq_len=args.max_seq_len, do_padding=args.do_padding) for dataset in args.test_datasets: prefix = dataset.split('_')[0] task_def = task_defs.get_task_def(prefix) task_id = tasks[prefix] task_type = task_def.task_type data_type = task_def.data_type dev_path = os.path.join(data_dir, '{}_dev.json'.format(dataset)) dev_data = None if os.path.exists(dev_path): dev_data_set = SingleTaskDataset(dev_path, False, maxlen=args.max_seq_len, task_id=task_id, task_def=task_def, printable=printable) if args.local_rank != -1: dev_data_set = DistTaskDataset(dev_data_set, task_id) single_task_batch_sampler = DistSingleTaskBatchSampler( dev_data_set, args.batch_size_eval, rank=args.local_rank, world_size=args.world_size) dev_data = DataLoader(dev_data_set, batch_sampler=single_task_batch_sampler, collate_fn=test_collater.collate_fn, pin_memory=args.cuda) else: dev_data = DataLoader(dev_data_set, batch_size=args.batch_size_eval, collate_fn=test_collater.collate_fn, pin_memory=args.cuda) dev_data_list.append(dev_data) test_path = os.path.join(data_dir, '{}_test.json'.format(dataset)) test_data = None if os.path.exists(test_path): test_data_set = SingleTaskDataset(test_path, False, maxlen=args.max_seq_len, task_id=task_id, task_def=task_def, printable=printable) if args.local_rank != -1: test_data_set = DistTaskDataset(test_data_set, task_id) single_task_batch_sampler = DistSingleTaskBatchSampler( test_data_set, args.batch_size_eval, rank=args.local_rank, world_size=args.world_size) test_data = DataLoader(test_data_set, batch_sampler=single_task_batch_sampler, collate_fn=test_collater.collate_fn, pin_memory=args.cuda) else: test_data = DataLoader(test_data_set, batch_size=args.batch_size_eval, collate_fn=test_collater.collate_fn, pin_memory=args.cuda) test_data_list.append(test_data) print_message(logger, '#' * 20) print_message(logger, opt) print_message(logger, '#' * 20) # div number of grad accumulation. num_all_batches = args.epochs * len( multi_task_train_data) // args.grad_accumulation_step print_message(logger, '############# Gradient Accumulation Info #############') print_message( logger, 'number of step: {}'.format(args.epochs * len(multi_task_train_data))) print_message( logger, 'number of grad grad_accumulation step: {}'.format( args.grad_accumulation_step)) print_message(logger, 'adjusted number of step: {}'.format(num_all_batches)) print_message(logger, '############# Gradient Accumulation Info #############') init_model = args.init_checkpoint state_dict = None if os.path.exists(init_model): if encoder_type == EncoderModelType.BERT or \ encoder_type == EncoderModelType.DEBERTA or \ encoder_type == EncoderModelType.ELECTRA: state_dict = torch.load(init_model, map_location=device) config = state_dict['config'] elif encoder_type == EncoderModelType.ROBERTA or encoder_type == EncoderModelType.XLM: model_path = '{}/model.pt'.format(init_model) state_dict = torch.load(model_path, map_location=device) arch = state_dict['args'].arch arch = arch.replace('_', '-') if encoder_type == EncoderModelType.XLM: arch = "xlm-{}".format(arch) # convert model arch from data_utils.roberta_utils import update_roberta_keys from data_utils.roberta_utils import patch_name_dict state = update_roberta_keys( state_dict['model'], nlayer=state_dict['args'].encoder_layers) state = patch_name_dict(state) literal_encoder_type = EncoderModelType( opt['encoder_type']).name.lower() config_class, model_class, tokenizer_class = MODEL_CLASSES[ literal_encoder_type] config = config_class.from_pretrained(arch).to_dict() state_dict = {'state': state} else: if opt['encoder_type'] not in EncoderModelType._value2member_map_: raise ValueError("encoder_type is out of pre-defined types") literal_encoder_type = EncoderModelType( opt['encoder_type']).name.lower() config_class, model_class, tokenizer_class = MODEL_CLASSES[ literal_encoder_type] config = config_class.from_pretrained(init_model).to_dict() config['attention_probs_dropout_prob'] = args.bert_dropout_p config['hidden_dropout_prob'] = args.bert_dropout_p config['multi_gpu_on'] = opt["multi_gpu_on"] if args.num_hidden_layers > 0: config['num_hidden_layers'] = args.num_hidden_layers opt.update(config) model = MTDNNModel(opt, device=device, state_dict=state_dict, num_train_step=num_all_batches) if args.resume and args.model_ckpt: print_message(logger, 'loading model from {}'.format(args.model_ckpt)) model.load(args.model_ckpt) #### model meta str headline = '############# Model Arch of MT-DNN #############' ### print network print_message(logger, '\n{}\n{}\n'.format(headline, model.network)) # dump config config_file = os.path.join(output_dir, 'config.json') with open(config_file, 'w', encoding='utf-8') as writer: writer.write('{}\n'.format(json.dumps(opt))) writer.write('\n{}\n{}\n'.format(headline, model.network)) print_message(logger, "Total number of params: {}".format(model.total_param)) # tensorboard tensorboard = None if args.tensorboard: args.tensorboard_logdir = os.path.join(args.output_dir, args.tensorboard_logdir) tensorboard = SummaryWriter(log_dir=args.tensorboard_logdir) if args.encode_mode: for idx, dataset in enumerate(args.test_datasets): prefix = dataset.split('_')[0] test_data = test_data_list[idx] with torch.no_grad(): encoding = extract_encoding(model, test_data, use_cuda=args.cuda) torch.save( encoding, os.path.join(output_dir, '{}_encoding.pt'.format(dataset))) return for epoch in range(0, args.epochs): print_message(logger, 'At epoch {}'.format(epoch), level=1) start = datetime.now() for i, (batch_meta, batch_data) in enumerate(multi_task_train_data): batch_meta, batch_data = Collater.patch_data( device, batch_meta, batch_data) task_id = batch_meta['task_id'] model.update(batch_meta, batch_data) if (model.updates) % ( args.log_per_updates) == 0 or model.updates == 1: ramaining_time = str( (datetime.now() - start) / (i + 1) * (len(multi_task_train_data) - i - 1)).split('.')[0] if args.adv_train and args.debug: debug_info = ' basic loss[%.5f] adv loss[%.5f] emb val[%.8f] noise val[%.8f] noise grad val[%.8f] no proj noise[%.8f] ' % ( model.basic_loss.avg, model.adv_loss.avg, model.emb_val.avg, model.noise_val.avg, model.noise_grad_val.avg, model.no_proj_noise_val.avg) else: debug_info = ' ' print_message( logger, 'Task [{0:2}] updates[{1:6}] train loss[{2:.5f}]{3}remaining[{4}]' .format(task_id, model.updates, model.train_loss.avg, debug_info, ramaining_time)) if args.tensorboard: tensorboard.add_scalar('train/loss', model.train_loss.avg, global_step=model.updates) if args.save_per_updates_on and ( (model.local_updates) % (args.save_per_updates * args.grad_accumulation_step) == 0) and args.local_rank in [-1, 0]: model_file = os.path.join( output_dir, 'model_{}_{}.pt'.format(epoch, model.updates)) evaluation(model, args.test_datasets, dev_data_list, task_defs, output_dir, epoch, n_updates=args.save_per_updates, with_label=True, tensorboard=tensorboard, glue_format_on=args.glue_format_on, test_on=False, device=device, logger=logger) evaluation(model, args.test_datasets, test_data_list, task_defs, output_dir, epoch, n_updates=args.save_per_updates, with_label=False, tensorboard=tensorboard, glue_format_on=args.glue_format_on, test_on=True, device=device, logger=logger) print_message(logger, 'Saving mt-dnn model to {}'.format(model_file)) model.save(model_file) evaluation(model, args.test_datasets, dev_data_list, task_defs, output_dir, epoch, with_label=True, tensorboard=tensorboard, glue_format_on=args.glue_format_on, test_on=False, device=device, logger=logger) evaluation(model, args.test_datasets, test_data_list, task_defs, output_dir, epoch, with_label=False, tensorboard=tensorboard, glue_format_on=args.glue_format_on, test_on=True, device=device, logger=logger) print_message(logger, '[new test scores at {} saved.]'.format(epoch)) if args.local_rank in [-1, 0]: model_file = os.path.join(output_dir, 'model_{}.pt'.format(epoch)) model.save(model_file) if args.tensorboard: tensorboard.close()
def main(): logger.info('Launching the MT-DNN training') opt = vars(args) # update data dir opt['data_dir'] = data_dir batch_size = args.batch_size tasks = {} task_def_list = [] dropout_list = [] train_datasets = [] for dataset in args.train_datasets: prefix = dataset.split('_')[0] if prefix in tasks: continue task_id = len(tasks) tasks[prefix] = task_id task_def = task_defs.get_task_def(prefix) task_def_list.append(task_def) train_path = os.path.join(data_dir, '{}_train.json'.format(dataset)) logger.info('Loading {} as task {}'.format(train_path, task_id)) train_data_set = SingleTaskDataset(train_path, True, maxlen=args.max_seq_len, task_id=task_id, task_def=task_def) train_datasets.append(train_data_set) train_collater = Collater(dropout_w=args.dropout_w, encoder_type=encoder_type, soft_label=args.mkd_opt > 0) multi_task_train_dataset = MultiTaskDataset(train_datasets) multi_task_batch_sampler = MultiTaskBatchSampler(train_datasets, args.batch_size, args.mix_opt, args.ratio) multi_task_train_data = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler, collate_fn=train_collater.collate_fn, pin_memory=args.cuda) opt['task_def_list'] = task_def_list dev_data_list = [] test_data_list = [] test_collater = Collater(is_train=False, encoder_type=encoder_type) for dataset in args.test_datasets: prefix = dataset.split('_')[0] task_def = task_defs.get_task_def(prefix) task_id = tasks[prefix] task_type = task_def.task_type data_type = task_def.data_type dev_path = os.path.join(data_dir, '{}_dev.json'.format(dataset)) dev_data = None if os.path.exists(dev_path): dev_data_set = SingleTaskDataset(dev_path, False, maxlen=args.max_seq_len, task_id=task_id, task_def=task_def) dev_data = DataLoader(dev_data_set, batch_size=args.batch_size_eval, collate_fn=test_collater.collate_fn, pin_memory=args.cuda) dev_data_list.append(dev_data) test_path = os.path.join(data_dir, '{}_test.json'.format(dataset)) test_data = None if os.path.exists(test_path): test_data_set = SingleTaskDataset(test_path, False, maxlen=args.max_seq_len, task_id=task_id, task_def=task_def) test_data = DataLoader(test_data_set, batch_size=args.batch_size_eval, collate_fn=test_collater.collate_fn, pin_memory=args.cuda) test_data_list.append(test_data) logger.info('#' * 20) logger.info(opt) logger.info('#' * 20) # div number of grad accumulation. num_all_batches = args.epochs * len( multi_task_train_data) // args.grad_accumulation_step logger.info('############# Gradient Accumulation Info #############') logger.info('number of step: {}'.format(args.epochs * len(multi_task_train_data))) logger.info('number of grad grad_accumulation step: {}'.format( args.grad_accumulation_step)) logger.info('adjusted number of step: {}'.format(num_all_batches)) logger.info('############# Gradient Accumulation Info #############') init_model = args.init_checkpoint state_dict = None if os.path.exists(init_model): state_dict = torch.load(init_model) config = state_dict['config'] else: if opt['encoder_type'] not in EncoderModelType._value2member_map_: raise ValueError("encoder_type is out of pre-defined types") literal_encoder_type = EncoderModelType( opt['encoder_type']).name.lower() config_class, model_class, tokenizer_class = MODEL_CLASSES[ literal_encoder_type] config = config_class.from_pretrained( init_model, output_hidden_states=True).to_dict( ) # change here to enable multi-layer output config['output_hidden_states'] = True config['attention_probs_dropout_prob'] = args.bert_dropout_p config['hidden_dropout_prob'] = args.bert_dropout_p config['multi_gpu_on'] = opt["multi_gpu_on"] if args.num_hidden_layers != -1: config['num_hidden_layers'] = args.num_hidden_layers opt.update(config) model = MTDNNModel(opt, state_dict=state_dict, num_train_step=num_all_batches) if args.resume and args.model_ckpt: logger.info('loading model from {}'.format(args.model_ckpt)) model.load(args.model_ckpt) #### model meta str headline = '############# Model Arch of MT-DNN #############' ### print network logger.info('\n{}\n{}\n'.format(headline, model.network)) # dump config config_file = os.path.join(output_dir, 'config.json') with open(config_file, 'w', encoding='utf-8') as writer: writer.write('{}\n'.format(json.dumps(opt))) writer.write('\n{}\n{}\n'.format(headline, model.network)) logger.info("Total number of params: {}".format(model.total_param)) # tensorboard if args.tensorboard: args.tensorboard_logdir = os.path.join(args.output_dir, args.tensorboard_logdir) tensorboard = SummaryWriter(log_dir=args.tensorboard_logdir) if args.encode_mode: for idx, dataset in enumerate(args.test_datasets): prefix = dataset.split('_')[0] test_data = test_data_list[idx] with torch.no_grad(): encoding = extract_encoding(model, test_data, use_cuda=args.cuda) torch.save( encoding, os.path.join(output_dir, '{}_encoding.pt'.format(dataset))) return for epoch in range(0, args.epochs): logger.warning('At epoch {}'.format(epoch)) start = datetime.now() for i, (batch_meta, batch_data) in enumerate(multi_task_train_data): batch_meta, batch_data = Collater.patch_data( args.cuda, batch_meta, batch_data) task_id = batch_meta['task_id'] model.update(batch_meta, batch_data) if (model.local_updates) % (args.log_per_updates * args.grad_accumulation_step ) == 0 or model.local_updates == 1: ramaining_time = str( (datetime.now() - start) / (i + 1) * (len(multi_task_train_data) - i - 1)).split('.')[0] logger.info( 'Task [{0:2}] updates[{1:6}] train loss[{2:.5f}] remaining[{3}]' .format(task_id, model.updates, model.train_loss.avg, ramaining_time)) if args.tensorboard: tensorboard.add_scalar('train/loss', model.train_loss.avg, global_step=model.updates) if args.save_per_updates_on and ( (model.local_updates) % (args.save_per_updates * args.grad_accumulation_step) == 0): model_file = os.path.join( output_dir, 'model_{}_{}.pt'.format(epoch, model.updates)) logger.info('Saving mt-dnn model to {}'.format(model_file)) model.save(model_file) for idx, dataset in enumerate(args.test_datasets): prefix = dataset.split('_')[0] task_def = task_defs.get_task_def(prefix) label_dict = task_def.label_vocab dev_data = dev_data_list[idx] if dev_data is not None: with torch.no_grad(): dev_metrics, dev_predictions, scores, golds, dev_ids = eval_model( model, dev_data, metric_meta=task_def.metric_meta, use_cuda=args.cuda, label_mapper=label_dict, task_type=task_def.task_type) for key, val in dev_metrics.items(): if args.tensorboard: tensorboard.add_scalar('dev/{}/{}'.format( dataset, key), val, global_step=epoch) if isinstance(val, str): logger.warning( 'Task {0} -- epoch {1} -- Dev {2}:\n {3}'.format( dataset, epoch, key, val)) else: logger.warning( 'Task {0} -- epoch {1} -- Dev {2}: {3:.3f}'.format( dataset, epoch, key, val)) score_file = os.path.join( output_dir, '{}_dev_scores_{}.json'.format(dataset, epoch)) results = { 'metrics': dev_metrics, 'predictions': dev_predictions, 'uids': dev_ids, 'scores': scores } dump(score_file, results) if args.glue_format_on: from experiments.glue.glue_utils import submit official_score_file = os.path.join( output_dir, '{}_dev_scores_{}.tsv'.format(dataset, epoch)) submit(official_score_file, results, label_dict) # test eval test_data = test_data_list[idx] if test_data is not None: with torch.no_grad(): test_metrics, test_predictions, scores, golds, test_ids = eval_model( model, test_data, metric_meta=task_def.metric_meta, use_cuda=args.cuda, with_label=False, label_mapper=label_dict, task_type=task_def.task_type) score_file = os.path.join( output_dir, '{}_test_scores_{}.json'.format(dataset, epoch)) results = { 'metrics': test_metrics, 'predictions': test_predictions, 'uids': test_ids, 'scores': scores } dump(score_file, results) if args.glue_format_on: from experiments.glue.glue_utils import submit official_score_file = os.path.join( output_dir, '{}_test_scores_{}.tsv'.format(dataset, epoch)) submit(official_score_file, results, label_dict) logger.info('[new test scores saved.]') model_file = os.path.join(output_dir, 'model_{}.pt'.format(epoch)) model.save(model_file) if args.tensorboard: tensorboard.close()
def __init__(self, opt, bert_config=None, initial_from_local=False): super(SANBertNetwork, self).__init__() self.dropout_list = nn.ModuleList() if opt["encoder_type"] not in EncoderModelType._value2member_map_: raise ValueError("encoder_type is out of pre-defined types") self.encoder_type = opt["encoder_type"] self.preloaded_config = None literal_encoder_type = EncoderModelType(self.encoder_type).name.lower() config_class, model_class, _ = MODEL_CLASSES[literal_encoder_type] if not initial_from_local: # self.bert = model_class.from_pretrained(opt['init_checkpoint'], config=self.preloaded_config) self.bert = model_class.from_pretrained( opt["init_checkpoint"], cache_dir=opt["transformer_cache"]) else: self.preloaded_config = config_class.from_dict( opt) # load config from opt self.preloaded_config.output_hidden_states = ( True # return all hidden states ) self.bert = model_class(self.preloaded_config) hidden_size = self.bert.config.hidden_size if opt.get("dump_feature", False): self.config = opt return if opt["update_bert_opt"] > 0: for p in self.bert.parameters(): p.requires_grad = False task_def_list = opt["task_def_list"] self.task_def_list = task_def_list self.decoder_opt = [] self.task_types = [] for task_id, task_def in enumerate(task_def_list): self.decoder_opt.append( generate_decoder_opt(task_def.enable_san, opt["answer_opt"])) self.task_types.append(task_def.task_type) # create output header self.scoring_list = nn.ModuleList() self.dropout_list = nn.ModuleList() for task_id in range(len(task_def_list)): task_def: TaskDef = task_def_list[task_id] lab = task_def.n_class decoder_opt = self.decoder_opt[task_id] task_type = self.task_types[task_id] task_dropout_p = (opt["dropout_p"] if task_def.dropout_p is None else task_def.dropout_p) dropout = DropoutWrapper(task_dropout_p, opt["vb_dropout"]) self.dropout_list.append(dropout) task_obj = tasks.get_task_obj(task_def) if task_obj is not None: # Move this to task_obj self.pooler = Pooler(hidden_size, dropout_p=opt["dropout_p"], actf=opt["pooler_actf"]) out_proj = task_obj.train_build_task_layer(decoder_opt, hidden_size, lab, opt, prefix="answer", dropout=dropout) elif task_type == TaskType.Span: assert decoder_opt != 1 out_proj = nn.Linear(hidden_size, 2) elif task_type == TaskType.SpanYN: assert decoder_opt != 1 out_proj = nn.Linear(hidden_size, 2) elif task_type == TaskType.SeqenceLabeling: out_proj = nn.Linear(hidden_size, lab) # elif task_type == TaskType.MaskLM: # if opt["encoder_type"] == EncoderModelType.ROBERTA: # # TODO: xiaodl # out_proj = MaskLmHeader(self.bert.embeddings.word_embeddings.weight) # else: # out_proj = MaskLmHeader(self.bert.embeddings.word_embeddings.weight) elif task_type == TaskType.SeqenceGeneration: # use orginal header out_proj = None elif task_type == TaskType.ClozeChoice: self.pooler = Pooler(hidden_size, dropout_p=opt["dropout_p"], actf=opt["pooler_actf"]) out_proj = nn.Linear(hidden_size, lab) else: if decoder_opt == 1: out_proj = SANClassifier( hidden_size, hidden_size, lab, opt, prefix="answer", dropout=dropout, ) else: out_proj = nn.Linear(hidden_size, lab) self.scoring_list.append(out_proj) self.config = opt