def forward(self):
        start_epoch = self.loading_epoch + 1
        total_epochs = start_epoch + self.total_epochs
        for epoch in range(start_epoch, total_epochs):
            self.train_one_epoch(epoch)
            self.test_one_epoch(epoch)

            # Update learning rate after every specified times iteration
            if self.scheduler:
                self.scheduler.step()

            if self.save:
                makeSubdir(self.model_path)
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                }, self.model_fullpath.format(self.model_name, epoch))

                # Log information
                logInfoWithDot(
                    self.logger, "SAVED MODEL: {}".format(
                        self.model_fullpath.format(self.model_name, epoch)))

        # self.writer.add_graph(self.model)
        self.writer.close()
        logInfoWithDot(
            self.logger, "TRAINING FINISHED, TIME USAGE: {} secs".format(
                timeSince(self.start_time)))
示例#2
0
文件: main.py 项目: mrzhuzhe/pepper
def trainIters(encoder,
               decoder,
               n_iters,
               print_every=1000,
               plot_every=100,
               learning_rate=0.01):
    start = time.time()
    plot_losses = []
    print_loss_total = 0  # Reset every print_every
    plot_loss_total = 0  # Reset every plot_every

    encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)
    training_pairs = [
        tensorsFromPair(random.choice(pairs)) for i in range(n_iters)
    ]

    # neg-loglikehood
    criterion = nn.NLLLoss()

    for iter in range(1, n_iters + 1):
        training_pair = training_pairs[iter - 1]
        input_tensor = training_pair[0]
        target_tensor = training_pair[1]

        loss = train(input_tensor, target_tensor, encoder, decoder,
                     encoder_optimizer, decoder_optimizer, criterion)
        print_loss_total += loss
        plot_loss_total += loss

        if iter % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print('%s (%d %d%%) %.4f' %
                  (timeSince(start, iter / n_iters), iter,
                   iter / n_iters * 100, print_loss_avg))

        if iter % plot_every == 0:
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0

    torch.save(encoder.state_dict(), encoderPATH)
    torch.save(decoder.state_dict(), attentionDencoderPATH)
    showPlot(plot_losses)
示例#3
0
文件: main.py 项目: Sanster/notes
def trainIters(input_lang,
               output_lang,
               encoder,
               decoder,
               n_iters,
               print_every=1000,
               plot_every=100,
               learning_rate=0.01):
    start = time.time()
    plot_losses = []
    print_loss_total = 0  # Reset every print_every
    plot_loss_total = 0  # Reset every plot_every

    encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)
    training_pairs = [
        tensorsFromPair(input_lang, output_lang, random.choice(pairs), device)
        for i in range(n_iters)
    ]
    criterion = nn.NLLLoss()

    for iter in tqdm(range(1, n_iters + 1), total=n_iters):
        training_pair = training_pairs[iter - 1]
        input_tensor = training_pair[0]
        target_tensor = training_pair[1]

        loss = train(input_tensor, target_tensor, encoder, decoder,
                     encoder_optimizer, decoder_optimizer, criterion)
        print_loss_total += loss
        plot_loss_total += loss

        if iter % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print('%s (%d %d%%) %.4f' %
                  (timeSince(start, iter / n_iters), iter,
                   iter / n_iters * 100, print_loss_avg))

        if iter % plot_every == 0:
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0

    showPlot(plot_losses)
    def train_one_epoch(self, ep, log_interval=50):
        self.model.train()

        lr = self.optimizer.param_groups[0]['lr']

        for i, (images, labels) in enumerate(self.trainloader, 0):
            images = images.to(self.device)
            labels = labels.to(self.device)

            # Inference
            preds = self.model(images)
            self.correct_counts += self.count_correct(preds, labels)
            self.total_counts += len(labels)
            # Loss
            loss = self.criterion(preds, labels)
            self.train_loss += loss.item()
            # Backward propagation
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            if i % log_interval == 0:
                self.logger.info(
                    'Train Epoch: {} [{}/{} ({:.0f}%)]\tLearning Rate: {}\tLoss: {:.6f}\tTime Usage:{:.8}'
                    .format(ep, i * len(images), len(self.trainloader.dataset),
                            100. * i / len(self.trainloader), lr, loss.data,
                            timeSince(self.start_time)))

            if i == len(self.trainloader) - 1:
                accuracy = 100.0 * self.correct_counts / self.total_counts
                self.logger.info(
                    'Loss of the network on the {0} train images: {1:.6f}'.
                    format(self.total_counts, self.train_loss))
                self.logger.info(
                    'Accuracy of the network on the {0} train images: {1:.3f}%'
                    .format(self.total_counts, accuracy))

                # Write to Tensorboard file
                self.writer.add_scalar('Loss/train', self.train_loss, ep)
                self.writer.add_scalar('Accuracy/train', accuracy, ep)
示例#5
0
def decode_batch(device,
                 pairs,
                 encoder,
                 decoder,
                 input_lang,
                 output_lang,
                 batch_size=64):
    # with open("test/test.ref", "w", encoding='utf-8') as f:
    #     for pair in pairs:
    #         f.write(pair[1] + '\n')
    start = time.time()
    f_ref = open("test/test.ref", "w", encoding='utf-8')
    with open("test/test.hyp", "w", encoding='utf-8') as f:
        searcher = GreedySearchDecoderBatch(encoder, decoder, device)
        i = 0
        k = 0
        while i < len(pairs):
            k += 1
            if k % 10 == 0:
                print('%s, batch=%d, i=%d' %
                      (timeSince(start, i / len(pairs)), k, i))
            # sentences = [pairs[j][0] for j in range(i, min(i+batch_size, len(pairs)))]
            all_output_words, ref_sentences = evaluate_batch(
                device,
                searcher,
                input_lang,
                output_lang,
                pairs[i:min(i + batch_size, len(pairs))],
                max_length=MAX_LENGTH)
            for j in range(len(ref_sentences)):
                # for output_words in all_output_words:
                output_words = all_output_words[j]
                output_sentence = ' '.join(output_words).replace('EOS',
                                                                 '').strip()
                f.write(output_sentence + '\n')
                f_ref.write(ref_sentences[j] + '\n')
            i += batch_size
    f_ref.close()
示例#6
0
def runTrain():

    all_losses = []
    total_loss = 0  # Reset every plot_every iters

    start = time.time()

    for iter in range(1, n_iters + 1):
        #   print(randomTrainingExample())
        output, loss = train(*randomTrainingExample())
        total_loss += loss

        if iter % print_every == 0:
            print('%s (%d %d%%) %.4f' %
                  (timeSince(start), iter, iter / n_iters * 100, loss))

        if iter % plot_every == 0:
            all_losses.append(total_loss / plot_every)
            total_loss = 0

    # 保存训练结果
    torch.save(rnn.state_dict(), PATH)
    plt.figure()
    plt.plot(all_losses)
            if i==config['n_iters_d']-1:
                break
            batch = train_loader.next_batch()
            if batch is None: # end of epoch
                break
            context, context_lens, utt_lens, floors,_,_,_,response,res_lens,_ = batch
            context, utt_lens = context[:,:,1:], utt_lens-1 # remove the sos token in the context and reduce the context length
            context, context_lens, utt_lens, floors, response, res_lens\
                = gVar(context), gVar(context_lens), gVar(utt_lens), gData(floors), gVar(response), gVar(res_lens)                      
         '''

        if itr % args.log_every == 0:
            elapsed = time.time() - itr_start_time
            log = '%s-%s|%s@gpu%d epo:[%d/%d] iter:[%d/%d] step_time:%ds elapsed:%s \n                      '\
            %(args.model, args.expname, args.dataset, args.gpu_id, epoch, config['epochs'],
                     itr, n_iters, elapsed, timeSince(epoch_start_time,itr/n_iters))
            for loss_name, loss_value in loss_records:
                log = log + loss_name + ':%.4f ' % (loss_value)
                if args.visual:
                    tb_writer.add_scalar(loss_name, loss_value, itr_global)
            logger.info(log)

            itr_start_time = time.time()

        if itr % args.valid_every == 0:
            valid_loader.epoch_init(config['batch_size'],
                                    config['diaglen'],
                                    1,
                                    shuffle=False)
            model.eval()
            loss_records = {}
示例#8
0
def train(embedder,
          encoder,
          hidvar,
          decoder,
          data_loader,
          vocab,
          n_iters,
          model_dir,
          p_teach_force=0.5,
          save_every=5000,
          sample_every=100,
          print_every=10,
          plot_every=100,
          learning_rate=0.00005):
    start = time.time()
    print_time_start = start
    plot_losses = []
    print_loss_total, print_loss_kl, print_loss_decoder = 0., 0., 0.  # Reset every print_every
    plot_loss_total = 0  # Reset every plot_every

    embedder_optimizer = optim.Adam(embedder.parameters(), lr=learning_rate)
    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    hidvar_optimizer = optim.Adam(hidvar.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)

    criterion = nn.NLLLoss(
        weight=None, size_average=True
    )  #, ignore_index=EOS_token) #average over a batch, ignore EOS

    data_iter = iter(data_loader)

    for it in range(1, n_iters + 1):
        q_batch, a_batch, q_lens, a_lens = data_iter.next()

        q_batch, a_batch, q_lens, a_lens = sortbatch(
            q_batch, a_batch, q_lens,
            a_lens)  # !!! important for pack sequence
        # sort sequences according to their lengthes in descending order

        kl_anneal_weight = (math.tanh((it - 3500) / 1000) + 1) / 2

        total_loss, kl_loss, decoder_loss = _train_step(
            q_batch, a_batch, q_lens, a_lens, embedder, encoder, hidvar,
            decoder, embedder_optimizer, encoder_optimizer, hidvar_optimizer,
            decoder_optimizer, criterion, kl_anneal_weight, p_teach_force)

        print_loss_total += total_loss
        print_loss_kl += kl_loss
        print_loss_decoder += decoder_loss
        plot_loss_total += total_loss
        if it % save_every == 0:
            if not os.path.exists('%slatentvar_%s/' % (model_dir, str(it))):
                os.makedirs('%slatentvar_%s/' % (model_dir, str(it)))
            torch.save(f='%slatentvar_%s/embedder.pckl' % (model_dir, str(it)),
                       obj=embedder)
            torch.save(f='%slatentvar_%s/encoder.pckl' % (model_dir, str(it)),
                       obj=encoder)
            torch.save(f='%slatentvar_%s/hidvar.pckl' % (model_dir, str(it)),
                       obj=hidvar)
            torch.save(f='%slatentvar_%s/decoder.pckl' % (model_dir, str(it)),
                       obj=decoder)
        if it % sample_every == 0:
            samp_idx = np.random.choice(len(q_batch), 4)  #pick 4 samples
            for i in samp_idx:
                question, target = q_batch[i].view(1,
                                                   -1), a_batch[i].view(1, -1)
                sampled_sentence = sample(embedder, encoder, hidvar, decoder,
                                          question, vocab)
                ivocab = {v: k for k, v in vocab.items()}
                print('question: %s' % (indexes2sent(
                    question.squeeze().numpy(), ivocab, ignore_tok=EOS_token)))
                print('target: %s' % (indexes2sent(
                    target.squeeze().numpy(), ivocab, ignore_tok=EOS_token)))
                print('predicted: %s' % (sampled_sentence))
        if it % print_every == 0:
            print_loss_total = print_loss_total / print_every
            print_loss_kl = print_loss_kl / print_every
            print_loss_decoder = print_loss_decoder / print_every
            print_time = time.time() - print_time_start
            print_time_start = time.time()
            print(
                'iter %d/%d  step_time:%ds  total_time:%s tol_loss: %.4f kl_loss: %.4f dec_loss: %.4f'
                % (it, n_iters, print_time, timeSince(start, it / n_iters),
                   print_loss_total, print_loss_kl, print_loss_decoder))
            print_loss_total, print_loss_kl, print_loss_decoder = 0, 0, 0
        if it % plot_every == 0:
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0
示例#9
0
for iter in range(1, n_iters + 1):
    training_pair = training_pairs[iter - 1]
    input_tensor = training_pair[0]
    target_tensor = training_pair[1]

    loss = train.train(input_tensor, target_tensor, encoder, decoder,
                       encoder_optimizer, decoder_optimizer,
                       teacher_forcing_rate, criterion, MAX_LENGTH)
    print_loss_total += loss
    plot_loss_total += loss

    if iter % print_every == 0:
        print_loss_avg = print_loss_total / print_every
        print_loss_total = 0
        print('%s (%d %d%%) %.4f' %
              (helper.timeSince(start, iter / n_iters), iter,
               iter / n_iters * 100, print_loss_avg))

    if iter % plot_every == 0:
        plot_loss_avg = plot_loss_total / plot_every
        plot_losses.append(plot_loss_avg)
        plot_loss_total = 0

helper.showPlot(plot_losses)

# for i in range(n_evaluations):
#     pair = random.choice(pairs)
#     print('>', pair[0])
#     print('=', pair[1])
#     output_words = train.evaluate(
#         encoder, decoder, lang, pair[0], MAX_LENGTH)
示例#10
0
文件: train.py 项目: jianyiyang5/mt
def trainIters(device,
               pairs,
               input_lang,
               output_lang,
               encoder,
               decoder,
               n_iters,
               print_every=1000,
               plot_every=100,
               save_every=10000,
               learning_rate=0.01,
               save_dir='checkpoints'):
    start = time.time()
    plot_losses = []
    print_loss_total = 0  # Reset every print_every
    plot_loss_total = 0  # Reset every plot_every

    encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)
    training_pairs = [
        tensorsFromPair(random.choice(pairs), input_lang, output_lang, device)
        for i in range(n_iters)
    ]
    criterion = nn.NLLLoss()

    for iter in range(1, n_iters + 1):
        training_pair = training_pairs[iter - 1]
        input_tensor = training_pair[0]
        target_tensor = training_pair[1]

        loss = train(device, input_tensor, target_tensor, encoder, decoder,
                     encoder_optimizer, decoder_optimizer, criterion)
        print_loss_total += loss
        plot_loss_total += loss

        if iter % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print('%s (%d %d%%) %.4f' %
                  (timeSince(start, iter / n_iters), iter,
                   iter / n_iters * 100, print_loss_avg))

        if iter % plot_every == 0:
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0

        # Save checkpoint
        if iter % save_every == 0:
            directory = save_dir
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save(
                {
                    'iteration': iter,
                    'en': encoder.state_dict(),
                    'de': decoder.state_dict(),
                    'en_opt': encoder_optimizer.state_dict(),
                    'de_opt': decoder_optimizer.state_dict(),
                    'loss': loss,
                    'input_lang': input_lang.__dict__,
                    'output_lang': output_lang.__dict__
                },
                os.path.join(directory, '{}_{}.tar'.format(iter,
                                                           'checkpoint')))

    showPlot(plot_losses)
示例#11
0
def trainIters(encoder, decoder, n_epoch, max_length, learning_rate=0.01):
    start = time.time()
    print_train_loss_total = 0  # Reset every print_every
    print_valid_loss_total = 0

    encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)

    training_pairs = [
        p.tensorsFromPair(train_input_lang, train_output_lang,
                          train_pairList[i])
        for i in range(0, len(train_pairList))
    ]
    validation_pairs = [
        p.tensorsFromPair(valid_input_lang, valid_output_lang,
                          valid_pairList[i])
        for i in range(0, len(valid_pairList))
    ]
    test_pairs = [
        p.tensorsFromPair(test_input_lang, test_output_lang, test_pairList[i])
        for i in range(0, len(test_pairList))
    ]

    criterion = nn.NLLLoss()

    for i in range(1, n_epoch + 1):
        # one epoch
        print(i, " Epoch...")

        # Training loop
        print("Training loop")
        for j in tqdm(range(1, len(train_pairList) + 1)):
            encoder1.train()
            attn_decoder1.train()

            training_pair = training_pairs[j - 1]
            input_tensor = training_pair[0]
            target_tensor = training_pair[1]

            train_loss = train(input_tensor, target_tensor, encoder, decoder,
                               encoder_optimizer, decoder_optimizer,
                               max_length, criterion)
            print_train_loss_total += train_loss

            if j % 10 == 0:
                writer.add_scalar('training loss', train_loss / 10,
                                  (i - 1) * len(train_pairList) + (j - 1))

        # Validation loop
        print("Validation loop")
        for j in tqdm(range(1, len(valid_pairList) + 1)):
            validation_pair = validation_pairs[j - 1]
            input_tensor = validation_pair[0]
            target_tensor = validation_pair[1]

            encoder1.eval()
            attn_decoder1.eval()

            val_loss = validate(input_tensor, target_tensor, encoder, decoder,
                                max_length, criterion)
            print_valid_loss_total += val_loss

            if j % 10 == 0:
                writer.add_scalar('validation loss', val_loss / 10,
                                  (i - 1) * len(valid_pairList) + (j - 1))

        # Print train loss Status
        print_train_loss_avg = print_train_loss_total / len(train_pairList)
        print_train_loss_total = 0
        print('Time elapsed : %s \tPercentage: %d%% \tml_loss : %.4f' %
              (h.timeSince(start, i / n_epoch), i / n_epoch * 100,
               print_train_loss_avg))

        # [input & target] pair [0][0]-src [0][1]-tgt
        train_hypothesis_list, train_reference_list = make_hypothesis_reference(
            encoder, decoder, train_pairList, train_input_lang, max_length)
        valid_hypothesis_list, valid_reference_list = make_hypothesis_reference(
            encoder, decoder, valid_pairList, valid_input_lang, max_length)

        bleu, rouge_l, _, precision, recall, f1 = e.eval_accuracies(
            train_hypothesis_list, train_reference_list)
        print("\nTraining Data Evaluation")
        print(
            "\nbleu : %f\trouge_l : %f\tprecision : %f\trecall : %f\tf1 : %f\t"
            % (bleu, rouge_l, precision, recall, f1))
        print("\n")
        # Random Evaluation

        evaluateRandomly(encoder, decoder, train_pairList, train_input_lang,
                         max_length)
        bleu, rouge_l, _, precision, recall, f1 = e.eval_accuracies(
            valid_hypothesis_list, valid_reference_list)
        print("\nValidation Data Evaluation")
        print(
            "\nbleu : %f\trouge_l : %f\tprecision : %f\trecall : %f\tf1 : %f\t"
            % (bleu, rouge_l, precision, recall, f1))
        print("\n")

        # Random Evaluation
        evaluateRandomly(encoder, decoder, valid_pairList, valid_input_lang,
                         max_length)
示例#12
0
def train(embedder, encoder, topic_picker, first_word_picker, decoder,
          learning_rate, data_loader, topic_size, voca_size, save_every,
          sample_every, print_every, plot_every, model_dir, vocab):
    start = time.time()
    print_time_start = start
    plot_losses = []
    print_loss_total, print_loss_topic, print_loss_word, print_loss_decoder = 0., 0., 0., 0.

    embedder_optimizer = optim.Adam(embedder.parameters(), lr=learning_rate)
    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    topic_picker_optimizer = optim.Adam(topic_picker.parameters(),
                                        lr=learning_rate)
    first_word_picker_optimizer = optim.Adam(first_word_picker.parameters(),
                                             lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)

    nll_loss = nn.NLLLoss()
    CE_loss = torch.nn.CrossEntropyLoss()

    data_iter = iter(data_loader)

    for it in range(1, n_iters + 1):
        batch_q, batch_ans, batch_topic = data_iter.next()

        #anneal weight?
        topic_loss_weight = 0.2
        word_loss_weight = 0.2

        topic_loss, first_word_loss, decoder_loss, total_loss = train_once(
            embedder, encoder, topic_picker, first_word_picker, decoder,
            batch_q, batch_ans, batch_topic, topic_size, voca_size,
            topic_loss_weight, word_loss_weight, embedder_optimizer,
            encoder_optimizer, topic_picker_optimizer,
            first_word_picker_optimizer, decoder_optimizer, nll_loss, CE_loss)

        print_loss_total += total_loss
        print_loss_decoder += decoder_loss
        print_loss_word += first_word_loss
        print_loss_topic += topic_loss

        if it % save_every == 0:
            if not os.path.exists('%sthree_net_%s/' % (model_dir, str(it))):
                os.makedirs('%sthree_net_%s/' % (model_dir, str(it)))
            torch.save(f='%sthree_net_%s/embedder.pckl' % (model_dir, str(it)),
                       obj=embedder)
            torch.save(f='%sthree_net_%s/encoder.pckl' % (model_dir, str(it)),
                       obj=encoder)
            torch.save(f='%sthree_net_%s/topic_picker.pckl' %
                       (model_dir, str(it)),
                       obj=topic_picker)
            torch.save(f='%sthree_net_%s/first_word_picker.pckl' %
                       (model_dir, str(it)),
                       obj=first_word_picker)
            torch.save(f='%sthree_net_%s/decoder.pckl' % (model_dir, str(it)),
                       obj=decoder)
        if it % sample_every == 0:
            samp_idx = np.random.choice(len(batch_q), 4)  #pick 4 samples
            for i in samp_idx:
                question, target = batch_q[i].view(1, -1), batch_ans[i].view(
                    1, -1)
                sampled_sentence = sample(embedder, encoder, topic_picker,
                                          first_word_picker, decoder, question,
                                          vocab)
                ivocab = {v: k for k, v in vocab.items()}
                print('question: %s' % (indexes2sent(
                    question.squeeze().numpy(), ivocab, ignore_tok=EOS_token)))
                print('target: %s' % (indexes2sent(
                    target.squeeze().numpy(), ivocab, ignore_tok=EOS_token)))
                print('predicted: %s' % (sampled_sentence))
        #print and plot
        if it % print_every == 0:
            print_loss_total = print_loss_total / print_every
            print_loss_word = print_loss_word / print_every
            print_loss_topic = print_loss_topic / print_every
            print_loss_decoder = print_loss_decoder / print_every
            print_time = time.time() - print_time_start
            print_time_start = time.time()
            print(
                'iter %d/%d  step_time:%ds  total_time:%s total_loss: %.4f topic_loss: %.4f first_word_loss: %.4f dec_loss: %.4f'
                % (it, n_iters, print_time, timeSince(
                    start, it / n_iters), print_loss_total, print_loss_topic,
                   print_loss_word, print_loss_decoder))

            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            print_loss_total, print_loss_topic, print_loss_word, print_loss_decoder = 0., 0., 0., 0.

    showPlot(plot_losses)
示例#13
0
def trainIters(device,
               pairs,
               input_lang,
               output_lang,
               encoder,
               decoder,
               batch_size,
               n_iters,
               print_every=250,
               plot_every=250,
               save_every=2000,
               learning_rate=0.01,
               save_dir='checkpoints'):
    start = time.time()
    plot_losses = []
    print_loss_total = 0  # Reset every print_every
    plot_loss_total = 0  # Reset every plot_every

    encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)
    training_pairs = [
        tensorsFromPair(random.choice(pairs), input_lang, output_lang, device)
        for i in range(n_iters)
    ]
    training_batches = [
        batch2TrainData(input_lang, output_lang,
                        [random.choice(pairs) for _ in range(batch_size)])
        for _ in range(n_iters)
    ]

    for iter in range(1, n_iters + 1):
        training_batch = training_batches[iter - 1]
        # Extract fields from batch
        input_variable, lengths, target_variable, mask, max_target_len = training_batch

        loss = train(device, input_variable, lengths, target_variable, mask,
                     max_target_len, encoder, decoder, encoder_optimizer,
                     decoder_optimizer, batch_size)
        print_loss_total += loss
        plot_loss_total += loss

        if iter % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print('%s (%d %d%%) %.4f' %
                  (timeSince(start, iter / n_iters), iter,
                   iter / n_iters * 100, print_loss_avg))

        if iter % plot_every == 0:
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0

        # Save checkpoint
        if iter % save_every == 0:
            directory = save_dir
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save(
                {
                    'iteration': iter,
                    'en': encoder.state_dict(),
                    'de': decoder.state_dict(),
                    'en_opt': encoder_optimizer.state_dict(),
                    'de_opt': decoder_optimizer.state_dict(),
                    'loss': loss,
                    'input_lang': input_lang.__dict__,
                    'output_lang': output_lang.__dict__
                },
                os.path.join(directory, '{}_{}.tar'.format(iter,
                                                           'checkpoint')))

    showPlot(plot_losses)
示例#14
0
文件: train.py 项目: rxe101/deepAPI
def train(args):
    timestamp = datetime.now().strftime('%Y%m%d%H%M')
    # LOG #
    logger = logging.getLogger(__name__)
    logging.basicConfig(
        level=logging.DEBUG, format="%(message)s"
    )  #,format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
    tb_writer = None
    if args.visual:
        # make output directory if it doesn't already exist
        os.makedirs(f'./output/{args.model}/{args.expname}/{timestamp}/models',
                    exist_ok=True)
        os.makedirs(
            f'./output/{args.model}/{args.expname}/{timestamp}/temp_results',
            exist_ok=True)
        fh = logging.FileHandler(
            f"./output/{args.model}/{args.expname}/{timestamp}/logs.txt")
        # create file handler which logs even debug messages
        logger.addHandler(fh)  # add the handlers to the logger
        tb_writer = SummaryWriter(
            f"./output/{args.model}/{args.expname}/{timestamp}/logs/")
        # save arguments
        json.dump(
            vars(args),
            open(f'./output/{args.model}/{args.expname}/{timestamp}/args.json',
                 'w'))

    # Device #
    if args.gpu_id < 0:
        device = torch.device("cuda")
    else:
        device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available(
        ) and args.gpu_id > -1 else "cpu")
    print(device)
    n_gpu = torch.cuda.device_count() if args.gpu_id < 0 else 1
    print(f"num of gpus:{n_gpu}")
    # Set the random seed manually for reproducibility.
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    def save_model(model, epoch, timestamp):
        """Save model parameters to checkpoint"""
        os.makedirs(f'./output/{args.model}/{args.expname}/{timestamp}/models',
                    exist_ok=True)
        ckpt_path = f'./output/{args.model}/{args.expname}/{timestamp}/models/model_epo{epoch}.pkl'
        print(f'Saving model parameters to {ckpt_path}')
        torch.save(model.state_dict(), ckpt_path)

    def load_model(model, epoch, timestamp):
        """Load parameters from checkpoint"""
        ckpt_path = f'./output/{args.model}/{args.expname}/{timestamp}/models/model_epo{epoch}.pkl'
        print(f'Loading model parameters from {ckpt_path}')
        model.load_state_dict(torch.load(checkpoint))

    config = getattr(configs, 'config_' + args.model)()

    ###############################################################################
    # Load data
    ###############################################################################
    train_set = APIDataset(args.data_path + 'train.desc.h5',
                           args.data_path + 'train.apiseq.h5',
                           config['max_sent_len'])
    valid_set = APIDataset(args.data_path + 'test.desc.h5',
                           args.data_path + 'test.apiseq.h5',
                           config['max_sent_len'])
    print("Loaded data!")

    ###############################################################################
    # Define the models
    ###############################################################################
    model = getattr(models, args.model)(config)
    if args.reload_from >= 0:
        load_model(model, args.reload_from)
    model = model.to(device)

    ###############################################################################
    # Training
    ###############################################################################
    logger.info("Training...")
    itr_global = 1
    start_epoch = 1 if args.reload_from == -1 else args.reload_from + 1
    for epoch in range(start_epoch, config['epochs'] + 1):

        epoch_start_time = time.time()
        itr_start_time = time.time()

        # shuffle (re-define) data between epochs
        train_loader = torch.utils.data.DataLoader(
            dataset=train_set,
            batch_size=config['batch_size'],
            shuffle=True,
            num_workers=1)
        train_data_iter = iter(train_loader)
        n_iters = train_data_iter.__len__()

        itr = 1
        while True:  # loop through all batches in training data
            model.train()
            try:
                descs, apiseqs, desc_lens, api_lens = train_data_iter.next()
            except StopIteration:  # end of epoch
                break
            batch = [
                tensor.to(device)
                for tensor in [descs, desc_lens, apiseqs, api_lens]
            ]
            loss_AE = model.train_AE(*batch)

            if itr % args.log_every == 0:
                elapsed = time.time() - itr_start_time
                log = '%s-%s|@gpu%d epo:[%d/%d] iter:[%d/%d] step_time:%ds elapsed:%s \n                      '\
                %(args.model, args.expname, args.gpu_id, epoch, config['epochs'],
                         itr, n_iters, elapsed, timeSince(epoch_start_time,itr/n_iters))
                for loss_name, loss_value in loss_AE.items():
                    log = log + loss_name + ':%.4f ' % (loss_value)
                    if args.visual:
                        tb_writer.add_scalar(loss_name, loss_value, itr_global)
                logger.info(log)

                itr_start_time = time.time()

            if itr % args.valid_every == 0:
                valid_loader = torch.utils.data.DataLoader(
                    dataset=valid_set,
                    batch_size=config['batch_size'],
                    shuffle=True,
                    num_workers=1)
                model.eval()
                loss_records = {}

                for descs, apiseqs, desc_lens, api_lens in valid_loader:
                    batch = [
                        tensor.to(device)
                        for tensor in [descs, desc_lens, apiseqs, api_lens]
                    ]
                    valid_loss = model.valid(*batch)
                    for loss_name, loss_value in valid_loss.items():
                        v = loss_records.get(loss_name, [])
                        v.append(loss_value)
                        loss_records[loss_name] = v

                log = 'Validation '
                for loss_name, loss_values in loss_records.items():
                    log = log + loss_name + ':%.4f  ' % (np.mean(loss_values))
                    if args.visual:
                        tb_writer.add_scalar(loss_name, np.mean(loss_values),
                                             itr_global)
                logger.info(log)

            itr += 1
            itr_global += 1

            if itr_global % args.eval_every == 0:  # evaluate the model in the develop set
                model.eval()
                save_model(model, itr_global,
                           timestamp)  # save model after each epoch

                valid_loader = torch.utils.data.DataLoader(dataset=valid_set,
                                                           batch_size=1,
                                                           shuffle=False,
                                                           num_workers=1)
                vocab_api = load_dict(args.data_path + 'vocab.apiseq.json')
                vocab_desc = load_dict(args.data_path + 'vocab.desc.json')
                metrics = Metrics()

                os.makedirs(
                    f'./output/{args.model}/{args.expname}/{timestamp}/temp_results',
                    exist_ok=True)
                f_eval = open(
                    f"./output/{args.model}/{args.expname}/{timestamp}/temp_results/iter{itr_global}.txt",
                    "w")
                repeat = 1
                decode_mode = 'sample'
                recall_bleu, prec_bleu = evaluate(model, metrics, valid_loader,
                                                  vocab_desc, vocab_api,
                                                  repeat, decode_mode, f_eval)

                if args.visual:
                    tb_writer.add_scalar('recall_bleu', recall_bleu,
                                         itr_global)
                    tb_writer.add_scalar('prec_bleu', prec_bleu, itr_global)

        # end of epoch ----------------------------
        model.adjust_lr()