Пример #1
0
 def test_pluck(self):
     from stuf import stuf
     stooges = [
         stuf(name='moe', age=40),
         stuf(name='larry', age=50),
         stuf(name='curly', age=60)
     ]
     self.assertEqual(
         self.qclass(*stooges).pluck('name').value(),
         ['moe', 'larry', 'curly'],
     )
     self.assertEqual(
         self.qclass(*stooges).pluck('name', 'age').value(),
         [('moe', 40), ('larry', 50), ('curly', 60)],
     )
     stooges = [['moe', 40], ['larry', 50], ['curly', 60]]
     self.assertEqual(
         self.qclass(*stooges).pluck(0).value(),
         ['moe', 'larry', 'curly'],
     )
     self.assertEqual(
         self.qclass(*stooges).pluck(1).value(),
         [40, 50, 60],
     )
     self.assertEqual(
         self.qclass(*stooges).pluck('place').value(), [],
     )
Пример #2
0
 def test_pick(self):
     from stuf import stuf
     stooges = [
         stuf(name='moe', age=40),
         stuf(name='larry', age=50),
         stuf(name='curly', age=60)
     ]
     manq = self.qclass(*stooges).pick('name')
     self.assertTrue(manq.balanced)
     manq.sync()
     self.assertTrue(manq.balanced)
     self.assertEqual(manq.value(), ['moe', 'larry', 'curly'])
     self.assertFalse(manq.balanced)
     manq = self.qclass(*stooges).pick('name', 'age')
     self.assertTrue(manq.balanced)
     manq.sync()
     self.assertTrue(manq.balanced)
     self.assertEqual(
         manq.value(), [('moe', 40), ('larry', 50), ('curly', 60)],
     )
     self.assertFalse(manq.balanced)
     manq = self.qclass(*stooges).pick('place')
     self.assertFalse(manq.balanced)
     manq.sync()
     self.assertTrue(manq.balanced)
     self.assertEqual(manq.value(), [])
     self.assertTrue(manq.balanced)
Пример #3
0
    def _setup_model_config(self, cfg):
        config = stuf({'name': cfg.model})
        config.num_epochs = cfg.num_epochs
        config.batch_size = cfg.batch_size
        # config.num_classes = cfg.data_config.dataset_A.num_classes
        config.frozen_layers = None
        config.base_learning_rate = cfg.learning_rate
        config.regularization = {'l2': 0.01}
        config.model_dir = os.path.join(cfg.user_config.user_model_dir,
                                        '_'.join(cfg.datasets))
        config.log_dir = os.path.join(config.model_dir, 'logs',
                                      cfg.user_config.experiment_start_time)

        if cfg.model == 'resnet_50_v2':
            # from tensorflow.keras.applications.resnet_v2 import preprocess_input
            config.preprocessing_module = 'tensorflow.keras.applications.resnet_v2'
            config.input_shape = (224, 224, 1)

        elif cfg.model == 'vgg16':
            # from tensorflow.keras.applications.vgg16 import preprocess_input
            config.preprocessing_module = 'tensorflow.keras.applications.vgg16'
            config.input_shape = (224, 224, 1)

        # config.preprocess_func = lambda x,y: (preprocess_input(x), y)
        return stuf({'model_config': config})
Пример #4
0
 def test_factory(self):
     from stuf import stuf
     thing = self.mclass(
         [('a', 1), ('b', 2), ('c', 3)], [('a', 1), ('b', 2), ('c', 3)]
     ).worker(stuf).map().get()
     self.assertEqual(
         thing, [stuf(a=1, b=2, c=3), stuf(a=1, b=2, c=3)]
     )
Пример #5
0
 def test_max(self):
     from stuf import stuf
     stooge = [
         stuf(name='moe', age=40),
         stuf(name='larry', age=50),
         stuf(name='curly', age=60),
     ]
     self.assertEqual(self.mclass(1, 2, 4).max().get(), 4)
     self.assertEqual(
         stuf(self.mclass(*stooge).worker(lambda x: x.age).max().get()),
         stuf(name='curly', age=60),
     )
Пример #6
0
 def test_new_child(self):
     from stuf.core import stuf
     stuffed = self.stuf.new_child()
     self.assertEquals(
         stuffed.maps,
         [
             stuf(),
             stuf([('test1', 'test1')]),
             stuf(test2='test2'),
             stuf(test3=stuf(e=1))
         ],
     )
Пример #7
0
 def test_max(self):
     from stuf import stuf
     stooges = [
         stuf(name='moe', age=40),
         stuf(name='larry', age=50),
         stuf(name='curly', age=60)
     ]
     self.assertEquals(
         stuf(self.qclass(*stooges).tap(lambda x: x.age).max().value()),
         stuf(name='curly', age=60),
     )
     self.assertEquals(self.qclass(1, 2, 4).max().value(), 4)
Пример #8
0
 def test_xattrs(self):
     from stuf import stuf
     from blade.xfilter import xattrs
     stooge = [
         stuf(name='moe', age=40),
         stuf(name='larry', age=50),
         stuf(name='curly', age=60)
     ]
     self.assertEqual(list(xattrs(stooge, 'name')),
                      ['moe', 'larry', 'curly'])
     self.assertEqual(
         list(xattrs(stooge, 'name', 'age')),
         [('moe', 40), ('larry', 50), ('curly', 60)],
     )
     self.assertEqual(list(xattrs(stooge, 'place')), [])
Пример #9
0
 def test_xattrs(self):
     from stuf import stuf
     from blade.xfilter import xattrs
     stooge = [
         stuf(name='moe', age=40),
         stuf(name='larry', age=50),
         stuf(name='curly', age=60)
     ]
     self.assertEqual(
         list(xattrs(stooge, 'name')), ['moe', 'larry', 'curly']
     )
     self.assertEqual(
         list(xattrs(stooge, 'name', 'age')),
         [('moe', 40), ('larry', 50), ('curly', 60)],
     )
     self.assertEqual(list(xattrs(stooge, 'place')), [])
Пример #10
0
 def test_max(self):
     self._false_true_false(
         self.qclass(1, 2, 4).max(), self.assertEqual, 4,
     )
     from stuf import stuf
     stooges = [
         stuf(name='moe', age=40),
         stuf(name='larry', age=50),
         stuf(name='curly', age=60),
     ]
     manq = self.qclass(*stooges).tap(lambda x: x.age).max()
     self.assertFalse(manq.balanced)
     manq.sync()
     self.assertTrue(manq.balanced)
     self.assertEqual(stuf(manq.end()), stuf(name='curly', age=60))
     self.assertTrue(manq.balanced)
Пример #11
0
 def get_parser(self, prog_name):
     parser = super(Drain, self).get_parser(prog_name)
     parser.formater_class = argparse.ArgumentDefaultsHelpFormatter
     parser.add_argument("-c", "--config", type=self.get_config, default=stuf(), help="Where to find the settings for the Drain")
     parser.add_argument('--queue', '-q', default=CLIApp.prefix, help="Queue name")
     self.parser = parser
     return parser
Пример #12
0
 def test_wrap(self):
     from stuf import stuf
     self.assertDictEqual(
         self.qclass(
             ('a', 1), ('b', 2), ('c', 3)
         ).reup().wrap(stuf).map().value(),
         stuf(a=1, b=2, c=3),
     )
Пример #13
0
 def mock_ns_tools(self, ns, ns_name, tools):
     ns._original_tools = tools.copy()
     for name, cbl in tools.items():
         mock_name = "%s.%s" % (ns_name, name)
         cls_s = self.setdefault(ns_name, stuf())
         rmock = cls_s[name] = Mock(name=mock_name)
         tools[name] = partial(cbl, runner=rmock)
     return tools
Пример #14
0
 def _setup_logger_config(self, cfg):
     config = stuf({
         'mlflow_tracking_dir':
         r'/media/data/jacob/Fossil_Project/experiments/mlflow',
         'program_log_file':
         os.path.join(cfg.model_config.log_dir, 'mlflow_log.txt')
     })
     return {'logger': config}
Пример #15
0
 def test_factory(self):
     from stuf import stuf
     self.assertDictEqual(
         self.qclass(
             ('a', 1), ('b', 2), ('c', 3)
         ).reup().factory(stuf).map().end(),
         stuf(a=1, b=2, c=3),
     )
Пример #16
0
 def to_proc_info(proc, **kw):
     proc.get_cpu_percent(0.25)  # sacrifice chicken
     raw = proc.as_dict()
     raw["time"] = time.time()
     raw.update(kw)
     out = stuf.stuf(raw)
     if out.cpu_percent == None or out.cpu_percent == 0.0:
         out.cpu_percent = proc.get_cpu_percent(0.5)
     return out
Пример #17
0
 def catch_exc(self, name=None):
     status = stuf()
     try:
         yield status
     except Exception, e:
         if self.debug:
             import pdb; pdb.post_mortem(sys.exc_info()[2])
         self.log.exception('%s failed: %s\n %s\n', 
                            name, e, pp.pformat(dict(status)))
Пример #18
0
 def to_proc_info(proc, **kw):
     proc.get_cpu_percent(.25)  # sacrifice chicken
     raw = proc.as_dict()
     raw['time'] = time.time()
     raw.update(kw)
     out = stuf.stuf(raw)
     if out.cpu_percent == None or out.cpu_percent == 0.0:
         out.cpu_percent = proc.get_cpu_percent(0.5)
     return out
Пример #19
0
 def test_attributes(self):
     from stuf import stuf
     stooge = [
         stuf(name='moe', age=40),
         stuf(name='larry', age=50),
         stuf(name='curly', age=60)
     ]
     self.assertEqual(
         self.mclass(*stooge).attrs('name').get(),
         ['moe', 'larry', 'curly'],
     )
     self.assertEqual(
         self.mclass(*stooge).attrs('name', 'age').get(),
         [('moe', 40), ('larry', 50), ('curly', 60)],
     )
     self.assertEqual(
         self.mclass(*stooge).attrs('place').get(), [],
     )
Пример #20
0
 def test_pick(self):
     from stuf import stuf
     stooges = [
         stuf(name='moe', age=40),
         stuf(name='larry', age=50),
         stuf(name='curly', age=60)
     ]
     self.assertEqual(
         self.qclass(*stooges).pick('name').end(),
         ['moe', 'larry', 'curly'],
     )
     self.assertEqual(
         self.qclass(*stooges).pick('name', 'age').end(),
         [('moe', 40), ('larry', 50), ('curly', 60)],
     )
     self.assertEqual(
         self.qclass(*stooges).pick('place').end(), [],
     )
Пример #21
0
    def test_bulk_add_pkg_regen_error(self, getreg):
        with patch('cheeseprism.index.IndexManager.regenerate_leaf') as rl:
            rl.side_effect = ValueError('BAAAAAAD')
            from cheeseprism.index import bulk_add_pkgs

            idx = self.make_one()
            pkg = stuf(name='dummypackage', version='0.1',
                       filename=self.dummy.name)
            pkgs = pkg,
            assert bulk_add_pkgs(idx, pkgs)
Пример #22
0
 def catch_exc(self, name=None):
     status = stuf()
     try:
         yield status
     except Exception, e:
         if self.debug:
             import pdb
             pdb.post_mortem(sys.exc_info()[2])
         self.log.exception('%s failed: %s\n %s\n', name, e,
                            pp.pformat(dict(status)))
Пример #23
0
 def mock_ns_tools(self, ns, ns_name, tools):
     ns._original_tools = tools.copy()
     if tools:
         for name, cbl in tools.items():
             mock_name = "%s.%s" % (ns_name, name)
             cls_s = self.setdefault(ns_name, stuf())
             rmock = cls_s[name] = Mock(name=mock_name)
             tools[name] = partial(cbl, runner=rmock)
         ns.init_attrs(ns, ns._tools)
     return tools
Пример #24
0
    def test_bulk_add_pkg_regen_error(self, getreg):
        with patch('cheeseprism.index.IndexManager.regenerate_leaf') as rl:
            rl.side_effect = ValueError('BAAAAAAD')
            from cheeseprism.index import bulk_add_pkgs

            idx = self.make_one()
            pkg = stuf(name='dummypackage',
                       version='0.1',
                       filename=self.dummy.name)
            pkgs = pkg,
            assert bulk_add_pkgs(idx, pkgs)
Пример #25
0
 def parse(self, args=None, namespace=None):
     # import pdb;pdb.set_trace()
     cfg = self.parser.parse_args(args, namespace)
     if len(cfg.__dict__) > 0: cfg = cfg.__dict__
     cfg = stuf(dict(cfg))
     cfg['user_config'] = self._setup_user_config()
     cfg['data_config'] = self._setup_data_config(cfg)
     cfg.update(self._setup_model_config(cfg))
     cfg.update(self._setup_logger_config(cfg))
     # os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu)
     # print('os.environ["CUDA_VISIBLE_DEVICES"] : ',os.environ["CUDA_VISIBLE_DEVICES"])
     return cfg
Пример #26
0
def accept_connections() -> None:
    """Accepts incoming connection requests to the server socket"""
    global server_socket
    global client_sockets

    while True:
        client_socket, address_info = server_socket.accept()
        thread = Thread(target=process_client_events, args=(client_socket, ))

        clients[client_socket] = stuf(thread=thread, user=None)

        thread.start()
Пример #27
0
 def test_max(self):
     from stuf import stuf
     stooges = [
         stuf(name='moe', age=40),
         stuf(name='larry', age=50),
         stuf(name='curly', age=60)
     ]
     manq = self.qclass(*stooges).tap(lambda x: x.age).max()
     self.assertFalse(manq.balanced)
     manq.sync()
     self.assertTrue(manq.balanced)
     self.assertEquals(
         stuf(manq.value()), stuf(name='curly', age=60),
     )
     self.assertFalse(manq.balanced)
     manq = self.qclass(1, 2, 4).max()
     self.assertFalse(manq.balanced)
     manq.sync()
     self.assertTrue(manq.balanced)
     self.assertEquals(manq.value(), 4)
     self.assertFalse(manq.balanced)
Пример #28
0
 def test_items(self):
     from stuf import stuf
     stooge = [
         stuf(name='moe', age=40),
         stuf(name='larry', age=50),
         stuf(name='curly', age=60)
     ]
     self.assertEqual(
         self.mclass(*stooge).items('name').get(),
         ['moe', 'larry', 'curly'],
     )
     self.assertEqual(
         self.mclass(*stooge).items('name', 'age').get(),
         [('moe', 40), ('larry', 50), ('curly', 60)],
     )
     stooge = [['moe', 40], ['larry', 50], ['curly', 60]]
     self.assertEqual(
         self.mclass(*stooge).items(0).get(), ['moe', 'larry', 'curly'],
     )
     self.assertEqual(self.mclass(*stooge).items(1).get(), [40, 50, 60])
     self.assertEqual(self.mclass(*stooge).items('place').get(), [])
Пример #29
0
 def _setup_user_config(self):
     return stuf({
         'experiments_db':
         EXPERIMENTS_DB,
         'user_csv_dir':
         r'/media/data_cifs/jacob/Fossil_Project/data/csv_data',
         'user_tfrecord_dir':
         r'/media/data/jacob/Fossil_Project/data/tfrecord_data',
         'user_model_dir':
         r'/media/data_cifs/jacob/Fossil_Project/models',
         'experiment_start_time':
         datetime.datetime.now().strftime('%a-%m-%d-%Y_%H-%M-%S')
     })
Пример #30
0
 def test_upload(self, wb):
     from cheeseprism.views import upload
     self.setup_event()
     context, request = self.base_cr
     request.method = 'POST'
     request.POST['content'] = FakeFS(path('dummypackage/dist/dummypackage-0.0dev.tar.gz'))
     with patch('cheeseprism.index.IndexManager.pkginfo_from_file',
                return_value=stuf(name='dummycode', version='0.0dev')) as pkif:
         res = upload(context, request)
         assert pkif.called
         assert 'PackageAdded' in self.event_results
         assert self.event_results['PackageAdded'][0].name == pkif.return_value.name
     assert res.headers == {'X-Swalow-Status': 'SUCCESS'}
Пример #31
0
    def testStuf(self):
        try:
            from stuf import stuf
        except ImportError:
            return

        items = [('a', 1), ('c', 'a string'), ('b', 2), ('d', 5.0455)]
        d = dict(items)
        s = stuf(items)

        self.assertEqual(d, s)

        self.assertEqual(sh.superhash(d), sh.superhash(s))
Пример #32
0
    def test_upload(self, wb):
        from cheeseprism.views import upload

        self.setup_event()
        context, request = self.base_cr
        request.method = "POST"
        request.POST["content"] = FakeFS(path("dummypackage/dist/dummypackage-0.0dev.tar.gz"))
        with patch(
            "cheeseprism.index.IndexManager.pkginfo_from_file", return_value=stuf(name="dummycode", version="0.0dev")
        ) as pkif:
            res = upload(context, request)
            assert pkif.called
            assert "PackageAdded" in self.event_results
            assert self.event_results["PackageAdded"][0].name == pkif.return_value.name
        assert res.headers == {"X-Swalow-Status": "SUCCESS"}
Пример #33
0
    def test_bulk_add_pkg(self, getreg):
        from cheeseprism.index import bulk_add_pkgs
        self.im = self.make_one()
        pkg = stuf(name='dummypackage', version='0.1',
                   filename=self.dummy.name)
        pkgs = pkg,
        index = Mock(name='index')
        index.path = self.im.path
        leaves, archs = bulk_add_pkgs(index, pkgs)

        assert len(archs) == 1
        assert len(leaves) == 1
        assert 'dummypackage' in leaves
        assert archs[0].basename() == u'dummypackage-0.0dev.tar.gz'
        assert index.regenerate_leaf.called
Пример #34
0
    def test_bulk_add_pkg(self, getreg):
        from cheeseprism.index import bulk_add_pkgs
        self.im = self.make_one()
        pkg = stuf(name='dummypackage',
                   version='0.1',
                   filename=self.dummy.name)
        pkgs = pkg,
        index = Mock(name='index')
        index.path = self.im.path
        leaves, archs = bulk_add_pkgs(index, pkgs)

        assert len(archs) == 1
        assert len(leaves) == 1
        assert 'dummypackage' in leaves
        assert archs[0].basename() == u'dummypackage-0.0dev.tar.gz'
        assert index.regenerate_leaf.called
Пример #35
0
 def test_upload(self, wb):
     from cheeseprism.views import upload
     self.setup_event()
     context, request = self.base_cr
     request.method = 'POST'
     request.POST['content'] = FakeFS(
         path('dummypackage/dist/dummypackage-0.0dev.tar.gz'))
     with patch('cheeseprism.index.IndexManager.pkginfo_from_file',
                return_value=stuf(name='dummycode',
                                  version='0.0dev')) as pkif:
         res = upload(context, request)
         assert pkif.called
         assert 'PackageAdded' in self.event_results
         assert self.event_results['PackageAdded'][
             0].name == pkif.return_value.name
     assert res.headers == {'X-Swalow-Status': 'SUCCESS'}
Пример #36
0
    def test_notify_packages_added(self, getreg):
        from cheeseprism.index import notify_packages_added
        pkg = stuf(name='dummypackage', version='0.1',
                   filename=self.dummy.name)
        pkgs = pkg,
        index = Mock(name='index')
        index.path = self.im.path
        reg = getreg.return_value = Mock(name='registry')                
        out = list(notify_packages_added(index, pkgs))

        assert len(out) == 1
        assert getreg.called
        assert reg.notify.called
        (event,), _ = reg.notify.call_args
        assert event.im is index
        assert event.version == '0.1'
        assert event.name == 'dummypackage'
Пример #37
0
    def test_notify_packages_added(self, getreg):
        from cheeseprism.index import notify_packages_added
        pkg = stuf(name='dummypackage',
                   version='0.1',
                   filename=self.dummy.name)
        pkgs = pkg,
        index = Mock(name='index')
        index.path = self.im.path
        reg = getreg.return_value = Mock(name='registry')
        out = list(notify_packages_added(index, pkgs))

        assert len(out) == 1
        assert getreg.called
        assert reg.notify.called
        (event, ), _ = reg.notify.call_args
        assert event.im is index
        assert event.version == '0.1'
        assert event.name == 'dummypackage'
Пример #38
0
class MetricFormatter(object):
    op = stuf.stuf(mem_perc=attrgetter('memory_percent'),
                   cpu_perc=attrgetter('cpu_percent'),
                   cnxs=attrgetter('connections'),
                   time=attrgetter('time'))

    def amf(self, values, attr, func=None):
        """
        Attribute mean filtered

        For a list of objects, filter out an object with a attr value
        of None, and return mean if any values are left.
        """
        getter = self.op[attr]
        res = (getter(x) for x in values if not getter(x) is None)
        if func:
            res = (func(x) for x in res if not func(x) is None)
        res = tuple(res)
        logger.debug("%s:%s", attr, res)
        res = res and stats.mean(res) or 0.0
        return res

    def out(self,
            prefix,
            appuser,
            proclist,
            tmplt="{1}.{2}.{3} {4} {0}",
            op=op):
        """
        returns an iterable of metrics for a user and list of procs
        """
        proclist = [x for x in proclist]
        amf = partial(self.amf, proclist)
        fmt = partial(tmplt.format, amf('time'), prefix, appuser)
        yield fmt('procs_running', len(tuple(proclist)))
        cpup = amf('cpu_perc')
        if cpup > 0.0:
            yield fmt('cpu_percent', cpup)
        yield fmt('memory_percent', amf('mem_perc'))
        yield fmt('connections', amf('cnxs', func=len))
Пример #39
0
 def __init__(self, config=None):
     super(Drain, self).__init__(config)
     self._ctor_cache = {}
     self.jobs = {}
     self.result_queue = self. async .queue()
     self.local_cache = stuf()
Пример #40
0
class JobRun(object):
    """
    Represents a single execution of a job
    """
    job_status = stuf(success=False)

    def __init__(self, job, queue, parent):
        self.queue = queue
        self.job = job
        self.parent = parent

    @reify
    def log(self):
        return self.parent.log

    def __call__(self):
        with self.girded_loins(self.job):
            return self.process_job(self.job)

    def prep_callable(self, callable_, job):
        #@@ should this be a hook?!
        if inspect.isclass(callable_):
            return callable_(job.id, self)

        if inspect.isfunction(callable_):
            callable_.uid = job.id
            return callable_

        if inspect.ismethod(callable_):
            self.log.warn('instance method may not be annotated with uid: %s' %
                          job.id)
            return callable_

    def process_job(self, job):
        """
        resolve the job on the path, and run it

        #@@ add hooks 
        """
        candidate = self.parent.resolve(job.path)
        job_callable = self.prep_callable(candidate, job)

        kwargs = job.args.get('kwargs', {}) or {}
        args = job.args.get('args', []) or []

        if isinstance(args, basestring):
            raise TypeError('Job arguments must be a non-string iterable')

        result = job_callable(*args, **kwargs)
        return result

    @contextmanager
    def girded_loins(self, job):
        """
        All the error catching and status reporting
        """
        job.update_state('STARTED')
        job.start = start = time.time()
        try:
            yield
            job.success = True
        except ImportError:
            job.update_state('NOTFOUND')
        except Exception, e:
            if self.parent.debug:
                import pdb
                pdb.post_mortem(sys.exc_info()[2])
            tb = traceback.format_exc(e).encode('zlib')
            job.tb = base64.encodestring(tb)
            job.exc = repr(e)
            job.update_state('FAILED')
            self.log.exception("Job failure by exception:\n%s",
                               pp.pformat(job))
        finally:
Пример #41
0
    PARAMS['transfer_to_Fossil'] = False  #True

    PARAMS['stage-2'] = {}  #'lr':1e-5,
    # 'color_mode':'grayscale',
    # 'num_channels':3,
    # 'BATCH_SIZE':8,
    # 'buffer_size':200,
    # 'num_epochs':150,
    # 'dataset_name':'Fossil',#'PNAS',
    # 'threshold':2,
    # 'frozen_layers':None,
    # 'model_name':'vgg16',#'resnet_50_v2',
    # 'splits':{'train':0.5,'validation':0.5},
    # 'seed':45}

    PARAMS = stuf(PARAMS)
    PARAMS['exclude_classes'] = [
        'notcataloged', 'notcatalogued', 'II. IDs, families uncertain',
        'Unidentified'
    ]
    PARAMS['regularization'] = {'l1': 3e-4}
    PARAMS['METRICS'] = ['accuracy', 'precision', 'recall']
    PARAMS['target_size'] = (512, 512)  #(128,128)#(256,256)# #(768,768)#
    PARAMS['augmentations'] = [{'flip': 1.0}]

    PARAMS = restore_or_initialize_experiment(PARAMS,
                                              restore_last=False,
                                              prefix='log_dir__',
                                              verbose=2)
    # pprint(PARAMS)
    neptune.init(project_qualified_name=PARAMS['neptune_project_name'])
Пример #42
0
 def test_wrap(self):
     from stuf import stuf
     thing = self.qclass(
             [('a', 1), ('b', 2), ('c', 3)]
         ).reup().wrap(stuf).map().shift().value()
     self.assertDictEqual(thing, stuf(a=1, b=2, c=3))
Пример #43
0
 def test_getter(self):
     from stuf import stuf
     from stuf.deep import getter
     foo = stuf(a=1)
     self.assertEqual(getter(foo, 'a'), 1)
     self.assertEqual(getter(stuf, 'items'), stuf.items)
Пример #44
0
import os
from fabric import network as net
from stuf import stuf


ctx = context = stuf()


def setup(ctx=ctx):
    pass

def teardown():
    net.disconnect_all()


Пример #45
0
 def __init__(self, config=None):
     super(Drain, self).__init__(config)
     self._ctor_cache = {}
     self.jobs = {}
     self.result_queue = self.async.queue()
     self.local_cache = stuf()
Пример #46
0
 def get_config(self, candidate):
     # turn it into a dict
     return stuf(candidate)
Пример #47
0
    'experiment_dir': '/media/data/jacob/sandbox_logs',
    'experiment_start_time': arrow.utcnow().format('YYYY-MM-DD_HH-mm-ss'),
    'optimizer': 'Adam',
    'loss': 'categorical_crossentropy',
    'lr': 1e-5,
    'color_mode': 'grayscale',
    'num_channels': 3,
    'BATCH_SIZE': 48,
    'num_epochs': 150,
    'dataset_name': 'Leaves',
    'frozen_layers': None,
    'model_name': 'vgg16',
    'augment_train': True,
    'aug_prob': 0.5,
    'splits': stuf({
        'train': 0.5,
        'validation': 0.5
    })
}

PARAMS = stuf(PARAMS)

PARAMS['experiment_name'] = '_'.join(
    [PARAMS['dataset_name'], PARAMS['model_name']])
PARAMS['regularization'] = {'l1': 1e-4}
PARAMS['METRICS'] = 'accuracy',
PARAMS['target_size'] = (224, 224)

PARAMS['log_dir'] = os.path.join(PARAMS['experiment_dir'],
                                 PARAMS['experiment_name'], 'log_dir',
                                 PARAMS['loss'],
                                 PARAMS['experiment_start_time'])
Пример #48
0
 def test_getcls(self):
     from stuf import stuf
     from stuf.deep import getcls
     foo = stuf(a=1)
     self.assertEqual(getcls(foo), stuf)
Пример #49
0
 def test_pkif(p, moe):
     assert p.basename() == 'mastertest-0.0.tar.gz'
     return stuf(name='mastertest', version='0.0-master')
Пример #50
0
 def test_factory(self):
     from stuf import stuf
     thing = self.qclass(
         ('a', 1), ('b', 2), ('c', 3)
     ).reup().factory(stuf).map().shift().end()
     self.assertDictEqual(thing, stuf(a=1, b=2, c=3), thing)
Пример #51
0
import contextlib
import copy
import functools
import mido
import os
import pickle
import stuf
import threading
import time
import uuid

controllers = collections.defaultdict(threading.Event)
inverse_controllers = collections.defaultdict(threading.Event)
notes = collections.defaultdict(threading.Event)

state = stuf.stuf({"inp": None, "out": None, "threads": [], "handlers": {}, "tempo": 120, "tsig": 4, "loops": [], "metronome": None})

# Utilities

def make_list(x):
    if isinstance(x, list):
        return x
    elif x is None:
        return []
    return [x]

def program(channel, program):
    state.out.send(mido.Message("program_change", channel=channel, program=program))

def wait_measures(n):
    time.sleep((60 / state.tempo) * state.tsig * n)
Пример #52
0
# ## This looks like a true solution

# Use stuf on a dict allows it to work like class atributes.
# So test['red'] becomes test.red
# 
# easier? I don't know, but I like it.

# In[6]:

#As dict
test = {'red':1}

print("As a dict object red = {}".format(test['red']))

#as Stuf
test1 = stuf(test)

print("As a stuf object red = {}".format(test1.red))

test1.lookout = 5

print("Create a new entry lookout = ",end="")
print(test1.lookout)

test1.lookout = 4

print("Modify the entry lookout = ",end="")
print(test1.lookout)


Пример #53
0
import os
from fabric import network as net
from stuf import stuf

ctx = context = stuf()


def setup(ctx=ctx):
    pass


def teardown():
    net.disconnect_all()
Пример #54
0
    num_epochs = PARAMS['num_epochs']
    shuffle=True
    initial_epoch=0

    src_db = pyleaves.DATABASE_PATH
    datasets = {
                'PNAS': pnas_dataset.PNASDataset(src_db=src_db),
                'Leaves': leaves_dataset.LeavesDataset(src_db=src_db),
                'Fossil': fossil_dataset.FossilDataset(src_db=src_db)
                }
    data = datasets[PARAMS['dataset_name']]
    data_config = stuf(threshold=0,
                       data_splits_meta={
                                         'train':PARAMS['train_size'],
                                         'val':PARAMS['val_size'],
                                         'test':PARAMS['test_size']
                                        }
                       )

    preprocess_input = get_preprocessing_func(PARAMS['model_name'])
    preprocess_input(tf.zeros([4, 32, 32, 3]), tf.zeros([4, 32]))
    load_example = partial(_load_example, img_size=PARAMS['image_size'], num_classes=data.num_classes)



    # class ConfusionMatrixCallback(Callback):
    #
    #     def __init__(self, log_dir, val_imgs, val_labels, classes, freq=1, seed=None):
    #         self.file_writer = tf.contrib.summary.create_file_writer(log_dir)
    #         self.log_dir = log_dir
Пример #55
0
def test_create():
    from doulado import script
    args = stuf.stuf()
    assert script.create(args)
Пример #56
0
 def test_deleter(self):
     from stuf import stuf
     from stuf.deep import deleter
     foo = stuf(a=1)
     self.assertEqual(deleter(foo, 'a'), None)
Пример #57
0
def create_experiment__A_train_val_test(src_db=pyleaves.DATABASE_PATH,
                                        include_runs=['all']):
    start = time.time()
    experiment_view = select_by_col(table=tables['runs'],
                                    column='experiment_type',
                                    value='A_train_val_test')
    datasets = {
        'PNAS': pnas_dataset.PNASDataset(src_db=src_db),
        'Leaves': leaves_dataset.LeavesDataset(src_db=src_db),
        'Fossil': fossil_dataset.FossilDataset(src_db=src_db)
    }

    logger = stuf({})
    for i, run in experiment_view.iterrows():
        run_config = stuf(run.to_dict())

        if include_runs != ['all'] and (run_config['run_id']
                                        not in include_runs):
            print(f"SKIPPING run with run_id={run_config['run_id']}")
            continue

        print('BEGINNING ', run)
        config = BaseConfig().parse(namespace=run_config)
        name = config.dataset_A
        logger[name] = stuf({})

        tfrecords_table = TFRecordsTable(
            db_path=config.user_config.experiments_db)

        data_config_A = config.data_config.dataset_A
        num_shards = data_config_A.num_shards
        resolution = data_config_A.resolution

        csv_logger = CSVLogger(config, csv_dir=data_config_A.csv_dir)

        data = datasets[name]
        encoder = base_dataset.LabelEncoder(data.data.family)
        ensure_dir_exists(data_config_A.csv_dir)
        num_classes = data_config_A.num_classes = encoder.num_classes

        processed = base_dataset.preprocess_data(data, encoder, data_config_A)
        for subset in processed.keys():
            Item = partial(
                TFRecordItem, **{
                    'run_id': run_config['run_id'],
                    'experiment_type': run_config['experiment_type'],
                    'file_group': subset,
                    'dataset_stage': 'dataset_A',
                    'dataset_name': name,
                    'resolution': resolution,
                    'num_channels': 3,
                    'num_classes': num_classes,
                    'num_shards': num_shards
                })

            random.shuffle(processed[subset])

            x, y = [list(i) for i in unzip(processed[subset])]
            save_csv_data(x, y, filepath=data_config_A.csv_data[subset])
            encoder.save_config(data_config_A.label_dir)

            class_counts = base_dataset.calculate_class_counts(y_data=y)
            class_weights = base_dataset.calculate_class_weights(y_data=y)

            class_counts_fpath = csv_logger.log_dict(
                class_counts, filepath=f'{subset}_class_counts.csv')
            class_weights_fpath = csv_logger.log_dict(
                class_weights, filepath=f'{subset}_class_weights.csv')

            print(
                f'logged class counts and weights to {class_counts_fpath} and {class_weights_fpath}, respectively'
            )

            logger[name][subset] = stuf({'start': time.time()})
            file_log = save_tfrecords(data=processed[subset],
                                      output_dir=data_config_A.tfrecord_dir,
                                      file_prefix=subset,
                                      target_size=(resolution, resolution),
                                      num_channels=3,
                                      num_classes=encoder.num_classes,
                                      num_shards=num_shards,
                                      TFRecordItem_factory=Item,
                                      tfrecords_table=tfrecords_table,
                                      verbose=True)

            logger[name][subset].end = time.time()
            logger[name][subset].total = logger[name][subset].end - logger[
                name][subset].start
            print(
                name, subset,
                f'took {logger[name][subset].total:.2f} sec to collect/convert to TFRecords'
            )
        print(f'Full experiment took a total of {time.time()-start}')
Пример #58
0
def main(**kwargs):

    import sys

    for k, v in kwargs.items():
        sys.argv += [k, v]

    from pprint import pprint
    import argparse
    import datetime
    import json
    import os


    parser = argparse.ArgumentParser()
    parser.add_argument('--neptune_project_name', default='jacobarose/sandbox', type=str, help='Neptune.ai project name to log under')
    parser.add_argument('--experiment_name', default='pnas_minimal_example', type=str, help='Neptune.ai experiment name to log under')
    parser.add_argument('--config_path', default=r'/home/jacob/projects/pyleaves/pyleaves/configs/example_configs/pnas_resnet_config.json', type=str, help='JSON config file')
    parser.add_argument('-gpu', '--gpu_id', default='1', type=str, help='integer number of gpu to train on', dest='gpu_id')
    parser.add_argument('-tags', '--add-tags', default=[], type=str, nargs='*', help='Add arbitrary list of tags to apply to this run in neptune', dest='tags')
    parser.add_argument('-f', default=None)
    args = parser.parse_args()

    with open(args.config_path, 'r') as config_file:
        PARAMS = json.load(config_file)

    # print(gpu)
    # os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)

    pprint(PARAMS)
    import tensorflow as tf
    import neptune
    # tf.debugging.set_log_device_placement(True)
    print(tf.__version__)





    import arrow
    import numpy as np
    import pandas as pd
    import seaborn as sns
    import matplotlib.pyplot as plt
    import io
    from stuf import stuf
    from more_itertools import unzip
    from functools import partial
    # import tensorflow as tf
    # tf.compat.v1.enable_eager_execution()
    AUTOTUNE = tf.data.experimental.AUTOTUNE

    from pyleaves.leavesdb.tf_utils.tf_utils import set_random_seed, reset_keras_session
    import pyleaves
    from pyleaves.utils.img_utils import random_pad_image
    from pyleaves.utils.utils import ensure_dir_exists
    from pyleaves.datasets import leaves_dataset, fossil_dataset, pnas_dataset, base_dataset
    from pyleaves.models.vgg16 import VGG16, VGG16GrayScale
    from pyleaves.models import resnet, vgg16
    from tensorflow.compat.v1.keras.callbacks import Callback, ModelCheckpoint, TensorBoard, LearningRateScheduler, EarlyStopping
    from tensorflow.keras import metrics
    from tensorflow.keras.preprocessing.image import load_img, img_to_array
    from tensorflow.keras import layers
    from tensorflow.keras import backend as K
    import tensorflow_datasets as tfds
    import neptune_tensorboard as neptune_tb

    seed = 346
    # set_random_seed(seed)
    # reset_keras_session()
    def get_preprocessing_func(model_name):
        if model_name.startswith('resnet'):
            from tensorflow.keras.applications.resnet_v2 import preprocess_input
        elif model_name == 'vgg16':
            from tensorflow.keras.applications.vgg16 import preprocess_input
        elif model_name=='shallow':
            def preprocess_input(x):
                return x/255.0 # ((x/255.0)-0.5)*2.0

        return preprocess_input #lambda x,y: (preprocess_input(x),y)

    def _load_img(image_path):#, img_size=(224,224)):
        img = tf.io.read_file(image_path)
        img = tf.image.decode_jpeg(img, channels=3)
        img = tf.image.convert_image_dtype(img, tf.float32)
        return img
        # return tf.compat.v1.image.resize_image_with_pad(img, *img_size)

    def _encode_label(label, num_classes=19):
        label = tf.cast(label, tf.int32)
        label = tf.one_hot(label, depth=num_classes)
        return label

    def _load_example(image_path, label, num_classes=19):
        img = _load_img(image_path)
        one_hot_label = _encode_label(label, num_classes=num_classes)
        return img, one_hot_label

    def _load_uint8_example(image_path, label, num_classes=19):
        img = tf.image.convert_image_dtype(_load_img(image_path)*255.0, dtype=tf.uint8)
        one_hot_label = _encode_label(label, num_classes=num_classes)
        return img, one_hot_label

    def rgb2gray_3channel(img, label):
        '''
        Convert rgb image to grayscale, but keep num_channels=3
        '''
        img = tf.image.rgb_to_grayscale(img)
        img = tf.image.grayscale_to_rgb(img)
        return img, label

    def rgb2gray_1channel(img, label):
        '''
        Convert rgb image to grayscale, num_channels from 3 to 1
        '''
        img = tf.image.rgb_to_grayscale(img)
        return img, label

    def log_data(logs):
        for k, v in logs.items():
            neptune.log_metric(k, v)

    neptune_logger = tf.keras.callbacks.LambdaCallback(on_epoch_end=lambda epoch, logs: log_data(logs))

    def focal_loss(gamma=2.0, alpha=4.0):

        gamma = float(gamma)
        alpha = float(alpha)

        def focal_loss_fixed(y_true, y_pred):
            """Focal loss for multi-classification
            FL(p_t)=-alpha(1-p_t)^{gamma}ln(p_t)
            Notice: y_pred is probability after softmax
            gradient is d(Fl)/d(p_t) not d(Fl)/d(x) as described in paper
            d(Fl)/d(p_t) * [p_t(1-p_t)] = d(Fl)/d(x)
            Focal Loss for Dense Object Detection
            https://arxiv.org/abs/1708.02002

            Arguments:
                y_true {tensor} -- ground truth labels, shape of [batch_size, num_cls]
                y_pred {tensor} -- model's output, shape of [batch_size, num_cls]

            Keyword Arguments:
                gamma {float} -- (default: {2.0})
                alpha {float} -- (default: {4.0})

            Returns:
                [tensor] -- loss.
            """
            epsilon = 1.e-9
            y_true = tf.convert_to_tensor(y_true, tf.float32)
            y_pred = tf.convert_to_tensor(y_pred, tf.float32)

            model_out = tf.add(y_pred, epsilon)
            ce = tf.multiply(y_true, -tf.log(model_out))
            weight = tf.multiply(y_true, tf.pow(tf.subtract(1., model_out), gamma))
            fl = tf.multiply(alpha, tf.multiply(weight, ce))
            reduced_fl = tf.reduce_max(fl, axis=1)
            return tf.reduce_mean(reduced_fl)
        return focal_loss_fixed


    def per_class_accuracy(y_true, y_pred):
        return tf.metrics.mean_per_class_accuracy(y_true, y_pred, num_classes=PARAMS['num_classes'])

    def build_model(model_params,
                    optimizer,
                    loss,
                    METRICS):

        if model_params['name']=='vgg16':
            model_builder = vgg16.VGG16GrayScale(model_params)
        elif model_params['name'].startswith('resnet'):
            model_builder = resnet.ResNet(model_params)

        base = model_builder.build_base()
        model = model_builder.build_head(base)

        model.compile(optimizer=optimizer,
                      loss=loss,
                      metrics=METRICS)

        return model

    def build_shallow(input_shape=(224,224,3),
                      num_classes=10,
                      optimizer=None,
                      loss=None,
                      METRICS=None):

        model = tf.keras.models.Sequential()
        model.add(layers.Conv2D(64, (7, 7), activation='relu', input_shape=input_shape, kernel_initializer=tf.initializers.GlorotNormal()))
        model.add(layers.MaxPooling2D((2, 2)))
        model.add(layers.Conv2D(64, (7, 7), activation='relu', kernel_initializer=tf.initializers.GlorotNormal()))
        model.add(layers.MaxPooling2D((2, 2)))
        model.add(layers.Conv2D(64, (7, 7), activation='relu', kernel_initializer=tf.initializers.GlorotNormal()))
        model.add(layers.Flatten())
        model.add(layers.Dense(64*2, activation='relu', kernel_initializer=tf.initializers.GlorotNormal()))
        model.add(layers.Dense(num_classes,activation='softmax', kernel_initializer=tf.initializers.GlorotNormal()))

        model.compile(optimizer=optimizer,
                      loss=loss,
                      metrics=METRICS)

        return model


    class ImageLogger:
        '''Tensorflow 2.0 version'''
        def __init__(self, log_dir: str, max_images: int, name: str):
            self.file_writer = tf.summary.create_file_writer(log_dir)
            self.log_dir = log_dir
            self.max_images = max_images
            self.name = name
            self._counter = tf.Variable(0, dtype=tf.int64)

            self.filepaths = []

        def add_log(self, img, counter=None, name=None):
            '''
            Intention is to generalize this to an abstract class for logging to any experiment management platform (e.g. neptune, mlflow, etc)

            Currently takes a filepath pointing to an image file and logs to current neptune experiment.
            '''

            # scaled_images = (img - tf.math.reduce_min(img))/(tf.math.reduce_max(img) - tf.math.reduce_min(img))
            # keep = 0
            # scaled_images = tf.image.convert_image_dtype(tf.squeeze(scaled_images[keep,:,:,:]), dtype=tf.uint8)
            # scaled_images = tf.expand_dims(scaled_images, 0)
            # tf.summary.image(name=self.name, data=scaled_images, step=self._counter, max_outputs=self.max_images)


            scaled_img = (img - np.min(img))/(np.max(img) - np.min(img)) * 255.0
            scaled_img = scaled_img.astype(np.uint32)

            neptune.log_image(log_name= name or self.name,
                              x=counter,
                              y=scaled_img)
            return scaled_img

        def __call__(self, images, labels):

            with self.file_writer.as_default():
                scaled_images = (images - tf.math.reduce_min(images))/(tf.math.reduce_max(images) - tf.math.reduce_min(images))
                keep = 0

                scaled_images = tf.image.convert_image_dtype(tf.squeeze(scaled_images[keep,:,:,:]), dtype=tf.uint8)
                scaled_images = tf.expand_dims(scaled_images, 0)

                labels = tf.argmax(labels[[keep], :],axis=1)
                tf.summary.image(name=self.name, data=scaled_images, step=self._counter, max_outputs=self.max_images)

                filepath = os.path.join(self.log_dir,'sample_images',f'{self.name}-{self._counter}.jpg')

                scaled_images = tf.image.encode_jpeg(tf.squeeze(scaled_images))
                tf.io.write_file(filename=tf.constant(filepath),
                                 contents=scaled_images)

            # self.add_log(scaled_images)
            self._counter.assign_add(1)
            return images, labels

    def _cond_apply(x, y, func, prob):
        """Conditionally apply func to x and y with probability prob.

        Parameters
        ----------
        x : type
            Input to conditionally pass through func
        y : type
            Label
        func : type
            Function to conditionally be applied to x and y
        prob : type
            Probability of applying function, within range [0.0,1.0]

        Returns
        -------
        x, y
        """
        return tf.cond((tf.random.uniform([], 0, 1) >= (1.0 - prob)), lambda: func(x,y), lambda: (x,y))


    class ImageAugmentor:
        """Short summary.

        Parameters
        ----------
        augmentations : dict
            Maps a sequence of named augmentations to a scalar probability,
             according to which they'll be conditionally applied in order.
        resize_w_pad : tuple, default=None
            Description of parameter `resize_w_pad`.
        random_crop :  tuple, default=None
            Description of parameter `random_crop`.
        random_jitter : dict
            First applies resize_w_pad, then random_crop. If user desires only 1 of these, set this to None.
            Should be a dict with 2 keys:
                'resize':(height, width)
                'crop_size':(crop_height,crop_width, channels)

        Only 1 of these 3 kwargs should be provided to any given augmentor:
        {'resize_w_pad', 'random_crop', 'random_jitter'}
        Example values for each:
            resize_w_pad=(224,224)
            random_crop=(224,224,3)
            random_jitter={'resize':(338,338),
                           'crop_size':(224,224, 3)}



        seed : int, default=None
            Random seed to apply to all augmentations

        Examples
        -------
        Examples should be written in doctest format, and
        should illustrate how to use the function/class.
        >>>

        Attributes
        ----------
        augmentations

        """

        def __init__(self,
                     name='',
                     augmentations={'rotate':1.0,
                                    'flip':1.0,
                                    'color':1.0,
                                    'rgb2gray_3channel':1.0},
                     resize_w_pad=None,
                     random_crop=None,
                     random_jitter={'resize':(338,338),
                                    'crop_size':(224,224,3)},
                     log_dir=None,
                     seed=None):

            self.name = name
            self.augmentations = augmentations
            self.seed = seed

            if resize_w_pad:
                self.target_h = resize_w_pad[0]
                self.target_w = resize_w_pad[1]
                # self.resize = self.resize_w_pad
            elif random_crop:
                self.crop_size = random_crop
                self.target_h = self.crop_size[0]
                self.target_w = self.crop_size[1]
                # self.resize = self.random_crop
            elif random_jitter:
                # self.target_h = tf.random.uniform([], random_jitter['crop_size'][0], random_jitter['resize'][0], dtype=tf.int32, seed=self.seed)
                # self.target_w = tf.random.uniform([], random_jitter['crop_size'][1], random_jitter['resize'][1], dtype=tf.int32, seed=self.seed)
                self.crop_size = random_jitter['crop_size']
                # self.resize = self.random_jitter
                self.target_h = random_jitter['crop_size'][0]
                self.target_w = random_jitter['crop_size'][1]
            self.resize = self.resize_w_pad



            self.maps = {'rotate':self.rotate,
                          'flip':self.flip,
                          'color':self.color,
                          'rgb2gray_3channel':self.rgb2gray_3channel,
                          'rgb2gray_1channel':self.rgb2gray_1channel}

            self.log_dir = log_dir

        def rotate(self, x: tf.Tensor, label: tf.Tensor) -> tf.Tensor:
            """Rotation augmentation

            Args:
                x,     tf.Tensor: Image
                label, tf.Tensor: arbitrary tensor, passes through unchanged

            Returns:
                Augmented image, label
            """
            # Rotate 0, 90, 180, 270 degrees
            return tf.image.rot90(x, tf.random.uniform(shape=[], minval=0, maxval=4, dtype=tf.int32,seed=self.seed)), label

        def flip(self, x: tf.Tensor, label: tf.Tensor) -> tf.Tensor:
            """Flip augmentation

            Args:
                x,     tf.Tensor: Image to flip
                label, tf.Tensor: arbitrary tensor, passes through unchanged
            Returns:
                Augmented image, label
            """
            x = tf.image.random_flip_left_right(x, seed=self.seed)
            x = tf.image.random_flip_up_down(x, seed=self.seed)

            return x, label

        def color(self, x: tf.Tensor, label: tf.Tensor) -> tf.Tensor:
            """Color augmentation

            Args:
                x,     tf.Tensor: Image
                label, tf.Tensor: arbitrary tensor, passes through unchanged

            Returns:
                Augmented image, label
            """
            x = tf.image.random_hue(x, 0.08, seed=self.seed)
            x = tf.image.random_saturation(x, 0.6, 1.6, seed=self.seed)
            x = tf.image.random_brightness(x, 0.05, seed=self.seed)
            x = tf.image.random_contrast(x, 0.7, 1.3, seed=self.seed)
            return x, label

        def rgb2gray_3channel(self, x: tf.Tensor, label: tf.Tensor) -> tf.Tensor:
            """Convert RGB image -> grayscale image, maintain number of channels = 3

            Args:
                x,     tf.Tensor: Image
                label, tf.Tensor: arbitrary tensor, passes through unchanged

            Returns:
                Augmented image, label
            """
            x = tf.image.rgb_to_grayscale(x)
            x = tf.image.grayscale_to_rgb(x)
            return x, label

        def rgb2gray_1channel(self, x: tf.Tensor, label: tf.Tensor) -> tf.Tensor:
            """Convert RGB image -> grayscale image, reduce number of channels from 3 -> 1

            Args:
                x,     tf.Tensor: Image
                label, tf.Tensor: arbitrary tensor, passes through unchanged

            Returns:
                Augmented image, label
            """
            x = tf.image.rgb_to_grayscale(x)
            return x, label

        def resize_w_pad(self, x: tf.Tensor, label: tf.Tensor) -> tf.Tensor:
            # TODO Finish this
            # random_pad_image(x,min_image_size=None,max_image_size=None,pad_color=None,seed=self.seed)
            return tf.image.resize_with_pad(x, target_height=self.target_h, target_width=self.target_w), label

        def random_crop(self, x: tf.Tensor, label: tf.Tensor) -> tf.Tensor:
            return tf.image.random_crop(x, size=self.crop_size), label

        @tf.function
        def random_jitter(self, x: tf.Tensor, label: tf.Tensor) -> tf.Tensor:
            x, label = self.resize_w_pad(x, label)
            x, label = self.random_crop(x, label)
            return x, label

        def apply_augmentations(self, dataset: tf.data.Dataset):
            """
            Call this function to apply all of the augmentation in the order of specification
            provided to the constructor __init__() of ImageAugmentor.

            Args:
                dataset, tf.data.Dataset: must yield individual examples of form (x, y)
            Returns:
                Augmented dataset
            """

            dataset = dataset.map(self.resize, num_parallel_calls=AUTOTUNE)

            for aug_name, aug_p in self.augmentations.items():
                aug = self.maps[aug_name]
                dataset = dataset.map(lambda x,y: _cond_apply(x, y, aug, prob=aug_p), num_parallel_calls=AUTOTUNE)
                # dataset = dataset.map(lambda x,y: _cond_apply(x, y, func=aug, prob=aug_p), num_parallel_calls=AUTOTUNE)

            return dataset


    class ImageLoggerCallback(Callback):
        '''Tensorflow 2.0 version

        Callback that keeps track of a tf.data.Dataset and logs the correct batch to neptune based on the current batch.
        '''
        def __init__(self, data :tf.data.Dataset, freq=1, max_images=-1, name='', encoder=None):

            self.data = data
            self.freq = freq
            self.max_images = max_images
            self.name = name
            self.encoder=encoder
            self.init_iterator()

        def init_iterator(self):
            self.data_iter = iter(self.data)
            self._batch = 0
            self._count = 0
            self.finished = False

        def yield_batch(self):
            batch_data = next(self.data_iter)
            self._batch += 1
            self._count += batch_data[0].shape[0]
            return batch_data

        def add_log(self, img, counter=None, name=None):
            '''
            Intention is to generalize this to an abstract class for logging to any experiment management platform (e.g. neptune, mlflow, etc)

            Currently takes a filepath pointing to an image file and logs to current neptune experiment.
            '''
            scaled_img = (img - np.min(img))/(np.max(img) - np.min(img)) * 255.0
            scaled_img = scaled_img.astype(np.uint32)

            neptune.log_image(log_name= name or self.name,
                              x=counter,
                              y=scaled_img)
            return scaled_img

        def on_train_batch_begin(self, batch, logs=None):
            if batch % self.freq or self.finished:
                return
            while batch >= self._batch:
                x, y = self.yield_batch()

            if self.max_images==-1:
                self.max_images=x.shape[0]

            if x.ndim==3:
                np.newaxis(x, axis=0)
            if x.shape[0]>self.max_images:
                x = x[:self.max_images,...]
                y = y[:self.max_images,...]

            x = x.numpy()
            y = np.argmax(y.numpy(),axis=1)
            if self.encoder:
                y = self.encoder.decode(y)
            for i in range(x.shape[0]):
                # self.add_log(x[i,...], counter=i, name = f'{self.name}-{y[i]}-batch_{str(self._batch).zfill(3)}')
                self.add_log(x[i,...], counter=self._count+i, name = f'{self.name}-{y[i]}')
            print(f'Batch {self._batch}: Logged {np.max([x.shape[0],self.max_images])} {self.name} images to neptune')

        def on_epoch_end(self, epoch, logs={}):
            self.finished = True


    class ConfusionMatrixCallback(Callback):
        '''Tensorflow 2.0 version'''
        def __init__(self, log_dir, imgs : dict, labels : dict, classes, freq=1, include_train=False, seed=None):
            self.file_writer = tf.summary.create_file_writer(log_dir)
            self.log_dir = log_dir
            self.seed = seed
            self._counter = 0
            assert np.all(np.array(imgs.keys()) == np.array(labels.keys()))
            self.imgs = imgs

            for k,v in labels.items():
                if v.ndim==2:
                    labels[k] = tf.argmax(v,axis=-1)
            self.labels = labels
            self.num_samples = {k:l.numpy().shape[0] for k,l in labels.items()}
            self.classes = classes
            self.freq = freq
            self.include_train = include_train

        def log_confusion_matrix(self, model, imgs, labels, epoch, name='', norm_cm=False):

            pred_labels = model.predict_classes(imgs)
            # pred_labels = tf.argmax(pred_labels,axis=-1)
            pred_labels = pred_labels[:,None]

            con_mat = tf.math.confusion_matrix(labels=labels, predictions=pred_labels, num_classes=len(self.classes)).numpy()
            if norm_cm:
                con_mat = np.around(con_mat.astype('float') / con_mat.sum(axis=1)[:, np.newaxis], decimals=2)
            con_mat_df = pd.DataFrame(con_mat,
                             index = self.classes,
                             columns = self.classes)

            figure = plt.figure(figsize=(12, 12))
            sns.heatmap(con_mat_df, annot=True, cmap=plt.cm.Blues)
            plt.tight_layout()
            plt.ylabel('True label')
            plt.xlabel('Predicted label')

            buf = io.BytesIO()
            plt.savefig(buf, format='png')
            buf.seek(0)

            image = tf.image.decode_png(buf.getvalue(), channels=4)
            image = tf.expand_dims(image, 0)

            with self.file_writer.as_default():
                tf.summary.image(name=name+'_confusion_matrix', data=image, step=self._counter)

            neptune.log_image(log_name=name+'_confusion_matrix',
                              x=self._counter,
                              y=figure)
            plt.close(figure)
            self._counter += 1

            return image

        def on_epoch_end(self, epoch, logs={}):

            if (not self.freq) or (epoch%self.freq != 0):
                return

            if self.include_train:
                cm_summary_image = self.log_confusion_matrix(self.model, self.imgs['train'], self.labels['train'], epoch=epoch, name='train')
            cm_summary_image = self.log_confusion_matrix(self.model, self.imgs['val'], self.labels['val'], epoch=epoch, name='val')

####################################################################################
####################################################################################
####################################################################################



    neptune.init(project_qualified_name=args.neptune_project_name)
    # neptune_tb.integrate_with_tensorflow()


    experiment_dir = '/media/data/jacob/sandbox_logs'
    experiment_name = args.experiment_name

    experiment_start_time = arrow.utcnow().format('YYYY-MM-DD_HH-mm-ss')
    log_dir =os.path.join(experiment_dir, experiment_name, 'log_dir',PARAMS['loss'], experiment_start_time)
    ensure_dir_exists(log_dir)
    print('Tensorboard log_dir: ', log_dir)
    # os.system(f'neptune tensorboard {log_dir} --project {args.neptune_project_name}')

    weights_best = os.path.join(log_dir, 'model_ckpt.h5')
    restore_best_weights=False
    histogram_freq=0
    patience=25
    num_epochs = PARAMS['num_epochs']
    initial_epoch=0

    src_db = pyleaves.DATABASE_PATH
    datasets = {
                'PNAS': pnas_dataset.PNASDataset(src_db=src_db),
                'Leaves': leaves_dataset.LeavesDataset(src_db=src_db),
                'Fossil': fossil_dataset.FossilDataset(src_db=src_db)
                }
    # data = datasets[PARAMS['dataset_name']]
    data_config = stuf(threshold=PARAMS['data_threshold'],
                       num_classes=PARAMS['num_classes']    ,
                       data_splits_meta={
                                         'train':PARAMS['train_size'],
                                         'val':PARAMS['val_size'],
                                         'test':PARAMS['test_size']
                                        }
                       )

    preprocess_input = get_preprocessing_func(PARAMS['model_name'])
    preprocess_input(tf.zeros([4, 224, 224, 3]))
    
    load_example = partial(_load_uint8_example, num_classes=data_config.num_classes)
    # load_example = partial(_load_example, num_classes=data_config.num_classes)


    if PARAMS['num_channels']==3:
        color_aug = {'rgb2gray_3channel':1.0}
    elif PARAMS['num_channels']==1:
        color_aug = {'rgb2gray_1channel':1.0}

    resize_w_pad=None
    random_jitter=None
    if not PARAMS['random_jitter']['resize']:
        resize_w_pad = PARAMS['image_size']
    else:
        random_jitter=PARAMS['random_jitter']

    TRAIN_image_augmentor = ImageAugmentor(name='train',
                                           augmentations={**PARAMS["augmentations"],
                                                          **color_aug},#'rotate':1.0,'flip':1.0,**color_aug},
                                           resize_w_pad=resize_w_pad,
                                           random_crop=None,
                                           random_jitter=random_jitter,
                                           log_dir=log_dir,
                                           seed=None)
    VAL_image_augmentor = ImageAugmentor(name='val',
                                         augmentations={**color_aug},
                                         resize_w_pad=PARAMS['image_size'],
                                         random_crop=None,
                                         random_jitter=None,
                                         log_dir=log_dir,
                                         seed=None)
    TEST_image_augmentor = ImageAugmentor(name='test',
                                          augmentations={**color_aug},
                                          resize_w_pad=PARAMS['image_size'],
                                          random_crop=None,
                                          random_jitter=None,
                                          log_dir=log_dir,
                                          seed=None)


    def neptune_log_augmented_images(split_data, num_demo_samples=40, PARAMS=PARAMS):
        num_demo_samples = 40
        cm_data_x = {'train':[],'val':[]}
        cm_data_y = {'train':[],'val':[]}
        cm_data_x['train'], cm_data_y['train'] = next(iter(get_data_loader(data=split_data['train'], data_subset_mode='train', batch_size=num_demo_samples, infinite=True, augment=False,seed=2836)))
        cm_data_x['val'], cm_data_y['val'] = next(iter(get_data_loader(data=split_data['val'], data_subset_mode='val', batch_size=num_demo_samples, infinite=True, augment=False, seed=2836)))

        for (k_x,v_x), (k_y, v_y) in zip(cm_data_x.items(), cm_data_y.items()):
            x = tf.data.Dataset.from_tensor_slices(v_x)
            y = tf.data.Dataset.from_tensor_slices(v_y)
            xy_data = tf.data.Dataset.zip((x, y))
            v = xy_data.map(VAL_image_augmentor.resize, num_parallel_calls=AUTOTUNE)
            v_aug = TRAIN_image_augmentor.apply_augmentations(xy_data)
            v_x, v_y = [i.numpy() for i in next(iter(v.batch(10*num_demo_samples)))]
            v_x_aug, v_y_aug = [i.numpy() for i in next(iter(v_aug.batch(10*num_demo_samples)))]
            k = k_x
            for i in range(num_demo_samples):
                print(f'Neptune: logging {k}_{i}')
                print(f'{v_x[i].shape}, {v_x_aug[i].shape}')
                idx = np.random.randint(0,len(v_x))
                if True: #'train' in k:
                    TRAIN_image_augmentor.logger.add_log(v_x[idx],counter=i, name=k)
                    TRAIN_image_augmentor.logger.add_log(v_x_aug[idx],counter=i, name=k+'_aug')


    def get_data_loader(data : tuple, data_subset_mode='train', batch_size=32, num_classes=None, infinite=True, augment=True, seed=2836):

        num_samples = len(data[0])
        x = tf.data.Dataset.from_tensor_slices(data[0])
        labels = tf.data.Dataset.from_tensor_slices(data[1])
        data = tf.data.Dataset.zip((x, labels))

        data = data.cache()
        if data_subset_mode == 'train':
            data = data.shuffle(buffer_size=num_samples)

        # data = data.map(lambda x,y: (tf.image.convert_image_dtype(load_img(x)*255.0,dtype=tf.uint8),y), num_parallel_calls=-1)
        # data = data.map(load_example, num_parallel_calls=AUTOTUNE)
        data = data.map(load_example, num_parallel_calls=AUTOTUNE)


        data = data.map(lambda x,y: (preprocess_input(x), y), num_parallel_calls=AUTOTUNE)

        if infinite:
            data = data.repeat()

        if data_subset_mode == 'train':
            data = data.shuffle(buffer_size=200, seed=seed)
            augmentor = TRAIN_image_augmentor
        elif data_subset_mode == 'val':
            augmentor = VAL_image_augmentor
        elif data_subset_mode == 'test':
            augmentor = TEST_image_augmentor

        if augment:
            data = augmentor.apply_augmentations(data)

        data = data.batch(batch_size, drop_remainder=True)

        return data.prefetch(AUTOTUNE)

    def get_tfds_data_loader(data : tf.data.Dataset, data_subset_mode='train', batch_size=32, num_samples=100, num_classes=19, infinite=True, augment=True, seed=2836):


        def encode_example(x, y):
            x = tf.image.convert_image_dtype(x, tf.float32) * 255.0
            y = _encode_label(y, num_classes=num_classes)
            return x, y

        test_d = next(iter(data))
        print(test_d[0].numpy().min())
        print(test_d[0].numpy().max())

        data = data.shuffle(buffer_size=num_samples) \
                   .cache() \
                   .map(encode_example, num_parallel_calls=AUTOTUNE)

        test_d = next(iter(data))
        print(test_d[0].numpy().min())
        print(test_d[0].numpy().max())

        data = data.map(preprocess_input, num_parallel_calls=AUTOTUNE)

        test_d = next(iter(data))
        print(test_d[0].numpy().min())
        print(test_d[0].numpy().max())

        if data_subset_mode == 'train':
            data = data.shuffle(buffer_size=100, seed=seed)
            augmentor = TRAIN_image_augmentor
        elif data_subset_mode == 'val':
            augmentor = VAL_image_augmentor
        elif data_subset_mode == 'test':
            augmentor = TEST_image_augmentor

        if augment:
            data = augmentor.apply_augmentations(data)

        test_d = next(iter(data))
        print(test_d[0].numpy().min())
        print(test_d[0].numpy().max())

        data = data.batch(batch_size, drop_remainder=True)
        if infinite:
            data = data.repeat()

        return data.prefetch(AUTOTUNE)






    # y_true = [[0, 1, 0], [0, 0, 1]]
    # y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]

    def accuracy(y_true, y_pred):
        y_pred = tf.argmax(y_pred, axis=-1)
        y_true = tf.argmax(y_true, axis=-1)

        return tf.reduce_mean(tf.cast(tf.equal(y_true, y_pred), tf.float32))


    def true_pos(y_true, y_pred):
        # y_true = K.ones_like(y_true)
        return K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))

    def false_pos(y_true, y_pred):
        # y_true = K.ones_like(y_true)
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        all_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
        return all_positives - true_positives

    def true_neg(y_true, y_pred):
        # y_true = K.ones_like(y_true)
        return K.sum(1-K.round(K.clip(y_true * y_pred, 0, 1)))

    def recall(y_true, y_pred):
        # y_true = K.ones_like(y_true)
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        all_positives = K.sum(K.round(K.clip(y_true, 0, 1)))

        recall = true_positives / (all_positives + K.epsilon())
        return recall

    def precision(y_true, y_pred):
        y_true = K.ones_like(y_true)

        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))

        predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
        precision = true_positives / (predicted_positives + K.epsilon())
        # tf.print(y_true, y_pred)
        return precision

    def f1_score(y_true, y_pred):
        m_precision = precision(y_true, y_pred)
        m_recall = recall(y_true, y_pred)
        # pdb.set_trace()
        return 2*((m_precision*m_recall)/(m_precision+m_recall+K.epsilon()))

    # def false_neg(y_true, y_pred):
    #     y_true = K.ones_like(~y_true)
    #     true_neg = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    #     all_negative = K.sum(K.round(K.clip(y_true, 0, 1)))
    #     return all_negatives - true_

        # return K.mean(K.argmax(y_true,axis=1)*K.argmax(y_pred,axis=1))

        # 'accuracy',
        # metrics.TrueNegatives(name='tn'),
        # metrics.FalseNegatives(name='fn'),
    METRICS = [
        f1_score,
        metrics.TruePositives(name='tp'),
        metrics.FalsePositives(name='fp'),
        metrics.CategoricalAccuracy(name='accuracy'),
        metrics.TopKCategoricalAccuracy(name='top_3_categorical_accuracy', k=3),
        metrics.TopKCategoricalAccuracy(name='top_5_categorical_accuracy', k=5)
    ]
    PARAMS['sys.argv'] = ' '.join(sys.argv)

    with neptune.create_experiment(name=experiment_name, params=PARAMS, upload_source_files=[__file__]):


        print('Logging experiment tags:')
        for tag in args.tags:
            print(tag)
            neptune.append_tag(tag)

        neptune.append_tag(PARAMS['dataset_name'])
        neptune.append_tag(PARAMS['model_name'])
        neptune.log_artifact(args.config_path)
        cm_data_x = {'train':[],'val':[]}
        cm_data_y = {'train':[],'val':[]}

        if PARAMS['dataset_name'] in tfds.list_builders():
            num_demo_samples=40

            tfds_builder = tfds.builder(PARAMS['dataset_name'])
            tfds_builder.download_and_prepare()

            num_samples = tfds_builder.info.splits['train'].num_examples
            num_samples_dict = {'train':int(num_samples*PARAMS['train_size']),
                            'val':int(num_samples*PARAMS['val_size']),
                            'test':int(num_samples*PARAMS['test_size'])}

            classes = tfds_builder.info.features['label'].names
            num_classes = len(classes)

            train_slice = [0,int(PARAMS['train_size']*100)]
            val_slice = [int(PARAMS['train_size']*100), int((PARAMS['train_size']+PARAMS['val_size'])*100)]
            test_slice = [100 - int(PARAMS['test_size']*100), 100]

            tfds_train_data = tfds.load(PARAMS['dataset_name'], split=f"train[{train_slice[0]}%:{train_slice[1]}%]", shuffle_files=True, as_supervised=True)
            tfds_validation_data = tfds.load(PARAMS['dataset_name'], split=f"train[{val_slice[0]}%:{val_slice[1]}%]", shuffle_files=True, as_supervised=True)
            tfds_test_data = tfds.load(PARAMS['dataset_name'], split=f"train[{test_slice[0]}%:{test_slice[1]}%]", shuffle_files=True, as_supervised=True)

            # PARAMS['batch_size']=1
            train_data = get_tfds_data_loader(data = tfds_train_data, data_subset_mode='train', batch_size=PARAMS['batch_size'], num_samples=num_samples_dict['train'], num_classes=num_classes, infinite=True, augment=True, seed=2836)
            validation_data = get_tfds_data_loader(data = tfds_validation_data, data_subset_mode='val', batch_size=PARAMS['batch_size'], num_samples=num_samples_dict['val'], num_classes=num_classes, infinite=True, augment=True, seed=2837)
            test_data = get_tfds_data_loader(data = tfds_test_data, data_subset_mode='test', batch_size=PARAMS['batch_size'], num_samples=num_samples_dict['test'], num_classes=num_classes, infinite=True, augment=True, seed=2838)

            # tfds_train_data = tfds.load(PARAMS['dataset_name'], split=f"train[{train_slice[0]}%:{train_slice[1]}%]", shuffle_files=True, as_supervised=True)
            # tfds_validation_data = tfds.load(PARAMS['dataset_name'], split=f"train[{val_slice[0]}%:{val_slice[1]}%]", shuffle_files=True, as_supervised=True)
            # tfds_test_data = tfds.load(PARAMS['dataset_name'], split=f"train[{test_slice[0]}%:{test_slice[1]}%]", shuffle_files=True, as_supervised=True)

            split_data = {'train':get_tfds_data_loader(data = tfds_train_data, data_subset_mode='train', batch_size=num_demo_samples, num_samples=num_samples_dict['train'], num_classes=num_classes, infinite=True, augment=True, seed=2836),
                          'val':get_tfds_data_loader(data = tfds_validation_data, data_subset_mode='val', batch_size=num_demo_samples, num_samples=num_samples_dict['val'], num_classes=num_classes, infinite=True, augment=True, seed=2837),
                          'test':get_tfds_data_loader(data = tfds_test_data, data_subset_mode='test', batch_size=num_demo_samples, num_samples=num_samples_dict['test'], num_classes=num_classes, infinite=True, augment=True, seed=2838)
                          }

            steps_per_epoch=num_samples_dict['train']//PARAMS['batch_size']
            validation_steps=num_samples_dict['val']//PARAMS['batch_size']

            cm_data_x['train'], cm_data_y['train'] = next(iter(split_data['train']))
            cm_data_x['val'], cm_data_y['val'] = next(iter(split_data['val']))

        else:
            data = datasets[PARAMS['dataset_name']]
            neptune.set_property('num_classes',data.num_classes)
            neptune.set_property('class_distribution',data.metadata.class_distribution)

            encoder = base_dataset.LabelEncoder(data.data.family)
            split_data = base_dataset.preprocess_data(data, encoder, data_config)
            # import pdb;pdb.set_trace()
            for subset, subset_data in split_data.items():
                split_data[subset] = [list(i) for i in unzip(subset_data)]

            PARAMS['batch_size'] = 32

            steps_per_epoch=len(split_data['train'][0])//PARAMS['batch_size']#//10
            validation_steps=len(split_data['val'][0])//PARAMS['batch_size']#//10

            split_datasets = {
                              k:base_dataset.BaseDataset.from_dataframe(
                                pd.DataFrame({
                                            'path':v[0],
                                            'family':v[1]
                                            })) \
                              for k,v in split_data.items()
                             }

            for k,v in split_datasets.items():
                print(k, v.num_classes)

            classes = split_datasets['train'].classes

            train_data=get_data_loader(data=split_data['train'], data_subset_mode='train', batch_size=PARAMS['batch_size'], infinite=True, augment=True, seed=2836)
            validation_data=get_data_loader(data=split_data['val'], data_subset_mode='val', batch_size=PARAMS['batch_size'], infinite=True, augment=True, seed=2837)
            if 'test' in split_data.keys():
                test_data=get_data_loader(data=split_data['test'], data_subset_mode='test', batch_size=PARAMS['batch_size'], infinite=True, augment=True, seed=2838)

            num_demo_samples=150
            # neptune_log_augmented_images(split_data, num_demo_samples=num_demo_samples, PARAMS=PARAMS)
            cm_data_x['train'], cm_data_y['train'] = next(iter(get_data_loader(data=split_data['train'], data_subset_mode='train', batch_size=num_demo_samples, infinite=True, augment=True, seed=2836)))
            cm_data_x['val'], cm_data_y['val'] = next(iter(get_data_loader(data=split_data['val'], data_subset_mode='val', batch_size=num_demo_samples, infinite=True, augment=True,  seed=2836)))


        ########################################################################################
        train_image_logger_cb = ImageLoggerCallback(data=train_data, freq=20, max_images=-1, name='train', encoder=encoder)
        val_image_logger_cb = ImageLoggerCallback(data=validation_data, freq=20, max_images=-1, name='val', encoder=encoder)
        ########################################################################################

        cm_callback = ConfusionMatrixCallback(log_dir, cm_data_x, cm_data_y, classes=classes, seed=PARAMS['seed'], include_train=True)
        checkpoint = ModelCheckpoint(weights_best, monitor='val_loss', verbose=0, save_best_only=True, save_weights_only=False, mode='min',restore_best_weights=restore_best_weights)
        tfboard = TensorBoard(log_dir=log_dir, histogram_freq=histogram_freq, write_images=True)
        early = EarlyStopping(monitor='val_loss', patience=patience, verbose=1)
        callbacks = [checkpoint,tfboard,early, cm_callback, neptune_logger, train_image_logger_cb, val_image_logger_cb]
    ##########################
        if PARAMS['optimizer'] == 'Adam':
            optimizer = tf.keras.optimizers.Adam(
                learning_rate=PARAMS['lr']
            )
        elif PARAMS['optimizer'] == 'Nadam':
            optimizer = tf.keras.optimizers.Nadam(
                learning_rate=PARAMS['lr']
            )
        elif PARAMS['optimizer'] == 'SGD':
            optimizer = tf.keras.optimizers.SGD(
                learning_rate=PARAMS['lr']
            )
    ##########################
        if PARAMS['loss']=='focal_loss':
            loss = focal_loss(gamma=2.0, alpha=4.0)
        elif PARAMS['loss']=='categorical_crossentropy':
            loss = 'categorical_crossentropy'
    ##########################
        model_params = stuf(name=PARAMS['model_name'],
                            model_dir=os.path.join(experiment_dir, experiment_name, 'models'),
                            num_classes=PARAMS['num_classes'],
                            frozen_layers = PARAMS['frozen_layers'],
                            input_shape = (*PARAMS['image_size'],PARAMS['num_channels']),
                            base_learning_rate = PARAMS['lr'],
                            regularization = PARAMS['regularization'])
    ####
        if PARAMS['model_name']=='shallow':
            model = build_shallow(input_shape=model_params.input_shape,
                                  num_classes=PARAMS['num_classes'],
                                  optimizer=optimizer,
                                  loss=loss,
                                  METRICS=METRICS)

        else:
            model = build_model(model_params,
                                optimizer,
                                loss,
                                METRICS)
        print(f"TRAINING {PARAMS['model_name']}")

        model.summary(print_fn=lambda x: neptune.log_text('model_summary', x))

        history = model.fit(train_data,
                            epochs=num_epochs,
                            callbacks=callbacks,
                            validation_data=validation_data,
                            shuffle=True,
                            initial_epoch=initial_epoch,
                            steps_per_epoch=steps_per_epoch,
                            validation_steps=validation_steps)


        if 'test' in split_data:
            results = model.evaluate(test_data,
                                    steps=len(split_data['test'][0]))
        else:
            results = model.evaluate(validation_data,
                                    steps=validation_steps)