Beispiel #1
0
    def validate_model(self, val_iter, epoch, step):
        for bid, batch in enumerate(val_iter):
            cns_code, seq_len, labels, images = batch
            # cns_code, seq_len, labels, images = next(val_iter)
            # cns_code, seq_len, labels, images = val_iter.items()[0]
            fake_imgs, real_imgs, d_loss, g_loss, l1_loss = self.generate_fake_samples(
                images, labels, cns_code, seq_len)
            print("Sample: d_loss: %.5f, g_loss: %.5f, l1_loss: %.5f" %
                  (d_loss, g_loss, l1_loss))
            break

        merged_fake_images = merge(scale_back(fake_imgs), [self.batch_size, 1])
        merged_real_images = merge(scale_back(real_imgs), [self.batch_size, 1])
        merged_pair = np.concatenate([merged_real_images, merged_fake_images],
                                     axis=1)

        model_id, _ = self.get_model_id_and_dir()

        model_sample_dir = os.path.join(self.sample_dir, model_id)
        if not os.path.exists(model_sample_dir):
            os.makedirs(model_sample_dir)

        sample_img_path = os.path.join(model_sample_dir,
                                       "sample_%02d_%04d.jpg" % (epoch, step))
        misc.imsave(sample_img_path, merged_pair)
        return l1_loss
Beispiel #2
0
    def sample_model(self, rA, rB, epoch, count, sess):
        fake_B, fake_A = sess.run([self.fake_B, self.fake_A],
                                  feed_dict={
                                      self.input_A: rA,
                                      self.input_B: rB
                                  })

        a1 = scale_back(rA)
        a2 = scale_back(fake_B)

        b1 = scale_back(rB)
        b2 = scale_back(fake_A)

        merged_pair = np.concatenate([a1, a2, b1, b2], axis=2)
        merged_pair = merged_pair.reshape(
            (merged_pair.shape[1], merged_pair.shape[2], merged_pair.shape[3]))

        s_dir = 'samples/'

        if not os.path.exists(s_dir):
            os.makedirs(s_dir)

        sample_img_path = os.path.join(s_dir,
                                       "sample_%02d_%04d.png" % (epoch, count))
        misc.imsave(sample_img_path, merged_pair)
Beispiel #3
0
    def get_loss(self, trip_od, scaled_trip_volume, in_flows, out_flows, g, multitask_weights=[0.5, 0.25, 0.25]):
        '''
        defines the procedure of evaluating loss function

        Inputs:
        ----------------------------------
        trip_od: list of origin destination pairs
        trip_volume: ground-truth of volume of trip which serves as our target.
        g: DGL graph object

        Outputs:
        ----------------------------------
        loss: value of loss function
        '''
        # calculate the in/out flow of nodes
        # scaled back trip volume
        trip_volume = utils.scale_back(scaled_trip_volume)
        # get in/out nodes of this batch
        out_nodes, out_flows_idx = torch.unique(trip_od[:, 0], return_inverse=True)
        in_nodes, in_flows_idx = torch.unique(trip_od[:, 1], return_inverse=True)
        # scale the in/out flows of the nodes in this batch
        scaled_out_flows = utils.scale(out_flows[out_nodes])
        scaled_in_flows = utils.scale(in_flows[in_nodes])
        # get embeddings of each node from GNN
        src_embedding = self.forward(g)
        dst_embedding = self.forward2(g)
        # get edge prediction
        edge_prediction = self.predict_edge(src_embedding, dst_embedding, trip_od)
        # get in/out flow prediction
        in_flow_prediction = self.predict_inflow(dst_embedding, in_nodes)
        out_flow_prediction = self.predict_outflow(src_embedding, out_nodes)
        # get edge prediction loss
        edge_predict_loss = MSE(edge_prediction, scaled_trip_volume)
        # get in/out flow prediction loss
        in_predict_loss = MSE(in_flow_prediction, scaled_in_flows)
        out_predict_loss = MSE(out_flow_prediction, scaled_out_flows)
        # get regularization loss
        reg_loss = 0.5 * (self.regularization_loss(src_embedding) + self.regularization_loss(dst_embedding))
        # return the overall loss
        return multitask_weights[0] * edge_predict_loss + multitask_weights[1] * in_predict_loss + multitask_weights[2] * out_predict_loss + self.reg_param * reg_loss
Beispiel #4
0
def transform():
    # Device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load Network
    g = Generator(64)
    g = g.to(device)
    g = nn.DataParallel(g)

    # Load Weights
    if (MODEL_LEGACY):
        g.load_state_dict(torch.load(MODEL_PATH))
    else:
        g.load_state_dict(torch.load(MODEL_PATH)["model_state_dict"])

    g = g.to(device)

    # Load Image and convert to PyTorch tensor
    real_image = Image.open(REAL_IMAGE_PATH).convert("RGB")
    transform = transforms.Compose([
        #transforms.Resize(256),
        #transforms.CenterCrop(256),
        transforms.ToTensor()
    ])
    real_tensor = transform(real_image).unsqueeze(0)

    with torch.no_grad():
        # Scale - Forward Pass - Scale Back
        real_tensor = scale(real_tensor)
        fake_tensor = g(real_tensor)
        fake_tensor = scale_back(fake_tensor)
        fake_tensor = fake_tensor.squeeze(0)

        # Tensor to Numpy Array
        fake_image = fake_tensor.cpu().numpy().transpose(1, 2, 0)
        show(fake_image)
Beispiel #5
0
    def interpolate(self, source_obj, between, model_dir, save_dir, steps):
        tf.global_variables_initializer().run()
        saver = tf.train.Saver(var_list=self.retrieve_generator_vars())
        self.restore_model(saver, model_dir)
        # new interpolated dimension
        new_x_dim = steps + 1
        alphas = np.linspace(0.0, 1.0, new_x_dim)

        def _interpolate_tensor(_tensor):
            """
            Compute the interpolated tensor here
            """

            x = _tensor[between[0]]
            y = _tensor[between[1]]

            interpolated = list()
            for alpha in alphas:
                interpolated.append(x * (1. - alpha) + alpha * y)

            interpolated = np.asarray(interpolated, dtype=np.float32)
            return interpolated

        def filter_embedding_vars(var):
            var_name = var.name
            if var_name.find("embedding") != -1:
                return True
            if var_name.find("inst_norm/shift") != -1 or var_name.find(
                    "inst_norm/scale") != -1:
                return True
            return False

        embedding_vars = filter(filter_embedding_vars,
                                tf.trainable_variables())
        # here comes the hack, we overwrite the original tensor
        # with interpolated ones. Note, the shape might differ

        # this is to restore the embedding at the end
        embedding_snapshot = list()
        for e_var in embedding_vars:
            val = e_var.eval(session=self.sess)
            embedding_snapshot.append((e_var, val))
            t = _interpolate_tensor(val)
            op = tf.assign(e_var, t, validate_shape=False)
            print("overwrite %s tensor" % e_var.name, "old_shape ->",
                  e_var.get_shape(), "new shape ->", t.shape)
            self.sess.run(op)

        source_provider = InjectDataProvider(source_obj)
        input_handle, _, eval_handle, _ = self.retrieve_handles()
        for step_idx in range(len(alphas)):
            alpha = alphas[step_idx]
            print("interpolate %d -> %.4f + %d -> %.4f" %
                  (between[0], 1. - alpha, between[1], alpha))
            source_iter = source_provider.get_single_embedding_iter(
                self.batch_size, 0)
            batch_buffer = list()
            count = 0
            for _, source_imgs in source_iter:
                count += 1
                labels = [step_idx] * self.batch_size
                generated, = self.sess.run(
                    [eval_handle.generator],
                    feed_dict={
                        input_handle.real_data: source_imgs,
                        input_handle.embedding_ids: labels
                    })
                merged_fake_images = merge(scale_back(generated),
                                           [self.batch_size, 1])
                batch_buffer.append(merged_fake_images)
            if len(batch_buffer):
                save_concat_images(
                    batch_buffer,
                    os.path.join(
                        save_dir, "frame_%02d_%02d_step_%02d.jpg" %
                        (between[0], between[1], step_idx)))
        # restore the embedding variables
        print("restore embedding values")
        for var, val in embedding_snapshot:
            op = tf.assign(var, val, validate_shape=False)
            self.sess.run(op)
Beispiel #6
0
    def infer(self):
        imga = self.read_and_decode(self.A_dir)
        imgA_batch = tf.train.batch([imga],
                                    batch_size=1,
                                    capacity=self.totalnum)
        imgb = self.read_and_decode(self.B_dir)
        imgB_batch = tf.train.batch([imgb],
                                    batch_size=1,
                                    capacity=self.totalnum)
        self.buildmodel()

        result_dir = "images/results/smiles/"
        checkpiont_dir = "checkpoint/smiles/"

        saver = tf.train.Saver()
        with tf.Session() as sess:
            sess.run(tf.local_variables_initializer())
            sess.run(tf.global_variables_initializer())
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)

            ckpt = tf.train.latest_checkpoint(checkpiont_dir)
            if ckpt:
                saver.restore(sess, ckpt)
                print("restored model %s" % checkpiont_dir)
            else:
                print("fail to restore model %s" % checkpiont_dir)
                return

            if not os.path.exists(result_dir):
                os.makedirs(result_dir)

            nowTime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            start_time = nowTime
            for i in range(0, self.totalnum):
                rA, rB = sess.run([imgA_batch, imgB_batch])  #

                rA = normalize_image(np.array(rA))
                #                    print(rA.shape)
                rB = normalize_image(np.array(rB))

                fake_B, fake_A = sess.run([self.fake_B, self.fake_A],
                                          feed_dict={
                                              self.input_A: rA,
                                              self.input_B: rB
                                          })

                a1 = scale_back(rA)
                a2 = scale_back(fake_B)

                b1 = scale_back(rB)
                b2 = scale_back(fake_A)

                merged_pair = np.concatenate([a1, a2, b1, b2], axis=2)
                merged_pair = merged_pair.reshape(
                    (merged_pair.shape[1], merged_pair.shape[2],
                     merged_pair.shape[3]))

                sample_img_path = os.path.join(result_dir,
                                               "sample_%04d.png" % (i))
                misc.imsave(sample_img_path, merged_pair)
                print(sample_img_path)

            end_Time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            print('start_time: {0}, and end_time: {1}'.format(
                start_time, end_Time))
            coord.request_stop()
            coord.join(threads)