def save_images(self, nets, epoch, curr_tr, images_i, images_j):
        """
        Saves input and output images.
        """
        nets = utils.set_mode(nets, "eval")

        exists_or_mkdir(self._images_dir)

        if curr_tr > 0:
            donorm = False
        else:
            donorm = True

        names = [
            'inputA_', 'inputB_', 'fakeA_', 'fakeB_', 'cycA_', 'cycB_',
            'mask_a', 'mask_b'
        ]

        with open(
                os.path.join(self._output_dir,
                             'epoch_' + str(epoch) + '.html'), 'w') as v_html:
            input_iter = minibatches(images_i,
                                     images_j,
                                     batch_size=1,
                                     shuffle=True)

            for i in range(0, self._num_imgs_to_save):
                #pdb.set_trace()
                print("Saving image {}/{}...".format(i,
                                                     self._num_imgs_to_save))
                self.image_a, self.image_b = next(input_iter)
                tmp_imgA = self.get_fake_image_pool(self.num_fake_inputs,
                                                    self.fake_images_A)
                self.fake_pool_A_mask = tmp_imgA["mask"]
                self.fake_pool_A = tmp_imgA["im"]

                tmp_imgB = self.get_fake_image_pool(self.num_fake_inputs,
                                                    self.fake_images_B)
                self.fake_pool_B_mask = tmp_imgB["mask"]
                self.fake_pool_B = tmp_imgB["im"]

                self.transition_rate = np.array([curr_tr], dtype=np.float32)
                self.donorm = np.array([donorm], dtype=np.float32)

                self.output_converter(
                    model.get_outputs(self.input_converter(), nets))

                figures_to_save = [
                    self.image_a, self.image_b, self.fake_images_b,
                    self.fake_images_a, self.cycle_images_a,
                    self.cycle_images_b, self.masks[0], self.masks[1]
                ]

                self.figure_writer(figures_to_save,
                                   names,
                                   v_html,
                                   epoch,
                                   i,
                                   html_mode=0)
def main():

    args = parse_args()
    if args is None:
        exit()

    to_train = args.to_train
    log_dir = args.log_dir
    config_filename = args.config_filename
    checkpoint_dir = args.checkpoint_dir
    skip = args.skip
    switch = args.switch
    threshold_fg = args.threshold

    exists_or_mkdir(log_dir)

    with open(config_filename) as config_file:
        config = json.load(config_file)

    lambda_a = float(config['_LAMBDA_A']) if '_LAMBDA_A' in config else 10.0
    lambda_b = float(config['_LAMBDA_B']) if '_LAMBDA_B' in config else 10.0
    pool_size = int(config['pool_size']) if 'pool_size' in config else 50

    to_restore = (to_train == 2)
    base_lr = float(config['base_lr']) if 'base_lr' in config else 0.0002
    max_step = int(config['max_step']) if 'max_step' in config else 200
    dataset_name = str(config['dataset_name'])
    do_flipping = bool(config['do_flipping'])
    checkpoint_name = args.checkpoint_name

    if checkpoint_name == '' and to_train != 1:
        print("Error: please provide the latest checkpoint name.")
        exit()

    cyclegan_model = CycleGAN(pool_size, lambda_a, lambda_b, log_dir,
                              to_restore, checkpoint_name, base_lr, max_step,
                              dataset_name, checkpoint_dir, do_flipping, skip,
                              switch, threshold_fg)

    if to_train > 0:
        cyclegan_model.train()
    else:
        cyclegan_model.test()
Beispiel #3
0
    def find_top_model(self, sess, sort=None, model_name='model', **kwargs):
        """Finds and returns a model architecture and its parameters from the database which matches the requirement.

        Parameters
        ----------
        sess : Session
            TensorFlow session.
        sort : List of tuple
            PyMongo sort comment, search "PyMongo find one sorting" and `collection level operations <http://api.mongodb.com/python/current/api/pymongo/collection.html>`__ for more details.
        model_name : str or None
            The name/key of model.
        kwargs : other events
            Other events, such as name, accuracy, loss, step number and etc (optinal).

        Examples
        ---------
        - see ``save_model``.

        Returns
        ---------
        network : TensorLayer layer
            Note that, the returned network contains all information of the document (record), e.g. if you saved accuracy in the document, you can get the accuracy by using ``net._accuracy``.
        """
        # print(kwargs)   # {}
        kwargs.update({'model_name': model_name})
        self._fill_project_info(kwargs)

        s = time.time()

        d = self.db.Model.find_one(filter=kwargs, sort=sort)

        _temp_file_name = '_find_one_model_ztemp_file'
        if d is not None:
            params_id = d['params_id']
            graphs = d['architecture']
            _datetime = d['time']
            exists_or_mkdir(_temp_file_name, False)
            with open(os.path.join(_temp_file_name, 'graph.pkl'), 'wb') as file:
                pickle.dump(graphs, file, protocol=pickle.HIGHEST_PROTOCOL)
        else:
            print("[Database] FAIL! Cannot find model: {}".format(kwargs))
            return False
        try:
            params = self._deserialization(self.model_fs.get(params_id).read())
            np.savez(os.path.join(_temp_file_name, 'params.npz'), params=params)

            network = load_graph_and_params(name=_temp_file_name, sess=sess)
            del_folder(_temp_file_name)

            pc = self.db.Model.find(kwargs)
            print(
                "[Database] Find one model SUCCESS. kwargs:{} sort:{} save time:{} took: {}s".
                format(kwargs, sort, _datetime, round(time.time() - s, 2))
            )

            # put all informations of model into the TL layer
            for key in d:
                network.__dict__.update({"_%s" % key: d[key]})

            # check whether more parameters match the requirement
            params_id_list = pc.distinct('params_id')
            n_params = len(params_id_list)
            if n_params != 1:
                print("     Note that there are {} models match the kwargs".format(n_params))
            return network
        except Exception as e:
            exc_type, exc_obj, exc_tb = sys.exc_info()
            fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
            logging.info("{}  {}  {}  {}  {}".format(exc_type, exc_obj, fname, exc_tb.tb_lineno, e))
            return False
Beispiel #4
0
    def find_top_model(self, sess, sort=None, model_name='model', **kwargs):
        """Finds and returns a model architecture and its parameters from the database which matches the requirement.

        Parameters
        ----------
        sess : Session
            TensorFlow session.
        sort : List of tuple
            PyMongo sort comment, search "PyMongo find one sorting" and `collection level operations <http://api.mongodb.com/python/current/api/pymongo/collection.html>`__ for more details.
        model_name : str or None
            The name/key of model.
        kwargs : other events
            Other events, such as name, accuracy, loss, step number and etc (optinal).

        Examples
        ---------
        - see ``save_model``.

        Returns
        ---------
        network : TensorLayer layer
            Note that, the returned network contains all information of the document (record), e.g. if you saved accuracy in the document, you can get the accuracy by using ``net._accuracy``.
        """
        # print(kwargs)   # {}
        kwargs.update({'model_name': model_name})
        self._fill_project_info(kwargs)

        s = time.time()

        d = self.db.Model.find_one(filter=kwargs, sort=sort)

        _temp_file_name = '_find_one_model_ztemp_file'
        if d is not None:
            params_id = d['params_id']
            graphs = d['architecture']
            _datetime = d['time']
            exists_or_mkdir(_temp_file_name, False)
            with open(os.path.join(_temp_file_name, 'graph.pkl'),
                      'wb') as file:
                pickle.dump(graphs, file, protocol=pickle.HIGHEST_PROTOCOL)
        else:
            print("[Database] FAIL! Cannot find model: {}".format(kwargs))
            return False
        try:
            params = self._deserialization(self.model_fs.get(params_id).read())
            np.savez(os.path.join(_temp_file_name, 'params.npz'),
                     params=params)

            network = load_graph_and_params(name=_temp_file_name, sess=sess)
            del_folder(_temp_file_name)

            pc = self.db.Model.find(kwargs)
            print(
                "[Database] Find one model SUCCESS. kwargs:{} sort:{} save time:{} took: {}s"
                .format(kwargs, sort, _datetime, round(time.time() - s, 2)))

            # put all informations of model into the TL layer
            for key in d:
                network.__dict__.update({"_%s" % key: d[key]})

            # check whether more parameters match the requirement
            params_id_list = pc.distinct('params_id')
            n_params = len(params_id_list)
            if n_params != 1:
                print("     Note that there are {} models match the kwargs".
                      format(n_params))
            return network
        except Exception as e:
            exc_type, exc_obj, exc_tb = sys.exc_info()
            fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
            logging.info("{}  {}  {}  {}  {}".format(exc_type, exc_obj, fname,
                                                     exc_tb.tb_lineno, e))
            return False
    def train(self):
        """
        Training Function.
        We use batch size = 1 for training
        """

        # Build the network and compute losses
        nets = self.model_setup()

        summary_writer = tf.summary.create_file_writer(
            os.path.join(self._output_dir, "log"))
        summary_writer.set_as_default()

        max_images = cyclegan_datasets.DATASET_TO_SIZES[self._dataset_name]
        half_training_ep = int(self._max_step / 2)

        # Restore the model to run the model from last checkpoint
        print("Loading the latest checkpoint...")

        if self._to_restore:
            checkpoint_name = os.path.join(self._checkpoint_dir,
                                           self._checkpoint_name)
            utils.load(checkpoint_name, nets=nets)
            self.global_step = int(checkpoint_name[-2:])
        else:
            self.global_step = 0

        exists_or_mkdir(self._output_dir)

        self.upd_fake_image_pool(self.num_fake_inputs, self.fake_pool_A,
                                 self.fake_pool_A_mask, self.fake_images_A)
        self.upd_fake_image_pool(self.num_fake_inputs, self.fake_pool_B,
                                 self.fake_pool_B_mask, self.fake_images_B)
        self.num_fake_inputs += 1

        # Training Loop
        for epoch in range(self.global_step, self._max_step):
            print("In the epoch ", epoch)
            print("Saving the latest checkpoint...")
            utils.save(nets,
                       os.path.join(self._output_dir, "AGGAN_%02d" % epoch))

            # Setting lr
            curr_lr = self._base_lr
            if epoch >= half_training_ep:
                curr_lr -= self._base_lr * (
                    epoch - half_training_ep) / half_training_ep
            self.g_A_trainer.learning_rate = curr_lr
            self.g_B_trainer.learning_rate = curr_lr
            self.g_A_trainer_bis.learning_rate = curr_lr
            self.g_B_trainer_bis.learning_rate = curr_lr
            self.d_A_trainer.learning_rate = curr_lr
            self.d_B_trainer.learning_rate = curr_lr

            if epoch < self._switch:
                curr_tr = 0
                donorm = True
                to_train_A = self.g_A_trainer
                to_train_B = self.g_B_trainer
                to_train_A_vars = self.g_A_vars + self.g_Ae_vars
                to_train_B_vars = self.g_B_vars + self.g_Be_vars
            else:
                curr_tr = self._threshold_fg
                donorm = False
                to_train_A = self.g_A_trainer_bis
                to_train_B = self.g_B_trainer_bis
                to_train_A_vars = self.g_A_vars
                to_train_B_vars = self.g_B_vars

            print("Loading data...")
            tot_inputs = data_loader.load_data(self._dataset_name,
                                               self._size_before_crop, False,
                                               self._do_flipping)
            self.inputs_img_i = tot_inputs['images_i']
            self.inputs_img_j = tot_inputs['images_j']
            assert (len(self.inputs_img_i) == len(self.inputs_img_j)
                    and max_images == len(self.inputs_img_i))

            self.save_images(nets, epoch, curr_tr, self.inputs_img_i,
                             self.inputs_img_j)
            nets = utils.set_mode(nets, "train")

            input_iter = minibatches(self.inputs_img_i,
                                     self.inputs_img_j,
                                     batch_size=1,
                                     shuffle=True)

            for i in range(max_images):
                print("Processing batch {}/{} in {}th epoch".format(
                    i, max_images, epoch))

                self.image_a, self.image_b = next(input_iter)
                tmp_imgA = self.get_fake_image_pool(self.num_fake_inputs,
                                                    self.fake_images_A)
                self.fake_pool_A_mask = tmp_imgA["mask"]
                self.fake_pool_A = tmp_imgA["im"]

                tmp_imgB = self.get_fake_image_pool(self.num_fake_inputs,
                                                    self.fake_images_B)
                self.fake_pool_B_mask = tmp_imgB["mask"]
                self.fake_pool_B = tmp_imgB["im"]

                self.transition_rate = np.array([curr_tr], dtype=np.float32)
                self.donorm = np.array([donorm], dtype=np.float32)

                with tf.GradientTape(persistent=True) as tape:
                    self.output_converter(
                        model.get_outputs(self.input_converter(), nets))
                    self.upd_fake_image_pool(self.num_fake_inputs,
                                             self.fake_images_b, self.masks[0],
                                             self.fake_images_B)
                    self.upd_fake_image_pool(self.num_fake_inputs,
                                             self.fake_images_a, self.masks[1],
                                             self.fake_images_A)
                    self.compute_losses()

                #pdb.set_trace()
                grad = tape.gradient(self.d_B_loss, self.d_B_vars)
                self.d_B_trainer.apply_gradients(zip(grad, self.d_B_vars))

                grad = tape.gradient(self.d_A_loss, self.d_A_vars)
                self.d_A_trainer.apply_gradients(zip(grad, self.d_A_vars))

                grad = tape.gradient(self.g_A_loss, to_train_A_vars)
                to_train_A.apply_gradients(zip(grad, to_train_A_vars))

                grad = tape.gradient(self.g_B_loss, to_train_B_vars)
                to_train_B.apply_gradients(zip(grad, to_train_B_vars))

                tot_loss = self.g_A_loss + self.g_B_loss + self.d_A_loss + self.d_B_loss

                print("[training_info] g_A_loss = {}, g_B_loss = {}, d_A_loss = {}, d_B_loss = {}, \
                    tot_loss = {}, lr={}, curr_tr={}"                                                     .format(self.g_A_loss, self.g_B_loss, self.d_A_loss, \
                    self.d_B_loss, tot_loss, curr_lr, curr_tr))

                tf.summary.scalar('g_A_loss',
                                  self.g_A_loss,
                                  step=self.global_step * max_images + i)
                tf.summary.scalar('g_B_loss',
                                  self.g_B_loss,
                                  step=self.global_step * max_images + i)
                tf.summary.scalar('d_A_loss',
                                  self.d_A_loss,
                                  step=self.global_step * max_images + i)
                tf.summary.scalar('d_B_loss',
                                  self.d_B_loss,
                                  step=self.global_step * max_images + i)
                tf.summary.scalar('learning_rate',
                                  to_train_A.learning_rate,
                                  step=self.global_step * max_images + i)
                tf.summary.scalar('total_loss',
                                  tot_loss,
                                  step=self.global_step * max_images + i)

                self.num_fake_inputs += 1

            self.global_step = epoch + 1
            summary_writer.flush()
    def save_images_bis(self, nets, epoch, images_i, images_j):
        """
        Saves input and output images.
        """
        names = [
            'input_A_', 'mask_A_', 'masked_inputA_', 'fakeB_', 'input_B_',
            'mask_B_', 'masked_inputB_', 'fakeA_'
        ]
        space = '&nbsp &nbsp &nbsp &nbsp &nbsp &nbsp &nbsp &nbsp &nbsp &nbsp &nbsp &nbsp ' \
                '&nbsp &nbsp &nbsp &nbsp &nbsp &nbsp &nbsp &nbsp &nbsp &nbsp &nbsp &nbsp ' \
                '&nbsp &nbsp &nbsp &nbsp &nbsp'

        nets = utils.set_mode(nets, "eval")
        #pdb.set_trace()

        exists_or_mkdir(self._images_dir)

        with open(
                os.path.join(self._output_dir,
                             'results_' + str(epoch) + '.html'),
                'w') as v_html:
            v_html.write("<b>Inputs" + space + "Masks" + space +
                         "Masked_images" + space + "Generated_images</b>")
            v_html.write("<br>")

            input_iter = minibatches(images_i,
                                     images_j,
                                     batch_size=1,
                                     shuffle=True)

            for i in range(0, self._num_imgs_to_save):
                print("Saving image {}/{}...".format(i,
                                                     self._num_imgs_to_save))

                #pdb.set_trace()
                self.image_a, self.image_b = next(input_iter)
                tmp_imgA = self.get_fake_image_pool(self.num_fake_inputs,
                                                    self.fake_images_A)
                self.fake_pool_A_mask = tmp_imgA["mask"]
                self.fake_pool_A = tmp_imgA["im"]

                tmp_imgB = self.get_fake_image_pool(self.num_fake_inputs,
                                                    self.fake_images_B)
                self.fake_pool_B_mask = tmp_imgB["mask"]
                self.fake_pool_B = tmp_imgB["im"]

                self.transition_rate = np.array([0.1], dtype=np.float32)
                self.donorm = np.array([True], dtype=np.float32)

                self.output_converter(
                    model.get_outputs(self.input_converter(), nets))

                figures_to_save = [
                    self.image_a, self.masks[0], self.masked_ims[0],
                    self.fake_images_b, self.image_b, self.masks[1],
                    self.masked_ims[1], self.fake_images_a
                ]

                self.figure_writer(figures_to_save,
                                   names,
                                   v_html,
                                   epoch,
                                   i,
                                   html_mode=1)
Beispiel #7
0
def train():
    writer_path = './exp1/'
    exists_or_mkdir(writer_path)
    writer = SummaryWriter(writer_path)

    # load mnist
    X_train, y_train, X_val, y_val, X_test, y_test = tl.files.load_mnist_dataset(
        shape=(-1, 28, 28, 1))
    n = len(X_train)

    # convert mnist images to 32*32
    xnew = np.zeros((n, 32, 32, 1), dtype='float32')
    for i in range(n):
        xnew[i] = tl.prepro.imresize(X_train[i],
                                     size=(32, 32),
                                     interp='bicubic',
                                     mode=None)
        xnew[i] /= 128
        xnew[i] -= 1
    X_train = xnew

    G = get_generator([None, flags.z_dim])

    G.train()

    # 2 different optimizers to train G and z separately
    g_optimizer = tf.optimizers.SGD(lr=1e-3)
    z_optimizer = tf.optimizers.SGD(lr=0.1)

    # initialize z by sampling from a Gaussian distribution
    z = tf.Variable(tf.random.normal([n, flags.z_dim],
                                     stddev=np.sqrt(1.0 / flags.z_dim)),
                    name="z",
                    trainable=True)

    step = 0
    for epoch in trange(
            flags.n_epoch,
            desc='epoch loop'):  ## iterate the dataset n_epoch times
        start_time = time.time()
        # iterate over the entire training set once
        for i in range(n // flags.batch_size):

            # get X_batch by indexing, without shuffling (important!)
            X_batch = X_train[i * flags.batch_size:(i + 1) * flags.batch_size]

            step_time = time.time()
            step += 1
            G.train()  # enable dropout
            with tf.GradientTape(persistent=True) as tape:
                # compute outputs
                z_batch = z[i * flags.batch_size:(i + 1) * flags.batch_size]
                # tape.watch(z_batch)
                fake_X = G(z_batch)
                # compute loss and update model
                loss = tl.cost.mean_squared_error(fake_X,
                                                  X_batch,
                                                  name='train_loss')

            # compute gradient to G and z
            grad = tape.gradient(loss, G.trainable_weights + [z])

            # Back_propagation
            grad_g = grad[:len(G.trainable_weights)]
            grad_z = grad[len(G.trainable_weights):]
            g_optimizer.apply_gradients(zip(grad_g, G.trainable_weights))
            z_optimizer.apply_gradients(zip(grad_z, [z]))
            del tape

            print("Epoch: [{}/{}] [{}/{}] took: {:.3f}, loss: {:.5f}".format(
                epoch, flags.n_epoch, i, n // flags.batch_size,
                time.time() - step_time, loss))

        # normalize z (remain unchanged for those vectors whose length is shorter than 1)
        z.assign(z / tf.math.maximum(
            tf.math.sqrt(tf.math.reduce_sum(z**2, axis=1))[:, tf.newaxis], 1))

        # testing
        G.save_weights('{}/G.npz'.format(flags.checkpoint_dir), format='npz')
        G.eval()

        # sampling from the z distribution
        z_mean = np.mean(z.numpy(), axis=0)
        z_cov = np.cov(z.numpy(), rowvar=False)
        sample = np.random.multivariate_normal(z_mean, z_cov, size=(25))
        result = G(sample.astype(np.float32))

        G.train()
        tl.visualize.save_images(
            result.numpy(), [5, 5],
            '{}/train_{:02d}.png'.format(flags.sample_dir, epoch))
        save_res = np.tile(
            tf.transpose(result, [0, 3, 1, 2]).numpy(), [1, 3, 1, 1])
        save_res = torch.tensor(save_res)
        save_grid = make_grid(save_res.cpu(),
                              nrow=5,
                              range=(-1, 1),
                              normalize=True)
        writer.add_image('eval/recon_imgs', save_grid, epoch + 1)