Ejemplo n.º 1
0
def main(unused_argv):
    loader = Loader(base_path=None, path="/data")
    datasets = loader.CUB(ratio=0.2, total_ratio=total_ratio)
    model = Resnet18(batch_size=FLAGS.batch_size)
    with model.graph.as_default():
        model.preload()

        vars = [
            var for var in tf.global_variables() if var.name.startswith("conv")
        ]

        global_step = tf.Variable(0, name='global_step', trainable=False)
        learning_rate = tf.train.exponential_decay(
            1e-3,
            global_step=global_step,
            decay_steps=5 * int(len(datasets["train"]) / FLAGS.batch_size),
            decay_rate=0.1,
            staircase=True)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            opt = tf.train.AdamOptimizer(learning_rate=learning_rate)
            grad_and_vars = opt.compute_gradients(loss=model.loss)

            for index, (grad, var) in enumerate(grad_and_vars):
                if FLAGS.fine_tune:
                    if var.op.name.startswith(
                            "dense") or var.op.name.startswith("conv5"):
                        grad_and_vars[index] = (grad * 10.0, var)
                elif FLAGS.freeze:
                    if var.op.name.startswith(
                            "conv1") or var.op.name.startswith("conv2"):
                        grad_and_vars[index] = (grad * 1e-3, var)

            train_op = opt.apply_gradients(grad_and_vars,
                                           global_step=global_step)
            # train_op = tf.train.AdamOptimizer(learning_rate=learning_rate)\
            #     .minimize(loss=model.loss, global_step=global_step)

        rest_vars = list(
            set([var for var in tf.global_variables()]) - set(vars))
        init_rest_vars = tf.variables_initializer(rest_vars)

    # writer = tf.summary.FileWriter("logs/", model.graph)
    # writer.flush()
    # writer.close()

    # vars = [var.name for var in vars]
    # print('\n'.join(vars))
    # import sys
    # sys.exit(0)

    with tf.Session(graph=model.graph) as sess:
        if os.path.exists(utils.path("models/trained")):
            tf.train.Saver().restore(
                sess,
                tf.train.latest_checkpoint(utils.path("models/trained/")))
        else:
            init_rest_vars.run()
            tf.train.Saver(vars).restore(sess,
                                         utils.path("models/init/models.ckpt"))

        from BatchLoader import BatchLoader
        LOG = utils.Log()

        for epoch in range(FLAGS.num_epochs):
            for phase in ('train', 'test'):
                dataset = datasets[phase]

                accs = utils.AverageMeter()
                losses = utils.AverageMeter()
                start_time = time.time()
                bar = progressbar.ProgressBar()

                for features, boxes, im_sizes in bar(
                        BatchLoader(dataset,
                                    batch_size=FLAGS.batch_size,
                                    pre_fetch=FLAGS.pre_fetch,
                                    shuffle=(phase == 'train'),
                                    op_fn=CUB_Dataset.list_to_tuple)):
                    boxes = utils.crop_boxes(boxes, im_sizes)
                    boxes = utils.box_transform(boxes, im_sizes)

                    if phase == 'train':
                        _, loss, outputs = sess.run(
                            [train_op, model.loss, model.fc],
                            feed_dict={
                                'features:0': features,
                                'boxes:0': boxes,
                                'training:0': phase == 'train',
                            })
                    else:
                        loss, outputs = sess.run(
                            [model.loss, model.fc],
                            feed_dict={
                                'features:0': features,
                                'boxes:0': boxes,
                                'training:0': phase == 'train',
                            })

                    acc = utils.compute_acc(outputs, boxes, im_sizes)

                    nsample = model.batch_size
                    accs.update(acc, nsample)
                    losses.update(loss, nsample)

                    LOG.add(phase, {"accu": float(acc), "loss": float(loss)})

                elapsed_time = time.time() - start_time
                print(
                    '[{}]\tEpoch: {}/{}\tLoss: {:.4f}\tAcc: {:.2%}\tTime: {:.3f}'
                    .format(phase, epoch, FLAGS.num_epochs, losses.avg,
                            accs.avg, elapsed_time))

        tf.train.Saver().save(sess,
                              utils.path("models/trained/resnet18.ckpt"),
                              global_step=global_step)
        if FLAGS.log_path is not None:
            LOG.dump(FLAGS.log_path)
    for phase in phases:
        for tag in tags:
            sub_id = sub_id + 1
            sub = fig.add_subplot(sub_lines, 1, sub_id)
            sub.plot(range(len(log[phase][tag])),
                     log[phase][tag],
                     label=(phase + '_' + tag))
            sub.set_title(phase + ': ' + tag)
    fig.tight_layout()
    plt.show()

# Visualize predicting result
if FLAGS.demo_only:
    figs_x, figs_y = (5, 5)
    loader = Loader(base_path=None, path="/data")
    datasets = loader.CUB(ratio=0.2, total_ratio=1.0)
    model = Resnet18(batch_size=figs_x * figs_y)
    with model.graph.as_default():
        model.preload()

    with tf.Session(graph=model.graph) as sess:
        tf.train.Saver().restore(
            sess,
            tf.train.latest_checkpoint(
                utils.path("models/" + FLAGS.model + "/")))
        data_loader = BatchLoader(datasets['test'],
                                  batch_size=figs_x * figs_y,
                                  pre_fetch=1,
                                  shuffle=True,
                                  op_fn=CUB_Dataset.list_to_tuple)
        fig = plt.figure(figsize=(6 * figs_x, 2 * figs_y))