def main(args):
    print('---------------------ARGS---------------------')
    print('--train_batch_size     : %d' % args.train_batch_size)
    print('--train_epoch          : %d' % args.train_epoch)
    print('--model_base_dir       : %s' % args.model_base_dir)
    print('--log_base_dir         : %s' % args.log_base_dir)
    print('--learning_rate        : %f' % args.learning_rate)
    print('--width                : %d' % args.width)
    print('--height               : %d' % args.height)
    print('--save_steps           : %d' % args.save_steps)
    print('--val_steps            : %d' % args.val_steps)
    print('--log_steps            : %d' % args.log_steps)
    print('--train_imgs_dir       : %s' % args.train_imgs_dir)
    print('--test_imgs_dir        : %s' % args.test_imgs_dir)
    print('---------------------END----------------------')

    subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S')
    log_dir_today = os.path.join(os.path.expanduser(args.log_base_dir), subdir)
    if not os.path.isdir(log_dir_today):
        os.makedirs(log_dir_today)
    model_dir = os.path.join(os.path.expanduser(args.model_base_dir), subdir)
    if not os.path.isdir(model_dir):
        os.makedirs(model_dir)
    print('Model directory: %s' % model_dir)
    print('Log directory  : %s' % log_dir_today)

    print('Load train and validation data...')
    train_imgs_dir_list = [os.path.join(args.train_imgs_dir, l)
                           for l in os.listdir(os.path.expanduser(args.train_imgs_dir))]
    test_imgs_dir_list = [os.path.join(args.test_imgs_dir, l)
                          for l in os.listdir(os.path.expanduser(args.test_imgs_dir))]

    assert len(train_imgs_dir_list) != 0 or len(test_imgs_dir_list) != 0

    train_x, train_y = utils.prepare_dataset(train_imgs_dir_list)
    test_x, test_y = utils.prepare_dataset(test_imgs_dir_list)

    train_ds = utils.create_inputs(train_x, train_y, batch_size=args.train_batch_size)
    test_ds = utils.create_inputs(test_x, test_y, batch_size=args.train_batch_size, train_data=False)

    print('Create model...')
    model = create_model()

    print('Callbacks for tensorboard...')
    tb = tf.keras.callbacks.TensorBoard(log_dir=log_dir_today, histogram_freq=1, write_images=True)
    st = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3)

    print('Compile model...')
    model.compile(optimizer='SGD', loss='sparse_categorical_crossentropy', metrics=['accuracy'],
                  learning_rate=args.learning_rate)

    print('Train model with args...')
    h = model.fit(train_ds, epochs=args.train_epoch, validation_data=test_ds, callbacks=[tb])

    print('Done, save model to: %s' % os.path.join(model_dir, 'cafar10.h5'))
    model.save(os.path.join(model_dir, 'cafar10.h5'))
Exemple #2
0
def main(_):
    with tf.Graph().as_default():
        num_batches_per_epoch_train = int(60000 / cfg.batch_size)
        num_batches_test = int(10000 / cfg.batch_size)

        batch_x, batch_labels = create_inputs(is_train=False)
        output = net.build_arch(batch_x, is_train=False)
        batch_acc = net.test_accuracy(output, batch_labels)
        saver = tf.train.Saver()

        step = 0

        summaries = []
        summaries.append(tf.summary.scalar('accuracy', batch_acc))
        summary_op = tf.summary.merge(summaries)

        with tf.Session() as sess:
            tf.train.start_queue_runners(sess=sess)
            summary_writer = tf.summary.FileWriter(cfg.test_logdir,
                                                   graph=sess.graph)

            for epoch in range(cfg.epoch):
                ckpt = os.path.join(
                    cfg.logdir,
                    'model.ckpt-%d' % (num_batches_per_epoch_train * epoch))
                saver.restore(sess, ckpt)

                for i in range(num_batches_test):
                    summary_str = sess.run(summary_op)
                    print('%d batches are tested.' % step)
                    summary_writer.add_summary(summary_str, step)

                    step += 1
Exemple #3
0
def main(_):
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        global_step = tf.get_variable('global_step', [],
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)

        num_batches_per_epoch = int(60000 / cfg.batch_size)
        opt = tf.train.AdamOptimizer()

        batch_x, batch_labels = create_inputs(is_train=True)
        # batch_y = tf.one_hot(batch_labels, depth=10, axis=1, dtype=tf.float32)
        with tf.device('/gpu:0'):
            with slim.arg_scope([slim.variable], device='/cpu:0'):
                output = net.build_arch(batch_x, is_train=True)
                loss = net.cross_ent_loss(output, batch_labels)

            grad = opt.compute_gradients(loss)

        loss_name = 'cross_ent_loss'

        summaries = []
        summaries.append(tf.summary.scalar(loss_name, loss))

        train_op = opt.apply_gradients(grad, global_step=global_step)

        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                                log_device_placement=False))
        sess.run(tf.global_variables_initializer())

        saver = tf.train.Saver(tf.global_variables(), max_to_keep=cfg.epoch)

        #read snapshot
        # latest = os.path.join(cfg.logdir, 'model.ckpt-4680')
        # saver.restore(sess, latest)

        summary_op = tf.summary.merge(summaries)
        tf.train.start_queue_runners(sess=sess)

        summary_writer = tf.summary.FileWriter(cfg.logdir, graph=sess.graph)

        for step in range(cfg.epoch * num_batches_per_epoch):
            tic = time.time()
            _, loss_value = sess.run([train_op, loss])
            print('%d iteration is finished in ' % step + '%f second' %
                  (time.time() - tic))
            # test1_v = sess.run(test2)

            # if np.isnan(loss_value):
            #     print('bbb')
            #  assert not np.isnan(np.any(test2_v[0])), 'a is nan'
            assert not np.isnan(loss_value), 'loss is nan'

            if step % 10 == 0:
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)

            if (step % num_batches_per_epoch) == 0:
                ckpt_path = os.path.join(cfg.logdir, 'model.ckpt')
                saver.save(sess, ckpt_path, global_step=step)