def __init__(self, opt, emb_matrix=None): print(">> Current Model: StudentNetwork") super(StudentNetwork, self).__init__() if opt['base_mode'] == 0: # start from scratch print(" Do not use base model.") self.inst_encoder = BaseNetwork(opt=opt, emb_matrix=emb_matrix) else: self.base_model_file = opt['save_dir'] + '/' + opt['base_id'] + '/best_model.pt' self.base_opt = torch_utils.load_config(self.base_model_file) if opt['base_mode'] == 1: # load & fine tune print(" Fine-tune base model.") inst_base_model = BaseModel(self.base_opt) inst_base_model.load(self.base_model_file) self.inst_encoder = inst_base_model.model elif opt['base_mode'] == 2: # load & fix pre-trained print(" Fix pre-trained base model.") inst_base_model = BaseModel(self.base_opt) inst_base_model.load(self.base_model_file) inst_base_model = inst_base_model.model for param in inst_base_model.parameters(): param.requires_grad = False inst_base_model.eval() self.inst_encoder = inst_base_model else: print('Illegal Parameter (base_mode).') assert False self.pe_emb = nn.Embedding(constant.MAX_LEN * 2 + 1, opt['pe_dim_attn']) self.ner_emb = nn.Embedding(constant.NER_NUM, opt['ner_dim_attn']) self.attn_layer = MultiAspectAttention(opt) self.final_linear = nn.Linear(2 * opt['hidden_dim'], opt['num_class']) self.opt = opt self.init_weights()
def __init__(self, opt, emb_matrix=None): print(">> Current Model: TeacherNetwork") super(TeacherNetwork, self).__init__() if opt['base_mode'] == 0: # start from scratch print(" Do not use base model.") self.inst_encoder = BaseNetwork(opt=opt, emb_matrix=emb_matrix) else: self.base_model_file = opt['save_dir'] + '/' + opt[ 'base_id'] + '/best_model.pt' self.base_opt = torch_utils.load_config(self.base_model_file) if opt['base_mode'] == 1: # load & fine tune print(" Fine-tune base model.") inst_base_model = BaseModel(self.base_opt) inst_base_model.load(self.base_model_file) self.inst_encoder = inst_base_model.model elif opt['base_mode'] == 2: # load & fix pre-trained print(" Fix pre-trained base model.") inst_base_model = BaseModel(self.base_opt) inst_base_model.load(self.base_model_file) inst_base_model = inst_base_model.model for param in inst_base_model.parameters(): param.requires_grad = False inst_base_model.eval() self.inst_encoder = inst_base_model else: print('Illegal Parameter (base_mode).') assert False self.rel_matrix = nn.Embedding( opt['num_class'], opt['num_class'], padding_idx=constant.LABEL_TO_ID['no_relation']) self.opt = opt self.init_weights()
file_logger = helper.FileLogger( model_save_dir + '/' + opt['log'], header="# epoch\ttrain_loss\tdev_loss\tdev_f1\ttest_loss\ttest_f1") # print model info helper.print_config(opt) # model base_outputs = ['placeholder' for _ in range(0, len(train_batch))] if opt['base_mode'] == 3: base_outputs = [] base_model_file = opt['save_dir'] + '/' + opt['base_id'] + '/best_model.pt' print("Loading base model from {}".format(base_model_file)) base_opt = torch_utils.load_config(base_model_file) base_model = BaseModel(opt=base_opt) base_model.load(base_model_file) base_model.model.eval() for _, batch in enumerate(train_batch): inputs = [b.cuda() for b in batch[:10] ] if opt['cuda'] else [b for b in batch[:10]] base_logits, _, _ = base_model.model(inputs) base_outputs.append([base_logits.data.cpu().numpy()]) teacher_outputs = [] if opt['use_teacher']: teacher_model_file = opt['save_dir'] + '/' + opt[ 'teacher_id'] + '/best_model.pt' print("Loading teacher model from {}".format(teacher_model_file)) teacher_opt = torch_utils.load_config(teacher_model_file) teacher_model = TeacherModel(opt=teacher_opt) teacher_model.load(teacher_model_file)