Exemple #1
0
def train(dataset_train, dataset_test, caffemodel=''):
    print('train() called')
    V = config.num_views
    batch_size = config.batch_size

    dataset_train.shuffle()
    data_size = dataset_train.size()

    print('training size:', data_size)

    with tf.Graph().as_default():
        with tf.device('/gpu:0'):

            tf_config = tf.ConfigProto(log_device_placement=False)
            tf_config.gpu_options.allow_growth = True
            tf_config.allow_soft_placement = True

            global_step = tf.Variable(0, trainable=False)

            # placeholders for graph input
            view_ = tf.placeholder('float32',
                                   shape=(None, V, 227, 227, 3),
                                   name='im0')
            y_ = tf.placeholder('int64', shape=(None), name='y')
            keep_prob_ = tf.placeholder('float32')

            # graph outputs
            fc8 = model.inference_multiview(view_, config.num_classes,
                                            keep_prob_)
            loss = model.loss(fc8, y_)
            train_op = model.train(loss, global_step, data_size)
            prediction = model.classify(fc8)
            placeholders = [view_, y_, keep_prob_, prediction, loss]
            validation_loss = tf.placeholder('float32',
                                             shape=(),
                                             name='validation_loss')
            validation_acc = tf.placeholder('float32',
                                            shape=(),
                                            name='validation_accuracy')

            saver = tf.train.Saver(tf.all_variables(), max_to_keep=1000)

            init_op = tf.global_variables_initializer()
            sess = tf.Session(config=tf_config)
            weights = config.weights
            if weights == -1:
                startepoch = 0
                if caffemodel:
                    sess.run(init_op)
                    model.load_alexnet_to_mvcnn(sess, caffemodel)
                    print('loaded pretrained caffemodel:', caffemodel)
                else:
                    sess.run(init_op)
                    print('init_op done')
            else:
                ld = config.log_dir
                startepoch = weights + 1
                ckptfile = os.path.join(ld,
                                        config.snapshot_prefix + str(weights))

                saver.restore(sess, ckptfile)
                print('restore variables done')

            total_seen = 0
            total_correct = 0
            total_loss = 0

            step = 0
            begin = startepoch
            end = config.max_epoch + startepoch
            for epoch in xrange(begin, end + 1):
                acc, eval_loss, predictions, labels = _test(
                    dataset_test, config, sess, placeholders)
                print('epoch %d: step %d, validation loss=%.4f, acc=%f' %
                      (epoch, step, eval_loss, acc * 100.))

                LOSS_LOGGER.log(eval_loss, epoch, "eval_loss")
                ACC_LOGGER.log(acc, epoch, "eval_accuracy")
                ACC_LOGGER.save(config.log_dir)
                LOSS_LOGGER.save(config.log_dir)
                ACC_LOGGER.plot(dest=config.log_dir)
                LOSS_LOGGER.plot(dest=config.log_dir)

                for batch_x, batch_y in dataset_train.batches(batch_size):
                    step += 1

                    feed_dict = {view_: batch_x, y_: batch_y, keep_prob_: 0.5}

                    _, pred, loss_value = sess.run([
                        train_op,
                        prediction,
                        loss,
                    ],
                                                   feed_dict=feed_dict)

                    total_loss += loss_value
                    correct = np.sum(pred == batch_y)
                    total_correct += correct
                    total_seen += batch_size

                    assert not np.isnan(
                        loss_value), 'Model diverged with loss = NaN'

                    if step % max(config.train_log_frq / config.batch_size,
                                  1) == 0:
                        acc_ = total_correct / float(total_seen)
                        ACC_LOGGER.log(acc_, epoch, "train_accuracy")
                        loss_ = total_loss / float(total_seen / batch_size)
                        LOSS_LOGGER.log(loss_, epoch, "train_loss")
                        print('epoch %d step %d, loss=%.2f, acc=%.2f' %
                              (epoch, step, loss_, acc_))
                        total_seen = 0
                        total_correct = 0
                        total_loss = 0

                if epoch % config.save_period == 0 or epoch == end:
                    checkpoint_path = os.path.join(
                        config.log_dir, config.snapshot_prefix + str(epoch))
                    saver.save(sess, checkpoint_path)
Exemple #2
0
def train(dataset_train, dataset_val, ckptfile='', caffemodel=''):
    print('train() called')
    is_finetune = bool(ckptfile)
    V = g_.NUM_VIEWS
    batch_size = FLAGS.batch_size

    dataset_train.shuffle()
    dataset_val.shuffle()
    data_size = dataset_train.size()
    print('training size:', data_size)

    with tf.Graph().as_default():
        startstep = 0 if not is_finetune else int(ckptfile.split('-')[-1])
        global_step = tf.Variable(startstep, trainable=False)

        # placeholders for graph input
        view_ = tf.placeholder('float32',
                               shape=(None, V, 227, 227, 3),
                               name='im0')
        y_ = tf.placeholder('int64', shape=(None), name='y')
        keep_prob_ = tf.placeholder('float32')

        # graph outputs
        fc8 = model.inference_multiview(view_, g_.NUM_CLASSES, keep_prob_)
        loss = model.loss(fc8, y_)
        train_op = model.train(loss, global_step, data_size)
        prediction = model.classify(fc8)

        # build the summary operation based on the F collection of Summaries
        summary_op = tf.summary.merge_all()

        # must be after merge_all_summaries
        validation_loss = tf.placeholder('float32',
                                         shape=(),
                                         name='validation_loss')
        validation_summary = tf.summary.scalar('validation_loss',
                                               validation_loss)
        validation_acc = tf.placeholder('float32',
                                        shape=(),
                                        name='validation_accuracy')
        validation_acc_summary = tf.summary.scalar('validation_accuracy',
                                                   validation_acc)

        saver = tf.train.Saver(tf.all_variables(), max_to_keep=1000)

        init_op = tf.global_variables_initializer()
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9)
        sess = tf.Session(config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement,
            gpu_options=gpu_options))
        if is_finetune:
            # load checkpoint file
            saver.restore(sess, ckptfile)
            print('restore variables done')
        elif caffemodel:
            # load caffemodel generated with caffe-tensorflow
            sess.run(init_op)
            model.load_alexnet_to_mvcnn(sess, caffemodel)
            print('loaded pretrained caffemodel:', caffemodel)
        else:
            # from scratch
            sess.run(init_op)
            print('init_op done')

        summary_writer = tf.summary.FileWriter(FLAGS.train_dir,
                                               graph=sess.graph)

        step = startstep
        for epoch in range(100):
            print('epoch:', epoch)

            for batch_x, batch_y in dataset_train.batches(batch_size):
                step += 1

                start_time = time.time()
                feed_dict = {view_: batch_x, y_: batch_y, keep_prob_: 0.5}

                _, pred, loss_value = sess.run([train_op, prediction, loss],
                                               feed_dict=feed_dict)

                duration = time.time() - start_time

                assert not np.isnan(
                    loss_value), 'Model diverged with loss = NaN'

                # print training information
                if step % 10 == 0 or step - startstep <= 30:
                    sec_per_batch = float(duration)
                    print(
                        '%s: step %d, loss=%.2f (%.1f examples/sec; %.3f sec/batch)'
                        % (datetime.now(), step, loss_value,
                           FLAGS.batch_size / duration, sec_per_batch))

                # validation
                if step % g_.VAL_PERIOD == 0:  # and step > 0:
                    val_losses = []
                    predictions = np.array([])

                    val_y = []
                    for val_step, (val_batch_x, val_batch_y) in \
                            enumerate(dataset_val.sample_batches(batch_size, g_.VAL_SAMPLE_SIZE)):
                        val_feed_dict = {
                            view_: val_batch_x,
                            y_: val_batch_y,
                            keep_prob_: 1.0
                        }
                        val_loss, pred = sess.run([loss, prediction],
                                                  feed_dict=val_feed_dict)
                        val_losses.append(val_loss)
                        predictions = np.hstack((predictions, pred))
                        val_y.extend(val_batch_y)

                    val_loss = np.mean(val_losses)

                    acc = metrics.accuracy_score(val_y[:predictions.size],
                                                 np.array(predictions))
                    print('%s: step %d, validation loss=%.4f, acc=%f' %
                          (datetime.now(), step, val_loss, acc * 100.))

                    # validation summary
                    val_loss_summ = sess.run(
                        validation_summary,
                        feed_dict={validation_loss: val_loss})
                    val_acc_summ = sess.run(validation_acc_summary,
                                            feed_dict={validation_acc: acc})
                    summary_writer.add_summary(val_loss_summ, step)
                    summary_writer.add_summary(val_acc_summ, step)
                    summary_writer.flush()

                if step % 100 == 0:
                    # print ('running summary')
                    summary_str = sess.run(summary_op, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, step)
                    summary_writer.flush()

                if step % g_.SAVE_PERIOD == 0 and step > startstep:
                    checkpoint_path = os.path.join(FLAGS.train_dir,
                                                   'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=step)
Exemple #3
0
def train(cfg, dataset_train, dataset_val, ckptfile='', caffemodel=''):
    print ('train() called')
    is_finetune = bool(ckptfile)
    V = g_.NUM_VIEWS
    batch_size = FLAGS.batch_size

    # dataset_train.shuffle()
    # dataset_val.shuffle()
    data_size, num_batch = dataset_train.get_len()
    # data_size = len(dataset_train)
    # print ('train size:', data_size)

    data_size_test, num_batch_test = dataset_val.get_len()
    print ('train size:', data_size)
    print ('test size:', data_size_test)

    best_eval_acc = 0




    with tf.Graph().as_default():
        # startstep = 0 if not is_finetune else int(ckptfile.split('-')[-1])
        startstep = 0
        global_step = tf.Variable(startstep, trainable=False)
         
        # placeholders for graph input
        view_ = tf.placeholder('float32', shape=(None, V, 224, 224, 3), name='im0')
        y_ = tf.placeholder('int64', shape=(None), name='y')
        is_training_pl = tf.placeholder(tf.bool, shape=())
        bn_decay = get_bn_decay(startstep)

        # graph outputs
        fc8 = model.inference_multiview(view_, g_.NUM_CLASSES, is_training_pl, bn_decay=bn_decay)
        loss = model.loss(fc8, y_)
        train_op = model.train(loss, global_step, data_size)
        prediction = model.classify(fc8)

        # build the summary operation based on the F colection of Summaries
        summary_op = tf.summary.merge_all()


        # must be after merge_all_summaries
        validation_loss = tf.placeholder('float32', shape=(), name='validation_loss')
        validation_summary = tf.summary.scalar('validation_loss', validation_loss)
        validation_acc = tf.placeholder('float32', shape=(), name='validation_accuracy')
        validation_acc_summary = tf.summary.scalar('validation_accuracy', validation_acc)

        # tvars = tf.trainable_variables()
        # print (tvars)
        # print (tf.get_default_graph().as_graph_def())

        saver = tf.train.Saver()

        init_op = tf.global_variables_initializer()
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        config.log_device_placement = False
        sess = tf.Session(config=config)
        
        if is_finetune:
            # load checkpoint file
            sess.run(init_op)
            optimistic_restore(sess, ckptfile)
            # saver.restore(sess, ckptfile)
            print ('restore variables done')
        elif caffemodel:
            # load caffemodel generated with caffe-tensorflow
            sess.run(init_op)
            model.load_alexnet_to_mvcnn(sess, caffemodel)
            print ('loaded pretrained caffemodel:', caffemodel)
        else:
            # from scratch
            sess.run(init_op)
            print ('init_op done')

        summary_writer = tf.summary.FileWriter(FLAGS.train_dir,
                                               graph=sess.graph) 

        step = startstep


        for epoch in range(100):
            total_correct_mv = 0
            loss_sum_mv = 0
            total_seen = 0

            val_correct_sum = 0
            val_seen = 0
            loss_val_sum = 0
            print ('epoch:', epoch)

            for i in range(num_batch):
                # st = time.time()
                batch_x, batch_y = dataset_train.get_batch(i)
                # print (time.time()-st)
                step += 1

                start_time = time.time()
                feed_dict = {view_: batch_x,
                             y_ : batch_y,
                             is_training_pl: True }

                _, pred, loss_value = sess.run(
                        [train_op, prediction,  loss,],
                        feed_dict=feed_dict)

                duration = time.time() - start_time

                correct_mv = np.sum(pred == batch_y)
                total_correct_mv += correct_mv
                total_seen += g_.BATCH_SIZE
                loss_sum_mv += (loss_value * g_.BATCH_SIZE)

                assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

                # print training information
                if step % 500 == 0 :
                    # print (pred)
                    # print (batch_y)
                    sec_per_batch = float(duration)
                    print ('%s: step %d, loss=%.2f, acc=%.4f (%.1f examples/sec; %.3f sec/batch)' \
                         % (datetime.now(), step, loss_sum_mv / float(total_seen), total_correct_mv / float(total_seen),
                                    FLAGS.batch_size/duration, sec_per_batch))

                    # for i in range(num_batch_test):
                    #     val_batch_x, val_batch_y = dataset_val.get_batch(i)
                    #     val_feed_dict = {view_: val_batch_x,
                    #                      y_: val_batch_y,
                    #                      is_training_pl: False}
                    #     val_loss, pred = sess.run([loss, prediction], feed_dict=val_feed_dict)
                    #
                    #     correct_mv_val = np.sum(pred == val_batch_y)
                    #     val_correct_sum += correct_mv_val
                    #     val_seen += g_.BATCH_SIZE
                    #     loss_val_sum += (val_loss * g_.BATCH_SIZE)
                    #
                    #     if i == 10:
                    #         print (pred)
                    #         print (val_batch_y)
                    #         print ('val loss=%.4f, acc=%.4f' % ((loss_val_sum / float(val_seen)), (val_correct_sum / float(val_seen))))


                if step % 1000 == 0:
                    # print 'running summary'
                    summary_str = sess.run(summary_op, feed_dict=feed_dict)
                    summary_writer.add_summary(summary_str, step)
                    summary_writer.flush()

                        
            # validation
            # val_losses = []
            # predictions = np.array([])
            # val_y = []

            for i in range(num_batch_test):
                val_batch_x, val_batch_y = dataset_val.get_batch(i)
                val_feed_dict = {view_: val_batch_x,
                                 y_  : val_batch_y,
                                 is_training_pl: False }
                val_loss, pred = sess.run([loss, prediction], feed_dict=val_feed_dict)

                correct_mv_val = np.sum(pred == val_batch_y)
                val_correct_sum += correct_mv_val
                val_seen += g_.BATCH_SIZE
                loss_val_sum += (val_loss * g_.BATCH_SIZE)

            val_mean_loss = (loss_val_sum / float(val_seen))
            acc = (val_correct_sum / float(val_seen))
            if acc > best_eval_acc:
                best_eval_acc = acc
                checkpoint_path = os.path.join(cfg.ckpt_folder, 'best_model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)

            print ('%s: epoch %d, validation loss=%.4f, acc=%f, best_acc=%f' %\
                    (datetime.now(), epoch, val_mean_loss, acc, best_eval_acc))
            # validation summary
            val_loss_summ = sess.run(validation_summary,
                    feed_dict={validation_loss: val_mean_loss})
            val_acc_summ = sess.run(validation_acc_summary,
                    feed_dict={validation_acc: acc})
            summary_writer.add_summary(val_loss_summ, step)
            summary_writer.add_summary(val_acc_summ, step)
            summary_writer.flush()