Beispiel #1
0
 def __init__(self, flags):
     self.io = SharedFlagIO(subprogram=True)
     self.get_weight_src(flags)
     self.modify = False
     self.io.logger.info('Parsing {}'.format(self.src_cfg))
     src_parsed = self.create_ops()
     self.meta, self.layers = src_parsed
Beispiel #2
0
def predict(flags, net: Net, framework: Framework):
    log = get_logger()
    io = SharedFlagIO(flags, subprogram=True)
    pool = ThreadPool()
    flags = io.read_flags() if io.read_flags() is not None else flags
    all_inps = [i for i in os.listdir(flags.imgdir) if framework.is_input(i)]
    if not all_inps:
        raise FileNotFoundError(f'Failed to find any images in {flags.imgdir}')
    batch = min(flags.batch, len(all_inps))
    n_batch = int(math.ceil(len(all_inps) / batch))
    for j in range(n_batch):
        start = j * batch
        stop = min(start + batch, len(all_inps))
        this_batch = all_inps[start:stop]
        img_path = partial(os.path.join, flags.imgdir)
        log.info(f'Preprocessing {batch} inputs...')
        with Timer() as t:
            x = pool.map(lambda inp: framework.preprocess(img_path(inp)),
                         this_batch)
        log.info(f'Done! ({batch/t.elapsed_secs:.2f} inputs/s)')
        log.info(f'Forwarding {batch} inputs...')
        with Timer() as t:
            x = [np.concatenate(net(np.expand_dims(i, 0)), 0) for i in x]
        log.info(f'Done! ({batch/t.elapsed_secs:.2f} inputs/s)')
        log.info(f'Postprocessing {batch} inputs...')
        with Timer() as t:
            postprocess = lambda i, pred: framework.postprocess(
                pred, img_path(this_batch[i]))
            pool.map(lambda p: postprocess(*p), enumerate(x))
        log.info(f'Done! ({batch/t.elapsed_secs:.2f} inputs/s)')
Beispiel #3
0
def loss(self, y_pred, y_true):
    losses = self.type.keys()
    loss_type = self.meta['type'].strip('[]')
    out_size = self.meta['out_size']
    H, W, _ = self.meta['inp_size']
    HW = H * W
    try:
        assert loss_type in losses, f'Loss type {loss_type} not implemented'
    except AssertionError as e:
        self.flags.error = str(e)
        self.logger.error(str(e))
        SharedFlagIO.send_flags(self)
        raise

    if self.first:
        self.logger.info('{} loss hyper-parameters:'.format(
            self.meta['model']))
        self.logger.info('Input Grid Size   = {}'.format(HW))
        self.logger.info('Number of Outputs = {}'.format(out_size))
        self.first = False

    diff = y_true - y_pred
    if loss_type in ['sse', '12']:
        return tf.nn.l2_loss(diff)
    elif loss_type == 'mse':
        return tf.keras.losses.MSE(y_true, y_pred)
    elif loss_type == ['smooth']:
        small = tf.cast(diff < 1, tf.float32)
        large = 1. - small
        return L1L2(tf.multiply(diff, large), tf.multiply(diff, small))
    elif loss_type in ['sparse', 'l1']:
        return l1(diff)
    elif loss_type == 'softmax':
        _loss = tf.nn.softmax_cross_entropy_with_logits(logits=y_pred)
        return tf.reduce_mean(_loss)
Beispiel #4
0
 def __init__(self, flags, darknet=None):
     super(NetBuilder, self).__init__(name=self.__class__.__name__)
     tf.autograph.set_verbosity(0)
     self.io = SharedFlagIO(subprogram=True)
     self.flags = self.io.read_flags() if self.io.read_flags(
     ) is not None else flags
     self.io_flags = self.io.io_flags
     self.logger = get_logger()
     self.darknet = Darknet(flags) if darknet is None else darknet
     self.num_layer = self.ntrain = len(self.darknet.layers) or 0
     self.meta = self.darknet.meta
Beispiel #5
0
    def __init__(self):
        super(BeaglesMainWindow, self).__init__()
        self.io = SharedFlagIO(subprogram=True)

        def createActions(actions: list):
            nonlocal self
            action = partial(newAction, self)
            cmd = 'global {0}; {0} = action("{1}", {2}, {3}, "{4}", "{5}", {6}, {7})'
            for act in actions:
                _str = act
                action_str = getStr(_str)
                shortcut, checkable, enabled = [
                    str(i) for i in self.actionFile[_str]
                ]
                shortcut = '"{}"'.format(
                    shortcut) if shortcut is not None else None
                detail = getStr(_str + "Detail")
                icon = _str
                callback = 'self.' + act
                cmd_string = cmd.format(_str, action_str, callback, shortcut,
                                        icon, detail, checkable, enabled)
                self.io.logger.debug(cmd_string)
                exec(cmd_string)

        with open('beagles/resources/actions/actions.json', 'r') as json_file:
            self.actionFile = json.load(json_file)
        actionList = list(self.actionFile.keys())
        createActions(actionList)
        self.setup()
Beispiel #6
0
    def __init__(self, num_divisions: int, video: os.PathLike,
                 unused_cameras: List[int]):
        self.logger = SharedFlagIO().logger
        # make sure the number of camera divisions is always an integer
        root = math.sqrt(num_divisions)
        self.div = root if isinstance(root, int) else int(math.ceil(root))

        self.video = video

        self.num_videos = self.div**2

        self.unused_cameras = unused_cameras

        self.folder = os.path.dirname(video)

        self.width, self.height = self._get_resolution(video)
Beispiel #7
0
def train(data, class_weights, flags, net: Net, framework: Framework,
          manager: tf.train.CheckpointManager):
    log = get_logger()
    io = SharedFlagIO(flags, subprogram=True)
    flags = io.read_flags() if io.read_flags() is not None else flags
    log.info('Building {} train op'.format(flags.model))
    goal = len(data) * flags.epoch
    first = True
    for i, (x_batch,
            loss_feed) in enumerate(framework.shuffle(data, class_weights)):
        loss = net(x_batch, training=True, **loss_feed)
        step = net.step.numpy()
        lr = net.optimizer.learning_rate.numpy()
        line = 'step: {} loss: {:f} lr: {:.2e} progress: {:.2f}%'
        if not first:
            flags.progress = i * flags.batch / goal * 100
            log.info(line.format(step, loss, lr, flags.progress))
        else:
            log.info(f"Following gradient from step {step}...")
        io.send_flags()
        flags = io.read_flags()
        ckpt = bool(not step % flags.save)
        if ckpt and not first:
            save = manager.save()
            log.info(f"Saved checkpoint: {save}")
        first = False
    if not ckpt:
        save = manager.save()
        log.info(f"Finished training at checkpoint: {save}")
Beispiel #8
0
 def __init__(self, flags, darknet=None):
     self.io = SharedFlagIO(subprogram=True)
     # disable eager mode for TF1-dependent code
     tf.compat.v1.disable_eager_execution()
     self.flags = self.io.read_flags() if self.io.read_flags(
     ) is not None else flags
     self.io_flags = self.io.io_flags
     self.logger = get_logger()
     self.ntrain = 0
     darknet = Darknet(flags) if darknet is None else darknet
     self.ntrain = len(darknet.layers)
     self.darknet = darknet
     self.num_layer = len(darknet.layers)
     self.framework = Framework.create(darknet.meta, flags)
     self.annotation_data = self.framework.parse()
     self.meta = darknet.meta
     self.graph = tf.Graph()
     device_name = flags.gpu_name if flags.gpu > 0.0 else None
     start = time.time()
     with tf.device(device_name):
         with self.graph.as_default():
             self.build_forward()
             self.setup_meta_ops()
     self.logger.info('Finished in {}s'.format(time.time() - start))
Beispiel #9
0
def annotate(flags, net, framework):
    log = get_logger()
    io = SharedFlagIO(flags, subprogram=True)
    flags = io.read_flags() if io.read_flags() is not None else flags
    for video in flags.video:
        frame_count = 0
        capture = cv2.VideoCapture(video)
        total_frames = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
        annotation_file = f'{os.path.splitext(video)[0]}_annotations.csv'
        if os.path.exists(annotation_file):
            log.info("Overwriting existing annotations")
            os.remove(annotation_file)
        log.info(f'Annotating {video}')
        with open(annotation_file, mode='a') as file:
            file_writer = csv.writer(file,
                                     delimiter=',',
                                     quotechar='"',
                                     quoting=csv.QUOTE_MINIMAL)
            while capture.isOpened():
                frame_count += 1
                if frame_count % 10 == 0:
                    flags.progress = round((100 * frame_count / total_frames),
                                           0)
                    io.io_flags()
                ret, frame = capture.read()
                if ret:
                    frame = np.asarray(frame)
                    h, w, _ = frame.shape
                    im = framework.resize_input(frame)
                    this_inp = np.expand_dims(im, 0)
                    boxes = framework.findboxes(
                        np.concatenate(net(this_inp), 0))
                    pred = [
                        framework.process_box(b, h, w, flags.threshold)
                        for b in boxes
                    ]
                    pred = filter(None, pred)
                    time_elapsed = capture.get(cv2.CAP_PROP_POS_MSEC) / 1000
                    [
                        file_writer.writerow([time_elapsed, *result])
                        for result in pred
                    ]
                else:
                    break
                if flags.kill:
                    capture.release()
                    exit(1)
        capture.release()
Beispiel #10
0
 def setUp(self):
     self.io = SharedFlagIO(subprogram=False)
     self.flags = self.io.flags
Beispiel #11
0
class TestBackend(TestCase):
    @classmethod
    def setUpClass(cls) -> None:
        open('tests/resources/checkpoint', 'w').close()
        with ZipFile('tests/resources/BCCD.v1-resize-416x416.voc.zip',
                     'r') as f:
            f.extractall('tests/resources/BCCD')
        time.sleep(5)

    def setUp(self):
        self.io = SharedFlagIO(subprogram=False)
        self.flags = self.io.flags

    def testBackendWrapperYoloV2(self):
        self.flags.model = 'tests/resources/yolov2-lite-3c.cfg'
        self.flags.dataset = 'tests/resources/BCCD/train'
        self.flags.labels = 'tests/resources/BCCD.classes'
        self.flags.annotation = 'tests/resources/BCCD/train'
        self.flags.backup = 'tests/resources/ckpt'
        self.flags.project_name = '_test'
        self.flags.trainer = 'adam'
        self.flags.lr = 0.00001
        self.flags.max_lr = 0.0001
        self.flags.step_size_coefficient = 10
        self.flags.load = 0
        self.flags.batch = 4
        self.flags.epoch = 1
        self.flags.train = True
        self.io.io_flags()
        proc = Popen([sys.executable, BACKEND_ENTRYPOINT],
                     stdout=PIPE,
                     shell=False)
        proc.communicate()
        self.assertEqual(proc.returncode, 0)
        self.flags.load = 1
        self.io.io_flags()
        proc = Popen([sys.executable, BACKEND_ENTRYPOINT],
                     stdout=PIPE,
                     shell=False)
        proc.communicate()
        self.flags.train = False
        self.flags.imgdir = 'tests/resources/BCCD/test'
        self.io.io_flags()
        proc = Popen([sys.executable, BACKEND_ENTRYPOINT],
                     stdout=PIPE,
                     shell=False)
        proc.communicate()
        self.flags.video = ['tests/resources/test.mp4']
        self.io.io_flags()
        proc = Popen([sys.executable, BACKEND_ENTRYPOINT],
                     stdout=PIPE,
                     shell=False)
        proc.communicate()
        self.assertEqual(proc.returncode, 0)

    def testBackendGradientExplosion(self):
        self.flags.model = 'tests/resources/yolov2-lite-3c.cfg'
        self.flags.dataset = 'tests/resources/BCCD/train'
        self.flags.labels = 'tests/resources/BCCD.classes'
        self.flags.annotation = 'tests/resources/BCCD/train'
        self.flags.backup = 'tests/resources/ckpt'
        self.flags.project_name = '_test'
        self.flags.trainer = 'adam'
        self.flags.lr = 100000000.0
        self.flags.max_lr = 100000000.0
        self.flags.load = 0
        self.flags.batch = 4
        self.flags.epoch = 1
        self.flags.train = True
        self.io.io_flags()
        proc = Popen([sys.executable, BACKEND_ENTRYPOINT],
                     stdout=PIPE,
                     shell=False)
        proc.communicate()
        self.assertNotEqual(proc.returncode, 0)

    def tearDown(self) -> None:
        self.io.cleanup_flags()

    @classmethod
    def tearDownClass(cls):
        for f in os.listdir('tests/resources/ckpt'):
            if f.endswith(
                ('.data-00000-of-00001', '.index', '.meta', '.profile')):
                f = os.path.join('tests/resources/ckpt', f)
                os.remove(f)
        os.remove('tests/resources/ckpt/checkpoint')
        rmtree('tests/resources/BCCD')
        try:
            rmtree('data/summaries/_test')
        except FileNotFoundError:
            pass
Beispiel #12
0
class Darknet(object):
    def __init__(self, flags):
        self.io = SharedFlagIO(subprogram=True)
        self.get_weight_src(flags)
        self.modify = False
        self.io.logger.info('Parsing {}'.format(self.src_cfg))
        src_parsed = self.create_ops()
        self.meta, self.layers = src_parsed
        # uncomment for v1 behavior
        # self.load_weights()

    def get_weight_src(self, flags):
        """
        analyse flags.load to know where is the
        source binary and what is its config.
        can be: None, flags.model, or some other
        """
        self.src_bin = flags.model + WGT_EXT
        self.src_bin = flags.binary + self.src_bin
        self.src_bin = os.path.abspath(self.src_bin)
        exist = os.path.isfile(self.src_bin)

        if flags.load == str():
            flags.load = int()
        if type(flags.load) is int:
            self.src_cfg = flags.model
            if flags.load:
                self.src_bin = None
            elif not exist:
                self.src_bin = None
        else:
            self.src_bin = flags.load
            name = self.model_name(flags.load)
            cfg_path = os.path.join(flags.config, name + CFG_EXT)
            if not os.path.isfile(cfg_path):
                self.io.logger.warn(
                    f'{cfg_path} not found, use {flags.model} instead')
                cfg_path = flags.model
            self.src_cfg = cfg_path
            flags.load = int()

    @staticmethod
    def model_name(file_path):
        file_name = os.path.basename(file_path)
        ext = str()
        if '.' in file_name:  # exclude extension
            file_name = file_name.split('.')
            ext = file_name[-1]
            file_name = '.'.join(file_name[:-1])
        if ext == str() or ext == 'meta':  # ckpt file
            file_name = file_name.split('-')
            num = int(file_name[-1])
            return '-'.join(file_name[:-1])
        if ext == 'weights':
            return file_name

    def create_ops(self):
        """
        return a list of `layers` objects (darkop.py)
        given path to binaries/ and configs/
        """
        cfg_layers = ConfigParser.create(self.src_cfg)

        meta = dict()
        layers = list()
        try:
            for i, info in enumerate(cfg_layers):
                if i == 0:
                    meta = info
                    continue
                else:
                    new = create_darkop(*info)
                layers.append(new)
        except TypeError as e:
            self.io.flags.error = str(e)
            self.io.logger.error(str(e))
            self.io.send_flags()
            raise
        return meta, layers

    def load_weights(self):
        """
        Use `layers` and Loader to load .weights file
        """
        self.io.logger.info(f'Loading {self.src_bin} ...')
        start = time.time()

        args = [self.src_bin, self.layers]
        wgts_loader = Loader.create(*args)
        for layer in self.layers:
            layer.load(wgts_loader)

        stop = time.time()
        self.io.logger.info('Finished in {}s'.format(stop - start))
Beispiel #13
0
class TiledCaptureArray:
    """Object definition for tiled capture arrays.

    Args:
        num_divisions: number of tiles to process

        video: path to source video

        unused_cameras: camera sources to skip during processing.

    Note:
        Cameras are numbered as follows:
        :math:`\\begin{bmatrix} 1 & 2 & 3 \\\\ 4 & 5 & 6 \\\\ 7 & 8 & 9 \\end{bmatrix}`
    """
    def __init__(self, num_divisions: int, video: os.PathLike,
                 unused_cameras: List[int]):
        self.logger = SharedFlagIO().logger
        # make sure the number of camera divisions is always an integer
        root = math.sqrt(num_divisions)
        self.div = root if isinstance(root, int) else int(math.ceil(root))

        self.video = video

        self.num_videos = self.div**2

        self.unused_cameras = unused_cameras

        self.folder = os.path.dirname(video)

        self.width, self.height = self._get_resolution(video)

    @staticmethod
    def _get_resolution(target):
        vid = cv2.VideoCapture(target)
        height = vid.get(cv2.CAP_PROP_FRAME_HEIGHT)
        width = vid.get(cv2.CAP_PROP_FRAME_WIDTH)
        return width, height

    def crop(self):
        """Stream copies processed tiles to labeled files using ffmpeg"""
        xs = list()
        ys = list()
        h_inc = int(self.height / self.div)
        w_inc = int(self.width / self.div)

        # setup ys
        for i in range(1, self.div + 1):
            ys.append(0)
        for i in range(1, self.div):
            for _ in range(1, self.div + 1):
                ys.append(i * h_inc)
        # setup xs
        for i in range(1, self.div + 1):
            xs.append(0)
            for j in range(1, self.div):
                xs.append(j * w_inc)

        for i in range(1, self.num_videos + 1):
            if i in self.unused_cameras:
                continue
            name, ext = os.path.splitext(self.video)
            x = xs[i - 1]
            y = ys[i - 1]
            output = f'{name}_camera_{i}{ext}'
            cmd = f'ffmpeg -hide_banner -y -i "{self.video}" -filter:v ' \
                  f'"crop={w_inc}:{h_inc}:{x}:{y}" -c:a copy -map_metadata 0 ' \
                  f'-map_metadata:s:v 0:s:v -map_metadata:s:a 0:s:a "{output}"'
            self.logger.debug(cmd)
            proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, shell=True)
            self.logger.info(
                f'Started ffmpeg PID: {proc.pid} Output: {output}')
Beispiel #14
0
class NetBuilder(tf.Module):
    """Initializes with flags that build a Darknet or with a prebuilt Darknet.
    Constructs the actual :obj:`Net` object upon being called.

    """
    def __init__(self, flags, darknet=None):
        super(NetBuilder, self).__init__(name=self.__class__.__name__)
        tf.autograph.set_verbosity(0)
        self.io = SharedFlagIO(subprogram=True)
        self.flags = self.io.read_flags() if self.io.read_flags(
        ) is not None else flags
        self.io_flags = self.io.io_flags
        self.logger = get_logger()
        self.darknet = Darknet(flags) if darknet is None else darknet
        self.num_layer = self.ntrain = len(self.darknet.layers) or 0
        self.meta = self.darknet.meta

    def __call__(self):
        self.global_step = tf.Variable(0, trainable=False)
        framework = Framework.create(self.darknet.meta, self.flags)
        self.annotation_data, self.class_weights = framework.parse()
        optimizer = self.build_optimizer()
        layers = self.compile_darknet()
        net = Net(layers, self.global_step, dtype=tf.float32)
        ckpt_kwargs = {'net': net, 'optimizer': optimizer}
        self.checkpoint = tf.train.Checkpoint(**ckpt_kwargs)
        name = f"{self.meta['name']}"
        manager = tf.train.CheckpointManager(self.checkpoint,
                                             self.flags.backup,
                                             self.flags.keep,
                                             checkpoint_name=name)
        # try to load a checkpoint from flags.load
        self.load_checkpoint(manager)
        self.logger.info('Compiling Net...')
        net.compile(loss=framework.loss, optimizer=optimizer)
        return net, framework, manager

    def build_optimizer(self):
        # setup kwargs for trainer
        kwargs = dict()
        if self.flags.trainer in MOMENTUM_USERS:
            kwargs.update({MOMENTUM: self.flags.momentum})
        if self.flags.trainer is NESTEROV:
            kwargs.update({self.flags.trainer: True})
        if self.flags.trainer is AMSGRAD:
            kwargs.update({AMSGRAD.lower(): True})
        if self.flags.clip:
            kwargs.update({'clipnorm': self.flags.clip_norm})
        ssc = self.flags.step_size_coefficient
        step_size = int(ssc * (len(self.annotation_data) // self.flags.batch))
        clr_kwargs = {
            'global_step': self.global_step,
            'mode': self.flags.clr_mode,
            'step_size': step_size,
            'learning_rate': self.flags.lr,
            'max_lr': self.flags.max_lr,
            'name': self.flags.model
        }
        # setup trainer
        return TRAINERS[self.flags.trainer](
            learning_rate=lambda: clr(**clr_kwargs), **kwargs)

    def compile_darknet(self):
        layers = list()
        roof = self.num_layer - self.ntrain
        prev = None
        for i, layer in enumerate(self.darknet.layers):
            layer = op_create(layer, prev, i, roof)
            layers.append(layer)
            prev = layer
        return layers

    def load_checkpoint(self, manager):
        if isinstance(self.flags.load, str):
            checkpoint = [
                i for i in manager.checkpoints if self.flags.load in i
            ]
            assert len(checkpoint) == 1
            self.checkpoint.restore(checkpoint)
            self.logger.info(f"Restored from {checkpoint}")
        elif self.flags.load < 0:
            self.checkpoint.restore(manager.latest_checkpoint)
            self.logger.info(f"Restored from {manager.latest_checkpoint}")
        elif self.flags.load >= 1:
            idx = self.flags.load - 1
            self.checkpoint.restore(manager.checkpoints[idx])
            self.logger.info(f"Restored from {manager.checkpoints[idx]}")
        else:
            self.logger.info("Initializing network weights from scratch.")
Beispiel #15
0
import os
import sys

sys.path.append(os.getcwd())
from beagles.io.flags import SharedFlagIO
from beagles.backend.net import NetBuilder, train, predict, annotate

if __name__ == '__main__':
    io = SharedFlagIO(subprogram=True)
    flags = io.read_flags()
    flags.started = True
    net_builder = NetBuilder(flags=flags)
    net, framework, manager = net_builder()
    flags = io.read_flags()
    if flags.train:
        train(net_builder.annotation_data, net_builder.class_weights, flags,
              net, framework, manager)
    elif flags.video:
        annotate(flags, net, framework)
    else:
        predict(flags, net, framework)
    flags = io.read_flags()
    flags.progress = 100.0
    flags.done = True
    io.io_flags()
    exit(0)
Beispiel #16
0
class TFNet:
    # Interface Methods:
    def __init__(self, flags, darknet=None):
        self.io = SharedFlagIO(subprogram=True)
        # disable eager mode for TF1-dependent code
        tf.compat.v1.disable_eager_execution()
        self.flags = self.io.read_flags() if self.io.read_flags(
        ) is not None else flags
        self.io_flags = self.io.io_flags
        self.logger = get_logger()
        self.ntrain = 0
        darknet = Darknet(flags) if darknet is None else darknet
        self.ntrain = len(darknet.layers)
        self.darknet = darknet
        self.num_layer = len(darknet.layers)
        self.framework = Framework.create(darknet.meta, flags)
        self.annotation_data = self.framework.parse()
        self.meta = darknet.meta
        self.graph = tf.Graph()
        device_name = flags.gpu_name if flags.gpu > 0.0 else None
        start = time.time()
        with tf.device(device_name):
            with self.graph.as_default():
                self.build_forward()
                self.setup_meta_ops()
        self.logger.info('Finished in {}s'.format(time.time() - start))

    def raise_error(self, error: Exception, traceback=None):
        form = "{}\nOriginal Tensorflow Error: {}"
        try:
            raise error
        except Exception as e:
            if traceback:
                oe = traceback.message
                self.flags.error = form.format(str(e), oe)
            else:
                self.flags.error = str(e)
            self.logger.error(str(e))
            self.io.send_flags()
            raise

    def build_forward(self):
        # Placeholders
        inp_size = self.meta['inp_size']
        self.inp = tf.keras.layers.Input(dtype=tf.float32,
                                         shape=tuple(inp_size),
                                         name='input')
        self.feed = dict()  # other placeholders

        # Build the forward pass
        state = identity(self.inp)
        roof = self.num_layer - self.ntrain
        self.logger.info(LINE)
        self.logger.info(HEADER)
        self.logger.info(LINE)
        for i, layer in enumerate(self.darknet.layers):
            scope = '{}-{}'.format(str(i), layer.type)
            args = [layer, state, i, roof, self.feed]
            state = op_create(*args)
            mess = state.verbalise()
            msg = mess if mess else LINE
            self.logger.info(msg)

        self.top = state
        self.out = tf.identity(state.out, name='output')

    def setup_meta_ops(self):
        tf.config.set_soft_device_placement(False)
        tf.debugging.set_log_device_placement(False)
        utility = min(self.flags.gpu, 1.)
        if utility > 0.0:
            tf.config.set_soft_device_placement(True)
        else:
            self.logger.info('Running entirely on CPU')

        if self.flags.train:
            self.build_train_op()

        if self.flags.summary:
            self.summary_op = tf.compat.v1.summary.merge_all()
            self.writer = tf.compat.v1.summary.FileWriter(
                self.flags.summary + self.flags.project_name)

        self.sess = tf.compat.v1.Session()
        self.sess.run(tf.compat.v1.global_variables_initializer())

        if not self.ntrain:
            return

        try:
            self.saver = tf.compat.v1.train.Saver(
                tf.compat.v1.global_variables())

            if self.flags.load != 0:
                self.load_from_ckpt()

        except tf.errors.NotFoundError as e:
            self.flags.error = str(e.message)
            self.send_flags()
            raise

        if self.flags.summary:
            self.writer.add_graph(self.sess.graph)

    def load_from_ckpt(self):
        if self.flags.load < 0:  # load lastest ckpt

            with open(os.path.join(self.flags.backup, 'checkpoint'), 'r') as f:
                last = f.readlines()[-1].strip()
                load_point = last.split(' ')[1]
                load_point = load_point.split('"')[1]
                print(load_point)
                load_point = load_point.split('-')[-1]
                self.flags.load = int(load_point)

        load_point = os.path.join(self.flags.backup, self.meta['name'])
        load_point = '{}-{}'.format(load_point, self.flags.load)
        self.logger.info('Loading from {}'.format(load_point))
        try:
            self.saver.restore(self.sess, load_point)
        except ValueError:
            self.load_old_graph(load_point)

    def load_old_graph(self, ckpt):
        ckpt_loader = Loader.create(ckpt)
        self.logger.info(old_graph_msg.format(ckpt))

        for var in tf.compat.v1.global_variables():
            name = var.name.split(':')[0]
            args = [name, var.get_shape()]
            val = ckpt_loader(*args)
            if val is None:
                self.raise_error(VariableIsNone(var))
            shp = val.shape
            plh = tf.compat.v1.placeholder(tf.float32, shp)
            op = tf.compat.v1.assign(var, plh)
            self.sess.run(op, {plh: val})

    def build_train_op(self):
        self.framework.loss(self.out)
        self.logger.info('Building {} train op'.format(self.meta['model']))
        self.global_step = tf.Variable(0, trainable=False)
        # setup kwargs for trainer
        kwargs = dict()
        if self.flags.trainer in ['momentum', 'rmsprop', 'nesterov']:
            kwargs.update({'momentum': self.flags.momentum})
        if self.flags.trainer == 'nesterov':
            kwargs.update({self.flags.trainer: True})
        if self.flags.trainer == 'AMSGrad':
            kwargs.update({self.flags.trainer.lower(): True})
        if self.flags.clip:
            kwargs.update({'clipnorm': self.flags.clip_norm})

        # setup cyclic_learning_rate args
        ssc = self.flags.step_size_coefficient
        step_size = int(ssc * (len(self.annotation_data) // self.flags.batch))
        clr_kwargs = {
            'global_step': self.global_step,
            'mode': self.flags.clr_mode,
            'step_size': step_size,
            'learning_rate': self.flags.lr,
            'max_lr': self.flags.max_lr,
            'name': 'learning-rate'
        }

        # setup trainer
        self.optimizer = TRAINERS[self.flags.trainer](clr(**clr_kwargs),
                                                      **kwargs)

        # setup gradients for all globals except the global_step
        vars = tf.compat.v1.global_variables()[:-1]  #
        grads = self.optimizer.get_gradients(self.framework.loss, vars)
        self.train_op = self.optimizer.apply_gradients(zip(grads, vars))