示例#1
0
    )
    val_loader = data.DataLoader(
        valset,
        batch_size=config["batch_size"],
        shuffle=False,
        num_workers=config["workers"],
    )
    if config["dataset_info"]:
        utils.dataloader_info(val_loader)

    # Initialize ship segmentation network
    num_classes = 1
    model_str = config["model"].lower()
    print("Loading ship segmentation model ({})...".format(model_str))
    if model_str == "enet":
        net = models.ENet(num_classes)
    elif model_str == "linknet":
        net = models.LinkNet(num_classes)
    elif model_str == "linknet34":
        net = models.LinkNet(num_classes, 34)
    elif model_str == "dilatedunet":
        net = models.DilatedUNet(classes=num_classes)
    else:
        raise ValueError(
            "requested unknown model {}, expect one of "
            "(ENet, LinkNet, LinkNet34, DilatedUNet)".format(config["model"])
        )
    print(net)

    # Loss function
    loss_name = config["loss"].lower()
def main(args):
    dataset = None
    # Retrieve datset
    if args.dataset == "cityscapes":
        dataset = datasets.Cityscapes()
    elif args.dataset == "freiburg":
        dataset = datasets.Freiburg()
    else:
        raise NotImplementedError("Dataset \"%s\" not yet supported." %
                                  args.dataset)
    # Configure directories
    data_dir = dataset.get_test_paths(args.data_dir)[0]
    if not os.path.exists(args.output):
        os.makedirs(args.output)
    # Parse first record and retrieve image dimensions
    first_record = os.path.join(data_dir, os.listdir(data_dir)[0])
    example = tt.tfrecord.tfrecord2example_dict(first_record)
    example = example["features"]["feature"]
    height = example["height"]["int64List"]["value"][0]
    width = example["width"]["int64List"]["value"][0]
    channels = example["image/channels"]["int64List"]["value"][0]
    decode_fn = lambda example: decode_tfrecord(example,
                                                [height, width, channels])

    # Create network and input stage
    net = models.ENet(dataset.num_classes)
    input_stage = tt.input.InputStage(input_shape=[height, width, channels])
    # Add test set to input stage
    num_examples = input_stage.add_dataset("test",
                                           data_dir,
                                           batch_size=1,
                                           decode_fn=decode_fn)

    input_image, file_id = input_stage.get_output()
    input_image = tf.expand_dims(input_image, axis=0)

    logits = net(input_image, training=False)
    p_class = tf.nn.softmax(logits)
    if args.size is not None:
        p_class = tf.image.resize_bilinear(logits, args.size)
    pred = tf.math.argmax(p_class, axis=-1)
    # Do the reverse embedding from trainId to dataset id
    if not args.color:
        pred = tf.expand_dims(pred, axis=-1)
        embedding = tf.constant(dataset.embedding_reversed, dtype=tf.uint8)
        pred_embed = tf.gather_nd(embedding, pred)
        # Expand lost dimension
        pred_embed = tf.expand_dims(pred_embed, axis=-1)
    else:
        pred_embed = tf.gather(dataset.colormap, tf.cast(pred, tf.int32))
        pred_embed = tf.cast(pred_embed, tf.uint8)
    # Encode output image
    pred_encoding = tf.image.encode_png(pred_embed[0])

    # Write encoded file to @args.output_dir
    output_dir = args.output
    if output_dir[-1] == "/":
        output_dir = output_dir[:-1]
    filename = tf.string_join([file_id, ".png"])
    filepath = tf.string_join([output_dir, filename], separator="/")
    write_file = tf.io.write_file(filepath, pred_encoding)

    print("Loading checkpoint")
    # Restore model from checkpoint (@args.ckpt)
    ckpt = tf.train.Checkpoint(model=net)
    status = ckpt.restore(args.ckpt)
    print("Checkpoint loaded")
    if tf.__version__ < "1.14.0":
        status.assert_existing_objects_matched()
    else:
        status.expect_partial()

    # Create session and restore model
    sess = tf.Session()
    status.initialize_or_restore(sess)
    # Initialize input stage
    input_stage.init_iterator("test", sess)

    # Create visualization thread
    manager = multiprocessing.Manager()
    filepaths = manager.list()
    pt = PlotThread(filepaths)
    p = multiprocessing.Process(target=pt)
    p.start()
    # Loop over all images
    while True:
        try:
            _, _file_id, path = sess.run((write_file, file_id, filepath))
            filepaths.append(path.decode("ascii"))
            logger.info("Written processed sample %s" % str(_file_id))
        except tf.errors.OutOfRangeError:
            break
    logger.info("Inference successfully finished.")
    p.join()
    return 0
示例#3
0
def main(args):
    # Handle dataset specific paths and number of classes and paths to
    # training and validation set.
    train_paths = []
    val_paths   = []
    classes = 0
    if args["dataset"] == "cityscapes":
        classes = 19
        train_paths.append(os.path.join(args["data_dir"], "train"))
        if args["coarse"]:
            train_paths.append(os.path.join(args["data_dir"], "train_extra"))
        val_paths.append(os.path.join(args["data_dir"], "val"))

    elif args["dataset"] == "freiburg":
        classes = 6
        train_paths.append(os.path.join(args["data_dir"], "train"))
        val_path = os.path.join(args["data_dir"], "val")
        if os.path.exists(val_path):
            train_paths.append(val_path)

    ##### Setup training pipeline #####
    train_input = tt.input.InputStage(args["batch_size"], args["size"])
    with tf.device("/device:GPU:0"): # FIXME make more dynamic
        # Add training dataset
        num_batches = train_input.add_dataset("train",
                                          train_paths,
                                          augment=True)
        # Get iterator output
        image, label, mask = train_input.get_output("train")
        # Setup network and build Network graph
        train_net    = models.ENet(classes, training=True)
        logits = train_net.build(image)
        pred   = train_net.get_predictions()

        # Build cost function and add optimizer
        with tf.name_scope("Loss"):
            # Establish loss function
            loss = tt.losses.masked_softmax_cross_entropy(
                label, logits, mask, classes, scope="XEntropy")

            # FIXME: insert parameters here:
            optimizer = tf.train.AdamOptimizer(args["learning_rate"])

            # Make sure to update the metrics when evaluating loss
            train_op  = optimizer.minimize(loss, global_step=global_step,
                                           name="TrainOp")
        # Create metric evaluation and summaries
        train_metrics = tt.metrics.Metrics(pred, label, classes, mask)
        with tf.variable_scope("TrainSummary"):
            metric_summaries = train_metrics.get_summaries()
            batch_metric_summaries = train_metrics.get_batch_summaries()

            summary_iter = tf.summary.merge([
                    batch_metric_summaries["Global"],
                    tf.summary.scalar("Loss", loss)
                ], name="IterationSummaries")

            summary_epoch = tf.summary.merge(
                [
                    metric_summaries["Global"],
                    metric_summaries["Class"],
                    metric_summaries["ConfusionMat"],
                    #TODO: move image summaries to validation thread
                ], name="EpochSummaries"
            )
            train_metric_update = train_metrics.get_update_op()

    ##### Setup validation pipeline #####
    val_input = tt.input.InputStage(args["batch_size"], args["size"])
    with tf.device("/device:GPU:1"): # FIXME make more dynamic
        # Add training dataset
        val_input.add_dataset("val_iter",
                              val_paths,
                              epochs=-1, # repeat infinately
                              augment=True)
        val_input.add_dataset("val",
                              val_paths,
                              epochs=1,
                              augment=True)
        # Get iterator output
        val_image, val_label, val_mask = val_input.get_output()
        # Setup network and build Network graph
        val_net    = models.ENet(classes, training=False)
        val_logits = val_net.build(image)
        val_pred   = val_net.get_predictions()
        # Build cost function and add optimizer
        with tf.name_scope("Loss"):
            # Establish loss function
            val_loss = tt.losses.masked_softmax_cross_entropy(
                val_label, val_logits, val_mask, classes, scope="ValXEntropy")
        with tf.variable_scope("ValidationSummary"):
            val_metrics = tt.metrics.Metrics(val_pred, val_label,
                                          classes, val_mask)
            val_metric_summaries = val_metrics.get_summaries()
            val_batch_metric_summaries = val_metrics.get_batch_summaries()

            val_summary_iter = tf.summary.merge([
                tf.summary.scalar("Loss", val_loss),
                val_batch_metric_summaries["Global"]
            ])
            val_summary_epoch = tf.summary.merge([
                val_metric_summaries["Global"],
                val_metric_summaries["ConfusionMat"],
                val_metric_summaries["Class"]
                #FIXME fix the image summaries (colormap?)
                tf.summary.image("Input", val_image),
                tf.summary.image("Label", tf.expand_dims(val_label, axis=-1)
                                 * (255//classes)),
                tf.summary.image(
                    "Predictions",
                    tf.expand_dims(
                        tf.cast(val_pred, dtype=tf.uint8), axis=-1)
                    * (255//classes))
            ])
            val_metric_update = val_metrics.get_update_op()

    # Create step variables
    with tf.variable_scope("StepCounters"):
        global_step = tf.Variable(0, dtype=tf.int64,
                                  trainable=False, name="GlobalStep")
        epoch_step = tf.Variable(0, trainable=False, name="EpochStep")
        epoch_step_inc = tf.assign_add(epoch_step, 1, name="EpochStepInc")

    with tf.Session() as sess:

        # Prepare fetches
        fetches = {}
        fetches["iteration"] = {
            "train"   : train_op,
            "step"    : global_step,
            "train/summary" : summary_iter,
            "train/metric"  : metric_update_op,
            "val/summary"   : ,
            "val/metric"    : metric_update_op
        }
        fetches["epoch"] = {
            "summary" : summary_epoch,
            "step"    : epoch_step_inc
        }

        # Initialize variables
        logger.debug("Initializing variables")
        sess.run(tf.global_variables_initializer())

        # Create checkpoint saver object
        vars_to_store = train_net.get_vars() + [epoch_step, global_step]
        saver   = tf.train.Saver(var_list=vars_to_store, max_to_keep=50)
        savedir = os.path.join(args["log_dir"], "model")
        if not os.path.exists(savedir):
            os.makedirs(savedir)
        elif tf.train.latest_checkpoint(args["log_dir"]) != None:
            ckpt = tf.train.latest_checkpoint(args["log_dir"])
            logger.info("Resuming from checkpoint \"%s\"" % ckpt)
            saver.restore(sess, ckpt)
        # Create summary writer objects
        train_summary_writer = tf.summary.FileWriter(args["log_dir"],
                                                     graph=sess.graph)
       # run_options  = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
       # run_metadata = tf.RunMetadata()
        logger.info("Starting training loop...")
        results = {}
        ##### Training Loop #####
        for epoch in range(1,args["epochs"]+1):
            # Create iterator counter to track progress
            _iter = range(0,num_batches)
            _iter = _iter if not show_progress \
                          else tqdm.tqdm(_iter, desc="train[%3d/%3d]"
                                         % (epoch, args["epochs"]))
            # Initialize input stage
            input_stage.init_iterator("train", sess)
            # Reset for another round
            train_metrics.reset_metrics(sess)

            for i in _iter:
                try:
                    _fetches = {"iteration" : fetches["iteration"]} \
                               if i < num_batches-1 else fetches
                    results = sess.run(_fetches
                    #                    options=run_options,
                    #                    run_metadata=run_metadata
                    )
                except tf.errors.OutOfRangeError:
                    pass
                summary_writer.add_summary(results["iteration"]["summary"],
                                           results["iteration"]["step"])
                #summary_writer.add_run_metadata(run_metadata, "step=%d" % i)
            # END for iter
            summary_writer.add_summary(results["epoch"]["summary"],
                                       results["epoch"]["step"])
            summary_writer.flush()
            saver.save(sess, savedir, global_step=results["epoch"]["step"])
        # END for epoch
    return 0
示例#4
0
    def __init__(self, args):
        self.args = args
        self.mode = args.mode
        self.epochs = args.epochs
        self.dataset = args.dataset
        self.data_path = args.data_path
        self.train_crop_size = args.train_crop_size
        self.eval_crop_size = args.eval_crop_size
        self.stride = args.stride
        self.batch_size = args.train_batch_size
        self.train_data = AerialDataset(crop_size=self.train_crop_size,
                                        dataset=self.dataset,
                                        data_path=self.data_path,
                                        mode='train')
        self.train_loader = DataLoader(self.train_data,
                                       batch_size=self.batch_size,
                                       shuffle=True,
                                       num_workers=2)
        self.eval_data = AerialDataset(dataset=self.dataset,
                                       data_path=self.data_path,
                                       mode='val')
        self.eval_loader = DataLoader(self.eval_data,
                                      batch_size=1,
                                      shuffle=False,
                                      num_workers=2)

        if self.dataset == 'Potsdam':
            self.num_of_class = 6
            self.epoch_repeat = get_test_times(6000, 6000,
                                               self.train_crop_size,
                                               self.train_crop_size)
        elif self.dataset == 'UDD5':
            self.num_of_class = 5
            self.epoch_repeat = get_test_times(4000, 3000,
                                               self.train_crop_size,
                                               self.train_crop_size)
        elif self.dataset == 'UDD6':
            self.num_of_class = 6
            self.epoch_repeat = get_test_times(4000, 3000,
                                               self.train_crop_size,
                                               self.train_crop_size)
        else:
            raise NotImplementedError

        if args.model == 'FCN':
            self.model = models.FCN8(num_classes=self.num_of_class)
        elif args.model == 'DeepLabV3+':
            self.model = models.DeepLab(num_classes=self.num_of_class,
                                        backbone='resnet')
        elif args.model == 'GCN':
            self.model = models.GCN(num_classes=self.num_of_class)
        elif args.model == 'UNet':
            self.model = models.UNet(num_classes=self.num_of_class)
        elif args.model == 'ENet':
            self.model = models.ENet(num_classes=self.num_of_class)
        elif args.model == 'D-LinkNet':
            self.model = models.DinkNet34(num_classes=self.num_of_class)
        else:
            raise NotImplementedError

        if args.loss == 'CE':
            self.criterion = CrossEntropyLoss2d()
        elif args.loss == 'LS':
            self.criterion = LovaszSoftmax()
        elif args.loss == 'F':
            self.criterion = FocalLoss()
        elif args.loss == 'CE+D':
            self.criterion = CE_DiceLoss()
        else:
            raise NotImplementedError

        self.schedule_mode = args.schedule_mode
        self.optimizer = opt.AdamW(self.model.parameters(), lr=args.lr)
        if self.schedule_mode == 'step':
            self.scheduler = opt.lr_scheduler.StepLR(self.optimizer,
                                                     step_size=30,
                                                     gamma=0.1)
        elif self.schedule_mode == 'miou' or self.schedule_mode == 'acc':
            self.scheduler = opt.lr_scheduler.ReduceLROnPlateau(self.optimizer,
                                                                mode='max',
                                                                patience=10,
                                                                factor=0.1)
        elif self.schedule_mode == 'poly':
            iters_per_epoch = len(self.train_loader)
            self.scheduler = Poly(self.optimizer,
                                  num_epochs=args.epochs,
                                  iters_per_epoch=iters_per_epoch)
        else:
            raise NotImplementedError

        self.evaluator = Evaluator(self.num_of_class)

        self.model = nn.DataParallel(self.model)

        self.cuda = args.cuda
        if self.cuda is True:
            self.model = self.model.cuda()

        self.resume = args.resume
        self.finetune = args.finetune
        assert not (self.resume != None and self.finetune != None)

        if self.resume != None:
            print("Loading existing model...")
            if self.cuda:
                checkpoint = torch.load(args.resume)
            else:
                checkpoint = torch.load(args.resume, map_location='cpu')
            self.model.load_state_dict(checkpoint['parameters'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.scheduler.load_state_dict(checkpoint['scheduler'])
            self.start_epoch = checkpoint['epoch'] + 1
            #start from next epoch
        elif self.finetune != None:
            print("Loading existing model...")
            if self.cuda:
                checkpoint = torch.load(args.finetune)
            else:
                checkpoint = torch.load(args.finetune, map_location='cpu')
            self.model.load_state_dict(checkpoint['parameters'])
            self.start_epoch = checkpoint['epoch'] + 1
        else:
            self.start_epoch = 1
        if self.mode == 'train':
            self.writer = SummaryWriter(comment='-' + self.dataset + '_' +
                                        self.model.__class__.__name__ + '_' +
                                        args.loss)
        self.init_eval = args.init_eval
def main(args, logger):
    # Retrieve training parameters for convenience
    params   = args.params               # All parameters
    hparams  = params["hyperparams"]     # Hyperparamters
    alparams = params["active_learning"] # Active learning parameters
    state = None # State dict
    # Define state and config filenames
    state_filename  = os.path.join(args.log_dir, "state.json")
    config_filename = os.path.join(args.log_dir, "config.json")
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)
        # Dump parameter config
        with open(config_filename, "w+") as f:
            json.dump(params, f, indent=4)

    # Retrieve dataset specific object
    if args.dataset == "cityscapes":
        dataset = datasets.Cityscapes(coarse=args.coarse)
        test_examples_glob = os.path.join(args.data_dir, "val", "*.tfrecord")
    elif args.dataset == "freiburg":
        dataset = datasets.Freiburg()
        test_examples_glob = os.path.join(args.data_dir, "test", "*.tfrecord")
    elif args.dataset == "vistas":
        dataset = datasets.Vistas()
        test_examples_glob = os.path.join(args.data_dir, "val", "*.tfrecord")
    else:
        raise NotImplementedError("Dataset \"%s\" not supported" % args.dataset)

    # Prepare dataset example file paths.
    train_examples_glob = os.path.join(args.data_dir, "train", "*.tfrecord")

    if not os.path.exists(state_filename):
        # Initialize state
        # Resolve example filenames
        train_val_examples = np.sort(np.array(glob.glob(train_examples_glob)))
        # Pick examples from training set to use for validation
        val_examples   = train_val_examples[:alparams["num_validation"]]
        # Use the rest as training examples
        train_examples = train_val_examples[alparams["num_validation"]:]

        # Use annotated test set, NOTE: cityscapes validation set
        test_examples  = np.array(glob.glob(test_examples_glob))

        # Draw random train examples and mark as annotated
        train_indices  = np.arange(len(train_examples), dtype=np.int32)
        np.random.shuffle(train_indices)

        initially_labelled = alparams["num_initially_labelled"]
        if initially_labelled < 0:
            # Use rest of labelled examples
            initially_labelled = len(train_examples)

        # Possibly add actually unlabelled examples
        no_label_indices = np.empty(0, dtype=str)
        if args.unlabelled is not None:
            no_label_glob     = os.path.join(args.unlabelled, "*.tfrecord")
            no_label_examples = glob.glob(no_label_glob)
            no_label_indices  = np.arange(
                len(train_indices), len(train_indices)+len(no_label_examples)
            )
            train_examples = np.concatenate(train_examples,
                                            no_label_examples)
            train_indices = np.concatenate((train_indices, no_label_indices))

        labelled = train_indices[:initially_labelled]
        unlabelled = train_indices[initially_labelled:]
        del train_indices

        # Setup initial state
        state = {
            "checkpoint" : None, # Keep track of latest checkpoint.
            "iteration"  : 0,
            "dataset" : {
                "train" : {
                    "filenames"  : list(train_examples),
                    "labelled"   : labelled.tolist(),
                    "unlabelled" : unlabelled.tolist(),
                    "no_label"   : no_label_indices.tolist()
                },
                "val"   : {
                    "filenames" : list(val_examples)
                },
                "test"  : {
                    "filenames" : list(test_examples)
                }
            }
        }
        with open(state_filename, "w+") as f:
            json.dump(state, f, indent=2)

    else:
        # Load state
        with open(state_filename, "r") as f:
            state = json.load(f)
        # Extract filename properties
        train_examples   = np.array(state["dataset"]["train"]["filenames"])
        val_examples     = np.array(state["dataset"]["val"]["filenames"])
        test_examples    = np.array(state["dataset"]["test"]["filenames"])
        labelled         = np.array(state["dataset"]["train"]["labelled"])
        unlabelled       = np.array(state["dataset"]["train"]["unlabelled"])
        no_label_indices = np.array(state["dataset"]["train"]["no_label"])

    train_input_labelled = np.full_like(train_examples, False, dtype=bool)
    train_input_labelled[labelled] = True
    train_input_indices = np.arange(len(train_examples))

    with tf.device("/device:CPU:0"):
        with tf.name_scope("Datasets"):
            # Create input placeholders
            train_input = tt.input.NumpyCapsule()
            train_input.filenames = train_examples
            train_input.labelled = train_input_labelled
            train_input.indices   = train_input_indices

            val_input = tt.input.NumpyCapsule()
            val_input.filenames = val_examples
            test_input = tt.input.NumpyCapsule()
            test_input.filenames = test_examples

            # Setup input pipelines
            train_input_stage = tt.input.InputStage(
                input_shape=[params["network"]["input"]["height"],
                             params["network"]["input"]["width"]])
            # Validation AND Test input stage
            val_input_stage  = tt.input.InputStage(
                input_shape=[params["network"]["input"]["height"],
                             params["network"]["input"]["width"]])

            # Add datasets
            train_input_stage.add_dataset_from_placeholders(
                "train", train_input.filenames,
                train_input.labelled, train_input.indices,
                batch_size=params["batch_size"],
                augment=True)
            # Validation set
            val_input_stage.add_dataset_from_placeholders(
                "val", val_input.filenames,
                batch_size=params["batch_size"])
            # Test set
            val_input_stage.add_dataset_from_placeholders(
                "test", test_input.filenames,
                batch_size=params["batch_size"])
            # Calculate number of batches in each iterator
            val_batches   = (len(val_examples) - 1)//params["batch_size"] + 1
            test_batches  = (len(test_examples) - 1)//params["batch_size"] + 1

            # Get iterator outputs
            train_image_raw, train_image, train_label, train_mask, \
                train_labelled, train_index = train_input_stage.get_output()
            val_image, val_label, val_mask = val_input_stage.get_output()

        # Create step variables
        with tf.variable_scope("StepCounters"):
            global_step = tf.Variable(0, dtype=tf.int64,
                                      trainable=False, name="GlobalStep")
            local_step  = tf.Variable(0, dtype=tf.int64,
                                      trainable=False, name="LocalStep")
            global_step_op = tf.assign_add(global_step, local_step)
            epoch_step  = tf.Variable(0, trainable=False, name="EpochStep")
            epoch_step_inc = tf.assign_add(epoch_step, 1)

    # Build training- and validation network
    regularization = {"drop_rates": hparams["dropout_rates"]}
    if hparams["weight_reg"]["L2"] > 0.0 \
       or hparams["weight_reg"]["L1"] > 0.0:
        regularization = {
            "weight_regularization" : tf.keras.regularizers.l1_l2(
                                          l1=hparams["weight_reg"]["L1"],
                                          l2=hparams["weight_reg"]["L2"]),
            "regularization_scaling" : hparams["weight_reg"]["glorot_scaling"],
        }

    # Initialize networks
    train_net = models.ENet(
        dataset.num_classes,
        **regularization
    )
    val_net = models.ENet(dataset.num_classes)

    with tf.device("/device:GPU:0"):
        # Build graph for training
        train_logits  = train_net(train_image, training=True)
        # Compute predictions: use @train_pred for metrics and
        # @pseudo_label for pseudo_annotation process.
        train_pred    = tf.math.argmax(train_logits, axis=-1,
                                       name="TrainPredictions")

        with tf.name_scope("PseudoAnnotation"):
            # Build ops one more time without dropout.
            pseudo_logits = train_net(train_image_raw, training=False)
            # Just make sure not to propagate gradients a second time.
            pseudo_logits = tf.stop_gradient(pseudo_logits)
            pseudo_label  = tf.math.argmax(pseudo_logits, axis=-1,
                                           name="TrainPredictions")
            pseudo_label = tf.cast(pseudo_label, tf.uint8)

            # Configure on-line high confidence pseudo labeling.
            pseudo_prob   = tf.nn.softmax(pseudo_logits, axis=-1, name="TrainProb")
            if alparams["measure"] == "entropy":
                # Reduce entropy over last dimension.
                # Compute prediction entropy
                entropy = - pseudo_prob * tf.math.log(pseudo_prob+EPSILON)
                entropy = tf.math.reduce_sum(entropy, axis=-1)
                # Convert logarithm base to units of number of classes
                # NOTE this will make the metric independent of number of
                #      classes as well the range in [0,1]
                log_base = tf.math.log(np.float32(dataset.num_classes))
                entropy = entropy / log_base
                # Convert entropy to confidence
                pseudo_confidence = 1.0 - entropy
            elif alparams["measure"] == "margin":
                # Difference between the two largest entries in last dimension.
                values, indices = tf.math.top_k(pseudo_prob, k=2)
                pseudo_confidence = values[:,:,:,0] - values[:,:,:,1]
            elif alparams["measure"] == "confidence":
                # Reduce max over last dimension.
                pseudo_confidence = tf.math.reduce_max(pseudo_prob, axis=-1)
            else:
                raise NotImplementedError("Uncertainty function not implemented.")
            pseudo_mean_confidence = tf.reduce_mean(
                tf.cast(pseudo_confidence, tf.float64),
                axis=(1,2))
            # Pseudo annotate high-confidence unlabeled example pixels
            pseudo_mask = tf.where(tf.math.less(pseudo_confidence, alparams["threshold"]),
                                   tf.zeros_like(pseudo_label,
                                                 dtype=train_label.dtype),
                                   tf.ones_like(pseudo_label,
                                                dtype=train_label.dtype))
            # Pseudo annotation logic (think of it as @tf.cond maped 
            # over batch dimension)
            train_label = tf.where(train_labelled, train_label,
                                   pseudo_label, name="MaybeGenLabel")
            train_mask  = tf.where(train_labelled, train_mask,
                                   pseudo_mask, name="MaybeGenMask")

    with tf.device("/device:GPU:1"):
        # Build validation network.
        val_logits = val_net(val_image, training=False)
        val_pred   = tf.math.argmax(val_logits, axis=-1,
                                    name="ValidationPredictions")

    # Build cost function
    with tf.name_scope("Cost"):
        with tf.device("/device:GPU:0"):
            # Establish loss function
            if hparams["softmax"]["multiscale"]:
                loss, loss_weights = \
                    tt.losses.multiscale_masked_softmax_cross_entropy(
                        train_label,
                        train_net.endpoint_outputs[0],
                        train_mask, dataset.num_classes,
                        weight=hparams["softmax"]["loginverse_scaling"],
                        label_smoothing=hparams["softmax"]["label_smoothing"],
                        scope="XEntropy")
                # NOTE: this will make @loss_weights checkpointed
                train_net.loss_scale_weights = loss_weights
            else:
                loss = tt.losses.masked_softmax_cross_entropy(
                    train_label,
                    train_logits,
                    train_mask, dataset.num_classes,
                    weight=hparams["softmax"]["loginverse_scaling"],
                    label_smoothing=hparams["softmax"]["label_smoothing"],
                    scope="XEntropy")
            cost = loss
            # Add regularization to cost function
            if len(train_net.losses) > 0:
                regularization_loss = tf.math.add_n(train_net.losses, name="Regularization")
                cost += tf.cast(regularization_loss, dtype=tf.float64)

            # Setup learning rate
            learning_rate = hparams["learning_rate"]
            if hparams["learning_rate_decay"] > 0.0:
                # Inverse time learning_rate if lr_decay specified
                learning_rate = tf.train.inverse_time_decay(
                    learning_rate, local_step,
                    decay_steps=train_batches,
                    decay_rate=hparams["learning_rate_decay"])

            # Create optimization procedure
            optimizer = tf.train.AdamOptimizer(learning_rate, **hparams["optimizer"]["kwargs"])

            # Create training op
            train_op  = optimizer.minimize(cost, global_step=local_step,
                                           name="TrainOp")
        # END tf.device("/device:GPU:0")
    # END tf.name_scope("Cost")

    # Create summary operations for training and validation network
    with tf.name_scope("Summary"):
        # Create colormap for image summaries
        colormap = tf.constant(dataset.colormap, dtype=tf.uint8,
                               name="Colormap")
        # Create metric evaluation and summaries
        with tf.device("/device:GPU:0"):
            with tf.name_scope("TrainMetrics"):
                # Create metrics object for training network.
                train_metrics = tt.metrics.Metrics(train_pred, train_label,
                                                   dataset.num_classes, train_mask)
                # Get Tensorflow update op.
                metric_update_op = train_metrics.get_update_op()
                # Get Tensorflow summary operations.
                metric_summaries = train_metrics.get_summaries()

            train_summary_iter = tf.summary.merge(
                [
                    # Summaries run at each iteration.
                    tf.summary.scalar("CrossEntropyLoss", loss,
                                      family="Losses"),
                    tf.summary.scalar("TotalCost", cost,
                                      family="Losses"),
                    tf.summary.scalar("LearningRate", learning_rate,
                                      family="Losses")
                ], name="IterationSummaries"
               )

            with tf.control_dependencies([metric_update_op]):
                train_summary_epoch = tf.summary.merge(
                    [
                        # Summaries run at epoch boundaries.
                        metric_summaries["Metrics"],
                        metric_summaries["ConfusionMat"]
                    ], name="EpochSummaries"
                   )

            train_image_summary = tf.summary.merge(
                [
                    tf.summary.image(
                        "PseudoLabel/input",
                        train_image_raw,
                        family="PseudoLabel"
                    ),
                    tf.summary.image(
                        "PseudoLabel/confidence",
                        tf.expand_dims(pseudo_confidence, axis=-1),
                        family="PseudoLabel"
                    ),
                    tf.summary.image(
                        "PseudoLabel", 
                        tf.gather(dataset.colormap,
                                  tf.cast(pseudo_label*pseudo_mask \
                                  + (1 - pseudo_mask)*255,
                                  tf.int32)),
                        family="PseudoLabel"
                    )
                ]
            )
        # Create metric evaluation and summaries
        with tf.device("/device:GPU:1"):
            with tf.name_scope("ValidationTestMetrics"):
                # Create metrics object
                val_metrics = tt.metrics.Metrics(val_pred, val_label,
                                                 dataset.num_classes, val_mask)
                # Get update tensorflow ops
                val_metric_update_op = val_metrics.get_update_op()
                # Get metric sumaries
                val_metric_summaries = val_metrics.get_summaries()

                with tf.control_dependencies([val_metric_update_op]):
                    val_metric_summary = tf.summary.merge(
                        [
                            # "Expensive" summaries run at epoch boundaries.
                            val_metric_summaries["Metrics"],
                            val_metric_summaries["ClassMetrics"],
                            val_metric_summaries["ConfusionMat"]
                        ], name="EpochSummaries"
                    )
                    val_image_summary = tf.summary.merge(
                        [
                            tf.summary.image("Input", val_image),
                            tf.summary.image("Label", tf.gather(
                                colormap, tf.cast(val_label + 255*(1-val_mask),
                                                  tf.int32))),
                            tf.summary.image("Predictions", tf.gather(
                                colormap, tf.cast(val_pred, tf.int32)))
                        ]
                    )
                    val_summary_epoch = val_metric_summary
                    test_summary_epoch = tf.summary.merge([
                        val_metric_summary,
                        val_image_summary
                        ]
                    )
        conf_summary_ph = tf.placeholder(tf.float64, shape=[None])
        conf_summary = tf.summary.histogram("ConfidenceDistribution",
                                            conf_summary_ph)
    # END name_scope("Summary")

    # Create session with soft device placement
    #     - some ops neet to run on the CPU
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.allow_growth = True
    with tf.Session(config=sess_config) as sess:
        logger.debug("Initializing variables...")
        sess.run(tf.global_variables_initializer())


        # Create checkpoint object
        with tf.name_scope("Checkpoint"):
            checkpoint = tf.train.Checkpoint(model=train_net,
                                             epoch=epoch_step,
                                             step=global_step,
                                             optimizer=optimizer)
            checkpoint_name = os.path.join(args.log_dir, "model")
            if args.checkpoint is not None:
                # CMDline checkpoint given
                ckpt = args.checkpoint
                if os.path.isdir(ckpt):
                    ckpt = tf.train.latest_checkpoint(ckpt)
                if ckpt is None:
                    logger.error("Checkpoint path \"%s\" is invalid.")
                    return 1
                logger.info("Resuming from checkpoint \"%s\"" % ckpt)
                status = checkpoint.restore(ckpt)
                if tf.__version__ < "1.14.0":
                    status.assert_existing_objects_matched()
                else:
                    status.expect_partial()
                status.initialize_or_restore(sess)
                if args.reinitialize_output:
                    sess.run(train_net.Final.kernel.initializer)

            elif state["checkpoint"] != None:
                # Try to restore from checkpoint in logdir
                ckpt = state["checkpoint"]
                logger.info("Resuming from checkpoint \"%s\"" % ckpt)
                status = checkpoint.restore(ckpt)
                if tf.__version__ < "1.14.0":
                    status.assert_existing_objects_matched()
                else:
                    status.expect_partial()
                status.initialize_or_restore(sess)

            with tf.name_scope("UpdateValidationWeights"):
                update_val_op = []
                for i in range(len(val_net.layers)):
                    for j in range(len(val_net.layers[i].variables)):
                        update_val_op.append(
                            tf.assign(val_net.layers[i].variables[j],
                                      train_net.layers[i].variables[j]))
                update_val_op = tf.group(update_val_op)

        ckpt_manager = tt.checkpoint_manager.CheckpointManager(checkpoint,
                                                           args.log_dir)
        # END scope Checkpoint
        # Prepare global fetches dict
        fetches = {
            "train" : {
                "iteration" : {
                    "step"     : global_step_op,
                    "summary"  : train_summary_iter,
                    "train_op" : train_op,
                    "update"   : metric_update_op,
                    "updates"  : train_net.updates
                },
                "epoch"     : {
                    "step"     : epoch_step,
                    "summary"  : train_summary_epoch,
                    "summary/image" : train_image_summary
                }
            },
            "val"   : { # Validation and test fetches
                "iteration" : {
                    "update"   : val_metric_update_op
                },
                "epoch"     : {
                    "step"     : epoch_step,
                    "MeanIoU"  : val_metrics.metrics["MeanIoU"],
                    "summary"  : val_summary_epoch,
                    # Also add image summary, however only added to
                    # writer every N epochs.
                    "summary/image" : val_image_summary
                }
            },
            "test" : {
                "iteration" : {"update"  : val_metric_update_op},
                "epoch"     : {"summary" : test_summary_epoch}
            }
        }

        # Train loop (until convergence) -> Pick unlabeled examples -> test_loop
        def train_loop(summary_writer):
            """
            Train loop closure.
            Runs training loop untill no improvement is seen in
            @params["epochs"] epochs before returning.
            """
            # How many epoch until counting @no_improvement
            _initial_grace_period = alparams["epochs/warm_up"]
            best_ckpt             = state["checkpoint"]
            best_mean_iou         = 0.0
            log_subdir            = summary_writer.get_logdir()
            run_name              = os.path.basename(log_subdir)
            checkpoint_prefix     = os.path.join(log_subdir, "model")
            num_iter_per_epoch    = np.maximum(train_input.size,
                                              val_input.size)
            no_improvement_count = 0
            while no_improvement_count < params["epochs"] \
                or _initial_grace_period >= 0:
                _initial_grace_period -= 1
                # Increment in-graph epoch counter.
                epoch = sess.run(epoch_step_inc)

                # Prepare inner loop iterator
                _iter = range(0, num_iter_per_epoch, params["batch_size"])
                if show_progress:
                    _iter = tqdm.tqdm(_iter, desc="%s[%d]" % (run_name, epoch),
                                      dynamic_ncols=True,
                                      ascii=True,
                                      postfix={"NIC": no_improvement_count})

                # Initialize iterators
                train_input_stage.init_iterator(
                    "train", sess, train_input.feed_dict)
                val_input_stage.init_iterator(
                    "val", sess, val_input.feed_dict)

                # Reset confusion matrices
                train_metrics.reset_metrics(sess)
                val_metrics.reset_metrics(sess)

                # Prepare iteration fetches
                _fetches = {
                    "train" : {"iteration" : fetches["train"]["iteration"]},
                    "val"   : {"iteration" : fetches["val"]["iteration"]}
                }
                # Update validation network weights
                sess.run(update_val_op)

                try:
                    for i in _iter:
                        if train_input.size-params["batch_size"] <= i < train_input.size:
                            # Fetches for last training iteration.
                            _fetches["train"]["epoch"] = fetches["train"]["epoch"]
                        if val_input.size-params["batch_size"] <= i < val_input.size:
                            _fetches["val"]["epoch"] = fetches["val"]["epoch"]

                        # Run fetches
                        results = sess.run(_fetches)

                        if "train" in results.keys():
                            # Add iteration summary
                            summary_writer.add_summary(
                                results["train"]["iteration"]["summary"],
                                results["train"]["iteration"]["step"])

                            # Maybe add epoch summary
                            if "epoch" in results["train"].keys():
                                summary_writer.add_summary(
                                    results["train"]["epoch"]["summary"],
                                    results["train"]["epoch"]["step"]
                                )
                                # Pop fetches to prohibit OutOfRangeError due to
                                # asymmetric train-/val- input size.
                                if results["train"]["epoch"]["step"] % 100 == 0:
                                    summary_writer.add_summary(
                                        results["train"]["epoch"]["summary/image"],
                                        results["train"]["epoch"]["step"]
                                    )
                                _fetches.pop("train")

                        if "val" in results.keys() and \
                           "epoch" in results["val"].keys():
                            # Add summaries to event log.
                            summary_writer.add_summary(
                                results["val"]["epoch"]["summary"],
                                results["val"]["epoch"]["step"]
                            )
                            if results["val"]["epoch"]["step"] % 100 == 0:
                                # Only report image summary every 100th epoch.
                                summary_writer.add_summary(
                                    results["val"]["epoch"]["summary/image"],
                                    results["val"]["epoch"]["step"]
                                )
                            # Check if MeanIoU improved and
                            # update counter and best
                            if results["val"]["epoch"]["MeanIoU"] > best_mean_iou:
                                best_mean_iou = results["val"]["epoch"]["MeanIoU"]
                                # Update checkpoint file used for
                                # @tf.train.latest_checkpoint to point at
                                # current best.
                                _ckpt_name = ckpt_manager.commit(
                                    checkpoint_prefix, sess)
                                if _ckpt_name != "":
                                    best_ckpt = _ckpt_name
                                # Reset counter
                                no_improvement_count = 0
                            else:
                                # Result has not improved, increment counter.
                                no_improvement_count += 1
                                if no_improvement_count >= params["epochs"] and \
                                   _initial_grace_period < 0:
                                    _iter.close()
                                    break
                            if show_progress:
                                _iter.set_postfix(NIC=no_improvement_count)
                            # Pop fetches to prohibit OutOfRangeError due to
                            # asymmetric train-/val- input size.
                            _fetches.pop("val")
                        # END "maybe add epoch summary"
                except tf.errors.OutOfRangeError:
                    logger.error("Out of range error. Attempting to continue.")
                    pass

                summary_writer.flush()
                ckpt_manager.cache(sess)
            # END while no_improvement_count < params["epochs"]
            return best_ckpt

        def test_loop(summary_writer):
            """
            Test loop closure.
            """
            _step = len(labelled)
            # Initialize validation input stage with test set
            val_input_stage.init_iterator("test", sess, test_input.feed_dict)
            _iter = range(0, test_input.size, params["batch_size"])
            if show_progress:
                _iter = tqdm.tqdm(_iter, desc="test[%d]" % (_step),
                                  ascii=True,
                                  dynamic_ncols=True)
            summary_proto = None
            val_metrics.reset_metrics(sess)
            try:
                for i in _iter:
                    # Accumulate confusion matrix
                    if i < test_input.size - params["batch_size"]:
                        sess.run(fetches["test"]["iteration"]["update"])
                    else:
                        # Run summary operation last iteration
                        _, summary_proto = sess.run([fetches["test"]["iteration"]["update"],
                                                     fetches["test"]["epoch"]["summary"]])
            except tf.errors.OutOfRangeError:
                pass
            # Add summary with number of labelled examples as step.
            # NOTE this only runs on each major iteration.
            summary_writer.add_summary(
                summary_proto, _step
            )

        def rank_confidence():
            # Allocate array to store all confidence scores
            num_examples = len(state["dataset"]["train"]["filenames"])
            confidence = np.zeros(num_examples, dtype=np.float32)
            # Initialize input stage
            train_input_stage.init_iterator("train", sess,
                                            train_input.feed_dict)
            _iter = range(0, train_input.size, params["batch_size"])
            if show_progress:
                _iter = tqdm.tqdm(_iter, desc="ranking[%d]" % len(labelled),
                                  ascii=True,
                                  dynamic_ncols=True)
            try:
                for i in _iter:
                    # Loop over all examples and compute confidence
                    batch_confidence, batch_indices = sess.run(
                        [pseudo_mean_confidence, train_index])
                    # Add to list of confidence
                    confidence[batch_indices] = batch_confidence
            except tf.errors.OutOfRangeError:
                pass

            # Filter out labelled examples
            unlabelled_confidence = confidence[unlabelled]

            selection_size = np.minimum(len(unlabelled),
                                        alparams["selection_size"])
            # Get the lowest confidence indices of unlabelled subset
            example_indices = np.argpartition(unlabelled_confidence,
                                              selection_size)
            example_indices = example_indices[:selection_size]
            # Convert to indices into all filenames list
            low_conf_examples = unlabelled[example_indices]
            return low_conf_examples, unlabelled_confidence

        checkpoint_path = state["checkpoint"]
        # Only add graph to first event file
        _graph = sess.graph if checkpoint_path == None else None
        with tf.summary.FileWriter(args.log_dir, graph=_graph) as test_writer:
            iterations = alparams["iterations"]
            if iterations < 0:
                # Iterate untill all data is consumed
                iterations = np.ceil(len(unlabelled)
                                     / float(alparams["selection_size"]))
                logger.info("Iteration count: %d" % iterations)

            while state["iteration"] < iterations:
                # Step 1: train_loop
                train_input.set_indices(labelled)

                if state["iteration"] == 0:
                    # Pretrain
                    log_subdir = os.path.join(args.log_dir, "pretrain")
                    # Only use labelled subset
                else:
                    # Any other iteration
                    log_subdir = os.path.join(args.log_dir, "iter-%d" %
                                              state["iteration"])
                    # Sample from the unlabelled set
                    p = alparams["pseudo_labelling_proportion"]
                    sample_size = int(len(labelled)*p/(1-p))
                    sample_size = np.minimum(sample_size, len(unlabelled))
                    train_input.set_sample_size(sample_size)

                # Create subdir if it doesn't exist
                if not os.path.exists(log_subdir):
                    os.mkdir(log_subdir)

                # Change checkpoint manager directory
                ckpt_manager.chdir(log_subdir)
                with tf.summary.FileWriter(log_subdir) as train_val_writer:
                    # Enter train loop
                    try:
                        checkpoint_path = train_loop(train_val_writer)
                    except KeyboardInterrupt as exception:
                        # Quickly store state
                        if ckpt_manager.latest_checkpoint != "":
                            state["checkpoint"] = ckpt_manager.latest_checkpoint
                        with open(state_filename, "w") as f:
                            json.dump(state, f, indent=2)
                            f.truncate()
                        raise exception


                # Reload best checkpoint
                status = checkpoint.restore(checkpoint_path)
                status.run_restore_ops(sess)
                sess.run(update_val_op)

                # Step 2: test_loop
                if test_input.size > 0:
                    # This step may be omitted on deployment
                    test_loop(test_writer)

                # Step 3: Find low confidence examples
                # Reset train_input to use all examples for ranking
                train_input.set_indices()
                if alparams["selection_size"] > 0:
                    low_conf_examples, unlabelled_conf = rank_confidence()
                    _hist_summary = sess.run(conf_summary,
                                             {conf_summary_ph: 
                                              unlabelled_conf})
                    test_writer.add_summary(_hist_summary, state["iteration"])
                else:
                    # Draw examples randomly
                    selection_size = np.minimum(alparams["selection_size"],
                                                len(unlabelled.tolist()))
                    if selection_size != 0:
                        low_conf_examples = np.random.choice(
                            unlabelled, np.abs(alparams["selection_size"]))
                    else:
                        low_conf_examples = []

                # (maybe) Pause for user to annotate
                to_annotate_indices = no_label_indices[np.isin(
                    no_label_indices, low_conf_examples)]

                while len(to_annotate_indices) > 0:
                    to_annotate = train_examples[to_annotate_indices]
                    # Poll user for filenames of annotated examples
                    logger.info("Please annotate the following examples:\n%s" %
                                "\n".join(to_annotate_basename.tolist()))
                    filenames = tkinter.filedialog.askopenfilename(
                        multiple=1,
                        filetypes=(("TFRecord", "*.tfrecord"),))

                    hit = [] # List of matching filename indices
                    for filename in filenames:
                        basename = os.path.basename(filename)
                        idx = -1
                        for i in range(len(to_annotate)):
                            if to_annotate[i].endswith(basename):
                                idx = i
                                break
                        if idx != -1:
                            # Update state filenames
                            train_examples[to_annotate_indices[idx]] = filename
                            hit.append(idx)
                        else:
                            logger.info("Unrecognized filepath: %s" % filename)
                    # Remove matched paths
                    to_annotate_indices = np.delete(to_annotate_indices, hit)


                # Remove annotated examples from unlabelled set
                no_label_indices = no_label_indices[np.isin(no_label_indices,
                                                             low_conf_examples,
                                                             invert=True)]


                logger.info(
                    "Moving following examples to labelled set:\n%s" %
                    "\n".join(train_examples[low_conf_examples].tolist())
                )
                # First make the update to input stage before
                # commiting state change
                train_input_labelled[low_conf_examples] = True
                train_input.labelled = train_input_labelled


                # Step 4: Update state information
                labelled = np.append(labelled, low_conf_examples)
                unlabelled = unlabelled[np.isin(unlabelled, low_conf_examples,
                                            assume_unique=True, invert=True)]
                state["dataset"]["train"]["filenames"] = train_examples.tolist()
                state["dataset"]["train"]["labelled"] = labelled.tolist()
                state["dataset"]["train"]["unlabelled"] = unlabelled.tolist()
                state["iteration"] += 1
                state["checkpoint"] = checkpoint_path
                # Dump updated state
                with open(state_filename, "w") as f:
                    json.dump(state, f, indent=2)
                    f.truncate()
    return 0
示例#6
0
def main(args):
    # Retrieve dataset specific object
    if args.dataset == "cityscapes":
        dataset = datasets.Cityscapes(coarse=args.coarse)
    elif args.dataset == "freiburg":
        dataset = datasets.Freiburg()
    elif args.dataset == "vistas":
        dataset = datasets.Vistas()
    else:
        raise NotImplementedError("Dataset \"%s\" not supported" %
                                  args.dataset)
    # Gather train and validation paths
    train_paths = os.path.join(args.data_dir, "train")
    val_paths = os.path.join(args.data_dir, "val")
    # Retrieve training parameters
    params = args.params
    hparams = params["hyperparams"]

    with tf.device("/device:CPU:0"):
        with tf.name_scope("Datasets"):
            # Setup input pipelines
            train_input = tt.input.InputStage(input_shape=[
                params["network"]["input"]["height"], params["network"]
                ["input"]["width"]
            ])
            val_input = tt.input.InputStage(input_shape=[
                params["network"]["input"]["height"], params["network"]
                ["input"]["width"]
            ])

            # Add datasets
            train_examples = train_input.add_dataset(
                "train",
                train_paths,
                batch_size=params["batch_size"],
                epochs=1,
                augment=True)
            val_examples = val_input.add_dataset(
                "val", val_paths, batch_size=params["batch_size"], epochs=1)
            # Calculate number of batches
            train_batches = (train_examples - 1) // params["batch_size"] + 1
            val_batches = (val_examples - 1) // params["batch_size"] + 1

            # Get iterator outputs
            _, train_image, train_label, train_mask = train_input.get_output()
            val_image, val_label, val_mask = val_input.get_output()

        # Create step variables
        with tf.variable_scope("StepCounters"):
            # I'll use one local (to this run) and a global step that
            # will be checkpointed in order to run various schedules on
            # the learning rate decay policy.
            global_step = tf.Variable(0,
                                      dtype=tf.int64,
                                      trainable=False,
                                      name="GlobalStep")
            local_step = tf.Variable(0,
                                     dtype=tf.int64,
                                     trainable=False,
                                     name="LocalStep")
            global_step_op = global_step + local_step
            epoch_step = tf.Variable(0, trainable=False, name="EpochStep")
            epoch_step_inc = tf.assign_add(epoch_step, 1, name="EpochStepInc")

    regularization = {}
    if hparams["weight_reg"]["L2"] > 0.0 \
        or hparams["weight_reg"]["L1"] > 0.0:
        regularization = {
            "weight_regularization":
            tf.keras.regularizers.l1_l2(l1=hparams["weight_reg"]["L1"],
                                        l2=hparams["weight_reg"]["L2"]),
            "regularization_scaling":
            hparams["weight_reg"]["glorot_scaling"]
        }
    # Build training and validation network and get prediction output
    train_net = models.ENet(dataset.num_classes, **regularization)
    val_net = models.ENet(dataset.num_classes)
    with tf.device("/device:GPU:0"):
        train_logits = train_net(train_image, training=True)
        train_pred = tf.math.argmax(train_logits,
                                    axis=-1,
                                    name="TrainPredictions")

    with tf.device("/device:GPU:1"):
        val_logits = val_net(val_image, training=False)
        val_pred = tf.math.argmax(val_logits,
                                  axis=-1,
                                  name="ValidationPredictions")

    # Build cost function
    with tf.name_scope("Cost"):
        with tf.device("/device:GPU:0"):
            # Establish loss function
            if hparams["softmax"]["multiscale"]:
                loss, loss_weights = \
                        tt.losses.multiscale_masked_softmax_cross_entropy(
                    train_label,
                    train_net.endpoint_outputs[0],
                    train_mask, dataset.num_classes,
                    weight=hparams["softmax"]["loginverse_scaling"],
                    label_smoothing=hparams["softmax"]["label_smoothing"],
                    scope="XEntropy")
                # NOTE: this will make @loss_weights checkpointed
                train_net.loss_scale_weights = loss_weights
            else:
                loss = tt.losses.masked_softmax_cross_entropy(
                    train_label,
                    train_logits,
                    train_mask,
                    dataset.num_classes,
                    weight=hparams["softmax"]["loginverse_scaling"],
                    label_smoothing=hparams["softmax"]["label_smoothing"],
                    scope="XEntropy")
            cost = loss
            # Add regularization to cost function
            if len(train_net.losses) > 0:
                regularization_loss = tf.math.add_n(train_net.losses,
                                                    name="Regularization")
                cost += tf.cast(regularization_loss, dtype=tf.float64)

            # Setup learning rate
            learning_rate = hparams["learning_rate"]
            if hparams["learning_rate_decay"] > 0.0:
                # Inverse time learning_rate if lr_decay specified
                learning_rate = tf.train.inverse_time_decay(
                    learning_rate,
                    local_step,
                    decay_steps=train_batches,
                    decay_rate=hparams["learning_rate_decay"])

            # Create optimization procedure
            optimizer = tf.train.AdamOptimizer(
                learning_rate, **hparams["optimizer"]["kwargs"])

            # Create training op
            train_op = optimizer.minimize(cost,
                                          global_step=local_step,
                                          name="TrainOp")
            # NOTE: Make sure to update batchnorm params and metrics for
            # each training iteration.

    # Create summary operations for training and validation network
    with tf.name_scope("Summary"):
        # Create colormap for image summaries
        colormap = tf.constant(dataset.colormap,
                               dtype=tf.uint8,
                               name="Colormap")
        # Create metric evaluation and summaries
        with tf.device("/device:GPU:0"):
            with tf.name_scope("TrainMetrics"):
                train_metrics = tt.metrics.Metrics(train_pred, train_label,
                                                   dataset.num_classes,
                                                   train_mask)
                metric_update_op = train_metrics.get_update_op()
                metric_summaries = train_metrics.get_summaries()

            train_summary_iter = tf.summary.merge([
                tf.summary.scalar("CrossEntropyLoss", loss, family="Losses"),
                tf.summary.scalar("TotalCost", cost, family="Losses"),
                tf.summary.scalar(
                    "LearningRate", learning_rate, family="Losses")
            ],
                                                  name="IterationSummaries")
            with tf.control_dependencies([metric_update_op]):
                train_summary_epoch = tf.summary.merge([
                    metric_summaries["Metrics"],
                    metric_summaries["ConfusionMat"],
                ],
                                                       name="EpochSummaries")

        # Create metric evaluation and summaries
        with tf.device("/device:GPU:1"):
            with tf.name_scope("ValidationMetrics"):
                val_metrics = tt.metrics.Metrics(val_pred, val_label,
                                                 dataset.num_classes, val_mask)
                val_metric_update_op = val_metrics.get_update_op()
                val_metric_summaries = val_metrics.get_summaries()

                with tf.control_dependencies([val_metric_update_op]):
                    val_summary_epoch = tf.summary.merge([
                        val_metric_summaries["Metrics"],
                        val_metric_summaries["ClassMetrics"],
                        val_metric_summaries["ConfusionMat"],
                        tf.summary.image("Input", val_image),
                        tf.summary.image(
                            "Label",
                            tf.gather(
                                colormap,
                                tf.cast(val_label + 255 *
                                        (1 - val_mask), tf.int32))),
                        tf.summary.image(
                            "Predictions",
                            tf.gather(colormap, tf.cast(val_pred, tf.int32)))
                    ],
                                                         name="EpochSummaries")
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)
        # Dump parameter configuration (args)
        with open(os.path.join(args.log_dir, "config.json"), "w+") as f:
            json.dump(params, f, indent=4, sort_keys=True)

    # Create session with soft device placement
    #     - some ops neet to run on the CPU
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    with tf.Session(config=sess_config) as sess:
        # Initialize/restore model variables
        logger.debug("Initializing model...")
        sess.run(tf.global_variables_initializer())
        # Create summary writer objects
        summary_writer = tf.summary.FileWriter(args.log_dir, graph=sess.graph)

        # Create checkpoint object
        with tf.name_scope("Checkpoint"):
            checkpoint = tf.train.Checkpoint(model=train_net,
                                             epoch=epoch_step,
                                             step=global_step,
                                             optimizer=optimizer)
            checkpoint_name = os.path.join(args.log_dir, "model")

            if args.checkpoint is not None:
                # CMDline checkpoint given
                ckpt = args.checkpoint
                if os.path.isdir(ckpt):
                    ckpt = tf.train.latest_checkpoint(ckpt)
                if ckpt is None:
                    logger.error("Checkpoint path \"%s\" is invalid.")
                    return 1
                logger.info("Resuming from checkpoint \"%s\"" % ckpt)
                status = checkpoint.restore(ckpt)
                if tf.__version__ < "1.14.0":
                    status.assert_existing_objects_matched()
                else:
                    status.expect_partial()
                status.initialize_or_restore(sess)

            elif tf.train.latest_checkpoint(args.log_dir) != None:
                # Try to restore from checkpoint in logdir
                ckpt = tf.train.latest_checkpoint(args.log_dir)
                logger.info("Resuming from checkpoint \"%s\"" % ckpt)
                status = checkpoint.restore(ckpt)
                if tf.__version__ < "1.14.0":
                    status.assert_existing_objects_matched()
                else:
                    status.expect_partial()
                status.initialize_or_restore(sess)

            with tf.name_scope("UpdateValidationWeights"):
                update_val_op = []
                for i in range(len(val_net.layers)):
                    for j in range(len(val_net.layers[i].variables)):
                        update_val_op.append(
                            tf.assign(val_net.layers[i].variables[j],
                                      train_net.layers[i].variables[j]))
                update_val_op = tf.group(update_val_op)
        # END scope Checkpoint

        # Prepare fetches
        fetches = {
            "train": {
                "iteration": {
                    "step": global_step_op,
                    "summary": train_summary_iter,
                    "train_op": train_op,
                    "update": metric_update_op,
                    "updates": train_net.updates
                },
                "epoch": {
                    "step": epoch_step,
                    "summary": train_summary_epoch
                }
            },
            "val": {
                "iteration": {
                    "update": val_metric_update_op
                },
                "epoch": {
                    "step": epoch_step,
                    "summary": val_summary_epoch
                }
            }
        }
        #run_options  = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        #run_metadata = tf.RunMetadata()
        logger.info("Starting training loop...")
        results = {}
        for epoch in range(1, params["epochs"] + 1):
            # Create iterator counter to track progress
            _iter = range(0, train_batches)
            if show_progress:
                _iter = tqdm.tqdm(_iter,
                                  desc="train[%3d/%3d]" %
                                  (epoch, params["epochs"]),
                                  ascii=True,
                                  dynamic_ncols=True)
            # Initialize input stage
            train_input.init_iterator("train", sess)
            val_input.init_iterator("val", sess)
            # Initialize or update validation network
            sess.run(update_val_op)
            # Reset for another round
            train_metrics.reset_metrics(sess)
            val_metrics.reset_metrics(sess)
            # Prepare initial fetches
            _fetches = {
                "train": {
                    "iteration": fetches["train"]["iteration"]
                },
                "val": {
                    "iteration": fetches["val"]["iteration"]
                }
            }

            for i in _iter:
                try:
                    # Dynamically update fetches
                    if i == train_batches - 1:
                        _fetches["train"]["epoch"] = fetches["train"]["epoch"]
                    if i == val_batches - 1:
                        _fetches["val"]["epoch"] = fetches["val"]["epoch"]
                    elif i == val_batches:
                        summary_writer.add_summary(
                            results["val"]["epoch"]["summary"],
                            results["val"]["epoch"]["step"])
                        _fetches.pop("val")
                    # Execute fetches
                    results = sess.run(
                        _fetches
                        #,options=run_options,
                        #run_metadata=run_metadata
                    )
                except tf.errors.OutOfRangeError:
                    pass
                # Update summaries
                summary_writer.add_summary(
                    results["train"]["iteration"]["summary"],
                    results["train"]["iteration"]["step"])
                #summary_writer.add_run_metadata(run_metadata, "step=%d" % i)

            # Update epoch counter
            _epoch = sess.run(epoch_step_inc)

            # Update epoch summaries
            summary_writer.add_summary(results["train"]["epoch"]["summary"],
                                       results["train"]["epoch"]["step"])
            summary_writer.flush()
            # Save checkpoint
            checkpoint.save(checkpoint_name, sess)

        ### FINAL VALIDATION ###
        _fetches = {"val": {"iteration": fetches["val"]["iteration"]}}
        _iter = range(0, val_batches)
        if show_progress:
            _iter = tqdm.tqdm(_iter,
                              desc="val[%3d/%3d]" %
                              (params["epochs"], params["epochs"]))
        # Re initialize network
        val_input.init_iterator("val", sess)
        sess.run(update_val_op)
        for i in _iter:
            try:
                if i >= val_batches - 1:
                    _fetches["val"]["epoch"] = fetches["val"]["epoch"]
                results = sess.run(_fetches)
            except tf.errors.OutOfRangeError:
                pass
        # Add final validation summary update
        summary_writer.add_summary(results["val"]["epoch"]["summary"],
                                   results["val"]["epoch"]["step"])
        # Close summary file
        summary_writer.close()
        logger.info("Training successfully finished %d epochs" %
                    params["epochs"])
    return 0