Beispiel #1
0
    def _get_gt_mask(self, gt_bboxes, class_labels, rescale_shape):

        init_anchors = self.init_anchors_dict[rescale_shape[0]]
        ind_list = init_anchors['index']
        sub2ind = init_anchors['sub2ind']

        gt_mask = copy.deepcopy(self.init_gt_mask_dict[rescale_shape[0]])
        one_hot_label = vec2onehot(class_labels, self._n_class)
        gt_cxy = np.stack([(gt_bboxes[:, 0] + gt_bboxes[:, 2]) // 2,
                           (gt_bboxes[:, 1] + gt_bboxes[:, 3]) // 2],
                          axis=-1)

        iou_mat = bboxtool.bbox_list_IOU(gt_bboxes,
                                         bboxtool.cxywh2xyxy(
                                             self._anchor_boxes),
                                         align=True)
        target_anchor_list = np.argmax(iou_mat, axis=-1)
        # print()
        # out_anchor_list = []
        for gt_id, (target_anchor_idx,
                    gt_bbox) in enumerate(zip(target_anchor_list, gt_bboxes)):
            if iou_mat[gt_id, target_anchor_idx] == 0:
                continue
            anchor_idx_list = []
            for scale_id, stride in enumerate(self._stride_list):
                anchor_feat_cxy = gt_cxy[gt_id] // stride
                gt_feat_cxy = gt_cxy[gt_id] / stride

                # print(gt_bboxes[gt_id],gt_cxy[gt_id])
                anchor_idx_list += sub2ind[(scale_id, anchor_feat_cxy[1],
                                            anchor_feat_cxy[0])]

            anchor_idx = anchor_idx_list[target_anchor_idx]
            scale_id, prior_id, row_id, col_id, anchor_xyxy, anchor_stride =\
                self._get_anchor_property(anchor_idx, ind_list)

            gt_mask[scale_id][prior_id][row_id, col_id, :4] =\
                bboxtool.xyxy2yolotcoord([gt_bbox], anchor_xyxy, anchor_stride, [col_id, row_id])
            gt_mask[scale_id][prior_id][row_id, col_id, 4] = 1
            # TODO
            # multi-class
            gt_mask[scale_id][prior_id][row_id, col_id,
                                        5:] = one_hot_label[gt_id]
            # out_anchor_list.append(anchor_list[anchor_idx])
            # out_anchor_list.append([col_id*anchor_stride, row_id*anchor_stride] + anchor_xyxy)

        return gt_mask  #, out_anchor_list
Beispiel #2
0
    def random_sampling(self,
                        sess,
                        keep_prob=1.0,
                        plot_size=10,
                        file_id=None,
                        save_path=None):
        n_samples = plot_size * plot_size
        random_vec = distributions.random_vector((n_samples, self.n_code),
                                                 dist_type='uniform')
        code_cont = distributions.random_vector((n_samples, self.n_continuous),
                                                dist_type='uniform')

        if self.n_discrete <= 0:
            code_discrete = np.zeros((n_samples, 0))
        else:
            code_discrete = []
            for i in range(self.n_discrete):
                n_class = self.cat_n_class_list[i]
                cur_code = [
                    np.random.choice(n_class) for i in range(n_samples)
                ]
                cur_code = dfutils.vec2onehot(cur_code, n_class)
                try:
                    code_discrete = np.concatenate((code_discrete, cur_code),
                                                   axis=-1)
                except ValueError:
                    code_discrete = cur_code

        self._viz_samples(sess,
                          random_vec,
                          code_discrete,
                          code_cont,
                          keep_prob,
                          plot_size=[plot_size, plot_size],
                          save_path=save_path,
                          file_name='random_sampling',
                          file_id=file_id)
Beispiel #3
0
    def interp_cont_sampling(self,
                             sess,
                             n_interpolation,
                             cont_code_id,
                             vary_discrete_id=None,
                             n_col_samples=None,
                             keep_prob=1.,
                             save_path=None,
                             file_id=None):
        """ Sample interpolation of one of continuous codes.

        Args:
            sess (tf.Session): tensorflow session
            n_interpolation (int): number of interpolation samples
            cont_code_id (int): index of continuous code for interpolation
            vary_discrete_id (int): Index of discrete code for varying
                while sample interpolation. All the discrete code will be fixed
                if it is None.
            keep_prob (float): keep probability for dropout
            save_path (str): directory for saving image
            file_id (int): index for saving image 
        """
        if cont_code_id >= self.n_continuous:
            return
        if vary_discrete_id is not None and vary_discrete_id < self.n_discrete:
            n_vary_class = self.cat_n_class_list[vary_discrete_id]
            n_samples = n_interpolation * n_vary_class
        elif n_col_samples is not None:
            n_vary_class = n_col_samples
            n_samples = n_interpolation * n_vary_class
        else:
            n_vary_class = 1
            n_samples = n_interpolation

        random_vec = distributions.random_vector((1, self.n_code),
                                                 dist_type='uniform')
        random_vec = np.tile(random_vec, (n_samples, 1))

        if self.n_discrete <= 0:
            code_discrete = np.zeros((n_samples, 0))
        else:
            code_discrete = []
            for i in range(self.n_discrete):
                n_class = self.cat_n_class_list[i]
                if i == vary_discrete_id:
                    cur_code = [
                        i for i in range(n_class)
                        for j in range(n_interpolation)
                    ]
                else:
                    cur_code = np.random.choice(n_class, 1) * np.ones(
                        (n_samples))
                cur_onehot_code = dfutils.vec2onehot(cur_code, n_class)
                try:
                    code_discrete = np.concatenate(
                        (code_discrete, cur_onehot_code), axis=-1)
                except ValueError:
                    code_discrete = cur_onehot_code

        if vary_discrete_id is not None and vary_discrete_id < self.n_discrete:
            code_cont = distributions.random_vector((1, self.n_continuous),
                                                    dist_type='uniform')
            code_cont = np.tile(code_cont, (n_samples, 1))
            cont_interp = np.linspace(-1., 1., n_interpolation)
            cont_interp = np.tile(cont_interp, (n_vary_class))
            code_cont[:, cont_code_id] = cont_interp
        else:
            code_cont = distributions.random_vector(
                (n_col_samples, self.n_continuous), dist_type='uniform')
            code_cont = np.repeat(code_cont, n_interpolation, axis=0)
            # code_cont = np.tile(code_cont.transpose(), (1, n_interpolation)).transpose()
            cont_interp = np.linspace(-1., 1., n_interpolation)
            cont_interp = np.tile(cont_interp, (n_vary_class))
            code_cont[:, cont_code_id] = cont_interp

        self._viz_samples(sess,
                          random_vec,
                          code_discrete,
                          code_cont,
                          keep_prob,
                          plot_size=[n_vary_class, n_interpolation],
                          save_path=save_path,
                          file_name='interp_cont_{}'.format(cont_code_id),
                          file_id=file_id)
Beispiel #4
0
    def vary_discrete_sampling(self,
                               sess,
                               vary_discrete_id,
                               keep_prob=1.0,
                               sample_per_class=10,
                               file_id=None,
                               save_path=None):
        """ Sampling by varying a discrete code.

        Args:
            sess (tf.Session): tensorflow session
            vary_discrete_id (int): index of discrete code for varying
            keep_prob (float): keep probability for dropout
            sample_per_class (int): number of samples for each class
            file_id (int): index for saving image 
            save_path (str): directory for saving image
        """
        if vary_discrete_id >= self.n_discrete:
            return
        n_vary_class = self.cat_n_class_list[vary_discrete_id]
        n_samples = n_vary_class * sample_per_class

        # sample_per_class = int(math.floor(n_samples / n_vary_class))
        n_remain_sample = n_samples - n_vary_class * sample_per_class

        random_vec = distributions.random_vector((n_samples, self.n_code),
                                                 dist_type='uniform')

        if self.n_discrete <= 0:
            code_discrete = np.zeros((n_samples, 0))
        else:
            code_discrete = []
            for i in range(self.n_discrete):
                n_class = self.cat_n_class_list[i]
                if i == vary_discrete_id:
                    cur_code = [
                        i for i in range(n_class)
                        for j in range(sample_per_class)
                    ]
                else:
                    cur_code = [
                        np.random.choice(n_class)
                        for j in range(sample_per_class)
                    ]
                    cur_code = np.tile(cur_code, (n_vary_class))
                cur_code = dfutils.vec2onehot(cur_code, n_class)
                try:
                    code_discrete = np.concatenate((code_discrete, cur_code),
                                                   axis=-1)
                except ValueError:
                    code_discrete = cur_code

        code_cont = distributions.random_vector((n_samples, self.n_continuous),
                                                dist_type='uniform')

        self._viz_samples(
            sess,
            random_vec,
            code_discrete,
            code_cont,
            keep_prob,
            plot_size=[n_vary_class, sample_per_class],
            save_path=save_path,
            file_name='vary_discrete_{}'.format(vary_discrete_id),
            file_id=file_id)
Beispiel #5
0
    def train_epoch(self,
                    sess,
                    train_data,
                    init_lr,
                    n_g_train=1,
                    n_d_train=1,
                    keep_prob=1.0,
                    summary_writer=None):
        """ Train for one epoch of training data

        Args:
            sess (tf.Session): tensorflow session
            train_data (DataFlow): DataFlow for training set
            init_lr (float): initial learning rate
            n_g_train (int): number of times of generator training for each step
            n_d_train (int): number of times of discriminator training for each step
            keep_prob (float): keep probability for dropout
            summary_writer (tf.FileWriter): write for summary. No summary will be
            saved if None.
        """

        assert int(n_g_train) > 0 and int(n_d_train) > 0
        display_name_list = ['d_loss', 'g_loss', 'LI_G', 'LI_D']
        cur_summary = None

        lr = init_lr
        lr_D = 2e-4
        lr_G = 1e-3
        # lr_G = lr_D * 10

        cur_epoch = train_data.epochs_completed
        step = 0
        d_loss_sum = 0
        g_loss_sum = 0
        LI_G_sum = 0
        LI_D_sum = 0
        self.epoch_id += 1
        while cur_epoch == train_data.epochs_completed:
            self.global_step += 1
            step += 1

            batch_data = train_data.next_batch_dict()
            im = batch_data['im']

            random_vec = distributions.random_vector((len(im), self.n_code),
                                                     dist_type='uniform')

            # code_discrete = []
            # discrete_label = []
            if self.n_discrete <= 0:
                code_discrete = np.zeros((len(im), 0))
                discrete_label = np.zeros((len(im), self.n_discrete))
            else:
                code_discrete = []
                discrete_label = []
                for i in range(self.n_discrete):
                    n_class = self.cat_n_class_list[i]
                    cur_code = np.random.choice(n_class, (len(im)))
                    cur_onehot_code = dfutils.vec2onehot(cur_code, n_class)
                    try:
                        code_discrete = np.concatenate(
                            (code_discrete, cur_onehot_code), axis=-1)
                        discrete_label = np.concatenate(
                            (discrete_label, np.expand_dims(cur_code,
                                                            axis=-1)),
                            axis=-1)
                    except ValueError:
                        code_discrete = cur_onehot_code
                        discrete_label = np.expand_dims(cur_code, axis=-1)

            code_cont = distributions.random_vector(
                (len(im), self.n_continuous), dist_type='uniform')

            # train discriminator
            for i in range(int(n_d_train)):

                _, d_loss, LI_D = sess.run(
                    [self.train_d_op, self.d_loss_op, self.LI_D],
                    feed_dict={
                        self.real: im,
                        self.lr: lr_D,
                        self.keep_prob: keep_prob,
                        self.random_vec: random_vec,
                        self.code_discrete: code_discrete,
                        self.discrete_label: discrete_label,
                        self.code_continuous: code_cont
                    })
            # train generator
            for i in range(int(n_g_train)):
                _, g_loss, LI_G = sess.run(
                    [self.train_g_op, self.g_loss_op, self.LI_G],
                    feed_dict={
                        self.lr: lr_G,
                        self.keep_prob: keep_prob,
                        self.random_vec: random_vec,
                        self.code_discrete: code_discrete,
                        self.discrete_label: discrete_label,
                        self.code_continuous: code_cont
                    })

            d_loss_sum += d_loss
            g_loss_sum += g_loss
            LI_G_sum += LI_G
            LI_D_sum += LI_D

            if step % 100 == 0:
                cur_summary = sess.run(self.train_summary_op,
                                       feed_dict={
                                           self.real: im,
                                           self.keep_prob: keep_prob,
                                           self.random_vec: random_vec,
                                           self.code_discrete: code_discrete,
                                           self.discrete_label: discrete_label,
                                           self.code_continuous: code_cont
                                       })

                viz.display(self.global_step,
                            step, [
                                d_loss_sum / n_d_train, g_loss_sum / n_g_train,
                                LI_G_sum, LI_D_sum
                            ],
                            display_name_list,
                            'train',
                            summary_val=cur_summary,
                            summary_writer=summary_writer)

        print('==== epoch: {}, lr:{} ===='.format(cur_epoch, lr))
        cur_summary = sess.run(self.train_summary_op,
                               feed_dict={
                                   self.real: im,
                                   self.keep_prob: keep_prob,
                                   self.random_vec: random_vec,
                                   self.code_discrete: code_discrete,
                                   self.discrete_label: discrete_label,
                                   self.code_continuous: code_cont
                               })
        viz.display(self.global_step,
                    step, [
                        d_loss_sum / n_d_train, g_loss_sum / n_g_train,
                        LI_G_sum, LI_D_sum
                    ],
                    display_name_list,
                    'train',
                    summary_val=cur_summary,
                    summary_writer=summary_writer)