コード例 #1
0
def main(_):
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)

    with tf.Session() as sess:
        srcnn = CGAN(sess,
                     image_size=FLAGS.image_size,
                     batch_size=FLAGS.batch_size,
                     c_dim=FLAGS.c_dim,
                     checkpoint_dir=FLAGS.checkpoint_dir,
                     sample_dir=FLAGS.sample_dir)

        srcnn.train(FLAGS)
コード例 #2
0
def test():

    if not os.path.exists(conf.output_path):
        os.makedirs(conf.output_path)
    data = load_data()
    model = CGAN()

    saver = tf.train.Saver()

    counter = 0
    start_time = time.time()

    with tf.Session() as sess:
        saver.restore(sess, conf.model_path_test)
        test_data = data["test"]()
        for img, cond, name in test_data:
            pimg, pcond = prepocess_test(img, cond)

            gen_img = sess.run(model.gen_img,
                               feed_dict={
                                   model.image: pimg,
                                   model.cond: pcond
                               })

            gen_img = gen_img.reshape(gen_img.shape[1:-1])

            gen_img1 = (gen_img + 1.) * 127.5

            print(gen_img1)
            path_save = conf.output_path + "/" + "%s" % (name)
            print(path_save)
            scipy.misc.imsave(path_save, gen_img1)
コード例 #3
0
ファイル: test.py プロジェクト: Ireneruru/GalaxyGAN_python
def train():
    model = CGAN()

    d_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize(
        model.d_loss, var_list=model.d_vars)
    g_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize(
        model.g_loss, var_list=model.g_vars)

    saver = tf.train.Saver()

    with tf.Session() as sess:
        saver.restore(sess, conf.model_path)
        np_path = "/home/chenyiru/FluxPreservation/demo_out/test/587724648721678356.npy"
        all = np.load(np_path)
        img, cond = all[:, :conf.img_size], all[:, conf.img_size:]

        pimg, pcond = prepocess_test(img, cond)
        gen_img = sess.run(model.gen_img,
                           feed_dict={
                               model.image: pimg,
                               model.cond: pcond
                           })
        gen_img = gen_img.reshape(gen_img.shape[1:])
        image = np.concatenate((gen_img, cond), axis=1)
        np.save(
            "/home/chenyiru/FluxPreservation/demo_out/587724648721678356.npy",
            image)
コード例 #4
0
def test():

    if not os.path.exists("test"):
        os.makedirs("test")
    data = load_data()
    model = CGAN()

    d_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize(
        model.d_loss, var_list=model.d_vars)
    g_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize(
        model.g_loss, var_list=model.g_vars)

    saver = tf.train.Saver()

    counter = 0
    start_time = time.time()

    with tf.Session() as sess:
        saver.restore(sess, conf.model_path_test)
        test_data = data["test"]()
        for img, cond, name in test_data:
            pimg, pcond = prepocess_test(img, cond)
            gen_img = sess.run(model.gen_img,
                               feed_dict={
                                   model.image: pimg,
                                   model.cond: pcond
                               })
            gen_img = gen_img.reshape(gen_img.shape[1:])
            gen_img = (gen_img + 1.) * 127.5
            image = np.concatenate((gen_img, cond), axis=1).astype(np.int)
            imsave(image, "./test" + "/%s" % name)
コード例 #5
0
    def __init__(self, flags):
        run_config = tf.ConfigProto()
        run_config.gpu_options.allow_growth = True
        self.sess = tf.Session(config=run_config)

        self.flags = flags
        self.dataset = Dataset(self.flags.dataset, self.flags)
        self.model = CGAN(self.sess, self.flags, self.dataset.image_size)

        self.best_auc_sum = 0.
        self._make_folders()

        self.saver = tf.train.Saver()
        self.sess.run(tf.global_variables_initializer())

        tf_utils.show_all_variables()
コード例 #6
0
def train(args):
    """ train model """
    batch_size = args.batch_size
    epochs = args.epochs
    base_lr = args.lr

    cgan = CGAN(args.name)
    train_dataset, _ = get_mnist_dataset(batch_size)

    with tf.Session() as sess:
        try:
            cgan.train(sess,
                       train_dataset,
                       base_lr=base_lr,
                       epochs=epochs,
                       save_period=10,
                       reset_logs=args.reset_logs,
                       version=args.version)
        except KeyboardInterrupt:
            print_with_time('Interrupted by user.')
        else:
            print_with_time('Training finished.')
        finally:
            print_with_time('Saving servable..')
            cgan.export(sess,
                        export_dir=f'{args.name}_export',
                        version=args.version)
コード例 #7
0
def train():
    data = load_data()
    model = CGAN()

    d_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rateD,
                                   beta1=conf.beta1).minimize(
                                       model.d_loss, var_list=model.d_vars)
    g_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rateG,
                                   beta1=conf.beta1).minimize(
                                       model.g_loss, var_list=model.g_vars)

    saver = tf.train.Saver()

    start_time = time.time()
    if not os.path.exists(conf.data_path_checkpoint + "/checkpoint"):
        os.makedirs(conf.data_path_checkpoint + "/checkpoint")
    if not os.path.exists(conf.output_path):
        os.makedirs(conf.output_path)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        if conf.model_path_train == "":
            sess.run(tf.global_variables_initializer())
        else:
            saver.restore(sess, conf.model_path_train)
        for epoch in range(conf.max_epoch):
            counter = 0
            train_data = data["train"]()
            for img, cond, name in train_data:
                pimg, pcond = prepocess_train(img, cond)

                _, m = sess.run([d_opt, model.d_loss],
                                feed_dict={
                                    model.image: pimg,
                                    model.cond: pcond
                                })
                _, M = sess.run([g_opt, model.g_loss],
                                feed_dict={
                                    model.image: pimg,
                                    model.cond: pcond
                                })

                counter += 1
                if counter % 50 == 0:
                    print ("Epoch [%s], Iteration [%s]: time: %s, d_loss: %s, g_loss: %s" \
                      % (epoch, counter, time.time() - start_time, m, M))
            if (epoch + 1) % conf.save_per_epoch == 0:
                save_path = saver.save(
                    sess, conf.data_path_checkpoint + "/checkpoint/" +
                    "model_%d.ckpt" % (epoch + 1))
                print("Model saved in file: %s" % (save_path))
コード例 #8
0
def train():
    data = load_data()
    model = CGAN()

    d_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize(model.d_loss, var_list=model.d_vars)
    g_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize(model.g_loss, var_list=model.g_vars)

    saver = tf.train.Saver()

    start_time = time.time()
    if not os.path.exists(conf.data_path + "/checkpoint"):
        os.makedirs(conf.data_path + "/checkpoint")
    if not os.path.exists(conf.output_path):
        os.makedirs(conf.output_path)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        if conf.model_path_train == "":
        #if not os.path.exists(conf.data_path + "/checkpoint"):
            sess.run(tf.global_variables_initializer())
        else:
            saver.restore(sess, conf.model_path_train)
            #saver.restore(sess, conf.data_path + "/checkpoint/")
            
        for epoch in np.arange(conf.max_epoch):
            counter = 0
            train_data = data["train"]()
            for img, cond, name in train_data:
                img, cond = prepocess_train(img, cond)
                _, m = sess.run([d_opt, model.d_loss], feed_dict={model.image:img, model.cond:cond})
                _, m = sess.run([d_opt, model.d_loss], feed_dict={model.image:img, model.cond:cond})
                _, M = sess.run([g_opt, model.g_loss], feed_dict={model.image:img, model.cond:cond})
                counter += 1
                if counter % 50 ==0:
                    print("Epoch [%d], Iteration [%d]: time: %4.4f, d_loss: %.8f, g_loss: %.8f" % (epoch, counter, time.time() - start_time, m, M))
            if (epoch + 1) % conf.save_per_epoch == 0:
                save_path = saver.save(sess, conf.data_path + "/checkpoint/" + "model_%d.ckpt" % (epoch+1))
                print("Model saved in file: %s" % save_path)
                test_data = data["test"]()
                for img, cond, name in test_data:
                    pimg, pcond = prepocess_test(img, cond)
                    gen_img = sess.run(model.gen_img, feed_dict={model.image:pimg, model.cond:pcond})
                    gen_img = gen_img.reshape(gen_img.shape[1:])
                    gen_img = (gen_img + 1.) * 127.5
                    image = np.concatenate((gen_img, cond), axis=1).astype(np.int)
                    imsave(image, conf.output_path + "/%s" % name)
コード例 #9
0
ファイル: test.py プロジェクト: tingard/PSFGAN
def test(mode):
    data = load_data()
    model = CGAN()

    saver = tf.train.Saver()

    counter = 0
    start_time = time.time()
    out_dir = conf.result_path
    filter_string = conf.filter_
    if not os.path.exists(conf.save_path):
        os.makedirs(conf.save_path)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    start_epoch = 0
    with tf.Session() as sess:
        saver.restore(sess, conf.model_path)
        for epoch in xrange(start_epoch, conf.max_epoch):
            if (epoch + 1) % conf.save_per_epoch == 0:
                test_data = data[str(mode)]()
                for img, cond, name in test_data:
                    name = name.replace('-' + filter_string + '.npy', '')
                    pimg, pcond = prepocess_test(img, cond)
                    gen_img = sess.run(model.gen_img,
                                       feed_dict={
                                           model.image: pimg,
                                           model.cond: pcond
                                       })
                    gen_img = gen_img.reshape(gen_img.shape[1:])

                    fits_recover = conf.unstretch(gen_img[:, :, 0])
                    hdu = fits.PrimaryHDU(fits_recover)
                    save_dir = '%s/epoch_%s/fits_output' % (out_dir, epoch + 1)
                    if not os.path.exists(save_dir):
                        os.makedirs(save_dir)
                    filename = '%s/%s-%s.fits' % (save_dir, name,
                                                  filter_string)
                    if os.path.exists(filename):
                        os.remove(filename)
                    hdu.writeto(filename)
コード例 #10
0
ファイル: main.py プロジェクト: JunYinDM/VCRO
def train(args):
    train_loader = DataLoader(DonutDataset(root=args.data_root, is_train=True),
                              batch_size=args.batch_size,
                              shuffle=True,
                              drop_last=True,
                              num_workers=3)
    test_loader = DataLoader(DonutDataset(root=args.data_root, is_train=False),
                             batch_size=args.batch_size,
                             shuffle=True,
                             drop_last=False,
                             num_workers=3)
    if args.model == 'regressor':
        model = Model(args)
    elif args.model == 'gan':
        model = CGAN(args)
    else:
        raise Exception('Not implemented')
    for epoch in range(args.epochs):
        print("EPOCH: ", epoch)
        model.train_one_epoch(train_loader, epoch)
        model.test_one_epoch(test_loader, epoch)
コード例 #11
0
class Solver(object):
    def __init__(self, flags):
        run_config = tf.ConfigProto()
        run_config.gpu_options.allow_growth = True
        self.sess = tf.Session(config=run_config)

        self.flags = flags
        self.dataset = Dataset(self.flags.dataset, self.flags)
        self.model = CGAN(self.sess, self.flags, self.dataset.image_size)

        self.best_auc_sum = 0.
        self._make_folders()

        self.saver = tf.train.Saver()
        self.sess.run(tf.global_variables_initializer())

        tf_utils.show_all_variables()

    def _make_folders(self):
        self.model_out_dir = "{}/model_{}_{}_{}".format(
            self.flags.dataset, self.flags.discriminator,
            self.flags.train_interval, self.flags.batch_size)
        if not os.path.isdir(self.model_out_dir):
            os.makedirs(self.model_out_dir)

        if self.flags.is_test:
            self.img_out_dir = "{}/seg_result_{}_{}_{}".format(
                self.flags.dataset, self.flags.discriminator,
                self.flags.train_interval, self.flags.batch_size)
            self.auc_out_dir = "{}/auc_{}_{}_{}".format(
                self.flags.dataset, self.flags.discriminator,
                self.flags.train_interval, self.flags.batch_size)

            if not os.path.isdir(self.img_out_dir):
                os.makedirs(self.img_out_dir)
            if not os.path.isdir(self.auc_out_dir):
                os.makedirs(self.auc_out_dir)

        elif not self.flags.is_test:
            self.sample_out_dir = "{}/sample_{}_{}_{}".format(
                self.flags.dataset, self.flags.discriminator,
                self.flags.train_interval, self.flags.batch_size)
            if not os.path.isdir(self.sample_out_dir):
                os.makedirs(self.sample_out_dir)

    def train(self):
        for iter_time in range(0, self.flags.iters + 1,
                               self.flags.train_interval):
            self.sample(iter_time)  # sampling images and save them

            # train discrminator
            for iter_ in range(1, self.flags.train_interval + 1):
                x_imgs, y_imgs = self.dataset.train_next_batch(
                    batch_size=self.flags.batch_size)
                d_loss = self.model.train_dis(x_imgs, y_imgs)
                self.print_info(iter_time + iter_, 'd_loss', d_loss)

            # train generator
            for iter_ in range(1, self.flags.train_interval + 1):
                x_imgs, y_imgs = self.dataset.train_next_batch(
                    batch_size=self.flags.batch_size)
                g_loss = self.model.train_gen(x_imgs, y_imgs)
                self.print_info(iter_time + iter_, 'g_loss', g_loss)

            auc_sum = self.eval(iter_time, phase='train')

            if self.best_auc_sum < auc_sum:
                self.best_auc_sum = auc_sum
                self.save_model(iter_time)

    def test(self):
        if self.load_model():
            print(' [*] Load Success!\n')
            self.eval(phase='test')
        else:
            print(' [!] Load Failed!\n')

    def sample(self, iter_time):
        if np.mod(iter_time, self.flags.sample_freq) == 0:
            idx = np.random.choice(self.dataset.num_val, 2, replace=False)
            x_imgs, y_imgs = self.dataset.val_imgs[
                idx], self.dataset.val_vessels[idx]
            samples = self.model.sample_imgs(x_imgs)

            # masking
            seg_samples = utils.remain_in_mask(samples,
                                               self.dataset.val_masks[idx])

            # crop to original image shape
            x_imgs_ = utils.crop_to_original(x_imgs, self.dataset.ori_shape)
            seg_samples_ = utils.crop_to_original(seg_samples,
                                                  self.dataset.ori_shape)
            y_imgs_ = utils.crop_to_original(y_imgs, self.dataset.ori_shape)

            # sampling
            self.plot(x_imgs_,
                      seg_samples_,
                      y_imgs_,
                      iter_time,
                      idx=idx,
                      save_file=self.sample_out_dir,
                      phase='train')

    def plot(self,
             x_imgs,
             samples,
             y_imgs,
             iter_time,
             idx=None,
             save_file=None,
             phase='train'):
        # initialize grid size
        cell_size_h, cell_size_w = self.dataset.ori_shape[
            0] / 100, self.dataset.ori_shape[1] / 100
        num_columns, margin = 3, 0.05
        width = cell_size_w * num_columns
        height = cell_size_h * x_imgs.shape[0]
        fig = plt.figure(figsize=(width, height))  # (column, row)
        gs = gridspec.GridSpec(x_imgs.shape[0], num_columns)  # (row, column)
        gs.update(wspace=margin, hspace=margin)

        # convert from normalized to original image
        x_imgs_norm = np.zeros_like(x_imgs)
        std, mean = 0., 0.
        for _ in range(x_imgs.shape[0]):
            if phase == 'train':
                std = self.dataset.val_mean_std[idx[_]]['std']
                mean = self.dataset.val_mean_std[idx[_]]['mean']
            elif phase == 'test':
                std = self.dataset.test_mean_std[idx[_]]['std']
                mean = self.dataset.test_mean_std[idx[_]]['mean']
            x_imgs_norm[_] = np.expand_dims(x_imgs[_], axis=0) * std + mean
        x_imgs_norm = x_imgs_norm.astype(np.uint8)

        # 1 channel to 3 channels
        samples_3 = np.stack((samples, samples, samples), axis=3)
        y_imgs_3 = np.stack((y_imgs, y_imgs, y_imgs), axis=3)

        imgs = [x_imgs_norm, samples_3, y_imgs_3]
        for col_index in range(len(imgs)):
            for row_index in range(x_imgs.shape[0]):
                ax = plt.subplot(gs[row_index * num_columns + col_index])
                plt.axis('off')
                ax.set_xticklabels([])
                ax.set_yticklabels([])
                ax.set_aspect('equal')
                plt.imshow(imgs[col_index][row_index].reshape(
                    self.dataset.ori_shape[0], self.dataset.ori_shape[1], 3),
                           cmap='Greys_r')

        if phase == 'train':
            plt.savefig(save_file +
                        '/{}_{}.png'.format(str(iter_time), idx[0]),
                        bbox_inches='tight')
            plt.close(fig)
        else:
            # save compared image
            plt.savefig(os.path.join(
                save_file, 'compared_{}.png'.format(
                    os.path.basename(
                        self.dataset.test_img_files[idx[0]])[:-4])),
                        bbox_inches='tight')
            plt.close(fig)

            # save vessel alone, vessel should be uint8 type
            Image.fromarray(np.squeeze(samples * 255).astype(np.uint8)).save(
                os.path.join(
                    save_file, '{}.png'.format(
                        os.path.basename(
                            self.dataset.test_img_files[idx[0]][:-4]))))

    def print_info(self, iter_time, name, loss):
        if np.mod(iter_time, self.flags.print_freq) == 0:
            ord_output = collections.OrderedDict([
                (name, loss), ('dataset', self.flags.dataset),
                ('discriminator', self.flags.discriminator),
                ('train_interval', np.float32(self.flags.train_interval)),
                ('gpu_index', self.flags.gpu_index)
            ])
            utils.print_metrics(iter_time, ord_output)

    def eval(self, iter_time=0, phase='train'):
        total_time, auc_sum = 0., 0.
        if np.mod(iter_time, self.flags.eval_freq) == 0:
            num_data, imgs, vessels, masks = None, None, None, None
            if phase == 'train':
                num_data = self.dataset.num_val
                imgs = self.dataset.val_imgs
                vessels = self.dataset.val_vessels
                masks = self.dataset.val_masks
            elif phase == 'test':
                num_data = self.dataset.num_test
                imgs = self.dataset.test_imgs
                vessels = self.dataset.test_vessels
                masks = self.dataset.test_masks

            generated = []
            for iter_ in range(num_data):
                x_img = imgs[iter_]
                x_img = np.expand_dims(x_img,
                                       axis=0)  # (H, W, C) to (1, H, W, C)

                # measure inference time
                start_time = time.time()
                generated_vessel = self.model.sample_imgs(x_img)
                total_time += (time.time() - start_time)

                generated.append(np.squeeze(
                    generated_vessel, axis=(0, 3)))  # (1, H, W, 1) to (H, W)

            generated = np.asarray(generated)
            # calculate measurements
            auc_sum = self.measure(generated, vessels, masks, num_data,
                                   iter_time, phase, total_time)

            if phase == 'test':
                # save test images
                segmented_vessel = utils.remain_in_mask(generated, masks)

                # crop to original image shape
                imgs_ = utils.crop_to_original(imgs, self.dataset.ori_shape)
                cropped_vessel = utils.crop_to_original(
                    segmented_vessel, self.dataset.ori_shape)
                vessels_ = utils.crop_to_original(vessels,
                                                  self.dataset.ori_shape)

                for idx in range(num_data):
                    self.plot(np.expand_dims(imgs_[idx], axis=0),
                              np.expand_dims(cropped_vessel[idx], axis=0),
                              np.expand_dims(vessels_[idx], axis=0),
                              'test',
                              idx=[idx],
                              save_file=self.img_out_dir,
                              phase='test')

        return auc_sum

    def measure(self, generated, vessels, masks, num_data, iter_time, phase,
                total_time):
        # masking
        vessels_in_mask, generated_in_mask = utils.pixel_values_in_mask(
            vessels, generated, masks)

        # averaging processing time
        avg_pt = (total_time / num_data) * 1000  # average processing tiem

        # evaluate Area Under the Curve of ROC and Precision-Recall
        auc_roc = utils.AUC_ROC(vessels_in_mask, generated_in_mask)
        auc_pr = utils.AUC_PR(vessels_in_mask, generated_in_mask)

        # binarize to calculate Dice Coeffient
        binarys_in_mask = utils.threshold_by_otsu(generated, masks)
        dice_coeff = utils.dice_coefficient_in_train(vessels_in_mask,
                                                     binarys_in_mask)
        acc, sensitivity, specificity = utils.misc_measures(
            vessels_in_mask, binarys_in_mask)
        score = auc_pr + auc_roc + dice_coeff + acc + sensitivity + specificity

        # auc_sum for saving best model in training
        auc_sum = auc_roc + auc_pr

        # print information
        ord_output = collections.OrderedDict([('auc_pr', auc_pr),
                                              ('auc_roc', auc_roc),
                                              ('dice_coeff', dice_coeff),
                                              ('acc', acc),
                                              ('sensitivity', sensitivity),
                                              ('specificity', specificity),
                                              ('score', score),
                                              ('auc_sum', auc_sum),
                                              ('best_auc_sum',
                                               self.best_auc_sum),
                                              ('avg_pt', avg_pt)])
        utils.print_metrics(iter_time, ord_output)

        # write in tensorboard when in train mode only
        if phase == 'train':
            self.model.measure_assign(auc_pr, auc_roc, dice_coeff, acc,
                                      sensitivity, specificity, score,
                                      iter_time)
        elif phase == 'test':
            # write in npy format for evaluation
            utils.save_obj(vessels_in_mask, generated_in_mask,
                           os.path.join(self.auc_out_dir, "auc_roc.npy"),
                           os.path.join(self.auc_out_dir, "auc_pr.npy"))

        return auc_sum

    def save_model(self, iter_time):
        self.model.best_auc_sum_assign(self.best_auc_sum)

        model_name = "iter_{}_auc_sum_{:.3}".format(iter_time,
                                                    self.best_auc_sum)
        self.saver.save(self.sess, os.path.join(self.model_out_dir,
                                                model_name))

        print('===================================================')
        print('                     Model saved!                  ')
        print(' Best auc_sum: {:.3}'.format(self.best_auc_sum))
        print('===================================================\n')

    def load_model(self):
        print(' [*] Reading checkpoint...')

        ckpt = tf.train.get_checkpoint_state(self.model_out_dir)
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            self.saver.restore(self.sess,
                               os.path.join(self.model_out_dir, ckpt_name))

            self.best_auc_sum = self.sess.run(self.model.best_auc_sum)
            print('====================================================')
            print('                     Model saved!                   ')
            print(' Best auc_sum: {:.3}'.format(self.best_auc_sum))
            print('====================================================')

            return True
        else:
            return False
コード例 #12
0
def run_gan():
    (train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()

    train_images = train_images.reshape(train_images.shape[0], 28, 28,
                                        1).astype('float32')
    train_images = (train_images - 127.5) / 127.5  # Normalize images to [-1,1]
    print(train_images.shape)

    train_labels = to_categorical(train_labels)
    print(train_labels.shape)

    # Batch and shuffle the data
    train_dataset = tf.data.Dataset.from_tensor_slices(
        (train_images, train_labels)).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

    gan = CGAN(gen_lr, disc_lr, noise_dim=NOISE_DIM)
    gan.create_generator()
    gan.create_discriminator()

    if model_test:
        # Test generator
        random_noise = tf.random.normal([1, NOISE_DIM])
        condition = tf.zeros(shape=(1, 10))
        generated_image = gan.generator([random_noise, condition])
        plt.imshow(generated_image[0, :, :, 0], cmap='gray')
        plt.show()
        # Test Discriminator
        prob = gan.discriminator([generated_image, condition])
        print("Probability of image being real: {}".format(sigmoid(prob)))

    gan.set_noise_seed(num_examples_to_generate)
    print(gan.label_seed.shape)
    gan.set_checkpoint(path=save_ckpt_path)
    gen_loss_array, disc_loss_array = gan.train(train_dataset, epochs=EPOCHS)

    # Plot Discriminator Loss
    plt.plot(range(EPOCHS), gen_loss_array)
    plt.plot(range(EPOCHS), disc_loss_array)
    plt.show()
コード例 #13
0
ファイル: train.py プロジェクト: tingard/PSFGAN
def train(evalset):
    data = load_data()
    model = CGAN()

    d_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize(model.d_loss, var_list=model.d_vars)
    g_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize(model.g_loss, var_list=model.g_vars)

    saver = tf.train.Saver()

    counter = 0
    start_time = time.time()
    out_dir = conf.result_path
    filter_string = conf.filter_
    if not os.path.exists(conf.save_path):
        os.makedirs(conf.save_path)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    start_epoch = 0
    with tf.Session() as sess:
        if conf.model_path == "":
            sess.run(tf.global_variables_initializer())
        else:
            saver.restore(sess, conf.model_path)
            try:
                log = open(conf.save_path + "/log")
                start_epoch = int(log.readline())
                log.close()
            except:
                pass
        for epoch in xrange(start_epoch, conf.max_epoch):
            train_data = data["train"]()
            for img, cond, _ in train_data:
                img, cond = prepocess_train(img, cond)
                _, m = sess.run([d_opt, model.d_loss], feed_dict={model.image: img, model.cond: cond})
                _, m = sess.run([d_opt, model.d_loss], feed_dict={model.image: img, model.cond: cond})
                _, M, flux = sess.run([g_opt, model.g_loss, model.delta],
                                      feed_dict={model.image: img, model.cond: cond})
                counter += 1
                print("Iterate [%d]: time: %4.4f, d_loss: %.8f, g_loss: %.8f, flux: %.8f" \
                      % (counter, time.time() - start_time, m, M, flux))
            if (epoch + 1) % conf.save_per_epoch == 0:
                # save_path = saver.save(sess, conf.data_path + "/checkpoint/" + "model_%d.ckpt" % (epoch+1))
                save_path = saver.save(sess, conf.save_path + "/model.ckpt")
                print("Model at epoch %s saved in file: %s" % (epoch + 1, save_path))

                log = open(conf.save_path + "/log", "w")
                log.write(str(epoch + 1))
                log.close()

                test_data = data[str(evalset)]()
                for img, cond, name in test_data:
                    name = name.replace('-'+filter_string+'.npy', '')
                    pimg, pcond = prepocess_test(img, cond)
                    gen_img = sess.run(model.gen_img, feed_dict={model.image: pimg, model.cond: pcond})
                    gen_img = gen_img.reshape(gen_img.shape[1:])

                    fits_recover = conf.unstretch(gen_img[:, :, 0])
                    hdu = fits.PrimaryHDU(fits_recover)
                    save_dir = '%s/epoch_%s/fits_output' % (out_dir, epoch + 1)
                    if not os.path.exists(save_dir):
                        os.makedirs(save_dir)
                    filename = '%s/%s-%s.fits' % (save_dir, name, filter_string)
                    if os.path.exists(filename):
                        os.remove(filename)
                    hdu.writeto(filename)
コード例 #14
0
def train():
    data = load_data()
    model = CGAN()

    d_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize(
        model.d_loss, var_list=model.d_vars)
    g_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize(
        model.g_loss, var_list=model.g_vars)

    saver = tf.train.Saver()

    counter = 0
    start_time = time.time()
    if not os.path.exists(conf.data_path + "/checkpoint"):
        os.makedirs(conf.data_path + "/checkpoint")
    if not os.path.exists(conf.output_path):
        os.makedirs(conf.output_path)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    mpsnr_img = []
    mpsnr_cond = []
    with tf.Session(config=config) as sess:
        if conf.model_path == "":
            sess.run(tf.initialize_all_variables())
        else:
            saver.restore(sess, conf.model_path)
        print conf.max_epoch
        for epoch in xrange(conf.max_epoch):
            train_data = data["train"]()
            for img, cond, name in train_data:
                img, cond = prepocess_train(img, cond)
                _, m = sess.run([d_opt, model.d_loss],
                                feed_dict={
                                    model.image: img,
                                    model.cond: cond
                                })
                _, m = sess.run([d_opt, model.d_loss],
                                feed_dict={
                                    model.image: img,
                                    model.cond: cond
                                })
                _, M = sess.run([g_opt, model.g_loss],
                                feed_dict={
                                    model.image: img,
                                    model.cond: cond
                                })
                counter += 1
                print "Iterate [%d]: time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
                      % (counter, time.time() - start_time, m, M)
            if (epoch + 1) % conf.save_per_epoch == 0:
                save_path = saver.save(
                    sess, conf.data_path + "/checkpoint/" + "model_%d.ckpt" %
                    (epoch + 1))
                print "Model saved in file: %s" % save_path
                mean_psnr_img = 0
                mean_psnr_cond = 0
                i = 0
                test_data = data["test"]()
                for img, cond, name in test_data:
                    pimg, pcond = prepocess_test(img, cond)
                    gen_img = sess.run(model.gen_img,
                                       feed_dict={
                                           model.image: pimg,
                                           model.cond: pcond
                                       })
                    gen_img = gen_img.reshape(gen_img.shape[1:])
                    gen_img = (gen_img + 1.) * 127.5
                    #print type(img), type(cond), type(gen_img), img.shape, cond.shape, gen_img.shape, img.dtype, cond.dtype, gen_img.dtype
                    mean_psnr_img = mean_psnr_img + skimage.measure.compare_psnr(
                        img, gen_img.astype(np.uint8))
                    mean_psnr_cond = mean_psnr_cond + skimage.measure.compare_psnr(
                        cond, gen_img.astype(np.uint8))
                    image = np.concatenate((gen_img, cond),
                                           axis=1).astype(np.int)
                    i = i + 1
                    imsave(image, conf.output_path + "/%s" % name)
                mean_psnr_img = mean_psnr_img / i
                mpsnr_img.append(mean_psnr_img)
                mean_psnr_cond = mean_psnr_cond / i
                mpsnr_cond.append(mean_psnr_cond)
        print mpsnr_cond
        print mpsnr_img
        plt.plot(mpsnr_cond)
        plt.show()
コード例 #15
0
ファイル: train.py プロジェクト: Ireneruru/GalaxyGAN_python
def train():

    data = load_data()
    model = CGAN()

    d_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize(
        model.d_loss, var_list=model.d_vars)
    g_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize(
        model.g_loss, var_list=model.g_vars)

    saver = tf.train.Saver()

    counter = 0
    start_time = time.time()
    if not os.path.exists(conf.save_path):
        os.makedirs(conf.save_path)
    if not os.path.exists(conf.output_path):
        os.makedirs(conf.output_path)

    start_epoch = 0
    try:
        log = open(conf.save_path + "/log")
        start_epoch = int(log.readline())
        log.close()
    except:
        pass

    with tf.Session() as sess:
        if conf.model_path == "":
            sess.run(tf.global_variables_initializer())
        else:
            saver.restore(sess, conf.model_path)
        for epoch in xrange(start_epoch, conf.max_epoch):
            train_data = data["train"]()
            for img, cond, _ in train_data:
                img, cond = prepocess_train(img, cond)
                _, m = sess.run([d_opt, model.d_loss],
                                feed_dict={
                                    model.image: img,
                                    model.cond: cond
                                })
                _, m = sess.run([d_opt, model.d_loss],
                                feed_dict={
                                    model.image: img,
                                    model.cond: cond
                                })
                _, M, flux = sess.run([g_opt, model.g_loss, model.delta],
                                      feed_dict={
                                          model.image: img,
                                          model.cond: cond
                                      })
                counter += 1
                print "Iterate [%d]: time: %4.4f, d_loss: %.8f, g_loss: %.8f, flux: %.8f"\
                      % (counter, time.time() - start_time, m, M, flux)
            if (epoch + 1) % conf.save_per_epoch == 0:
                save_path = saver.save(sess, conf.save_path + "/model.ckpt")
                print "Model saved in file: %s" % save_path

                log = open(conf.save_path + "/log", "w")
                log.write(str(epoch + 1))
                log.close()

                test_data = data["test"]()
                test_count = 0
                for img, cond, name in test_data:
                    test_count += 1
                    pimg, pcond = prepocess_test(img, cond)
                    gen_img = sess.run(model.gen_img,
                                       feed_dict={
                                           model.image: pimg,
                                           model.cond: pcond
                                       })
                    gen_img = gen_img.reshape(gen_img.shape[1:])
                    image = np.concatenate((gen_img, cond), axis=1)
                    np.save(conf.output_path + "/" + name, image)
コード例 #16
0
ファイル: infer.py プロジェクト: hyes92121/MLDS2018SPRING-1

def cvt_output(model_output):
    img = model_output.data.numpy()[0]
    img = np.transpose(img, (1, 2, 0))
    img = 0.5 * img + 0.5
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    #img = np.transpose(img, (2, 0, 1))

    return img


state = torch.load('model.tar', map_location=lambda storage, loc: storage)
config = state['config']

model = CGAN(config)
model.load_state_dict(state['state_dict'])
model.eval()

with open('vocab.pkl', 'rb') as f:
    vocab = pickle.load(f)

conditions = []

with open(sys.argv[1]) as f:
    for line in f:
        line = line.split(',')[1]
        line = line.split()
        conditions.append({'hair': [line[0]], 'eyes': [line[2]]})

generated_imgs = []
コード例 #17
0
def train():
    data = load_data()
    model = CGAN()

    d_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize(
        model.d_loss, var_list=model.d_vars)
    g_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize(
        model.g_loss, var_list=model.g_vars)

    saver = tf.train.Saver()

    counter = 0
    start_time = time.time()
    if not os.path.exists(conf.data_path + "/checkpoint"):
        os.makedirs(conf.data_path + "/checkpoint")
    if not os.path.exists(conf.output_path):
        os.makedirs(conf.output_path)

    with tf.Session() as sess:
        if conf.model_path == "":
            sess.run(tf.initialize_all_variables())
        else:
            saver.restore(sess, conf.model_path)
        for epoch in xrange(conf.max_epoch):
            train_data = data["train"]
            for img, cond in train_data:
                img, cond = prepocess_train(img, cond)
                _, m = sess.run([d_opt, model.d_loss],
                                feed_dict={
                                    model.image: img,
                                    model.cond: cond
                                })
                _, m = sess.run([d_opt, model.d_loss],
                                feed_dict={
                                    model.image: img,
                                    model.cond: cond
                                })
                _, M = sess.run([g_opt, model.g_loss],
                                feed_dict={
                                    model.image: img,
                                    model.cond: cond
                                })
                counter += 1
                print "Iterate [%d]: time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
                      % (counter, time.time() - start_time, m, M)
            if (epoch + 1) % conf.save_per_epoch == 0:
                save_path = saver.save(
                    sess, conf.data_path + "/checkpoint/" + "model_%d.ckpt" %
                    (epoch + 1))
                print "Model saved in file: %s" % save_path
                test_data = data["test"]
                test_count = 0
                for img, cond in test_data:
                    test_count += 1
                    pimg, pcond = prepocess_test(img, cond)
                    gen_img = sess.run(model.gen_img,
                                       feed_dict={
                                           model.image: pimg,
                                           model.cond: pcond
                                       })
                    gen_img = gen_img.reshape(gen_img.shape[1:])
                    gen_img = (gen_img + 1.) * 127.5
                    image = np.concatenate((gen_img, cond),
                                           axis=1).astype(np.int)
                    imsave(image, conf.output_path + "/%d.jpg" % test_count)