Пример #1
0
def greedy_test(args):
    """ Test function """

    # load vocabulary
    vocab = torch.load(args.vocab)

    # build model
    translator = Transformer(args, vocab)
    translator.eval()

    # load parameters
    translator.load_state_dict(torch.load(args.decode_model_path))
    if args.cuda:
        translator = translator.cuda()

    test_data = read_corpus(args.decode_from_file, source="src")
    # ['<BOS>', '<PAD>', 'PAD', '<PAD>', '<PAD>']
    pred_data = len(test_data) * [[
        constants.PAD_WORD if i else constants.BOS_WORD
        for i in range(args.decode_max_steps)
    ]]

    output_file = codecs.open(args.decode_output_file, "w", encoding="utf-8")
    for test, pred in zip(test_data, pred_data):
        pred_output = [constants.PAD_WORD] * args.decode_max_steps
        test_var = to_input_variable([test], vocab.src, cuda=args.cuda)

        # only need one time
        enc_output = translator.encode(test_var[0], test_var[1])
        for i in range(args.decode_max_steps):
            pred_var = to_input_variable([pred[:i + 1]],
                                         vocab.tgt,
                                         cuda=args.cuda)

            scores = translator.translate(enc_output, test_var[0], pred_var)

            _, argmax_idxs = torch.max(scores, dim=-1)
            one_step_idx = argmax_idxs[-1].item()

            pred_output[i] = vocab.tgt.id2word[one_step_idx]
            if (one_step_idx
                    == constants.EOS) or (i == args.decode_max_steps - 1):
                print("[Source] %s" % " ".join(test))
                print("[Predict] %s" % " ".join(pred_output[:i]))
                print()

                output_file.write(" ".join(pred_output[:i]) + "\n")
                output_file.flush()
                break
            pred[i + 1] = vocab.tgt.id2word[one_step_idx]

    output_file.close()
Пример #2
0
    def __init__(self, model_source, rewrite_len=30, beam_size=4, debug=False):
        self.beam_size = beam_size
        self.rewrite_len = rewrite_len
        self.debug = debug

        model_source = torch.load(model_source,
                                  map_location=lambda storage, loc: storage)
        self.dict = model_source["word2idx"]
        self.idx2word = {v: k for k, v in model_source["word2idx"].items()}
        self.args = args = model_source["settings"]
        torch.manual_seed(args.seed)
        model = Transformer(args)
        model.load_state_dict(model_source['model'])
        self.model = model.eval()
Пример #3
0
def init_training(args):
    """ Initialize training process """

    # load vocabulary
    vocab = torch.load(args.vocab)

    # build model
    transformer = Transformer(args, vocab)

    # if finetune
    if args.finetune:
        print("[Finetune] %s" % args.finetune_model_path)
        transformer.load_state_dict(torch.load(args.finetune_model_path))

    # vocab_mask for masking padding
    vocab_mask = torch.ones(len(vocab.tgt))
    vocab_mask[vocab.tgt[constants.PAD_WORD]] = 0

    # loss object
    cross_entropy_loss = nn.CrossEntropyLoss(weight=vocab_mask,
                                             size_average=False)

    if args.cuda:
        transformer = transformer.cuda()
        cross_entropy_loss = cross_entropy_loss.cuda()

    if args.optimizer == "Warmup_Adam":
        optimizer = ScheduledOptim(
            torch.optim.Adam(transformer.get_trainable_parameters(),
                             betas=(0.9, 0.98),
                             eps=1e-09), args.d_model, args.n_warmup_steps)

    if args.optimizer == "Adam":
        optimizer = torch.optim.Adam(
            params=transformer.get_trainable_parameters(),
            lr=args.lr,
            betas=(0.9, 0.98),
            eps=1e-8)

    if args.optimizer == 'SGD':
        optimizer = torch.optim.SGD(
            params=transformer.get_trainable_parameters(), lr=args.lr)

    # multi gpus
    if torch.cuda.device_count() > 1:
        print("[Multi GPU] using", torch.cuda.device_count(), "GPUs\n")
        transformer = nn.DataParallel(transformer)

    return vocab, transformer, optimizer, cross_entropy_loss
Пример #4
0
def load_model(checkpoint, device):
    model_args = checkpoint["settings"]

    model = Transformer(
        model_args["embedding_size"],
        model_args["src_vocab_size"],
        model_args["tgt_vocab_size"],
        model_args["src_pad_idx"],
        model_args["num_heads"],
        model_args["num_encoder_layers"],
        model_args["num_decoder_layers"],
        model_args["forward_expansion"],
        model_args["dropout"],
        model_args["max_len"],
        model_args["device"],
    ).to(device)

    model.load_state_dict(checkpoint["state_dict"])
    print("[Info] Trained model state loaded.")
    return model
Пример #5
0
    def __init__(self, model_source, cuda=False, beam_size=3):
        self.torch = torch.cuda if cuda else torch
        self.cuda = cuda
        self.beam_size = beam_size

        if self.cuda:
            model_source = torch.load(model_source)
        else:
            model_source = torch.load(
                model_source, map_location=lambda storage, loc: storage)
        self.src_dict = model_source["src_dict"]
        self.tgt_dict = model_source["tgt_dict"]
        self.src_idx2word = {v: k for k, v in model_source["tgt_dict"].items()}
        self.args = args = model_source["settings"]
        model = Transformer(args)
        model.load_state_dict(model_source['model'])

        if self.cuda: model = model.cuda()
        else: model = model.cpu()
        self.model = model.eval()
Пример #6
0
hp_str = (f"{args.dataset}_subword_"
        f"{args.d_model}_{args.d_hidden}_{args.n_layers}_{args.n_heads}_"
        f"{args.drop_ratio:.3f}_{args.warmup}_{'uni_' if args.causal_enc else ''}")
logger.info(f'Starting with HPARAMS: {hp_str}')
model_name = args.models_dir + '/' + args.prefix + hp_str

# build the model
if not args.universal:
    model = Transformer(SRC, TRG, args)
else:
    model = UniversalTransformer(SRC, TRG, args)

# logger.info(str(model))
if args.load_from is not None:
    with torch.cuda.device(args.gpu):
        model.load_state_dict(torch.load(args.models_dir + '/' + args.load_from + '.pt',
        map_location=lambda storage, loc: storage.cuda()))  # load the pretrained models.


# use cuda
if args.gpu > -1:
    model.cuda(args.gpu)

# additional information
args.__dict__.update({'model_name': model_name, 'hp_str': hp_str,  'logger': logger})

# show the arg:
arg_str = "args:\n"
for w in sorted(args.__dict__.keys()):
    if (w is not "U") and (w is not "V") and (w is not "Freq"):
        arg_str += "{}:\t{}\n".format(w, args.__dict__[w])
logger.info(arg_str)
Пример #7
0
model = args.model(SRC, TRG, args)
logger.info(str(model))
if args.load_from is not None:
    with torch.cuda.device(args.gpu):
        model.load_state_dict(
            torch.load('./models/' + args.load_from + '.pt',
                       map_location=lambda storage, loc: storage.cuda())
        )  # load the pretrained models.

# if using a teacher
teacher_model = None
if args.teacher is not None:
    teacher_model = Transformer(SRC, TRG, teacher_args)
    with torch.cuda.device(args.gpu):
        teacher_model.load_state_dict(
            torch.load('./models/' + args.teacher + '.pt',
                       map_location=lambda storage, loc: storage.cuda()))
    for params in teacher_model.parameters():
        params.requires_grad = False

    if (args.share_encoder) and (args.load_from is None):
        model.encoder = copy.deepcopy(teacher_model.encoder)
        for params in model.encoder.parameters():
            params.requires_grad = True

# use cuda
if args.gpu > -1:
    model.cuda(args.gpu)
    if align_table is not None:
        align_table = torch.LongTensor(align_table).cuda(args.gpu)
        align_table = Variable(align_table)
    train_set = ASRDataset(conf.dataset.path, indices=split_indices["train"])

    model = Transformer(
        conf.dataset.n_vocab,
        conf.model.delta,
        conf.dataset.n_mels,
        conf.model.feature_channel,
        conf.model.dim,
        conf.model.dim_ff,
        conf.model.n_layer,
        conf.model.n_head,
        conf.model.dropout,
    ).to(device)

    ckpt = torch.load(conf.ckpt, map_location=lambda storage, loc: storage)
    model.load_state_dict(ckpt["model"])

    model.eval()

    train_loader = conf.training.dataloader.make(train_set,
                                                 collate_fn=collate_data)

    pbar = tqdm(train_loader)

    show_sample = 0
    db_i = 0

    with torch.no_grad() as no_grad, lmdb.open(conf.dataset.alignment,
                                               map_size=1024**4,
                                               readahead=False) as env:
        for mels, tokens, mel_lengths, token_lengths, texts, files in pbar:
def main(args):
    # Image preprocessing
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    # Load vocabulary wrapper
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Build models
    if args.model_type == 'no_attention':
        encoder = Encoder(args.embed_size).to(device)
        decoder = Decoder(args.embed_size, args.hidden_size, len(vocab),
                          args.num_layers).to(device)

    elif args.model_type == 'attention':
        encoder = EncoderAtt(encoded_image_size=9).to(device)
        decoder = DecoderAtt(vocab, args.encoder_dim, args.hidden_size,
                             args.attention_dim, args.embed_size,
                             args.dropout_ratio, args.alpha_c).to(device)

    elif args.model_type == 'transformer':

        model = Transformer(len(vocab), args.embed_size,
                            args.transformer_layers, 8,
                            args.dropout_ratio).eval()

    else:
        print('Select model_type attention or no_attention')

    if args.model_type != 'transformer':
        encoder = encoder.to(device)
        decoder = decoder.to(device)

        # Load the trained model parameters
        encoder.load_state_dict(
            torch.load(args.encoder_path, map_location=torch.device('cpu')))
        decoder.load_state_dict(
            torch.load(args.decoder_path, map_location=torch.device('cpu')))
    else:
        model = model.to(device)
        model.load_state_dict(
            torch.load(args.model_path, map_location=torch.device('cpu')))

    filenames = os.listdir(args.image_dir)

    predicted = {}

    for file in tqdm(filenames):
        if file == '.DS_Store':
            continue
        # Prepare an image
        image = load_image(os.path.join(args.image_dir, file), transform)
        image_tensor = image.to(device)

        if args.model_type == 'attention':
            features = encoder(image_tensor)
            sampled_ids, _ = decoder.sample(features)
            sampled_ids = sampled_ids[0].cpu().numpy()
            sampled_caption = ['<start>']
        elif args.model_type == 'no_attention':
            features = encoder(image_tensor)
            sampled_ids = decoder.sample(features)
            sampled_ids = sampled_ids[0].cpu().numpy()
            sampled_caption = ['<start>']

        elif args.model_type == 'transformer':
            e_outputs = model.encoder(image_tensor)
            max_seq_length = 20
            sampled_ids = torch.zeros(max_seq_length, dtype=torch.long)
            sampled_ids[0] = torch.LongTensor([[vocab.word2idx['<start>']]
                                               ]).to(device)

            for i in range(1, max_seq_length):

                trg_mask = np.triu(np.ones((1, i, i)), k=1).astype('uint8')
                trg_mask = Variable(torch.from_numpy(trg_mask) == 0).to(device)

                out = model.decoder(sampled_ids[:i].unsqueeze(0), e_outputs,
                                    trg_mask)

                out = model.out(out)
                out = F.softmax(out, dim=-1)
                val, ix = out[:, -1].data.topk(1)
                sampled_ids[i] = ix[0][0]

            sampled_ids = sampled_ids.cpu().numpy()
            sampled_caption = []

        # Convert word_ids to words
        #sampled_caption = []
        for word_id in sampled_ids:
            word = vocab.idx2word[word_id]
            sampled_caption.append(word)
            if word == '<end>':
                break
        sentence = ' '.join(sampled_caption)
        #print(sentence)
        predicted[file] = sentence
        #print(file, sentence)

    json.dump(predicted, open(args.predict_json, 'w'))
Пример #10
0
def test(args):
    """ Decode with beam search """

    # load vocabulary
    vocab = torch.load(args.vocab)

    # build model
    translator = Transformer(args, vocab)
    translator.eval()

    # load parameters
    translator.load_state_dict(torch.load(args.decode_model_path))
    if args.cuda:
        translator = translator.cuda()

    test_data = read_corpus(args.decode_from_file, source="src")
    output_file = codecs.open(args.decode_output_file, "w", encoding="utf-8")
    for test in test_data:
        test_seq, test_pos = to_input_variable([test],
                                               vocab.src,
                                               cuda=args.cuda)
        test_seq_beam = test_seq.expand(args.decode_beam_size,
                                        test_seq.size(1))

        enc_output = translator.encode(test_seq, test_pos)
        enc_output_beam = enc_output.expand(args.decode_beam_size,
                                            enc_output.size(1),
                                            enc_output.size(2))

        beam = Beam_Search_V2(beam_size=args.decode_beam_size,
                              tgt_vocab=vocab.tgt,
                              length_alpha=args.decode_alpha)
        for i in range(args.decode_max_steps):

            # the first time for beam search
            if i == 0:
                # <BOS>
                pred_var = to_input_variable(beam.candidates[:1],
                                             vocab.tgt,
                                             cuda=args.cuda)
                scores = translator.translate(enc_output, test_seq, pred_var)
            else:
                pred_var = to_input_variable(beam.candidates,
                                             vocab.tgt,
                                             cuda=args.cuda)
                scores = translator.translate(enc_output_beam, test_seq_beam,
                                              pred_var)

            log_softmax_scores = F.log_softmax(scores, dim=-1)
            log_softmax_scores = log_softmax_scores.view(
                pred_var[0].size(0), -1, log_softmax_scores.size(-1))
            log_softmax_scores = log_softmax_scores[:, -1, :]

            is_done = beam.advance(log_softmax_scores)
            beam.update_status()

            if is_done:
                break

        print("[Source] %s" % " ".join(test))
        print("[Predict] %s" % beam.get_best_candidate())
        print()

        output_file.write(beam.get_best_candidate() + "\n")
        output_file.flush()

    output_file.close()
Пример #11
0
def main():
    """Entry point.
    """
    if torch.cuda.is_available():
        device = torch.device(torch.cuda.current_device())
        print(f"Using CUDA device {device}")
    else:
        device = None

    # Load data
    vocab = Vocab(config_data.vocab_file)
    data_hparams = {
        # "batch_size" is ignored for train since we use dynamic batching
        "batch_size": config_data.test_batch_size,
        "bos_id": vocab.bos_token_id,
        "eos_id": vocab.eos_token_id,
    }
    datasets = {
        split: data_utils.Seq2SeqData(os.path.join(
            config_data.input_dir,
            f"{config_data.filename_prefix}{split}.npy"),
                                      hparams=data_hparams,
                                      device=device)
        for split in ["train", "valid", "test"]
    }
    print(f"Training data size: {len(datasets['train'])}")
    beam_width = config_model.beam_width

    # Create logging
    tx.utils.maybe_create_dir(args.output_dir)
    logging_file = os.path.join(args.output_dir, "logging.txt")
    logger = utils.get_logger(logging_file)
    print(f"logging file is saved in: {logging_file}")

    # Create model and optimizer
    model = Transformer(config_model, config_data, vocab).to(device)

    best_results = {"score": 0, "epoch": -1}
    lr_config = config_model.lr_config
    if lr_config["learning_rate_schedule"] == "static":
        init_lr = lr_config["static_lr"]
        scheduler_lambda = lambda x: 1.0
    else:
        init_lr = lr_config["lr_constant"]
        scheduler_lambda = functools.partial(
            utils.get_lr_multiplier, warmup_steps=lr_config["warmup_steps"])
    optim = torch.optim.Adam(model.parameters(),
                             lr=init_lr,
                             betas=(0.9, 0.997),
                             eps=1e-9)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optim, scheduler_lambda)

    @torch.no_grad()
    def _eval_epoch(epoch, mode, print_fn=None):
        if print_fn is None:
            print_fn = print
            tqdm_leave = True
        else:
            tqdm_leave = False
        model.eval()
        eval_data = datasets[mode]
        eval_iter = tx.data.DataIterator(eval_data)
        references, hypotheses = [], []
        for batch in tqdm.tqdm(eval_iter,
                               ncols=80,
                               leave=tqdm_leave,
                               desc=f"Eval on {mode} set"):
            predictions = model(
                encoder_input=batch.source,
                beam_width=beam_width,
            )
            if beam_width == 1:
                decoded_ids = predictions[0].sample_id
            else:
                decoded_ids = predictions["sample_id"][:, :, 0]

            hypotheses.extend(h.tolist() for h in decoded_ids)
            references.extend(r.tolist() for r in batch.target_output)
        hypotheses = utils.list_strip_eos(hypotheses, vocab.eos_token_id)
        references = utils.list_strip_eos(references, vocab.eos_token_id)

        if mode == "valid":
            # Writes results to files to evaluate BLEU
            # For 'eval' mode, the BLEU is based on token ids (rather than
            # text tokens) and serves only as a surrogate metric to monitor
            # the training process
            fname = os.path.join(args.output_dir, "tmp.eval")
            hwords, rwords = [], []
            for hyp, ref in zip(hypotheses, references):
                hwords.append([str(y) for y in hyp])
                rwords.append([str(y) for y in ref])
            hwords = tx.utils.str_join(hwords)
            rwords = tx.utils.str_join(rwords)
            hyp_file, ref_file = tx.utils.write_paired_text(
                hwords,
                rwords,
                fname,
                mode="s",
                src_fname_suffix="hyp",
                tgt_fname_suffix="ref",
            )
            eval_bleu = tx.evals.file_bleu(ref_file,
                                           hyp_file,
                                           case_sensitive=True)
            logger.info("epoch: %d, eval_bleu %.4f", epoch, eval_bleu)
            print_fn(f"epoch: {epoch:d}, eval_bleu {eval_bleu:.4f}")

            if eval_bleu > best_results["score"]:
                logger.info("epoch: %d, best bleu: %.4f", epoch, eval_bleu)
                best_results["score"] = eval_bleu
                best_results["epoch"] = epoch
                model_path = os.path.join(args.output_dir,
                                          args.output_filename)
                logger.info("Saving model to %s", model_path)
                print_fn(f"Saving model to {model_path}")

                states = {
                    "model": model.state_dict(),
                    "optimizer": optim.state_dict(),
                    "scheduler": scheduler.state_dict(),
                }
                torch.save(states, model_path)

        elif mode == "test":
            # For 'test' mode, together with the commands in README.md, BLEU
            # is evaluated based on text tokens, which is the standard metric.
            fname = os.path.join(args.output_dir, "test.output")
            hwords, rwords = [], []
            for hyp, ref in zip(hypotheses, references):
                hwords.append(vocab.map_ids_to_tokens_py(hyp))
                rwords.append(vocab.map_ids_to_tokens_py(ref))
            hwords = tx.utils.str_join(hwords)
            rwords = tx.utils.str_join(rwords)
            hyp_file, ref_file = tx.utils.write_paired_text(
                hwords,
                rwords,
                fname,
                mode="s",
                src_fname_suffix="hyp",
                tgt_fname_suffix="ref",
            )
            logger.info("Test output written to file: %s", hyp_file)
            print_fn(f"Test output written to file: {hyp_file}")

    def _train_epoch(epoch: int):
        model.train()
        train_iter = tx.data.DataIterator(
            datasets["train"],
            data_utils.CustomBatchingStrategy(config_data.max_batch_tokens))

        progress = tqdm.tqdm(
            train_iter,
            ncols=80,
            desc=f"Training epoch {epoch}",
        )
        for train_batch in progress:
            optim.zero_grad()
            loss = model(
                encoder_input=train_batch.source,
                decoder_input=train_batch.target_input,
                labels=train_batch.target_output,
            )
            loss.backward()

            optim.step()
            scheduler.step()

            step = scheduler.last_epoch
            if step % config_data.display_steps == 0:
                logger.info("step: %d, loss: %.4f", step, loss)
                lr = optim.param_groups[0]["lr"]
                progress.write(f"lr: {lr:.4e} step: {step}, loss: {loss:.4}")
            if step and step % config_data.eval_steps == 0:
                _eval_epoch(epoch, mode="valid", print_fn=progress.write)
        progress.close()

    model_path = os.path.join(args.output_dir, args.output_filename)

    if args.run_mode == "train_and_evaluate":
        logger.info("Begin running with train_and_evaluate mode")
        if os.path.exists(model_path):
            logger.info("Restore latest checkpoint in %s", model_path)
            ckpt = torch.load(model_path)
            model.load_state_dict(ckpt["model"])
            optim.load_state_dict(ckpt["optimizer"])
            scheduler.load_state_dict(ckpt["scheduler"])
            _eval_epoch(0, mode="valid")

        for epoch in range(config_data.max_train_epoch):
            _train_epoch(epoch)
            _eval_epoch(epoch, mode="valid")

    elif args.run_mode in ["evaluate", "test"]:
        logger.info("Begin running with %s mode", args.run_mode)
        logger.info("Restore latest checkpoint in %s", model_path)
        ckpt = torch.load(model_path)
        model.load_state_dict(ckpt["model"])
        _eval_epoch(0, mode=("test" if args.run_mode == "test" else "valid"))

    else:
        raise ValueError(f"Unknown mode: {args.run_mode}")
Пример #12
0
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': loss
                    }, model_name)

            else:
                print(print_out, end='')
                myfile.write(print_out)

            myfile.close()

    print('Training completed!')
    print('Testing started.......')
    ## Testing
    checkpoint = torch.load('model_name', map_location='cpu')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    model.eval()
    with torch.no_grad():
        running_acc_test = 0
        running_loss_test = 0
        all_pred = []
        all_labels = []
        for batch in test_iterator:
            batch.codes, batch.label = batch.codes.to(device), batch.label.to(
                device)
            output_test = model(batch.codes).squeeze(1)
            loss_test = criterion(output_test, batch.label)
            acc_test = softmax_accuracy(output_test, batch.label)
            running_acc_test += acc_test.item()
Пример #13
0
vocab_to_json(SRC.vocab, params.word_json_file, params.src_lang)
print("write data to json finished !")

# optimizer = torch.optim.Adam(model.parameters(), lr=params.lr, betas=(0.9, 0.98), eps=1e-9)
optimizer = ScheduledOptim(
    torch.optim.Adam(model.parameters(),
                     lr=params.lr,
                     betas=(0.9, 0.98),
                     eps=1e-09), params.d_model, params.n_warmup_steps)
performance = Performance(len(TRG.vocab),
                          trg_pad,
                          is_smooth=params.is_label_smooth)
print("\nbegin training model")
if os.path.exists("models/tfs.pkl"):
    print('load previous trained model')
    model.load_state_dict(torch.load("models/tfs.pkl"))

best_loss = None
train_global_steps = 0

writer = SummaryWriter()
for epoch in range(params.epochs):
    start = time.time()
    total_loss = 0.0
    step = 0

    train_n_word_total = 0
    train_n_word_correct = 0

    test_n_word_total = 0
    test_n_word_correct = 0
Пример #14
0
def main() -> None:
    """Entry point.
    """
    print("Start!!!")
    sys.stdout.flush()
    if args.run_mode == "train":
        train_data = MultiAlignedDataMultiFiles(config_data.train_data_params,
                                                device=device)
        #train_data = tx.data.MultiAlignedData(config_data.train_data_params, device=device)
        print("will data_iterator")
        data_iterator = tx.data.DataIterator({"train": train_data})
        print("data_iterator done")

        # Create model and optimizer
        model = Transformer(config_model, config_data, train_data.vocab('src'))
        model.to(device)
        print("device:", device)
        print("vocab src1:", train_data.vocab('src').id_to_token_map_py)
        print("vocab src2:", train_data.vocab('src').token_to_id_map_py)

        model = ModelWrapper(model, config_model.beam_width)
        if torch.cuda.device_count() > 1:
            #model = nn.DataParallel(model.cuda(), device_ids=[0, 1]).to(device)
            #model = MyDataParallel(model.cuda(), device_ids=[0, 1]).to(device)
            model = MyDataParallel(model.cuda()).to(device)

        lr_config = config_model.lr_config
        if lr_config["learning_rate_schedule"] == "static":
            init_lr = lr_config["static_lr"]
            scheduler_lambda = lambda x: 1.0
        else:
            init_lr = lr_config["lr_constant"]
            scheduler_lambda = functools.partial(
                get_lr_multiplier, warmup_steps=lr_config["warmup_steps"])
        optim = torch.optim.Adam(model.parameters(),
                                 lr=init_lr,
                                 betas=(0.9, 0.997),
                                 eps=1e-9)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optim, scheduler_lambda)

        output_dir = Path(args.output_dir)
        if not output_dir.exists():
            output_dir.mkdir()

        def _save_epoch(epoch):

            checkpoint_name = f"checkpoint{epoch}.pt"
            print(f"saveing model...{checkpoint_name}")
            torch.save(model.state_dict(), output_dir / checkpoint_name)

        def _train_epoch(epoch):
            data_iterator.switch_to_dataset('train')
            model.train()
            #model.module.train()
            #print("after model.module.train")
            sys.stdout.flush()
            step = 0
            num_steps = len(data_iterator)
            loss_stats = []
            for batch in data_iterator:
                #print("batch:", batch)
                #batch = batch.to(device)
                return_dict = model(batch)
                #return_dict = model.module.forward(batch)
                loss = return_dict['loss']
                #print("loss:", loss)
                loss = loss.mean()
                #print("loss:", loss)
                #print("loss.item():", loss.item())
                loss_stats.append(loss.item())

                optim.zero_grad()
                loss.backward()
                optim.step()
                scheduler.step()

                config_data.display = 1
                if step % config_data.display == 0:
                    avr_loss = sum(loss_stats) / len(loss_stats)
                    ppl = utils.get_perplexity(avr_loss)
                    print(
                        f"epoch={epoch}, step={step}/{num_steps}, loss={avr_loss:.4f}, ppl={ppl:.4f}, lr={scheduler.get_lr()[0]}"
                    )
                    sys.stdout.flush()
                step += 1

        print("will train")
        for i in range(config_data.num_epochs):
            print("epoch i:", i)
            sys.stdout.flush()
            _train_epoch(i)
            _save_epoch(i)

    elif args.run_mode == "test":
        test_data = tx.data.MultiAlignedData(config_data.test_data_params,
                                             device=device)
        data_iterator = tx.data.DataIterator({"test": test_data})
        print("test_data vocab src1 before load:",
              test_data.vocab('src').id_to_token_map_py)

        # Create model and optimizer
        model = Transformer(config_model, config_data, test_data.vocab('src'))

        model = ModelWrapper(model, config_model.beam_width)
        #print("state_dict:", model.state_dict())
        model_loaded = torch.load(args.load_checkpoint)
        #print("model_loaded state_dict:", model_loaded)
        model_loaded = rm_begin_str_in_keys("module.", model_loaded)
        #print("model_loaded2 state_dict:", model_loaded)

        model.load_state_dict(model_loaded)
        #model.load_state_dict(torch.load(args.load_checkpoint))
        model.to(device)

        data_iterator.switch_to_dataset('test')
        model.eval()
        print("will predict !!!")
        sys.stdout.flush()

        fo = open(args.pred_output_file, "w")
        print("test_data vocab src1:",
              test_data.vocab('src').id_to_token_map_py)
        print("test_data vocab src2:",
              test_data.vocab('src').token_to_id_map_py)
        with torch.no_grad():
            for batch in data_iterator:
                print("batch:", batch)
                return_dict = model.predict(batch)
                preds = return_dict['preds'].cpu()
                print("preds:", preds)
                pred_words = tx.data.map_ids_to_strs(preds,
                                                     test_data.vocab('src'))
                #src_words = tx.data.map_ids_to_strs(batch['src_text'], test_data.vocab('src'))
                src_words = [" ".join(sw) for sw in batch['src_text']]
                for swords, words in zip(src_words, pred_words):
                    print(str(swords) + "\t" + str(words))
                    fo.write(str(words) + "\n")
                #print(" ".join(batch.src_text) + "\t" + pred_words)
                #print(batch.src_text, pred_words)
                #fo.write(str(pred_words) + "\n")
                fo.flush()
        fo.close()

    else:
        raise ValueError(f"Unknown mode: {args.run_mode}")
Пример #15
0
HID_DIM = 512
ENC_LAYERS = 6
DEC_LAYERS = 6
ENC_HEADS = 8
DEC_HEADS = 8
ENC_PF_DIM = 2048
DEC_PF_DIM = 2048
ENC_DROPOUT = 0.1
DEC_DROPOUT = 0.1

enc = Encoder(INPUT_DIM, HID_DIM, ENC_LAYERS, ENC_HEADS, ENC_PF_DIM,
              ENC_DROPOUT, device)
dec = Decoder(OUTPUT_DIM, HID_DIM, DEC_LAYERS, DEC_HEADS, DEC_PF_DIM,
              DEC_DROPOUT, device)
model = Transformer(enc, dec, PAD_IDX, device).to(device)
model.load_state_dict(torch.load('model.pt'))
model.eval()

sent = '中新网9月19日电据英国媒体报道,当地时间19日,苏格兰公投结果出炉,55%选民投下反对票,对独立说“不”。在结果公布前,英国广播公司(BBC)预测,苏格兰选民以55%对45%投票反对独立。'
tokens = [tok for tok in jieba.cut(sent)]
tokens = [TEXT.init_token] + tokens + [TEXT.eos_token]

src_indexes = [vocab2id.get(token, UNK_IDX) for token in tokens]
src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(device)
src_mask = model.make_src_mask(src_tensor)

with torch.no_grad():
    enc_src = model.encoder(src_tensor, src_mask)

trg_indexes = [vocab2id[TEXT.init_token]]
Пример #16
0
train_iter, valid_iter = data.BucketIterator.splits(
    datasets=(train_data, valid_data),
    batch_size=Config.batch_size,
    sort_key=lambda x: len(x.text),
    repeat=False,
    shuffle=True)

time_end = time.time()
print(time_end - time_start, 's complete the processed')

print("vocab size", len(TEXT.vocab))
model = Transformer(Config, len(TEXT.vocab))
print(model)
model.cuda()
model.train()
model.load_state_dict(torch.load('epoch_0_0.4479.pt'))
optimizer = optim.Adam(model.parameters(), lr=Config.lr)
loss = nn.CrossEntropyLoss()
model.add_optimizer(optimizer)
model.add_loss_op(loss)
train_losses = []
val_accuracies = []
#    val_accuracy, F1_score = evaluate_model(model, dataset.val_iterator)
#    print("\tVal Accuracy: {:.4f}".format(val_accuracy))

for i in range(Config.max_epochs):
    print("\nEpoch: {}".format(i))
    train_loss, val_accuracy = model.run_epoch(train_iter, valid_iter, i)
    train_losses.append(train_loss)
    val_accuracies.append(val_accuracy)
Пример #17
0
from my_data import MyData

my_data = MyData(params.src_lang, params.trg_lang)

model = Transformer(len(my_data.src_word2idx), len(my_data.trg_word2idx),
                    params.d_model, params.n_layers, params.heads,
                    params.dropout)
if params.is_cuda:
    model = model.cuda()

# print(model)
print('trg_vocal_len: ', len(my_data.trg_word2idx))
print('src_vocab_len: ', len(my_data.src_word2idx))

torch.save(model.state_dict(), "models/tfs.pkl")
model.load_state_dict(torch.load('models/tfs.pkl'))
model.eval()

# print(batch)
# src = batch.src.transpose(0, 1)
# trg = batch.trg.transpose(0, 1)
src_sentence = 'He was brave.'
sentence = [i for i in nltk.word_tokenize(src_sentence)]
src = my_data.turn_to_idx(sentence, params.src_lang)
src = src.unsqueeze(0)

trg_sentence = ['<sos>']
trg_input = my_data.turn_to_idx(trg_sentence, params.trg_lang)
trg_input = trg_input.unsqueeze(0)

print(src.size())
Пример #18
0
def main():
    parser = argparse.ArgumentParser(description='Commonsense Dataset Dev')

    # Experiment params
    parser.add_argument('--mode', type=str, help='train or test mode', required=True, choices=['train', 'test'])
    parser.add_argument('--expt_dir', type=str, help='root directory to save model & summaries')
    parser.add_argument('--expt_name', type=str, help='expt_dir/expt_name: organize experiments')
    parser.add_argument('--run_name', type=str, help='expt_dir/expt_name/run_name: organize training runs')
    parser.add_argument('--test_file', type=str, default='test',
                        help='The file containing test data to evaluate in test mode.')

    # Model params
    parser.add_argument('--model', type=str, help='transformer model (e.g. roberta-base)', required=True)
    parser.add_argument('--num_layers', type=int,
                        help='Number of hidden layers in transformers (default number if not provided)', default=-1)
    parser.add_argument('--seq_len', type=int, help='tokenized input sequence length', default=256)
    parser.add_argument('--num_cls', type=int, help='model number of classes', default=2)
    parser.add_argument('--ckpt', type=str, help='path to model checkpoint .pth file')

    # Data params
    parser.add_argument('--pred_file', type=str, help='address of prediction csv file, for "test" mode',
                        default='results.csv')
    parser.add_argument('--dataset', type=str, default='com2sense')
    # Training params
    parser.add_argument('--lr', type=float, help='learning rate', default=1e-5)
    parser.add_argument('--epochs', type=int, help='number of epochs', default=100)
    parser.add_argument('--batch_size', type=int, help='batch size', default=8)
    parser.add_argument('--acc_step', type=int, help='gradient accumulation steps', default=1)
    parser.add_argument('--log_interval', type=int, help='interval size for logging training summaries', default=100)
    parser.add_argument('--save_interval', type=int, help='save model after `n` weight update steps', default=30000)
    parser.add_argument('--val_size', type=int, help='validation set size for evaluating metrics, '
                                                     'and it need to be even to get pairwise accuracy', default=2048)

    # GPU params
    parser.add_argument('--gpu_ids', type=str, help='GPU IDs (0,1,2,..) seperated by comma', default='0')
    parser.add_argument('-data_parallel',
                        help='Whether to use nn.dataparallel (currently available for BERT-based models)',
                        action='store_true')
    parser.add_argument('--use_amp', type=str2bool, help='Automatic-Mixed Precision (T/F)', default='T')
    parser.add_argument('-cpu', help='use cpu only (for test)', action='store_true')

    # Misc params
    parser.add_argument('--num_workers', type=int, help='number of worker threads for Dataloader', default=1)

    # Parse Args
    args = parser.parse_args()

    # Dataset list
    dataset_names = csv2list(args.dataset)
    print()

    # Multi-GPU
    device_ids = csv2list(args.gpu_ids, int)
    print('Selected GPUs: {}'.format(device_ids))

    # Device for loading dataset (batches)
    device = torch.device(device_ids[0])
    if args.cpu:
        device = torch.device('cpu')

    # Text-to-Text
    text2text = ('t5' in args.model)
    uniqa = ('unified' in args.model)

    assert not (text2text and args.use_amp == 'T'), 'use_amp should be F when using T5-based models.'
    # Train params
    n_epochs = args.epochs
    batch_size = args.batch_size
    lr = args.lr
    accumulation_steps = args.acc_step
    # Todo: Verify the grad-accum code (loss avging seems slightly incorrect)

    # Train
    if args.mode == 'train':
        # Ensure CUDA available for training
        assert torch.cuda.is_available(), 'No CUDA device for training!'

        # Setup train log directory
        log_dir = os.path.join(args.expt_dir, args.expt_name, args.run_name)

        if not os.path.exists(log_dir):
            os.makedirs(log_dir)

        # TensorBoard summaries setup  -->  /expt_dir/expt_name/run_name/
        writer = SummaryWriter(log_dir)

        # Train log file
        log_file = setup_logger(parser, log_dir)

        print('Training Log Directory: {}\n'.format(log_dir))

        # Dataset & Dataloader
        dataset = BaseDataset('train', tokenizer=args.model, max_seq_len=args.seq_len, text2text=text2text, uniqa=uniqa)
        train_datasets = ConcatDataset([dataset])

        dataset = BaseDataset('dev', tokenizer=args.model, max_seq_len=args.seq_len, text2text=text2text, uniqa=uniqa)
        val_datasets = ConcatDataset([dataset])

        train_loader = DataLoader(train_datasets, batch_size, shuffle=True, drop_last=True,
                                  num_workers=args.num_workers)
        val_loader = DataLoader(val_datasets, batch_size, shuffle=True, drop_last=True, num_workers=args.num_workers)

        # In multi-dataset setups, also track dataset-specific loaders for validation metrics
        val_dataloaders = []
        if len(dataset_names) > 1:
            for val_dset in val_datasets.datasets:
                loader = DataLoader(val_dset, batch_size, shuffle=True, drop_last=True, num_workers=args.num_workers)

                val_dataloaders.append(loader)

        # Tokenizer
        tokenizer = dataset.get_tokenizer()

        # Split sizes
        train_size = train_datasets.__len__()
        val_size = val_datasets.__len__()
        log_msg = 'Train: {} \nValidation: {}\n\n'.format(train_size, val_size)

        # Min of the total & subset size
        val_used_size = min(val_size, args.val_size)
        log_msg += 'Validation Accuracy is computed using {} samples. See --val_size\n'.format(val_used_size)

        log_msg += 'No. of Classes: {}\n'.format(args.num_cls)
        print_log(log_msg, log_file)

        # Build Model
        model = Transformer(args.model, args.num_cls, text2text, device_ids, num_layers=args.num_layers)
        if args.data_parallel and not args.ckpt:
            model = nn.DataParallel(model, device_ids=device_ids)
            device = torch.device(f'cuda:{model.device_ids[0]}')

        if not text2text:
            model.to(device)

        model.train()

        # Loss & Optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr)
        optimizer.zero_grad()

        scaler = GradScaler(enabled=args.use_amp)

        # Step & Epoch
        start_epoch = 1
        curr_step = 1
        best_val_acc = 0.0

        # Load model checkpoint file (if specified)
        if args.ckpt:
            checkpoint = torch.load(args.ckpt, map_location=device)

            # Load model & optimizer
            model.load_state_dict(checkpoint['model_state_dict'])
            if args.data_parallel:
                model = nn.DataParallel(model, device_ids=device_ids)
                device = torch.device(f'cuda:{model.device_ids[0]}')
            model.to(device)

            curr_step = checkpoint['curr_step']
            start_epoch = checkpoint['epoch']
            prev_loss = checkpoint['loss']

            log_msg = 'Resuming Training...\n'
            log_msg += 'Model successfully loaded from {}\n'.format(args.ckpt)
            log_msg += 'Training loss: {:2f} (from ckpt)\n'.format(prev_loss)

            print_log(log_msg, log_file)

        steps_per_epoch = len(train_loader)
        start_time = time()

        for epoch in range(start_epoch, start_epoch + n_epochs):
            for batch in tqdm(train_loader):
                # Load batch to device
                batch = {k: v.to(device) for k, v in batch.items()}

                with autocast(args.use_amp):
                    if text2text:
                        # Forward + Loss
                        output = model(batch)
                        loss = output[0]

                    else:
                        # Forward Pass
                        label_logits = model(batch)
                        label_gt = batch['label']

                        # Compute Loss
                        loss = criterion(label_logits, label_gt)

                if args.data_parallel:
                    loss = loss.mean()
                # Backward Pass
                loss /= accumulation_steps
                scaler.scale(loss).backward()

                if curr_step % accumulation_steps == 0:
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()

                # Print Results - Loss value & Validation Accuracy
                if curr_step % args.log_interval == 0:
                    # Validation set accuracy
                    if val_datasets:
                        val_metrics = compute_eval_metrics(model, val_loader, device, val_used_size, tokenizer,
                                                           text2text, parallel=args.data_parallel)

                        # Reset the mode to training
                        model.train()

                        log_msg = 'Validation Accuracy: {:.2f} %  || Validation Loss: {:.4f}'.format(
                            val_metrics['accuracy'], val_metrics['loss'])

                        print_log(log_msg, log_file)

                        # Add summaries to TensorBoard
                        writer.add_scalar('Val/Loss', val_metrics['loss'], curr_step)
                        writer.add_scalar('Val/Accuracy', val_metrics['accuracy'], curr_step)

                    # Add summaries to TensorBoard
                    writer.add_scalar('Train/Loss', loss.item(), curr_step)

                    # Compute elapsed & remaining time for training to complete
                    time_elapsed = (time() - start_time) / 3600

                    log_msg = 'Epoch [{}/{}], Step [{}/{}], Loss: {:.4f} | time elapsed: {:.2f}h |'.format(
                        epoch, n_epochs, curr_step, steps_per_epoch, loss.item(), time_elapsed)

                    print_log(log_msg, log_file)

                # Save the model
                if curr_step % args.save_interval == 0:
                    path = os.path.join(log_dir, 'model_' + str(curr_step) + '.pth')

                    state_dict = {'model_state_dict': model.state_dict(),
                                  'curr_step': curr_step, 'loss': loss.item(),
                                  'epoch': epoch, 'val_accuracy': best_val_acc}

                    torch.save(state_dict, path)

                    log_msg = 'Saving the model at the {} step to directory:{}'.format(curr_step, log_dir)
                    print_log(log_msg, log_file)

                curr_step += 1

            # Validation accuracy on the entire set
            if val_datasets:
                log_msg = '-------------------------------------------------------------------------\n'
                val_metrics = compute_eval_metrics(model, val_loader, device, val_size, tokenizer, text2text,
                                                   parallel=args.data_parallel)

                log_msg += '\nAfter {} epoch:\n'.format(epoch)
                log_msg += 'Validation Accuracy: {:.2f} %  || Validation Loss: {:.4f}\n'.format(
                    val_metrics['accuracy'], val_metrics['loss'])

                # For Multi-Dataset setup:
                if len(dataset_names) > 1:
                    # compute validation set metrics on each dataset independently
                    for loader in val_dataloaders:
                        metrics = compute_eval_metrics(model, loader, device, val_size, tokenizer, text2text,
                                                       parallel=args.data_parallel)

                        log_msg += '\n --> {}\n'.format(loader.dataset.get_classname())
                        log_msg += 'Validation Accuracy: {:.2f} %  || Validation Loss: {:.4f}\n'.format(
                            metrics['accuracy'], metrics['loss'])

                # Save best model after every epoch
                if val_metrics["accuracy"] > best_val_acc:
                    best_val_acc = val_metrics["accuracy"]

                    step = '{:.1f}k'.format(curr_step / 1000) if curr_step > 1000 else '{}'.format(curr_step)
                    filename = 'ep_{}_stp_{}_acc_{:.4f}_{}.pth'.format(
                        epoch, step, best_val_acc, args.model.replace('-', '_').replace('/', '_'))

                    path = os.path.join(log_dir, filename)
                    if args.data_parallel:
                        model_state_dict = model.module.state_dict()
                    else:
                        model_state_dict = model.state_dict()
                    state_dict = {'model_state_dict': model_state_dict,
                                  'curr_step': curr_step, 'loss': loss.item(),
                                  'epoch': epoch, 'val_accuracy': best_val_acc}

                    torch.save(state_dict, path)

                    log_msg += "\n** Best Performing Model: {:.2f} ** \nSaving weights at {}\n".format(best_val_acc,
                                                                                                       path)

                log_msg += '-------------------------------------------------------------------------\n\n'
                print_log(log_msg, log_file)

                # Reset the mode to training
                model.train()

        writer.close()
        log_file.close()

    elif args.mode == 'test':

        # Dataloader
        dataset = BaseDataset(args.test_file, tokenizer=args.model, max_seq_len=args.seq_len, text2text=text2text,
                              uniqa=uniqa)

        loader = DataLoader(dataset, batch_size, num_workers=args.num_workers)

        tokenizer = dataset.get_tokenizer()

        model = Transformer(args.model, args.num_cls, text2text, num_layers=args.num_layers)
        model.eval()
        model.to(device)

        # Load model weights
        if args.ckpt:
            checkpoint = torch.load(args.ckpt, map_location=device)
            model.load_state_dict(checkpoint['model_state_dict'])
        data_len = dataset.__len__()
        print('Total Samples: {}'.format(data_len))

        is_pairwise = 'com2sense' in dataset_names

        # Inference
        metrics = compute_eval_metrics(model, loader, device, data_len, tokenizer, text2text, is_pairwise=is_pairwise,
                                       is_test=True, parallel=args.data_parallel)

        df = pd.DataFrame(metrics['meta'])
        df.to_csv(args.pred_file)

        print(f'Results for model {args.model}')
        print(f'Results evaluated on file {args.test_file}')
        print('Sentence Accuracy: {:.4f}'.format(metrics['accuracy']))
        if is_pairwise:
            print('Pairwise Accuracy: {:.4f}'.format(metrics['pair_acc']))
Пример #19
0
def main():

    # 1. argparser
    opts = parse(sys.argv[1:])
    print(opts)

    # 3. visdom
    vis = visdom.Visdom(port=opts.port)
    # 4. data set
    train_set = None
    test_set = None

    # train_set = KorEngDataset(root='./data', split='train')
    train_set = KorEngDataset(root='./data', split='valid')

    # 5. data loader
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=opts.batch_size,
                                               collate_fn=train_set.collate_fn,
                                               shuffle=True,
                                               num_workers=4,
                                               pin_memory=True)

    # test_loader = torch.utils.data.DataLoader(test_set,
    #                                           batch_size=1,
    #                                           collate_fn=test_set.collate_fn,
    #                                           shuffle=False,
    #                                           num_workers=2,
    #                                           pin_memory=True)

    # 6. network
    model = Transformer(num_vocab=110000,
                        model_dim=512,
                        max_seq_len=64,
                        num_head=8,
                        num_layers=6,
                        dropout=0.1).to(device)
    model = torch.nn.DataParallel(module=model, device_ids=device_ids)

    # 7. loss
    criterion = torch.nn.CrossEntropyLoss(ignore_index=0)

    # 8. optimizer
    optimizer = torch.optim.SGD(params=model.parameters(),
                                lr=opts.lr,
                                momentum=opts.momentum,
                                weight_decay=opts.weight_decay)

    # 9. scheduler
    scheduler = MultiStepLR(optimizer=optimizer,
                            milestones=[30, 45],
                            gamma=0.1)

    # 10. resume
    if opts.start_epoch != 0:

        checkpoint = torch.load(
            os.path.join(opts.save_path, opts.save_file_name) +
            '.{}.pth.tar'.format(opts.start_epoch - 1),
            map_location=device)  # 하나 적은걸 가져와서 train
        model.load_state_dict(
            checkpoint['model_state_dict'])  # load model state dict
        optimizer.load_state_dict(
            checkpoint['optimizer_state_dict'])  # load optim state dict
        scheduler.load_state_dict(
            checkpoint['scheduler_state_dict'])  # load sched state dict
        print('\nLoaded checkpoint from epoch %d.\n' %
              (int(opts.start_epoch) - 1))

    else:

        print('\nNo check point to resume.. train from scratch.\n')

    # for statement
    for epoch in range(opts.start_epoch, opts.epoch):

        # 11. train
        train(epoch=epoch,
              vis=vis,
              train_loader=train_loader,
              model=model,
              criterion=criterion,
              optimizer=optimizer,
              scheduler=scheduler,
              opts=opts)

        scheduler.step()
Пример #20
0
    }
    idx2word = {i: w for i, w in enumerate(tgt_vocab)}  # 解码时需要
    tgt_vocab_size = len(tgt_vocab)

    # 输入和输出的最大长度
    src_len = 5  # enc_input max sequence length
    tgt_len = 6  # dec_input(=dec_output) max sequence length

    # 将数据转为id序列
    enc_inputs, dec_inputs, dec_outputs = make_data(sentences)

    loader = DataLoader(MyDataSet(enc_inputs, dec_inputs, dec_outputs),
                        batch_size=2,
                        shuffle=True)

    # 加载模型
    model = Transformer()
    model.load_state_dict(
        torch.load('./save_model/epoch4_ckpt.bin', map_location='cpu'))
    print("模型加载成功...")
    model.eval()

    # Test
    enc_inputs, _, _ = next(iter(loader))
    greedy_dec_input = greedy_decoder(model,
                                      enc_inputs[0].view(1, -1),
                                      start_symbol=tgt_vocab["S"])
    predict, _, _, _ = model(enc_inputs[0].view(1, -1), greedy_dec_input)
    predict = predict.data.max(1, keepdim=True)[1]
    print(enc_inputs[0], '->', [idx2word[n.item()] for n in predict.squeeze()])
Пример #21
0
def main(args):

    # 0. initial setting

    # set environmet
    cudnn.benchmark = True

    if not os.path.isdir('./ckpt'):
        os.mkdir('./ckpt')
    if not os.path.isdir('./results'):
        os.mkdir('./results')
    if not os.path.isdir(os.path.join('./ckpt', args.name)):
        os.mkdir(os.path.join('./ckpt', args.name))
    if not os.path.isdir(os.path.join('./results', args.name)):
        os.mkdir(os.path.join('./results', args.name))
    if not os.path.isdir(os.path.join('./results', args.name, "log")):
        os.mkdir(os.path.join('./results', args.name, "log"))

    # set logger
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(message)s')
    handler = logging.FileHandler("results/{}/log/{}.log".format(
        args.name, time.strftime('%c', time.localtime(time.time()))))
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    logger.addHandler(logging.StreamHandler())
    args.logger = logger

    # set cuda
    if torch.cuda.is_available():
        args.logger.info("running on cuda")
        args.device = torch.device("cuda")
        args.use_cuda = True
    else:
        args.logger.info("running on cpu")
        args.device = torch.device("cpu")
        args.use_cuda = False

    args.logger.info("[{}] starts".format(args.name))

    # 1. load data

    args.logger.info("loading data...")
    src, tgt = load_data(args.path)

    src_vocab = Vocab(init_token='<sos>',
                      eos_token='<eos>',
                      pad_token='<pad>',
                      unk_token='<unk>')
    src_vocab.load(os.path.join(args.path, 'vocab.en'))
    tgt_vocab = Vocab(init_token='<sos>',
                      eos_token='<eos>',
                      pad_token='<pad>',
                      unk_token='<unk>')
    tgt_vocab.load(os.path.join(args.path, 'vocab.de'))

    # 2. setup

    args.logger.info("setting up...")

    sos_idx = 0
    eos_idx = 1
    pad_idx = 2
    max_length = 50

    src_vocab_size = len(src_vocab)
    tgt_vocab_size = len(tgt_vocab)

    # transformer config
    d_e = 512  # embedding size
    d_q = 64  # query size (= key, value size)
    d_h = 2048  # hidden layer size in feed forward network
    num_heads = 8
    num_layers = 6  # number of encoder/decoder layers in encoder/decoder

    args.sos_idx = sos_idx
    args.eos_idx = eos_idx
    args.pad_idx = pad_idx
    args.max_length = max_length
    args.src_vocab_size = src_vocab_size
    args.tgt_vocab_size = tgt_vocab_size
    args.d_e = d_e
    args.d_q = d_q
    args.d_h = d_h
    args.num_heads = num_heads
    args.num_layers = num_layers

    model = Transformer(args)
    model.to(args.device)
    loss_fn = nn.CrossEntropyLoss(ignore_index=pad_idx)
    optimizer = optim.Adam(model.parameters(), lr=1e-5)

    if args.load:
        model.load_state_dict(load(args, args.ckpt))

    # 3. train / test

    if not args.test:
        # train
        args.logger.info("starting training")
        acc_val_meter = AverageMeter(name="Acc-Val (%)",
                                     save_all=True,
                                     save_dir=os.path.join(
                                         'results', args.name))
        train_loss_meter = AverageMeter(name="Loss",
                                        save_all=True,
                                        save_dir=os.path.join(
                                            'results', args.name))
        train_loader = get_loader(src['train'],
                                  tgt['train'],
                                  src_vocab,
                                  tgt_vocab,
                                  batch_size=args.batch_size,
                                  shuffle=True)
        valid_loader = get_loader(src['valid'],
                                  tgt['valid'],
                                  src_vocab,
                                  tgt_vocab,
                                  batch_size=args.batch_size)

        for epoch in range(1, 1 + args.epochs):
            spent_time = time.time()
            model.train()
            train_loss_tmp_meter = AverageMeter()
            for src_batch, tgt_batch in tqdm(train_loader):
                # src_batch: (batch x source_length), tgt_batch: (batch x target_length)
                optimizer.zero_grad()
                src_batch, tgt_batch = torch.LongTensor(src_batch).to(
                    args.device), torch.LongTensor(tgt_batch).to(args.device)
                batch = src_batch.shape[0]
                # split target batch into input and output
                tgt_batch_i = tgt_batch[:, :-1]
                tgt_batch_o = tgt_batch[:, 1:]

                pred = model(src_batch.to(args.device),
                             tgt_batch_i.to(args.device))
                loss = loss_fn(pred.contiguous().view(-1, tgt_vocab_size),
                               tgt_batch_o.contiguous().view(-1))
                loss.backward()
                optimizer.step()

                train_loss_tmp_meter.update(loss / batch, weight=batch)

            train_loss_meter.update(train_loss_tmp_meter.avg)
            spent_time = time.time() - spent_time
            args.logger.info(
                "[{}] train loss: {:.3f} took {:.1f} seconds".format(
                    epoch, train_loss_tmp_meter.avg, spent_time))

            # validation
            model.eval()
            acc_val_tmp_meter = AverageMeter()
            spent_time = time.time()

            for src_batch, tgt_batch in tqdm(valid_loader):
                src_batch, tgt_batch = torch.LongTensor(
                    src_batch), torch.LongTensor(tgt_batch)
                tgt_batch_i = tgt_batch[:, :-1]
                tgt_batch_o = tgt_batch[:, 1:]

                with torch.no_grad():
                    pred = model(src_batch.to(args.device),
                                 tgt_batch_i.to(args.device))

                corrects, total = val_check(
                    pred.max(dim=-1)[1].cpu(), tgt_batch_o)
                acc_val_tmp_meter.update(100 * corrects / total, total)

            spent_time = time.time() - spent_time
            args.logger.info(
                "[{}] validation accuracy: {:.1f} %, took {} seconds".format(
                    epoch, acc_val_tmp_meter.avg, spent_time))
            acc_val_meter.update(acc_val_tmp_meter.avg)

            if epoch % args.save_period == 0:
                save(args, "epoch_{}".format(epoch), model.state_dict())
                acc_val_meter.save()
                train_loss_meter.save()
    else:
        # test
        args.logger.info("starting test")
        test_loader = get_loader(src['test'],
                                 tgt['test'],
                                 src_vocab,
                                 tgt_vocab,
                                 batch_size=args.batch_size)
        pred_list = []
        model.eval()

        for src_batch, tgt_batch in test_loader:
            #src_batch: (batch x source_length)
            src_batch = torch.Tensor(src_batch).long().to(args.device)
            batch = src_batch.shape[0]
            pred_batch = torch.zeros(batch, 1).long().to(args.device)
            pred_mask = torch.zeros(batch, 1).bool().to(
                args.device)  # mask whether each sentece ended up

            with torch.no_grad():
                for _ in range(args.max_length):
                    pred = model(
                        src_batch,
                        pred_batch)  # (batch x length x tgt_vocab_size)
                    pred[:, :, pad_idx] = -1  # ignore <pad>
                    pred = pred.max(dim=-1)[1][:, -1].unsqueeze(
                        -1)  # next word prediction: (batch x 1)
                    pred = pred.masked_fill(
                        pred_mask,
                        2).long()  # fill out <pad> for ended sentences
                    pred_mask = torch.gt(pred.eq(1) + pred.eq(2), 0)
                    pred_batch = torch.cat([pred_batch, pred], dim=1)
                    if torch.prod(pred_mask) == 1:
                        break

            pred_batch = torch.cat([
                pred_batch,
                torch.ones(batch, 1).long().to(args.device) + pred_mask.long()
            ],
                                   dim=1)  # close all sentences
            pred_list += seq2sen(pred_batch.cpu().numpy().tolist(), tgt_vocab)

        with open('results/pred.txt', 'w', encoding='utf-8') as f:
            for line in pred_list:
                f.write('{}\n'.format(line))

        os.system(
            'bash scripts/bleu.sh results/pred.txt multi30k/test.de.atok')
Пример #22
0
        val_dataloader_vua,
        loss_criterion,
        using_GPU,
        RNNseq_model,
        Transformer_model,
    )
    val_end_f1 = 2 * precision * recall / (precision + recall)
    if val_end_f1 > bestf1:
        bestf1 = val_end_f1
        best_model_weights[0] = copy.deepcopy(RNNseq_model.state_dict())
        best_model_weights[1] = copy.deepcopy(Transformer_model.state_dict())
        best_optimizer_dict[0] = copy.deepcopy(rnn_optimizer.state_dict())
        best_optimizer_dict[1] = copy.deepcopy(trans_optimizer.state_dict())

RNNseq_model.load_state_dict(best_model_weights[0])
Transformer_model.load_state_dict(best_model_weights[1])
rnn_optimizer.load_state_dict(best_optimizer_dict[0])
trans_optimizer.load_state_dict(best_optimizer_dict[1])

torch.save(
    {
        'epoch': num_epochs,
        'model_state_dict': best_model_weights,
        'optimizer_state_dict': best_optimizer_dict,
    }, './models/vua/model.tar')
"""
3.3
plot the training process: losses for validation and training dataset
"""
plt.figure(0)
plt.title('Loss for VUA dataset')
def inference(path_to_save_predictions, forecast_window, dataloader, device,
              path_to_save_model, best_model):

    device = torch.device(device)

    model = Transformer().double().to(device)
    model.load_state_dict(torch.load(path_to_save_model + best_model))
    criterion = torch.nn.MSELoss()

    val_loss = 0
    with torch.no_grad():

        model.eval()
        for plot in range(25):

            for index_in, index_tar, _input, target, sensor_number in dataloader:

                # starting from 1 so that src matches with target, but has same length as when training
                src = _input.permute(
                    1, 0,
                    2).double().to(device)[1:, :, :]  # 47, 1, 7: t1 -- t47
                target = target.permute(1, 0,
                                        2).double().to(device)  # t48 - t59

                next_input_model = src
                all_predictions = []

                for i in range(forecast_window - 1):

                    prediction = model(next_input_model,
                                       device)  # 47,1,1: t2' - t48'

                    if all_predictions == []:
                        all_predictions = prediction  # 47,1,1: t2' - t48'
                    else:
                        all_predictions = torch.cat(
                            (all_predictions, prediction[-1, :, :].unsqueeze(0)
                             ))  # 47+,1,1: t2' - t48', t49', t50'

                    pos_encoding_old_vals = src[
                        i + 1:, :,
                        1:]  # 46, 1, 6, pop positional encoding first value: t2 -- t47
                    pos_encoding_new_val = target[i + 1, :, 1:].unsqueeze(
                        1
                    )  # 1, 1, 6, append positional encoding of last predicted value: t48
                    pos_encodings = torch.cat(
                        (pos_encoding_old_vals, pos_encoding_new_val)
                    )  # 47, 1, 6 positional encodings matched with prediction: t2 -- t48

                    next_input_model = torch.cat(
                        (src[i + 1:, :, 0].unsqueeze(-1),
                         prediction[-1, :, :].unsqueeze(0)))  #t2 -- t47, t48'
                    next_input_model = torch.cat(
                        (next_input_model, pos_encodings),
                        dim=2)  # 47, 1, 7 input for next round

                true = torch.cat((src[1:, :, 0], target[:-1, :, 0]))
                loss = criterion(true, all_predictions[:, :, 0])
                val_loss += loss

            val_loss = val_loss / 10
            scaler = load('scalar_item.joblib')
            src_humidity = scaler.inverse_transform(src[:, :, 0].cpu())
            target_humidity = scaler.inverse_transform(target[:, :, 0].cpu())
            prediction_humidity = scaler.inverse_transform(
                all_predictions[:, :, 0].detach().cpu().numpy())
            plot_prediction(plot, path_to_save_predictions, src_humidity,
                            target_humidity, prediction_humidity,
                            sensor_number, index_in, index_tar)

        logger.info(f"Loss On Unseen Dataset: {val_loss.item()}")
Пример #24
0
        # np.save(pre + '_valid_accs.npy', valid_accs)
        # np.save(pre + '_train_losses.npy', train_losses)
        # np.save(pre + '_valid_losses.npy', valid_losses)
        test_model_name = str(r) + model_name
        model = Transformer(device=device,
                            d_feature=test_data.sig_len,
                            d_model=d_model,
                            d_inner=d_inner,
                            n_layers=num_layers,
                            n_head=num_heads,
                            d_k=64,
                            d_v=64,
                            dropout=dropout,
                            class_num=class_num)
        chkpoint = torch.load(test_model_name, map_location='cuda:3')
        model.load_state_dict(chkpoint['model'])
        model = model.to(device)
        test_epoch(test_loader, device, model, test_data.__len__())

    # models = []
    # for r in range(10):
    #     test_model_name = str(r) + model_name
    #     model = Transformer(device=device, d_feature=test_data.sig_len, d_model=d_model, d_inner=d_inner,
    #                         n_layers=num_layers, n_head=num_heads, d_k=64, d_v=64, dropout=dropout,
    #                         class_num=class_num)
    #     chkpoint = torch.load(test_model_name)
    #     model.load_state_dict(chkpoint['model'])
    #     model = model.to(device)
    #     models.append(model)
    # voting_epoch(test_loader, device, models, test_data.__len__())
Пример #25
0
        default=10,
        type=int
    )
    args = parser.parse_args()

    cfg = Config.load(args.experiment)

    tokenizer = Tokenizer()
    tokenizer.load_model('tokenizer')
    model = Transformer(
        d_model=cfg.d_emb,
        h=cfg.head_num,
        encoder_N=cfg.encoder_num,
        decoder_N=cfg.decoder_num,
        vocab_size=tokenizer.vocab_len(),
        pad_token_id=tokenizer.pad_id[0],
        dropout=0.1
    )
    model_path = os.path.join('data', str(args.experiment), args.model_name)
    model.load_state_dict(torch.load(f=model_path))
    model = model.to(device)
    x = args.input
    generate(
        x=x,
        beam_width=args.width,
        device=device,
        max_seq_len=args.max_seq_len,
        model=model,
        tokenizer=tokenizer
    )
Пример #26
0
def main(tokenizer, src_tok_file, tgt_tok_file, train_file, val_file,
         test_file, num_epochs, batch_size, d_model, nhead, num_encoder_layers,
         num_decoder_layers, dim_feedforward, dropout, learning_rate,
         data_path, checkpoint_file, do_train):
    logging.info('Using tokenizer: {}'.format(tokenizer))

    src_tokenizer = TokenizerWrapper(tokenizer, BLANK_WORD, SEP_TOKEN,
                                     CLS_TOKEN, PAD_TOKEN, MASK_TOKEN)
    src_tokenizer.train(src_tok_file, 20000, SPECIAL_TOKENS)

    tgt_tokenizer = TokenizerWrapper(tokenizer, BLANK_WORD, SEP_TOKEN,
                                     CLS_TOKEN, PAD_TOKEN, MASK_TOKEN)
    tgt_tokenizer.train(tgt_tok_file, 20000, SPECIAL_TOKENS)

    SRC = ttdata.Field(tokenize=src_tokenizer.tokenize, pad_token=BLANK_WORD)
    TGT = ttdata.Field(tokenize=tgt_tokenizer.tokenize,
                       init_token=BOS_WORD,
                       eos_token=EOS_WORD,
                       pad_token=BLANK_WORD)

    logging.info('Loading training data...')
    train_ds, val_ds, test_ds = ttdata.TabularDataset.splits(
        path=data_path,
        format='tsv',
        train=train_file,
        validation=val_file,
        test=test_file,
        fields=[('src', SRC), ('tgt', TGT)])

    test_src_sentence = val_ds[0].src
    test_tgt_sentence = val_ds[0].tgt

    MIN_FREQ = 2
    SRC.build_vocab(train_ds.src, min_freq=MIN_FREQ)
    TGT.build_vocab(train_ds.tgt, min_freq=MIN_FREQ)

    logging.info(f'''SRC vocab size: {len(SRC.vocab)}''')
    logging.info(f'''TGT vocab size: {len(TGT.vocab)}''')

    train_iter = ttdata.BucketIterator(train_ds,
                                       batch_size=batch_size,
                                       repeat=False,
                                       sort_key=lambda x: len(x.src))
    val_iter = ttdata.BucketIterator(val_ds,
                                     batch_size=1,
                                     repeat=False,
                                     sort_key=lambda x: len(x.src))
    test_iter = ttdata.BucketIterator(test_ds,
                                      batch_size=1,
                                      repeat=False,
                                      sort_key=lambda x: len(x.src))

    source_vocab_length = len(SRC.vocab)
    target_vocab_length = len(TGT.vocab)

    model = Transformer(d_model=d_model,
                        nhead=nhead,
                        num_encoder_layers=num_encoder_layers,
                        num_decoder_layers=num_decoder_layers,
                        dim_feedforward=dim_feedforward,
                        dropout=dropout,
                        source_vocab_length=source_vocab_length,
                        target_vocab_length=target_vocab_length)
    optim = torch.optim.Adam(model.parameters(),
                             lr=learning_rate,
                             betas=(0.9, 0.98),
                             eps=1e-9)
    model = model.cuda()

    if do_train:
        train_losses, valid_losses = train(train_iter, val_iter, model, optim,
                                           num_epochs, batch_size,
                                           test_src_sentence,
                                           test_tgt_sentence, SRC, TGT,
                                           src_tokenizer, tgt_tokenizer,
                                           checkpoint_file)
    else:
        logging.info('Skipped training.')

    # Load best model and score test set
    logging.info('Loading best model.')
    model.load_state_dict(torch.load(checkpoint_file))
    model.eval()
    logging.info('Scoring the test set...')
    score_start = time.time()
    test_bleu, test_chrf = score(test_iter, model, tgt_tokenizer, SRC, TGT)
    score_time = time.time() - score_start
    logging.info(f'''Scoring complete in {score_time/60:.3f} minutes.''')
    logging.info(f'''BLEU : {test_bleu}''')
    logging.info(f'''CHRF : {test_chrf}''')
Пример #27
0
SRC_PAD_IDX = SRC.vocab.stoi['<pad>']
TRG_PAD_IDX = TRG.vocab.stoi['<pad>']

model = Transformer(len(SRC.vocab), len(TRG.vocab), MAX_LEN, MODEL_SIZE,
                    FF_SIZE, KEY_SIZE, VALUE_SIZE, NUM_HEADS, NUM_LAYERS,
                    DROPOUT, SRC_PAD_IDX, TRG_PAD_IDX).to(DEVICE)
criterion = nn.CrossEntropyLoss(ignore_index=TRG_PAD_IDX)
opt = AdamWrapper(model.parameters(), MODEL_SIZE, WARMUP)

if args.train or args.continue_training:
    if args.train:
        best_val_loss = float('inf')
        with open(LOG_PATH, 'w') as f:
            f.write('')
    else:
        model.load_state_dict(torch.load(MODEL_PATH))
        with open(LOG_PATH, 'r') as f:
            val_losses = [float(line.split()[-1]) for line in f]
            best_val_loss = min(val_losses)

    print(f'best_val_loss: {best_val_loss}')

    for epoch in range(NUM_EPOCHS):
        train_loss = train(model, criterion, opt, train_iter, epoch)
        val_loss = evaluate(model, criterion, val_iter)

        with open(LOG_PATH, 'a') as f:
            f.write('{:.10f} {:.10f}\n'.format(train_loss, val_loss))

        if val_loss < best_val_loss:
            best_val_loss = val_loss
Пример #28
0
def main(args, hparams):

    # prepare data
    testset = TextMelLoader(hparams.test_files, hparams, shuffle=False)
    collate_fn = TextMelCollate(hparams.n_frames_per_step)
    test_loader = DataLoader(
        testset,
        num_workers=1,
        shuffle=False,
        batch_size=1,
        pin_memory=False,
        collate_fn=collate_fn,
    )

    # prepare model
    model = Transformer(hparams).cuda("cuda:0")
    checkpoint_restore = load_avg_checkpoint(args.checkpoint_path)
    model.load_state_dict(checkpoint_restore)
    model.eval()
    print("# total parameters:", sum(p.numel() for p in model.parameters()))

    # infer
    duration_add = 0
    with torch.no_grad():
        for i, batch in tqdm(enumerate(test_loader)):
            x, y = parse_batch(batch)

            # the start time
            start = time.perf_counter()
            (
                mel_output,
                mel_output_postnet,
                _,
                enc_attn_list,
                dec_attn_list,
                dec_enc_attn_list,
            ) = model.inference(x)

            # the end time
            duration = time.perf_counter() - start
            duration_add += duration

            # denormalize the feats and save the mels and attention plots
            mel_predict = mel_output_postnet[0]
            mel_denorm = denormalize_feats(mel_predict, hparams.dump)
            mel_path = os.path.join(args.output_infer,
                                    "{:0>3d}".format(i) + ".pt")
            torch.save(mel_denorm, mel_path)

            plot_data(
                (
                    mel_output.detach().cpu().numpy()[0],
                    mel_output_postnet.detach().cpu().numpy()[0],
                    mel_denorm.numpy(),
                ),
                i,
                args.output_infer,
            )

            plot_attn(
                enc_attn_list,
                dec_attn_list,
                dec_enc_attn_list,
                i,
                args.output_infer,
            )

        duration_avg = duration_add / (i + 1)
        print("The average inference time is: %f" % duration_avg)
Пример #29
0
    transformer_model = Transformer(
        train_dataset.en_vocab_size,
        config.max_output_len,
        train_dataset.cn_vocab_size,
        config.max_output_len,
        num_layers=config.n_layers,
        model_dim=config.model_dim,
        num_heads=config.num_heads,
        ffn_dim=config.ffn_dim,
        dropout=config.dropout,
    ).to(config.device)
    print("使用模型:")
    print(transformer_model)
    total_steps = 0
    if config.load_model:
        transformer_model.load_state_dict(torch.load(config.load_model_path))
        total_steps = int(re.split('[_/.]', config.model_file)[1])

    optimizer = torch.optim.Adam(transformer_model.parameters(),
                                 lr=config.learning_rate)
    loss_function = CrossEntropyLoss(ignore_index=0)

    train_losses, val_losses, bleu_scores = [], [], []

    while total_steps < config.num_steps:
        # 訓練模型
        transformer_model.train()
        transformer_model.zero_grad()
        losses = []
        loss_sum = 0.0
        for step in range(config.summary_steps):
def main(TEXT, LABEL, train_loader, test_loader):

    # for sentiment analysis. load .pt file
    from KoBERT.Bert_model import BERTClassifier
    from kobert.pytorch_kobert import get_pytorch_kobert_model
    bertmodel, vocab = get_pytorch_kobert_model()
    sa_model = BERTClassifier(bertmodel, dr_rate=0.5).to(device)
    sa_model.load_state_dict(torch.load('bert_SA-model.pt'))

    # print argparse
    for idx, (key, value) in enumerate(args.__dict__.items()):
        if idx == 0:
            print("\nargparse{\n", "\t", key, ":", value)
        elif idx == len(args.__dict__) - 1:
            print("\t", key, ":", value, "\n}")
        else:
            print("\t", key, ":", value)

    from model import Transformer, GradualWarmupScheduler

    # Transformer model init
    model = Transformer(args, TEXT, LABEL)
    if args.per_soft:
        sorted_path = 'sorted_model-soft.pth'
    else:
        sorted_path = 'sorted_model-rough.pth'

    # loss 계산시 pad 제외.
    criterion = nn.CrossEntropyLoss(ignore_index=LABEL.vocab.stoi['<pad>'])

    optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr)
    scheduler = GradualWarmupScheduler(optimizer,
                                       multiplier=8,
                                       total_epoch=args.num_epochs)

    # pre-trained 된 vectors load
    model.src_embedding.weight.data.copy_(TEXT.vocab.vectors)
    model.trg_embedding.weight.data.copy_(LABEL.vocab.vectors)
    model.to(device)
    criterion.to(device)

    # overfitting 막기
    best_valid_loss = float('inf')

    # train
    if args.train:
        for epoch in range(args.num_epochs):
            torch.manual_seed(SEED)
            scheduler.step(epoch)
            start_time = time.time()

            # train, validation
            train_loss, train_acc = train(model, train_loader, optimizer,
                                          criterion)
            valid_loss, valid_acc = test(model, test_loader, criterion)

            # time cal
            end_time = time.time()
            epoch_mins, epoch_secs = epoch_time(start_time, end_time)

            #torch.save(model.state_dict(), sorted_path) # for some overfitting
            #전에 학습된 loss 보다 현재 loss 가 더 낮을시 모델 저장.
            if valid_loss < best_valid_loss:
                best_valid_loss = valid_loss
                torch.save(
                    {
                        'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': valid_loss
                    }, sorted_path)
                print(
                    f'\t## SAVE valid_loss: {valid_loss:.3f} | valid_acc: {valid_acc:.3f} ##'
                )

            # print loss and acc
            print(
                f'\n\t==Epoch: {epoch + 1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s=='
            )
            print(
                f'\t==Train Loss: {train_loss:.3f} | Train_acc: {train_acc:.3f}=='
            )
            print(
                f'\t==Valid Loss: {valid_loss:.3f} | Valid_acc: {valid_acc:.3f}==\n'
            )

    # inference
    print("\t----------성능평가----------")
    checkpoint = torch.load(sorted_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    test_loss, test_acc = test(model, test_loader, criterion)  # 아
    print(f'==test_loss : {test_loss:.3f} | test_acc: {test_acc:.3f}==')
    print("\t-----------------------------")
    while (True):
        inference(device, args, TEXT, LABEL, model, sa_model)
        print("\n")