示例#1
0
def main(args):
    """Performs the necessary operations once the arguments have been parsed"""
    # For reproducibility
    torch.manual_seed(42)
    random.seed(42)

    print("Loading datasets...")
    train_ds, valid_ds = load_cityscapes_datasets(args.dataset_folder)
    print("Done!\n")

    print("Setting up model, optimizer, and loss function...")
    model, optimizer, loss_fn = setup_model(args)
    print("Done!\n")

    start_epoch = 1

    if args.checkpoint_path:
        print("Loading checkpoint...")
        start_epoch, train_losses, valid_losses = load_model_checkpoint(
            model, optimizer, args.checkpoint_path)
        print("Done!\n")

    if args.is_training_model:
        if args.checkpoint_path is None:
            start_epoch, train_losses, valid_losses = 1, [], []

        num_epochs = start_epoch + args.num_epochs

        print("Setting up learning rate scheduler...")
        lr_scheduler = PolynomialLRScheduler(optimizer, args.learning_rate,
                                             len(train_ds), num_epochs - 1)
        print("Done!\n")

        print("Training model...\n")
        train(model,
              optimizer,
              loss_fn,
              train_ds,
              valid_ds,
              start_epoch,
              num_epochs,
              lr_scheduler,
              args,
              train_losses=train_losses,
              valid_losses=valid_losses)
        print("Done!")
    elif args.is_evaluating_model:
        print("Evaluating model on the validation set...")
        evaluate_save_predictions(model, valid_ds, args)
        print("Done!\n")

        print("Calculating mIoU...")
        mIoU = evalPixelLevelSemanticLabeling.main()
        print("mIoU is {:.3f}".format(mIoU))
示例#2
0
def eval(MODEL, TARGET, BATCH_SIZE, **kwargs):
    if not os.path.exists(os.path.join('results', TARGET)):
        os.makedirs(os.path.join('results', TARGET))
    # >>>>>>> DATASET
    from cityscapesScripts.cityscapesscripts.helpers import labels as L
    from cityscapesScripts.cityscapesscripts.evaluation import evalPixelLevelSemanticLabeling as E
    from PIL import Image

    trainId2id = np.zeros((20, ), np.uint8)
    for i in range(19):
        trainId2id[i] = L.trainId2label[i].id
    trainId2id[19] = 0

    cityspaces = CitySpaces()
    scale_factor = np.array(kwargs['CROP_SIZE']) / np.array(kwargs['IM_SIZE'])
    size = tuple(
        int(l // s) for (l, s) in zip(cityspaces.image_size, scale_factor))
    imnames, ims, lbs = cityspaces.build_queue(target=TARGET,
                                               crop=cityspaces.image_size,
                                               resize=size,
                                               z_range=None,
                                               batch_size=BATCH_SIZE,
                                               num_threads=2)
    # <<<<<<<

    # >>>>>>> MODEL
    with tf.variable_scope('net'):
        with tf.variable_scope('params') as params:
            pass
        net = FRRN(None, None, kwargs['K'], ims, lbs,
                   partial(_arch_type_a, NUM_CLASSES), params, False)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Run!
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.graph.finalize()
    sess.run(init_op)
    net.load(sess, MODEL)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)
    try:
        while not coord.should_stop():
            names, preds = sess.run([imnames, net.preds])
            for name, pred in zip(names, preds):
                name = os.path.basename(str(name, 'utf-8'))
                im = Image.fromarray(trainId2id[pred])
                im = im.resize(
                    (cityspaces.image_size[1], cityspaces.image_size[0]),
                    Image.NEAREST)
                im.save(os.path.join('results', TARGET, name), "PNG")
            print('.', end='', flush=True)
    except tf.errors.OutOfRangeError:
        print('Complete')

    coord.request_stop()
    coord.join(threads)

    E.main([TARGET])