Exemple #1
0
    def train(self):
        self.sess.run(tf.global_variables_initializer())
        self.itr_epoch = len(self.source_training_data[0]) // self.bs
        self.total_iteration = self.eps * self.itr_epoch

        training_acc = 0.0
        training_loss = 0.0

        for itr in range(1, self.total_iteration + 1):
            _tr_img_batch, _tr_lab_batch = init.next_batch(image=self.source_training_data[0],
                                                           label=self.source_training_data[1],
                                                           batch_size=self.bs)

            _train_accuracy, _train_loss, _ = self.sess.run([self.accuracy, self.loss, self.train_op],
                                                            feed_dict={self.x: _tr_img_batch,
                                                                       self.y: _tr_lab_batch,
                                                                       self.is_training: True})
            training_acc += _train_accuracy
            training_loss += _train_loss

            if itr % self.itr_epoch == 0:
                _current_eps = int(itr / self.itr_epoch)
                summary = self.sess.run(self.merged, feed_dict={self.x: _tr_img_batch,
                                                                self.y: _tr_lab_batch,
                                                                self.is_training: False})

                training_acc = float(training_acc / self.itr_epoch)
                training_loss = float(training_loss / self.itr_epoch)

                validation_acc, validation_loss = eval.validation_procedure(validation_data=self.source_validation_data,
                                                                            distribution_op=self.distribution,
                                                                            loss_op=self.loss, inputX=self.x,
                                                                            inputY=self.y, num_class=self.num_class,
                                                                            batch_size=self.bs,
                                                                            is_training=self.is_training,
                                                                            session=self.sess)

                log = "Epoch: [%d], Training Accuracy: [%g], Validation Accuracy: [%g], Loss Training: [%g], " \
                      "Loss_validation: [%g], Time: [%s]" % \
                      (_current_eps, training_acc, validation_acc, training_loss, validation_loss,
                       time.ctime(time.time()))

                init.save2file(log, self.ckptDir, self.model)

                self.writer.add_summary(summary, _current_eps)

                self.saver.save(self.sess, self.ckptDir + self.model + '-' + str(_current_eps))

                eval.test_procedure(test_data=self.source_test_data, distribution_op=self.distribution, inputX=self.x,
                                    inputY=self.y, mode='source', num_class=self.num_class, batch_size=self.bs,
                                    session=self.sess, is_training=self.is_training, ckptDir=self.ckptDir,
                                    model=self.model)

                eval.test_procedure(test_data=self.target_test_data, distribution_op=self.distribution, inputX=self.x,
                                    inputY=self.y, mode='target', num_class=self.num_class, batch_size=self.bs,
                                    session=self.sess, is_training=self.is_training, ckptDir=self.ckptDir,
                                    model=self.model)

                training_acc = 0.0
                training_loss = 0.0
Exemple #2
0
 def test(self, reload_path):
     self.saver.restore(self.sess, reload_path)
     eval.test_procedure(self.target_test_data,
                         distribution_op=self.distribution_source,
                         inputX=self.x_source,
                         inputY=self.y_source,
                         mode='source',
                         batch_size=self.bs,
                         session=self.sess,
                         is_training=self.is_training,
                         ckptDir=self.ckptDir,
                         model=self.model,
                         num_class=self.num_class)
Exemple #3
0
    def train(self):
        self.sess.run(tf.global_variables_initializer())
        self.itr_epoch = len(self.training_data[0]) // self.bs

        training_acc = 0.0
        training_loss = 0.0

        for e in range(1, self.eps + 1):
            if e == 150 or e == 225:
                self.lr *= 0.1
            for itr in range(self.itr_epoch):
                _index = np.random.randint(low=0, high=len(self.training_data[0]), size=self.bs)
                _tr_img_batch = self.training_data[0][_index]
                _tr_lab_batch = self.training_data[1][_index]

                _tr_img_batch = data_augmentation(_tr_img_batch)

                _train_accuracy, _train_loss, _ = self.sess.run([self.accuracy, self.cost, self.train_op],
                                                                feed_dict={self.x: _tr_img_batch,
                                                                           self.y: _tr_lab_batch,
                                                                           self.is_training: True})
                training_acc += _train_accuracy
                training_loss += _train_loss

            summary = self.sess.run(self.merged, feed_dict={self.x: _tr_img_batch,
                                                            self.y: _tr_lab_batch,
                                                            self.is_training: False})

            training_acc = float(training_acc / self.itr_epoch)
            training_loss = float(training_loss / self.itr_epoch)

            log = "Epoch: [%d], Training Accuracy: [%g], Training Loss: [%g], Learning Rate: [%f], Time: [%s]" % \
                  (e, training_acc, training_loss, self.lr, time.ctime(time.time()))

            save2file(log, self.ckptDir, self.model)

            self.writer.add_summary(summary, e)

            self.saver.save(self.sess, self.ckptDir + self.model + '-' + str(e))

            eval.test_procedure(test_data=self.test_data, distribution_op=self.distribution, inputX=self.x,
                                inputY=self.y, mode='test', num_class=self.num_class, batch_size=self.bs,
                                session=self.sess, is_training=self.is_training, ckptDir=self.ckptDir,
                                model=self.model)

            training_acc = 0.0
            training_loss = 0.0
Exemple #4
0
    def train(self):
        self.sess.run(tf.global_variables_initializer())
        self.itr_epoch = len(self.source_training_data[0]) // self.bs

        source_training_acc = 0.0
        source_training_loss = 0.0

        for e in range(1, self.eps + 1):
            for itr in range(self.itr_epoch):
                feed_dict_train, feed_dict_eval = self.getBatchData()
                _ = self.sess.run(self.train_op, feed_dict=feed_dict_train)

                _training_accuracy, _training_loss = self.sess.run(
                    [self.accuracy_source, self.loss],
                    feed_dict=feed_dict_eval)

                source_training_acc += _training_accuracy
                source_training_loss += _training_loss

            summary = self.sess.run(self.merged, feed_dict=feed_dict_eval)

            source_training_acc = float(source_training_acc / self.itr_epoch)
            source_training_loss = float(source_training_loss / self.itr_epoch)

            source_validation_acc, source_validation_loss = eval.validation_procedure(
                validation_data=self.source_validation_data,
                distribution_op=self.distribution_source,
                loss_op=self.loss,
                inputX=self.x_source,
                inputY=self.y_source,
                num_class=self.num_class,
                batch_size=self.bs,
                is_training=self.is_training,
                session=self.sess)

            log1 = "Epoch: [%d], Training Accuracy: [%g], Validation Accuracy: [%g], Training Loss: [%g], " \
                   "Validation Loss: [%g], Time: [%s]" % (
                e, source_training_acc, source_validation_acc, source_training_loss, source_validation_loss,
                time.ctime(time.time()))

            self.plt_epoch.append(e)
            self.plt_training_accuracy.append(source_training_acc)
            self.plt_training_loss.append(source_training_loss)
            self.plt_validation_accuracy.append(source_validation_acc)
            self.plt_validation_loss.append(source_validation_loss)

            utils.plotAccuracy(x=self.plt_epoch,
                               y1=self.plt_training_accuracy,
                               y2=self.plt_validation_accuracy,
                               figName=self.model,
                               line1Name='training',
                               line2Name='validation',
                               savePath=self.ckptDir)

            utils.plotLoss(x=self.plt_epoch,
                           y1=self.plt_training_loss,
                           y2=self.plt_validation_loss,
                           figName=self.model,
                           line1Name='training',
                           line2Name='validation',
                           savePath=self.ckptDir)

            init.save2file(log1, self.ckptDir, self.model)

            self.writer.add_summary(summary, e)

            self.saver.save(self.sess,
                            self.ckptDir + self.model + '-' + str(e))

            eval.test_procedure(self.source_test_data,
                                distribution_op=self.distribution_source,
                                inputX=self.x_source,
                                inputY=self.y_source,
                                mode='source',
                                batch_size=self.bs,
                                session=self.sess,
                                is_training=self.is_training,
                                ckptDir=self.ckptDir,
                                model=self.model,
                                num_class=self.num_class)

            source_training_acc = 0.0
            source_training_loss = 0.0
Exemple #5
0
    def train(self):
        print('Initialize parameters')
        self.sess.run(tf.global_variables_initializer())
        print('Global variables initialization finished')

        print('Reload parameters')
        self.src_encoder_reloadSaver = tf.train.Saver(
            var_list=self.src_reload_var)
        self.tar_encoder_reloadSaver = tf.train.Saver(
            var_list=self.tar_reload_var)
        self.classifier_reloadSaver = tf.train.Saver(
            var_list=self.cla_reload_var)

        self.src_encoder_reloadSaver.restore(self.sess, self.reloadPath)
        self.tar_encoder_reloadSaver.restore(self.sess, self.reloadPath)
        self.classifier_reloadSaver.restore(self.sess, self.reloadPath)
        print(
            'source encoder, target encoder and classifier have been successfully reloaded !'
        )

        # 开始训练
        self.itr_epoch = len(self.source_training_data[0]) // self.bs

        source_training_acc = 0.0
        g_loss = 0.0
        d_loss = 0.0

        for e in range(1, self.eps + 1):
            for itr in range(self.itr_epoch):
                feed_dict_train, feed_dict_eval = self.getBatchData()
                _ = self.sess.run(self.d_train_op, feed_dict=feed_dict_train)
                _ = self.sess.run(self.g_train_op, feed_dict=feed_dict_train)

                _training_accuracy, _g_loss, _d_loss = self.sess.run(
                    [self.accuracy_source, self.g_loss, self.d_loss],
                    feed_dict=feed_dict_eval)

                source_training_acc += _training_accuracy
                g_loss += _g_loss
                d_loss += _d_loss

            summary = self.sess.run(self.merged, feed_dict=feed_dict_eval)

            source_training_acc = float(source_training_acc / self.itr_epoch)
            g_loss = float(g_loss / self.itr_epoch)
            d_loss = float(d_loss / self.itr_epoch)

            log1 = "Epoch: [%d], Training Accuracy: [%g], G Loss: [%g], D Loss: [%g], Time: [%s]" % (
                e, source_training_acc, g_loss, d_loss, time.ctime(
                    time.time()))

            self.plt_epoch.append(e)
            self.plt_training_accuracy.append(source_training_acc)
            self.plt_g_loss.append(g_loss)
            self.plt_d_loss.append(d_loss)

            utils.plotAccuracy(x=self.plt_epoch,
                               y1=self.plt_training_accuracy,
                               y2=None,
                               figName=self.model,
                               line1Name='training',
                               line2Name='',
                               savePath=self.ckptDir)

            utils.plotLoss(x=self.plt_epoch,
                           y1=self.plt_g_loss,
                           y2=self.plt_d_loss,
                           figName=self.model,
                           line1Name='g loss',
                           line2Name='d loss',
                           savePath=self.ckptDir)

            init.save2file(log1, self.ckptDir, self.model)

            self.writer.add_summary(summary, e)

            self.saver.save(self.sess,
                            self.ckptDir + self.model + '-' + str(e))

            eval.test_procedure(self.source_test_data,
                                distribution_op=self.distribution_source,
                                inputX=self.x_source,
                                inputY=self.y_source,
                                mode='source',
                                num_class=self.num_class,
                                batch_size=self.bs,
                                session=self.sess,
                                is_training=self.is_training,
                                ckptDir=self.ckptDir,
                                model=self.model)
            eval.test_procedure(self.target_test_data,
                                distribution_op=self.distribution_target,
                                inputX=self.x_target,
                                inputY=self.y_target,
                                mode='target',
                                num_class=self.num_class,
                                batch_size=self.bs,
                                session=self.sess,
                                is_training=self.is_training,
                                ckptDir=self.ckptDir,
                                model=self.model)

            source_training_acc = 0.0
            g_loss = 0.0
            d_loss = 0.0