Exemple #1
0
    def train(self, x, num_iter):

        for i in range(self.num_epochs):
            print("Epoch no :" + str(i + 1) + "/" + str(self.num_epochs))

            for j in tqdm(range(num_iter)):

                x1, y = self.gen_data(x, self.batch_size // 2)
                # train the discriminator
                self.discriminator.train_on_batch(x1, y)
                # Freeze the discriminator to train the GAN model
                utils.make_trainable(self.discriminator, False)
                # train the gan model
                inp = utils.gen_noise(self.batch_size // 2)
                labels = np.zeros((self.batch_size // 2, 1))
                self.gan_model.train_on_batch(inp, labels)

                # make the discriminator params back to trainable for the next iteration
                utils.make_trainable(self.discriminator, True)

            #save the weights and plot the results every 10 epochs
            if i % 10 == 0:
                self.gan_model.save_weights(self.save_path + str(i + 1) +
                                            ".h5")
                utils.plot(self.generator)
Exemple #2
0
    def __init__(self):
        # Input shape
        self.channels = 3
        self.img_size = 64
        self.latent_dim = 100
        self.time = time()
        self.dataset_name = 'vdsr'
        self.learning_rate = 1e-4

        optimizer = Adam(self.learning_rate, beta_1=0.5, decay=0.00005)

        self.gf = 64  # filter size of generator's last layer
        self.df = 64  # filter size of discriminator's first layer

        # Configure data loader
        self.data_loader = DataLoader(dataset_name=self.dataset_name,
                                      img_res=(self.img_size, self.img_size),
                                      mem_load=True)
        self.n_data = self.data_loader.get_n_data()

        self.generator = self.build_generator()
        print(
            "---------------------generator summary----------------------------"
        )
        self.generator.summary()

        self.generator.compile(loss='mse',
                               optimizer=optimizer,
                               metrics=['mse'])

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        print(
            "\n---------------------discriminator summary----------------------------"
        )
        self.discriminator.summary()

        self.discriminator.compile(loss='mse',
                                   optimizer=optimizer,
                                   metrics=['accuracy'])

        make_trainable(self.discriminator, False)

        z = Input(shape=(self.latent_dim, ))
        fake_img = self.generator(z)

        # for the combined model, we only train ganerator
        self.discriminator.trainable = False

        validity = self.discriminator(fake_img)

        self.combined = Model([z], [validity])
        print(
            "\n---------------------combined summary----------------------------"
        )
        self.combined.summary()
        self.combined.compile(loss=['binary_crossentropy'],
                              optimizer=optimizer)
Exemple #3
0
    def train(self, epochs, batch_size, sample_interval):
        def named_logs(model, logs):
            result = {}
            for l in zip(model.metrics_names, logs):
                result[l[0]] = l[1]
            return result

        start_time = datetime.datetime.now()

        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        max_iter = int(self.n_data / batch_size)
        os.makedirs('./logs/%s' % self.time, exist_ok=True)
        tensorboard = TensorBoard('./logs/%s' % self.time)
        tensorboard.set_model(self.generator)

        os.makedirs('models/%s' % self.time, exist_ok=True)
        with open('models/%s/%s_architecture.json' % (self.time, 'generator'),
                  'w') as f:
            f.write(self.generator.to_json())
        print(
            "\nbatch size : %d | num_data : %d | max iteration : %d | time : %s \n"
            % (batch_size, self.n_data, max_iter, self.time))
        for epoch in range(1, epochs + 1):
            for iter in range(max_iter):
                # ------------------
                #  Train Generator
                # ------------------
                ref_imgs = self.data_loader.load_data(batch_size)

                noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
                gen_imgs = self.generator.predict(noise)
                make_trainable(self.discriminator, True)
                d_loss_real = self.discriminator.train_on_batch(
                    ref_imgs, valid * 0.9)  # label smoothing
                d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

                make_trainable(self.discriminator, False)

                logs = self.combined.train_on_batch([noise], [valid])
                tensorboard.on_epoch_end(iter,
                                         named_logs(self.combined, [logs]))

                if iter % (sample_interval // 10) == 0:
                    elapsed_time = datetime.datetime.now() - start_time
                    print(
                        "epoch:%d | iter : %d / %d | time : %10s | g_loss : %15s | d_loss : %s "
                        % (epoch, iter, max_iter, elapsed_time, logs, d_loss))

                if (iter + 1) % sample_interval == 0:
                    self.sample_images(epoch, iter + 1)

            # save weights after every epoch
            self.generator.save_weights('models/%s/%s_epoch%d_weights.h5' %
                                        (self.time, 'generator', epoch))
Exemple #4
0
    def train(self, data, batch_size=250, num_epochs=25, eval_size=200):
        losses = []
        train, test = train_test_split(data)
        for epoch in range(num_epochs):
            for i in range(len(train) // batch_size):
                # ------------------
                # Train Disciminator
                # ------------------
                make_trainable(self.discriminator, True)
                # Get some real conformations from the train data
                real_confs = train[i * batch_size:(i + 1) * batch_size]
                real_confs = real_confs.reshape(-1, self.n_atoms, 3, 1)

                # Sample high dimensional noise and generate fake conformations
                noise = make_latent_samples(batch_size, self.noise_dim)
                fake_confs = self.generator.predict_on_batch(noise)

                # Label the conformations accordingly
                real_confs_labels, fake_confs_labels = make_labels(batch_size)

                self.discriminator.train_on_batch(real_confs,
                                                  real_confs_labels)
                self.discriminator.train_on_batch(fake_confs,
                                                  fake_confs_labels)

                # --------------------------------------------------
                #  Train Generator via GAN (swith off discriminator)
                # --------------------------------------------------
                noise = make_latent_samples(batch_size, self.noise_dim)
                make_trainable(self.discriminator, False)
                g_loss = self.gan.train_on_batch(noise, real_confs_labels)

            # Evaluate performance after epoch
            conf_eval_real = test[np.random.choice(len(test),
                                                   eval_size,
                                                   replace=False)]
            conf_eval_real = conf_eval_real.reshape(-1, self.n_atoms, 3, 1)
            noise = make_latent_samples(eval_size, self.noise_dim)
            conf_eval_fake = self.generator.predict_on_batch(noise)

            eval_real_labels, eval_fake_labels = make_labels(eval_size)

            d_loss_r = self.discriminator.test_on_batch(
                conf_eval_real, eval_real_labels)
            d_loss_f = self.discriminator.test_on_batch(
                conf_eval_fake, eval_fake_labels)
            d_loss = (d_loss_r + d_loss_f) / 2

            # we want the fake to be realistic!
            g_loss = self.gan.test_on_batch(noise, eval_real_labels)

            print(
                "Epoch: {:>3}/{} Discriminator Loss: {:>6.4f} Generator Loss: {:>6.4f}"
                .format(epoch + 1, num_epochs, d_loss, g_loss))

            losses.append((d_loss, g_loss))
        return losses
Exemple #5
0
def create_model(**kwargs):
    image_width = kwargs.get('image_width', settings.image_width)
    image_height = kwargs.get('image_height', settings.image_height)
    num_input_channels = kwargs.get('num_input_channels',
                                    settings.num_input_channels)
    loss = kwargs.get('loss', settings.loss)
    metrics = kwargs.get('metrics', settings.metrics)
    optimizer = kwargs.get('optimizer', settings.optimizer)
    learning_rate = kwargs.get('learning_rate', settings.learning_rate)

    # For custom loss and metrics functions
    if loss == 'dice_coef_loss':
        loss = dice_coef_loss

    if metrics == ['dice_coef']:
        metrics = [dice_coef]

    if optimizer == 'adam':
        optimizer = Adam(lr=learning_rate,
                         beta_1=kwargs.get('beta_1', settings.beta_1))

    generator = create_generator(**kwargs)
    generator.compile(loss=loss, optimizer=optimizer)
    discriminator = create_discriminator(**kwargs)
    discriminator.compile(loss=loss, optimizer=optimizer)

    # Make discriminator untrainable for stacked training
    utils.make_trainable(discriminator, False)
    GAN_input = Input(shape=(num_input_channels, image_width, image_height),
                      name='GAN_input')
    H = generator(GAN_input)
    GAN_V = discriminator(H)
    GAN = Model(GAN_input, GAN_V)
    GAN.compile(loss=loss, optimizer=optimizer)

    print('~~~~~~~~~~~GAN~~~~~~~~~~~')
    GAN.summary()
    print('~~~~~~~~GENERATOR~~~~~~~~')
    generator.summary()
    print('~~~~~~DISCRIMINATOR~~~~~~')
    discriminator.summary()

    return GAN
Exemple #6
0
test_imgs, test_vessels, test_masks=utils.get_imgs(test_dir, augmentation=False, img_size=img_size, dataset=dataset, mask=True)

# create networks
g = generator(img_size, n_filters_g) 
if FLAGS.discriminator=='pixel':
    d, d_out_shape = discriminator_pixel(img_size, n_filters_d,init_lr)
elif FLAGS.discriminator=='patch1':
    d, d_out_shape = discriminator_patch1(img_size, n_filters_d,init_lr)
elif FLAGS.discriminator=='patch2':
    d, d_out_shape = discriminator_patch2(img_size, n_filters_d,init_lr)
elif FLAGS.discriminator=='image':
    d, d_out_shape = discriminator_image(img_size, n_filters_d,init_lr)
else:
    d, d_out_shape = discriminator_dummy(img_size, n_filters_d,init_lr)

utils.make_trainable(d, False)
gan=GAN(g,d,img_size, n_filters_g, n_filters_d,alpha_recip, init_lr)
generator=pretrain_g(g, img_size, n_filters_g, init_lr)
g.summary()
d.summary()
gan.summary() 
with open(os.path.join(model_out_dir,"g_{}_{}.json".format(FLAGS.discriminator,FLAGS.ratio_gan2seg)),'w') as f:
    f.write(g.to_json())
        
# start training
scheduler=utils.Scheduler(n_train_imgs//batch_size, n_train_imgs//batch_size, schedules, init_lr) if alpha_recip>0 else utils.Scheduler(0, n_train_imgs//batch_size, schedules, init_lr)
print "training {} images :".format(n_train_imgs)
for n_round in range(n_rounds):
    
    # train D
    utils.make_trainable(d, True)
def train_for_n(nb_epoch=5000, BATCH_SIZE=32):
    for e in tqdm(range(nb_epoch)):
        ### Shuffle and Batch the data
        _random = np.random.randint(0, emb_cs.shape[0], size=BATCH_SIZE)
        _random2 = np.random.randint(0, emb_zh.shape[0], size=BATCH_SIZE)
        if not WORD_ONLY:
            pos_seq_cs_batch = pos_seq_cs[_random]
            pos_seq_zh_batch = pos_seq_zh[_random2]
        emb_cs_batch = emb_cs[_random]
        emb_zh_batch = emb_zh[_random2]
        noise_g = np.random.normal(0,
                                   1,
                                   size=(BATCH_SIZE, MAX_SEQUENCE_LENGTH,
                                         NOISE_SIZE))
        reward_batch = np.zeros((BATCH_SIZE, 1))

        #############################################
        ### Train generator
        #############################################
        for ep in range(1):  # G v.s. D training ratio
            if not WORD_ONLY:
                output_g = generator.predict(
                    [emb_zh_batch, pos_seq_zh_batch, noise_g, reward_batch])
            else:
                output_g = generator.predict(
                    [emb_zh_batch, noise_g, reward_batch])
            action_g, action_one_hot_g = get_action(output_g)
            emb_g = translate(emb_zh_batch, action_g)
            text_g = translate_output(emb_zh_batch, action_g)

            # tag POS
            if not WORD_ONLY:
                pos_seq_g = []
                for line in text_g:
                    words = pseg.cut(line)
                    sub_data = []
                    idx = 0
                    for w in words:
                        if w.flag == "x":
                            idx = 0
                        elif idx == 0:
                            sub_data.append(postag[w.flag])
                            idx = 1
                    pos_seq_g.append(sub_data)

                pos_seq_g = pad_sequences(pos_seq_g,
                                          maxlen=MAX_SEQUENCE_LENGTH,
                                          padding='post',
                                          truncating='post',
                                          value=0)

            one_hot_action = action_one_hot_g.reshape(BATCH_SIZE,
                                                      MAX_SEQUENCE_LENGTH, 2)

            make_trainable(generator, True)

            if not WORD_ONLY:
                reward_batch = discriminator.predict([emb_g, pos_seq_g])[:, 0]
                g_loss = generator.train_on_batch(
                    [emb_zh_batch, pos_seq_zh_batch, noise_g, reward_batch],
                    one_hot_action)
            else:
                reward_batch = discriminator.predict([emb_g])[:, 0]
                g_loss = generator.train_on_batch(
                    [emb_zh_batch, noise_g, reward_batch], one_hot_action)

            losses["g"].append(g_loss)
            write_log(callbacks, log_g, g_loss, len(losses["g"]))
            if g_loss < 0.15:  # early stop
                break

        #############################################
        ### Train discriminator on generated sentence
        #############################################
        X_emb = np.concatenate((emb_cs_batch, emb_g))
        if not WORD_ONLY:
            X_pos = np.concatenate((pos_seq_cs_batch, pos_seq_g))
        y = np.zeros([2 * BATCH_SIZE])
        y[0:BATCH_SIZE] = 0.7 + np.random.random([BATCH_SIZE]) * 0.3
        y[BATCH_SIZE:] = 0 + np.random.random([BATCH_SIZE]) * 0.3

        make_trainable(discriminator, True)
        model.embedding_word.trainable = False
        if not WORD_ONLY:
            model.embedding_pos.trainable = False
        model.g_bi.trainable = False

        for ep in range(1):  # G v.s. D training ratio
            if not WORD_ONLY:
                d_loss = discriminator.train_on_batch([X_emb, X_pos], y)
            else:
                d_loss = discriminator.train_on_batch([X_emb], y)
            losses["d"].append(d_loss)
            write_log(callbacks, log_d, d_loss, len(losses["d"]))
            if d_loss < 0.6:  # early stop
                break

    ### Save model
    generator.save_weights(MODEL_PATH + "gen.mdl")
    discriminator.save_weights(MODEL_PATH + "dis.mdl")
                              value=0)
    X_pos = np.concatenate((XT_pos, pos_seq_g))

X_emb = np.concatenate((XT_emb, emb_g))
n = XT_emb.shape[0]
y = np.zeros([2 * n])
y[:n] = 0.7 + np.random.random([n]) * 0.3
y[n:] = 0 + np.random.random([n]) * 0.3

random_id = np.random.randint(0, X_emb.shape[0], size=BATCH_SIZE * 10)
XX_emb = X_emb[random_id]
y = y[random_id]
if not WORD_ONLY:
    XX_pos = X_pos[random_id]

make_trainable(discriminator, True)
K.set_value(discriminator.optimizer.lr, dopt)
K.set_value(discriminator.optimizer.decay, dopt / 100)
if not WORD_ONLY:
    discriminator.fit([XX_emb, XX_pos],
                      y,
                      epochs=10,
                      batch_size=BATCH_SIZE,
                      validation_split=0.1,
                      callbacks=[earlystopper])
else:
    discriminator.fit([XX_emb],
                      y,
                      epochs=10,
                      batch_size=BATCH_SIZE,
                      validation_split=0.1,
Exemple #9
0
def train(model, training_data, validation_data, **kwargs):
    x_train = training_data[0]
    y_train = training_data[1]

    batch_size = kwargs.get('batch_size', settings.batch_size)

    # model is [GAN, generator, discriminator]
    GAN = model
    generator = GAN.layers[1]
    discriminator = GAN.layers[2]

    # Pre-train the discriminator
    print('pre-training the discriminator')
    num_images = batch_size
    splice = random.sample(range(0, x_train.shape[0]), num_images)
    XT = x_train[splice, :, :, :]
    YT = y_train[splice, :, :, :]
    generated_images = generator.predict(XT)
    y_fake = np.array([0] * num_images).astype('float')
    y_real = np.array([1] * batch_size).astype('float')
    y = np.concatenate((y_fake, y_real))
    x = np.concatenate((generated_images, YT))
    utils.make_trainable(discriminator, True)
    utils.make_trainable(generator, True)
    GAN.fit(x, y, nb_epoch=1, batch_size=batch_size, verbose=1)

    # Train the GAN
    for epoch in range(100):
        print("Epoch is", epoch)
        print("Number of batches", int(x_train.shape[0] / batch_size))

        for index in range(int(x_train.shape[0] / batch_size)):
            x_image_batch = x_train[index * batch_size:(index + 1) *
                                    batch_size]
            y_image_batch = y_train[index * batch_size:(index + 1) *
                                    batch_size]

            # if conditional GAN:
            #   generated_images = concat(generated_images, y_image_batch)

            # Train discriminator, generator weights are frozen
            generated_images = generator.predict(x_image_batch, verbose=1)
            utils.make_trainable(discriminator, True)
            utils.make_trainable(generator, False)
            y_fake = np.array([0] * batch_size).astype('float')
            y_real = np.array([1] * batch_size).astype('float')
            d_loss = GAN.train_on_batch(generated_images, y_fake)
            d_loss += GAN.train_on_batch(y_image_batch, y_real)

            # Train generator, discriminator weights are frozen
            utils.make_trainable(discriminator, False)
            utils.make_trainable(generator, True)
            # g_loss = GAN.train_on_batch(generated_images, y_fake)
            g_loss = GAN.train_on_batch(x_image_batch, y_real)
            # g_loss /= 2
            # g_loss = generator.train_on_batch(x_image_batch, y_image_batch)
            print("batch %d \t d_loss %f \t g_loss : %f" %
                  (index, d_loss, g_loss))

            # Save weights every 9 indexes
            if index % 10 == 9:
                generator.save_weights('generator_weights', True)
                discriminator.save_weights('discriminator_weights', True)

        # Save a generated image every epoch
        image = combine_images(generated_images)
        image = deprocess(image)
        Image.fromarray(image.astype(
            np.uint8)).save(str(epoch) + "_" + str(index) + ".png")
Exemple #10
0
def main():
    """ Creates and trains the GAN. """

    # Creates directories
    makedirs()

    # Initializes and trains the generator if necessary
    generator = train_generator()

    # Loads discriminator
    if FLAGS.disc_path is None:
        discriminator = create_discriminator((15, 26, 512))
    else:
        discriminator = load_model(FLAGS.disc_path)

    # Compiles discriminator
    discriminator.compile(optimizer="adam",
                          loss="binary_crossentropy",
                          metrics=["acc"])

    # Sets initial discriminator trainability to False
    make_trainable(discriminator, False)

    # Loads VGG
    vgg = keras.applications.VGG16(include_top=False)

    # VGG should never be trained
    vgg.trainable = False

    # Creates GAN Model
    input_layer = Input(shape=(240, 426, 3))
    out_gen = generator(input_layer)
    out_vgg = vgg(out_gen)
    out_disc = discriminator(out_vgg)

    # Compiles the GAN
    # Notice the loss_weights argument: this can be changed to achieve different
    # results. In general, binary cross-entropy controls resolution accuracy
    # and RMSE controls color accuracy.
    model = Model(inputs=input_layer, outputs=[out_disc, out_gen])
    model.compile(optimizer="adam",
                  loss=["binary_crossentropy", root_mean_squared_error],
                  loss_weights=[0.9, 0.1],
                  metrics=["acc"])

    # Gets filepaths of training set
    files = os.listdir(FLAGS.x_dir)

    # Set to False to train the discriminator first
    train_gen = False

    # Trains the model
    for epoch in range(FLAGS.start_epoch, FLAGS.epochs):
        # Shuffles the training set and divide it into batches
        np.random.shuffle(files)
        batches = chunks(files, FLAGS.batch_size)

        # Trains a batch
        for batch in batches:
            # Creates batch tensors
            x_train, y_train = create_batch_tensors(batch)

            # Generator's turn to train
            if train_gen:
                print("Training generator: epoch {0}".format(epoch))

                # Set discriminator trainability to False
                make_trainable(discriminator, False)

                # Trains the generator
                metrics = model.fit(x_train,
                                    [np.ones([len(x_train)]), y_train])

                # If the generator is good enough, switch to discriminator
                if metrics.history["generator_acc"][0] > .9:
                    train_gen = False

            # Discriminator's turn to train
            else:
                print("Training discriminator: epoch {0}".format(epoch))

                # Set discriminator trainability to True
                make_trainable(discriminator, True)

                # Get generated data from inputs
                gen_input = x_train
                gen_output = generator.predict(gen_input)

                # Combines x data with their corresponding y labels
                disc_input = np.concatenate([gen_output, y_train])
                ground_truth = np.concatenate(
                    [np.zeros([len(gen_output)]),
                     np.ones([len(y_train)])])

                # Shuffles generated images and real images
                disc_input, ground_truth = shuffle(disc_input, ground_truth)

                # Run VGG embeddings on generated images
                vgg_output = vgg.predict(disc_input, batch_size=1)

                # Trains the discriminator
                metrics = discriminator.fit(vgg_output, ground_truth)

                # If the discriminator is good enough, switch to generator
                if metrics.history["acc"][0] > .9:
                    train_gen = True

        print("\n \n \n \n COMPLETED EPOCH {0} \n \n \n \n".format(epoch))

        # Get samples for visual verification
        out = generator.predict(x_train)
        base_filepath = FLAGS.out_dir + "/samples/gen/epoch_{0}".format(epoch)

        # Save samples
        for i, sample in enumerate(out):
            resized_train = cv2.resize(x_train[i], (0, 0),
                                       fx=2,
                                       fy=2,
                                       interpolation=cv2.INTER_CUBIC)
            cv2.imwrite(base_filepath + "_img_{0}_input.png".format(i),
                        resized_train)
            cv2.imwrite(base_filepath + "_img_{0}_pred.png".format(i), sample)
            cv2.imwrite(base_filepath + "_img_{0}_true.png".format(i),
                        y_train[i])

        # Save generator and discriminator weights
        savepath = FLAGS.out_dir + "/weights"
        generator.save(savepath + "/gen/epoch_{0}.h5py".format(epoch))
        make_trainable(discriminator, True)
        discriminator.save(savepath + "/disc/epoch_{0}.h5py".format(epoch))