コード例 #1
0
ファイル: trainer.py プロジェクト: Jiahao-Huang/tclane3
    def kFoldTrain(self):
        cfg = self.cfg

        # make folds
        self.train_dataset.shuffle()
        folds = self.make_folds()
        loss_t_mean, acc_t_mean, loss_v_mean, acc_v_mean = 0, 0, 0, 0

        for k in range(self.k_fold):  # fold-k will be valid set
            logger.info("=" * 5 + " Epoch %d - Fold %d " % (self.epoch, k) +
                        "=" * 5)
            train = []
            valid = folds[k]
            for i in range(self.k_fold):  # folds except k will be train set
                if i != k:
                    train = train + folds[i]

            train_dataloader = DataLoader(train,
                                          batch_size=cfg.batch_size,
                                          shuffle=True,
                                          collate_fn=collate_fn(cfg))
            valid_dataloader = DataLoader(valid,
                                          batch_size=cfg.batch_size,
                                          shuffle=True,
                                          collate_fn=collate_fn(cfg))

            loss_t, acc_t = self.train(train_dataloader)
            loss_v, acc_v = self.validate(valid_dataloader)

            loss_t_mean += loss_t / self.k_fold
            acc_t_mean += acc_t / self.k_fold
            loss_v_mean += loss_v / self.k_fold
            acc_v_mean += acc_v / self.k_fold

        return loss_t_mean, acc_t_mean, loss_v_mean, acc_v_mean
コード例 #2
0
def get_data_loaders_new(args, tokenizer):
    train_data = get_dataset(tokenizer,
                             args.train_path,
                             args.fea_path,
                             n_history=args.max_history)
    valid_data = get_dataset(tokenizer,
                             args.valid_path,
                             args.fea_path,
                             n_history=args.max_history)
    train_dataset = AVSDDataSet(train_data[0],
                                tokenizer, (train_data[1], valid_data[1]),
                                drop_rate=0,
                                train=True)
    valid_dataset = AVSDDataSet(valid_data[0],
                                tokenizer, (valid_data[1], train_data[1]),
                                drop_rate=0,
                                train=False)
    train_loader = DataLoader(train_dataset,
                              shuffle=(not args.distributed),
                              batch_size=args.train_batch_size,
                              num_workers=4,
                              collate_fn=lambda x: collate_fn(
                                  x, tokenizer.pad_token_id, features=True))
    valid_loader = DataLoader(valid_dataset,
                              shuffle=False,
                              batch_size=args.valid_batch_size,
                              num_workers=4,
                              collate_fn=lambda x: collate_fn(
                                  x, tokenizer.pad_token_id, features=True))
    return train_loader, valid_loader
コード例 #3
0
def main(args):
    model = RCNN(vocab_size=args.vocab_size,
                 embedding_dim=args.embedding_dim,
                 hidden_size=args.hidden_size,
                 hidden_size_linear=args.hidden_size_linear,
                 class_num=args.class_num,
                 dropout=args.dropout).to(args.device)

    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model, dim=0)

    train_texts, train_labels = read_file(args.train_file_path)
    word2idx = build_dictionary(train_texts, vocab_size=args.vocab_size)
    logger.info('Dictionary Finished!')

    full_dataset = CustomTextDataset(train_texts, train_labels, word2idx)
    num_train_data = len(full_dataset) - args.num_val_data
    train_dataset, val_dataset = random_split(
        full_dataset, [num_train_data, args.num_val_data])
    train_dataloader = DataLoader(dataset=train_dataset,
                                  collate_fn=lambda x: collate_fn(x, args),
                                  batch_size=args.batch_size,
                                  shuffle=True)

    valid_dataloader = DataLoader(dataset=val_dataset,
                                  collate_fn=lambda x: collate_fn(x, args),
                                  batch_size=args.batch_size,
                                  shuffle=True)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    train(model, optimizer, train_dataloader, valid_dataloader, args)
    logger.info('******************** Train Finished ********************')

    # Test
    if args.test_set:
        test_texts, test_labels = read_file(args.test_file_path)
        test_dataset = CustomTextDataset(test_texts, test_labels, word2idx)
        test_dataloader = DataLoader(dataset=test_dataset,
                                     collate_fn=lambda x: collate_fn(x, args),
                                     batch_size=args.batch_size,
                                     shuffle=True)

        model.load_state_dict(
            torch.load(os.path.join(args.model_save_path, "best.pt")))
        _, accuracy, precision, recall, f1, cm = evaluate(
            model, test_dataloader, args)
        logger.info('-' * 50)
        logger.info(
            f'|* TEST SET *| |ACC| {accuracy:>.4f} |PRECISION| {precision:>.4f} |RECALL| {recall:>.4f} |F1| {f1:>.4f}'
        )
        logger.info('-' * 50)
        logger.info('---------------- CONFUSION MATRIX ----------------')
        for i in range(len(cm)):
            logger.info(cm[i])
        logger.info('--------------------------------------------------')
コード例 #4
0
def visualize(model, dataset, doc):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    """
    # Predicts, and visualizes one document with html file
    :param model: pretrained model
    :param dataset: news20 dataset
    :param doc: document to feed in
    :return: html formatted string for whole document
    """

    orig_doc = [word_tokenize(sent) for sent in sent_tokenize(doc)]
    doc, num_sents, num_words = dataset.transform(doc)
    label = 0  # dummy label for transformation

    doc, label, doc_length, sent_length = collate_fn([(doc, label, num_sents,
                                                       num_words)])
    score, word_att_weight, sentence_att_weight \
        = model(doc.to(device), doc_length.to(device), sent_length.to(device))

    # predicted = int(torch.max(score, dim=1)[1])
    classes = ['Cryptography', 'Electronics', 'Medical', 'Space']
    result = "<h2>Attention Visualization</h2>"

    bar_chart(classes,
              torch.softmax(score.detach(), dim=1).flatten().cpu(),
              'Prediction')
    result += '<br><img src="prediction_bar_chart.png"><br>'
    for orig_sent, att_weight, sent_weight in zip(
            orig_doc, word_att_weight[0].tolist(),
            sentence_att_weight[0].tolist()):
        result += map_sentence_to_color(orig_sent, att_weight, sent_weight)

    return result
コード例 #5
0
def translate(translator, src_seq, src_pos, domain):
    src_word = Constants.BOS_SRC
    tgt_word = Constants.BOS_TGT
    if domain == Constants.BOS_TGT:
        src_word, tgt_word = tgt_word, src_word

    # s2t by previous model
    tgt_hyp, _ = translator.translate_batch(src_seq, src_pos, domain)
    tgt_hyp = [[tgt_word] + t_hyp[0] + [Constants.EOS] for t_hyp in tgt_hyp]
    tgt_seq_hyp, tgt_pos_hyp = collate_fn(tgt_hyp)

    return tgt_seq_hyp, tgt_pos_hyp
コード例 #6
0
def get_data_loaders(args, tokenizer):
    dev_dataset = InferenceDataset('dev', tokenizer, args)
    train_dataset = InferenceDataset('train', tokenizer, args)

    if args.small_data != -1:
        logger.info('Using small subset of data')
        dev_dataset = Subset(dev_dataset, list(range(args.small_data)))
        train_dataset = Subset(train_dataset, list(range(args.small_data)))

    dev_dataloader = DataLoader(dev_dataset,
                                batch_size=args.batch_size,
                                shuffle=(not args.distributed),
                                num_workers=8,
                                collate_fn=lambda x: collate_fn(x, tokenizer.eos_token_id, args))
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=(not args.distributed),
                                  num_workers=8,
                                  collate_fn=lambda x: collate_fn(x, tokenizer.eos_token_id, args))

    return train_dataloader, dev_dataloader
コード例 #7
0
ファイル: train.py プロジェクト: zhuzhutingru123/DSTC8-AVSD
def get_data_loaders_new(args, tokenizer):
    train_data = get_dataset(tokenizer,
                             args.train_path,
                             args.fea_path,
                             n_history=args.max_history)
    #with open("train_data_gpt2.pkl", "rb") as f:
    #    train_data = pkl.load(f)
    # pkl.dump(train_data, f)
    valid_data = get_dataset(tokenizer,
                             args.valid_path,
                             args.fea_path,
                             n_history=args.max_history)
    #with open("valid_data_gpt2.pkl", "rb") as f:
    #    valid_data = pkl.load(f)
    # pkl.dump(valid_data, f)
    train_dataset = AVSDDataSet(train_data[0],
                                tokenizer, (train_data[1], valid_data[1]),
                                drop_rate=0,
                                train=True)
    valid_dataset = AVSDDataSet(valid_data[0],
                                tokenizer, (valid_data[1], train_data[1]),
                                drop_rate=0,
                                train=False)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.train_batch_size,
                              num_workers=4,
                              shuffle=(not args.distributed),
                              collate_fn=lambda x: collate_fn(
                                  x, tokenizer.pad_token_id, features=True))
    valid_loader = DataLoader(valid_dataset,
                              batch_size=args.valid_batch_size,
                              num_workers=4,
                              shuffle=False,
                              collate_fn=lambda x: collate_fn(
                                  x, tokenizer.pad_token_id, features=True))
    return train_loader, valid_loader
コード例 #8
0
def test(model, tokenizer, test_data, args):
    logger.info("Test starts!")
    model_load(args.model_dir, model)
    model = model.to(device)

    test_dataset = QueryDataset(test_data)
    test_data_loader = DataLoader(
        test_dataset,
        sampler=SequentialSampler(test_dataset),
        batch_size=args.bsz,
        num_workers=args.num_workers,
        collate_fn=lambda x: collate_fn(x, tokenizer, args.sample, args.
                                        max_seq_len))

    test_loss, test_str = evaluate(model, test_data_loader)
    logger.info(f"| test  | {test_str}")
コード例 #9
0
    def test(self):
        cfg = self.cfg

        test_dataloader = DataLoader(self.test_dataset,
                                     batch_size=cfg.batch_size,
                                     shuffle=False,
                                     collate_fn=collate_fn(cfg))
        results = []

        for (X, _) in test_dataloader:
            with torch.no_grad():
                for (k, v) in X.items():
                    X[k] = v.to(self.device)

                y_pred = self.model(X)

                result = F.softmax(y_pred, dim=-1)[:, 1].to('cpu').tolist()
                results += result

        with open(os.path.join(cfg.cwd, cfg.result_file), 'w') as f:
            f.write('\t'.join('%s' % r for r in results))
コード例 #10
0
ファイル: trainer.py プロジェクト: laohur/ContextTransformer
def decode(model, src_seq, src_pos, ctx_seq, ctx_pos, args, token_len):
    translator = Translator(max_token_seq_len=args.max_token_seq_len,
                            beam_size=10,
                            n_best=1,
                            device=args.device,
                            bad_mask=None,
                            model=model)
    tgt_seq = []
    all_hyp, all_scores = translator.translate_batch(src_seq, src_pos, ctx_seq,
                                                     ctx_pos)
    for idx_seqs in all_hyp:  # batch
        idx_seq = idx_seqs[0]  # n_best=1
        end_pos = len(idx_seq)
        for i in range(len(idx_seq)):
            if idx_seq[i] == Constants.EOS:
                end_pos = i
                break
        # tgt_seq.append([Constants.BOS] + idx_seq[:end_pos][:args.max_word_seq_len] + [Constants.EOS])
        tgt_seq.append(idx_seq[:end_pos][:args.max_word_seq_len])
    batch_seq, batch_pos = collate_fn(tgt_seq, max_len=token_len)
    return batch_seq.to(args.device), batch_pos.to(args.device)
コード例 #11
0
def visualize_doc(model, dataset, doc, answer):
    # 입력된 doc을 사전에 학습된 모델에 넣고 weight 시각화
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    """
    # Predicts, and visualizes one document with html file
    :param model: pretrained model
    :param dataset: news20 dataset
    :param doc: document to feed in
    :return: html formatted string for whole document
    """
    #문장 분리 후 단어 분리
    orig_doc = [word_tokenize(sent) for sent in sent_tokenize(doc)]
    # doc:
    doc, num_sents, num_words = dataset.transform(doc)
    label = 0  # dummy label for transformation

    doc, label, doc_length, sent_length = collate_fn([(doc, label, num_sents,
                                                       num_words)])

    score, word_att_weight, sentence_att_weight \
        = model(doc.to(device), doc_length.to(device), sent_length.to(device))

    predict = torch.argmax(score.detach(), dim=1).flatten().cpu()

    if predict == answer:  #모델이 답을 맞춘 경우
        result = "<p>Examples of correct prediction results:</p>"
        result += '<input type="text" name="serial" value="%s" >' % (answer)

    elif predict != answer:  #모델이 답을 틀린 경우
        result = "<p>Examples of wrong prediction results:</p>"
        result += '<input type="text" name="serial" value="%s" >' % (predict)
        result += '<input type="text" name="serial" value="%s" >' % (answer)

    for orig_sent, att_weight, sent_weight in zip(
            orig_doc, word_att_weight[0].tolist(),
            sentence_att_weight[0].tolist()):
        result += map_sentence_to_color(orig_sent, att_weight, sent_weight)

    return result
コード例 #12
0
def visualize_chart(model, dataset, doc):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # transform(doc) 한 example에 대한 결과 출력.
    doc, num_sents, num_words = dataset.transform(doc[0])
    label = 0  # dummy label for transformation

    doc, label, doc_length, sent_length = collate_fn([(doc, label, num_sents,
                                                       num_words)])

    score, word_att_weight, sentence_att_weight \
        = model(doc.to(device), doc_length.to(device), sent_length.to(device))

    # predicted = int(torch.max(score, dim=1)[1])
    classes = ['0', '1', '2']
    #     classes = ['Cryptography', 'Electronics', 'Medical', 'Space']
    result = "<h2>Attention Visualization</h2>"

    bar_chart(classes,
              torch.softmax(score.detach(), dim=1).flatten().cpu(),
              'Prediction')
    result += '<br><img src="prediction_bar_chart.png"><br>'

    return result
コード例 #13
0
ファイル: test_dataset.py プロジェクト: cschaefer26/TacoGan
    def test_collate_fn(self):
        mels = (-np.ones(
            (2, 2), dtype=np.float), np.ones((3, 2), dtype=np.float))
        seqs = ([1, 2], [1, 2, 3])
        ids = ('mel_1', 'mel_2')
        mel_lens = (2, 3)
        batch = tuple(zip(seqs, mels, ids, mel_lens))

        seqs, mels, stops, ids, mel_lens = collate_fn(batch=batch,
                                                      r=3,
                                                      silence_len=0)

        expected_seqs = np.array([[1, 2, 0], [1, 2, 3]])
        np.testing.assert_almost_equal(seqs, expected_seqs, decimal=8)

        expected_mels = np.array([[[-1, -1], [-1, -1], [-1, -1]],
                                  [[1, 1], [1, 1], [1, 1]]])
        np.testing.assert_almost_equal(mels, expected_mels, decimal=8)

        expected_stops = np.array([[0, 1, 0], [0, 0, 1]])
        np.testing.assert_almost_equal(stops, expected_stops, decimal=8)

        expected_lens = np.array([2, 3])
        np.testing.assert_almost_equal(mel_lens, expected_lens, decimal=8)
コード例 #14
0
        print("{}~{}".format(len(valid_data[i]['noisy']),
                             len(valid_data[i + 99]['noisy'])))
    except:
        print("last batch: ", i, len(valid_data))
        print("{}~{}".format(len(valid_data[i]['noisy']),
                             len(valid_data[-1]['noisy'])))

valid_dataset = TextDataset(valid_data)
valid_dataloader = DataLoader(
    valid_dataset,
    sampler=SequentialSampler(valid_dataset),
    batch_size=args.eval_batch_size,
    num_workers=args.num_workers,
    collate_fn=lambda x: collate_fn(x,
                                    tokenizer,
                                    args.max_seq_length,
                                    eos=args.eos_setting,
                                    tokenizer_type=args.tokenizer))

(val_loss, val_loss_token), valid_str = evaluate(model, valid_dataloader, args)

valid_noisy = [x['noisy'] for x in valid_data]
valid_clean = [x['clean'] for x in valid_data]
valid_annot = [x['annotation'] for x in valid_data]
prediction = correct_beam(model,
                          tokenizer,
                          valid_noisy,
                          args,
                          eos=args.eos_setting,
                          length_limit=0.15)
コード例 #15
0
def main(cfg):
    cwd = utils.get_original_cwd()
    cfg.cwd = cwd
    cfg.pos_size = 2 * cfg.pos_limit + 2
    logger.info(f'\n{cfg.pretty()}')

    __Model__ = {
        'cnn': models.PCNN,
        'rnn': models.BiLSTM,
        'transformer': models.Transformer,
        'gcn': models.GCN,
        'capsule': models.Capsule,
        'lm': models.LM,
    }

    # device
    if cfg.use_gpu and torch.cuda.is_available():
        device = torch.device('cuda', cfg.gpu_id)
    else:
        device = torch.device('cpu')
    logger.info(f'device: {device}')

    # 如果不修改预处理的过程,这一步最好注释掉,不用每次运行都预处理数据一次
    if cfg.preprocess:
        preprocess(cfg)

    train_data_path = os.path.join(cfg.cwd, cfg.out_path, 'train.pkl')
    valid_data_path = os.path.join(cfg.cwd, cfg.out_path, 'valid.pkl')
    test_data_path = os.path.join(cfg.cwd, cfg.out_path, 'test.pkl')
    vocab_path = os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl')

    if cfg.model_name == 'lm':
        vocab_size = None
    else:
        vocab = load_pkl(vocab_path)
        vocab_size = vocab.count
    cfg.vocab_size = vocab_size

    train_dataset = CustomDataset(train_data_path)
    valid_dataset = CustomDataset(valid_data_path)
    test_dataset = CustomDataset(test_data_path)

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=cfg.batch_size,
                                  shuffle=True,
                                  collate_fn=collate_fn(cfg))
    valid_dataloader = DataLoader(valid_dataset,
                                  batch_size=cfg.batch_size,
                                  shuffle=True,
                                  collate_fn=collate_fn(cfg))
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=cfg.batch_size,
                                 shuffle=True,
                                 collate_fn=collate_fn(cfg))

    model = __Model__[cfg.model_name](cfg)
    model.to(device)
    logger.info(f'\n {model}')

    optimizer = optim.Adam(model.parameters(),
                           lr=cfg.learning_rate,
                           weight_decay=cfg.weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     factor=cfg.lr_factor,
                                                     patience=cfg.lr_patience)
    criterion = nn.CrossEntropyLoss()

    best_f1, best_epoch = -1, 0
    es_loss, es_f1, es_epoch, es_patience, best_es_epoch, best_es_f1, es_path, best_es_path = 1e8, -1, 0, 0, 0, -1, '', ''
    train_losses, valid_losses = [], []

    if cfg.show_plot and cfg.plot_utils == 'tensorboard':
        writer = SummaryWriter('tensorboard')
    else:
        writer = None

    logger.info('=' * 10 + ' Start training ' + '=' * 10)

    for epoch in range(1, cfg.epoch + 1):
        manual_seed(cfg.seed + epoch)
        train_loss = train(epoch, model, train_dataloader, optimizer,
                           criterion, device, writer, cfg)
        valid_f1, valid_loss = validate(epoch, model, valid_dataloader,
                                        criterion, device, cfg)
        scheduler.step(valid_loss)
        model_path = model.save(epoch, cfg)
        # logger.info(model_path)

        train_losses.append(train_loss)
        valid_losses.append(valid_loss)
        if best_f1 < valid_f1:
            best_f1 = valid_f1
            best_epoch = epoch
        # 使用 valid loss 做 early stopping 的判断标准
        if es_loss > valid_loss:
            es_loss = valid_loss
            es_f1 = valid_f1
            es_epoch = epoch
            es_patience = 0
            es_path = model_path
        else:
            es_patience += 1
            if es_patience >= cfg.early_stopping_patience:
                best_es_epoch = es_epoch
                best_es_f1 = es_f1
                best_es_path = es_path

    if cfg.show_plot:
        if cfg.plot_utils == 'matplot':
            plt.plot(train_losses, 'x-')
            plt.plot(valid_losses, '+-')
            plt.legend(['train', 'valid'])
            plt.title('train/valid comparison loss')
            plt.show()

        if cfg.plot_utils == 'tensorboard':
            for i in range(len(train_losses)):
                writer.add_scalars('train/valid_comparison_loss', {
                    'train': train_losses[i],
                    'valid': valid_losses[i]
                }, i)
            writer.close()

    logger.info(
        f'best(valid loss quota) early stopping epoch: {best_es_epoch}, '
        f'this epoch macro f1: {best_es_f1:0.4f}')
    logger.info(f'this model save path: {best_es_path}')
    logger.info(
        f'total {cfg.epoch} epochs, best(valid macro f1) epoch: {best_epoch}, '
        f'this epoch macro f1: {best_f1:.4f}')

    validate(-1, model, test_dataloader, criterion, device, cfg)
コード例 #16
0
def train(model, tokenizer, train_data, valid_data, args, eos=False):
    model.train()

    train_dataset = TextDataset(train_data)
    train_dataloader = DataLoader(train_dataset, sampler=RandomSampler(train_dataset),
                                  batch_size=args.train_batch_size, num_workers=args.num_workers,
                                  collate_fn=lambda x: collate_fn(x, tokenizer, args.max_seq_length, eos=eos, tokenizer_type=args.tokenizer))

    valid_dataset = TextDataset(valid_data)
    valid_dataloader = DataLoader(valid_dataset, sampler=SequentialSampler(valid_dataset),
                                  batch_size=args.eval_batch_size, num_workers=args.num_workers,
                                  collate_fn=lambda x: collate_fn(x, tokenizer, args.max_seq_length, eos=eos, tokenizer_type=args.tokenizer))

    valid_noisy = [x['noisy'] for x in valid_data]
    valid_clean = [x['clean'] for x in valid_data]

    epochs = (args.max_steps - 1) // len(train_dataloader) + 1
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
                                 betas=eval(args.adam_betas), eps=args.eps,
                                 weight_decay=args.weight_decay)
    lr_lambda = lambda x: x / args.num_warmup_steps if x <= args.num_warmup_steps else (x / args.num_warmup_steps) ** -0.5
    scheduler = LambdaLR(optimizer, lr_lambda)

    step = 0
    best_val_gleu = -float("inf")
    meter = Meter()
    for epoch in range(1, epochs + 1):
        print("===EPOCH: ", epoch)
        for batch in train_dataloader:
            step += 1
            batch = tuple(t.to(args.device) for t in batch)
            loss, items = calc_loss(model, batch)
            meter.add(*items)

            loss.backward()
            if args.max_grad_norm > 0:
                nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
            optimizer.step()
            model.zero_grad()
            scheduler.step()

            if step % args.log_interval == 0:
                lr = scheduler.get_lr()[0]
                loss_sent, loss_token = meter.average()

                logger.info(f' [{step:5d}] lr {lr:.6f} | {meter.print_str(True)}')
                nsml.report(step=step, scope=locals(), summary=True,
                            train__lr=lr, train__loss_sent=loss_sent, train__token_ppl=math.exp(loss_token))
                meter.init()

            if step % args.eval_interval == 0:
                start_eval = time.time()
                (val_loss, val_loss_token), valid_str = evaluate(model, valid_dataloader, args)
                prediction = correct(model, tokenizer, valid_noisy, args, eos=eos, length_limit=0.1)
                val_em = em(prediction, valid_clean)
                cnt = 0
                for noisy, pred, clean in zip(valid_noisy, prediction, valid_clean):
                    print(f'[{noisy}], [{pred}], [{clean}]')
                    # 10개만 출력하기
                    cnt += 1
                    if cnt == 20:
                        break
                val_gleu = gleu(prediction, valid_clean)

                logger.info('-' * 89)
                logger.info(f' [{step:6d}] valid | {valid_str} | em {val_em:5.2f} | gleu {val_gleu:5.2f}')
                logger.info('-' * 89)
                nsml.report(step=step, scope=locals(), summary=True,
                            valid__loss_sent=val_loss, valid__token_ppl=math.exp(val_loss_token),
                            valid__em=val_em, valid__gleu=val_gleu)

                if val_gleu > best_val_gleu:
                    best_val_gleu = val_gleu
                    nsml.save("best")
                meter.start += time.time() - start_eval

            if step >= args.max_steps:
                break
        #nsml.save(epoch)
        if step >= args.max_steps:
            break
コード例 #17
0
    if args.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            broadcast_buffers=False,
        )

    train_loader = DataLoader(
        train_set,
        batch_size=args.batch,
        sampler=data_sampler(train_set,
                             shuffle=True,
                             distributed=args.distributed),
        num_workers=args.num_workers,
        collate_fn=collate_fn(args),
    )
    valid_loader = DataLoader(
        valid_set,
        batch_size=args.batch,
        sampler=data_sampler(valid_set,
                             shuffle=False,
                             distributed=args.distributed),
        num_workers=args.num_workers,
        collate_fn=collate_fn(args),
    )

    for epoch in range(args.epoch):
        train(args,
              epoch,
              train_loader,
コード例 #18
0
def main(args):
    acc_list = []
    f1_score_list = []
    prec_list = []
    recall_list = []
    for i in range(10):
        setup_data()
        model = RCNN(vocab_size=args.vocab_size,
                     embedding_dim=args.embedding_dim,
                     hidden_size=args.hidden_size,
                     hidden_size_linear=args.hidden_size_linear,
                     class_num=args.class_num,
                     dropout=args.dropout).to(args.device)

        if args.n_gpu > 1:
            model = torch.nn.DataParallel(model, dim=0)

        train_texts, train_labels = read_file(args.train_file_path)
        word2idx, embedding = build_dictionary(train_texts, args.vocab_size,
                                               args.lexical, args.syntactic,
                                               args.semantic)

        logger.info('Dictionary Finished!')

        full_dataset = CustomTextDataset(train_texts, train_labels, word2idx,
                                         args)
        num_train_data = len(full_dataset) - args.num_val_data
        train_dataset, val_dataset = random_split(
            full_dataset, [num_train_data, args.num_val_data])
        train_dataloader = DataLoader(dataset=train_dataset,
                                      collate_fn=lambda x: collate_fn(x, args),
                                      batch_size=args.batch_size,
                                      shuffle=True)

        valid_dataloader = DataLoader(dataset=val_dataset,
                                      collate_fn=lambda x: collate_fn(x, args),
                                      batch_size=args.batch_size,
                                      shuffle=True)

        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
        train(model, optimizer, train_dataloader, valid_dataloader, embedding,
              args)
        logger.info('******************** Train Finished ********************')

        # Test
        if args.test_set:
            test_texts, test_labels = read_file(args.test_file_path)
            test_dataset = CustomTextDataset(test_texts, test_labels, word2idx,
                                             args)
            test_dataloader = DataLoader(
                dataset=test_dataset,
                collate_fn=lambda x: collate_fn(x, args),
                batch_size=args.batch_size,
                shuffle=True)

            model.load_state_dict(
                torch.load(os.path.join(args.model_save_path, "best.pt")))
            _, accuracy, precision, recall, f1, cm = evaluate(
                model, test_dataloader, embedding, args)
            logger.info('-' * 50)
            logger.info(
                f'|* TEST SET *| |ACC| {accuracy:>.4f} |PRECISION| {precision:>.4f} |RECALL| {recall:>.4f} |F1| {f1:>.4f}'
            )
            logger.info('-' * 50)
            logger.info('---------------- CONFUSION MATRIX ----------------')
            for i in range(len(cm)):
                logger.info(cm[i])
            logger.info('--------------------------------------------------')
            acc_list.append(accuracy / 100)
            prec_list.append(precision)
            recall_list.append(recall)
            f1_score_list.append(f1)

    avg_acc = sum(acc_list) / len(acc_list)
    avg_prec = sum(prec_list) / len(prec_list)
    avg_recall = sum(recall_list) / len(recall_list)
    avg_f1_score = sum(f1_score_list) / len(f1_score_list)
    logger.info('--------------------------------------------------')
    logger.info(
        f'|* TEST SET *| |Avg ACC| {avg_acc:>.4f} |Avg PRECISION| {avg_prec:>.4f} |Avg RECALL| {avg_recall:>.4f} |Avg F1| {avg_f1_score:>.4f}'
    )
    logger.info('--------------------------------------------------')
    plot_df = pd.DataFrame({
        'x_values': range(10),
        'avg_acc': acc_list,
        'avg_prec': prec_list,
        'avg_recall': recall_list,
        'avg_f1_score': f1_score_list
    })
    plt.plot('x_values',
             'avg_acc',
             data=plot_df,
             marker='o',
             markerfacecolor='blue',
             markersize=12,
             color='skyblue',
             linewidth=4)
    plt.plot('x_values',
             'avg_prec',
             data=plot_df,
             marker='',
             color='olive',
             linewidth=2)
    plt.plot('x_values',
             'avg_recall',
             data=plot_df,
             marker='',
             color='olive',
             linewidth=2,
             linestyle='dashed')
    plt.plot('x_values',
             'avg_f1_score',
             data=plot_df,
             marker='',
             color='olive',
             linewidth=2,
             linestyle='dashed')
    plt.legend()
    fname = 'lexical-semantic-syntactic.png' if args.lexical and args.semantic and args.syntactic \
                            else 'semantic-syntactic.png' if args.semantic and args.syntactic \
                            else 'lexical-semantic.png' if args.lexical and args.semantic \
                            else 'lexical-syntactic.png'if args.lexical and args.syntactic \
                            else 'lexical.png' if args.lexical \
                            else 'syntactic.png' if args.syntactic \
                            else 'semantic.png' if args.semantic \
                            else 'plain.png'
    if not (path.exists('./images')):
        mkdir('./images')
    plt.savefig(path.join('./images', fname))
コード例 #19
0
    def _sampling(self, epoch):
        self.model.eval()
        loader = self.test_loader
        asset_path = os.path.join(self.asset_path)

        indices = random.sample(range(len(loader.dataset)),
                                self.config["num_sample"])
        batch = collate_fn([loader.dataset[i] for i in indices])
        for key in batch.keys():
            batch[key] = batch[key].to(self.device)
        prime = batch['pitch'][:, :self.config["num_prime"]]
        if isinstance(self.model, torch.nn.DataParallel):
            model = self.model.module
        else:
            model = self.model
        prime_rhythm = batch['rhythm'][:, :self.config["num_prime"]]
        result_dict = model.sampling(prime_rhythm, prime, batch['chord'],
                                     self.config["topk"],
                                     self.config['attention_map'])
        result_key = 'pitch'
        pitch_idx = result_dict[result_key].cpu().numpy()

        logger.info("==========sampling result of epoch %03d==========" %
                    epoch)
        os.makedirs(os.path.join(asset_path, 'sampling_results',
                                 'epoch_%03d' % epoch),
                    exist_ok=True)

        for sample_id in range(pitch_idx.shape[0]):
            logger.info(("Sample %02d : " % sample_id) +
                        str(pitch_idx[sample_id][self.config["num_prime"]:self.
                                                 config["num_prime"] + 20]))
            save_path = os.path.join(
                asset_path, 'sampling_results', 'epoch_%03d' % epoch,
                'epoch%03d_sample%02d.mid' % (epoch, sample_id))
            gt_pitch = batch['pitch'].cpu().numpy()
            gt_chord = batch['chord'][:, :-1].cpu().numpy()
            sample_dict = {
                'pitch': pitch_idx[sample_id],
                'rhythm': result_dict['rhythm'][sample_id].cpu().numpy(),
                'chord': csc_matrix(gt_chord[sample_id])
            }

            with open(save_path.replace('.mid', '.pkl'), 'wb') as f_samp:
                pickle.dump(sample_dict, f_samp)
            instruments = pitch_to_midi(pitch_idx[sample_id],
                                        gt_chord[sample_id],
                                        model.frame_per_bar, save_path)
            save_instruments_as_image(save_path.replace('.mid', '.jpg'),
                                      instruments,
                                      frame_per_bar=model.frame_per_bar,
                                      num_bars=(model.max_len //
                                                model.frame_per_bar))

            # save groundtruth
            logger.info(("Groundtruth %02d : " % sample_id) + str(gt_pitch[
                sample_id,
                self.config["num_prime"]:self.config["num_prime"] + 20]))
            gt_path = os.path.join(
                asset_path, 'sampling_results', 'epoch_%03d' % epoch,
                'epoch%03d_groundtruth%02d.mid' % (epoch, sample_id))
            gt_dict = {
                'pitch': gt_pitch[sample_id, :-1],
                'rhythm': batch['rhythm'][sample_id, :-1].cpu().numpy(),
                'chord': csc_matrix(gt_chord[sample_id])
            }
            with open(gt_path.replace('.mid', '.pkl'), 'wb') as f_gt:
                pickle.dump(gt_dict, f_gt)
            gt_instruments = pitch_to_midi(gt_pitch[sample_id, :-1],
                                           gt_chord[sample_id],
                                           model.frame_per_bar, gt_path)
            save_instruments_as_image(gt_path.replace('.mid', '.jpg'),
                                      gt_instruments,
                                      frame_per_bar=model.frame_per_bar,
                                      num_bars=(model.max_len //
                                                model.frame_per_bar))

            if self.config['attention_map']:
                os.makedirs(os.path.join(asset_path, 'attention_map',
                                         'epoch_%03d' % epoch, 'RDec-Chord',
                                         'sample_%02d' % sample_id),
                            exist_ok=True)

                for head_num in range(8):
                    for l, w in enumerate(result_dict['weights_bdec']):
                        fig_w = plt.figure(figsize=(8, 8))
                        ax_w = fig_w.add_subplot(1, 1, 1)
                        heatmap_w = ax_w.pcolor(w[sample_id,
                                                  head_num].cpu().numpy(),
                                                cmap='Reds')
                        ax_w.set_xticks(np.arange(0,
                                                  self.model.module.max_len))
                        ax_w.xaxis.tick_top()
                        ax_w.set_yticks(np.arange(0,
                                                  self.model.module.max_len))
                        ax_w.set_xticklabels(rhythm_to_symbol_list(
                            result_dict['rhythm'][sample_id].cpu().numpy()),
                                             fontdict=x_fontdict)
                        chord_symbol_list = [''] * pitch_idx.shape[1]
                        for t in sorted(
                                chord_array_to_dict(
                                    gt_chord[sample_id]).keys()):
                            chord_symbol_list[t] = chord_array_to_dict(
                                gt_chord[sample_id])[t].tolist()
                        ax_w.set_yticklabels(chord_to_symbol_list(
                            gt_chord[sample_id]),
                                             fontdict=y_fontdict)
                        ax_w.invert_yaxis()
                        plt.savefig(
                            os.path.join(
                                asset_path, 'attention_map',
                                'epoch_%03d' % epoch, 'RDec-Chord',
                                'sample_%02d' % sample_id,
                                'epoch%03d_RDec-Chord_sample%02d_head%02d_layer%02d.jpg'
                                % (epoch, sample_id, head_num, l)))
                        plt.close()
コード例 #20
0
def train(model, optimizer, tokenizer, train_data, valid_data, args):
    logger.info("Training starts!")
    os.makedirs(args.model_dir, exist_ok=True)

    train_dataset = QueryDataset(train_data)
    train_data_loader = DataLoader(
        train_dataset,
        sampler=RandomSampler(train_dataset),
        batch_size=args.bsz,
        num_workers=args.num_workers,
        collate_fn=lambda x: collate_fn(x, tokenizer, args.sample, args.
                                        max_seq_len))

    valid_dataset = QueryDataset(valid_data)
    valid_data_loader = DataLoader(
        valid_dataset,
        sampler=SequentialSampler(valid_dataset),
        batch_size=args.bsz,
        num_workers=args.num_workers,
        collate_fn=lambda x: collate_fn(x, tokenizer, args.sample, args.
                                        max_seq_len))

    n_batch = (len(train_dataset) - 1) // args.bsz + 1
    logger.info(f"  Number of training batch: {n_batch}")
    if args.eval_interval is None:
        args.eval_interval = n_batch

    try:
        best_valid_loss = float('inf')
        model.train()
        params = get_params(model)
        train_logger = TrainLogger()
        train_logger_part = TrainLogger()
        step = 0
        for epoch in range(1, args.n_epochs + 1):
            logger.info(f"Epoch {epoch:2d}")
            for batch in train_data_loader:
                step += 1
                batch = tuple(t.to(device) for t in batch)
                loss, items = calc_loss(model, batch)
                loss.backward()
                nn.utils.clip_grad_norm_(params, args.clip)
                optimizer.step()
                optimizer.zero_grad()
                train_logger.add(*items)
                train_logger_part.add(*items)

                if step % args.log_interval == 0:
                    logger.info(
                        f"  step {step:8d} | {train_logger_part.print_str(True)}"
                    )
                    train_logger_part.init()

                if step % args.eval_interval == 0:
                    start_eval = time.time()
                    logger.info('-' * 90)
                    train_loss, train_str = train_logger.average(
                    ), train_logger.print_str()
                    logger.info(f"| step {step:8d} | train | {train_str}")

                    # evaluate valid loss, ppl
                    with torch.no_grad():
                        valid_loss, valid_str = evaluate(
                            model, valid_data_loader, args.eval_n_steps)
                    logger.info(f"| step {step:8d} | valid | {valid_str}")
                    if valid_loss[0] < best_valid_loss:
                        model_save(args.model_dir, model, optimizer)
                        logger.info(">>>>> Saving model (new best validation)")
                        best_valid_loss = valid_loss[0]
                    logger.info('-' * 90)

                    model.train()
                    train_logger.init()
                    train_logger_part.start += time.time() - start_eval

    except KeyboardInterrupt:
        logger.info('-' * 90)
        logger.info('  Exiting from training early')