Beispiel #1
0
import torch
import ubelt as ub
import torchvision  # NOQA
import torch.nn as nn
import math
import torch  # NOQA
import torch.nn.functional as F
from clab import util
from clab.models import mixin
from clab.models.output_shape_for import OutputShapeFor
import numpy as np

from clab import util  # NOQA
from clab import getLogger
logger = getLogger(__name__)
print = util.protect_print(logger.info)


def default_nonlinearity():
    # nonlinearity = functools.partial(nn.ReLU, inplace=False)
    return nn.LeakyReLU(inplace=True)


class DenseLayer(nn.Sequential):
    """
    self = DenseLayer(32, 32, 4)

    """
    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate=0):
        util.super2(DenseLayer, self).__init__()
        self.bn_size = bn_size
Beispiel #2
0
from clab.live import unet2
from clab.live import unet3
from clab.live.urban_metrics import instance_fscore
from clab.live.urban_pred import seeded_instance_label_from_probs
from clab.tasks.urban_mapper_3d import UrbanMapper3D
from clab.torch import criterions
from clab.torch import hyperparams
from clab.torch import im_loaders
from clab.torch import metrics
from clab.torch import models
from clab.torch import transforms
from clab.torch import xpu_device
from clab.torch.transforms import (ImageCenterScale, DTMCenterScale, ZipTransforms)
from clab.torch.transforms import (RandomWarpAffine, RandomGamma, RandomBlur,)

print = util.protect_print(print)

DEBUG = ub.argflag('--debug')


def package_pretrained_submission():
    """
    Gather the models trained during phase 1 and output them in a format
    useable by the phase 2 solution. Note: remember to put the output folder
    into docker.
    """
    # model1 = '/home/local/KHQ/jon.crall/data/work/urban_mapper2/test/input_26400-sotwptrx/solver_52200-fqljkqlk_unet2_ybypbjtw_smvuzfkv_a=1,c=RGB,n_ch=6,n_cl=4/_epoch_00000000/stitched'

    # model2 = '/home/local/KHQ/jon.crall/data/work/urban_mapper4/test/input_26400-fgetszbh/solver_25800-phpjjsqu_dense_unet_mmavmuou_zeosddyf_a=1,c=RGB,n_ch=6,n_cl=4/_epoch_00000026/stitched'

    # Localize the data