コード例 #1
0
ファイル: yolov3.py プロジェクト: tianyaoZhang/yolov3
    def valid_epoch(self, sess, summary_writer=None):
        """ Train the model for one epoch

            Args:
                sess (tf.Session())
                init_lr (float): learning rate
                summary_writer (tf.summary)
        """

        display_name_list = ['cls_loss', 'bbox_loss', 'obj_loss', 'loss']
        cur_summary = None

        cls_loss_sum = 0
        bbox_loss_sum = 0
        obj_loss_sum = 0
        loss_sum = 0
        step = 0

        self.epoch_id += 1

        step += 1
        self.global_step += 1
        loss, cls_loss, bbox_loss, obj_loss, cur_summary = sess.run([
            self.loss_op, self.cls_loss, self.bbox_loss, self.obj_loss,
            self.valid_summary_op
        ])

        cls_loss_sum += cls_loss
        bbox_loss_sum += bbox_loss
        obj_loss_sum += obj_loss
        loss_sum += loss

        while True:
            try:
                step += 1
                self.global_step += 1

                loss, cls_loss, bbox_loss, obj_loss = sess.run([
                    self.loss_op, self.cls_loss, self.bbox_loss, self.obj_loss
                ])

                cls_loss_sum += cls_loss
                bbox_loss_sum += bbox_loss
                obj_loss_sum += obj_loss
                loss_sum += loss

            except tf.errors.OutOfRangeError:
                break

        # write summary
        print('[valid]:', end='')
        viz.display(self.epoch_id,
                    step,
                    [cls_loss_sum, bbox_loss_sum, obj_loss_sum, loss_sum],
                    display_name_list,
                    'valid',
                    summary_val=cur_summary,
                    summary_writer=summary_writer)
コード例 #2
0
 def generate_batch(self, sess, summary_writer=None):
     display_name_list = []
     self.global_step += 1
     cur_summary = sess.run(self.generate_summary_op)
     viz.display(self.global_step,
                 1, [],
                 display_name_list,
                 'train',
                 summary_val=cur_summary,
                 summary_writer=summary_writer)
コード例 #3
0
    def train_steps(self,
                    sess,
                    train_data,
                    init_lr=1e-3,
                    t0=15000,
                    t1=25000,
                    max_step=100,
                    summary_writer=None):
        self.epoch_id += 1
        display_name_list = ['loss']
        # cur_epoch = train_data.epochs_completed

        cur_summary = None
        step = 0

        loss_sum = 0
        while step < max_step and self.global_step <= t1:
            if self.global_step <= t0:
                lr = init_lr
            else:
                lr = init_lr * (0.001**((self.global_step - t0) / (t1 - t0)))
            step += 1
            self.global_step += 1

            batch_data = train_data.next_batch_dict()
            _, loss, cur_summary = sess.run(
                [self.train_op, self.loss_op, self.train_summary_op],
                feed_dict={
                    self.lr: lr,
                    self.image: batch_data['im'],
                    self.label: batch_data['label']
                })
            loss_sum += loss
            summary_writer.add_summary(cur_summary, self.global_step)

            # if step % 100 == 0:
            #     viz.display(
            #         global_step=self.global_step,
            #         step=step,
            #         scaler_sum_list=[loss_sum],
            #         name_list=display_name_list,
            #         collection='train',
            #         summary_val=cur_summary,
            #         summary_writer=summary_writer)

        print('==== lr:{} ===='.format(lr))
        viz.display(global_step=self.global_step,
                    step=step,
                    scaler_sum_list=[loss_sum],
                    name_list=display_name_list,
                    collection='train',
                    summary_val=cur_summary,
                    summary_writer=summary_writer)
コード例 #4
0
    def train_epoch(self, sess, lr, max_step=None, summary_writer=None):
        if max_step is None:
            max_step = 2**30

        display_name_list = ['loss']
        cur_summary = None

        loss_sum = 0
        step = 0

        while True and step < max_step:
            try:
                step += 1
                self.global_step += 1

                _, loss, cur_summary = sess.run(
                    [self.train_op, self.loss_op, self.train_summary_op],
                    feed_dict={
                        self.lr: lr,
                        self.keep_prob: 1.
                    })
                loss_sum += loss

                if step % 100 == 0:
                    viz.display(self.global_step,
                                step, [loss_sum],
                                display_name_list,
                                'train',
                                summary_val=cur_summary,
                                summary_writer=summary_writer)

            except tf.errors.OutOfRangeError:
                break

        viz.display(self.global_step,
                    step, [loss_sum],
                    display_name_list,
                    'train',
                    summary_val=cur_summary,
                    summary_writer=summary_writer)
コード例 #5
0
    def train_epoch(self, sess, train_data, lr, summary_writer=None):
        self.epoch_id += 1
        display_name_list = ['loss']
        cur_epoch = train_data.epochs_completed

        cur_summary = None
        step = 0
        while train_data.epochs_completed <= cur_epoch:
            step += 1
            self.global_step += 1

            loss_sum = 0
            batch_data = train_data.next_batch_dict()
            _, loss = sess.run(
                [self.train_op, self.loss_op],
                feed_dict={
                    self.lr: lr,
                    self.image: batch_data['im'],
                    self.label: batch_data['label']
                })

            loss_sum += loss

            if step % 100 == 0:
                viz.display(global_step=self.global_step,
                            step=step,
                            scaler_sum_list=[loss_sum],
                            name_list=display_name_list,
                            collection='train',
                            summary_val=cur_summary,
                            summary_writer=summary_writer)

        print('==== epoch: {}, lr:{} ===='.format(cur_epoch, lr))
        viz.display(global_step=self.global_step,
                    step=step,
                    scaler_sum_list=[loss_sum],
                    name_list=display_name_list,
                    collection='train',
                    summary_val=cur_summary,
                    summary_writer=summary_writer)
コード例 #6
0
ファイル: infogan.py プロジェクト: conan7882/tf-gans
    def train_epoch(self,
                    sess,
                    train_data,
                    init_lr,
                    n_g_train=1,
                    n_d_train=1,
                    keep_prob=1.0,
                    summary_writer=None):
        """ Train for one epoch of training data

        Args:
            sess (tf.Session): tensorflow session
            train_data (DataFlow): DataFlow for training set
            init_lr (float): initial learning rate
            n_g_train (int): number of times of generator training for each step
            n_d_train (int): number of times of discriminator training for each step
            keep_prob (float): keep probability for dropout
            summary_writer (tf.FileWriter): write for summary. No summary will be
            saved if None.
        """

        assert int(n_g_train) > 0 and int(n_d_train) > 0
        display_name_list = ['d_loss', 'g_loss', 'LI_G', 'LI_D']
        cur_summary = None

        lr = init_lr
        lr_D = 2e-4
        lr_G = 1e-3
        # lr_G = lr_D * 10

        cur_epoch = train_data.epochs_completed
        step = 0
        d_loss_sum = 0
        g_loss_sum = 0
        LI_G_sum = 0
        LI_D_sum = 0
        self.epoch_id += 1
        while cur_epoch == train_data.epochs_completed:
            self.global_step += 1
            step += 1

            batch_data = train_data.next_batch_dict()
            im = batch_data['im']

            random_vec = distributions.random_vector((len(im), self.n_code),
                                                     dist_type='uniform')

            # code_discrete = []
            # discrete_label = []
            if self.n_discrete <= 0:
                code_discrete = np.zeros((len(im), 0))
                discrete_label = np.zeros((len(im), self.n_discrete))
            else:
                code_discrete = []
                discrete_label = []
                for i in range(self.n_discrete):
                    n_class = self.cat_n_class_list[i]
                    cur_code = np.random.choice(n_class, (len(im)))
                    cur_onehot_code = dfutils.vec2onehot(cur_code, n_class)
                    try:
                        code_discrete = np.concatenate(
                            (code_discrete, cur_onehot_code), axis=-1)
                        discrete_label = np.concatenate(
                            (discrete_label, np.expand_dims(cur_code,
                                                            axis=-1)),
                            axis=-1)
                    except ValueError:
                        code_discrete = cur_onehot_code
                        discrete_label = np.expand_dims(cur_code, axis=-1)

            code_cont = distributions.random_vector(
                (len(im), self.n_continuous), dist_type='uniform')

            # train discriminator
            for i in range(int(n_d_train)):

                _, d_loss, LI_D = sess.run(
                    [self.train_d_op, self.d_loss_op, self.LI_D],
                    feed_dict={
                        self.real: im,
                        self.lr: lr_D,
                        self.keep_prob: keep_prob,
                        self.random_vec: random_vec,
                        self.code_discrete: code_discrete,
                        self.discrete_label: discrete_label,
                        self.code_continuous: code_cont
                    })
            # train generator
            for i in range(int(n_g_train)):
                _, g_loss, LI_G = sess.run(
                    [self.train_g_op, self.g_loss_op, self.LI_G],
                    feed_dict={
                        self.lr: lr_G,
                        self.keep_prob: keep_prob,
                        self.random_vec: random_vec,
                        self.code_discrete: code_discrete,
                        self.discrete_label: discrete_label,
                        self.code_continuous: code_cont
                    })

            d_loss_sum += d_loss
            g_loss_sum += g_loss
            LI_G_sum += LI_G
            LI_D_sum += LI_D

            if step % 100 == 0:
                cur_summary = sess.run(self.train_summary_op,
                                       feed_dict={
                                           self.real: im,
                                           self.keep_prob: keep_prob,
                                           self.random_vec: random_vec,
                                           self.code_discrete: code_discrete,
                                           self.discrete_label: discrete_label,
                                           self.code_continuous: code_cont
                                       })

                viz.display(self.global_step,
                            step, [
                                d_loss_sum / n_d_train, g_loss_sum / n_g_train,
                                LI_G_sum, LI_D_sum
                            ],
                            display_name_list,
                            'train',
                            summary_val=cur_summary,
                            summary_writer=summary_writer)

        print('==== epoch: {}, lr:{} ===='.format(cur_epoch, lr))
        cur_summary = sess.run(self.train_summary_op,
                               feed_dict={
                                   self.real: im,
                                   self.keep_prob: keep_prob,
                                   self.random_vec: random_vec,
                                   self.code_discrete: code_discrete,
                                   self.discrete_label: discrete_label,
                                   self.code_continuous: code_cont
                               })
        viz.display(self.global_step,
                    step, [
                        d_loss_sum / n_d_train, g_loss_sum / n_g_train,
                        LI_G_sum, LI_D_sum
                    ],
                    display_name_list,
                    'train',
                    summary_val=cur_summary,
                    summary_writer=summary_writer)
コード例 #7
0
ファイル: yolov3.py プロジェクト: tianyaoZhang/yolov3
    def train_epoch(self, sess, init_lr, summary_writer=None):
        """ Train the model for one epoch

            Args:
                sess (tf.Session())
                init_lr (float): learning rate
                summary_writer (tf.summary)
        """

        display_name_list = ['cls_loss', 'bbox_loss', 'obj_loss', 'loss']
        cur_summary = None
        lr = init_lr

        cls_loss_sum = 0
        bbox_loss_sum = 0
        obj_loss_sum = 0
        loss_sum = 0
        step = 0

        self.epoch_id += 1
        while True:
            try:
                step += 1
                self.global_step += 1

                if step % 100 == 0:
                    _, loss, cls_loss, bbox_loss, obj_loss, cur_summary = sess.run(
                        [
                            self.train_op, self.loss_op, self.cls_loss,
                            self.bbox_loss, self.obj_loss,
                            self.train_summary_op
                        ],
                        feed_dict={self.lr: lr})

                    viz.display(
                        self.global_step,
                        step,
                        [cls_loss_sum, bbox_loss_sum, obj_loss_sum, loss_sum],
                        display_name_list,
                        'train',
                        summary_val=cur_summary,
                        summary_writer=summary_writer)

                else:
                    _, loss, cls_loss, bbox_loss, obj_loss = sess.run(
                        [
                            self.train_op, self.loss_op, self.cls_loss,
                            self.bbox_loss, self.obj_loss
                        ],
                        feed_dict={self.lr: lr})

                cls_loss_sum += cls_loss
                bbox_loss_sum += bbox_loss
                obj_loss_sum += obj_loss
                loss_sum += loss

            except tf.errors.OutOfRangeError:
                break

        # write summary
        print('==== epoch: {}, lr:{} ===='.format(self.epoch_id, lr))
        viz.display(self.global_step,
                    step,
                    [cls_loss_sum, bbox_loss_sum, obj_loss_sum, loss_sum],
                    display_name_list,
                    'train',
                    summary_val=None,
                    summary_writer=summary_writer)
コード例 #8
0
    def train_epoch(self,
                    sess,
                    train_data,
                    init_lr,
                    n_g_train=1,
                    n_d_train=1,
                    keep_prob=1.0,
                    summary_writer=None):

        assert int(n_g_train) > 0 and int(n_d_train) > 0
        display_name_list = ['d_loss', 'g_loss', 'L_fake', 'L_real']
        cur_summary = None

        lr = init_lr * (0.9**self.epoch_id)

        cur_epoch = train_data.epochs_completed
        step = 0
        d_loss_sum = 0
        g_loss_sum = 0
        l_fake_sum = 0
        l_real_sum = 0
        self.epoch_id += 1
        while cur_epoch == train_data.epochs_completed:
            self.global_step += 1
            step += 1

            batch_data = train_data.next_batch_dict()
            im = batch_data['im']

            random_vec = distributions.random_vector((len(im), self.n_code),
                                                     dist_type='uniform')

            # train discriminator
            for i in range(int(n_d_train)):

                _, d_loss = sess.run(
                    [self.train_d_op, self.d_loss_op],
                    feed_dict={
                        self.real: im,
                        self.lr: lr,
                        self.keep_prob: keep_prob,
                        self.random_vec: random_vec
                    })

            # train generator
            for i in range(int(n_g_train)):
                # random_vec = distributions.random_vector(
                #     (len(im), self.n_code), dist_type='uniform')
                _, g_loss = sess.run(
                    [self.train_g_op, self.g_loss_op],
                    feed_dict={
                        self.lr: lr,
                        self.keep_prob: keep_prob,
                        self.random_vec: random_vec
                    })

            # update k
            # random_vec = distributions.random_vector(
            #         (len(im), self.n_code), dist_type='uniform')
            _, L_fake, L_real = sess.run(
                [self.update_op, self.L_fake, self.L_real],
                feed_dict={
                    self.real: im,
                    self.random_vec: random_vec,
                    self.keep_prob: keep_prob,
                })

            d_loss_sum += d_loss
            g_loss_sum += g_loss
            l_fake_sum += L_fake
            l_real_sum += L_real

            if step % 100 == 0:
                cur_summary = sess.run(self.train_summary_op,
                                       feed_dict={
                                           self.real: im,
                                           self.keep_prob: keep_prob,
                                           self.random_vec: random_vec
                                       })

                viz.display(self.global_step,
                            step, [
                                d_loss_sum / n_d_train, g_loss_sum / n_g_train,
                                l_fake_sum, l_real_sum
                            ],
                            display_name_list,
                            'train',
                            summary_val=cur_summary,
                            summary_writer=summary_writer)

        print('==== epoch: {}, lr:{} ===='.format(cur_epoch, lr))
        cur_summary = sess.run(self.train_summary_op,
                               feed_dict={
                                   self.real: im,
                                   self.keep_prob: keep_prob,
                                   self.random_vec: random_vec
                               })
        viz.display(self.global_step,
                    step, [
                        d_loss_sum / n_d_train, g_loss_sum / n_g_train,
                        l_fake_sum, l_real_sum
                    ],
                    display_name_list,
                    'train',
                    summary_val=cur_summary,
                    summary_writer=summary_writer)