コード例 #1
0
ファイル: utils.py プロジェクト: 812864539/models
def variables_to_restore(scope=None, strip_scope=False):
  """Returns a list of variables to restore for the specified list of methods.

  It is supposed that variable name starts with the method's scope (a prefix
  returned by _method_scope function).

  Args:
    methods_names: a list of names of configurable methods.
    strip_scope: if True will return variable names without method's scope.
      If methods_names is None will return names unchanged.
    model_scope: a scope for a whole model.

  Returns:
    a dictionary mapping variable names to variables for restore.
  """
  if scope:
    variable_map = {}
    method_variables = slim.get_variables_to_restore(include=[scope])
    for var in method_variables:
      if strip_scope:
        var_name = var.op.name[len(scope) + 1:]
      else:
        var_name = var.op.name
      variable_map[var_name] = var

    return variable_map
  else:
    return {v.op.name: v for v in slim.get_variables_to_restore()}
コード例 #2
0
ファイル: nav_utils.py プロジェクト: 812864539/models
def get_repr_from_image(images_reshaped, modalities, data_augment, encoder,
                        freeze_conv, wt_decay, is_training):
  # Pass image through lots of convolutional layers, to obtain pool5
  if modalities == ['rgb']:
    with tf.name_scope('pre_rgb'):
      x = (images_reshaped + 128.) / 255. # Convert to brightness between 0 and 1.
      if data_augment.relight and is_training:
        x = tf_utils.distort_image(x, fast_mode=data_augment.relight_fast)
      x = (x-0.5)*2.0
    scope_name = encoder
  elif modalities == ['depth']:
    with tf.name_scope('pre_d'):
      d_image = images_reshaped
      x = 2*(d_image[...,0] - 80.0)/100.0
      y = d_image[...,1]
      d_image = tf.concat([tf.expand_dims(x, -1), tf.expand_dims(y, -1)], 3)
      x = d_image
    scope_name = 'd_'+encoder

  resnet_is_training = is_training and (not freeze_conv)
  with slim.arg_scope(resnet_v2.resnet_utils.resnet_arg_scope(resnet_is_training)):
    fn = getattr(tf_utils, encoder)
    x, end_points = fn(x, num_classes=None, global_pool=False,
                       output_stride=None, reuse=None,
                       scope=scope_name)
  vars_ = slim.get_variables_to_restore()

  conv_feat = x
  return conv_feat, vars_
コード例 #3
0
def _create_image_encoder(preprocess_fn, factory_fn, image_shape, batch_size=32,
                         session=None, checkpoint_path=None,
                         loss_mode="cosine"):
    image_var = tf.placeholder(tf.uint8, (None, ) + image_shape)

    preprocessed_image_var = tf.map_fn(
        lambda x: preprocess_fn(x, is_training=False),
        tf.cast(image_var, tf.float32))

    l2_normalize = loss_mode == "cosine"
    feature_var, _ = factory_fn(
        preprocessed_image_var, l2_normalize=l2_normalize, reuse=None)
    feature_dim = feature_var.get_shape().as_list()[-1]

    if session is None:
        session = tf.Session()
    if checkpoint_path is not None:
        slim.get_or_create_global_step()
        init_assign_op, init_feed_dict = slim.assign_from_checkpoint(
            checkpoint_path, slim.get_variables_to_restore())
        session.run(init_assign_op, feed_dict=init_feed_dict)

    def encoder(data_x):
        out = np.zeros((len(data_x), feature_dim), np.float32)
        _run_in_batches(
            lambda x: session.run(feature_var, feed_dict=x),
            {image_var: data_x}, out, batch_size)
        return out

    return encoder
コード例 #4
0
ファイル: model.py プロジェクト: ALISCIFP/models
def build_inceptionv3_graph(images, endpoint, is_training, checkpoint,
                            reuse=False):
  """Builds an InceptionV3 model graph.

  Args:
    images: A 4-D float32 `Tensor` of batch images.
    endpoint: String, name of the InceptionV3 endpoint.
    is_training: Boolean, whether or not to build a training or inference graph.
    checkpoint: String, path to the pretrained model checkpoint.
    reuse: Boolean, whether or not we are reusing the embedder.
  Returns:
    inception_output: `Tensor` holding the InceptionV3 output.
    inception_variables: List of inception variables.
    init_fn: Function to initialize the weights (if not reusing, then None).
  """
  with slim.arg_scope(inception.inception_v3_arg_scope()):
    _, endpoints = inception.inception_v3(
        images, num_classes=1001, is_training=is_training)
    inception_output = endpoints[endpoint]
    inception_variables = slim.get_variables_to_restore()
    inception_variables = [
        i for i in inception_variables if 'global_step' not in i.name]
    if is_training and not reuse:
      init_saver = tf.train.Saver(inception_variables)
      def init_fn(scaffold, sess):
        del scaffold
        init_saver.restore(sess, checkpoint)
    else:
      init_fn = None
    return inception_output, inception_variables, init_fn
コード例 #5
0
ファイル: mobilenet_v2.py プロジェクト: Sanster/tf_ctpn
    def _image_to_head(self, is_training, reuse=None):
        with slim.arg_scope(mobilenet_v2.training_scope(is_training=is_training)):
            net, endpoints = mobilenet_v2.mobilenet_base(self._image, conv_defs=CTPN_DEF)

        self.variables_to_restore = slim.get_variables_to_restore()

        self._act_summaries.append(net)
        self._layers['head'] = net

        return net
コード例 #6
0
ファイル: utils.py プロジェクト: beacandler/tf-slim-demo
def variabels_to_restore(scope=None, strip_scope=False):
    """Returns a list of variabels to restore for the specified list method.

    It is supposed that variable name starts with the method's scope (a prefix
    returned by _method_scope function.)

    Args:
        scope: a scope for a whole model.
        strip_scope: If True will return variable names without method's scope.
    """
    if scope:
        variable_map = {}
        variables_to_restore = slim.get_variables_to_restore(include=[scope])
        for var in variables_to_restore:
            if strip_scope:
                var_name = var.op.name[len(scope) + 1:]
            else:
                var_name = var.op.name
            variable_map[var_name] = var
        return variable_map
    else:
        return {var.op.name: var for var in slim.get_variables_to_restore()}
コード例 #7
0
    def evaluate(self,
            eval_config):
            """Runs Evaluation ops on examples

               Args:
                eval_config: protobuf config for evaluation
            """
            if not isinstance(eval_config, eval_pb2.EvalConfig):
                raise ValueError('train_config not type'
                                 'train_pb2.TrainConfig.')

            init_local = tf.local_variables_initializer()

            ckpt = tf.train.get_checkpoint_state(eval_config.\
                                                    checkpoint_to_restore)
            if not ckpt or not ckpt.model_checkpoint_path:
                raise ValueError("checkpoint to restore does not exist \
                                    or is not a valid checkpoint")

            print("TENSORFLOW INFO: Restoring from %s" % ckpt.model_checkpoint_path)
            with tf.Session() as sess:
                init_local.run()
                exclude_list = ["eval/total_loss"]
                vars_to_restore = slim.get_variables_to_restore(exclude = exclude_list)
                saver = tf.train.Saver(var_list = vars_to_restore) #TODO maybe allow var_list option?
                saver.restore(sess,ckpt.model_checkpoint_path)

                #create and start Queue runners
                for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
                    self._threads.extend(
                                        qr.create_threads(sess = sess,
                                                          coord = self._coord,
                                                          daemon = True,
                                                          start = True))
                try:

                    while not self._coord.should_stop():
                        results = sess.run(self._eval_ops)
                        print(self._eval_fmt_str % tuple(results))

                        if not self._is_batch_evaluation:
                            break

                except Exception as e:
                    print('TENSORFLOW INFO: Evaluation failed with ',e)

                finally:
                    self._coord.request_stop()
                    self._coord.join(self._threads,
                                     stop_grace_period_secs = 10)
コード例 #8
0
ファイル: demo_detect.py プロジェクト: happog/yolo-tf
def main():
    model = config.get('config', 'model')
    cachedir = utils.get_cachedir(config)
    with open(os.path.join(cachedir, 'names'), 'r') as f:
        names = [line.strip() for line in f]
    width = config.getint(model, 'width')
    height = config.getint(model, 'height')
    yolo = importlib.import_module('model.' + model)
    cell_width, cell_height = utils.calc_cell_width_height(config, width, height)
    tf.logging.info('(width, height)=(%d, %d), (cell_width, cell_height)=(%d, %d)' % (width, height, cell_width, cell_height))
    with tf.Session() as sess:
        paths = [os.path.join(cachedir, profile + '.tfrecord') for profile in args.profile]
        num_examples = sum(sum(1 for _ in tf.python_io.tf_record_iterator(path)) for path in paths)
        tf.logging.warn('num_examples=%d' % num_examples)
        image_rgb, labels = utils.data.load_image_labels(paths, len(names), width, height, cell_width, cell_height, config)
        image_std = tf.image.per_image_standardization(image_rgb)
        image_rgb = tf.cast(image_rgb, tf.uint8)
        ph_image = tf.placeholder(image_std.dtype, [1] + image_std.get_shape().as_list(), name='ph_image')
        global_step = tf.contrib.framework.get_or_create_global_step()
        builder = yolo.Builder(args, config)
        builder(ph_image)
        variables_to_restore = slim.get_variables_to_restore()
        ph_labels = [tf.placeholder(l.dtype, [1] + l.get_shape().as_list(), name='ph_' + l.op.name) for l in labels]
        with tf.name_scope('total_loss') as name:
            builder.create_objectives(ph_labels)
            total_loss = tf.losses.get_total_loss(name=name)
        tf.global_variables_initializer().run()
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess, coord)
        _image_rgb, _image_std, _labels = sess.run([image_rgb, image_std, labels])
        coord.request_stop()
        coord.join(threads)
        feed_dict = dict([(ph, np.expand_dims(d, 0)) for ph, d in zip(ph_labels, _labels)])
        feed_dict[ph_image] = np.expand_dims(_image_std, 0)
        logdir = utils.get_logdir(config)
        assert os.path.exists(logdir)
        model_path = tf.train.latest_checkpoint(logdir)
        tf.logging.info('load ' + model_path)
        slim.assign_from_checkpoint_fn(model_path, variables_to_restore)(sess)
        tf.logging.info('global_step=%d' % sess.run(global_step))
        tf.logging.info('total_loss=%f' % sess.run(total_loss, feed_dict))
        _ = Drawer(sess, names, builder.model.cell_width, builder.model.cell_height, _image_rgb, _labels, builder.model, feed_dict)
        plt.show()
コード例 #9
0
ファイル: model.py プロジェクト: ALISCIFP/models
  def build_pretrained_graph(
      self, images, resnet_layer, checkpoint, is_training, reuse=False):
    """See baseclass."""
    with slim.arg_scope(resnet_v2.resnet_arg_scope()):
      _, endpoints = resnet_v2.resnet_v2_50(
          images, is_training=is_training, reuse=reuse)
      resnet_layer = 'resnet_v2_50/block%d' % resnet_layer
      resnet_output = endpoints[resnet_layer]
      resnet_variables = slim.get_variables_to_restore()
      resnet_variables = [
          i for i in resnet_variables if 'global_step' not in i.name]
      if is_training and not reuse:
        init_saver = tf.train.Saver(resnet_variables)
        def init_fn(scaffold, sess):
          del scaffold
          init_saver.restore(sess, checkpoint)
      else:
        init_fn = None

      return resnet_output, resnet_variables, init_fn
コード例 #10
0
def main():
    args = parse_args()

    with tf.Session(graph=tf.Graph()) as session:
        input_var = tf.placeholder(
            tf.uint8, (None, 128, 64, 3), name="images")
        image_var = tf.map_fn(
            lambda x: _preprocess(x), tf.cast(input_var, tf.float32),
            back_prop=False)

        factory_fn = _network_factory()
        features, _ = factory_fn(image_var, reuse=None)
        features = tf.identity(features, name="features")

        saver = tf.train.Saver(slim.get_variables_to_restore())
        saver.restore(session, args.checkpoint_in)

        output_graph_def = tf.graph_util.convert_variables_to_constants(
            session, tf.get_default_graph().as_graph_def(),
            [features.name.split(":")[0]])
        with tf.gfile.GFile(args.graphdef_out, "wb") as file_handle:
            file_handle.write(output_graph_def.SerializeToString())
コード例 #11
0
ファイル: train.py プロジェクト: EkaterinaPogodina/gossipnet
def train(resume, visualize):
    np.random.seed(cfg.random_seed)
    dataset, train_imdb = get_dataset()
    do_val = len(cfg.train.val_imdb) > 0

    class_weights = class_equal_weights(train_imdb)
    (preloaded_batch, enqueue_op, enqueue_placeholders,
     q_size) = setup_preloading(
            Gnet.get_batch_spec(train_imdb['num_classes']))
    reg = tf.contrib.layers.l2_regularizer(cfg.train.weight_decay)
    net = Gnet(num_classes=train_imdb['num_classes'], batch=preloaded_batch,
               weight_reg=reg, class_weights=class_weights)
    lr_gen = LearningRate()
    # reg_ops = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    # reg_op = tf.reduce_mean(reg_ops)
    # optimized_loss = net.loss + reg_op
    optimized_loss = tf.contrib.losses.get_total_loss()
    learning_rate, train_op = get_optimizer(
        optimized_loss, net.trainable_variables)

    val_net = val_imdb = None
    if do_val:
        val_imdb = imdb.get_imdb(cfg.train.val_imdb, is_training=False)
        val_net = Gnet(num_classes=val_imdb['num_classes'], reuse=True)

    with tf.name_scope('summaries'):
        tf.summary.scalar('loss', optimized_loss)
        tf.summary.scalar('data_loss', net.loss)
        tf.summary.scalar('data_loss_normed', net.loss_normed)
        tf.summary.scalar('data_loss_unnormed', net.loss_unnormed)
        tf.summary.scalar('lr', learning_rate)
        tf.summary.scalar('q_size', q_size)
        if cfg.train.histograms:
            tf.summary.histogram('roi_feats', net.roifeats)
            tf.summary.histogram('det_imfeats', net.det_imfeats)
            tf.summary.histogram('pw_feats', net.pw_feats)
            for i, blockout in enumerate(net.block_feats):
                tf.summary.histogram('block{:02d}'.format(i + 1),
                                     blockout)
        merge_summaries_op = tf.summary.merge_all()

    with tf.name_scope('averaging'):
        ema = tf.train.ExponentialMovingAverage(decay=0.7)
        maintain_averages_op = ema.apply(
            [net.loss_normed, net.loss_unnormed, optimized_loss])
        # update moving averages after every loss evaluation
        with tf.control_dependencies([train_op]):
            train_op = tf.group(maintain_averages_op)
        smoothed_loss_normed = ema.average(net.loss_normed)
        smoothed_loss_unnormed = ema.average(net.loss_unnormed)
        smoothed_optimized_loss = ema.average(optimized_loss)

    restorer = ckpt = None
    if resume:
        ckpt = tf.train.get_checkpoint_state('./')
        restorer = tf.train.Saver()
    elif cfg.gnet.imfeats:
        variables_to_restore = slim.get_variables_to_restore(
            include=["resnet_v1"])
        variables_to_exclude = \
            slim.get_variables_by_suffix('Adam_1', scope='resnet_v1') + \
            slim.get_variables_by_suffix('Adam', scope='resnet_v1') + \
            slim.get_variables_by_suffix('Momentum', scope='resnet_v1')
        restorer = tf.train.Saver(
            list(set(variables_to_restore) - set(variables_to_exclude)))

    saver = tf.train.Saver(max_to_keep=None)
    model_manager = ModelManager()
    config = tf.ConfigProto()
    with tf.Session(config=config) as sess:
        train_writer = tf.summary.FileWriter(cfg.log_dir, sess.graph)
        tf.global_variables_initializer().run()
        tf.local_variables_initializer().run()
        coord = start_preloading(
            sess, enqueue_op, dataset, enqueue_placeholders)

        start_iter = 1
        if resume:
            restorer.restore(sess, ckpt.model_checkpoint_path)
            tensor = tf.get_default_graph().get_tensor_by_name("global_step:0")
            start_iter = sess.run(tensor + 1)
        elif cfg.gnet.imfeats:
            restorer.restore(sess, cfg.train.pretrained_model)

        for it in range(start_iter, cfg.train.num_iter + 1):
            if coord.should_stop():
                break

            if visualize:
                # don't do actual training, just visualize data
                visualize_detections(sess, it, learning_rate, lr_gen, net,
                                     train_op)
                continue

            (_, val_total_loss, val_loss_normed, val_loss_unnormed,
             summary) = sess.run(
                [train_op, smoothed_optimized_loss, smoothed_loss_normed,
                 smoothed_loss_unnormed, merge_summaries_op],
                feed_dict={learning_rate: lr_gen.get_lr(it)})
            train_writer.add_summary(summary, it)

            if it % cfg.train.display_iter == 0:
                print(('{}  iter {:6d}   lr {:8g}   opt loss {:8g}     '
                       'data loss normalized {:8g}   '
                       'unnormalized {:8g}').format(
                    datetime.now(), it, lr_gen.get_lr(it), val_total_loss,
                    val_loss_normed, val_loss_unnormed))

            if do_val and it % cfg.train.val_iter == 0:
                print('{}  starting validation'.format(datetime.now()))
                val_map, mc_ap, pc_ap = val_run(sess, val_net, val_imdb)
                print(('{}  iter {:6d}   validation pass:   mAP {:5.1f}   '
                       'multiclass AP {:5.1f}').format(
                      datetime.now(), it, val_map, mc_ap))

                #save_path = saver.save(sess, net.name, global_step=it)
                #print('wrote model to {}'.format(save_path))
                # dump_debug_info(sess, net, it)
                #model_manager.add(it, val_map, save_path)
                #model_manager.print_summary()
                #model_manager.write_link_to_best('./gnet_best')

            #elif it % cfg.train.save_iter == 0 or it == cfg.train.num_iter:
            #    save_path = saver.save(sess, net.name, global_step=it)
            #    print('wrote model to {}'.format(save_path))
            #    # dump_debug_info(sess, net, it)

        coord.request_stop()
        coord.join()
    print('training finished')
    if do_val:
        print('summary of validation performance')
        model_manager.print_summary()
コード例 #12
0
if not tf.gfile.Exists(VALIDATION_LOG_DIR):
    tf.gfile.MakeDirs(VALIDATION_LOG_DIR)

batch_train_set, batch_validation_set, images_num = data_processing.pre_processing(data_set='train', batch_size=BATCH_SIZE)

def get_accuracy(logits, labels):
    correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    return accuracy

with tf.Graph().as_default():
    train_images = tf.placeholder(tf.float32, [BATCH_SIZE, 224, 224, 3])
    train_labels = tf.placeholder(tf.float32, [BATCH_SIZE, len(data_processing.IMG_CLASSES)])
    logits, _ = nets.vgg.vgg_16(inputs=train_images, num_classes=2, is_training=True)
    variables_to_restore = slim.get_variables_to_restore(exclude=['vgg_16/fc8'])
    restorer = tf.train.Saver(variables_to_restore)

    train_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=train_labels))
    validation_loss = train_loss
    learning_rate = 1e-4
    optimizer = tf.train.AdamOptimizer(learning_rate).minimize(train_loss)

    train_accuracy = get_accuracy(logits, train_labels)
    validation_accuracy = train_accuracy
    saver = tf.train.Saver()

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
コード例 #13
0
def model(inputs_image,
          output_grids,
          is_training_mode,
          dropout_keep_prob,
          reuse=False):
    def bottleneck(inputs,
                   out_channels,
                   t=1,
                   stride=1,
                   scope=None,
                   reuse=False):
        net = inputs

        with tf.variable_scope(scope or 'bottleneck', reuse=reuse):
            net = slim.batch_norm(net,
                                  is_training=is_training_mode,
                                  scope='bn-1',
                                  reuse=reuse)
            net = slim.conv2d(net,
                              int(net.shape[-1]) * t, [1, 1],
                              scope='conv-1',
                              reuse=reuse)

            net_ = net
            with tf.variable_scope('depthwise-conv-1', reuse=reuse):
                weights = tf.get_variable(
                    'weights', [3, 1, int(net.shape[-1]), 1],
                    tf.float32,
                    initializer=tf.contrib.layers.xavier_initializer(),
                    trainable=True)
                biases = tf.get_variable('biases', [int(net.shape[-1])],
                                         tf.float32,
                                         initializer=tf.zeros_initializer(),
                                         trainable=True)
                net = tf.nn.depthwise_conv2d(net, weights, [1, 1, 1, 1],
                                             'SAME')
                net = tf.nn.bias_add(net, biases)
                net = tf.nn.relu(net)
            with tf.variable_scope('depthwise-conv-2', reuse=reuse):
                weights = tf.get_variable(
                    'weights', [1, 3, int(net.shape[-1]), 1],
                    tf.float32,
                    initializer=tf.contrib.layers.xavier_initializer(),
                    trainable=True)
                biases = tf.get_variable('biases', [int(net.shape[-1])],
                                         tf.float32,
                                         initializer=tf.zeros_initializer(),
                                         trainable=True)
                net = tf.nn.depthwise_conv2d(net, weights,
                                             [1, stride, stride, 1], 'SAME')
                net = tf.nn.bias_add(net, biases)

                if [int(i) for i in net.shape[1:]
                    ] != [int(i) for i in net_.shape[1:]]:
                    net = tf.nn.relu(net)

            if [int(i) for i in net.shape[1:]
                ] == [int(i) for i in net_.shape[1:]]:
                net = tf.add(net, net_)
                net = tf.nn.relu(net)

            net = slim.conv2d(net,
                              out_channels, [1, 1],
                              scope='conv-out',
                              activation_fn=None,
                              reuse=reuse)
            if [int(i) for i in net.shape[1:]
                ] == [int(i) for i in inputs.shape[1:]]:
                net = tf.add(net, inputs)

        return net

    if not reuse:
        _exclude_params = [v.name for v in slim.get_variables_to_restore()]

    net = tf.image.convert_image_dtype(inputs_image, dtype=tf.float32)
    #net = tf.subtract(net, 0.5)
    #net = tf.multiply(net, 2.0)

    outputs_dict = dict()
    with tf.variable_scope('face-net', reuse=reuse):
        net = slim.batch_norm(net,
                              is_training=is_training_mode,
                              scope='bn-input',
                              reuse=reuse)

        with tf.variable_scope('block-input', reuse=reuse):
            net = slim.conv2d(net, 32, [3, 1], scope='conv-1', reuse=reuse)
            net = slim.conv2d(net,
                              32, [1, 3],
                              stride=2,
                              scope='conv-2',
                              reuse=reuse)
            outputs_dict[(int(net.shape[1]), int(net.shape[2]))] = net

        net = bottleneck(net,
                         16,
                         t=1,
                         stride=1,
                         scope='bottleneck-1-1',
                         reuse=reuse)
        outputs_dict[(int(net.shape[1]), int(net.shape[2]))] = net

        net = bottleneck(net,
                         24,
                         t=6,
                         stride=1,
                         scope='bottleneck-2-1',
                         reuse=reuse)
        outputs_dict[(int(net.shape[1]), int(net.shape[2]))] = net
        net = bottleneck(net,
                         24,
                         t=6,
                         stride=2,
                         scope='bottleneck-2-2',
                         reuse=reuse)
        outputs_dict[(int(net.shape[1]), int(net.shape[2]))] = net

        net = bottleneck(net,
                         32,
                         t=6,
                         stride=1,
                         scope='bottleneck-3-1',
                         reuse=reuse)
        outputs_dict[(int(net.shape[1]), int(net.shape[2]))] = net
        net = bottleneck(net,
                         32,
                         t=6,
                         stride=1,
                         scope='bottleneck-3-2',
                         reuse=reuse)
        outputs_dict[(int(net.shape[1]), int(net.shape[2]))] = net
        net = bottleneck(net,
                         32,
                         t=6,
                         stride=2,
                         scope='bottleneck-3-3',
                         reuse=reuse)
        outputs_dict[(int(net.shape[1]), int(net.shape[2]))] = net

        net = bottleneck(net,
                         64,
                         t=6,
                         stride=1,
                         scope='bottleneck-4-1',
                         reuse=reuse)
        outputs_dict[(int(net.shape[1]), int(net.shape[2]))] = net
        net = bottleneck(net,
                         64,
                         t=6,
                         stride=1,
                         scope='bottleneck-4-2',
                         reuse=reuse)
        outputs_dict[(int(net.shape[1]), int(net.shape[2]))] = net
        net = bottleneck(net,
                         64,
                         t=6,
                         stride=1,
                         scope='bottleneck-4-3',
                         reuse=reuse)
        outputs_dict[(int(net.shape[1]), int(net.shape[2]))] = net
        net = bottleneck(net,
                         64,
                         t=6,
                         stride=2,
                         scope='bottleneck-4-4',
                         reuse=reuse)
        outputs_dict[(int(net.shape[1]), int(net.shape[2]))] = net

        net = bottleneck(net,
                         96,
                         t=6,
                         stride=1,
                         scope='bottleneck-5-1',
                         reuse=reuse)
        outputs_dict[(int(net.shape[1]), int(net.shape[2]))] = net
        net = bottleneck(net,
                         96,
                         t=6,
                         stride=1,
                         scope='bottleneck-5-2',
                         reuse=reuse)
        outputs_dict[(int(net.shape[1]), int(net.shape[2]))] = net
        net = bottleneck(net,
                         96,
                         t=6,
                         stride=2,
                         scope='bottleneck-5-3',
                         reuse=reuse)
        outputs_dict[(int(net.shape[1]), int(net.shape[2]))] = net

        # net = bottleneck(net, 128, t=6, stride=1, scope='bottleneck-6-1', reuse=reuse)
        # net = bottleneck(net, 128, t=6, stride=1, scope='bottleneck-6-2', reuse=reuse)
        # net = bottleneck(net, 128, t=6, stride=2, scope='bottleneck-6-3', reuse=reuse)

        net = bottleneck(net, 96 * 2, t=6, stride=1, scope='bottleneck-6-1')
        outputs_dict[(int(net.shape[1]), int(net.shape[2]))] = net

        net = slim.conv2d(net, 96 * 2 * 4, [1, 1], scope='conv-1', reuse=reuse)
        outputs_dict[(int(net.shape[1]), int(net.shape[2]))] = net

        # classifier
        # avg_pool
        # conv2d 1x1xK
        if not reuse:
            global __model_params
            __model_params = [
                var for var in slim.get_variables_to_restore()
                if var.name not in _exclude_params
            ]

        return [outputs_dict[tuple(grid)] for grid in output_grids]
コード例 #14
0
def train():

    seed = 8964
    tf.set_random_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    if not os.path.exists(opt.checkpoint_dir):
        os.makedirs(opt.checkpoint_dir)

    with tf.Graph().as_default():
        global_step = tf.Variable(0, name='global_step', trainable=False)
        incr_global_step = tf.assign(global_step, global_step + 1)
        #optim = tf.train.AdamOptimizer(opt.learning_rate, 0.9)
        lr = tf.placeholder(tf.float32, shape=[], name="lr")

        optim = tf.train.AdamOptimizer(lr, 0.9)

        loader = DataLoader(opt)

        losses = []

        img_losses = []
        rigid_warp_losses = []
        disp_smooth_losses = []

        sem_losses = []
        sem_warp_losses = []
        sem_mask_losses = []
        sem_edge_losses = []

        sem_seg_losses = []
        ins0_seg_losses = []
        ins1_edge_seg_losses = []

        ins_losses = []

        with tf.variable_scope(tf.get_variable_scope()):
            for i in range(opt.num_gpus):
                with tf.device('/gpu:{:d}'.format(i)):
                    with tf.name_scope('gpu{:d}'.format(i)):
                        """
                        Get images batch from data loader
                        'tgt_image' : target (middle frame)
                        'src_image_stack' : consists of 2 source images
                        'intrinsics' : camera intrinsic data
                        'tgt_sem_tuple' : semantic target data
                        'src_sem_stack_tuple' : semantic source data
                        'tgt_ins_tuple' : instance target data
                        'src_ins_stack_tuple' : instance mask source data
                        """
                        tgt_image, src_image_stack, intrinsics, tgt_sem_tuple, src_sem_stack_tuple, tgt_ins_tuple, src_ins_stack_tuple = loader.load_train_batch(
                        )

                        # Build Model
                        model = SIGNetModel(opt, tgt_image, src_image_stack,
                                            intrinsics, tgt_sem_tuple,
                                            src_sem_stack_tuple, tgt_ins_tuple,
                                            src_ins_stack_tuple)

                        # Handle losses
                        losses.append(model.total_loss)
                        tf.get_variable_scope().reuse_variables()

                        img_losses.append(model.img_loss)
                        rigid_warp_losses.append(model.rigid_warp_loss)
                        disp_smooth_losses.append(model.disp_smooth_loss)
                        if opt.sem_as_loss:
                            sem_losses.append(model.sem_loss)
                            if opt.sem_warp_explore:
                                sem_warp_losses.append(model.sem_warp_loss)
                            if opt.sem_mask_explore:
                                sem_mask_losses.append(model.sem_mask_loss)
                            if opt.sem_edge_explore:
                                sem_edge_losses.append(model.sem_edge_loss)
                        if opt.ins_as_loss:
                            ins_losses.append(model.ins_loss)

                        if opt.sem_assist and opt.add_segnet:
                            sem_seg_losses.append(model.sem_seg_loss)
                            ins0_seg_losses.append(model.ins0_seg_loss)
                            ins1_edge_seg_losses.append(
                                model.ins1_edge_seg_loss)

                        #TODO tensorboard
                        tf.summary.image('tgt_image_g%02d' % (i),
                                         tgt_image,
                                         max_outputs=opt.max_outputs)
                        tf.summary.image('src_image_prev_g%02d' % (i),
                                         src_image_stack[:, :, :, :3],
                                         max_outputs=opt.max_outputs)
                        tf.summary.image('src_image_next_g%02d' % (i),
                                         src_image_stack[:, :, :, 3:],
                                         max_outputs=opt.max_outputs)
                        tf.summary.scalar('loss_g%02d' % (i), model.total_loss)
                        tf.summary.scalar('img_loss_g%02d' % (i),
                                          model.img_loss)
                        tf.summary.scalar('rigid_warp_loss_g%02d' % (i),
                                          model.rigid_warp_loss)
                        tf.summary.scalar('disp_smooth_loss_g%02d' % (i),
                                          model.disp_smooth_loss)

                        if opt.sem_as_loss:
                            tf.summary.scalar('sem_loss_g%02d' % (i),
                                              model.sem_loss)
                            if opt.sem_warp_explore:
                                tf.summary.scalar('sem_warp_loss_g%02d' % (i),
                                                  model.sem_warp_loss)
                        if opt.ins_as_loss:
                            tf.summary.scalar('ins_loss_g%02d' % (i),
                                              model.ins_loss)

                        if opt.sem_assist and opt.add_segnet:
                            tf.summary.scalar('sem_seg_loss_g%02d' % (i),
                                              model.sem_seg_loss)
                            tf.summary.scalar('ins0_seg_loss_g%02d' % (i),
                                              model.ins0_seg_loss)
                            tf.summary.scalar('ins1_edge_seg_loss_g%02d' % (i),
                                              model.ins1_edge_seg_loss)

                        #TODO Add bookkeeping ops
                        if i == 0:
                            # Train Op
                            if opt.mode == 'train_flow' and opt.flownet_type == "residual":
                                train_vars = tf.get_collection(
                                    tf.GraphKeys.TRAINABLE_VARIABLES,
                                    "flow_net")
                            else:
                                #TODO try to enable a solution to fix posenet weight in first stage
                                if opt.mode == 'train_rigid' and opt.fixed_posenet:
                                    if opt.new_sem_dispnet:
                                        train_vars = tf.get_collection(
                                            tf.GraphKeys.TRAINABLE_VARIABLES,
                                            "depth_sem_net")
                                    else:
                                        train_vars = tf.get_collection(
                                            tf.GraphKeys.TRAINABLE_VARIABLES,
                                            "depth_net")
                                else:
                                    train_vars = [
                                        var
                                        for var in tf.trainable_variables()
                                    ]

                            loading_net = ["depth_net", "pose_net"]

                            if opt.new_sem_dispnet:
                                loading_net.append("depth_sem_net")
                            if opt.new_sem_posenet:
                                loading_net.append("pose_sem_net")

                            vars_to_restore = slim.get_variables_to_restore(
                                include=loading_net)

                            if opt.init_ckpt_file != None:
                                init_assign_op, init_feed_dict = slim.assign_from_checkpoint(
                                    opt.init_ckpt_file, vars_to_restore)

        #TODO Cal mean losses among gpus, and track the loss in TF Summary.
        loss = tf.stack(axis=0, values=losses)
        loss = tf.reduce_mean(loss, 0)
        tf.summary.scalar('loss', loss)

        rigid_warp_loss = tf.stack(axis=0, values=rigid_warp_losses)
        rigid_warp_loss = tf.reduce_mean(rigid_warp_loss, 0)
        tf.summary.scalar('rigid_warp_loss', rigid_warp_loss)
        tf.summary.scalar(
            'unit_rigid_warp_loss',
            rigid_warp_loss / (opt.rigid_warp_weight +
                               tf.convert_to_tensor(1e-8, dtype=tf.float32)))

        disp_smooth_loss = tf.stack(axis=0, values=disp_smooth_losses)
        disp_smooth_loss = tf.reduce_mean(disp_smooth_loss, 0)
        tf.summary.scalar('disp_smooth_loss', disp_smooth_loss)
        tf.summary.scalar(
            'unit_disp_smooth_loss',
            disp_smooth_loss / (opt.disp_smooth_weight +
                                tf.convert_to_tensor(1e-8, dtype=tf.float32)))

        img_loss = tf.stack(axis=0, values=img_losses)
        img_loss = tf.reduce_mean(img_loss, 0)
        tf.summary.scalar('img_loss', img_loss)

        if opt.sem_as_loss:
            sem_loss = tf.stack(axis=0, values=sem_losses)
            sem_loss = tf.reduce_mean(sem_loss, 0)
            tf.summary.scalar('sem_loss', sem_loss)

            if opt.sem_warp_explore:
                sem_warp_loss = tf.stack(axis=0, values=sem_warp_losses)
                sem_warp_loss = tf.reduce_mean(sem_warp_loss, 0)
                tf.summary.scalar('sem_warp_loss', model.sem_warp_loss)
                tf.summary.scalar(
                    'unit_sem_warp_loss', model.sem_warp_loss /
                    (opt.sem_warp_weight +
                     tf.convert_to_tensor(1e-8, dtype=tf.float32)))
            if opt.sem_mask_explore:
                sem_mask_loss = tf.stack(axis=0, values=sem_mask_losses)
                sem_mask_loss = tf.reduce_mean(sem_mask_loss, 0)
                tf.summary.scalar('sem_mask_loss', model.sem_mask_loss)
                tf.summary.scalar(
                    'unit_sem_mask_loss', model.sem_mask_loss /
                    (opt.sem_mask_weight +
                     tf.convert_to_tensor(1e-8, dtype=tf.float32)))
            if opt.sem_edge_explore:
                sem_edge_loss = tf.stack(axis=0, values=sem_edge_losses)
                sem_edge_loss = tf.reduce_mean(sem_edge_loss, 0)
                tf.summary.scalar('sem_edge_loss', model.sem_edge_loss)
                tf.summary.scalar(
                    'unit_sem_edge_loss', model.sem_edge_loss /
                    (opt.sem_edge_weight +
                     tf.convert_to_tensor(1e-8, dtype=tf.float32)))

        if opt.sem_assist and opt.add_segnet:
            sem_seg_loss = tf.stack(axis=0, values=sem_seg_losses)
            sem_seg_loss = tf.reduce_mean(sem_seg_loss, 0)
            tf.summary.scalar('sem_seg_loss', sem_seg_loss)
            tf.summary.scalar(
                'unit_sem_seg_loss', model.sem_seg_loss /
                (opt.sem_seg_weight +
                 tf.convert_to_tensor(1e-8, dtype=tf.float32)))

            ins0_seg_loss = tf.stack(axis=0, values=ins0_seg_losses)
            ins0_seg_loss = tf.reduce_mean(ins0_seg_loss, 0)
            tf.summary.scalar('ins0_seg_loss', ins0_seg_loss)
            tf.summary.scalar(
                'unit_ins0_seg_loss', model.ins0_seg_loss /
                (opt.ins0_seg_weight +
                 tf.convert_to_tensor(1e-8, dtype=tf.float32)))

            ins1_edge_seg_loss = tf.stack(axis=0, values=ins1_edge_seg_losses)
            ins1_edge_seg_loss = tf.reduce_mean(ins1_edge_seg_loss, 0)
            tf.summary.scalar('ins1_edge_seg_loss', ins1_edge_seg_loss)
            tf.summary.scalar(
                'unit_ins1_edge_seg_loss', model.ins1_edge_seg_loss /
                (opt.ins1_edge_seg_weight +
                 tf.convert_to_tensor(1e-8, dtype=tf.float32)))

        if opt.ins_as_loss:
            ins_loss = tf.stack(axis=0, values=ins_losses)
            ins_loss = tf.reduce_mean(ins_loss, 0)
            tf.summary.scalar('ins_loss', ins_loss)

        train_op = slim.learning.create_train_op(
            loss,
            optim,
            variables_to_train=train_vars,
            colocate_gradients_with_ops=True)

        # Saver
        saver = tf.train.Saver([var for var in tf.model_variables()] + \
                                [global_step],
                                max_to_keep=opt.max_to_keep)

        merged_summary = tf.summary.merge_all()

        # Session
        sv = tf.train.Supervisor(logdir=opt.checkpoint_dir,
                                 save_summaries_secs=0,
                                 saver=None)

        config = tf.ConfigProto(allow_soft_placement=True)
        config.gpu_options.allow_growth = True

        with sv.managed_session(config=config) as sess:
            train_writer = tf.summary.FileWriter(opt.summary_dir, sess.graph)

            if opt.init_ckpt_file != None:
                sess.run(init_assign_op, init_feed_dict)
            start_time = time.time()

            for step in range(1, opt.max_steps):

                if step < 50000:
                    lrate = 0.0001
                elif step < 100000:
                    lrate = 0.00005

                fetches = {
                    "train": train_op,
                    "global_step": global_step,
                    "incr_global_step": incr_global_step
                }

                if step % opt.print_interval == 0:
                    fetches["loss"] = loss
                    fetches["img_loss"] = img_loss

                    if opt.sem_as_loss:
                        fetches["sem_loss"] = sem_loss
                    if opt.ins_as_loss:
                        fetches["ins_loss"] = ins_loss
                    if opt.add_segnet:
                        fetches["sem_seg_loss"] = sem_seg_loss
                        fetches["ins0_seg_loss"] = ins0_seg_loss
                        fetches["ins1_edge_seg_loss"] = ins1_edge_seg_loss

                results = sess.run(fetches, feed_dict={lr: lrate})

                #TODO Write TF Summary to file.
                if step % opt.save_summ_freq == 0:
                    step_summary = sess.run(merged_summary)
                    train_writer.add_summary(step_summary, step)

                if step % opt.print_interval == 0:

                    time_per_iter = (time.time() -
                                     start_time) / opt.print_interval
                    start_time = time.time()

                    if opt.sem_as_loss:
                        print('Iteration: [%7d] | Time: %4.4fs/iter | Loss: %.3f ImgLoss: %.3f SemLoss: %.3f' \
                        % (step, time_per_iter, results["loss"], results["img_loss"], results["sem_loss"]))
                    elif opt.ins_as_loss:
                        print('Iteration: [%7d] | Time: %4.4fs/iter | Loss: %.3f ImgLoss: %.3f InsLoss: %.3f' \
                        % (step, time_per_iter, results["loss"], results["img_loss"], results["ins_loss"]))
                    else:
                        print('Iteration: [%7d] | Time: %4.4fs/iter | ImgLoss: %.3f' \
                        % (step, time_per_iter, results["loss"]))

                if step % opt.save_ckpt_freq == 0:
                    saver.save(sess,
                               os.path.join(opt.checkpoint_dir, 'model'),
                               global_step=step)
コード例 #15
0
def train(config_yaml, displayiters, saveiters, max_to_keep=5):
    start_path = os.getcwd()
    os.chdir(str(Path(config_yaml).parents[0])
             )  #switch to folder of config_yaml (for logging)
    setup_logging()

    cfg = load_config(config_yaml)
    cfg['batch_size'] = 1  #in case this was edited for analysis.

    dataset = create_dataset(cfg)
    batch_spec = get_batch_spec(cfg)
    batch, enqueue_op, placeholders = setup_preloading(batch_spec)
    losses = pose_net(cfg).train(batch)
    total_loss = losses['total_loss']

    for k, t in losses.items():
        tf.summary.scalar(k, t)
    merged_summaries = tf.summary.merge_all()

    variables_to_restore = slim.get_variables_to_restore(include=["resnet_v1"])
    restorer = tf.train.Saver(variables_to_restore)
    saver = tf.train.Saver(
        max_to_keep=max_to_keep
    )  # selects how many snapshots are stored, see https://github.com/AlexEMG/DeepLabCut/issues/8#issuecomment-387404835

    sess = tf.Session()
    coord, thread = start_preloading(sess, enqueue_op, dataset, placeholders)
    train_writer = tf.summary.FileWriter(cfg.log_dir, sess.graph)
    learning_rate, train_op = get_optimizer(total_loss, cfg)

    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())

    # Restore variables from disk.
    restorer.restore(sess, cfg.init_weights)

    max_iter = int(cfg.multi_step[-1][1])

    if displayiters == None:
        display_iters = max(1, int(cfg.display_iters))
    else:
        display_iters = max(1, int(displayiters))
        print("Display_iters overwritten as", display_iters)

    if saveiters == None:
        save_iters = max(1, int(cfg.save_iters))

    else:
        save_iters = max(1, int(saveiters))
        print("Save_iters overwritten as", save_iters)

    cum_loss = 0.0
    lr_gen = LearningRate(cfg)

    stats_path = Path(config_yaml).with_name('learning_stats.csv')
    lrf = open(str(stats_path), 'w')

    print("Training parameter:")
    print(cfg)
    print("Starting training....")
    for it in range(max_iter + 1):
        current_lr = lr_gen.get_lr(it)
        [_, loss_val,
         summary] = sess.run([train_op, total_loss, merged_summaries],
                             feed_dict={learning_rate: current_lr})
        cum_loss += loss_val
        train_writer.add_summary(summary, it)

        if it % display_iters == 0 and it > 0:
            average_loss = cum_loss / display_iters
            cum_loss = 0.0
            logging.info("iteration: {} loss: {} lr: {}".format(
                it, "{0:.4f}".format(average_loss), current_lr))
            lrf.write("{}, {:.5f}, {}\n".format(it, average_loss, current_lr))
            lrf.flush()

        # Save snapshot
        if (it % save_iters == 0 and it != 0) or it == max_iter:
            model_name = cfg.snapshot_prefix
            saver.save(sess, model_name, global_step=it)

    lrf.close()
    sess.close()
    coord.request_stop()
    coord.join([thread])
    #return to original path.
    os.chdir(str(start_path))
コード例 #16
0
    def train(self,
                train_config,
                loss,
                scalar_updates,
                optimizer,
                eval_ops_dict=None,
                pre_ops = None):
        """Coordinates the Training stage
                Args:
                  train_config: protobuf configuration for training
                  loss: loss defined by FeatureExtractor
                  scalar_updates: any scalar update ops to peform while training
                  optimizer: the optimizer to use for training
                  eval_ops_dict: a dict mapping format strings to evaluation operations
                  pre_ops: any operations to before before training after restoring
        """
        if not isinstance(train_config, train_pb2.TrainConfig):
            raise ValueError('train_config not type'
                                'train_pb2.TrainConfig.')

        if train_config.eval_while_training and not eval_ops_dict:
            raise ValueError("Can't eval and train without eval ops")

        fine_tune = False
        """Check whether classification ckpt dir exists. If not create it"""
        if not os.path.exists(
            os.path.join(train_config.from_classification_checkpoint)):
            print("TENSORFLOW INFO: Classification Checkpoint doesn't exist. Created checkpoint dir. ")
            os.mkdir(train_config.from_classification_checkpoint)
            #we must be fine tuning
            fine_tune = True
        else:
            print("TENSORFLOW INFO: Classification Checkpoint exists...restoring from it.")

        #this however will restore all variables if nothing is specified
        vars_to_restore = None

        if fine_tune and train_config.fine_tune_checkpoint:
                #things we always need to exclude
                exclude_list = ["train/total_loss","test_1/total_loss","Logits",'eval','global_step']
                if train_config.exclude_from_fine_tune:
                    exclude_list+=list(train_config.exclude_from_fine_tune)
                vars_to_restore = slim.get_variables_to_restore(exclude = exclude_list)
                #TODO add restore from map

        saver = tf.train.Saver(var_list = vars_to_restore) if vars_to_restore else None

        def flatten(l):
            flat = []
            for sublist in l:
                if isinstance(sublist,list):
                    for item in sublist:
                        flat.append(item)
                else:
                    flat.append(sublist)
            return flat

        #if variables to train is not specified, just train them all
        if not train_config.scopes_or_variables_to_train:
            train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

        #if vars to train is specified, then just train those
        else:
          train_vars = [tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          scope = s)
                            for s in train_config.scopes_or_variables_to_train]

        train_vars = flatten(train_vars)
        #user must specify the scope
        if train_config.scopes_or_names_for_update_ops and train_config.scopes_or_names_for_update_ops[0] != "all":
            update_ops = [tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                            scope = s)
                            for s in train_config.scopes_or_names_for_update_ops]
            update_ops = flatten(update_ops)

        elif train_config.scopes_or_names_for_update_ops and train_config.scopes_or_names_for_update_ops[0] == "all":
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

        if update_ops:
            with tf.control_dependencies(reduce(operator.concat,[scalar_updates,update_ops])):
                train_op = optimizer.minimize(loss,
                                              var_list = train_vars,
                                              global_step = self._global_step)
        else:
            with tf.control_dependencies(scalar_updates):
                train_op = optimizer.minimize(loss,
                                              var_list = train_vars,
                                              global_step = self._global_step)
        print(update_ops)
        checkpoint_dir = train_config.from_classification_checkpoint

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        file_writer = tf.summary.FileWriter(checkpoint_dir,tf.get_default_graph())
        with tf.train.MonitoredTrainingSession(save_checkpoint_secs = train_config.\
                                                                          keep_checkpoint_every_n_minutes*60,
                                               checkpoint_dir = checkpoint_dir,
                                               hooks = [ tf.train.StopAtStepHook(num_steps = train_config.\
                                                                                                 num_steps),
                                               logger_hook.LoggerHook(eval_ops_dict,
                                               train_config.\
                                                   log_frequency,
                                                   self._global_step)],
                                               config = config) as mon_sess:
            if fine_tune and train_config.fine_tune_checkpoint and saver:
                saver.restore(mon_sess,train_config.fine_tune_checkpoint)
            if pre_ops:
                mon_sess.run(pre_ops)
            for v in train_vars:
                print(v.name)
            print("TENSORFLOW INFO: Proceeding to training stage")
            while not mon_sess.should_stop():
                mon_sess.run(train_op,feed_dict = {'train/is_training:0':train_config.is_training})
コード例 #17
0
def main(_):
    tf.reset_default_graph()

    env = environment.get_game_environment(FLAGS.maps,
                                           multiproc=FLAGS.multiproc,
                                           random_goal=FLAGS.random_goal,
                                           random_spawn=FLAGS.random_spawn,
                                           apple_prob=FLAGS.apple_prob)
    exp = expert.Expert()
    net = CMAP(num_iterations=FLAGS.vin_iterations,
               estimate_scale=FLAGS.estimate_scale,
               unified_fuser=FLAGS.unified_fuser,
               unified_vin=FLAGS.unified_vin,
               biased_fuser=FLAGS.biased_fuser,
               biased_vin=FLAGS.biased_vin,
               regularization=FLAGS.reg)

    estimate_images = [
        estimate[0, -1, :, :, 0]
        for estimate in net.intermediate_tensors['estimate_map_list']
    ]
    goal_images = [
        goal[0, -1, :, :, 0]
        for goal in net.intermediate_tensors['goal_map_list']
    ]
    reward_images = [
        reward[0, -1, :, :, 0]
        for reward in net.intermediate_tensors['reward_map_list']
    ]
    value_images = [
        value[0, -1, :, :, 0]
        for value in net.intermediate_tensors['value_map_list']
    ]
    action_images = [
        action[0, -1, :, :, 0]
        for action in net.intermediate_tensors['action_map_list']
    ]

    step_history = tf.placeholder(tf.string, name='step_history')
    step_history_op = tf.summary.text('game/step_history',
                                      step_history,
                                      collections=['game'])

    global_step = slim.get_or_create_global_step()
    update_global_step_op = tf.assign_add(global_step, 1)

    init_op = tf.variables_initializer([global_step])
    load_op, load_feed_dict = slim.assign_from_checkpoint(
        FLAGS.modeldir,
        slim.get_variables_to_restore(exclude=[global_step.name]))

    init_op = tf.group(init_op, load_op)

    slim.learning.train(train_op=tf.no_op('train'),
                        logdir=FLAGS.logdir,
                        init_op=init_op,
                        init_feed_dict=load_feed_dict,
                        global_step=global_step,
                        train_step_fn=DAGGER_train_step,
                        train_step_kwargs=dict(
                            env=env,
                            exp=exp,
                            net=net,
                            update_global_step_op=update_global_step_op,
                            step_history=step_history,
                            step_history_op=step_history_op,
                            estimate_maps=estimate_images,
                            goal_maps=goal_images,
                            reward_maps=reward_images,
                            value_maps=value_images,
                            action_maps=action_images),
                        number_of_steps=FLAGS.num_games,
                        save_interval_secs=300 if not FLAGS.debug else 60,
                        save_summaries_secs=300 if not FLAGS.debug else 60)
コード例 #18
0
def train(args):
    model = AppearanceNetwork(args)

    save_directory = './save/'
    log_file_path = './training.log'
    log_file = open(log_file_path, 'w')

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    with tf.Graph().as_default():
        global_step = tf.Variable(0, name='global_step', trainable=False)

        image_patches_placeholder = tf.placeholder(
            tf.float32, shape=[args.batch_size, 7, 128, 64, 3])

        labels_placeholder = tf.placeholder(tf.float32,
                                            shape=[args.batch_size])

        lr = tf.Variable(args.base_learning_rate,
                         trainable=False,
                         name="learning_rate")

        features, logits = model.inference(image_patches_placeholder)

        loss = model.cross_entropy_loss(logits, labels_placeholder)

        train_op = build_graph(args, global_step, lr, loss)

        sess = tf.Session()

        saver = tf.train.Saver(max_to_keep=100)

        ckpt = tf.train.get_checkpoint_state('./save')
        if ckpt is None:
            init = tf.global_variables_initializer()
            sess.run(init)
            if args.pretrained_ckpt_path is not None:
                # slim.get_or_create_global_step()
                init_assign_op, init_feed_dict = slim.assign_from_checkpoint(
                    args.pretrained_ckpt_path,
                    slim.get_variables_to_restore(exclude=[
                        "lstm", "fc_layer", "loss", "learning_rate", "softmax",
                        "global_step"
                    ]))
                sess.run(init_assign_op, feed_dict=init_feed_dict)
        else:
            print 'Loading Model from ' + ckpt.model_checkpoint_path
            saver.restore(sess, ckpt.model_checkpoint_path)

        best_epoch = -1
        best_loss_epoch = 0.0
        for curr_epoch in range(args.num_epoches):
            training_loss_epoch = 0.0
            valid_loss_epoch = 0.0

            ############################################# Training process ######################################
            print 'Training epoch ' + str(curr_epoch +
                                          1) + '........................'
            training_data_loader = DataLoader(is_valid=False)

            if curr_epoch % 10 == 0:
                sess.run(
                    tf.assign(
                        lr,
                        args.base_learning_rate *
                        (args.decay_rate**curr_epoch / 10)))

            training_data_loader.shuffle()
            training_data_loader.reset_pointer()

            for step in range(training_data_loader.num_batches):
                start_time = time.time()

                image_patches, labels = training_data_loader.next_batch()

                _, loss_batch = sess.run(
                    [train_op, loss],
                    feed_dict={
                        image_patches_placeholder: image_patches,
                        labels_placeholder: labels
                    })

                end_time = time.time()
                training_loss_epoch += loss_batch
                print(
                    "Training {}/{} (epoch {}), train_loss = {:.8f}, time/batch = {:.3f}"
                    .format(step + 1, training_data_loader.num_batches,
                            curr_epoch + 1, loss_batch, end_time - start_time))

            print 'Epoch ' + str(curr_epoch +
                                 1) + ' training is done! Saving model...'
            checkpoint_path = os.path.join(save_directory, 'model.ckpt')
            saver.save(sess, checkpoint_path, global_step=global_step)

            ############################################# Validating process ######################################
            print 'Validating epoch ' + str(curr_epoch +
                                            1) + '...........................'
            valid_data_loader = DataLoader(is_valid=True)

            valid_data_loader.shuffle()
            valid_data_loader.reset_pointer()
            for step in range(valid_data_loader.num_batches):
                start_time = time.time()

                image_patches, labels = valid_data_loader.next_batch()

                loss_batch = sess.run(loss,
                                      feed_dict={
                                          image_patches_placeholder:
                                          image_patches,
                                          labels_placeholder: labels
                                      })

                end_time = time.time()
                valid_loss_epoch += loss_batch
                print(
                    "Validating {}/{} (epoch {}), valid_loss = {:.8f}, time/batch = {:.3f}"
                    .format(step + 1, valid_data_loader.num_batches,
                            curr_epoch + 1, loss_batch, end_time - start_time))

            # Update best valid epoch
            if best_epoch == -1 or best_loss_epoch > valid_loss_epoch:
                best_epoch = curr_epoch + 1
                best_loss_epoch = valid_loss_epoch

            log_file.write('epoch ' + str(curr_epoch + 1) + '\n')
            log_file.write(
                str(curr_epoch + 1) + ',' + str(training_loss_epoch) + '\n')
            log_file.write(
                str(curr_epoch + 1) + ',' + str(valid_loss_epoch) + '\n')
            log_file.write(str(best_epoch) + ',' + str(best_loss_epoch) + '\n')

        log_file.close()
コード例 #19
0
def create_model(config, sess, ensemble_scope=None, train=False):
    logging.info('Building model...')
    model = StandardModel(config)

    # Construct a mapping between saved variable names and names in the current
    # scope. There are two reasons why names might be different:
    #
    #   1. This model is part of an ensemble, in which case a model-specific
    #       name scope will be active.
    #
    #   2. The saved model is from an old version of Nematus (before deep model
    #        support was added) and uses a different variable naming scheme
    #        for the GRUs.
    variables = slim.get_variables_to_restore()
    var_map = {}
    for v in variables:
        name = v.name.split(':')[0]
        if ensemble_scope == None:
            saved_name = name
        elif v.name.startswith(ensemble_scope):
            saved_name = name[len(ensemble_scope):]
        else: # v belongs to a different model in the ensemble.
            continue
        if config.model_version == 0.1:
            # Backwards compatibility with the old variable naming scheme.
            saved_name = compat.revert_variable_name(saved_name, 0.1)
        var_map[saved_name] = v
    saver = tf.train.Saver(var_map, max_to_keep=None)

    # compute reload model filename
    reload_filename = None
    if config.reload == 'latest_checkpoint':
        checkpoint_dir = os.path.dirname(config.saveto)
        reload_filename = tf.train.latest_checkpoint(checkpoint_dir)
        if reload_filename != None:
            if (os.path.basename(reload_filename).rsplit('-', 1)[0] !=
                os.path.basename(config.saveto)):
                logging.error("Mismatching model filename found in the same directory while reloading from the latest checkpoint")
                sys.exit(1)
            logging.info('Latest checkpoint found in directory ' + os.path.abspath(checkpoint_dir))
    elif config.reload != None:
        reload_filename = config.reload
    if (reload_filename == None) and (config.prior_model != None):
        logging.info('Initializing model parameters from prior')
        reload_filename = config.prior_model

    # initialize or reload training progress
    if train:
        progress = training_progress.TrainingProgress()
        progress.bad_counter = 0
        progress.uidx = 0
        progress.eidx = 0
        progress.estop = False
        progress.history_errs = []
        if reload_filename and config.reload_training_progress:
            path = reload_filename + '.progress.json'
            if os.path.exists(path):
                logging.info('Reloading training progress')
                progress.load_from_json(path)
                if (progress.estop == True or
                    progress.eidx > config.max_epochs or
                    progress.uidx >= config.finish_after):
                    logging.warning('Training is already complete. Disable reloading of training progress (--no_reload_training_progress) or remove or modify progress file (%s) to train anyway.' % reload_path)
                    sys.exit(0)

    # load prior model
    if train and config.prior_model != None:
        load_prior(config, sess, saver)
    
    # initialize or restore model
    if reload_filename == None:
        logging.info('Initializing model parameters from scratch...')
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
    else:
        logging.info('Loading model parameters from file ' + os.path.abspath(reload_filename))
        saver.restore(sess, os.path.abspath(reload_filename))
        if train:
            # The global step is currently recorded in two places:
            #   1. model.t, a tf.Variable read and updated by the optimizer
            #   2. progress.uidx, a Python integer updated by train()
            # We reset model.t to the value recorded in progress to allow the
            # value to be controlled by the user (either explicitly by
            # configuring the value in the progress file or implicitly by using
            # --no_reload_training_progress).
            model.reset_global_step(progress.uidx, sess)

    logging.info('Done')

    if train:
        return model, saver, progress
    else:
        return model, saver
コード例 #20
0
	# Create model
	if options.modelName == "NASNet":
		arg_scope = nasnet.nasnet_large_arg_scope()
		with slim.arg_scope(arg_scope):
			logits, endPoints = nasnet.build_nasnet_large(scaledInputBatchImages, is_training=False, num_classes=options.numClasses)

	elif options.modelName == "IncResV2":
		arg_scope = inception_resnet_v2.inception_resnet_v2_arg_scope()
		with slim.arg_scope(arg_scope):
			# logits, endPoints = inception_resnet_v2.inception_resnet_v2(scaledInputBatchImages, is_training=False)
			with tf.variable_scope('InceptionResnetV2', 'InceptionResnetV2', [scaledInputBatchImages], reuse=None) as scope:
				with slim.arg_scope([slim.batch_norm, slim.dropout], is_training=False):
				  net, endPoints = inception_resnet_v2.inception_resnet_v2_base(scaledInputBatchImages, scope=scope, activation_fn=tf.nn.relu)

		variablesToRestore = slim.get_variables_to_restore(include=["InceptionResnetV2"])

	else:
		print ("Error: Model not found!")
		exit (-1)

# TODO: Attach the decoder to the encoder
print (endPoints.keys())
if options.useSkipConnections:
	print ("Adding skip connections from the encoder to the decoder!")
predictedLogits = attachDecoder(net, endPoints, tf.shape(scaledInputBatchImages))
predictedMask = tf.expand_dims(tf.argmax(predictedLogits, axis=-1), -1, name="predictedMasks")

if options.tensorboardVisualization:
	tf.summary.image('Original Image', inputBatchImages, max_outputs=3)
	tf.summary.image('Desired Mask', tf.to_float(inputBatchMasks), max_outputs=3)
コード例 #21
0
ファイル: train.py プロジェクト: cniklaus/text_simplification
def train(model_config=None):
    model_config = (DefaultConfig()
                    if model_config is None else model_config)
    data = TrainData(model_config)

    graph = None
    if model_config.framework == 'transformer':
        graph = TransformerGraph(data, True, model_config)
    elif model_config.framework == 'seq2seq':
        graph = Seq2SeqGraph(data, True, model_config)
    else:
        raise NotImplementedError('Unknown Framework.')
    graph.create_model_multigpu()

    ckpt_path = None
    if model_config.warm_start:
        if model_config.warm_start == 'recent':
            ckpt_path = find_best_ckpt(model_config)
        else:
            ckpt_path = model_config.warm_start
        var_list = slim.get_variables_to_restore()
    if ckpt_path is not None:
        # Handling missing vars by ourselves
        available_vars = {}
        reader = tf.train.NewCheckpointReader(ckpt_path)
        var_dict = {var.op.name: var for var in var_list}
        for var in var_dict:
            if 'global_step' in var:
                continue
            if 'optimization' in var:
                continue
            if reader.has_tensor(var):
                var_ckpt = reader.get_tensor(var)
                var_cur = var_dict[var]
                if any([var_cur.shape[i] != var_ckpt.shape[i] for i in range(len(var_ckpt.shape))]):
                    print('Variable %s missing due to shape.', var)
                else:
                    available_vars[var] = var_dict[var]
            else:
                print('Variable %s missing.', var)

        partial_restore_ckpt = slim.assign_from_checkpoint_fn(
            ckpt_path, available_vars,
            ignore_missing_vars=False, reshape_variables=False)

    def init_fn(session):
        # Restore ckpt either from warm start or automatically get when changing optimizer
        ckpt_path = None
        if model_config.warm_start:
            ckpt_path = model_config.warm_start

        if ckpt_path is not None:
            if model_config.use_partial_restore:
                partial_restore_ckpt(session)
            else:
                try:
                    graph.saver.restore(session, ckpt_path)
                except Exception as ex:
                    print('Fully restore failed, use partial restore instead. \n %s' % str(ex))
                    partial_restore_ckpt(session)

            print('Warm start with checkpoint %s' % ckpt_path)

    sv = tf.train.Supervisor(logdir=model_config.logdir,
                             global_step=graph.global_step,
                             saver=graph.saver,
                             init_fn=init_fn,
                             save_model_secs=model_config.save_model_secs)
    sess = sv.PrepareSession(config=session.get_session_config(model_config))
    perplexitys = []
    start_time = datetime.now()
    while True:
        input_feed = get_graph_train_data(
            data,
            graph.objs,
            model_config)

        # fetches = [graph.train_op, graph.loss, graph.global_step,
        #            graph.perplexity, graph.ops, graph.attn_dists, graph.targets, graph.cs]
        # _, loss, step, perplexity, _ops , attn_dists, targets, cs = sess.run(fetches, input_feed)
        fetches = [graph.train_op, graph.loss, graph.global_step,
                   graph.perplexity, graph.ops, graph.logits]
        _, loss, step, perplexity, _, logits = sess.run(fetches, input_feed)
        perplexitys.append(perplexity)

        if step % model_config.model_print_freq == 0:
            end_time = datetime.now()
            time_span = end_time - start_time
            start_time = end_time
            print('Perplexity:\t%f at step %d using %s.' % (perplexity, step, time_span))
            perplexitys.clear()
コード例 #22
0
def main(unused_argv):
    # Get dataset-dependent information.
    # Prepare for visualization.
    tf.gfile.MakeDirs(FLAGS.vis_logdir)
    save_dir = os.path.join(FLAGS.vis_logdir, _SEMANTIC_PREDICTION_SAVE_FOLDER)
    tf.gfile.MakeDirs(save_dir)
    raw_save_dir = os.path.join(FLAGS.vis_logdir,
                                _RAW_SEMANTIC_PREDICTION_SAVE_FOLDER)
    tf.gfile.MakeDirs(raw_save_dir)
    num_vis_examples = FLAGS.num_vis_examples

    print('Visualizing on set', FLAGS.split)

    g = tf.Graph()
    with g.as_default():
        samples = model_input.get_input_fn(FLAGS)()
        outputs_to_num_classes = model.get_output_to_num_classes(FLAGS)

        # Get model segmentation predictions.
        if tuple(FLAGS.eval_scales) == (1.0, ):
            tf.logging.info('Performing single-scale test.')
            predictions, probs = model.predict_labels(
                samples['image'],
                samples,
                FLAGS,
                outputs_to_num_classes=outputs_to_num_classes,
                image_pyramid=FLAGS.image_pyramid,
                merge_method=FLAGS.merge_method,
                atrous_rates=FLAGS.atrous_rates,
                add_image_level_feature=FLAGS.add_image_level_feature,
                aspp_with_batch_norm=FLAGS.aspp_with_batch_norm,
                aspp_with_separable_conv=FLAGS.aspp_with_separable_conv,
                multi_grid=FLAGS.multi_grid,
                depth_multiplier=FLAGS.depth_multiplier,
                output_stride=FLAGS.output_stride,
                decoder_output_stride=FLAGS.decoder_output_stride,
                decoder_use_separable_conv=FLAGS.decoder_use_separable_conv,
                crop_size=[FLAGS.image_size, FLAGS.image_size],
                logits_kernel_size=FLAGS.logits_kernel_size,
                model_variant=FLAGS.model_variant)
        else:
            tf.logging.info('Performing multi-scale test.')
            predictions, probs = model.predict_labels_multi_scale(
                samples['image'],
                samples,
                FLAGS,
                outputs_to_num_classes=outputs_to_num_classes,
                eval_scales=FLAGS.eval_scales,
                add_flipped_images=FLAGS.add_flipped_images,
                merge_method=FLAGS.merge_method,
                atrous_rates=FLAGS.atrous_rates,
                add_image_level_feature=FLAGS.add_image_level_feature,
                aspp_with_batch_norm=FLAGS.aspp_with_batch_norm,
                aspp_with_separable_conv=FLAGS.aspp_with_separable_conv,
                multi_grid=FLAGS.multi_grid,
                depth_multiplier=FLAGS.depth_multiplier,
                output_stride=FLAGS.output_stride,
                decoder_output_stride=FLAGS.decoder_output_stride,
                decoder_use_separable_conv=FLAGS.decoder_use_separable_conv,
                crop_size=[FLAGS.image_size, FLAGS.image_size],
                logits_kernel_size=FLAGS.logits_kernel_size,
                model_variant=FLAGS.model_variant)

        if FLAGS.output_mode == 'segment':
            predictions = tf.squeeze(
                tf.cast(predictions[FLAGS.output_mode], tf.int32))
            probs = probs[FLAGS.output_mode]

            labels = tf.squeeze(tf.cast(samples['label'], tf.int32))
            weights = tf.cast(
                tf.not_equal(
                    labels, model_input.dataset_descriptors[
                        FLAGS.dataset].ignore_label), tf.int32)

            labels *= weights
            predictions *= weights

            tf.train.get_or_create_global_step()
            saver = tf.train.Saver(contrib_slim.get_variables_to_restore())
            sv = tf.train.Supervisor(graph=g,
                                     logdir=FLAGS.vis_logdir,
                                     init_op=tf.global_variables_initializer(),
                                     summary_op=None,
                                     summary_writer=None,
                                     global_step=None,
                                     saver=saver)
            num_batches = int(
                math.ceil(num_vis_examples / float(FLAGS.batch_size)))
            last_checkpoint = None

            # Infinite loop to visualize the results when new checkpoint is created.
            while True:
                last_checkpoint = contrib_slim.evaluation.wait_for_new_checkpoint(
                    FLAGS.checkpoint_dir, last_checkpoint)
                start = time.time()
                print('Starting visualization at ' +
                      time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime()))
                print('Visualizing with model %s', last_checkpoint)

                print('Visualizing with model ', last_checkpoint)

                with sv.managed_session(FLAGS.master,
                                        start_standard_services=False) as sess:
                    # sv.start_queue_runners(sess)
                    sv.saver.restore(sess, last_checkpoint)

                    image_id_offset = 0
                    refs = []
                    for batch in range(num_batches):
                        print('Visualizing batch', batch + 1, num_batches)
                        refs.extend(
                            _process_batch(sess=sess,
                                           samples=samples,
                                           semantic_predictions=predictions,
                                           labels=labels,
                                           image_id_offset=image_id_offset,
                                           save_dir=save_dir))
                        image_id_offset += FLAGS.batch_size

            print('Finished visualization at ' +
                  time.strftime('%Y-%m-%d-%H:%M:%S', time.gmtime()))
            time_to_next_eval = start + FLAGS.eval_interval_secs - time.time()
            if time_to_next_eval > 0:
                time.sleep(time_to_next_eval)
コード例 #23
0
ファイル: variables_test.py プロジェクト: LevinJ/CodeSamples
                             initializer=tf.truncated_normal_initializer(stddev=0.1),
                             regularizer=slim.l2_regularizer(0.05),
                             device='/CPU:0')

weights_2 = slim.model_variable('weights_2',
                              shape=[10, 10, 3 , 3],
                              initializer=tf.truncated_normal_initializer(stddev=0.1),
                              regularizer=slim.l2_regularizer(0.05),
                              device='/CPU:0')

my_var = slim.variable('my_var',
                       shape=[20, 1],
                       initializer=tf.zeros_initializer())
regular_variables_and_model_variables = slim.get_variables()

variables_to_restore = slim.get_variables_to_restore(exclude=["v1"])
# Launch the graph in a session.
sess = tf.Session()

# Evaluate the tensor `c`.
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    print(my_var.eval())
    print(weights.eval())
    
    
   

    
   
コード例 #24
0
def export(checkpoint_path, modelNo):

    print("Begin exporting %s" % format(datetime.now().isoformat()))

    saved_model_dir = "SavedModel"

    inference_dir = os.environ['MODEL_INFERENCE_PATH']
    export_dir = os.path.join(inference_dir, saved_model_dir, modelNo,
                              "SavedModel")

    print("The path of saved model: %s" % export_dir)

    if tf.gfile.Exists(export_dir):
        print('Saved model folder already exist. Delete it firstly.')
        if (export_dir.endswith(saved_model_dir)):
            tf.gfile.DeleteRecursively(export_dir)

    if (checkpoint_path == None):
        checkpoint_path = tf.train.latest_checkpoint(FLAGS.train_dir)

    print("checkpoint_path: %s" % checkpoint_path)

    with tf.Graph().as_default() as graph:
        tf_global_step = slim.get_or_create_global_step()
        labels_to_names = dataset_utils.read_label_file(FLAGS.dataset_dir)
        num_classes = len(labels_to_names)
        network_fn = nets_factory.get_network_fn(
            FLAGS.model_name,
            num_classes=(num_classes - FLAGS.labels_offset),
            weight_decay=FLAGS.weight_decay,
            is_training=False)

        input_shape = [None, FLAGS.train_image_size, FLAGS.train_image_size, 3]
        input_tensor = tf.placeholder(name='input_1',
                                      dtype=tf.float32,
                                      shape=input_shape)

        predictions_key = "Predictions"
        if FLAGS.model_name.startswith("resnet"):
            logits, endpoints = network_fn(input_tensor)
            predictions_key = "predictions"
        elif FLAGS.model_name.startswith("inception"):
            logits, endpoints = network_fn(input_tensor,
                                           create_aux_logits=False)
        elif FLAGS.model_name.startswith("nasnet_mobile"):
            logits, endpoints = network_fn(input_tensor, use_aux_head=0)

        predictions = endpoints[predictions_key]

        if FLAGS.moving_average_decay:
            variable_averages = tf.train.ExponentialMovingAverage(
                FLAGS.moving_average_decay, tf_global_step)
            variables_to_restore = variable_averages.variables_to_restore(
                slim.get_model_variables())
            variables_to_restore[tf_global_step.op.name] = tf_global_step
        else:
            variables_to_restore = slim.get_variables_to_restore()

        saver = tf.train.Saver(
            var_list=variables_to_restore)  #Same as slim.get_variables()

        init1 = tf.global_variables_initializer()
        init2 = tf.local_variables_initializer()
        with tf.Session() as sess:
            sess.run(init1)
            sess.run(init2)
            saver.restore(sess, checkpoint_path)

            #uninitialized_variables = [str(v, 'utf-8') for v in set(sess.run(tf.report_uninitialized_variables()))]
            #print(uninitialized_variables)
            #tf.graph_util.convert_variables_to_constants()

            print("Exporting saved model to: %s" % export_dir)

            prediction_signature = predict_signature_def(
                inputs={'input_1': input_tensor},
                outputs={'output': predictions})

            signature_def_map = {
                tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                prediction_signature
            }

            builder = tf.saved_model.builder.SavedModelBuilder(export_dir)

            builder.add_meta_graph_and_variables(
                sess,
                tags=[tf.saved_model.tag_constants.SERVING],
                signature_def_map=signature_def_map,
                clear_devices=True,
                main_op=None,  #Suggest tf.tables_initializer()?
                strip_default_attrs=False)  #Suggest True?
            builder.save()
            sess.close()
            print("Done exporting %s" % format(datetime.now().isoformat()))
コード例 #25
0
ファイル: train.py プロジェクト: happog/yolo-tf
def main():
    model = config.get('config', 'model')
    logdir = utils.get_logdir(config)
    if args.delete:
        tf.logging.warn('delete logging directory: ' + logdir)
        shutil.rmtree(logdir, ignore_errors=True)
    cachedir = utils.get_cachedir(config)
    with open(os.path.join(cachedir, 'names'), 'r') as f:
        names = [line.strip() for line in f]
    width = config.getint(model, 'width')
    height = config.getint(model, 'height')
    cell_width, cell_height = utils.calc_cell_width_height(config, width, height)
    tf.logging.warn('(width, height)=(%d, %d), (cell_width, cell_height)=(%d, %d)' % (width, height, cell_width, cell_height))
    yolo = importlib.import_module('model.' + model)
    paths = [os.path.join(cachedir, profile + '.tfrecord') for profile in args.profile]
    num_examples = sum(sum(1 for _ in tf.python_io.tf_record_iterator(path)) for path in paths)
    tf.logging.warn('num_examples=%d' % num_examples)
    with tf.name_scope('batch'):
        image_rgb, labels = utils.data.load_image_labels(paths, len(names), width, height, cell_width, cell_height, config)
        with tf.name_scope('per_image_standardization'):
            image_std = tf.image.per_image_standardization(image_rgb)
        batch = tf.train.shuffle_batch((image_std,) + labels, batch_size=args.batch_size,
            capacity=config.getint('queue', 'capacity'), min_after_dequeue=config.getint('queue', 'min_after_dequeue'),
            num_threads=multiprocessing.cpu_count()
        )
    global_step = tf.contrib.framework.get_or_create_global_step()
    builder = yolo.Builder(args, config)
    builder(batch[0], training=True)
    with tf.name_scope('total_loss') as name:
        builder.create_objectives(batch[1:])
        total_loss = tf.losses.get_total_loss(name=name)
    variables_to_restore = slim.get_variables_to_restore(exclude=args.exclude)
    with tf.name_scope('optimizer'):
        try:
            decay_steps = config.getint('exponential_decay', 'decay_steps')
            decay_rate = config.getfloat('exponential_decay', 'decay_rate')
            staircase = config.getboolean('exponential_decay', 'staircase')
            learning_rate = tf.train.exponential_decay(args.learning_rate, global_step, decay_steps, decay_rate, staircase=staircase)
            tf.logging.warn('using a learning rate start from %f with exponential decay (decay_steps=%d, decay_rate=%f, staircase=%d)' % (args.learning_rate, decay_steps, decay_rate, staircase))
        except (configparser.NoSectionError, configparser.NoOptionError):
            learning_rate = args.learning_rate
            tf.logging.warn('using a staionary learning rate %f' % args.learning_rate)
        optimizer = get_optimizer(config, args.optimizer)(learning_rate)
        tf.logging.warn('optimizer=' + args.optimizer)
        train_op = slim.learning.create_train_op(total_loss, optimizer, global_step,
            clip_gradient_norm=args.gradient_clip, summarize_gradients=config.getboolean('summary', 'gradients'),
        )
    if args.transfer:
        path = os.path.expanduser(os.path.expandvars(args.transfer))
        tf.logging.warn('transferring from ' + path)
        init_assign_op, init_feed_dict = slim.assign_from_checkpoint(path, variables_to_restore)
        def init_fn(sess):
            sess.run(init_assign_op, init_feed_dict)
            tf.logging.warn('transferring from global_step=%d, learning_rate=%f' % sess.run((global_step, learning_rate)))
    else:
        init_fn = lambda sess: tf.logging.warn('global_step=%d, learning_rate=%f' % sess.run((global_step, learning_rate)))
    summary(config)
    tf.logging.warn('tensorboard --logdir ' + logdir)
    slim.learning.train(train_op, logdir, master=args.master, is_chief=(args.task == 0),
        global_step=global_step, number_of_steps=args.steps, init_fn=init_fn,
        summary_writer=tf.summary.FileWriter(os.path.join(logdir, args.logname)),
        save_summaries_secs=args.summary_secs, save_interval_secs=args.save_secs
    )
コード例 #26
0
def tf_train_flow(
        train_once_fn,
        model_dir='./model',
        max_models_keep=1,
        save_interval_seconds=600,
        save_interval_steps=1000,
        num_epochs=None,
        num_steps=None,
        save_model=True,
        save_interval_epochs=1,
        num_steps_per_epoch=0,
        restore_from_latest=True,
        metric_eval_fn=None,
        init_fn=None,
        restore_fn=None,
        restore_scope=None,
        save_all_scope=False,  #TODO save load from restore scope only but svae all
        variables_to_restore=None,
        variables_to_save=None,  #by default will be the same as variables_to_restore
        sess=None):
    """
  similary flow as tf_flow, but add model try reload and save
  """
    if sess is None:
        #TODO melt.get_session is global session but may cause non close at last
        sess = melt.get_session()
    logging.info('tf_train_flow start')
    print('max_models_keep:', max_models_keep, file=sys.stderr)
    print('save_interval_seconds:', save_interval_seconds, file=sys.stderr)

    #this is usefull for you use another model with another scope, and just load and restore/save initalize your scope vars!
    #this is not for finetune but mainly for like using another model as in predict like this introducing graph other model scope and ignore here

    var_list = None if not restore_scope else tf.get_collection(
        tf.GraphKeys.GLOBAL_VARIABLES, scope=restore_scope)
    if not variables_to_restore:
        variables_to_restore = var_list
    if not variables_to_save:
        variables_to_save = variables_to_restore
    if save_all_scope:
        variables_to_save = None

    if variables_to_restore is None:
        #load all var in checkpoint try to save all var(might more then original checkpoint) if not specifiy variables_to_save
        varnames_in_checkpoint = melt.get_checkpoint_varnames(model_dir)
        #print(varnames_in_checkpoint)
        variables_to_restore = slim.get_variables_to_restore(
            include=varnames_in_checkpoint)

    #logging.info('variables_to_restore:{}'.format(variables_to_restore))
    loader = tf.train.Saver(var_list=variables_to_restore)

    saver = tf.train.Saver(
        max_to_keep=max_models_keep,
        keep_checkpoint_every_n_hours=save_interval_seconds / 3600.0,
        var_list=variables_to_save)
    epoch_saver = tf.train.Saver(var_list=variables_to_save, max_to_keep=1000)
    best_epoch_saver = tf.train.Saver(var_list=variables_to_save)

    ##TODO for safe restore all init will be ok ?
    #if variables_to_restore is None:
    init_op = tf.group(
        tf.global_variables_initializer(
        ),  #variables_initializer(global_variables())
        tf.local_variables_initializer()
    )  #variables_initializer(local_variables())
    # else:
    #   init_op = tf.group(tf.variables_initializer(variables_to_restore),
    #                      tf.local_variables_initializer())

    ##--mostly this will be fine except for using assistant predictor, initialize again! will make assistant predictor wrong
    ##so assume to all run init op! if using assistant predictor, make sure it use another session

    sess.run(init_op)

    #melt.init_uninitialized_variables(sess)

    #pre_step means the step last saved, train without pretrained,then -1
    pre_step = -1
    fixed_pre_step = -1  #fixed pre step is for epoch num to be correct if yu change batch size
    model_path = _get_model_path(model_dir, save_model)
    model_dir = gezi.get_dir(
        model_dir)  #incase you pass ./model/model-ckpt1000 -> ./model
    if model_path is not None:
        if not restore_from_latest:
            print('using recent but not latest model', file=sys.stderr)
            model_path = melt.recent_checkpoint(model_dir)
        model_name = os.path.basename(model_path)
        timer = gezi.Timer('Loading and training from existing model [%s]' %
                           model_path)
        if restore_fn is not None:
            restore_fn(sess)
        loader.restore(sess, model_path)
        timer.print()
        pre_step = melt.get_model_step(model_path)
        pre_epoch = melt.get_model_epoch(model_path)
        fixed_pre_step = pre_step
        if pre_epoch is not None:
            #like using batch size 32, then reload train using batch size 64
            if abs(pre_step / num_steps_per_epoch - pre_epoch) > 0.1:
                fixed_pre_step = int(pre_epoch * num_steps_per_epoch)
                logging.info('Warning, epoch is diff with pre_step / num_steps_per_epoch:{}, pre_epoch:{},maybe you change batch size and we will adjust to set pre_step as {}'\
                  .format(pre_step / num_steps_per_epoch, pre_epoch, fixed_pre_step))
    else:
        print('Train all start step 0', file=sys.stderr)
        #https://stackoverflow.com/questions/40220201/tensorflow-tf-initialize-all-variables-vs-tf-initialize-local-variables
        #tf.initialize_all_variables() is a shortcut to tf.initialize_variables(tf.all_variables()),
        #tf.initialize_local_variables() is a shortcut to tf.initialize_variables(tf.local_variables()),
        #which initializes variables in GraphKeys.VARIABLES and GraphKeys.LOCAL_VARIABLE collections, respectively.
        #init_op = tf.group(tf.global_variables_initializer(),
        #                   tf.local_variables_initializer())
        #[var for var in tf.all_variables() if var.op.name.startswith(restore_scope)] will be the same as tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=restore_scope)

        #sess.run(init_op)

        #like use image model, build image graph, reload first train, and then will go to same checkpoint all varaible just restore will ok
        #for finetune from loading other model init
        if init_fn is not None:
            init_fn(sess)

    if save_interval_epochs and num_steps_per_epoch:
        epoch_dir = os.path.join(model_dir, 'epoch')
        gezi.try_mkdir(epoch_dir)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    checkpoint_path = os.path.join(model_dir, 'model.ckpt')

    tf.train.write_graph(sess.graph_def, model_dir, 'train.pbtxt')
    only_one_step = False
    try:
        step = start = pre_step + 1
        fixed_step = fixed_pre_step + 1
        #hack just for save one model after load
        if num_steps < 0 or (num_steps and num_steps < step):
            print('just load and resave then exit', file=sys.stderr)
            saver.save(sess,
                       _get_checkpoint_path(checkpoint_path, step,
                                            num_steps_per_epoch),
                       global_step=step)
            sess.close()
            exit(0)

        if num_epochs < 0:
            only_one_step = True
            print('just run one step', file=sys.stderr)

        early_stop = True  #TODO allow config
        num_bad_epochs = 0
        pre_epoch_eval_loss = 1e20
        best_epoch_eval_loss = 1e20
        num_allowed_bad_epochs = 4  #allow 5 non decrease eval loss epochs  before stop
        while not coord.should_stop():
            stop = train_once_fn(sess,
                                 step,
                                 is_start=(step == start),
                                 fixed_step=fixed_step)
            if only_one_step:
                stop = True
            if save_model and step:
                #step 0 is also saved! actually train one step and save
                if step % save_interval_steps == 0:
                    timer = gezi.Timer('save model step %d to %s' %
                                       (step, checkpoint_path))
                    saver.save(sess,
                               _get_checkpoint_path(checkpoint_path,
                                                    fixed_step,
                                                    num_steps_per_epoch),
                               global_step=step)
                    timer.print()
                #if save_interval_epochs and num_steps_per_epoch and step % (num_steps_per_epoch * save_interval_epochs) == 0:
                #if save_interval_epochs and num_steps_per_epoch and step % num_steps_per_epoch == 0:
                if save_interval_epochs and num_steps_per_epoch and fixed_step % num_steps_per_epoch == 0:
                    #epoch = step // num_steps_per_epoch
                    epoch = fixed_step // num_steps_per_epoch
                    eval_loss = melt.eval_loss()
                    if eval_loss:
                        #['eval_loss:3.2','eal_accuracy:4.3']
                        eval_loss = float(
                            eval_loss.strip('[]').split(',')[0].strip(
                                "'").split(':')[-1])
                        if os.path.exists(
                                os.path.join(epoch_dir, 'best_eval_loss.txt')):
                            with open(
                                    os.path.join(epoch_dir,
                                                 'best_eval_loss.txt')) as f:
                                best_epoch_eval_loss = float(
                                    f.readline().split()[-1].strip())
                        if eval_loss < best_epoch_eval_loss:
                            best_epoch_eval_loss = eval_loss
                            logging.info(
                                'Now best eval loss is epoch %d eval_loss:%f' %
                                (epoch, eval_loss))
                            with open(
                                    os.path.join(epoch_dir,
                                                 'best_eval_loss.txt'),
                                    'w') as f:
                                f.write('%d %d %f\n' %
                                        (epoch, step, best_epoch_eval_loss))
                            best_epoch_saver.save(
                                sess, os.path.join(epoch_dir,
                                                   'model.ckpt-best'))

                        with open(os.path.join(epoch_dir, 'eval_loss.txt'),
                                  'a') as f:
                            f.write('%d %d %f\n' % (epoch, step, eval_loss))
                        if eval_loss >= pre_epoch_eval_loss:
                            num_bad_epochs += 1
                            if num_bad_epochs > num_allowed_bad_epochs:
                                logging.warning(
                                    'Evaluate loss not decrease for last %d epochs'
                                    % (num_allowed_bad_epochs + 1))
                                if not os.path.exists(
                                        os.path.join(epoch_dir,
                                                     'model.ckpt-noimprove')):
                                    best_epoch_saver.save(
                                        sess,
                                        os.path.join(epoch_dir,
                                                     'model.ckpt-noimprove'))
                                ##-------well remove it since
                                #if early_stop:
                                #  stop = True
                        else:
                            num_bad_epochs = 0
                        pre_epoch_eval_loss = eval_loss
                    if step % (num_steps_per_epoch *
                               save_interval_epochs) == 0:
                        epoch_saver.save(sess,
                                         os.path.join(epoch_dir,
                                                      'model.ckpt-%d' % epoch),
                                         global_step=step)
                    #--------do not add step
                    # epoch_saver.save(sess,
                    #        os.path.join(epoch_dir,'model.ckpt-%d'%epoch))
            if stop is True:
                print('Early stop running %d stpes' % (step), file=sys.stderr)
                raise tf.errors.OutOfRangeError(
                    None, None, 'Early stop running %d stpes' % (step))
            if num_steps and (step + 1) == start + num_steps:
                raise tf.errors.OutOfRangeError(None, None,
                                                'Reached max num steps')
            #max_num_epochs = 1000
            max_num_epochs = num_epochs
            if max_num_epochs and num_steps_per_epoch and step // num_steps_per_epoch >= max_num_epochs:
                raise tf.errors.OutOfRangeError(
                    None, None,
                    'Reached max num epochs of %d' % max_num_epochs)
            step += 1
            fixed_step += 1
    except tf.errors.OutOfRangeError, e:
        if not (step
                == start) and save_model and step % save_interval_steps != 0:
            saver.save(sess,
                       _get_checkpoint_path(checkpoint_path, step,
                                            num_steps_per_epoch),
                       global_step=step)
        if only_one_step:
            print('Done one step', file=sys.stderr)
            exit(0)
        if metric_eval_fn is not None:
            metric_eval_fn()
        if (num_epochs and step / num_steps_per_epoch >= num_epochs) or (
                num_steps and (step + 1) == start + num_steps):
            print('Done training for %.3f epochs, %d steps.' %
                  (step / num_steps_per_epoch, step + 1),
                  file=sys.stderr)
            #FIXME becase coord.join seems not work,  RuntimeError: Coordinator stopped with threads still running: Thread-9
            exit(0)
        else:
            print('Should not stop, but stopped at epoch: %.3f' %
                  (step / num_steps_per_epoch),
                  file=sys.stderr)
            print(traceback.format_exc(), file=sys.stderr)
            raise e
コード例 #27
0
def main(dataset_dir, log_dir, tfrecord_filename, convlayer):
    
    plt.style.use('ggplot')
    image_size=299
    img_size=image_size*image_size*3

    file_pattern=tfrecord_filename + '_%s_*.tfrecord'
    #State the batch_size to evaluate each time, which can be a lot more than the training batch
    batch_size = 36

    #State the number of epochs to evaluate
    num_epochs = 1

    #Get the latest checkpoint file
    checkpoint_file = tf.train.latest_checkpoint(log_dir)

    #Just construct the graph from scratch again
    
    with tf.Graph().as_default() as graph:
        tf.logging.set_verbosity(tf.logging.INFO)
        #Get the dataset first and load one batch of validation images and labels tensors. Set is_training as False so as to use the evaluation preprocessing
        dataset = get_split('train', dataset_dir, file_pattern, tfrecord_filename)
        images, raw_images, labels = load_batch(dataset, batch_size = batch_size, height=image_size, width=image_size, is_training = False)

        #Create some information about the training steps
        x = tf.placeholder(tf.float32, shape=[None, img_size], name='x')
        #Now create the inference model but set is_training=False
        with slim.arg_scope(inception_resnet_v2_arg_scope()):
            logits, end_points = inception_resnet_v2(images, num_classes = dataset.num_classes, is_training = False)

        # #get all the variables to restore from the checkpoint file and create the saver function to restore
        variables_to_restore = slim.get_variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)
        def restore_fn(sess):
            return saver.restore(sess, checkpoint_file)

        #Just define the metrics to track without the loss or whatsoever
        conv2dx = end_points[convlayer]
        predictions = tf.argmax(end_points['Predictions'], 1)
        accuracy, accuracy_update = tf.contrib.metrics.streaming_accuracy(predictions, labels)
        
        def plot_conv_layer(layer_name, images):
            # Create a feed-dict containing just one image.
            # Calculate and retrieve the output values of the layer
            # when inputting that image.
            for j in range(5):
                image = images[j]
                values = sess.run(layer_name, feed_dict={x:np.reshape([image], [1, img_size], order='F')})
            
                # Number of filters used in the conv. layer.
                num_filters = values.shape[3]
            
                # Number of grids to plot.
                # Rounded-up, square-root of the number of filters.
                grids = math.ceil(math.sqrt(num_filters))
                
                # Create figure with a grid of sub-plots.
                fig, axes = plt.subplots(grids, grids)
                
                # Plot the output images of all the filters.
                for i, ax in enumerate(axes.flat):
                    # Only plot the images for valid filters.
                    if i<num_filters:
                        # Get the output image of using the i'th filter.
                        # See new_conv_layer() for details on the format
                        # of this 4-dim tensor.
                        img = values[0, :, :, i]
            
                        # Plot image.
                        ax.imshow(img, interpolation='nearest')
                    
                    # Remove ticks from the plot.
                    ax.set_xticks([])
                    ax.set_yticks([])
                
                # Ensure the plot is shown correctly with multiple plots
                # in a single Notebook cell.
                plt.show()
        def plot_sample_images(images, labels):
            for j in range(9):
                images = images[j]
                images.append(images)
            
            grids = math.ceil(math.sqrt(batch_size))
            fig, axes = plt.subplots(grids, grids)
            fig.subplots_adjust(hspace=0.50, wspace=0.2, top=0.97, bottom=0.06)
			        
            for i, ax in enumerate(axes.flat):
                label_name = dataset.labels_to_name[labels[i]]
                # Plot image.
                ax.imshow(images[i])
                xlabel = 'GroundTruth: ' + label_name
                # Show the classes as the label on the x-axis.
                ax.set_xlabel(xlabel)
                
                # Remove ticks from the plot.
                ax.set_xticks([])
                ax.set_yticks([])
            
            # Ensure the plot is shown correctly with multiple plots
            # in a single Notebook cell.
            plt.show()

         #Get your supervisor
        sv = tf.train.Supervisor(logdir =  None, summary_op = None, saver = None, init_fn = restore_fn)

        #Now we are ready to run in one session
        with sv.managed_session() as sess:

            #Now we want to visualize the last batch's images just to see what our model has predicted
            raw_images, labels, predictions = sess.run([raw_images, labels, predictions])       
            plot_conv_layer(conv2dx, raw_images)
            
            logging.info('Model Visualisation completed!.')
コード例 #28
0
def train_fusion():
    with tf.Graph().as_default() as g:
        flow_image = tf.placeholder(
            tf.float32, [None, IMG_HEIGHT, IMG_WIDTH, IMG_FLOW_CHANNEL],
            name='flow_image')
        label = tf.placeholder(tf.int32, [None, args.class_number],
                               name='label')
        is_training = tf.placeholder(tf.bool)
        flow_logits = two_stream_model('None', flow_image, args.network,
                                       args.class_number, args.keep_prob,
                                       args.batch_size, FRAMES_PER_VIDEO,
                                       is_training, 'flow')

        flow_variables = restore_variables()
        flow_restorer = tf.train.Saver(flow_variables)
        #vgg_16 first layer fo flow model
        if args.network == 'vgg_16':
            fiw_variables = slim.get_variables_to_restore(
                include=['flow_model/vgg_16/conv1/conv1_1'])
        elif args.network == 'resnet_v1_50':
            fiw_variables = slim.get_variables_to_restore(
                include=['flow_model/resnet_v1_50/conv1/weights'])
        flow_input_weights_restorer = tf.train.Saver(fiw_variables)
        #Loss
        with tf.name_scope('loss'):
            flow_loss = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits(logits=flow_logits,
                                                        labels=label))
            tf.summary.scalar('flow_loss', flow_loss)
        #Accuracy
        with tf.name_scope('accuracy'):
            flow_accuracy = tf.reduce_mean(
                tf.cast(
                    tf.equal(tf.argmax(flow_logits, 1), tf.argmax(label, 1)),
                    tf.float32))
            tf.summary.scalar('flow_accuracy', flow_accuracy)

        opt = tf.train.AdamOptimizer(args.lr)
        optimizer = slim.learning.create_train_op(flow_loss, opt)
        saver = tf.train.Saver()

        summary_op = tf.summary.merge_all()
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        with tf.Session(config=config, graph=g) as sess:
            summary_writer = tf.summary.FileWriter(TRAIN_LOG_DIR, sess.graph)
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            g.finalize()
            if args.network == 'vgg_16':
                flow_restorer.restore(sess, VGG_16_MODEL_DIR)
                flow_input_weights_restorer.restore(sess,
                                                    FLOW_INPUT_WEIGHTS_VGG_16)
            elif args.network == 'resnet_v1_50':
                flow_restorer.restore(sess, RES_v1_50_MODEL_DIR)
                flow_input_weights_restorer.restore(
                    sess, FLOW_INPUT_WEIGHTS_RES_v1_50)

            step = 0
            best_acc = 0.0
            best_ls = 10000.0
            best_val_acc = 0.0
            best_val_ls = 10000.0
            best_epoch = 0
            for epoch in range(args.epoches):
                acc_epoch = 0
                ls_epoch = 0
                batch_index = 0
                for i in range(len(train_video_indices) // args.batch_size):
                    step += 1
                    start_time = time.time()
                    flow_batch_data, batch_index = get_batches(
                        'None', TRAIN_FLOW_PATH, args.batch_size,
                        train_video_indices, batch_index, 'flow')

                    _, ls, acc, summary = sess.run(
                        [optimizer, flow_loss, flow_accuracy, summary_op],
                        feed_dict={
                            flow_image: flow_batch_data['images'],
                            label: flow_batch_data['labels'],
                            is_training: True
                        })
                    ls_epoch += ls
                    acc_epoch += acc

                    if i % 10 == 0:
                        end_time = time.time()
                        print('runing time {} :'.format(end_time - start_time))
                        print('Epoch {}, step {}, loss {}, acc {}'.format(
                            epoch + 1, step, ls, acc))
                        summary_writer.add_summary(summary, step)
                num = len(train_video_indices) // args.batch_size
                if best_acc < acc_epoch / num:
                    best_acc = acc_epoch / num
                if best_ls > ls_epoch / num:
                    best_ls = ls_epoch / num
                print(
                    '=========\n Epoch {}, best acc {}, best ls {}, loss {}, acc {}======'
                    .format(epoch + 1, best_acc, best_ls, ls_epoch / num,
                            acc_epoch / num))
                #validation
                ls_epoch = 0
                acc_epoch = 0
                batch_index = 0
                v_step = 0
                for i in range(
                        len(validation_video_indices) // args.batch_size):
                    v_step += 1
                    flow_batch_data, batch_index = get_batches(
                        'None', VALIDATION_FLOW_PATH, args.batch_size,
                        validation_video_indices, batch_index, 'flow')
                    ls, acc = sess.run(
                        [flow_loss, flow_accuracy],
                        feed_dict={
                            flow_image: flow_batch_data['images'],
                            label: flow_batch_data['labels'],
                            is_training: False
                        })
                    ls_epoch += ls
                    acc_epoch += acc

                if best_val_acc < acc_epoch / v_step:
                    best_val_acc = acc_epoch / v_step
                    best_epoch = epoch
                    saver.save(sess, TRAIN_CHECK_POINT + 'flow_train.ckpt')
                if best_val_ls > ls_epoch / v_step:
                    best_val_ls = ls_epoch / v_step

                print(
                    'Validation best epoch {}, best acc {}, best ls {}, acc {}, ls {}'
                    .format(best_epoch + 1, best_val_acc, best_val_ls,
                            ls_epoch / v_step, acc_epoch / v_step))
コード例 #29
0
def train_model(FLAGS):
    batch_size = FLAGS.batch_size

    tfrecords_list = [os.path.join(FLAGS.input_dir, 'train_tfrecords_5')]
    dataset = tf.data.TFRecordDataset(tfrecords_list)
    # TODO:
    dataset = dataset.map(_parse_ucf_features_train)
    dataset = dataset.shuffle(buffer_size=500)
    dataset = dataset.repeat(-1).batch(batch_size)

    iterator = dataset.make_initializable_iterator()
    next_elem = iterator.get_next()
    img_reshape, img_width, img_height, img_channel, img_label, img_reg = next_elem

    # valid
    # dataset_valid = tf.data.TFRecordDataset([os.path.join(FLAGS.input_dir, 'test_tfrecords_5')])
    # TODO:
    dataset_valid = tf.data.TFRecordDataset(
        [os.path.join(FLAGS.input_dir, 'test_tfrecords_5')])
    dataset_valid = dataset_valid.map(_parse_ucf_features_test).shuffle(
        buffer_size=500)
    dataset_valid = dataset_valid.repeat(-1).batch(batch_size)
    iterator_valid = dataset_valid.make_initializable_iterator()
    next_elem_valid = iterator_valid.get_next()
    img_reshape_valid, img_width_valid, img_height_valid, \
    img_channel_valid, img_label_valid, img_reg_valid = next_elem_valid

    # build model
    x_input = tf.placeholder(tf.float32,
                             shape=(batch_size, image_size, image_size, 3))
    y_input = tf.placeholder(tf.int64, shape=(None))
    reg_input = tf.placeholder(tf.float32, shape=(None))
    end_points_train = build_inception_model(x_input,
                                             y_input,
                                             reg_input,
                                             reuse=False,
                                             is_training=True,
                                             FLAGS=FLAGS)

    x_input_valid = tf.placeholder(tf.float32,
                                   shape=(batch_size, image_size, image_size,
                                          3))
    y_input_valid = tf.placeholder(tf.int64, shape=(None))
    reg_input_valid = tf.placeholder(tf.float32, shape=(None))
    end_point_test = build_inception_model(x_input_valid,
                                           y_input_valid,
                                           reg_input_valid,
                                           reuse=True,
                                           is_training=False,
                                           FLAGS=FLAGS)

    ## TODO: should defined before train_op
    # https://github.com/tensorflow/tensorflow/issues/7244
    variables = slim.get_variables_to_restore()
    variables_to_restore = [
        v for v in variables if v.name.split('/')[0] == 'InceptionV4'
    ]

    ## train op
    global_step = tf.train.get_or_create_global_step()
    inc_global_step = tf.assign(global_step, global_step + 1)
    learning_rate = tf.train.exponential_decay(FLAGS.learning_rate,
                                               global_step,
                                               FLAGS.decay_step,
                                               FLAGS.decay_rate,
                                               staircase=True)

    trainable_variables = []
    trainable_variables.extend(
        tf.trainable_variables(scope='InceptionV4/Mixed_6a'))
    trainable_variables.extend(
        tf.trainable_variables(scope='InceptionV4/Mixed_6b'))
    trainable_variables.extend(
        tf.trainable_variables(scope='InceptionV4/Mixed_6c'))
    trainable_variables.extend(
        tf.trainable_variables(scope='InceptionV4/Mixed_6d'))
    trainable_variables.extend(
        tf.trainable_variables(scope='InceptionV4/Mixed_6e'))
    trainable_variables.extend(
        tf.trainable_variables(scope='InceptionV4/Mixed_6f'))
    trainable_variables.extend(
        tf.trainable_variables(scope='InceptionV4/Mixed_6g'))
    trainable_variables.extend(
        tf.trainable_variables(scope='InceptionV4/Mixed_6h'))
    trainable_variables.extend(
        tf.trainable_variables(scope='InceptionV4/Mixed_7a'))
    trainable_variables.extend(
        tf.trainable_variables(scope='InceptionV4/Mixed_7b'))
    trainable_variables.extend(
        tf.trainable_variables(scope='InceptionV4/Mixed_7c'))
    trainable_variables.extend(
        tf.trainable_variables(scope='InceptionV4/Mixed_7d'))
    trainable_variables.extend(
        tf.trainable_variables(scope='InceptionV4/Logits'))
    trainable_variables.extend(tf.trainable_variables(scope='Beauty'))

    print('trainable_variables:')
    print(trainable_variables)
    # TODO: * use 'slim.learning.create_train_op' instead of 'optimizer.minimize'
    # and add update_ops
    # https://blog.csdn.net/qq_25737169/article/details/79616671
    # train_op = tf.train.AdamOptimizer(learning_rate, beta1=0.9, beta2=0.999, epsilon=1e-08, use_locking=False).minimize(
    #     end_points_train['cost'])
    optimizer = tf.train.AdamOptimizer(learning_rate,
                                       beta1=0.9,
                                       beta2=0.999,
                                       epsilon=1e-08,
                                       use_locking=False)
    train_op = slim.learning.create_train_op(end_points_train['cost'],
                                             optimizer, global_step)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    ## summary op
    cost_summary = tf.summary.scalar('cost', end_points_train['cost'])
    acc_summary = tf.summary.scalar('acc', end_points_train['acc'])
    learning_rate_summary = tf.summary.scalar('learning_rate', learning_rate)
    cost_rmse_summary = tf.summary.scalar('cost_mse',
                                          end_points_train['cost_rmse'])
    cost_entropy_summary = tf.summary.scalar('cost_entropy',
                                             end_points_train['cost_entropy'])
    L2_summary = tf.summary.scalar('L2', end_points_train['L2'])
    rmse_valid_summary = tf.summary.scalar('mse_valid',
                                           end_point_test['cost_rmse'])
    image_summary = tf.summary.image('input_img', img_reshape)

    regress_conn_summary = tf.summary.tensor_summary(
        'regress_conn', end_points_train['regress_conn'])
    y_input_summary = tf.summary.tensor_summary('y_input', y_input)
    regress_label_summary = tf.summary.tensor_summary(
        'regress_label', end_points_train['regress_label'])

    ## tf.summary.merge_all is deprecated
    # summary_op = tf.summary.merge_all()
    summary_op = tf.summary.merge([
        cost_summary, learning_rate_summary, cost_entropy_summary, acc_summary,
        cost_rmse_summary, image_summary, L2_summary, regress_conn_summary,
        y_input_summary, regress_label_summary
    ])

    saver = tf.train.Saver(variables_to_restore)
    saver_all = tf.train.Saver(max_to_keep=1)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    if not os.path.exists(FLAGS.save_model_dir):
        os.mkdir(FLAGS.save_model_dir)
    if not os.path.exists(FLAGS.summary_dir):
        os.mkdir(FLAGS.summary_dir)
    log_path = os.path.join(FLAGS.summary_dir, 'result.log')
    with tf.Session(config=config) as sess, open(log_path, 'w') as writer:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(iterator.initializer)
        sess.run(iterator_valid.initializer)
        ## tf.train.SummaryWriter is deprecated
        summary_writer = tf.summary.FileWriter(FLAGS.summary_dir,
                                               graph=sess.graph)

        if FLAGS.checkpoint is not None:
            saver.restore(sess, FLAGS.checkpoint)

        for step in range(FLAGS.max_iter):
            start_time = time.time()
            fetches = {
                'train_op': train_op,
                'global_step': global_step,
                'inc_global_step': inc_global_step,
                'update_ops': update_ops
            }

            if (step + 1) % FLAGS.print_info_freq == 0 or step == 0:
                fetches['cost'] = end_points_train['cost']
                fetches['acc'] = end_points_train['acc']
                fetches['regress_conn'] = end_points_train['regress_conn']
                fetches['predict_conn'] = end_points_train['predict_conn']
                fetches['regress_label'] = end_points_train['regress_label']
                fetches['predict_softmax'] = end_points_train[
                    'predict_softmax']
                fetches['beauty_weight'] = end_points_train['beauty_weight']
                fetches['cost_rmse'] = end_points_train['cost_rmse']
                fetches['cost_entropy'] = end_points_train['cost_entropy']
                fetches['learning_rate'] = learning_rate
                fetches['L2'] = end_points_train['L2']
                fetches['predict'] = end_points_train['predict']

            if (step + 1) % FLAGS.summary_freq == 0:
                fetches['summary_op'] = summary_op

            img_reshape_val, img_width_val, img_height_val, \
            img_channel_val, img_label_val, img_reg_val = \
                sess.run([img_reshape, img_width, img_height,
                          img_channel, img_label, img_reg])
            # print("shape of img_reshape_val:{}, shape of img_label_val:{}".format(
            #     img_reshape_val.shape, img_label_val.shape
            # ))

            result = sess.run(fetches,
                              feed_dict={
                                  x_input: img_reshape_val,
                                  y_input: img_label_val,
                                  reg_input: img_reg_val
                              })

            if (step + 1) % FLAGS.save_model_freq == 0:
                print("save model")
                saver_all.save(sess,
                               os.path.join(FLAGS.save_model_dir, 'model'),
                               global_step=global_step)

            if (step + 1) % FLAGS.summary_freq == 0:
                summary_writer.add_summary(result['summary_op'],
                                           result['global_step'])

            if (step + 1) % FLAGS.print_info_freq == 0 or step == 0:
                epoch = math.ceil(result['global_step'] * 1.0 /
                                  FLAGS.print_info_freq)
                rate = FLAGS.batch_size / (time.time() - start_time)
                print("epoch:{}\t, rate:{:.2f} image/sec".format(epoch, rate))
                print("global step:{}".format(result['global_step']))
                print("cost:{:.6f}".format(result['cost']))
                print("acc:{:.4f}".format(result['acc']))
                print("cost entropy:{:.6f}".format(result['cost_entropy']))
                print("cost rmse:{:.6f}".format(result['cost_rmse']))
                print("np cost rmse:{:.6f}".format(
                    np.mean(np.square(result['regress_label'].flat))))
                test = np.subtract(
                    result['regress_conn'],
                    np.reshape(img_label_val, newshape=(batch_size, 1)))
                # print("regression_conn - img_label_val:{}".format(test.flat[:]))
                # print("rmse sum:{:.6f}, length of img_label_val:{}".
                #       format(np.sum(np.abs(test)), len(img_label_val)))
                print("shape of predict_conn:{}".format(
                    result["predict_conn"].shape))
                print("shape of regression_conn:{}".format(
                    result["regress_conn"].shape))
                print("shape of label:{}".format(img_label_val.shape))
                print("L2:{:.6f}".format(result['L2']))
                print("learning rate:{:.10f}".format(result['learning_rate']))
                print("regress_conn:{}".format(
                    result["regress_conn"].flat[0:32]))
                print("regression_conn - img_label_val tf:{}".format(
                    result['regress_label'].flat[0:32]))

                print("label:{}".format(img_label_val[0:32]))
                print("reg_label:{}".format(img_reg_val[0:32]))
                print("predict:{}".format(result['predict'][0:32]))
                print("predict_softmax:{}".format(
                    result['predict_softmax'].flat[0:20]))
                # print("shape of predict_softmax:{}, shape of beauty weight:{}".format(
                #     result['predict_softmax'].shape, result['beauty_weight'].shape
                # ))
                print("summary:{}".format(FLAGS.summary_dir))
                print("")

            if (step % 500 == 0) or (step >= 3000 and
                                     (step + 1) % FLAGS.valid_freq == 0):
                batch_num = int(1100 / batch_size)
                accuracy_average = 0
                rmse_average = 0
                pearson_average = 0
                for i_valid in range(batch_num):
                    img_reshape_a, img_width_a, img_height_a, \
                    img_channel_a, img_label_a, img_reg_a = \
                        sess.run([img_reshape_valid, img_width_valid,
                                  img_height_valid, img_channel_valid,
                                  img_label_valid, img_reg_valid])

                    accuracy, rmse, predict_valid, regress_conn_valid, \
                    summary_str, global_step_val, predict_conn_valid = sess.run(
                        [end_point_test['acc'], end_point_test['cost_rmse'],
                         end_point_test['predict'], end_point_test['regress_conn'],
                         rmse_valid_summary, global_step, end_point_test['predict_conn']],
                        feed_dict={
                            x_input_valid: img_reshape_a,
                            y_input_valid: img_label_a,
                            reg_input_valid: img_reg_a
                        })
                    summary_writer.add_summary(summary_str,
                                               global_step_val + i_valid)

                    # TODO
                    pearson_val = pearsonr(img_reg_a.flat[:],
                                           regress_conn_valid.flat[:])[0]
                    if i_valid == 0:
                        print('predict:{}'.format(predict_valid))
                        print('label:{}'.format(img_label_a))
                        print('predict_conn_valid:{}'.format(
                            predict_conn_valid.flat[0:20]))
                        # print('reg_label:{}'.format(img_reg_a))
                        print('regress_conn:{}'.format(
                            regress_conn_valid.flat[:]))
                    print(
                        'valid acc:{:.4f}, valid rmse:{:.4f}, pearson cor:{:.3f}'
                        .format(accuracy, rmse, pearson_val))
                    accuracy_average += accuracy
                    rmse_average += rmse
                    pearson_average += pearson_val
                accuracy_average /= batch_num
                rmse_average /= batch_num
                pearson_average /= batch_num
                print('valid av_acc:{:.4f}, av_rmse:{:.4f}, av_peason:{:.4f}'.
                      format(accuracy_average, rmse_average, pearson_average))
                writer.write(
                    'iter{}, valid av_acc:{:.4f}, av_rmse:{:.4f}, av_peason:{:.4f}\n'
                    .format(step, accuracy_average, rmse_average,
                            pearson_average))
                writer.flush()

                if pearson_average >= 0.91:
                    shutil.copytree(
                        FLAGS.save_model_dir,
                        "{}_{:.5f}_{}".format(FLAGS.save_model_dir,
                                              pearson_average, step))
        summary_writer.close()
コード例 #30
0
video_size = FLAGS.num_frames
total_size = batch_size * video_size

video_data = tf.placeholder(tf.float32, [batch_size, video_size, 224, 224, 3])

batch_video_data = tf.reshape(video_data, [total_size, 224, 224, 3])

# for i in range(2):

pre_logit, epoints = resnet_v2.resnet_v2_50(
    inputs=batch_video_data,
    num_classes=None,
    # reuse = True,
    scope='resnet_v2_50')

orig_vars = slim.get_variables_to_restore()

with tf.variable_scope('post_conv'):
    # pre_logit = tf.reshape(pre_logit, [total_size, 2048])
    embeddings = layers.fully_connected(pre_logit,
                                        1024 if FLAGS.big_embeddings else 10,
                                        activation_fn=None)
    activations = tf.nn.relu(embeddings)
    scores = layers.fully_connected(activations, 1, activation_fn=None)
    scores = tf.reshape(scores, [batch_size, video_size, 1])

post_conv_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   scope='post_conv')


def pl_kl(scores):
コード例 #31
0
def train(model_config):
    print(list_config(model_config))
    train_dataloader = TrainData(model_config)

    graph = Graph(True, model_config, train_dataloader)
    graph.create_model_multigpu()
    print('Built Model Done!')

    if model_config.warm_start:
        ckpt_path = model_config.warm_start
        var_list = slim.get_variables_to_restore()
        available_vars = {}
        reader = tf.train.NewCheckpointReader(ckpt_path)
        var_dict = {var.op.name: var for var in var_list}
        for var in var_dict:
            if reader.has_tensor(var):
                var_ckpt = reader.get_tensor(var)
                var_cur = var_dict[var]
                if any([
                        var_cur.shape[i] != var_ckpt.shape[i]
                        for i in range(len(var_ckpt.shape))
                ]):
                    print('Variable %s missing due to shape.', var)
                else:
                    available_vars[var] = var_dict[var]
            else:
                print('Variable %s missing.', var)

        partial_restore_ckpt = slim.assign_from_checkpoint_fn(
            ckpt_path,
            available_vars,
            ignore_missing_vars=True,
            reshape_variables=False)

    with tf.train.MonitoredTrainingSession(
            checkpoint_dir=model_config.logdir,
            save_checkpoint_secs=model_config.save_model_secs,
            config=get_session_config(),
            hooks=[
                tf.train.CheckpointSaverHook(
                    model_config.logdir,
                    save_secs=model_config.save_model_secs,
                    saver=graph.saver)
            ],
            save_summaries_steps=None,
            save_summaries_secs=None,  # Disable tf.summary
    ) as sess:

        run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        run_metadata = tf.RunMetadata()

        if model_config.warm_start:
            partial_restore_ckpt(sess)
            print('Warm start with ckpt %s' % model_config.warm_start)
        else:
            ckpt = tf.train.get_checkpoint_state(model_config.logdir)
            if ckpt:
                print('Loading previous checkpoint from: %s' %
                      model_config.logdir)
                graph.saver.restore(sess, ckpt.model_checkpoint_path)

        if model_config.init_vocab_emb:
            sess.run(graph.vocab_embs_init_fn)
            print('init vocab embedding from %s' % model_config.init_vocab_emb)

        if model_config.init_abbr_emb and hasattr(graph, 'abbr_embs_init_fn'):
            sess.run(graph.abbr_embs_init_fn)
            print('init abbr embedding from %s' % model_config.init_abbr_emb)

        if model_config.init_cui_emb:
            sess.run(graph.sense_embs_init_fn)
            print('init cui embedding from %s' % model_config.init_abbr_emb)

        perplexitys = []
        start_time = datetime.now()
        epoch = 0
        previous_step = 0
        previous_step_cui = 0
        while True:
            epoch += 1
            progbar = Progbar(target=train_dataloader.size)
            # Train task
            for _ in range(model_config.task_iter_steps):
                batch_start_time = time.time()
                input_feed, _, targets = get_feed(graph.data_feeds,
                                                  train_dataloader,
                                                  model_config, True)
                # print('\nLoad data, time=%s' % str(time.time()-batch_start_time))
                fetches = [
                    graph.train_op, graph.increment_global_step_task,
                    graph.increment_global_step, graph.global_step_task,
                    graph.perplexity, graph.loss
                ]

                batch_start_time = time.time()
                _, _, _, step, perplexity, loss = sess.run(
                    fetches,
                    input_feed,
                    options=run_options,
                    run_metadata=run_metadata)

                if model_config.progress_bar:
                    print('\nForward and backward, time=%s' %
                          str(time.time() - batch_start_time))

                    # Create the Timeline object, and write it to a json
                    tl = timeline.Timeline(run_metadata.step_stats)
                    ctf = tl.generate_chrome_trace_format()
                    with open('timeline/timeline.json', 'w') as f:
                        f.write(ctf)

                    # if step == 2:
                    #     exit()

                    perplexitys.append(perplexity)
                    progbar.update(current=targets[-1]['line_id'],
                                   values=[('loss', loss),
                                           ('ppl', perplexity)])

                if (step - previous_step) > model_config.model_print_freq:
                    end_time = datetime.now()
                    time_span = end_time - start_time
                    start_time = end_time
                    print(
                        '\nTASK: Perplexity:\t%f at step=%d using %s with loss=%s.'
                        % (perplexity, step, time_span, np.mean(loss)))
                    perplexitys.clear()
                    previous_step = step

                # evaluate after a few steps
                if step and step % 2000 == 0:
                    test.evaluate_and_write_to_disk(
                        sess,
                        graph,
                        model_config,
                        train_dataloader,
                        output_file_path=model_config.logdir +
                        'test_score.csv',
                        epoch=epoch,
                        step=step,
                        loss=loss,
                        perplexity=perplexity)

            # Fine tune CUI
            if model_config.extra_mode:
                for _ in range(model_config.cui_iter_steps):
                    input_feed = get_feed_cui(graph.obj_cui, train_dataloader,
                                              model_config)

                    fetches = [
                        graph.train_op_cui, graph.increment_global_step_cui,
                        graph.increment_global_step, graph.global_step_cui,
                        graph.perplexity_cui, graph.loss_cui
                    ]

                    _, _, _, step, perplexity, loss = sess.run(
                        fetches, input_feed)

                    if (step -
                            previous_step_cui) > model_config.model_print_freq:
                        end_time = datetime.now()
                        time_span = end_time - start_time
                        start_time = end_time
                        print(
                            'CUI: Perplexity:\t%f at step %d using %s with loss:%s.'
                            % (perplexity, step, time_span, np.mean(loss)))
                        perplexitys.clear()
                        previous_step_cui = step
コード例 #32
0
ファイル: train_discoverer.py プロジェクト: asmadotgh/dissect
def train():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', '-c', type=str)
    parser.add_argument('--debug', '-d', action='store_true')
    args = parser.parse_args()

    # ============= Load config =============
    config_path = args.config
    config = yaml.load(open(config_path))
    print(config)

    # ============= Experiment Folder=============
    assets_dir = os.path.join(config['log_dir'], config['name'])
    log_dir = os.path.join(assets_dir, 'log')
    ckpt_dir = os.path.join(assets_dir, 'ckpt_dir')
    sample_dir = os.path.join(assets_dir, 'sample')
    test_dir = os.path.join(assets_dir, 'test')
    # make directory if not exist
    try:
        os.makedirs(log_dir)
    except:
        pass
    try:
        os.makedirs(ckpt_dir)
    except:
        pass
    try:
        os.makedirs(sample_dir)
    except:
        pass
    try:
        os.makedirs(test_dir)
    except:
        pass

    # ============= Experiment Parameters =============
    ckpt_dir_cls = config['cls_experiment']
    BATCH_SIZE = config['batch_size']
    EPOCHS = config['epochs']
    channels = config['num_channel']
    input_size = config['input_size']
    NUMS_CLASS_cls = config['num_class']
    NUMS_CLASS = config['num_bins']
    target_class = config['target_class']
    lambda_GAN = config['lambda_GAN']
    lambda_cyc = config['lambda_cyc']
    lambda_cls = config['lambda_cls']
    save_summary = int(config['save_summary'])
    save_ckpt = int(config['save_ckpt'])
    ckpt_dir_continue = config['ckpt_dir_continue']
    k_dim = config['k_dim']
    lambda_r = config['lambda_r']
    disentangle = k_dim > 1
    discriminate_evert_nth = config['discriminate_every_nth']
    generate_every_nth = config['generate_every_nth']
    dataset = config['dataset']
    if dataset == 'CelebA':
        pretrained_classifier = celeba_classifier
        my_data_loader = ImageLabelLoader()
        Discriminator_Ordinal = Discriminator_Ordinal_128
        Generator_Encoder_Decoder = Generator_Encoder_Decoder_128
        Discriminator_Contrastive = Discriminator_Contrastive_128
    elif dataset == 'shapes':
        pretrained_classifier = shapes_classifier
        if args.debug:
            my_data_loader = ShapesLoader(
                dbg_mode=True,
                dbg_size=config['batch_size'],
                dbg_image_label_dict=config['image_label_dict'])
        else:
            my_data_loader = ShapesLoader()
        Discriminator_Ordinal = Discriminator_Ordinal_64
        Generator_Encoder_Decoder = Generator_Encoder_Decoder_64
        Discriminator_Contrastive = Discriminator_Contrastive_64
    elif dataset == 'CelebA64' or dataset == 'dermatology':
        pretrained_classifier = celeba_classifier
        my_data_loader = ImageLabelLoader(input_size=64)
        Discriminator_Ordinal = Discriminator_Ordinal_64
        Generator_Encoder_Decoder = Generator_Encoder_Decoder_64
        Discriminator_Contrastive = Discriminator_Contrastive_64
    elif dataset == 'synthderm':
        pretrained_classifier = celeba_classifier
        my_data_loader = ImageLabelLoader(input_size=64)
        Discriminator_Ordinal = Discriminator_Ordinal_64
        Generator_Encoder_Decoder = Generator_Encoder_Decoder_64
        Discriminator_Contrastive = Discriminator_Contrastive_64

    if ckpt_dir_continue == '':
        continue_train = False
    else:
        ckpt_dir_continue = os.path.join(ckpt_dir_continue, 'ckpt_dir')
        continue_train = True

    global_step = tf.Variable(0,
                              dtype=tf.int32,
                              trainable=False,
                              name='global_step')

    # ============= Data =============
    try:
        categories, file_names_dict = read_data_file(
            config['image_label_dict'])
    except:
        print("Problem in reading input data file : ",
              config['image_label_dict'])
        sys.exit()
    data = np.asarray(list(file_names_dict.keys()))
    print("The classification categories are: ")
    print(categories)
    print('The size of the training set: ', data.shape[0])
    fp = open(os.path.join(log_dir, 'setting.txt'), 'w')
    fp.write('config_file:' + str(config_path) + '\n')
    fp.close()

    # ============= placeholder =============
    x_source = tf.placeholder(tf.float32,
                              [None, input_size, input_size, channels],
                              name='x_source')
    y_s = tf.placeholder(tf.int32, [None, NUMS_CLASS], name='y_s')
    y_source = y_s[:, 0]
    train_phase = tf.placeholder(tf.bool, name='train_phase')

    y_t = tf.placeholder(tf.int32, [None, NUMS_CLASS], name='y_t')
    y_target = y_t[:, 0]

    if disentangle:
        y_regularizer = tf.placeholder(tf.int32, [None], name='y_regularizer')
        y_r = tf.placeholder(tf.float32, [None, k_dim], name='y_r')
        y_r_0 = tf.zeros_like(y_r, name='y_r_0')

    # ============= G & D =============
    G = Generator_Encoder_Decoder(
        "generator")  # with conditional BN, SAGAN: SN here as well
    D = Discriminator_Ordinal("discriminator")  # with SN and projection

    real_source_logits = D(x_source, y_s, NUMS_CLASS, "NO_OPS")
    if disentangle:
        fake_target_img, fake_target_img_embedding = G(
            x_source, y_regularizer * NUMS_CLASS + y_target,
            NUMS_CLASS * k_dim)
        fake_source_img, fake_source_img_embedding = G(
            fake_target_img, y_regularizer * NUMS_CLASS + y_source,
            NUMS_CLASS * k_dim)
        fake_source_recons_img, x_source_img_embedding = G(
            x_source, y_regularizer * NUMS_CLASS + y_source,
            NUMS_CLASS * k_dim)
    else:
        fake_target_img, fake_target_img_embedding = G(x_source, y_target,
                                                       NUMS_CLASS)
        fake_source_img, fake_source_img_embedding = G(fake_target_img,
                                                       y_source, NUMS_CLASS)
        fake_source_recons_img, x_source_img_embedding = G(
            x_source, y_source, NUMS_CLASS)
    fake_target_logits = D(fake_target_img, y_t, NUMS_CLASS, None)

    # ============= pre-trained classifier =============
    real_img_cls_logit_pretrained, real_img_cls_prediction = pretrained_classifier(
        x_source, NUMS_CLASS_cls, reuse=False, name='classifier')
    fake_img_cls_logit_pretrained, fake_img_cls_prediction = pretrained_classifier(
        fake_target_img, NUMS_CLASS_cls, reuse=True)
    real_img_recons_cls_logit_pretrained, real_img_recons_cls_prediction = pretrained_classifier(
        fake_source_img, NUMS_CLASS_cls, reuse=True)

    # ============= pre-trained classifier loss =============
    real_p = tf.cast(y_target, tf.float32) * 1.0 / float(NUMS_CLASS - 1)
    fake_q = fake_img_cls_prediction[:, target_class]
    fake_evaluation = (real_p * safe_log(fake_q)) + (
        (1 - real_p) * safe_log(1 - fake_q))
    fake_evaluation = -tf.reduce_mean(fake_evaluation)

    recons_evaluation = (real_img_cls_prediction[:, target_class] * safe_log(
        real_img_recons_cls_prediction[:, target_class])) + (
            (1 - real_img_cls_prediction[:, target_class]) *
            safe_log(1 - real_img_recons_cls_prediction[:, target_class]))
    recons_evaluation = -tf.reduce_mean(recons_evaluation)

    # ============= regularizer constrastive discriminator loss =============
    if disentangle:
        R = Discriminator_Contrastive("disentangler")

        regularizer_fake_target_v_source_logits = R(
            tf.concat([x_source, fake_target_img], axis=-1), k_dim)
        regularizer_fake_source_v_target_logits = R(
            tf.concat([fake_target_img, fake_source_img], axis=-1), k_dim)
        regularizer_fake_source_v_source_logits = R(
            tf.concat([x_source, fake_source_img], axis=-1), k_dim)
        regularizer_fake_source_recon_v_source_logits = R(
            tf.concat([x_source, fake_source_recons_img], axis=-1), k_dim)

    # ============= Loss =============
    D_loss_GAN, D_acc, D_precision, D_recall = discriminator_loss(
        'hinge', real_source_logits, fake_target_logits)
    G_loss_GAN = generator_loss('hinge', fake_target_logits)
    G_loss_cyc = l1_loss(x_source, fake_source_img)
    G_loss_rec = l1_loss(
        x_source, fake_source_recons_img
    )  #+l2_loss(x_source_img_embedding, fake_source_img_embedding)
    D_loss = (D_loss_GAN * lambda_GAN)
    D_opt = tf.train.AdamOptimizer(2e-4, beta1=0.,
                                   beta2=0.9).minimize(D_loss,
                                                       var_list=D.var_list(),
                                                       global_step=global_step)

    if disentangle:
        R_fake_target_v_source_loss, R_fake_target_v_source_acc = contrastive_regularizer_loss(
            regularizer_fake_target_v_source_logits, y_r)
        R_fake_source_v_target_loss, R_fake_source_v_target_acc = contrastive_regularizer_loss(
            regularizer_fake_source_v_target_logits, y_r)
        R_fake_source_v_source_loss, R_fake_source_v_source_acc = contrastive_regularizer_loss(
            regularizer_fake_source_v_source_logits, y_r_0)
        R_fake_source_recon_v_source_loss, R_fake_source_recon_v_source_acc = contrastive_regularizer_loss(
            regularizer_fake_source_recon_v_source_logits, y_r_0)
        R_loss = R_fake_target_v_source_loss + R_fake_source_v_target_loss + R_fake_source_v_source_loss + R_fake_source_recon_v_source_loss
        R_opt = tf.train.AdamOptimizer(2e-4, beta1=0., beta2=0.9).minimize(
            R_loss * lambda_r, var_list=R.var_list(), global_step=global_step)
        G_loss = (G_loss_GAN * lambda_GAN) + (G_loss_rec * lambda_cyc) + (
            G_loss_cyc * lambda_cyc) + (fake_evaluation * lambda_cls) + (
                recons_evaluation * lambda_cls) + (R_loss * lambda_r)
        G_opt = tf.train.AdamOptimizer(2e-4, beta1=0., beta2=0.9).minimize(
            G_loss,
            var_list=G.var_list() + R.var_list(),
            global_step=global_step)
    else:
        G_loss = (G_loss_GAN * lambda_GAN) + (G_loss_rec * lambda_cyc) + (
            G_loss_cyc * lambda_cyc) + (fake_evaluation * lambda_cls) + (
                recons_evaluation * lambda_cls)
        G_opt = tf.train.AdamOptimizer(2e-4, beta1=0., beta2=0.9).minimize(
            G_loss, var_list=G.var_list(), global_step=global_step)

    # ============= summary =============
    real_img_sum = tf.summary.image('real_img', x_source)
    fake_img_sum = tf.summary.image('fake_target_img', fake_target_img)
    fake_source_img_sum = tf.summary.image('fake_source_img', fake_source_img)
    fake_source_recons_img_sum = tf.summary.image('fake_source_recons_img',
                                                  fake_source_recons_img)

    acc_d = tf.summary.scalar('discriminator/acc_d', D_acc)
    precision_d = tf.summary.scalar('discriminator/precision_d', D_precision)
    recall_d = tf.summary.scalar('discriminator/recall_d', D_recall)
    loss_d_sum = tf.summary.scalar('discriminator/loss_d', D_loss)
    loss_d_GAN_sum = tf.summary.scalar('discriminator/loss_d_GAN', D_loss_GAN)

    loss_g_sum = tf.summary.scalar('generator/loss_g', G_loss)
    loss_g_GAN_sum = tf.summary.scalar('generator/loss_g_GAN', G_loss_GAN)
    loss_g_cyc_sum = tf.summary.scalar('generator/G_loss_cyc', G_loss_cyc)
    G_loss_rec_sum = tf.summary.scalar('generator/G_loss_rec', G_loss_rec)

    evaluation_fake = tf.summary.scalar('generator/fake_evaluation',
                                        fake_evaluation)
    evaluation_recons = tf.summary.scalar('generator/recons_evaluation',
                                          recons_evaluation)
    g_sum = tf.summary.merge([
        loss_g_sum, loss_g_GAN_sum, loss_g_cyc_sum, real_img_sum,
        G_loss_rec_sum, fake_img_sum, fake_source_img_sum,
        fake_source_recons_img_sum, evaluation_fake, evaluation_recons
    ])
    d_sum = tf.summary.merge(
        [loss_d_sum, loss_d_GAN_sum, acc_d, precision_d, recall_d])
    # Disentangler Contrastive Regularizer losses
    if disentangle:
        loss_r_fake_target_v_source = tf.summary.scalar(
            'disentangler/loss_r_fake_target_v_source',
            R_fake_target_v_source_loss)
        loss_r_fake_source_v_target = tf.summary.scalar(
            'disentangler/loss_r_fake_source_v_target',
            R_fake_source_v_target_loss)
        loss_r_fake_source_v_source = tf.summary.scalar(
            'disentangler/loss_r_fake_source_v_source',
            R_fake_source_v_source_loss)
        loss_r_fake_source_recon_v_source = tf.summary.scalar(
            'disentangler/loss_r_fake_source_recon_v_source',
            R_fake_source_recon_v_source_loss)
        loss_r_sum = tf.summary.scalar('disentangler/loss_r', R_loss)

        acc_r_fake_target_v_source = tf.summary.scalar(
            'disentangler/acc_r_fake_target_v_source',
            R_fake_target_v_source_acc)
        acc_r_fake_source_v_target = tf.summary.scalar(
            'disentangler/acc_r_fake_source_v_target',
            R_fake_source_v_target_acc)
        acc_r_fake_source_v_source = tf.summary.scalar(
            'disentangler/acc_r_fake_source_v_source',
            R_fake_source_v_source_acc)
        acc_r_fake_source_recon_v_source = tf.summary.scalar(
            'disentangler/acc_r_fake_source_recon_v_source',
            R_fake_source_recon_v_source_acc)
        r_sum = tf.summary.merge([
            loss_r_sum, loss_r_fake_target_v_source,
            loss_r_fake_source_v_target, loss_r_fake_source_v_source,
            loss_r_fake_source_recon_v_source, acc_r_fake_target_v_source,
            acc_r_fake_source_v_target, acc_r_fake_source_v_source,
            acc_r_fake_source_recon_v_source
        ])

    # ============= session =============
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()

    writer = tf.summary.FileWriter(log_dir, sess.graph)

    # ============= Checkpoints =============
    if continue_train:
        print(" [*] before training, Load checkpoint ")
        print(" [*] Reading checkpoint...")

        ckpt = tf.train.get_checkpoint_state(ckpt_dir_continue)
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            saver.restore(sess, os.path.join(ckpt_dir_continue, ckpt_name))
            print(ckpt_dir_continue, ckpt_name)
            print("Successful checkpoint upload")
        else:
            print("Failed checkpoint load")
    else:
        print(" [!] before training, no need to Load ")

    # ============= load pre-trained classifier checkpoint =============
    class_vars = [
        var for var in slim.get_variables_to_restore()
        if 'classifier' in var.name
    ]
    name_to_var_map_local = {var.op.name: var for var in class_vars}
    temp_saver = tf.train.Saver(var_list=name_to_var_map_local)
    ckpt = tf.train.get_checkpoint_state(ckpt_dir_cls)
    ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
    temp_saver.restore(sess, os.path.join(ckpt_dir_cls, ckpt_name))
    print("Classifier checkpoint loaded.................")
    print(ckpt_dir_cls, ckpt_name)

    # ============= Training =============
    for e in range(EPOCHS):
        np.random.shuffle(data)
        for i in range(data.shape[0] // BATCH_SIZE):
            if args.debug:
                image_paths = np.array(
                    [str(ind) for ind in my_data_loader.tmp_list])
            else:
                image_paths = data[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
            img, labels = my_data_loader.load_images_and_labels(
                image_paths,
                image_dir=config['image_dir'],
                n_class=1,
                file_names_dict=file_names_dict,
                num_channel=channels,
                do_center_crop=True)

            labels = labels.ravel()
            target_labels = np.random.randint(0,
                                              high=NUMS_CLASS,
                                              size=BATCH_SIZE)

            identity_ind = labels == target_labels

            labels = convert_ordinal_to_binary(labels, NUMS_CLASS)
            target_labels = convert_ordinal_to_binary(target_labels,
                                                      NUMS_CLASS)

            if disentangle:
                target_disentangle_ind = np.random.randint(0,
                                                           high=k_dim,
                                                           size=BATCH_SIZE)
                target_disentangle_ind_one_hot = np.eye(
                    k_dim)[target_disentangle_ind]
                target_disentangle_ind_one_hot[identity_ind, :] = 0
                my_feed_dict = {
                    y_t: target_labels,
                    x_source: img,
                    train_phase: True,
                    y_s: labels,
                    y_regularizer: target_disentangle_ind,
                    y_r: target_disentangle_ind_one_hot
                }
            else:
                my_feed_dict = {
                    y_t: target_labels,
                    x_source: img,
                    train_phase: True,
                    y_s: labels
                }

            if (i + 1) % discriminate_evert_nth == 0:

                _, d_loss, summary_str, counter = sess.run(
                    [D_opt, D_loss, d_sum, global_step],
                    feed_dict=my_feed_dict)
                writer.add_summary(summary_str, counter)

            if (i + 1) % generate_every_nth == 0:
                if disentangle:
                    _, g_loss, g_summary_str, r_loss, r_summary_str, counter = sess.run(
                        [G_opt, G_loss, g_sum, R_loss, r_sum, global_step],
                        feed_dict=my_feed_dict)
                    # _, r_loss, r_summary_str = sess.run([R_opt, R_loss, r_sum], feed_dict=my_feed_dict)
                    writer.add_summary(r_summary_str, counter)
                else:
                    _, g_loss, g_summary_str, counter = sess.run(
                        [G_opt, G_loss, g_sum, global_step],
                        feed_dict=my_feed_dict)
                writer.add_summary(g_summary_str, counter)

            def save_results(sess, step):
                num_seed_imgs = 8
                img, labels = my_data_loader.load_images_and_labels(
                    image_paths[0:num_seed_imgs],
                    image_dir=config['image_dir'],
                    n_class=1,
                    file_names_dict=file_names_dict,
                    num_channel=channels,
                    do_center_crop=True)
                labels = np.repeat(labels, NUMS_CLASS * k_dim, 0)
                labels = labels.ravel()
                labels = convert_ordinal_to_binary(labels, NUMS_CLASS)
                img_repeat = np.repeat(img, NUMS_CLASS * k_dim, 0)

                target_labels = np.asarray([
                    np.asarray(range(NUMS_CLASS))
                    for j in range(num_seed_imgs * k_dim)
                ])
                target_labels = target_labels.ravel()
                identity_ind = labels == target_labels
                target_labels = convert_ordinal_to_binary(
                    target_labels, NUMS_CLASS)

                if disentangle:
                    target_disentangle_ind = np.asarray([
                        np.repeat(np.asarray(range(k_dim)), NUMS_CLASS)
                        for j in range(num_seed_imgs)
                    ])
                    target_disentangle_ind = target_disentangle_ind.ravel()
                    target_disentangle_ind_one_hot = np.eye(
                        k_dim)[target_disentangle_ind]
                    target_disentangle_ind_one_hot[identity_ind, :] = 0
                    my_feed_dict = {
                        y_t: target_labels,
                        x_source: img_repeat,
                        train_phase: False,
                        y_s: labels,
                        y_regularizer: target_disentangle_ind,
                        y_r: target_disentangle_ind_one_hot
                    }
                else:
                    my_feed_dict = {
                        y_t: target_labels,
                        x_source: img_repeat,
                        train_phase: False,
                        y_s: labels
                    }

                FAKE_IMG, fake_logits_ = sess.run(
                    [fake_target_img, fake_target_logits],
                    feed_dict=my_feed_dict)

                output_fake_img = np.reshape(
                    FAKE_IMG,
                    [-1, k_dim, NUMS_CLASS, input_size, input_size, channels])

                # save samples
                sample_file = os.path.join(sample_dir, '%06d.jpg' % step)
                save_images(output_fake_img,
                            sample_file,
                            num_samples=num_seed_imgs,
                            nums_class=NUMS_CLASS,
                            k_dim=k_dim,
                            image_size=input_size,
                            num_channel=channels)
                np.save(sample_file.split('.jpg')[0] + '_y.npy', labels)

            _approx_num_seen_batches = int(counter / 3)
            if _approx_num_seen_batches % save_summary == 0:
                save_results(sess, _approx_num_seen_batches)

            if _approx_num_seen_batches % save_ckpt == 0:
                saver.save(sess,
                           ckpt_dir +
                           "/model%2d.ckpt" % _approx_num_seen_batches,
                           global_step=global_step)
コード例 #33
0
def run_training():
    sess = tf.Session()  # config=tf.ConfigProto(log_device_placement=True))

    # create input path and labels np.array from csv annotations
    df_annos = pd.read_csv(ANNOS_CSV, index_col=0)
    df_annos = df_annos.sample(frac=1).reset_index(
        drop=True)  # shuffle the whole datasets
    if DATA == 'l8':
        path_col = ['l8_vis_jpg']
    elif DATA == 's1':
        path_col = ['s1_vis_jpg']
    elif DATA == 'l8s1':
        path_col = ['l8_vis_jpg', 's1_vis_jpg']

    input_files_train = JPG_DIR + df_annos.loc[df_annos.partition == 'train',
                                               path_col].values
    input_labels_train = df_annos.loc[df_annos.partition == 'train',
                                      'pop_density_log2'].values
    input_files_val = JPG_DIR + df_annos.loc[df_annos.partition == 'val',
                                             path_col].values
    input_labels_val = df_annos.loc[df_annos.partition == 'val',
                                    'pop_density_log2'].values
    input_id_train = df_annos.loc[df_annos.partition == 'train',
                                  'village_id'].values
    input_id_val = df_annos.loc[df_annos.partition == 'val',
                                'village_id'].values

    print('input_files_train shape:', input_files_train.shape)
    train_set_size = len(input_labels_train)

    # data input
    with tf.device('/cpu:0'):
        train_images_batch, train_labels_batch, _ = \
        dataset.input_batches(FLAGS.batch_size, FLAGS.output_size, input_files_train, input_labels_train, input_id_train,
                              IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNEL, regression=True, augmentation=True, normalization=True)
        val_images_batch, val_labels_batch, _ = \
        dataset.input_batches(FLAGS.batch_size, FLAGS.output_size, input_files_val, input_labels_val, input_id_val,
                              IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNEL, regression=True, augmentation=False, normalization=True)

    images_l8_placeholder = tf.placeholder(
        tf.float32, shape=[None, IMAGE_HEIGHT, IMAGE_WIDTH, 3])
    images_s1_placeholder = tf.placeholder(
        tf.float32, shape=[None, IMAGE_HEIGHT, IMAGE_WIDTH, 3])
    vectors_nl_placeholder = tf.placeholder(tf.float32, shape=[None, 25])
    labels_placeholder = tf.placeholder(tf.float32, shape=[
        None,
    ])
    print('finish data input')

    TRAIN_BATCHES_PER_EPOCH = int(
        train_set_size /
        FLAGS.batch_size)  # number of training batches/steps in each epoch
    MAX_STEPS = TRAIN_BATCHES_PER_EPOCH * FLAGS.max_epoch  # total number of training batches/steps

    # CNN forward reference
    if MODEL == 'vgg':
        with slim.arg_scope(
                vgg.vgg_arg_scope(weight_decay=FLAGS.weight_decay)):
            outputs, _ = vgg.vgg_16(images_l8_placeholder,
                                    images_s1_placeholder,
                                    num_classes=FLAGS.output_size,
                                    dropout_keep_prob=FLAGS.dropout_keep,
                                    is_training=True)
            outputs = tf.squeeze(
                outputs
            )  # change shape from (B,1) to (B,), same as label input
    if MODEL == 'resnet':
        with slim.arg_scope(resnet_v1.resnet_arg_scope()):
            outputs, _ = resnet_v1.resnet_v1_152(images_placeholder,
                                                 num_classes=FLAGS.output_size,
                                                 is_training=True)
            outputs = tf.squeeze(
                outputs
            )  # change shape from (B,1) to (B,), same as label input

    # loss
    labels_real = tf.pow(2.0, labels_placeholder)
    outputs_real = tf.pow(2.0, outputs)

    # only loss_log2_mse are used for gradient calculate, model minimize this value
    loss_log2_mse = tf.reduce_mean(tf.squared_difference(
        labels_placeholder, outputs),
                                   name='loss_log2_mse')
    loss_real_rmse = tf.sqrt(tf.reduce_mean(
        tf.squared_difference(labels_real, outputs_real)),
                             name='loss_real_rmse')
    loss_real_mae = tf.losses.absolute_difference(labels_real, outputs_real)

    tf.summary.scalar('loss_log2_mse', loss_log2_mse)
    tf.summary.scalar('loss_real_rmse', loss_real_rmse)
    tf.summary.scalar('loss_real_mae', loss_real_mae)

    # accuracy (R2)
    def r_sqaured(labels, outputs):
        sst = tf.reduce_sum(
            tf.squared_difference(labels, tf.reduce_mean(labels)))
        sse = tf.reduce_sum(tf.squared_difference(labels, outputs))
        return (1.0 - tf.div(sse, sst))

    r2_log2 = r_sqaured(labels_placeholder, outputs)
    r2_real = r_sqaured(labels_real, outputs_real)

    tf.summary.scalar('r2_log2', r2_log2)
    tf.summary.scalar('r2_real', r2_real)

    # determine the model vairables to restore from pre-trained checkpoint
    if MODEL == 'vgg':
        model_variables = slim.get_variables_to_restore(exclude=[
            'vgg_16/fc8', 'vgg_16/conv1', 'vgg_16/fc7/dim_reduce',
            'vgg_16/combine'
        ])
    if MODEL == 'resnet':
        model_variables = slim.get_variables_to_restore(
            exclude=['resnet_v1_152/logits'])  #, 'resnet_v1_152/conv1'])

    # training step and learning rate
    global_step = tf.Variable(0, name='global_step',
                              trainable=False)  #, dtype=tf.int64)
    learning_rate = tf.train.exponential_decay(
        FLAGS.learning_rate,  # initial learning rate
        global_step=global_step,  # current step
        decay_steps=MAX_STEPS,  # total numbers step to decay 
        decay_rate=FLAGS.lr_decay_rate
    )  # final learning rate = FLAGS.learning_rate * decay_rate
    tf.summary.scalar('learning_rate', learning_rate)

    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    #     optimizer = tf.train.RMSPropOptimizer(learning_rate=learning_rate)
    #     optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)

    # to only update gradient in first and last layer
    #     vars_update = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'vgg_16/(conv1|fc8)')
    #     print('variables to update in traing: ', vars_update)

    train_op = optimizer.minimize(
        loss_log2_mse, global_step=global_step)  #, var_list = vars_update)

    # summary output in tensorboard
    summary = tf.summary.merge_all()
    summary_writer_train = tf.summary.FileWriter(
        os.path.join(LOG_DIR, 'log_train'), sess.graph)
    summary_writer_val = tf.summary.FileWriter(
        os.path.join(LOG_DIR, 'log_val'), sess.graph)

    # variable initialize
    init = tf.global_variables_initializer()
    sess.run(init)

    if RESUME:
        restorer = tf.train.Saver()
        restorer.restore(sess, PRETRAIN_WEIGHTS)
        print('loaded pre-trained weights: ', PRETRAIN_WEIGHTS)
    else:
        ##### restore the model from pre-trained checkpoint for new VGG archtecture #####

        # restore the weights for the layers that are nor modified in the new arch (excep conv1, fc8)
        restorer = tf.train.Saver(model_variables)
        restorer.restore(sess, PRETRAIN_WEIGHTS)
        print('loaded pre-trained weights: ', PRETRAIN_WEIGHTS)

        # a fake layer to hold the new variables to restore
        with tf.variable_scope("vgg_16"):
            fake_net = slim.repeat(images_l8_placeholder,
                                   2,
                                   slim.conv2d,
                                   64, [3, 3],
                                   scope='conv1')

        # print out the vairables in fake layer
        dup_weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                        'vgg_16/conv1')
        print('duplicated weights to update: ')
        for ww in dup_weights:
            print(ww)

        # get the vairables for the fakes layer
        with tf.variable_scope("vgg_16/conv1", reuse=True):
            weights1 = tf.get_variable("conv1_1/weights")
            bias1 = tf.get_variable("conv1_1/biases")
            weights2 = tf.get_variable("conv1_2/weights")
            bias2 = tf.get_variable("conv1_2/biases")

        # restore the vairables of fake layer with checkpoint weights
        restorer = tf.train.Saver([weights1, bias1, weights2, bias2])
        restorer.restore(sess, PRETRAIN_WEIGHTS)
        print('loaded pre-trained weights: ', PRETRAIN_WEIGHTS)

        # assign the weights of fake layer to true model vairables
        with tf.variable_scope("vgg_16", reuse=True):
            assign_ops = [
                tf.assign(tf.get_variable("conv1_l8/conv1_l8_1/weights"),
                          weights1),
                tf.assign(tf.get_variable("conv1_s1/conv1_s1_1/weights"),
                          weights1),
                tf.assign(tf.get_variable("conv1_l8/conv1_l8_2/weights"),
                          weights2),
                tf.assign(tf.get_variable("conv1_s1/conv1_s1_2/weights"),
                          weights2),
                tf.assign(tf.get_variable("conv1_l8/conv1_l8_1/biases"),
                          bias1),
                tf.assign(tf.get_variable("conv1_s1/conv1_s1_1/biases"),
                          bias1),
                tf.assign(tf.get_variable("conv1_l8/conv1_l8_2/biases"),
                          bias2),
                tf.assign(tf.get_variable("conv1_s1/conv1_s1_2/biases"), bias2)
            ]

        sess.run([assign_ops])
        ###########################################################################

    # saver object to save checkpoint during training
    saver = tf.train.Saver(tf.global_variables(), max_to_keep=10)

    print('start training...')
    epoch = 0
    best_r2 = -float('inf')
    for step in xrange(MAX_STEPS):
        if step % TRAIN_BATCHES_PER_EPOCH == 0:
            epoch += 1

        start_time = time.time()  # record the time used for each batch

        images_out, labels_out = sess.run(
            [train_images_batch,
             train_labels_batch])  # inputs of this batch, numpy array format

        duration_batch = time.time() - start_time

        if step == 0:
            print("finished reading batch data")
            print("images_out shape:", images_out.shape)
        feed_dict = {
            images_l8_placeholder: images_out[:, :, :, :3],
            images_s1_placeholder: images_out[:, :, :, 3:],
            labels_placeholder: labels_out
        }

        _, train_loss, train_accuracy, train_outputs, lr = \
            sess.run([train_op, loss_log2_mse, r2_log2, outputs, learning_rate], feed_dict=feed_dict)

        duration = time.time() - start_time

        if step % 10 == 0 or (
                step + 1) == MAX_STEPS:  # print traing loss every 10 batches
            print('Step %d epoch %d lr %.3e: log2 MSE loss = %.4f log2 R2 = %.4f (%.3f sec, %.3f sec(each batch))' \
                  % (step, epoch, lr, train_loss, train_accuracy, duration*10, duration_batch))
            summary_str = sess.run(summary, feed_dict=feed_dict)
            summary_writer_train.add_summary(summary_str, step)
            summary_writer_train.flush()

        if step % 50 == 0 or (
                step + 1
        ) == MAX_STEPS:  # calculate and print validation loss every 50 batches
            images_out, labels_out = sess.run(
                [val_images_batch, val_labels_batch])
            feed_dict = {
                images_l8_placeholder: images_out[:, :, :, :3],
                images_s1_placeholder: images_out[:, :, :, 3:],
                labels_placeholder: labels_out
            }

            val_loss, val_accuracy = sess.run([loss_log2_mse, r2_log2],
                                              feed_dict=feed_dict)
            print('Step %d epoch %d: val log2 MSE = %.4f val log2 R2 = %.4f ' %
                  (step, epoch, val_loss, val_accuracy))

            summary_str = sess.run(summary, feed_dict=feed_dict)
            summary_writer_val.add_summary(summary_str, step)
            summary_writer_val.flush()

            # in each epoch, if the validation R2 is higher than best R2, save the checkpoint
            if step % (TRAIN_BATCHES_PER_EPOCH -
                       TRAIN_BATCHES_PER_EPOCH % 50) == 0:
                if val_accuracy > best_r2:
                    best_r2 = val_accuracy
                    checkpoint_file = os.path.join(LOG_DIR, 'model.ckpt')
                    saver.save(sess,
                               checkpoint_file,
                               global_step=step,
                               write_state=True)
コード例 #34
0
def main(argv=None):
    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list
    if not tf.gfile.Exists(FLAGS.checkpoint_path):
        tf.gfile.MkDir(FLAGS.checkpoint_path)
    else:
        if not FLAGS.restore:
            tf.gfile.DeleteRecursively(FLAGS.checkpoint_path)
            tf.gfile.MkDir(FLAGS.checkpoint_path)

    input_images = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_images')
    input_score_maps = tf.placeholder(tf.float32, shape=[None, None, None, 1], name='input_score_maps')
    if FLAGS.geometry == 'RBOX':
        input_geo_maps = tf.placeholder(tf.float32, shape=[None, None, None, 5], name='input_geo_maps')
    else:
        input_geo_maps = tf.placeholder(tf.float32, shape=[None, None, None, 8], name='input_geo_maps')
    input_training_masks = tf.placeholder(tf.float32, shape=[None, None, None, 1], name='input_training_masks')
    input_labels = tf.placeholder(tf.float32, shape=[None, None, 4, 2], name='input_labels')

    global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False)
    learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, global_step, decay_steps=10000, decay_rate=0.94, staircase=True)
    # add summary
    tf.summary.scalar('learning_rate', learning_rate)
    opt = tf.train.AdamOptimizer(learning_rate)
    # opt = tf.train.MomentumOptimizer(learning_rate, 0.9)

    # split
    input_images_split = tf.split(input_images, len(gpus))
    input_score_maps_split = tf.split(input_score_maps, len(gpus))
    input_geo_maps_split = tf.split(input_geo_maps, len(gpus))
    input_training_masks_split = tf.split(input_training_masks, len(gpus))
    input_labels_split = tf.split(input_labels, len(gpus))
    #x = tf.placeholder(tf.int16, shape=[None, None, 4, 2])
    #y = tf.split(x, len(gpus))

    tower_grads = []
    reuse_variables = None
    for i, gpu_id in enumerate(gpus):
        with tf.device('/gpu:%d' % gpu_id):
            with tf.name_scope('model_%d' % gpu_id) as scope:
                iis = input_images_split[i]
                isms = input_score_maps_split[i]
                igms = input_geo_maps_split[i]
                itms = input_training_masks_split[i]
		il = input_labels_split[i]
                total_loss, model_loss, f_score, f_geometry, f_dat = tower_loss(iis, isms, igms, itms, il, reuse_variables)
                batch_norm_updates_op = tf.group(*tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope))
                reuse_variables = True
                train_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='feature_fusion')
                grads = opt.compute_gradients(total_loss, var_list=train_var)
                tower_grads.append(grads)
	        #stuff = tf.split(x,len(gpus))[i]

    grads = average_gradients(tower_grads)
    apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)

    summary_op = tf.summary.merge_all()
    # save moving average
    variable_averages = tf.train.ExponentialMovingAverage(FLAGS.moving_average_decay, global_step)
    variables_averages_op = variable_averages.apply(tf.trainable_variables())
    # batch norm updates
    with tf.control_dependencies([variables_averages_op, apply_gradient_op, batch_norm_updates_op]):
        train_op = tf.no_op(name='train_op')
    
    variables = slim.get_variables_to_restore()
    #print variables[0].name.split('/')
    #print variables
    var_list = []
    for v in variables:
	if len(v.name.split('/')) == 1:
		var_list.append(v)
	elif v.name.split('/')[1] != "myconv1" or not v.name.find('custom_filter'):
		var_list.append(v)
	else:
		pass
    #var_list=[v for v in variables if v.name.split('/')[1] != "conv1"]
    saver = tf.train.Saver(var_list)
    #print var_list
    summary_writer = tf.summary.FileWriter(FLAGS.checkpoint_path, tf.get_default_graph())
    
    '''
    training_list = ["D0006-0285025", "D0017-1592006", "D0041-5370006", "D0041-5370026", "D0042-1070001", "D0042-1070002", "D0042-1070003", "D0042-1070004", "D0042-1070005", "D0042-1070006", "D0042-1070007", "D0042-1070008", "D0042-1070009", "D0042-1070010", "D0042-1070015", "D0042-1070012", "D0042-1070013", "D0079-0019007", "D0089-5235001"]
    validation_list = ["D0090-5242001", "D0117-5755018", "D0117-5755024", "D0117-5755025", "D0117-5755033"]

    with open('Data/cropped_annotations0.txt', 'r') as f:
            annotation_file = f.readlines()
    val_data0 = []
    val_data1 = []
    train_data0 = []
    train_data1 = []
    labels = []
    trainValTest = 2
    for line in annotation_file:
    	if len(line)>1 and line[:11] == 'cropped_img':
                if (len(labels) > 0):
		    if trainValTest == 0:
			train_data1.append(labels)
		    elif trainValTest == 1: 	
			val_data1.append(labels)
                    labels = []
		    trainValTest = 2
        	if line[12:25] in training_list:
		    file_name = "Data/cropped_img_train/"+line[12:].split(".tiff",1)[0]+".tiff"
		    im = cv2.imread(file_name)[:, :, ::-1]
                    train_data0.append(im.astype(np.float32))
		    trainValTest = 0
		elif line[12:25] in validation_list:
	            file_name = "Data/cropped_img_val/"+line[12:].split(".tiff",1)[0]+".tiff"
                    im = cv2.imread(file_name)[:, :, ::-1]
		    val_data0.append(im.astype(np.float32))
		    trainValTest = 1
        elif trainValTest != 2:
	 	annotation_data = line.split(" ")
                if (len(annotation_data) > 2):
		    x, y = float(annotation_data[0]), float(annotation_data[1])
                    w, h = float(annotation_data[2]), float(annotation_data[3])
                    labels.append([[int(x),int(y-h)],[int(x+w),int(y-h)],[int(x+w),int(y)],[int(x),int(y)]])
    if trainValTest == 0:
	train_data1.append(labels)
    elif trainValTest == 1:
	val_data1.append(labels)
    '''  
    init = tf.global_variables_initializer()
    
    if FLAGS.pretrained_model_path is not None:
        print "hereeeee"
	variable_restore_op = slim.assign_from_checkpoint_fn(FLAGS.pretrained_model_path, slim.get_trainable_variables(),
                                                             ignore_missing_vars=True)
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        #reader = tf.train.NewCheckpointReader("./"+FLAGS.checkpoint_path)
        if FLAGS.restore:
            ckpt_state = tf.train.get_checkpoint_state(FLAGS.checkpoint_path)
	    model_path = os.path.join(FLAGS.checkpoint_path, os.path.basename(ckpt_state.model_checkpoint_path))
            print('Continue training from previous checkpoint here {}'.format(model_path))
            saver.restore(sess, model_path)
        else:
            sess.run(init)
            if FLAGS.pretrained_model_path is not None:
                variable_restore_op(sess)
 	sess.run(tf.global_variables_initializer())
	variables_names = [v.name for v in tf.trainable_variables()]
	#print "................."
 	#print variables_names
        #print tf.all_variables()       
        training_list = ["D0006-0285025", "D0017-1592006", "D0041-5370006", "D0041-5370026", "D0042-1070001", "D0042-1070002", "D0042-1070003", "D0042-1070004", "D0042-1070005", "D0042-1070006", "D0042-1070007", "D0042-1070008", "D0042-1070009", "D0042-1070010", "D0042-1070015", "D0042-1070012", "D0042-1070013", "D0079-0019007", "D0089-5235001"]


	a = FLAGS.checkpoint_path[-2]
        data_size = 0
	
        with open('Data/cropped_annotations.txt', 'r') as f:
            annotation_file = f.readlines()
        for line in annotation_file:
            if len(line)>1 and line[:13] == './cropped_img' and line[14:27] in training_list:
                data_size +=1
	print "Char model: " + a
	print "Reg constant: " + str(reg_constant)
	print "Data size: " + str(data_size)
	epoche_size = 3 #ata_size / 32
	print "This many steps per epoche: " + str(epoche_size)
        data_generator = icdar.get_batch(num_workers=FLAGS.num_readers, q_size=10,
                                         input_size=FLAGS.input_size,
                                         batch_size=FLAGS.batch_size_per_gpu * len(gpus), data_path=a, trainOrVal="train")
        #print "getting the data batches"
	val_data_generator = icdar.get_batch(num_workers=FLAGS.num_readers, q_size=10,
                                         input_size=FLAGS.input_size,
                                         batch_size=FLAGS.batch_size_per_gpu * len(gpus), data_path=a, trainOrVal="val")
	start = time.time()
        epochsA, ml_list, tl_list = [], [], []
        epochsB, train_fscore, val_fscore = [], [], []
	#print "entering model training"
        for step in range(FLAGS.max_steps):
	    print "this is an iteration............"
            data = next(data_generator)
	    #val_data = next(val_data_generator)
	    
	    if (step % epoche_size == 100):
		#print 'Epochs {:.4f}, ml {:.4f}, tl {:.4f}'.format(float(step)/epoche_size, ml, tl) 
		'''
		train_size = len(train_data0)
                TP, FP, FN = 0.0, 0.0, 0.0
                for i in range(train_size / 128):
                    score, geometry = sess.run([f_score, f_geometry], feed_dict={input_images: train_data0[128*i: 128*(i+1)]})
                    labels = sess.run(stuff, feed_dict = {x: train_data1[128*i:128*(i+1)]})
                    TP0, FP0, FN0 = evalEAST.evaluate(score, geometry, labels)
                    TP += TP0
                    FP += FP0
                    FN += FN0
                p_train, r_train = TP / (TP + FP), TP / (TP + FN)
                fscore_train = 2 * p_train * r_train / (p_train + r_train)
		'''
                #for i in range(len(data[0])):
		#    count_right_cache = 0
                #score, geometry = sess.run([f_score, f_geometry], feed_dict={input_images: data[0]})
		#p_train, r_train, fscore_train = evalEAST.evaluate(score, geometry, data[5])
		#score, geometry = sess.run([f_score, f_geometry], feed_dict={input_images: val_data[0]})
                #p_val, r_val, fscore_val = evalEAST.evaluate(score, geometry, val_data[1])
		'''
                for i in range(len(score)):
		    count_right_cache = 0
		    print score[i].shape, geometry[i].shape
	            boxes = detect(score_map=score[i], geo_map=geometry[i])
                    if boxes is not None:
                        boxes = boxes[:, :8].reshape((-1, 4, 2))
                        for box in boxes:
                            box = sort_poly(box.astype(np.int32))
                            if np.linalg.norm(box[0] - box[1]) < 5 or np.linalg.norm(box[3]-box[0]) < 5:
                                continue
                            count_wrong += 1
                            num_true_pos = len(data[5][i])
                            for i2 in range(num_true_pos):
                                #print box
                                #print label[i][i2]
                                if (checkIOU(box, label[i][i2]) == True):
                                    count_right_cache += 1
                                    count_wrong -= 1
                    count_posNotDetected += num_true_pos - count_right_cache
                    count_right += count_right_cache
                p_train = (float) (count_right) / (float) (count_right + count_wrong)  # TP / TP + FP
                r_train = (float) (count_right) / (float) (count_right + count_posNotDetected)  # TP / TP + FN
                fscore_train = 2 * (p_train * r_train) / (p_train + r_train)
		print "hi"
	
		score, geometry = sess.run([f_score, f_geometry], feed_dict={input_images: val_data[0]})
                for i in range(len(score)):
                    count_right_cache = 0
                    #score, geometry = sess.run([f_score, f_geometry], feed_dict={input_images: val_data[0][i]})
                    boxes = detect(score_map=score[i], geo_map=geometry[i])
                    if boxes is not None:
                        boxes = boxes[:, :8].reshape((-1, 4, 2))
                        for box in boxes:
                            box = sort_poly(box.astype(np.int32))
                            if np.linalg.norm(box[0] - box[1]) < 5 or np.linalg.norm(box[3]-box[0]) < 5:
                                continue
                            count_wrong += 1
                            num_true_pos = len(val_data[1][i])
                            for i2 in range(num_true_pos):
                                #print box
                                #print label[i][i2]
                                if (checkIOU(box, label[i][i2]) == True):
                                    count_right_cache += 1
                                    count_wrong -= 1
                    count_posNotDetected += num_true_pos - count_right_cache
                    count_right += count_right_cache
                p_val = (float) (count_right) / (float) (count_right + count_wrong)  # TP / TP + FP
                r_val = (float) (count_right) / (float) (count_right + count_posNotDetected)  # TP / TP + FN
                fscore_val = 2 * (p_val * r_val) / (p_val + r_val)
                #return precision, recall, fscore
		'''
		#    score, geometry = sess.run([f_score, f_geometry], feed_dict={input_images: data[0][i]})
		#    fscore_train, p_train, r_train = evalEAST.evaluate(score, geometry, data[5][i])
		#    score, geometry = sess.run([f_score, f_geometry], feed_dict={input_images: val_data[0]})
                #    fscore_val, p_val, r_val = evalEAST.evaluate(score, geometry, val_data[1])

		print 'Epochs {:.4f}, train fscore {:.4f}, train p {:.4f}, train r {:.4f}, val fscore {:.4f}, val p {:.4f}, val r {:.4f}'.format(float(step)/epoche_size, fscore_train, p_train, r_train, fscore_val, p_val, r_val)            
               
	    #data0 = np.zeros((32,512,512,39)) 
	    ml, tl, _ = sess.run([model_loss, total_loss, train_op], feed_dict={input_images: data[0],
                                                                                input_score_maps: data[2],
                                                                                input_geo_maps: data[3],
                                                                                input_training_masks: data[4]})
            print ml, tl
	    if step % epoche_size == 0:
		print 'Epochs {:.4f}, ml {:.4f}, tl {:.4f}'.format(float(step)/epoche_size, ml, tl)	
	        #score2, geometry2, dat2 = sess.run([f_score, f_geometry, f_dat], feed_dict={input_images: data[0], input_labels: abc})
                #p_train, r_train, fscore_train = evalEAST.evaluate(score2, geometry2, dat2)
		#print ".."
                #score2, geometry2 = sess.run([f_score, f_geometry], feed_dict={input_images: val_data[0]})
                #p_val, r_val, fscore_val = evalEAST.evaluate(score2, geometry2, val_data[5])
		#print 'Train fscore {:.4f}, train p {:.4f}, train r {:.4f}, val fscore {:.4f}, val p {:.4f}, val r {:.4f}'.format(fscore_train, p_train, r_train, fscore_val, p_val, r_val) 
            
	    if np.isnan(tl):
                print('Loss diverged, stop training')
                break
                       
	    if step % epoche_size == 0: #FLAGS.save_summary_steps == 0:
                saver.save(sess, FLAGS.checkpoint_path + 'model.ckpt', global_step=global_step)
		_, tl, summary_str = sess.run([train_op, total_loss, summary_op], feed_dict={input_images: data0,
                                                                                             input_score_maps: data[2],
                                                                                             input_geo_maps: data[3],
                                                                                             input_training_masks: data[4]})
                summary_writer.add_summary(summary_str, global_step=step)
コード例 #35
0
def create_model(config, sess, ensemble_scope=None, train=False):
    logging.info('Building model...')
    model = StandardModel(config)

    # Is this model part of an ensemble?
    if ensemble_scope == None:
        # No: it's a standalone model, so the saved variable names should match
        # model's and we don't need to map them.
        saver = tf.train.Saver(max_to_keep=None)
    else:
        # Yes: there is an active model-specific scope, so tell the Saver to
        # map the saved variables to the scoped variables.
        variables = slim.get_variables_to_restore()
        var_map = {}
        for v in variables:
            if v.name.startswith(ensemble_scope):
                base_name = v.name[len(ensemble_scope):].split(':')[0]
                var_map[base_name] = v
        saver = tf.train.Saver(var_map, max_to_keep=None)

    # compute reload model filename
    reload_filename = None
    if config.reload == 'latest_checkpoint':
        checkpoint_dir = os.path.dirname(config.saveto)
        reload_filename = tf.train.latest_checkpoint(checkpoint_dir)
        if reload_filename != None:
            if (os.path.basename(reload_filename).rsplit('-', 1)[0] !=
                    os.path.basename(config.saveto)):
                logging.error(
                    "Mismatching model filename found in the same directory while reloading from the latest checkpoint"
                )
                sys.exit(1)
            logging.info('Latest checkpoint found in directory ' +
                         os.path.abspath(checkpoint_dir))
    elif config.reload != None:
        reload_filename = config.reload
    if (reload_filename == None) and (config.prior_model != None):
        logging.info('Initializing model parameters from prior')
        reload_filename = config.prior_model

    # initialize or reload training progress
    if train:
        progress = training_progress.TrainingProgress()
        progress.bad_counter = 0
        progress.uidx = 0
        progress.eidx = 0
        progress.estop = False
        progress.history_errs = []
        if reload_filename and config.reload_training_progress:
            path = reload_filename + '.progress.json'
            if os.path.exists(path):
                logging.info('Reloading training progress')
                progress.load_from_json(path)
                if (progress.estop == True or progress.eidx > config.max_epochs
                        or progress.uidx >= config.finish_after):
                    logging.warning(
                        'Training is already complete. Disable reloading of training progress (--no_reload_training_progress) or remove or modify progress file (%s) to train anyway.'
                        % reload_path)
                    sys.exit(0)

    # load prior model
    if train and config.prior_model != None:
        load_prior(config, sess, saver)

    # initialize or restore model
    if reload_filename == None:
        logging.info('Initializing model parameters from scratch...')
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
    else:
        logging.info('Loading model parameters from file ' +
                     os.path.abspath(reload_filename))
        saver.restore(sess, os.path.abspath(reload_filename))
        if train:
            # The global step is currently recorded in two places:
            #   1. model.t, a tf.Variable read and updated by the optimizer
            #   2. progress.uidx, a Python integer updated by train()
            # We reset model.t to the value recorded in progress to allow the
            # value to be controlled by the user (either explicitly by
            # configuring the value in the progress file or implicitly by using
            # --no_reload_training_progress).
            model.reset_global_step(progress.uidx, sess)

    logging.info('Done')

    if train:
        return model, saver, progress
    else:
        return model, saver
コード例 #36
0
ファイル: geonet_main.py プロジェクト: yang330624/GeoNet
def train():

    seed = 8964
    tf.set_random_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    pp = pprint.PrettyPrinter()
    pp.pprint(flags.FLAGS.__flags)

    if not os.path.exists(opt.checkpoint_dir):
        os.makedirs(opt.checkpoint_dir)

    with tf.Graph().as_default():
        # Data Loader
        loader = DataLoader(opt)
        tgt_image, src_image_stack, intrinsics = loader.load_train_batch()

        # Build Model
        model = GeoNetModel(opt, tgt_image, src_image_stack, intrinsics)
        loss = model.total_loss

        # Train Op
        if opt.mode == 'train_flow' and opt.flownet_type == "residual":
            # we pretrain DepthNet & PoseNet, then finetune ResFlowNetS
            train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "flow_net")
            vars_to_restore = slim.get_variables_to_restore(include=["depth_net", "pose_net"])
        else:
            train_vars = [var for var in tf.trainable_variables()]
            vars_to_restore = slim.get_model_variables()

        if opt.init_ckpt_file != None:
            init_assign_op, init_feed_dict = slim.assign_from_checkpoint(
                                            opt.init_ckpt_file, vars_to_restore)

        optim = tf.train.AdamOptimizer(opt.learning_rate, 0.9)
        train_op = slim.learning.create_train_op(loss, optim,
                                                 variables_to_train=train_vars)

        # Global Step
        global_step = tf.Variable(0,
                                name='global_step',
                                trainable=False)
        incr_global_step = tf.assign(global_step,
                                     global_step+1)

        # Parameter Count
        parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) \
                                        for v in train_vars])

        # Saver
        saver = tf.train.Saver([var for var in tf.model_variables()] + \
                                [global_step],
                                max_to_keep=opt.max_to_keep)

        # Session
        sv = tf.train.Supervisor(logdir=opt.checkpoint_dir,
                                 save_summaries_secs=0,
                                 saver=None)
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with sv.managed_session(config=config) as sess:
            print('Trainable variables: ')
            for var in train_vars:
                print(var.name)
            print("parameter_count =", sess.run(parameter_count))

            if opt.init_ckpt_file != None:
                sess.run(init_assign_op, init_feed_dict)
            start_time = time.time()

            for step in range(1, opt.max_steps):
                fetches = {
                    "train": train_op,
                    "global_step": global_step,
                    "incr_global_step": incr_global_step
                }
                if step % 100 == 0:
                    fetches["loss"] = loss
                results = sess.run(fetches)
                if step % 100 == 0:
                    time_per_iter = (time.time() - start_time) / 100
                    start_time = time.time()
                    print('Iteration: [%7d] | Time: %4.4fs/iter | Loss: %.3f' \
                          % (step, time_per_iter, results["loss"]))
                if step % opt.save_ckpt_freq == 0:
                    saver.save(sess, os.path.join(opt.checkpoint_dir, 'model'), global_step=step)
コード例 #37
0
def main(_):
    if FLAGS.model_name == 'vgg-16':
        net_width = 224
        net_height = 224
        consider_top = 41

        data = np.load(
            'data/cnn_parameters/carlavp_label_to_horvpz_fov_pitch.npz')
        train_dir = 'data/saved_models/vgg16/model.ckpt-20227'

        _R_MEAN = 123.68
        _G_MEAN = 116.78
        _B_MEAN = 103.94
        resnet_average_channels = np.array(np.concatenate(
            (np.tile(_R_MEAN, (net_height, net_width, 1)),
             np.tile(_G_MEAN, (net_height, net_width, 1)),
             np.tile(_B_MEAN, (net_height, net_width, 1))),
            axis=2),
                                           dtype=np.float32)
    elif FLAGS.model_name == 'inception-v4':
        net_width = 299
        net_height = 299
        consider_top = 53

        data = np.load(
            'data/cnn_parameters/carlavp-299x299_label_to_horvpz_fov_pitch.npz'
        )
        train_dir = 'data/saved_models/incp4/model.ckpt-17721'
    else:
        print("Invalid CNN model name specified")
        return

    if FLAGS.train_dir != '':
        train_dir = FLAGS.train_dir

    all_bins = data['all_bins']
    all_sphere_centres = data['all_sphere_centres']
    all_sphere_radii = data['all_sphere_radii']

    no_params_model = 4

    num_bins = 500

    img_path = FLAGS.img_path
    img_cv = cv2.imread(img_path)
    img_cv = cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB)
    orig_height, orig_width, orig_channels = img_cv.shape

    my_img = cv2.resize(img_cv,
                        dsize=(net_width, net_height),
                        interpolation=cv2.INTER_CUBIC)

    if FLAGS.model_name == 'vgg-16':
        my_img = (np.array(my_img, np.float32))
        my_img = my_img - resnet_average_channels
    elif FLAGS.model_name == 'inception-v4':
        my_img = (np.array(my_img, np.float32)) * (1. / 255)
        my_img = (my_img - 0.5) * 2
    else:
        print("Invalid CNN model name specified")
        return

    with tf.Graph().as_default():
        tf.logging.set_verbosity(tf.logging.INFO)

        img = tf.reshape(my_img, [1, net_width, net_height, 3])

        if FLAGS.model_name == 'vgg-16':
            with slim.arg_scope(vgg.vgg_arg_scope()):
                logits, _ = vgg.vgg_16(img,
                                       num_classes=num_bins * no_params_model,
                                       is_training=False)
        elif FLAGS.model_name == 'inception-v4':
            with slim.arg_scope(inception_v4.inception_v4_arg_scope()):
                logits, _ = inception_v4.inception_v4(img,
                                                      num_classes=num_bins *
                                                      no_params_model,
                                                      is_training=False)
        else:
            print("Invalid CNN model name specified")
            return

        probabilities = tf.nn.softmax(logits)

        checkpoint_path = train_dir
        init_fn = slim.assign_from_checkpoint_fn(
            checkpoint_path, slim.get_variables_to_restore())

        with tf.Session() as sess:
            with slim.queues.QueueRunners(sess):
                sess.run(tf.initialize_local_variables())
                init_fn(sess)
                start = timer()
                np_probabilities, np_rawvals = sess.run(
                    [probabilities, logits])

                i = 0

                pred_indices = np.zeros(no_params_model, dtype=np.int)
                for ln in range(no_params_model):
                    predsoft = my_softmax(np_rawvals[i, :].reshape(
                        no_params_model, -1)[ln, :][np.newaxis])
                    predsoft = predsoft.squeeze()

                    topindices = predsoft.argsort()[::-1][:consider_top]
                    probsindices = predsoft[topindices] / np.sum(
                        predsoft[topindices])
                    pred_indices[ln] = np.abs(
                        int(np.round(np.sum(probsindices * topindices))))

                estimated_input_points = get_horvpz_from_projected_4indices_modified(
                    pred_indices[:4], all_bins, all_sphere_centres,
                    all_sphere_radii)

                end = timer()

    print("Time taken: {0:.2f}s".format(end - start))
    print("Output of the code")
    print("------------------------------------------------")

    plot_scaled_horizonvector_vpz_picture(img_cv,
                                          estimated_input_points,
                                          net_dims=(net_width, net_height),
                                          color='go',
                                          show_vz=True,
                                          verbose=True)
    plt.show()

    fx, fy, roll_from_horizon, my_tilt = get_intrinisic_extrinsic_params_from_horizonvector_vpz(
        img_dims=(orig_width, orig_height),
        horizonvector_vpz=estimated_input_points,
        net_dims=(net_width, net_height),
        verbose=False)

    overhead_hmatrix, est_range_u, est_range_v = get_overhead_hmatrix_from_4cameraparams(
        fx=fx,
        fy=fy,
        my_tilt=my_tilt,
        my_roll=-radians(roll_from_horizon),
        img_dims=(orig_width, orig_height),
        verbose=False)

    scaled_overhead_hmatrix, target_dim = get_scaled_homography(
        overhead_hmatrix, 1080 * 2, est_range_u, est_range_v)

    warped = cv2.warpPerspective(img_cv,
                                 scaled_overhead_hmatrix,
                                 dsize=target_dim,
                                 flags=cv2.INTER_CUBIC)

    plt.imshow(warped)
    # plt.xticks([])
    # plt.yticks([])
    plt.show()
    os.makedirs("output/", exist_ok=True)
    txt_file = 'output/' + img_path[img_path.rfind('/') + 1:img_path.rfind(
        '.')] + '_homography_matrix_' + FLAGS.model_name + '.txt'
    np.savetxt(txt_file, scaled_overhead_hmatrix)
    print("Homography matrix saved to the text file:", txt_file)
    print("------------------------------------------------")
コード例 #38
0
ファイル: model_loader.py プロジェクト: rsennrich/nematus
def init_or_restore_variables(config, sess, ensemble_scope=None, train=False):
    # Construct a mapping between saved variable names and names in the current
    # scope. There are two reasons why names might be different:
    #
    #   1. This model is part of an ensemble, in which case a model-specific
    #       name scope will be active.
    #
    #   2. The saved model is from an old version of Nematus (before deep model
    #        support was added) and uses a different variable naming scheme
    #        for the GRUs.
    variables = slim.get_variables_to_restore()
    var_map = {}
    for v in variables:
        name = v.name.split(':')[0]
        if ensemble_scope == None:
            saved_name = name
        elif v.name.startswith(ensemble_scope.name + "/"):
            saved_name = name[len(ensemble_scope.name)+1:]
            # The ensemble scope is repeated for Adam variables. See
            # https://github.com/tensorflow/tensorflow/issues/8120
            if saved_name.startswith(ensemble_scope.name + "/"):
                saved_name = saved_name[len(ensemble_scope.name)+1:]
        else: # v belongs to a different model in the ensemble.
            continue
        if config.model_version == 0.1:
            # Backwards compatibility with the old variable naming scheme.
            saved_name = _revert_variable_name(saved_name, 0.1)
        var_map[saved_name] = v
    saver = tf.train.Saver(var_map, max_to_keep=None)

    # compute reload model filename
    reload_filename = None
    if config.reload == 'latest_checkpoint':
        checkpoint_dir = os.path.dirname(config.saveto)
        reload_filename = tf.train.latest_checkpoint(checkpoint_dir)
        if reload_filename != None:
            if (os.path.basename(reload_filename).rsplit('-', 1)[0] !=
                os.path.basename(config.saveto)):
                logging.error("Mismatching model filename found in the same directory while reloading from the latest checkpoint")
                sys.exit(1)
            logging.info('Latest checkpoint found in directory ' + os.path.abspath(checkpoint_dir))
    elif config.reload != None:
        reload_filename = config.reload
    if (reload_filename == None) and (config.prior_model != None):
        logging.info('Initializing model parameters from prior')
        reload_filename = config.prior_model

    # initialize or reload training progress
    if train:
        progress = training_progress.TrainingProgress()
        progress.bad_counter = 0
        progress.uidx = 0
        progress.eidx = 0
        progress.estop = False
        progress.history_errs = []
        progress.valid_script_scores = []
        if reload_filename and config.reload_training_progress:
            path = reload_filename + '.progress.json'
            if os.path.exists(path):
                logging.info('Reloading training progress')
                progress.load_from_json(path)
                if (progress.estop == True or
                    progress.eidx > config.max_epochs or
                    progress.uidx >= config.finish_after):
                    logging.warning('Training is already complete. Disable reloading of training progress (--no_reload_training_progress) or remove or modify progress file (%s) to train anyway.' % path)
                    sys.exit(0)

    # load prior model
    if train and config.prior_model != None:
        load_prior(config, sess, saver)

    # initialize or restore model
    if reload_filename == None:
        logging.info('Initializing model parameters from scratch...')
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
    else:
        logging.info('Loading model parameters from file ' + os.path.abspath(reload_filename))
        saver.restore(sess, os.path.abspath(reload_filename))

    logging.info('Done')

    if train:
        return saver, progress
    else:
        return saver
コード例 #39
0
def train(config_yaml):
    setup_logging()

    config_path = Path(config_yaml).resolve()
    cfg = load_config(config_yaml)
    dataset = create_dataset(cfg)

    batch_spec = get_batch_spec(cfg)
    batch, enqueue_op, placeholders = setup_preloading(batch_spec)

    losses = pose_net(cfg).train(batch)
    total_loss = losses['total_loss']

    for k, t in losses.items():
        tf.summary.scalar(k, t)
    merged_summaries = tf.summary.merge_all()

    variables_to_restore = slim.get_variables_to_restore(include=["resnet_v1"])

    restorer = tf.train.Saver(variables_to_restore)
    saver = tf.train.Saver(max_to_keep=5)

    sess = tf.Session()

    coord, thread = start_preloading(sess, enqueue_op, dataset, placeholders)

    train_writer = tf.summary.FileWriter(cfg.log_dir, sess.graph)

    learning_rate, train_op = get_optimizer(total_loss, cfg)

    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())

    # Restore variables from disk.
    restorer.restore(sess, cfg.init_weights)

    max_iter = int(cfg.multi_step[-1][1])

    display_iters = cfg.display_iters
    cum_loss = 0.0
    lr_gen = LearningRate(cfg)

    # Continue with training existing network if possible
    print('Looking for latest snapshot (if any)...')
    assert config_path.exists()
    training_path = config_path.parent
    stats_path = Path(config_yaml).with_name('learning_stats.csv')
    snapshots = [
        s.with_suffix('').name for s in training_path.glob('snapshot-*.index')
    ]
    if len(snapshots) > 0:
        latest_snapshot_id = max(
            [int(s[len('snapshot-'):]) for s in snapshots])
        latest_snapshot = 'snapshot-{}'.format(latest_snapshot_id)
        snapshot_path = training_path / latest_snapshot
        start_iter = int(latest_snapshot.rsplit('-')[-1])
        print("Latest snapshot:", start_iter)
        saver.restore(sess, str(snapshot_path))
        lrf = open(stats_path, 'a')  # a for append to old one
    else:
        print(
            "No previous trained models found, training from iteration 1....")
        start_iter = 0
        lrf = open(
            stats_path,
            'w')  # w for write over whatever, I'm starting a new model anyway.
        lrf.write("iteration, average_loss, learning_rate\n".format())

    for it in range(start_iter + 1, max_iter + 1):
        current_lr = lr_gen.get_lr(it)
        [_, loss_val,
         summary] = sess.run([train_op, total_loss, merged_summaries],
                             feed_dict={learning_rate: current_lr})
        cum_loss += loss_val
        train_writer.add_summary(summary, it)

        if it % display_iters == 0:
            average_loss = cum_loss / display_iters
            cum_loss = 0.0
            logging.info("iteration: {} loss: {} lr: {}".format(
                it, "{0:.4f}".format(average_loss), current_lr))
            lrf.write("{}, {:.5f}, {}\n".format(it, average_loss, current_lr))
            lrf.flush()

        # Save snapshot
        if (it % cfg.save_iters == 0 and it != 0) or it == max_iter:
            model_name = cfg.snapshot_prefix
            saver.save(sess, model_name, global_step=it)
            print("Saved latest model...")

    lrf.close()
    sess.close()
    coord.request_stop()
    coord.join([thread])
コード例 #40
0
def train():
    setup_logging()

    cfg = load_config()
    dataset = create_dataset(cfg)

    batch_spec = get_batch_spec(cfg)
    batch, enqueue_op, placeholders = setup_preloading(batch_spec)

    losses = pose_net(cfg).train(batch)
    total_loss = losses['total_loss']

    for k, t in losses.items():
        tf.summary.scalar(k, t)
    merged_summaries = tf.summary.merge_all()

    variables_to_restore = slim.get_variables_to_restore(include=["resnet_v1"])
    restorer = tf.train.Saver(variables_to_restore)
    saver = tf.train.Saver(max_to_keep=5)

    sess = tf.Session()

    coord, thread = start_preloading(sess, enqueue_op, dataset, placeholders)

    train_writer = tf.summary.FileWriter(cfg.log_dir, sess.graph)

    learning_rate, train_op = get_optimizer(total_loss, cfg)

    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())

    # Restore variables from disk.
    restorer.restore(sess, cfg.init_weights)

    max_iter = int(cfg.multi_step[-1][1])

    display_iters = cfg.display_iters
    cum_loss = 0.0
    lr_gen = LearningRate(cfg)

    for it in range(max_iter+1):
        current_lr = lr_gen.get_lr(it)
        [_, loss_val, summary] = sess.run([train_op, total_loss, merged_summaries],
                                          feed_dict={learning_rate: current_lr})
        cum_loss += loss_val
        train_writer.add_summary(summary, it)

        if it % display_iters == 0:
            average_loss = cum_loss / display_iters
            cum_loss = 0.0
            logging.info("iteration: {} loss: {} lr: {}"
                         .format(it, "{0:.4f}".format(average_loss), current_lr))

        # Save snapshot
        if (it % cfg.save_iters == 0 and it != 0) or it == max_iter:
            model_name = cfg.snapshot_prefix
            saver.save(sess, model_name, global_step=it)

    sess.close()
    coord.request_stop()
    coord.join([thread])
コード例 #41
0
            target_log_prob_fn=unnormalized_posterior,
            step_size=np.float32(1.),
            num_leapfrog_steps=3),
        num_adaptation_steps=int(0.8 * burn))

    initial_state = tf.constant(
        np.zeros((PARAMS.batch_size, PARAMS.z_dim)).astype(np.float32))
    samples, [st_size, log_accept_ratio] = tfp.mcmc.sample_chain(
        num_results=N,
        num_burnin_steps=burn,
        current_state=initial_state,
        kernel=adaptive_hmc,
        trace_fn=lambda _, pkr: [
            pkr.inner_results.accepted_results.step_size, pkr.inner_results.
            log_accept_ratio
        ])
    p_accept = tf.reduce_mean(tf.exp(tf.minimum(log_accept_ratio, 0.)))

    zz = tf.placeholder(tf.float32, shape=[N - burn, PARAMS.z_dim])
    gen_out1 = gen(zz, reuse=tf.AUTO_REUSE, training=False)

    variables_to_restore = slim.get_variables_to_restore()
    saver = tf.train.Saver(variables_to_restore)

with tf.Session(graph=g) as sess:
    saver.restore(sess, PARAMS.model_path)

    samples_ = sess.run(samples)
    np.save(save_dir + '/samples.npy', samples_)
    print('HMC acceptance ratio = {}'.format(sess.run(p_accept)))
コード例 #42
0
def init_or_restore_variables(config, sess, ensemble_scope=None, train=False):
    # Construct a mapping between saved variable names and names in the current
    # scope. There are two reasons why names might be different:
    #
    #   1. This model is part of an ensemble, in which case a model-specific
    #       name scope will be active.
    #
    #   2. The saved model is from an old version of Nematus (before deep model
    #        support was added) and uses a different variable naming scheme
    #        for the GRUs.
    variables = slim.get_variables_to_restore()
    var_map = {}
    for v in variables:
        name = v.name.split(':')[0]
        if ensemble_scope == None:
            saved_name = name
        elif v.name.startswith(ensemble_scope.name + "/"):
            saved_name = name[len(ensemble_scope.name) + 1:]
            # The ensemble scope is repeated for Adam variables. See
            # https://github.com/tensorflow/tensorflow/issues/8120
            if saved_name.startswith(ensemble_scope.name + "/"):
                saved_name = saved_name[len(ensemble_scope.name) + 1:]
        else:  # v belongs to a different model in the ensemble.
            continue
        if config.model_version == 0.1:
            # Backwards compatibility with the old variable naming scheme.
            saved_name = compat.revert_variable_name(saved_name, 0.1)
        var_map[saved_name] = v
    saver = tf.train.Saver(var_map, max_to_keep=None)

    # compute reload model filename
    reload_filename = None
    if config.reload == 'latest_checkpoint':
        checkpoint_dir = os.path.dirname(config.saveto)
        reload_filename = tf.train.latest_checkpoint(checkpoint_dir)
        if reload_filename != None:
            if (os.path.basename(reload_filename).rsplit('-', 1)[0] !=
                    os.path.basename(config.saveto)):
                logging.error(
                    "Mismatching model filename found in the same directory while reloading from the latest checkpoint"
                )
                sys.exit(1)
            logging.info('Latest checkpoint found in directory ' +
                         os.path.abspath(checkpoint_dir))
    elif config.reload != None:
        reload_filename = config.reload
    if (reload_filename == None) and (config.prior_model != None):
        logging.info('Initializing model parameters from prior')
        reload_filename = config.prior_model

    # initialize or reload training progress
    if train:
        progress = training_progress.TrainingProgress()
        progress.bad_counter = 0
        progress.uidx = 0
        progress.eidx = 0
        progress.estop = False
        progress.history_errs = []
        progress.valid_script_scores = []
        if reload_filename and config.reload_training_progress:
            path = reload_filename + '.progress.json'
            if os.path.exists(path):
                logging.info('Reloading training progress')
                progress.load_from_json(path)
                if (progress.estop == True or progress.eidx > config.max_epochs
                        or progress.uidx >= config.finish_after):
                    logging.warning(
                        'Training is already complete. Disable reloading of training progress (--no_reload_training_progress) or remove or modify progress file (%s) to train anyway.'
                        % reload_path)
                    sys.exit(0)

    # load prior model
    if train and config.prior_model != None:
        load_prior(config, sess, saver)

    # initialize or restore model
    if reload_filename == None:
        logging.info('Initializing model parameters from scratch...')
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
    else:
        logging.info('Loading model parameters from file ' +
                     os.path.abspath(reload_filename))
        saver.restore(sess, os.path.abspath(reload_filename))

    logging.info('Done')

    if train:
        return saver, progress
    else:
        return saver
コード例 #43
0
def trainSegSoftNetwork(arg_dict):
	config = tf.ConfigProto(allow_soft_placement = True)
	# log_device_placement=True
	# allow_soft_placement=True
	sess = tf.Session(config = config)
	##########################################
	with tf.variable_scope('Input_Variables'):
		image_placeholder = tf.placeholder(tf.float32, [arg_dict['batch_size'], arg_dict['input_size'], arg_dict['input_size'], 1])
		label_placeholder = tf.placeholder(tf.float32, [arg_dict['batch_size'], arg_dict['input_size'], arg_dict['input_size'], 1])
		is_training = tf.placeholder(tf.bool, [], name='is_training')
	##########################################
	with tf.variable_scope('Network'):
		print('Constructing model...')
		network_eval_batch = arg_dict['model_construct_function'](image_placeholder, is_training)
	with tf.variable_scope('Loss'):
		print('Adding loss function...')
		loss, errors = gen_loss_seg_nomorph(network_eval_batch, label_placeholder)
	##########################################
	with tf.variable_scope('Input_Decoding'):
		print('Populating input queues...')
		image_valid_batch, label_valid_batch = get_eval_batch_seg(arg_dict['dataset'], arg_dict['n_reader_threads'], arg_dict['batch_size'], arg_dict['input_size'])
		image_train_batch, label_train_batch = get_train_batch_seg(arg_dict['dataset'], arg_dict['n_reader_threads'], arg_dict['batch_size'], arg_dict['input_size'], max_trans = arg_dict['aug_trans_max'], max_rot = arg_dict['aug_rot_max'])
		print('Starting input threads...')
		coord = tf.train.Coordinator()
		threads = tf.train.start_queue_runners(sess=sess, coord=coord)
	##########################################
	global_step = tf.Variable(0, name='global_step', trainable=False)
	with tf.variable_scope('Optimizer'):
		print('Initializing optimizer...')
		learn_rate, train_op = arg_dict['learn_function'](loss, arg_dict['dataset'].train_size, arg_dict['batch_size'], global_step, arg_dict['start_learn_rate'], arg_dict['epocs_per_lr_decay'], const_learn_rate=arg_dict['const_learn_rate'])
	##########################################
	with tf.variable_scope('Saver'):
		print('Generating summaries and savers...')
		training_summary, validation_summary = gen_summary_seg(loss, errors, learn_rate)
		summary_writer = tf.summary.FileWriter(arg_dict['log_dir'], sess.graph)
		saver = tf.train.Saver(slim.get_variables_to_restore(), max_to_keep=2)
	##########################################
	print('Initializing model...')
	sess.run(tf.global_variables_initializer())
	if 'network_to_restore' in arg_dict.keys() and arg_dict['network_to_restore'] is not None:
		saver.restore(sess,arg_dict['network_to_restore'])

	for step in range(0, arg_dict['num_steps']):
		start_time = time.time()
		img_batch, label_batch = sess.run([image_train_batch, label_train_batch])
		_, train_loss, summary_output, cur_step = sess.run(fetches=[train_op, loss, training_summary, global_step], feed_dict={image_placeholder: img_batch, label_placeholder: label_batch, is_training: True})
		duration = time.time() - start_time
		if (step+1) % 50 == 0: # CMDline updates every 50 steps
			examples_per_sec = arg_dict['batch_size'] / duration
			sec_per_batch = float(duration)
			format_str = ('%s: step %d, loss = %.4f (%.1f examples/sec; %.3f sec/batch)')
			print (format_str % (datetime.now(), cur_step, train_loss, examples_per_sec, sec_per_batch))
		if (step+1) % 100 == 0: # Tensorboard updates values every 100 steps
			summary_writer.add_summary(summary_output, cur_step)
			img_batch, label_batch = sess.run([image_valid_batch, label_valid_batch])
			summary_output = sess.run(fetches=[validation_summary], feed_dict={image_placeholder: img_batch, label_placeholder: label_batch, is_training: False})[0]
			summary_writer.add_summary(summary_output, cur_step)
		if (step+1) % 1000 == 0: # Save model every 1k steps
			checkpoint_path = os.path.join(arg_dict['log_dir'], 'model.ckpt')
			saver.save(sess, checkpoint_path, global_step=cur_step)

	# Save model after training is terminated...
	checkpoint_path = os.path.join(arg_dict['log_dir'], 'model.ckpt')
	saver.save(sess, checkpoint_path, global_step=cur_step)
コード例 #44
0
    tf.summary.scalar('obj loss', obj_loss)

    lr_rate = 0.001
    var_list = [
        v for v in tf.trainable_variables()
        if any(v.name in s for s in train_vars)
    ]
    with tf.name_scope('train'):
        optm = tf.train.MomentumOptimizer(learning_rate=lr_rate, momentum=0.9)
        grads_vars = optm.compute_gradients(loss, var_list)
        train_op = optm.apply_gradients(grads_and_vars=grads_vars)

    exclude = [
        'resnet_v1_50/logits/weights', 'resnet_v1_50/logits/biases',
        'resnet_v1_50/fc6/weights', 'resnet_v1_50/fc6/biases'
    ]
    vars_restore = slim.get_variables_to_restore(exclude=exclude)

    restorer = tf.train.Saver(var_list=vars_restore)
    saver = tf.train.Saver(max_to_keep=10)

    feed_dict = OrderedDict.fromkeys([inputs, truths, training])
    group_op = tf.group(train_op)

    data_family = load_data_sets(output_dims, batch_size, train_list, val_list,
                                 base_dir)

    # begin training
    train(group_op, loss, feed_dict, data_family, num_epochs, saver, restorer,
          model_path)
コード例 #45
0
ファイル: train.py プロジェクト: mkabra/poseTF
def train(cfg):
#    setup_logging()

    cfg = edict(cfg.__dict__)
    cfg = config.convert_to_deepcut(cfg)

    dirname = os.path.dirname(__file__)
    init_weights = os.path.join(dirname, 'models/resnet_v1_50.ckpt')

    if not os.path.exists(init_weights):
        # Download and save the pretrained resnet weights.
        logging.info('Downloading pretrained resnet 50 weights ...')
        urllib.urlretrieve('http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz', os.path.join(dirname,'models','resnet_v1_50_2016_08_28.tar.gz'))
        tar = tarfile.open(os.path.join(dirname,'models','resnet_v1_50_2016_08_28.tar.gz'))
        tar.extractall(path=os.path.join(dirname,'models'))
        tar.close()
        logging.info('Done downloading pretrained weights')

    db_file_name = os.path.join(cfg.cachedir, 'train_data.p')
    dataset = PoseDataset(cfg, db_file_name)
    train_info = {'train_dist':[],'train_loss':[],'val_dist':[],'val_loss':[],'step':[]}

    batch_spec = get_batch_spec(cfg)
    batch, enqueue_op, placeholders = setup_preloading(batch_spec)

    net = pose_net(cfg)
    losses = net.train(batch)
    total_loss = losses['total_loss']
    outputs = [net.heads['part_pred'], net.heads['locref']]

    for k, t in losses.items():
        tf.summary.scalar(k, t)

    variables_to_restore = slim.get_variables_to_restore(include=["resnet_v1"])
    restorer = tf.train.Saver(variables_to_restore)
    saver = tf.train.Saver(max_to_keep=50)

    sess = tf.Session()

    coord, thread = start_preloading(sess, enqueue_op, dataset, placeholders)

    learning_rate, train_op = get_optimizer(total_loss, cfg)

    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())

    # Restore variables from disk.
    restorer.restore(sess, init_weights)

    #max_iter = int(cfg.multi_step[-1][1])
    max_iter = int(cfg.dl_steps)
    display_iters = cfg.display_step
    cum_loss = 0.0
    lr_gen = LearningRate(cfg)

    model_name = os.path.join( cfg.cachedir, cfg.expname + '_' + name)
    ckpt_file = os.path.join(cfg.cachedir, cfg.expname + '_' + name + '_ckpt')

    for it in range(max_iter+1):
        current_lr = lr_gen.get_lr(it)
        [_, loss_val] = sess.run([train_op, total_loss], # merged_summaries],
                                          feed_dict={learning_rate: current_lr})
        cum_loss += loss_val
 #       train_writer.add_summary(summary, it)

        if it % display_iters == 0:

            cur_out, batch_out = sess.run([outputs, batch], feed_dict={learning_rate: current_lr})
            scmap, locref = predict.extract_cnn_output(cur_out, cfg)

            # Extract maximum scoring location from the heatmap, assume 1 person
            loc_pred = predict.argmax_pose_predict(scmap, locref, cfg.stride)
            loc_in = batch_out[Batch.locs]
            dd = np.sqrt(np.sum(np.square(loc_pred[:,:,:2]-loc_in),axis=-1))
            dd = dd*cfg.dlc_rescale
            average_loss = cum_loss / display_iters
            cum_loss = 0.0
            print("iteration: {} loss: {} dist: {}  lr: {}"
                         .format(it, "{0:.4f}".format(average_loss),
                                 '{0:.2f}'.format(dd.mean()), current_lr))
            train_info['step'].append(it)
            train_info['train_loss'].append(loss_val)
            train_info['val_loss'].append(loss_val)
            train_info['val_dist'].append(dd.mean())
            train_info['train_dist'].append(dd.mean())

        if it % cfg.save_td_step == 0:
            save_td(cfg, train_info)
        # Save snapshot
        if (it % cfg.save_step == 0 ) or it == max_iter:
            saver.save(sess, model_name, global_step=it,
                       latest_filename=os.path.basename(ckpt_file))

    coord.request_stop()
    coord.join([thread])
    sess.close()
コード例 #46
0
import inception_v1
import input_data
import tensorflow.contrib.slim as slim
import tensorflow as tf 


x = tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 3])
keep_prob = tf.placeholder(dtype=tf.float32)
logits = inception_v1.inception_v1(x, keep_prob, 5)
logits = tf.reshape(logits, [-1, 5])

exclusions = ['InceptionV1/Logits']
inception_except_logits = slim.get_variables_to_restore(exclude=exclusions)
CKPT_FILE = 'inception_v1.ckpt'
init_fn = slim.assign_from_checkpoint_fn(
	CKPT_FILE,
	inception_except_logits, ignore_missing_vars=True)


y = tf.nn.softmax(logits)
y_ = tf.placeholder(dtype=tf.float32,shape=[None, 5])
output_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='InceptionV1/Logits')
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy, var_list=output_vars)

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

flower_photos = input_data.read_data_sets('flower_photos/')

with tf.Session() as sess:
def do_fine_tune(network, optimizer, learning_rate, batch_size, epoch_num,
                 label_type, num_stack, num_skip, social_signal_type,
                 trained_model_path, restore_epoch=None):
    """Run training.
    Args:
        network: network to train
        optimizer: adam or adadelta or rmsprop
        learning_rate: initial learning rate
        batch_size: size of mini batch
        epoch_num: epoch num to train
        label_type: phone or character
        num_stack: int, the number of frames to stack
        num_skip: int, the number of frames to skip
        social_signal_type: insert or insert2 or insert3 or remove
        trained_model_path: path to the pre-trained model
        restore_epoch: epoch of the model to restore
    """
    # Tell TensorFlow that the model will be built into the default graph
    with tf.Graph().as_default():
        # Read dataset
        train_data = DataSetDialog(data_type='train', label_type=label_type,
                                   social_signal_type=social_signal_type,
                                   num_stack=num_stack, num_skip=num_skip,
                                   is_sorted=True)
        dev_data = DataSetDialog(data_type='dev', label_type=label_type,
                                 social_signal_type=social_signal_type,
                                 num_stack=num_stack, num_skip=num_skip,
                                 is_sorted=False)
        test_data = DataSetDialog(data_type='test', label_type=label_type,
                                  social_signal_type=social_signal_type,
                                  num_stack=num_stack, num_skip=num_skip,
                                  is_sorted=False)
        # TODO:作る
        # eval1_data = DataSet(data_type='eval1', label_type=label_type,
        #                      social_signal_type=social_signal_type,
        #                      num_stack=num_stack, num_skip=num_skip,
        #                      is_sorted=False)
        # eval2_data = DataSet(data_type='eval2', label_type=label_type,
        #                      social_signal_type=social_signal_type,
        #                      num_stack=num_stack, num_skip=num_skip,
        #                      is_sorted=False)
        # eval3_data = DataSet(data_type='eval3', label_type=label_type,
        #                      social_signal_type=social_signal_type,
        #                      num_stack=num_stack, num_skip=num_skip,
        #                      is_sorted=False)

        # Add to the graph each operation
        loss_op = network.loss()
        train_op = network.train(optimizer=optimizer,
                                 learning_rate_init=learning_rate,
                                 is_scheduled=False)
        decode_op = network.decoder(decode_type='beam_search',
                                    beam_width=20)
        per_op = network.ler(decode_op)

        # Build the summary tensor based on the TensorFlow collection of
        # summaries
        summary_train = tf.summary.merge(network.summaries_train)
        summary_dev = tf.summary.merge(network.summaries_dev)

        # Add the variable initializer operation
        init_op = tf.global_variables_initializer()

        # Create a saver for writing training checkpoints
        saver = tf.train.Saver(max_to_keep=None)

        # Count total parameters
        parameters_dict, total_parameters = count_total_parameters(
            tf.trainable_variables())
        for parameter_name in sorted(parameters_dict.keys()):
            print("%s %d" % (parameter_name, parameters_dict[parameter_name]))
        print("Total %d variables, %s M parameters" %
              (len(parameters_dict.keys()), "{:,}".format(total_parameters / 1000000)))

        csv_steps = []
        csv_train_loss = []
        csv_dev_loss = []

        # Create a session for running operation on the graph
        with tf.Session() as sess:
            # Instantiate a SummaryWriter to output summaries and the graph
            summary_writer = tf.summary.FileWriter(
                network.model_dir, sess.graph)

            # Initialize parameters
            sess.run(init_op)

            # Restore pre-trained model's parameters
            ckpt = tf.train.get_checkpoint_state(trained_model_path)
            if ckpt:
                # Use last saved model
                model_path = ckpt.model_checkpoint_path
                if restore_epoch is not None:
                    model_path = model_path.split('/')[:-1]
                    model_path = '/'.join(model_path) + \
                        '/model.ckpt-' + str(restore_epoch)
            else:
                raise ValueError('There are not any checkpoints.')
            exclude = ['output/Variable', 'output/Variable_1']
            variables_to_restore = slim.get_variables_to_restore(
                exclude=exclude)
            restorer = tf.train.Saver(variables_to_restore)
            restorer.restore(sess, model_path)
            print("Model restored: " + model_path)

            # Train model
            iter_per_epoch = int(train_data.data_num / batch_size)
            if (train_data.data_num / batch_size) != int(train_data.data_num / batch_size):
                iter_per_epoch += 1
            max_steps = iter_per_epoch * epoch_num
            start_time_train = time.time()
            start_time_epoch = time.time()
            start_time_step = time.time()
            fmean_best = 0
            for step in range(max_steps):
                # Create feed dictionary for next mini batch (train)
                inputs, labels, seq_len, _ = train_data.next_batch(
                    batch_size=batch_size)
                indices, values, dense_shape = list2sparsetensor(labels)
                feed_dict_train = {
                    network.inputs_pl: inputs,
                    network.label_indices_pl: indices,
                    network.label_values_pl: values,
                    network.label_shape_pl: dense_shape,
                    network.seq_len_pl: seq_len,
                    network.keep_prob_input_pl: network.dropout_ratio_input,
                    network.keep_prob_hidden_pl: network.dropout_ratio_hidden,
                    network.lr_pl: learning_rate
                }

                # Create feed dictionary for next mini batch (dev)
                inputs, labels, seq_len, _ = dev_data.next_batch(
                    batch_size=batch_size)
                indices, values, dense_shape = list2sparsetensor(labels)
                feed_dict_dev = {
                    network.inputs_pl: inputs,
                    network.label_indices_pl: indices,
                    network.label_values_pl: values,
                    network.label_shape_pl: dense_shape,
                    network.seq_len_pl: seq_len,
                    network.keep_prob_input_pl: network.dropout_ratio_input,
                    network.keep_prob_hidden_pl: network.dropout_ratio_hidden
                }

                # Update parameters & compute loss
                _, loss_train = sess.run(
                    [train_op, loss_op], feed_dict=feed_dict_train)
                loss_dev = sess.run(loss_op, feed_dict=feed_dict_dev)
                csv_steps.append(step)
                csv_train_loss.append(loss_train)
                csv_dev_loss.append(loss_dev)

                if (step + 1) % 10 == 0:
                    # Change feed dict for evaluation
                    feed_dict_train[network.keep_prob_input_pl] = 1.0
                    feed_dict_train[network.keep_prob_hidden_pl] = 1.0
                    feed_dict_dev[network.keep_prob_input_pl] = 1.0
                    feed_dict_dev[network.keep_prob_hidden_pl] = 1.0

                    # Compute accuracy & \update event file
                    ler_train, summary_str_train = sess.run([per_op, summary_train],
                                                            feed_dict=feed_dict_train)
                    ler_dev, summary_str_dev, labels_st = sess.run([per_op, summary_dev, decode_op],
                                                                   feed_dict=feed_dict_dev)
                    summary_writer.add_summary(summary_str_train, step + 1)
                    summary_writer.add_summary(summary_str_dev, step + 1)
                    summary_writer.flush()

                    # Decode
                    # try:
                    #     labels_pred = sparsetensor2list(labels_st, batch_size)
                    # except:
                    #     labels_pred = [[0] * batch_size]

                    duration_step = time.time() - start_time_step
                    print('Step %d: loss = %.3f (%.3f) / ler = %.4f (%.4f) (%.3f min)' %
                          (step + 1, loss_train, loss_dev, ler_train, ler_dev, duration_step / 60))

                    # if label_type == 'character':
                    #     if social_signal_type == 'remove':
                    #         map_file_path = '../evaluation/mapping_files/ctc/char2num_remove.txt'
                    #     else:
                    #         map_file_path = '../evaluation/mapping_files/ctc/char2num_' + \
                    #             social_signal_type + '.txt'
                    #     print('True: %s' % num2char(labels[-1], map_file_path))
                    #     print('Pred: %s' % num2char(
                    #         labels_pred[-1], map_file_path))
                    # elif label_type == 'phone':
                    #     if social_signal_type == 'remove':
                    #         map_file_path = '../evaluation/mapping_files/ctc/phone2num_remove.txt'
                    #     else:
                    #         map_file_path = '../evaluation/mapping_files/ctc/phone2num_' + \
                    #             social_signal_type + '.txt'
                    #     print('True: %s' % num2phone(
                    #         labels[-1], map_file_path))
                    #     print('Pred: %s' % num2phone(
                    #         labels_pred[-1], map_file_path))

                    sys.stdout.flush()
                    start_time_step = time.time()

                # Save checkpoint and evaluate model per epoch
                if (step + 1) % iter_per_epoch == 0 or (step + 1) == max_steps:
                    duration_epoch = time.time() - start_time_epoch
                    epoch = (step + 1) // iter_per_epoch
                    print('-----EPOCH:%d (%.3f min)-----' %
                          (epoch, duration_epoch / 60))

                    # Save model (check point)
                    checkpoint_file = os.path.join(
                        network.model_dir, 'model.ckpt')
                    save_path = saver.save(
                        sess, checkpoint_file, global_step=epoch)
                    print("Model saved in file: %s" % save_path)

                    start_time_eval = time.time()
                    if label_type == 'character':
                        print('■Dev Evaluation:■')
                        fmean_epoch = do_eval_fmeasure(session=sess, decode_op=decode_op,
                                                       network=network, dataset=dev_data,
                                                       label_type=label_type,
                                                       social_signal_type=social_signal_type)
                        # error_epoch = do_eval_cer(session=sess,
                        #                           decode_op=decode_op,
                        #                           network=network,
                        #                           dataset=dev_data,
                        #                           eval_batch_size=batch_size)

                        if fmean_epoch > fmean_best:
                            fmean_best = fmean_epoch
                            print('■■■ ↑Best Score (F-measure)↑ ■■■')

                            do_eval_fmeasure(session=sess, decode_op=decode_op,
                                             network=network, dataset=test_data,
                                             label_type=label_type,
                                             social_signal_type=social_signal_type)
                            # print('■eval1 Evaluation:■')
                            # do_eval_cer(session=sess, decode_op=decode_op,
                            #             network=network, dataset=eval1_data,
                            #             eval_batch_size=batch_size)
                            # print('■eval2 Evaluation:■')
                            # do_eval_cer(session=sess, decode_op=decode_op,
                            #             network=network, dataset=eval2_data,
                            #             eval_batch_size=batch_size)
                            # print('■eval3 Evaluation:■')
                            # do_eval_cer(session=sess, decode_op=decode_op,
                            #             network=network, dataset=eval3_data,
                            #             eval_batch_size=batch_size)

                    else:
                        print('■Dev Evaluation:■')
                        fmean_epoch = do_eval_fmeasure(session=sess, decode_op=decode_op,
                                                       network=network, dataset=dev_data,
                                                       label_type=label_type,
                                                       social_signal_type=social_signal_type)
                        # error_epoch = do_eval_per(session=sess,
                        #                           per_op=per_op,
                        #                           network=network,
                        #                           dataset=dev_data,
                        #                           eval_batch_size=batch_size)

                        if fmean_epoch < fmean_best:
                            fmean_best = fmean_epoch
                            print('■■■ ↑Best Score (F-measure)↑ ■■■')

                            do_eval_fmeasure(session=sess, decode_op=decode_op,
                                             network=network, dataset=test_data,
                                             label_type=label_type,
                                             social_signal_type=social_signal_type)
                            # print('■eval1 Evaluation:■')
                            # do_eval_per(session=sess, per_op=per_op,
                            #             network=network, dataset=eval1_data,
                            #             eval_batch_size=batch_size)
                            # print('■eval2 Evaluation:■')
                            # do_eval_per(session=sess, per_op=per_op,
                            #             network=network, dataset=eval2_data,
                            #             eval_batch_size=batch_size)
                            # print('■eval3 Evaluation:■')
                            # do_eval_per(session=sess, per_op=per_op,
                            #             network=network, dataset=eval3_data,
                            #             eval_batch_size=batch_size)

                    duration_eval = time.time() - start_time_eval
                    print('Evaluation time: %.3f min' %
                          (duration_eval / 60))

                    start_time_epoch = time.time()
                    start_time_step = time.time()

            duration_train = time.time() - start_time_train
            print('Total time: %.3f hour' % (duration_train / 3600))

            # Save train & dev loss
            save_loss(csv_steps, csv_train_loss, csv_dev_loss,
                      save_path=network.model_dir)

            # Training was finished correctly
            with open(os.path.join(network.model_dir, 'complete.txt'), 'w') as f:
                f.write('')
コード例 #48
0
ファイル: model.py プロジェクト: ulzee/megapix-gan
    def build_model(self):

        lerp_size = (1.0 / 161.0 / 32.0)
        lerp_factor = tf.Variable(0.333,
                                  name='lerp_factor',
                                  trainable=False,
                                  dtype=tf.float32)
        lerp_op = tf.assign(lerp_factor, lerp_factor + lerp_size)
        self.clip_lerp = tf.assign(lerp_factor, tf.minimum(lerp_op, 1.0))

        if self.y_dim:
            self.y = tf.placeholder(tf.float32, [self.batch_size, self.y_dim],
                                    name='y')
        else:
            self.y = None

        image_dims = [self.imsize, self.imsize, 1]

        # input params
        self.inputs = tf.placeholder(tf.float32,
                                     [self.batch_size] + image_dims,
                                     name='real_images')

        self.z = tf.placeholder(tf.float32, [None, self.z_dim], name='z')
        self.z_sum = histogram_summary("z", self.z)

        # Gen and discrim
        self.G = self.generator(self.z, self.y)
        self.D, self.D_logits = self.discriminator(self.inputs,
                                                   self.y,
                                                   reuse=False)
        self.sampler = self.sampler(self.z, self.y)
        self.D_, self.D_logits_ = self.discriminator(self.G,
                                                     self.y,
                                                     reuse=True)

        try:
            print('Defined G and D:')
            input()
        except:
            pass

        # Loss calculations
        self.d_sum = histogram_summary("d", self.D)
        self.d__sum = histogram_summary("d_", self.D_)
        self.G_sum = image_summary("G", self.G)

        def sigmoid_cross_entropy_with_logits(x, y):
            try:
                return tf.nn.sigmoid_cross_entropy_with_logits(logits=x,
                                                               labels=y)
            except:
                return tf.nn.sigmoid_cross_entropy_with_logits(logits=x,
                                                               targets=y)

        self.d_loss_real = tf.reduce_mean(
            sigmoid_cross_entropy_with_logits(self.D_logits,
                                              tf.ones_like(self.D)))
        self.d_loss_fake = tf.reduce_mean(
            sigmoid_cross_entropy_with_logits(self.D_logits_,
                                              tf.zeros_like(self.D_)))
        self.g_loss = tf.reduce_mean(
            sigmoid_cross_entropy_with_logits(self.D_logits_,
                                              tf.ones_like(self.D_)))

        self.d_loss_real_sum = scalar_summary("d_loss_real", self.d_loss_real)
        self.d_loss_fake_sum = scalar_summary("d_loss_fake", self.d_loss_fake)

        self.d_loss = self.d_loss_real + self.d_loss_fake

        self.g_loss_sum = scalar_summary("g_loss", self.g_loss)
        self.d_loss_sum = scalar_summary("d_loss", self.d_loss)

        t_vars = tf.trainable_variables()

        self.d_vars = [var for var in t_vars if 'd_' in var.name]
        self.g_vars = [var for var in t_vars if 'g_' in var.name]

        if self.grow is not None:
            current_vars = slim.get_variables_to_restore()
            prev_vars = []
            ignored_vars = []
            for var in current_vars:
                if '_h' in str(var):
                    parts = str(var).split('/')[1]
                    parts = parts.split('_')[1]
                    hvar = int(parts.replace('h', ''))
                    if 2**(hvar + 1) > self.grow:
                        ignored_vars.append(var)
                        continue
                elif '_b' in str(var):
                    parts = str(var).split('/')[1]
                    parts = parts.split('_')[1]
                    hvar = int(parts.replace('bn', ''))
                    if 2**(hvar + 1) > self.grow:
                        ignored_vars.append(var)
                        continue
                if 'lerp_factor' in str(var):
                    ignored_vars.append(var)
                    continue
                if 'g_image_%d' % self.stacks in str(var):
                    ignored_vars.append(var)
                    continue
                if 'd_image' in str(var):
                    ignored_vars.append(var)
                    continue
                prev_vars.append(var)
            self.load_vars = prev_vars

            print('These nodes will be ignored:')
            for var in ignored_vars:
                print('   *', var.name)
            print('These nodes will be loaded:')
            for var in prev_vars:
                print('   *', var.name)

            try:
                print('Press enter to confirm vars:')
                input()
            except:
                pass
コード例 #49
0
def main(argv=None):

    gpu_options = tf.GPUOptions(
        per_process_gpu_memory_fraction=cfg.GPU_MEMORY_FRACTION)

    config = tf.ConfigProto(
        gpu_options=gpu_options,
        log_device_placement=False,
    )

    classes = load_coco_names(cfg.CLASS_NAME)

    if cfg.FROZEN_MODEL:
        pass
    #
    #     t0 = time.time()
    #     frozenGraph = load_graph(cfg.FROZEN_MODEL)
    #     print("Loaded graph in {:.2f}s".format(time.time()-t0))
    #
    #     boxes, inputs = get_boxes_and_inputs_pb(frozenGraph)
    #
    #     with tf.Session(graph=frozenGraph, config=config) as sess:
    #         t0 = time.time()
    #         detected_boxes = sess.run(
    #             boxes, feed_dict={inputs: [img_resized]})

    else:
        if cfg.TINY:
            model = yolo_v3_tiny.yolo_v3_tiny
        else:
            model = yolo_v3.yolo_v3

        boxes, inputs = get_boxes_and_inputs(model, len(classes),
                                             cfg.IMAGE_SIZE, cfg.DATA_FORMAT)
        # boxes : coordinates of top left and bottom right points.
        saver = tf.train.Saver(var_list=tf.global_variables(scope='detector'))

        #
        # for specific object recognition
        #
        vgg16_image_size = vgg_16.default_image_size

        s_class_names = cfg.S_CLASS_PATH
        s_classes = [l.split(" ") for l in open(s_class_names, "r")]
        if len(s_classes[0]):  # classフォーマットが "id classname"の場合
            s_labels = {int(l[0]): l[1].replace("\n", "") for l in s_classes}
        else:  # classフォーマットが "classname"のみの場合
            s_labels = {
                i: l.replace("\n", "")
                for i, l in enumerate(s_classes)
            }

        num_classes_s = len(s_labels.keys())

        num_classes_extractor = cfg.S_EXTRACTOR_NUM_OF_CLASSES
        s_model = cfg.S_CKPT_FILE

        extractor_name = cfg.S_EXTRACTOR_NAME

        specific_pred, [
            cropped_images_placeholder, original_images_placeholder, keep_prob,
            is_training
        ] = specific_object_recognition(vgg16_image_size, num_classes_s,
                                        num_classes_extractor, extractor_name)

        variables_to_restore = slim.get_variables_to_restore(
            include=["shigeNet_v1"])
        restorer = tf.train.Saver(variables_to_restore)
        with tf.Session(config=config) as sess:
            t0 = time.time()
            saver.restore(sess, cfg.CKPT_FILE)
            print('YOLO v3 Model restored in {:.2f}s'.format(time.time() - t0),
                  "from:", cfg.CKPT_FILE)

            t0 = time.time()
            restorer.restore(sess, s_model)
            print(
                'Specific object recognition Model restored in {:.2f}s'.format(
                    time.time() - t0), "from:", s_model)

            # prepare test set
            with open(cfg.TEST_FILE_PATH, 'r') as f:
                f_ = [line.rstrip().split() for line in f]

            data = [
                [l, get_annotation(l[0], txtname=cfg.GT_INFO_FILE_NAME)]
                for l in f_
            ]  # data: [[(path_str, label), [frame, center_x, center_y, size_x, size_y]],...]
            data = [l for l in data
                    if l[1] is not None]  # annotationを取得できなかった画像は飛ばす

            def is_cropped_file_Exist(orig_filepath):
                d, file = os.path.split(orig_filepath)
                cropped_d = d + "_cropped"
                cropped_file = os.path.join(cropped_d, file)
                return os.path.exists(cropped_file)

            data = [l for l in data
                    if is_cropped_file_Exist(l[0][0])]  # 対となるcrop画像がない画像は飛ばす

            # log
            f = open(cfg.OUTPUT_LOG_PATH, 'w')
            writer = csv.writer(f, lineterminator='\n')
            writer.writerow([
                'image path', 'class/movie_name', 'IoU', 'Average Precision',
                'TP', 'FP', 'FN', 'is RoI detected?', 'gt label',
                ' highest_conf_label', 'detect time', 'recog time'
            ])

            total_iou = []  # 画像毎のiouのリスト
            total_ap = []  # 画像毎のaverage precisionのリスト
            total_tp = 0
            total_fp = 0
            total_fn = 0

            # iterative run
            for count, gt in enumerate(
                    data
            ):  # gt: [(path_str, label), [frame, center_x, center_y, size_x, size_y]
                # for evaluation
                gt_box = [float(i) for i in gt[1][1:]]
                gt_box = [
                    gt_box[0] - (gt_box[2] / 2), gt_box[1] - (gt_box[3] / 2),
                    gt_box[0] + (gt_box[2] / 2), gt_box[1] + (gt_box[3] / 2)
                ]
                gt_label = int(gt[0][1])
                gt_anno = {gt_label: gt_box}

                print(count, ":", gt[0][0])
                img = Image.open(gt[0][0])
                img_resized = letter_box_image(img, cfg.IMAGE_SIZE,
                                               cfg.IMAGE_SIZE, 128)
                img_resized = img_resized.astype(np.float32)

                t0 = time.time()
                detected_boxes = sess.run(boxes,
                                          feed_dict={inputs: [img_resized]})

                filtered_boxes = non_max_suppression(
                    detected_boxes,
                    confidence_threshold=cfg.CONF_THRESHOLD,
                    iou_threshold=cfg.IOU_THRESHOLD)
                detect_time = time.time() - t0

                print("detected boxes in :{:.2f}s ".format(detect_time),
                      filtered_boxes)

                # specific object recognition!
                np_img = np.array(img) / 255
                target_label = 0  # seesaaの場合 (データセットのクラス番号毎にここを変える.)

                if len(filtered_boxes.keys()) != 0:  # 何かしら検出された時
                    is_detected = True

                    # get specific object name
                    for cls, bboxs in filtered_boxes.items():
                        if cls == target_label:  # ターゲットラベルなら
                            print("target class detected!")
                            bounding_boxes = []
                            bboxs_ = copy.deepcopy(
                                bboxs
                            )  # convert_to_original_size()がbboxを破壊してしまうため
                            for box, score in bboxs:
                                orig_size_box = convert_to_original_size(
                                    box,
                                    np.array((cfg.IMAGE_SIZE, cfg.IMAGE_SIZE)),
                                    np.array(img.size), True)
                                # print(orig_size_box)
                                cropped_image = np_img[
                                    int(orig_size_box[1]):int(orig_size_box[3]
                                                              ),
                                    int(orig_size_box[0]):int(orig_size_box[2]
                                                              )]
                                bounding_boxes.append(cropped_image)

                                # cv2.imshow('result', cropped_image)
                                # cv2.waitKey(0)

                            input_original = cv2.resize(
                                padding(np_img),
                                (vgg16_image_size, vgg16_image_size))
                            input_original = np.tile(
                                input_original, (len(bounding_boxes), 1, 1,
                                                 1))  # croppedと同じ枚数分画像を重ねる

                            cropped_images = []
                            for bbox in bounding_boxes:
                                cropped_images.append(
                                    cv2.resize(
                                        padding(bbox),
                                        (vgg16_image_size, vgg16_image_size)))

                            input_cropped = np.asarray(cropped_images)

                            t0 = time.time()
                            pred = sess.run(specific_pred,
                                            feed_dict={
                                                cropped_images_placeholder:
                                                input_cropped,
                                                original_images_placeholder:
                                                input_original,
                                                keep_prob: 1.0,
                                                is_training: False
                                            })

                            recog_time = time.time() - t0
                            print("Predictions found in {:.2f}s".format(
                                recog_time))

                            # pred_label = [s_labels[i] for i in pred.tolist()] # idからクラス名を得る

                            classes = [
                                s_labels[i] for i in range(num_classes_s)
                            ]

                            filtered_boxes = {}
                            for i, n in enumerate(pred.tolist()):
                                if n in filtered_boxes.keys():
                                    filtered_boxes[n].extend(
                                        [bboxs_[i]])  # filtered box
                                else:
                                    filtered_boxes[n] = [bboxs_[i]]

                    # evaluation
                    print("specific obj:", filtered_boxes)
                    [tp, fp, fn], iou, ap, highest_conf_label = evaluate(
                        filtered_boxes, gt_anno, img,
                        thresh=0.1)  # 一枚の画像の評価を行う

                else:  #何も検出されなかった時
                    is_detected = False
                    iou = 0.0
                    ap = 0.0
                    tp = 0
                    fp = 0
                    fn = len(gt_anno.values())
                    highest_conf_label = -1

                total_iou.append(iou)
                total_ap.append(ap)
                print("IoU:", iou)
                print("average Precision:", ap)
                print("mean average IoU:",
                      sum(total_iou) / (len(total_iou) + 1e-05))
                print("mean Average Precision:",
                      sum(total_ap) / (len(total_ap) + 1e-05))

                total_tp += tp
                total_fp += fp
                total_fn += fn

                # draw pred_bbox
                draw_boxes(filtered_boxes, img, classes,
                           (cfg.IMAGE_SIZE, cfg.IMAGE_SIZE), True)
                # draw GT
                draw = ImageDraw.Draw(img)
                color = (0, 0, 0)
                draw.rectangle(gt_box, outline=color)
                draw.text(gt_box[:2], 'GT_' + classes[gt_label], fill=color)

                img.save(
                    os.path.join(
                        cfg.OUTPUT_DIR,
                        '{0:04d}_'.format(count) + os.path.basename(gt[0][0])))

                movie_name = os.path.basename(os.path.dirname(gt[0][0]))
                movie_parant_dir = os.path.basename(
                    os.path.dirname(os.path.dirname(gt[0][0])))
                pred_label = classes[
                    highest_conf_label] if highest_conf_label != -1 else "None"
                writer.writerow([
                    gt[0][0],
                    os.path.join(movie_name, movie_parant_dir), iou, ap, tp,
                    fp, fn, is_detected, classes[gt_label], pred_label,
                    detect_time, recog_time
                ])

            print("total tp :", total_tp)
            print("total fp :", total_fp)
            print("total fn :", total_fn)
            f.close()
            print("proc finished.")
コード例 #50
0
def train(
    config_yaml,
    displayiters,
    saveiters,
    maxiters,
    max_to_keep=5,
    keepdeconvweights=True,
    allow_growth=False,
):
    start_path = os.getcwd()
    os.chdir(
        str(Path(config_yaml).parents[0])
    )  # switch to folder of config_yaml (for logging)
    setup_logging()

    cfg = load_config(config_yaml)
    net_type = cfg["net_type"]
    if cfg["dataset_type"] in ("scalecrop", "tensorpack", "deterministic"):
        print(
            "Switching batchsize to 1, as tensorpack/scalecrop/deterministic loaders do not support batches >1. Use imgaug/default loader."
        )
        cfg["batch_size"] = 1  # in case this was edited for analysis.-

    dataset = create_dataset(cfg)
    batch_spec = get_batch_spec(cfg)
    batch, enqueue_op, placeholders = setup_preloading(batch_spec)

    losses = pose_net(cfg).train(batch)
    total_loss = losses["total_loss"]

    for k, t in losses.items():
        TF.summary.scalar(k, t)
    merged_summaries = TF.summary.merge_all()

    if "snapshot" in Path(cfg["init_weights"]).stem and keepdeconvweights:
        print("Loading already trained DLC with backbone:", net_type)
        variables_to_restore = slim.get_variables_to_restore()
    else:
        print("Loading ImageNet-pretrained", net_type)
        # loading backbone from ResNet, MobileNet etc.
        if "resnet" in net_type:
            variables_to_restore = slim.get_variables_to_restore(include=["resnet_v1"])
        elif "mobilenet" in net_type:
            variables_to_restore = slim.get_variables_to_restore(
                include=["MobilenetV2"]
            )
        elif "efficientnet" in net_type:
            variables_to_restore = slim.get_variables_to_restore(
                include=["efficientnet"]
            )
            variables_to_restore = {
                var.op.name.replace("efficientnet/", "")
                + "/ExponentialMovingAverage": var
                for var in variables_to_restore
            }
        else:
            print("Wait for DLC 2.3.")

    restorer = TF.train.Saver(variables_to_restore)
    saver = TF.train.Saver(
        max_to_keep=max_to_keep
    )  # selects how many snapshots are stored, see https://github.com/AlexEMG/DeepLabCut/issues/8#issuecomment-387404835

    if allow_growth == True:
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = TF.Session(config=config)
    else:
        sess = TF.Session()

    coord, thread = start_preloading(sess, enqueue_op, dataset, placeholders)
    train_writer = TF.summary.FileWriter(cfg["log_dir"], sess.graph)

    if cfg.get("freezeencoder", False):
        if "efficientnet" in net_type:
            print("Freezing ONLY supported MobileNet/ResNet currently!!")
            learning_rate, train_op, tstep = get_optimizer(total_loss, cfg)

        print("Freezing encoder...")
        learning_rate, _, train_op = get_optimizer_with_freeze(total_loss, cfg)
    else:
        learning_rate, train_op, tstep = get_optimizer(total_loss, cfg)

    sess.run(TF.global_variables_initializer())
    sess.run(TF.local_variables_initializer())

    # Restore variables from disk.
    restorer.restore(sess, cfg["init_weights"])
    if maxiters == None:
        max_iter = int(cfg["multi_step"][-1][1])
    else:
        max_iter = min(int(cfg["multi_step"][-1][1]), int(maxiters))
        # display_iters = max(1,int(displayiters))
        print("Max_iters overwritten as", max_iter)

    if displayiters == None:
        display_iters = max(1, int(cfg["display_iters"]))
    else:
        display_iters = max(1, int(displayiters))
        print("Display_iters overwritten as", display_iters)

    if saveiters == None:
        save_iters = max(1, int(cfg["save_iters"]))

    else:
        save_iters = max(1, int(saveiters))
        print("Save_iters overwritten as", save_iters)

    cum_loss = 0.0
    lr_gen = LearningRate(cfg)

    stats_path = Path(config_yaml).with_name("learning_stats.csv")
    lrf = open(str(stats_path), "w")

    print("Training parameter:")
    print(cfg)
    print("Starting training....")
    for it in range(max_iter + 1):
        if "efficientnet" in net_type:
            dict = {tstep: it}
            current_lr = sess.run(learning_rate, feed_dict=dict)
        else:
            current_lr = lr_gen.get_lr(it)
            dict = {learning_rate: current_lr}

        [_, loss_val, summary] = sess.run(
            [train_op, total_loss, merged_summaries], feed_dict=dict
        )
        cum_loss += loss_val
        train_writer.add_summary(summary, it)

        if it % display_iters == 0 and it > 0:
            average_loss = cum_loss / display_iters
            cum_loss = 0.0
            logging.info(
                "iteration: {} loss: {} lr: {}".format(
                    it, "{0:.4f}".format(average_loss), current_lr
                )
            )
            lrf.write("{}, {:.5f}, {}\n".format(it, average_loss, current_lr))
            lrf.flush()

        # Save snapshot
        if (it % save_iters == 0 and it != 0) or it == max_iter:
            model_name = cfg["snapshot_prefix"]
            saver.save(sess, model_name, global_step=it)

    lrf.close()
    sess.close()
    coord.request_stop()
    coord.join([thread])
    # return to original path.
    os.chdir(str(start_path))