예제 #1
0
def main():
    data_path = './data/chatbot.txt'
    voc, pairs = loadPrepareData(data_path)

    # 把含有低频词的句子扔掉
    MIN_COUNT = Config.MIN_COUNT
    pairs = trimRareWords(voc, pairs, MIN_COUNT)

    training_batches = [batch2TrainData(voc, [random.choice(pairs) for _ in range(Config.batch_size)])
                        for _ in range(Config.total_step)]

    # 词嵌入部分
    embedding = nn.Embedding(voc.num_words, Config.hidden_size)

    # 定义编码解码器
    encoder = EncoderRNN(Config.hidden_size, embedding, Config.encoder_n_layers, Config.dropout)
    decoder = LuongAttnDecoderRNN(Config.attn_model, embedding, Config.hidden_size, voc.num_words, Config.decoder_n_layers, Config.dropout)

    # 定义优化器
    encoder_optimizer = optim.Adam(encoder.parameters(), lr=Config.learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=Config.learning_rate * Config.decoder_learning_ratio)

    start_iteration = 1
    save_every = 4000   # 多少步保存一次模型

    for iteration in range(start_iteration, Config.total_step + 1):
        training_batch = training_batches[iteration - 1]
        input_variable, lengths, target_variable, mask, max_target_len = training_batch

        start_time = time.time()
        # Run a training iteration with batch
        loss = train(input_variable, lengths, target_variable, mask, max_target_len, encoder,
                     decoder, embedding, encoder_optimizer, decoder_optimizer, Config.batch_size, Config.clip)

        time_str = datetime.datetime.now().isoformat()
        log_str = "time: {}, Iteration: {}; Percent complete: {:.1f}%; loss: {:.4f}, spend_time: {:6f}".format(time_str, iteration, iteration / Config.total_step * 100, loss, time.time() - start_time)
        rainbow(log_str)

        # Save checkpoint
        if iteration % save_every == 0:
            save_path = './save_model/'
            if not os.path.exists(save_path):
                os.makedirs(save_path)

            torch.save({
                'iteration': iteration,
                'encoder': encoder.state_dict(),
                'decoder': decoder.state_dict(),
                'en_opt': encoder_optimizer.state_dict(),
                'de_opt': decoder_optimizer.state_dict(),
                'loss': loss,
                'voc_dict': voc.__dict__,
                'embedding': embedding.state_dict()
            }, os.path.join(save_path, '{}_{}_model.tar'.format(iteration, 'checkpoint')))
예제 #2
0
def train_iters(*,  # data: Data,
                corpus: Corpus,
                encoder: EncoderRNN,
                decoder: AttnDecoderRNN,
                device: torch.device,
                n_iters: int,
                batch_size: int,
                teacher_forcing_ratio: float,
                print_every: int = 1000,
                learning_rate: float = 0.01
                ) -> None:
    data = torch.utils.data.DataLoader(dataset=corpus, batch_size=batch_size)

    start: float = time.time()
    plot_losses: List[float] = []
    print_loss_total: float = 0  # Reset every print_every
    plot_loss_total: float = 0  # Reset every plot_every

    encoder_optimizer: Optimizer = SGD(encoder.parameters(), lr=learning_rate)
    decoder_optimizer: Optimizer = SGD(decoder.parameters(), lr=learning_rate)
    #
    # training_pairs: List[ParallelTensor] = [random.choice(data.pairs).tensors(source_vocab=data.source_vocab,
    #                                                                           target_vocab=data.target_vocab,
    #                                                                           device=device)
    #                                         for _ in range(n_iters)]

    criterion: nn.NLLLoss = nn.NLLLoss(reduction='mean')  # ignore_index=corpus.characters.pad_int)

    # for pair in parallel_data:
    #    print(f"src={len(pair['data'])}\ttgt={len(pair['labels'])}")

    for iteration in range(1, n_iters + 1):  # type: int

        # training_pair: ParallelTensor = training_pairs[iteration - 1]
        # input_tensor: torch.Tensor = training_pair.source   # shape: [seq_len, batch_size=1]
        # target_tensor: torch.Tensor = training_pair.target  # shape: [seq_len, batch_size=1]

        for batch in data:
            # print(f"batch['data'].shape={batch['data'].shape}\tbatch['labels'].shape{batch['labels'].shape}")
            # sys.exit()
            input_tensor: torch.Tensor = batch["data"].permute(1, 0)
            target_tensor: torch.Tensor = batch["labels"].permute(1, 0)

            actual_batch_size: int = min(batch_size, input_tensor.shape[1])

            verify_shape(tensor=input_tensor, expected=[corpus.word_tensor_length, actual_batch_size])
            verify_shape(tensor=target_tensor, expected=[corpus.label_tensor_length, actual_batch_size])

            # print(f"input_tensor.shape={input_tensor.shape}\t\ttarget_tensor.shape={target_tensor.shape}")
            # sys.exit()

            loss: float = train(input_tensor=input_tensor,
                                target_tensor=target_tensor,
                                encoder=encoder,
                                decoder=decoder,
                                encoder_optimizer=encoder_optimizer,
                                decoder_optimizer=decoder_optimizer,
                                criterion=criterion,
                                device=device,
                                max_src_length=corpus.word_tensor_length,
                                max_tgt_length=corpus.label_tensor_length,
                                batch_size=actual_batch_size,
                                start_of_sequence_symbol=corpus.characters.start_of_sequence.integer,
                                teacher_forcing_ratio=teacher_forcing_ratio)

            print_loss_total += loss
            plot_loss_total += loss

        if iteration % print_every == 0:
            print_loss_avg: float = print_loss_total / print_every
            print_loss_total = 0
            print('%s (%d %d%%) %.4f' % (time_since(since=start, percent=iteration / n_iters),
                                         iteration, iteration / n_iters * 100, print_loss_avg))
            sys.stdout.flush()
예제 #3
0
if loadFilename:
    encoder.load_state_dict(encoder_sd)
    decoder.load_state_dict(decoder_sd)
# 使用合适的设备
encoder = encoder.to(device)
decoder = decoder.to(device)
print('Models built and ready to go!')

######################################################################
# 设置进入训练模式,从而开启dropout
encoder.train()
decoder.train()

# 初始化优化器
print('Building optimizers ...')
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(),
                               lr=learning_rate * decoder_learning_ratio)
if loadFilename:
    encoder_optimizer.load_state_dict(encoder_optimizer_sd)
    decoder_optimizer.load_state_dict(decoder_optimizer_sd)

# 开始训练
print("Starting Training!")


def trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer,
               decoder_optimizer, embedding, encoder_n_layers,
               decoder_n_layers, save_dir, n_iteration, batch_size,
               print_every, save_every, clip, corpus_name, loadFilename):
예제 #4
0
def trainIters(learning_rate=0.001):
    epochs = 1
    plot_train_losses = []
    plot_val_losses = []
    plot_loss_total = 0  # Reset every plot_every
    hidden_size = 256
    print('------- Hypers --------\n'
          '- epochs: %i\n'
          '- learning rate: %g\n'
          '- hidden size: %i\n'
          '----------------'
          '' % (epochs, learning_rate, hidden_size))

    # set model
    vocab_size_encoder = get_vocab_size(CodeEncoder())
    vocab_size_decoder = get_vocab_size(CommentEncoder())
    print(vocab_size_encoder)
    print(vocab_size_decoder)
    print('----------------')
    # COMMENT OUT WHEN FIRST TRAINING
    # encoder, decoder = load_model()
    encoder = EncoderRNN(vocab_size_encoder, hidden_size).to(device)
    decoder = AttnDecoderRNN(hidden_size, vocab_size_decoder,
                             dropout_p=0.1).to(device)

    # set training hypers
    criterion = nn.NLLLoss()
    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)

    # set data
    dataLoaders = createLoaders(extras=extras, debug=True)

    # used for initial input of decoder
    # with open('dicts/comment_dict.pkl', 'rb') as pfile:
    # 	SOS_token = pickle.load(pfile)['<SOS>']
    # since we already prepend <SOS> to the comment, don't think need this in decoder model anymore
    SOS_token = None

    # iteration
    counts = []
    best_val_loss = 100
    for eps in range(1, epochs + 1):
        print('Epoch Number', eps)
        for count, (inputs, targets) in enumerate(dataLoaders['train'], 0):
            inputs = torch.LongTensor(inputs[0])
            targets = torch.LongTensor(targets[0])
            inputs, targets = inputs.to(device), targets.to(device)

            loss = train(inputs,
                         targets,
                         encoder,
                         decoder,
                         encoder_optimizer,
                         decoder_optimizer,
                         criterion,
                         SOS_token=SOS_token)
            plot_loss_total += loss
            # if count != 0 and count % 10 == 0:
            print(count, loss)

        counts.append(eps)
        plot_loss_avg = plot_loss_total / len(dataLoaders['train'])
        plot_train_losses.append(plot_loss_avg)
        val_loss = validate_model(encoder,
                                  decoder,
                                  criterion,
                                  dataLoaders['valid'],
                                  SOS_token=SOS_token,
                                  device=device)
        if val_loss < best_val_loss:
            save_model(encoder, decoder)
            best_val_loss = val_loss
        plot_val_losses.append(val_loss)
        plot_loss_total = 0
        save_loss(plot_train_losses, plot_val_losses)
    showPlot(counts, plot_train_losses, plot_val_losses)