예제 #1
0
    def train(self, nb_epoch=CONFIG['nb_epoch']):
        print("Start training.")

        for epoch in range(nb_epoch):

            print("Epoch : " + str(epoch))
            g_loss = []
            d_loss = []

            for batch_id, (x, target) in enumerate(self.train_loader):
                real_batch_data = x.to(self.device)
                current_batch_size = x.shape[0]

                packed_real_data = pack(real_batch_data, self.packing)
                packed_batch_size = packed_real_data.shape[0]

                # labels
                label_real = torch.full((packed_batch_size,), 1, device=self.device).squeeze()
                label_fake = torch.full((packed_batch_size,), 0, device=self.device).squeeze()
                # smoothed real labels between 0.7 and 1, and fake between 0 and 0.3
                label_real_smooth = torch.rand((packed_batch_size,)).to(self.device).squeeze() * 0.3 + 0.7
                label_fake_smooth = torch.rand((packed_batch_size,)).to(self.device).squeeze() * 0.3

                temp_discriminator_loss = []
                temp_generator_loss = []

                ### Train discriminator multiple times
                for i in range(self.nb_discriminator_step):
                    loss_discriminator_total = self.train_discriminator(packed_real_data,
                                                                        current_batch_size,
                                                                        label_real_smooth if self.real_label_smoothing else label_real,
                                                                        label_fake_smooth if self.fake_label_smoothing else label_fake)

                    temp_discriminator_loss.append(loss_discriminator_total.item())
                    # print("Discriminator step ", str(i), " with loss : ", loss_discriminator_total.item())

                ### Train generator multiple times
                for i in range(self.nb_generator_step):
                    loss_generator_total = self.train_generator(current_batch_size, label_real)
                    temp_generator_loss.append(loss_generator_total.item())

                if batch_id == len(self.train_loader) - 2:
                    save_images(real_batch_data, self.save_path + "real/", self.image_size, self.image_channels,
                                self.nb_image_to_gen, epoch)

                ### Keep track of losses
                d_loss.append(torch.mean(torch.tensor(temp_discriminator_loss)))
                g_loss.append(torch.mean(torch.tensor(temp_generator_loss)))

            self.discriminator_losses.append(torch.mean(torch.tensor(d_loss)))
            self.generator_losses.append(torch.mean(torch.tensor(g_loss)))

            save_images(self.generator(self.saved_latent_input), self.save_path + "gen_", self.image_size,
                        self.image_channels, self.nb_image_to_gen, epoch)

            write_loss_plot(self.generator_losses, "G loss", self.save_path, clear_plot=False)
            write_loss_plot(self.discriminator_losses, "D loss", self.save_path, clear_plot=True)

        print("Training finished.")
예제 #2
0
def train():
    time1 = time.time()
    input_path_S = [i.strip() for i in open(a.input_dir+'style.txt', 'r').readlines()]
    input_path_C = [i.strip() for i in open(a.input_dir+'content.txt', 'r').readlines()]
    target_path = [i.strip() for i in open(a.input_dir+'target.txt', 'r').readlines()]
    print(time.time() - time1)

    # ###################### network ################
    batch_inputsS_holder = tf.placeholder(tf.float32, [a.style_num * a.style_sample_n, 80, 80, 1], name='inputsS')
    batch_inputsC_holder = tf.placeholder(tf.float32, [a.content_num * a.content_sample_n, 80, 80, 1], name='inputsC')
    batch_targets_holder = tf.placeholder(tf.float32, [a.target_batch_size, 80, 80, 1], name='targets')

    # compute the number of black pixels
    black = tf.greater(batch_targets_holder, 0.5)
    as_ints = tf.cast(black, tf.int32)
    zero_n = tf.reduce_sum(as_ints, [1, 2, 3]) + 1

    # compute the mean of black pixels
    zeros = tf.zeros_like(batch_targets_holder)
    new_tensor = tf.where(black, batch_targets_holder, zeros)
    mean_pixel_value = tf.reduce_sum(new_tensor, [1, 2, 3])/tf.to_float(zero_n)

    # zero_n = tf.placeholder(tf.float32,[a.target_batch_size,1],name='zero_n')
    # mean_pixel_value = tf.placeholder(tf.float32,[a.target_batch_size,1],name='mean_pixel_value')

    with tf.variable_scope("generator"):
        pictures_decode, model_loss, model_mse = create_generator(batch_inputsS_holder, batch_inputsC_holder,
                                                                  batch_targets_holder, zero_n, mean_pixel_value)

    # ########prepare data ###################################
    input_path_S_holder = tf.placeholder(tf.string)
    input_path_C_holder = tf.placeholder(tf.string)
    target_path_holder = tf.placeholder(tf.string)

    dataset1 = tf.data.Dataset.from_tensor_slices(input_path_S_holder)
    dataset1 = dataset1.map(process, num_parallel_calls=a.num_parallel_prefetch)
    dataset1 = dataset1.prefetch(a.style_sample_n*a.style_num * a.num_parallel_prefetch)
    dataset1 = dataset1.batch(a.style_sample_n*a.style_num).repeat(a.max_epochs)

    dataset2 = tf.data.Dataset.from_tensor_slices(input_path_C_holder)
    dataset2 = dataset2.map(process, num_parallel_calls=a.num_parallel_prefetch)
    dataset2 = dataset2.prefetch(a.content_sample_n*a.content_num * a.num_parallel_prefetch)
    dataset2 = dataset2.batch(a.content_sample_n*a.content_num).repeat(a.max_epochs)

    dataset3 = tf.data.Dataset.from_tensor_slices(target_path_holder)
    dataset3 = dataset3.map(process, num_parallel_calls=a.num_parallel_prefetch)
    dataset3 = dataset3.prefetch(a.target_batch_size * a.num_parallel_prefetch)
    dataset3 = dataset3.batch(a.target_batch_size).repeat(a.max_epochs)

    iterator1 = dataset1.make_initializable_iterator()
    one_element1 = tf.convert_to_tensor(iterator1.get_next())

    iterator2 = dataset2.make_initializable_iterator()
    one_element2 = tf.convert_to_tensor(iterator2.get_next())

    iterator3 = dataset3.make_initializable_iterator()
    one_element3 = tf.convert_to_tensor(iterator3.get_next())

    ############################################################################

    # model_tvars = [var for var in tf.trainable_variables() if var.name.startswith("generator")]
    # optim_d = tf.train.AdamOptimizer(learning_rate=a.adam_lr).minimize(model_loss, var_list=model_tvars)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        model_tvars = [var for var in tf.trainable_variables() if var.name.startswith("generator")]
        # model_optim = tf.train.RMSPropOptimizer(a.rmsprop_lr)
        # learning_rate = tf.train.exponential_decay(a.adam_lr, global_step, a.decay_steps, a.decay_rate)
        model_optim = tf.train.AdamOptimizer(a.adam_lr)
        model_grads_and_vars = model_optim.compute_gradients(model_loss, var_list=model_tvars)
        model_train = model_optim.apply_gradients(model_grads_and_vars)

    saver = tf.train.Saver(max_to_keep=2)
    init = tf.global_variables_initializer()

    logdir = a.output_dir if (a.trace_freq > 0 or a.summary_freq > 0) else None
    sv = tf.train.Supervisor(logdir=logdir, saver=None, summary_op=None)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with sv.managed_session(config=config) as sess:
        sess.run(init)

        if a.checkpoint is not None:
            print("loading model from checkpoint")
            checkpoint = tf.train.latest_checkpoint(a.checkpoint)
            saver.restore(sess, checkpoint)
            print('ok')

        start = time.time()
        steps_per_epoch = int(len(target_path)/a.target_batch_size)
        max_steps = a.max_epochs*steps_per_epoch

        sess.run(iterator1.initializer, feed_dict={input_path_S_holder: input_path_S})
        sess.run(iterator2.initializer, feed_dict={input_path_C_holder: input_path_C})
        sess.run(iterator3.initializer, feed_dict={target_path_holder: target_path})

        for step in range(max_steps):
            def should(freq):
                return freq > 0 and ((step + 1) % freq == 0 or step == max_steps - 1)

            batch_inputsS = sess.run(one_element1)
            batch_inputsC = sess.run(one_element2)
            batch_targets = sess.run(one_element3)

            _, loss, mse, outputs = sess.run([model_train, model_loss, model_mse, pictures_decode],
                                             feed_dict={batch_inputsS_holder: batch_inputsS,
                                             batch_inputsC_holder: batch_inputsC,
                                             batch_targets_holder: batch_targets})

            if should(a.display_freq):
                print("saving display images")
                save_images(outputs, step, [4, 13], 'output')
                save_images(batch_targets, step, [4, 13], 'target')

            if should(a.progress_freq):
                # global_step will have the correct step count if we resume from a checkpoint
                train_epoch = math.ceil(step / steps_per_epoch)
                train_step = (step - 1) % steps_per_epoch + 1
                rate = (step + 1) * a.target_batch_size / (time.time() - start)
                remaining = (max_steps - step) * a.target_batch_size / rate
                print("progress  epoch %d  step %d  image/sec %0.1f  remaining %dm" % (train_epoch,
                      train_step, rate, remaining / 60))
                print("model_loss", loss)
                print("mse", mse)

            if should(a.save_freq):
                print("saving model")
                saver.save(sess, os.path.join(a.output_dir, "model"), global_step=step)

            if sv.should_stop():
                break
예제 #3
0
def train(self, config):
    """Train DCGAN"""

    d_optim = self.d_optim
    g_optim = self.g_optim

    tf.compat.v1.initialize_all_variables().run()

    self.saver = tf.compat.v1.train.Saver()
    #self.g_sum = tf.merge_summary([#self.z_sum,
    #    self.d__sum,
    #    self.G_sum, self.d_loss_fake_sum, self.g_loss_sum])
    # self.d_sum = tf.merge_summary([#self.z_sum,
    #     self.d_sum, self.d_loss_real_sum, self.d_loss_sum])
    self.writer = tf.compat.v1.summary.FileWriter("./logs",
                                                  self.sess.graph_def)

    coord = tf.train.Coordinator()
    threads = tf.compat.v1.train.start_queue_runners(coord=coord)

    # Hang onto a copy of z so we can feed the same one every time we store
    # samples to disk for visualization
    assert self.sample_size > self.batch_size
    assert self.sample_size % self.batch_size == 0
    sample_z = []
    steps = self.sample_size // self.batch_size
    assert steps > 0
    sample_zs = []
    for i in range(steps):
        cur_zs = self.sess.run(self.zses[0])
        assert all(z.shape[0] == self.batch_size for z in cur_zs)
        sample_zs.append(cur_zs)
    sample_zs = [
        np.concatenate([batch[i] for batch in sample_zs], axis=0)
        for i in range(len(sample_zs[0]))
    ]
    assert all(sample_z.shape[0] == self.sample_size for sample_z in sample_zs)

    counter = 1

    if self.load(self.checkpoint_dir):
        print(" [*] Load SUCCESS")
    else:
        print(" [!] Load failed...")

    start_time = time.time()
    print_time = time.time()
    sample_time = time.time()
    save_time = time.time()
    idx = 0
    try:
        while not coord.should_stop():
            idx += 1
            batch_start_time = time.time()
            """
            batch_images = self.images.eval()
            from pylearn2.utils.image import save
            for i in xrange(self.batch_size):
                save("train_image_%d.png" % i, batch_images[i, :, :, :] / 2. + 0.5)
            """

            #for i in xrange(3):
            #    self.sess.run([d_optim], feed_dict=feed_dict)

            _d_optim, _d_sum, \
            _g_optim,  \
            errD_fake, errD_real, errD_class, \
            errG = self.sess.run([d_optim, self.d_sum,
                                            g_optim, # self.g_sum,
                                            self.d_loss_fakes[0],
                                            self.d_loss_reals[0],
                                            self.d_loss_classes[0],
                                            self.g_losses[0]])

            counter += 1
            if time.time() - print_time > 15.:
                print_time = time.time()
                total_time = print_time - start_time
                d_loss = errD_fake + errD_real + errD_class
                sec_per_batch = (print_time - start_time) / (idx + 1.)
                sec_this_batch = print_time - batch_start_time
                print(
                    "[Batch %(idx)d] time: %(total_time)4.4f, d_loss: %(d_loss).8f, g_loss: %(errG).8f, d_loss_real: %(errD_real).8f, d_loss_fake: %(errD_fake).8f, d_loss_class: %(errD_class).8f, sec/batch: %(sec_per_batch)4.4f, sec/this batch: %(sec_this_batch)4.4f"
                )
                #% locals()

            if (idx < 300
                    and idx % 10 == 0) or time.time() - sample_time > 300:
                sample_time = time.time()
                samples = []
                # generator hard codes the batch size
                for i in range(self.sample_size // self.batch_size):
                    feed_dict = {}
                    for z, zv in zip(self.zses[0], sample_zs):
                        if zv.ndim == 2:
                            feed_dict[z] = zv[i * self.batch_size:(i + 1) *
                                              self.batch_size, :]
                        elif zv.ndim == 4:
                            feed_dict[z] = zv[i * self.batch_size:(i + 1) *
                                              self.batch_size, :, :, :]
                        else:
                            assert False
                    cur_samples, = self.sess.run([self.Gs[0]],
                                                 feed_dict=feed_dict)
                    samples.append(cur_samples)
                samples = np.concatenate(samples, axis=0)
                assert samples.shape[0] == self.sample_size
                save_images(samples, [8, 8],
                            self.sample_dir + '/train_%s.png' % (idx))

            if time.time() - save_time > 3600:
                save_time = time.time()
                self.save(config.checkpoint_dir, counter)
    except tf.errors.OutOfRangeError:
        print("Done training; epoch limit reached.")
    finally:
        coord.request_stop()

    coord.join(threads)
예제 #4
0
def train(self, config):
    """Train DCGAN"""

    d_optim = self.d_optim
    g_optim = self.g_optim

    tf.initialize_all_variables().run()

    self.saver = tf.train.Saver()
    #self.g_sum = tf.merge_summary([#self.z_sum,
    #    self.d__sum,
    #    self.G_sum, self.d_loss_fake_sum, self.g_loss_sum])
    # self.d_sum = tf.merge_summary([#self.z_sum,
    #     self.d_sum, self.d_loss_real_sum, self.d_loss_sum])
    self.writer = tf.train.SummaryWriter("./logs", self.sess.graph_def)


    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    # Hang onto a copy of z so we can feed the same one every time we store
    # samples to disk for visualization
    assert self.sample_size > self.batch_size
    assert self.sample_size % self.batch_size == 0
    sample_z = []
    steps = self.sample_size // self.batch_size
    assert steps > 0
    sample_zs = []
    for i in xrange(steps):
        cur_zs = self.sess.run(self.zses[0])
        assert all(z.shape[0] == self.batch_size for z in cur_zs)
        sample_zs.append(cur_zs)
    sample_zs = [np.concatenate([batch[i] for batch in sample_zs], axis=0) for i in xrange(len(sample_zs[0]))]
    assert all(sample_z.shape[0] == self.sample_size for sample_z in sample_zs)

    counter = 1

    if self.load(self.checkpoint_dir):
        print(" [*] Load SUCCESS")
    else:
        print(" [!] Load failed...")

    start_time = time.time()
    print_time = time.time()
    sample_time = time.time()
    save_time = time.time()
    idx = 0
    try:
        while not coord.should_stop():
            idx += 1
            batch_start_time = time.time()

            """
            batch_images = self.images.eval()
            from pylearn2.utils.image import save
            for i in xrange(self.batch_size):
                save("train_image_%d.png" % i, batch_images[i, :, :, :] / 2. + 0.5)
            """


            #for i in xrange(3):
            #    self.sess.run([d_optim], feed_dict=feed_dict)

            _d_optim, _d_sum, \
            _g_optim,  \
            errD_fake, errD_real, errD_class, \
            errG = self.sess.run([d_optim, self.d_sum,
                                            g_optim, # self.g_sum,
                                            self.d_loss_fakes[0],
                                            self.d_loss_reals[0],
                                            self.d_loss_classes[0],
                                            self.g_losses[0]])

            counter += 1
            if time.time() - print_time > 15.:
                print_time = time.time()
                total_time = print_time - start_time
                d_loss = errD_fake + errD_real + errD_class
                sec_per_batch = (print_time - start_time) / (idx + 1.)
                sec_this_batch = print_time - batch_start_time
                print "[Batch %(idx)d] time: %(total_time)4.4f, d_loss: %(d_loss).8f, g_loss: %(errG).8f, d_loss_real: %(errD_real).8f, d_loss_fake: %(errD_fake).8f, d_loss_class: %(errD_class).8f, sec/batch: %(sec_per_batch)4.4f, sec/this batch: %(sec_this_batch)4.4f" \
                    % locals()

            if (idx < 300 and idx % 10 == 0) or time.time() - sample_time > 300:
                sample_time = time.time()
                samples = []
                # generator hard codes the batch size
                for i in xrange(self.sample_size // self.batch_size):
                    feed_dict = {}
                    for z, zv in zip(self.zses[0], sample_zs):
                        if zv.ndim == 2:
                            feed_dict[z] = zv[i*self.batch_size:(i+1)*self.batch_size, :]
                        elif zv.ndim == 4:
                            feed_dict[z] = zv[i*self.batch_size:(i+1)*self.batch_size, :, :, :]
                        else:
                            assert False
                    cur_samples, = self.sess.run(
                        [self.Gs[0]],
                        feed_dict=feed_dict
                    )
                    samples.append(cur_samples)
                samples = np.concatenate(samples, axis=0)
                assert samples.shape[0] == self.sample_size
                save_images(samples, [8, 8],
                            self.sample_dir + '/train_%s.png' % ( idx))


            if time.time() - save_time > 3600:
                save_time = time.time()
                self.save(config.checkpoint_dir, counter)
    except tf.errors.OutOfRangeError:
        print "Done training; epoch limit reached."
    finally:
        coord.request_stop()

    coord.join(threads)