Ejemplo n.º 1
0
def train(ob, max_steps=30001, batch_size=3):
        # For train the bayes, the FLAG_OPT SHOULD BE SGD, BUT FOR TRAIN THE NORMAL SEGNET,
        # THE FLAG_OPT SHOULD BE ADAM!!!
	image_filename=[]
	label_filename=[]
	val_image_filename=[]
	val_label_filename=[]
	i=0
	for o in ob:

		image_filename[i], label_filename[i] = get_filename_list(o.train_file, o.config)
		val_image_filename[i], val_label_filename[i] = get_filename_list(o.val_file, o.config)
        	i=i+1
	with ob[0].graph.as_default():
		with ob[1].graph.as_default():
			with ob[2].graph.as_default():
				with ob[3].graph.as_default():
					i=0
					for o in ob:
						if o.images_tr is None:
							o.images_tr, o.labels_tr = dataset_inputs(image_filename[i], label_filename[i], batch_size, o.config)
							o.images_val, o.labels_val = dataset_inputs(val_image_filename[i], val_label_filename[i], batch_size,o.config)

					l=tf.concat([latenv(ob[0]),latenv(ob[1])latenv(ob[2]),latenv(ob[3])],axis=0)
              				fc1=tf.contrib.layers.fully_connected(l,786432,activation_fn=tf.nn.relu,    normalizer_fn=None, weights_initializer=initializers.xavier_initializer(),
              			biases_initializer=tf.zeros_initializer(),trainable=True)
              				fc2=tf.contrib.layers.fully_connected(fc1,540000,activation_fn=tf.nn.relu,    normalizer_fn=None, weights_initializer=initializers.xavier_initializer(),
              			biases_initializer=tf.zeros_initializer(),trainable=True)
              				fc3=tf.contrib.layers.fully_connected(fc2,270000,activation_fn=tf.nn.relu,    normalizer_fn=None, weights_initializer=initializers.xavier_initializer(),
              			biases_initializer=tf.zeros_initializer(),trainable=True)
              				fc4=tf.contrib.layers.fully_connected(fc3,270000,activation_fn=None,    normalizer_fn=None, weights_initializer=initializers.xavier_initializer(),
              			biases_initializer=tf.zeros_initializer(),trainable=True)
                                                                  
            					#define separate losses for each segment, and a loss for the fully connected layers, change input file
            				loss1 = segloss(logits=ob[0].logits, labels=ob[0].labels_pl)
					loss2 = segloss(logits=ob[1].logits, labels=ob[1].labels_pl)
					loss3 = segloss(logits=ob[2].logits, labels=ob[2].labels_pl)
					loss4 = segloss(logits=ob[3].logits, labels=ob[3].labels_pl)
					fcloss=multiloss(fc4,ob[0].labels_pl)
					init=tf.global_variables_initializer()

					train1, train2,train3,train4,trainf,global_step = train_op(seg1=loss1,seg2=loss2, seg3=loss3,seg3=loss3,lossf=fcloss,opt='ADAM')
Ejemplo n.º 2
0
    def train(self, max_steps=30000, batch_size=3):
        image_filename, label_filename = get_filename_list(
            self.train_file, self.config)
        val_image_filename, val_label_filename = get_filename_list(
            self.val_file, self.config)

        with self.graph.as_default():
            if self.images_tr is None:
                self.images_tr, self.labels_tr = dataset_inputs(
                    image_filename, label_filename, batch_size, self.config)
                self.images_val, self.labels_val = dataset_inputs(
                    val_image_filename, val_label_filename, batch_size,
                    self.config)

            loss, accuracy, prediction = cal_loss(logits=self.logits,
                                                  labels=self.labels_pl)
            #                                                  , number_class=self.num_classes)
            train, global_step = train_op(total_loss=loss, opt=self.opt)

            summary_op = tf.summary.merge_all()

            with self.sess.as_default():
                self.sess.run(tf.local_variables_initializer())
                self.sess.run(tf.global_variables_initializer())

                coord = tf.train.Coordinator()
                threads = tf.train.start_queue_runners(coord=coord)

                train_writer = tf.summary.FileWriter(self.tb_logs,
                                                     self.sess.graph)
                self.saver = tf.train.Saver()

                for step in range(max_steps):
                    image_batch, label_batch = self.sess.run(
                        [self.images_tr, self.labels_tr])
                    feed_dict = {
                        self.inputs_pl: image_batch,
                        self.labels_pl: label_batch,
                        self.batch_size_pl: batch_size
                    }

                    _, _loss, _accuracy, summary = self.sess.run(
                        [train, loss, accuracy, summary_op],
                        feed_dict=feed_dict)
                    self.train_loss.append(_loss)
                    self.train_accuracy.append(_accuracy)
                    print(
                        "Iteration {}: Train Loss{:6.3f}, Train Accu {:6.3f}".
                        format(step, self.train_loss[-1],
                               self.train_accuracy[-1]))

                    if step % 100 == 0:
                        train_writer.add_summary(summary, step)

                    if step % 1000 == 0:
                        print("start validating..")
                        _val_loss = []
                        _val_acc = []
                        for test_step in range(9):
                            image_batch_val, label_batch_val = self.sess.run(
                                [self.images_val, self.labels_val])
                            feed_dict_valid = {
                                self.inputs_pl: image_batch_val,
                                self.labels_pl: label_batch_val,
                                self.batch_size_pl: batch_size
                            }
                            # since we still using mini-batch, so in the batch norm we set phase_train to be
                            # true, and because we didin't run the trainop process, so it will not update
                            # the weight!
                            _loss, _acc, _val_pred = self.sess.run(
                                [loss, accuracy, self.logits], feed_dict_valid)
                            _val_loss.append(_loss)
                            _val_acc.append(_acc)

                        self.val_loss.append(np.mean(_val_loss))
                        self.val_acc.append(np.mean(_val_acc))

                        print("Val Loss {:6.3f}, Val Acc {:6.3f}".format(
                            self.val_loss[-1], self.val_acc[-1]))

                        self.saver.save(self.sess,
                                        os.path.join(self.saved_dir,
                                                     'model.ckpt'),
                                        global_step=self.model_version)
                        self.model_version = self.model_version + 1

                coord.request_stop()
                coord.join(threads)
Ejemplo n.º 3
0
    def train(self, max_steps=30001, batch_size=3):
        # For train the bayes, the FLAG_OPT SHOULD BE SGD, BUT FOR TRAIN THE NORMAL SEGNET,
        # THE FLAG_OPT SHOULD BE ADAM!!!

        image_filename, label_filename = get_filename_list(self.train_file, self.config)
        val_image_filename, val_label_filename = get_filename_list(self.val_file, self.config)

        with self.graph.as_default():
            if self.images_tr is None:
                self.images_tr, self.labels_tr = dataset_inputs(image_filename, label_filename, batch_size, self.config)
                self.images_val, self.labels_val = dataset_inputs(val_image_filename, val_label_filename, batch_size,
                                                                  self.config)

            loss, accuracy, prediction = cal_loss(logits=self.logits, labels=self.labels_pl,
                                                     number_class=self.num_classes)
            train, global_step = train_op(total_loss=loss, opt=self.opt)

            summary_op = tf.summary.merge_all()

            with self.sess.as_default():
                self.sess.run(tf.local_variables_initializer())
                self.sess.run(tf.global_variables_initializer())

                coord = tf.train.Coordinator()
                threads = tf.train.start_queue_runners(coord=coord)
                # The queue runners basic reference:
                # https://www.tensorflow.org/versions/r0.12/how_tos/threading_and_queues
                train_writer = tf.summary.FileWriter(self.tb_logs, self.sess.graph)
                for step in range(max_steps):
                    print("OK")
                    image_batch, label_batch = self.sess.run([self.images_tr, self.labels_tr])
                    feed_dict = {self.inputs_pl: image_batch,
                                 self.labels_pl: label_batch,
                                 self.is_training_pl: True,
                                 self.keep_prob_pl: 0.5,
                                 self.with_dropout_pl: True,
                                 self.batch_size_pl: batch_size}

                    _, _loss, _accuracy, summary = self.sess.run([train, loss, accuracy, summary_op],
                                                                 feed_dict=feed_dict)
                    self.train_loss.append(_loss)
                    self.train_accuracy.append(_accuracy)
                    print("Iteration {}: Train Loss{:6.3f}, Train Accu {:6.3f}".format(step, self.train_loss[-1],
                                                                                       self.train_accuracy[-1]))

                    if step % 100 == 0:
                        conv_classifier = self.sess.run(self.logits, feed_dict=feed_dict)
                        print('per_class accuracy by logits in training time',
                              per_class_acc(conv_classifier, label_batch, self.num_classes))
                        # per_class_acc is a function from utils
                        train_writer.add_summary(summary, step)

                    if step % 1000 == 0:
                        print("start validating.......")
                        _val_loss = []
                        _val_acc = []
                        hist = np.zeros((self.num_classes, self.num_classes))
                        for test_step in range(int(20)):
                            fetches_valid = [loss, accuracy, self.logits]
                            image_batch_val, label_batch_val = self.sess.run([self.images_val, self.labels_val])
                            feed_dict_valid = {self.inputs_pl: image_batch_val,
                                               self.labels_pl: label_batch_val,
                                               self.is_training_pl: True,
                                               self.keep_prob_pl: 1.0,
                                               self.with_dropout_pl: False,
                                               self.batch_size_pl: batch_size}
                            # since we still using mini-batch, so in the batch norm we set phase_train to be
                            # true, and because we didin't run the trainop process, so it will not update
                            # the weight!
                            _loss, _acc, _val_pred = self.sess.run(fetches_valid, feed_dict_valid)
                            _val_loss.append(_loss)
                            _val_acc.append(_acc)
                            hist += get_hist(_val_pred, label_batch_val)

                        print_hist_summary(hist)

                        self.val_loss.append(np.mean(_val_loss))
                        self.val_acc.append(np.mean(_val_acc))

                        print(
                            "Iteration {}: Train Loss {:6.3f}, Train Acc {:6.3f}, Val Loss {:6.3f}, Val Acc {:6.3f}".format(
                                step, self.train_loss[-1], self.train_accuracy[-1], self.val_loss[-1],
                                self.val_acc[-1]))

                coord.request_stop()
                coord.join(threads)
Ejemplo n.º 4
0
    def train(self):
        image_filename, label_filename = get_filename_list(self.train_file)
        val_image_filename, val_label_filename = get_filename_list(
            self.val_file)

        if self.images_tr is None:
            self.images_tr, self.labels_tr = dataset_inputs(
                image_filename, label_filename, FLAGS.batch_size, self.input_w,
                self.input_h, self.input_c)
            self.images_val, self.labels_val = dataset_inputs(
                val_image_filename, val_label_filename, FLAGS.batch_size,
                self.input_w, self.input_h, self.input_c)

        loss, accuracy, predictions = cal_loss(logits=self.logits,
                                               labels=self.labels_pl,
                                               n_classes=self.n_classes)
        train, global_step = train_op(loss, FLAGS.learning_rate)

        tf.summary.scalar("global_step", global_step)
        tf.summary.scalar("total loss", loss)

        # Calculate total number of trainable parameters
        total_parameters = 0
        for variable in tf.trainable_variables():
            shape = variable.get_shape()
            variable_parameters = 1
            for dim in shape:
                variable_parameters *= dim.value
            total_parameters += variable_parameters
        print('Total Trainable Parameters: ', total_parameters)

        with tf.train.SingularMonitoredSession(
                # save/load model state
                checkpoint_dir=FLAGS.runtime_dir,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.n_epochs),
                    tf.train.CheckpointSaverHook(
                        checkpoint_dir=FLAGS.runtime_dir,
                        save_steps=FLAGS.checkpoint_frequency,
                        saver=tf.train.Saver()),
                    tf.train.SummarySaverHook(
                        save_steps=FLAGS.summary_frequency,
                        output_dir=FLAGS.runtime_dir,
                        scaffold=tf.train.Scaffold(
                            summary_op=tf.summary.merge_all()),
                    )
                ],
                config=tf.ConfigProto(log_device_placement=True)) as mon_sess:

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord, sess=mon_sess)

            while not mon_sess.should_stop():

                image_batch, label_batch = mon_sess.raw_session().run(
                    [self.images_tr, self.labels_tr])
                feed_dict = {
                    self.inputs_pl: image_batch,
                    self.labels_pl: label_batch,
                    self.is_training_pl: True,
                    self.keep_prob_pl: 0.5,
                    self.with_dropout_pl: True,
                    self.batch_size_pl: FLAGS.batch_size
                }

                step, _, training_loss, training_acc = mon_sess.run(
                    [global_step, train, loss, accuracy], feed_dict=feed_dict)

                print("Iteration {}: Train Loss{:9.6f}, Train Accu {:9.6f}".
                      format(step, training_loss, training_acc))

                # Check against validation set
                if step % FLAGS.validate_frequency == 0:
                    sampled_losses = []
                    sampled_accuracies = []

                    hist = np.zeros((self.n_classes, self.n_classes))

                    for test_step in range(int(20)):
                        fetches_valid = [loss, accuracy, self.logits]
                        image_batch_val, label_batch_val = mon_sess.raw_session(
                        ).run([self.images_val, self.labels_val])

                        feed_dict_valid = {
                            self.inputs_pl: image_batch_val,
                            self.labels_pl: label_batch_val,
                            self.is_training_pl: True,
                            self.keep_prob_pl: 1.0,
                            self.with_dropout_pl: False,
                            self.batch_size_pl: FLAGS.batch_size
                        }

                        validate_loss, validate_acc, predictions = mon_sess.raw_session(
                        ).run(fetches_valid, feed_dict_valid)
                        sampled_losses.append(validate_loss)
                        sampled_accuracies.append(validate_acc)
                        hist += get_hist(predictions, label_batch_val)

                    print_hist_summary(hist)

                    # Average loss and accuracy over n samples from validation set
                    avg_loss = np.mean(sampled_losses)
                    avg_acc = np.mean(sampled_accuracies)

                    print(
                        "Iteration {}: Avg Val Loss {:9.6f}, Avg Val Acc {:9.6f}"
                        .format(step, avg_loss, avg_acc))

                coord.request_stop()
                coord.join(threads)