Beispiel #1
0
def test_transfer_commonality_check(bch_limit):
    X_batch_init, C_batch_init, XN_batch_init, CN_batch_init = next(
        testing_batches)
    batch_idx = 0

    while True:
        print('current batch idx in test transfer commonality check: ',
              batch_idx)
        X_batch, C_batch, XN_batch, CN_batch = next(testing_batches)
        test_rsts = model.test(C_batch)

        # Test
        combined_bch_img = []

        combined_bch_img = np.concatenate(
            [test_rsts['cond'], test_rsts['test_sample']])
        test_save_img_name = os.path.join(
            out_dir, ("bch_{}_{}.png".format(batch_idx, 'test')))
        plot_batch(combined_bch_img, test_save_img_name)

        # Transfer
        test_rsts = []
        transfer_rts = model.transfer(XN_batch_init, CN_batch_init, C_batch)
        bs = X_batch.shape[0]

        transfer_save_rsts = np.concatenate(
            [C_batch, X_batch_init, transfer_rts])
        transfer_save_img_name = os.path.join(
            out_dir, "bch_{}_{}".format(batch_idx, 'transfer'))
        plot_batch(transfer_save_rsts, transfer_save_img_name + '_src_app.png')

        batch_idx += 1

        if batch_idx >= bch_limit:
            break
Beispiel #2
0
def restore_launch(mission_type, bch_limit=None):
    print('mission type: ', mission_type)
    startingDT = datetime.datetime.now()
    print('Starting DT: ', startingDT)
    X_batch_init, C_batch_init, XN_batch_init, CN_batch_init = next(
        testing_batches)
    plot_batch(X_batch_init, os.path.join(out_dir, 'target_appearance.png'))
    plot_batch(C_batch_init, os.path.join(out_dir, 'target_pose.png'))

    batch_idx = 0
    while True:
        print('current batch idx: ', batch_idx)
        X_batch, C_batch, XN_batch, CN_batch = next(testing_batches)
        if (X_batch is not None
                and bch_limit is None) or (batch_idx < bch_limit):
            if mission_type == 'test':
                test_rsts = model.test(C_batch)

                for k in test_rsts:
                    a_value_set = test_rsts[k]
                    if k == 'test_sample':

                        k_dir = os.path.join(out_dir, k)
                        if not os.path.exists(k_dir):
                            os.makedirs(k_dir)

                        overall_name = os.path.join(
                            k_dir, "bch_{}_{}".format(batch_idx, mission_type))
                        utilities.save_batch_img_np_txt(
                            a_value_set, overall_name)
                    elif k == 'cond':
                        k_dir = os.path.join(out_dir, k)
                        if not os.path.exists(k_dir):
                            os.makedirs(k_dir)

                        cond_bch_name = k_dir + os.sep + 'bch_{}_{}'.format(
                            batch_idx, mission_type) + '.png'
                        plot_batch(a_value_set, cond_bch_name)

            elif mission_type == 'transfer':
                test_rsts = []
                transfer_rts = model.transfer(XN_batch, CN_batch, C_batch_init)
                bs = X_batch.shape[0]
                for j in range(bs):
                    test_rsts.append(transfer_rts[j, ...])

                overall_name = os.path.join(
                    out_dir, "bch_{}_{}".format(batch_idx, mission_type))
                plot_batch(X_batch, overall_name + '_src_app.png')
                test_rsts = np.array(test_rsts)
                utilities.save_batch_img_np_txt(test_rsts, overall_name)
            batch_idx += 1
        else:
            break
    endingDT = datetime.datetime.now()
    print('ending DT: ', endingDT)
    def log_result(self, result, **kwargs):
        global_step = self.log_ops["global_step"].eval(session)
        if "summary" in result:
            self.writer.add_summary(result["summary"], global_step)
            self.writer.flush()
        if "log" in result:
            for k in sorted(result["log"]):
                v = result["log"][k]
                self.logger.info("{}: {}".format(k, v))
        if "img" in result:
            for k, v in result["img"].items():
                plot_batch(
                    v,
                    os.path.join(self.out_dir,
                                 k + "_{:07}.png".format(global_step)))

            if self.valid_batches is not None:
                # validation run
                X_batch, C_batch, XN_batch, CN_batch = next(self.valid_batches)
                feed_dict = {
                    self.xn: XN_batch,
                    self.cn: CN_batch,
                    self.x: X_batch,
                    self.c: C_batch
                }
                fetch_dict = dict()
                fetch_dict["imgs"] = self.img_ops
                fetch_dict["summary"] = self.valid_summary_op
                fetch_dict["validation_loss"] = self.log_ops["loss"]
                result = session.run(fetch_dict, feed_dict)
                self.writer.add_summary(result["summary"], global_step)
                self.writer.flush()
                # display samples
                imgs = result["imgs"]
                for k, v in imgs.items():
                    plot_batch(
                        v,
                        os.path.join(
                            self.out_dir,
                            "valid_" + k + "_{:07}.png".format(global_step)))
                # log validation loss
                validation_loss = result["validation_loss"]
                self.logger.info("{}: {}".format("validation_loss",
                                                 validation_loss))
                if self.checkpoint_best and validation_loss < self.best_loss:
                    # checkpoint if validation loss improved
                    self.logger.info(
                        "step {}: Validation loss improved from {:.4e} to {:.4e}"
                        .format(global_step, self.best_loss, validation_loss))
                    self.best_loss = validation_loss
                    self.make_checkpoint(global_step, prefix="best_")
        if global_step % self.test_frequency == 0:
            if self.valid_batches is not None:
                # testing
                X_batch, C_batch, XN_batch, CN_batch = next(self.valid_batches)

                x_gen = self.test(C_batch)
                for k in x_gen:
                    plot_batch(
                        x_gen[k],
                        os.path.join(
                            self.out_dir,
                            "testing_{}_{:07}.png".format(k, global_step)))
                # transfer
                bs = X_batch.shape[0]
                imgs = list()
                imgs.append(np.zeros_like(X_batch[0, ...]))
                for r in range(bs):
                    imgs.append(C_batch[r, ...])
                for i in range(bs):
                    x_infer = X_batch[i, ...]
                    c_infer = C_batch[i, ...]
                    imgs.append(X_batch[i, ...])

                    x_infer_batch = x_infer[None, ...].repeat(bs, axis=0)
                    c_infer_batch = c_infer[None, ...].repeat(bs, axis=0)
                    c_generate_batch = C_batch
                    results = self.transfer(x_infer_batch, c_infer_batch,
                                            c_generate_batch)
                    for j in range(bs):
                        imgs.append(results[j, ...])
                imgs = np.stack(imgs, axis=0)
                plot_batch(
                    imgs,
                    os.path.join(self.out_dir,
                                 "transfer_{:07}.png".format(global_step)))
        if global_step % self.ckpt_frequency == 0:
            self.make_checkpoint(global_step)
Beispiel #4
0
                                    train=False)

        model = Model(config, out_dir, logger)
        assert opt.checkpoint is not None
        model.restore_graph(opt.checkpoint)

        for step in trange(10):
            X_batch, C_batch, XN_batch, CN_batch = next(valid_batches)
            bs = X_batch.shape[0]
            imgs = list()
            imgs.append(np.zeros_like(X_batch[0, ...]))
            for r in range(bs):
                imgs.append(C_batch[r, ...])
            for i in range(bs):
                x_infer = XN_batch[i, ...]
                c_infer = CN_batch[i, ...]
                imgs.append(X_batch[i, ...])

                x_infer_batch = x_infer[None, ...].repeat(bs, axis=0)
                c_infer_batch = c_infer[None, ...].repeat(bs, axis=0)
                c_generate_batch = C_batch
                results = model.transfer(x_infer_batch, c_infer_batch,
                                         c_generate_batch)
                for j in range(bs):
                    imgs.append(results[j, ...])
            imgs = np.stack(imgs, axis=0)
            plot_batch(imgs,
                       os.path.join(out_dir, "transfer_{}.png".format(step)))
    else:
        raise NotImplemented()
Beispiel #5
0
def save_batch_img_np_txt(a_batch, save_path, save_bch_img=True):
    a_file_name = save_path + '.npy'
    np.save(a_file_name, a_batch)
    if save_bch_img:
        plot_batch(a_batch, save_path + '.png')
        init_shape = [config["init_batches"] * batch_size] + img_shape
        box_factor = config["box_factor"]

        data_index = config["data_index"]
        batches = get_batches(data_shape, data_index, train = True, box_factor = box_factor)
        init_batches = get_batches(init_shape, data_index, train = True, box_factor = box_factor)
        valid_batches = get_batches(data_shape, data_index, train = False, box_factor = box_factor)
        logger.info("Number of training samples: {}".format(batches.n))
        logger.info("Number of validation samples: {}".format(valid_batches.n))

        model = Model(config, out_dir, logger)
        if opt.checkpoint is not None:
            model.restore_graph(opt.checkpoint)
        else:
            model.init_graph(next(init_batches))
        if opt.retrain:
            model.reset_global_step()
        model.fit(batches, valid_batches)
        
    elif opt.mode == "transfer":
        if not opt.checkpoint:
            raise Exception("transfer requires --checkpoint")
        config['batch_size'] = 1
        config['box_factor'] = 2
        model = Model(config, out_dir, logger)
        model.restore_graph(opt.checkpoint)
        x_encode, c_encode, c_decode = model.prepare_tranfer(opt.src_img, opt.src_jo, opt.tar_img, opt.tar_jo)        
        x_gen = model.transfer(x_encode, c_encode, c_decode)
        plot_batch(x_gen, os.path.join(out_dir, "testing.png"))
    else:
        raise NotImplemented()