コード例 #1
0
    def train(self):
        self.sess.run(tf.global_variables_initializer())

        self.itr_epoch = len(self.source_training_data[0]) // self.bs

        source_training_acc = 0.0
        source_training_loss = 0.0

        for e in range(1, self.eps + 1):
            for itr in range(self.itr_epoch):
                feed_dict_train, feed_dict_eval = self.getBatchData()
                _ = self.sess.run(self.train_op, feed_dict=feed_dict_train)

                _training_accuracy, _training_loss = self.sess.run([self.accuracy_source, self.loss],
                                                                   feed_dict=feed_dict_eval)

                source_training_acc += _training_accuracy
                source_training_loss += _training_loss

            summary = self.sess.run(self.merged, feed_dict=feed_dict_eval)

            source_training_acc = float(source_training_acc / self.itr_epoch)
            source_training_loss = float(source_training_loss / self.itr_epoch)

            source_validation_acc, source_validation_loss = self.validation_procedure(
                validation_data=self.source_validation_data, distribution_op=self.distribution_source,
                loss_op=self.loss, inputX=self.x_source, inputY=self.y_source)

            log1 = "Epoch: [%d], Domain: Source, Training Accuracy: [%g], Validation Accuracy: [%g], " \
                   "Training Loss: [%g], Validation Loss: [%g], Time: [%s]" % (
                       e, source_training_acc, source_validation_acc, source_training_loss, source_validation_loss,
                       time.ctime(time.time()))

            self.plt_epoch.append(e)
            self.plt_training_accuracy.append(source_training_acc)
            self.plt_training_loss.append(source_training_loss)
            self.plt_validation_accuracy.append(source_validation_acc)
            self.plt_validation_loss.append(source_validation_loss)

            CoGAN_utils.plotAccuracy(x=self.plt_epoch,
                                     y1=self.plt_training_accuracy,
                                     y2=self.plt_validation_accuracy,
                                     figName=self.model,
                                     line1Name='training',
                                     line2Name='validation',
                                     savePath=self.ckptDir)

            CoGAN_utils.plotLoss(x=self.plt_epoch,
                                 y1=self.plt_training_loss,
                                 y2=self.plt_validation_loss,
                                 figName=self.model,
                                 line1Name='training',
                                 line2Name='validation',
                                 savePath=self.ckptDir)

            CoGAN_utils.save2file(log1, self.ckptDir, self.model)

            self.writer.add_summary(summary, e)

            self.saver.save(self.sess, self.ckptDir + self.model + '-' + str(e))

            self.test_procedure(self.source_test_data, distribution_op=self.distribution_source,
                                inputX=self.x_source,
                                inputY=self.y_source, mode='source')

            source_training_acc = 0.0
            source_training_loss = 0.0
コード例 #2
0
    def train(self):
        print('Initialize parameters')
        self.sess.run(tf.global_variables_initializer())
        print('Global variables initialization finished')

        print('Reload parameters')
        dict_var = {}
        for i in self.src_var:
            for j in self.tar_var:
                if i.name[i.name.find('/') + 1:] in j.name[j.name.find('/') +
                                                           1:]:
                    dict_var[i.name[:-2]] = j

        self.src_encoder_reloadSaver = tf.train.Saver(var_list=self.src_var)
        self.tar_encoder_reloadSaver = tf.train.Saver(var_list=dict_var)
        self.classifier_reloadSaver = tf.train.Saver(var_list=self.cla_var)

        self.src_encoder_reloadSaver.restore(self.sess, self.reloadPath)
        self.tar_encoder_reloadSaver.restore(self.sess, self.reloadPath)
        self.classifier_reloadSaver.restore(self.sess, self.reloadPath)
        print(
            'source encoder, target encoder and classifier have been successfully reloaded !'
        )

        self.itr_epoch = len(self.source_training_data[0]) // self.bs

        source_training_acc = 0.0
        source_training_loss = 0.0

        for e in range(1, self.eps + 1):
            for itr in range(self.itr_epoch):
                feed_dict_train, feed_dict_eval = self.getBatchData()
                _ = self.sess.run(self.d_train_op, feed_dict=feed_dict_train)
                _ = self.sess.run(self.g_train_op, feed_dict=feed_dict_train)

                _training_accuracy, _training_loss = self.sess.run(
                    [self.accuracy_source, self.supervised_loss],
                    feed_dict=feed_dict_eval)

                source_training_acc += _training_accuracy
                source_training_loss += _training_loss

            summary = self.sess.run(self.merged, feed_dict=feed_dict_eval)

            source_training_acc = float(source_training_acc / self.itr_epoch)
            source_training_loss = float(source_training_loss / self.itr_epoch)

            source_validation_acc, source_validation_loss = self.validation_procedure(
                validation_data=self.source_validation_data,
                distribution_op=self.distribution_source,
                loss_op=self.supervised_loss,
                inputX=self.x_source,
                inputY=self.y_source)

            log1 = "Epoch: [%d], Domain: Source, Training Accuracy: [%g], Validation Accuracy: [%g], " \
                   "Training Loss: [%g], Validation Loss: [%g], Time: [%s]" % (
                       e, source_training_acc, source_validation_acc, source_training_loss, source_validation_loss,
                       time.ctime(time.time()))

            self.plt_epoch.append(e)
            self.plt_training_accuracy.append(source_training_acc)
            self.plt_training_loss.append(source_training_loss)
            self.plt_validation_accuracy.append(source_validation_acc)
            self.plt_validation_loss.append(source_validation_loss)

            CoGAN_utils.plotAccuracy(x=self.plt_epoch,
                                     y1=self.plt_training_accuracy,
                                     y2=self.plt_validation_accuracy,
                                     figName=self.model,
                                     line1Name='training',
                                     line2Name='validation',
                                     savePath=self.ckptDir)

            CoGAN_utils.plotLoss(x=self.plt_epoch,
                                 y1=self.plt_training_loss,
                                 y2=self.plt_validation_loss,
                                 figName=self.model,
                                 line1Name='training',
                                 line2Name='validation',
                                 savePath=self.ckptDir)

            CoGAN_utils.save2file(log1, self.ckptDir, self.model)

            self.writer.add_summary(summary, e)

            self.saver.save(self.sess,
                            self.ckptDir + self.model + '-' + str(e))

            self.test_procedure(self.source_test_data,
                                distribution_op=self.distribution_source,
                                inputX=self.x_source,
                                inputY=self.y_source,
                                mode='source')
            self.test_procedure(self.target_test_data,
                                distribution_op=self.distribution_target,
                                inputX=self.x_target,
                                inputY=self.y_target,
                                mode='target')

            source_training_acc = 0.0
            source_training_loss = 0.0