Esempio n. 1
0
def main(pretrain_checkpoint_dir,
         train_summary_writer,
         vocab: Vocab,
         dataloader: DataLoader,
         batch_size: int = 64,
         embedding_dim: int = 256,
         seq_length: int = 3000,
         gen_seq_len: int = 3000,
         gen_rnn_units: int = 1024,
         disc_rnn_units: int = 1024,
         epochs: int = 40000,
         pretrain_epochs: int = 4000,
         learning_rate: float = 1e-4,
         rollout_num: int = 2,
         gen_pretrain: bool = False,
         disc_pretrain: bool = False,
         load_gen_weights: bool = False,
         load_disc_weights: bool = False,
         save_gen_weights: bool = True,
         save_disc_weights: bool = True,
         disc_steps: int = 3):
    gen = Generator(dataloader=dataloader,
                    vocab=vocab,
                    batch_size=batch_size,
                    embedding_dim=embedding_dim,
                    seq_length=seq_length,
                    checkpoint_dir=pretrain_checkpoint_dir,
                    rnn_units=gen_rnn_units,
                    start_token=0,
                    learning_rate=learning_rate)
    if load_gen_weights:
        gen.load_weights()
    if gen_pretrain:
        gen_pre_trainer = GenPretrainer(gen,
                                        dataloader=dataloader,
                                        vocab=vocab,
                                        pretrain_epochs=pretrain_epochs,
                                        tb_writer=train_summary_writer,
                                        learning_rate=learning_rate)
        print('Start pre-training generator...')
        gen_pre_trainer.pretrain(gen_seq_len=gen_seq_len,
                                 save_weights=save_gen_weights)

    disc = Discriminator(vocab_size=vocab.vocab_size,
                         embedding_dim=embedding_dim,
                         rnn_units=disc_rnn_units,
                         batch_size=batch_size,
                         checkpoint_dir=pretrain_checkpoint_dir,
                         learning_rate=learning_rate)
    if load_disc_weights:
        disc.load_weights()
    if disc_pretrain:
        disc_pre_trainer = DiscPretrainer(disc,
                                          gen,
                                          dataloader=dataloader,
                                          vocab=vocab,
                                          pretrain_epochs=pretrain_epochs,
                                          tb_writer=train_summary_writer,
                                          learning_rate=learning_rate)
        print('Start pre-training discriminator...')
        disc_pre_trainer.pretrain(save_disc_weights)
    rollout = Rollout(generator=gen,
                      discriminator=disc,
                      vocab=vocab,
                      batch_size=batch_size,
                      seq_length=seq_length,
                      rollout_num=rollout_num)

    with tqdm(desc='Epoch: ', total=epochs, dynamic_ncols=True) as pbar:
        for epoch in range(epochs):
            fake_samples = gen.generate()
            rewards = rollout.get_reward(samples=fake_samples)
            gen_loss = gen.train_step(fake_samples, rewards)
            real_samples, _ = dataloader.get_batch(shuffle=shuffle,
                                                   seq_length=seq_length,
                                                   batch_size=batch_size,
                                                   training=True)
            disc_loss = 0
            for i in range(disc_steps):
                disc_loss += disc.train_step(fake_samples,
                                             real_samples) / disc_steps

            with train_summary_writer.as_default():
                tf.summary.scalar('gen_train_loss', gen_loss, step=epoch)
                tf.summary.scalar('disc_train_loss', disc_loss, step=epoch)
                tf.summary.scalar('total_train_loss',
                                  disc_loss + gen_loss,
                                  step=epoch)

            pbar.set_postfix(gen_train_loss=tf.reduce_mean(gen_loss),
                             disc_train_loss=tf.reduce_mean(disc_loss),
                             total_train_loss=tf.reduce_mean(gen_loss +
                                                             disc_loss))

            if (epoch + 1) % 5 == 0 or (epoch + 1) == 1:
                print('保存weights...')
                # 保存weights
                gen.model.save_weights(gen.checkpoint_prefix)
                disc.model.save_weights(disc.checkpoint_prefix)
                # gen.model.save('gen.h5')
                # disc.model.save('disc.h5')

                # 测试 disc
                fake_samples = gen.generate(gen_seq_len)
                real_samples = dataloader.get_batch(shuffle=shuffle,
                                                    seq_length=gen_seq_len,
                                                    batch_size=batch_size,
                                                    training=False)
                disc_loss = disc.test_step(fake_samples, real_samples)

                # 测试 gen
                gen_loss = gen.test_step()

                # 得到bleu_score
                # bleu_score = get_bleu_score(true_seqs=real_samples, genned_seqs=fake_samples)
                genned_sentences = vocab.extract_seqs(fake_samples)
                # print(genned_sentences)
                # print(vocab.idx2char[fake_samples[0]])

                # 记录 test losses
                with train_summary_writer.as_default():
                    tf.summary.scalar('disc_test_loss',
                                      tf.reduce_mean(disc_loss),
                                      step=epoch)
                    tf.summary.scalar('gen_test_loss',
                                      tf.reduce_mean(gen_loss),
                                      step=epoch)
                    # tf.summary.scalar('bleu_score', tf.reduce_mean(bleu_score), step=epoch + gen_pretrain * pretrain_epochs)

            pbar.update()
Esempio n. 2
0
                    discriminator=disc,
                    batch_size=batch_size,
                    embedding_size=embedding_dim,
                    sequence_length=seq_len,
                    start_token=start_token,
                    rollout_num=rollout_num)


for epoch in range(EPOCHS):
    fake_samples = gen.generate()
    rewards = rollout.get_reward(samples=fake_samples)
    gen_loss = gen.train_step(fake_samples, rewards)
    real_samples, _ = get_batch(seq_len, batch_size)
    disc_loss = 0
    for i in range(disc_steps):
        disc_loss += disc.train_step(fake_samples, real_samples)/disc_steps

    with train_summary_writer.as_default():
        tf.summary.scalar('gen_train_loss', gen_loss, step=epoch)
        tf.summary.scalar('disc_train_loss', disc_loss, step=epoch)
        tf.summary.scalar('total_train_loss', disc_loss + gen_loss, step=epoch)

    if epoch % 7 == 0 or epoch == 0:
        disc.model.save_weights(disc.checkpoint_prefix)
        gen.model.save_weights(disc.checkpoint_prefix)
        samples = gen.generate(gen_seq_len)
        genned_songs = extract_songs(samples)
        bleu_score = get_bleu_score(genned_songs)
        # print(idx2char[samples[0]])
        gen.model.save_weights(gen.checkpoint_prefix)
class AVGRunner:
    def __init__(self, restore_path, mode="train"):
        self.notifier = Notifier()
        self.logger = Logger(
            path=c.SAVE_PATH,
            name=f"train{datetime.datetime.now().strftime('%Y%m%d%H%M')}.log")
        self.global_step = 0
        self.num_steps = c.MAX_ITER

        tf_config = tf.ConfigProto(allow_soft_placement=True)
        tf_config.gpu_options.allow_growth = True
        self.sess = tf.Session(config=tf_config)
        self.summary_writer = tf.summary.FileWriter(c.SAVE_SUMMARY,
                                                    graph=self.sess.graph)

        if mode == "train":
            self._out_seq = c.OUT_SEQ
            self._h = c.H
            self._w = c.W
        else:
            # self._batch = 1
            self._out_seq = c.PREDICT_LENGTH
            self._h = c.PREDICTION_H
            self._w = c.PREDICTION_W

        self._in_seq = c.IN_SEQ
        self._batch = c.BATCH_SIZE

        self.g_model = Generator(self.sess, self.summary_writer, mode=mode)
        if c.ADVERSARIAL and mode == "train":
            self.d_model = Discriminator(self.sess, self.summary_writer)
        else:
            self.d_model = None

        self.saver = tf.train.Saver(max_to_keep=0)

        if restore_path is not None:
            self.saver.restore(self.sess, restore_path)
        else:
            self.sess.run(tf.global_variables_initializer())

    def get_train_batch(self, iterator):
        data, *_ = iterator.sample(batch_size=self._batch)
        in_data = data[:, :self._in_seq, :, :, :]

        if c.IN_CHANEL == 3:
            gt_data = data[:,
                           self._in_seq:self._in_seq + self._out_seq, :, :, :]
        elif c.IN_CHANEL == 1:
            gt_data = data[:,
                           self._in_seq:self._in_seq + self._out_seq, :, :, :]
        else:
            raise NotImplementedError

        if c.NORMALIZE:
            in_data = normalize_frames(in_data)
            gt_data = normalize_frames(gt_data)
        in_data = crop_img(in_data)
        gt_data = crop_img(gt_data)
        return in_data, gt_data

    def train(self):
        train_iter = Iterator(time_interval=c.RAINY_TRAIN,
                              sample_mode="random",
                              seq_len=self._in_seq + self._out_seq,
                              stride=1)
        while self.global_step < c.MAX_ITER:

            if c.ADVERSARIAL and self.global_step > c.ADV_INVOLVE:
                print("start d_model")
                in_data, gt_data = self.get_train_batch(train_iter)
                d_loss, *_ = self.d_model.train_step(in_data, gt_data,
                                                     self.g_model)
            else:
                d_loss = 0

            in_data, gt_data = self.get_train_batch(train_iter)
            g_loss, mse, gd_loss, global_step = self.g_model.train_step(
                in_data, gt_data, self.d_model)

            self.global_step = global_step

            self.logger.info(f"Iter {self.global_step}: \n\t "
                             f"g_loss: {g_loss:.4f} \n\t"
                             f"mse: {mse:.4f} \n\t "
                             f"mse_real: {gd_loss:.4f} \n\t"
                             f"d_loss: {d_loss:.4f}")

            if (self.global_step + 1) % c.SAVE_ITER == 0:
                self.save_model()

            if (self.global_step + 1) % c.VALID_ITER == 0:
                self.run_benchmark(global_step, mode="Valid")

    def valid(self):
        test_iter = Clip_Iterator(c.VALID_DIR_CLIPS)
        evaluator = Evaluator(self.global_step)
        i = 0
        for data in test_iter.sample_valid(self._batch):
            in_data = data[:, :self._in_seq, ...]
            if c.IN_CHANEL == 3:
                gt_data = data[:,
                               self._in_seq:self._in_seq + self._out_seq, :, :,
                               1:-1]
            elif c.IN_CHANEL == 1:
                gt_data = data[:, self._in_seq:self._in_seq + self._out_seq,
                               ...]
            else:
                raise NotImplementedError
            if c.NORMALIZE:
                in_data = normalize_frames(in_data)
                gt_data = normalize_frames(gt_data)

            mse, mae, gdl, pred = self.g_model.valid_step(in_data, gt_data)
            evaluator.evaluate(gt_data, pred)
            self.logger.info(f"Iter {self.global_step} {i}: \n\t "
                             f"mse:{mse:.4f} \n\t "
                             f"mae:{mae:.4f} \n\t "
                             f"gdl:{gdl:.4f}")
            i += 1
        evaluator.done()

    def save_model(self):
        from os.path import join
        save_path = self.saver.save(self.sess,
                                    join(c.SAVE_MODEL, "model.ckpt"),
                                    global_step=self.global_step)
        self.logger.info("Model saved in path: %s" % save_path)

    def run_benchmark(self, iter, mode="Test"):
        if mode == "Valid":
            time_interval = c.RAINY_VALID
            stride = 5
        else:
            time_interval = c.RAINY_TEST
            stride = 1
        test_iter = Iterator(time_interval=time_interval,
                             sample_mode="sequent",
                             seq_len=self._in_seq + self._out_seq,
                             stride=1)
        evaluator = Evaluator(iter, length=self._out_seq, mode=mode)
        i = 1
        while not test_iter.use_up:
            data, date_clip, *_ = test_iter.sample(batch_size=self._batch)
            in_data = np.zeros(shape=(self._batch, self._in_seq, self._h,
                                      self._w, c.IN_CHANEL))
            gt_data = np.zeros(shape=(self._batch, self._out_seq, self._h,
                                      self._w, 1))
            if type(data) == type([]):
                break
            in_data[...] = data[:, :self._in_seq, :, :, :]

            if c.IN_CHANEL == 3:
                gt_data[...] = data[:, self._in_seq:self._in_seq +
                                    self._out_seq, :, :, :]
            elif c.IN_CHANEL == 1:
                gt_data[...] = data[:, self._in_seq:self._in_seq +
                                    self._out_seq, :, :, :]
            else:
                raise NotImplementedError

            # in_date = date_clip[0][:c.IN_SEQ]

            if c.NORMALIZE:
                in_data = normalize_frames(in_data)
                gt_data = normalize_frames(gt_data)
            in_data = crop_img(in_data)
            gt_data = crop_img(gt_data)
            mse, mae, gdl, pred = self.g_model.valid_step(in_data, gt_data)
            evaluator.evaluate(gt_data, pred)
            self.logger.info(
                f"Iter {iter} {i}: \n\t mse:{mse} \n\t mae:{mae} \n\t gdl:{gdl}"
            )
            i += 1
            if i % stride == 0:
                if c.IN_CHANEL == 3:
                    in_data = in_data[:, :, :, :, 1:-1]

                for b in range(self._batch):
                    predict_date = date_clip[b][self._in_seq - 1]
                    self.logger.info(f"Save {predict_date} results")
                    if mode == "Valid":
                        save_path = os.path.join(
                            c.SAVE_VALID, str(iter),
                            predict_date.strftime("%Y%m%d%H%M"))
                    else:
                        save_path = os.path.join(
                            c.SAVE_TEST, str(iter),
                            predict_date.strftime("%Y%m%d%H%M"))

                    path = os.path.join(save_path, "in")
                    save_png(in_data[b], path)

                    path = os.path.join(save_path, "pred")
                    save_png(pred[b], path)

                    path = os.path.join(save_path, "out")
                    save_png(gt_data[b], path)
        evaluator.done()
        self.notifier.eval(iter, evaluator.result_path)