def kFoldTrain(self): cfg = self.cfg # make folds self.train_dataset.shuffle() folds = self.make_folds() loss_t_mean, acc_t_mean, loss_v_mean, acc_v_mean = 0, 0, 0, 0 for k in range(self.k_fold): # fold-k will be valid set logger.info("=" * 5 + " Epoch %d - Fold %d " % (self.epoch, k) + "=" * 5) train = [] valid = folds[k] for i in range(self.k_fold): # folds except k will be train set if i != k: train = train + folds[i] train_dataloader = DataLoader(train, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg)) valid_dataloader = DataLoader(valid, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg)) loss_t, acc_t = self.train(train_dataloader) loss_v, acc_v = self.validate(valid_dataloader) loss_t_mean += loss_t / self.k_fold acc_t_mean += acc_t / self.k_fold loss_v_mean += loss_v / self.k_fold acc_v_mean += acc_v / self.k_fold return loss_t_mean, acc_t_mean, loss_v_mean, acc_v_mean
def get_data_loaders_new(args, tokenizer): train_data = get_dataset(tokenizer, args.train_path, args.fea_path, n_history=args.max_history) valid_data = get_dataset(tokenizer, args.valid_path, args.fea_path, n_history=args.max_history) train_dataset = AVSDDataSet(train_data[0], tokenizer, (train_data[1], valid_data[1]), drop_rate=0, train=True) valid_dataset = AVSDDataSet(valid_data[0], tokenizer, (valid_data[1], train_data[1]), drop_rate=0, train=False) train_loader = DataLoader(train_dataset, shuffle=(not args.distributed), batch_size=args.train_batch_size, num_workers=4, collate_fn=lambda x: collate_fn( x, tokenizer.pad_token_id, features=True)) valid_loader = DataLoader(valid_dataset, shuffle=False, batch_size=args.valid_batch_size, num_workers=4, collate_fn=lambda x: collate_fn( x, tokenizer.pad_token_id, features=True)) return train_loader, valid_loader
def main(args): model = RCNN(vocab_size=args.vocab_size, embedding_dim=args.embedding_dim, hidden_size=args.hidden_size, hidden_size_linear=args.hidden_size_linear, class_num=args.class_num, dropout=args.dropout).to(args.device) if args.n_gpu > 1: model = torch.nn.DataParallel(model, dim=0) train_texts, train_labels = read_file(args.train_file_path) word2idx = build_dictionary(train_texts, vocab_size=args.vocab_size) logger.info('Dictionary Finished!') full_dataset = CustomTextDataset(train_texts, train_labels, word2idx) num_train_data = len(full_dataset) - args.num_val_data train_dataset, val_dataset = random_split( full_dataset, [num_train_data, args.num_val_data]) train_dataloader = DataLoader(dataset=train_dataset, collate_fn=lambda x: collate_fn(x, args), batch_size=args.batch_size, shuffle=True) valid_dataloader = DataLoader(dataset=val_dataset, collate_fn=lambda x: collate_fn(x, args), batch_size=args.batch_size, shuffle=True) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) train(model, optimizer, train_dataloader, valid_dataloader, args) logger.info('******************** Train Finished ********************') # Test if args.test_set: test_texts, test_labels = read_file(args.test_file_path) test_dataset = CustomTextDataset(test_texts, test_labels, word2idx) test_dataloader = DataLoader(dataset=test_dataset, collate_fn=lambda x: collate_fn(x, args), batch_size=args.batch_size, shuffle=True) model.load_state_dict( torch.load(os.path.join(args.model_save_path, "best.pt"))) _, accuracy, precision, recall, f1, cm = evaluate( model, test_dataloader, args) logger.info('-' * 50) logger.info( f'|* TEST SET *| |ACC| {accuracy:>.4f} |PRECISION| {precision:>.4f} |RECALL| {recall:>.4f} |F1| {f1:>.4f}' ) logger.info('-' * 50) logger.info('---------------- CONFUSION MATRIX ----------------') for i in range(len(cm)): logger.info(cm[i]) logger.info('--------------------------------------------------')
def visualize(model, dataset, doc): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") """ # Predicts, and visualizes one document with html file :param model: pretrained model :param dataset: news20 dataset :param doc: document to feed in :return: html formatted string for whole document """ orig_doc = [word_tokenize(sent) for sent in sent_tokenize(doc)] doc, num_sents, num_words = dataset.transform(doc) label = 0 # dummy label for transformation doc, label, doc_length, sent_length = collate_fn([(doc, label, num_sents, num_words)]) score, word_att_weight, sentence_att_weight \ = model(doc.to(device), doc_length.to(device), sent_length.to(device)) # predicted = int(torch.max(score, dim=1)[1]) classes = ['Cryptography', 'Electronics', 'Medical', 'Space'] result = "<h2>Attention Visualization</h2>" bar_chart(classes, torch.softmax(score.detach(), dim=1).flatten().cpu(), 'Prediction') result += '<br><img src="prediction_bar_chart.png"><br>' for orig_sent, att_weight, sent_weight in zip( orig_doc, word_att_weight[0].tolist(), sentence_att_weight[0].tolist()): result += map_sentence_to_color(orig_sent, att_weight, sent_weight) return result
def translate(translator, src_seq, src_pos, domain): src_word = Constants.BOS_SRC tgt_word = Constants.BOS_TGT if domain == Constants.BOS_TGT: src_word, tgt_word = tgt_word, src_word # s2t by previous model tgt_hyp, _ = translator.translate_batch(src_seq, src_pos, domain) tgt_hyp = [[tgt_word] + t_hyp[0] + [Constants.EOS] for t_hyp in tgt_hyp] tgt_seq_hyp, tgt_pos_hyp = collate_fn(tgt_hyp) return tgt_seq_hyp, tgt_pos_hyp
def get_data_loaders(args, tokenizer): dev_dataset = InferenceDataset('dev', tokenizer, args) train_dataset = InferenceDataset('train', tokenizer, args) if args.small_data != -1: logger.info('Using small subset of data') dev_dataset = Subset(dev_dataset, list(range(args.small_data))) train_dataset = Subset(train_dataset, list(range(args.small_data))) dev_dataloader = DataLoader(dev_dataset, batch_size=args.batch_size, shuffle=(not args.distributed), num_workers=8, collate_fn=lambda x: collate_fn(x, tokenizer.eos_token_id, args)) train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=(not args.distributed), num_workers=8, collate_fn=lambda x: collate_fn(x, tokenizer.eos_token_id, args)) return train_dataloader, dev_dataloader
def get_data_loaders_new(args, tokenizer): train_data = get_dataset(tokenizer, args.train_path, args.fea_path, n_history=args.max_history) #with open("train_data_gpt2.pkl", "rb") as f: # train_data = pkl.load(f) # pkl.dump(train_data, f) valid_data = get_dataset(tokenizer, args.valid_path, args.fea_path, n_history=args.max_history) #with open("valid_data_gpt2.pkl", "rb") as f: # valid_data = pkl.load(f) # pkl.dump(valid_data, f) train_dataset = AVSDDataSet(train_data[0], tokenizer, (train_data[1], valid_data[1]), drop_rate=0, train=True) valid_dataset = AVSDDataSet(valid_data[0], tokenizer, (valid_data[1], train_data[1]), drop_rate=0, train=False) train_loader = DataLoader(train_dataset, batch_size=args.train_batch_size, num_workers=4, shuffle=(not args.distributed), collate_fn=lambda x: collate_fn( x, tokenizer.pad_token_id, features=True)) valid_loader = DataLoader(valid_dataset, batch_size=args.valid_batch_size, num_workers=4, shuffle=False, collate_fn=lambda x: collate_fn( x, tokenizer.pad_token_id, features=True)) return train_loader, valid_loader
def test(model, tokenizer, test_data, args): logger.info("Test starts!") model_load(args.model_dir, model) model = model.to(device) test_dataset = QueryDataset(test_data) test_data_loader = DataLoader( test_dataset, sampler=SequentialSampler(test_dataset), batch_size=args.bsz, num_workers=args.num_workers, collate_fn=lambda x: collate_fn(x, tokenizer, args.sample, args. max_seq_len)) test_loss, test_str = evaluate(model, test_data_loader) logger.info(f"| test | {test_str}")
def test(self): cfg = self.cfg test_dataloader = DataLoader(self.test_dataset, batch_size=cfg.batch_size, shuffle=False, collate_fn=collate_fn(cfg)) results = [] for (X, _) in test_dataloader: with torch.no_grad(): for (k, v) in X.items(): X[k] = v.to(self.device) y_pred = self.model(X) result = F.softmax(y_pred, dim=-1)[:, 1].to('cpu').tolist() results += result with open(os.path.join(cfg.cwd, cfg.result_file), 'w') as f: f.write('\t'.join('%s' % r for r in results))
def decode(model, src_seq, src_pos, ctx_seq, ctx_pos, args, token_len): translator = Translator(max_token_seq_len=args.max_token_seq_len, beam_size=10, n_best=1, device=args.device, bad_mask=None, model=model) tgt_seq = [] all_hyp, all_scores = translator.translate_batch(src_seq, src_pos, ctx_seq, ctx_pos) for idx_seqs in all_hyp: # batch idx_seq = idx_seqs[0] # n_best=1 end_pos = len(idx_seq) for i in range(len(idx_seq)): if idx_seq[i] == Constants.EOS: end_pos = i break # tgt_seq.append([Constants.BOS] + idx_seq[:end_pos][:args.max_word_seq_len] + [Constants.EOS]) tgt_seq.append(idx_seq[:end_pos][:args.max_word_seq_len]) batch_seq, batch_pos = collate_fn(tgt_seq, max_len=token_len) return batch_seq.to(args.device), batch_pos.to(args.device)
def visualize_doc(model, dataset, doc, answer): # 입력된 doc을 사전에 학습된 모델에 넣고 weight 시각화 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") """ # Predicts, and visualizes one document with html file :param model: pretrained model :param dataset: news20 dataset :param doc: document to feed in :return: html formatted string for whole document """ #문장 분리 후 단어 분리 orig_doc = [word_tokenize(sent) for sent in sent_tokenize(doc)] # doc: doc, num_sents, num_words = dataset.transform(doc) label = 0 # dummy label for transformation doc, label, doc_length, sent_length = collate_fn([(doc, label, num_sents, num_words)]) score, word_att_weight, sentence_att_weight \ = model(doc.to(device), doc_length.to(device), sent_length.to(device)) predict = torch.argmax(score.detach(), dim=1).flatten().cpu() if predict == answer: #모델이 답을 맞춘 경우 result = "<p>Examples of correct prediction results:</p>" result += '<input type="text" name="serial" value="%s" >' % (answer) elif predict != answer: #모델이 답을 틀린 경우 result = "<p>Examples of wrong prediction results:</p>" result += '<input type="text" name="serial" value="%s" >' % (predict) result += '<input type="text" name="serial" value="%s" >' % (answer) for orig_sent, att_weight, sent_weight in zip( orig_doc, word_att_weight[0].tolist(), sentence_att_weight[0].tolist()): result += map_sentence_to_color(orig_sent, att_weight, sent_weight) return result
def visualize_chart(model, dataset, doc): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # transform(doc) 한 example에 대한 결과 출력. doc, num_sents, num_words = dataset.transform(doc[0]) label = 0 # dummy label for transformation doc, label, doc_length, sent_length = collate_fn([(doc, label, num_sents, num_words)]) score, word_att_weight, sentence_att_weight \ = model(doc.to(device), doc_length.to(device), sent_length.to(device)) # predicted = int(torch.max(score, dim=1)[1]) classes = ['0', '1', '2'] # classes = ['Cryptography', 'Electronics', 'Medical', 'Space'] result = "<h2>Attention Visualization</h2>" bar_chart(classes, torch.softmax(score.detach(), dim=1).flatten().cpu(), 'Prediction') result += '<br><img src="prediction_bar_chart.png"><br>' return result
def test_collate_fn(self): mels = (-np.ones( (2, 2), dtype=np.float), np.ones((3, 2), dtype=np.float)) seqs = ([1, 2], [1, 2, 3]) ids = ('mel_1', 'mel_2') mel_lens = (2, 3) batch = tuple(zip(seqs, mels, ids, mel_lens)) seqs, mels, stops, ids, mel_lens = collate_fn(batch=batch, r=3, silence_len=0) expected_seqs = np.array([[1, 2, 0], [1, 2, 3]]) np.testing.assert_almost_equal(seqs, expected_seqs, decimal=8) expected_mels = np.array([[[-1, -1], [-1, -1], [-1, -1]], [[1, 1], [1, 1], [1, 1]]]) np.testing.assert_almost_equal(mels, expected_mels, decimal=8) expected_stops = np.array([[0, 1, 0], [0, 0, 1]]) np.testing.assert_almost_equal(stops, expected_stops, decimal=8) expected_lens = np.array([2, 3]) np.testing.assert_almost_equal(mel_lens, expected_lens, decimal=8)
print("{}~{}".format(len(valid_data[i]['noisy']), len(valid_data[i + 99]['noisy']))) except: print("last batch: ", i, len(valid_data)) print("{}~{}".format(len(valid_data[i]['noisy']), len(valid_data[-1]['noisy']))) valid_dataset = TextDataset(valid_data) valid_dataloader = DataLoader( valid_dataset, sampler=SequentialSampler(valid_dataset), batch_size=args.eval_batch_size, num_workers=args.num_workers, collate_fn=lambda x: collate_fn(x, tokenizer, args.max_seq_length, eos=args.eos_setting, tokenizer_type=args.tokenizer)) (val_loss, val_loss_token), valid_str = evaluate(model, valid_dataloader, args) valid_noisy = [x['noisy'] for x in valid_data] valid_clean = [x['clean'] for x in valid_data] valid_annot = [x['annotation'] for x in valid_data] prediction = correct_beam(model, tokenizer, valid_noisy, args, eos=args.eos_setting, length_limit=0.15)
def main(cfg): cwd = utils.get_original_cwd() cfg.cwd = cwd cfg.pos_size = 2 * cfg.pos_limit + 2 logger.info(f'\n{cfg.pretty()}') __Model__ = { 'cnn': models.PCNN, 'rnn': models.BiLSTM, 'transformer': models.Transformer, 'gcn': models.GCN, 'capsule': models.Capsule, 'lm': models.LM, } # device if cfg.use_gpu and torch.cuda.is_available(): device = torch.device('cuda', cfg.gpu_id) else: device = torch.device('cpu') logger.info(f'device: {device}') # 如果不修改预处理的过程,这一步最好注释掉,不用每次运行都预处理数据一次 if cfg.preprocess: preprocess(cfg) train_data_path = os.path.join(cfg.cwd, cfg.out_path, 'train.pkl') valid_data_path = os.path.join(cfg.cwd, cfg.out_path, 'valid.pkl') test_data_path = os.path.join(cfg.cwd, cfg.out_path, 'test.pkl') vocab_path = os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl') if cfg.model_name == 'lm': vocab_size = None else: vocab = load_pkl(vocab_path) vocab_size = vocab.count cfg.vocab_size = vocab_size train_dataset = CustomDataset(train_data_path) valid_dataset = CustomDataset(valid_data_path) test_dataset = CustomDataset(test_data_path) train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg)) valid_dataloader = DataLoader(valid_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg)) test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg)) model = __Model__[cfg.model_name](cfg) model.to(device) logger.info(f'\n {model}') optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=cfg.lr_factor, patience=cfg.lr_patience) criterion = nn.CrossEntropyLoss() best_f1, best_epoch = -1, 0 es_loss, es_f1, es_epoch, es_patience, best_es_epoch, best_es_f1, es_path, best_es_path = 1e8, -1, 0, 0, 0, -1, '', '' train_losses, valid_losses = [], [] if cfg.show_plot and cfg.plot_utils == 'tensorboard': writer = SummaryWriter('tensorboard') else: writer = None logger.info('=' * 10 + ' Start training ' + '=' * 10) for epoch in range(1, cfg.epoch + 1): manual_seed(cfg.seed + epoch) train_loss = train(epoch, model, train_dataloader, optimizer, criterion, device, writer, cfg) valid_f1, valid_loss = validate(epoch, model, valid_dataloader, criterion, device, cfg) scheduler.step(valid_loss) model_path = model.save(epoch, cfg) # logger.info(model_path) train_losses.append(train_loss) valid_losses.append(valid_loss) if best_f1 < valid_f1: best_f1 = valid_f1 best_epoch = epoch # 使用 valid loss 做 early stopping 的判断标准 if es_loss > valid_loss: es_loss = valid_loss es_f1 = valid_f1 es_epoch = epoch es_patience = 0 es_path = model_path else: es_patience += 1 if es_patience >= cfg.early_stopping_patience: best_es_epoch = es_epoch best_es_f1 = es_f1 best_es_path = es_path if cfg.show_plot: if cfg.plot_utils == 'matplot': plt.plot(train_losses, 'x-') plt.plot(valid_losses, '+-') plt.legend(['train', 'valid']) plt.title('train/valid comparison loss') plt.show() if cfg.plot_utils == 'tensorboard': for i in range(len(train_losses)): writer.add_scalars('train/valid_comparison_loss', { 'train': train_losses[i], 'valid': valid_losses[i] }, i) writer.close() logger.info( f'best(valid loss quota) early stopping epoch: {best_es_epoch}, ' f'this epoch macro f1: {best_es_f1:0.4f}') logger.info(f'this model save path: {best_es_path}') logger.info( f'total {cfg.epoch} epochs, best(valid macro f1) epoch: {best_epoch}, ' f'this epoch macro f1: {best_f1:.4f}') validate(-1, model, test_dataloader, criterion, device, cfg)
def train(model, tokenizer, train_data, valid_data, args, eos=False): model.train() train_dataset = TextDataset(train_data) train_dataloader = DataLoader(train_dataset, sampler=RandomSampler(train_dataset), batch_size=args.train_batch_size, num_workers=args.num_workers, collate_fn=lambda x: collate_fn(x, tokenizer, args.max_seq_length, eos=eos, tokenizer_type=args.tokenizer)) valid_dataset = TextDataset(valid_data) valid_dataloader = DataLoader(valid_dataset, sampler=SequentialSampler(valid_dataset), batch_size=args.eval_batch_size, num_workers=args.num_workers, collate_fn=lambda x: collate_fn(x, tokenizer, args.max_seq_length, eos=eos, tokenizer_type=args.tokenizer)) valid_noisy = [x['noisy'] for x in valid_data] valid_clean = [x['clean'] for x in valid_data] epochs = (args.max_steps - 1) // len(train_dataloader) + 1 optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=eval(args.adam_betas), eps=args.eps, weight_decay=args.weight_decay) lr_lambda = lambda x: x / args.num_warmup_steps if x <= args.num_warmup_steps else (x / args.num_warmup_steps) ** -0.5 scheduler = LambdaLR(optimizer, lr_lambda) step = 0 best_val_gleu = -float("inf") meter = Meter() for epoch in range(1, epochs + 1): print("===EPOCH: ", epoch) for batch in train_dataloader: step += 1 batch = tuple(t.to(args.device) for t in batch) loss, items = calc_loss(model, batch) meter.add(*items) loss.backward() if args.max_grad_norm > 0: nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() model.zero_grad() scheduler.step() if step % args.log_interval == 0: lr = scheduler.get_lr()[0] loss_sent, loss_token = meter.average() logger.info(f' [{step:5d}] lr {lr:.6f} | {meter.print_str(True)}') nsml.report(step=step, scope=locals(), summary=True, train__lr=lr, train__loss_sent=loss_sent, train__token_ppl=math.exp(loss_token)) meter.init() if step % args.eval_interval == 0: start_eval = time.time() (val_loss, val_loss_token), valid_str = evaluate(model, valid_dataloader, args) prediction = correct(model, tokenizer, valid_noisy, args, eos=eos, length_limit=0.1) val_em = em(prediction, valid_clean) cnt = 0 for noisy, pred, clean in zip(valid_noisy, prediction, valid_clean): print(f'[{noisy}], [{pred}], [{clean}]') # 10개만 출력하기 cnt += 1 if cnt == 20: break val_gleu = gleu(prediction, valid_clean) logger.info('-' * 89) logger.info(f' [{step:6d}] valid | {valid_str} | em {val_em:5.2f} | gleu {val_gleu:5.2f}') logger.info('-' * 89) nsml.report(step=step, scope=locals(), summary=True, valid__loss_sent=val_loss, valid__token_ppl=math.exp(val_loss_token), valid__em=val_em, valid__gleu=val_gleu) if val_gleu > best_val_gleu: best_val_gleu = val_gleu nsml.save("best") meter.start += time.time() - start_eval if step >= args.max_steps: break #nsml.save(epoch) if step >= args.max_steps: break
if args.distributed: model = nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, broadcast_buffers=False, ) train_loader = DataLoader( train_set, batch_size=args.batch, sampler=data_sampler(train_set, shuffle=True, distributed=args.distributed), num_workers=args.num_workers, collate_fn=collate_fn(args), ) valid_loader = DataLoader( valid_set, batch_size=args.batch, sampler=data_sampler(valid_set, shuffle=False, distributed=args.distributed), num_workers=args.num_workers, collate_fn=collate_fn(args), ) for epoch in range(args.epoch): train(args, epoch, train_loader,
def main(args): acc_list = [] f1_score_list = [] prec_list = [] recall_list = [] for i in range(10): setup_data() model = RCNN(vocab_size=args.vocab_size, embedding_dim=args.embedding_dim, hidden_size=args.hidden_size, hidden_size_linear=args.hidden_size_linear, class_num=args.class_num, dropout=args.dropout).to(args.device) if args.n_gpu > 1: model = torch.nn.DataParallel(model, dim=0) train_texts, train_labels = read_file(args.train_file_path) word2idx, embedding = build_dictionary(train_texts, args.vocab_size, args.lexical, args.syntactic, args.semantic) logger.info('Dictionary Finished!') full_dataset = CustomTextDataset(train_texts, train_labels, word2idx, args) num_train_data = len(full_dataset) - args.num_val_data train_dataset, val_dataset = random_split( full_dataset, [num_train_data, args.num_val_data]) train_dataloader = DataLoader(dataset=train_dataset, collate_fn=lambda x: collate_fn(x, args), batch_size=args.batch_size, shuffle=True) valid_dataloader = DataLoader(dataset=val_dataset, collate_fn=lambda x: collate_fn(x, args), batch_size=args.batch_size, shuffle=True) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) train(model, optimizer, train_dataloader, valid_dataloader, embedding, args) logger.info('******************** Train Finished ********************') # Test if args.test_set: test_texts, test_labels = read_file(args.test_file_path) test_dataset = CustomTextDataset(test_texts, test_labels, word2idx, args) test_dataloader = DataLoader( dataset=test_dataset, collate_fn=lambda x: collate_fn(x, args), batch_size=args.batch_size, shuffle=True) model.load_state_dict( torch.load(os.path.join(args.model_save_path, "best.pt"))) _, accuracy, precision, recall, f1, cm = evaluate( model, test_dataloader, embedding, args) logger.info('-' * 50) logger.info( f'|* TEST SET *| |ACC| {accuracy:>.4f} |PRECISION| {precision:>.4f} |RECALL| {recall:>.4f} |F1| {f1:>.4f}' ) logger.info('-' * 50) logger.info('---------------- CONFUSION MATRIX ----------------') for i in range(len(cm)): logger.info(cm[i]) logger.info('--------------------------------------------------') acc_list.append(accuracy / 100) prec_list.append(precision) recall_list.append(recall) f1_score_list.append(f1) avg_acc = sum(acc_list) / len(acc_list) avg_prec = sum(prec_list) / len(prec_list) avg_recall = sum(recall_list) / len(recall_list) avg_f1_score = sum(f1_score_list) / len(f1_score_list) logger.info('--------------------------------------------------') logger.info( f'|* TEST SET *| |Avg ACC| {avg_acc:>.4f} |Avg PRECISION| {avg_prec:>.4f} |Avg RECALL| {avg_recall:>.4f} |Avg F1| {avg_f1_score:>.4f}' ) logger.info('--------------------------------------------------') plot_df = pd.DataFrame({ 'x_values': range(10), 'avg_acc': acc_list, 'avg_prec': prec_list, 'avg_recall': recall_list, 'avg_f1_score': f1_score_list }) plt.plot('x_values', 'avg_acc', data=plot_df, marker='o', markerfacecolor='blue', markersize=12, color='skyblue', linewidth=4) plt.plot('x_values', 'avg_prec', data=plot_df, marker='', color='olive', linewidth=2) plt.plot('x_values', 'avg_recall', data=plot_df, marker='', color='olive', linewidth=2, linestyle='dashed') plt.plot('x_values', 'avg_f1_score', data=plot_df, marker='', color='olive', linewidth=2, linestyle='dashed') plt.legend() fname = 'lexical-semantic-syntactic.png' if args.lexical and args.semantic and args.syntactic \ else 'semantic-syntactic.png' if args.semantic and args.syntactic \ else 'lexical-semantic.png' if args.lexical and args.semantic \ else 'lexical-syntactic.png'if args.lexical and args.syntactic \ else 'lexical.png' if args.lexical \ else 'syntactic.png' if args.syntactic \ else 'semantic.png' if args.semantic \ else 'plain.png' if not (path.exists('./images')): mkdir('./images') plt.savefig(path.join('./images', fname))
def _sampling(self, epoch): self.model.eval() loader = self.test_loader asset_path = os.path.join(self.asset_path) indices = random.sample(range(len(loader.dataset)), self.config["num_sample"]) batch = collate_fn([loader.dataset[i] for i in indices]) for key in batch.keys(): batch[key] = batch[key].to(self.device) prime = batch['pitch'][:, :self.config["num_prime"]] if isinstance(self.model, torch.nn.DataParallel): model = self.model.module else: model = self.model prime_rhythm = batch['rhythm'][:, :self.config["num_prime"]] result_dict = model.sampling(prime_rhythm, prime, batch['chord'], self.config["topk"], self.config['attention_map']) result_key = 'pitch' pitch_idx = result_dict[result_key].cpu().numpy() logger.info("==========sampling result of epoch %03d==========" % epoch) os.makedirs(os.path.join(asset_path, 'sampling_results', 'epoch_%03d' % epoch), exist_ok=True) for sample_id in range(pitch_idx.shape[0]): logger.info(("Sample %02d : " % sample_id) + str(pitch_idx[sample_id][self.config["num_prime"]:self. config["num_prime"] + 20])) save_path = os.path.join( asset_path, 'sampling_results', 'epoch_%03d' % epoch, 'epoch%03d_sample%02d.mid' % (epoch, sample_id)) gt_pitch = batch['pitch'].cpu().numpy() gt_chord = batch['chord'][:, :-1].cpu().numpy() sample_dict = { 'pitch': pitch_idx[sample_id], 'rhythm': result_dict['rhythm'][sample_id].cpu().numpy(), 'chord': csc_matrix(gt_chord[sample_id]) } with open(save_path.replace('.mid', '.pkl'), 'wb') as f_samp: pickle.dump(sample_dict, f_samp) instruments = pitch_to_midi(pitch_idx[sample_id], gt_chord[sample_id], model.frame_per_bar, save_path) save_instruments_as_image(save_path.replace('.mid', '.jpg'), instruments, frame_per_bar=model.frame_per_bar, num_bars=(model.max_len // model.frame_per_bar)) # save groundtruth logger.info(("Groundtruth %02d : " % sample_id) + str(gt_pitch[ sample_id, self.config["num_prime"]:self.config["num_prime"] + 20])) gt_path = os.path.join( asset_path, 'sampling_results', 'epoch_%03d' % epoch, 'epoch%03d_groundtruth%02d.mid' % (epoch, sample_id)) gt_dict = { 'pitch': gt_pitch[sample_id, :-1], 'rhythm': batch['rhythm'][sample_id, :-1].cpu().numpy(), 'chord': csc_matrix(gt_chord[sample_id]) } with open(gt_path.replace('.mid', '.pkl'), 'wb') as f_gt: pickle.dump(gt_dict, f_gt) gt_instruments = pitch_to_midi(gt_pitch[sample_id, :-1], gt_chord[sample_id], model.frame_per_bar, gt_path) save_instruments_as_image(gt_path.replace('.mid', '.jpg'), gt_instruments, frame_per_bar=model.frame_per_bar, num_bars=(model.max_len // model.frame_per_bar)) if self.config['attention_map']: os.makedirs(os.path.join(asset_path, 'attention_map', 'epoch_%03d' % epoch, 'RDec-Chord', 'sample_%02d' % sample_id), exist_ok=True) for head_num in range(8): for l, w in enumerate(result_dict['weights_bdec']): fig_w = plt.figure(figsize=(8, 8)) ax_w = fig_w.add_subplot(1, 1, 1) heatmap_w = ax_w.pcolor(w[sample_id, head_num].cpu().numpy(), cmap='Reds') ax_w.set_xticks(np.arange(0, self.model.module.max_len)) ax_w.xaxis.tick_top() ax_w.set_yticks(np.arange(0, self.model.module.max_len)) ax_w.set_xticklabels(rhythm_to_symbol_list( result_dict['rhythm'][sample_id].cpu().numpy()), fontdict=x_fontdict) chord_symbol_list = [''] * pitch_idx.shape[1] for t in sorted( chord_array_to_dict( gt_chord[sample_id]).keys()): chord_symbol_list[t] = chord_array_to_dict( gt_chord[sample_id])[t].tolist() ax_w.set_yticklabels(chord_to_symbol_list( gt_chord[sample_id]), fontdict=y_fontdict) ax_w.invert_yaxis() plt.savefig( os.path.join( asset_path, 'attention_map', 'epoch_%03d' % epoch, 'RDec-Chord', 'sample_%02d' % sample_id, 'epoch%03d_RDec-Chord_sample%02d_head%02d_layer%02d.jpg' % (epoch, sample_id, head_num, l))) plt.close()
def train(model, optimizer, tokenizer, train_data, valid_data, args): logger.info("Training starts!") os.makedirs(args.model_dir, exist_ok=True) train_dataset = QueryDataset(train_data) train_data_loader = DataLoader( train_dataset, sampler=RandomSampler(train_dataset), batch_size=args.bsz, num_workers=args.num_workers, collate_fn=lambda x: collate_fn(x, tokenizer, args.sample, args. max_seq_len)) valid_dataset = QueryDataset(valid_data) valid_data_loader = DataLoader( valid_dataset, sampler=SequentialSampler(valid_dataset), batch_size=args.bsz, num_workers=args.num_workers, collate_fn=lambda x: collate_fn(x, tokenizer, args.sample, args. max_seq_len)) n_batch = (len(train_dataset) - 1) // args.bsz + 1 logger.info(f" Number of training batch: {n_batch}") if args.eval_interval is None: args.eval_interval = n_batch try: best_valid_loss = float('inf') model.train() params = get_params(model) train_logger = TrainLogger() train_logger_part = TrainLogger() step = 0 for epoch in range(1, args.n_epochs + 1): logger.info(f"Epoch {epoch:2d}") for batch in train_data_loader: step += 1 batch = tuple(t.to(device) for t in batch) loss, items = calc_loss(model, batch) loss.backward() nn.utils.clip_grad_norm_(params, args.clip) optimizer.step() optimizer.zero_grad() train_logger.add(*items) train_logger_part.add(*items) if step % args.log_interval == 0: logger.info( f" step {step:8d} | {train_logger_part.print_str(True)}" ) train_logger_part.init() if step % args.eval_interval == 0: start_eval = time.time() logger.info('-' * 90) train_loss, train_str = train_logger.average( ), train_logger.print_str() logger.info(f"| step {step:8d} | train | {train_str}") # evaluate valid loss, ppl with torch.no_grad(): valid_loss, valid_str = evaluate( model, valid_data_loader, args.eval_n_steps) logger.info(f"| step {step:8d} | valid | {valid_str}") if valid_loss[0] < best_valid_loss: model_save(args.model_dir, model, optimizer) logger.info(">>>>> Saving model (new best validation)") best_valid_loss = valid_loss[0] logger.info('-' * 90) model.train() train_logger.init() train_logger_part.start += time.time() - start_eval except KeyboardInterrupt: logger.info('-' * 90) logger.info(' Exiting from training early')