Esempio n. 1
0
def create_initial_model(name):
    full_filename = os.path.join(conf['MODEL_DIR'], name) + ".h5"
    if os.path.isfile(full_filename):
        model = load_model(full_filename, custom_objects={'loss': loss})
        return model

    model = build_model(name)

    # Save graph in tensorboard. This graph has the name scopes making it look
    # good in tensorboard, the loaded models will not have the scopes.
    tf_callback = TensorBoard(log_dir=os.path.join(conf['LOG_DIR'], name),
                              histogram_freq=0,
                              batch_size=1,
                              write_graph=True,
                              write_grads=False)
    tf_callback.set_model(model)
    tf_callback.on_epoch_end(0)
    tf_callback.on_train_end(0)

    from self_play import self_play
    self_play(model,
              n_games=conf['N_GAMES'],
              mcts_simulations=conf['MCTS_SIMULATIONS'])
    model.save(full_filename)
    best_filename = os.path.join(conf['MODEL_DIR'], 'best_model.h5')
    model.save(best_filename)
    return model
Esempio n. 2
0
    def train(self, epochs, batch_size=128, save_interval=50):

        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        start_time = datetime.datetime.now()

        tensorboard = TensorBoard(batch_size=batch_size, write_grads=True)
        tensorboard.set_model(self.combined)

        def named_logs(model, logs):
            result = {}
            for l in zip(model.metrics_names, logs):
                result[l[0]] = l[1]
            return result

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random half of images
            imgs = self.data_loader.load_data(batch_size)

            # Sample noise and generate a batch of new images
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            gen_imgs = self.generator.predict(noise)

            # Train the discriminator (real classified as ones and generated as zeros)
            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            # Train the generator (wants discriminator to mistake images as real)
            g_loss = self.combined.train_on_batch(noise, valid)

            elapsed_time = datetime.datetime.now() - start_time

            # Plot the progress
            print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f] time: %s" %
                  (epoch, d_loss[0], 100 * d_loss[1], g_loss, elapsed_time))

            tensorboard.on_epoch_end(epoch, named_logs(self.combined,
                                                       [g_loss]))

            # If at save interval => save generated image samples
            if epoch % save_interval == 0:
                self.save_imgs(epoch)
                self.combined.save_weights(
                    f"saved_model/{self.dataset_name}/{epoch}.h5")

        self.save_imgs(epochs - 1)
        self.combined.save_weights(
            f"saved_model/{self.dataset_name}/{epoch}.h5")
Esempio n. 3
0
    def train(self, epochs, batch_size, sample_interval):
        def named_logs(model, logs):
            result = {}
            for l in zip(model.metrics_names, logs):
                result[l[0]] = l[1]
            return result

        start_time = datetime.datetime.now()

        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        max_iter = int(self.n_data / batch_size)
        os.makedirs(f"{self.backup_dir}/logs/{self.time}", exist_ok=True)
        tensorboard = TensorBoard(f"{self.backup_dir}/logs/{self.time}")
        tensorboard.set_model(self.generator)

        os.makedirs(f"{self.backup_dir}/models/{self.time}/", exist_ok=True)
        with open(
                f"{self.backup_dir}/models/{self.time}/generator_architecture.json",
                "w") as f:
            f.write(self.generator.to_json())
        print(
            f"\nbatch size : {batch_size} | num_data : {self.n_data} | max iteration : {max_iter} | time : {self.time} \n"
        )
        for epoch in range(1, epochs + 1):
            for iter in range(max_iter):
                # ------------------
                #  Train Generator
                # ------------------
                ref_imgs = self.dl.load_data(batch_size)

                noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
                gen_imgs = self.generator.predict(noise)
                make_trainable(self.discriminator, True)
                d_loss_real = self.discriminator.train_on_batch(
                    ref_imgs, valid * 0.9)  # label smoothing *0.9
                d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

                make_trainable(self.discriminator, False)

                logs = self.combined.train_on_batch([noise], [valid])
                tensorboard.on_epoch_end(iter,
                                         named_logs(self.combined, [logs]))

                if iter % (sample_interval // 10) == 0:
                    elapsed_time = datetime.datetime.now() - start_time
                    print(
                        f"epoch:{epoch} | iter : {iter} / {max_iter} | time : {elapsed_time} | g_loss : {logs} | d_loss : {d_loss} "
                    )

                if (iter + 1) % sample_interval == 0:
                    self.sample_images(epoch, iter + 1)

            # save weights after every epoch
            self.generator.save_weights(
                f"{self.backup_dir}/models/{self.time}/generator_epoch{epoch}_weights.h5"
            )
Esempio n. 4
0
    def train(self, epochs, batch_size, sample_interval):
        def named_logs(model, logs):
            result = {}
            for l in zip(model.metrics_names, logs):
                result[l[0]] = l[1]
            return result

        start_time = datetime.datetime.now()

        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        max_iter = int(self.n_data / batch_size)
        os.makedirs('./logs/%s' % self.time, exist_ok=True)
        tensorboard = TensorBoard('./logs/%s' % self.time)
        tensorboard.set_model(self.generator)

        os.makedirs('models/%s' % self.time, exist_ok=True)
        with open('models/%s/%s_architecture.json' % (self.time, 'generator'),
                  'w') as f:
            f.write(self.generator.to_json())
        print(
            "\nbatch size : %d | num_data : %d | max iteration : %d | time : %s \n"
            % (batch_size, self.n_data, max_iter, self.time))
        for epoch in range(1, epochs + 1):
            for iter in range(max_iter):
                # ------------------
                #  Train Generator
                # ------------------
                ref_imgs = self.data_loader.load_data(batch_size)

                noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
                gen_imgs = self.generator.predict(noise)
                make_trainable(self.discriminator, True)
                d_loss_real = self.discriminator.train_on_batch(
                    ref_imgs, valid * 0.9)  # label smoothing
                d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

                make_trainable(self.discriminator, False)

                logs = self.combined.train_on_batch([noise], [valid])
                tensorboard.on_epoch_end(iter,
                                         named_logs(self.combined, [logs]))

                if iter % (sample_interval // 10) == 0:
                    elapsed_time = datetime.datetime.now() - start_time
                    print(
                        "epoch:%d | iter : %d / %d | time : %10s | g_loss : %15s | d_loss : %s "
                        % (epoch, iter, max_iter, elapsed_time, logs, d_loss))

                if (iter + 1) % sample_interval == 0:
                    self.sample_images(epoch, iter + 1)

            # save weights after every epoch
            self.generator.save_weights('models/%s/%s_epoch%d_weights.h5' %
                                        (self.time, 'generator', epoch))
Esempio n. 5
0
    def train(self, epochs, batch_size=BATCH_SIZE, sample_interval=50):

        tensorboard = TensorBoard(log_dir=LOG_DIR)
        tensorboard.set_model(self.discriminator)

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Detect batch size in npys

            batch_size = min(self.this_npy_num_imgs, batch_size)

            idx = np.random.randint(0, self.this_npy_num_imgs-1, batch_size)

            # Select a random batch of images
            #self.X_train = os.path.join(OUTPATH, np.random.choice(os.listdir(OUTPATH)))
            self.X_train = np.load(train_robin, allow_pickle=True)
            self.X_train = self.X_train[idx]
            self.X_train = np.expand_dims(self.X_train, axis=3)
            self.X_train = self.X_train / (255/2) - 1

            noise = np.random.normal(-1, 1, ((batch_size,) + self.latent_dim))

            # Adversarial ground truths
            valid = np.ones(self.X_train.shape)
            fake = np.zeros(self.X_train.shape)

            # Generate a batch of new images
            gen_imgs = self.generator.predict(noise)

            if epoch == 0 or accuracy < 80:
                # Train the discriminator
                d_loss_real = self.discriminator.train_on_batch(self.X_train,
                                                                valid)
                d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            else:
                # Test the discriminator
                d_loss_real = self.discriminator.test_on_batch(self.X_train,
                                                               valid)
                d_loss_fake = self.discriminator.test_on_batch(gen_imgs, fake)

            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            accuracy = 100*d_loss[1]

            # ---------------------
            #  Train Generator
            # ---------------------
            noise = np.random.normal(-1, 1, ((batch_size,) + self.latent_dim))

            if epoch == 0 or accuracy > 20:
                # Train the generator (to have the discriminator label samples
                # as valid)
                g_loss = self.combined.train_on_batch(noise, valid)
            else:
                # Train the generator (to have the discriminator label samples
                # as valid)
                g_loss = self.combined.test_on_batch(noise, valid)

            tensorboard.on_epoch_end(epoch, {'generator loss': g_loss,
                                             'discriminator loss': d_loss[0],
                                             'Accuracy': accuracy,
                                             'Comb. loss': g_loss + d_loss[0]})

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                print(f"@ {epoch:{len(str(EPOCHS))}}:\t"
                      f"Accuracy: {int(accuracy):3}%\t"
                      f"G-Loss: {g_loss:6.3f}\t"
                      f"D-Loss: {d_loss[0]:6.3f}\t"
                      f"Combined: {g_loss+d_loss[0]:6.3f}")
                self.sample_images(epoch)

        tensorboard.on_train_end(tensorboard)
        self.discriminator.save('discriminator.h5')
        self.generator.save('generator.h5')
 def on_epoch_end(self, epoch, logs=None):
     TensorBoardEmbeddingMixin.on_epoch_end(self, epoch)
     TensorBoard.on_epoch_end(self, epoch, logs)
Esempio n. 7
0
    res = []
    print_msg = "iteration: {} : loss: {:.6f}, acc: {:.4%}, avg_pred: {:.4f}, avg_y: {:.4f}, left_iter_to_test: {}"
    best_score = np.inf
    for filename in os.listdir("./logs"):
        os.remove("./logs/{}".format(filename))
    tensorboard = TensorBoard(log_dir='./logs',
                              histogram_freq=0,
                              batch_size=BATCH_SIZE,
                              write_graph=True,
                              write_images=False)
    tensorboard.set_model(model)

    for x, y in generator(train_data, BATCH_SIZE):
        r = model.train_on_batch(x, y)
        tensorboard.on_epoch_end(iteration, {
            'train_loss': r[0],
            'train_acc': r[1]
        })
        r += [np.mean(y)]
        res.append(r)

        iteration += 1
        if iteration % PRINT_DATA_EACH == 0:
            print(
                print_msg.format(iteration, *np.mean(res, axis=0),
                                 TEST_EACH - ((iteration - 1) % TEST_EACH)))
            res = []

        if iteration % (TEST_EACH) == 0:
            true = []
            test = []
            for i in tqdm(range(len(test_data) - TIMESTEP_SIZE - DISTANCE),
Esempio n. 8
0
    def train_esrgan(self, 
        epochs=None, batch_size=16, 
        modelname=None, 
        datapath_train=None,
        datapath_validation=None, 
        steps_per_validation=1000,
        datapath_test=None, 
        workers=4, max_queue_size=10,
        first_epoch=0,
        print_frequency=1,
        crops_per_image=2,
        log_weight_frequency=None, 
        log_weight_path='./model/', 
        log_tensorboard_path='./data/logs/',
        log_tensorboard_update_freq=10,
        log_test_frequency=500,
        log_test_path="./images/samples/", 
        media_type='i'        
    ):
        """Train the ESRGAN network

        :param int epochs: how many epochs to train the network for
        :param str modelname: name to use for storing model weights etc.
        :param str datapath_train: path for the image files to use for training
        :param str datapath_test: path for the image files to use for testing / plotting
        :param int print_frequency: how often (in epochs) to print progress to terminal. Warning: will run validation inference!
        :param int log_weight_frequency: how often (in epochs) should network weights be saved. None for never
        :param int log_weight_path: where should network weights be saved        
        :param int log_test_frequency: how often (in epochs) should testing & validation be performed
        :param str log_test_path: where should test results be saved
        :param str log_tensorboard_path: where should tensorflow logs be sent
        """

        
        
         # Create data loaders
        train_loader = DataLoader(
            datapath_train, batch_size,
            self.height_hr, self.width_hr,
            self.upscaling_factor,
            crops_per_image,
            media_type,
            self.channels,
            self.colorspace
        )

        # Validation data loader
        validation_loader = None 
        if datapath_validation is not None:
            validation_loader = DataLoader(
                datapath_validation, batch_size,
                self.height_hr, self.width_hr,
                self.upscaling_factor,
                crops_per_image,
                media_type,
                self.channels,
                self.colorspace
        )

        test_loader = None
        if datapath_test is not None:
            test_loader = DataLoader(
                datapath_test, 1,
                self.height_hr, self.width_hr,
                self.upscaling_factor,
                1,
                media_type,
                self.channels,
                self.colorspace
        )
    
        # Use several workers on CPU for preparing batches
        enqueuer = OrderedEnqueuer(
            train_loader,
            use_multiprocessing=True,
            shuffle=True
        )
        enqueuer.start(workers=workers, max_queue_size=max_queue_size)
        output_generator = enqueuer.get()
        
        # Callback: tensorboard
        if log_tensorboard_path:
            tensorboard = TensorBoard(
                log_dir=os.path.join(log_tensorboard_path, modelname),
                histogram_freq=0,
                batch_size=batch_size,
                write_graph=True,
                write_grads=True,
                update_freq=log_tensorboard_update_freq
            )
            tensorboard.set_model(self.esrgan)
        else:
            print(">> Not logging to tensorboard since no log_tensorboard_path is set")

        # Learning rate scheduler
        def lr_scheduler(epoch, lr):
            factor = 0.5
            decay_step =  [50000,100000,200000,300000]  
            if epoch in decay_step and epoch:
                return lr * factor
            return lr
        lr_scheduler_gan = LearningRateScheduler(lr_scheduler, verbose=1)
        lr_scheduler_gan.set_model(self.esrgan)
        lr_scheduler_gen = LearningRateScheduler(lr_scheduler, verbose=0)
        lr_scheduler_gen.set_model(self.generator)
        lr_scheduler_dis = LearningRateScheduler(lr_scheduler, verbose=0)
        lr_scheduler_dis.set_model(self.discriminator)
        lr_scheduler_ra = LearningRateScheduler(lr_scheduler, verbose=0)
        lr_scheduler_ra.set_model(self.ra_discriminator)

        
        # Callback: format input value
        def named_logs(model, logs):
            """Transform train_on_batch return value to dict expected by on_batch_end callback"""
            result = {}
            for l in zip(model.metrics_names, logs):
                result[l[0]] = l[1]
            return result

        # Shape of output from discriminator
        disciminator_output_shape = list(self.ra_discriminator.output_shape)
        disciminator_output_shape[0] = batch_size
        disciminator_output_shape = tuple(disciminator_output_shape)

        # VALID / FAKE targets for discriminator
        real = np.ones(disciminator_output_shape)
        fake = np.zeros(disciminator_output_shape) 
               

        # Each epoch == "update iteration" as defined in the paper        
        print_losses = {"GAN": [], "D": []}
        start_epoch = datetime.datetime.now()
        
        # Random images to go through
        #idxs = np.random.randint(0, len(train_loader), epochs)        
        
        # Loop through epochs / iterations
        for epoch in range(first_epoch, int(epochs)+first_epoch):
            lr_scheduler_gan.on_epoch_begin(epoch)
            lr_scheduler_ra.on_epoch_begin(epoch)
            lr_scheduler_dis.on_epoch_begin(epoch)
            lr_scheduler_gen.on_epoch_begin(epoch)

            # Start epoch time
            if epoch % print_frequency == 0:
                print("\nEpoch {}/{}:".format(epoch+1, epochs+first_epoch))
                start_epoch = datetime.datetime.now()            

            # Train discriminator 
            self.discriminator.trainable = True
            self.ra_discriminator.trainable = True
            
            imgs_lr, imgs_hr = next(output_generator)
            generated_hr = self.generator.predict(imgs_lr)

            real_loss = self.ra_discriminator.train_on_batch([imgs_hr,generated_hr], real)
            #print("Real: ",real_loss)
            fake_loss = self.ra_discriminator.train_on_batch([generated_hr,imgs_hr], fake)
            #print("Fake: ",fake_loss)
            discriminator_loss = 0.5 * np.add(real_loss, fake_loss)

            # Train generator
            self.discriminator.trainable = False
            self.ra_discriminator.trainable = False
            
            #for _ in tqdm(range(10),ncols=60,desc=">> Training generator"):
            imgs_lr, imgs_hr = next(output_generator)
            gan_loss = self.esrgan.train_on_batch([imgs_lr,imgs_hr], [imgs_hr,real,imgs_hr])
     
            # Callbacks
            logs = named_logs(self.esrgan, gan_loss)
            tensorboard.on_epoch_end(epoch, logs)
            

            # Save losses            
            print_losses['GAN'].append(gan_loss)
            print_losses['D'].append(discriminator_loss)

            # Show the progress
            if epoch % print_frequency == 0:
                g_avg_loss = np.array(print_losses['GAN']).mean(axis=0)
                d_avg_loss = np.array(print_losses['D']).mean(axis=0)
                print(">> Time: {}s\n>> GAN: {}\n>> Discriminator: {}".format(
                    (datetime.datetime.now() - start_epoch).seconds,
                    ", ".join(["{}={:.4f}".format(k, v) for k, v in zip(self.esrgan.metrics_names, g_avg_loss)]),
                    ", ".join(["{}={:.4f}".format(k, v) for k, v in zip(self.discriminator.metrics_names, d_avg_loss)])
                ))
                print_losses = {"GAN": [], "D": []}

                # Run validation inference if specified
                if datapath_validation:
                    validation_losses = self.generator.evaluate_generator(
                        validation_loader,
                        steps=steps_per_validation,
                        use_multiprocessing=workers>1,
                        workers=workers
                    )
                    print(">> Validation Losses: {}".format(
                        ", ".join(["{}={:.4f}".format(k, v) for k, v in zip(self.generator.metrics_names, validation_losses)])
                    ))                

            # If test images are supplied, run model on them and save to log_test_path
            if datapath_test and epoch % log_test_frequency == 0:
                plot_test_images(self.generator, test_loader, datapath_test, log_test_path, epoch, modelname,
                channels = self.channels,colorspace=self.colorspace)

            # Check if we should save the network weights
            if log_weight_frequency and epoch % log_weight_frequency == 0:
                # Save the network weights
                self.save_weights(os.path.join(log_weight_path, modelname))
Esempio n. 9
0
class cycGAN():
    def __init__(self):

        self.root_path = FLAGS.root_path
        self.trainA = FLAGS.trainA
        self.trainB = FLAGS.trainB

        # hyperparameter setup
        self.lambda_A = 10.0  # cyclic loss weight A2B
        self.lambda_B = 10.0  # cyclic loss weight B2A

        self.lambda_id_A = 0.1 * self.lambda_A
        self.lambda_id_B = 0.1 * self.lambda_B

        self.lambda_D = 1.0  # weight for loss discriminator guess on sythetic image
        self.lr_D = 2e-4
        self.lr_G = 2e-4

        self.generator_iter = 1
        self.discriminator_iter = 1
        self.beta_1 = 0.5
        self.beta_2 = 0.999
        self.batch_size = 1
        self.batch_num = 0
        self.epochs = 200
        self.save_interval = 1
        self.fake_pool_size = 50
        self.channels = 3

        self.Real_label = 1  # Use e.g. 0.9 to avoid training the discriminators to zero loss
        self.img_shape = (256, 256, 3)
        self.img_rows = 256
        self.img_columns = 256
        self.save_interval = 50

        self.DA_losses = []
        self.DB_losses = []

    def setup_model(self):

        # initial image pooling
        self.image_pooling = ImagePool(self.fake_pool_size)
        # optimizer
        self.opt_D = Adam(self.lr_D, self.beta_1, self.beta_2)
        self.opt_G = Adam(self.lr_G, self.beta_1, self.beta_2)

        # setup discriminator model
        self.D_A = model_discriminator(self.img_shape)
        self.D_B = model_discriminator(self.img_shape)

        self.D_A.summary()

        self.loss_weights_D = [0.5]
        self.img_A = Input(shape=self.img_shape)  # real image
        self.img_B = Input(shape=self.img_shape)

        # discriminator build
        self.guess_A = self.D_A(self.img_A)
        self.guess_B = self.D_B(self.img_B)

        self.D_A = Model(inputs=self.img_A,
                         outputs=self.guess_A,
                         name='D_A_Model')  #name for save model
        self.D_B = Model(inputs=self.img_B,
                         outputs=self.guess_B,
                         name='D_B_Model')

        self.D_A.compile(optimizer=self.opt_D, \
                    loss='mse', \
                    loss_weights=self.loss_weights_D, \
                    metrics=['accuracy'])

        self.D_B.compile(optimizer=self.opt_D, \
                    loss='mse', \
                    loss_weights=self.loss_weights_D, \
                    metrics=['accuracy'])

        # for generator model, we do not train discriminator
        self.D_A_static = Network(inputs=self.img_A,
                                  outputs=self.guess_A,
                                  name='D_A_static_model')
        self.D_B_static = Network(inputs=self.img_B,
                                  outputs=self.guess_B,
                                  name='D_B_static_model')

        #generator setup
        self.G_A2B = model_generator(self.channels,
                                     self.img_shape,
                                     name='G_A2B_model')
        self.G_B2A = model_generator(self.channels,
                                     self.img_shape,
                                     name='G_B2A_model')

        self.G_A2B.summary()

        # # import image
        # self.img_A = Input(shape=self.img_shape)
        # self.img_B = Input(shape=self.img_shape)

        # generate fake images, transfer image from A to B
        self.fake_B = self.G_A2B(self.img_A)
        self.fake_A = self.G_B2A(self.img_B)

        # reconstruction, transfer to original image from fake image
        self.reconstor_A = self.G_B2A(self.fake_B)
        self.reconstor_B = self.G_A2B(self.fake_A)

        self.D_A_static.trainable = False
        self.D_B_static.trainable = False

        # Discriminators determines validity of translated images
        self.valid_A = self.D_A_static(self.fake_A)
        self.valid_B = self.D_A_static(self.fake_B)

        # identity learning
        self.identity_A = self.G_B2A(self.img_A)
        self.identity_B = self.G_A2B(self.img_B)

        # combined two models and compile
        # ombined model trains generators to fool discriminators

        self.combined_model = Model(inputs=[self.img_A, self.img_B], \
                                    outputs=[self.valid_A, self.valid_B, \
                                    self.reconstor_A, self.reconstor_B, \
                                    self.identity_A, self.identity_B],
                                    name='Combined_G_model')

        self.combined_model.compile(loss=['mse', 'mse', \
                                           'mae', 'mae', \
                                           'mae', 'mae'],\
                                    loss_weights=[self.lambda_D, self.lambda_D, \
                                                  self.lambda_A, self.lambda_B, \
                                                  self.lambda_id_A, self.lambda_id_B],\
                                    optimizer=self.opt_G)

        self.config = tf.ConfigProto()
        self.config.gpu_options.allow_growth = True

        # Create a session with the above options specified.
        kerasbackend.tensorflow_backend.set_session(
            tf.Session(config=self.config))

        # patchGAN
        # output shape of D(PatchGAN)
        patch = int(self.img_rows / 2**4)
        self.disc_patch = (patch, patch, 1)

    def save_loss_tocsv(self, history, time):

        model_path = os.path.join('images/{}'.format(time))
        if not os.path.exists(model_path):
            os.makedirs(model_path)
        keys = sorted(history.keys())
        with open('images/{}/loss_output.csv'.format(time), 'w') as csv_file:
            writer = csv.writer(csv_file, delimiter=',')
            writer.writerow(keys)
            writer.writerows(zip(*[history[key] for key in keys]))

    def reNamed_logs(self, logs):
        log_array = [float(i) for i in logs]
        result = {}
        result["GA_loss"] = (log_array[1] + log_array[3]) * 0.5
        result["GB_loss"] = (log_array[4] + log_array[2]) * 0.5
        result["GA_Fake_loss"] = log_array[1]
        result["GB_Fake_loss"] = log_array[2]
        result["Recon_loss"] = np.mean(log_array[3:5])
        result["g_loss"] = log_array[0]
        return result

    def saveimage(self,
                  batch_index,
                  epoch_index,
                  path_testA=FLAGS.path_testA,
                  path_testB=FLAGS.path_testB):

        os.makedirs('images', exist_ok=True)
        rows, columns = 2, 3
        img_A = imread(path_testA, pilmode='RGB').astype(np.float32)
        img_B = imread(path_testB, pilmode='RGB').astype(np.float32)

        img_A = resize(img_A, (256, 256))
        img_B = resize(img_B, (256, 256))
        #   normalization
        imgs_A, imgs_B = [], []
        imgs_A.append(img_A)
        imgs_A = np.array(imgs_A) / 127.5 - 1

        imgs_B.append(img_B)
        imgs_B = np.array(imgs_B) / 127.5 - 1

        # transform other domain
        fake_B = self.G_A2B.predict(imgs_A)
        fake_A = self.G_B2A.predict(imgs_B)

        # recontract to orginal domain
        recon_A = self.G_B2A.predict(fake_B)
        recon_B = self.G_A2B.predict(fake_A)

        imgs = np.concatenate(
            [imgs_A, fake_B, recon_A, imgs_B, fake_A, recon_B])

        # rescale image 0-1
        imgs = 0.5 * imgs + 0.5
        # print(batch_index)
        titles = ['original', 'transform', 'goback']
        fig, axs = plt.subplots(rows, columns)
        count = 0
        for i in range(rows):
            for j in range(columns):
                axs[i, j].imshow(imgs[count])
                axs[i, j].set_title(titles[j])
                axs[i, j].axis('off')
                count += 1
        if FLAGS.is_train:
            plt.savefig("images/%d_%d.png" % (epoch_index, batch_index))
        else:
            plt.savefig("images/test_{}.png".format(batch_index))
        plt.close()

    def train(self):

        # checking loss

        self.tb_G_loss_track = TensorBoard(log_dir='./cycGAN_G_loss', histogram_freq=0, \
                                      batch_size= self.batch_size, \
                                      write_graph=True, \
                                      write_grads=False, \
                                      write_images=False, \
                                      embeddings_freq=0, \
                                      embeddings_layer_names=None, \
                                      embeddings_metadata=None, \
                                      embeddings_data=None, \
                                      update_freq='epoch')

        self.tb_G_loss_track.set_model(self.combined_model)
        # Training discriminators one-hot vector
        valid = np.ones((self.batch_size, ) + self.disc_patch)
        fake = np.zeros((self.batch_size, ) + self.disc_patch)

        for epoch in range(self.epochs):

            # update learning rate (decay) for each epoch
            self.lr_D, self.lr_G = update_lr(epoch, self.lr_D, self.lr_G)
            kerasbackend.set_value(self.D_A.optimizer.lr, self.lr_D)
            kerasbackend.set_value(self.D_B.optimizer.lr, self.lr_D)
            kerasbackend.set_value(self.combined_model.optimizer.lr, self.lr_G)

            for batch_index, (batch_num, imgs_A, imgs_B) in enumerate(
                    loaddata_batch(self.batch_size, self.root_path,
                                   self.trainA, self.trainB)):

                fake_B_tmp = self.G_A2B.predict(imgs_A)
                fake_A_tmp = self.G_B2A.predict(imgs_B)

                fake_B = self.image_pooling.fake_image_pooling(fake_B_tmp)
                fake_A = self.image_pooling.fake_image_pooling(fake_A_tmp)

                dA_loss_real = self.D_A.train_on_batch(imgs_A, valid)
                dA_loss_fake = self.D_A.train_on_batch(fake_A, fake)

                dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)

                dB_loss_real = self.D_B.train_on_batch(imgs_B, valid)
                dB_loss_fake = self.D_B.train_on_batch(fake_B, fake)

                dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

                d_loss = 0.5 * np.add(dA_loss, dB_loss)

                # training generation
                g_loss = self.combined_model.train_on_batch([imgs_A, imgs_B], \
                                                       [valid, valid, \
                                                        imgs_A, imgs_B, \
                                                        imgs_A, imgs_B])

                # add tensorboard
                self.tb_G_loss_track.on_epoch_end(
                    batch_index, logs=self.reNamed_logs(g_loss))

                # collect losses for plot

                self.DA_losses.append(dA_loss)
                self.DB_losses.append(dB_loss)

                time = datetime.datetime.now().strftime("%Y-%m-%d-%H%M%S")
                print("[epoch_index: %d/%d][batch_index:%d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, adv: %05f, recon: %05f, id: %05f] [time:%s]" \
                    % (epoch, self.epochs, \
                       batch_index, batch_num, \
                       d_loss[0], 100 * d_loss[1], \
                       g_loss[0], \
                       np.mean(g_loss[1:3]), \
                       np.mean(g_loss[3:5]), \
                       np.mean(g_loss[5:6]), \
                       time))
                if batch_index % self.save_interval == 0:
                    print(
                        "<<=========save image + test image + original image + reconstruct image + return original image============>>"
                    )
                    self.saveimage(batch_index, epoch, FLAGS.path_testA,
                                   FLAGS.path_testB)  #save and test

            training_history = {
                'DA_losses': self.DA_losses,
                'DB_losses': self.DB_losses
            }
            self.save_loss_tocsv(
                training_history,
                datetime.datetime.now().strftime("%Y-%m-%d-%H%M%S"))
            self.tb_G_loss_track.on_epoch_end(epoch)

            if (epoch % 50 == 0):
                save_model(epoch, self.D_A)
                save_model(epoch, self.D_B)
                save_model(epoch, self.combined_model)
                save_model(epoch, self.G_A2B)
                save_model(epoch, self.G_B2A)

    def test(self, modelname):
        #load weights
        self.combined_model.load_weights('weight_combined_G_model.h5')
        self.D_A.load_weights('weight_D_A_Model.h5')
        self.D_B.load_weights('weight_D_B_Model.h5')
        self.G_B2A.load_weights('weight_G_B2A_model.h5')
        self.G_A2B.load_weights('weight_G_A2B_model.h5')

        #batch index and epoch = -1 as test
        self.saveimage(-1, -1)
 def on_epoch_end(self, epoch, logs=None):
     TensorBoardEmbeddingMixin.on_epoch_end(self, epoch)
     TensorBoard.on_epoch_end(self, epoch, logs)
Esempio n. 11
0
    def train(self, batch_size=4, epochs=25):
        cf = self.cf
        self.compile()
        model = self.keras_model
        word_vectors, char_vectors, train_ques_ids, X_train, y_train, val_ques_ids, X_valid, y_valid = self.data_train

        qanet_cb = QANetCallback(decay=cf.EMA_DECAY)
        tb = TensorBoard(log_dir=cf.TENSORBOARD_PATH,
                         histogram_freq=0,
                         write_graph=False,
                         write_images=False,
                         update_freq=cf.TENSORBOARD_UPDATE_FREQ)

        # Call set_model for all callbacks
        qanet_cb.set_model(model)
        tb.set_model(model)

        ep_list = []
        avg_train_loss_list = []
        em_score_list = []
        f1_score_list = []

        global_steps = 0
        gt_start_list, gt_end_list = y_valid[2:]
        for ep in range(1, epochs + 1):  # Epoch num start from 1
            print('----------- Training for epoch {}...'.format(ep))
            # Train
            batch = 0
            sum_loss = 0
            num_batches = (len(X_train[0]) - 1) // batch_size + 1
            for X_batch, y_batch in get_batch(X_train,
                                              y_train,
                                              batch_size=batch_size,
                                              shuffle=True):
                batch_logs = {'batch': batch, 'size': len(X_batch[0])}
                tb.on_batch_begin(batch, batch_logs)

                loss, loss_p1, loss_p2, loss_start, loss_end = model.train_on_batch(
                    X_batch, y_batch)
                sum_loss += loss
                avg_loss = sum_loss / (batch + 1)
                print(
                    'Epoch: {}/{}, Batch: {}/{}, Accumulative average loss: {:.4f}, Loss: {:.4f}, Loss_P1: {:.4f}, Loss_P2: {:.4f}, Loss_start: {:.4f}, Loss_end: {:.4f}'
                    .format(ep, epochs, batch, num_batches, avg_loss, loss,
                            loss_p1, loss_p2, loss_start, loss_end))
                batch_logs.update({
                    'loss': loss,
                    'loss_p1': loss_p1,
                    'loss_p2': loss_p2
                })
                qanet_cb.on_batch_end(batch, batch_logs)
                tb.on_batch_end(batch, batch_logs)

                global_steps += 1
                batch += 1

            ep_list.append(ep)
            avg_train_loss_list.append(avg_loss)

            print('Backing up temp weights...')
            model.save_weights(cf.TEMP_MODEL_PATH)
            qanet_cb.on_epoch_end(ep)  # Apply EMA weights
            model.save_weights(cf.MODEL_PATH % str(ep))

            print('----------- Validating for epoch {}...'.format(ep))
            valid_scores = self.validate(X_valid,
                                         y_valid,
                                         gt_start_list,
                                         gt_end_list,
                                         batch_size=cf.BATCH_SIZE)
            em_score_list.append(valid_scores['exact_match'])
            f1_score_list.append(valid_scores['f1'])
            print(
                '------- Result of epoch: {}/{}, Average_train_loss: {:.6f}, EM: {:.4f}, F1: {:.4f}\n'
                .format(ep, epochs, avg_loss, valid_scores['exact_match'],
                        valid_scores['f1']))

            tb.on_epoch_end(ep, {
                'f1': valid_scores['f1'],
                'em': valid_scores['exact_match']
            })

            # Write result to CSV file
            result = pd.DataFrame({
                'epoch': ep_list,
                'avg_train_loss': avg_train_loss_list,
                'em': em_score_list,
                'f1': f1_score_list
            })
            result.to_csv(cf.RESULT_LOG, index=None)

            # Restore the original weights to continue training
            print('Restoring temp weights...')
            model.load_weights(cf.TEMP_MODEL_PATH)

        tb.on_train_end(None)
Esempio n. 12
0
        for i in range(len(x)):

            # train disc on real
            disc.train_on_batch([x[i], y[i]], real_y)

            # gen fake
            fake = gen.predict(x[i])

            # train disc on fake
            disc.train_on_batch([x[i], fake], fake_y)

            # train combined
            disc.trainable = False
            combined.train_on_batch(x[i], [y[i], real_y])
            disc.trainable = True

            #log.write(str(e) + ", " + str(s) + ", " + str(dr_loss) + ", " + str(df_loss) + ", " + str(g_loss[0]) + ", " + str(g_loss[1]) + ", " + str(opt_dcgan.get_config()["lr"]) + "\n")

    # output random result
    #val_sequence = sequences[train_offset:]
    #generated_y = gen.predict(x[random_index])
    #save_image(strip(x[random_index]) / 2 + 0.5, y[random_index], re_shape(generated_y), "validation/e{}_{}.png".format(e, s))

    # save weights
    gen.save_weights(checkpoint_gen_name, overwrite=True)
    disc.save_weights(checkpoint_disc_name, overwrite=True)

    tensorlog.on_epoch_end(e)

tensorlog.on_train_end()
Esempio n. 13
0
    def train(self, epochs, batch_size=BATCH_SIZE, sample_interval=50):

        tensorboard = TensorBoard(log_dir=LOG_DIR)
        tensorboard.set_model(self.discriminator)

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Detect batch size in npys

            batch_size = min(self.img_per_npy, batch_size)


            # Select a random batch of images
            self.X_train = os.path.join(OUTPATH,
                                        np.random.choice(os.listdir(OUTPATH)))
            self.X_train = np.load(self.X_train, allow_pickle=True)
            idx = np.random.randint(0, len(self.X_train), batch_size)
            self.X_train = self.X_train[idx]
            self.X_train = np.expand_dims(self.X_train, axis=3)
            self.X_train = self.X_train / (-255/2) + 1

            noise = np.random.normal(-1, 1, ((batch_size,) + LATENT_SIZE))

            # Adversarial ground truths

            valid = np.ones((batch_size,))
            if ADD_LABEL_NOISE:
                valid -= np.random.uniform(high=LABEL_NOISE,
                                           size=(batch_size,))
            for img in range(batch_size):
                if np.random.rand() < P_FLIP_LABEL:
                    valid[img] = 1 - valid[img]

            fake = np.zeros((batch_size,))
            if ADD_LABEL_NOISE:
                fake += np.random.uniform(high=LABEL_NOISE,
                                          size=(batch_size,))
                print(fake)
            for img in range(batch_size):
                if np.random.rand() < P_FLIP_LABEL:
                    fake[img] = 1 - fake[img]

            # Generate a batch of new images
            gen_imgs = self.generator.predict(noise)

            if epoch == 0 or accuracy < 80:
                # Train the discriminator
                d_loss_real = self.discriminator.train_on_batch(self.X_train,
                                                                valid)
                d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            else:
                # Test the discriminator
                d_loss_real = self.discriminator.test_on_batch(self.X_train,
                                                               valid)
                d_loss_fake = self.discriminator.test_on_batch(gen_imgs, fake)

            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            accuracy = 100*d_loss[1]

            # ---------------------
            #  Train Generator
            # ---------------------
            noise = np.random.normal(-1, 1, ((batch_size,) + LATENT_SIZE))

            if epoch == 0 or accuracy > 52:
                # Train the generator (to have the discriminator label samples
                # as valid)
                g_loss = self.combined.train_on_batch(noise, valid)
            else:
                # Train the generator (to have the discriminator label samples
                # as valid)
                g_loss = self.combined.test_on_batch(noise, valid)

            tensorboard.on_epoch_end(epoch, {'generator loss': g_loss,
                                             'discriminator loss': d_loss[0],
                                             'Accuracy': accuracy,
                                             'Comb. loss': g_loss + d_loss[0]})

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                print(f"@ {epoch:{len(str(EPOCHS))}}:\t"
                      f"Accuracy: {int(accuracy):3}%\t"
                      f"G-Loss: {g_loss:6.3f}\t"
                      f"D-Loss: {d_loss[0]:6.3f}\t"
                      f"Combined: {g_loss+d_loss[0]:6.3f}")
                self.sample_images(epoch, accuracy, real_imgs=self.X_train)
            if epoch % SAVE_INTERVAL == 0:
                self.discriminator.save('discriminator.h5')
                self.generator.save('generator.h5')

        tensorboard.on_train_end(tensorboard)
Esempio n. 14
0
    def train(self, epochs, batch_size=128, sample_interval=100):

        (X_train, _), (_, _) = mnist.load_data()

        
        # Normalization to the scale -1 to 1
        # X_train = X_train.astype('float32')
        X_train = X_train / 127.5 - 1.
        X_train = np.expand_dims(X_train, axis=3)

        # Create the labels
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        tensorboard = TensorBoard(log_dir='./tmp/logs',
                                  histogram_freq=0,
                                  write_graph=True)
        tensorboard.set_model(self.combined)

        g_loss_list = []
        d_loss_list = []
        for epoch in range(epochs):

            # ----------------------------- #
            #   Randomly pick batch imags
            #   to train the discriminator
            # ----------------------------- #

            # Randomly pick batch imags
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]

            # Generate batch-size of random noise with latent dimension size
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            gen_imgs = self.generator.predict(noise)

            # The loss of discriminator
            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            # Avg loss
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
            d_loss_list.append(d_loss[0])

            # # --------------------------- #
            # #  Train the generator
            # # --------------------------- #
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            g_loss = self.combined.train_on_batch(noise, valid)
            g_loss_list.append(g_loss)

            print_str = "%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss)
            print(print_str)
            tensorboard.on_epoch_end(epoch, self._named_logs(['d_loss', 'd_accuracy'], d_loss))
            tensorboard.on_epoch_end(epoch, self._named_logs(['g_loss'], [g_loss]))

            if epoch % sample_interval == 0:
                self.sample_images(epoch)
        tensorboard.on_train_end(None)
        # Save the loss image
        plt.xlabel("Steps")
        plt.ylabel("Loss")
        plt.title("Loss of Generator and Discriminator")
        plt.plot(np.arange(epochs), g_loss_list, label="Generator Loss")
        plt.plot(np.arange(epochs), d_loss_list, label="Discriminator Loss")
        plt.legend()
        plt.savefig('epoch {}.png'.format(epochs))
        print("Traning finished.\n...\n")
        # Model Saving
        self.generator.save("models/generator.tf")
        self.discriminator.save("models/discriminator.tf")
        self.combined.save("models/gan.tf")
        print("Model Saved.\n-------------------------------------------------------------------")
Esempio n. 15
0
class Network:
    """
    A deep convolutional neural network that takes as input the ML representation of a game (n, m, k) and tries to
    return 1 if player 1 (white in chess, red in checkers etc.) is going to win, -1 if player 2 is going to win,
    and 0 if the game is going to be a draw.

    The Network's model will input a binary matrix with a shape of GameClass.STATE_SHAPE and output a tuple consisting
    of the probability distribution over legal moves and the position's evaluation.
    """
    def __init__(self,
                 GameClass,
                 model_path=None,
                 reinforcement_training=False,
                 hyper_params=None):
        self.GameClass = GameClass
        self.model_path = model_path

        # lazily initialized so Network can be passed between processes before being initialized
        self.model = None

        self.reinforcement_training = reinforcement_training
        self.hyper_params = hyper_params if hyper_params is not None else {}
        self.tensor_board = None
        self.epoch = 0

    def initialize(self):
        """
        Initializes the Network's model.
        """
        # Note: keras imports are within functions to prevent initializing keras in processes that import from this file
        from keras.models import load_model
        from keras.callbacks import TensorBoard

        if self.model is not None:
            return

        if self.model_path is not None:
            input_shape = self.GameClass.STATE_SHAPE
            output_shape = self.GameClass.MOVE_SHAPE
            self.model = load_model(self.model_path)
            if self.model.input_shape != (None, ) + input_shape:
                raise Exception('Input shape of loaded model doesn\'t match!')
            if self.model.output_shape != [(None, ) + output_shape, (None, 1)]:
                raise Exception('Output shape of loaded model doesn\'t match!')
            # TODO: recompile model with loss_weights and learning schedule from config file
        else:
            self.model = self.create_model(**self.hyper_params)

        if self.reinforcement_training:
            self.tensor_board = TensorBoard(
                log_dir=f'{get_training_path(self.GameClass)}/logs/'
                f'model_reinforcement_{time()}',
                histogram_freq=0,
                write_graph=True)
            self.tensor_board.set_model(self.model)

    def create_model(self,
                     kernel_size=(4, 4),
                     convolutional_filters=64,
                     residual_layers=6,
                     value_head_neurons=16,
                     policy_loss_value=1):
        """
        https://www.youtube.com/watch?v=OPgRNY3FaxA
        """
        # Note: keras imports are within functions to prevent initializing keras in processes that import from this file
        from keras.models import Model
        from keras.layers import Input, Conv2D, BatchNormalization, Flatten, Dense, Activation, Add, Reshape

        input_shape = self.GameClass.STATE_SHAPE
        output_shape = self.GameClass.MOVE_SHAPE
        output_neurons = np.product(output_shape)

        input_tensor = Input(input_shape)

        # convolutional layer
        x = Conv2D(convolutional_filters, kernel_size,
                   padding='same')(input_tensor)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)

        # residual layers
        for _ in range(residual_layers):
            y = Conv2D(convolutional_filters, kernel_size, padding='same')(x)
            y = BatchNormalization()(y)
            y = Activation('relu')(y)
            y = Conv2D(convolutional_filters, kernel_size, padding='same')(y)
            y = BatchNormalization()(y)
            # noinspection PyTypeChecker
            x = Add()([x, y])
            x = Activation('relu')(x)

        # policy head
        policy = Conv2D(2, (1, 1), padding='same')(x)
        policy = BatchNormalization()(policy)
        policy = Activation('relu')(policy)
        policy = Flatten()(policy)
        policy = Dense(output_neurons, activation='softmax')(policy)
        policy = Reshape(output_shape, name='policy')(policy)

        # value head
        value = Conv2D(1, (1, 1), padding='same')(x)
        value = BatchNormalization()(value)
        value = Activation('relu')(value)
        value = Flatten()(value)
        value = Dense(value_head_neurons, activation='relu')(value)
        value = Dense(1, activation='tanh', name='value')(value)

        model = Model(input_tensor, [policy, value])
        model.compile(optimizer='adam',
                      loss={
                          'policy': 'categorical_crossentropy',
                          'value': 'mean_squared_error'
                      },
                      loss_weights={
                          'policy': policy_loss_value,
                          'value': 1
                      },
                      metrics=['mean_squared_error'])
        return model

    def predict(self, states):
        return self.model.predict(states)

    def call(self, states):
        """
        For any of the given states, if no moves are legal, then the corresponding probability distribution will be a
        list with a single 1. This is done to allow for pass moves which are not encapsulated by GameClass.MOVE_SHAPE.

        :param states: The input positions with shape (k,) + GameClass.STATE_SHAPE, where k is the number of positions.
        :return: A list of length k. Each element of the list is a tuple where the 0th element is the probability
                 distribution on legal moves (ordered correspondingly with GameClass.get_possible_moves), and the 1st
                 element is the evaluation (a float in (-1, 1)).
        """
        raw_policies, evaluations = self.predict(states)

        filtered_policies = [
            raw_policy[self.GameClass.get_legal_moves(state)]
            for state, raw_policy in zip(states, raw_policies)
        ]
        filtered_policies = [
            filtered_policy /
            np.sum(filtered_policy) if len(filtered_policy) > 0 else [1]
            for filtered_policy in filtered_policies
        ]

        evaluations = evaluations.reshape(states.shape[0])
        return [(filtered_policy, evaluation) for filtered_policy, evaluation
                in zip(filtered_policies, evaluations)]

    def choose_move(self, position, return_distribution=False, optimal=False):
        distribution, evaluation = self.call(position[np.newaxis, ...])[0]
        idx = np.argmin(distribution) if optimal else np.random.choice(
            np.arange(len(distribution)), p=distribution)
        move = self.GameClass.get_possible_moves(position)[idx]
        return (move, distribution) if return_distribution else move

    def train(self, data, validation_fraction=0.2):
        # Note: keras imports are within functions to prevent initializing keras in processes that import from this file
        from keras.callbacks import TensorBoard, EarlyStopping

        split = int((1 - validation_fraction) * len(data))
        train_input, train_output = self.get_training_data(
            self.GameClass, data[:split])
        test_input, test_output = self.get_training_data(
            self.GameClass, data[split:])
        print('Training Samples:', train_input.shape[0])
        print('Validation Samples:', test_input.shape[0])

        self.model.fit(
            train_input,
            train_output,
            epochs=100,
            validation_data=(test_input, test_output),
            callbacks=[
                TensorBoard(
                    log_dir=
                    f'{get_training_path(self.GameClass)}/logs/model_{time()}'
                ),
                EarlyStopping(monitor='val_loss',
                              patience=3,
                              restore_best_weights=True)
            ])

    def train_step(self, states, policies, values):
        logs = self.model.train_on_batch(states, [policies, values],
                                         return_dict=True)
        self.tensor_board.on_epoch_end(self.epoch, logs)
        self.epoch += 1

    def finish_training(self):
        self.tensor_board.on_train_end()

    def save(self, model_path):
        self.model.save(model_path)

    def equal_model_architecture(self, network):
        """
        Both networks must be initialized.

        :return: True if this Network's model and the given network's model have the same architecture.
        """
        return self.model.get_config() == network.model.get_config()

    @classmethod
    def get_training_data(cls, GameClass, data, one_hot=False, shuffle=True):
        """


        :param GameClass:
        :param data: A list of game, outcome tuples. Each game is a list of position, distribution tuples.
        :param one_hot:
        :param shuffle:
        :return:
        """
        states = []
        policy_outputs = []
        value_outputs = []

        for game, outcome in data:
            for position, distribution in game:
                legal_moves = GameClass.get_legal_moves(position)
                policy = np.zeros_like(legal_moves, dtype=float)
                policy[legal_moves] = distribution
                policy /= np.sum(policy)  # rescale so total probability is 1

                if one_hot:
                    idx = np.unravel_index(policy.argmax(), policy.shape)
                    policy = np.zeros_like(policy)
                    policy[idx] = 1

                states.append(position)
                policy_outputs.append(policy)
                value_outputs.append(outcome)

        input_data = np.stack(states, axis=0)
        policy_outputs = np.stack(policy_outputs, axis=0)
        value_outputs = np.array(value_outputs)

        if shuffle:
            shuffle_indices = np.arange(input_data.shape[0])
            np.random.shuffle(shuffle_indices)
            input_data = input_data[shuffle_indices, ...]
            policy_outputs = policy_outputs[shuffle_indices, ...]
            value_outputs = value_outputs[shuffle_indices]

        return input_data, [policy_outputs, value_outputs]
Esempio n. 16
0
        for i, k in enumerate(args.style_layers):
            # log['style_loss'][k].append(out[offset + i])
            key = "style_loss_%s" % k
            style_block_loss = out[offset + i]
            logs[key] = style_block_loss
            style_loss += style_block_loss
        logs['style_loss'] = style_loss

        stop_time = time.time()
        print('Iteration %d/%d: loss = %f. t = %f (%f)' %
              (it + 1, args.num_iterations, out[0], stop_time - start_time,
               stop_time2 - start_time2))

        if not ((it + 1) % args.save_every):
            print('Saving checkpoint in %s...' % (args.checkpoint_path))
            model_checkpoint.on_epoch_end(it)
            print('Checkpoint saved.')

        # tensorboard
        if (log_dir):
            tensorboard.on_epoch_end(it, logs)

        start_time = time.time()

    # close callback
    if (log_dir):
        tensorboard.on_train_end(None)

    # save model
    pastiche_net.save_weights(weights_path)
Esempio n. 17
0
    def train_srgan(
        self,
        epochs,
        batch_size,
        dataname,
        datapath_train,
        datapath_validation=None,
        steps_per_validation=10,
        datapath_test=None,
        workers=40,
        max_queue_size=100,
        first_epoch=0,
        print_frequency=50,
        crops_per_image=3,
        log_weight_frequency=1000,
        log_weight_path='./data/weights/doctor_gan_ct_sn/',
        log_tensorboard_path='./data/logs/',
        log_tensorboard_name='RTC-SR',
        log_tensorboard_update_freq=1000,
        log_test_frequency=1000,
        log_test_path="./images/samples-ct-sn/",
    ):
        """Train the SRGAN network

        :param int epochs: how many epochs to train the network for
        :param str dataname: name to use for storing model weights etc.
        :param str datapath_train: path for the image files to use for training
        :param str datapath_test: path for the image files to use for testing / plotting
        :param int print_frequency: how often (in epochs) to print progress to terminal. Warning: will run validation inference!
        :param int log_weight_frequency: how often (in epochs) should network weights be saved. None for never
        :param int log_weight_path: where should network weights be saved
        :param int log_test_frequency: how often (in epochs) should testing & validation be performed
        :param str log_test_path: where should test results be saved
        :param str log_tensorboard_path: where should tensorflow logs be sent
        :param str log_tensorboard_name: what folder should tf logs be saved under
        """

        # Create train data loader
        loader = DataLoader(datapath_train, batch_size, self.height_hr,
                            self.width_hr, self.upscaling_factor,
                            crops_per_image)

        # Validation data loader
        if datapath_validation is not None:
            validation_loader = DataLoader(datapath_validation, batch_size,
                                           self.height_hr, self.width_hr,
                                           self.upscaling_factor,
                                           crops_per_image)
        print("Picture Loaders has been ready.")
        # Use several workers on CPU for preparing batches
        enqueuer = OrderedEnqueuer(loader,
                                   use_multiprocessing=False,
                                   shuffle=True)
        enqueuer.start(workers=workers, max_queue_size=max_queue_size)
        output_generator = enqueuer.get()
        print("Data Enqueuer has been ready.")
        # Callback: tensorboard
        if log_tensorboard_path:
            tensorboard = TensorBoard(log_dir=os.path.join(
                log_tensorboard_path, log_tensorboard_name),
                                      histogram_freq=0,
                                      batch_size=batch_size,
                                      write_graph=False,
                                      write_grads=False,
                                      update_freq=log_tensorboard_update_freq)
            tensorboard.set_model(self.srgan)
        else:
            print(
                ">> Not logging to tensorboard since no log_tensorboard_path is set"
            )

        # Callback: format input value
        def named_logs(model, logs):
            """Transform train_on_batch return value to dict expected by on_batch_end callback"""
            result = {}
            for l in zip(model.metrics_names, logs):
                result[l[0]] = l[1]
            return result

        # Shape of output from discriminator
        disciminator_output_shape = list(self.discriminator.output_shape)
        disciminator_output_shape[0] = batch_size
        disciminator_output_shape = tuple(disciminator_output_shape)

        # VALID / FAKE targets for discriminator
        real = np.ones(disciminator_output_shape)
        fake = np.zeros(disciminator_output_shape)

        # Each epoch == "update iteration" as defined in the paper
        print_losses = {"G": [], "D": []}
        start_epoch = datetime.datetime.now()

        # Random images to go through
        idxs = np.random.randint(0, len(loader), epochs)
        if self.use_EMA:
            self.EMAer = ExponentialMovingAverage(self.srgan)  # 在模型compile之后执行
        if self.use_EMA: self.EMAer.inject()  # 在模型compile之后执行

        # Loop through epochs / iterations
        for epoch in range(first_epoch, int(epochs) + first_epoch):
            # print(epoch)
            # Start epoch time
            if epoch % print_frequency == 1:
                start_epoch = datetime.datetime.now()

            # Train discriminator
            imgs_lr, imgs_hr = next(output_generator)

            if self.use_EMA: self.EMAer.apply_ema_weights()  # 将EMA的权重应用到模型中
            generated_hr = self.generator.predict(imgs_lr)  # 进行预测、验证、保存等操作
            if self.use_EMA:
                self.EMAer.reset_old_weights(
                )  # 继续训练之前,要恢复模型旧权重。还是那句话,EMA不影响模型的优化轨迹。

            real_loss = self.discriminator.train_on_batch(imgs_hr, real)
            fake_loss = self.discriminator.train_on_batch(generated_hr, fake)
            discriminator_loss = 0.5 * np.add(real_loss, fake_loss)

            # Train generator
            features_hr = self.vgg.predict(self.preprocess_vgg(imgs_hr))
            generator_loss = self.srgan.train_on_batch(imgs_lr,
                                                       [real, features_hr])

            # Callbacks
            logs = named_logs(self.srgan, generator_loss)
            tensorboard.on_epoch_end(epoch, logs)
            # print(generator_loss, discriminator_loss)
            # Save losses
            print_losses['G'].append(generator_loss)
            print_losses['D'].append(discriminator_loss)

            # Show the progress
            if epoch % print_frequency == 0:
                g_avg_loss = np.array(print_losses['G']).mean(axis=0)
                d_avg_loss = np.array(print_losses['D']).mean(axis=0)
                print(
                    "\nEpoch {}/{} | Time: {}s\n>> Generator/GAN: {}\n>> Discriminator: {}"
                    .format(
                        epoch, epochs + first_epoch,
                        (datetime.datetime.now() - start_epoch).seconds,
                        ", ".join([
                            "{}={:.4f}".format(k, v) for k, v in zip(
                                self.srgan.metrics_names, g_avg_loss)
                        ]), ", ".join([
                            "{}={:.4f}".format(k, v) for k, v in zip(
                                self.discriminator.metrics_names, d_avg_loss)
                        ])))
                print_losses = {"G": [], "D": []}
                # Run validation inference if specified
                # if datapath_validation:
                #     print(">> Running validation inference")
                #     validation_losses = self.generator.evaluate_generator(
                #         validation_loader,
                #         steps=steps_per_validation,
                #         use_multiprocessing=workers>1,
                #         workers=workers
                #     )
                #     print(">> Validation Losses: {}".format(
                #         ", ".join(["{}={:.4f}".format(k, v) for k, v in zip(self.generator.metrics_names, validation_losses)])
                #     ))

            # If test images are supplied, run model on them and save to log_test_path
            if datapath_test and epoch % log_test_frequency == 0:
                print(">> Ploting test images")
                if self.use_EMA:
                    self.EMAer.apply_ema_weights()  # 将EMA的权重应用到模型中
                plot_test_images(self,
                                 loader,
                                 datapath_test,
                                 log_test_path,
                                 epoch,
                                 refer_model=self.refer_model)
                if self.use_EMA: self.EMAer.reset_old_weights()

                # Check if we should save the network weights
            if log_weight_frequency and epoch % log_weight_frequency == 0:
                # Save the network weights
                print(">> Saving the network weights")
                if self.use_EMA:
                    self.EMAer.apply_ema_weights()  # 将EMA的权重应用到模型中
                self.save_weights(os.path.join(log_weight_path, dataname),
                                  epoch)
                if self.use_EMA: self.EMAer.reset_old_weights()
Esempio n. 18
0
class DQNAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        self.gamma = 0.99
        self.epsilon = 0.2
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        self.learning_rate = .0001  #.0001
        self.tau = 0.1
        self.buffer_size = 2000
        self.model = self._build_model()
        self.target_model = self._build_model()
        print(self.model.summary())

        # internal memory (deque)
        self.memory = deque(maxlen=self.buffer_size)
        self.A_loss = []
        self.experience = namedtuple(
            "Data",
            field_names=["state", "action", "reward", "next_state", "done"])

        # Create the TensorBoard callback,
        # which we will drive manually
        self.tensorboard = TensorBoard(log_dir='/tmp/my_tf_logs',
                                       histogram_freq=0,
                                       batch_size=32,
                                       write_graph=True,
                                       write_grads=True)
        self.tensorboard.set_model(self.target_model)
        self.batch_id = 0

    def _build_model(self):
        # Neural Net for Deep-Q learning Model
        # Define input layer (states)
        states = layers.Input(shape=(self.state_size), name='input')
        c1 = layers.Convolution2D(filters=32,
                                  kernel_size=8,
                                  strides=4,
                                  activation='relu')(states)  # edge detection
        c2 = layers.Convolution2D(filters=64,
                                  kernel_size=4,
                                  strides=2,
                                  activation='relu')(c1)
        c3 = layers.Convolution2D(filters=64,
                                  kernel_size=3,
                                  strides=1,
                                  activation='relu')(c2)
        l1 = layers.Flatten()(c3)
        l2 = layers.Dense(256, activation='relu')(l1)
        Q_val = layers.Dense(units=self.action_size,
                             name='Q_Values',
                             activation='linear')(l2)
        # Create Keras model
        model = models.Model(inputs=[states], outputs=Q_val)  #actions
        model.compile(loss='mse',
                      optimizer=optimizers.Adam(lr=self.learning_rate))
        self.get_conv = K.function(inputs=[model.input],
                                   outputs=model.layers[1].output)
        return model

    def _build_model_old(self):
        # NVIDIA
        frame = layers.Input(shape=(self.state_size), name='input')

        c1 = layers.Convolution2D(filters=24,
                                  kernel_size=5,
                                  strides=2,
                                  activation='elu')(frame)
        c2 = layers.Convolution2D(filters=36,
                                  kernel_size=5,
                                  strides=2,
                                  activation='elu')(c1)
        c3 = layers.Convolution2D(filters=48,
                                  kernel_size=5,
                                  strides=2,
                                  activation='elu')(c2)
        c4 = layers.Convolution2D(filters=64, kernel_size=3,
                                  activation='elu')(c3)
        c5 = layers.Convolution2D(filters=64, kernel_size=3,
                                  activation='elu')(c4)

        l1 = layers.Flatten()(c5)
        l2 = layers.Dense(100, activation='elu')(l1)
        l3 = layers.Dense(50, activation='elu')(l2)

        Q_val = layers.Dense(units=self.action_size,
                             name='Q_Values',
                             activation='linear')(l3)

        model = models.Model(inputs=[frame], outputs=Q_val)

        # we use MSE (Mean Squared Error) as loss function
        model.compile(loss='mse',
                      optimizer=optimizers.Adam(lr=self.learning_rate))
        self.get_conv = K.function(inputs=[model.input],
                                   outputs=model.layers[1].output)
        return model

    def step(self, state, action, reward, next_state, done):
        d = self.experience(state, action, reward, next_state, done)
        if (len(self.memory) == self.buffer_size):
            self.memory.popleft()
        self.memory.append(d)

    def conv_to_tensor(self, img):  ###CHANGE
        if (len(img) < 10):  # why is this here??
            return img
        # Black and White Image ex: 1, 244, 244, 1
        if (len(img.shape) == 2):
            img = np.expand_dims(img, axis=3)
            img = np.expand_dims(img, axis=0)
        # RGB Image or stacked image: 1, 244, 244, 3
        elif (len(img.shape) == 3):
            img = np.expand_dims(img, axis=0)
        return img

    def predict(self, state):
        if np.random.rand() <= self.epsilon:
            return rn.randrange(self.action_size)
        act_values = self.model.predict(state)
        return np.argmax(act_values[0])  # returns action

    def learn(self, batch_size=32, target_train=False):
        if (len(self.memory) < batch_size):
            return
        minibatch = rn.sample(self.memory, batch_size)
        for state, action, reward, next_state, done in minibatch:
            target = reward
            if not done:
                # Updating value of best action taken at the moment
                target = reward + self.gamma * \
                    np.amax(self.target_model.predict(self.conv_to_tensor(next_state))[0])
            target_f = self.target_model.predict(self.conv_to_tensor(state))
            target_f[0][action] = target
            qloss = self.model.fit(self.conv_to_tensor(state),
                                   target_f,
                                   epochs=1,
                                   verbose=0)
            logs = qloss.history['loss'][0]
            self.tensorboard.on_epoch_end(
                self.batch_id, named_logs(self.target_model, [logs]))
            self.A_loss.append(qloss)
        self.batch_id += 1
        if (target_train):
            print("TARGET TRAIN")
            self.target_train()
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

    def target_train(self):
        weights = self.model.get_weights()
        target_weights = self.target_model.get_weights()
        for i in range(len(target_weights)):
            target_weights[i] = weights[i] * self.tau + target_weights[i] * (
                1 - self.tau)
        self.target_model.set_weights(target_weights)
Esempio n. 19
0
class Trainer:
    """Class object to setup and carry the training.

    Takes as input a generator that produces SR images.
    Conditionally, also a discriminator network and a feature extractor
        to build the components of the perceptual loss.
    Compiles the model(s) and trains in a GANS fashion if a discriminator is provided, otherwise
    carries a regular ISR training.

    Args:
        generator: Keras model, the super-scaling, or generator, network.
        discriminator: Keras model, the discriminator network for the adversarial
            component of the perceptual loss.
        feature_extractor: Keras model, feature extractor network for the deep features
            component of perceptual loss function.
        lr_train_dir: path to the directory containing the Low-Res images for training.
        hr_train_dir: path to the directory containing the High-Res images for training.
        lr_valid_dir: path to the directory containing the Low-Res images for validation.
        hr_valid_dir: path to the directory containing the High-Res images for validation.
        learning_rate: float.
        loss_weights: dictionary, use to weigh the components of the loss function.
            Contains 'MSE' for the MSE loss component, and can contain 'discriminator' and 'feat_extr'
            for the discriminator and deep features components respectively.
        logs_dir: path to the directory where the tensorboard logs are saved.
        weights_dir: path to the directory where the weights are saved.
        dataname: string, used to identify what dataset is used for the training session.
        weights_generator: path to the pre-trained generator's weights, for transfer learning.
        weights_discriminator: path to the pre-trained discriminator's weights, for transfer learning.
        n_validation:integer, number of validation samples used at training from the validation set.
        T: 0 < float <1, determines the 'flatness' threshold level for the training patches.
            See the TrainerHelper class for more details.
        lr_decay_frequency: integer, every how many epochs the learning rate is reduced.
        lr_decay_factor: 0 < float <1, learning rate reduction multiplicative factor.

    Methods:
        train: combines the networks and triggers training with the specified settings.

    """
    def __init__(
        self,
        generator,
        discriminator,
        feature_extractor,
        lr_train_dir,
        hr_train_dir,
        lr_valid_dir,
        hr_valid_dir,
        learning_rate=0.0004,
        loss_weights={'MSE': 1.0},
        logs_dir='logs',
        weights_dir='weights',
        dataname=None,
        weights_generator=None,
        weights_discriminator=None,
        n_validation=None,
        T=0.01,
        lr_decay_frequency=100,
        lr_decay_factor=0.5,
    ):
        if discriminator:
            assert generator.patch_size * generator.scale == discriminator.patch_size
        if feature_extractor:
            assert generator.patch_size * generator.scale == feature_extractor.patch_size
        self.generator = generator
        self.discriminator = discriminator
        self.feature_extractor = feature_extractor
        self.scale = generator.scale
        self.lr_patch_size = generator.patch_size
        self.learning_rate = learning_rate
        self.loss_weights = loss_weights
        self.best_metrics = {}
        self.pretrained_weights_path = {
            'generator': weights_generator,
            'discriminator': weights_discriminator,
        }
        self.lr_decay_factor = lr_decay_factor
        self.lr_decay_frequency = lr_decay_frequency

        self.helper = TrainerHelper(
            generator=self.generator,
            weights_dir=weights_dir,
            logs_dir=logs_dir,
            lr_train_dir=lr_train_dir,
            feature_extractor=self.feature_extractor,
            discriminator=self.discriminator,
            dataname=dataname,
            pretrained_weights_path=self.pretrained_weights_path,
        )

        self.model = self._combine_networks()

        self.train_dh = DataHandler(
            lr_dir=lr_train_dir,
            hr_dir=hr_train_dir,
            patch_size=self.lr_patch_size,
            scale=self.scale,
            n_validation_samples=None,
            T=T,
        )
        self.valid_dh = DataHandler(
            lr_dir=lr_valid_dir,
            hr_dir=hr_valid_dir,
            patch_size=self.lr_patch_size,
            scale=self.scale,
            n_validation_samples=n_validation,
            T=0.01,
        )
        self.logger = get_logger(__name__)

    def _combine_networks(self):
        """
        Constructs the combined model which contains the generator network,
        as well as discriminator and geature extractor, if any are defined.
        """

        lr = Input(shape=(self.lr_patch_size, ) * 2 + (3, ))
        sr = self.generator.model(lr)
        outputs = [sr]
        losses = ['mse']
        loss_weights = [self.loss_weights['MSE']]
        if self.discriminator:
            self.discriminator.model.trainable = False
            validity = self.discriminator.model(sr)
            outputs.append(validity)
            losses.append('binary_crossentropy')
            loss_weights.append(self.loss_weights['discriminator'])
        if self.feature_extractor:
            self.feature_extractor.model.trainable = False
            sr_feats = self.feature_extractor.model(sr)
            outputs.extend([*sr_feats])
            losses.extend(['mse'] * len(sr_feats))
            loss_weights.extend(
                [self.loss_weights['feat_extr'] / len(sr_feats)] *
                len(sr_feats))
        combined = Model(inputs=lr, outputs=outputs)
        # https://stackoverflow.com/questions/42327543/adam-optimizer-goes-haywire-after-200k-batches-training-loss-grows
        optimizer = Adam(epsilon=0.0000001)
        combined.compile(loss=losses,
                         loss_weights=loss_weights,
                         optimizer=optimizer,
                         metrics={'generator': PSNR})
        return combined

    def _lr_scheduler(self, epoch):
        """ Scheduler for the learning rate updates. """

        n_decays = epoch // self.lr_decay_frequency
        # no lr below minimum control 10e-6
        return max(1e-6, self.learning_rate * (self.lr_decay_factor**n_decays))

    def _load_weights(self):
        """
        Loads the pretrained weights from the given path, if any is provided.
        If a discriminator is defined, does the same.
        """

        gen_w = self.pretrained_weights_path['generator']
        if gen_w:
            self.model.get_layer('generator').load_weights(gen_w)
        if self.discriminator:
            dis_w = self.pretrained_weights_path['discriminator']
            if dis_w:
                self.model.get_layer('discriminator').load_weights(dis_w)
                self.discriminator.model.load_weights(dis_w)

    def train(self, epochs, steps_per_epoch, batch_size):
        """
        Carries on the training for the given number of epochs.
        Sends the losses to Tensorboard.
        """

        starting_epoch = self.helper.initialize_training(
            self)  # load_weights, creates folders, creates basename
        self.tensorboard = TensorBoard(
            log_dir=self.helper.callback_paths['logs'])
        self.tensorboard.set_model(self.model)

        # validation data
        validation_set = self.valid_dh.get_validation_set(batch_size)
        y_validation = [validation_set['hr']]
        if self.discriminator:
            discr_out_shape = list(
                self.discriminator.model.outputs[0].shape)[1:4]
            valid = np.ones([batch_size] + discr_out_shape)
            fake = np.zeros([batch_size] + discr_out_shape)
            validation_valid = np.ones([len(validation_set['hr'])] +
                                       discr_out_shape)
            y_validation.append(validation_valid)
        if self.feature_extractor:
            validation_feats = self.feature_extractor.model.predict(
                validation_set['hr'])
            y_validation.extend([*validation_feats])

        for epoch in range(starting_epoch, epochs):
            self.logger.info('Epoch {e}/{tot_eps}'.format(e=epoch,
                                                          tot_eps=epochs))
            K.set_value(self.model.optimizer.lr,
                        self._lr_scheduler(epoch=epoch))
            self.logger.info('Current learning rate: {}'.format(
                K.eval(self.model.optimizer.lr)))
            epoch_start = time()
            for step in tqdm(range(steps_per_epoch)):
                batch = self.train_dh.get_batch(batch_size)
                sr = self.generator.model.predict(batch['lr'])
                y_train = [batch['hr']]
                losses = {}

                ## Discriminator training
                if self.discriminator:
                    d_loss_real = self.discriminator.model.train_on_batch(
                        batch['hr'], valid)
                    d_loss_fake = self.discriminator.model.train_on_batch(
                        sr, fake)
                    d_loss_real = dict(
                        zip(
                            [
                                'train_d_real_' + m
                                for m in self.discriminator.model.metrics_names
                            ],
                            d_loss_real,
                        ))
                    d_loss_fake = dict(
                        zip(
                            [
                                'train_d_fake_' + m
                                for m in self.discriminator.model.metrics_names
                            ],
                            d_loss_fake,
                        ))
                    losses.update(d_loss_real)
                    losses.update(d_loss_fake)
                    y_train.append(valid)

                ## Generator training
                if self.feature_extractor:
                    hr_feats = self.feature_extractor.model.predict(
                        batch['hr'])
                    y_train.extend([*hr_feats])

                trainig_loss = self.model.train_on_batch(batch['lr'], y_train)
                losses.update(
                    dict(
                        zip(['train_' + m for m in self.model.metrics_names],
                            trainig_loss)))
                self.tensorboard.on_epoch_end(epoch * steps_per_epoch + step,
                                              losses)
                self.logger.debug('Losses at step {s}:\n {l}'.format(s=step,
                                                                     l=losses))

            elapsed_time = time() - epoch_start
            self.logger.info('Epoch {} took {:10.1f}s'.format(
                epoch, elapsed_time))

            validation_loss = self.model.evaluate(validation_set['lr'],
                                                  y_validation,
                                                  batch_size=batch_size)
            losses = dict(
                zip(['val_' + m for m in self.model.metrics_names],
                    validation_loss))

            monitored_metrics = {}
            if (not self.discriminator) and (not self.feature_extractor):
                monitored_metrics.update({'val_loss': 'min'})
            else:
                monitored_metrics.update({'val_generator_loss': 'min'})

            self.helper.on_epoch_end(
                epoch=epoch,
                losses=losses,
                generator=self.model.get_layer('generator'),
                discriminator=self.discriminator,
                metrics=monitored_metrics,
            )
            self.tensorboard.on_epoch_end(epoch, losses)
        self.tensorboard.on_train_end(None)
Esempio n. 20
0
def train_model(model, data, config, include_tensorboard):
	model_history = History()
	model_history.on_train_begin()
	saver = ModelCheckpoint(full_path(config.model_file()), verbose=1, save_best_only=True, period=1)
	saver.set_model(model)
	early_stopping = EarlyStopping(min_delta=config.min_delta, patience=config.patience, verbose=1)
	early_stopping.set_model(model)
	early_stopping.on_train_begin()
	csv_logger = CSVLogger(full_path(config.csv_log_file()))
	csv_logger.on_train_begin()
	if include_tensorboard:
		tensorborad = TensorBoard(histogram_freq=10, write_images=True)
		tensorborad.set_model(model)
	else:
	 tensorborad = Callback()

	epoch = 0
	stop = False
	while(epoch <= config.max_epochs and stop == False):
		epoch_history = History()
		epoch_history.on_train_begin()
		valid_sizes = []
		train_sizes = []
		print("Epoch:", epoch)
		for dataset in data.datasets:
			print("dataset:", dataset.name)
			model.reset_states()
			dataset.reset_generators()

			valid_sizes.append(dataset.valid_generators[0].size())
			train_sizes.append(dataset.train_generators[0].size())
			fit_history = model.fit_generator(dataset.train_generators[0],
				dataset.train_generators[0].size(), 
				nb_epoch=1, 
				verbose=0, 
				validation_data=dataset.valid_generators[0], 
				nb_val_samples=dataset.valid_generators[0].size())

			epoch_history.on_epoch_end(epoch, last_logs(fit_history))

			train_sizes.append(dataset.train_generators[1].size())
			fit_history = model.fit_generator(dataset.train_generators[1],
				dataset.train_generators[1].size(),
				nb_epoch=1, 
				verbose=0)

			epoch_history.on_epoch_end(epoch, last_logs(fit_history))

		epoch_logs = average_logs(epoch_history, train_sizes, valid_sizes)
		model_history.on_epoch_end(epoch, logs=epoch_logs)
		saver.on_epoch_end(epoch, logs=epoch_logs)
		early_stopping.on_epoch_end(epoch, epoch_logs)
		csv_logger.on_epoch_end(epoch, epoch_logs)
		tensorborad.on_epoch_end(epoch, epoch_logs)
		epoch+= 1

		if early_stopping.stopped_epoch > 0:
			stop = True

	early_stopping.on_train_end()
	csv_logger.on_train_end()
	tensorborad.on_train_end({})
Esempio n. 21
0
                if np.random.randint(0, 2):
                    sel_samples = random.choice(neg_samples)
                else:
                    sel_samples = random.choice(pos_samples)
            # Classification 모델 학습
            loss_class = model_classifier.train_on_batch(
                [X, X2[:, sel_samples, :]],
                [Y1[:, sel_samples, :], Y2[:, sel_samples, :]])

            def named_logs(model, logs):
                result = {}
                for l in zip(model.metrics_names, logs):
                    result[l[0]] = l[1]
                return result

            tensorboard.on_epoch_end(epoch_num,
                                     named_logs(model_classifier, loss_class))

            losses[iter_num, 0] = loss_rpn[1]
            losses[iter_num, 1] = loss_rpn[2]

            losses[iter_num, 2] = loss_class[1]
            losses[iter_num, 3] = loss_class[2]
            losses[iter_num, 4] = loss_class[3]

            iter_num += 1

            progbar.update(iter_num,
                           [('rpn_cls', np.mean(losses[:iter_num, 0])),
                            ('rpn_regr', np.mean(losses[:iter_num, 1])),
                            ('detector_cls', np.mean(losses[:iter_num, 2])),
                            ('detector_regr', np.mean(losses[:iter_num, 3]))])
Esempio n. 22
0
def main(not_parsed_args):
    logging.info('Build dataset')
    train_set = get_training_set(FLAGS.dataset_h, FLAGS.dataset_l,
                                 FLAGS.frames, FLAGS.scale, True,
                                 'filelist.txt', True, FLAGS.patch_size,
                                 FLAGS.future_frame)
    if FLAGS.dataset_val:
        val_set = get_eval_set(FLAGS.dataset_val_h, FLAGS.dataset_val_l,
                               FLAGS.frames, FLAGS.scale, True, 'filelist.txt',
                               True, FLAGS.patch_size, FLAGS.future_frame)

    logging.info('Build model')
    model = RBPN()
    model.summary()
    last_epoch, last_step = load_weights(model)
    model.compile(optimizer=optimizers.Adam(FLAGS.lr),
                  loss=losses.mae,
                  metrics=[psnr])

    # checkpoint = ModelCheckpoint('models/model.hdf5', verbose=1)
    tensorboard = TensorBoard(log_dir='./tf_logs',
                              batch_size=FLAGS.batch_size,
                              write_graph=False,
                              write_grads=True,
                              write_images=True,
                              update_freq='batch')
    tensorboard.set_model(model)

    logging.info('Training start')
    for e in range(last_epoch, FLAGS.epochs):
        tensorboard.on_epoch_begin(e)
        for s in range(last_step + 1, len(train_set) // FLAGS.batch_size):
            tensorboard.on_batch_begin(s)
            x, y = train_set.batch(FLAGS.batch_size)
            loss = model.train_on_batch(x, y)
            print('Epoch %d step %d, loss %f psnr %f' %
                  (e, s, loss[0], loss[1]))
            tensorboard.on_batch_end(s, named_logs(model, loss, s))

            if FLAGS.dataset_val and s > 0 and s % FLAGS.val_interval == 0 or s == len(
                    train_set) // FLAGS.batch_size - 1:
                logging.info('Validation start')
                val_loss = 0
                val_psnr = 0
                for j in range(len(val_set)):
                    x_val, y_val = val_set.batch(1)
                    score = model.test_on_batch(x_val, y_val)
                    val_loss += score[0]
                    val_psnr += score[1]
                val_loss /= len(val_set)
                val_psnr /= len(val_set)
                logging.info('Validation average loss %f psnr %f' %
                             (val_loss, val_psnr))

            if s > 0 and s % FLAGS.save_interval == 0 or s == len(
                    train_set) // FLAGS.batch_size - 1:
                logging.info('Saving model')
                filename = 'model_%d_%d.h5' % (e, s)
                path = os.path.join(FLAGS.model_dir, filename)
                path_info = os.path.join(FLAGS.model_dir, 'info')
                model.save_weights(path)
                f = open(path_info, 'w')
                f.write(filename)
                f.close()
        tensorboard.on_epoch_end(e)
        last_step = -1
Esempio n. 23
0
def main(dataset, batch_size, patch_size, epochs, label_smoothing,
         label_flipping):
    print(project_dir)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True  # dynamically grow the memory used on the GPU
    sess = tf.Session(config=config)
    K.tensorflow_backend.set_session(
        sess)  # set this TensorFlow session as the default session for Keras

    image_data_format = "channels_first"
    K.set_image_data_format(image_data_format)

    save_images_every_n_batches = 30
    save_model_every_n_epochs = 0

    # configuration parameters
    print("Config params:")
    print("  dataset = {}".format(dataset))
    print("  batch_size = {}".format(batch_size))
    print("  patch_size = {}".format(patch_size))
    print("  epochs = {}".format(epochs))
    print("  label_smoothing = {}".format(label_smoothing))
    print("  label_flipping = {}".format(label_flipping))
    print("  save_images_every_n_batches = {}".format(
        save_images_every_n_batches))
    print("  save_model_every_n_epochs = {}".format(save_model_every_n_epochs))

    model_name = datetime.strftime(datetime.now(), '%y%m%d-%H%M')
    model_dir = os.path.join(project_dir, "models", model_name)
    fig_dir = os.path.join(project_dir, "reports", "figures")
    logs_dir = os.path.join(project_dir, "reports", "logs", model_name)

    os.makedirs(model_dir)

    # Load and rescale data
    ds_train_gen = data_utils.DataGenerator(file_path=dataset,
                                            dataset_type="train",
                                            batch_size=batch_size)
    ds_train_disc = data_utils.DataGenerator(file_path=dataset,
                                             dataset_type="train",
                                             batch_size=batch_size)
    ds_val = data_utils.DataGenerator(file_path=dataset,
                                      dataset_type="val",
                                      batch_size=batch_size)
    enq_train_gen = OrderedEnqueuer(ds_train_gen,
                                    use_multiprocessing=True,
                                    shuffle=True)
    enq_train_disc = OrderedEnqueuer(ds_train_disc,
                                     use_multiprocessing=True,
                                     shuffle=True)
    enq_val = OrderedEnqueuer(ds_val, use_multiprocessing=True, shuffle=False)

    img_dim = ds_train_gen[0][0].shape[-3:]

    n_batch_per_epoch = len(ds_train_gen)
    epoch_size = n_batch_per_epoch * batch_size

    print("Derived params:")
    print("  n_batch_per_epoch = {}".format(n_batch_per_epoch))
    print("  epoch_size = {}".format(epoch_size))
    print("  n_batches_val = {}".format(len(ds_val)))

    # Get the number of non overlapping patch and the size of input image to the discriminator
    nb_patch, img_dim_disc = data_utils.get_nb_patch(img_dim, patch_size)

    tensorboard = TensorBoard(log_dir=logs_dir,
                              histogram_freq=0,
                              batch_size=batch_size,
                              write_graph=True,
                              write_grads=True,
                              update_freq='batch')

    try:
        # Create optimizers
        opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
        # opt_discriminator = SGD(lr=1E-3, momentum=0.9, nesterov=True)
        opt_discriminator = Adam(lr=1E-3,
                                 beta_1=0.9,
                                 beta_2=0.999,
                                 epsilon=1e-08)

        # Load generator model
        generator_model = models.generator_unet_upsampling(img_dim)
        generator_model.summary()
        plot_model(generator_model,
                   to_file=os.path.join(fig_dir, "generator_model.png"),
                   show_shapes=True,
                   show_layer_names=True)

        # Load discriminator model
        # TODO: modify disc to accept real input as well
        discriminator_model = models.DCGAN_discriminator(
            img_dim_disc, nb_patch)
        discriminator_model.summary()
        plot_model(discriminator_model,
                   to_file=os.path.join(fig_dir, "discriminator_model.png"),
                   show_shapes=True,
                   show_layer_names=True)

        # TODO: pretty sure this is unnecessary
        generator_model.compile(loss='mae', optimizer=opt_discriminator)
        discriminator_model.trainable = False

        DCGAN_model = models.DCGAN(generator_model, discriminator_model,
                                   img_dim, patch_size, image_data_format)

        # L1 loss applies to generated image, cross entropy applies to predicted label
        loss = [models.l1_loss, 'binary_crossentropy']
        loss_weights = [1E1, 1]
        DCGAN_model.compile(loss=loss,
                            loss_weights=loss_weights,
                            optimizer=opt_dcgan)

        discriminator_model.trainable = True
        discriminator_model.compile(loss='binary_crossentropy',
                                    optimizer=opt_discriminator)

        tensorboard.set_model(DCGAN_model)

        # Start training
        enq_train_gen.start(workers=1, max_queue_size=20)
        enq_train_disc.start(workers=1, max_queue_size=20)
        enq_val.start(workers=1, max_queue_size=20)
        out_train_gen = enq_train_gen.get()
        out_train_disc = enq_train_disc.get()
        out_val = enq_val.get()

        print("Start training")
        for e in range(1, epochs + 1):
            # Initialize progbar and batch counter
            progbar = generic_utils.Progbar(epoch_size)
            start = time.time()

            for batch_counter in range(1, n_batch_per_epoch + 1):
                X_transformed_batch, X_orig_batch = next(out_train_disc)

                # Create a batch to feed the discriminator model
                X_disc, y_disc = data_utils.get_disc_batch(
                    X_transformed_batch,
                    X_orig_batch,
                    generator_model,
                    batch_counter,
                    patch_size,
                    label_smoothing=label_smoothing,
                    label_flipping=label_flipping)

                # Update the discriminator
                disc_loss = discriminator_model.train_on_batch(X_disc, y_disc)

                # Create a batch to feed the generator model
                X_gen_target, X_gen = next(out_train_gen)
                y_gen = np.zeros((X_gen.shape[0], 2), dtype=np.uint8)
                # Set labels to 1 (real) to maximize the discriminator loss
                y_gen[:, 1] = 1

                # Freeze the discriminator
                discriminator_model.trainable = False
                gen_loss = DCGAN_model.train_on_batch(X_gen,
                                                      [X_gen_target, y_gen])
                # Unfreeze the discriminator
                discriminator_model.trainable = True

                metrics = [("D logloss", disc_loss), ("G tot", gen_loss[0]),
                           ("G L1", gen_loss[1]), ("G logloss", gen_loss[2])]
                progbar.add(batch_size, values=metrics)

                logs = {k: v for (k, v) in metrics}
                logs["size"] = batch_size

                tensorboard.on_batch_end(batch_counter, logs=logs)

                # Save images for visualization
                if batch_counter % save_images_every_n_batches == 0:
                    # Get new images from validation
                    data_utils.plot_generated_batch(
                        X_transformed_batch, X_orig_batch, generator_model,
                        os.path.join(logs_dir, "current_batch_training.png"))
                    X_transformed_batch, X_orig_batch = next(out_val)
                    data_utils.plot_generated_batch(
                        X_transformed_batch, X_orig_batch, generator_model,
                        os.path.join(logs_dir, "current_batch_validation.png"))

            print("")
            print('Epoch %s/%s, Time: %s' % (e, epochs, time.time() - start))
            tensorboard.on_epoch_end(e, logs=logs)

            if (save_model_every_n_epochs >= 1 and e % save_model_every_n_epochs == 0) or \
                    (e == epochs):
                print("Saving model for epoch {}...".format(e), end="")
                sys.stdout.flush()
                gen_weights_path = os.path.join(
                    model_dir, 'gen_weights_epoch{:03d}.h5'.format(e))
                generator_model.save_weights(gen_weights_path, overwrite=True)

                disc_weights_path = os.path.join(
                    model_dir, 'disc_weights_epoch{:03d}.h5'.format(e))
                discriminator_model.save_weights(disc_weights_path,
                                                 overwrite=True)

                DCGAN_weights_path = os.path.join(
                    model_dir, 'DCGAN_weights_epoch{:03d}.h5'.format(e))
                DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True)
                print("done")

    except KeyboardInterrupt:
        pass

    enq_train_gen.stop()
    enq_train_disc.stop()
    enq_val.stop()
Esempio n. 24
0
class ExtendedLogger(Callback):

    val_data_metrics = {}

    def __init__(self,
                 prediction_layer,
                 output_dir='./tmp',
                 stateful=False,
                 stateful_reset_interval=None,
                 starting_indicies=None):

        if stateful and stateful_reset_interval is None:
            raise ValueError(
                'If model is stateful, then seq-len has to be defined!')

        super(ExtendedLogger, self).__init__()

        self.csv_dir = os.path.join(output_dir, 'csv')
        self.tb_dir = os.path.join(output_dir, 'tensorboard')
        self.pred_dir = os.path.join(output_dir, 'predictions')
        self.plot_dir = os.path.join(output_dir, 'plots')

        make_dir(self.csv_dir)
        make_dir(self.tb_dir)
        make_dir(self.plot_dir)
        make_dir(self.pred_dir)

        self.stateful = stateful
        self.stateful_reset_interval = stateful_reset_interval
        self.starting_indicies = starting_indicies
        self.csv_logger = CSVLogger(os.path.join(self.csv_dir, 'run.csv'))
        self.tensorboard = TensorBoard(log_dir=self.tb_dir, write_graph=True)
        self.prediction_layer = prediction_layer

    def set_params(self, params):
        super(ExtendedLogger, self).set_params(params)
        self.tensorboard.set_params(params)
        self.tensorboard.batch_size = params['batch_size']
        self.csv_logger.set_params(params)

    def set_model(self, model):
        super(ExtendedLogger, self).set_model(model)
        self.tensorboard.set_model(model)
        self.csv_logger.set_model(model)

    def on_batch_begin(self, batch, logs=None):
        self.csv_logger.on_batch_begin(batch, logs=logs)
        self.tensorboard.on_batch_begin(batch, logs=logs)

    def on_batch_end(self, batch, logs=None):
        self.csv_logger.on_batch_end(batch, logs=logs)
        self.tensorboard.on_batch_end(batch, logs=logs)

    def on_train_begin(self, logs=None):
        self.csv_logger.on_train_begin(logs=logs)
        self.tensorboard.on_train_begin(logs=logs)

    def on_train_end(self, logs=None):
        self.csv_logger.on_train_end(logs=logs)
        self.tensorboard.on_train_end(logs)

    def on_epoch_begin(self, epoch, logs=None):
        self.csv_logger.on_epoch_begin(epoch, logs=logs)
        self.tensorboard.on_epoch_begin(epoch, logs=logs)

    def on_epoch_end(self, epoch, logs=None):

        with timeit('metrics'):

            outputs = self.model.get_layer(self.prediction_layer).output
            self.prediction_model = Model(inputs=self.model.input,
                                          outputs=outputs)

            batch_size = self.params['batch_size']

            if isinstance(self.validation_data[-1], float):
                val_data = self.validation_data[:-2]
            else:
                val_data = self.validation_data[:-1]

            y_true = val_data[1]

            callback = None
            if self.stateful:
                callback = ResetStatesCallback(
                    interval=self.stateful_reset_interval)
                callback.model = self.prediction_model

            y_pred = self.prediction_model.predict(val_data[:-1],
                                                   batch_size=batch_size,
                                                   verbose=1,
                                                   callback=callback)

            print(y_true.shape, y_pred.shape)

            self.write_prediction(epoch, y_true, y_pred)

            y_true = y_true.reshape((-1, 7))
            y_pred = y_pred.reshape((-1, 7))

            self.save_error_histograms(epoch, y_true, y_pred)
            self.save_topview_trajectories(epoch, y_true, y_pred)

            new_logs = {
                name: np.array(metric(y_true, y_pred))
                for name, metric in self.val_data_metrics.items()
            }
            logs.update(new_logs)

            homo_logs = self.try_add_homoscedastic_params()
            logs.update(homo_logs)

            self.tensorboard.validation_data = self.validation_data
            self.csv_logger.validation_data = self.validation_data

            self.tensorboard.on_epoch_end(epoch, logs=logs)
            self.csv_logger.on_epoch_end(epoch, logs=logs)

    def add_validation_metrics(self, metrics_dict):
        self.val_data_metrics.update(metrics_dict)

    def add_validation_metric(self, name, metric):
        self.val_data_metrics[name] = metric

    def try_add_homoscedastic_params(self):
        homo_pos_loss_layer = search_layer(self.model, 'homo_pos_loss')
        homo_quat_loss_layer = search_layer(self.model, 'homo_quat_loss')

        if homo_pos_loss_layer:
            homo_pos_log_vars = np.array(homo_pos_loss_layer.get_weights()[0])
            homo_quat_log_vars = np.array(
                homo_quat_loss_layer.get_weights()[0])
            return {
                'pos_log_var': np.array(homo_pos_log_vars),
                'quat_log_var': np.array(homo_quat_log_vars),
            }
        else:
            return {}

    def write_prediction(self, epoch, y_true, y_pred):
        filename = '{:04d}_predictions.npy'.format(epoch)
        filename = os.path.join(self.pred_dir, filename)
        arr = {'y_pred': y_pred, 'y_true': y_true}
        np.save(filename, arr)

    def save_topview_trajectories(self,
                                  epoch,
                                  y_true,
                                  y_pred,
                                  max_segment=1000):

        if self.starting_indicies is None:
            self.starting_indicies = {'valid': range(0, 4000, 1000) + [4000]}

        for begin, end in pairwise(self.starting_indicies['valid']):

            diff = end - begin
            if diff > max_segment:
                subindicies = range(begin, end, max_segment) + [end]
                for b, e in pairwise(subindicies):
                    self.save_trajectory(epoch, y_true, y_pred, b, e)

            self.save_trajectory(epoch, y_true, y_pred, begin, end)

    def save_trajectory(self, epoch, y_true, y_pred, begin, end):
        true_xy, pred_xy = y_true[begin:end, :2], y_pred[begin:end, :2]

        true_q = quaternion.as_quat_array(y_true[begin:end, [6, 3, 4, 5]])
        true_q = quaternion.as_euler_angles(true_q)[1]

        pred_q = quaternion.as_quat_array(y_pred[begin:end, [6, 3, 4, 5]])
        pred_q = quaternion.as_euler_angles(pred_q)[1]

        plt.clf()

        plt.plot(true_xy[:, 0], true_xy[:, 1], 'g-')
        plt.plot(pred_xy[:, 0], pred_xy[:, 1], 'r-')

        for ((x1, y1), (x2, y2)) in zip(true_xy, pred_xy):
            plt.plot([x1, x2], [y1, y2],
                     color='k',
                     linestyle='-',
                     linewidth=0.3,
                     alpha=0.2)

        plt.grid(True)
        plt.xlabel('x [m]')
        plt.ylabel('y [m]')
        plt.title('Top-down view of trajectory')
        plt.axis('equal')

        x_range = (np.min(true_xy[:, 0]) - .2, np.max(true_xy[:, 0]) + .2)
        y_range = (np.min(true_xy[:, 1]) - .2, np.max(true_xy[:, 1]) + .2)

        plt.xlim(x_range)
        plt.ylim(y_range)

        filename = 'epoch={epoch:04d}_begin={begin:04d}_end={end:04d}_trajectory.pdf' \
          .format(epoch=epoch, begin=begin, end=end)
        filename = os.path.join(self.plot_dir, filename)
        plt.savefig(filename)

    def save_error_histograms(self, epoch, y_true, y_pred):
        pos_errors = PoseMetrics.abs_errors_position(y_true, y_pred)
        pos_errors = np.sort(pos_errors)

        angle_errors = PoseMetrics.abs_errors_orienation(y_true, y_pred)
        angle_errors = np.sort(angle_errors)

        size = len(y_true)
        ys = np.arange(size) / float(size)

        plt.clf()

        plt.subplot(2, 1, 1)
        plt.title('Empirical CDF of absolute errors')
        plt.grid(True)
        plt.plot(pos_errors, ys, 'k-')
        plt.xlabel('Absolute Position Error (m)')
        plt.xlim(0, 1.2)

        plt.subplot(2, 1, 2)
        plt.grid(True)
        plt.plot(angle_errors, ys, 'r-')
        plt.xlabel('Absolute Angle Error (deg)')
        plt.xlim(0, 70)

        filename = '{:04d}_cdf.pdf'.format(epoch)
        filename = os.path.join(self.plot_dir, filename)
        plt.savefig(filename)
Esempio n. 25
0
    def train(self, epochs, batch_size=1, sample_interval=50):

        tensorboard = TensorBoard(batch_size=batch_size, write_grads=True)
        tensorboard.set_model(self.combined)

        def named_logs(model, logs):
            result = {}
            for l in zip(model.metrics_names, logs):
                result[l[0]] = l[1]
            return result

        start_time = datetime.datetime.now()

        for epoch in range(epochs):

            # ----------------------
            #  Train Discriminator
            # ----------------------

            # Sample images and their conditioning counterparts
            imgs_hr, imgs_lr = self.data_loader.load_data(batch_size)

            # From low res. image generate high res. version
            fake_hr = self.generator.predict(imgs_lr)

            valid = np.ones((batch_size, ) + self.disc_patch)
            fake = np.zeros((batch_size, ) + self.disc_patch)

            # Train the discriminators (original images = real / generated = Fake)
            d_loss_real = self.discriminator.train_on_batch(imgs_hr, valid)
            d_loss_fake = self.discriminator.train_on_batch(fake_hr, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ------------------
            #  Train Generator
            # ------------------

            # Sample images and their conditioning counterparts
            imgs_hr, imgs_lr = self.data_loader.load_data(batch_size)

            # The generators want the discriminators to label the generated images as real
            valid = np.ones((batch_size, ) + self.disc_patch)

            # Extract ground truth image features using pre-trained VGG19 model
            image_features = self.vgg.predict(imgs_hr)

            # Train the generators
            g_loss = self.combined.train_on_batch([imgs_lr, imgs_hr],
                                                  [valid, image_features])

            elapsed_time = datetime.datetime.now() - start_time
            # Plot the progress
            print(
                "[Epoch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f] time: %s " \
                % (epoch, epochs,
                   d_loss[0], 100 * d_loss[1],
                   g_loss[0],
                   elapsed_time))

            tensorboard.on_epoch_end(epoch, named_logs(self.combined, g_loss))

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.sample_images(epoch)
                self.combined.save_weights(
                    f"saved_model/{self.dataset_name}/{epoch}.h5")
Esempio n. 26
0
class LogWriter():
    def __init__(self,
                 root_dir,
                 batch_size,
                 histogram_freq=0,
                 write_graph=True,
                 write_grads=False):

        self.root_dir = root_dir

        self.start_time = time.time()

        if not os.path.exists(self.root_dir):
            os.mkdir(self.root_dir)
            print('*** Create folder: {} ***'.format(self.root_dir))

        now_time = time.strftime('%y%m%d_%H%M%S', time.localtime())
        self.root_dir_with_datetime = os.path.join(self.root_dir,
                                                   now_time).replace(
                                                       '\\', '/')
        if not os.path.exists(self.root_dir_with_datetime):
            os.mkdir(self.root_dir_with_datetime)
            print('*** Create folder: {} ***'.format(
                self.root_dir_with_datetime))

        os.mkdir(
            os.path.join(self.root_dir_with_datetime,
                         "logs").replace('\\', '/'))
        os.mkdir(
            os.path.join(self.root_dir_with_datetime,
                         "csv").replace('\\', '/'))
        os.mkdir(
            os.path.join(self.root_dir_with_datetime,
                         "models").replace('\\', '/'))
        os.mkdir(
            os.path.join(self.root_dir_with_datetime,
                         "movies").replace('\\', '/'))

        self.tb = TensorBoard(log_dir=os.path.join(self.root_dir_with_datetime,
                                                   "logs").replace('\\', '/'),
                              histogram_freq=histogram_freq,
                              batch_size=batch_size,
                              write_graph=write_graph,
                              write_grads=write_grads)

        # count batch
        self.batch_id = 0

        # minimum reward
        self.max_reward = -1e4

        # count iteration
        self.iteration = 1

    def get_movie_pass(self):
        return os.path.join(self.root_dir_with_datetime,
                            "movies").replace('\\', '/')

    def add_loss(self, losses):
        # log losses into tensorboard
        for loss, name in zip(losses, self.loss_names):
            summary = tf.Summary()
            summary.value.add(tag=name, simple_value=loss)
            self.tb.writer.add_summary(summary, self.batch_id)
            self.tb.writer.flush()
            self.tb.on_epoch_end(self.batch_id)

        # log losses into csv
        with open(os.path.join(self.root_dir_with_datetime, 'csv',
                               'loss.csv').replace('\\', '/'),
                  'a',
                  newline='') as f:
            writer = csv.writer(f)
            writer.writerow([self.batch_id, *losses])

        self.batch_id += 1

    def set_loss_name(self, names):
        """ set the first row for loss.csv """
        self.loss_names = names
        with open(os.path.join(self.root_dir_with_datetime, 'csv',
                               'loss.csv').replace('\\', '/'),
                  'w',
                  newline='') as f:
            writer = csv.writer(f)
            writer.writerow(self.loss_names)

    def add_reward(self, episode, reward, info={}):
        """ record episode_reward and max_episode_reward """

        # Standard output
        print(episode, ":", reward, end="")

        for key in info.keys():
            print(", ", key, ":", info[key], end="")

        print()

        # log the max_episode_reward
        if self.max_reward < reward:
            self.max_reward = reward

        with open(os.path.join(self.root_dir_with_datetime, 'csv',
                               'max_reward.csv').replace('\\', '/'),
                  'a',
                  newline='') as f:
            writer = csv.writer(f)
            summary = tf.Summary()
            summary.value.add(tag="max_episode_reward",
                              simple_value=self.max_reward)
            for i in range(info['steps']):
                iteration = self.iteration - info['steps'] + i
                self.tb.writer.add_summary(summary, iteration)
                writer.writerow((iteration, self.max_reward))
            self.tb.writer.flush()

        # log episode_reward
        with open(os.path.join(self.root_dir_with_datetime, 'csv',
                               'reward.csv').replace('\\', '/'),
                  'a',
                  newline='') as f:

            # log episode_reward into tensorboard
            summary = tf.Summary()
            summary.value.add(tag="episode_reward",
                              simple_value=reward)  # change
            self.tb.writer.add_summary(summary, episode)
            self.tb.writer.flush()

            # log episode_reward into csv
            writer = csv.writer(f)
            writer.writerow((episode, reward))

    def count_iteration(self):
        """ count the iteration """
        self.iteration += 1

    def save_weights(self, agent, info=''):
        agent.save_weights(
            os.path.join(self.root_dir_with_datetime,
                         'models').replace('\\', '/'), info)

    def save_model_arch(self, agent):
        agent.save_model_arch(
            os.path.join(self.root_dir_with_datetime,
                         'models').replace('\\', '/'))

    def save_evaluate_rewards(self, evaluate_rewards):
        with open(os.path.join(self.root_dir_with_datetime,
                               'evaluate_rewards.csv').replace('\\', '/'),
                  'w',
                  newline='') as f:
            writer = csv.writer(f)

            # write the avg_eps_reward at the first line
            avg_eps_reward = sum(evaluate_rewards) / len(evaluate_rewards)

            writer.writerow((avg_eps_reward, ))
            for reward in evaluate_rewards:
                writer.writerow((reward, ))

    def set_model(self, model):
        self.tb.set_model(model)

    def save_setting(self, args):
        with open(os.path.join(self.root_dir_with_datetime,
                               'setting.csv').replace('\\', '/'),
                  'w',
                  newline='') as f:
            writer = csv.writer(f)
            for k, v in vars(args).items():
                writer.writerow((k, v))
                print(k, v)

    def log_total_time_cost(self):
        """ Call it at the end of the code """
        with open(os.path.join(self.root_dir_with_datetime,
                               'setting.csv').replace('\\', '/'),
                  'a',
                  newline='') as f:
            writer = csv.writer(f)
            writer.writerow(('total_time_cost', time.time() - self.start_time))
            print('*** total_time_cost:{} ***'.format(time.time() -
                                                      self.start_time))

    def store_memories(self, agent):
        if agent.memory_storation_size:
            agent.store_memories(self.root_dir_with_datetime)
Esempio n. 27
0
class Trainer:
    """Class object to setup and carry the training.

    Takes as input a generator that produces SR images.
    Conditionally, also a discriminator network and a feature extractor
        to build the components of the perceptual loss.
    Compiles the model(s) and trains in a GANS fashion if a discriminator is provided, otherwise
    carries a regular ISR training.

    Args:
        generator: Keras model, the super-scaling, or generator, network.
        discriminator: Keras model, the discriminator network for the adversarial
            component of the perceptual loss.
        feature_extractor: Keras model, feature extractor network for the deep features
            component of perceptual loss function.
        lr_train_dir: path to the directory containing the Low-Res images for training.
        hr_train_dir: path to the directory containing the High-Res images for training.
        lr_valid_dir: path to the directory containing the Low-Res images for validation.
        hr_valid_dir: path to the directory containing the High-Res images for validation.
        learning_rate: float.
        loss_weights: dictionary, use to weigh the components of the loss function.
            Contains 'generator' for the generator loss component, and can contain 'discriminator' and 'feature_extractor'
            for the discriminator and deep features components respectively.
        logs_dir: path to the directory where the tensorboard logs are saved.
        weights_dir: path to the directory where the weights are saved.
        dataname: string, used to identify what dataset is used for the training session.
        weights_generator: path to the pre-trained generator's weights, for transfer learning.
        weights_discriminator: path to the pre-trained discriminator's weights, for transfer learning.
        n_validation:integer, number of validation samples used at training from the validation set.
        flatness: dictionary. Determines determines the 'flatness' threshold level for the training patches.
            See the TrainerHelper class for more details.
        lr_decay_frequency: integer, every how many epochs the learning rate is reduced.
        lr_decay_factor: 0 < float <1, learning rate reduction multiplicative factor.

    Methods:
        train: combines the networks and triggers training with the specified settings.

    """
    def __init__(
        self,
        generator,
        discriminator,
        feature_extractor,
        lr_train_dir,
        hr_train_dir,
        lr_valid_dir,
        hr_valid_dir,
        loss_weights={
            'generator': 1.0,
            'discriminator': 0.003,
            'feature_extractor': 1 / 12
        },
        log_dirs={
            'logs': 'logs',
            'weights': 'weights'
        },
        fallback_save_every_n_epochs=2,
        dataname=None,
        weights_generator=None,
        weights_discriminator=None,
        n_validation=None,
        flatness={
            'min': 0.0,
            'increase_frequency': None,
            'increase': 0.0,
            'max': 0.0
        },
        learning_rate={
            'initial_value': 0.0004,
            'decay_frequency': 100,
            'decay_factor': 0.5
        },
        adam_optimizer={
            'beta1': 0.9,
            'beta2': 0.999,
            'epsilon': None
        },
        losses={
            'generator': 'mae',
            'discriminator': 'binary_crossentropy',
            'feature_extractor': 'mse',
        },
        metrics={'generator': 'PSNR_Y'},
    ):
        self.generator = generator
        self.discriminator = discriminator
        self.feature_extractor = feature_extractor
        self.scale = generator.scale
        self.lr_patch_size = generator.patch_size
        self.learning_rate = learning_rate
        self.loss_weights = loss_weights
        self.weights_generator = weights_generator
        self.weights_discriminator = weights_discriminator
        self.adam_optimizer = adam_optimizer
        self.dataname = dataname
        self.flatness = flatness
        self.n_validation = n_validation
        self.losses = losses
        self.log_dirs = log_dirs
        self.metrics = metrics
        if self.metrics['generator'] == 'PSNR_Y':
            self.metrics['generator'] = PSNR_Y
        elif self.metrics['generator'] == 'PSNR':
            self.metrics['generator'] = PSNR
        self._parameters_sanity_check()
        self.model = self._combine_networks()

        self.settings = {}
        self.settings['training_parameters'] = locals()
        self.settings['training_parameters'][
            'lr_patch_size'] = self.lr_patch_size
        self.settings = self.update_training_config(self.settings)

        self.logger = get_logger(__name__)

        self.helper = TrainerHelper(
            generator=self.generator,
            weights_dir=log_dirs['weights'],
            logs_dir=log_dirs['logs'],
            lr_train_dir=lr_train_dir,
            feature_extractor=self.feature_extractor,
            discriminator=self.discriminator,
            dataname=dataname,
            weights_generator=self.weights_generator,
            weights_discriminator=self.weights_discriminator,
            fallback_save_every_n_epochs=fallback_save_every_n_epochs,
        )

        self.train_dh = DataHandler(
            lr_dir=lr_train_dir,
            hr_dir=hr_train_dir,
            patch_size=self.lr_patch_size,
            scale=self.scale,
            n_validation_samples=None,
        )
        self.valid_dh = DataHandler(
            lr_dir=lr_valid_dir,
            hr_dir=hr_valid_dir,
            patch_size=self.lr_patch_size,
            scale=self.scale,
            n_validation_samples=n_validation,
        )

    def _parameters_sanity_check(self):
        """ Parameteres sanity check. """

        if self.discriminator:
            assert self.lr_patch_size * self.scale == self.discriminator.patch_size
            self.adam_optimizer
        if self.feature_extractor:
            assert self.lr_patch_size * self.scale == self.feature_extractor.patch_size

        check_parameter_keys(
            self.learning_rate,
            needed_keys=['initial_value'],
            optional_keys=['decay_factor', 'decay_frequency'],
            default_value=None,
        )
        check_parameter_keys(
            self.flatness,
            needed_keys=[],
            optional_keys=['min', 'increase_frequency', 'increase', 'max'],
            default_value=0.0,
        )
        check_parameter_keys(
            self.adam_optimizer,
            needed_keys=['beta1', 'beta2'],
            optional_keys=['epsilon'],
            default_value=None,
        )
        check_parameter_keys(self.log_dirs, needed_keys=['logs', 'weights'])

    def _combine_networks(self):
        """
        Constructs the combined model which contains the generator network,
        as well as discriminator and geature extractor, if any are defined.
        """

        lr = Input(shape=(self.lr_patch_size, ) * 2 + (3, ))
        sr = self.generator.model(lr)
        outputs = [sr]
        losses = [self.losses['generator']]
        loss_weights = [self.loss_weights['generator']]

        if self.discriminator:
            self.discriminator.model.trainable = False
            validity = self.discriminator.model(sr)
            outputs.append(validity)
            losses.append(self.losses['discriminator'])
            loss_weights.append(self.loss_weights['discriminator'])
        if self.feature_extractor:
            self.feature_extractor.model.trainable = False
            sr_feats = self.feature_extractor.model(sr)
            outputs.extend([*sr_feats])
            losses.extend([self.losses['feature_extractor']] * len(sr_feats))
            loss_weights.extend(
                [self.loss_weights['feature_extractor'] / len(sr_feats)] *
                len(sr_feats))
        combined = Model(inputs=lr, outputs=outputs)
        # https://stackoverflow.com/questions/42327543/adam-optimizer-goes-haywire-after-200k-batches-training-loss-grows
        optimizer = Adam(
            beta_1=self.adam_optimizer['beta1'],
            beta_2=self.adam_optimizer['beta2'],
            lr=self.learning_rate['initial_value'],
            epsilon=self.adam_optimizer['epsilon'],
        )
        combined.compile(loss=losses,
                         loss_weights=loss_weights,
                         optimizer=optimizer,
                         metrics=self.metrics)
        return combined

    def _lr_scheduler(self, epoch):
        """ Scheduler for the learning rate updates. """

        n_decays = epoch // self.learning_rate['decay_frequency']
        lr = self.learning_rate['initial_value'] * (
            self.learning_rate['decay_factor']**n_decays)
        # no lr below minimum control 10e-7
        return max(1e-7, lr)

    def _flatness_scheduler(self, epoch):
        if self.flatness['increase']:
            n_increases = epoch // self.flatness['increase_frequency']
        else:
            return self.flatness['min']

        f = self.flatness['min'] + n_increases * self.flatness['increase']

        return min(self.flatness['max'], f)

    def _load_weights(self):
        """
        Loads the pretrained weights from the given path, if any is provided.
        If a discriminator is defined, does the same.
        """

        if self.weights_generator:
            self.model.get_layer('generator').load_weights(
                self.weights_generator)

        if self.discriminator:
            if self.weights_discriminator:
                self.model.get_layer('discriminator').load_weights(
                    self.weights_discriminator)
                self.discriminator.model.load_weights(
                    self.weights_discriminator)

    def _format_losses(self, prefix, losses, model_metrics):
        """ Creates a dictionary for tensorboard tracking. """

        return dict(zip([prefix + m for m in model_metrics], losses))

    def update_training_config(self, settings):
        """ Summarizes training setting. """

        _ = settings['training_parameters'].pop('weights_generator')
        _ = settings['training_parameters'].pop('self')
        _ = settings['training_parameters'].pop('generator')
        _ = settings['training_parameters'].pop('discriminator')
        _ = settings['training_parameters'].pop('feature_extractor')
        settings['generator'] = {}
        settings['generator']['name'] = self.generator.name
        settings['generator']['parameters'] = self.generator.params
        settings['generator']['weights_generator'] = self.weights_generator

        _ = settings['training_parameters'].pop('weights_discriminator')
        if self.discriminator:
            settings['discriminator'] = {}
            settings['discriminator']['name'] = self.discriminator.name
            settings['discriminator'][
                'weights_discriminator'] = self.weights_discriminator
        else:
            settings['discriminator'] = None

        if self.discriminator:
            settings['feature_extractor'] = {}
            settings['feature_extractor']['name'] = self.feature_extractor.name
            settings['feature_extractor'][
                'layers'] = self.feature_extractor.layers_to_extract
        else:
            settings['feature_extractor'] = None

        return settings

    def train(self, epochs, steps_per_epoch, batch_size, monitored_metrics):
        """
        Carries on the training for the given number of epochs.
        Sends the losses to Tensorboard.

        Args:
            epochs: how many epochs to train for.
            steps_per_epoch: how many batches epoch.
            batch_size: amount of images per batch.
            monitored_metrics: dictionary, the keys are the metrics that are monitored for the weights
                saving logic. The values are the mode that trigger the weights saving ('min' vs 'max').
        """

        self.settings['training_parameters'][
            'steps_per_epoch'] = steps_per_epoch
        self.settings['training_parameters']['batch_size'] = batch_size
        starting_epoch = self.helper.initialize_training(
            self)  # load_weights, creates folders, creates basename

        self.tensorboard = TensorBoard(
            log_dir=self.helper.callback_paths['logs'])
        self.tensorboard.set_model(self.model)

        # validation data
        validation_set = self.valid_dh.get_validation_set(batch_size)
        y_validation = [validation_set['hr']]
        if self.discriminator:
            discr_out_shape = list(
                self.discriminator.model.outputs[0].shape)[1:4]
            valid = np.ones([batch_size] + discr_out_shape)
            fake = np.zeros([batch_size] + discr_out_shape)
            validation_valid = np.ones([len(validation_set['hr'])] +
                                       discr_out_shape)
            y_validation.append(validation_valid)
        if self.feature_extractor:
            validation_feats = self.feature_extractor.model.predict(
                validation_set['hr'])
            y_validation.extend([*validation_feats])

        for epoch in range(starting_epoch, epochs):
            self.logger.info('Epoch {e}/{tot_eps}'.format(e=epoch,
                                                          tot_eps=epochs))
            K.set_value(self.model.optimizer.lr,
                        self._lr_scheduler(epoch=epoch))
            self.logger.info('Current learning rate: {}'.format(
                K.eval(self.model.optimizer.lr)))

            flatness = self._flatness_scheduler(epoch)
            if flatness:
                self.logger.info(
                    'Current flatness treshold: {}'.format(flatness))

            epoch_start = time()
            for step in tqdm(range(steps_per_epoch)):
                batch = self.train_dh.get_batch(batch_size, flatness=flatness)
                y_train = [batch['hr']]
                training_losses = {}

                ## Discriminator training
                if self.discriminator:
                    sr = self.generator.model.predict(batch['lr'])
                    d_loss_real = self.discriminator.model.train_on_batch(
                        batch['hr'], valid)
                    d_loss_fake = self.discriminator.model.train_on_batch(
                        sr, fake)
                    d_loss_fake = self._format_losses(
                        'train_d_fake_', d_loss_fake,
                        self.discriminator.model.metrics_names)
                    d_loss_real = self._format_losses(
                        'train_d_real_', d_loss_real,
                        self.discriminator.model.metrics_names)
                    training_losses.update(d_loss_real)
                    training_losses.update(d_loss_fake)
                    y_train.append(valid)

                ## Generator training
                if self.feature_extractor:
                    hr_feats = self.feature_extractor.model.predict(
                        batch['hr'])
                    y_train.extend([*hr_feats])

                model_losses = self.model.train_on_batch(batch['lr'], y_train)
                model_losses = self._format_losses('train_', model_losses,
                                                   self.model.metrics_names)
                training_losses.update(model_losses)

                self.tensorboard.on_epoch_end(epoch * steps_per_epoch + step,
                                              training_losses)
                self.logger.debug('Losses at step {s}:\n {l}'.format(
                    s=step, l=training_losses))

            elapsed_time = time() - epoch_start
            self.logger.info('Epoch {} took {:10.1f}s'.format(
                epoch, elapsed_time))

            validation_losses = self.model.evaluate(validation_set['lr'],
                                                    y_validation,
                                                    batch_size=batch_size)
            validation_losses = self._format_losses('val_', validation_losses,
                                                    self.model.metrics_names)

            if epoch == starting_epoch:
                remove_metrics = []
                for metric in monitored_metrics:
                    if (metric not in training_losses) and (
                            metric not in validation_losses):
                        msg = ' '.join([
                            metric,
                            'is NOT among the model metrics, removing it.'
                        ])
                        self.logger.error(msg)
                        remove_metrics.append(metric)
                for metric in remove_metrics:
                    _ = monitored_metrics.pop(metric)

            # should average train metrics
            end_losses = {}
            end_losses.update(validation_losses)
            end_losses.update(training_losses)

            self.helper.on_epoch_end(
                epoch=epoch,
                losses=end_losses,
                generator=self.model.get_layer('generator'),
                discriminator=self.discriminator,
                metrics=monitored_metrics,
            )
            self.tensorboard.on_epoch_end(epoch, validation_losses)

        self.tensorboard.on_train_end(None)
Esempio n. 28
0
class NLITaskTrain(object):
    def __init__(self,
                 model,
                 train_data,
                 test_data,
                 dev_data=None,
                 optimizer=None,
                 log_dir=None,
                 save_dir=None,
                 name=None):
        self.model = model
        self.name = name
        """Data"""
        self.train_label = train_data[-1]
        self.train_data = train_data[:-1]
        self.test_data = test_data
        self.dev_data = dev_data
        if self.dev_data is not None:
            self.dev_label = self.dev_data[-1]
            self.dev_data = self.dev_data[:-1]
        """Train Methods"""
        self.optimizer = optimizer
        self.current_optimizer = None
        self.current_optimizer_id = -1
        self.current_switch_steps = -1
        """Others"""
        self.log_dir = log_dir
        if self.log_dir is not None and not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)
        self.logger = TensorBoard(log_dir=self.log_dir)

        self.save_dir = save_dir
        if self.save_dir is not None and not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)

    def train(self, batch_size=128, eval_interval=512, shuffle=True):
        return

    def train_multi_optimizer(self,
                              batch_size=128,
                              eval_interval=512,
                              shuffle=True):
        assert isinstance(self.optimizer, Iterable) is True
        assert len(self.optimizer) > 1

        self.current_optimizer = None
        self.current_optimizer_id = -1
        self.current_switch_steps = -1

        self.init_optimizer()
        self.model.summary()

        train_steps, no_progress_steps, epoch = 0, 0, 0
        train_batch_start = 0
        best_loss = np.inf

        while True:
            if shuffle:
                random_index = np.random.permutation(len(self.train_label))
                self.train_data = [
                    data[random_index] for data in self.train_data
                ]
                self.train_label = self.train_label[random_index]

            dev_loss, dev_acc = self.evaluate(batch_size=batch_size)
            self.logger.on_epoch_end(epoch=epoch,
                                     logs={
                                         "dev_loss": dev_loss,
                                         "dev_acc": dev_acc
                                     })
            self.model.save(
                self.save_dir +
                "epoch{}-loss{}-acc{}.model".format(epoch, dev_loss, dev_acc))
            epoch += 1
            no_progress_steps += 1

            if dev_loss < best_loss:
                best_loss = dev_loss
                no_progress_steps = 0

            if no_progress_steps > self.current_switch_steps:
                self.switch_optimizer()
                no_progress_steps = 0

            for i in range(eval_interval):
                train_loss, train_acc = self.model.train_on_batch([
                    data[train_batch_start:train_batch_start + batch_size]
                    for data in self.train_data
                ], self.train_label[train_batch_start:train_batch_start +
                                    batch_size])
                self.logger.on_batch_end(train_steps,
                                         logs={
                                             "train_loss": train_loss,
                                             "train_acc": train_acc
                                         })

                train_steps += 1
                train_batch_start += batch_size
                if train_batch_start > len(self.train_label):
                    train_batch_start = 0
                    if shuffle:
                        random_index = np.random.permutation(
                            len(self.train_label))
                        self.train_data = [
                            data[random_index] for data in self.train_data
                        ]
                        self.train_label = self.train_label[random_index]

    def init_optimizer(self):
        self.current_optimizer_id = 0
        self.current_optimizer, self.current_switch_steps = self.optimizer[
            self.current_optimizer_id]
        self.model.compile(optimizer=self.current_optimizer,
                           loss="binary_crossentropy",
                           metrics=["acc"])
        self.logger.set_model(self.model)
        logger.info("Switch to {} optimizer".format(self.current_optimizer))

    def evaluate(self, X=None, y=None, batch_size=None):
        if X is None:
            X, y = self.dev_data, self.dev_label

        loss, acc = self.model.evaluate(X, y, batch_size=batch_size)
        return loss, acc

    def switch_optimizer(self):
        self.current_optimizer_id += 1
        if self.current_optimizer_id >= len(self.optimizer):
            logger.info("Training processes finished")
            exit(0)

        self.current_optimizer, self.current_switch_steps = self.optimizer[
            self.current_optimizer_id]
        self.model.compile(optimizer=self.current_optimizer,
                           loss="binary_crossentropy",
                           metrics=["acc"])
        self.logger.set_model(self.model)
        logger.info("Switch to {} optimizer".format(self.current_optimizer))
Esempio n. 29
0
    def train_srgan(self,
                    epochs,
                    batch_size,
                    dataname,
                    datapath_train,
                    datapath_validation=None,
                    steps_per_epoch=10000,
                    steps_per_validation=100,
                    datapath_test=None,
                    workers=16,
                    max_queue_size=10,
                    first_step=0,
                    print_frequency=1,
                    crops_per_image=4,
                    log_weight_frequency=None,
                    log_weight_path='./data/weights/',
                    log_tensorboard_path='./data/logs/',
                    log_tensorboard_name='SRGAN',
                    log_tensorboard_update_freq=10000,
                    log_test_frequency=1,
                    log_test_path="./images/samples/",
                    job_dir=None):
        """Train the SRGAN network

        :param int epochs: how many epochs to train the network for
        :param str dataname: name to use for storing model weights etc.
        :param str datapath_train: path for the image files to use for training
        :param str datapath_test: path for the image files to use for testing / plotting
        :param int print_frequency: how often (in epochs) to print progress to terminal. Warning: will run validation inference!
        :param int log_weight_frequency: how often (in epochs) should network weights be saved. None for never
        :param int log_weight_path: where should network weights be saved        
        :param int log_test_frequency: how often (in epochs) should testing & validation be performed
        :param str log_test_path: where should test results be saved
        :param str log_tensorboard_path: where should tensorflow logs be sent
        :param str log_tensorboard_name: what folder should tf logs be saved under        
        """

        # Create train data loader
        loader = DataLoader(datapath_train, batch_size, self.height_hr,
                            self.width_hr, self.upscaling_factor,
                            crops_per_image)

        # Validation data loader
        if datapath_validation is not None:
            validation_loader = DataLoader(datapath_validation, batch_size,
                                           self.height_hr, self.width_hr,
                                           self.upscaling_factor,
                                           crops_per_image)

        # Use several workers on CPU for preparing batches
        enqueuer = OrderedEnqueuer(loader,
                                   use_multiprocessing=True,
                                   shuffle=True)
        enqueuer.start(workers=workers, max_queue_size=max_queue_size)
        output_generator = enqueuer.get()

        # Callback: tensorboard
        if log_tensorboard_path:
            tensorboard = TensorBoard(log_dir=os.path.join(
                log_tensorboard_path, log_tensorboard_name),
                                      histogram_freq=0,
                                      batch_size=batch_size,
                                      write_graph=False,
                                      write_grads=False,
                                      update_freq=log_tensorboard_update_freq)
            tensorboard.set_model(self.srgan)
        else:
            print(
                ">> Not logging to tensorboard since no log_tensorboard_path is set"
            )

        # Callback: format input value
        def named_logs(model, logs):
            """Transform train_on_batch return value to dict expected by on_batch_end callback"""
            result = {}
            for l in zip(model.metrics_names, logs):
                result[l[0]] = l[1]
            return result

        # Shape of output from discriminator
        disciminator_output_shape = list(self.discriminator.output_shape)
        disciminator_output_shape[0] = batch_size
        disciminator_output_shape = tuple(disciminator_output_shape)

        # VALID / FAKE targets for discriminator
        real = np.ones(disciminator_output_shape)
        fake = np.zeros(disciminator_output_shape)

        # Each epoch == "update iteration" as defined in the paper
        print_losses = {"G": [], "D": []}
        start_epoch = datetime.datetime.now()

        # Random images to go through
        idxs = np.random.randint(0, len(loader), epochs)

        # Some dummy variables to track
        current_epoch = 1
        logs = {}

        # Loop through epochs / iterations
        for step in range(0, steps_per_epoch * int(epochs)):

            # Epoch change e.g. steps_per_epoch steps completed
            epoch_change = False
            if step > current_epoch * steps_per_epoch:
                current_epoch += 1
                epoch_change = True

            # print('Step {}, Current Epoch {}'.format(step, current_epoch))

            # Start epoch time
            if epoch_change:
                start_epoch = datetime.datetime.now()

            # Train discriminator
            imgs_lr, imgs_hr = next(output_generator)
            generated_hr = self.generator.predict(imgs_lr)
            real_loss = self.discriminator.train_on_batch(imgs_hr, real)
            fake_loss = self.discriminator.train_on_batch(generated_hr, fake)
            discriminator_loss = 0.5 * np.add(real_loss, fake_loss)

            # Train generator
            features_hr = self.vgg.predict(self.preprocess_vgg(imgs_hr))
            generator_loss = self.srgan.train_on_batch(imgs_lr,
                                                       [real, features_hr])

            # Callbacks
            if logs and not epoch_change:
                for k, v in named_logs(self.srgan, generator_loss).items():
                    logs[k] += v
            else:
                if logs and epoch_change:
                    tensorboard.on_epoch_end(step + first_step, logs)

                logs = named_logs(self.srgan, generator_loss)

            # Save losses
            print_losses['G'].append(generator_loss)
            print_losses['D'].append(discriminator_loss)

            # Show the progress
            if epoch_change and current_epoch % print_frequency == 0:
                g_avg_loss = np.array(print_losses['G']).mean(axis=0)
                d_avg_loss = np.array(print_losses['D']).mean(axis=0)
                print(
                    "\nEpoch {}/{} | Time: {}s\n>> Generator/GAN: {}\n>> Discriminator: {}"
                    .format(
                        current_epoch, epochs,
                        (datetime.datetime.now() - start_epoch).seconds,
                        ", ".join([
                            "{}={:.4f}".format(k, v) for k, v in zip(
                                self.srgan.metrics_names, g_avg_loss)
                        ]), ", ".join([
                            "{}={:.4f}".format(k, v) for k, v in zip(
                                self.discriminator.metrics_names, d_avg_loss)
                        ])))
                print_losses = {"G": [], "D": []}

                # Run validation inference if specified
                if datapath_validation:
                    validation_losses = self.generator.evaluate_generator(
                        validation_loader,
                        steps=steps_per_validation,
                        use_multiprocessing=workers > 1,
                        workers=workers)
                    print(">> Validation Losses: {}".format(", ".join([
                        "{}={:.4f}".format(k, v) for k, v in zip(
                            self.generator.metrics_names, validation_losses)
                    ])))

            # If test images are supplied, run model on them and save to log_test_path
            if datapath_test and epoch_change and current_epoch % log_test_frequency == 0:
                plot_test_images(self, loader, datapath_test, log_test_path,
                                 current_epoch)

            # Check if we should save the network weights
            if log_weight_frequency and epoch_change and current_epoch % log_weight_frequency == 0:
                # Save the network weights
                self.save_weights(log_weight_path, dataname)
        dB_loss_real = D_B.train_on_batch(imgs_B, valid)
        dB_loss_fake = D_B.train_on_batch(fake_B, fake)

        dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

        d_loss = 0.5 * np.add(dA_loss, dB_loss)

        #training generation
        g_loss = combined_model.train_on_batch([imgs_A, imgs_B],\
                                               [valid, valid, \
                                                imgs_A, imgs_B, \
                                                imgs_A, imgs_B])

        #add tensorboard
        tb_G_loss_track.on_epoch_end(batch_index, logs=reNamed_logs(g_loss))

        #collect losses for plot

        DA_losses.append(dA_loss)
        DB_losses.append(dB_loss)

        time = datetime.datetime.now().strftime("%Y-%m-%d-%H%M%S")

        print ("[epoch_index: %d/%d][batch_index:%d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, adv: %05f, recon: %05f, id: %05f] [time:%s]" \
                                                                                % ( epoch, epochs, \
                                                                                  batch_index, batch_num, \
                                                                                  d_loss[0], 100*d_loss[1],\
                                                                                  g_loss[0],\
                                                                                  np.mean(g_loss[1:3]),\
                                                                                  np.mean(g_loss[3:5]),\