Пример #1
0
    def inference(self, X, dirr, train_status=False, Not_Realtest=True):
        now = datetime.now().isoformat()[11:]
        print("------- Testing begin: {} -------\n".format(now), flush=True)
        if Not_Realtest:
            x_list, y_list, tnum = X.next_batch()
        else:
            x_list, tnum = X.next_batch(Not_Realtest=False)
            y_list = None
        rd = 0
        pdx = []
        yl = []
        with tf.Session() as sessb:
            # Initialize all global and local variables
            init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
            sessb.run(init_op)
            # Create a coordinator and run all QueueRunner objects
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord, sess=sessb)

            while True:
                try:
                    if Not_Realtest:
                        x, y = sessb.run([x_list, y_list])
                        x = x.astype(np.uint8)
                    else:
                        x = sessb.run([x_list])
                        y = None
                        x = x[0].astype(np.uint8)
                    feed_dict = {self.x_in: x, self.is_train: train_status}
                    fetches = [self.pred, self.net, self.w]
                    pred, net, w = self.sesh.run(fetches, feed_dict)
                    if Not_Realtest:
                        ac.CAM(net, w, pred, x, y, dirr, 'Test', rd)
                    # else:
                    #     ac.CAM_R(net, w, pred, x, dirr, 'Test', rd)

                    if rd == 0:
                        pdx = pred
                        yl = y
                    else:
                        pdx = np.concatenate((pdx, pred), axis=0)
                        yl = np.concatenate((yl, y), axis=None)

                    rd += 1

                except tf.errors.OutOfRangeError:
                    # Stop the threads
                    coord.request_stop()

                    # Wait for threads to stop
                    coord.join(threads)
                    if Not_Realtest:
                        ac.metrics(pdx, yl, dirr, 'Test')
                    else:
                        ac.realout(pdx, dirr, 'Test')
                    sessb.close()
                    now = datetime.now().isoformat()[11:]
                    print("------- Testing end: {} -------\n".format(now), flush=True)
                    break
Пример #2
0
    def inference(self, X, dirr, train_status=False, Not_Realtest=True):
        now = datetime.now().isoformat()[11:]
        print("------- Testing begin: {} -------\n".format(now), flush=True)
        rd = 0
        pdx = []
        yl = []
        if Not_Realtest:
            itr, file, ph = X.data()
            next_element = itr.get_next()
            with tf.Session() as sessa:
                sessa.run(itr.initializer, feed_dict={ph: file})
                while True:
                    try:
                        x, y = sessa.run(next_element)
                        feed_dict = {self.x_in: x, self.is_train: train_status}
                        fetches = [self.pred, self.net, self.w]
                        pred, net, w = self.sesh.run(fetches, feed_dict)
                        ac.CAM(net, w, pred, x, y, dirr, 'Test', rd)
                        if rd == 0:
                            pdx = pred
                            yl = y
                        else:
                            pdx = np.concatenate((pdx, pred), axis=0)
                            yl = np.concatenate((yl, y), axis=None)
                        rd += 1
                    except tf.errors.OutOfRangeError:
                        ac.metrics(pdx, yl, dirr, 'Test')
                        break
        else:
            itr, img, ph = X.data(Not_Realtest=False)
            next_element = itr.get_next()
            with tf.Session() as sessa:
                sessa.run(itr.initializer, feed_dict={ph: img})
                while True:
                    try:
                        x = sessa.run(next_element)
                        feed_dict = {self.x_in: x, self.is_train: train_status}
                        fetches = [self.pred, self.net, self.w]
                        pred, net, w = self.sesh.run(fetches, feed_dict)
                        # ac.CAM_R(net, w, pred, x, dirr, 'Test', rd)
                        if rd == 0:
                            pdx = pred
                        else:
                            pdx = np.concatenate((pdx, pred), axis=0)
                        rd += 1
                    except tf.errors.OutOfRangeError:
                        ac.realout(pdx, dirr, 'Test')
                        break

        now = datetime.now().isoformat()[11:]
        print("------- Testing end: {} -------\n".format(now), flush=True)
Пример #3
0
    def train(self,
              X,
              dirr,
              max_iter=np.inf,
              max_epochs=np.inf,
              cross_validate=True,
              verbose=True,
              save=True,
              outdir="./out"):

        if save:
            saver = tf.train.Saver(tf.global_variables(), max_to_keep=None)

        try:
            err_train = 0
            now = datetime.now().isoformat()[11:]
            print("------- Training begin: {} -------\n".format(now),
                  flush=True)

            x_list, y_list, nums = X.next_batch()
            with tf.Session() as sessa:
                # Initialize all global and local variables
                init_op = tf.group(tf.global_variables_initializer(),
                                   tf.local_variables_initializer())
                sessa.run(init_op)
                # Create a coordinator and run all QueueRunner objects
                coord = tf.train.Coordinator()
                threads = tf.train.start_queue_runners(coord=coord, sess=sessa)

                while True:
                    try:
                        x, y = sessa.run([x_list, y_list])
                        x = x.astype(np.uint8)

                        feed_dict = {
                            self.x_in: x,
                            self.y_in: y,
                            self.dropout_: self.dropout
                        }

                        fetches = [
                            self.merged_summary, self.logits, self.pred,
                            self.pred_cost, self.global_step, self.train_op
                        ]

                        summary, logits, pred, cost, i, _ = self.sesh.run(
                            fetches, feed_dict)

                        self.train_logger.add_summary(summary, i)
                        err_train += cost

                        if i % 1000 == 0 and verbose:
                            print("round {} --> cost: ".format(i),
                                  cost,
                                  flush=True)

                        elif i == max_iter and verbose:
                            print("round {} --> cost: ".format(i),
                                  cost,
                                  flush=True)

                        if i % 1000 == 0 and verbose:  # and i >= 10000:

                            if cross_validate:
                                xv, yv = sessa.run([x_list, y_list])
                                xv = xv.astype(np.uint8)

                                feed_dict = {self.x_in: xv, self.y_in: yv}
                                fetches = [self.pred_cost, self.merged_summary]
                                valid_cost, valid_summary = self.sesh.run(
                                    fetches, feed_dict)

                                self.valid_logger.add_summary(valid_summary, i)

                                print("round {} --> CV cost: ".format(i),
                                      valid_cost,
                                      flush=True)

                        if i == max_iter - int(
                                i / 1000) - 2 and verbose:  # and i >= 10000:

                            if cross_validate:
                                now = datetime.now().isoformat()[11:]
                                print("------- Validation begin: {} -------\n".
                                      format(now),
                                      flush=True)
                                xv, yv = sessa.run([x_list, y_list])
                                xv = xv.astype(np.uint8)

                                feed_dict = {self.x_in: xv, self.y_in: yv}
                                fetches = [
                                    self.pred_cost, self.merged_summary,
                                    self.pred, self.net, self.w
                                ]
                                valid_cost, valid_summary, pred, net, w = self.sesh.run(
                                    fetches, feed_dict)

                                self.valid_logger.add_summary(valid_summary, i)

                                print("round {} --> Last CV cost: ".format(i),
                                      valid_cost,
                                      flush=True)
                                ac.CAM(net, w, pred, xv, yv, dirr,
                                       'Validation')
                                ac.metrics(pred, yv, dirr, 'Validation')
                                now = datetime.now().isoformat()[11:]
                                print("------- Validation end: {} -------\n".
                                      format(now),
                                      flush=True)

                        # if i%50000 == 0 and save:
                        #     interfile=os.path.join(os.path.abspath(outdir), "{}_cnn_{}".format(
                        #             self.datetime, "_".join(map(str, self.input_dim))))
                        #     saver.save(self.sesh, interfile, global_step=self.step)

                    except tf.errors.OutOfRangeError:
                        # Stop the threads
                        coord.request_stop()

                        # Wait for threads to stop
                        coord.join(threads)
                        sessa.close()

                        print(
                            "final avg cost (@ step {} = epoch {}): {}".format(
                                i, np.around(i / nums * self.batch_size),
                                err_train / i),
                            flush=True)

                        now = datetime.now().isoformat()[11:]
                        print("------- Training end: {} -------\n".format(now),
                              flush=True)

                        if save:
                            outfile = os.path.join(
                                os.path.abspath(outdir),
                                "inceptionres1_{}".format("_".join(
                                    ['dropout', str(self.dropout)])))
                            saver.save(self.sesh, outfile, global_step=None)
                        try:
                            self.train_logger.flush()
                            self.train_logger.close()
                            self.valid_logger.flush()
                            self.valid_logger.close()

                        except (AttributeError):  # not logging
                            print('Not logging', flush=True)

                        break

        except (KeyboardInterrupt):

            print("final avg cost (@ step {} = epoch {}): {}".format(
                i, np.around(i / nums * self.batch_size), err_train / i),
                  flush=True)

            now = datetime.now().isoformat()[11:]
            print("------- Training end: {} -------\n".format(now), flush=True)

            if save:
                outfile = os.path.join(
                    os.path.abspath(outdir), "inceptionres1_{}".format(
                        "_".join(['dropout', str(self.dropout)])))
                saver.save(self.sesh, outfile, global_step=None)
            try:
                self.train_logger.flush()
                self.train_logger.close()
                self.valid_logger.flush()
                self.valid_logger.close()

            except (AttributeError):  # not logging
                print('Not logging', flush=True)

            sys.exit(0)
Пример #4
0
    def train(self,
              X,
              ct,
              bs,
              dirr,
              max_iter=np.inf,
              cross_validate=True,
              verbose=True,
              save=True,
              outdir="./out"):
        start_time = time.time()
        if save:
            saver = tf.train.Saver(tf.global_variables(), max_to_keep=None)

        try:
            err_train = 0
            now = datetime.now().isoformat()[11:]
            print("------- Training begin: {} -------\n".format(now),
                  flush=True)
            itr, file, ph = X.data()
            next_element = itr.get_next()
            with tf.Session() as sessa:
                sessa.run(itr.initializer, feed_dict={ph: file})
                while True:
                    try:
                        x, y = sessa.run(next_element)

                        feed_dict = {
                            self.x_in: x,
                            self.y_in: y,
                            self.dropout_: self.dropout
                        }

                        fetches = [
                            self.merged_summary, self.logits, self.pred,
                            self.pred_cost, self.global_step, self.train_op
                        ]

                        summary, logits, pred, cost, i, _ = self.sesh.run(
                            fetches, feed_dict)

                        self.train_logger.add_summary(summary, i)
                        err_train += cost

                        if i % 1000 == 0 and verbose:
                            print("round {} --> cost: ".format(i),
                                  cost,
                                  flush=True)

                            if cross_validate:
                                xv, yv = sessa.run(next_element)

                                feed_dict = {self.x_in: xv, self.y_in: yv}
                                fetches = [self.pred_cost, self.merged_summary]
                                valid_cost, valid_summary = self.sesh.run(
                                    fetches, feed_dict)

                                self.valid_logger.add_summary(valid_summary, i)

                                print("round {} --> CV cost: ".format(i),
                                      valid_cost,
                                      flush=True)

                        if i == max_iter - int(
                                i / 1000) - 2 and verbose:  # and i >= 10000:

                            if cross_validate:
                                now = datetime.now().isoformat()[11:]
                                print("------- Validation begin: {} -------\n".
                                      format(now),
                                      flush=True)
                                xv, yv = sessa.run(next_element)

                                feed_dict = {self.x_in: xv, self.y_in: yv}
                                fetches = [
                                    self.pred_cost, self.merged_summary,
                                    self.pred, self.net, self.w
                                ]
                                valid_cost, valid_summary, pred, net, w = self.sesh.run(
                                    fetches, feed_dict)

                                self.valid_logger.add_summary(valid_summary, i)

                                print("round {} --> Last CV cost: ".format(i),
                                      valid_cost,
                                      flush=True)
                                ac.CAM(net, w, pred, xv, yv, dirr,
                                       'Validation')
                                ac.metrics(pred, yv, dirr, 'Validation')
                                now = datetime.now().isoformat()[11:]
                                print("------- Validation end: {} -------\n".
                                      format(now),
                                      flush=True)

                        # if i%50000 == 0 and save:
                        #     interfile=os.path.join(os.path.abspath(outdir), "{}_cnn_{}".format(
                        #             self.datetime, "_".join(map(str, self.input_dim))))
                        #     saver.save(self.sesh, interfile, global_step=self.step)

                    except tf.errors.OutOfRangeError:
                        print(
                            "final avg cost (@ step {} = epoch {}): {}".format(
                                i + 1, np.around(i / ct * bs), err_train / i),
                            flush=True)

                        now = datetime.now().isoformat()[11:]
                        print("------- Training end: {} -------\n".format(now),
                              flush=True)

                        if save:
                            outfile = os.path.join(
                                os.path.abspath(outdir),
                                "inceptionres2_{}".format("_".join(
                                    ['dropout', str(self.dropout)])))
                            saver.save(self.sesh, outfile, global_step=None)
                        try:
                            self.train_logger.flush()
                            self.train_logger.close()
                            self.valid_logger.flush()
                            self.valid_logger.close()

                        except (AttributeError):  # not logging
                            print('Not logging', flush=True)

                        break
            print("--- %s seconds ---" % (time.time() - start_time))

        except (KeyboardInterrupt):

            print("final avg cost (@ step {} = epoch {}): {}".format(
                i, np.around(i / ct * bs), err_train / i),
                  flush=True)

            now = datetime.now().isoformat()[11:]
            print("------- Training end: {} -------\n".format(now), flush=True)

            if save:
                outfile = os.path.join(
                    os.path.abspath(outdir), "inceptionres2_{}".format(
                        "_".join(['dropout', str(self.dropout)])))
                saver.save(self.sesh, outfile, global_step=None)
            try:
                self.train_logger.flush()
                self.train_logger.close()
                self.valid_logger.flush()
                self.valid_logger.close()

            except (AttributeError):  # not logging
                print('Not logging', flush=True)

            sys.exit(0)
Пример #5
0
    def train(self, X, VAX, ct, bs, dirr, pmd, max_iter=np.inf, cross_validate=True, verbose=True, save=True, outdir="./out"):
        start_time = time.time()
        svs = 0
        if save:
            saver = tf.train.Saver(tf.global_variables(), max_to_keep=None)

        try:
            err_train = 0
            now = datetime.now().isoformat()[11:]
            print("------- Training begin: {} -------\n".format(now), flush=True)
            itr, file, ph = X.data()
            next_element = itr.get_next()

            vaitr, vafile, vaph = VAX.data(train=False)
            vanext_element = vaitr.get_next()

            with tf.Session() as sessa:
                sessa.run(itr.initializer, feed_dict={ph: file})
                sessa.run(vaitr.initializer, feed_dict={vaph: vafile})
                train_cost = []
                validation_cost = []
                valid_cost = 0
                while True:
                    try:
                        x, y = sessa.run(next_element)

                        feed_dict = {self.x_in: x, self.y_in: y}

                        fetches = [self.merged_summary, self.logits, self.pred,
                                   self.pred_cost, self.global_step, self.train_op]

                        summary, logits, pred, cost, i, _ = self.sesh.run(fetches, feed_dict)

                        self.train_logger.add_summary(summary, i)
                        err_train += cost

                        if i < 2:
                            train_cost.append(cost)

                        try:
                            mintrain = min(train_cost)
                        except ValueError:
                            mintrain = 0

                        if cost <= mintrain and i > 29999:
                            if cross_validate:
                                temp_valid = []
                                for iii in range(10):
                                    x, y = sessa.run(vanext_element)
                                    feed_dict = {self.x_in: x, self.y_in: y, self.is_train: False}
                                    fetches = [self.pred_cost, self.merged_summary]
                                    valid_cost, valid_summary = self.sesh.run(fetches, feed_dict)
                                    self.valid_logger.add_summary(valid_summary, i)
                                    temp_valid.append(valid_cost)

                                tempminvalid = np.mean(temp_valid)
                                try:
                                    minvalid = min(validation_cost)
                                except ValueError:
                                    minvalid = 0

                                if tempminvalid <= minvalid:
                                    train_cost.append(cost)
                                    print("round {} --> loss: ".format(i), cost, flush=True)
                                    print("round {} --> validation loss: ".format(i), tempminvalid, flush=True)
                                    print("New Min loss model found!")
                                    validation_cost.append(tempminvalid)
                                    if save:
                                        outfile = os.path.join(os.path.abspath(outdir),
                                                               "{}_{}".format(self.model,
                                                                              "_".join(['dropout', str(self.dropout)])))
                                        saver.save(self.sesh, outfile, global_step=None)
                                        svs = i

                            else:
                                train_cost.append(cost)
                                print("round {} --> loss: ".format(i), cost, flush=True)
                                print("New Min loss model found!")
                                if save:
                                    outfile = os.path.join(os.path.abspath(outdir),
                                                           "{}_{}".format(self.model,
                                                                          "_".join(['dropout', str(self.dropout)])))
                                    saver.save(self.sesh, outfile, global_step=None)
                                    svs = i
                        else:
                            train_cost.append(cost)

                        if i % 1000 == 0 and verbose:
                            print("round {} --> loss: ".format(i), cost, flush=True)
                            if cross_validate:
                                temp_valid = []
                                for iii in range(100):
                                    x, y = sessa.run(vanext_element)
                                    feed_dict = {self.x_in: x, self.y_in: y, self.is_train: False}
                                    fetches = [self.pred_cost, self.merged_summary]
                                    valid_cost, valid_summary = self.sesh.run(fetches, feed_dict)
                                    self.valid_logger.add_summary(valid_summary, i)
                                    temp_valid.append(valid_cost)
                                tempminvalid = np.mean(temp_valid)
                                try:
                                    minvalid = min(validation_cost)
                                except ValueError:
                                    minvalid = 0
                                validation_cost.append(tempminvalid)
                                print("round {} --> Step Average validation loss: ".format(i), tempminvalid, flush=True)

                                if save and tempminvalid <= minvalid:
                                    print("New Min loss model found!")
                                    print("round {} --> loss: ".format(i), cost, flush=True)
                                    outfile = os.path.join(os.path.abspath(outdir),
                                                           "{}_{}".format(self.model,
                                                                          "_".join(['dropout', str(self.dropout)])))
                                    saver.save(self.sesh, outfile, global_step=None)
                                    svs = i

                                if i > 79999:
                                    valid_mean_cost = np.mean(validation_cost[-10:-1])
                                    print('Mean validation loss: {}'.format(valid_mean_cost))
                                    if valid_cost > valid_mean_cost:
                                        print("Early stopped! No improvement for at least 10000 iterations")
                                        break
                                    else:
                                        print("Passed early stopping evaluation. Continue training!")

                        if i >= max_iter-2 and verbose:

                            if cross_validate:
                                print("final avg loss (@ step {} = epoch {}): {}".format(
                                    i + 1, np.around(i / ct * bs), err_train / i), flush=True)

                                now = datetime.now().isoformat()[11:]
                                print("------- Training end: {} -------\n".format(now), flush=True)

                                now = datetime.now().isoformat()[11:]
                                print("------- Final Validation begin: {} -------\n".format(now), flush=True)
                                x, y = sessa.run(vanext_element)
                                feed_dict = {self.x_in: x, self.y_in: y, self.is_train: False}
                                fetches = [self.pred_cost, self.merged_summary]
                                valid_cost, valid_summary= self.sesh.run(fetches, feed_dict)

                                self.valid_logger.add_summary(valid_summary, i)
                                print("round {} --> Final Last validation loss: ".format(i), valid_cost, flush=True)
                                now = datetime.now().isoformat()[11:]
                                print("------- Final Validation end: {} -------\n".format(now), flush=True)
                            try:
                                self.train_logger.flush()
                                self.train_logger.close()
                                self.valid_logger.flush()
                                self.valid_logger.close()

                            except(AttributeError):  # not logging
                                print('Not logging', flush=True)
                            break

                    except tf.errors.OutOfRangeError:
                        if cross_validate:
                            print("final avg loss (@ step {} = epoch {}): {}".format(
                                i + 1, np.around(i / ct * bs), err_train / i), flush=True)

                            now = datetime.now().isoformat()[11:]
                            print("------- Training end: {} -------\n".format(now), flush=True)

                            now = datetime.now().isoformat()[11:]
                            print("------- Final Validation begin: {} -------\n".format(now), flush=True)
                            x, y = sessa.run(vanext_element)
                            feed_dict = {self.x_in: x, self.y_in: y, self.is_train: False}
                            fetches = [self.pred_cost, self.merged_summary, self.pred, self.net, self.w]
                            valid_cost, valid_summary, pred, net, w = self.sesh.run(fetches, feed_dict)

                            self.valid_logger.add_summary(valid_summary, i)
                            print("round {} --> Final Last validation loss: ".format(i), valid_cost, flush=True)
                            ac.CAM(net, w, pred, x, y, dirr, 'Validation', bs, pmd)
                            ac.metrics(pred, y, dirr, 'Validation', pmd)
                            now = datetime.now().isoformat()[11:]
                            print("------- Final Validation end: {} -------\n".format(now), flush=True)

                        try:
                            self.train_logger.flush()
                            self.train_logger.close()
                            self.valid_logger.flush()
                            self.valid_logger.close()

                        except(AttributeError):  # not logging
                            print('Not logging', flush=True)

                        break
                try:
                    print("final avg loss (@ step {} = epoch {}): {}".format(
                        i + 1, np.around(i / ct * bs), err_train / i), flush=True)

                    now = datetime.now().isoformat()[11:]
                    print("------- Training end: {} -------\n".format(now), flush=True)

                    if svs < 15000 and save:
                            print("Save the last model as the best model.")
                            outfile = os.path.join(os.path.abspath(outdir),
                                                   "{}_{}".format(self.model, "_".join(['dropout', str(self.dropout)])))
                            saver.save(self.sesh, outfile, global_step=None)

                    if cross_validate:
                        now = datetime.now().isoformat()[11:]
                        print("------- Validation begin: {} -------\n".format(now), flush=True)
                        x, y = sessa.run(vanext_element)
                        feed_dict = {self.x_in: x, self.y_in: y, self.is_train: False}
                        fetches = [self.pred_cost, self.merged_summary, self.pred, self.net, self.w]
                        valid_cost, valid_summary, pred, net, w = self.sesh.run(fetches, feed_dict)

                        self.valid_logger.add_summary(valid_summary, i)
                        print("round {} --> Last validation loss: ".format(i), valid_cost, flush=True)
                        ac.CAM(net, w, pred, x, y, dirr, 'Validation', bs, pmd)
                        ac.metrics(pred, y, dirr, 'Validation', pmd)
                        now = datetime.now().isoformat()[11:]
                        print("------- Validation end: {} -------\n".format(now), flush=True)

                    try:
                        self.train_logger.flush()
                        self.train_logger.close()
                        self.valid_logger.flush()
                        self.valid_logger.close()

                    except(AttributeError):  # not logging
                        print('Not logging', flush=True)

                except tf.errors.OutOfRangeError:
                    print("final avg loss (@ step {} = epoch {}): {}".format(
                        i + 1, np.around(i / ct * bs), err_train / i), flush=True)

                    now = datetime.now().isoformat()[11:]
                    print("------- Training end: {} -------\n".format(now), flush=True)
                    print('No more validation needed!')

            print("--- %s seconds ---" % (time.time() - start_time))

        except(KeyboardInterrupt):

            print("final avg loss (@ step {} = epoch {}): {}".format(
                i, np.around(i / ct * bs), err_train / i), flush=True)

            now = datetime.now().isoformat()[11:]
            print("------- Training end: {} -------\n".format(now), flush=True)

            if save:
                outfile = os.path.join(os.path.abspath(outdir),
                                       "{}_{}".format(self.model, "_".join(['dropout', str(self.dropout)])))
                saver.save(self.sesh, outfile, global_step=None)
            try:
                self.train_logger.flush()
                self.train_logger.close()
                self.valid_logger.flush()
                self.valid_logger.close()

            except(AttributeError):  # not logging
                print('Not logging', flush=True)

            sys.exit(0)