Пример #1
0
 def __init__(self, **kwargs):
     self.__dict__.update(self._defaults)  # set up default values
     self.__dict__.update(kwargs)  # and update with user overrides
     self.class_names = self._get_class()
     self.anchors = self._get_anchors()
     self.alpha = 1.4
     config = tf.ConfigProto()
     tf.keras.backend.set_learning_phase(0)
     if self.opt == OPT.XLA:
         config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
         sess = tf.Session(config=config)
         tf.keras.backend.set_session(sess)
     elif self.opt == OPT.MKL:
         config.intra_op_parallelism_threads = 4
         config.inter_op_parallelism_threads = 4
         sess = tf.Session(config=config)
         tf.keras.backend.set_session(sess)
     elif self.opt == OPT.DEBUG:
         sess = tf_debug.TensorBoardDebugWrapperSession(
             tf.get_session(), "fangsixie-Inspiron-7572:6064")
         tf.keras.backend.set_session(sess)
     else:
         sess = tf.get_session()
     self.sess = sess
     if tf.executing_eagerly():
         self.generate()
     else:
         self.boxes, self.scores, self.classes = self.generate()
Пример #2
0
def set_value(x, val):
    """
    Get parameter value from a shared variable.
    """
    if is_theano():
        x.set_value(val)
    elif is_cgt():
        x.op.set_value(val)
    elif is_tf():
        tf.get_session().run(tf.assign(x, val))
    else:
        import ipdb; ipdb.set_trace()
Пример #3
0
def set_value(x, val):
    """
    Get parameter value from a shared variable.
    """
    if is_theano():
        x.set_value(val)
    elif is_cgt():
        x.op.set_value(val)
    elif is_tf():
        tf.get_session().run(tf.assign(x, val))
    else:
        import ipdb
        ipdb.set_trace()
Пример #4
0
    def __init__(self, **kwargs):
        self.__dict__.update(self._defaults)  # set up default values
        self.class_names = get_classes(
            self.model_config[self.backbone]['classes_path'])
        self.anchors = get_anchors(
            self.model_config[self.backbone]['anchors_path'])
        self.input_shape = self.model_config[self.backbone]['input_size']
        config = tf.ConfigProto()

        if self.opt == OPT.XLA:
            config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
            sess = tf.Session(config=config)
            tf.keras.backend.set_session(sess)
        elif self.opt == OPT.MKL:
            config.intra_op_parallelism_threads = 4
            config.inter_op_parallelism_threads = 4
            sess = tf.Session(config=config)
            tf.keras.backend.set_session(sess)
        elif self.opt == OPT.DEBUG:
            tf.logging.set_verbosity(tf.logging.DEBUG)
            sess = tf_debug.TensorBoardDebugWrapperSession(
                tf.Session(config=tf.ConfigProto(log_device_placement=True)),
                "localhost:6064")
            tf.keras.backend.set_session(sess)
        else:
            sess = tf.get_session()
        self.sess = sess
        self.generate()
Пример #5
0
    print(X_test.shape)
    print(ecal_train.shape)
    print(ecal_test.shape)
    print('*************************************************************************************')
    train_history = defaultdict(list)
    test_history = defaultdict(list)

    # Broadcast initial variable states from rank 0 to all other processes.
    # This is necessary to ensure consistent initialization of all workers when
    # training is started with random weights or restored from a checkpoint.
    #callbacks = [
    #    hvd.callbacks.BroadcastGlobalVariablesCallback(0),
    #]

    bcast_op = hvd.broadcast_global_variables(0)
    tf.get_session().run(bcast_op)

    for epoch in range(nb_epochs):
        print('Epoch {} of {}'.format(epoch + 1, nb_epochs))

        nb_batches = int(X_train.shape[0] / batch_size)
        if verbose:
            progress_bar = Progbar(target=nb_batches)

        epoch_gen_loss = []
        epoch_disc_loss = []
        for index in range(nb_batches):
            if verbose:
                progress_bar.update(index)
            else:
                if index % 100 == 0: