def __init__(self, opt, emb_matrix=None): print(">> Current Model: StudentNetwork") super(StudentNetwork, self).__init__() if opt['base_mode'] == 0: # start from scratch print(" Do not use base model.") self.inst_encoder = BaseNetwork(opt=opt, emb_matrix=emb_matrix) else: self.base_model_file = opt['save_dir'] + '/' + opt['base_id'] + '/best_model.pt' self.base_opt = torch_utils.load_config(self.base_model_file) if opt['base_mode'] == 1: # load & fine tune print(" Fine-tune base model.") inst_base_model = BaseModel(self.base_opt) inst_base_model.load(self.base_model_file) self.inst_encoder = inst_base_model.model elif opt['base_mode'] == 2: # load & fix pre-trained print(" Fix pre-trained base model.") inst_base_model = BaseModel(self.base_opt) inst_base_model.load(self.base_model_file) inst_base_model = inst_base_model.model for param in inst_base_model.parameters(): param.requires_grad = False inst_base_model.eval() self.inst_encoder = inst_base_model else: print('Illegal Parameter (base_mode).') assert False self.pe_emb = nn.Embedding(constant.MAX_LEN * 2 + 1, opt['pe_dim_attn']) self.ner_emb = nn.Embedding(constant.NER_NUM, opt['ner_dim_attn']) self.attn_layer = MultiAspectAttention(opt) self.final_linear = nn.Linear(2 * opt['hidden_dim'], opt['num_class']) self.opt = opt self.init_weights()
def __init__(self, opt, emb_matrix=None): print(">> Current Model: TeacherNetwork") super(TeacherNetwork, self).__init__() if opt['base_mode'] == 0: # start from scratch print(" Do not use base model.") self.inst_encoder = BaseNetwork(opt=opt, emb_matrix=emb_matrix) else: self.base_model_file = opt['save_dir'] + '/' + opt[ 'base_id'] + '/best_model.pt' self.base_opt = torch_utils.load_config(self.base_model_file) if opt['base_mode'] == 1: # load & fine tune print(" Fine-tune base model.") inst_base_model = BaseModel(self.base_opt) inst_base_model.load(self.base_model_file) self.inst_encoder = inst_base_model.model elif opt['base_mode'] == 2: # load & fix pre-trained print(" Fix pre-trained base model.") inst_base_model = BaseModel(self.base_opt) inst_base_model.load(self.base_model_file) inst_base_model = inst_base_model.model for param in inst_base_model.parameters(): param.requires_grad = False inst_base_model.eval() self.inst_encoder = inst_base_model else: print('Illegal Parameter (base_mode).') assert False self.rel_matrix = nn.Embedding( opt['num_class'], opt['num_class'], padding_idx=constant.LABEL_TO_ID['no_relation']) self.opt = opt self.init_weights()