Ejemplo n.º 1
0
def run_eval_suite(name,
                   dest_task=tasks.normal,
                   graph_file=None,
                   model_file=None,
                   logger=None,
                   sample=800,
                   show_images=False,
                   old=False):

    if graph_file is not None:
        graph = TaskGraph(tasks=[tasks.rgb, dest_task], pretrained=False)
        graph.load_weights(graph_file)
        model = graph.edge(tasks.rgb, dest_task).load_model()
    elif old:
        model = DataParallelModel.load(UNetOld().cuda(), model_file)
    elif model_file is not None:
        #model = DataParallelModel.load(UNet(downsample=5).cuda(), model_file)
        model = DataParallelModel.load(UNet(downsample=6).cuda(), model_file)
    else:
        model = Transfer(src_task=tasks.normal,
                         dest_task=dest_task).load_model()

    model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True)

    dataset = ValidationMetrics("almena", dest_task=dest_task)
    result = dataset.evaluate(model, sample=800)
    logger.text(name + ": " + str(result))
Ejemplo n.º 2
0
def s3BuildOps(conf):
    """
    Compare a source folder with what's already in S3 and given
    the direction you specify it should figure out what to do.
    :param src_files:
    :param keyprefix:
    :param bucket:
    :return:
    """
    s3 = Transfer(conf['bucket'])
    opstore = {}
    log = Logger("s3BuildOps")
    prefix = "{0}/".format(conf['keyprefix']).replace("//", "/")

    log.title('The following locations were found:')
    if conf['direction'] == S3Operation.Direction.UP:
        tostr = 's3://{0}/{1}'.format(conf['bucket'], conf['keyprefix'])
        fromstr = conf['localroot']
    else:
        fromstr = 's3://{0}/{1}'.format(conf['bucket'], conf['keyprefix'])
        tostr = conf['localroot']
    log.info('FROM: {0}'.format(fromstr))
    log.info('TO  : {0}'.format(tostr))

    log.title('The following operations are queued:')

    response = s3.list(prefix)

    # Get all the files we have locally
    files = {}
    if os.path.isdir(conf['localroot']):
        files = {}
        localProductWalker(conf['localroot'], files)

    # Fill in any files we find on the remote
    if 'Contents' in response:
        for result in response['Contents']:
            dstkey = result['Key'].replace(prefix, '')
            if dstkey in files:
                files[dstkey]['dst'] = result
            else:
                files[dstkey] = {'dst': result}

    for relname in files:
        fileobj = files[relname]
        opstore[relname] = S3Operation(relname, fileobj, conf)

    if len(opstore) == 0:
        log.info("-- NO Operations Queued --")

    return opstore
Ejemplo n.º 3
0
def run_viz_suite(name,
                  data,
                  dest_task=tasks.depth_zbuffer,
                  graph_file=None,
                  model_file=None,
                  logger=None,
                  old=False,
                  multitask=False,
                  percep_mode=None):

    if graph_file is not None:
        graph = TaskGraph(tasks=[tasks.rgb, dest_task], pretrained=False)
        graph.load_weights(graph_file)
        model = graph.edge(tasks.rgb, dest_task).load_model()
    elif old:
        model = DataParallelModel.load(UNetOld().cuda(), model_file)
    elif multitask:
        model = DataParallelModel.load(
            UNet(downsample=5, out_channels=6).cuda(), model_file)
    elif model_file is not None:
        print('here')
        #model = DataParallelModel.load(UNet(downsample=5).cuda(), model_file)
        model = DataParallelModel.load(UNet(downsample=6).cuda(), model_file)
    else:
        model = Transfer(src_task=tasks.rgb, dest_task=dest_task).load_model()

    model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True)

    # DATA LOADING 1
    results = model.predict(data)[:, -3:].clamp(min=0, max=1)
    if results.shape[1] == 1:
        results = torch.cat([results] * 3, dim=1)

    if percep_mode:
        percep_model = Transfer(src_task=dest_task,
                                dest_task=tasks.normal).load_model()
        percep_model.eval()
        eval_loader = torch.utils.data.DataLoader(
            torch.utils.data.TensorDataset(results),
            batch_size=16,
            num_workers=16,
            shuffle=False,
            pin_memory=True)
        final_preds = []
        for preds, in eval_loader:
            print('preds shape', preds.shape)
            final_preds += [percep_model.forward(preds[:, -3:])]
        results = torch.cat(final_preds, dim=0)

    return results
Ejemplo n.º 4
0
def s3GetFolderList(bucket, prefix):
    """
    Given a path array, ending in a Product, snake through the
    S3 bucket recursively and list all the products available
    :param patharr:
    :param path:
    :param currlevel:
    :return:
    """
    log = Logger('CollectionList')
    s3 = Transfer(bucket)
    results = []
    # list everything at this collection
    response = s3.list(prefix, Delimiter='/')
    if 'CommonPrefixes' in response:
        for o in response.get('CommonPrefixes'):
            results.append(o['Prefix'].replace(prefix, '').replace('/', ''))
    return results
Ejemplo n.º 5
0
def s3ProductWalker(bucket, patharr, currpath=[], currlevel=0):
    """
    Given a path array, ending in a Product, snake through the
    S3 bucket recursively and list all the products available
    :param patharr:
    :param path:
    :param currlevel:
    :return:
    """
    log = Logger('ProductWalk')
    s3 = Transfer(bucket)
    if currlevel >= len(patharr):
        return

    # If it's a collection then we need to iterate over folders and recurse on each
    if patharr[currlevel]['type'] == 'collection':
        # list everything at this collection
        pref = "/".join(currpath) + "/" if len(currpath) > 0 else ""
        result = s3.list(pref, Delimiter='/')
        if 'CommonPrefixes' in result:
            for o in result.get('CommonPrefixes'):
                s3ProductWalker(bucket, patharr,
                                o.get('Prefix')[:-1].split('/'), currlevel + 1)
        else:
            return

    # If it's a container then no iteration necessary. Just append the path and recurse
    elif patharr[currlevel]['type'] == 'group':
        currpath.append(patharr[currlevel]['folder'])
        s3ProductWalker(bucket, patharr, currpath, currlevel + 1)

    # If it's a project then get the XML file and print it
    elif patharr[currlevel]['type'] == 'product':
        currpath.append(patharr[currlevel]['folder'])
        result = s3.list("/".join(currpath) + "/", Delimiter='/')
        if 'Contents' in result:
            for c in result['Contents']:
                if os.path.splitext(c['Key'])[1] == '.xml':
                    log.info('Project: {0} (Modified: {1})'.format(
                        c['Key'], c['LastModified']))
        return
Ejemplo n.º 6
0
def run_perceptual_eval_suite(name,
                              intermediate_task=tasks.normal,
                              dest_task=tasks.normal,
                              graph_file=None,
                              model_file=None,
                              logger=None,
                              sample=800,
                              show_images=False,
                              old=False,
                              perceptual_transfer=None,
                              multitask=False):

    if perceptual_transfer is None:
        percep_model = Transfer(src_task=intermediate_task,
                                dest_task=dest_task).load_model()

    if graph_file is not None:
        graph = TaskGraph(tasks=[tasks.rgb, intermediate_task],
                          pretrained=False)
        graph.load_weights(graph_file)
        model = graph.edge(tasks.rgb, intermediate_task).load_model()
    elif old:
        model = DataParallelModel.load(UNetOld().cuda(), model_file)
    elif multitask:
        print('running multitask')
        model = DataParallelModel.load(
            UNet(downsample=5, out_channels=6).cuda(), model_file)
    elif model_file is not None:
        #model = DataParallelModel.load(UNet(downsample=5).cuda(), model_file)
        model = DataParallelModel.load(UNet(downsample=6).cuda(), model_file)
    else:
        model = Transfer(src_task=tasks.rgb,
                         dest_task=intermediate_task).load_model()

    model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True)

    dataset = ValidationMetrics("almena", dest_task=dest_task)
    result = dataset.evaluate_with_percep(model,
                                          sample=800,
                                          percep_model=percep_model)
    logger.text(name + ": " + str(result))
Ejemplo n.º 7
0
def transfers():
    if request.method == 'GET':
        return render_template('transfers.html',
                               transfers=app.transfers.get_transfers())
    else:
        if request.form['submit'] == "Save":
            season_id = request.form['seasonID']
            old_id = request.form['oldID']
            new_id = request.form['newID']
            player_id = request.form['playerID']
            fee = request.form['fee']
            transfer = Transfer(season_id, player_id, old_id, new_id, fee)
            app.transfers.add_transfer(transfer)
        else:
            id = request.form['id']
            season_id = request.form['seasonID']
            old_id = request.form['oldID']
            new_id = request.form['newID']
            player_id = request.form['playerID']
            fee = request.form['fee']
            transfer = Transfer(season_id, player_id, old_id, new_id, fee)
            app.transfers.update_transfer(id, transfer)

        return redirect(url_for('transfers'))
Ejemplo n.º 8
0
def run_viz_suite(name,
                  data_loader,
                  dest_task=tasks.depth_zbuffer,
                  graph_file=None,
                  model_file=None,
                  old=False,
                  multitask=False,
                  percep_mode=None,
                  downsample=6,
                  out_channels=3,
                  final_task=tasks.normal,
                  oldpercep=False):

    extra_task = [final_task] if percep_mode else []

    if graph_file is not None:
        graph = TaskGraph(tasks=[tasks.rgb, dest_task] + extra_task,
                          pretrained=False)
        graph.load_weights(graph_file)
        model = graph.edge(tasks.rgb, dest_task).load_model()
    elif old:
        model = DataParallelModel.load(UNetOld().cuda(), model_file)
    elif multitask:
        model = DataParallelModel.load(
            UNet(downsample=5, out_channels=6).cuda(), model_file)
    elif model_file is not None:
        # downsample = 5 or 6
        print('loading main model')
        #model = DataParallelModel.load(UNetReshade(downsample=downsample,  out_channels=out_channels).cuda(), model_file)
        model = DataParallelModel.load(
            UNet(downsample=downsample, out_channels=out_channels).cuda(),
            model_file)
        #model = DataParallelModel.load(UNet(downsample=6).cuda(), model_file)
    else:
        model = DummyModel(
            Transfer(src_task=tasks.rgb, dest_task=dest_task).load_model())

    model.compile(torch.optim.Adam, lr=3e-4, weight_decay=2e-6, amsgrad=True)

    # DATA LOADING 1
    results = []
    final_preds = []

    if percep_mode:
        print('Loading percep model...')
        if graph_file is not None and not oldpercep:
            percep_model = graph.edge(dest_task, final_task).load_model()
            percep_model.compile(torch.optim.Adam,
                                 lr=3e-4,
                                 weight_decay=2e-6,
                                 amsgrad=True)
        else:
            percep_model = Transfer(src_task=dest_task,
                                    dest_task=final_task).load_model()
        percep_model.eval()

    print("Converting...")
    for data, in data_loader:
        preds = model.predict_on_batch(data)[:, -3:].clamp(min=0, max=1)
        results.append(preds.detach().cpu())
        if percep_mode:
            try:
                final_preds += [
                    percep_model.forward(preds[:, -3:]).detach().cpu()
                ]
            except RuntimeError:
                preds = torch.cat([preds] * 3, dim=1)
                final_preds += [
                    percep_model.forward(preds[:, -3:]).detach().cpu()
                ]
        #break

    if percep_mode:
        results = torch.cat(final_preds, dim=0)
    else:
        results = torch.cat(results, dim=0)

    return results
Ejemplo n.º 9
0
    def __init__(
        self, tasks, tasks_in={}, tasks_out={},
        pretrained=True,
        freeze_list=[], direct_edges={}, lazy=False,
        model_class='resnet_based', models_dir="./models"
    ):
        super().__init__()
        self.tasks = tasks
        self.tasks += [task.base for task in self.tasks if hasattr(task, "base")]
        self.tasks_in, self.tasks_out = tasks_in, tasks_out
        self.pretrained = pretrained
        self.edges_in, self.edges_out, = {}, {}
        self.direct_edges = direct_edges
        self.freeze_list = freeze_list
        self.edge_map = {}
        print('Creating graph with tasks:', self.tasks)
        self.params = {}
        transfer_models = model_types[model_class]
        
        for task in self.tasks_out.get("edges", None):
            key = str((task.name, "LS"))
            model_type, path = transfer_models.get(task.name, {})["down"]
            path = os.path.join(models_dir, path)
            if not os.path.isfile(path):
                path = None
            transfer = Transfer(
                task, task_configs.tasks.LS,
                model_type=model_type, path=path
            )
                
            transfer.freezed = task in self.tasks_out.get("freeze")
            self.edges_out[task.name] = transfer
            self.edge_map[key] = transfer
            
            try:
                if not lazy:
                    transfer.load_model()
            except Exception as e:
                print(e)
                IPython.embed()
            
            if transfer.freezed:
                transfer.set_requires_grad(False)
            else:
                self.params[key] = transfer
        
        for task in self.tasks_in.get("edges", None):
            key = str(("LS", task.name))
            model_type, path = transfer_models.get(task.name, {})["up"]
            path = os.path.join(models_dir, path)
            if not os.path.isfile(path):
                path = None
            transfer = Transfer(
                task_configs.tasks.LS, task,
                model_type=model_type, path=path
            )
            transfer.freezed = task in self.tasks_in.get("freeze")
            self.edges_in[task.name] = transfer
            self.edge_map[key] = transfer
            
            try:
                if not lazy:
                    transfer.load_model()
            except Exception as e:
                print(e)
                IPython.embed()
            
            if transfer.freezed: 
                transfer.set_requires_grad(False)
            else:
                self.params[key] = transfer
            
        
        # construct transfer graph
        for src_task, dest_task in itertools.product(self.tasks, self.tasks):
            key = str((src_task.name, dest_task.name))
            transfer = None
            if src_task==dest_task: continue
            if isinstance(dest_task, RealityTask): continue
            if src_task==task_configs.tasks.LS or dest_task==task_configs.tasks.LS:
                continue
            if isinstance(src_task, RealityTask):
                transfer = RealityTransfer(src_task, dest_task)
                self.edge_map[key] = transfer
            elif key in self.direct_edges:
                transfer = Transfer(src_task, dest_task, pretrained=pretrained)
                transfer.freezed = key in self.freeze_list
                
                try:
                    if not lazy: transfer.load_model()
                except Exception as e:
                    print(e)
                    IPython.embed()
                
                if transfer.model_type is None:
                    continue
                if not transfer.freezed:
                    self.params[key] = transfer
                else:
                    print("Setting link: " + str(key) + " not trainable.")
                    transfer.set_requires_grad(False)
            else: continue

            self.edge_map[key] = transfer
        
        self.params = nn.ModuleDict(self.params)
Ejemplo n.º 10
0
class S3Operation:
    """
    A Simple class for storing src/dst file information and the operation we need to perform
    """
    class FileOps:
        # Kind of an enumeration
        DELETE_REMOTE = "Delete Remote"
        DELETE_LOCAL = "Delete Local"
        UPLOAD = "Upload"
        DOWNLOAD = "Download"
        IGNORE = "Ignore"

    class Direction:
        # Kind of an enumeration
        UP = "up"
        DOWN = "down"

    class FileState:
        # Kind of an enumeration
        LOCALONLY = "Local-Only"
        REMOTEONLY = "Remote-Only"
        UPDATENEEDED = "Update Needed"
        SAME = "Files Match"

    def __init__(self, key, fileobj, conf):
        """
        :param key: The relative key/path of the file in question
        :param fileobj: the file object with 'src' and 'dst'
        :param conf: the configuration dictionary
        """
        self.log = Logger('S3Ops')
        self.s3 = Transfer(conf['bucket'])
        self.key = key

        # Set some sensible defaults
        self.filestate = self.FileState.SAME
        self.op = self.FileOps.IGNORE

        self.delete = conf['delete']
        self.force = conf['force']
        self.localroot = conf['localroot']
        self.bucket = conf['bucket']
        self.direction = conf['direction']
        self.keyprefix = conf['keyprefix']
        self.s3size = 0

        # And the final paths we use:
        self.abspath = self.getAbsLocalPath()
        self.fullkey = self.getS3Key()

        # The remote size (if it exists) helps us figure out percent done
        if 'dst' in fileobj:
            self.s3size = fileobj['dst']['Size']

        # Figure out what we have
        if 'src' in fileobj and 'dst' not in fileobj:
            self.filestate = self.FileState.LOCALONLY

        if 'src' not in fileobj and 'dst' in fileobj:
            self.filestate = self.FileState.REMOTEONLY

        if 'src' in fileobj and 'dst' in fileobj:
            if s3issame(fileobj['src'], fileobj['dst']):
                self.filestate = self.FileState.SAME
            else:
                self.filestate = self.FileState.UPDATENEEDED

        # The Upload Case
        # ------------------------------
        if self.direction == self.Direction.UP:
            # Two cases for uploading the file: New file or different file
            if self.filestate == self.FileState.LOCALONLY or self.filestate == self.FileState.UPDATENEEDED:
                self.op = self.FileOps.UPLOAD

            # If we've requested a force, do the upload anyway
            elif self.FileState.SAME and self.force:
                self.op = self.FileOps.UPLOAD

            # If the remote is there but the local is not and we're uploading then clean up the remote
            # this requires thed delete flag be set
            elif self.filestate == self.FileState.REMOTEONLY and self.delete:
                self.op = self.FileOps.DELETE_REMOTE

        # The Download Case
        # ------------------------------
        elif self.direction == self.Direction.DOWN:
            if self.filestate == self.FileState.REMOTEONLY or self.filestate == self.FileState.UPDATENEEDED:
                self.op = self.FileOps.DOWNLOAD

            # If we've requested a force, do the download anyway
            elif self.FileState.SAME and self.force:
                self.op = self.FileOps.DOWNLOAD

            # If the local is there but the remote is not and we're downloading then clean up the local
            # this requires thed delete flag be set
            elif self.filestate == self.FileState.LOCALONLY and self.delete:
                self.op = self.FileOps.DELETE_LOCAL

        self.log.info(str(self))

    def getS3Key(self):
        # Not using path.join because can't be guaranteed a unix system
        return "{1}/{2}".format(self.bucket, self.keyprefix, self.key)

    def getAbsLocalPath(self):
        # Not using path.join because can't be guaranteed a unix system
        return os.path.join(self.localroot, self.key)

    def execute(self):
        """
        Actually run the command to upload/download/delete the file
        :return:
        """

        if self.op == self.FileOps.IGNORE:
            self.log.info(" [{0}] {1}: Nothing to do. Continuing.".format(
                self.op, self.key))

        elif self.op == self.FileOps.UPLOAD:
            self.upload()

        elif self.op == self.FileOps.DOWNLOAD:
            self.download()

        elif self.op == self.FileOps.DELETE_LOCAL:
            self.delete_local()

        elif self.op == self.FileOps.DELETE_REMOTE:
            self.delete_remote()

    def __repr__(self):
        """
        When we print this class as a string this is what we output
        """
        forcestr = "(FORCE)" if self.force else ""
        opstr = "{0:12s} ={2}=> {1:10s}".format(self.filestate, self.op,
                                                forcestr)
        return "./{1:60s} [ {0:21s} ]".format(opstr.strip(), self.key)

    def delete_remote(self):
        """
        Delete a Remote file
        """
        self.log.info("Deleting: {0} ==> ".format(self.fullkey))
        # This step prints straight to stdout and does not log
        self.s3.delete(self.fullkey)
        self.log.debug("S3 Deletion Completed: {0}".format(self.fullkey))

    def delete_local(self):
        """
        Delete a local file
        """
        dirname = os.path.dirname(self.abspath)
        os.remove(self.abspath)
        self.log.info("Deleting Local file: {0} ==> ".format(self.abspath))
        # now walk backwards and clean up empty folders
        try:
            os.removedirs(dirname)
            self.log.debug('Cleaning up folders: {0}'.format(dirname))
        except:
            self.log.debug(
                'Folder cleanup stopped since there were still files: {0}'.
                format(dirname))
            pass
        self.log.debug("Local Deletion Completed: {0}".format(self.abspath))

    def download(self):
        """
        Just upload one file using Boto3
        :param bucket:
        :param key:
        :param filepath:
        :return:
        """
        log = Logger('S3FileDownload')

        # Make a directory if that's needed
        dirpath = os.path.dirname(self.abspath)
        if not os.path.exists(dirpath):
            try:
                os.makedirs(dirpath)
            except Exception as e:
                raise Exception(
                    "ERROR: Directory `{0}` could not be created.".format(
                        dirpath))

        log.info("Downloading: {0} ==> ".format(self.fullkey))
        # This step prints straight to stdout and does not log
        self.s3.download(self.fullkey, self.abspath, size=self.s3size)
        print ""
        log.debug("Download Completed: {0}".format(self.abspath))

    def upload(self):
        """
        Just upload one file using Boto3
        :param bucket:
        :param key:
        :param filepath:
        :return:
        """
        log = Logger('S3FileUpload')

        log.info("Uploading: {0} ==> s3://{1}/{2}".format(
            self.abspath, self.bucket, self.fullkey))
        # This step prints straight to stdout and does not log
        self.s3.upload(self.abspath, self.fullkey)
        print ""
        log.debug("Upload Completed: {0}".format(self.abspath))
Ejemplo n.º 11
0
    def __init__(self, key, fileobj, conf):
        """
        :param key: The relative key/path of the file in question
        :param fileobj: the file object with 'src' and 'dst'
        :param conf: the configuration dictionary
        """
        self.log = Logger('S3Ops')
        self.s3 = Transfer(conf['bucket'])
        self.key = key

        # Set some sensible defaults
        self.filestate = self.FileState.SAME
        self.op = self.FileOps.IGNORE

        self.delete = conf['delete']
        self.force = conf['force']
        self.localroot = conf['localroot']
        self.bucket = conf['bucket']
        self.direction = conf['direction']
        self.keyprefix = conf['keyprefix']
        self.s3size = 0

        # And the final paths we use:
        self.abspath = self.getAbsLocalPath()
        self.fullkey = self.getS3Key()

        # The remote size (if it exists) helps us figure out percent done
        if 'dst' in fileobj:
            self.s3size = fileobj['dst']['Size']

        # Figure out what we have
        if 'src' in fileobj and 'dst' not in fileobj:
            self.filestate = self.FileState.LOCALONLY

        if 'src' not in fileobj and 'dst' in fileobj:
            self.filestate = self.FileState.REMOTEONLY

        if 'src' in fileobj and 'dst' in fileobj:
            if s3issame(fileobj['src'], fileobj['dst']):
                self.filestate = self.FileState.SAME
            else:
                self.filestate = self.FileState.UPDATENEEDED

        # The Upload Case
        # ------------------------------
        if self.direction == self.Direction.UP:
            # Two cases for uploading the file: New file or different file
            if self.filestate == self.FileState.LOCALONLY or self.filestate == self.FileState.UPDATENEEDED:
                self.op = self.FileOps.UPLOAD

            # If we've requested a force, do the upload anyway
            elif self.FileState.SAME and self.force:
                self.op = self.FileOps.UPLOAD

            # If the remote is there but the local is not and we're uploading then clean up the remote
            # this requires thed delete flag be set
            elif self.filestate == self.FileState.REMOTEONLY and self.delete:
                self.op = self.FileOps.DELETE_REMOTE

        # The Download Case
        # ------------------------------
        elif self.direction == self.Direction.DOWN:
            if self.filestate == self.FileState.REMOTEONLY or self.filestate == self.FileState.UPDATENEEDED:
                self.op = self.FileOps.DOWNLOAD

            # If we've requested a force, do the download anyway
            elif self.FileState.SAME and self.force:
                self.op = self.FileOps.DOWNLOAD

            # If the local is there but the remote is not and we're downloading then clean up the local
            # this requires thed delete flag be set
            elif self.filestate == self.FileState.LOCALONLY and self.delete:
                self.op = self.FileOps.DELETE_LOCAL

        self.log.info(str(self))
Ejemplo n.º 12
0
    def __init__(
        self,
        tasks=tasks,
        edges=None,
        edges_exclude=None,
        pretrained=True,
        finetuned=False,
        reality=[],
        task_filter=[tasks.segment_semantic],
        freeze_list=[],
        lazy=False,
        initialize_from_transfer=True,
    ):

        super().__init__()
        self.tasks = list(set(tasks) - set(task_filter))
        self.tasks += [
            task.base for task in self.tasks if hasattr(task, "base")
        ]
        self.edge_list, self.edge_list_exclude = edges, edges_exclude
        self.pretrained, self.finetuned = pretrained, finetuned
        self.edges, self.adj, self.in_adj = [], defaultdict(list), defaultdict(
            list)
        self.edge_map, self.reality = {}, reality
        self.initialize_from_transfer = initialize_from_transfer
        print('Creating graph with tasks:', self.tasks)
        self.params = {}

        # construct transfer graph
        for src_task, dest_task in itertools.product(self.tasks, self.tasks):
            key = (src_task, dest_task)
            if edges is not None and key not in edges: continue
            if edges_exclude is not None and key in edges_exclude: continue
            if src_task == dest_task: continue
            if isinstance(dest_task, RealityTask): continue
            # print (src_task, dest_task)
            transfer = None
            if isinstance(src_task, RealityTask):
                if dest_task not in src_task.tasks: continue
                transfer = RealityTransfer(src_task, dest_task)
            else:
                transfer = Transfer(src_task,
                                    dest_task,
                                    pretrained=pretrained,
                                    finetuned=finetuned)
                transfer.name = get_transfer_name(transfer)
                if not self.initialize_from_transfer:
                    transfer.path = None
            if transfer.model_type is None:
                continue
            # print ("Added transfer", transfer)
            self.edges += [transfer]
            self.adj[src_task.name] += [transfer]
            self.in_adj[dest_task.name] += [transfer]
            self.edge_map[str((src_task.name, dest_task.name))] = transfer
            if isinstance(transfer, nn.Module):
                if str((src_task.name, dest_task.name)) not in freeze_list:
                    self.params[str(
                        (src_task.name, dest_task.name))] = transfer
                else:
                    print("Setting link: " +
                          str((src_task.name, dest_task.name)) +
                          " not trainable.")
                try:
                    if not lazy: transfer.load_model()
                except Exception as e:
                    print(e)
                    IPython.embed()

        self.params = nn.ModuleDict(self.params)
Ejemplo n.º 13
0
    def __init__(
        self,
        tasks=tasks,
        realities=None,
        edges=None,
        edges_exclude=None,
        pretrained=True,
        finetuned=False,
        reality=[],
        task_filter=[tasks.segment_semantic, tasks.class_scene],
        freeze_list=[],
        lazy=False,
    ):

        super().__init__(tasks=[])
        self.tasks = list(set(tasks) - set(task_filter))
        self.tasks += [
            task.base for task in self.tasks if hasattr(task, "base")
        ]
        self.edge_list, self.edge_list_exclude = edges, edges_exclude
        self.pretrained, self.finetuned = pretrained, finetuned
        self.edges, self.adj, self.in_adj = [], defaultdict(list), defaultdict(
            list)
        self.edge_map, self.reality = {}, reality
        print('graph tasks!', self.tasks)
        self.params = nn.ModuleDict()
        self.realities = realities

        # RGB -> Normal
        transfer = Transfer(TASKS.rgb,
                            TASKS.normal,
                            pretrained=pretrained,
                            finetuned=finetuned)
        transfer.name = get_transfer_name(transfer)
        self.params[str((TASKS.rgb.name, TASKS.normal.name))] = transfer
        self.edge_map[str((TASKS.rgb.name, TASKS.normal.name))] = transfer
        self.edges += [transfer]
        try:
            if not lazy: transfer.load_model()
        except:
            print('Cound not load model:',
                  str((TASKS.rgb.name, TASKS.normal.name)))
            IPython.embed()

        # RGB -> Depth
        transfer = Transfer(TASKS.rgb,
                            TASKS.depth_zbuffer,
                            pretrained=pretrained,
                            finetuned=finetuned)
        transfer.name = get_transfer_name(transfer)
        self.params[str((TASKS.rgb.name, TASKS.depth_zbuffer.name))] = transfer
        self.edge_map[str(
            (TASKS.rgb.name, TASKS.depth_zbuffer.name))] = transfer
        try:
            if not lazy: transfer.load_model()
        except:
            print('Cound not load model:',
                  str((TASKS.rgb.name, TASKS.depth_zbuffer.name)))
            IPython.embed()

        # Depth -> Normals
        src_task = (TASKS.depth_zbuffer.name, TASKS.FoV.name,
                    TASKS.normal.name)
        target_task = TASKS.normal
        transfer_name = str((src_task, target_task.name))
        model_type, path = pretrained_transfers[(src_task, target_task.name)]
        transfer = Transfer(src_task,
                            target_task,
                            pretrained=pretrained,
                            model_type=model_type,
                            path=path,
                            checkpoint=False,
                            finetuned=finetuned,
                            name=f"{src_task}2{target_task.name}")
        transfer.name = transfer_name
        self.params[transfer_name] = transfer
        self.edge_map[transfer_name] = transfer
        try:
            if not lazy: transfer.load_model()
        except:
            print('Cound not load model:', transfer_name)
            IPython.embed()

        # Normal -> Depth
        src_task = (TASKS.depth_zbuffer.name, TASKS.FoV.name,
                    TASKS.normal.name)
        target_task = TASKS.depth_zbuffer
        transfer_name = str((src_task, target_task.name))
        model_type, path = pretrained_transfers[(src_task, target_task.name)]
        transfer = Transfer(src_task,
                            target_task,
                            pretrained=pretrained,
                            model_type=model_type,
                            path=path,
                            checkpoint=False,
                            finetuned=finetuned,
                            name=f"{src_task}2{target_task.name}")
        transfer.name = transfer_name
        self.params[transfer_name] = transfer
        self.edge_map[transfer_name] = transfer
        try:
            if not lazy: transfer.load_model()
        except:
            print('Cound not load model:', transfer_name)
            IPython.embed()

        for src_task, dest_task in itertools.product(self.realities,
                                                     self.tasks):
            key = (src_task, dest_task)
            if edges is not None and key not in edges: continue
            if edges_exclude is not None and key in edges_exclude: continue
            if src_task == dest_task: continue
            if isinstance(dest_task, RealityTask): continue
            transfer = None
            if isinstance(src_task, RealityTask):
                if dest_task not in src_task.tasks: continue
                transfer = RealityTransfer(src_task, dest_task)
            else:
                transfer = Transfer(src_task,
                                    dest_task,
                                    pretrained=pretrained,
                                    finetuned=finetuned)
                transfer.name = get_transfer_name(transfer)
            if transfer.model_type is None:
                continue
            self.edges += [transfer]
            self.adj[src_task.name] += [transfer]
            self.in_adj[dest_task.name] += [transfer]
            self.edge_map[str((src_task.name, dest_task.name))] = transfer
            if isinstance(transfer, nn.Module):
                if str((src_task.name, dest_task.name)) not in freeze_list:
                    self.params[str(
                        (src_task.name, dest_task.name))] = transfer
                else:
                    print("freezing " + str((src_task.name, dest_task.name)))
                try:
                    if not lazy: transfer.load_model()
                except:
                    print('Cound not load model:',
                          str((src_task.name, dest_task.name)))
                    IPython.embed()