コード例 #1
0
def get_model(text_proc, args):
    sent_vocab = text_proc.vocab
    model = Transformer(dict_size=len(sent_vocab),
                        image_feature_dim=args.image_feat_size,
                        vocab=sent_vocab,
                        tf_ratio=args.teacher_forcing)

    # Initialize the networks and the criterion
    if len(args.start_from) > 0:
        print("Initializing weights from {}".format(args.start_from))
        model.load_state_dict(torch.load(args.start_from,
                                              map_location=lambda storage, location: storage))

    # Ship the model to GPU, maybe
    if torch.cuda.is_available():
        model.cuda()
        # if args.distributed:
          #   model.cuda()
            # model = torch.nn.parallel.DistributedDataParallel(model)
        # else:
          #   model = torch.nn.DataParallel(model).cuda()
        # elif torch.cuda.device_count() > 1:
          #   model = torch.nn.DataParallel(model).cuda()
        # else:
            # model.cuda()
    return model
コード例 #2
0
def prediction(text):
    params = Params('config/params.json')

    # load tokenizer and torchtext Fields
    pickle_tokenizer = open('pickles/tokenizer.pickle', 'rb')
    cohesion_scores = pickle.load(pickle_tokenizer)
    tokenizer = LTokenizer(scores=cohesion_scores)

    pickle_kor = open('pickles/kor.pickle', 'rb')
    kor = pickle.load(pickle_kor)
    pickle_eng = open('pickles/eng.pickle', 'rb')
    eng = pickle.load(pickle_eng)
    eos_idx = eng.vocab.stoi['<eos>']

    # select model and load trained model
    model = Transformer(params)
    model.load_state_dict(torch.load(params.save_model))
    model.to(params.device)
    model.eval()

    # convert input into tensor and forward it through selected model
    tokenized = tokenizer.tokenize(text)
    indexed = [kor.vocab.stoi[token] for token in tokenized]


    source = torch.LongTensor(indexed).unsqueeze(0).to(params.device)  # [1, source_len]: unsqueeze to add batch size
    target = torch.zeros(1, params.max_len).type_as(source.data)       # [1, max_len]

    encoder_output = model.encoder(source)
    next_symbol = eng.vocab.stoi['<sos>']

    for i in range(0, params.max_len):
        if next_symbol == eos_idx:
            break
        target[0][i] = next_symbol
        decoder_output, _ = model.decoder(target, source, encoder_output)  # [1, target length, output dim]
        prob = decoder_output.squeeze(0).max(dim=-1, keepdim=False)[1]
        next_word = prob.data[i]
        next_symbol = next_word.item()

    #eos_idx = torch.where(target[0] == eos_idx)[0][0]
    #eos_idx = eos_idx.item()
    eos_index = 34
    print(eos_idx)
    target = target[0][:eos_idx].unsqueeze(0)

    # translation_tensor = [target length] filed with word indices
    target, attention_map = model(source, target)
    target = target.squeeze(0).max(dim=-1)[1]

    reply_token = [eng.vocab.itos[token] for token in target if token != 3]
    print(reply_token)
    #translation = translated_token[:translated_token.index('<eos>')]
    #translation = ''.join(translation)
    reply = ' '.join(reply_token)
    #print(reply)

    #display_attention(tokenized, reply_token, attention_map[4].squeeze(0)[:-1])
    return reply 
コード例 #3
0
ファイル: predict.py プロジェクト: lih0905/NLP_Study
def predict(config):
    params = Params('config/params.json')

    # load tokenizer and torchtext Fields
    pickle_tokenizer = open('pickles/tokenizer.pickle', 'rb')
    cohesion_scores = pickle.load(pickle_tokenizer)
    tokenizer = LTokenizer(scores=cohesion_scores)

    pickle_kor = open('pickles/kor.pickle', 'rb')
    kor = pickle.load(pickle_kor)

    pickle_eng = open('pickles/eng.pickle', 'rb')
    eng = pickle.load(pickle_eng)

    # select model and load trained model
    model = Transformer(params)

    model.load_state_dict(torch.load(params.save_model))
    model.to(params.device)
    model.eval()

    input = clean_text(config.input)

    # convert input into tensor and forward it through selected model
    tokenized = tokenizer.tokenize(input)
    indexed = [kor.vocab.stoi[token] for token in tokenized]

    source = torch.LongTensor(indexed).unsqueeze(0).to(
        params.device)  # [1, source length]: unsqueeze to add batch size
    target = torch.zeros(1, params.max_len).type_as(source.data)

    encoder_output = model.encoder(source)
    next_symbol = eng.vocab.stoi['<sos>']

    for i in range(0, params.max_len):
        target[0][i] = next_symbol
        dec_output = model.decoder(target, source, encoder_output)
        # dec_output = [1, target length, output dim]
        prob = dec_output.squeeze(0).max(dim=-1, keepdim=False)[1]
        next_word = prob.data[i]
        next_symbol = next_word.item()

    # translation_tensor = [target length] filed with word indices
    target = model(source, target)
    target = torch.argmax(target.squeeze(0), -1)
    # target = target.squeeze(0).max(dim=-1, keepdim=False)
    translation = [eng.vocab.itos[token] for token in target][1:]

    translation = ' '.join(translation)
    print(f'kor> {config.input}')
    print(f'eng> {translation.capitalize()}')
コード例 #4
0
ファイル: translate.py プロジェクト: superMC5657/transformer
def load_model(opt, device):
    checkpoint = torch.load(opt.model, map_location=device)
    model_opt = checkpoint['settings']

    model = Transformer(model_opt.src_vocab_size,
                        model_opt.trg_vocab_size,
                        model_opt.src_pad_idx,
                        model_opt.trg_pad_idx,
                        trg_emb_prj_weight_sharing=model_opt.proj_share_weight,
                        src_emb_prj_weight_sharing=model_opt.embs_share_weight,
                        d_k=model_opt.d_k,
                        d_v=model_opt.d_v,
                        d_model=model_opt.d_model,
                        d_word_vec=model_opt.d_word_vec,
                        d_inner=model_opt.d_inner_hid,
                        n_layers=model_opt.n_layers,
                        n_head=model_opt.n_head,
                        dropout=model_opt.dropout).to(device)

    model.load_state_dict(checkpoint['model'])
    print('[Info] Trained model state loaded.')
    return model
コード例 #5
0
  optimizer = torch.optim.Adam(model.parameters(), lr=const.LEARNING_RATE)
  cross_entropy = nn.CrossEntropyLoss(ignore_index=TRG_PAD_IDX)

  print(f'The model has {model_utils.count_parameters(model):,} trainable parameters')

  trainer = Trainer(
    const=const,
    optimizer=optimizer,
    criterion=cross_entropy,
    device=device,
  )

  trainer.train(
    model=model,
    train_iterator=train_iterator,
    valid_iterator=valid_iterator,
  )

  model.load_state_dict(torch.load('./checkpoints/model.best.pt'))
  
  trainer.test(model=model, test_iterator=test_iterator)

  bleu_score = inference_utils.calculate_bleu(
    data=test_data,
    source_field=source,
    target_field=target,
    model=model,
    device=device,
  )

  print(f'BLEU score = {bleu_score*100:.2f}')
コード例 #6
0
def main(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    logger = logging.getLogger(__name__)
    logger.setLevel(logging.DEBUG)
    handler = TqdmLoggingHandler()
    handler.setFormatter(logging.Formatter(" %(asctime)s - %(message)s"))
    logger.addHandler(handler)
    logger.propagate = False

    write_log(logger, "Load data")

    def load_data(args):
        gc.disable()
        with open(f"{args.preprocessed_data_path}/hanja_korean_word2id.pkl",
                  "rb") as f:
            data = pickle.load(f)
            hanja_word2id = data['hanja_word2id']
            korean_word2id = data['korean_word2id']

        with open(f"{args.preprocessed_data_path}/preprocessed_test.pkl",
                  "rb") as f:
            data = pickle.load(f)
            test_hanja_indices = data['hanja_indices']
            test_korean_indices = data['korean_indices']

        gc.enable()
        write_log(logger, "Finished loading data!")
        return hanja_word2id, korean_word2id, test_hanja_indices, test_korean_indices

    hanja_word2id, korean_word2id, test_hanja_indices, test_korean_indices = load_data(
        args)
    hanja_vocab_num = len(hanja_word2id)
    korean_vocab_num = len(korean_word2id)

    hk_dataset = HanjaKoreanDataset(test_hanja_indices,
                                    test_korean_indices,
                                    min_len=args.min_len,
                                    src_max_len=args.src_max_len,
                                    trg_max_len=args.trg_max_len)
    hk_loader = DataLoader(hk_dataset,
                           drop_last=True,
                           batch_size=args.hk_batch_size,
                           num_workers=4,
                           prefetch_factor=4,
                           pin_memory=True)
    write_log(logger, f"hanja-korean: {len(hk_dataset)}, {len(hk_loader)}")
    del test_hanja_indices, test_korean_indices

    write_log(logger, "Build model")
    model = Transformer(hanja_vocab_num,
                        korean_vocab_num,
                        pad_idx=args.pad_idx,
                        bos_idx=args.bos_idx,
                        eos_idx=args.eos_idx,
                        src_max_len=args.src_max_len,
                        trg_max_len=args.trg_max_len,
                        d_model=args.d_model,
                        d_embedding=args.d_embedding,
                        n_head=args.n_head,
                        dim_feedforward=args.dim_feedforward,
                        num_encoder_layer=args.num_encoder_layer,
                        num_decoder_layer=args.num_decoder_layer,
                        num_mask_layer=args.num_mask_layer)

    model.load_state_dict(
        torch.load(args.checkpoint_path, map_location=device)['model'])
    model.src_output_linear = None
    model.src_output_linear2 = None
    model.src_output_norm = None
    model.mask_encoders = None
    model = model.to(device)
    model.eval()

    write_log(logger, "Load SentencePiece model")
    parser = spm.SentencePieceProcessor()
    parser.Load(os.path.join(args.preprocessed_data_path, 'm_korean.model'))

    predicted_list = list()
    label_list = list()
    every_batch = torch.arange(0,
                               args.beam_size * args.hk_batch_size,
                               args.beam_size,
                               device=device)
    tgt_masks = {
        l: model.generate_square_subsequent_mask(l, device)
        for l in range(1, args.trg_max_len + 1)
    }

    with torch.no_grad():
        for src_sequences, trg_sequences in tqdm(hk_loader):
            src_sequences = src_sequences.to(device)
            label_list.extend(trg_sequences.tolist())

            # Encoding
            # encoder_out: (src_seq, batch_size, d_model)
            # src_key_padding_mask: (batch_size, src_seq)
            encoder_out = model.src_embedding(src_sequences).transpose(0, 1)
            src_key_padding_mask = (src_sequences == model.pad_idx)
            for encoder in model.encoders:
                encoder_out = encoder(
                    encoder_out, src_key_padding_mask=src_key_padding_mask)

            # Expanding
            # encoder_out: (src_seq, batch_size * k, d_model)
            # src_key_padding_mask: (batch_size * k, src_seq)
            src_seq_size = encoder_out.size(0)
            src_key_padding_mask = src_key_padding_mask.view(
                args.hk_batch_size, 1, -1).repeat(1, args.beam_size, 1)
            src_key_padding_mask = src_key_padding_mask.view(-1, src_seq_size)
            encoder_out = encoder_out.view(-1, args.hk_batch_size, 1,
                                           args.d_model).repeat(
                                               1, 1, args.beam_size, 1)
            encoder_out = encoder_out.view(src_seq_size, -1, args.d_model)

            # Scores save vector & decoding list setting
            scores_save = torch.zeros(args.beam_size * args.hk_batch_size,
                                      1,
                                      device=device)
            top_k_scores = torch.zeros(args.beam_size * args.hk_batch_size,
                                       1,
                                       device=device)
            complete_seqs = dict()
            complete_ind = set()

            # Decoding start token setting
            seqs = torch.tensor([[model.bos_idx]],
                                dtype=torch.long,
                                device=device)
            seqs = seqs.repeat(args.beam_size * args.hk_batch_size,
                               1).contiguous()

            for step in range(model.trg_max_len):
                # Decoder setting
                # tgt_mask: (out_seq)
                # tgt_key_padding_mask: (batch_size * k, out_seq)
                tgt_mask = tgt_masks[seqs.size(1)]
                tgt_key_padding_mask = (seqs == model.pad_idx)

                # Decoding sentence
                # decoder_out: (out_seq, batch_size * k, d_model)
                decoder_out = model.trg_embedding(seqs).transpose(0, 1)
                for decoder in model.decoders:
                    decoder_out = decoder(
                        decoder_out,
                        encoder_out,
                        tgt_mask=tgt_mask,
                        memory_key_padding_mask=src_key_padding_mask,
                        tgt_key_padding_mask=tgt_key_padding_mask)

                # Score calculate
                # scores: (batch_size * k, vocab_num)
                scores = F.gelu(model.trg_output_linear(decoder_out[-1]))
                scores = model.trg_output_linear2(
                    model.trg_output_norm(scores))
                scores = F.log_softmax(scores, dim=1)

                # Repetition Penalty
                if step > 0 and args.repetition_penalty > 0:
                    prev_ix = next_word_inds.view(-1)
                    for index, prev_token_id in enumerate(prev_ix):
                        scores[index][prev_token_id] *= args.repetition_penalty

                # Add score
                scores = top_k_scores.expand_as(scores) + scores
                if step == 0:
                    # scores: (batch_size, vocab_num)
                    # top_k_scores: (batch_size, k)
                    scores = scores[::args.beam_size]
                    scores[:, model.eos_idx] = float(
                        '-inf')  # set eos token probability zero in first step
                    top_k_scores, top_k_words = scores.topk(
                        args.beam_size, 1, True, True)
                else:
                    # top_k_scores: (batch_size * k, out_seq)
                    top_k_scores, top_k_words = scores.view(
                        args.hk_batch_size, -1).topk(args.beam_size, 1, True,
                                                     True)

                # Previous and Next word extract
                # seqs: (batch_size * k, out_seq + 1)
                prev_word_inds = top_k_words // korean_vocab_num
                next_word_inds = top_k_words % korean_vocab_num
                top_k_scores = top_k_scores.view(
                    args.hk_batch_size * args.beam_size, -1)
                top_k_words = top_k_words.view(
                    args.hk_batch_size * args.beam_size, -1)
                seqs = seqs[prev_word_inds.view(-1) + every_batch.unsqueeze(
                    1).repeat(1, args.beam_size).view(-1)]
                seqs = torch.cat([
                    seqs,
                    next_word_inds.view(args.beam_size * args.hk_batch_size,
                                        -1)
                ],
                                 dim=1)

                # Find and Save Complete Sequences Score
                eos_ind = torch.where(
                    next_word_inds.view(-1) == model.eos_idx)[0]
                if len(eos_ind) > 0:
                    eos_ind = eos_ind.tolist()
                    complete_ind_add = set(eos_ind) - complete_ind
                    complete_ind_add = list(complete_ind_add)
                    complete_ind.update(eos_ind)
                    if len(complete_ind_add) > 0:
                        scores_save[complete_ind_add] = top_k_scores[
                            complete_ind_add]
                        for ix in complete_ind_add:
                            complete_seqs[ix] = seqs[ix].tolist()

            # If eos token doesn't exist in sequence
            score_save_pos = torch.where(scores_save == 0)
            if len(score_save_pos[0]) > 0:
                for ix in score_save_pos[0].tolist():
                    complete_seqs[ix] = seqs[ix].tolist()
                scores_save[score_save_pos] = top_k_scores[score_save_pos]

            # Beam Length Normalization
            lp = torch.tensor([
                len(complete_seqs[i])
                for i in range(args.hk_batch_size * args.beam_size)
            ],
                              device=device)
            lp = (((lp + args.beam_size)**args.beam_alpha) /
                  ((args.beam_size + 1)**args.beam_alpha))
            scores_save = scores_save / lp.unsqueeze(1)

            # Predicted and Label processing
            ind = scores_save.view(args.hk_batch_size, args.beam_size,
                                   -1).argmax(dim=1)
            ind_expand = ind.view(-1) + every_batch
            predicted_list.extend(
                [complete_seqs[i] for i in ind_expand.tolist()])

    with open(
            f'./results_beam_{args.beam_size}_{args.beam_alpha}_{args.repetition_penalty}.pkl',
            'wb') as f:
        pickle.dump(
            {
                'prediction':
                predicted_list,
                'label':
                label_list,
                'prediction_decode':
                [parser.DecodeIds(pred) for pred in predicted_list],
                'label_decode':
                [parser.DecodeIds(label) for label in label_list]
            }, f)
コード例 #7
0
def main(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def load_data(args):
        gc.disable()
        with open(f"{args.preprocessed_data_path}/hanja_korean_word2id.pkl",
                  "rb") as f:
            data = pickle.load(f)
            hanja_word2id = data['hanja_word2id']
            korean_word2id = data['korean_word2id']

        with open(f"{args.preprocessed_data_path}/preprocessed_test.pkl",
                  "rb") as f:
            data = pickle.load(f)
            test_hanja_indices = data['hanja_indices']
            test_additional_hanja_indices = data['additional_hanja_indices']

        gc.enable()
        return hanja_word2id, korean_word2id, test_hanja_indices, test_additional_hanja_indices

    hanja_word2id, korean_word2id, hanja_indices, additional_hanja_indices = load_data(
        args)
    hanja_vocab_num = len(hanja_word2id)
    korean_vocab_num = len(korean_word2id)

    print('Loader and Model Setting...')
    h_dataset = HanjaDataset(hanja_indices,
                             additional_hanja_indices,
                             hanja_word2id,
                             min_len=args.min_len,
                             src_max_len=args.src_max_len)
    h_loader = DataLoader(h_dataset,
                          drop_last=True,
                          batch_size=args.batch_size,
                          num_workers=4,
                          prefetch_factor=4)

    model = Transformer(hanja_vocab_num,
                        korean_vocab_num,
                        pad_idx=args.pad_idx,
                        bos_idx=args.bos_idx,
                        eos_idx=args.eos_idx,
                        src_max_len=args.src_max_len,
                        trg_max_len=args.trg_max_len,
                        d_model=args.d_model,
                        d_embedding=args.d_embedding,
                        n_head=args.n_head,
                        dim_feedforward=args.dim_feedforward,
                        num_encoder_layer=args.num_encoder_layer,
                        num_decoder_layer=args.num_decoder_layer,
                        num_mask_layer=args.num_mask_layer)

    model.load_state_dict(
        torch.load(args.checkpoint_path, map_location='cpu')['model'])
    model.decoders = None
    model.trg_embedding = None
    model.trg_output_linear = None
    model.trg_output_linear2 = None
    model.trg_output_norm = None
    model = model.to(device)
    model.eval()

    masking_acc = defaultdict(float)

    with torch.no_grad():
        for inputs, labels in h_loader:
            # Setting
            inputs = inputs.to(device)
            labels = labels.to(device)
            masked_position = labels != args.pad_idx
            masked_labels = labels[masked_position].contiguous().view(
                -1).unsqueeze(1)
            total_mask_count = masked_labels.size(0)

            # Prediction, output: Batch * Length * Vocab
            pred = model.reconstruct_predict(inputs,
                                             masked_position=masked_position)
            _, pred = pred.topk(10, 1, True, True)

            # Top1, 5, 10
            masking_acc[1] += (torch.sum(
                masked_labels == pred[:, :1]).item()) / total_mask_count
            masking_acc[5] += (torch.sum(
                masked_labels == pred[:, :5]).item()) / total_mask_count
            masking_acc[10] += (torch.sum(
                masked_labels == pred).item()) / total_mask_count

    for key in masking_acc.keys():
        masking_acc[key] /= len(h_loader)

    for key, value in masking_acc.items():
        print(f'Top {key} Accuracy: {value:.4f}')

    with open('./mask_result.pkl', 'wb') as f:
        pickle.dump(masking_acc, f)
コード例 #8
0
            writer.add_scalars('bce', {'bce_valid': bce_val}, n_iter)
            writer.add_scalars('accuracy', {'acc_train': acc_val}, n_iter)
            model = model.train()
            if (config.model == "experts" and n_iter < 13000):
                continue
            if (ppl_val <= best_ppl):
                best_ppl = ppl_val
                patient = 0
                model.save_model(best_ppl, n_iter, 0, 0, bleu_score_g,
                                 bleu_score_b)
                weights_best = deepcopy(model.state_dict())
            else:
                patient += 1
            if (patient > 2): break
except KeyboardInterrupt:
    print('-' * 89)
    print('Exiting from training early')

## TESTING
model.load_state_dict({name: weights_best[name] for name in weights_best})
model.eval()
model.epoch = 100
loss_test, ppl_test, bce_test, acc_test, bleu_score_g, bleu_score_b = evaluate(
    model, data_loader_tst, ty="test", max_dec_step=50)

file_summary = config.save_path + "summary.txt"
with open(file_summary, 'w') as the_file:
    the_file.write("EVAL\tLoss\tPPL\tAccuracy\tBleu_g\tBleu_b\n")
    the_file.write("{}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.2f}\t{:.2f}\n".format(
        "test", loss_test, ppl_test, acc_test, bleu_score_g, bleu_score_b))
コード例 #9
0
def main():
    ''' 
    Usage:
    python train.py -data_pkl m30k_deen_shr.pkl -log m30k_deen_shr -embs_share_weight -proj_share_weight -label_smoothing -save_model trained -b 256 -warmup 128000
    '''

    parser = argparse.ArgumentParser()

    parser.add_argument('-data_pkl',
                        default=None)  # all-in-1 data pickle or bpe field

    parser.add_argument('-train_path', default=None)  # bpe encoded data
    parser.add_argument('-val_path', default=None)  # bpe encoded data

    parser.add_argument('-epoch', type=int, default=10)
    parser.add_argument('-b', '--batch_size', type=int, default=2048)

    parser.add_argument('-d_model', type=int, default=512)
    parser.add_argument('-d_inner_hid', type=int, default=2048)
    parser.add_argument('-d_k', type=int, default=64)
    parser.add_argument('-d_v', type=int, default=64)

    parser.add_argument('-n_head', type=int, default=8)
    parser.add_argument('-n_layers', type=int, default=6)
    parser.add_argument('-warmup', '--n_warmup_steps', type=int, default=4000)

    parser.add_argument('-dropout', type=float, default=0.1)
    parser.add_argument('-embs_share_weight', action='store_true')
    parser.add_argument('-proj_share_weight', action='store_true')

    parser.add_argument('-log', default=None)
    parser.add_argument('-save_model', default=None)
    parser.add_argument('-save_mode',
                        type=str,
                        choices=['all', 'best'],
                        default='best')

    parser.add_argument('-no_cuda', action='store_true')
    parser.add_argument('-label_smoothing', action='store_true')

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda
    opt.d_word_vec = opt.d_model

    if not opt.log and not opt.save_model:
        print('No experiment result will be saved.')
        raise

    if opt.batch_size < 2048 and opt.n_warmup_steps <= 4000:
        print('[Warning] The warmup steps may be not enough.\n' \
              '(sz_b, warmup) = (2048, 4000) is the official setting.\n' \
              'Using smaller batch w/o longer warmup may cause ' \
              'the warmup stage ends with only little data trained.')

    device = torch.device('cuda' if opt.cuda else 'cpu')

    # ========= Loading Dataset =========#

    if all((opt.train_path, opt.val_path)):
        training_data, validation_data = prepare_dataloaders_from_bpe_files(
            opt, device)
    elif opt.data_pkl:
        training_data, validation_data = prepare_dataloaders(opt, device)
    else:
        raise

    print(opt)

    transformer = Transformer(opt.src_vocab_size,
                              opt.trg_vocab_size,
                              src_pad_idx=opt.src_pad_idx,
                              trg_pad_idx=opt.trg_pad_idx,
                              trg_emb_prj_weight_sharing=opt.proj_share_weight,
                              src_emb_prj_weight_sharing=opt.embs_share_weight,
                              d_k=opt.d_k,
                              d_v=opt.d_v,
                              d_model=opt.d_model,
                              d_word_vec=opt.d_word_vec,
                              d_inner=opt.d_inner_hid,
                              n_layers=opt.n_layers,
                              n_head=opt.n_head,
                              dropout=opt.dropout).to(device)
    model_path = 'checkpoints/pretrained.chkpt'
    checkpoint = torch.load(model_path, map_location=device)
    transformer.load_state_dict(checkpoint['model'])
    optimizer = ScheduledOptim(
        optim.Adam(transformer.parameters(), betas=(0.9, 0.98), eps=1e-09),
        2.0, opt.d_model, opt.n_warmup_steps)

    train(transformer, training_data, validation_data, optimizer, device, opt)
コード例 #10
0
ファイル: generate_samples.py プロジェクト: zequnl/PAML
print("Test model", config.model)
model = Transformer(p.vocab, model_file_path=config.save_path, is_eval=False)
# get persona map
filename = 'data/ConvAI2/test_persona_map'
with open(filename, 'rb') as f:
    persona_map = pickle.load(f)

#generate
iterations = 11
weights_original = deepcopy(model.state_dict())
tasks = p.get_personas('test')
for per in tqdm(tasks):
    num_of_dialog = p.get_num_of_dialog(persona=per, split='test')
    for val_dial_index in range(num_of_dialog):
        train_iter, val_iter = p.get_data_loader(persona=per,
                                                 batch_size=config.batch_size,
                                                 split='test',
                                                 fold=val_dial_index)
        persona = []
        for ppp in persona_map[per]:
            persona += ppp
        persona = list(set(persona))
        do_learning(model,
                    train_iter,
                    val_iter,
                    iterations=iterations,
                    persona=persona)
        model.load_state_dict(
            {name: weights_original[name]
             for name in weights_original})
コード例 #11
0
    dataset = CustomDataset(src_lines, trg_lines, tokenizer, config)
    data_loader = DataLoader(dataset, batch_size=16, shuffle=True)
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
                                                     mode='min',
                                                     patience=2)

    train_continue = False
    plus_epoch = 30
    if train_continue:
        weights = glob.glob('./weight/transformer_*')
        last_epoch = int(weights[-1].split('_')[-1])
        weight_path = weights[-1].replace('\\', '/')
        print('weight info of last epoch', weight_path)
        model.load_state_dict(torch.load(weight_path))
        total_epoch = last_epoch + plus_epoch
    else:
        last_epoch = 0
        total_epoch = plus_epoch

    model.train()
    for epoch in range(plus_epoch):
        epoch_loss = 0
        for iteration, data in enumerate(data_loader):
            encoder_inputs, decoder_inputs, targets = data
            optimizer.zero_grad()
            logits, _ = model(encoder_inputs, decoder_inputs)
            logits = logits.contiguous().view(-1, trg_vocab_size)
            targets = targets.contiguous().view(-1)
            loss = criterion(logits, targets)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    config = TransformerConfig(src_vocab_size=src_vocab_size,
                               trg_vocab_size=trg_vocab_size,
                               hidden_size=512,
                               num_hidden_layers=6,
                               num_attn_head=8,
                               hidden_act='gelu',
                               device=device,
                               feed_forward_size=2048,
                               padding_idx=0,
                               share_embeddings=True,
                               enc_max_seq_length=128,
                               dec_max_seq_length=128)

    model = Transformer(config).to(config.device)

    model_path = './weight/transformer_30'
    model.load_state_dict(torch.load(model_path))

    sentences = ['책 한 권을 빌리시게 되면 지금으로부터 사주 동안 빌릴 수 있습니다. 할인해주세요.',
                 '이벤트 할인은 일일 일회 제한이며 십퍼센트 할인이 가능하며 중복 할인은 적용되지 않습니다.',
                 '해당 상품은 만이천팔백원입니다.',
                 '번호는 공일공 다시 구구공공 다시 공구팔구이고 이전에 두번 방문했습니다.',
                 '고객님의 객실은 비동 천삼백이호이고 객실키는 2개 제공됩니다.',
                 '가랑비에 옷 젖는 줄 모른다.']
    for s in sentences:
        result = predict(config, tokenizer, model, s)
        print('predict reuslt:', result)
コード例 #13
0
ファイル: MAML.py プロジェクト: cstghitpku/PAML
                persona=tasks_iter.__next__(),
                batch_size=config.batch_size,
                split='train')
        #before first update
        v_loss, v_ppl = do_evaluation(meta_net, val_iter)
        train_loss_before.append(math.exp(v_loss))
        # Update fast nets
        val_loss, v_ppl = do_learning_fix_step(
            meta_net, train_iter, val_iter, iterations=config.meta_iteration)
        train_loss_meta.append(math.exp(val_loss.item()))
        batch_loss += val_loss
        # log

        # reset
        meta_net.load_state_dict(
            {name: weights_original[name]
             for name in weights_original})

    writer.add_scalars('loss_before',
                       {'train_loss_before': np.mean(train_loss_before)},
                       meta_iteration)
    writer.add_scalars('loss_meta',
                       {'train_loss_meta': np.mean(train_loss_meta)},
                       meta_iteration)

    # meta Update
    if (config.meta_optimizer == 'noam'):
        meta_optimizer.optimizer.zero_grad()
    else:
        meta_optimizer.zero_grad()
    batch_loss /= meta_batch_size
コード例 #14
0
def training(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    #===================================#
    #==============Logging==============#
    #===================================#

    logger = logging.getLogger(__name__)
    logger.setLevel(logging.DEBUG)
    handler = TqdmLoggingHandler()
    handler.setFormatter(
        logging.Formatter(" %(asctime)s - %(message)s", "%Y-%m-%d %H:%M:%S"))
    logger.addHandler(handler)
    logger.propagate = False

    #===================================#
    #============Data Load==============#
    #===================================#

    # 1) Data open
    write_log(logger, "Load data...")
    gc.disable()
    with open(os.path.join(args.preprocess_path, 'processed.pkl'), 'rb') as f:
        data_ = pickle.load(f)
        train_src_indices = data_['train_src_indices']
        valid_src_indices = data_['valid_src_indices']
        train_trg_indices = data_['train_trg_indices']
        valid_trg_indices = data_['valid_trg_indices']
        src_word2id = data_['src_word2id']
        trg_word2id = data_['trg_word2id']
        src_vocab_num = len(src_word2id)
        trg_vocab_num = len(trg_word2id)
        del data_
    gc.enable()
    write_log(logger, "Finished loading data!")

    # 2) Dataloader setting
    dataset_dict = {
        'train':
        CustomDataset(train_src_indices,
                      train_trg_indices,
                      min_len=args.min_len,
                      src_max_len=args.src_max_len,
                      trg_max_len=args.trg_max_len),
        'valid':
        CustomDataset(valid_src_indices,
                      valid_trg_indices,
                      min_len=args.min_len,
                      src_max_len=args.src_max_len,
                      trg_max_len=args.trg_max_len),
    }
    dataloader_dict = {
        'train':
        DataLoader(dataset_dict['train'],
                   drop_last=True,
                   batch_size=args.batch_size,
                   shuffle=True,
                   pin_memory=True,
                   num_workers=args.num_workers),
        'valid':
        DataLoader(dataset_dict['valid'],
                   drop_last=False,
                   batch_size=args.batch_size,
                   shuffle=False,
                   pin_memory=True,
                   num_workers=args.num_workers)
    }
    write_log(
        logger,
        f"Total number of trainingsets  iterations - {len(dataset_dict['train'])}, {len(dataloader_dict['train'])}"
    )

    #===================================#
    #===========Train setting===========#
    #===================================#

    # 1) Model initiating
    write_log(logger, 'Instantiating model...')
    model = Transformer(
        src_vocab_num=src_vocab_num,
        trg_vocab_num=trg_vocab_num,
        pad_idx=args.pad_id,
        bos_idx=args.bos_id,
        eos_idx=args.eos_id,
        d_model=args.d_model,
        d_embedding=args.d_embedding,
        n_head=args.n_head,
        dim_feedforward=args.dim_feedforward,
        num_common_layer=args.num_common_layer,
        num_encoder_layer=args.num_encoder_layer,
        num_decoder_layer=args.num_decoder_layer,
        src_max_len=args.src_max_len,
        trg_max_len=args.trg_max_len,
        dropout=args.dropout,
        embedding_dropout=args.embedding_dropout,
        trg_emb_prj_weight_sharing=args.trg_emb_prj_weight_sharing,
        emb_src_trg_weight_sharing=args.emb_src_trg_weight_sharing,
        parallel=args.parallel)
    model.train()
    model = model.to(device)
    tgt_mask = model.generate_square_subsequent_mask(args.trg_max_len - 1,
                                                     device)

    # 2) Optimizer & Learning rate scheduler setting
    optimizer = optimizer_select(model, args)
    scheduler = shceduler_select(optimizer, dataloader_dict, args)
    scaler = GradScaler()

    # 3) Model resume
    start_epoch = 0
    if args.resume:
        write_log(logger, 'Resume model...')
        checkpoint = torch.load(
            os.path.join(args.save_path, 'checkpoint.pth.tar'))
        start_epoch = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        scaler.load_state_dict(checkpoint['scaler'])
        del checkpoint

    #===================================#
    #=========Model Train Start=========#
    #===================================#

    best_val_acc = 0

    write_log(logger, 'Traing start!')

    for epoch in range(start_epoch + 1, args.num_epochs + 1):
        start_time_e = time()
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()
            if phase == 'valid':
                write_log(logger, 'Validation start...')
                val_loss = 0
                val_acc = 0
                model.eval()
            for i, (src, trg) in enumerate(
                    tqdm(dataloader_dict[phase],
                         bar_format='{l_bar}{bar:30}{r_bar}{bar:-2b}')):

                # Optimizer setting
                optimizer.zero_grad(set_to_none=True)

                # Input, output setting
                src = src.to(device, non_blocking=True)
                trg = trg.to(device, non_blocking=True)

                trg_sequences_target = trg[:, 1:]
                non_pad = trg_sequences_target != args.pad_id
                trg_sequences_target = trg_sequences_target[
                    non_pad].contiguous().view(-1)

                # Train
                if phase == 'train':

                    # Loss calculate
                    with autocast():
                        predicted = model(src,
                                          trg[:, :-1],
                                          tgt_mask,
                                          non_pad_position=non_pad)
                        predicted = predicted.view(-1, predicted.size(-1))
                        loss = label_smoothing_loss(predicted,
                                                    trg_sequences_target,
                                                    args.pad_id)

                    scaler.scale(loss).backward()
                    scaler.unscale_(optimizer)
                    clip_grad_norm_(model.parameters(), args.clip_grad_norm)
                    scaler.step(optimizer)
                    scaler.update()

                    if args.scheduler in ['constant', 'warmup']:
                        scheduler.step()
                    if args.scheduler == 'reduce_train':
                        scheduler.step(loss)

                    # Print loss value only training
                    if i == 0 or freq == args.print_freq or i == len(
                            dataloader_dict['train']):
                        acc = (predicted.max(dim=1)[1] == trg_sequences_target
                               ).sum() / len(trg_sequences_target)
                        iter_log = "[Epoch:%03d][%03d/%03d] train_loss:%03.3f | train_acc:%03.2f%% | learning_rate:%1.6f | spend_time:%02.2fmin" % \
                            (epoch, i, len(dataloader_dict['train']),
                            loss.item(), acc*100, optimizer.param_groups[0]['lr'],
                            (time() - start_time_e) / 60)
                        write_log(logger, iter_log)
                        freq = 0
                    freq += 1

                # Validation
                if phase == 'valid':
                    with torch.no_grad():
                        predicted = model(src,
                                          trg[:, :-1],
                                          tgt_mask,
                                          non_pad_position=non_pad)
                        loss = F.cross_entropy(predicted, trg_sequences_target)
                    val_loss += loss.item()
                    val_acc += (predicted.max(dim=1)[1] == trg_sequences_target
                                ).sum() / len(trg_sequences_target)
                    if args.scheduler == 'reduce_valid':
                        scheduler.step(val_loss)
                    if args.scheduler == 'lambda':
                        scheduler.step()

            if phase == 'valid':
                val_loss /= len(dataloader_dict[phase])
                val_acc /= len(dataloader_dict[phase])
                write_log(logger, 'Validation Loss: %3.3f' % val_loss)
                write_log(logger,
                          'Validation Accuracy: %3.2f%%' % (val_acc * 100))
                if val_acc > best_val_acc:
                    write_log(logger, 'Checkpoint saving...')
                    torch.save(
                        {
                            'epoch': epoch,
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            'scheduler': scheduler.state_dict(),
                            'scaler': scaler.state_dict()
                        }, f'checkpoint_{args.parallel}.pth.tar')
                    best_val_acc = val_acc
                    best_epoch = epoch
                else:
                    else_log = f'Still {best_epoch} epoch accuracy({round(best_val_acc.item()*100, 2)})% is better...'
                    write_log(logger, else_log)

    # 3) Print results
    print(f'Best Epoch: {best_epoch}')
    print(f'Best Accuracy: {round(best_val_acc.item(), 2)}')
コード例 #15
0
def main(proc_id, args):
    trg_sp = spm.SentencePieceProcessor()
    trg_sp.Load(args.spm_trg_path)
    trg_vocab_num = trg_sp.piece_size()
    bos_id = trg_sp.bos_id()
    eos_id = trg_sp.eos_id()
    pad_id = trg_sp.pad_id()
    src_vocab = requests.get(f'{args.api_url}/getMetaData').json()['src_vocab']
    unk_id = src_vocab['<unk>']

    device = torch.device(f"cuda:{proc_id}")
    model = Transformer(len(src_vocab),
                        trg_vocab_num,
                        pad_idx=pad_id,
                        bos_idx=bos_id,
                        eos_idx=eos_id,
                        src_max_len=args.src_max_len,
                        trg_max_len=args.trg_max_len,
                        d_model=args.d_model,
                        d_embedding=args.d_embedding,
                        n_head=args.n_head,
                        dim_feedforward=args.dim_feedforward,
                        num_encoder_layer=args.num_encoder_layer,
                        num_decoder_layer=args.num_decoder_layer,
                        num_mask_layer=args.num_mask_layer)

    model.load_state_dict(
        torch.load(args.checkpoint_path, map_location=device)['model'])
    model.src_output_linear = None
    model.src_output_linear2 = None
    model.src_output_norm = None
    model.mask_encoders = None
    model = model.to(device)
    model = model.eval()

    tgt_masks = {
        l: model.generate_square_subsequent_mask(l, device)
        for l in range(1, args.trg_max_len + 1)
    }

    while True:
        data = requests.get(f'{args.api_url}/getData').json()
        pred_data = {'file': data['file'], 'content': []}
        parsed_ids = []
        for d in data['content']:
            parsed_id = [src_vocab.get(c, unk_id) for c in d['hanja']]
            if args.min_len <= len(parsed_id) <= args.src_max_len:
                input_id = np.zeros(args.src_max_len, dtype=np.int64)
                input_id[:len(parsed_id)] = parsed_id
                parsed_ids.append(input_id)
                pred_data['content'].append(d)

        num_iter = ceil(len(parsed_ids) / args.batch_size)
        batch_size_ = args.batch_size
        predicted_num = 0

        with torch.no_grad():
            batch_indices = torch.arange(0,
                                         args.beam_size * args.batch_size,
                                         args.beam_size,
                                         device=device)
            for iter_ in range(num_iter):
                iter_time = time()
                src_sequences = parsed_ids[iter_ *
                                           args.batch_size:(iter_ + 1) *
                                           args.batch_size]

                scores_save = torch.zeros(args.beam_size * args.batch_size,
                                          1,
                                          device=device)
                top_k_scores = torch.zeros(args.beam_size * args.batch_size,
                                           1,
                                           device=device)
                complete_seqs = dict()
                complete_ind = set()
                if len(src_sequences) < args.batch_size:
                    batch_size_ = len(src_sequences)
                    batch_indices = torch.arange(0,
                                                 args.beam_size * batch_size_,
                                                 args.beam_size,
                                                 device=device)
                    scores_save = torch.zeros(args.beam_size * batch_size_,
                                              1,
                                              device=device)
                    top_k_scores = torch.zeros(args.beam_size * batch_size_,
                                               1,
                                               device=device)

                src_sequences = torch.cat([
                    torch.cuda.LongTensor(seq, device=device)
                    for seq in src_sequences
                ])
                src_sequences = src_sequences.view(batch_size_,
                                                   args.src_max_len)

                # Encoding
                # encoder_out: (src_seq, batch_size, d_model), src_key_padding_mask: (batch_size, src_seq)
                encoder_out = model.src_embedding(src_sequences).transpose(
                    0, 1)
                src_key_padding_mask = (src_sequences == pad_id)
                for encoder in model.encoders:
                    encoder_out = encoder(
                        encoder_out, src_key_padding_mask=src_key_padding_mask)

                # Expanding
                # encoder_out: (src_seq, batch_size*k, d_model), src_key_padding_mask: (batch_size*k, src_seq)
                src_seq_size = encoder_out.size(0)
                src_key_padding_mask = src_key_padding_mask.view(
                    batch_size_, 1, -1).repeat(1, args.beam_size, 1)
                src_key_padding_mask = src_key_padding_mask.view(
                    -1, src_seq_size)
                encoder_out = encoder_out.view(-1, batch_size_, 1,
                                               args.d_model).repeat(
                                                   1, 1, args.beam_size, 1)
                encoder_out = encoder_out.view(src_seq_size, -1, args.d_model)

                # Decoding start token setting
                seqs = torch.tensor([[bos_id]],
                                    dtype=torch.long,
                                    device=device)
                seqs = seqs.repeat(args.beam_size * batch_size_,
                                   1).contiguous()

                for step in range(model.trg_max_len):
                    # Decoder setting
                    # tgt_mask: (out_seq), tgt_key_padding_mask: (batch_size * k, out_seq)
                    tgt_mask = tgt_masks[seqs.size(1)]
                    tgt_key_padding_mask = (seqs == pad_id)

                    # Decoding sentence
                    # decoder_out: (out_seq, batch_size * k, d_model)
                    decoder_out = model.trg_embedding(seqs).transpose(0, 1)
                    for decoder in model.decoders:
                        decoder_out = decoder(
                            decoder_out,
                            encoder_out,
                            tgt_mask=tgt_mask,
                            memory_key_padding_mask=src_key_padding_mask,
                            tgt_key_padding_mask=tgt_key_padding_mask)

                    # Score calculate
                    # scores: (batch_size * k, vocab_num)
                    scores = F.gelu(model.trg_output_linear(decoder_out[-1]))
                    scores = model.trg_output_linear2(
                        model.trg_output_norm(scores))
                    scores = F.log_softmax(scores, dim=1)

                    # Repetition Penalty
                    if step > 0 and args.repetition_penalty > 0:
                        prev_ix = next_word_inds.view(-1)
                        for index, prev_token_id in enumerate(prev_ix):
                            scores[index][
                                prev_token_id] *= args.repetition_penalty

                    # Add score
                    scores = top_k_scores.expand_as(scores) + scores
                    if step == 0:
                        # scores: (batch_size, vocab_num)
                        # top_k_scores: (batch_size, k)
                        scores = scores[::args.beam_size]
                        # set eos token probability zero in first step
                        scores[:, eos_id] = float('-inf')
                        top_k_scores, top_k_words = scores.topk(
                            args.beam_size, 1, True, True)
                    else:
                        # top_k_scores: (batch_size * k, out_seq)
                        top_k_scores, top_k_words = scores.view(
                            batch_size_, -1).topk(args.beam_size, 1, True,
                                                  True)

                    # Previous and Next word extract
                    # seqs: (batch_size * k, out_seq + 1)
                    prev_word_inds = top_k_words // trg_vocab_num
                    next_word_inds = top_k_words % trg_vocab_num
                    top_k_scores = top_k_scores.view(
                        batch_size_ * args.beam_size, -1)
                    top_k_words = top_k_words.view(
                        batch_size_ * args.beam_size, -1)
                    seqs = seqs[prev_word_inds.view(-1) +
                                batch_indices.unsqueeze(1).repeat(
                                    1, args.beam_size).view(-1)]
                    seqs = torch.cat([
                        seqs,
                        next_word_inds.view(args.beam_size * batch_size_, -1)
                    ],
                                     dim=1)

                    # Find and Save Complete Sequences Score
                    eos_ind = torch.where(next_word_inds.view(-1) == eos_id)[0]
                    if len(eos_ind) > 0:
                        eos_ind = eos_ind.tolist()
                        complete_ind_add = set(eos_ind) - complete_ind
                        complete_ind_add = list(complete_ind_add)
                        complete_ind.update(eos_ind)
                        if len(complete_ind_add) > 0:
                            scores_save[complete_ind_add] = top_k_scores[
                                complete_ind_add]
                            for ix in complete_ind_add:
                                complete_seqs[ix] = seqs[ix].tolist()

                # If eos token doesn't exist in sequence
                score_save_pos = torch.where(scores_save == 0)
                if len(score_save_pos[0]) > 0:
                    for ix in score_save_pos[0].tolist():
                        complete_seqs[ix] = seqs[ix].tolist()
                    scores_save[score_save_pos] = top_k_scores[score_save_pos]

                # Beam Length Normalization
                lp = torch.tensor([
                    len(complete_seqs[i])
                    for i in range(batch_size_ * args.beam_size)
                ],
                                  device=device)
                lp = (((lp + args.beam_size)**args.beam_alpha) /
                      ((args.beam_size + 1)**args.beam_alpha))
                scores_save = scores_save / lp.unsqueeze(1)

                # Predicted and Label processing
                ind = scores_save.view(batch_size_, args.beam_size,
                                       -1).argmax(dim=1)
                ind = (ind.view(-1) + batch_indices).tolist()
                for i in ind:
                    predicted_sequence = trg_sp.decode_ids(complete_seqs[i])
                    pred_data['content'][predicted_num][
                        'predicted_sequence'] = predicted_sequence
                    predicted_num += 1

                iter_time = time() - iter_time
                print(
                    f"{proc_id} - iter: {iter_ + 1}/{num_iter}, {iter_time:.2f}"
                )

        res = requests.post(f'{args.api_url}/commitData',
                            json=pred_data).json()
        print(f"{proc_id} - Progress: {res['progress']}, {pred_data['file']}")
        if res['progress'] == 'finish':
            return
コード例 #16
0
ファイル: train.py プロジェクト: fhzh123/dacon_competition
def training(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    #===================================#
    #==============Logging==============#
    #===================================#

    logger = logging.getLogger(__name__)
    logger.setLevel(logging.DEBUG)
    handler = TqdmLoggingHandler()
    handler.setFormatter(
        logging.Formatter(" %(asctime)s - %(message)s", "%Y-%m-%d %H:%M:%S"))
    logger.addHandler(handler)
    logger.propagate = False

    #===================================#
    #============Data Load==============#
    #===================================#

    # 1) Dataloader setting
    write_log(logger, "Load data...")
    gc.disable()
    dataset_dict = {
        'train': CustomDataset(data_path=args.preprocessed_path,
                               phase='train'),
        'valid': CustomDataset(data_path=args.preprocessed_path,
                               phase='valid'),
        'test': CustomDataset(data_path=args.preprocessed_path, phase='test')
    }
    unique_menu_count = dataset_dict['train'].unique_count()
    dataloader_dict = {
        'train':
        DataLoader(dataset_dict['train'],
                   drop_last=True,
                   batch_size=args.batch_size,
                   shuffle=True,
                   pin_memory=True,
                   num_workers=args.num_workers,
                   collate_fn=PadCollate()),
        'valid':
        DataLoader(dataset_dict['valid'],
                   drop_last=False,
                   batch_size=args.batch_size,
                   shuffle=False,
                   pin_memory=True,
                   num_workers=args.num_workers,
                   collate_fn=PadCollate()),
        'test':
        DataLoader(dataset_dict['test'],
                   drop_last=False,
                   batch_size=args.batch_size,
                   shuffle=False,
                   pin_memory=True,
                   num_workers=args.num_workers,
                   collate_fn=PadCollate())
    }
    gc.enable()
    write_log(
        logger,
        f"Total number of trainingsets  iterations - {len(dataset_dict['train'])}, {len(dataloader_dict['train'])}"
    )

    #===================================#
    #===========Model setting===========#
    #===================================#

    # 1) Model initiating
    write_log(logger, "Instantiating models...")
    model = Transformer(model_type=args.model_type,
                        input_size=unique_menu_count,
                        d_model=args.d_model,
                        d_embedding=args.d_embedding,
                        n_head=args.n_head,
                        dim_feedforward=args.dim_feedforward,
                        num_encoder_layer=args.num_encoder_layer,
                        dropout=args.dropout)
    model = model.train()
    model = model.to(device)

    # 2) Optimizer setting
    optimizer = optimizer_select(model, args)
    scheduler = shceduler_select(optimizer, dataloader_dict, args)
    criterion = nn.MSELoss()
    scaler = GradScaler(enabled=True)

    model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

    # 2) Model resume
    start_epoch = 0
    if args.resume:
        checkpoint = torch.load(os.path.join(args.model_path,
                                             'checkpoint.pth.tar'),
                                map_location='cpu')
        start_epoch = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        model = model.train()
        model = model.to(device)
        del checkpoint

    #===================================#
    #=========Model Train Start=========#
    #===================================#

    best_val_rmse = 9999999

    write_log(logger, 'Train start!')

    for epoch in range(start_epoch, args.num_epochs):
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()
                train_start_time = time.time()
                freq = 0
            elif phase == 'valid':
                model.eval()
                val_loss = 0
                val_rmse = 0

            for i, (src_menu, label_lunch,
                    label_supper) in enumerate(dataloader_dict[phase]):

                # Optimizer setting
                optimizer.zero_grad()

                # Input, output setting
                src_menu = src_menu.to(device, non_blocking=True)
                label_lunch = label_lunch.float().to(device, non_blocking=True)
                label_supper = label_supper.float().to(device,
                                                       non_blocking=True)

                # Model
                with torch.set_grad_enabled(phase == 'train'):
                    with autocast(enabled=True):
                        if args.model_type == 'sep':
                            logit = model(src_menu)
                            logit_lunch = logit[:, 0]
                            logit_supper = logit[:, 0]
                        elif args.model_type == 'total':
                            logit = model(src_menu)
                            logit_lunch = logit[:, 0]
                            logit_supper = logit[:, 1]

                    # Loss calculate
                    loss_lunch = criterion(logit_lunch, label_lunch)
                    loss_supper = criterion(logit_supper, label_supper)
                    loss = loss_lunch + loss_supper

                # Back-propagation
                if phase == 'train':
                    scaler.scale(loss).backward()
                    scaler.unscale_(optimizer)
                    clip_grad_norm_(model.parameters(), args.clip_grad_norm)
                    scaler.step(optimizer)
                    scaler.update()

                    # Scheduler setting
                    if args.scheduler in ['constant', 'warmup']:
                        scheduler.step()
                    if args.scheduler == 'reduce_train':
                        scheduler.step(loss)

                # Print loss value
                rmse_loss = torch.sqrt(loss)
                if phase == 'train':
                    if i == 0 or freq == args.print_freq or i == len(
                            dataloader_dict['train']):
                        batch_log = "[Epoch:%d][%d/%d] train_MSE_loss:%2.3f  | train_RMSE_loss:%2.3f | learning_rate:%3.6f | spend_time:%3.2fmin" \
                                % (epoch+1, i, len(dataloader_dict['train']),
                                loss.item(), rmse_loss.item(), optimizer.param_groups[0]['lr'],
                                (time.time() - train_start_time) / 60)
                        write_log(logger, batch_log)
                        freq = 0
                    freq += 1
                elif phase == 'valid':
                    val_loss += loss.item()
                    val_rmse += rmse_loss.item()

        if phase == 'valid':
            val_loss /= len(dataloader_dict['valid'])
            val_rmse /= len(dataloader_dict['valid'])
            write_log(logger, 'Validation Loss: %3.3f' % val_loss)
            write_log(logger, 'Validation RMSE: %3.3f' % val_rmse)

            if val_rmse < best_val_rmse:
                write_log(logger, 'Checkpoint saving...')
                if not os.path.exists(args.save_path):
                    os.mkdir(args.save_path)
                torch.save(
                    {
                        'epoch': epoch,
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict(),
                        'scaler': scaler.state_dict()
                    }, os.path.join(args.save_path, f'checkpoint_cap.pth.tar'))
                best_val_rmse = val_rmse
                best_epoch = epoch
            else:
                else_log = f'Still {best_epoch} epoch RMSE({round(best_val_rmse, 3)}) is better...'
                write_log(logger, else_log)

    # 3)
    write_log(logger, f'Best Epoch: {best_epoch+1}')
    write_log(logger, f'Best Accuracy: {round(best_val_rmse, 3)}')
コード例 #17
0
class Trainer:
    def __init__(self,
                 params,
                 mode,
                 train_iter=None,
                 valid_iter=None,
                 test_iter=None):
        self.params = params

        # Train mode
        if mode == 'train':
            self.train_iter = train_iter
            self.valid_iter = valid_iter

        # Test mode
        else:
            self.test_iter = test_iter

        self.model = Transformer(self.params)
        self.model.to(self.params.device)

        # Scheduling Optimzer
        self.optimizer = ScheduledAdam(optim.Adam(self.model.parameters(),
                                                  betas=(0.9, 0.98),
                                                  eps=1e-9),
                                       hidden_dim=params.hidden_dim,
                                       warm_steps=params.warm_steps)

        self.criterion = nn.CrossEntropyLoss(ignore_index=self.params.pad_idx)
        self.criterion.to(self.params.device)

    def train(self):
        print(self.model)
        print(
            f'The model has {self.model.count_params():,} trainable parameters'
        )
        best_valid_loss = float('inf')

        for epoch in range(self.params.num_epoch):
            self.model.train()
            epoch_loss = 0
            start_time = time.time()

            for batch in self.train_iter:
                # For each batch, first zero the gradients
                self.optimizer.zero_grad()
                source = batch.kor
                target = batch.eng

                # target sentence consists of <sos> and following tokens (except the <eos> token)
                output = self.model(source, target[:, :-1])[0]

                # ground truth sentence consists of tokens and <eos> token (except the <sos> token)
                output = output.contiguous().view(-1, output.shape[-1])
                target = target[:, 1:].contiguous().view(-1)
                # output = [(batch size * target length - 1), output dim]
                # target = [(batch size * target length - 1)]
                loss = self.criterion(output, target)
                loss.backward()

                # clip the gradients to prevent the model from exploding gradient
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               self.params.clip)

                self.optimizer.step()

                # 'item' method is used to extract a scalar from a tensor which only contains a single value.
                epoch_loss += loss.item()

            train_loss = epoch_loss / len(self.train_iter)
            valid_loss = self.evaluate()

            end_time = time.time()
            epoch_mins, epoch_secs = epoch_time(start_time, end_time)

            if valid_loss < best_valid_loss:
                best_valid_loss = valid_loss
                torch.save(self.model.state_dict(), self.params.save_model)

            print(
                f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s'
            )
            print(
                f'\tTrain Loss: {train_loss:.3f} | Val. Loss: {valid_loss:.3f}'
            )

    def evaluate(self):
        self.model.eval()
        epoch_loss = 0

        with torch.no_grad():
            for batch in self.valid_iter:
                source = batch.kor
                target = batch.eng

                output = self.model(source, target[:, :-1])[0]

                output = output.contiguous().view(-1, output.shape[-1])
                target = target[:, 1:].contiguous().view(-1)

                loss = self.criterion(output, target)

                epoch_loss += loss.item()

        return epoch_loss / len(self.valid_iter)

    def inference(self):
        self.model.load_state_dict(torch.load(self.params.save_model))
        self.model.eval()
        epoch_loss = 0

        with torch.no_grad():
            for batch in self.test_iter:
                source = batch.kor
                target = batch.eng

                output = self.model(source, target[:, :-1])[0]

                output = output.contiguous().view(-1, output.shape[-1])
                target = target[:, 1:].contiguous().view(-1)

                loss = self.criterion(output, target)

                epoch_loss += loss.item()

        test_loss = epoch_loss / len(self.test_iter)
        print(f'Test Loss: {test_loss:.3f}')