Esempio n. 1
0
def main(_):
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    gpu = str(get_available_gpus(FLAGS.gpu_num))
    print('GPU devices: ', gpu)
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu

    # ============================================================================
    # ============================= TRAIN ========================================
    # ============================================================================
    print(sorted_str_dict(FLAGS.__dict__))
    if FLAGS.resume_step is not None:
        print('Ready to resume from step %d.' % FLAGS.resume_step)

    assert FLAGS.gpu_num is not None, 'should specify the number of gpu.'
    assert FLAGS.gpu_num > 0, 'the number of gpu should be bigger than 0.'
    if FLAGS.eval_only:
        logdir = LogDir(FLAGS.database, model_id())
        logdir.print_all_info()
        f_log = open(
            logdir.exp_dir + '/' + str(datetime.datetime.now()) + '.txt', 'w')
        f_log.write('step,loss,precision,wd\n')
        f_log.write(sorted_str_dict(FLAGS.__dict__) + '\n')
    else:
        f_log, logdir, has_nan = train(FLAGS.resume_step)

        if has_nan:
            f_log.write('TEST:0,nan,nan\n')
            f_log.flush()
            return

    # ============================================================================
    # ============================= EVAL =========================================
    # ============================================================================
    f_log.write('TEST:step,loss,precision\n')

    import glob
    i_ckpts = sorted(glob.glob(logdir.snapshot_dir + '/model.ckpt-*.index'),
                     key=os.path.getmtime)

    # ============================================================================
    # ======================== Eval for the last model ===========================
    # ============================================================================
    i_ckpt = i_ckpts[-1].split('.index')[0]
    loss, precision = eval(i_ckpt)
    step = i_ckpt.split('-')[-1]
    print('%s %s] Step %s Test' %
          (str(datetime.datetime.now()), str(os.getpid()), step))
    print('\t loss = %.4f, precision = %.4f' % (loss, precision))
    f_log.write('TEST:%s,%f,%f\n' % (step, loss, precision))
    f_log.flush()

    f_log.close()
Esempio n. 2
0
def main(_):
    # ============================================================================
    # ============================= TRAIN ========================================
    # ============================================================================
    print(sorted_str_dict(FLAGS.__dict__))
    if FLAGS.resume_step is not None:
        print 'Ready to resume from step %d.' % FLAGS.resume_step

    assert FLAGS.gpu_num is not None, 'should specify the number of gpu.'
    assert FLAGS.gpu_num > 0, 'the number of gpu should be bigger than 0.'
    if FLAGS.eval_only:
        logdir = LogDir(FLAGS.database, FLAGS.log_dir, FLAGS.weight_decay_mode)
        logdir.print_all_info()
        f_log = open(
            logdir.exp_dir + '/' + str(datetime.datetime.now()) + '.txt', 'w')
        f_log.write('step,loss,precision,wd\n')
        f_log.write(sorted_str_dict(FLAGS.__dict__) + '\n')
    else:
        f_log, logdir, has_nan = train(FLAGS.resume_step)

        if has_nan:
            f_log.write('TEST:0,nan,nan\n')
            f_log.flush()
            return

    # ============================================================================
    # ============================= EVAL =========================================
    # ============================================================================
    f_log.write('TEST:step,loss,precision\n')

    import glob
    i_ckpts = sorted(glob.glob(logdir.snapshot_dir + '/model.ckpt-*.index'),
                     key=os.path.getmtime)

    # ============================================================================
    # ======================== Eval for the last model ===========================
    # ============================================================================
    i_ckpt = i_ckpts[-1].split('.index')[0]
    precision = eval(i_ckpt)
    step = i_ckpt.split('-')[-1]
    print '%s %s] Step %s Test' % (str(
        datetime.datetime.now()), str(os.getpid()), step)
    print '\t precision = %.4f' % (precision)
    f_log.write('TEST:%s,%f\n' % (step, precision))
    f_log.flush()

    f_log.close()
Esempio n. 3
0
def train(resume_step=None):
    # < preparing arguments >
    if FLAGS.float_type == 16:
        print('\n< using tf.float16 >\n')
        float_type = tf.float16
    else:
        print('\n< using tf.float32 >\n')
        float_type = tf.float32
    new_layer_names = FLAGS.new_layer_names
    if FLAGS.new_layer_names is not None:
        new_layer_names = new_layer_names.split(',')

    # < data set >
    data_list = FLAGS.subsets_for_training.split(',')
    if len(data_list) < 1:
        data_list = ['train']
    list_images = []
    list_labels = []
    with tf.device('/cpu:0'):
        reader = SegmentationImageReader(
            FLAGS.database,
            data_list, (FLAGS.train_image_size, FLAGS.train_image_size),
            FLAGS.random_scale,
            random_mirror=True,
            random_blur=True,
            random_rotate=FLAGS.random_rotate,
            color_switch=FLAGS.color_switch,
            scale_rate=(FLAGS.scale_min, FLAGS.scale_max))
        for _ in xrange(FLAGS.gpu_num):
            image_batch, label_batch = reader.dequeue(FLAGS.batch_size)
            list_images.append(image_batch)
            list_labels.append(label_batch)

    # < network >
    model = pspnet_mg.PSPNetMG(
        reader.num_classes,
        mode='train',
        resnet=FLAGS.network,
        bn_mode='frozen' if FLAGS.bn_frozen else 'gather',
        data_format=FLAGS.data_format,
        initializer=FLAGS.initializer,
        fine_tune_filename=FLAGS.fine_tune_filename,
        wd_mode=FLAGS.weight_decay_mode,
        gpu_num=FLAGS.gpu_num,
        float_type=float_type,
        has_aux_loss=FLAGS.has_aux_loss,
        train_like_in_paper=FLAGS.train_like_in_paper,
        structure_in_paper=FLAGS.structure_in_paper,
        new_layer_names=new_layer_names,
        loss_type=FLAGS.loss_type,
        consider_dilated=FLAGS.consider_dilated)
    train_ops = model.build_train_ops(list_images, list_labels)

    # < log dir and model id >
    logdir = LogDir(FLAGS.database, model_id())
    logdir.print_all_info()
    if not os.path.exists(logdir.log_dir):
        print('creating ', logdir.log_dir, '...')
        os.mkdir(logdir.log_dir)
    if not os.path.exists(logdir.database_dir):
        print('creating ', logdir.database_dir, '...')
        os.mkdir(logdir.database_dir)
    if not os.path.exists(logdir.exp_dir):
        print('creating ', logdir.exp_dir, '...')
        os.mkdir(logdir.exp_dir)
    if not os.path.exists(logdir.snapshot_dir):
        print('creating ', logdir.snapshot_dir, '...')
        os.mkdir(logdir.snapshot_dir)

    gpu_options = tf.GPUOptions(allow_growth=False)
    config = tf.ConfigProto(log_device_placement=False,
                            gpu_options=gpu_options,
                            allow_soft_placement=True)
    sess = tf.Session(config=config)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    init = [
        tf.global_variables_initializer(),
        tf.local_variables_initializer()
    ]
    sess.run(init)

    # < convert npy to .ckpt >
    step = 0
    if '.npy' in FLAGS.fine_tune_filename:
        # This can transform .npy weights with variables names being the same to the tf ckpt model.
        fine_tune_variables = []
        npy_dict = np.load(FLAGS.fine_tune_filename).item()
        new_layers_names = ['Momentum']
        for v in tf.global_variables():
            if any(elem in v.name for elem in new_layers_names):
                continue

            name = v.name.split(':0')[0]
            if name not in npy_dict:
                continue

            v.load(npy_dict[name], sess)
            fine_tune_variables.append(v)

        saver = tf.train.Saver(var_list=fine_tune_variables)
        saver.save(sess, logdir.snapshot_dir + '/model.ckpt', global_step=0)
        return

    # < load pre-trained model>
    import_variables = tf.trainable_variables()
    if FLAGS.fine_tune_filename is not None and resume_step is None:
        fine_tune_variables = []
        new_layers_names = model.new_layers_names
        new_layers_names.append('Momentum')
        new_layers_names.append('up_sample')
        for v in import_variables:
            if any(elem in v.name for elem in new_layers_names):
                print('< Finetuning Process: not import %s >' % v.name)
                continue
            fine_tune_variables.append(v)

        loader = tf.train.Saver(var_list=fine_tune_variables, allow_empty=True)
        loader.restore(sess, FLAGS.fine_tune_filename)
        print('< Succesfully loaded fine-tune model from %s. >' %
              FLAGS.fine_tune_filename)
    elif resume_step is not None:
        # ./snapshot/model.ckpt-3000
        i_ckpt = logdir.snapshot_dir + '/model.ckpt-%d' % resume_step

        loader = tf.train.Saver(max_to_keep=0)
        loader.restore(sess, i_ckpt)

        step = resume_step
        print('< Succesfully loaded model from %s at step=%s. >' %
              (i_ckpt, resume_step))
    else:
        print('< Not import any model. >')

    f_log = open(logdir.exp_dir + '/' + str(datetime.datetime.now()) + '.txt',
                 'w')
    f_log.write('step,loss,precision,wd\n')
    f_log.write(sorted_str_dict(FLAGS.__dict__) + '\n')

    print('\n< training process begins >\n')
    average_loss = 0.0
    show_period = 20
    snapshot = FLAGS.snapshot
    max_iter = FLAGS.train_max_iter
    lrn_rate = FLAGS.lrn_rate

    lr_step = []
    if FLAGS.lr_step is not None:
        temps = FLAGS.lr_step.split(',')
        for t in temps:
            lr_step.append(int(t))

    saver = tf.train.Saver(max_to_keep=2)
    t0 = None
    wd_rate = FLAGS.weight_decay_rate
    wd_rate2 = FLAGS.weight_decay_rate2

    if FLAGS.save_first_iteration == 1:
        saver.save(sess, logdir.snapshot_dir + '/model.ckpt', global_step=step)

    has_nan = False
    while step < max_iter + 1:
        if FLAGS.poly_lr == 1:
            lrn_rate = ((1 - 1.0 * step / max_iter)**0.9) * FLAGS.lrn_rate

        step += 1
        if len(lr_step) > 0 and step == lr_step[0]:
            lrn_rate *= FLAGS.step_size
            lr_step.remove(step)

        _, loss, wd, precision = sess.run(
            [train_ops, model.loss, model.wd, model.precision_op],
            feed_dict={
                model.lrn_rate_ph: lrn_rate,
                model.wd_rate_ph: wd_rate,
                model.wd_rate2_ph: wd_rate2
            })

        if math.isnan(loss) or math.isnan(wd):
            print('\nloss or weight norm is nan. Training Stopped!\n')
            has_nan = True
            break

        average_loss += loss

        if step % snapshot == 0:
            saver.save(sess,
                       logdir.snapshot_dir + '/model.ckpt',
                       global_step=step)
            sess.run([tf.local_variables_initializer()])

        if step % show_period == 0:
            left_hours = 0

            if t0 is not None:
                delta_t = (datetime.datetime.now() - t0).total_seconds()
                left_time = (max_iter - step) / show_period * delta_t
                left_hours = left_time / 3600.0

            t0 = datetime.datetime.now()
            average_loss /= show_period

            f_log.write('%d,%f,%f,%f\n' % (step, average_loss, precision, wd))
            f_log.flush()

            print('%s %s] Step %s, lr = %f, wd_rate = %f, wd_rate_2 = %f ' \
                  % (str(datetime.datetime.now()), str(os.getpid()), step, lrn_rate, wd_rate, wd_rate2))
            print('\t loss = %.4f, precision = %.4f, wd = %.4f' %
                  (average_loss, precision, wd))
            print('\t estimated time left: %.1f hours. %d/%d' %
                  (left_hours, step, max_iter))

            average_loss = 0.0

    coord.request_stop()
    coord.join(threads)

    return f_log, logdir, has_nan  # f_log and logdir returned for eval.
Esempio n. 4
0
def train(resume_step=None):
    global_step = tf.get_variable('global_step', [], dtype=tf.int64,
                                  initializer=tf.constant_initializer(0), trainable=False)
    image_size = FLAGS.train_image_size

    print '================',
    if FLAGS.data_type == 16:
        print 'using tf.float16 ====================='
        data_type = tf.float16
        print 'can not use float16 at this moment, because of tf.nn.bn, if using fused_bn, the learning will be nan',
        print ', no idea what happened.'
    else:
        print 'using tf.float32 ====================='
        data_type = tf.float32

    if FLAGS.database == 'CityScapes':
        from database.cityscapes_reader import CityScapesReader as ImageReader
        num_classes = 19
        data_list = FLAGS.subsets_for_training.split(',')
        if len(data_list) < 1:
            data_list = ['train']
    else:
        print("Unknown database %s" % FLAGS.database)
        return
    print data_list
    images = []
    labels = []

    with tf.device('/cpu:0'):
        reader = ImageReader(
            FLAGS.server,
            data_list,
            (image_size, image_size),
            FLAGS.random_scale,
            random_mirror=True,
            random_blur=True,
            random_rotate=FLAGS.random_rotate,
            color_switch=FLAGS.color_switch,
            scale_rate=(FLAGS.scale_min, FLAGS.scale_max))

    print '================ Database Info ================'
    for i in range(FLAGS.gpu_num):
        with tf.device('/cpu:0'):
            image_batch, label_batch = reader.dequeue(FLAGS.batch_size)
            images.append(image_batch)
            labels.append(label_batch)

    wd_rate_ph = tf.placeholder(data_type, shape=())
    wd_rate2_ph = tf.placeholder(data_type, shape=())
    lrn_rate_ph = tf.placeholder(data_type, shape=())

    new_layer_names = FLAGS.new_layer_names
    if FLAGS.new_layer_names is not None:
        new_layer_names = new_layer_names.split(',')
    assert 'pspnet' in FLAGS.network

    resnet = 'resnet_v1_101'
    PSPModel = pspnet_mg.PSPNetMG

    with tf.variable_scope(resnet):
        model = PSPModel(num_classes, lrn_rate_ph, wd_rate_ph, wd_rate2_ph,
                         mode='train', bn_epsilon=FLAGS.epsilon, resnet=resnet,
                         norm_only=FLAGS.norm_only,
                         initializer=FLAGS.initializer,
                         fix_blocks=FLAGS.fix_blocks,
                         fine_tune_filename=FLAGS.fine_tune_filename,
                         bn_ema=FLAGS.ema_decay,
                         bn_frozen=FLAGS.bn_frozen,
                         wd_mode=FLAGS.weight_decay_mode,
                         fisher_filename=FLAGS.fisher_filename,
                         gpu_num=FLAGS.gpu_num,
                         float_type=data_type,
                         fisher_epsilon=FLAGS.fisher_epsilon,
                         has_aux_loss=FLAGS.has_aux_loss,
                         train_like_in_paper=FLAGS.train_like_in_paper,
                         structure_in_paper=FLAGS.structure_in_paper,
                         new_layer_names=new_layer_names,
                         loss_type=FLAGS.loss_type)
        model.inference(images)
        model.build_train_op(labels)

    names = []
    num_params = 0
    for v in tf.trainable_variables():
        # print v.name
        names.append(v.name)
        num = 1
        for i in v.get_shape().as_list():
            num *= i
        num_params += num
    print "Trainable parameters' num: %d" % num_params

    print 'iou precision shape: ', model.predictions.get_shape(), labels[0].get_shape()
    pred = tf.reshape(model.predictions, [-1, ])
    gt = tf.reshape(labels[0], [-1, ])
    indices = tf.squeeze(tf.where(tf.less_equal(gt, num_classes - 1)), 1)
    gt = tf.cast(tf.gather(gt, indices), tf.int32)
    pred = tf.gather(pred, indices)
    precision_op, update_op = tf.contrib.metrics.streaming_mean_iou(pred, gt, num_classes=num_classes)
    # ========================= end of building model ================================

    step = 0
    logdir = LogDir(FLAGS.database, FLAGS.log_dir, FLAGS.weight_decay_mode)
    logdir.print_all_info()
    if not os.path.exists(logdir.log_dir):
        print 'creating ', logdir.log_dir, '...'
        os.mkdir(logdir.log_dir)
    if not os.path.exists(logdir.database_dir):
        print 'creating ', logdir.database_dir, '...'
        os.mkdir(logdir.database_dir)
    if not os.path.exists(logdir.exp_dir):
        print 'creating ', logdir.exp_dir, '...'
        os.mkdir(logdir.exp_dir)
    if not os.path.exists(logdir.snapshot_dir):
        print 'creating ', logdir.snapshot_dir, '...'
        os.mkdir(logdir.snapshot_dir)

    init = [tf.global_variables_initializer(), tf.local_variables_initializer()]

    gpu_options = tf.GPUOptions(allow_growth=False)
    config = tf.ConfigProto(log_device_placement=False, gpu_options=gpu_options, allow_soft_placement=True)
    sess = tf.Session(config=config)
    sess.run(init)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    import_variables = tf.trainable_variables()
    if FLAGS.fix_blocks > 0 or FLAGS.bn_frozen > 0:
        import_variables = tf.global_variables()

    if '.npy' in FLAGS.fine_tune_filename:
        # This can transform .npy weights with variables names being the same to the tf ckpt model.
        fine_tune_variables = []
        npy_dict = np.load(FLAGS.fine_tune_filename).item()
        new_layers_names = ['Momentum']
        for v in tf.global_variables():
            print '=====Saving initial snapshot process:',
            if any(elem in v.name for elem in new_layers_names):
                print 'not import', v.name
                continue

            name = v.name.split(':0')[0]
            if name not in npy_dict:
                print 'not find', v.name
                continue

            v.load(npy_dict[name], sess)
            print 'saving', v.name
            fine_tune_variables.append(v)

        saver = tf.train.Saver(var_list=fine_tune_variables)
        saver.save(sess, logdir.snapshot_dir + '/model.ckpt', global_step=0)

        return

    if FLAGS.fine_tune_filename is not None and resume_step is None:
        fine_tune_variables = []
        new_layers_names = model.new_layers_names
        new_layers_names.append('Momentum')
        for v in import_variables:
            if any(elem in v.name for elem in new_layers_names):
                print '=====Finetuning Process: not import %s' % v.name
                continue
            fine_tune_variables.append(v)

        loader = tf.train.Saver(var_list=fine_tune_variables)
        loader.restore(sess, FLAGS.fine_tune_filename)
        print('=====Succesfully loaded fine-tune model from %s.' % FLAGS.fine_tune_filename)
    elif resume_step is not None:
        # ./snapshot/model.ckpt-3000
        i_ckpt = logdir.snapshot_dir + '/model.ckpt-%d' % resume_step

        loader = tf.train.Saver(max_to_keep=0)
        loader.restore(sess, i_ckpt)

        step = resume_step
        print('=====Succesfully loaded model from %s at step=%s.' % (i_ckpt, resume_step))
    else:
        print '=====Not import any model.'

    print '=========================== training process begins ================================='
    f_log = open(logdir.exp_dir + '/' + str(datetime.datetime.now()) + '.txt', 'w')
    f_log.write('step,loss,precision,wd\n')
    f_log.write(sorted_str_dict(FLAGS.__dict__) + '\n')

    average_loss = 0.0
    show_period = 20
    snapshot = FLAGS.snapshot
    max_iter = FLAGS.train_max_iter
    lrn_rate = FLAGS.lrn_rate

    lr_step = []
    if FLAGS.lr_step is not None:
        temps = FLAGS.lr_step.split(',')
        for t in temps:
            lr_step.append(int(t))

    # fine_tune_variables = []
    # for v in tf.global_variables():
    #     if 'Momentum' in v.name:
    #         continue
    #     print '=====Saving initial snapshot process: saving %s' % v.name
    #     fine_tune_variables.append(v)
    #
    # saver = tf.train.Saver(var_list=fine_tune_variables)
    # saver.save(sess, logdir.snapshot_dir + '/model.ckpt', global_step=0)

    saver = tf.train.Saver(max_to_keep=2)
    t0 = None
    wd_rate = FLAGS.weight_decay_rate
    wd_rate2 = FLAGS.weight_decay_rate2

    if FLAGS.save_first_iteration == 1:
        saver.save(sess, logdir.snapshot_dir + '/model.ckpt', global_step=step)

    has_nan = False
    while step < max_iter + 1:
        if FLAGS.poly_lr == 1:
            lrn_rate = ((1-1.0*step/max_iter)**0.9) * FLAGS.lrn_rate

        step += 1
        if len(lr_step) > 0 and step == lr_step[0]:
            lrn_rate *= FLAGS.step_size
            lr_step.remove(step)

        _, loss, wd, update, precision = sess.run([
            model.train_op, model.loss, model.wd, update_op, precision_op
        ],
            feed_dict={
                lrn_rate_ph: lrn_rate,
                wd_rate_ph: wd_rate,
                wd_rate2_ph: wd_rate2
            }
        )

        if math.isnan(loss) or math.isnan(wd):
            print 'loss or weight norm is nan. Training Stopped!'
            has_nan = True
            break

        average_loss += loss

        if step % snapshot == 0:
            saver.save(sess, logdir.snapshot_dir + '/model.ckpt', global_step=step)
            sess.run([tf.local_variables_initializer()])

        if step % show_period == 0:
            left_hours = 0

            if t0 is not None:
                delta_t = (datetime.datetime.now() - t0).seconds
                left_time = (max_iter - step) / show_period * delta_t
                left_hours = left_time/3600.0

            t0 = datetime.datetime.now()

            average_loss /= show_period

            if step == 0:
                average_loss *= show_period

            f_log.write('%d,%f,%f,%f\n' % (step, average_loss, precision, wd))
            f_log.flush()

            print '%s %s] Step %s, lr = %f, wd_rate = %f, wd_rate_2 = %f ' \
                  % (str(datetime.datetime.now()), str(os.getpid()), step, lrn_rate, wd_rate, wd_rate2)
            print '\t loss = %.4f, precision = %.4f, wd = %.4f' % (average_loss, precision, wd)
            print '\t estimated time left: %.1f hours. %d/%d' % (left_hours, step, max_iter)

            average_loss = 0.0

    coord.request_stop()
    coord.join(threads)

    return f_log, logdir, has_nan  # f_log and logdir returned for eval.
Esempio n. 5
0
def train(resume_step=None):
    global_step = tf.get_variable('global_step', [],
                                  dtype=tf.int64,
                                  initializer=tf.constant_initializer(0),
                                  trainable=False)
    print('================', end='')
    if FLAGS.data_type == 16:
        print('using tf.float16 =====================')
        data_type = tf.float16
    else:
        print('using tf.float32 =====================')
        data_type = tf.float32

    wd_rate_ph = tf.placeholder(data_type, shape=())
    wd_rate2_ph = tf.placeholder(data_type, shape=())
    lrn_rate_ph = tf.placeholder(data_type, shape=())

    with tf.variable_scope(FLAGS.resnet):
        images, labels, num_classes = dataset_reader.build_input(
            FLAGS.batch_size,
            'train',
            examples_per_class=FLAGS.examples_per_class,
            dataset=FLAGS.database,
            resize_image=FLAGS.resize_image,
            color_switch=FLAGS.color_switch,
            blur=FLAGS.blur)
        model = resnet.ResNet(num_classes,
                              lrn_rate_ph,
                              wd_rate_ph,
                              wd_rate2_ph,
                              optimizer=FLAGS.optimizer,
                              mode='train',
                              bn_epsilon=FLAGS.epsilon,
                              resnet=FLAGS.resnet,
                              norm_only=FLAGS.norm_only,
                              initializer=FLAGS.initializer,
                              fix_blocks=FLAGS.fix_blocks,
                              fine_tune_filename=FLAGS.fine_tune_filename,
                              bn_ema=FLAGS.ema_decay,
                              wd_mode=FLAGS.weight_decay_mode,
                              fisher_filename=FLAGS.fisher_filename,
                              gpu_num=FLAGS.gpu_num,
                              fisher_epsilon=FLAGS.fisher_epsilon,
                              float_type=data_type,
                              separate_regularization=FLAGS.separate_reg)
        model.inference(images)
        model.build_train_op(labels)

    names = []
    num_params = 0
    for v in tf.trainable_variables():
        # print v.name
        names.append(v.name)
        num = 1
        for i in v.get_shape().as_list():
            num *= i
        num_params += num
    print("Trainable parameters' num: %d" % num_params)

    precisions = tf.nn.in_top_k(tf.cast(model.predictions, tf.float32),
                                model.labels, 1)
    precision_op = tf.reduce_mean(tf.cast(precisions, tf.float32))
    # ========================= end of building model ================================

    step = 0
    saver = tf.train.Saver(max_to_keep=0)
    logdir = LogDir(FLAGS.database, FLAGS.log_dir, FLAGS.weight_decay_mode)
    logdir.print_all_info()
    if not os.path.exists(logdir.log_dir):
        print('creating ', logdir.log_dir, '...')
        os.mkdir(logdir.log_dir)
    if not os.path.exists(logdir.database_dir):
        print('creating ', logdir.database_dir, '...')
        os.mkdir(logdir.database_dir)
    if not os.path.exists(logdir.exp_dir):
        print('creating ', logdir.exp_dir, '...')
        os.mkdir(logdir.exp_dir)
    if not os.path.exists(logdir.snapshot_dir):
        print('creating ', logdir.snapshot_dir, '...')
        os.mkdir(logdir.snapshot_dir)

    init = [
        tf.global_variables_initializer(),
        tf.local_variables_initializer()
    ]

    gpu_options = tf.GPUOptions(allow_growth=False)
    config = tf.ConfigProto(log_device_placement=False,
                            gpu_options=gpu_options)
    sess = tf.Session(config=config)
    sess.run(init)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    import_variables = tf.trainable_variables()
    if FLAGS.fix_blocks > 0:
        import_variables = tf.global_variables()

    if FLAGS.fine_tune_filename is not None and resume_step is None:
        fine_tune_variables = []
        new_layers_names = model.new_layers_names
        new_layers_names.append('Momentum')
        for v in import_variables:
            if any(elem in v.name for elem in new_layers_names):
                print('not loading %s' % v.name)
                continue
            fine_tune_variables.append(v)

        loader = tf.train.Saver(var_list=fine_tune_variables)
        loader.restore(sess, FLAGS.fine_tune_filename)
        print('Succesfully loaded fine-tune model from %s.' %
              FLAGS.fine_tune_filename)
    elif resume_step is not None:
        # ./snapshot/model.ckpt-3000
        i_ckpt = logdir.snapshot_dir + '/model.ckpt-%d' % resume_step
        saver.restore(sess, i_ckpt)

        step = resume_step
        print('Succesfully loaded model from %s at step=%s.' %
              (i_ckpt, resume_step))
    else:
        print('Not import any model.')

    print(
        '=========================== training process begins ================================='
    )
    f_log = open(logdir.exp_dir + '/' + str(datetime.datetime.now()) + '.txt',
                 'w')
    f_log.write('step,loss,precision,wd\n')
    f_log.write(sorted_str_dict(FLAGS.__dict__) + '\n')

    average_loss = 0.0
    average_precision = 0.0
    show_period = 20
    snapshot = FLAGS.snapshot
    max_iter = FLAGS.train_max_iter
    lrn_rate = FLAGS.lrn_rate

    lr_step = []
    if FLAGS.lr_step is not None:
        temps = FLAGS.lr_step.split(',')
        for t in temps:
            lr_step.append(int(t))

    t0 = None
    wd_rate = FLAGS.weight_decay_rate
    wd_rate2 = FLAGS.weight_decay_rate2
    while step < max_iter + 1:
        step += 1

        if FLAGS.lr_policy == 'step':
            if len(lr_step) > 0 and step == lr_step[0]:
                lrn_rate *= FLAGS.step_size
                lr_step.remove(step)
        elif FLAGS.lr_policy == 'poly':
            lrn_rate = ((1 - 1.0 *
                         (step - 1) / max_iter)**0.9) * FLAGS.lrn_rate
        elif FLAGS.lr_policy == 'linear':
            lrn_rate = FLAGS.lrn_rate / step
        else:
            lrn_rate = FLAGS.lrn_rate

        _, loss, wd, precision = sess.run(
            [model.train_op, model.loss, model.wd, precision_op],
            feed_dict={
                lrn_rate_ph: lrn_rate,
                wd_rate_ph: wd_rate,
                wd_rate2_ph: wd_rate2
            })

        average_loss += loss
        average_precision += precision

        if FLAGS.save_first_iteration == 1 or step % snapshot == 0:
            saver.save(sess,
                       logdir.snapshot_dir + '/model.ckpt',
                       global_step=step)

        if step % show_period == 0:
            left_hours = 0

            if t0 is not None:
                delta_t = (datetime.datetime.now() - t0).seconds
                left_time = (max_iter - step) / show_period * delta_t
                left_hours = left_time / 3600.0

            t0 = datetime.datetime.now()

            average_loss /= show_period
            average_precision /= show_period

            if step == 0:
                average_loss *= show_period
                average_precision *= show_period

            f_log.write('%d,%f,%f,%f\n' %
                        (step, average_loss, average_precision, wd))
            f_log.flush()

            print('%s %s] Step %s, lr = %f, wd_rate = %f, wd_rate_2 = %f ' \
                  % (str(datetime.datetime.now()), str(os.getpid()), step, lrn_rate, wd_rate, wd_rate2))
            print('\t loss = %.4f, precision = %.4f, wd = %.4f' %
                  (average_loss, average_precision, wd))
            print('\t estimated time left: %.1f hours. %d/%d' %
                  (left_hours, step, max_iter))

            average_loss = 0.0
            average_precision = 0.0

    coord.request_stop()
    coord.join(threads)

    return f_log, logdir  # f_log returned for eval.
Esempio n. 6
0
def train(resume_step=None):
    global_step = tf.get_variable('global_step', [],
                                  dtype=tf.int64,
                                  initializer=tf.constant_initializer(0),
                                  trainable=False)
    image_size = FLAGS.train_image_size

    print '================',
    if FLAGS.data_type == 16:
        print 'using tf.float16 ====================='
        data_type = tf.float16
        print 'can not use float16 at this moment, because of tf.nn.bn, if using fused_bn, the learning will be nan',
        print ', no idea what happened.'
    else:
        print 'using tf.float32 ====================='
        data_type = tf.float32

    data_list = FLAGS.subsets_for_training.split(',')
    if len(data_list) < 1:
        data_list = ['train']
    print data_list

    images = []
    labels = []

    with tf.device('/cpu:0'):
        IMG_MEAN = np.array((103.939, 116.779, 123.68), dtype=np.float32)
        coord = tf.train.Coordinator()
        reader = ImageReader('./data/train', 'train.txt', '480,480', 'true',
                             'true', 255, IMG_MEAN, coord)

    print '================ Database Info ================'
    for i in range(FLAGS.gpu_num):
        with tf.device('/cpu:0'):
            image_batch, label_batch = reader.dequeue(FLAGS.batch_size)
            images.append(image_batch)
            labels.append(label_batch)

    wd_rate_ph = tf.placeholder(data_type, shape=())
    wd_rate2_ph = tf.placeholder(data_type, shape=())
    lrn_rate_ph = tf.placeholder(data_type, shape=())

    resnet = 'resnet_v1_50'

    ResnetModel = resnet_v1_50.ResNet
    with tf.variable_scope(resnet):
        model = ResnetModel(num_classes,
                            lrn_rate_ph,
                            wd_rate_ph,
                            wd_rate2_ph,
                            mode='train',
                            bn_epsilon=FLAGS.epsilon,
                            norm_only=FLAGS.norm_only,
                            initializer=FLAGS.initializer,
                            fix_blocks=FLAGS.fix_blocks,
                            fine_tune_filename=FLAGS.fine_tune_filename,
                            wd_mode=FLAGS.weight_decay_mode,
                            fisher_filename=FLAGS.fisher_filename)
        model.inference(images)
        model.build_train_op(labels)

    print 'iou precision shape: ', model.predictions.get_shape(
    ), labels[0].get_shape()
    pred = tf.reshape(model.predictions, [
        -1,
    ])
    gt = tf.reshape(labels[0], [
        -1,
    ])
    indices = tf.squeeze(tf.where(tf.less_equal(gt, num_classes - 1)), 1)
    gt = tf.cast(tf.gather(gt, indices), tf.int32)
    pred = tf.gather(pred, indices)
    precision_op, update_op = tf.contrib.metrics.streaming_mean_iou(
        pred, gt, num_classes=num_classes)
    # ========================= end of building model ================================

    step = 0
    logdir = LogDir(FLAGS.database, FLAGS.log_dir, FLAGS.weight_decay_mode)
    logdir.print_all_info()
    if not os.path.exists(logdir.log_dir):
        print 'creating ', logdir.log_dir, '...'
        os.mkdir(logdir.log_dir)
    if not os.path.exists(logdir.database_dir):
        print 'creating ', logdir.database_dir, '...'
        os.mkdir(logdir.database_dir)
    if not os.path.exists(logdir.exp_dir):
        print 'creating ', logdir.exp_dir, '...'
        os.mkdir(logdir.exp_dir)
    if not os.path.exists(logdir.snapshot_dir):
        print 'creating ', logdir.snapshot_dir, '...'
        os.mkdir(logdir.snapshot_dir)

    init = [
        tf.global_variables_initializer(),
        tf.local_variables_initializer()
    ]

    gpu_options = tf.GPUOptions(allow_growth=False)
    config = tf.ConfigProto(log_device_placement=False,
                            gpu_options=gpu_options,
                            allow_soft_placement=True)
    sess = tf.Session(config=config)
    sess.run(init)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    import_variables = tf.trainable_variables()
    if FLAGS.fix_blocks > 0 or FLAGS.bn_frozen > 0:
        import_variables = tf.global_variables()

    if FLAGS.fine_tune_filename is not None and resume_step is None:
        fine_tune_variables = []
        new_layers_names = model.new_layers_names
        new_layers_names.append('Momentum')
        new_layers_names.append('up_sample')
        for v in import_variables:
            if any(elem in v.name for elem in new_layers_names):
                print '=====Finetuning Process: not import %s' % v.name
                continue
            fine_tune_variables.append(v)

        loader = tf.train.Saver(var_list=fine_tune_variables, allow_empty=True)
        loader.restore(sess, FLAGS.fine_tune_filename)
        print('=====Succesfully loaded fine-tune model from %s.' %
              FLAGS.fine_tune_filename)
    elif resume_step is not None:
        # ./snapshot/model.ckpt-3000
        i_ckpt = './model/model.ckpt-%d' % resume_step

        loader = tf.train.Saver(max_to_keep=0)
        loader.restore(sess, i_ckpt)

        step = resume_step
        print('=====Succesfully loaded model from %s at step=%s.' %
              (i_ckpt, resume_step))
    else:
        print '=====Not import any model.'

    print '=========================== training process begins ================================='
    f_log = open(logdir.exp_dir + '/' + str(datetime.datetime.now()) + '.txt',
                 'w')
    f_log.write('step,loss,precision,wd\n')
    f_log.write(sorted_str_dict(FLAGS.__dict__) + '\n')

    average_loss = 0.0
    show_period = 20
    snapshot = FLAGS.snapshot
    max_iter = FLAGS.train_max_iter
    lrn_rate = FLAGS.lrn_rate

    lr_step = []
    if FLAGS.lr_step is not None:
        temps = FLAGS.lr_step.split(',')
        for t in temps:
            lr_step.append(int(t))

    saver = tf.train.Saver(max_to_keep=2)
    t0 = None
    wd_rate = FLAGS.weight_decay_rate
    wd_rate2 = FLAGS.weight_decay_rate2

    if FLAGS.save_first_iteration == 1:
        saver.save(sess, logdir.snapshot_dir + '/model.ckpt', global_step=step)

    has_nan = False
    while step < max_iter + 1:
        if FLAGS.poly_lr == 1:
            lrn_rate = ((1 - 1.0 * step / max_iter)**0.9) * FLAGS.lrn_rate

        step += 1
        if len(lr_step) > 0 and step == lr_step[0]:
            lrn_rate *= FLAGS.step_size
            lr_step.remove(step)

        _, loss, wd, update, precision = sess.run(
            [model.train_op, model.loss, model.wd, update_op, precision_op],
            feed_dict={
                lrn_rate_ph: lrn_rate,
                wd_rate_ph: wd_rate,
                wd_rate2_ph: wd_rate2
            })

        average_loss += loss

        if step % snapshot == 0:
            saver.save(sess,
                       logdir.snapshot_dir + '/model.ckpt',
                       global_step=step)
            sess.run([tf.local_variables_initializer()])

        if step % show_period == 0:
            left_hours = 0

            if t0 is not None:
                delta_t = (datetime.datetime.now() - t0).seconds
                left_time = (max_iter - step) / show_period * delta_t
                left_hours = left_time / 3600.0

            t0 = datetime.datetime.now()

            average_loss /= show_period

            if step == 0:
                average_loss *= show_period

            f_log.write('%d,%f,%f,%f\n' % (step, average_loss, precision, wd))
            f_log.flush()

            print '%s %s] Step %s, lr = %f, wd_rate = %f, wd_rate_2 = %f ' \
                  % (str(datetime.datetime.now()), str(os.getpid()), step, lrn_rate, wd_rate, wd_rate2)
            print '\t loss = %.4f, precision = %.4f, wd = %.4f' % (
                average_loss, precision, wd)
            print '\t estimated time left: %.1f hours. %d/%d' % (
                left_hours, step, max_iter)

            average_loss = 0.0

    coord.request_stop()
    coord.join(threads)

    return f_log, logdir, has_nan  # f_log and logdir returned for eval.