Ejemplo n.º 1
0
    def get_nodes_features(self, combination):
        # combination: a list of indices
        # Each leaf contains a one-hot encoding of a key, and a one-hot encoding of the value
        # Every other node is empty, for now
        selected_key, values = combination

        # The root is [one-hot selected key] + [0 ... 0]
        nodes = [
            common.one_hot(selected_key, len(self.leaf_indices)) + [0] * len(self.leaf_indices)
        ]
        for i in range(1, self.num_nodes):
            if i in self.leaf_indices:
                leaf_num = self.leaf_indices.index(i)
                node = common.one_hot(leaf_num, len(self.leaf_indices)) + common.one_hot(values[leaf_num],
                                                                                         len(self.leaf_indices))
            else:
                node = [0] * (2 * len(self.leaf_indices))
            nodes.append(node)
        return nodes
Ejemplo n.º 2
0
 def train_discriminator(self):
     alg = self.algorithm
     # get states and actions
     state_a, action_a, rewards_a, state_a_, terminals_a = self.algorithm.er_agent.sample(
     )[:5]
     state_e, action_e, rewards_e, state_e_, terminals_e = self.algorithm.er_expert.sample(
     )[:5]
     states = np.concatenate([state_a, state_e])
     dones = np.concatenate([terminals_a, terminals_e])
     if not self.env.continuous_actions:
         action_e = common.one_hot(action_e,
                                   num_classes=self.env.action_size)
     actions = np.concatenate([action_a, action_e])
     # labels (policy/expert) : 0/1, and in 1-hot form: policy-[1,0], expert-[0,1]
     labels_a = np.zeros(shape=(state_a.shape[0], ))
     labels_e = np.ones(shape=(state_e.shape[0], ))
     labels = np.expand_dims(np.concatenate([labels_a, labels_e]), axis=1)
     fetches = [
         alg.discriminator.minimize, alg.discriminator.loss,
         alg.discriminator.acc
     ]
     if self.env.use_airl:
         states_ = np.concatenate([state_a_, state_e_])
         feed_dict = {
             alg.states: states,
             alg.states_: states_,
             alg.done_ph: dones,
             alg.actions: actions,
             alg.label: labels,
             alg.do_keep_prob: self.env.do_keep_prob
         }
     else:
         feed_dict = {
             alg.states: states,
             alg.actions: actions,
             alg.label: labels,
             alg.do_keep_prob: self.env.do_keep_prob
         }
     run_vals = self.sess.run(fetches, feed_dict)
     self.update_stats('discriminator', 'loss', run_vals[1])
     self.update_stats('discriminator', 'accuracy', run_vals[2])
     if self.itr % self.env.discr_policy_itrvl == 0:
         self.writer.add_scalar('train/discriminator/loss', run_vals[1],
                                self.itr)
         self.writer.add_scalar('train/discriminator/accuracy', run_vals[2],
                                self.itr)
Ejemplo n.º 3
0
 def generate(self, labels, sess=None):
     '''
     class_: label of class you want to create samples
     no_sample: number of samples you want to create
     '''
     no_samples = labels.shape[0]
     labels = one_hot(labels, depth=self.no_classes)
     z_ = np.random.normal(0, 1, (no_samples, 100))
     z_ = np.concatenate((z_, labels), axis=1)
     z_ = z_.reshape((no_samples, 1, 1, -1))
     if sess == None:
         G_Images_c = self.sess.run(self.g_net, {
             self.z: z_,
             self.istraining: False
         })
     else:
         G_Images_c = sess.run(self.g_net, {
             self.z: z_,
             self.istraining: False
         })
     return G_Images_c
Ejemplo n.º 4
0
    def train(self, data, labels, no_epochs=1000):
        batch_size = self.batch_size
        no_data = data.shape[0]
        bat_num = no_data // batch_size
        index_data = np.arange(no_data)
        print("Initiate...")

        if not self.is_restored:
            self.sess.run(tf.global_variables_initializer())

        self.writer_train = tf.summary.FileWriter(self.logs_path)
        if not os.path.isdir(self.save_path_models):
            os.makedirs(self.save_path_models)

        if not os.path.isdir(self.save_path_imgs):
            os.makedirs(self.save_path_imgs)

        if not os.path.isdir(self.save_path_generator):
            os.makedirs(self.save_path_generator)

        self.writer_train.add_graph(self.sess.graph)

        if self.no_classes > 10:
            labels_test = np.ones((10, 100), dtype=np.int16)
            for i in range(10):
                labels_test[i] = labels_test[i] * i
        else:
            labels_test = np.ones((self.no_classes, 100), dtype=np.int16)
            for i in range(self.no_classes):
                labels_test[i] = labels_test[i] * i

        z_test = np.random.normal(0, 1, (100, 100))
        generate_imgs_time = 0
        print("Start training ACGAN...")
        for epoch in range(no_epochs):
            print("")
            print('epoch {}:'.format(epoch + 1))
            np.random.shuffle(index_data)
            start = time.time()
            x_ = []
            y_ = []
            z_ = []
            for ite in range(bat_num):
                x_ = data[index_data[ite * batch_size:(ite + 1) * batch_size]]
                y_ = labels[index_data[ite * batch_size:(ite + 1) *
                                       batch_size]]
                y_onehot = one_hot(y_, self.no_classes)
                z_ = np.random.normal(0, 1, (batch_size, 100))
                z_ = np.concatenate((z_, y_onehot), axis=1)
                z_ = z_.reshape((batch_size, 1, 1, -1))

                if epoch == 0:
                    self.sess.run(self.d_clip)
                    _ = self.sess.run(self.D_optim, {
                        self.X: x_,
                        self.y: y_,
                        self.z: z_,
                        self.istraining: True
                    })
                    continue

                if (ite + 1) % 5 == 0:
                    # print('train g')
                    _ = self.sess.run(self.G_optim, {
                        self.X: x_,
                        self.y: y_,
                        self.z: z_,
                        self.istraining: True
                    })
                else:
                    # print('train D')
                    self.sess.run(self.d_clip)
                    _ = self.sess.run(self.D_optim, {
                        self.X: x_,
                        self.y: y_,
                        self.z: z_,
                        self.istraining: True
                    })

                if ite + 1 == bat_num:  # every self.FLAGS.F_show_img batchs or final batch, we show some generated images
                    for i in range(labels_test.shape[0]):
                        # c means class
                        labels_test_c = one_hot(labels_test[i],
                                                self.no_classes)
                        no_test_sample = len(labels_test_c)
                        z_test_c = np.concatenate((z_test, labels_test_c),
                                                  axis=1)
                        z_test_c = z_test_c.reshape(
                            (no_test_sample, 1, 1, self.z_dim))

                        G_Images_c = self.sess.run(self.g_net, {
                            self.z: z_test_c,
                            self.istraining: False
                        })
                        G_Images_c = (G_Images_c + 1.0) / 2.0
                        G_Images_c = make_image_from_batch(G_Images_c)
                        G_Images_c = (G_Images_c * 255).astype(np.uint8)
                        if self.no_channels == 3:
                            G_Images_c = G_Images_c[:, :, ::-1]
                        cv2.imwrite(
                            '{}/epoch_{}_class_{}.png'.format(
                                self.save_path_imgs, epoch, i), G_Images_c)
                    generate_imgs_time = generate_imgs_time + 1
                    labels_test_c = []
                progress(ite + 1, bat_num)
            # we will show the loss of only final batch in each epoch
            # self.list_loss = [
            #     self.loss_Class_fake, self.loss_Class_real, self.D_loss_w,
            #     self.g_loss_w, self.D_loss_all,  self.G_loss_all
            # ]
            loss_C_fake, loss_C_real, D_loss_w, g_loss_w, D_loss_all, G_loss_all = self.sess.run(
                self.list_loss,
                feed_dict={
                    self.X: x_,
                    self.y: y_,
                    self.z: z_,
                    self.istraining: False
                })
            D_loss_all = D_loss_all / 2.0
            summary_for_loss = self.sess.run(self.summary_loss,
                                             feed_dict={
                                                 self.loss_c_fake_ph:
                                                 loss_C_fake,
                                                 self.loss_c_real_ph:
                                                 loss_C_real,
                                                 self.loss_D_w_ph: D_loss_w,
                                                 self.loss_D_total_ph:
                                                 D_loss_all,
                                                 self.loss_G_w_ph: g_loss_w,
                                                 self.loss_G_total_ph:
                                                 G_loss_all,
                                                 self.istraining: False
                                             })
            self.writer_train.add_summary(summary_for_loss, epoch)
            save_path_models = self.saver_models.save(
                self.sess,
                "{}/model_{}/model.ckpt".format(self.save_path_models, epoch))

            save_path_G = self.saver_G.save(
                self.sess, "{}/model.ckpt".format(self.save_path_generator))

            stop = time.time()
            print("")
            print('time: {}'.format(stop - start))
            print('loss D: {}, loss G: {}'.format(D_loss_all, G_loss_all))
            print('saved model in: {}'.format(save_path_models))
            print('saved G in: {}'.format(save_path_G))
            print("")
            print("=======================================")
        self.writer_train.close()
        self.sess.close()