def test(model, ema, args, data): device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu") criterion = nn.CrossEntropyLoss() loss = 0 answers = dict() model.eval() backup_params = EMA(0) for name, param in model.named_parameters(): if param.requires_grad: backup_params.register(name, param.data) param.data.copy_(ema.get(name)) total_time = 0 previous_time = time.time() for batch in iter(data.dev_iter): #time1 = time.time() with torch.no_grad(): p1, p2 = model(batch.c_char,batch.q_char,batch.c_word[0],batch.q_word[0],batch.c_word[1],batch.q_word[1]) #p1, p2 = model(batch) #time2 = time.time() #total_time = total_time + time2 - time1 batch_loss = criterion(p1, batch.s_idx) + criterion(p2, batch.e_idx) loss += batch_loss.item() # (batch, c_len, c_len) batch_size, c_len = p1.size() ls = nn.LogSoftmax(dim=1) mask = (torch.ones(c_len, c_len) * float('-inf')).to(device).tril(-1).unsqueeze(0).expand(batch_size, -1, -1) score = (ls(p1).unsqueeze(2) + ls(p2).unsqueeze(1)) + mask score, s_idx = score.max(dim=1) score, e_idx = score.max(dim=1) s_idx = torch.gather(s_idx, 1, e_idx.view(-1, 1)).squeeze() for i in range(batch_size): id = batch.id[i] answer = batch.c_word[0][i][s_idx[i]:e_idx[i] + 1] answer = ' '.join([data.CONTEXT_WORD.vocab.itos[idx] for idx in answer]) if answer == "<eos>": answer = "" answers[id] = answer #print(f'one epoch time {time.time()-previous_time}') #print(f'total time {total_time}') for name, param in model.named_parameters(): if param.requires_grad: param.data.copy_(backup_params.get(name)) with open(args.prediction_file, 'w', encoding='utf-8') as f: print(json.dumps(answers), file=f) opts = evaluate.parse_args(args=[f"{args.dataset_file}", f"{args.prediction_file}" ]) results = evaluate.main(opts) return loss, results['exact'], results['f1'], results['HasAns_exact'], results['HasAns_f1'], results['NoAns_exact'], results['NoAns_f1']
def adapt(self, num_classes): ''' To allow adapting the model to a different dataset with the same semantic classifier weights num_classes: number of classes in the target dataset return: None ''' if isinstance(self.model, torch.nn.DataParallel): self.model.module.adapt(num_classes) else: self.model.adapt(num_classes) self.model.to(self.device) self.ema = EMA(self.model, self.config.ema_alpha)
def test(model, ema, args, data): device = torch.device( f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu") criterion = nn.CrossEntropyLoss() loss = 0 answers = dict() model.eval() backup_params = EMA(0) for name, param in model.named_parameters(): if param.requires_grad: backup_params.register(name, param.data) param.data.copy_(ema.get(name)) with torch.set_grad_enabled(False): for batch in iter(data.dev_iter): p1, p2 = model(batch) batch_loss = criterion(p1, batch.s_idx) + criterion( p2, batch.e_idx) loss += batch_loss.item() # (batch, c_len, c_len) batch_size, c_len = p1.size() ls = nn.LogSoftmax(dim=1) mask = (torch.ones(c_len, c_len) * float('-inf')).to(device).tril(-1).unsqueeze(0).expand( batch_size, -1, -1) score = (ls(p1).unsqueeze(2) + ls(p2).unsqueeze(1)) + mask score, s_idx = score.max(dim=1) score, e_idx = score.max(dim=1) s_idx = torch.gather(s_idx, 1, e_idx.view(-1, 1)).squeeze() for i in range(batch_size): id = batch.id[i] answer = batch.c_word[0][i][s_idx[i]:e_idx[i] + 1] answer = ' '.join( [data.WORD.vocab.itos[idx] for idx in answer]) answers[id] = answer for name, param in model.named_parameters(): if param.requires_grad: param.data.copy_(backup_params.get(name)) #print(answers) with open(args.prediction_file, 'w', encoding='utf-8') as f: print(json.dumps(answers, indent=4), file=f) results = evaluate.main(args, answers, data) return loss / len(data.dev_iter), results['exact_match'], results['f1']
def load_model_state(self, chkpt_dict_path): ''' Loads model state based on a checkpoint saved by SemCo _save_checkpoint() function. ''' print("Loading Model State") checkpoint_dict = torch.load(chkpt_dict_path, map_location=self.device) if 'model_state_dict' in checkpoint_dict: state_dict = checkpoint_dict['model_state_dict'] else: print('model_state_dict key is not present in checkpoint, loading pretrained model failed, using original initialization for model') return # handle state_dictionaries where keys has 'module' in them (if the model was wrapped in nn.DataParallel) if all(['module' in key for key in state_dict.keys()]): if all(['module' in key for key in self.model.state_dict()]): pass else: state_dict= {k.replace('module.',''):v for k,v in state_dict.items()} if 'ema_shadow' in checkpoint_dict: checkpoint_dict['ema_shadow'] = {k.replace('module.',''):v for k,v in checkpoint_dict['ema_shadow'].items()} try: self.model.load_state_dict(state_dict) except Exception as e: print(f'Problem occurred during naive state_dict loading: {e}.\nTrying to only load common params') try: model_state= self.model.state_dict() pretrained_state = {k:v for k,v in state_dict.items() if k in model_state and v.size() == model_state[k].size()} unloaded_state = set(list(state_dict.keys())) - set(list(model_state.keys())) model_state.update(pretrained_state) self.model.load_state_dict(model_state) print(f'Success. Following params in pretrained_state_dict were not loaded: {unloaded_state}') except Exception as e: print(f'Unable to load model state due to following error. Model will be initialised randomly. \n {e}') if 'ema_shadow' in checkpoint_dict: try: self.ema = EMA(self.model, self.config.ema_alpha) similar_params = {k:v for k,v in checkpoint_dict['ema_shadow'].items() if k in self.ema.shadow and v.size() == self.ema.shadow[k].size()} self.ema.shadow.update(similar_params) print(f'EMA shadow has been loaded successfully. {len(similar_params)} out of {len(self.ema.shadow)} params were loaded') except Exception as e: print(f'Unable to load EMA shadow. EMA will be reinitialised with current model params. {e}') self.ema = EMA(self.model, self.config.ema_alpha) else: print('EMA shadow is not found in checkpoint dictionary. EMA will be reinitialised with current model params.') self.ema = EMA(self.model, self.config.ema_alpha) try: if 'classes' in checkpoint_dict: classes = self.dataset_meta['classes'] classes_model = checkpoint_dict['classes'] if all([classes_model[i] == classes[i] for i in range(len(classes))]): print(f'classes matched successfully') else: print( "Classes loaded don't match the classes used while training the model, output of softmax can't be trusted") except Exception as e: print("can't load classes file. Pls check and try again.") return
def __init__(self, config, dataset_meta, device, L='dynamic', device_ids=None): self.config = config self.dataset_meta = dataset_meta if 'stats' not in dataset_meta: self.dataset_meta['stats'] = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) # imagenet_stats self.device = device self.parallel = config.parallel self.device_ids = device_ids self.L = L self.label_emb_guessor, self.emb_dim = self._get_label_guessor() self.model = self._set_model(config.parallel, device_ids) self.optim = self._get_optimiser(self.config) if not self.config.no_amp: from apex import amp self.model, self.optim = amp.initialize(self.model, self.optim, opt_level="O1") if self.config.parallel: self.model = nn.DataParallel(self.model) # initialise the exponential moving average model self.ema = EMA(self.model, self.config.ema_alpha) if self.config.use_pretrained: self.load_model_state(config.checkpoint_path) if self.config.freeze_backbone: self._freeze_model_backbone() self.logger, self.writer, self.time_stamp = self._setup_default_logging()
def cw_tree_attack_targeted(): cw = CarliniL2_qa(debug=args.debugging) criterion = nn.CrossEntropyLoss() loss = 0 tot = 0 adv_loss = 0 targeted_success = 0 untargeted_success = 0 adv_text = [] answers = dict() adv_answers = dict() # model.eval() embed = torch.load(args.word_vector) device = torch.device("cuda:0" if args.cuda else "cpu") vocab = Vocab(filename=args.dictionary, data=[PAD_WORD, UNK_WORD, EOS_WORD, SOS_WORD]) generator = Generator(args.test_data, vocab=vocab, embed=embed) transfered_embedding = torch.load('bidaf_transfered_embedding.pth') transfer_emb = torch.nn.Embedding.from_pretrained(transfered_embedding).to( device) seqback = WrappedSeqback(embed, device, attack=True, seqback_model=generator.seqback_model, vocab=vocab, transfer_emb=transfer_emb) treelstm = generator.tree_model generator.load_state_dict(torch.load(args.load_ae)) backup_params = EMA(0) for name, param in model.named_parameters(): if param.requires_grad: backup_params.register(name, param.data) param.data.copy_(ema.get(name)) class TreeModel(nn.Module): def __init__(self): super(TreeModel, self).__init__() self.inputs = None def forward(self, hidden): self.embedding = seqback(hidden) return model(batch, perturbed=self.embedding) def set_temp(self, temp): seqback.temp = temp def get_embedding(self): return self.embedding def get_seqback(self): return seqback tree_model = TreeModel() for batch in tqdm(iter(data.dev_iter), total=1000): p1, p2 = model(batch) orig_answer, orig_s_idx, orig_e_idx = write_to_ans( p1, p2, batch, answers) batch_loss = criterion(p1, batch.s_idx) + criterion(p2, batch.e_idx) loss += batch_loss.item() append_info = append_input(batch, vocab) batch_add_start = append_info['add_start'] batch_add_end = append_info['add_end'] batch_start_target = torch.LongTensor( append_info['target_start']).to(device) batch_end_target = torch.LongTensor( append_info['target_end']).to(device) add_sents = append_info['append_sent'] input_embedding = model.word_emb(batch.c_word[0]) append_info['tree'] = [generator.get_tree(append_info['tree'])] seqback.sentences = input_embedding.clone().detach() seqback.batch_trees = append_info['tree'] seqback.batch_add_sent = append_info['ae_sent'] seqback.start = append_info['add_start'] seqback.end = append_info['add_end'] seqback.adv_sent = [] batch_tree_embedding = [] for bi, append_sent in enumerate(append_info['ae_sent']): seqback.target_start = append_info['target_start'][ 0] - append_info['add_start'][0] seqback.target_end = append_info['target_end'][0] - append_info[ 'add_start'][0] sentences = [ torch.tensor(append_sent, dtype=torch.long, device=device) ] seqback.target = sentences[0][seqback. target_start:seqback.target_end + 1] trees = [append_info['tree'][bi]] tree_embedding = treelstm(sentences, trees)[0][0].detach() batch_tree_embedding.append(tree_embedding) hidden = torch.cat(batch_tree_embedding, dim=0) cw.batch_info = append_info cw.num_classes = append_info['tot_length'] adv_hidden = cw.run(tree_model, hidden, (batch_start_target, batch_end_target), input_token=input_embedding) seqback.adv_sent = [] # re-test for bi, (add_start, add_end) in enumerate(zip(batch_add_start, batch_add_end)): if bi in cw.o_best_sent: ae_words = cw.o_best_sent[bi] bidaf_tokens = bidaf_convert_to_idx(ae_words) batch.c_word[0].data[bi, add_start:add_end] = torch.LongTensor( bidaf_tokens) p1, p2 = model(batch) adv_answer, adv_s_idx, adv_e_idx = write_to_ans( p1, p2, batch, adv_answers) batch_loss = criterion(p1, batch.s_idx) + criterion(p2, batch.e_idx) adv_loss += batch_loss.item() for bi, (start_target, end_target) in enumerate( zip(batch_start_target, batch_end_target)): start_output = adv_s_idx end_output = adv_e_idx targeted_success += int( compare(start_output, start_target.item(), end_output, end_target.item())) untargeted_success += int( compare_untargeted(start_output, start_target.item(), end_output, end_target.item())) for i in range(len(add_sents)): logger.info(("orig:", transform(add_sents[i]))) try: logger.info(("adv:", cw.o_best_sent[i])) adv_text.append({ 'adv_text': cw.o_best_sent[i], 'qas_id': batch.id[i], 'adv_predict': (orig_s_idx, orig_e_idx), 'orig_predict': (adv_s_idx, adv_e_idx), 'Orig answer:': orig_answer, 'Adv answer:': adv_answer }) joblib.dump(adv_text, root_dir + '/adv_text.pkl') except: adv_text.append({ 'adv_text': transform(add_sents[i]), 'qas_id': batch.id[i], 'adv_predict': (orig_s_idx, orig_e_idx), 'orig_predict': (adv_s_idx, adv_e_idx), 'Orig answer:': orig_answer, 'Adv answer:': adv_answer }) joblib.dump(adv_text, root_dir + '/adv_text.pkl') continue # for batch size = 1 tot += 1 logger.info(("orig predict", (orig_s_idx, orig_e_idx))) logger.info(("adv append predict", (adv_s_idx, adv_e_idx))) logger.info(("targeted successful rate:", targeted_success)) logger.info(("untargetd successful rate:", untargeted_success)) logger.info(("Orig answer:", orig_answer)) logger.info(("Adv answer:", adv_answer)) logger.info(("tot:", tot)) for name, param in model.named_parameters(): if param.requires_grad: param.data.copy_(backup_params.get(name)) with open(options.prediction_file, 'w', encoding='utf-8') as f: print(json.dumps(answers), file=f) with open(options.prediction_file + '_adv.json', 'w', encoding='utf-8') as f: print(json.dumps(adv_answers), file=f) results = evaluate.main(options) logger.info(tot) logger.info(("adv loss, results['exact_match'], results['f1']", loss, results['exact_match'], results['f1'])) return loss, results['exact_match'], results['f1']
def cw_random_word_attack(): cw = CarliniL2_untargeted_qa(debug=args.debugging) criterion = nn.CrossEntropyLoss() loss = 0 adv_loss = 0 targeted_success = 0 untargeted_success = 0 adv_text = [] answers = dict() adv_answers = dict() backup_params = EMA(0) for name, param in model.named_parameters(): if param.requires_grad: backup_params.register(name, param.data) param.data.copy_(ema.get(name)) tot = 0 for batch in tqdm(iter(data.dev_iter), total=1000): p1, p2 = model(batch) orig_answer, orig_s_idx, orig_e_idx = write_to_ans( p1, p2, batch, answers) batch_loss = criterion(p1, batch.s_idx) + criterion(p2, batch.e_idx) loss += batch_loss.item() append_info = append_random_input(batch) allow_idxs = append_info['allow_idx'] batch_start_target = torch.LongTensor([0]).to(device) batch_end_target = torch.LongTensor([0]).to(device) input_embedding = model.word_emb(batch.c_word[0]) cw_mask = np.zeros(input_embedding.shape).astype(np.float32) cw_mask = torch.from_numpy(cw_mask).float().to(device) for bi, allow_idx in enumerate(allow_idxs): cw_mask[bi, np.array(allow_idx)] = 1 cw.wv = model.word_emb.weight cw.inputs = batch cw.mask = cw_mask cw.batch_info = append_info cw.num_classes = append_info['tot_length'] # print(transform(to_list(batch.c_word[0][0]))) cw.run(model, input_embedding, (batch_start_target, batch_end_target)) # re-test for bi, allow_idx in enumerate(allow_idxs): if bi in cw.o_best_sent: for i, idx in enumerate(allow_idx): batch.c_word[0].data[bi, idx] = cw.o_best_sent[bi][i] p1, p2 = model(batch) adv_answer, adv_s_idx, adv_e_idx = write_to_ans( p1, p2, batch, adv_answers) batch_loss = criterion(p1, batch.s_idx) + criterion(p2, batch.e_idx) adv_loss += batch_loss.item() for bi, (start_target, end_target) in enumerate( zip(batch_start_target, batch_end_target)): start_output = adv_s_idx end_output = adv_e_idx targeted_success += int( compare(start_output, start_target.item(), end_output, end_target.item())) untargeted_success += int( compare_untargeted(start_output, start_target.item(), end_output, end_target.item())) for i in range(len(allow_idxs)): try: logger.info(("adv:", transform(cw.o_best_sent[i]))) adv_text.append({ 'added_text': transform(cw.o_best_sent[i]), 'adv_text': transform(to_list(batch.c_word[0][0])), 'qas_id': batch.id[i], 'adv_predict': (orig_s_idx, orig_e_idx), 'orig_predict': (adv_s_idx, adv_e_idx), 'Orig answer:': orig_answer, 'Adv answer:': adv_answer }) joblib.dump(adv_text, root_dir + '/adv_text.pkl') except: adv_text.append({ 'adv_text': transform(to_list(batch.c_word[0][0])), 'qas_id': batch.id[i], 'adv_predict': (orig_s_idx, orig_e_idx), 'orig_predict': (adv_s_idx, adv_e_idx), 'Orig answer:': orig_answer, 'Adv answer:': adv_answer }) joblib.dump(adv_text, root_dir + '/adv_text.pkl') continue # for batch size = 1 tot += 1 logger.info(("orig predict", (orig_s_idx, orig_e_idx))) logger.info(("adv append predict", (adv_s_idx, adv_e_idx))) logger.info(("targeted successful rate:", targeted_success)) logger.info(("untargetd successful rate:", untargeted_success)) logger.info(("Orig answer:", orig_answer)) logger.info(("Adv answer:", adv_answer)) logger.info(("tot:", tot)) for name, param in model.named_parameters(): if param.requires_grad: param.data.copy_(backup_params.get(name)) with open(options.prediction_file, 'w', encoding='utf-8') as f: print(json.dumps(answers), file=f) with open(options.prediction_file + '_adv.json', 'w', encoding='utf-8') as f: print(json.dumps(adv_answers), file=f) results = evaluate.main(options) logger.info(tot) logger.info(("adv loss, results['exact_match'], results['f1']", loss, results['exact_match'], results['f1'])) return loss, results['exact_match'], results['f1']
answer_append_sentences = joblib.load( 'sampled_perturb_answer_sentences.pkl') question_append_sentences = joblib.load( 'sampled_perturb_question_sentences.pkl') model = BiDAF(options, data.WORD.vocab.vectors).to(device) if options.old_model is not None: model.load_state_dict( torch.load(options.old_model, map_location="cuda:{}".format(options.gpu))) if options.old_ema is not None: # ema = pickle.load(open(options.old_ema, "rb")) ema = torch.load(options.old_ema, map_location=device) else: ema = EMA(options.exp_decay_rate) for name, param in model.named_parameters(): if param.requires_grad: ema.register(name, param.data) torch.manual_seed(args.seed) if torch.cuda.is_available(): if not args.cuda: print( "WARNING: You have a CUDA device, so you should probably run with --cuda" ) else: torch.cuda.manual_seed(args.seed) random.seed(args.seed) if args.model == 'word_attack':
def train(args): db = Data(args) # db.build_vocab() # 每次build_vocab,相同频数的字词id可能不同 db.load_vocab() db.build_dataset() # 得到train_loader model = BiDAF(args) if args.cuda: model = model.cuda() if args.ema: ema = EMA(0.999) print("Register EMA ...") for name, param in model.named_parameters(): if param.requires_grad: ema.register(name, param.data) init_lr = args.init_lr optimizer = torch.optim.Adam(params=model.parameters(), lr=init_lr) lr = init_lr batch_step = args.batch_step loss_fn = nn.CrossEntropyLoss() logger = Logger('./logs') step = 0 valid_raw_article_list = db.valid_raw_article_list valid_answer_list = db.valid_answer_list print('========== Train ==============') for epoch in range(args.epoch_num): print('---Epoch', epoch, "lr:", lr) running_loss = 0.0 count = 0 print("len(db.train_loader):", len(db.train_loader)) for article, question, answer_span, _ in db.train_loader: if args.cuda: article, question, answer_span = article.cuda(), question.cuda( ), answer_span.cuda() p1, p2 = model(article, question) loss_p1 = loss_fn(p1, answer_span.transpose(0, 1)[0]) loss_p2 = loss_fn(p2, answer_span.transpose(0, 1)[1]) running_loss += loss_p1.item() running_loss += loss_p2.item() optimizer.zero_grad() (loss_p1 + loss_p2).backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 2) optimizer.step() if args.ema: for name, param in model.named_parameters(): if param.requires_grad: param.data = ema(name, param.data) count += 1 if count % batch_step == 0: rep_str = '[{}] Epoch {}, loss: {:.3f}' print( rep_str.format( datetime.datetime.now().strftime('%Y%m%d-%H%M%S'), epoch, running_loss / batch_step)) info = {'loss': running_loss / batch_step} running_loss = 0.0 count = 0 # 1. Log scalar values (scalar summary) for tag, value in info.items(): logger.scalar_summary(tag, value, step + 1) # 2. Log values and gradients of the parameters (histogram summary) for tag, value in model.named_parameters(): tag = tag.replace('.', '/') logger.histo_summary(tag, value.data.cpu().numpy(), step + 1) logger.histo_summary(tag + '/grad', value.grad.data.cpu().numpy(), step + 1) step += 1 # 验证集 if args.with_valid: print('======== Epoch {} result ========'.format(epoch)) print("len(db.valid_loader):", len(db.valid_loader)) valid_result = [] idx = 0 for article, question, _ in db.valid_loader: if args.cuda: article, question = article.cuda(), question.cuda() p1, p2 = model(article, question, is_trainning=False) _, p1_predicted = torch.max(p1.cpu().data, 1) _, p2_predicted = torch.max(p2.cpu().data, 1) p1_predicted = p1_predicted.numpy().tolist() p2_predicted = p2_predicted.numpy().tolist() for _p1, _p2, _raw_article, _answer in zip( p1_predicted, p2_predicted, valid_raw_article_list[idx:idx + len(p1_predicted)], valid_answer_list[idx:idx + len(p1_predicted)]): valid_result.append({ "ref_answer": _answer, "cand_answer": "".join(_raw_article[_p1:_p2 + 1]) }) idx = idx + len(p1_predicted) rouge_score = test_score(valid_result) info = {'rouge_score': rouge_score} for tag, value in info.items(): logger.scalar_summary(tag, value, epoch + 1) lr = max(0.00001, init_lr * 0.9**(epoch + 1)) print("lr:", lr) parameters = filter(lambda param: param.requires_grad, model.parameters()) optimizer = torch.optim.Adam(params=parameters, lr=lr, weight_decay=1e-7) # print(len(db.valid_loader)) if epoch >= 1 and args.saved_model_file: torch.save(model.state_dict(), args.saved_model_file + "_epoch_" + str(epoch)) print("saved model")
def train(args, data): device = torch.device( "cuda:{}".format(args.gpu) if torch.cuda.is_available() else "cpu") model = BiDAF(args, data.WORD.vocab.vectors).to(device) ema = EMA(args.exp_decay_rate) for name, param in model.named_parameters(): if param.requires_grad: ema.register(name, param.data) parameters = filter(lambda p: p.requires_grad, model.parameters()) optimizer = optim.Adadelta(parameters, lr=args.learning_rate) criterion = nn.CrossEntropyLoss() writer = SummaryWriter(log_dir='runs/' + args.model_time) model.train() loss, last_epoch = 0, -1 max_dev_exact, max_dev_f1 = -1, -1 iterator = data.train_iter for i, batch in enumerate(iterator): present_epoch = int(iterator.epoch) if present_epoch == args.epoch: break if present_epoch > last_epoch: print('epoch:', present_epoch + 1) last_epoch = present_epoch p1, p2 = model(batch) optimizer.zero_grad() batch_loss = criterion(p1, batch.s_idx) + criterion(p2, batch.e_idx) loss += batch_loss.item() batch_loss.backward() optimizer.step() for name, param in model.named_parameters(): if param.requires_grad: ema.update(name, param.data) if (i + 1) % args.print_freq == 0: dev_loss, dev_exact, dev_f1 = test(model, ema, args, data) c = (i + 1) // args.print_freq writer.add_scalar('loss/train', loss, c) writer.add_scalar('loss/dev', dev_loss, c) writer.add_scalar('exact_match/dev', dev_exact, c) writer.add_scalar('f1/dev', dev_f1, c) print('train loss: {} / dev loss: {}'.format(loss, dev_loss) + ' / dev EM: {} / dev F1: {}'.format(dev_exact, dev_f1)) if dev_f1 > max_dev_f1: max_dev_f1 = dev_f1 max_dev_exact = dev_exact best_model = copy.deepcopy(model) loss = 0 model.train() writer.close() print('max dev EM: {} / max dev F1: {}'.format(max_dev_exact, max_dev_f1)) return best_model
def train(args): db = Data(args) # db.build_vocab() # 每次build_vocab,相同频数的字词id可能不同 db.load_vocab() db.build_dataset() # 得到train_loader # model = BiDAF(args) model = SLQA(args) first_model = "./checkpoints/SLQA_elmo_epoch_0" model.load_state_dict(torch.load(first_model)) if args.cuda: model = model.cuda() if args.ema: ema = EMA(0.999) print("Register EMA ...") for name, param in model.named_parameters(): if param.requires_grad: ema.register(name, param.data) init_lr = args.init_lr parameters = filter(lambda param: param.requires_grad, model.parameters()) weight_decay = 1e-6 weight_decay = 0 optimizer = torch.optim.Adam(params=parameters, lr=init_lr, weight_decay=weight_decay) batch_step = args.batch_step loss_fn = nn.CrossEntropyLoss() logger = Logger('./logs') step = 0 train_raw_article_list = db.train_raw_article_list train_raw_question_list = db.train_raw_question_list valid_raw_article_list = db.valid_raw_article_list valid_answer_list = db.valid_answer_list valid_raw_question_list = db.valid_raw_question_list # question_hdf5_f = h5py.File(args.question_hdf5_path, "r") # article_hdf5_f = h5py.File(args.article_hdf5_path, "r") print('========== Train ==============') for epoch in range(args.epoch_num): print('---Epoch', epoch) running_loss = 0.0 count = 0 print("len(db.train_loader):", len(db.train_loader)) train_idx = 0 for batch_id, (article, question, answer_span, _) in enumerate(db.train_loader): if args.cuda: article, question, answer_span = article.cuda(), question.cuda( ), answer_span.cuda() # tmp_train_raw_article_list = train_raw_article_list[train_idx:train_idx + question.size()[0]] # tmp_train_raw_question_list = train_raw_question_list[train_idx:train_idx + question.size()[0]] # question_elmo = gen_elmo_by_text(question_hdf5_f, tmp_train_raw_question_list, args.max_question_len) # article_elmo = gen_elmo_by_text(article_hdf5_f, tmp_train_raw_article_list, args.max_article_len) # pickle.dump((article_elmo, question_elmo), open(elmo_save_path, "wb")) elmo_save_path = "/backup231/lhliu/jszn/elmo/" + str( batch_id) + ".pkl" article_elmo, question_elmo = pickle.load( open(elmo_save_path, "rb")) # print(elmo_save_path) article_elmo = torch.tensor(article_elmo, dtype=torch.float) question_elmo = torch.tensor(question_elmo, dtype=torch.float) # train_idx += question.size()[0] # continue if args.cuda: question_elmo = question_elmo.cuda() article_elmo = article_elmo.cuda() p1, p2 = model(article, question, article_elmo=article_elmo, question_elmo=question_elmo) loss_p1 = loss_fn(p1, answer_span.transpose(0, 1)[0]) loss_p2 = loss_fn(p2, answer_span.transpose(0, 1)[1]) running_loss += loss_p1.item() running_loss += loss_p2.item() optimizer.zero_grad() (loss_p1 + loss_p2).backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 2) optimizer.step() if args.ema: for name, param in model.named_parameters(): if param.requires_grad: param.data = ema(name, param.data) count += 1 if count % batch_step == 0: rep_str = '[{}] Epoch {}, loss: {:.3f}' print( rep_str.format( datetime.datetime.now().strftime('%Y%m%d-%H%M%S'), epoch, running_loss / batch_step)) # info = {'loss': running_loss / batch_step} running_loss = 0.0 count = 0 # # 1. Log scalar values (scalar summary) # for tag, value in info.items(): # logger.scalar_summary(tag, value, step + 1) # # 2. Log values and gradients of the parameters (histogram summary) # for tag, value in model.named_parameters(): # tag = tag.replace('.', '/') # logger.histo_summary(tag, value.data.cpu().numpy(), step + 1) # logger.histo_summary(tag + '/grad', value.grad.data.cpu().numpy(), step + 1) step += 1 # break # 验证集 if args.with_valid: print('======== Epoch {} result ========'.format(epoch)) print("len(db.valid_loader):", len(db.valid_loader)) valid_result = [] idx = 0 for article, question, _ in db.valid_loader: if args.cuda: article, question = article.cuda(), question.cuda() tmp_valid_raw_article_list = valid_raw_article_list[idx:idx + question. size()[0]] tmp_valid_raw_question_list = valid_raw_question_list[ idx:idx + question.size()[0]] question_elmo = gen_elmo_by_text(question_hdf5_f, tmp_valid_raw_question_list, args.max_question_len) article_elmo = gen_elmo_by_text(article_hdf5_f, tmp_valid_raw_article_list, args.max_article_len) if args.cuda: question_elmo = question_elmo.cuda() article_elmo = article_elmo.cuda() p1, p2 = model(article, question, article_elmo, question_elmo, is_training=False) _, p1_predicted = torch.max(p1.cpu().data, 1) _, p2_predicted = torch.max(p2.cpu().data, 1) p1_predicted = p1_predicted.numpy().tolist() p2_predicted = p2_predicted.numpy().tolist() assert question.size()[0] == len(p1_predicted) for _p1, _p2, _raw_article, _answer in zip( p1_predicted, p2_predicted, valid_raw_article_list[idx:idx + len(p1_predicted)], valid_answer_list[idx:idx + len(p1_predicted)]): valid_result.append({ "ref_answer": _answer, "cand_answer": "".join(_raw_article[_p1:_p2 + 1]) }) idx = idx + len(p1_predicted) rouge_score = test_score(valid_result) info = {'rouge_score': rouge_score} for tag, value in info.items(): logger.scalar_summary(tag, value, epoch + 1) #lr = init_lr lr = max(0.00001, init_lr * 0.9**(epoch + 1)) # 考虑是否使用 print("lr:", lr) parameters = filter(lambda param: param.requires_grad, model.parameters()) optimizer = torch.optim.Adam(params=parameters, lr=lr, weight_decay=weight_decay) # print(len(db.valid_loader)) if epoch >= 0 and args.saved_model_file: torch.save(model.state_dict(), args.saved_model_file + "_epoch_" + str(epoch)) print("saved model")
def train(args, data): if args.load_model != "": model = BiDAF(args, data.WORD.vocab.vectors) model.load_state_dict(torch.load(args.load_model)) else: model = BiDAF(args, data.WORD.vocab.vectors) device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu") model = model.to(device) ema = EMA(args.exp_decay_rate) for name, param in model.named_parameters(): if param.requires_grad: ema.register(name, param.data) for name, i in model.named_parameters(): if not i.is_leaf: print(name,i) writer = SummaryWriter(log_dir='runs/' + args.model_name) best_model = None for iterator, dev_iter, dev_file_name, index, print_freq, lr in zip(data.train_iter, data.dev_iter, args.dev_files, range(len(data.train)), args.print_freq, args.learning_rate): # print # (iterator[0]) embed() exit(0) optimizer = optim.Adadelta(model.parameters(), lr=lr) criterion = nn.CrossEntropyLoss() model.train() loss, last_epoch = 0, 0 max_dev_exact, max_dev_f1 = -1, -1 print(f"Training with {dev_file_name}") print() for i, batch in tqdm(enumerate(iterator), total=len(iterator) * args.epoch[index], ncols=100): present_epoch = int(iterator.epoch) eva = False if present_epoch == args.epoch[index]: break if present_epoch > last_epoch: print('epoch:', present_epoch + 1) eva = True last_epoch = present_epoch p1, p2 = model(batch) optimizer.zero_grad() batch_loss = criterion(p1, batch.s_idx) + criterion(p2, batch.e_idx) loss += batch_loss.item() batch_loss.backward() optimizer.step() for name, param in model.named_parameters(): if param.requires_grad: ema.update(name, param.data) torch.cuda.empty_cache() if (i + 1) % print_freq == 0 or eva: dev_loss, dev_exact, dev_f1 = test(model, ema, args, data, dev_iter, dev_file_name) c = (i + 1) // print_freq writer.add_scalar('loss/train', loss, c) writer.add_scalar('loss/dev', dev_loss, c) writer.add_scalar('exact_match/dev', dev_exact, c) writer.add_scalar('f1/dev', dev_f1, c) print() print(f'train loss: {loss:.3f} / dev loss: {dev_loss:.3f}' f' / dev EM: {dev_exact:.3f} / dev F1: {dev_f1:.3f}') if dev_f1 > max_dev_f1: max_dev_f1 = dev_f1 max_dev_exact = dev_exact best_model = copy.deepcopy(model) loss = 0 model.train() writer.close() print(f'max dev EM: {max_dev_exact:.3f} / max dev F1: {max_dev_f1:.3f}') print("testing with test batch on best model") test_loss, test_exact, test_f1 = test(best_model, ema, args, data, list(data.test_iter)[-1], args.test_files[-1]) print(f'test loss: {test_loss:.3f}' f' / test EM: {test_exact:.3f} / test F1: {test_f1:.3f}') return best_model
def train(args, data): device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu") model = BiDAF(args, data.CONTEXT_WORD.vocab.vectors).to(device) num = count_parameters(model) print(f'paramter {num}') if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") model = nn.DataParallel(model) ema = EMA(args.exp_decay_rate) for name, param in model.named_parameters(): if param.requires_grad: ema.register(name, param.data) parameters = filter(lambda p: p.requires_grad, model.parameters()) optimizer = optim.Adadelta(parameters, lr=args.learning_rate) criterion = nn.CrossEntropyLoss() writer = SummaryWriter(log_dir='runs/' + args.model_time) model.train() loss, last_epoch = 0, -1 max_dev_exact, max_dev_f1 = -1, -1 print('totally {} epoch'.format(args.epoch)) sys.stdout.flush() iterator = data.train_iter iterator.repeat = True for i, batch in enumerate(iterator): present_epoch = int(iterator.epoch) if present_epoch == args.epoch: print('present_epoch value:',present_epoch) break if present_epoch > last_epoch: print('epoch:', present_epoch + 1) last_epoch = present_epoch p1, p2 = model(batch.c_char,batch.q_char,batch.c_word[0],batch.q_word[0],batch.c_word[1],batch.q_word[1]) optimizer.zero_grad() batch_loss = criterion(p1, batch.s_idx) + criterion(p2, batch.e_idx) loss += batch_loss.item() batch_loss.backward() optimizer.step() for name, param in model.named_parameters(): if param.requires_grad: ema.update(name, param.data) if (i + 1) % args.print_freq == 0: dev_loss, dev_exact, dev_f1, dev_hasans_exact, dev_hasans_f1, dev_noans_exact,dev_noans_f1 = test(model, ema, args, data) c = (i + 1) // args.print_freq writer.add_scalar('loss/train', loss, c) writer.add_scalar('loss/dev', dev_loss, c) writer.add_scalar('exact_match/dev', dev_exact, c) writer.add_scalar('f1/dev', dev_f1, c) print(f'train loss: {loss:.3f} / dev loss: {dev_loss:.3f}' f' / dev EM: {dev_exact:.3f} / dev F1: {dev_f1:.3f}' f' / dev hasans EM: {dev_hasans_exact} / dev hasans F1: {dev_hasans_f1}' f' / dev noans EM: {dev_noans_exact} / dev noans F1: {dev_noans_f1}') if dev_f1 > max_dev_f1: max_dev_f1 = dev_f1 max_dev_exact = dev_exact best_model = copy.deepcopy(model) loss = 0 model.train() sys.stdout.flush() writer.close() args.max_f1 = max_dev_f1 print(f'max dev EM: {max_dev_exact:.3f} / max dev F1: {max_dev_f1:.3f}') return best_model
def train(args, data): device = torch.device( f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu") model = BiDAF(args, data.WORD.vocab.vectors).to(device) ema = EMA(args.exp_decay_rate) for name, param in model.named_parameters(): if param.requires_grad: ema.register(name, param.data) parameters = filter(lambda p: p.requires_grad, model.parameters()) optimizer = optim.Adadelta(parameters, lr=args.learning_rate) criterion = nn.CrossEntropyLoss() writer = SummaryWriter(log_dir='runs/' + args.model_time) model.train() loss, last_epoch = 0, -1 max_dev_exact, max_dev_f1 = -1, -1 iterator = data.train_iter num_batch = len(iterator) for present_epoch in range(args.epoch): print('epoch', present_epoch + 1) for i, batch in enumerate(iterator): # present_epoch = int(iterator.epoch) """ if present_epoch == args.epoch: print(present_epoch) print() print(args.epoch) break if present_epoch > last_epoch: print('epoch:', present_epoch + 1) last_epoch = present_epoch """ p1, p2 = model(batch) optimizer.zero_grad() """ print(p1) print() print(batch.s_idx) """ if len(p1.size()) == 1: p1 = p1.reshape(1, -1) if len(p2.size()) == 1: p2 = p2.reshape(1, -1) batch_loss = criterion(p1, batch.s_idx) + criterion( p2, batch.e_idx) loss += batch_loss.item() batch_loss.backward() optimizer.step() for name, param in model.named_parameters(): if param.requires_grad: ema.update(name, param.data) best_model = copy.deepcopy(model) if i + 1 == num_batch: dev_loss, dev_exact, dev_f1 = test(model, ema, args, data) c = (i + 1) // args.print_freq writer.add_scalar('loss/train', loss / num_batch, c) writer.add_scalar('loss/dev', dev_loss, c) writer.add_scalar('exact_match/dev', dev_exact, c) writer.add_scalar('f1/dev', dev_f1, c) print( f'train loss: {loss/num_batch:.3f} / dev loss: {dev_loss:.3f}' f' / dev EM: {dev_exact:.3f} / dev F1: {dev_f1:.3f}') if dev_f1 > max_dev_f1: max_dev_f1 = dev_f1 max_dev_exact = dev_exact best_model = copy.deepcopy(model) loss = 0 model.train() writer.close() print(f'max dev EM: {max_dev_exact:.3f} / max dev F1: {max_dev_f1:.3f}') return best_model
class SemCo: def __init__(self, config, dataset_meta, device, L='dynamic', device_ids=None): self.config = config self.dataset_meta = dataset_meta if 'stats' not in dataset_meta: self.dataset_meta['stats'] = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) # imagenet_stats self.device = device self.parallel = config.parallel self.device_ids = device_ids self.L = L self.label_emb_guessor, self.emb_dim = self._get_label_guessor() self.model = self._set_model(config.parallel, device_ids) self.optim = self._get_optimiser(self.config) if not self.config.no_amp: from apex import amp self.model, self.optim = amp.initialize(self.model, self.optim, opt_level="O1") if self.config.parallel: self.model = nn.DataParallel(self.model) # initialise the exponential moving average model self.ema = EMA(self.model, self.config.ema_alpha) if self.config.use_pretrained: self.load_model_state(config.checkpoint_path) if self.config.freeze_backbone: self._freeze_model_backbone() self.logger, self.writer, self.time_stamp = self._setup_default_logging() def train(self, labelled_data, valid_data=None, training_config=None, save_best_model=False): """ SemCo training function. labelled_data: dictionary holding labelled data in the form {'train/img1.png' : 'classA', ...}. This is relative to the dataset directory valid_data: dictionary holding validation data in the form {'train/img1.png' : 'classA', ...}. This is relative to the dataset directory training_config: to override parser config entirely if needed. save_best_model: if true, the best model (model state, ema state, optimizer state, classes) will be saved under './saved_models' directory. """ # to allow overriding training params for different runs of training if training_config is None: training_config = self.config else: self.config = training_config L = len(labelled_data) n_iters_per_epoch, n_iters_all = self._init_training(training_config, L) # define criterion for upper(semantic embedding) and lower(discrete label) paths crit_lower = lambda inp, targ: F.cross_entropy(inp, targ, reduction='none') crit_upper = lambda inp, targ: 1 - F.cosine_similarity(inp, targ) optim = self.optim if n_iters_all == 0: n_iters_all = 1 # to avoid division by zero if we choose to set epochs to zero to skip a round lr_schdlr = WarmupCosineLrScheduler(optim, max_iter=n_iters_all, warmup_iter=0) num_workers = 0 if self.device == 'cpu' else training_config.num_workers_per_gpu * torch.cuda.device_count() if self.parallel else 4 dltrain_x, dltrain_u = self._get_train_loaders(labelled_data, n_iters_per_epoch, num_workers, pin_memory=True, cache_imgs=training_config.cache_imgs) print(f'Num of Labeled Training Data: {len(dltrain_x.dataset)}\nNum of Unlabeled Training Data:{len(dltrain_u.dataset)}') if valid_data: dlvalid = self._get_val_loader(valid_data, num_workers, pin_memory=True, cache_imgs=training_config.cache_imgs) print(f'Num of Validation Data: {len(dlvalid.dataset)}') train_args = dict(n_iters=n_iters_per_epoch, optim=optim, crit_lower=crit_lower, crit_upper=crit_upper, lr_schdlr=lr_schdlr, dltrain_x=dltrain_x, dltrain_u=dltrain_u) best_acc = -1 best_epoch = 0 best_loss = 1e6 early_stopping_counter = 0 best_metric = best_acc if training_config.es_metric == 'accuracy' else best_loss self.logger.info('-----------start training--------------') epochs_iterator = range(training_config.n_epoches) if not self.config.no_progress_bar else \ tqdm(range(training_config.n_epoches),desc='Epoch') # so that it displays the bar per epoch not per iteration for epoch in epochs_iterator: # training starts here train_loss, loss_x, loss_u, mask_mean, \ loss_emb_x, loss_emb_u, mask_emb, mask_combined = \ self._train_one_epoch(epoch, **train_args) if valid_data: top1, top5, valid_loss, top1_emb, top5_emb, top1_combined = self._evaluate(dlvalid, crit_lower) if valid_data: self.writer.add_scalars('train/1.loss', {'train': train_loss, 'test': valid_loss}, epoch) else: self.writer.add_scalar('train/1.loss', train_loss, epoch) self.writer.add_scalar('train/2.train_loss_x', loss_x, epoch) self.writer.add_scalar('train/2.train_loss_emb_x', loss_emb_x, epoch) self.writer.add_scalar('train/3.train_loss_u', loss_u, epoch) self.writer.add_scalar('train/3.train_loss_emb_u', loss_emb_u, epoch) self.writer.add_scalar('train/5.mask_mean', mask_mean, epoch) self.writer.add_scalar('train/5.mask_emb_mean', mask_emb, epoch) self.writer.add_scalar('train/5.mask_combined_mean', mask_combined, epoch) if valid_data: self.writer.add_scalars('test/1.test_acc', {'top1': top1, 'top5': top5, 'top1_emb': top1_emb, 'top5_emb': top5_emb, 'top1_combined': top1_combined}, epoch) best_current = top1 if training_config.es_metric == 'accuracy' else valid_loss # only start looking for best model after min_wait period has expired if epoch >= training_config.min_wait_before_es: isworse = lambda best,current: best <= current if training_config.es_metric == 'accuracy' else best >= current if isworse(best_metric, best_current): best_metric = best_current best_epoch = epoch if training_config.early_stopping_epochs: best_model_state = self.model.state_dict() best_ema_state = {k:v.clone().detach() for k,v in self.ema.shadow.items()} early_stopping_counter = 0 if save_best_model: try: self._save_checkpoint() except Exception as e: print(f'Failed to save checkpoint: {e}') elif training_config.early_stopping_epochs: early_stopping_counter +=1 else: print('Minimum wait period still not expired. Leaving best epoch and best metric to default values') self.logger.info( "Epoch {}. Top1: {:.4f}. Top5: {:.4f}. Top1_emb: {:.4f}. Top5_emb: {:.4f}. Top1_comb: {:.4f}. best_metric: {:.4f} in epoch{}". format(epoch, top1, top5, top1_emb, top5_emb, top1_combined, best_metric, best_epoch)) # check if early stopping is to be activated if training_config.early_stopping_epochs and early_stopping_counter == training_config.early_stopping_epochs: self.logger.info(f"Early stopping activated, loading best models and ending training. " f"{training_config.early_stopping_epochs} epochs with no improvement.") self.model.load_state_dict(best_model_state) self.ema.shadow = best_ema_state break # this will only be activated in last epoch to decide whether best model should be loaded or not before ending training if epoch == training_config.n_epoches-1: self.logger.info(f"Break epoch is reached") # in case early stopping is configured, load best model before exiting (edge case for early stopping) if training_config.early_stopping_epochs and valid_data and epoch >= training_config.min_wait_before_es: self.logger.info(f"Loading best model and ending training (since early stopping is set)") self.model.load_state_dict(best_model_state) self.ema.shadow = best_ema_state self.writer.close() def predict(self): num_work = 0 if self.device == 'cpu' else 4 dataloader = self._get_test_loader(num_work, pin_memory=True, cache_imgs=self.config.cache_imgs) # using EMA params to evaluate performance self.ema.apply_shadow() self.ema.model.eval() self.ema.model.to(self.device) predictions = [] with torch.no_grad(): for ims in dataloader: ims = ims.to(self.device) logits, _, _ = self.ema.model(ims) probs = torch.softmax(logits, dim=1) scores, lbs_guess = torch.max(probs, dim=1) predictions.append(lbs_guess) predictions = torch.cat(predictions).cpu().detach().numpy() predictions = [self.dataset_meta['classes'][elem] for elem in predictions] filenames = [name.split('/')[-1] for name in dataloader.dataset.data] df = pd.DataFrame({'id': filenames, 'class': predictions}) # note roll back model current params to continue training self.ema.restore() return df def load_model_state(self, chkpt_dict_path): ''' Loads model state based on a checkpoint saved by SemCo _save_checkpoint() function. ''' print("Loading Model State") checkpoint_dict = torch.load(chkpt_dict_path, map_location=self.device) if 'model_state_dict' in checkpoint_dict: state_dict = checkpoint_dict['model_state_dict'] else: print('model_state_dict key is not present in checkpoint, loading pretrained model failed, using original initialization for model') return # handle state_dictionaries where keys has 'module' in them (if the model was wrapped in nn.DataParallel) if all(['module' in key for key in state_dict.keys()]): if all(['module' in key for key in self.model.state_dict()]): pass else: state_dict= {k.replace('module.',''):v for k,v in state_dict.items()} if 'ema_shadow' in checkpoint_dict: checkpoint_dict['ema_shadow'] = {k.replace('module.',''):v for k,v in checkpoint_dict['ema_shadow'].items()} try: self.model.load_state_dict(state_dict) except Exception as e: print(f'Problem occurred during naive state_dict loading: {e}.\nTrying to only load common params') try: model_state= self.model.state_dict() pretrained_state = {k:v for k,v in state_dict.items() if k in model_state and v.size() == model_state[k].size()} unloaded_state = set(list(state_dict.keys())) - set(list(model_state.keys())) model_state.update(pretrained_state) self.model.load_state_dict(model_state) print(f'Success. Following params in pretrained_state_dict were not loaded: {unloaded_state}') except Exception as e: print(f'Unable to load model state due to following error. Model will be initialised randomly. \n {e}') if 'ema_shadow' in checkpoint_dict: try: self.ema = EMA(self.model, self.config.ema_alpha) similar_params = {k:v for k,v in checkpoint_dict['ema_shadow'].items() if k in self.ema.shadow and v.size() == self.ema.shadow[k].size()} self.ema.shadow.update(similar_params) print(f'EMA shadow has been loaded successfully. {len(similar_params)} out of {len(self.ema.shadow)} params were loaded') except Exception as e: print(f'Unable to load EMA shadow. EMA will be reinitialised with current model params. {e}') self.ema = EMA(self.model, self.config.ema_alpha) else: print('EMA shadow is not found in checkpoint dictionary. EMA will be reinitialised with current model params.') self.ema = EMA(self.model, self.config.ema_alpha) try: if 'classes' in checkpoint_dict: classes = self.dataset_meta['classes'] classes_model = checkpoint_dict['classes'] if all([classes_model[i] == classes[i] for i in range(len(classes))]): print(f'classes matched successfully') else: print( "Classes loaded don't match the classes used while training the model, output of softmax can't be trusted") except Exception as e: print("can't load classes file. Pls check and try again.") return def adapt(self, num_classes): ''' To allow adapting the model to a different dataset with the same semantic classifier weights num_classes: number of classes in the target dataset return: None ''' if isinstance(self.model, torch.nn.DataParallel): self.model.module.adapt(num_classes) else: self.model.adapt(num_classes) self.model.to(self.device) self.ema = EMA(self.model, self.config.ema_alpha) def _evaluate(self, dataloader, criterion): # using EMA params to evaluate performance self.ema.apply_shadow() self.ema.model.eval() self.ema.model.to(self.device) loss_meter = AverageMeter() top1_meter = AverageMeter() top5_meter = AverageMeter() top1_emb_meter = AverageMeter() top5_emb_meter = AverageMeter() top1_combined_meter = AverageMeter() with torch.no_grad(): for ims, lbs in dataloader: ims = ims.to(self.device) lbs = lbs.to(self.device) logits, logits_emb, _ = self.ema.model(ims) sim = F.cosine_similarity(logits_emb.unsqueeze(1), self.label_emb_guessor.embedding_matrix.unsqueeze(0), dim=-1) sim = sim * self.label_emb_guessor.sharpening_factor loss = criterion(logits, lbs).mean() scores_emb = torch.softmax(sim, -1) scores = torch.softmax(logits, dim=1) top1, top5 = accuracy(scores, lbs, (1, 5)) top1_emb, top5_emb = accuracy(scores_emb, lbs, (1, 5)) scores_combined = torch.mean(torch.stack([scores_emb, scores]), dim=0) top1_combined, _ = accuracy(scores_combined, lbs, (1, 5)) loss_meter.update(loss.item()) top1_meter.update(top1.item()) top5_meter.update(top5.item()) top1_emb_meter.update(top1_emb.item()) top5_emb_meter.update(top5_emb.item()) top1_combined_meter.update(top1_combined.item()) # note roll back model current params to continue training self.ema.restore() return top1_meter.avg, top5_meter.avg, loss_meter.avg, top1_emb_meter.avg, top5_emb_meter.avg, top1_combined_meter.avg def _set_model(self, parallel, device_ids): classes = self.dataset_meta['classes'] n = len(classes) if self.config.model_backbone is not None: if self.config.model_backbone == 'wres': model = WideResnetWithEmbeddingHead(num_classes=n, k=self.config.wres_k, n=28, emb_dim=self.emb_dim) elif self.config.model_backbone == 'resnet18': model = ResNet18WithEmbeddingHead(num_classes=n, emb_dim=self.emb_dim, pretrained=not self.config.no_imgnet_pretrained) elif self.config.model_backbone == 'resnet50': model = ResNet50WithEmbeddingHead(num_classes=n, emb_dim=self.emb_dim, pretrained=not self.config.no_imgnet_pretrained) # if no backbone is passed in args, auto infer based on im size elif self.config.im_size <= 64: model = WideResnetWithEmbeddingHead(num_classes=n, k=self.config.wres_k, n=28, emb_dim=self.emb_dim) else: model = ResNet50WithEmbeddingHead(num_classes=n, emb_dim=self.emb_dim, pretrained=not self.config.no_imgnet_pretrained) model.to(self.device) return model def _freeze_model_backbone(self): for name, param in self.model.named_parameters(): if 'fc_emb' in name or 'fc_classes' in name: param.requires_grad = True print(f'{name} parameter is unfrozen') else: param.requires_grad = False print('All remaining parameters are frozen.') def _train_one_epoch(self, epoch, n_iters, optim, crit_lower, crit_upper, lr_schdlr, dltrain_x, dltrain_u): # note: _x denotes supervised and _u denotes unsupervised # note: when suffix '_emb' is appended to variable, it denotes same variable but for upper path # Renaming for consistency criteria_x = crit_lower criteria_u = crit_lower criteria_x_emb = crit_upper criteria_u_emb = crit_upper if not self.config.no_amp: from apex import amp self.model.train() loss_meter = AverageMeter() loss_x_meter = AverageMeter() loss_u_meter = AverageMeter() loss_emb_x_meter = AverageMeter() loss_emb_u_meter = AverageMeter() # the number of gradient-considered strong augmentation (logits above threshold) of unlabeled samples n_strong_aug_meter = AverageMeter() max_score = AverageMeter() max_score_emb = AverageMeter() mask_meter = AverageMeter() mask_emb_meter = AverageMeter() mask_combined_meter = AverageMeter() epoch_start = time.time() # start time dl_x, dl_u = iter(dltrain_x), iter(dltrain_u) iterator = range(n_iters) if self.config.no_progress_bar else tqdm(range(n_iters), desc='Epoch {}'.format(epoch)) for it in iterator: ims_x_weak, ims_x_strong, lbs_x = next(dl_x) ims_u_weak, ims_u_strong = next(dl_u) lbs_x = lbs_x.to(self.device) bt = ims_x_weak.size(0) mu = int(ims_u_weak.size(0) // bt) imgs = torch.cat([ims_x_weak, ims_u_weak, ims_u_strong], dim=0).to(self.device) imgs = interleave(imgs, 2 * mu + 1) logits, logits_emb, _ = self.model(imgs) del imgs logits = de_interleave(logits, 2 * mu + 1) logits_x = logits[:bt] logits_u_w, logits_u_s = torch.split(logits[bt:], bt * mu) del logits logits_emb = de_interleave(logits_emb, 2 * mu + 1) logits_emb__x = logits_emb[:bt] logits_emb_u_w, logits_emb_u_s = torch.split(logits_emb[bt:], bt * mu) del logits_emb # supervised loss for upper and lower paths loss_x = criteria_x(logits_x, lbs_x).mean() loss_x_emb = criteria_x_emb(logits_emb__x, self.label_emb_guessor.embedding_matrix[lbs_x]).mean() # guessing the labels for upper and lower paths with torch.no_grad(): probs = torch.softmax(logits_u_w, dim=1) scores, lbs_u_guess = torch.max(probs, dim=1) mask = scores.ge(self.config.thr).float() # get label guesses and mask based on embedding predictions (upper path) lbs_emb_u_guess, mask_emb, scores_emb, lbs_guess_help = self.label_emb_guessor(logits_emb_u_w) # combining the losses via co-training (blind version) mask_combined = mask.bool() | mask_emb.bool() # each loss path will have two components (co-training implementation) loss_u = (criteria_u(logits_u_s, lbs_u_guess) * mask).mean() + \ (criteria_u(logits_u_s, lbs_guess_help) * mask_emb).mean() * (self.config.lambda_emb) / 3 loss_u_emb = (criteria_u_emb(logits_emb_u_s, lbs_emb_u_guess) * mask_emb).mean() + \ (criteria_u_emb(logits_emb_u_s, self.label_emb_guessor.embedding_matrix[lbs_u_guess]) * mask).mean() loss_lower = loss_x + self.config.lam_u * loss_u loss_upper = loss_x_emb + self.config.lam_u * loss_u_emb loss = loss_lower + self.config.lambda_emb * loss_upper optim.zero_grad() if not self.config.no_amp: with amp.scale_loss(loss, optim) as scaled_loss: scaled_loss.backward() else: loss.backward() optim.step() self.ema.update_params() lr_schdlr.step() loss_meter.update(loss.item()) loss_x_meter.update(loss_x.item()) loss_u_meter.update(loss_u.item()) mask_meter.update(mask.mean().item()) n_strong_aug_meter.update(mask_emb.sum().item()) max_score.update(scores.mean()) max_score_emb.update(scores_emb.mean()) loss_emb_x_meter.update(loss_x_emb.item()) loss_emb_u_meter.update(loss_u_emb.item()) mask_combined_meter.update(mask_combined.float().mean().item()) mask_emb_meter.update(mask_emb.mean().item()) if (it + 1) % 512 == 0: t = time.time() - epoch_start lr_log = [pg['lr'] for pg in optim.param_groups] lr_log = sum(lr_log) / len(lr_log) self.logger.info("epoch:{}, iter: {}. loss: {:.4f}. loss_u: {:.4f}. loss_x: {:.4f}. max_score:{:.4f}. " " Mask:{:.4f} loss_u_emb:{:.4f}. loss_x_emb:{:.4f}. mask_emb:{:.4f}. max_score_emb:{:.4f}. mask_emb_count:{:.4f}. mask_combined:{:.4f}. . LR: {:.4f}. Time: {:.2f}".format( epoch, it + 1, loss_meter.avg, loss_u_meter.avg, loss_x_meter.avg, max_score.avg, mask_meter.avg, loss_emb_u_meter.avg, loss_emb_x_meter.avg, mask_emb_meter.avg, max_score_emb.avg, n_strong_aug_meter.avg, mask_combined_meter.avg, lr_log, t)) epoch_start = time.time() self.ema.update_buffer() return loss_meter.avg, loss_x_meter.avg, loss_u_meter.avg, mask_meter.avg, \ loss_emb_x_meter.avg, loss_emb_u_meter.avg, mask_emb_meter.avg, mask_combined_meter.avg def _get_train_loaders(self, labelled_data, n_iters_per_epoch, num_workers, pin_memory, cache_imgs): mean, std = self.dataset_meta['stats'] kwargs = dict(dataset_path=self.config.dataset_path, classes=self.dataset_meta['classes'], labelled_data=labelled_data, batch_size=self.config.batch_size, mu=self.config.mu, n_iters_per_epoch=n_iters_per_epoch, size=self.config.im_size, cropsize=self.config.cropsize, mean=mean, std=std, num_workers=num_workers, pin_memory=pin_memory, cache_imgs=cache_imgs) return get_train_loaders(**kwargs) def _get_val_loader(self, valid_data, num_workers, pin_memory, cache_imgs): mean, std = self.dataset_meta['stats'] kwargs = dict(dataset_path=self.config.dataset_path, classes=self.dataset_meta['classes'], labelled_data=valid_data, batch_size=3 * self.config.batch_size, size=self.config.im_size, cropsize=self.config.cropsize, mean=mean, std=std, num_workers=num_workers, pin_memory=pin_memory, cache_imgs=cache_imgs) return get_val_loader(**kwargs) def _get_test_loader(self, num_workers, pin_memory, cache_imgs): mean, std = self.dataset_meta['stats'] kwargs = dict(dataset_path=self.config.dataset_path, classes=self.dataset_meta['classes'], batch_size=3 * self.config.batch_size, size=self.config.im_size, cropsize=self.config.cropsize, mean=mean, std=std, num_workers=num_workers, pin_memory=pin_memory, cache_imgs=cache_imgs) return get_test_loader(**kwargs) def _get_label_guessor(self): classes = self.dataset_meta['classes'] class_2_embeddings_dict = get_labels2wv_dict(classes, self.config.word_vec_path) emb_dim = len(list(class_2_embeddings_dict.values())[0]) if self.config.eps is None: eps = 0.15 if emb_dim < 100 else 0.2 if emb_dim < 256 else 0.28 # for label grouping clustering else: eps = self.config.eps label_group_idx, gr_mapping = get_grouping(class_2_embeddings_dict, eps=eps, return_mapping=True) label_guessor = LabelEmbeddingGuessor(classes, label_group_idx, class_2_embeddings_dict, self.config.thr_emb, self.device) return label_guessor, emb_dim def _setup_default_logging(self, default_level=logging.INFO): format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" dataset_name = get_dataset_name(self.config.dataset_path) output_dir = os.path.join(dataset_name, f'x{self.L}') os.makedirs(output_dir, exist_ok=True) writer = SummaryWriter(comment=f'{dataset_name}_{self.L}') logger = logging.getLogger('train') logger.setLevel(default_level) time_stamp = time_str() logging.basicConfig( # unlike the root logger, a custom logger can’t be configured using basicConfig() filename=os.path.join(output_dir, f'{time_stamp}_{self.L}_labelled_instances.log'), format=format, datefmt="%m/%d/%Y %H:%M:%S", level=default_level) # to avoid double printing when creating new instances of class if not logger.handlers: console_handler = logging.StreamHandler(sys.stdout) console_handler.setLevel(default_level) console_handler.setFormatter(logging.Formatter(format)) logger.addHandler(console_handler) # logger.info(dict(self.config._get_kwargs())) if self.device != 'cpu': logger.info(f'Device used: {self.device}_{torch.cuda.get_device_name(self.device)}') logger.info(f'Model: {self.model.module.__class__ if isinstance(self.model, torch.nn.DataParallel) else self.model.__class__}') logger.info(f'Num_labels: {self.L}') logger.info(f'Image_size: {self.config.im_size}') logger.info(f'Cropsize: {self.config.cropsize}') logger.info("Total params: {:.2f}M".format( sum(p.numel() for p in self.model.parameters()) / 1e6)) return logger, writer, time_stamp def _init_training(self, training_config, L): n_iters_per_epoch = training_config.n_imgs_per_epoch // training_config.batch_size n_iters_all = n_iters_per_epoch * training_config.n_epoches if training_config.seed > 0: torch.manual_seed(training_config.seed) random.seed(training_config.seed) np.random.seed(training_config.seed) self.logger.info("***** Running training *****") self.logger.info(f" Num Epochs = {training_config.n_epoches}") self.logger.info(f" Early Stopping Epochs Patience = " f"{training_config.early_stopping_epochs if training_config.early_stopping_epochs else None}") self.logger.info(f" Minimum Wait before ES = {training_config.min_wait_before_es} epochs") self.logger.info(f" Batch size Labelled = {training_config.batch_size}") self.logger.info(f" Total optimization steps = {n_iters_all}") return n_iters_per_epoch, n_iters_all def _get_optimiser(self, training_config): # set weight decay to zero for batch-norm layers wd_params, non_wd_params = [], [] for name, param in self.model.named_parameters(): if 'bn' in name: non_wd_params.append(param) # bn.weight, bn.bias and classifier.bias else: wd_params.append(param) param_list = [{'params': wd_params}, {'params': non_wd_params, 'weight_decay': 0}] optim = torch.optim.SGD(param_list, lr=training_config.lr, weight_decay=training_config.weight_decay, momentum=training_config.momentum, nesterov=True) return optim def _save_checkpoint(self): save_dir = 'saved_models' #os.path.abspath(os.path.join(self.config.checkpoint_path, os.pardir)) if not os.path.exists(save_dir): os.mkdir(save_dir) dataset_name = get_dataset_name(self.config.dataset_path) model_name = self.model.module._get_name() if isinstance(self.model, torch.nn.DataParallel) else self.model._get_name() checkpoint = {'ema_shadow':self.ema.shadow, 'model_state_dict': self.model.state_dict(), 'classes': self.dataset_meta['classes']} fpath = f'{save_dir}/{model_name}_{dataset_name}_{self.time_stamp}_checkpoint_dict.pth' torch.save(checkpoint,fpath) self.logger.info(f'Model Saved in: {fpath}')
def train(args, data): device = torch.device( f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu") model = BiDAF(args).to(device) D_batch = args.train_batch_size ema = EMA(args.exp_decay_rate) for name, param in model.named_parameters(): if param.requires_grad: ema.register(name, param.data) parameters = filter(lambda p: p.requires_grad, model.parameters()) optimizer = optim.Adadelta(parameters, lr=args.learning_rate) criterion = nn.CrossEntropyLoss() # writer = SummaryWriter(log_dir='runs/' + args.model_time) model.train() loss, last_epoch = 0, -1 max_dev_exact, max_dev_f1 = -1, -1 i = 0 # iterator = data.train_iter while i + D_batch < len(data.data): b_id = i e_id = i + D_batch # present_epoch = int(iterator.epoch) # if present_epoch == args.epoch: # break # if present_epoch > last_epoch: # print('epoch:', present_epoch + 1) # last_epoch = present_epoch p1, p2 = model(data, b_id, e_id) optimizer.zero_grad() s_idx, e_idx = data.get_targ(b_id, e_id) batch_loss = criterion(p1, s_idx) + criterion(p2, e_idx) loss += batch_loss.item() batch_loss.backward() optimizer.step() for name, param in model.named_parameters(): if param.requires_grad: ema.update(name, param.data) # if (i + 1) % args.print_freq == 0: # dev_loss, dev_exact, dev_f1 = test(model, ema, args, data) # c = (i + 1) // args.print_freq # # writer.add_scalar('loss/train', loss, c) # # writer.add_scalar('loss/dev', dev_loss, c) # # writer.add_scalar('exact_match/dev', dev_exact, c) # # writer.add_scalar('f1/dev', dev_f1, c) # # print(f'train loss: {loss:.3f} / dev loss: {dev_loss:.3f}' # # f' / dev EM: {dev_exact:.3f} / dev F1: {dev_f1:.3f}') # if dev_f1 > max_dev_f1: # max_dev_f1 = dev_f1 # max_dev_exact = dev_exact # best_model = copy.deepcopy(model) # loss = 0 # model.train() i += D_batch # writer.close() print(f'max dev EM: {max_dev_exact:.3f} / max dev F1: {max_dev_f1:.3f}') return best_model