Пример #1
0
    def __init__(self,
                 model_name='default',
                 model_group='default',
                 model_type=None,
                 filter_scale=0,
                 input_shape=None,  # TBD
                 load_args=False,
                 **kwargs
                 ):
        self.model_args = {
            'model_name': model_name,
            'model_group': model_group,
            'model_type': model_type,
            'filter_scale': filter_scale,
            'input_shape': input_shape
        }
        for key in kwargs:
            self.model_args[key] = kwargs[key]

        self.model_dir = gouda.ensure_dir(project_path('results', model_group, model_name))

        args_path = self.model_dir('model_args.json')
        if load_args:
            if args_path.exists():
                self.model_args = gouda.load_json(args_path)
                self.model_args['model_name'] = model_name
                self.model_args['model_group'] = model_group
            else:
                raise ValueError("Cannot load model args. No file found at: {}".format(args_path.path))
        if not args_path.exists():
            # Only models without pre-existing model args will save in order to prevent overwriting
            gouda.save_json(self.model_args, args_path)

        self.compile()
Пример #2
0
 def __init__(self):
     config_path = os.path.join(os.path.dirname(__file__),
                                'model_args.json')
     self.model_args = gouda.load_json(config_path)
     self.compile_model()
     self.model.load_weights(
         os.path.join(os.path.dirname(__file__),
                      'model_weights/model_weights.tf'))
Пример #3
0
    def __init__(self,
                 model_name='default',
                 model_group='default',
                 model_type='template',
                 filter_scale=0,
                 num_outputs=2,
                 input_shape=[512, 512, 1],
                 load_args=False,
                 **kwargs):
        """Initialize a model for the network

        Parameters
        ----------
        model_name : str
            The name of the model to use - should define the model level parameters
        model_group : str
            The group of models to use - should define the model structure or data paradigm
        filter_scale : int
            The scaling factor to use for the model layers (scales by powers of 2)
        input_shape : tuple of ints
            The shape of data to be passed to the model (not including batch size)
        load_args : bool
            Whether to use pre-existing arguments for the given model group+name
        """
        if input_shape is None and load_args is False:
            raise ValueError("Input shape cannot be None for model object")
        self.model_dir = gouda.GoudaPath(
            gouda.ensure_dir(RESULTS_DIR, model_group, model_name))
        if load_args:
            if self.model_dir('model_args.json').exists():
                self.model_args = gouda.load_json(
                    self.model_dir('model_args.json'))
            else:
                raise ValueError(
                    "Cannot find model args for model {}/{}".format(
                        model_group, model_name))
        else:
            self.model_args = {
                'model_name': model_name,
                'model_group': model_group,
                'model_type': model_type,
                'filter_scale': filter_scale,
                'input_shape': input_shape,
            }
            for key in kwargs:
                self.model_args[key] = kwargs[key]
            gouda.save_json(self.model_args, self.model_dir('model_args.json'))
        K.clear_session()
        self.model = lookup_model(model_type)(**self.model_args)
Пример #4
0
    def load_args(self, args_path):
        """Load model arguments from a json file.

        NOTE
        ----
        Custom methods/models/etc will be loaded with a value of 'custom' and should be replaced
        """
        if os.path.exists(args_path):
            to_warn = []
            loaded_args = gouda.load_json(args_path)
            for key in loaded_args:
                self.model_args[key] = loaded_args[key]
                if isinstance(loaded_args[key],
                              str) and loaded_args[key].startswith('CUSTOM: '):
                    to_warn.append(key)
            if len(to_warn) > 0:
                warnings.warn(
                    'Custom arguments for [{}] were found and should be replaced'
                    .format(', '.join(to_warn)))
        else:
            raise ValueError('No file found at {}'.format(args_path))
Пример #5
0
    def __init__(self,
                 model_name='default',
                 model_group='default',
                 model_type='multires',
                 filter_scale=0,
                 out_layers=1,
                 out_classes=2,
                 input_shape=[1024, 1360, 1],
                 patch_out=False,
                 load_args=False,
                 **kwargs):
        K.clear_session()
        self.loaded = False

        self.model_dir = gouda.GoudaPath(
            gouda.ensure_dir(RESULTS_DIR, model_group, model_name))
        args_path = self.model_dir / 'model_args.json'
        if load_args:
            if not args_path.exists():
                raise ValueError("No model arguments found at path: {}".format(
                    args_path.abspath))
            self.model_args = gouda.load_json(args_path)
        else:
            self.model_args = {
                'model_name': model_name,
                'model_group': model_group,
                'model_type': model_type,
                'filter_scale': filter_scale,
                'out_layers': out_layers,
                'out_classes': out_classes,
                'input_shape': input_shape,
                'patch_out': patch_out
            }
            for key in kwargs:
                self.model_args[key] = kwargs[key]
            gouda.save_json(self.model_args, args_path)

        model_func = get_model_func(self.model_args['model_type'])
        self.model = model_func(**self.model_args)
Пример #6
0
    def train(self,
              train_data,
              val_data,
              metrics=None,
              starting_epoch=1,
              lr_type=None,
              loss_type=None,
              epochs=50,
              save_every=10,
              load_args=False,
              reduce_lr_on_plateau=False,
              extra_callbacks=[],
              version='default',
              **kwargs):
        """NOTE: logging_handler from the CoreModel is replaced by metrics in the distributed. This model relies on the keras model.fit methods more than the custom training loop.

        NOTE: extra_callbacks should not include TensorBoard or ModelCheckpoint, since they are already used. ReduceLROnPlateau and LRLogger will be already included as well if reduce_lr_on_plateau is True.
        """
        log_dir = gouda.ensure_dir(self.model_dir(version))
        args_path = log_dir('training_args.json')
        weights_dir = gouda.ensure_dir(log_dir('training_weights'))
        train_args = {
            'epochs': epochs,
            'lr_type': lr_type,
            'loss_type': loss_type
        }
        for key in kwargs:
            train_args[key] = kwargs[key]
        if reduce_lr_on_plateau:
            if 'plateau_factor' not in train_args:
                train_args['plateau_factor'] = 0.1
            if 'plateau_patience' not in train_args:
                train_args['plateau_patience'] = 3
        if load_args:
            if args_path.exists():
                train_args = gouda.load_json(args_path)
            else:
                raise ValueError("No training args file found at `{}`".format(
                    args_path.abspath))
        # TODO - check to see if dataset is distributed, requires manual batchsize if so
        # train_args['batch_size'] = 8
        # train_args['val_batch_size'] = 8
        for item in train_data.take(1):
            train_args['batch_size'] = item[0].numpy().shape[0]
        for item in val_data.take(1):
            train_args['val_batch_size'] = item[0].numpy().shape[0]

        self.compile_model(checking=True)
        if starting_epoch == 1:
            self.save_weights(weights_dir('model_weights_init.tf').abspath)

        # Set learning rate type and optimizer
        optimizer = self._setup_optimizer(train_args)

        # Set loss type
        if train_args['loss_type'] is None:
            if self.model_args['loss_type'] is None:
                raise ValueError("No loss function defined")
            train_args['loss_type'] = self.model_args['loss_type']
        if isinstance(train_args['loss_type'], str):
            loss = get_loss_func(train_args['loss_type'])(**train_args)
        else:
            loss = train_args['loss_type']
            train_args['loss_type'] = str(loss)

        # Currently, just uses the default training/validation steps

        # Save training args as json
        # save_args = train_args.copy()
        # for key in train_args:
        #     key_type = str(type(save_args[key]))
        #     if 'function' in key_type or 'class' in key_type:
        #         save_args[key] = 'custom'
        save_args = clean_for_json(train_args.copy())
        gouda.save_json(save_args, args_path)

        checkpoint_prefix = weights_dir('model_weights_e{epoch}').abspath
        callbacks = [
            tf.keras.callbacks.TensorBoard(log_dir=log_dir.abspath),
            tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
                                               save_weights_only=True)
        ]
        if reduce_lr_on_plateau:
            callbacks.append(
                tf.keras.callbacks.ReduceLROnPlateau(
                    monitor='val_loss',
                    factor=train_args['plateau_factor'],
                    patience=train_args['plateau_patience']))
            callbacks.append(LRLogger())
        callbacks += extra_callbacks
        with self.strategy.scope():
            self.model.compile(loss=loss, optimizer=optimizer, metrics=metrics)

        try:
            self.model.fit(train_data,
                           validation_data=val_data,
                           epochs=epochs,
                           initial_epoch=starting_epoch - 1,
                           callbacks=callbacks)
        except KeyboardInterrupt:
            print("\nInterrupting model training...")
        self.save_weights(log_dir('model_weights.tf').abspath)
Пример #7
0
    def train(self,
              train_data,
              val_data,
              starting_epoch=1,
              epochs=200,
              save_every=-1,
              version_name='default',
              load_args=False,
              **kwargs):
        train_args = kwargs
        for x, y in train_data.take(1):
            batch_size = y.numpy().shape[0]
        for x, y in val_data.take(1):
            val_batch_size = y.numpy().shape[0]

        version_dir = self.model_dir / version_name
        args_path = version_dir / 'training_args.json'
        if load_args:
            if not args_path.exists():
                raise ValueError(
                    "No training arguments found at path: {}".format(
                        args_path.abspath))
            train_args = gouda.load_json(args_path)
        else:
            defaults = {
                'learning_rate': 1e-4,
                'lr_decay_rate': None,
                'lr_decay_steps': None,
                'label_smoothing': 0.05,
            }
            train_args['version_name':version_name]
            for key in defaults:
                if key not in train_args:
                    train_args[key] = defaults[key]
            train_args['batch_size'] = batch_size
            gouda.save_json(train_args, args_path)

        if train_args['lr_decay_rate'] is not None and train_args[
                'lr_decay_steps'] is not None:
            lr = tf.keras.optimizers.schedules.ExponentialDecay(
                initial_learning_rate=train_args['learning_rate'],
                decay_steps=train_args['lr_decay_steps'],
                decay_rate=train_args['lr_decay_rate'])
        else:
            lr = train_args['learning_rate']
        opt = tf.keras.optimizers.Adam(learning_rate=lr)
        loss_func = tf.keras.losses.BinaryCrossentropy(
            from_logits=False, label_smoothing=train_args['label_smoothing'])

        train_writer = tf.summary.create_file_writer(version_dir / 'train')
        train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
        train_bal = BalanceMetric('train_balance')
        train_acc = [
            tf.keras.metrics.BinaryAccuracy('train_accuracy_{}'.format(i))
            for i in range(self.model_args['out_classes'])
        ]
        train_mcc = [
            MatthewsCorrelationCoefficient(name='train_mcc_{}'.format(i))
            for i in range(self.model_args['out_classes'])
        ]

        val_writer = tf.summary.create_file_writer(version_dir / 'val')
        val_loss = tf.keras.metrics.Mean('val_loss', dtype=tf.float32)
        val_bal = BalanceMetric('val_balance')
        val_acc = [
            tf.keras.metrics.BinaryAccuracy('val_accuracy_{}'.format(i))
            for i in range(self.model_args['out_classes'])
        ]
        val_mcc = [
            MatthewsCorrelationCoefficient(name='val_mcc_{}'.format(i))
            for i in range(self.model_args['out_classes'])
        ]

        def train_step(model, optimizer, x, y, num_classes):
            with tf.GradientTape() as tape:
                predicted = model(x, training=True)
                loss = loss_func(y, predicted)
            grad = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grad, model.trainable_variables))

            train_loss(loss)
            train_bal(y, predicted)
            y_split = tf.split(y, num_classes, axis=-1)
            pred_split = tf.split(predicted, num_classes, axis=-1)
            for acc, mcc, y, pred in zip(train_acc, train_mcc, y_split,
                                         pred_split):
                acc(y, pred)
                mcc(y, pred)

        def val_step(model, x, y, num_classes):
            predicted = model(x, training=False)
            loss = loss_func(y, predicted)
            val_loss(loss)
            val_bal(y, predicted)
            y_split = tf.split(y, num_classes, axis=-1)
            pred_split = tf.split(predicted, num_classes, axis=-1)
            for acc, mcc, y, pred in zip(val_acc, val_mcc, y_split,
                                         pred_split):
                acc(y, pred)
                mcc(y, pred)

        train_step = tf.function(train_step)
        val_step = tf.function(val_step)

        epoch_pbar = tqdm.tqdm(total=spochs,
                               unit=' epochs',
                               initial=starting_epoch)

        val_steps = StableCounter()
        val_batch_pbar = tqdm.tqdm(total=val_steps(),
                                   unit=' val samples',
                                   leave=False)
        for image, label in val_data:
            val_step(self.model, image, label, self.model_args['out_classes'])
            val_batch_pbar.update(val_batch_size)
            val_steps += val_batch_size
        val_steps.stop()
        val_batch_pbar.close()

        logstring = 'Untrained: '
        with val_writer.as_default():
            tf.summary.scalar('loss', val_loss.result(), step=0)
            logstring += 'Val Loss: {:.4f}'.format(val_loss.result())
            tf.summary.scalar('balance', val_bal.result(), step=0)
            logstring += ', Val Balance: {:.4f}'.format(val_bal.result())
            accuracies = []
            for i, acc in enumerate(val_acc):
                tf.summary.scalar('accuracy_{}'.format(i),
                                  acc.result(),
                                  step=0)
                accuracies.append('{:.2f}'.format(acc.result() * 100))
            logstring += ", Val Accuracy " + '/'.join(accuracies)
            mccs = []
            for i, mcc in enumerate(val_mcc):
                tf.summary.scalar('mcc_{}'.format(i), mcc.result(), step=0)
                mccs.append("{:.4f}".format(mcc.result()))
            logstring += ", Val MCC " + '/'.join(mccs)
        epoch_pbar.write(logstring)
        val_loss.reset_states()
        val_bal.reset_states()
        for acc in val_acc:
            acc.reset_states()
        for mcc in val_mcc:
            mcc.reset_states()

        weights_dir = gouda.ensure_dir(version_dir / 'training_weights')
        self.model.save_weights(weights_dir / 'initial_weights.h5')
        train_steps = StableCounter()

        try:
            for epoch in range(starting_epoch, epochs):
                train_batch_pbar = tqdm.tqdm(total=train_steps(),
                                             unit=' samples',
                                             leave=False)
                for image, label in train_data:
                    train_step(self.model, opt, image, label,
                               self.model_args['train_classes'])
                    train_batch_pbar.update(train_args['batch_size'])
                    train_steps += train_args['batch_size']
                train_steps.stop()
                train_batch_pbar.close()
                logstring = "Epoch {:04d}".format(epoch)
                with train_writer.as_default():
                    if not isinstance(lr, float):
                        tf.summary.scalar('lr', lr(opt.iterations), step=epoch)
                    tf.summar.scalar('loss', train_loss.result(), step=epoch)
                    logstring += ', Loss: {:.4f}'.format(train_loss.result())
                    tf.summar.scalar('balance', train_bal.result(), step=epoch)
                    logstring += ', Balance: {:.4f}'.format(train_bal.result())
                    accuracies = []
                    for i, acc in enumerate(train_acc):
                        tf.summary.scalar('accuracy_{}'.format(i),
                                          acc.result(),
                                          step=epoch)
                        accuracies.append("{:.2f}".format(acc.result() * 100))
                    logstring += ', Accuracy: ' + '/'.join(accuracies)
                    mccs = []
                    for i, mcc in enumerate(train_mcc):
                        tf.summary.scalar('mcc_{}'.format(i),
                                          mcc.result(),
                                          step=epoch)
                        mccs.append("{:.4f}".format(mcc.result()))
                    logstring += ', MCC: ' + '/'.join(mccs)
                train_loss.reset_states()
                train_bal.reset_states()
                for acc in train_acc:
                    acc.reset_states()
                for mcc in train_mcc:
                    mcc.reset_states()

                val_batch_pbar = tqdm.tqdm(total=val_steps(),
                                           unit=' val samples',
                                           leave=False)
                for image, label in val_data:
                    val_step(self.model, image, label,
                             self.model_args['train_classes'])
                    val_batch_pbar.update(val_batch_size)
                val_batch_pbar.close()
                logstring += ' || '
                with val_writer.as_default():
                    tf.summary.scalar('loss', val_loss.result(), step=epoch)
                    logstring += 'Val Loss: {:.4f}'.format(val_loss.result())
                    tf.summary.scalar('balance', val_bal.result(), step=epoch)
                    logstring += 'Val Balance: {:.4f}'.format(val_bal.result())
                    accuracies = []
                    for i, acc in enumerate(val_acc):
                        tf.summary.scalar('accuracy_{}'.format(i),
                                          acc.result(),
                                          step=epoch)
                        accuracies.append("{:.2f}".format(acc.result() * 100))
                    logstring += ', Val Accuracy: ' + '/'.join(accuracies)
                    mccs = []
                    for i, mcc in enumerate(val_mcc):
                        tf.summar.scalar('mcc_{}'.format(i),
                                         mcc.result(),
                                         step=epoch)
                        mccs.append("{:.4f}".format(mcc.result()))
                    logstring += ', Val MCC: ' + '/'.join(mccs)
                val_loss.reset_states()
                val_bal.reset_states()
                for acc in val_acc:
                    acc.reset_states()
                for mcc in val_mcc:
                    mcc.reset_states()
                epoch_pbar.write(log_string)

                if (epoch + 1) % save_every == 0 and save_every != -1:
                    self.model.save_weights(
                        weights_dir / 'model_weights_e{:03d}.h5'.format(epoch))
                epoch_pbar.update(1)
        except KeyboardInterrupt:
            print("KeyboardInterrupt - stopping training...")
        self.model.save_weights(version_dir / 'model_weights.h5')
        epoch_pbar.close()
Пример #8
0
    def train(self,
              train_data,
              val_data,
              starting_epoch=1,
              epochs=200,
              save_every=-1,
              version_name='default',
              load_args=False,
              **kwargs):
        """Train the model

        Parameters
        ----------
        train_data : tf.data.Dataset
            The data to train on
        val_data : tf.data.Dataset
            The data to validate on
        starting_epoch : int
            The epoch to start on - can be set greater than 1 to continue previous training
        epochs : int
            The epoch to end on - if starting epoch is greater than 1, the model will still only train until it reaches this total epoch count
        save_every : int
            The number of epochs between each set of model weights to save
        version_name : str
            The name of the model to train - version name should group training/hyperparameters
        load_args : bool
            Whether to use pre-existing parameters for the given model group+name+version
        """
        version_dir = gouda.ensure_dir(self.model_dir / version_name)
        weights_dir = gouda.ensure_dir(version_dir / 'training_weights')
        if load_args:
            if version_dir('train_args.json').exists():
                train_args = gouda.load_json(version_dir('train_args.json'))
            else:
                raise ValueError("No existing args found for {}/{}/{}".format(
                    self.model_args['model_group'],
                    self.model_args['model_name'], version_name))
        else:
            defaults = {
                'learning_rate': 1e-4,
                'lr_decay_steps': None,
                'lr_decay_rate': None,
                'label_smoothing': 0.05
            }
            train_args = kwargs
            for item in defaults:
                if item not in train_args:
                    train_args[item] = defaults[item]

        for x, y in train_data.take(1):
            batch_size = y.numpy().shape[0]
        for x, y in val_data.take(1):
            val_batch_size = y.numpy().shape[0]

        train_args['batch_size'] = batch_size
        train_args['val_batch_size'] = val_batch_size

        gouda.save_json(train_args, version_dir / 'train_args.json')

        if train_args['lr_decay_rate'] is not None and train_args[
                'lr_decay_steps'] is not None:
            lr = tf.keras.optimizers.schedules.ExponentialDecay(
                initial_learning_rate=train_args['learning_rate'],
                decay_steps=train_args['lr_decay_steps'],
                decay_rate=train_args['lr_decay_rate'])
        else:
            lr = train_args['learning_rate']
        opt = tf.keras.optimizers.Adam(learning_rate=lr)
        loss_func = tf.keras.losses.BinaryCrossentropy(
            from_logits=False, label_smoothing=train_args['label_smoothing'])

        # set up tensorboard

        train_writer = tf.summary.create_file_writer(version_dir / 'train')
        train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
        train_acc = tf.keras.metrics.BinaryAccuracy('train_accuracy')

        val_writer = tf.summary.create_file_writer(version_dir / 'val')
        val_loss = tf.keras.metrics.Mean('val_loss', dtype=tf.float32)
        val_acc = tf.keras.metrics.BinaryAccuracy('val_accuracy')

        # set up train/val steps
        def train_step(model, optimizer, x, y):
            with tf.GradientTape() as tape:
                predicted = model(x, training=True)
                loss = loss_func(y, predicted)
            grad = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grad, model.trainable_variables))

            train_loss(loss)
            train_acc(y, predicted)

        def val_step(model, x, y):
            predicted = model(x, training=False)
            loss = loss_func(y, predicted)
            val_loss(loss)
            val_acc(y, predicted)

        # training loop
        epoch_pbar = tqdm.tqdm(total=epochs,
                               unit=' epochs',
                               initial=starting_epoch)

        # Baseline Validation
        val_steps = StableCounter()
        val_batch_pbar = tqdm.tqdm(total=None,
                                   unit=' val samples',
                                   leave=False)
        for image, label in val_data:
            val_step(self.model, image, label)
            val_batch_pbar.update(val_batch_size)
            val_steps += val_batch_size
        val_steps.stop()
        with val_writer.as_default():
            tf.summary.scalar('loss', val_loss.result(), step=0)
            tf.summary.scalar('accuracy', val_acc.result(), step=0)
        log_string = 'Untrained: Val Loss: {:.4f}, Val Accuracy: {:6.2f}'.format(
            val_loss.result(),
            val_acc.result() * 100)
        val_batch_pbar.close()
        val_loss.reset_states()
        val_acc.reset_states()
        epoch_pbar.write(log_string)

        train_steps = StableCounter()
        self.model.save_weights(weights_dir('initial_weights.h5').abspath)
        try:
            for epoch in range(starting_epoch, epochs):
                train_batch_pbar = tqdm.tqdm(total=train_steps(),
                                             unit=' samples',
                                             leave=False)
                for image, label in train_data:
                    train_step(self.model, opt, image, label)
                    train_batch_pbar.update(batch_size)
                    train_steps += batch_size
                train_steps.stop()
                with train_writer.as_default():
                    if not isinstance(lr, float):
                        tf.summary.scalar('lr', lr(opt.iterations), step=epoch)
                    tf.summary.scalar('loss', train_loss.result(), step=epoch)
                    tf.summary.scalar('accuracy',
                                      train_acc.result(),
                                      step=epoch)
                train_batch_pbar.close()
                log_string = 'Epoch {:04d}, Loss: {:.4f}, Accuracy: {:6.2f}'.format(
                    epoch, train_loss.result(),
                    train_acc.result() * 100)

                val_batch_pbar = tqdm.tqdm(total=val_steps())
                for image, label in val_data:
                    val_step(self.model, image, label)
                    val_batch_pbar.update(val_batch_size)
                with val_writer.as_default():
                    tf.summary.scalar('loss', val_loss.result(), step=epoch)
                    tf.summary.scalar('accuracy', val_acc.result(), step=epoch)
                val_batch_pbar.close()
                log_string += ' || Val Loss: {:.4f}, Val Accuracy: {:6.2f}'.format(
                    val_loss.result(),
                    val_acc.result() * 100)
                if (epoch + 1) % save_every == 0 and save_every != -1:
                    self.model.save_weights(
                        weights_dir(
                            'model_weights_e{:03d}.h5'.format(epoch)).abspath)

                epoch_pbar.write(log_string)
                train_loss.reset_states()
                train_acc.reset_states()
                val_loss.reset_states()
                val_acc.reset_states()
                epoch_pbar.update(1)
        except KeyboardInterrupt:
            epoch_pbar.write('Stopping model with keypress...')
        epoch_pbar.close()
        self.model.save_weights(version_dir('model_weights.h5').abspath)
Пример #9
0
    def train(
        self,
        train_data,
        val_data,
        num_train_samples=None,
        num_val_samples=None,
        starting_epoch=1,
        epochs=50,
        model_version='default',
        load_args=False,
        plot_model=True,
        learning_rate=1e-4,
        label_smoothing=0.1,
        loss_func='mixed',
        save_every=10,
        **kwargs
    ):
        # Set up directories
        log_dir = gouda.ensure_dir(self.model_dir(model_version))
        args_path = log_dir('training_args.json')
        weights_dir = gouda.ensure_dir(log_dir('training_weights'))

        # Set up training args
        train_args = {
            'learning_rate': learning_rate,
            'label_smoothing': label_smoothing,
            'loss_function': loss_func,
            'lr_exp_decay_rate': None,  # Multiplier for exponential decay (ie 0.2 means lr_2 = 0.2 * lr_1)
            'lr_exp_decay_steps': None,  # Steps between exponential lr decay
            'lr_cosine_decay': False,  # Whether to use cosine lr decay
            'lr_cosine_decay_steps': epochs,  # The number of steps to reach the minimum lr
            'lr_cosine_decay_min_lr': 0.0  # The minimum lr when using steps < epochs
        }
        for key in kwargs:
            train_args[key] = kwargs[key]
        if load_args:
            if args_path.exists():
                train_args = gouda.load_json(args_path.abspath)
            else:
                raise ValueError("Cannot load training args. No file found at: {}".format(args_path.path))

        for x, y in train_data.take(1):
            train_args['train_batch_size'] = y.numpy().shape[0]
        for x, y in val_data.take(1):
            train_args['val_batch_size'] = y.numpy().shape[0]

        gouda.save_json(train_args, args_path)

        # Save initial weights and model structure
        if self.model is None:
            self.compile()
        self.model.save_weights(weights_dir('model_weights_init.h5').abspath)
        if plot_model:
            tf.keras.utils.plot_model(self.model, to_file=log_dir('model.png').abspath, show_shapes=True)

        # Set up loss
        if train_args['lr_exp_decay_rate'] is not None and train_args['lr_exp_decay_steps'] is not None:
            lr = tf.keras.optimizers.schedules.ExponentialDecay(
                initial_learning_rate=train_args['learning_rate'],
                decay_steps=train_args['lr_exp_decay_steps'],
                decay_rate=train_args['lr_exp_decay_rate']
            )
        elif train_args['lr_cosine_decay']:
            alpha = train_args['lr_cosine_decay_min_lr'] / train_args['learning_rate']
            lr = tf.keras.experimental.CosineDecay(train_args['learning_rate'], train_args['lr_cosine_decay_steps'], alpha=alpha)
        else:
            lr = train_args['learning_rate']
        opt = tf.keras.optimizers.Adam(learning_rate=lr)
        loss_func = get_loss_func(train_args['loss_function'], **train_args)
        # if train_args['loss_function'] == 'bce':
        #     loss_func = tf.keras.losses.BinaryCrossentropy(from_logits=False, label_smoothing=train_args['label_smoothing'])
        # elif train_args['loss_function'] == 'iou':
        #     loss_func = IOU_loss
        # elif train_args['loss_function'] == 'mixed':
        #     loss_func = mixed_IOU_BCE_loss(train_args['loss_alpha'], train_args['label_smoothing'])
        # elif
        # else:
        #     raise NotImplementedError("Loss function `{}` hasn't been added yet".format(train_args['loss_function']))

        # Set up logging
        train_writer = tf.summary.create_file_writer(log_dir('train').path)
        train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
        train_acc = tf.keras.metrics.BinaryAccuracy('train_accuracy')
        train_bal = BalanceMetric('train_balance')
        train_string = "Loss: {:.4f}, Accuracy: {:6.2f}, Balance: {:.2f}"

        val_writer = tf.summary.create_file_writer(log_dir('val').path)
        val_loss = tf.keras.metrics.Mean('val_loss', dtype=tf.float32)
        val_acc = tf.keras.metrics.BinaryAccuracy('val_accuracy')
        val_bal = BalanceMetric('val_balance')
        val_string = " || Val Loss: {:.4f}, Val Accuracy: {:6.2f}, Val Balance: {:.2f}"

        # Define train/val steps
        def train_step(model, optimizer, x, y):
            with tf.GradientTape() as tape:
                predicted = model(x, training=True)
                loss = loss_func(y, predicted)
            grad = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grad, model.trainable_variables))

            train_loss(loss)
            train_acc(y, predicted)
            train_bal(y, predicted)

        def val_step(model, x, y):
            predicted = model(x, training=False)
            loss = loss_func(y, predicted)
            val_loss(loss)
            val_acc(y, predicted)
            val_bal(y, predicted)

        train_step = tf.function(train_step)
        val_step = tf.function(val_step)

        train_steps = StableCounter()
        if num_train_samples is not None:
            train_steps.set(num_train_samples)
        val_steps = StableCounter()
        if num_val_samples is not None:
            val_steps.set(num_val_samples)

        # Training loop
        epoch_pbar = tqdm.tqdm(total=epochs, unit=' epochs', initial=starting_epoch)
        val_batch_pbar = tqdm.tqdm(total=val_steps(), unit=' val samples', leave=False)
        for image, label in val_data:
            val_step(self.model, image, label)
            val_batch_pbar.update(train_args['val_batch_size'])
            val_steps += train_args['val_batch_size']
        val_steps.stop()
        with val_writer.as_default():
            tf.summary.scalar('loss', val_loss.result(), step=0)
            tf.summary.scalar('accuracy', val_acc.result(), step=0)
            tf.summary.scalar('balance', val_bal.result(), step=0)
        epoch_pbar.write('Pretrained - Val Loss: {:.4f}, Val Accuracy: {:6.2f}, Val Balance: {:.2f}'.format(val_loss.result(), 100 * val_acc.result(), val_bal.result()))
        val_batch_pbar.close()
        val_loss.reset_states()
        val_acc.reset_states()
        val_bal.reset_states()

        try:
            for epoch in range(starting_epoch, epochs):
                log_string = 'Epoch {:3d} - '.format(epoch)
                # Training loop
                train_batch_pbar = tqdm.tqdm(total=train_steps(), unit=' samples', leave=False)
                for image, label in train_data:
                    train_step(self.model, opt, image, label)
                    train_batch_pbar.update(train_args['train_batch_size'])
                    train_steps += train_args['train_batch_size']
                with train_writer.as_default():
                    tf.summary.scalar('loss', train_loss.result(), step=epoch)
                    tf.summary.scalar('accuracy', train_acc.result(), step=epoch)
                    tf.summary.scalar('balance', train_bal.result(), step=epoch)
                log_string += train_string.format(train_loss.result(), train_acc.result() * 100, train_bal.result())
                train_batch_pbar.close()

                # Validation Loop
                val_batch_pbar = tqdm.tqdm(total=val_steps(), unit=' val samples', leave=False)
                for image, label in val_data:
                    val_step(self.model, image, label)
                    val_batch_pbar.update(train_args['val_batch_size'])
                with val_writer.as_default():
                    tf.summary.scalar('loss', val_loss.result(), step=epoch)
                    tf.summary.scalar('accuracy', val_acc.result(), step=epoch)
                    tf.summary.scalar('balance', val_bal.result(), step=epoch)
                log_string += val_string.format(val_loss.result(), val_acc.result() * 100, val_bal.result())
                val_batch_pbar.close()

                if (epoch + 1) % 10 == 0:
                    self.model.save_weights(weights_dir('model_weights_e{:03d}.h5'.format(epoch)).path)

                epoch_pbar.write(log_string)
                train_loss.reset_states()
                train_acc.reset_states()
                train_bal.reset_states()
                train_steps.stop()
                val_loss.reset_states()
                val_acc.reset_states()
                val_bal.reset_states()
                epoch_pbar.update(1)
        except KeyboardInterrupt:
            print("Interrupting training...")
        self.model.save_weights(log_dir('model_weights.h5').path)
        epoch_pbar.close()
Пример #10
0
    def train(
            self,
            train_data,
            val_data,
            logging_handler=None,
            starting_epoch=1,
            lr_type=None,
            loss_type=None,
            epochs=50,
            save_every=10,
            early_stopping=10,  # implement later - tie into logging
            load_args=False,
            sample_callback=None,
            version='default',
            **kwargs):
        log_dir = gouda.ensure_dir(self.model_dir(version))
        args_path = log_dir('training_args.json')
        weights_dir = gouda.ensure_dir(log_dir('training_weights'))

        if logging_handler is None:
            logging_handler = EmptyLoggingHandler()

        train_args = {
            'epochs': epochs,
            'lr_type': lr_type,
            'loss_type': loss_type
        }
        for key in kwargs:
            train_args[key] = kwargs[key]
        if load_args:
            if args_path.exists():
                train_args = gouda.load_json(args_path)
            else:
                raise ValueError("No training args file found at `{}`".format(
                    args_path.abspath))
        for item in train_data.take(1):
            train_args['batch_size'] = item[0].numpy().shape[0]
        for item in val_data.take(1):
            train_args['val_batch_size'] = item[0].numpy().shape[0]

        self.compile_model(checking=True)
        if starting_epoch == 1:
            self.save_weights(weights_dir('model_weights_init.tf').abspath)

        # Set learning rate type and optimizer
        optimizer = self._setup_optimizer(train_args)

        # Set loss type
        if train_args['loss_type'] is None:
            if 'loss_type' not in self.model_args or self.model_args[
                    'loss_type'] is None:
                raise ValueError(
                    "No loss function defined. Use keyword 'loss_type' to define loss in model or training arguments."
                )
            train_args['loss_type'] = self.model_args['loss_type']
        if isinstance(train_args['loss_type'], str):
            loss = get_loss_func(train_args['loss_type'])(**train_args)
        else:
            loss = train_args['loss_type']
            train_args['loss_type'] = str(loss)

        # Set training step
        if 'train_step' not in train_args:
            train_args['train_step'] = self.model_args['train_step']
        if isinstance(train_args['train_step'], str):
            train_step = get_update_step(train_args['train_step'],
                                         is_training=True)
        else:
            train_step = train_args['train_step']
            train_args['train_step'] = 'custom_train_func'

        # Set validation step
        if 'val_step' not in train_args:
            train_args['val_step'] = self.model_args['val_step']
        if isinstance(train_args['val_step'], str):
            val_step = get_update_step(train_args['val_step'],
                                       is_training=False)
        else:
            val_step = train_args['val_step']
            train_args['val_step'] = 'custom_val_func'

        # Save training args as json
        save_args = clean_for_json(train_args.copy())
        gouda.save_json(save_args, args_path)

        # Start loggers
        logging_handler.start(log_dir, total_epochs=epochs)
        train_counter = StableCounter()
        if 'train_steps' in train_args:
            train_counter.set(train_args['train_steps'])
        val_counter = StableCounter()
        if 'val_steps' in train_args:
            val_counter.set(train_args['val_steps'])

        train_step = train_step(self.model, optimizer, loss, logging_handler)
        val_step = val_step(self.model, loss, logging_handler)

        epoch_pbar = tqdm.tqdm(total=epochs,
                               unit=' epochs',
                               initial=starting_epoch - 1)
        val_pbar = tqdm.tqdm(total=val_counter(),
                             unit=' val samples',
                             leave=False)
        try:
            for item in val_data:
                val_step(item)
                # logging_handler.val_step(val_step(item))
                batch_size = item[0].shape[0]
                val_counter += batch_size
                val_pbar.update(batch_size)
            log_string = logging_handler.write('Pretrained')
            epoch_pbar.write(log_string)
            val_counter.stop()
        except KeyboardInterrupt:
            if 'val_steps' not in train_args:
                val_counter.reset()
            logging_handler.write("Skipping pre-training validation.")
        epoch_pbar.update(1)
        val_pbar.close()

        try:
            epoch_digits = str(gouda.num_digits(epochs))
            for epoch in range(starting_epoch, epochs):
                train_pbar = tqdm.tqdm(total=train_counter(),
                                       unit=' samples',
                                       leave=False)
                for item in train_data:
                    batch_size = item[0].shape[0]
                    train_counter += batch_size
                    train_pbar.update(batch_size)
                    train_step(item)
                train_pbar.close()
                if sample_callback is not None:
                    sample_callback(self.model, epoch)
                val_pbar = tqdm.tqdm(total=val_counter(), leave=False)
                for item in val_data:
                    batch_size = item[0].shape[0]
                    val_counter += batch_size
                    val_pbar.update(batch_size)
                    val_step(item)
                val_pbar.close()
                train_counter.stop()
                val_counter.stop()
                log_string = logging_handler.write(epoch)
                epoch_pbar.write(log_string)
                epoch_pbar.update(1)
                if (epoch + 1) % 10 == 0:
                    weight_string = 'model_weights_e{:0' + epoch_digits + 'd}.tf'
                    self.save_weights(
                        weights_dir(weight_string.format(epoch)).abspath)

        except KeyboardInterrupt:
            print("Interrupting model training...")
            logging_handler.interrupt()
        epoch_pbar.close()
        self.save_weights(log_dir('model_weights.tf').abspath)
        logging_handler.stop()